# stochastic_expectation_maximization_with_variance_reduction__265d733c.pdf Stochastic Expectation Maximization with Variance Reduction Jianfei Chen , Jun Zhu , Yee Whye Teh and Tong Zhang Dept. of Comp. Sci. & Tech., BNRist Center, State Key Lab for Intell. Tech. & Sys., Institute for AI, THBI Lab, Tsinghua University, Beijing, 100084, China Department of Statistics, University of Oxford Tencent AI Lab {chenjian14@mails,dcszj@}.tsinghua.edu.cn y.w.teh@stats.ox.ac.uk; tongzhang@tongzhang-ml.org Expectation-Maximization (EM) is a popular tool for learning latent variable models, but the vanilla batch EM does not scale to large data sets because the whole data set is needed at every E-step. Stochastic Expectation Maximization (s EM) reduces the cost of E-step by stochastic approximation. However, s EM has a slower asymptotic convergence rate than batch EM, and requires a decreasing sequence of step sizes, which is difficult to tune. In this paper, we propose a variance reduced stochastic EM (s EM-vr) algorithm inspired by variance reduced stochastic gradient descent algorithms. We show that s EM-vr has the same exponential asymptotic convergence rate as batch EM. Moreover, s EM-vr only requires a constant step size to achieve this rate, which alleviates the burden of parameter tuning. We compare s EM-vr with batch EM, s EM and other algorithms on Gaussian mixture models and probabilistic latent semantic analysis, and s EM-vr converges significantly faster than these baselines. 1 Introduction Latent variable models are an important class of models due to their wide applicability across machine learning and statistics. Examples include factor analysis in psychology and the understanding of human cognition [32], hidden Markov models for modelling sequences, e.g. speech and language [29], and DNA [15], document and topic models [17, 4] and mixture models for density estimation and clustering [26]. Expectation Maximization (EM) [12] is a basic tool for maximum likelihood estimation for the parameters in latent variable models. It is an iterative algorithm with two steps: an E-step which calculates the expectation of sufficient statistics under the latent variable posteriors given the current parameters, and an M-step which updates the parameters given the expectations. With the phenomenal growth in big data sets in recent years, the basic batch EM (b EM) algorithm in [12] is quickly becoming infeasible because the whole data set is needed at every E-step. Cappé and Moulines [6] proposed a stochastic EM (s EM) algorithm for exponential family models, which reduces the time complexity for the E-step by approximating the full-batch expectation with an exponential moving average over minibatches of data. s EM has been adopted in many applications including natural language processing [24], topic modeling [16, 14] and hidden Markov models [5]. However, s EM has a slow asymptotic convergence rate due to the high variance of each update. Unlike the original batch EM (b EM), which converges exponentially fast near a local optimum, the distance towards a local optimum only decreases at the rate O(1/ T) for s EM, where T is the corresponding author 32nd Conference on Neural Information Processing Systems (Neur IPS 2018), Montréal, Canada. number of iterations. Moreover, s EM requires a decreasing sequence of step sizes to converge. The decay rate of step sizes is often difficult to tune. Recently, there has been much progress in accelerating stochastic gradient descent (SGD) by reducing the variance of the stochastic gradients, including SAG, SAGA and SVRG [22, 20, 11]. These algorithms achieve better convergence rates by utilizing infrequently computed batch gradients as control variates. Such ideas have also been brought into gradient-based Bayesian learning algorithms, including stochastic variational inference [25], as well as stochastic gradient Markov-chain Monte Carlo [13, 8, 7] (SGMCMC). In this paper, we develop a variance reduced stochastic EM algorithm (s EM-vr). In each epoch, that is, a full pass through the data set, our algorithm computes the full batch expectation as a control variate, and uses this to reduce the variance of minibatch updates in that epoch. Let E be the number of epochs and M be the number of minibatch iterations per epoch. We show that near a local optimum, our algorithm, with a constant step size, enjoys a convergence rate of O((M 1 log M)E/2) to the optimum. Like b EM, our convergence rate is exponential with respect to the number of epochs, and is asymptotically faster than s EM. We also show that our algorithm converges globally with a constant step size, under stronger assumptions. Note that leveraging variance reduction ideas in s EM is not straightforward, since s EM is not a stochastic gradient descent algorithm but rather a stochastic approximation [21] algorithm. In particular, the proof techniques we utilize are different than those in stochastic gradient descent algorithms. We demonstrate our algorithm on Gaussian mixture models and probabilistic latent semantic analysis [18]. s EM-vr achieves significantly faster convergence comparing with s EM, b EM, and other gradient-based and Bayesian algorithms. 2 Background We review batch and stochastic EM algorithms in this section. Throughout the paper we focus on exponential family models with tractable Eand M-steps, which stochastic EM [6] is designed for. 2.1 EM Algorithm The EM algorithm is designed for models with some observed variable x and hidden variable h. We assume an exponential family joint distribution p(x, h; θ) = b(x, h) exp{η(θ) φ(x, h) A(θ)} parameterized by θ. Given a data set of N ( 1) observations X = {xi}N i=1, we want to obtain a maximum likelihood estimation (MLE) of the parameter θ, by maximizing the log marginal likelihood L(θ) := PN i=1 log p(xi; θ) = PN i=1 log R hi p(xi, hi; θ)dθ, where the variables (xi, hi) are i.i.d. given θ. Denote H = {hi}N i=1. Batch expectation-maximization (b EM) [12] optimizes the log marginal likelihood L(θ) by constructing a lower bound of it: L(θ) Q(θ; ˆθ) Ep(H|X;ˆθ) h log p(H|X, ˆθ) i , (1) Q(θ; ˆθ) := Ep(H|X;ˆθ) [log p(X, H; θ)] = N η(θ) F(ˆθ) A(θ) + constant, (2) where we define F(ˆθ) := 1 N PN i=1 fi(ˆθ) as the full-batch expected sufficient statistics, and where fi(ˆθ) := Ep(hi|xi;ˆθ) [φ(xi, hi)] is the expected sufficient statistics conditioned on observed datum xi. Let ˆθe be the estimated parameter at iteration or epoch e, where each epoch is a complete pass through the data set. In the E-step, b EM tightens the bound in Eq. (1) by setting ˆθ = ˆθe, and computes the expected sufficient statistics F(ˆθe). In the M-step, b EM finds a maximizer ˆθe+1 of the lower bound with respect to θ, by solving the optimization problem argmaxθ{η(θ) F(ˆθ) A(θ)}. The solution is denoted as R(F(ˆθ)), and is assumed to be tractable. In summary, the b EM updates can be written simply as E-step: compute F(ˆθe), M-step: let ˆθe+1 = R(F(ˆθe)). (3) The algorithm is also applicable to maximum a posteriori (MAP) estimation of parameters, with a conjugate prior p(θ; α) = exp{η(θ) α A(θ)} with the hyperparameter α. Instead of L(θ), MAP maximizes L(θ) + log p(θ; α) Nη(θ) α/N + F(ˆθ) NA(θ) + constant, and we still apply Eq. (3), but with fi(ˆθ) := α/N + Ep(hi|xi;ˆθ) [φ(xi, hi)] instead. 2.2 Stochastic EM Algorithm When the data set is large, that is, N is large, computing F(ˆθt) in the E-step is too expensive because it needs a full pass though the entire data set. Stochastic EM (s EM) [6] avoids this by maintaining an exponentially moving average ˆst as an approximation of the full average F(ˆθt). At iteration t, s EM picks a single random datum i, and updates: E step: ˆst+1 = (1 ρt)ˆst + ρtfi(ˆθt), M step: ˆθt+1 = R(ˆst+1), where (ρt) is a sequence of step sizes that satisfy P t ρt = and P t ρ2 t < . We deliberately choose different iteration indices e and t for b EM and s EM to emphasize their different time complexity per iteration. In practice, s EM can take a minibatch of data instead of a single datum per iteration, but we stick to a single datum for cleaner presentation. The two s EM updates can be rolled into a single update ˆst+1 = (1 ρt)ˆst + ρtfi(ˆst). (4) where for simplicity we have overloaded the notation with fi(s) := fi(R(s)). This first maps s, which can be interpreted as the estimated mean parameter of the model, into the parameters θ = R(s), before computing the required expected sufficient statistics fi(θ) under the posterior given observation xi. Which of the two definitions should be clear from the type of its argument and we feel this helps reduce notational burden on the reader. We similarly overload F(s) := F(R(s)) and L(s) := L(R(s)) accordingly, so we can also write b EM updates (Eq. 3) as simply ˆse+1 = F(ˆse). Intuitively, we want to find a stationary point s under b EM iterations, i.e., s = F(s ). We can view b EM as a fixed-point algorithm, and s EM as a Robbins-Monro [30] algorithm to solve the equation s = F(s ). Because of the cheap updates, s EM can converge faster than b EM on large data sets in the beginning. However, due to the variance of the estimator ˆst, s EM has a slower asymptotic convergence rate than b EM for finite data sets. Specifically, let s = F(s ) be a stationary point, Cappe and Monlines [6] showed that E ˆs T s 2 = O(ρT ) for s EM, which is at best O(T 1) since P t ρt = . In contrast, Dempster et al. [12] showed that b EM converges as ˆs E s 2 (1 λ) 2E ˆs0 s , where 1 λ [0, 1) is a constant that is defined in Sec. 3.3. As long as the data set is finite, the exponential rate of b EM is faster than s EM. 2 Moreover, s EM needs a decreasing sequence of step sizes to converge, whose decay rate is difficult to tune. 3 Variance Reduced Stochastic Expectation Maximization In this section, we describe a variance reduced stochastic EM algorithm (s EM-vr), and develop the theory for its convergence. s EM-vr enjoys an exponential convergence rate with a constant step size. 3.1 Algorithm Description We run the algorithm for E epochs and M minibatch iterations per epoch, so that there are T := ME iterations in total. For simplicity we choose M = N and use minibatches of size 1, though our analysis is not limited to this case. Each epoch has the same time complexity as b EM. We index iteration t in epoch e as e, t. Let ˆse,t be the estimated sufficient statistics at iteration e, t. Starting from the initial estimate ˆs0,0, s EM-vr performs the following updates in epoch e, Stochastic EM with Variance Reduction 1. Compute F(ˆse,0), and save F(ˆse,0) as well as ˆse,0 2. For each iteration t = 1, . . . , M, randomly sample a datum i, and update ˆse,t+1 = (1 ρ)ˆse,t + ρ [fi(ˆse,t) fi(ˆse,0) + F(ˆse,0)] . (5) 3. Let ˆse+1,0 = ˆse,M. Let Ee,t and Vare,t be the expectation and variance over the random index i in iteration e, t. Comparing Eq. (5) with Eq. (4), we observe that the s EM and s EM-vr updates have the same expectation 2Without affecting the convergence rates, we slightly adjust the convergence theorems in [6, 12] to view them in a uniformed way, see Appendix A for details. Et [ˆst+1] = (1 ρ)ˆst + ρF(ˆst). However their variances are different: s EM has Vart [ˆst+1] = ρ2 t Vart[fi(ˆst)], while s EM-vr has Vare,t [ˆse,t+1] = ρ2Vare,t [fi(ˆse,t) fi(ˆse,0)]. If the algorithm converges, i.e., the sequence (ˆse,t) converges to a point s , and fi( ) is continuous, the variance of s EM-vr will converge to zero, while that of s EM will remain positive. Therefore, s EM-vr has asymptotically smaller variance than s EM, and we will see that this leads to better asymptotic convergence rates. The time complexity of s EM-vr per epoch is the same as b EM and s EM, with a constant factor up to 3, for computing fi(ˆse,t), fi(ˆse,0) and F(ˆse,0). The space complexity also has a constant factor up to 3, for storing ˆse,0 and F(ˆse,0) along with ˆse,t. In practice, the difference is less than 3 times because the time and space costs for other aspects of the methods are the same, e.g. data storage. 3.2 Related Works A possible alternative to s EM is Titterington s online algorithm [33], which replaces the exact M-step with a gradient ascent step to optimize Q(θ; ˆθ), where the gradient is multiplied with the inverse Fisher information of p(x, h; θ). Titterington s algorithm is locally equivalent to s EM [6]. However, as argued by Cappé and Moulines [6], Titterington s algorithm has several issues, including the Fisher information being expensive to compute in high dimensions, the need for explicit matrix inversion, and that the updated parameters are not guaranteed to be valid. Moreover, leveraging variance reduced stochastic gradient algorithms [20, 22, 11] for Titterington s algorithm is not straightforward as the Fisher information matrix changes with θ. Zhu et al. has proposed a variance reduced stochastic gradient EM algorithm [39]. There are also some theoretical analysis of EM algorithm for high dimensional data [3, 35]. Instead of performing point estimation of parameters, Bayesian inference algorithms, including variational inference (VI) and Markov-chain Monte-Carlo (MCMC), can also be adopted, to infer the posterior distribution of parameters. Variance reducing techniques have also been applied to these settings, including smoothed stochastic variational inference (SSVI) [25] and variance reduced stochastic gradient MCMC (VRSGMCMC) algorithms [13, 8, 7]. However, convergence guarantees for SSVI have not been developed, while VRSGMCMC algorithms are typically much slower than s EM-vr due to the intrinsic randomness of MCMC. For example, the time complexity to converge to an ϵ-precision in terms of the 2-Wasserstein distance of the true posterior and the MCMC distribution is O(N +κ3/2 d/ϵ), where κ is a condition number and d is the dimensionality of the parameters [7]. 3.3 Local Convergence Rate We analyze the local convergence rate of a sequence {ˆse,t} of s EM-vr iterates to a stationary point s with s = F(s ). Let θ := R(s ) be the natural parameter corresponding to the mean parameter s . Theorem 1. If (a) The Hessian 2L(θ ) is negative definite, i.e., θ is a strict local maximum of L(θ ). (b) i, fi(s) is Lf-Lipschitz continuous, and F(s) is βf-smooth. (c) e, t, ˆse,t s < λ/βf, where 1 λ is the maximum eigenvalue of J := F(s )/ s . Then, for any step size ρ λ/(32L2 f), we have E ˆs E,0 s 2 exp ( Mλρ/4) + 32L2 fρ/λ E ˆs0,0 s 2 . (6) In particular, if ρ = ρ := 4 log(M/κ2)/(λM), where κ2 := 128L2 f/λ2, then we have E ˆs E,0 s 2 1 + log(M/κ2) κ2/M E ˆs0,0 s 2 . (7) Remarks. Assumption (a) follows directly from the original EM paper (Theorem 4) [12]. [12] analyzed the convergence only in an infinitesimal neighbourhood of s , while Assumption (c) gives an explicit radius of convergence. Assumption (b) is new and required to control the variance and radius of convergence. Note also that we analyse the convergence of the mean parameters, while [12] analysed that for parameters. However they are equivalent if R(s) is Lipschitz continuous. In Appendix A.1 we show that negative definite 2L(θ ) in Assumption (a) implies that λ > 0 in Assumption (c). Proof. We first analyze the convergence behavior at a specific epoch e, and omit the epoch index e for concise notations. We further denote t := ˆst s for any t. By Eq. (5), Et t+1 2 = Et (1 ρ)ˆst + ρF(ˆst) s + ρ [fi(ˆst) fi(ˆs0) F(ˆst) + F(ˆs0)] 2 = (1 ρ)ˆst + ρF(ˆst) s 2 + ρ2Et fi(ˆst) fi(ˆs0) F(ˆst) + F(ˆs0) 2 , (8) where the second equality is due to Et [fi(ˆse,t) fi(ˆse,0) + F(ˆse,0)] = F(ˆse,t). We have (1 ρ)ˆst + ρF(ˆst) s 2 = (1 ρ) t + ρ(F(ˆst) s ) + ρJ t ρJ t 2 (1 ρ) t + ρJ t + ρ F(ˆst) s J t 2 h (1 ρλ) t + (ρ/2)βf t 2i2 = [1 ρ (λ βf t /2)]2 t 2 (1 ρλ/2)2 t 2 (1 ρλ/2) t 2 , (9) where the second line utilizes triangular inequality, the third line utilizes (1 ρ)I + ρJ 1 ρ + ρ(1 λ) = 1 ρλ,where is the ℓ2 operator norm, and the smoothness in (b), which implies F(ˆst) s J (ˆst s ) (βf/2) ˆst s 2. The last line utilizes (c). By (b), F is Lf-Lipschitz and i, fi F is 2Lf-Lipschitz continuous. Therefore Et fi(ˆst) fi(ˆs0) F(ˆst) + F(ˆs0) 2 4L2 f ˆst ˆs0 2 8L2 f( t 2 + 0 2). (10) Combining Eq. (8, 9, 10), and utilizing our assumption ρ λ/(32L2 f), we have E t+1 2 1 ρλ/2 + 8ρ2L2 f t 2 + 8ρ2L2 f 0 2 (1 ρλ/4) t 2 + 8ρ2L2 f 0 2 . We get Eq. (6, 7) by analyzing the sequence at+1 (1 ϵρ)at + cρ2a0, where at = E t 2, ϵ = λ/4 and c = 8L2 f. The analysis is in Appendix B. Comparison with b EM: As mentioned in Sec. 2.2, b EM has E ˆs E s 2 (1 λ)2E ˆs0 s 2. The distance decreases exponentially for both b EM and s EM-vr, but at different speeds. If M is large, s EM-vr (Eq. 7) converges much faster than b EM because 1 + log(M/κ2) κ2/M (1 λ)2, thanks to its cheap stochastic updates. Comparison with s EM: As mentioned in Sec. 2.2, s EM has E ˆs T s 2 = O(T 1), which is not exponential, and is asymptotically slower than s EM-vr. The key difference is we can bound the variance term for s EM-vr by ˆst ˆs0 2 in Eq. (10), so the variance goes to zero as (ˆse,t) converges. The advantage of s EM-vr over s EM is especially significant when E is large. Moreover, s EM requires a decreasing sequence of step sizes to converge [6], which is more difficult to tune comparing with the constant step size of s EM-vr. 3.4 Global Convergence Theorem 1 only considers the case near a local maximum of the log marginal likelihood. We now show that under stronger assumptions, there exists a constant step size, such that s EM-vr can globally converge to a stationary point s = F(s ), one with L(s ) = 0 [12]. Theorem 2. Suppose (a) The natural parameter function η(θ) is Lη-Lipschitz, and fi(s) is Lf-Lipschitz for all i, (b) for any x and h, log p(x, h; θ) is γ-strongly-concave w.r.t. θ. Then for any constant step size ρ < γ/ (M(M 1)LηLf), s EM-vr converges to a stationary point, starting from any valid sufficient statistics vector ˆs0,0. A sufficient condition for (b) is the exponential family is canonical, i.e., η(θ) = θ, and we want the MAP estimation instead of MLE, where the prior log p(θ) is γ-strongly-concave. We leave the proof in Appendix C. The idea is first show that s EM-vr is a generalized EM (GEM) algorithm [36], which improves E[Q(θ; ˆθ)] after each epoch, and then apply Wu s convergence theorem for GEM [36]. 0 1 2 3 4 5 6 7 8 9 10 30 s EM vr s EM b EM 0 1 2 3 4 5 6 7 8 9 10 s EM vr s EM Figure 1: Toy Gaussian Mixture. Left: log10 E ˆµt µ 2, Right: log10 Vart[ˆµt]/ρ2 t, Xaxis: number of epochs. Data set D V |I| NIPS [1] 1.5k 12k 1.93m NYTimes [1] 0.3m 102k 99m Wiki [38] 3.6m 8k 524m Pub Med [1] 8.1m 141k 731m Table 1: Statistics of datasets for p LSA. k=thousands, m=millions. 4 Applications and Experiments We demonstrate the application of s EM-vr on a toy Gaussian mixture model and probabilistic latent semantic analysis. 4.1 Toy Gaussian Mixture We fit a mixture of two Gaussians, p(x|µ) = 0.2N(µ, 1) + 0.8N( µ, 1), with a single unknown parameter µ. Let X = {xi}N i=1 be the data set, and hi {1, 2} be the cluster assignment of xi. We write hik := I(hi = k) as a shortcut, where I( ) is the indicator function. The joint likelihood is p(X, H|µ) exp{P k hik log N(xi; µk, 1)} exp{P i η(µ) φ(xi, hi)}, where the natural parameter η(µ) = (µ, µ, µ2/2, µ2/2) and the sufficient statistics φ(xi, hi) = (xihi1, xihi2, hi1, hi2). Let γik(µ) = p(hi = k|xi, µ) πi N(xi; µk, 1) for k {1, 2} be the posterior probabilities. The expected sufficient statistics fi(µ) = Ep(hi,xi|µ)φ(xi, hi) = (xiγi1(µ), xiγi2(µ), γi1(µ), γi2(µ)), and F(µ) = 1/N P i fi(µ). The mapping from sufficient statistics to parameters is R(s) = (s1 s2)/(s3 s4). b EM, s EM, and s EM-vr updates are then defined respectively as Eq. (3), Eq. (4), and Eq. (5). We construct a dataset of N = 10, 000 samples drawn from the model with µ = 0.5, and run b EM until convergence (to double precision) to obtain the MLE µ . We then measure the convergence of E ˆµt µ 2 as well as the variance term Vart[ˆµt]/ρ2 t for b EM, s EM, and s EM-vr with respective to the number of epochs. Vart[ˆµt] is always quadratic with respect to the step size ρt, so we divide it by ρ2 t to cancel the effect of the step size, and just study the intrinsic variance. We tune the step size manually, and set ρt = 3/(t + 10) for s EM and ρ = 0.003 for s EM-vr. The result is shown as Fig. 1. s EM converges faster than b EM in the first 8 epochs, and then it is outperformed by b EM, because s EM is asymptotically slower, as mentioned in Sec. 2.2. The convergence curve of s EM-vr exhibits a staircase pattern. In the beginning of each epoch it converges very fast because ˆse,t ˆse,0 is small, so the variance is small. The variance then becomes larger and the convergence slows down. Then we start a new epoch and compute a new F(ˆse,0), so that the convergence is fast again. On the other hand, the variance of s EM remains constant. 4.2 Probabilistic Latent Semantic Analysis 4.2.1 Model and Algorithm Probabilistic Latent Semantic Analysis (p LSA) [18] represents text documents as mixtures of topics. p LSA takes a list I of tokens, where each token i is represented by a pair of document and word IDs (di, vi), that indicates for the presence of a word vi in document di. Denote [n] = {1, . . . , n}, we have di [D] and vi [V ]. p LSA assigns a latent topic zi [K] for each token, and defines the joint likelihood as p(I, Z|θ, φ) = Q i I Cat(zi; θdi)Cat(vi; φzi), with the parameters θ = {θd}D d=1 and φ = {φk}K k=1. We have priors p(θd) = Dir(θd; K, α ) and p(φk) = Dir(φk; V, β ), where Dir(K, α) is a K-dimensional symmetric Dirichlet distribution with the concentration parameter α, and find an MAP estimation argmaxθ,φ log P Z p(W, Z|θ, φ) + log p(θ) + log p(φ). Only the updates are presented here and the derivation is in Appendix D. Let γik(θ, φ) := p(zi = k|vi, θ, φ) θdi,kφk,vi be the posterior topic assignment of the token vi, b EM updates γdk(θ, φ) = P i Id γik(θ, φ), and γkv(θ, φ) = P i Iv γik(θ, φ) in E-step, where Id = {(di, vi)|di = d} and Iv = {(di, vi)|vi = v}. M-step is θdk = (γdk+α)/(P k γdk+Kα), and φkv = (γkv+β)/(P v γkv+V β), where α = α 1 and β = β 1. We distinguish (γik, γdk, γvk) and (I, Id, Iv) by indices for simplicity. s EM approximates the full batch expected sufficient statistics γdk and γkv with exponential moving averages ˆst,d,k and ˆst,k,v at iteration t, and updates ˆst+1,d,k = (1 ρt)ˆst,d,k+ρt |I| |ˆI| P i ˆId γik(ˆθt, ˆφt), and ˆst+1,k,v = (1 ρt)ˆst,k,v + ρt |I| |ˆI| P i ˆIv γik(ˆθt, ˆφt), where we sample a minibatch ˆI I of tokens per iteration, ˆId, ˆIv are defined in the same way as Id, Iv. ˆθt and ˆφt are computed in the M-step with ˆst,d,k and ˆst,k,v. This s EM algorithm is known as SCVB0 [16]. s EM-vr updates as ˆse,t+1,d,k = (1 ρ)ˆse,t,d,k + ρ |I| i ˆId(γik(ˆθe,t, ˆφe,t) γik(ˆθe,0, ˆφe,0)) + ργdk(ˆθe,0, ˆφe,0), and ˆse,t+1,k,v = (1 ρ)ˆse,t,k,v + ρ |I| i ˆIv(γik(ˆθe,t, ˆφe,t) γik(ˆθe,0, ˆφe,0)) + ργkv(ˆθe,0, ˆφe,0), where γdk(ˆθe,0, ˆφe,0) and γkv(ˆθe,0, ˆφe,0) is computed by b EM per epoch. We have pseudocode for s EM and s EM-vr in Appendix D. If θ is integrated out instead of maximized, we recover an MAP estimation [14] of latent Dirichlet allocation (LDA) [4]. Many existing algorithms for LDA actually optimize the p LSA objective as an approximation of the LDA objective, including CVB0 [2, 31, 19], SCVB0 [16], BP-LDA [10], ESCA [37], and Warp LDA [9]. This approximation works well in practice when the number of topics is small [2]. We have more discussions in Appendix D.1. 4.2.2 Experimental Settings We compare s EM-vr with b EM and s EM (SCVB0), which is the start-of-the-art algorithm for p LSA, on four datasets listed in Table 1. We also compare with two gradient based algorithms, stochastic mirror descent (SMD) [10] and reparameterized stochastic gradient descent (RSGD) as well as their variants with SVRG-style [20] variance reduction, denoted as SMD-vr and RSGD-vr, despite their convergence properties are unknown. Both SMD and RSGD replace the M-step with a stochastic gradient step. SMD updates as θdk θdk exp{ρ θdk Q} and φkv φkv exp{ρ φkv Q}, where Q is defined as Eq. (1). RSGD adopts the reparameterization θdk = exp λdk P k exp λdk and φkv = exp τkv P v exp τkv , and directly optimize Q w.r.t. λ and τ by stochastic gradient descent. Derivations of SMD and RSGD are in Appendix D.6. All the algorithms are implemented in C++, and are highly-optimized and parallelized. The testing machine has two 12-core Xeon E5-2692v2 CPUs and 64GB main memory. We assess the convergence of algorithms by the training objective log p(W|θ, φ) + log p(θ|α ) + log p(φ|β ), i.e., logarithm of unnormalized posterior distribution p(θ, φ|W, α , β ). For each dataset and the number of topics K {50, 100}, we first select the hyperparameters by a grid search Kα {0.1, 1, 10, 100} and β {0.01, 0.1, 1}.3 Then, we do another grid search to choose the step size. For s EM-vr, we choose ρ {0.01, 0.02, 0.05, 0.1, 0.2}, and for all other stochastic algorithms, we set ρt = a/(t + t0)κ, and choose a {10 7, . . . , 100}, t0 {10, 100, 1000} and κ {0.5, 0.75, 1}.4 Finally, we repeat 5 runs with difference random seeds for each algorithm with its best step size. E is 20 for NIPS and NYTimes, and 5 for Wiki and Pub Med. M is 50 for NIPS and 500 for all the other datasets. 4.2.3 Results for p LSA We plot the training objective against running time as first and second row of Fig. 2. We find that gradient-based algorithms and b EM are not competitive with s EM and s EM-vr, so we only report their results on NIPS, to make the distinction s EM and s EM-vr more clear. Full results and more explanations of the slow convergence of gradient-based algorithms are available in Appendix D.6. Due to the reduced variance, s EM-vr consistently converges faster to better training objective than s EM and b EM on all the datasets, while the constant step size of s EM-vr is easier to tune than the decreasing sequence of step sizes for s EM. 3We find that all the algorithms have the same best hyperparameter configuration. 4We have tried constant step sizes for SMD-vr and RSGD-vr but found it worse than decreasing step sizes. NIPS NYTimes Wiki Pub Med 10 1 100 101 1.48 s EM vr s EM b EM RSGD vr RSGD SMD vr SMD 101 102 103 cv s EM s EM 102 4 101 6 101 cv s EM s EM 102 2 102 3 102 cv s EM s EM 10 1 100 101 1.500 s EM vr s EM b EM RSGD vr RSGD SMD vr SMD 101 102 103 8.200 cv s EM s EM 102 2 102 3 102 3.600 cv s EM s EM 2 102 3 102 4 102 5.85 cv s EM s EM 10 1 100 101 2000 s EM vr s EM GOEM SSVI SVI 101 102 103 6000 s EM vr s EM GOEM SSVI SVI 101 102 103 1600 s EM vr s EM GOEM SSVI SVI 101 102 103 5100 s EM vr s EM GOEM SSVI SVI Figure 2: p LSA and LDA convergence results. X-axis is running time in seconds. First and second row: p LSA with K = 50 and K = 100, y-axis is the training objective. Third row: LDA with K = 10, y-axis is the testing perplexity. 4.3 Results for LDA As discussed in Sec. 4.2.1, algorithms for p LSA also work well as approximate training algorithms for LDA, if the number of topics is small. Therefore, we also evaluate our s EM-vr algorithm for LDA, with a small number of K = 10 topics. The training algorithm is exactly the same, but the evaluation metric is different. We hold out a small testing set, and report the testing perplexity, computed by the left-to-right algorithm [34] on the testing set. We compare with a state-of-the-art algorithm, Gibbs online expectation maximization (GOEM) [14], which outperforms a wide range of algorithms including SVI [17], hybrid variational-Gibbs [27], and SGRLD [28]. We also compare with stochastic variational inference (SVI) [17] and its variance reduced variant SSVI [25]. The third row of Fig. 2 shows the results. We observed that s EM-vr converges the fastest on all the datasets except NIPS, where s EM converges faster due to its cheaper iterations. s EM-vr always gets better results than s EM in the end. GOEM converges slower due to its high Monte-Carlo variance. SVI and SSVI converge slower due to their inexact mean field assumption and expensive iterations, including an inner loop for inferring the local latent variables and frequent evaluation of the expensive digamma function. For a larger number of topics, such as 100, we find that GOEM performs the best since it does not approximate LDA as p LSA, and does not make mean field assumptions as SVI and SSVI. Extending our algorithm to variational EM and Monte-Carlo EM, when the E-step is not tractable, is an interesting future direction. 5 Conclusions and Discussions We propose a variance reduced stochastic EM (s EM-vr) algorithm. s EM-vr achieves a 1 + log(M/κ2) E local convergence rate, which is faster than both the (1 λ) 2E rate of batch EM and O(T 1) rate of plain stochastic EM (s EM). Unlike s EM, which requires a decreasing sequence of step sizes to converge, s EM-vr only requires a constant step size to achieve this local convergence rate as well as global convergence, under stronger assumptions. We compare s EM-vr against b EM, s EM and other gradient and Bayesian algorithms, on GMM and p LSA tasks, and find that s EM-vr converges significantly faster than these alternatives. An interesting future direction is leveraging recent progress on variance reduced stochastic gradient descent for non-convex optimization [23] to relax our assumptions on strongly-log-concavity, and extend s EM-vr to stochastic control variates, which works better on very large data sets. Extending our work to variational EM and Monte-Carlo EM is also interesting. Acknowledgments We thank Chris Maddison, Adam Foster, and Jin Xu for proofreading. J.C. and J.Z. were supported by the National Key Research and Development Program of China (No.2017YFA0700904), NSFC projects (Nos. 61620106010, 61621136008, 61332007), the MIIT Grant of Int. Man. Comp. Stan (No. 2016ZXFB00001), Tsinghua Tiangong Institute for Intelligent Computing, the NVIDIA NVAIL Program and a Project from Siemens. YWT was supported by funding from the European Research Council under the European Union s Seventh Framework Programme (FP7/2007-2013) ERC grant agreement no. 617071, and from Tencent AI Lab through the Oxford-Tencent Collaboration on Large Scale Machine Learning. [1] Arthur Asuncion and David Newman. Uci machine learning repository, 2007. [2] Arthur Asuncion, Max Welling, Padhraic Smyth, and Yee Whye Teh. On smoothing and inference for topic models. In Proceedings of the Twenty-Fifth Conference on Uncertainty in Artificial Intelligence, pages 27 34. AUAI Press, 2009. [3] Sivaraman Balakrishnan, Martin J Wainwright, Bin Yu, et al. Statistical guarantees for the em algorithm: From population to sample-based analysis. The Annals of Statistics, 45(1):77 120, 2017. [4] David M Blei, Andrew Y Ng, and Michael I Jordan. Latent dirichlet allocation. Journal of machine Learning research, 3(Jan):993 1022, 2003. [5] Olivier Cappé. Online em algorithm for hidden markov models. Journal of Computational and Graphical Statistics, 20(3):728 749, 2011. [6] Olivier Cappé and Eric Moulines. On-line expectation maximization algorithm for latent data models. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 71(3):593 613, 2009. [7] Niladri S Chatterji, Nicolas Flammarion, Yi-An Ma, Peter L Bartlett, and Michael I Jordan. On the theory of variance reduction for stochastic gradient monte carlo. ar Xiv preprint ar Xiv:1802.05431, 2018. [8] Changyou Chen, Wenlin Wang, Yizhe Zhang, Qinliang Su, and Lawrence Carin. A convergence analysis for a class of practical variance-reduction stochastic gradient mcmc. ar Xiv preprint ar Xiv:1709.01180, 2017. [9] Jianfei Chen, Kaiwei Li, Jun Zhu, and Wenguang Chen. Warplda: a cache efficient o (1) algorithm for latent dirichlet allocation. Proceedings of the VLDB Endowment, 9(10):744 755, 2016. [10] Jianshu Chen, Ji He, Yelong Shen, Lin Xiao, Xiaodong He, Jianfeng Gao, Xinying Song, and Li Deng. End-to-end learning of lda by mirror-descent back propagation over a deep architecture. In Advances in Neural Information Processing Systems, pages 1765 1773, 2015. [11] Aaron Defazio, Francis Bach, and Simon Lacoste-Julien. Saga: A fast incremental gradient method with support for non-strongly convex composite objectives. In Advances in neural information processing systems, pages 1646 1654, 2014. [12] Arthur P Dempster, Nan M Laird, and Donald B Rubin. Maximum likelihood from incomplete data via the em algorithm. Journal of the royal statistical society. Series B (methodological), pages 1 38, 1977. [13] Kumar Avinava Dubey, Sashank J Reddi, Sinead A Williamson, Barnabas Poczos, Alexander J Smola, and Eric P Xing. Variance reduction in stochastic gradient langevin dynamics. In Advances in neural information processing systems, pages 1154 1162, 2016. [14] Christophe Dupuy and Francis Bach. Online but accurate inference for latent variable models with local gibbs sampling. The Journal of Machine Learning Research, 18(1):4581 4625, 2017. [15] Richard Durbin, Sean R Eddy, Anders Krogh, and Graeme Mitchison. Biological sequence analysis: probabilistic models of proteins and nucleic acids. Cambridge university press, 1998. [16] James Foulds, Levi Boyles, Christopher Du Bois, Padhraic Smyth, and Max Welling. Stochastic collapsed variational bayesian inference for latent dirichlet allocation. In Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining, pages 446 454. ACM, 2013. [17] Matthew D Hoffman, David M Blei, Chong Wang, and John Paisley. Stochastic variational inference. The Journal of Machine Learning Research, 14(1):1303 1347, 2013. [18] Thomas Hofmann. Probabilistic latent semantic analysis. In Proceedings of the Fifteenth conference on Uncertainty in artificial intelligence, pages 289 296. Morgan Kaufmann Publishers Inc., 1999. [19] Katsuhiko Ishiguro, Issei Sato, and Naonori Ueda. Averaged collapsed variational bayes inference. Journal of Machine Learning Research, 18(1):1 29, 2017. [20] Rie Johnson and Tong Zhang. Accelerating stochastic gradient descent using predictive variance reduction. In Advances in neural information processing systems, pages 315 323, 2013. [21] Harold Kushner and G George Yin. Stochastic approximation and recursive algorithms and applications, volume 35. Springer Science & Business Media, 2003. [22] Nicolas Le Roux, Mark Schmidt, and Francis Bach. A stochastic gradient method with an exponential convergence rate for finite training sets. In Advances in Neural Information Processing Systems, pages 2663 2671, 2012. [23] Lihua Lei and Michael Jordan. Less than a single pass: Stochastically controlled stochastic gradient. In Artificial Intelligence and Statistics, pages 148 156, 2017. [24] Percy Liang and Dan Klein. Online em for unsupervised models. In Proceedings of human language technologies: The 2009 annual conference of the North American chapter of the association for computational linguistics, pages 611 619. Association for Computational Linguistics, 2009. [25] Stephan Mandt and David Blei. Smoothed gradients for stochastic variational inference. In Advances in Neural Information Processing Systems, pages 2438 2446, 2014. [26] Geoffrey Mc Lachlan and David Peel. Finite mixture models. John Wiley & Sons, 2004. [27] David Mimno, Matt Hoffman, and David Blei. Sparse stochastic inference for latent dirichlet allocation. ar Xiv preprint ar Xiv:1206.6425, 2012. [28] Sam Patterson and Yee Whye Teh. Stochastic gradient riemannian langevin dynamics on the probability simplex. In Advances in Neural Information Processing Systems, pages 3102 3110, 2013. [29] Lawrence R Rabiner. A tutorial on hidden markov models and selected applications in speech recognition. Proceedings of the IEEE, 77(2):257 286, 1989. [30] Herbert Robbins and Sutton Monro. A stochastic approximation method. The annals of mathematical statistics, pages 400 407, 1951. [31] Issei Sato and Hiroshi Nakagawa. Rethinking collapsed variational bayes inference for lda. In ICML, 2012. [32] Charles Spearman and L. W. Jones. Human Ability. Macmillan, 1950. [33] D Michael Titterington. Recursive parameter estimation using incomplete data. Journal of the Royal Statistical Society. Series B (Methodological), pages 257 267, 1984. [34] Hanna M Wallach, Iain Murray, Ruslan Salakhutdinov, and David Mimno. Evaluation methods for topic models. In Proceedings of the 26th annual international conference on machine learning, pages 1105 1112. ACM, 2009. [35] Zhaoran Wang, Quanquan Gu, Yang Ning, and Han Liu. High dimensional expectationmaximization algorithm: Statistical optimization and asymptotic normality. ar Xiv preprint ar Xiv:1412.8729, 2014. [36] CF Jeff Wu. On the convergence properties of the em algorithm. The Annals of statistics, pages 95 103, 1983. [37] Manzil Zaheer, Michael Wick, Jean-Baptiste Tristan, Alex Smola, and Guy Steele. Exponential stochastic cellular automata for massively parallel inference. In Artificial Intelligence and Statistics, pages 966 975, 2016. [38] Aonan Zhang, Jun Zhu, and Bo Zhang. Sparse online topic models. In Proceedings of the 22nd international conference on World Wide Web, pages 1489 1500. ACM, 2013. [39] Rongda Zhu, Lingxiao Wang, Chengxiang Zhai, and Quanquan Gu. High-dimensional variancereduced stochastic gradient expectation-maximization algorithm. In International Conference on Machine Learning, pages 4180 4188, 2017.