# the_variational_predictive_natural_gradient__b325c46a.pdf The Variational Predictive Natural Gradient Da Tang 1 Rajesh Ranganath 2 Abstract Variational inference transforms posterior inference into parametric optimization thereby enabling the use of latent variable models where otherwise impractical. However, variational inference can be finicky when different variational parameters control variables that are strongly correlated under the model. Traditional natural gradients based on the variational approximation fail to correct for correlations when the approximation is not the true posterior. To address this, we construct a new natural gradient called the Variational Predictive Natural Gradient (VPNG). Unlike traditional natural gradients for variational inference, this natural gradient accounts for the relationship between model parameters and variational parameters. We demonstrate the insight with a simple example as well as the empirical value on a classification task, a deep generative model of images, and probabilistic matrix factorization for recommendation. 1. Introduction Variational inference (Jordan et al., 1999) transforms posterior inference in latent variable models into optimization. It posits a parametric approximating family and tries to find the distribution in this family that minimizes the Kullback Leibler (KL) divergence to the posterior. Variational inference makes posterior computation practical where it would not be otherwise. It has powered many applications, including computational biology (Carbonetto et al., 2012; Stegle et al., 2010), language (Miao et al., 2016), compressive sensing (Shi et al., 2014), neuroscience (Manning et al., 2014; Harrison & Green, 2010), and medicine (Ranganath et al., 2016). Variational inference requires choosing an approximating 1Department of Computer Science, Columbia University, New York, New York, USA 2The Courant Institute, New York University, New York, New York, USA. Correspondence to: Da Tang . Proceedings of the 36 th International Conference on Machine Learning, Long Beach, California, PMLR 97, 2019. Copyright 2019 by the author(s). family. The variational family plus the model together define the variational objective. The variational objective can be optimized with stochastic gradients for a broad range of models (Kingma & Welling, 2014; Ranganath et al., 2014; Rezende et al., 2014). When the posterior has correlations, dimensions of the optimization problem become tied, i.e., there is curvature. One way to correct for curvature in optimization is to use natural gradients (Amari, 1998; Ollivier et al., 2011; Thomas et al., 2016) . Natural gradients for variational inference (Hoffman et al., 2013) adjust for the non-Euclidean nature of probability distributions. But they may not change the gradient direction when the variational approximation is far from the posterior. To deal with curvature induced by dependent observation dimensions in the variational objective, we define a new type of natural gradient: the variational predictive natural gradient (VPNG). The VPNG rescales the gradient with the inverse of the expected Fisher information matrix of the reparameterized model likelihood. We relate this matrix to the negative Hessian of the expected log-likelihood part of the evidence lower bound (ELBO), thereby showing it captures the curvature of variational inference. Our new natural gradient captures potential pathological curvature introduced by the log-likelihood traditional natural gradient cannot capture. Further, unlike traditional natural gradients for variational inference, the VPNG corrects for curvature in the objective between model parameters and variational parameters. In Section 3, we will design an illustrate example where the VPNG points almost directly to the optimum, while both the vanilla gradient and the natural gradient point in almost an orthogonal direction. We show our approach outperforms vanilla gradient optimization and the traditional natural gradient optimization on several latent variable models, including Bayesian logistic regression on synthetic data, variational autoencoders (Kingma & Welling, 2014; Rezende et al., 2014) on images, and variational probabilistic matrix factorization (Mnih & Salakhutdinov, 2008; Gopalan et al., 2015; Liang et al., 2016) on movie recommendation data. Related work. Variational inference has been transformed by the use of Monte Carlo gradient estimators (Kingma & Ba, 2014; Rezende et al., 2014; Mnih & Gregor, 2014; Ranganath et al., 2014; Titsias & L azaro-Gredilla, The Variational Predictive Natural Gradient 2014). Though these approaches expand the applicability of variational inference, the underlying optimization problem can still be hard. Some recent work applied second-order optimization to solve this problem. For example, Fan et al. (2015) derived Hessian-free style optimization for variational inference. Another line of related work is on efficiently computing Fisher information and natural gradients for complex model likelihood such as the K-FAC approximation (Martens & Grosse, 2015; Grosse & Martens, 2016; Ba et al., 2016). Finally, the VPNG can be combined with methods for robustly setting step sizes, like using the VPNG curvature matrix to build the quadratic approximation in Trust VI (Regier et al., 2017). 2. Background Latent variable models Latent variable models posit latent structure z to describe data x with parameters θ. The model is p(x, z) = p(z)p(x | z; θ). The model is split into a prior over the hidden structure p(z) and likelihood that describes the probability of data. Variational inference Variational inference (Jordan et al., 1999) approximates the posterior distribution p(z | x; θ) with a distribution q(z | x; λ) over the latent variables indexed by parameter λ. It works by maximizing the ELBO: L(λ, θ) = Eq [log p(x | z; θ)] KL(q(z | x; λ)||p(z)) (1) Maximizing the ELBO minimizes the KL divergence to the posterior. The model parameters θ and variational parameters λ can be optimized together. The family q is chosen to be amenable to stochastic optimization. One example is the mean-field family, where q(z | x) is factorized over all coordinates of z like in the variational autoencoder. q-Fisher information The ELBO can be optimized with gradients. The effectiveness of gradient ascent methods relates to the geometry of the problem. When the loss landscape contains variables that control the objective in a coupled manner, like the means of two correlated latent variables, gradient ascent methods can be slow. One way to adjust for this coupling or curvature is to use natural gradients (Amari, 1998). Natural gradients account for the non-Euclidean geometry of parameters of probability distributions by looking for optimal ascent directions in symmetric KL-divergence balls. The natural gradient relies on the Fisher information of q, Fq = Eq λ log q(z|x; λ) λ log q(z|x; λ) . (2) 75 50 25 0 25 50 75 100 125 150 75 Gradient VPNG Current Optimum Figure 1. The VPNGs are more effective than vanilla gradients and traditional natural gradients (pointing into the same direction with the vanilla gradients for this example). We call this matrix the q-Fisher information matrix. With this Fisher information matrix, the natural gradient is NG λ L(λ) = F 1 q λL(λ). Natural gradients have been used to optimize the the ELBO (Hoffman et al., 2013). The natural gradient works because it approximates the Hessian of the ELBO at the optimum. The negative Hessian matrix of the ELBO is: λ2 L = Fq+ Z 2 λ2 q (log q(z | x; λ) log p(z | x))dz. (3) The last integral in the above equation is small when the variational distribution q(z | x; λ) is close to the posterior distribution p(z | x). Hence, the q-Fisher information matrix can be viewed as a positive semi-definite version of the negative Hessian matrix of the ELBO. Thus natural gradients improve optimization efficiency, when the variational approximation is close to the posterior. 3. The Variational Predictive Natural Gradient The q-Fisher information is insufficient. Consider the following example with bivariate Gaussian likelihood that has an unknown mean µ = µ1 µ2 , a pathological known covariance Σ = 1 1 ε 1 ε 1 for some constant 0 ε 1, and an isotropic Gaussian prior: p(x1:n, µ) = p(µ | 0, I2) i=1 N (xi | µ, Σ) . (4) To do variational inference, we choose a mean-field approximation q(µ; λ) = N(µ1 | λ1, σ2)N(µ2 | λ2, σ2) with σ to be fixed. The posterior distribution for this problem is analytic: p(µ|x) = N(µ , Σ ) where Σ = (nΣ 1 + I2) 1 and µ = (n I2 + Σ) 1 Pn i=1 xi. The optimal solution for the variational parameter λ should be µ . The gradient The Variational Predictive Natural Gradient of the objective function L(λ) is λL(λ) = λ + Σ 1 The precision matrix Σ 1 is pathological. It has an eigenvector v1 = 1 2(1, 1) with eigenvalue 1 2 ε, and an eigenvector v2 = 1 2(1, 1) with eigenvalue 1 ε. As a result, vanilla gradients will almost always go along the direction of the eigenvector v2, as shown in Figure 1. Further, natural gradients fail to resolve this. The q-Fisher information matrix of this problem is diagonal, so it cannot help resolve the extreme curvature between the parameters λ1 and λ2. Notice that this pathological curvature is not due to that mean-field approximation family on q(µ; λ) does not contain the true posterior p(µ | x). In fact, even if we optimize q(µ; λ) over the family of all bivariate Gaussian distributions N(µ | λµ, λΣ), the partial gradient λµL over the mean parameter vector λµ will still have the same curvature issue. The issue arises since the variational approximation does not approximate the posterior well at initialization . In general, if at some point the current q iterate cannot approximate the posterior well, then the corresponding q-Fisher information matrix may not be able to correct the curvature in the parameters. 3.1. Negative Hessian of the expected log-likelihood The pathology of the ELBO for the model in Equation 4 comes from the ill-conditioned covariance matrix Σ. The covariance matrix of the posterior can correct for this pathology since its covariance matrix is Σ 1 nΣ. The disconnect lies in that variational inference is only close to the posterior at its optimum, which implies that q-natural gradients only correct for the curvature well once the variational approximation is close to the posterior, i.e., the inference problem is almost solved. The problem is that the q-Fisher information matrix measures how parameter perturbations alter the variational approximation, regardless of the current model parameters and the quality of the current variational approximation. We bring the model back into the picture by considering positive definite matrices that resemble the negative Hessian matrix of the expected log-likelihood part Lll = Eq [log p(x | z; θ)] of the ELBO, over both the variational parameter λ and the model parameter θ. The expected log-likelihood contains where the model and variational approximation interact, so its Hessian contains the relevant curvature for optimize the ELBO. However, since we are maximizing the ELBO, the matrices need to not only resemble the negative Hessian, but should also be positive semidefinite. The negative Hessian of the expected log-likelihood is not guaranteed to be positive semi-definite. Our goal is to construct a positive semidefinite matrix related to the negative Hessian that accelerate inference by considering the curvature both the variational parameter and the model parameter. In the sequel, we will show this new matrix is a type of Fisher information. To compute gradients and Hessians, we need to compute derivatives over expectations controlled by the variational parameter λ. In general, we can differentiate and use score function-style estimators from black box variational inference (Ranganath et al., 2014). For simplicity, consider the case where q is reparameterizable (Kingma & Welling, 2014; Rezende et al., 2014). Then draws for z from q can be written as deterministic transformations g of noise terms ε with parameter-free distributions s. This simplifies the computations: z = g(x, ε; λ) q(z | x; λ) ε s(ε). (6) The reparameterization trick can be applied to many common distributions (i.e. reparameterize a Gaussian draw ν N(µ, σ2) as ν = µ + σε where ε N(0, 1)). With this trick, denote η = (λ , θ ) , the negative Hessian matrix of Lll becomes: η2 log p(x | z = g(x, ε; λ); θ) . Let us first consider the case where the variational distribution q factorizes over data points: q(z | x; λ) = Qn i=1 q(zi | xi; λ). This factorization occurs in many popular models, such as in variational autoencoders (VAEs) (Kingma & Welling, 2014; Rezende et al., 2014). Denote Q as the empirical distribution of the observed data x1:n. Also denote p(zi) and p(xi | zi) as the prior and likelihood function for any single data point xi. Moreover, for any data point xi and x i, we define the function u(xi, x i, εi, η) = 2 η2 log p(x i | zi = g(xi, εi; λ); θ). Since we can use zi = g(xi, εi; λ) to reparameterize zi, we can assume that the Jacobian matrix zi εi is always invertible and hence by the inverse function theorem we can also write εi as a function of zi, xi and λ. Hence, we can also express the above equation as η2 log p(x i | zi = g(xi, εi; λ); θ) = v(xi, x i, zi, η). With this notation, we can rewrite the above negative Hes- The Variational Predictive Natural Gradient sian matrix for Lll as η2 log p(xi | zi = g(xi, εi; λ); θ) η2 log p(xi | zi = g(xi, εi; λ); θ) = n EQ(xi) [Eεi [u(xi, xi, εi, η)]] = n EQ(xi) Eq(zi | xi;λ) [v(xi, xi, zi, η)] (7) Assessing the positive definiteness of Equation 7 is a challenge because of the expectation with respect to the variational approximation. To make the positive definiteness easier to wrangle, we make the assumption that p(zi)p(xi | zi) Q(xi)q(zi | xi). (8) When our model is learning a successful parameter vector η, the distribution p(zi, xi) = p(zi)p(xi | zi) should be close to the distribution Q(zi, xi) = Q(xi)q(zi | xi) since the variational distribution q is trying to learn the posterior distribution p(zi | xi) while p(xi) is trying to learn the empirical data distribution Q. This is the only approximation we will use to derive the VPNG. This substitution is similar to q(z | x) p(z | x) made when analyzing the q-Fisher information matrix. They can be quite different when the q(z | x; λ) approximating family may not be large enough to accurately approximate the posterior distribution p(z | x), and when the p(x | z; θ) model may not be able to accurately learn the data distribution Q. With Equation 8 in hand, we have η2 n Ep(zi) Ep(xi | zi;θ) [v(xi, xi, zi, η)] . (9) This matrix is computable via Monte Carlo, however in the next section we show that this matrix may not be positive semidefinite and provide a method to derive a matrix that is positive semidefinite. 3.2. Predictive Sampling for Positive Semi-definiteness The inner expectation of Equation 9 is an expectation of v( xi, xi, zi, η) with respect to the distribution p( xi | zi; θ) on xi. This matrix appears to be an average of Fisher information matrices, and thus positive semidefinite. However, v is not the Hessian of a distribution over xi since xi appears on both sides of conditioning bar. The failure of v to be the Hessian of a distribution for xi means Equation 9 may not be positive definite. Next, we provide a concrete example where its not positive definite. Non Positive Semi-definiteness of Second-Order Derivative. Consider a model with data points x1, . . . , xn R and local latent variables z1, . . . , zn R. The prior is p(z) = Qn i=1 N(zi | 0, 12), the model distribution is p(x | z; θ) = Qn i=1 N(xi | θzi, 12) and the variational distribution is q(z | x; λ) = Qn i=1 N(zi | λxi, σ2) with λ, θ R and the hyperparameter σ > 0. Then we can reparameterize each zi = λ xi + εi with εi N(0, σ2) drawn in an i.i.d. way. Under this model, Equation 9 equals n θ2(θ2 + 1) θ2 1 θ2 1 1 which is not positive semi-definite when |θ| < 1 Predictive Sampling for Positive Semidefiniteness. The failure of the Hessian in Equation 9 to be positive definite stems from v not being the Hessian of a probability distribution. To remedy this, we sample the xi on both side of the conditioning bar independently. That is replace Ep(xi | zi;θ) [v(xi, xi, zi, η)] (10) Ep(xi | zi;θ) Ep(x i | zi;θ) [v(xi, x i, zi, η)] , (11) where x i is a newly drawn data point from the same distribution p( | zi; θ). This step is required. Rescaling the gradient with the inverse of the first equation does not guarantee convergence. This step will allows construction of a positive definite matrix that captures the essence of the negative Hessian. With this transformation, we get n Ep(zi) Ep(xi | zi;θ) Ep(x i | zi;θ) [v(xi, x i, zi, η)] n EQ(xi) Eq(zi | xi;λ) Ep(x i | zi;θ) [v(xi, x i, zi, η)] =n EQ(xi) Eεi Ep(x i | zi=g(xi,εi;λ);θ) [ u(xi, x i, εi, η)] . (12) The approximation step follows from the earlier assumption that the joint of p and q are close (see Equation 8). The inner expectation of the above equation is the negative Hessian matrix of the logarithm of the density of the distribution p(x i | zi = g(xi, εi; λ); θ) with respect to the parameter η, given the latent variable εi and the data point xi. Therefore, this inner expectation equals the Fisher information matrix of this distribution, which is always positive semi-definite. The matrix in Equation 12 meets our desiderata: it maintains structure from the negative Hessian of the expected log-likelihood, is guaranteed to be positive semidefinite for any model and variational approximation to that optimization converges, and is computable via Monte Carlo samples. To see that it is computable, 1This matrix is normally related to both the variational parameter λ and the model parameter θ. Here this matrix is independent with λ since in this model z λ can be represented without λ. The variational parameter will appear in this matrix if we set zi N(λ2xi, σ2) in this model. The Variational Predictive Natural Gradient the matrix in Equation 12 equals n EQ(xi) Eεi Ep(x i | zi=g(xi,εi;λ);θ) [ u(xi, x i, εi, η)] =n EQ(xi)[Eεi[Ep(x i | zi=g(xi,εi;λ);θ)[ η log p(x i | zi = g(xi, εi; λ); θ)) η log p(x i | zi = g(xi, εi; λ); θ)) ]]]. (13) This equation can be computed by sampling a data point from the observed data, sampling a noise term, and resampling a new data point from the model likelihood. 3.3. The variational predictive natural gradient The matrix in Equation 13 is the expectation over a type of Fisher information. First, define p(x i | zi = g(xi, εi; λ); θ) as the reparameterized predictive model distribution. The Fisher information of this matrix given xi and εi is Frep(xi, εi) =Ep(x i | zi=g(xi,εi;λ);θ)[ η log p(x i | zi = g(xi, εi; λ); θ) η log p(x i | zi = g(xi, εi; λ); θ) ]. Averaging the Fisher information of the reparameterized predictive model distribution over observed data points and draws from the variational approximation and rescaling by the number of data points gives. n EQ(xi)Eε[Frep(xi, εi)] = n EQ(xi)[Eεi[Ep(x i | zi=g(xi,εi;λ);θ)[ η log p(x i | zi = g(xi, εi; λ); θ)) η log p(x i | zi = g(xi, εi; λ); θ)) ]]] = Eε[Ep(x | z=g(x,ε;λ);θ)[ η log p(x | z = g(x, ε; λ); θ) η log p(x | z = g(x, ε; λ); θ) ]] =: Fr. The positive semidefinite matrix related to the negative Hessian of the ELBO we derived in the previous section is exactly the expected Fisher information of the reparameterized predictive model distribution p(x | z = g(x, ε; λ); θ). The expected density of reparameterized predictive model distribution can be viewed as the variational predictive distribution r(x | x; λ, θ) of new data Eε [p(x | z = g(x, ε; λ); θ)] =Eq(z | x;λ) [p(x | z; θ)] :=r(x | x; λ, θ). This distribution is the predictive distribution with the posterior replaced by the variational approximation. Hence, we call the matrix in Section 3.3 as the variational predictive Fisher information matrix. This matrix can capture curvature. Though we derive it by assuming q factorizes, this matrix may still capture curvature for the general case. To illustrate that variational predictive Fisher information matrix can capture curvature, consider the example in Equation 4, we can reparameterize latent variable µ in the variational distribution q(µ; λ) = N(µ1 | λ1, σ2)N(µ2 | λ2, σ2) as µ = λ + ε, ε N(0, σ2 I2). Then the reparameterized predicted distribution p(x | µ = λ + ε) equals n Q i=1 N(x i | λ + ε, Σ), whose Fisher information matrix is just nΣ 1. Hence the variational predictive Fisher information matrix for this model is Fr = nΣ 1, which almost exactly matches with the pathological curvature structure in the gradient in Equation 5. Therefore, our variational predictive Fisher information matrix contains the curvature we want to correct. Hence, we apply an update with the new natural gradient, the variational predictive natural gradient (VPNG): VPNG λ,θ L = F 1 r λ,θL(λ, θ). (14) With this new natural gradient, the algorithm can move towards the true mean rather than getting stuck on the line λ1 λ2 = 0, as shown in Figure 1. The variational predictive Fisher information matrix in Section 3.3 is a positive semi-definite matrix related to the negative Hessian of the expected log-likelihood part Lll of the ELBO. It can capture the curvature of variational inference since the expected log-likelihood part of the ELBO usually plays a more important role in the whole objective and we can view the KL divergence part KL(q(z | x)||p(z)) as a regularization for the q distribution. In practice, the KL divergence term gets scaled by a ratio β (0, 1) to learn better representations (Bowman et al., 2016). With this scaling the curvature of the expected log-likelihood part Lll is even more important. 3.4. Comparison with the traditional natural gradient The traditional natural gradient points to the steepest ascent direction of the ELBO in the symmetric KL divergence space of the variational distribution q (Hoffman et al., 2013). The VPNG shares a similar type of geometric structure: it points to the steepest ascent direction of the ELBO in the expected (over the parameter-free distribution s(ε) and data distribution Q(x)) symmetric KL divergence space of the reparameterized predictive distribution p(x | z = g(x, ε; λ); θ). Details are shown in appendix. The q-Fisher information matrix tries to capture the curvature of the ELBO. However, it strongly relies on quality of the fidelity of the variational approximation to the posterior, q(z | x) p(z | x). The new Fisher information matrix, Fr relies on a similar approximation p(z)p(x | z) Q(x)q(z | x), these approximations are still quite different in many cases such as when the model does not approximate the true data distribution well (de- The Variational Predictive Natural Gradient Algorithm 1 Variational inference with VPNGs Input: Data x1:n, Model p(x, z, β). Initialize the parameters λ, θ, and µ. repeat Draw samples ˆβ and ˆzi (Equation 15). Draw i.i.d samples ˆx (k) i (Equation 15). Compute the Fisher information matrix ˆFr (Equation 16). Compute the natural gradient ˆ VPNG λ,θ L (Equation 17). Update the parameters λ, θ with the gradient ˆ VPNG λ,θ L. (Optional) Adjust the dampening parameter µ. until convergence scribed in the paragraph after Equation 8 ). Moreover, Fr has the advantage that it considers the curvature from both the variational parameter λ and the model parameter θ while the q-Fisher information matrix does not consider θ. 4. Variational Inference with Approximate Curvature To build an algorithm with the VPNG, we need to compute the reparameterized predictive distribution and take an expectation with respect to its Fisher information. These steps will only be tractable for specific choices of models and variational approximations. We address how to compute it with Monte Carlo in a broader setting here. We can generate samples for x in the distribution p(x | z; θ) for ˆz drawn from q. These samples can be used to estimate the integrals in the definition of Fr. They are generated through the following Monte Carlo sampling process. Using k {1, . . . , M} to index the Monte Carlo samples: ˆz q(z | x; λ), ˆx (k) i p(x | ˆz; θ). (15) Reparameterization makes it easy to approximate the needed gradients of log p(x (k) i | ˆz; θ) with respect to λ: λ log p(ˆx (k) i | ˆz; θ) λˆz ˆzi log p(ˆx (k) i | ˆz; θ) . Denote ˆbi,k = λ,θ log p(ˆx (k) i | ˆz; θ). Using samples from Equation 15, we can estimate the variational predictive Fisher information in Section 3.3 as i=1 ˆbi,kˆb i,k. (16) This is an unbiased estimate of the variational predictive Fisher information matrix in Section 3.3. The approximate variational predictive Fisher information matrix ˆFr might be non-invertible. Since rank( ˆFr) Mn, the matrix is non-invertible if Mn < dim(λ) + dim(θ). We add a small dampening parameter µ to ensure invertibility. This parameter can be fixed or dynamically adjusted. With this dampening parameter, the approximate variational predictive natural gradient is ˆ VPNG λ,θ L = ( ˆFr + µI) 1 λ,θL. (17) Algorithm 1 summarizes VPNG updates. We set the dampening parameter µ to be a constant in our experiments. We show this algorithm works well in Section 5. 5. Experiments We explore the empirical performance of variational inference using the VPNG updates in Algorithm 12. We consider Bayesian Logistic regression on a synthetic dataset, the VAE on a real handwritten digit dataset, and variational matrix factorization on a real movie recommendation dataset. We test their performances using different metrics on both train and held-out data. We compare VPNG with vanilla gradient optimization and traditional natural gradient optimization using RMSProp (Tieleman & Hinton, 2012) and Adam (Kingma & Ba, 2014) to set the learning rates in all three algorithms. For each algorithm in each task, we show the better result by applying these two learning rate adjustment techniques and select the best decay rate (if applicable) and step size. We use ten Monte Carlo samples to estimate the ELBO, its derivatives, and the variational predictive Fisher information matrix Fr. 5.1. Bayesian Logistic regression We test Algorithm 1 with a Bayesian Logistic regression model on a synthetic dataset. We have the data x1:n and the labels y1:n where xi R4 is a vector and yi {0, 1} is a binary label. Each pair of (xi, yi) is generated through the following process: ai Uniform[ 5, 5] R εk i Uniform[ 0.005, 0.005], k {1, 2, 3, 4}, xi = ai, ai yi =I [ (1, 2, 3, 4), xi 0] . The generated data are all very close to the ground truth classification boundary (1, 2, 3, 4), x = 0. We use Logistic regression with parameter w to model this data. We place an isotropic Gaussian prior distribution p0(w) = N(w | 0, σ2 0 I5) on the parameter w where the parameter σ0 = 100. We apply mean-field variational infer- ence to the parameter w: q(w; µ, σ) = 5Q i=1 N(wi | µi, σ2 i ). 2Code is available at: https://github.com/datang1992/VPNG. The Variational Predictive Natural Gradient Table 1. Bayesian Logistic regression AUC Method Train AUC Test AUC Gradient 0.734 0.017 0.718 0.022 NG 0.744 0.043 0.751 0.047 VPNG 0.972 0.011 0.967 0.011 Mean-field variational families are popular primarily for their optimization efficiency. We aim to show that VPNG improves upon the speed of mean-field approaches. The data generative process and the initial prior parameter σ0 makes the ELBO pathological. Specifically, the covariates are strongly correlated while all data points have small margins with respect to the ground truth boundary. We generate 500 samples and select a fixed set which contains 80% of the whole data for training and use the rest for testing. We test Algorithm 1 and the baseline methods on this data. We do not need Monte Carlo samples of predicted data as the Fr can be computed efficiently given samples from the latent variables in this problem. To compare performances, we allow each algorithm to run 2000 iterations for 10 runs with various step sizes and compare the AUC scores for both the train and test procedure. The AUC scores are computed with the mean prediction. The results are shown in Table 1. In the experiments, we calculate the train and test AUC scores for every 100 iterations and and report the average of the last 5 outputs for each method. Table 1 shows the train and test AUC scores for each method, over all 10 runs. Our method outperforms the baselines. We show a test AUC-iteration curve for this experiment in Appendix B. The vanilla gradient and traditional natural gradient do not perform well because of the curvature induced by the correlation in the covariates. 5.2. Variational autoencoder We also study VPNGs for variational autoencoders (VAEs) (Kingma & Welling, 2014; Rezende et al., 2014) on binarized MNIST (Le Cun et al., 1998). MNIST contains 70,000 images (60,000 for training and 10,000 for testing) of handwritten digits, each of size 28 28. We use a 100-dimensional latent representation zi. Our variational distribution factorizes and we use a three-layer inference network to output the mean and variance of the variational distribution given a datapoint. The generative model transforms z using a three-layer neural network to output logits for each pixel. We use 200 hidden units for both the inference and generative networks. To efficiently compute variational predictive Fisher infor- 0 200 400 600 800 1000 Time (s) Gradient NG VPNG 0 200 400 600 800 1000 Time (s) Gradient NG VPNG 0 500 1000 1500 2000 2500 Number of iterations Gradient NG VPNG 0 500 1000 1500 2000 2500 Number of iterations Gradient NG VPNG Figure 2. VAE Learning curves on binarized MNIST mation matrices, we view the entire VAE structure as a 6layer neural network with a stochastic layer between the third and fourth layer. We then apply the tridiagonal blockwise Kronecker-factored curvature approximation (K-FAC), (Martens & Grosse, 2015). This enables us to compute Fisher information matrices faster in feed-forward neural networks. We further improve efficiency by constructing low-rank approximations of large matrices. Finally, we use exponential moving averages of quantities related to the K-FAC approximations. We show more details in appendix. We compare the VPNG method with the vanilla gradient and natural gradient optimizations. Since the traditional natural gradient does not deal with the model parameter θ, we use the vanilla gradient for θ in this setting. We do not need to compare the performances of the VPNG with the traditional natural gradient by fixing the model parameter θ and learning only the variational parameter λ for two reasons. First, this setting is not common for VAEs. Second, we need to have a fixed value for θ and it is difficult to obtain an optimal value for it before running the algorithms. We select a batch size of 600 since we print the ELBO values every 100 iterations. Hence, we evaluate the performances for each algorithm exactly once per epoch. The test ELBO values are computed over the whole test set and the train ELBO values are computed over a fixed set of 10,000 randomly-chosen (out of the whole 60,000) images. We allow each method to run for 1,000 seconds (we found similar results at longer runtimes) and select the best step sizes among several reasonable choices. Figure 2 shows the results. Though the VPNG method is the slowest per iteration, it outperforms the baseline optimizations on both the train and test sets, even on running time. We also compare these methods with the second-order optimization method, the Hessian-free Stochastic Gaussian Variational Inference (HFSGVI) (Fan et al., 2015). However, it was not fast enough due to the large amount of Hessian-vector product computations. The ELBO values with this method are still far below The Variational Predictive Natural Gradient -200 within 1,000 seconds, which is much slower than the methods shown in Figure 2. The intuitive reason for the performance gain stems from the fact that the VAE parameters control pixels that are highly correlated across images. The VPNG corrects for this correlation. 5.3. Variational matrix factorization Our third experiment is on Movie Lens 20M (Harper & Konstan, 2016). This is a movie recommendation dataset that contains 20 million movie ratings from n 135K users on mtotal 27K movies. Each rating Rraw u,i of the movie i by the user u is a value in the set {0.5, 1.0, 1.5, . . . , 5.0}. We convert the ratings to integer values between 0 and 9 and select all movies with at least 5K ratings yielding m 1K movies. We model the zeros as in implicit matrix factorization (Gopalan et al., 2015). We use Poisson matrix factorization to model this data. Assume there is a latent representation βu Rd for each user u and there is a latent representation θi Rd for each movie i. Here d = 100 is the latent variable dimensionality. Denote softplus(t) = log(1 + et). We model the likelihood as p(R | θ, β) = i=1 Poisson(Ru,i | µ = softplus(β u θi)). We do variational inference on the user latent variable β and treat the movie variables θ as model parameters. The prior on each user latent variable is a standard Normal. We set the variational distribution as q(β | R; λ) = n Q u=1 q(βu | Ru; λ), where q(βu | Ru; λ) uses an inference network that takes as input the row u of the rating matrix R. Similar to the VAE experiment, we use a 3-layer feed-forward neural network. We use 300 hidden units for this experiment. Notice that the above likelihood is exactly a 1-layer feedforward neural network (without the bias term) that takes the latent representations drawn from the variational distribution q(β | R; λ) and outputs the rating matrix as a random matrix with a pointwise Poisson likelihood. Hence, we could view the model as a single-layer generative network and treat the latent variable θ as its parameter. We have transformed variational matrix factorization to a task similar to the VAE. Hence, when we apply Algorithm 1 to this model, we can apply the same tricks used in the VAE experiments to accelerate the performances. We treat the whole model as a 4-layer feedforward neural network and again apply the tridiagonal block-wise K-FAC approximation (Martens & Grosse, 2015) and adopt low-rank approximations of large matrices (again, more details in appendix). The results are shown in Figure 3. We randomly split the data matrix R into train and test sets where the train set contains 90% of the rows of R (it contains ratings from 90% of 500 1000 1500 2000 2500 3000 Time (s) Gradient NG VPNG 500 1000 1500 2000 2500 3000 Time (s) Gradient NG VPNG 250 500 750 1000 1250 1500 1750 Number of iterations Gradient NG VPNG 250 500 750 1000 1250 1500 1750 Number of iterations Gradient NG VPNG Figure 3. VMF Learning curves on Movie Lens 20M the users) and the test set contains the remaining rows. The test ELBO values are computed over the random sampled test set and the train ELBO values are computed over a fixed subset (with its size equal to the test set size) of the whole train set. Since this dataset is larger, we use a batch size of 3000. As can be seen in this figure, the VPNG updates outperform the baseline optimizations on both the train and test learning curves. The curves look slightly different among various train/test splits of the dataset but Algorithm 1 consistently outperforms the baseline methods. The difference stems from the correlations in the ratings of the movies. The traditional natural gradient performs the worst at the beginning since it is only guaranteed to perform well at the end (when q(z | x) is close to the posterior distribution p(z | x), Equation 3 explains this), but not necessarily at the beginning, due to it does not consider potential curvature information in the model distribution. Across both experiments, we find that VPNG dramatically improves estimation and inference at early iterations. 6. CONCLUSION We introduced the variational predictive natural gradients. They adjust for parameter dependencies in variational inference induced by correlations in the observations. We show how to approximate the Fisher information without manual model-specific computations. We demonstrate the insight on a bivariate Gaussian model and the empirical value on a classification model on synthetic data, a deep generative model of images, and matrix factorization for movie recommendation. Future work includes extending to general Bayesian networks with multiple stochastic layers. Acknowledgements We want to thank Jaan Altosaar, Bharat Srikishan, Dawen Liang and Scott Linderman for their helpful comments and suggestions on this paper. The Variational Predictive Natural Gradient Amari, S.-I. Natural gradient works efficiently in learning. Neural computation, 10(2):251 276, 1998. Ba, J., Grosse, R., and Martens, J. Distributed second-order optimization using kronecker-factored approximations. 2016. Bowman, S. R., Vilnis, L., Vinyals, O., Dai, A., Jozefowicz, R., and Bengio, S. Generating sentences from a continuous space. In Proceedings of The 20th SIGNLL Conference on Computational Natural Language Learning, pp. 10 21, 2016. Carbonetto, P., Stephens, M., et al. Scalable variational inference for bayesian variable selection in regression, and its accuracy in genetic association studies. Bayesian analysis, 7(1):73 108, 2012. Fan, K., Wang, Z., Beck, J., Kwok, J., and Heller, K. A. Fast second order stochastic backpropagation for variational inference. In Advances in Neural Information Processing Systems, pp. 1387 1395, 2015. Gopalan, P., Hofman, J. M., and Blei, D. M. Scalable recommendation with hierarchical poisson factorization. In UAI, pp. 326 335, 2015. Grosse, R. and Martens, J. A kronecker-factored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, pp. 573 582, 2016. Harper, F. M. and Konstan, J. A. The movielens datasets: History and context. ACM Transactions on Interactive Intelligent Systems (Tii S), 5(4):19, 2016. Harrison, L. M. and Green, G. G. A bayesian spatiotemporal model for very large data sets. Neuro Image, 50(3):1126 1141, 2010. Hoffman, M., Blei, D., Wang, C., and Paisley, J. Stochastic variational inference. The Journal of Machine Learning Research, 14(1):1303 1347, 2013. Jordan, M., Ghahramani, Z., Jaakkola, T., and Saul, L. An introduction to variational methods for graphical models. Machine learning, 37(2):183 233, 1999. Kingma, D. and Welling, M. Auto-encoding variational Bayes. In International Conference on Learning Representations, 2014. Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. ar Xiv preprint ar Xiv:1412.6980, 2014. Le Cun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradientbased learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. Liang, D., Charlin, L., Mc Inerney, J., and Blei, D. M. Modeling user exposure in recommendation. In Proceedings of the 25th International Conference on World Wide Web, pp. 951 961. International World Wide Web Conferences Steering Committee, 2016. Manning, J. R., Ranganath, R., Norman, K. A., and Blei, D. M. Topographic factor analysis: a bayesian model for inferring brain networks from neural data. Plo S one, 9 (5):e94914, 2014. Martens, J. and Grosse, R. Optimizing neural networks with kronecker-factored approximate curvature. In International Conference on Machine Learning, pp. 2408 2417, 2015. Miao, Y., Yu, L., and Blunsom, P. Neural variational inference for text processing. In International Conference on Machine Learning, pp. 1727 1736, 2016. Mnih, A. and Gregor, K. Neural variational inference and learning in belief networks. ar Xiv preprint ar Xiv:1402.0030, 2014. Mnih, A. and Salakhutdinov, R. R. Probabilistic matrix factorization. In Advances in neural information processing systems, pp. 1257 1264, 2008. Ollivier, Y., Arnold, L., Auger, A., and Hansen, N. Information-geometric optimization algorithms: A unifying picture via invariance principles. ar Xiv preprint ar Xiv:1106.3708, 2011. Ranganath, R., Gerrish, S., and Blei, D. Black box variational inference. In Artificial Intelligence and Statistics, pp. 814 822, 2014. Ranganath, R., Perotte, A., Elhadad, N., and Blei, D. Deep survival analysis. In Machine Learning for Healthcare Conference, pp. 101 114, 2016. Regier, J., Jordan, M. I., and Mc Auliffe, J. Fast black-box variational inference through stochastic trust-region optimization. In Advances in Neural Information Processing Systems, pp. 2399 2408, 2017. Rezende, D., Mohamed, S., and Wierstra, D. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pp. 1278 1286, 2014. Shi, T., Tang, D., Xu, L., and Moscibroda, T. Correlated compressive sensing for networked data. In UAI, pp. 722 731, 2014. Stegle, O., Parts, L., Durbin, R., and Winn, J. A bayesian framework to account for complex non-genetic factors in gene expression levels greatly increases power in eqtl The Variational Predictive Natural Gradient studies. PLo S computational biology, 6(5):e1000770, 2010. Thomas, P., Silva, B. C., Dann, C., and Brunskill, E. Energetic natural gradient descent. In International Conference on Machine Learning, pp. 2887 2895, 2016. Tieleman, T. and Hinton, G. Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4 (2):26 31, 2012. Titsias, M. and L azaro-Gredilla, M. Doubly stochastic variational bayes for non-conjugate inference. In International Conference on Machine Learning, pp. 1971 1979, 2014.