{"title": "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm", "book": "Advances in Neural Information Processing Systems", "page_first": 2378, "page_last": 2386, "abstract": "We propose a general purpose variational inference algorithm that forms a natural counterpart of gradient descent for optimization. Our method iteratively transports a set of particles to match the target distribution, by applying a form of functional gradient descent that minimizes the KL divergence. Empirical studies are performed on various real world models and datasets, on which our method is competitive with existing state-of-the-art methods. The derivation of our method is based on a new theoretical result that connects the derivative of KL divergence under smooth transforms with Stein\u2019s identity and a recently proposed kernelized Stein discrepancy, which is of independent interest.", "full_text": "Stein Variational Gradient Descent: A General\n\nPurpose Bayesian Inference Algorithm\n\nQiang Liu\n\nDilin Wang\nDepartment of Computer Science\n\nDartmouth College\nHanover, NH 03755\n\n{qiang.liu, dilin.wang.gr}@dartmouth.edu\n\nAbstract\n\nWe propose a general purpose variational inference algorithm that forms a natural\ncounterpart of gradient descent for optimization. Our method iteratively trans-\nports a set of particles to match the target distribution, by applying a form of\nfunctional gradient descent that minimizes the KL divergence. Empirical studies\nare performed on various real world models and datasets, on which our method is\ncompetitive with existing state-of-the-art methods. The derivation of our method\nis based on a new theoretical result that connects the derivative of KL divergence\nunder smooth transforms with Stein\u2019s identity and a recently proposed kernelized\nStein discrepancy, which is of independent interest.\n\n1\n\nIntroduction\n\nBayesian inference provides a powerful tool for modeling complex data and reasoning under uncer-\ntainty, but casts a long standing challenge on computing intractable posterior distributions. Markov\nchain Monte Carlo (MCMC) has been widely used to draw approximate posterior samples, but is\noften slow and has dif\ufb01culty accessing the convergence. Variational inference instead frames the\nBayesian inference problem into a deterministic optimization that approximates the target distribution\nwith a simpler distribution by minimizing their KL divergence. This makes variational methods\nef\ufb01ciently solvable by using off-the-shelf optimization techniques, and easily applicable to large\ndatasets (i.e., \"big data\") using the stochastic gradient descent trick [e.g., 1]. In contrast, it is much\nmore challenging to scale up MCMC to big data settings [see e.g., 2, 3].\nMeanwhile, both the accuracy and computational cost of variational inference critically depend on\nthe set of distributions in which the approximation is de\ufb01ned. Simple approximation sets, such as\nthese used in the traditional mean \ufb01eld methods, are too restrictive to resemble the true posterior\ndistributions, while more advanced choices cast more dif\ufb01culties on the subsequent optimization tasks.\nFor this reason, ef\ufb01cient variational methods often need to be derived on a model-by-model basis,\ncausing is a major barrier for developing general purpose, user-friendly variational tools applicable\nfor different kinds of models, and accessible to non-ML experts in application domains.\nThis case is in contrast with the maximum a posteriori (MAP) optimization tasks for \ufb01nding the\nposterior mode (sometimes known as the poor man\u2019s Bayesian estimator, in contrast with the full\nBayesian inference for approximating the full posterior distribution), for which variants of (stochastic)\ngradient descent serve as a simple, generic, yet extremely powerful toolbox. There has been a recent\ngrowth of interest in creating user-friendly variational inference tools [e.g., 4\u20137], but more efforts are\nstill needed to develop more ef\ufb01cient general purpose algorithms.\nIn this work, we propose a new general purpose variational inference algorithm which can be treated\nas a natural counterpart of gradient descent for full Bayesian inference (see Algorithm 1). Our\n\n30th Conference on Neural Information Processing Systems (NIPS 2016), Barcelona, Spain.\n\n\falgorithm uses a set of particles for approximation, on which a form of (functional) gradient descent\nis performed to minimize the KL divergence and drive the particles to \ufb01t the true posterior distribution.\nOur algorithm has a simple form, and can be applied whenever gradient descent can be applied. In\nfact, it reduces to gradient descent for MAP when using only a single particle, while automatically\nturns into a full Bayesian sampling approach with more particles.\nUnderlying our algorithm is a new theoretical result that connects the derivative of KL divergence\nw.r.t. smooth variable transforms and a recently introduced kernelized Stein discrepancy [8\u201310],\nwhich allows us to derive a closed form solution for the optimal smooth perturbation direction that\ngives the steepest descent on the KL divergence within the unit ball of a reproducing kernel Hilbert\nspace (RKHS). This new result is of independent interest, and can \ufb01nd wide application in machine\nlearning and statistics beyond variational inference.\n\n2 Background\n\nPreliminary Let x be a continuous random variable or parameter of interest taking values in\nX \u2282 Rd, and {Dk} is a set of i.i.d. observation. With prior p0(x), Bayesian inference of x involves\nk=1 p(Dk|x), where\n\nreasoning with the posterior distribution p(x) := \u00afp(x)/Z with \u00afp(x) := p0(x)(cid:81)N\nZ =(cid:82) \u00afp(x)dx is the troublesome normalization constant. We have dropped the conditioning on data\n(RKHS) H of k(x, x(cid:48)) is the closure of linear span {f : f (x) =(cid:80)m\nN, xi \u2208 X}, equipped with inner products (cid:104)f, g(cid:105)H =(cid:80)\ninner product (cid:104)f , g(cid:105)Hd = (cid:80)d\n\n{Dk} in p(x) for convenience.\nLet k(x, x(cid:48)) : X \u00d7 X \u2192 R be a positive de\ufb01nite kernel. The reproducing kernel Hilbert space\ni=1 aik(x, xi), ai \u2208 R, m \u2208\ni bik(x, xi).\nDenote by Hd the space of vector functions f = [f1, . . . , fd](cid:62) with fi \u2208 H, equipped with\ni=1(cid:104)fi, gi(cid:105)H. We assume all the vectors are column vectors. Let\n\u2207xf = [\u2207xf1, . . . ,\u2207xfd].\nStein\u2019s Identity and Kernelized Stein Discrepancy Stein\u2019s identity plays a fundamental role in\nour framework. Let p(x) be a continuously differentiable (also called smooth) density supported on\nX \u2286 Rd, and \u03c6(x) = [\u03c61(x),\u00b7\u00b7\u00b7 , \u03c6d(x)](cid:62) a smooth vector function. Stein\u2019s identity states that for\nsuf\ufb01ciently regular \u03c6, we have\n\nij aibjk(xi, xj) for g(x) =(cid:80)\n\nBr\n\nBr\n\nEx\u223cp[Ap\u03c6(x)] = 0,\n\nwhere\n\nAp\u03c6(x) = \u2207x log p(x)\u03c6(x)(cid:62) + \u2207x\u03c6(x),\n\nlimr\u2192\u221e(cid:72)\n\np(x)\u03c6(x)(cid:62)n(x)dS = 0 when X = Rd, where(cid:72)\n\n(1)\nwhere Ap is called the Stein operator, which acts on function \u03c6 and yields a zero mean function\nAp\u03c6(x) under x \u223c p. This identity can be easily checked using integration by parts by assuming\nmild zero boundary conditions on \u03c6: either p(x)\u03c6(x) = 0, \u2200x \u2208 \u2202X when X is compact, or\nis the surface integral on the sphere\nBr of radius r centered at the origin and n(x) is the unit normal to Br. We call that \u03c6 is in the Stein\nclass of p if Stein\u2019s identity (1) holds.\nNow let q(x) be a different smooth density also supported in X , and consider the expectation of\nAp\u03c6(x) under x \u223c q, then Ex\u223cq[Ap\u03c6(x)] would no longer equal zero for general \u03c6. Instead, the\nmagnitude of Ex\u223cq[Ap\u03c6(x)] relates to how different p and q are, and can be leveraged to de\ufb01ne a\ndiscrepancy measure, known as Stein discrepancy, by considering the \u201cmaximum violation of Stein\u2019s\nidentity\u201d for \u03c6 in some proper function set F:\nD(q, p) = max\n\u03c6\u2208F\n\n(cid:8)Ex\u223cq[trace(Ap\u03c6(x))](cid:9),\n\nHere the choice of this function set F is critical, and decides the discriminative power and computa-\ntional tractability of Stein discrepancy. Traditionally, F is taken to be sets of functions with bounded\nLipschitz norms, which unfortunately casts a challenging functional optimization problem that is\ncomputationally intractable or requires special considerations (see Gorham and Mackey [11] and\nreference therein).\nKernelized Stein discrepancy (KSD) bypasses this dif\ufb01culty by maximizing \u03c6 in the unit ball of a\nreproducing kernel Hilbert space (RKHS) for which the optimization has a closed form solution.\nKSD is de\ufb01ned as\n\n(cid:8)Ex\u223cq[trace(Ap\u03c6(x))],\n\n||\u03c6||Hd \u2264 1(cid:9),\n\ns.t.\n\nD(q, p) = max\n\u03c6\u2208Hd\n\n(2)\n\n2\n\n\fwhere we assume the kernel k(x, x(cid:48)) of RKHS H is in the Stein class of p as a function of x for any\n\ufb01xed x(cid:48) \u2208 X . The optimal solution of (2) has been shown to be \u03c6(x) = \u03c6\nq,p||Hd [8\u201310],\nq,p(x)/||\u03c6\n\u2217\n\u2217\nwhere\n\nfor which we have\n\nq,p(\u00b7) = Ex\u223cq[Apk(x,\u00b7)],\n\u2217\n\u03c6\n\nD(q, p) = ||\u03c6\nq,p||Hd .\n\u2217\n(3)\nq,p(x) \u2261 0) if and only if p = q\n\u2217\nOne can further show that D(q, p) equals zero (and equivalently \u03c6\nonce k(x, x(cid:48)) is strictly positive de\ufb01nite in a proper sense [See 8, 10], which is satis\ufb01ed by commonly\nused kernels such as the RBF kernel k(x, x(cid:48)) = exp(\u2212 1\nh||x \u2212 x(cid:48)||2\n2). Note that the RBF kernel is\nalso in the Stein class of smooth densities supported in X = Rd because of its decaying property.\nBoth Stein operator and KSD depend on p only through the score function \u2207x log p(x), which can\nbe calculated without knowing the normalization constant of p, because we have \u2207x log p(x) =\n\u2207x log \u00afp(x) when p(x) = \u00afp(x)/Z. This property makes Stein\u2019s identity a powerful tool for handling\nunnormalized distributions that appear widely in machine learning and statistics.\n\n3 Variational Inference Using Smooth Transforms\nVariational inference approximates the target distribution p(x) using a simpler distribution q\u2217(x)\nfound in a prede\ufb01ned set Q = {q(x)} of distributions by minimizing the KL divergence, that is,\n\n(cid:8)KL(q || p) \u2261 Eq[log q(x)] \u2212 Eq[log \u00afp(x)] + log Z(cid:9),\n\nq\u2217 = arg min\nq\u2208Q\n\n(4)\n\nwhere we do not need to calculate the constant log Z for solving the optimization. The choice of\nset Q is critical and de\ufb01nes different types of variational inference methods. The best set Q should\nstrike a balance between i) accuracy, broad enough to closely approximate a large class of target\ndistributions, ii) tractability, consisting of simple distributions that are easy for inference, and iii)\nsolvability so that the subsequent KL minimization problem can be ef\ufb01ciently solved.\nIn this work, we focus on the sets Q consisting of distributions obtained by smooth transforms from a\ntractable reference distribution, that is, we take Q to be the set of distributions of random variables of\nform z = T (x) where T : X \u2192 X is a smooth one-to-one transform, and x is drawn from a tractable\nreference distribution q0(x). By the change of variables formula, the density of z is\n\nq[T ](z) = q(T \u22121(z)) \u00b7 | det(\u2207zT \u22121(z))|,\n\nwhere T \u22121 denotes the inverse map of T and \u2207zT \u22121 the Jacobian matrix of T \u22121. Such distributions\nare computationally tractable, in the sense that the expectation under q[T ] can be easily evaluated by\naveraging {zi} when zi = T (xi) and xi \u223c q0. Such Q can also in principle closely approximate\nalmost arbitrary distributions: it can be shown that there always exists a measurable transform T\nbetween any two distributions without atoms (i.e. no single point carries a positive mass); in addition,\nfor Lipschitz continuous densities p and q, there always exist transforms between them that are least\nas smooth as both p and q. We refer the readers to Villani [12] for in-depth discussion on this topic.\nIn practice, however, we need to restrict the set of transforms T properly to make the corresponding\nvariational optimization in (4) practically solvable. One approach is to consider T with certain\nparametric form and optimize the corresponding parameters [e.g., 13, 14]. However, this introduces a\ndif\ufb01cult problem on selecting the proper parametric family to balance the accuracy, tractability and\nsolvability, especially considering that T has to be an one-to-one map and has to have an ef\ufb01ciently\ncomputable Jacobian matrix.\nInstead, we propose a new algorithm that iteratively constructs incremental transforms that effectively\nperform steepest descent on T in RKHS. Our algorithm does not require to explicitly specify\nparametric forms, nor to calculate the Jacobian matrix, and has a particularly simple form that\nmimics the typical gradient descent algorithm, making it easily implementable even for non-experts\nin variational inference.\n\n3.1 Stein Operator as the Derivative of KL Divergence\n\nTo explain how we minimize the KL divergence in (4), we consider an incremental transform formed\nby a small perturbation of the identity map: T (x) = x + \u0001\u03c6(x), where \u03c6(x) is a smooth function\n\n3\n\n\fthat characterizes the perturbation direction and the scalar \u0001 represents the perturbation magnitude.\nWhen |\u0001| is suf\ufb01ciently small, the Jacobian of T is full rank (close to the identity matrix), and hence\nT is guaranteed to be an one-to-one map by the inverse function theorem.\nThe following result, which forms the foundation of our method, draws an insightful connection\nbetween Stein operator and the derivative of KL divergence w.r.t. the perturbation magnitude \u0001.\nTheorem 3.1. Let T (x) = x + \u0001\u03c6(x) and q[T ](z) the density of z = T (x) when x \u223c q(x), we have\n(5)\n\n\u2207\u0001KL(q[T ] || p)(cid:12)(cid:12)\u0001=0 = \u2212Ex\u223cq[trace(Ap\u03c6(x))],\n\nwhere Ap\u03c6(x) = \u2207x log p(x)\u03c6(x)(cid:62) + \u2207x\u03c6(x) is the Stein operator.\nRelating this to the de\ufb01nition of KSD in (2), we can identify the \u03c6\ndirection that gives the steepest descent on the KL divergence in zero-centered balls of Hd.\nLemma 3.2. Assume the conditions in Theorem 3.1. Consider all the perturbation directions \u03c6 in\nthe ball B = {\u03c6 \u2208 Hd : ||\u03c6||Hd \u2264 D(q, p)} of vector-valued RKHS Hd, the direction of steepest\ndescent that maximizes the negative gradient in (5) is the \u03c6\n\n\u2217\nq,p in (3) as the optimal perturbation\n\n\u2217\nq,p in (3), i.e.,\nq,p(\u00b7) = Ex\u223cq[\u2207x log p(x)k(x,\u00b7) + \u2207xk(x,\u00b7)],\n\u2217\n\u03c6\n\nfor which (5) equals the square of KSD, that is, \u2207\u0001KL(q[T ] || p)(cid:12)(cid:12)\u0001=0 = \u2212D2(q, p).\n\n(6)\n\nThe result in Lemma (3.2) suggests an iterative procedure that transforms an initial reference distri-\n0(x) = x + \u00010 \u00b7 \u03c6\n\u2217\nbution q0 to the target distribution p: we start with applying transform T \u2217\nq0,p(x)\non q0 which decreases the KL divergence by an amount of \u00010 \u00b7 D2(q0, p), where \u00010 is a small\nstep size; this would give a new distribution q1(x) = q0[T 0](x), on which a further transform\n1(x) = x + \u00011 \u00b7 \u03c6\nq1,p(x) can further decrease the KL divergence by \u00011 \u00b7 D2(q1, p). Repeating this\n\u2217\nT \u2217\nprocess one constructs a path of distributions {q(cid:96)}n\n\n(cid:96)=1 between q0 and p via\n\n(cid:96) (x) = x + \u0001(cid:96) \u00b7 \u03c6\n\u2217\nT \u2217\nq(cid:96),p(x).\n\nwhere\n\nq(cid:96)+1 = q(cid:96)[T \u2217\n(cid:96) ],\n\n(7)\nThis would eventually converge to the target p with suf\ufb01ciently small step-size {\u0001(cid:96)}, under which\np,q\u221e(x) \u2261 0 and T \u2217\np,q\u221e(x) \u2261 0.\n\u2217\n\u2217\n\u221e reduces to the identity map. Recall that q\u221e = p if and only if \u03c6\n\u03c6\nFunctional Gradient To gain further intuition on this process, we now reinterpret (6) as a functional\ngradient in RKHS. For any functional F [f ] of f \u2208 Hd, its (functional) gradient \u2207f F [f ] is a function\nin Hd such that F [f + \u0001g(x)] = F [f ] + \u0001 (cid:104)\u2207f F [f ], g(cid:105)Hd + O(\u00012) for any g \u2208 Hd and \u0001 \u2208 R.\nTheorem 3.3. Let T (x) = x + f (x), where f \u2208 Hd, and q[T ] the density of z = T (x) when x \u223c q,\n\n\u2207f KL(q[T ] || p)(cid:12)(cid:12)f =0 = \u2212\u03c6\n\n\u2217\nq,p(x),\n\nwhose RKHS norm is ||\u03c6\nq,p||Hd = D(q, p).\n\u2217\nThis suggests that T \u2217(x) = x + \u0001 \u00b7 \u03c6\n\u2217\nq,p(x) is equivalent to a step of functional gradient descent in\nRKHS. However, what is critical in the iterative procedure (7) is that we also iteratively apply the\nvariable transform so that every time we would only need to evaluate the functional gradient descent\nat zero perturbation f = 0 on the identity map T (x) = x. This brings a critical advantage since\nthe gradient at f (cid:54)= 0 is more complex and would require to calculate the inverse Jacobian matrix\n[\u2207xT (x)]\u22121 that casts computational or implementation hurdles.\n\n3.2 Stein Variational Gradient Descent\n\nTo implement the iterative procedure (7) in practice, one would need to approximate the expectation\ni}n\nq,p(x) in (6). To do this, we can \ufb01rst draw a set of particles {x0\n\u2217\nfor calculating \u03c6\ni=1 from the initial\ndistribution q0, and then iteratively update the particles with an empirical version of the transform in\n\u2217\n(7) in which the expectation under q(cid:96) in \u03c6\nq(cid:96),p is approximated by the empirical mean of particles\n{x(cid:96)\ni}n\ni=1 at the (cid:96)-th iteration. This procedure is summarized in Algorithm 1, which allows us to\n(deterministically) transport a set of points to match our target distribution p(x), effectively providing\n\n4\n\n\fAlgorithm 1 Bayesian Inference via Variational Gradient Descent\n\nInput: A target distribution with density function p(x) and a set of initial particles {x0\nOutput: A set of particles {xi}n\nfor iteration (cid:96) do\ni \u2190 x(cid:96)\n1\nx(cid:96)+1\nn\nwhere \u0001(cid:96) is the step size at the (cid:96)-th iteration.\n\ni=1 that approximates the target distribution p(x).\n\n(cid:2)k(x(cid:96)\n\ni + \u0001(cid:96) \u02c6\u03c6\u2217(x(cid:96)\n\nj) + \u2207x(cid:96)\n\nj, x)\u2207x(cid:96)\n\n\u02c6\u03c6\u2217(x) =\n\ni ) where\n\nn(cid:88)\n\nk(x(cid:96)\n\nj\n\nlog p(x(cid:96)\n\nj\n\nj=1\n\ni}n\ni=1.\n\nj, x)(cid:3), (8)\n\nend for\n\ni\n\n\u221a\n\ni=1 h(x(cid:96)\n\ni )/n \u2212 Eq(cid:96) [h(x)] = O(1/\n\nguarantee that(cid:80)n\n\ni} to get the empirical measure \u02c6q(cid:96)+1 of particles {x(cid:96)+1\n\na sampling method for p(x). We can see that the implementation of this procedure does not depend\non the initial distribution q0 at all, and in practice we can start with a set of arbitrary points {xi}n\ni=1,\npossibly generated by a complex (randomly or deterministic) black-box procedure.\nWe can expect that {x(cid:96)\ni}n\ni=1 forms increasingly better approximation for q(cid:96) as n increases. To\nsee this, denote by \u03a6 the nonlinear map that takes the measure of q(cid:96) and outputs that of q(cid:96)+1 in\n\u2217\n(7), that is, q(cid:96)+1 = \u03a6(cid:96)(q(cid:96)), where q(cid:96) enters the map through both q(cid:96)[T \u2217\n(cid:96) ] and \u03c6\nq(cid:96),p. Then, the\nupdates in Algorithm 1 can be seen as applying the same map \u03a6 on the empirical measure \u02c6q(cid:96) of\nparticles {x(cid:96)\n} at the next iteration, that is,\n\u02c6q(cid:96)+1 = \u03a6(cid:96)(\u02c6q(cid:96)). Since \u02c6q0 converges to q0 as n increases, \u02c6q(cid:96) should also converge to q(cid:96) when the\nmap \u03a6 is \u201ccontinuous\u201d in a proper sense. Rigorous theoretical results on such convergence have\nbeen established in the mean \ufb01eld theory of interacting particle systems [e.g., 15], which in general\nn) for bounded testing functions h. In addition,\nthe distribution of each particle x(cid:96)\ni0, for any \ufb01xed i0, also tends to q(cid:96), and is independent with any\nother \ufb01nite subset of particles as n \u2192 \u221e, a phenomenon called propagation of chaos [16]. We leave\nconcrete theoretical analysis for future work.\nAlgorithm 1 mimics a gradient dynamics at the particle level, where the two terms in \u02c6\u03c6\u2217(x) in (8)\nplay different roles: the \ufb01rst term drives the particles towards the high probability areas of p(x)\nby following a smoothed gradient direction, which is the weighted sum of the gradients of all the\npoints weighted by the kernel function. The second term acts as a repulsive force that prevents\nall the points to collapse together into local modes of p(x); to see this, consider the RBF kernel\nh (x \u2212 xj)k(xj, x), which drives\nk(x, x(cid:48)) = exp(\u2212 1\nj\nIf we let bandwidth h \u2192 0, the\nx away from its neighboring points xj that have large k(xj, x).\nrepulsive term vanishes, and update (8) reduces to a set of independent chains of typical gradient\nascent for maximizing log p(x) (i.e., MAP) and all the particles would collapse into the local modes.\nAnother interesting case is when we use only a single particle (n = 1), in which case Algorithm 1\nreduces to a single chain of typical gradient ascent for MAP for any kernel that satis\ufb01es \u2207xk(x, x) = 0\n(for which RBF holds). This suggests that our algorithm can generalize well for supervised learning\ntasks even with a very small number n of particles, since gradient ascent for MAP (n = 1) has been\nshown to be very successful in practice. This property distinguishes our particle method with the\ntypical Monte Carlo methods that requires to average over many points. The key difference here is\nthat we use a deterministic repulsive force, other than Monte Carlo randomness, to get diverse points\nfor distributional approximation.\n\nh||x \u2212 x(cid:48)||2), the second term reduces to(cid:80)\n\n2\n\nsettings when p(x) \u221d p0(x)(cid:81) N\n\nComplexity and Ef\ufb01cient Implementation The major computation bottleneck in (8) lies on cal-\nculating the gradient \u2207x log p(x) for all the points {xi}n\ni=1; this is especially the case in big data\nk=1p(Dk|x) with a very large N. We can conveniently address this\nproblem by approximating \u2207x log p(x) with subsampled mini-batches \u2126 \u2282 {1, . . . , N} of the data\n\n\u2207x log p(x) \u2248 log p0(x) +\n\nlog p(Dk | x).\n\n(9)\n\nAdditional speedup can be obtained by parallelizing the gradient evaluation of the n particles.\n\nThe update (8) also requires to compute the kernel matrix {k(xi, xj)} which costs O(cid:0)n2(cid:1); in practice,\n\nthis cost can be relatively small compared with the cost of gradient evaluation, since it can be suf\ufb01cient\nto use a relatively small n (e.g., several hundreds) in practice. If there is a need for very large n, one\n\n(cid:88)\n\nk\u2208\u2126\n\nN\n|\u2126|\n\n5\n\n\fcan approximate the summation(cid:80)n\n\nexpansion of the kernel k(x, x(cid:48)) [17].\n\ni=1 in (8) by subsampling the particles, or using a random feature\n\n4 Related Works\n\nOur work is mostly related to Rezende and Mohamed [13], which also considers variational inference\nover the set of transformed random variables, but focuses on transforms of parametric form T (x) =\nf(cid:96)(\u00b7\u00b7\u00b7 (f1(f0(x)))) where fi(\u00b7) is a prede\ufb01ned simple parametric transform and (cid:96) a prede\ufb01ned length;\nthis essentially creates a feedforward neural network with (cid:96) layers, whose invertibility requires further\nconditions on the parameters and needs to be established case by case. The similar idea is also\ndiscussed in Marzouk et al. [14], which also considers transforms parameterized in special ways\nto ensure the invertible and the computational tractability of the Jacobian matrix. Recently, Tran\net al. [18] constructed a variational family that achieves universal approximation based on Gaussian\nprocess (equivalent to a single-layer, in\ufb01nitely-wide neural network), which does not have a Jacobian\nmatrix but needs to calculate the inverse of the kernel matrix of the Gaussian process. Our algorithm\nhas a simpler form, and does not require to calculate any matrix determinant or inversion. Several\nother works also leverage variable transforms in variational inference, but with more limited forms;\nexamples include af\ufb01ne transforms [19, 20], and recently the copula models that correspond to\nelement-wise transforms over the individual variables [21, 22].\nOur algorithm maintains and updates a set of particles, and is of similar style with the Gaussian\nmixture variation inference methods whose mean parameters can be treated as a set of particles.\n[23\u201326, 5]. Optimizing such mixture KL objectives often requires certain approximation, and this\nwas done most recently in Gershman et al. [5] by approximating the entropy using Jensen\u2019s inequality\nand the expectation term using Taylor approximation. There is also a large set of particle-based Monte\nCarlo methods, including variants of sequential Monte Carlo [e.g., 27, 28], as well as a recent particle\nmirror descent for optimizing the variational objective function [7]; compared with these methods,\nour method does not have the weight degeneration problem, and is much more \u201cparticle-ef\ufb01cient\u201d in\nthat we reduce to MAP with only one single particle.\n\n5 Experiments\n\nh||x \u2212 x(cid:48)||2\nj k(xi, xj) \u2248 n exp(\u2212 1\n\ni=1; this is based on the intuition that we would have(cid:80)\n\nWe test our algorithm on both toy and real world examples, on which we \ufb01nd our method tends to\noutperform a variety of baseline methods. Our code is available at https://github.com/DartML/\nStein-Variational-Gradient-Descent.\nFor all our experiments, we use RBF kernel k(x, x(cid:48)) = exp(\u2212 1\n2), and take the bandwidth\nto be h = med2/ log n, where med is the median of the pairwise distance between the current points\n{xi}n\nh med2) = 1,\nso that for each xi the contribution from its own gradient and the in\ufb02uence from the other points\nbalance with each other. Note that in this way, the bandwidth h actually changes adaptively across\nthe iterations. We use AdaGrad for step size and initialize the particles using the prior distribution\nunless otherwise speci\ufb01ed.\nToy Example on 1D Gaussian Mixture We set our target distribution to be p(x) = 1/3N (x; \u2212\n2, 1) + 2/3N (x; 2, 1), and initialize the particles using q0(x) = N (x;\u221210, 1). This creates a\nchallenging situation since the probability mass of p(x) and q0(x) are far away each other (with\nalmost zero overlap). Figure 1 shows how the distribution of the particles (n = 1) of our method\nevolve at different iterations. We see that despite the small overlap between q0(x) and p(x), our\nmethod can push the particles towards the target distribution, and even recover the mode that is further\naway from the initial point. We found that other particle based algorithms, such as Dai et al. [7], tend\nto experience weight degeneracy on this toy example due to the ill choice of q0(x).\nFigure 2 compares our method with Monte Carlo sampling when using the obtained particles to\nestimate expectation Ep(h(x)) with different test functions h(\u00b7). We see that the MSE of our method\ntends to perform similarly or better than the exact Monte Carlo sampling. This may be because our\nparticles are more spread out than i.i.d. samples due to the repulsive force, and hence give higher\nestimation accuracy. It remains an open question to formally establish the error rate of our method.\n\n6\n\n\fFigure 1: Toy example with 1D Gaussian mixture. The red dashed lines are the target density function\nand the solid green lines are the densities of the particles at different iterations of our algorithm\n(estimated using kernel density estimator) . Note that the initial distribution is set to have almost zero\noverlap with the target distribution, and our method demonstrates the ability of escaping the local\nmode on the left to recover the mode on the left that is further away. We use n = 100 particles.\n\n(a) Estimating E(x)\n\n(b) Estimating E(x2)\n\n(c) Estimating E(cos(\u03c9x + b))\n\nFigure 2: We use the same setting as Figure 1, except varying the number n of particles. (a)-(c)\nshow the mean square errors when using the obtained particles to estimate expectation Ep(h(x)) for\nh(x) = x, x2, and cos(\u03c9x + b); for cos(\u03c9x + b), we draw \u03c9 \u223c N (0, 1) and b \u223c Uniform([0, 2\u03c0])\nand report the average MSE over 20 random draws of \u03c9 and b.\n\nBayesian Logistic Regression We consider Bayesian logistic regression for binary classi\ufb01cation\nusing the same setting as Gershman et al. [5], which assigns the regression weights w with a\nGaussian prior p0(w|\u03b1) = N (w, \u03b1\u22121) and p0(\u03b1) = Gamma(\u03b1, 1, 0.01). The inference is applied\non posterior p(x|D) with x = [w, log \u03b1]. We compared our algorithm with the no-U-turn sampler\n(NUTS)1 [29] and non-parametric variational inference (NPV)2 [5] on the 8 datasets (N > 500) used\nin Gershman et al. [5], and \ufb01nd they tend to give very similar results on these (relatively simple)\ndatasets; see Appendix for more details.\nWe further test the binary Covertype dataset3 with 581,012 data points and 54 features. This dataset\nis too large, and a stochastic gradient descent is needed for speed. Because NUTS and NPV do\nnot have mini-batch option in their code, we instead compare with the stochastic gradient Langevin\ndynamics (SGLD) by Welling and Teh [2], the particle mirror descent (PMD) by Dai et al. [7], and\nthe doubly stochastic variational inference (DSVI) by Titsias and L\u00e1zaro-Gredilla [19].4 We also\ncompare with a parallel version of SGLD that runs n parallel chains and take the last point of each\nchain as the result. This parallel SGLD is similar with our method and we use the same step-size of\n\u0001(cid:96) = a/(t + 1).55 for both as suggested by Welling and Teh [2] for fair comparison; 5 we select a\nusing a validation set within the training set. For PMD, we use a step size of a\nt), and\nRBF kernel k(x, x(cid:48)) = exp(\u2212||x \u2212 x(cid:48)||2/h) with bandwidth h = 0.002 \u00d7 med2 which is based on\nthe guidance of Dai et al. [7] which we \ufb01nd works most ef\ufb01ciently for PMD. Figure 3(a)-(b) shows\nthe results when we initialize our method and both versions of SGLD using the prior p0(\u03b1)p0(w|\u03b1);\nwe \ufb01nd that PMD tends to be unstable with this initialization because it generates weights w with\nlarge magnitudes, so we divided the initialized weights by 10 for PMD; as shown in Figure 3(a),\nthis gives some advantage to PMD in the initial stage. We \ufb01nd our method generally performs the\nbest, followed with the parallel SGLD, which is much better than its sequential counterpart; this\ncomparison is of course in favor of parallel SGLD, since each iteration of it requires n = 100 times of\nlikelihood evaluations compared with sequential SGLD. However, by leveraging the matrix operation\nin MATLAB, we \ufb01nd that each iteration of parallel SGLD is only 3 times more expensive than\nsequential SGLD.\n\nN /(100 +\n\n\u221a\n\n1code: http://www.cs.princeton.edu/ mdhoffma/\n2code: http://gershmanlab.webfactional.com/pubs/npv.v1.zip\n3https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html\n4code: http://www.aueb.gr/users/mtitsias/code/dsvi_matlabv1.zip.\n5We scale the gradient of SGLD by a factor of 1/n to make it match with the scale of our gradient in (8).\n\n7\n\n-100100.10.20.30.40th Iteration-100100.10.20.30.450th Iteration-100100.10.20.30.475th Iteration-100100.10.20.30.4100th Iteration-100100.10.20.30.4150th Iteration-100100.10.20.30.4500th IterationSample Size (n)10 50 250Log10 MSE-1.5-1-0.5Sample Size (n)10 50 250Log10 MSE-3-2-10Sample Size (n)10 50 250Log10 MSE-3.5-3-2.5-2Monte CarloStein Variational Gradient Descent\f(a) Particle size n = 100\n\n(b) Results at 3000 iteration (\u2248 0.32 epoches)\n\nFigure 3: Results on Bayesian logistic regression on Covertype dataset w.r.t. epochs and the particle\nsize n. We use n = 100 particles for our method, parallel SGLD and PMD, and average the last 100\npoints for the sequential SGLD. The \u201cparticle-based\u201d methods (solid lines) in principle require 100\ntimes of likelihood evaluations compare with DVSI and sequential SGLD (dash lines) per iteration,\nbut are implemented ef\ufb01ciently using Matlab matrix operation (e.g., each iteration of parallel SGLD\nis about 3 times slower than sequential SGLD). We partition the data into 80% for training and 20%\nfor testing and average on 50 random trials. A mini-batch size of 50 is used for all the algorithms.\n\nBayesian Neural Network We compare our algorithm with the probabilistic back-propagation\n(PBP) algorithm by Hern\u00e1ndez-Lobato and Adams [30] on Bayesian neural networks. Our experiment\nsettings are almost identity, except that we use a Gamma(1, 0.1) prior for the inverse covariances and\ndo not use the trick of scaling the input of the output layer. We use neural networks with one hidden\nlayers, and take 50 hidden units for most datasets, except that we take 100 units for Protein and Year\nwhich are relatively large; all the datasets are randomly partitioned into 90% for training and 10% for\ntesting, and the results are averaged over 20 random trials, except for Protein and Year on which 5\nand 1 trials are repeated, respectively. We use RELU(x) = max(0, x) as the active function, whose\nweak derivative is I[x > 0] (Stein\u2019s identity also holds for weak derivatives; see Stein et al. [31]).\nPBP is repeated using the default setting of the authors\u2019 code6. For our algorithm, we only use 20\nparticles, and use AdaGrad with momentum as what is standard in deep learning. The mini-batch\nsize is 100 except for Year on which we use 1000.\nWe \ufb01nd our algorithm consistently improves over PBP both in terms of the accuracy and speed; this\nis encouraging since PBP were speci\ufb01cally designed for Bayesian neural network. We also \ufb01nd that\nour results are comparable with the more recent results reported on the same datasets [e.g., 32\u201334]\nwhich leverage some advanced techniques that we can also bene\ufb01t from.\n\nDataset\nBoston\nConcrete\nEnergy\nKin8nm\nNaval\nCombined\nProtein\nWine\nYacht\nYear\n\nAvg. Test RMSE\n\nAvg. Test LL\n\nPBP\n\n2.977 \u00b1 0.093\n5.506 \u00b1 0.103\n1.734 \u00b1 0.051\n0.098 \u00b1 0.001\n0.006 \u00b1 0.000\n4.052 \u00b1 0.031\n4.623 \u00b1 0.009\n0.614 \u00b1 0.008\n0.778 \u00b1 0.042\n0.778 \u00b1 0.042\n0.778 \u00b1 0.042\n8.733 \u00b1 NA\n\nPBP\n\nOur Method\nOur Method\n2.957 \u00b1 0.099\n\u22122.504 \u00b1 0.029\n2.957 \u00b1 0.099 \u22122.579 \u00b1 0.052 \u22122.504 \u00b1 0.029\n2.957 \u00b1 0.099\n\u22122.504 \u00b1 0.029\n5.324 \u00b1 0.104\n\u22123.082 \u00b1 0.018\n5.324 \u00b1 0.104 \u22123.137 \u00b1 0.021 \u22123.082 \u00b1 0.018\n5.324 \u00b1 0.104\n\u22123.082 \u00b1 0.018\n1.374 \u00b1 0.045\n\u22121.767 \u00b1 0.024\n1.374 \u00b1 0.045 \u22121.981 \u00b1 0.028 \u22121.767 \u00b1 0.024\n1.374 \u00b1 0.045\n\u22121.767 \u00b1 0.024\n0.090 \u00b1 0.001\n0.984 \u00b1 0.008\n0.090 \u00b1 0.001\n0.090 \u00b1 0.001\n0.984 \u00b1 0.008\n0.984 \u00b1 0.008\n0.004 \u00b1 0.000\n4.089 \u00b1 0.012\n0.004 \u00b1 0.000\n0.004 \u00b1 0.000\n4.089 \u00b1 0.012\n4.089 \u00b1 0.012\n4.033 \u00b1 0.033\n\u22122.815 \u00b1 0.008\n4.033 \u00b1 0.033\n4.033 \u00b1 0.033 \u22122.819 \u00b1 0.008 \u22122.815 \u00b1 0.008\n\u22122.815 \u00b1 0.008\n4.606 \u00b1 0.013\n\u22122.947 \u00b1 0.003\n4.606 \u00b1 0.013 \u22122.950 \u00b1 0.002 \u22122.947 \u00b1 0.003\n4.606 \u00b1 0.013\n\u22122.947 \u00b1 0.003\n0.609 \u00b1 0.010\n\u22120.925 \u00b1 0.014\n0.609 \u00b1 0.010 \u22120.931 \u00b1 0.014 \u22120.925 \u00b1 0.014\n0.609 \u00b1 0.010\n\u22120.925 \u00b1 0.014\n\u22121.211 \u00b1 0.044\n0.864 \u00b1 0.052 \u22121.211 \u00b1 0.044\n\u22121.211 \u00b1 0.044 \u22121.225 \u00b1 0.042\n8.684 \u00b1 NA\n\u22123.580 \u00b1 NA\n8.684 \u00b1 NA\n8.684 \u00b1 NA\n\u22123.580 \u00b1 NA\n\u22123.580 \u00b1 NA\n\u22123.586 \u00b1 NA\n\n0.901 \u00b1 0.010\n3.735 \u00b1 0.004\n\nAvg. Time (Secs)\nPBP\n18\n33\n25\n118\n173\n136\n682\n26\n25\n\nOurs\n161616\n242424\n212121\n414141\n494949\n515151\n686868\n222222\n25\n684684684\n\n7777\n\n6 Conclusion\n\nWe propose a simple general purpose variational inference algorithm for fast and scalable Bayesian\ninference. Future directions include more theoretical understanding on our method, more practical\napplications in deep learning models, and other potential applications of our basic Theorem in\nSection 3.1.\n\nAcknowledgement This work is supported in part by NSF CRII 1565796.\n\n6https://github.com/HIPS/Probabilistic-Backpropagation\n\n8\n\n0.112Number of Epoches0.650.70.75Testing Accuracy1 10 50 250Particle Size (n)0.650.70.75Testing AccuracyStein Variational Gradient Descent (Our Method)Stochastic Langevin (Parallel SGLD)Particle Mirror Descent (PMD)Doubly Stochastic (DSVI)Stochastic Langevin (Sequential SGLD)\fReferences\n[1] M. D. Hoffman, D. M. Blei, C. Wang, and J. Paisley. Stochastic variational inference. JMLR, 2013.\n[2] M. Welling and Y. W. Teh. Bayesian learning via stochastic gradient Langevin dynamics. In ICML, 2011.\n[3] D. Maclaurin and R. P. Adams. Fire\ufb02y Monte Carlo: Exact MCMC with subsets of data. In UAI, 2014.\n[4] R. Ranganath, S. Gerrish, and D. M. Blei. Black box variational inference. In AISTATS, 2014.\n[5] S. Gershman, M. Hoffman, and D. Blei. Nonparametric variational inference. In ICML, 2012.\n[6] A. Kucukelbir, R. Ranganath, A. Gelman, and D. Blei. Automatic variational inference in STAN. In NIPS,\n\n2015.\n\n[7] B. Dai, N. He, H. Dai, and L. Song. Provable Bayesian inference via particle mirror descent. In AISTATS,\n\n2016.\n\n[8] Q. Liu, J. D. Lee, and M. I. Jordan. A kernelized Stein discrepancy for goodness-of-\ufb01t tests and model\n\nevaluation. arXiv preprint arXiv:1602.03253, 2016.\n\n[9] C. J. Oates, M. Girolami, and N. Chopin. Control functionals for Monte Carlo integration. Journal of the\n\nRoyal Statistical Society, Series B, 2017.\n\n[10] K. Chwialkowski, H. Strathmann, and A. Gretton. A kernel test of goodness-of-\ufb01t. arXiv preprint\n\narXiv:1602.02964, 2016.\n\n[11] J. Gorham and L. Mackey. Measuring sample quality with Stein\u2019s method. In NIPS, pages 226\u2013234, 2015.\n[12] C. Villani. Optimal transport: old and new, volume 338. Springer Science & Business Media, 2008.\n[13] D. J. Rezende and S. Mohamed. Variational inference with normalizing \ufb02ows. In ICML, 2015.\n[14] Y. Marzouk, T. Moselhy, M. Parno, and A. Spantini. An introduction to sampling via measure transport.\n\narXiv preprint arXiv:1602.05023, 2016.\n\n[15] P. Del Moral. Mean \ufb01eld simulation for Monte Carlo integration. CRC Press, 2013.\n[16] M. Kac. Probability and related topics in physical sciences, volume 1. American Mathematical Soc., 1959.\n[17] A. Rahimi and B. Recht. Random features for large-scale kernel machines. In NIPS, pages 1177\u20131184,\n\n2007.\n\n[18] D. Tran, R. Ranganath, and D. M. Blei. Variational Gaussian process. In ICLR, 2016.\n[19] M. Titsias and M. L\u00e1zaro-Gredilla. Doubly stochastic variational Bayes for non-conjugate inference. In\n\nICML, pages 1971\u20131979, 2014.\n\n[20] E. Challis and D. Barber. Af\ufb01ne independent variational inference. In NIPS, 2012.\n[21] S. Han, X. Liao, D. B. Dunson, and L. Carin. Variational Gaussian copula inference. In AISTATS, 2016.\n[22] D. Tran, D. M. Blei, and E. M. Airoldi. Copula variational inference. In NIPS, 2015.\n[23] C. M. B. N. Lawrence and T. J. M. I. Jordan. Approximating posterior distributions in belief networks\n\nusing mixtures. In NIPS, 1998.\n\n[24] T. S. Jaakkola and M. I. Jordon. Improving the mean \ufb01eld approximation via the use of mixture distributions.\n\nIn Learning in graphical models, pages 163\u2013173. MIT Press, 1999.\n\n[25] N. D. Lawrence. Variational inference in probabilistic models. PhD thesis, University of Cambridge, 2001.\n[26] T. D. Kulkarni, A. Saeedi, and S. Gershman. Variational particle approximations. arXiv preprint\n\narXiv:1402.5715, 2014.\n\n[27] C. Robert and G. Casella. Monte Carlo statistical methods. Springer Science & Business Media, 2013.\n[28] A. Smith, A. Doucet, N. de Freitas, and N. Gordon. Sequential Monte Carlo methods in practice. Springer\n\nScience & Business Media, 2013.\n\n[29] M. D. Hoffman and A. Gelman. The No-U-Turn sampler: Adaptively setting path lengths in Hamiltonian\n\nMonte Carlo. The Journal of Machine Learning Research, 15(1):1593\u20131623, 2014.\n\n[30] J. M. Hern\u00e1ndez-Lobato and R. P. Adams. Probabilistic backpropagation for scalable learning of Bayesian\n\nneural networks. In ICML, 2015.\n\n[31] C. Stein, P. Diaconis, S. Holmes, G. Reinert, et al. Use of exchangeable pairs in the analysis of simulations.\n\nIn Stein\u2019s Method, pages 1\u201325. Institute of Mathematical Statistics, 2004.\n\n[32] Y. Li, J. M. Hern\u00e1ndez-Lobato, and R. E. Turner. Stochastic expectation propagation. In NIPS, 2015.\n[33] Y. Li and R. E. Turner. Variational inference with Renyi divergence. arXiv preprint arXiv:1602.02311,\n\n2016.\n\n[34] Y. Gal and Z. Ghahramani. Dropout as a Bayesian approximation: Representing model uncertainty in deep\n\nlearning. arXiv preprint arXiv:1506.02142, 2015.\n\n9\n\n\f", "award": [], "sourceid": 1239, "authors": [{"given_name": "Qiang", "family_name": "Liu", "institution": "Dartmouth College"}, {"given_name": "Dilin", "family_name": "Wang", "institution": "Dartmouth College"}]}