# understanding_stochastic_natural_gradient_variational_inference__ad53aa16.pdf Understanding Stochastic Natural Gradient Variational Inference Kaiwen Wu 1 Jacob R. Gardner 1 Stochastic natural gradient variational inference (NGVI) is a popular posterior inference method with applications in various probabilistic models. Despite its wide usage, little is known about the non-asymptotic convergence rate in the stochastic setting. We aim to lessen this gap and provide a better understanding. For conjugate likelihoods, we prove the first O( 1 T ) non-asymptotic convergence rate of stochastic NGVI. The complexity is no worse than stochastic gradient descent (a.k.a. black-box variational inference) and the rate likely has better constant dependency that leads to faster convergence in practice. For non-conjugate likelihoods, we show that stochastic NGVI with the canonical parameterization implicitly optimizes a non-convex objective. Thus, a global convergence rate of O( 1 T ) is unlikely without some significant new understanding of optimizing the ELBO using natural gradients. 1. Introduction Given a prior p(z) and a likelihood p(y | z), variational inference (VI) approximates the posterior p(z | y) by optimizing the evidence lower bound (ELBO) in a family of variational distributions (Blei et al., 2017). Natural gradient variational inference (NGVI), in particular, optimizes the ELBO by natural gradient descent (NGD) (Amari, 1998). Different from (standard) gradient descent that follows the steepest descent direction induced by the Euclidean distance, NGD follows the steepest descent direction induced by the KL divergence (Honkela & Valpola, 2004; Hensman et al., 2012; Hoffman et al., 2013). The folk wisdom is that the KL divergence is a better metric to compare distributions and thus NGD is believed to be superior than gradient descent, a.k.a. black-box variational inference (Ranganath 1Department of Computer and Information Science, University of Pennsylvania, Philadelphia, United States. Correspondence to: Kaiwen Wu , Jacob R. Gardner . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). et al., 2014). Indeed, NGVI as well as its variants empirically outperforms gradient descent in many cases, and thus enjoys applications in a wide range of probabilistic models. Here, we name a few exmaples: latent Dirichlet allocation topic models (Hoffman et al., 2013), Bayesian neural networks (Khan et al., 2018; Osawa et al., 2019), and large-scale Gaussian processes (Hensman et al., 2013; 2015; Salimbeni et al., 2018). Despite its wide usage, a non-asymptotic convergence rate of NGVI in the stochastic setting is absent, even for simple conjugate likelihoods. A few convergence arguments exist in the literature, but none of them applies to any practical uses of NGVI. For example, Hoffman et al. (2013) have a convergence argument 1 by assuming the Fisher information matrix has eigenvalues bounded from below (by a positive constant) throughout the natural gradient updates. Khan et al. (2016) analyze a variant of NGVI based on Bregman proximal gradient descent by assuming the (KL) divergence is α-strongly convex, a condition that generally does not hold (at least for the KL divergence). Besides, Khan et al. (2016) did not obtain a complexity bound in the stochastic setting they only showed convergence to a region around stationary points. Note that these assumptions do not hold in the entire domain, provably. Even if they hold in a subset of the domain, the constants in these assumptions are difficult to estimate, and might even be arbitrarily bad as the posterior distribution p(z | y) contracts.2 This work aims to lessen this gap and obtain a clean analysis, with minimal assumptions, that is applicable to some practical uses of stochastic NGVI. For the sake of generality, existing analyses have to use assumptions that does not hold in practice. Therefore, we pursue the opposite direction of generality the basic setting of conjugate likelihoods, for which we establish the first O( 1 T ) non-asymptotic convergence rate of stochastic natural gradient variational inference. This rate has the same complexity as the convergence rate of stochastic projected (and proximal) gradient descent recently studied by Domke (2020); Domke et al. (2023); Kim et al. (2023). This, along with our experiments, implies that NGVI ultimately may share the same 1Hoffman et al. (2013) did not give a convergence proof besides a reference to Bottou (1998). 2For instance, the Fisher information matrix gets increasingly close to singular as the covariance of the posterior p(z | y) shrinks. Understanding Stochastic Natural Gradient Variational Inference complexity with other first-order methods. The empirical observation that NGVI is faster than stochastic gradient descent is likely due to a better constant dependency in the big O notation. Indeed, as we will see later, our convergence rate of stochastic NGVI is independent of the objective s condition number and the distance from the initialization to the optimum. Nevertheless, the constant improvement may play a huge difference in practice. Although our convergence rate for stochastic NGVI assumes conjugate likelihoods, it is already applicable to some practical uses, including large-scale Bayesian linear regression and variational parameter learning in stochastic variational Gaussian processes (Hensman et al., 2013; 2015; Salimbeni et al., 2018). Indeed, we will show that all assumptions are strictly satisfied in practice and the constant in the convergence rate can be bounded explicitly using statistics from the training data. For non-conjugate likelihoods, we show that the canonical implementation of stochastic NGVI implicitly optimizes a non-convex objective even when the likelihoods are simple log-concave distributions. Hence, the convergence behavior of stochastic NGVI with non-conjugate likelihoods is more nuanced, which might partially explain why the theoretical understanding of stochastic NGVI is lacking throughout the years. This lack of convexity implies that proving a global convergence rate of O( 1 T ) for non-conjugate likelihoods may require new properties of the ELBO, e.g., the PolyakŁojasiewicz inequality (Polyak et al., 1963; Lojasiewicz, 1963), in order to explain the empirical success of stochastic NGVI for non-conjugate likelihoods (e.g., Hoffman et al., 2013; Salimbeni et al., 2018). 2. Background Notation. We use to denote the vector Euclidean norm. For matrices, the same symbol is overloaded to denote the spectral norm. F denotes the Frobenius norm. , denotes an inner product, whose domain is inferred from its arguments. Let DKL( , ) denote the Kullback Leibler divergence between distributions. Sd ++ (and Sd +) represents the collection of all d d symmetric positive (semi-)definite matrices. Let (and ) be the partial order induced by Sd ++ (and Sd +), i.e., A B if and only if A B Sd ++. 2.1. Variational Inference with Exponential Families Suppose we have a prior p(z) on latent variables z and a likelihood p(y | z) on observations y. Variational inference (VI) aims to find the best approximation of the posterior p(z | y) inside a variational family Q by minimizing the Kullback Leibler (KL) divergence minimize q Q DKL(q(z), p(z | y)), where q is the variational distribution. This is the equivalent to minimizing the objective ℓ(q) = Eq(z)[log p(y | z)] + DKL(q(z), p(z)), (1) which is called the negative evidence lower bound (ELBO). Throughout the paper, we assume the variational family Q is an exponential family (which will be defined below), and the prior p(z) is in Q. Though, the posterior p(z | y) is not necessarily in Q, unless the likelihood is conjugate: we call the likelihood p(y | z) conjugate (with the prior) if and only if p(z | y) Q. Conjugacy implies the variational approximation is exact, so long as (1) is minimized globally. Exponential Family. A (regular and minimal) exponential family is a collection of distributions indexed by a canonical parameter η in the form q(z; η) = h(z) exp ϕ(z), η A(η) , (2) where h is the base measure, ϕ is the sufficient statistic, η is the natural parameter, and A is the log-partition function. The set of all possible η that make q(z; η) integrable forms an open convex set D, called the natural parameter space. The log-partition function A : D R is differentiable and strictly convex on D. The associated expectation parameter ω of q(z; η) is defined as the expected sufficient statistic: ω = Eq(z;η)[ϕ(z)]. (3) The set of all possible expectation parameters ω again forms a convex set Ω, called the expectation parameter space. The natural and expectation parameter spaces, D and Ω, are linked by the gradients of the log-partition function A and its convex conjugate A , where the differentiable and strictly convex function A : Ω R is defined as A (ω) = max η D η, ω A(η). Indeed, the gradient maps A : D Ωand A : Ω D are inverses of each other. Namely, if η D and ω Ω satisfy (3) representing the same distribution, then A(η) = ω, A (ω) = η. (4) Example. A d-dimensional Gaussian distribution N(µ, Σ) has its natural parameter η = (λ, Λ) defined on D = {η = (λ, Λ) Rd Rd d : Λ Sd ++} (5) with the parameter conversion identity λ = Σ 1µ and Λ = 1 2Σ 1. The expectation parameter ω is defined on Ω= {ω = (ξ, Ξ) Rd Rd d : Ξ ξξ Sd ++} (6) with the identity ξ = µ and Ξ = Σ + µµ . See A for more details, where we give explicit expressions for A(η), A (ω), A(η) and A (ω). For more background on exponential families, we direct readers to the monograph by Wainwright & Jordan (2008). Understanding Stochastic Natural Gradient Variational Inference Remark 1 (Overload ℓ). Let η and ω be the natural and expectation parameters of the same distribution q. We have ℓ(q) = ℓ(n)(η) = ℓ(e)(ω), where ℓ(n) and ℓ(e) are the negative ELBO as functions of the natural and expectation parameters respectively. Technically, ℓ, ℓ(n) and ℓ(e) are different functions with different domains and arguments. For notation simplicity, however, we will drop the superscript when the context allows the superscript will be inferred from the argument. 2.2. Natural Gradient Descent for Variational Inference Natural gradient variational inference (NGVI) optimizes the ELBO by natural gradient descent (NGD). It iteratively updates the natural parameter η of the variational distribution by taking a steepest descent step induced by the KL divergence. The yielded update rule is a preconditioned gradient descent with the Fisher information matrix (FIM). Definition 1 (FIM). Given a (not necessarily exponential family) distribution q(z; η) parameterized by η, the Fisher information matrix is defined as F(η) = Eq(z)[ 2 η log q(z; η)], where 2 is taken w.r.t. η. In particular, for the exponential family (2), it takes a simple form F(η) = 2A(η). Definition 2 (NGD). Natural gradient descent iterates ηt+1 = ηt γt F(ηt) 1 ℓ(ηt), (7) where γt is the step size and F(η) is the Fisher information matrix. The FIM-preconditioned gradient F(ηt) 1 ℓ(ηt) is often called the natural gradient. We call the update (7) the canonical NGD update, as NGD can be implemented in other parameters beyond the natural parameter η. However, typically NGD converges the fastest in the natural parameterization, which will be the main focus of this paper. We will revisit other parameterizations in 5. Explicitly inverting the FIM is inefficient, e.g., its takes O(d6) time for a Gaussian due to its d(d + 1) d(d + 1) size. Fortunately, the NGD update can be implemented without explicit FIM inversion for the exponential family (Raskutti & Mukherjee, 2015). Let η and ω be the natural and expectation parameters of the same distribution, hence having the same ELBO value ℓ(n)(η) = ℓ(e)(ω). Plugging in the identity ω = A(η) as in (4), we obtain ℓ(n)(η) = ℓ(e)( A(η)). Differentiating w.r.t. η on both sides gives ℓ(n)(η) = 2A(η) ℓ(e)( A(η)) = F(η) ℓ(e)(ω), which implies F(η) 1 ℓ(n)(η) = ℓ(e)(ω): the natural gradient (of the natural parameter) is simply the gradient of the (negative) ELBO w.r.t. the expectation parameter. Thus, the NGD update rule (7) reduces to ηt+1 = ηt γt ℓ(ωt), (8) with no explicit FIM inversion. 2.3. Natural Gradient Descent as Mirror Descent This section reviews the connection between NGD and mirror descent (MD). This connection was discovered by Raskutti & Mukherjee (2015) and later applied to variational inference by Khan & Lin (2017); Khan et al. (2018). Definition 3 (Bregman Divergence). Given a differentiable and strictly convex function Φ, the associated Bregman divergence is defined as DΦ(a, b) = Φ(a) Φ(b) Φ(b), a b , and we call Φ the distance generating function. Recall that A , as the convex conjugate of the log-partition function A, is differentiable and strictly convex on Ω. Thus, it is a valid distance generating function and induces a Bregman divergence DA on the expectation parameters. The divergence is then used to define mirror descent (Nemirovskij & Yudin, 1983), which iteratively solves regularized firstorder approximations. Definition 4 (MD). The mirror descent update is defined as ωt+1 = argmin ω Ω ℓ(ωt), ω + 1 γt DA (ω, ωt), (9) where γt > 0 is the step size. Mirror descent (MD) is a generalization of gradient descent. If the Bregman divergence DA (ω, ωt) in (9) is replaced with the squared Euclidean norm 1 2 ω ωt 2, we recover the familiar update rule ωt+1 = ωt γt ℓ(ωt). To implement MD efficiently, the minimization in (9) needs to be solved in closed-form. Taking the derivative w.r.t. ω on both sides and setting it equal to zero, we obtain A (ωt+1) = A (ωt) γt ℓ(ωt). (10) Recall A : Ω D maps the expectation parameter ω of an exponential family distribution q to its natural parameter η, i.e., A (ωt) = ηt for all t 0. Thus, the MD update (10) recovers the NGD update (8) exactly. The discussion in this section so far is summarized in the lemma below. In particular, we will not distinguish between NGD and MD in the rest of the paper. Understanding Stochastic Natural Gradient Variational Inference Lemma 1 (NGD = MD). Suppose the NGD update (7) and the MD update (9) start from the same variational distribution q0, i.e., η0 = A (ω0). Then, we have ηt = A (ωt) for all t 0. Namely, NGD and MD produce exactly the same sequence of variational distributions. We introduce a few definitions useful for later proofs. Our results in 4 are built upon casting NGD as a special case of MD and utilizing the recent developments of stochastic MD for relatively smooth and relatively strongly convex functions (Birnbaum et al., 2011; Bauschke et al., 2017; Lu et al., 2018; Hanzely & Richt arik, 2021). Definition 5. Let Φ be a differentiable and strictly convex function. A function f is called β-smooth relative to Φ if f(a) f(b) + f(b), a b + βDΦ(a, b) holds for all a, b in the domain. A function f is called α-strongly convex relative to Φ if f(a) f(b) + f(b), a b + αDΦ(a, b) holds for all a, b in the domain. Relative smoothness and relative strongly convexity recover the usual definitions of smoothness and strong convexity when Φ( ) = 1 3. Stochastic Natural Gradient VI This section discusses the implementation of natural gradient variational inference in the stochastic setting. Two types of stochasticity may arise in pratice: (a) the expected log likelihood Eq(z)[log p(y | z)] in the ELBO (1) does not have a closed-form for most non-conjugate likelihoods due to the intractable integral, and thus one needs to estimate it stochastically;3 (b) the expected log likelihood is a finite sum over a large number of training data, and one needs to employ mini-batch stochastic optimization, e.g., Example 1. Care is required when implementing the update rule (8), or equivalently the update rule (10), in the stochastic setting. Recall that the natural parameter η D has a domain. In the stochastic setting, implementing the update rule (8) with stochastic gradients b ℓ(ωt) ℓ(ωt) does not necessarily guarantee that ηt+1 stays inside the domain D, in which case NGD breaks down. 3.1. A Sufficient Condition for Valid NGD Updates We will give a sufficient condition on the stochastic gradient b ℓ(ωt) that guarantees the natural parameter η always stays inside the domain D. As shown in A, the KL divergence 3Though, the expected log likelihood does have a closed-form for some non-conjugate likelihoods. See 5 for an example. has a closed-form gradient w.r.t. the expectation parameter: ωDKL(q(z), p(z)) = η ηp, where η and ηp are the natural parameters of q(z) and p(z) respectively. Thus, the stochasticity comes solely from stochastically estimating the expected log likelihood: b ℓ(ω) = b ωEq(z)[log p(y | z)] + η ηp. Plugging it into the NGD update (8), we obtain ηt+1 = (1 γt)ηt + γt b ωEqt(z)[log p(y | z)] + ηp . Recall that the natural parameter space D is an open convex set. Hence, ηt+1 stays in D provided that (a) γt [0, 1] and (b) b ωEqt(z)[log p(y | z)] + ηp D. The first condition is satisfied if the step size γt is chosen properly. The second condition is more complicated, but can still be satisfied by carefully constructed stochastic gradient estimators. 3.2. Common Stochastic Gradient Estimators This section discusses a common special case: (a) the variational family Q is the collection of all Gaussians; (b) the prior p(z) is a Gaussian; and (c) the likelihood p(y | z) is log-concave in z. In the following, we give two examples of stochastic gradients. One example guarantees valid NGD updates while the other one does not. For Gaussians, the only constraint on the natural parameter η = (λ, Λ) is that its second component is negative definite Λ 0. The sufficient condition for valid stochastic NGD updates in the previous section reduces to the following: Remark 2. Suppose that the variational and the prior are both Gaussians. The NGD update (8) is valid for all t 0 in the stochastic setting if (a) the step size γt [0, 1] and (b) the stochastic gradient of the expected log likelihood (b ξ, b Ξ) = b ωEq(z) log p(y | z) has its second component negative definite b Ξ 0. Automatic Differentiation. For intractable expected log likelihoods, a simple estimator for their gradients uses the reparameterization trick (Kingma & Welling, 2013; Titsias & L azaro-Gredilla, 2014; Rezende et al., 2014) and automatic differentiation, shown in Algorithm 1. This stochastic gradient guarantees that b ΞEq log p(y | z) is unbiased and symmetric (Murray, 2016), but does not guarantee b Ξ is negative definite. A counterexample is given in B. Bonnet s and Price s Gradients. Consider the gradients w.r.t. the mean and covariance of a Gaussian N(µ, Σ). By the Bonnet and Price theorems (Bonnet, 1964; Price, 1958; Opper & Archambeau, 2009), they are µEq(z)[log p(x | z)] = Eq(z)[ z log p(x | z)], ΣEq(z)[log p(x | z)] = 1 2Eq(z)[ 2 z log p(x | z)]. Understanding Stochastic Natural Gradient Variational Inference Algorithm 1: Auto Differentiation Stochastic Gradient Input: ω = (ξ, Ξ), the expectation parameter of q(z) Output: b ξ, b Ξ = b ωEq(z)[log p(y | z)] 1 (µ, Σ) = ξ, Ξ ξξ // conversion 2 C = cholesky(Σ) 3 u N(0, I) 4 z = µ + Cu // z N(µ, Σ) 5 loss = log p(y | z) // forward pass 6 loss.backward() // compute b ξ, b Ξ Applying the chain rule through ξ = µ and Ξ = µµ + Σ, and approximating the expectations with samples, we obtain a stochastic gradient b ωEq(z)[log p(y | z)]: i=1 [ z log p(y | zi) 2 z log p(y | zi) µ] i=1 2 z log p(y | zi) where zi q are i.i.d. samples from the variational distribution. While the stochastic gradient b ξ in (11) coincides with the reparameterization trick, the second line is not the same as the stochastic gradient by automatic differentiation in Algorithm 1: b Ξ is negative definite for all log-concave likelihoods (concavity in z). Hence, (11) guarantees valid stochastic NGD updates provided that γt [0, 1], and often appears in the natural gradient variational inference literature (e.g., Khan et al., 2015; Khan & Lin, 2017; Zhang et al., 2018; Lin et al., 2020). Additional Discussion. The main goal of this section is to point out the sufficient condition for valid NGD updates in the stochastic setting, as well as its special case Remark 2. Those observations, though simple, are prerequisites for the convergence of stochastic NGVI in 4. Moreover, the Bonnet and Price stochastic gradients will be used in the experiments in 6. We mention a few common workarounds to take advantage of automatic differentiation, even though natively applying automatic differentiation may break down stochastic NGD. Numerous approximate NGD methods admit valid updates in the stochastic settings (Khan et al., 2018; Osawa et al., 2019; Lin et al., 2020), with some specifically addressing the constraint on the natural parameter (Lin et al., 2020). An alternative is to parameterize the variational distribution with an unconstrained parameter, e.g., the mean and the covariance square root. Refer to Salimbeni et al. (2018) for more examples of parameterizations. As a side effect, changing the parameterization also changes the ELBO landscape and may slow down the convergence. 4. Convergence of Stochastic NGVI Even though NGVI is known to converge in one step for conjugate likelihoods, it generally does not in the stochastic setting. This section aims to establish a convergence rate of stochastic NGVI for conjugate likelihoods. The main techniques we will use are recent developments of stochastic mirror descent for relatively smooth and strongly convex functions (Lu et al., 2018; Hanzely & Richt arik, 2021). Definition 6 (Hanzely & Richt arik, 2021). Given the step sizes {γt} t=0 and the iterates {ωt} t=0 generated by the updates (9), we define the gradient variance at the step t as 1 γt E[ b ℓ(ωt) ℓ(ωt), ωt+1, ωt+1 | ωt], (12) where ωt+1, = argminω Ω ℓ(ωt) ω + 1 γt DA (ω, ωt) and the conditional expectation is taken over the randomness of the stochastic gradient b ℓ(ωt). Note that the gradient variance (12) reduces to the familiar one E b ℓ(ωt) ℓ(ωt) 2 for gradient descent updates ωt+1, = ωt γt ℓ(ωt) and ωt+1 =ωt γt b ℓ(ωt). For mirror descent, however, (12) is a generalization that does not depend on a norm. The norm-independency is crucial for our setting. Common stochastic mirror descent analyses require the distance generating function Φ to be strongly convex w.r.t. a norm and then measure the gradient variance in the dual norm (e.g., Bubeck, 2015; Lan, 2020; Liu et al., 2023; Nguyen et al., 2023; Fatkhullin & He, 2024). However, as shown in A.1, the conjugate of the log-partition function A is not strongly convex w.r.t. any norms, which prevents us from measuring the gradient variance with a norm. The absence of strong convexity in the distance generating function A may partially explain why a precise convergence rate of stochastic natural gradient variational inference is not developed over the years. Lemma 2. For conjugate likelihoods, the negative ELBO ℓ(ω) is 1-smooth 1-strongly convex relative to the convex conjugate A of the log-partition function. The relative 1-smoothness and 1-strong convexity imply that the negative ELBO is a well-conditioned objective. Besides, the first-order approximation at an arbitrary ωt Ωis exact: ℓ(ω) = ℓ(ωt) + ℓ(ωt), ω ωt + DA (ω, ωt). With the exact gradient ℓ(ω), the mirror descent update (9), which minimizes the first-order approximation, converges in one step with the step size γt = 1. However, onestep convergence is generally not possible in the stochastic setting. Next, we present a general convergence rate that holds for all conjugate likelihoods the prior p(z) and the likelihood p(y | z) are chosen such that the posterior p(z | y) is in the same exponential family as the prior. Understanding Stochastic Natural Gradient Variational Inference Assumption 1. The stochastic gradient b ℓ(ωt) 1. respects the domain: ηt+1 D for all t 0 in (8); 2. is unbiased: E[b ℓ(ωt) | ωt] = ℓ(ωt); 3. has bounded variance: (12) is bounded by V > 0. Theorem 1. Suppose the likelihood p(y | z) is conjugate and the stochastic gradient b ℓ(ωt) satisfies Assumption 1. Running T + 1 iterations of stochastic natural gradient descent with γt = 2 2+t generate a point ωT +1 that satisfies E[ℓ( ωT +1)] min ω Ωℓ(ω) V T + 2, (13) where ωT +1 = 2 (T +1)(T +2) PT t=0(t + 1)ωt+1. Let q T +1 be the variational distribution represented by ωT +1. Then, the KL divergence to the true posterior q is bounded by E[DKL( q T +1, q )] V T + 2. (14) We make two observations on the rate (13). First, the rate interpolates between stochastic and deterministic settings. In particular, zero variance V = 0 implies convergence in one step. Second, the convergence rate does not depend on the distance from the initialization q0 to the true posterior q . This leads to an interesting interpretation: no matter how far away the initialization is to the true posterior, after the first iteration ω1 always goes to a sublevel set whose size only depends on the variance V . Both properties are due to the step size schedule γt = 2 2+t, in particular γ0 = 1. In general, linearly decreasing step sizes also guarantee convergence, but may lose these two properties. It is not entirely clear if the conditions in Assumption 1 hold in practice at all. In particular, Assumption 1 requires the gradient variance (12), defined in a non-standard form, to be bounded. The rest of this section is devoted to this question by a case study of a common conjugate variational inference problem, where we show all conditions in Assumption 1 indeed hold in practice. Example 1 (Bayesian Linear Regression). Consider p(z) = N(0, P), p(y | X, z) = N(Xz, σ2I), where the prior p(z) is a zero-mean Gaussian and the label y has an independent Gaussian observation noise. The negative ELBO can be written as a finite sum ℓ(q) = Eq(z)[log p(y | X, z)] + DKL(q(z), p(z)) i=1 Eq(z) log p(yi | xi, z) + DKL(q(z), p(z)), where {xi}n i=1 are the rows of X Rn d. Without loss of generality, we assume that the variational distribution is initialized as a standard normal distribution q0 = N(0, I). Data Sub-Sampling Stochastic Gradient. Each iteration samples m data points uniformly and independently: xi1, xi2, , xim. Each index ik is independently sampled from the uniform distribution U[n]. The stochastic natural gradient b ℓ(ω) is k=1 Eq log p(yik | xik, z)+DKL(q, p) , (15) where p(yik | xik, z) = N(z xik, σ2) and the expectation Eq(z) log p(yik | xik, z) is computed in a closed-form. Each stochastic NGD update can be computed in O(d2m), while the closed-form posterior of Bayesian linear regression takes O(d2n + d3) to compute. Approximating the posterior via stochastic NGD is more practical for large datasets. Indeed, it is widely used in variational Gaussian processes (e.g., Hensman et al., 2013; Salimbeni et al., 2018) where n might be too large to even fit the data into the memory. Now we verify the conditions in Assumption 1. For each i [n], the second component Ξ of the gradient ξ, Ξ = ωEq log p(yi | xi, z) is negative definite (see C.1). By Remark 2, the stochastic gradient (15) indeed respects the domain D and results in valid NGD updates, as long as 0 γt 1. It is clearly unbiased as each data point xik is sampled uniformly. Lastly, its variance is bounded: Lemma 3. The stochastic gradient (15) satisfies 1 γt E[ b ℓ(ωt) ℓ(ωt), ωt+1, ωt+1 | ωt] V2, (16) where V2 = (νs1+ 1 2ν2s2+2ν2b s1s2n+ν3b2s2n2) n2 σ4m, with ν = max{1, P }, b = max1 i n yixi , and the empirical variances s1 = Ej U[n] yjxj 1 n Pn i=1 yixi 2. and s2 = Ej U[n] xjx j 1 n Pn i=1 xix i 2 F. The constant in Lemma 3 is not necessarily tight and may be improved. Nevertheless, it serves the purpose to show that the gradient variance is bounded by a constant. Application to Gaussian Process Regression. Our result immediately applies to stochastic variational Gaussian processes (SVGP) (Hensman et al., 2013), a popular large-scale Gaussian process regression model. SVGP training minimizes the negative ELBO of the form Z p(f | u)q(u) log p(y | f) df du + DKL(q(u), p(u)), where the variational distribution is q(u) with the likelihood p(y | f) = N(y; 0, σ2I), p(f | u) = N(f; Kfu K 1 uuu, Kff Kfu K 1 uu Kuf), Understanding Stochastic Natural Gradient Variational Inference and the prior p(u) = N(u; 0, Kuu). Simplify the ELBO by removing terms independent of q(u) gives Z q(u) log N(y; Kfu K 1 uuu, σ2I) du +DKL(q(u), p(u)). Hence, finding the optimal variational distribution q(u) is equivalent to Bayesian linear regression in Example 1 with X = Kfu K 1 uu and the prior covariance P = Kuu. Even though the optimal variational distribution q has a closedform, computing it exactly needs to access the entire dataset. Besides, q varies after every GP hyperparameter update, and it is expensive to compute q exactly every iteration. Thus, a popular approach is jointly minimizing the variational parameters and the hyperparameters by mini-batch stochastic optimization. Lemma 3 together with Theorem 1 gives a convergence rate of the variational distribution in SVGP training. The convergence rate may also find applications in some collapsed variational inference methods (e.g., Hensman et al., 2012), where NGD is applied to a subset of latent variables in an conjugate exponential family. 5. ELBO Landscape In the last section, we have seen that the (negative) ELBO ℓ(ω), as a function of the expectation paremters ω, has good properties when the likelihood is conjugate (see Lemma 2). These properties are crucial for the convergence analysis. The natural question is whether the ELBO preserves these properties for non-conjugate likelihoods. This section studies variational inference with a Gaussian prior p(z), a Gaussian variational family Q, and a non Gaussian (i.e., non-conjugate) likelihood p(y | z). Surprisingly, we show that even when the likelihood is logconcave, the ELBO ℓ(ω) is not guaranteed to be convex in the expectation parameter. This is in sharp contrast to the mean-square-root parameterization (m, C), with m and C representing the mean and the Cholesky factor respectively, used in stochastic gradient, where the ELBO is smooth and strongly convex (Domke, 2020). Below we give two examples (with details in E) where the negative ELBO ℓ(ω) is non-convex in the expectation parameter ω, even for simple log-concave likelihoods. To show the objective is non-convex, all we need to do is to find a dataset such that the negative ELBO is non-convex. Logistic Regression. Consider an 1-dimensional Bayesian logistic regression on the dataset {(xi, yi)}n i=1 with xi [ 1, 1] and yi { 1, 1}. The prior p(w, b) on the weight w and the bias b is a standard Gaussian distribution. The negative ELBO ℓ(ω) is i=1 log 1 + exp( yi(wxi + b)) # + DKL(q, p), where q(w, b) is a Gaussian variational distribution. Restrict the expectation parameter ω of q on the convex subset {ω=(0, Ξ) : Ξ = diag(s1, s2), s1 > 0, s2 > 0} Ω, where the first component of ω is zero and the second component is a diagonal matrix. If ℓ(ω) was convex in ω, it would be convex in s2 at least. Taking the second-order derivative with respect to s2, we have i=1 E[ψi(1 ψi)(6ψ2 i 6ψi + 1)] + 1 2s2 2 , where ψi = ψ(wxi + b) with ψ the sigmoid function and the expectation is taken over (w, b) qω. Note that the expectation is negative in the limit: lim s1,s2 0 Eq(w,b)[ψi(1 ψi)(6ψ2 i 6ψi + 1)] = 1 In particular, there exists an absolute constant δ > 0 such that when s1 = s2 = δ we have ψi(1 ψi)(6ψ2 i 6ψi + 1) < 1 for all 1 i n. This implies 2 s2ℓ(ω) < 0 when n 8/δ2 for a particular ω = (0, Ξ) with Ξ = diag(δ, δ). Poisson Regression. We choose this example because of its analytical ELBO. Bayesian Poisson regression assumes that y | x follows a Poisson distribution with the expectation E[y | x] = exp(w x). The prior p(w) on the weight w is a standard Gaussian distribution. Let ω = (ξ, Ξ) be the expectation parameter of the Gaussian variational distribution q. The expected log likelihood Eq(w) log p(y | X, w) is h exp ξ xi + 1 2x i Ξ ξξ xi i , which is not convex in ξ. Compute the Hessian of ℓ(ω) w.r.t. ξ and evaluate it on the subset of the domain {ω=(ξ, Ξ) : Ξ = ξξ + 2I} Ω. Then, we obtain i=1 (x i ξ)(x i ξ 2)xix i + 2 ξA (ω). For a fixed ω = (ξ, Ξ), there exists a dataset {(xi, yi)}n i=1 such that 0 < x i ξ < 2 for all 1 i n. Now, consider the Hessian 2 ξℓ(ω) on the scaled dataset {(cxi, yi)}n i=1 evaluated at 1 cξ. As c , we have found a dataset such that the Hessian is negative. Understanding Stochastic Natural Gradient Variational Inference 101 103 105 iterations DKL(qt, q ) 101 103 105 iterations Figure 1: Mini-batch Bayesian linear regression on the Bike dataset. Left: The KL divergence to the optimum q . Right: The training negative log predictive density. Recent work demonstrates that the negative ELBO is convex with a log-concave likelihood if the variational distribution is a Gaussian distribution with the mean-square-root parameterization (Domke, 2020). However, the above two examples show that it is not the case for the expectation parameter ω. Given that the canonical implementation of NGD is equivalent to mirror descent in the expectation parameter space, NGD may implicitly optimize a non-convex objective when the likelihood is non-conjugate even for simple log-concave likelihoods. Nonetheless, the negative ELBO does have some convenient properties for log-concave likelihoods it is not an arbitrary non-convex objective. Proposition 1. Suppose the prior and the variational family are both Gaussians. If the likelihood p(y | z) is log-concave in z, then the negative ELBO ℓ(ω) as a function of the expectation parameter has a unique minimizer ω . In addition, if the likelihood p(y | z) is differentiable in z, then ω is the unique stationary point of ℓ(ω). Proposition 1 is not surprising, since there is a differentiable bijection between the expectation parameterization and the mean-square-root parameterization. The uniqueness of the minimizer and the stationary point is derived from strongly convexity in the mean-square-root parameterization. Thus, stochastic NGVI may still converge to the optimum with log-concave likelihoods despite the non-convexity. We end this section by discussing some implications. With non-conjugate likelihoods, the negative ELBO ℓ(ω) is not strongly convex nor relatively strongly convex, since it is not even convex. Strong convexity plays a crucial role in stochastic optimization. Without it, stochastic gradient descent has a convergence rate of O(1/ T) under standard assumptions. This rate is improved to O(1/T) for strongly convex functions. The fact that stochastic NGVI is implicitly optimizing a non-convex objective implies that we may need to resort to new properties of the ELBO to prove its O(1/T) convergence rate for non-conjugate likelihoods, if it can achieve this rate at all. One possibility to achieve this is the Polyak-Łojasiewicz inequality (Karimi et al., 2016). 6. Numerical Simulation This section presents supporting numerical simulations on datasets from the UCI repository (Bike and Mushroom) and MNIST (Kelly et al., 2017; Le Cun et al., 1998) 6.1. Bayesian Linear Regression Figure 1 presents Bayesian linear regression on the Bike dataset (n = 17, 389), with a standard normal prior and a noise σ2 = 1. The (negative) ELBO is optimized by SGD and stochastic NGD with a mini-batch size of 1000. SGD uses a step size schedule γt = 1 105+t, a linearly decreasing schedule on the same order as Domke et al. (2023, Theorem 10). Stochastic NGD uses a step size schedule γt = 2 2+t predicted by Theorem 1. The true posterior q of Bayesian linear regression has a closed-form, which allows us to plot the optimality gap in the KL divergence. In addition, we plot the negative predictive log density (NLPD) on the training set. In the log-log scale, the KL divergences to the optimal posterior of both methods decrease at the same rate, with roughly the same slope in the figure. This suggests that both methods have the same O( 1 T ) complexity, and that stochastic NGD may be only constant times faster than SGD. Nonetheless, stochastic NGD converges very fast in the early stage. It takes SGD thousands of iterations to catch up the progress that stochastic NGD makes in the first few iterations, implying that stochastic NGD has a much better constant factor in the big O notation. Indeed, recall that the convergence rate in Theorem 1 only depends on the stochastic gradient variance, independent of the objective s condition number and the distance from the initialization to the optimum (see 4). 6.2. Non-Conjugate Likelihoods Figure 2 shows Bayesian logistic regression on the Mushroom dataset (n = 8124) and MNIST (a subset of 1 and 7 with n 13, 000 images). Again, stochastic NGD is faster than SGD, but the improvement is less drastic compared with conjugate likelihoods. This is consistent with previous empirical observations (Salimbeni et al., 2018). Besides faster convergence, it appears that the step size of stochastic NGD is easier to tune in practice. In most cases, the step size γ = 0.1 convergences smoothly. Sometimes γ =0.1 is too large such that stochastic NGD oscillates in the final stage (Figure 2 right panel). Simply decreasing it to γ =0.01 leads to smooth convergence in most cases. These observations suggest that the ELBO ℓ(ω), as a function of the expectation parameter, might have a small smoothness constant. Indeed, the smoothness constant is 1 for conjugate likelihoods (recall Lemma 2). For non-conjugate likelihoods in practice, we hypothesize its smoothness constant might be close to 1 as well. Understanding Stochastic Natural Gradient Variational Inference 0 200 400 iterations SGD 1e-03 (p) SGD 1e-05 (p) SGD 1e-05 (r) NGD 1e-01 (p) NGD 1e-02 (p) NGD 1e-03 (p) 0 5000 10000 iterations SGD 1e-06 (p) SGD 1e-07 (p) SGD 1e-07 (r) NGD 1e-02 (p) NGD 1e-03 (p) Figure 2: Bayesian logistic regression on Mushroom and MNIST. Labels with (p) use the stochastic gradient by the Price theorem (11). Labels with (r) use the stochastic gradient by the reparamerization trick. We point out a side note that the Price stochastic gradient (11) is a high-quality gradient estimator superior to the reparameterization trick. For instance, in the special case when the log likelihood log p(y | z) is a quadratic function in z, e.g., p(y | z) = N(z, I), the Price stochastic gradient b Ξ is exact and has zero variance! For general non-conjugate likelihoods, we expect the Price stochastic gradient has lower variance. Indeed, SGD has a dramatic improvement by just switching to the Price stochastic gradient. 7. Related Work Natural gradient descent was initially proposed by Amari (1998) as a learning algorithm for multi-layer perceptrons that is believed to exploit the information geometry. Subsequently, this method has been applied to variational inference (e.g., Honkela & Valpola, 2004; Hensman et al., 2012; Hoffman et al., 2013; Khan & Lin, 2017). Recent new developments of natural gradient variational inference include generalization to mixtures of exponential families (Lin et al., 2019; Arenz et al., 2023), handling the positive definite domain constraint (Lin et al., 2020), supporting structured matrix parameterization (Lin et al., 2021), implementation via automatic differentiation (Salimbeni et al., 2018), adaptations to online learning (Ch erief-Abdellatif et al., 2019), and generalization to Wasserstein statistical manifold (Chen & Li, 2020; Li & Zhao, 2023). Outside variational inference, natural gradient descent has been applied to training (non-Bayesian) neural networks in supervised learning (e.g. Bernacchia et al., 2018; Song et al., 2018; Zhang et al., 2019). Interestingly, Zhang et al. (2018) establish a connection between training neural networks with noisy natural gradient and variational inference. For a comprehensive survey of this area, we direct readers to the monograph by Martens (2020). In particular, Martens (2020) hypothesize a O( 1 T ) asymptotic convergence rate via an argument based on Fisher efficiency. They also gave a non-asymptotic convergence rate of O( 1 T ) for stochastic preconditioned gradient descent with a fixed preconditioning matrix and a quadratic objective. In addition, natural gradient descent has been applied to policy optimization in reinforcement learning, leading to natural policy gradient (Kakade, 2001). With the connection to mirror descent, there is a recent interest in this method that leads to a series of analyses (e.g., Geist et al., 2019; Shani et al., 2020; Agarwal et al., 2021; Khodadadian et al., 2021; Xiao, 2022; Yuan et al., 2023). The natural gradient methods applied to different machine learning problems mentioned previously share a common feature: the gradient direction is preconditioned with the Fisher information matrix. Despite being coined with the same name natural gradient , we point out a subtle difference in natural gradient variational inference. The distance generating function A in NGVI, namely the log-partition function s conjugate, is the negative differential entropy that is non-strongly convex and non-smooth, as shown in A.1. In contrast, the distance generating function in natural policy gradient is the negative Shannon entropy, which is well-known to be strongly convex w.r.t. a norm (e.g., Bubeck, 2015, Section 4.3). As mentioned in 4, this strong convexity is a key condition for mirror descent analyses in the stochastic setting. We hope our work motivates new developments in stochastic mirror descent for non-strongly convex non-smooth distance generating functions. 8. Conclusion Over the years, empirical observations suggest stochastic natural gradient descent (NGD) is faster than stochastic gradient descent for variational inference. To understand how fast NGD converges, we prove the first O( 1 T ) nonasymptotic convergence rate for conjugate likelihoods. The rate appears to be tight based on experiments, suggesting that stochastic natural gradient variational inference (NGVI) may be only constant times faster than stochastic gradient descent. Nevertheless, the constant improvement could be dramatic in practice. For non-conjugate likelihoods, we show that canonical stochastic NGVI implicitly optimizes a non-convex objective, which suggests that a O( 1 T ) rate is unlikely without discoveries of new properties of the ELBO. Acknowledgements The authors thank the anonymous reviewers for constructive feedbacks. KW would like to thank Kyurae Kim for helpful discussions in the early stage of this work. KW and JRG are supported by NSF award IIS-2145644. Impact Statement This paper presents work whose goal is to advance the field of machine learning. There might be many potential societal consequences of our work, none of which we feel must be specifically highlighted here. Understanding Stochastic Natural Gradient Variational Inference Agarwal, A., Kakade, S. M., Lee, J. D., and Mahajan, G. On the theory of policy gradient methods: Optimality, approximation, and distribution shift. Journal of Machine Learning Research, 22(98):1 76, 2021. 9 Amari, S.-I. Natural gradient works efficiently in learning. Neural computation, 10(2):251 276, 1998. 1, 9 Arenz, O., Dahlinger, P., Ye, Z., Volpp, M., and Neumann, G. A unified perspective on natural gradient variational inference with gaussian mixture models. Transactions on Machine Learning Research, 2023. ISSN 2835-8856. 9 Bauschke, H. H., Bolte, J., and Teboulle, M. A descent lemma beyond lipschitz gradient continuity: First-order methods revisited and applications. Mathematics of Operations Research, 42(2):330 348, 2017. 4 Bernacchia, A., Lengyel, M., and Hennequin, G. Exact natural gradient in deep linear networks and its application to the nonlinear case. volume 31, 2018. 9 Birnbaum, B., Devanur, N. R., and Xiao, L. Distributed algorithms via gradient descent for Fisher markets. In Proceedings of the 12th ACM conference on Electronic commerce, pp. 127 136, 2011. 4 Blei, D. M., Kucukelbir, A., and Mc Auliffe, J. D. Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518):859 877, 2017. 1 Bonnet, G. Transformations des signaux al eatoires a travers les systemes non lin eaires sans m emoire. In Annales des T el ecommunications, volume 19, pp. 203 220. Springer, 1964. 4 Bottou, L. Online learning and stochastic approximations. Online learning in neural networks, 17(9):142, 1998. 1 Bubeck, S. Convex optimization: Algorithms and complexity. Foundations and Trends in Machine Learning, 8 (3-4):231 357, 2015. 5, 9 Chen, Y. and Li, W. Optimal transport natural gradient for statistical manifolds with continuous sample space. Information Geometry, 3(1):1 32, 2020. 9 Ch erief-Abdellatif, B.-E., Alquier, P., and Khan, M. E. A generalization bound for online variational inference. In Proceedings of The Eleventh Asian Conference on Machine Learning, volume 101, pp. 662 677. PMLR, 2019. 9 Domke, J. Provable smoothness guarantees for black-box variational inference. In Proceedings of the 37th International Conference on Machine Learning, volume 119, pp. 2587 2596. PMLR, 2020. 1, 7, 8 Domke, J., Gower, R. M., and Garrigos, G. Provable convergence guarantees for black-box variational inference. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. 1, 8, 24 Fatkhullin, I. and He, N. Taming nonconvex stochastic mirror descent with general Bregman divergence. In Proceedings of The 27th International Conference on Artificial Intelligence and Statistics, volume 238, pp. 3493 3501. PMLR, 2024. 5 Geist, M., Scherrer, B., and Pietquin, O. A theory of regularized Markov decision processes. In Proceedings of the 36th International Conference on Machine Learning, volume 97, pp. 2160 2169. PMLR, 2019. 9 Hanzely, F. and Richt arik, P. Fastest rates for stochastic mirror descent methods. Computational Optimization and Applications, 79:717 766, 2021. 4, 5, 21 Hensman, J., Rattray, M., and Lawrence, N. Fast variational inference in the conjugate exponential family. In Advances in Neural Information Processing Systems, volume 25, 2012. 1, 7, 9 Hensman, J., Fusi, N., and Lawrence, N. D. Gaussian processes for big data. In Proceedings of the Twenty Ninth Conference on Uncertainty in Artificial Intelligence, pp. 282 290, 2013. 1, 2, 6 Hensman, J., Matthews, A., and Ghahramani, Z. Scalable variational gaussian process classification. In Artificial Intelligence and Statistics, pp. 351 360, 2015. 1, 2 Hoffman, M. D., Blei, D. M., Wang, C., and Paisley, J. Stochastic variational inference. Journal of Machine Learning Research, 14(40):1303 1347, 2013. 1, 2, 9 Honkela, A. and Valpola, H. Unsupervised variational Bayesian learning of nonlinear models. In Advances in Neural Information Processing Systems, volume 17, 2004. 1, 9 Kakade, S. M. A natural policy gradient. In Advances in Neural Information Processing Systems, volume 14. MIT Press, 2001. 9 Karimi, H., Nutini, J., and Schmidt, M. Linear convergence of gradient and proximal-gradient methods under the polyak-łojasiewicz condition. In Machine Learning and Knowledge Discovery in Databases, pp. 795 811. Springer International Publishing, 2016. 8 Kelly, M., Longjohn, R., and Nottingham, K. The UCI machine learning repository, 2017. URL https:// archive.ics.uci.edu. 8 Understanding Stochastic Natural Gradient Variational Inference Khan, M. and Lin, W. Conjugate-computation variational inference: Converting variational inference in nonconjugate models to inferences in conjugate models. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, volume 54, pp. 878 887, 2017. 3, 5, 9, 13 Khan, M., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In Proceedings of the 35th International Conference on Machine Learning, volume 80, pp. 2611 2620, 2018. 1, 3, 5 Khan, M. E., Babanezhad, R., Lin, W., Schmidt, M., and Sugiyama, M. Faster stochastic variational inference using proximal-gradient methods with general divergence functions. In Proceedings of the Thirty-Second Conference on Uncertainty in Artificial Intelligence, 2016. 1 Khan, M. E. E., Baque, P., Fleuret, F., and Fua, P. Kullbackleibler proximal variational inference. In Advances in Neural Information Processing Systems, volume 28. Curran Associates, Inc., 2015. 5 Khodadadian, S., Jhunjhunwala, P. R., Varma, S. M., and Maguluri, S. T. On the linear convergence of natural policy gradient algorithm. In 2021 60th IEEE Conference on Decision and Control (CDC), pp. 3794 3799. IEEE Press, 2021. 9 Kim, K., Oh, J., Wu, K., Ma, Y., and Gardner, J. R. On the convergence of black-box variational inference. In Advances in Neural Information Processing Systems, volume 36, 2023. 1 Kingma, D. P. and Welling, M. Auto-encoding variational Bayes. In The First International Conference on Learning Representations, 2013. 4 Lan, G. First-order and stochastic optimization methods for machine learning, volume 1. Springer, 2020. 5 Le Cun, Y., Cortes, C., and Burges, C. J. The MNIST database, 1998. URL https://archive.ics.uci. edu. 8 Li, W. and Zhao, J. Wasserstein information matrix. Information Geometry, 6(1):203 255, 2023. 9 Lin, W., Khan, M. E., and Schmidt, M. Fast and simple natural-gradient variational inference with mixture of exponential-family approximations. In Proceedings of the 36th International Conference on Machine Learning, volume 97, pp. 3992 4002, 2019. 9 Lin, W., Schmidt, M., and Khan, M. E. Handling the positive-definite constraint in the Bayesian learning rule. In Proceedings of the 37th International Conference on Machine Learning, volume 119, pp. 6116 6126, 2020. 5, 9 Lin, W., Nielsen, F., Emtiyaz, K. M., and Schmidt, M. Tractable structured natural-gradient descent using local parameterizations. In Proceedings of the 38th International Conference on Machine Learning, volume 139, pp. 6680 6691. PMLR, 2021. 9 Liu, Z., Nguyen, T. D., Nguyen, T. H., Ene, A., and Nguyen, H. High probability convergence of stochastic gradient methods. In Proceedings of the 40th International Conference on Machine Learning, volume 202, pp. 21884 21914. PMLR, 2023. 5 Lojasiewicz, S. A topological property of real analytic subsets. Coll. du CNRS, Les equations aux d eriv ees partielles, 117(87-89):2, 1963. 2 Lu, H., Freund, R. M., and Nesterov, Y. Relatively smooth convex optimization by first-order methods, and applications. SIAM Journal on Optimization, 28(1):333 354, 2018. 4, 5 Martens, J. New insights and perspectives on the natural gradient method. Journal of Machine Learning Research, 21(146):1 76, 2020. 9 Murray, I. Differentiation of the cholesky decomposition. ar Xiv preprint ar Xiv:1602.07527, 2016. 4 Nemirovskij, A. S. and Yudin, D. B. Problem complexity and method efficiency in optimization. 1983. 3 Nguyen, T. D., Nguyen, T. H., Ene, A., and Nguyen, H. Improved convergence in high probability of clipped gradient methods with heavy tailed noise. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. 5 Nielsen, F. and Garcia, V. Statistical exponential families: A digest with flash cards. ar Xiv preprint ar Xiv:0911.4863, 2009. 13 Opper, M. and Archambeau, C. The variational Gaussian approximation revisited. Neural computation, 21(3):786 792, 2009. 4 Osawa, K., Swaroop, S., Khan, M. E. E., Jain, A., Eschenhagen, R., Turner, R. E., and Yokota, R. Practical deep learning with Bayesian principles. In Advances in neural information processing systems, volume 32, 2019. 1, 5 Polyak, B. T. et al. Gradient methods for minimizing functionals. Zhurnal vychislitel noi matematiki i matematicheskoi fiziki, 3(4):643 653, 1963. 2 Understanding Stochastic Natural Gradient Variational Inference Price, R. A useful theorem for nonlinear devices having Gaussian inputs. IRE Transactions on Information Theory, 4(2):69 72, 1958. 4 Ranganath, R., Gerrish, S., and Blei, D. Black Box Variational Inference. In Proceedings of the Seventeenth International Conference on Artificial Intelligence and Statistics, volume 33, pp. 814 822, 2014. 1 Raskutti, G. and Mukherjee, S. The information geometry of mirror descent. IEEE Transactions on Information Theory, 61(3):1451 1457, 2015. 3 Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic backpropagation and approximate inference in deep generative models. In Proceedings of the 31st International Conference on Machine Learning, volume 32, pp. 1278 1286, 2014. 4 Salimbeni, H., Eleftheriadis, S., and Hensman, J. Natural gradients in practice: Non-conjugate variational inference in gaussian process models. In International Conference on Artificial Intelligence and Statistics, pp. 689 697, 2018. 1, 2, 5, 6, 8, 9 Shani, L., Efroni, Y., and Mannor, S. Adaptive trust region policy optimization: Global convergence and faster rates for regularized mdps. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pp. 5668 5675, 2020. 9 Song, Y., Song, J., and Ermon, S. Accelerating natural gradient with higher-order invariance. In Proceedings of the 35th International Conference on Machine Learning, volume 80, pp. 4713 4722, 2018. 9 Titsias, M. and L azaro-Gredilla, M. Doubly stochastic variational bayes for non-conjugate inference. In Proceedings of the 31st International Conference on Machine Learning, volume 32, pp. 1971 1979, 2014. 4 Wainwright, M. J. and Jordan, M. I. Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 1(1 2):1 305, 2008. 2 Xiao, L. On the convergence rates of policy gradient methods. Journal of Machine Learning Research, 23(282): 1 36, 2022. 9 Yuan, R., Du, S. S., Gower, R. M., Lazaric, A., and Xiao, L. Linear convergence of natural policy gradient methods with log-linear policies. In The Eleventh International Conference on Learning Representations, 2023. 9 Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy natural gradient as variational inference. In Proceedings of the 35th International Conference on Machine Learning, volume 80, pp. 5852 5861. PMLR, 2018. 5, 9 Zhang, G., Martens, J., and Grosse, R. B. Fast convergence of natural gradient descent for over-parameterized neural networks. In Advances in Neural Information Processing Systems, volume 32, 2019. 9 Understanding Stochastic Natural Gradient Variational Inference A. Exponential Family The following useful lemma is well-known (e.g., Nielsen & Garcia, 2009), which hints the connection between NGD and mirror descent. For the sake of completeness, we present a proof here, which largely follows Khan & Lin (2017, Lemma 2). Lemma 4. Let q(z) and q (z) be distributions in the same exponential family. That is, they share the same base measure h(z), sufficient statistics ϕ(z), and the log-partition function A(η). Let ω and ω be the expectation parameters of q(z) and q (z) respectively. Then, we have DA (ω, ω ) = DA(η , η) = DKL(q, q ). where DA and DA are the Bregman divergences associated with A and A , respectively. Proof. The first equality is a standard property of the Bregman divergence, since η and ω are dual to each other. Only the second equality needs a proof: DKL(q, q ) = Eq(z)[log q(z) log q (z)] = Eq(z)[log h(z) + ϕ(z), η A(η)] Eq(z)[log h(z) + ϕ(z), η A(η )] = Eq(z)[ϕ(z)], η η A(η) + A(η ) = A(η ) A(η) ω, η η = A(η ) A(η) A(η), η η = DA(η , η), where the second line uses the definition of the exponential family; the forth line uses the definition of the expectation parameter; the fifth line uses the duality between the natural and expectation parameters (4); and the last line uses the definition of the Bregman divergence. Let ω and ω be the expectation parameters of q and q respectively. Lemma 4 gives a simple expression for the derivative of DKL(q, q ) w.r.t. the expectation parameter ω: d dω DKL(q, q ) = d dω DA (ω, ω ) = A (ω) A (ω ) where the second line is a standard property of the Bregman divergence and the third line is because A ( ) maps a expectation parameter to its corresponding natural parameter. A.1. Gaussian Distributions: The Log-Partition Function This section gives an explicit expression of the log-partition function A(η) of the Gaussian distribution and its convex conjugate A . In addition, we show that the convex conjugate A is non-smooth and non-strongly convex. Let q(z) = N(z; µ, Σ) be a d-dimensional Gaussian with the mean µ and the covariance Σ. Its density is of the form q(z; η) exp 1 2(z µ) Σ 1(z µ) 1 2 log det(Σ) = exp z, Σ 1µ + zz , 1 2 log det(Σ) . The sufficient statistic is the map ϕ : z 7 (z, zz ). The natural parameter η = (λ, Λ) satisfies λ = Σ 1µ and Λ = 1 2Σ 1. The log-partition function A( ) as a function of the mean and the covariance is A(µ, Σ) = 1 2µ Σ 1µ + 1 2 log det(Σ). Understanding Stochastic Natural Gradient Variational Inference Plug in the relation between (µ, Σ) and the natural parameter η = (λ, Λ). We obtain an explicit expression of the log-partition function: A(λ, Λ) = 1 2 log det( 2Λ), (17) where λ Rd and Λ 0. The convex conjugate A as a function of the expectation parameter ω = (ξ, Ξ) is A (ξ, Ξ) = sup λ Rd,Λ 0 ξ, λ + Ξ, Λ + 1 4λ Λ 1λ + 1 2 log det( 2Λ) 2 log det Ξ ξξ , (18) where the second line solves the maximization by taking the derivative and setting it equal to zero. The constraint on the expectation parameter (ξ, Ξ) is Ξ ξξ 0. Consider the restriction of A on the convex set {ω=(ξ, Ξ) : ξ = 0, Ξ 0}. Clearly, A is already non-strongly convex and non-smooth in Ξ: in one dimension d2 dx2 ( log x) = 1 x2 is neither lower nor upper bounded. Since the absolute value is the only norm (up to a constant) in one dimension, A is not strongly convex w.r.t. any norms. Finally, both A and A are non-smooth and non-strongly convex due to the duality between smoothness and strong convexity. A.2. Gaussian Distributions: Conversion between the Natural and Expectation Parameters This section gives explicit expressions of A and A of Gaussian distributions. These maps convert between the natural parameter η = (λ, Λ) and the expectation parameter ω = (ξ, Ξ). Differentiating the log-partition function (17) gives λA(λ, Λ) = 1 ΛA(λ, Λ) = 1 4Λ 1λλ Λ 1 1 The gradient map exactly transforms the natural parameter to the expectation parameter, in that λA(λ, Λ) = ξ, ΛA(λ, Λ) = Ξ. Similarly, differentiating the conjugate A (18) gives ξA (ξ, Ξ) = Ξ ξξ 1ξ, ΞA (ξ, Ξ) = 1 which transform the expectation parameter back to the natural parameter: ξA (ξ, Ξ) = λ, ΞA (ξ, Ξ) = Λ. For a Gaussian distribution N(µ, Σ), recall the relation between the mean/covariance and the natural parameter λ = Σ 1µ, Λ = 1 and the relation between the mean/covariance and the expectation parameter ξ = µ, Ξ = Σ + µµ . One can verify that these relations are indeed consistent with the maps A(η) = ω and A (ω) = η. Understanding Stochastic Natural Gradient Variational Inference B. Automatic Differentiation Stochastic Gradient Counterexample This section gives a counterexample where estimating the gradient of the expected log likelihood b ΞEq(z)[log p(y | z)] using automatic differentiation shown in Algorithm 1 does not guarantee a negative definite stochastic gradient b Ξ. For simplicity, we use a zero-mean Gaussian q(z) = N(0, Σ) so that Ξ = Σ, and a likelihood p(y | z) = N(z, I). The following code produces a stochastic gradient that is not negative definite approximately 50% of the time. 1 import torch 2 from torch.distributions import Multivariate Normal 5 if __name__ == "__main__": 8 sigma = torch.eye(d).requires_grad_() 9 chol = torch.linalg.cholesky(sigma) 11 u = torch.randn(d) 12 z = chol @ u 14 dist = Multivariate Normal(z, torch.eye(d)) 15 y = torch.zeros(d) 17 loss = dist.log_prob(y) 18 loss.backward() 20 print(loss.item()) 21 print(sigma.grad) 23 # check the diagonal gradient 24 # print(-0.5 * u ** 2) 26 det = torch.linalg.det(sigma.grad) 28 if det > 0.: 29 print("not n.d.") 31 print("......") C. Stochastic Gradient Variance for Bayesian Linear Regression in Example 1 In this section, we restrict ourselves to Bayesian linear regression in Example 1, and establish bounds on the stochastic gradient variance (12). Recall that the negative ELBO is a sum of the negative expected log likelihood and the KL divergence: ℓ(ω) = ωEq(z)[log p(y | z)] + DKL(q(z), p(z)). As discussed in A, the KL divergence has a simple closed-form gradient available when the variational distribution q(z) and the prior p(z) are both in the same exponential family: ωDKL(q(z), p(z)) = η ηp. where η and ηp are the natural parameters of q(z) and p(z), respectively. Therefore, the stochasticity solely comes from estimating the expected log likelihood, and the stochastic gradient b ℓ(ω) admits the form b ℓ(ω) = b ωEq(z)[log p(y | z)] + ωDKL(q(z), p(z)). Understanding Stochastic Natural Gradient Variational Inference For now, we assume the stochastic gradient of the expected log likelihood (b ξ, b Ξ) = b ωEq(z)[log p(y | z)] has its second component b Ξ negative definite, a sufficient condition for valid stochastic NGD updates. We will show why this is the case in the upcoming section. The next lemma shows that the natural parameter s second component Λt is bounded away from zero throughout the NGD updates, if the stochastic gradient b ΞEq(z) log p(y | z) is always negative definite. Lemma 5. For Bayesian linear regression in Example 1, suppose the stochastic gradient b ΞEqt(z) log p(y | X, z) 0 and the step size 0 γt 1 for all t 0. Then, we have Λt 1 2ν I, or equivalently 1 2Λ 1 t νI, throughout the NGD updates for all t 0, where ν = max{1, P } > 0. Proof. We prove it by induction. The base case Λ0 = 1 2ν I satisfies the inequality trivially. For t 1, recall the NGD update on the natural parameter ηt+1 = (1 γt)ηt + γt b ωEqt(z) log p(y | z) + ηp , where ηp is the natural parameter of the prior. This yields an update on the second component of the natural parameter: Λt+1 = (1 γt)Λt + γt b ΞEqt(z) log p(y | z) + Λp . By the assumption b ΞEq(z) log p(y | z) 0, we have Λt+1 (1 γt)Λt + γtΛp where the second line uses the induction hypothesis and the fact that Λp = 1 The following lemma shows that the matrix inverse is a Lipschitz function in region bounded away from zero. Lemma 6. Suppose Λ1, Λ2 1 2ν I. Then, we have Λ 1 1 Λ 1 2 F 4ν2 Λ1 Λ2 F. Proof. Straightforward calculation gives a proof: Λ 1 1 Λ 1 2 F = Λ 1 1 Λ 1 2 F Λ 1 1 (Λ1 Λ2)Λ 1 2 F Λ 1 1 Λ1 Λ2 F Λ 1 2 4ν2 Λ1 Λ2 F, where the third line uses the inequality AB F A B F. Additional Notations. Let ηt and ωt be the natural and expectation parameters of the Gaussian variational distribution qt at the step t. Hence, we have ω = A(η) and η = A (ω). Define ηt+1, = ηt γt ℓ(ωt) as the natural parameter after a NGD update from ηt using the exact (natural) gradient ℓ(ωt). Recall that ωt+1, = argmin ω Ω ℓ(ωt), ω ωt + 1 γt DA (ω, ωt) is the expectation parameter after a mirror descent update from ωt using the exact gradient ℓ(ωt). Recall the relation ωt+1, = A(ηt+1, ) based on the equivalence of NGD and mirror descent. The components of ηt+1, and ωt+1, , i.e. ηt+1, = (λt+1, , Λt+1, ), ωt+1, = (ξt+1, , Ξt+1, ), are marked with in the subscript as well. Understanding Stochastic Natural Gradient Variational Inference C.1. Data Sub-Sampling Stochastic Gradient The stochastic gradient (15) uses the following the estimate of the expected log likelihood b ωEq(z)[log p(y | z)] = ω k=1 Eq(z) log p(yik | xik, z) k=1 Eq(z) h1 2(yik z xik)2i k=1 Eq(z) h1 2(z xik)2 yikz xik + 1 2 xikx ik, Ξ yikxik, ξ , where we note that the stochastic gradient s second component b ΞEq(z)[log p(y | z)] is indeed negative definite a requirement for the NGD updates to stay inside the domain (recall Assumption 1). We obtain a concrete expression for the data sub-sampling stochastic gradient b ℓ(ω) = b ξℓ(ω), b Ξℓ(ω) as follows: b ξℓ(ωt) = 1 k=1 yikxik + λ λp, b Ξℓ(ωt) = 1 1 2xikx ik + Λ Λp, where η = (λ, Λ) is the natural parameter of the variational distribution q(z) and ηp = (λp, Λp) is the natural parameter of the prior p(z). Meanwhile the exact gradient is i=1 yixi + λ λp, 1 2xix i + Λ Λp. Roadmap. We give a brief overview before diving into the detailed proof of Lemma 3, which involves a large amount of (somewhat tedious) calculation. Lemmas 7 and 8 give bounds on the gradient variances E b ξℓ(ωt) ξℓ(ωt) 2 | ωt and E b Ξℓ(ωt) Ξℓ(ωt) 2 F | ωt measured in the Euclidean norm. These two bounds, however, are not quite enough for the convergence proof, as the desired gradient variance (12) does not depend on a specific norm. Lemmas 10 and 11 bound ξt+1, ξt+1 and Ξt+1, Ξt+1 F with b ξℓ(ωt) ξℓ(ωt) and b Ξℓ(ωt) Ξℓ(ωt) F. Lemma 3 utilizes Lemmas 10 and 11 to reduce the gradient variance (12), a norm-independent one, to the usual gradient variance measured in the Euclidean norm, which is readily tackled by Lemma 7 and Lemma 8. Lemma 7. The following inequality holds: E b ξℓ(ωt) ξℓ(ωt) 2 | ωt = n2 where we recall that s1 = Ej U[n] yjxj 1 n Pn i=1 yixi 2 is the variance of yjxj. Understanding Stochastic Natural Gradient Variational Inference Proof. Straightforward calculation gives a proof: E b ξℓ(ωt) ξℓ(ωt) 2 | ωt = 1 m Ej U[n] yjxj 1 where the third line uses the fact that ik s are independently sampled from the uniform distribution U[n]. Lemma 8. The following inequality holds: E b Ξℓ(ωt) Ξℓ(ωt) 2 F | ωt = 1 m s2 σ4 , (19) where we recall that s2 = Ej U[n] xjx j 1 n Pn i=1 xix i 2 F is the variance of xjx j . Proof. A straightforward calculation gives a proof: E b Ξℓ(ωt) Ξℓ(ωt) 2 F | ωt = E 1 k=1 xikx ik 1 i=1 xix i 2 i=1 xix i 2 m Ej U[n] xjx j 1 i=1 xix i 2 where the third line uses the fact that ik s are sampled independently from the uniform distribution U[n]. Our proof strategy is to relate the desired gradient variance (12) with the gradient variances in Lemmas 7 and 8. To establish the relation, we need to show the natural parameter s first component λt and the expectation parameter s first component ξt are bounded throughout the NGD updates. The trick is to observe that the natural parameter s first component λt stays in a particular region: Lemma 9. Define the convex set C = Pn i=1 ρiyixi : ρi 0, Pn i=1 ρi n . Then, we have λt C and λt+1, C throughout the NGD updates for all t 0. Proof. We prove λt C by induction. The base case t = 0 holds as the initialization q0 = N(0, I) has λ0 = 0, with coefficients ρ1 = ρ2 = = ρn = 0. For t 1, recall the update from t to t + 1: λt+1 = (1 γt)λt + γt b ξEqt(z)[log p(x | z)] + λp = (1 γt)λt + γt b ξEqt(z)[log p(x | z)], where the second line uses λp = 0 since the prior p(z) is a zero-mean Gaussian. Recall that the stochastic gradient of the expected log likelihood at the step t is of the form b ξEqt(z)[log p(x | z)] = n k=1 yikxik, Understanding Stochastic Natural Gradient Variational Inference where ik s are sampled independently and uniformly from {1, 2, , n}. The stochastic gradient b ξEqt(z)[log p(x | z)] is in the convex set C, since the sum of its coefficients is exactly n. Observe that λt+1 is a convex combination of two points in C, and thus stays in C as well. The proof is completed by an induction. The argument for λt+1, C follows similarly, because the exact gradient of ξEqt(z)[log p(x | z)] = Pn i=1 yixi is in the convex set C as well. As a result, we immediately obtain a bound on the first component of the natural parameter: Corollary 1. We have λt bn and λt, bn for all t 0, where b = max1 i n yixi . Proof. Straightforward calculation gives a proof: i=1 ρi yixi bn. The proof for λt, follows the same steps. As a result, we also obtain a bound on the first component of the expectation parameter: Corollary 2. We have ξt νbn and ξt, νbn for all t 0, where b = max1 i n yixi . Proof. Recall the relation between the natural and expectation parameters: ξt = 1 2Λ 1 t λt. Recall that 0 1 2Λ 1 t νI by Lemma 5. Thus, we have ξt ν λt νbn. Lemma 10. We have ξt+1, ξt+1 γtν b ξℓ(ωt) ξℓ(ωt) + 2γtν2bn b Ξℓ(ωt) Ξℓ(ωt) F. Proof. Straightforward calculation gives ξt+1, ξt+1 = 1 2 Λ 1 t+1λt+1 Λ 1 t+1, λt+1, 2 Λ 1 t+1λt+1 Λ 1 t+1λt+1, + Λ 1 t+1λt+1, Λ 1 t+1, λt+1, 2 Λ 1 t+1(λt+1 λt+1, ) + 1 2 (Λ 1 t+1 Λ 1 t+1, )λt+1, , where the first line uses the relation between the natural and expectation parameters. We cope with the two terms separately. For the first term, we have 1 2 Λ 1 t+1(λt+1 λt+1, ) ν λt+1 λt+1, = γtν b ξℓ(ωt) ξℓ(ωt) , where the first inequality uses 1 2Λ 1 t+1 νI by Lemma 5; the second equality uses the definition of the NGD update. For the second term, we have 1 2 (Λ 1 t+1 Λ 1 t+1, )λt+1, 1 2 Λ 1 t+1 Λ 1 t+1, F λt+1, 2ν2 Λt+1 Λt+1, F λt+1, = 2γtν2 b Ξℓ(ωt) Ξℓ(ωt) F λt+1, 2γtν2bn b Ξℓ(ωt) Ξℓ(ωt) F, where the second line uses the Lipschitz condition in Lemma 6; the third line uses the the definition of the NGD update; the last line uses Corollary 1. Summing the two bounds completes the proof. Lemma 11. We have Ξt+1, Ξt+1 F 2γtν2bn b ξℓ(ωt) ξℓ(ωt) + (2γtν2 + 4γtν3b2n2) b Ξℓ(ωt) Ξℓ(ωt) F. Understanding Stochastic Natural Gradient Variational Inference Proof. Expanding the norm, we have Ξt+1, Ξt+1 F = 1 2 Λ 1 t+1, Λ 1 t+1 + ξt+1, ξ t+1, ξt+1ξ t+1 F Λ 1 t+1, Λ 1 t+1 F + ξt+1, ξ t+1, ξt+1ξ t+1 F, where the first line uses the relation between natural and expectation parameters. For the first term, we have Λ 1 t+1, Λ 1 t+1 F 2ν2 Λt+1, Λt+1 F 2γtν2 b Ξℓ(ωt) Ξℓ(ωt) F, where the first line uses Lemma 6; the second line uses the definition of the NGD update. For the second term, we have ξt+1, ξ t+1, ξt+1ξ t+1 F = ξt+1, ξ t+1, ξt+1, ξ t+1 + ξt+1, ξ t+1 ξt+1ξ t+1 F ξt+1, (ξt+1, ξt+1) F + (ξt+1, ξt+1)ξ t+1 F = ξt+1, ξt+1, ξt+1 + ξt+1, ξt+1 ξt+1 2 max{ ξt+1, , ξt+1 } ξt+1, ξt+1 2νbn ξt+1, ξt+1 2γtν2bn b ξℓ(ωt) ξℓ(ωt) + 4γtν3b2n2 b Ξℓ(ωt) Ξℓ(ωt) F where the third line is because ab F = a b ; the fifth line uses Corollary 2; the last line uses Lemma 10. Summing the above two parts finishes the proof. Now we are ready to prove the main results of this section, the variance bound of data sub-sampling stochastic gradient. Lemma 3. The stochastic gradient (15) satisfies 1 γt E[ b ℓ(ωt) ℓ(ωt), ωt+1, ωt+1 | ωt] V2, (16) where V2 = (νs1 + 1 2ν2s2 + 2ν2b s1s2n + ν3b2s2n2) n2 σ4m, with ν = max{1, P }, b = max1 i n yixi , and the empirical variances s1 = Ej U[n] yjxj 1 n Pn i=1 yixi 2. and s2 = Ej U[n] xjx j 1 n Pn i=1 xix i 2 F. Proof. Expanding the inner product inside the expectation, we need to bound the expectation of b ξℓ(ωt) ξℓ(ωt), ξt+1, ξt+1 + b Ξℓ(ωt) Ξℓ(ωt), Ξt+1, Ξt+1 . For the first term b ξℓ(ωt) ξℓ(ωt), ξt+1, ξt+1 , applying the Cauchy-Schwarz inequality and Lemma 10 yields b ξℓ(ωt) ξℓ(ωt), ξt+1, ξt+1 b ξℓ(ωt) ξℓ(ωt) ξt+1, ξt+1 γtν b ξℓ(ωt) ξℓ(ωt) 2 + 2γtν2bn b ξℓ(ωt) ξℓ(ωt) b Ξℓ(ωt) Ξℓ(ωt) F. (20) For the second term b Ξℓ(ωt) Ξℓ(ωt), Ξt+1, Ξt+1 , applying the Cauchy-Schwarz inequality and Lemma 11 gives b Ξℓ(ωt) Ξℓ(ωt), Ξt+1, Ξt+1 b Ξℓ(ωt) Ξℓ(ωt) F Ξt+1, Ξt+1 F 2γtν2bn b ξℓ(ωt) ξℓ(ωt) b Ξℓ(ωt) Ξℓ(ωt) F + (2γtν2 + 4γtν3b2n2) b Ξℓ(ωt) Ξℓ(ωt) 2 F (21) Summing (20) and (21), and then applying the inequality E h b ξℓ(ωt) ξℓ(ωt) b Ξℓ(ωt) Ξℓ(ωt) F E h b ξℓ(ωt) ξℓ(ωt) 2i E h b Ξℓ(ωt) Ξℓ(ωt) 2 F Understanding Stochastic Natural Gradient Variational Inference where the expectations are conditioned on ωt, we obtain a bound on as follows: E[ b ℓ(ωt) ℓ(ωt),ωt+1, ωt+1 | ωt] γtνE b ξℓ(ωt) ξℓ(ωt) 2 | ωt + (2γtν2 + 4γtν3b2n2)E h b Ξℓ(ωt) Ξℓ(ωt) 2 F | ωt i E h b ξℓ(ωt) ξℓ(ωt) 2 | ωt i E h b Ξℓ(ωt) Ξℓ(ωt) 2 F | ωt i m s1 σ4 + (2γtν2 + 4γtν3b2n2) 1 m s2 σ4 + 4γtν2bn 1 m s1 σ4 + +1 m s2 σ4 + 2γtν2bn3 σ4 + γtν3b2 n4 σ4 γt(νs1 + +1 2ν2s2 + 2ν2b s1s2n + ν3b2s2n2)n2 where the second equality is due to Lemma 7 and Lemma 8. Dividing both sides by γt completes the proof. D. Proof of the Main Theorem Lemma 2. For conjugate likelihoods, the negative ELBO ℓ(ω) is 1-smooth 1-strongly convex relative to the convex conjugate A of the log-partition function. Proof. Let q (z) = p(z | y) be the posterior. By the definition of the negative ELBO, we have ℓ(ω) = DKL(q, q ) + C, where q Q is the variational distribution inside an exponential family Q parameterized by the expectation parameter ω and C = p(y) is a constant (log evidence) that does not depend on q and ω. Thanks to conjugacy, the posterior q is of the same form as q. By Lemma 4, ℓ(ω) DKL(q, q ) = DA (ω, ω ). Observe that the Bregman divergence DA (ω, ω ) is trivially 1-smooth and 1-strongly convex in ω relative to A . Below we present the main theorem, which adapts the results by Hanzely & Richt arik (2021) to stochastic natural gradient variational inference. Theorem 1. Suppose the likelihood p(y | z) is conjugate and the stochastic gradient b ℓ(ωt) satisfies Assumption 1. Running T + 1 iterations of stochastic natural gradient descent with γt = 2 2+t generate a point ωT +1 that satisfies E[ℓ( ωT +1)] min ω Ωℓ(ω) V T + 2, (13) where ωT +1 = 2 (T +1)(T +2) PT t=0(t + 1)ωt+1. Let q T +1 be the variational distribution represented by ωT +1. Then, the KL divergence to the true posterior q is bounded by E[DKL( q T +1, q )] V T + 2. (14) Proof. By the descent lemma of Hanzely & Richt arik (2021, Lemma 5.2), we have E[ℓ(ωt+1)] ℓ(ω ) 1 γt 1 DA (ω , ωt) 1 γt E[DA (ω , ωt+1)] + γt V. Plugging in γt = 2 2+t, we obtain E[ℓ(ωt+1)] ℓ(ω ) 1 2t DA (ω , ωt) 1 2 t + 2 E[DA (ω , ωt+1)] + γt V. Multiply the inequality by t + 1 and sum from 0 to T. Then we have t=0 (t + 1)(ℓ(ωt+1) ℓ(ω )) 1 t + 1 t + 2 1 2V (T + 1). Understanding Stochastic Natural Gradient Variational Inference Dividing both sides by PT t=0(t + 1) = 1 2(T + 1)(T + 2), and use the convexity of f, we obtain ℓ( ωT +1) V T + 2 To get the convergence rate in terms of the KL divergence, notice that ℓ( ωt+1) ℓ(ω ) = ℓ(ω ) (ωT +1 ω ) + DA ( ωT +1, ω ) = DA ( ωT +1, ω ) = DKL( q T +1, q ), where the first line is due to 1-smoothness and 1-strong convexity relative to A ; the second line is because the optimal parameter ω has zero gradient; the third line is due to Lemma 4. E. Missing Proofs in 5 Proposition 1. Suppose the prior and the variational family are both Gaussians. If the likelihood p(y | z) is log-concave in z, then the negative ELBO ℓ(ω) as a function of the expectation parameter has a unique minimizer ω . In addition, if the likelihood p(y | z) is differentiable in z, then ω is the unique stationary point of ℓ(ω). Proof. Consider the set Θ = {θ = µ, C : µ Rd, C Sd ++} which parameterizes all (non-degenerate) Gaussian distributions. Define f(θ) = µ, CC + µµ . Namely, f maps θ to the expectation parameter space Ω. Thanks to the uniqueness of matrix square root, f is a bijection. Since ℓ(mr) is strongly convex in θ, it has a unique minimizer θ Θ. Define ω = f(θ ). It is clear that ω Ωis the unique minimizer of ℓ(e). Consider the identity ℓ(mr)(θ) = ℓ(e)(f(θ)). (22) Taking the derivative of (22) on both sides, we have µℓ(mr)(θ) = ℓ(e) ω=f(θ) + 2 ℓ(e) Cℓ(mr)(θ) = 2 ℓ(e) It easy to see that ℓ(mr)(θ) = 0 iff ℓ(e)(f(θ)) = 0. Namely, f maps stationary points to stationary points. Since there is only one stationary point in Θ due strong convexity, there is only one stationary point in Ωas well. E.1. Bayesian Logistic Regression We give a more detailed description of the non-convexity of Bayesian logistic regression. Recall that we focus on the restriction of ℓ(ω) on the convex subset {ω = (0, Ξ) : Ξ = diag(s1, s2), s1 > 0, s2 > 0} Ω. Observe that wxi + b follows a Gaussian distribution N(0, x2 i s1 + s2). Therefore, we can use the Price theorem to take the derivative w.r.t. s2. Taking the first-order derivative of ℓ(ω) w.r.t. s2, we have i=1 Eq(w,b)[ψi(1 ψi)] + 1 Understanding Stochastic Natural Gradient Variational Inference where we use ψi to denote ψ(wxi +b) and ψ is the sigmoid function. Using the Price theorem again to take the second-order derivative of ℓ(ω) w.r.t. s2, we have s2 2 ℓ(ω) = i=1 Eq(w,b)[ψi(1 ψi)(6ψ2 i 6ψi + 1)] + 1 2s2 2 , Note that 2 s2 2 ℓ(ω) is continuous w.r.t. s1 and s2. Moreover, we have lim s1 0,s2 0 E[ψi(1 ψi)(6ψ2 i 6ψi + 1)] = 1 Therefore, there exists a small positive constant δ > 0, such that s1 = s2 = δ and Ew N(0,s1),b N(0,s2)[ψi(1 ψi)(6ψ2 i 6ψi + 1)] < 1 Crucially, δ is an absolute constant that does not depend on i. Because all 1 xi 1 are bounded, the distribution wxi + b N(0, x2 i s1 + s2) will shrink to zero as long as s1 + s2 0, regardless of the index i. This implies that when s1 = s2 = δ, we have i=1 Ew N(0,s1),b N(0,s2)[ψi(1 ψi)(6ψ2 i 6ψi + 1)] < 1 Therefore, when s1 = s2 = δ and n 8 δ2 , the second order derivative is negative s2 2 ℓ(ω) < 1 16n + 1 2δ2 < 0, which implies that the objective is non-convex in the expectation parameter. E.2. Bayesian Poisson Regression Bayesian Poisson regression assumes that y | x follows a Poisson distribution with the expectation E[y | x] = exp(w x), which gives the log likelihood log p(y | x, w) = log y! + yw x exp(w x). We impose a Gaussian prior p(w) = N(0, I) and approximate the posterior p(w | y) using a Gaussian variational distribution q(w). A nice property of the Bayesian Poisson regression is that its ELBO has a closed-form expression i=1 Eq(w)[ yiw xi + exp(w xi)] + DKL q, p h yiξ xi + exp ξ xi + 1 2x i Ξ ξξ xi i + DA (ω, ω0). The Hessian 2 ξℓ(ω) is i=1 exp ξ xi + 1 2x i (Ξ ξξ )xi 1 + 1 x i ξ 2 xix i + 2 ξA (ω). Evaluating the Hessian on the subset of the domain {ω = (ξ, Ξ) Ω: Ξ = ξξ + 2I}, we obtain the following i=1 exp x i ξ + x i xi x i ξ(x i ξ 2)xix i + 2 ξA (ω). With 0 < x i ξ < 2 for all i, which can be satisfied by constructing the dataset properly, and using exp(x i ξ + x i xi) > 1, we can drop the exponential term. The rest of the argument follows the main paper. Understanding Stochastic Natural Gradient Variational Inference F. Experimental Details In all experiments, SGD uses the (m, C) parameterization, where m is the Gaussian mean and C is the Cholesky factor of the Gaussian covariance. We parameterize C as a lower triangular matrix with strictly positive diagonal entries. For SGD, we clamp the diagonal entries of C to make sure they are no smaller than 10 10. This is effectively a projection step. F.1. Bayesian Linear Regression This is a Bayesian linear regression problem exactly the same as Example 1 with a standard Gaussian prior. Note that the expected log likelihood Eq(z) log p(y | X, z) is integrated in a closed-form. The only stochasticity comes from the mini-batch data sub-sampling. Domke et al. (2023, Theorem 7 and Theorem 10) have proved convergence for stochastic proximal (projected) gradient descent with a step size schedule γt = min µ µ 2t+1 (t+1)2 . It is not easy to come up with a tight estimate of the constant a. Therefore, we pick the linearly decreasing schedule 1 105+t for SGD. The reason for the specific constant 105 in the denominator is that 10 5 is roughly the largest step size such that SGD does not diverge in its initial stage. F.2. Bayesian Logistic Regression On Mushroom, the step size of SGD is tuned from {10 3, 10 4, 10 5, 10 6}, while the step size of NGD is tuned from {5 10 1, 10 1, 10 2, 10 3}. On MNIST, the step size of SGD is tuned from {10 5, 10 6, 10 7}, while the step size of NGD is tuned from {10 1, 10 2, 10 3}. Divergent curves (due to large step sizes) are not plotted in the graph. We use 10 samples from the variational distribution to estimate the stochastic gradient in every iteration. Legends without the label (p) use the reparameterization trick to compute the stochastic gradient. For SGD with the label (p) , we use the Price theorem as follows. First, observe the following relation between C and Σ: CEq(z) log p(y | z) = 2 ΣEq(z)[log p(y | z)] C = Eq(z)[ 2 z log p(y | z)] C. To obtain a stochastic gradient estimate b CEq(z) log p(y | z), replace the expectation with sample approximation.