# amortised_learning_by_wakesleep__2ca324c6.pdf Amortised Learning by Wake-Sleep Li K. Wenliang 1 Theodore Moskovitz 1 Heishiro Kanagawa 1 Maneesh Sahani 1 Models that employ latent variables to capture structure in observed data lie at the heart of many current unsupervised learning algorithms, but exact maximum-likelihood learning for powerful and flexible latent-variable models is almost always intractable. Thus, state-of-the-art approaches either abandon the maximum-likelihood framework entirely, or else rely on a variety of variational approximations to the posterior distribution over the latents. Here, we propose an alternative approach that we call amortised learning. Rather than computing an approximation to the posterior over latents, we use a wake-sleep Monte-Carlo strategy to learn a function that directly estimates the maximum-likelihood parameter updates. Amortised learning is possible whenever samples of latents and observations can be simulated from the generative model, treating the model as a black box . We demonstrate its effectiveness on a wide range of complex models, including those with latents that are discrete or supported on non-Euclidean spaces. 1. Introduction Many problems in machine learning, particularly unsupervised learning, can be approached by fitting flexible parametric probabilistic models to data, often based on local latent variables whose number scales with the number of observations. Once the optimal parameters are found, the resulting model may be used to synthesise samples, detect outliers, or relate observations to a latent representation . The quality of all of these operations depends on the appropriateness of the model class chosen and the optimality of the identified parameters. Although many fitting objectives have been explored in the literature, maximum-likelihood (ML) estimation remains 1Gatsby Computational Neuroscience Unit. Correspondence to: Li K. Wenliang . Proceedings of the 37 th International Conference on Machine Learning, Online, PMLR 119, 2020. Copyright 2020 by the author(s). prominent and comes with attractive theoretical properties, including consistency and asymptotic efficiency (Newey & Mc Fadden, 1994). A challenge, however, is that analytic evaluation of the likelihoods of rich, flexible latent variable models is usually intractable. The Expectation Maximisation (EM) algorithm (Dempster et al., 1977) offers one route to ML estimation in such circumstances, but it in turn requires an explicit calculation of (expected values under) the posterior distribution over latent variables, which also proves to be intractable in most cases of interest. Consequently, state-of-the-art ML-related methods almost always rely on approximations, particularly in large-data settings. Denote the joint distribution of a generative model as pθ(z, x) where z is latent and x is observed, and θ is the vector of parameters. EM breaks the ML problem into an iteration of two sub-problems. Given parameters θt on the tth iteration, first find the posterior pθt(z|x); then maximise a lower bound to the likelihood that depends on this posterior to obtain θt+1. This bound is tight when computed using the correct posterior, ensuring convergence to a local mode of the likelihood. The intractability of pθ(z|x) forces some combination of Monte-Carlo estimation and the use of a tractable parametric approximating family which we call q(z|x) (Bishop, 2006). To avoid repeating the expensive optimisation in finding q(z|x) for each x, amortised inference trains an encoding or recognition model, with parameters φ, to map from any x directly to an approximate posterior qφ(z|x). Examples of amortised inference models include the Helmholtz machine (Dayan et al., 1995; Hinton et al., 1995) trained by the wakesleep algorithm; and the variational auto-encoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) trained using reparamerisation gradient methods. With considerable effort on improving variational inference (reviewed in (Zhang et al., 2018)), complex and flexible generative models have been trained on large, high-dimensional datasets. However, approximate variational inference poses at least three challenges. First, the parametric form of the approximate posterior q(z|x), and particularly any factorisations assumed, must be crafted for each model. Second, methods such as reparameterisation require specific transformations tailored to the type of latent variables, whether they are continuous or discrete, and whether or not the support is Amortised Learning by Wake-Sleep Figure 1. VAE trained on binarised MNIST digits. Top: mean images generated by decoding points on a grid of 2-D latent variables. Bottom three rows show five samples of real MNSIT digit (top), the corresponding true posteriors (middle) found by histogram and the approximate posteriors computed by the encoder. Euclidean. Third, given a flexible generative model, such as one with conditional dependence modelled using neural networks, the true posteriors may be irregular in ways that are difficult to approximate. We illustrate this latter effect using a standard VAE with two-dimensional z trained on binarised MNIST digits (Figure 1). The exact posterior may be distorted or multi-modal, even though only Gaussian posteriors are ever produced by the encoder. When inference is only approximate, the M-step of EM may not increase the likelihood, and so approximate methods usually converge away from the ML parameter values. The dependence of learnt parameters on the quality of the posterior approximation is not straightforward, and the error may not be reduced by (say) approximations with lower Kullback-Leibler (KL) divergence (Turner & Sahani, 2011); indeed errors in posterior statistics that enter the objective function may be unbounded (Huggins et al., 2019). Here, we propose a novel approach to ML learning in flexible latent variable models that avoids the complications of posterior estimation, instead learning to predict the gradient of the likelihood directly an approach we call amortised learning. The particular realisation we develop here, amor- tised learning by wake sleep (ALWS), requires only that sampling from the generative model pθ(z, x) be possible, and that the gradient θ log pθ(z, x) be available (possibly by automated methods), but otherwise does not make assumptions about the latent variable form or distribution. We test the performance of ALWS on a wide range of tasks and models, including hierarchical models with heterogeneous priors, nonlinear dynamical systems, and deep models of images. All experiments use the same form of gradient model trained by simple least-squares regression. For image generation, we find that models trained with ALWS can produce samples of considerably better quality than those trained using algorithms based on variational inference. 2. Background 2.1. Model Definition Consider a probabilistic generative model with parameter vector θ that defines a prior on latents pθ(z) and a conditional on observations pθ(x|z). In ML learning, we seek parameters that maximise the log (marginal) likelihood log pθ(x) = log Z pθ(z)pθ(x|z)dz (1) averaged over a set of i.i.d. data D = {x m}M m=1. One approach is to iteratively update θ by following the gradient θ(x) := θ log pθ(x) (2) at each iteration1 2.2. Variational Inference for Learning For many models of interest, the integral in (1) cannot be evaluated analytically, and so direct computation of the gradient is intractable. A popular alternative is to maximise a variational lower bound on the marginal likelihood defined by a distribution q(z): F(q, θ) := Eq(z)[log pθ(z, x)] + H[q] log pθ(x), (3) where H[q] is the entropy of q. Thus, the parameter θ can be updated by following the gradient of F(q, θ) w.r.t. θ θF(q, θ) = θEq(z)[log pθ(z, x)] = Eq(z)[ θ log pθ(z, x)]. (4) When q(z) = pθ(z|x), the lower bound in (3) is tight, and the gradient in (4) is equal to that of the likelihood (see Appendix A.3). Variational approximations attempt to bring q close to pθ(z|x), usually by seeking to minimise DKL[q(z)||pθ(z|x)] (which corresponds to maximising the bound F w.r.t. q). However, although minimising 1We define the likelihood gradient for a single data point here and throughout; an actual update will typically follow the gradient averaged over i.i.d data. Amortised Learning by Wake-Sleep DKL[q(z)||pθ(z|x)] over q ensures consistent optimisation of a single objective, the resulting gradient in (4) will often be a poor approximation to the likelihood gradient (2). 2.3. Conditional Expectation and LSR Our approach is to avoid the difficulties introduced by approximating pθ(z|x) with q(z) in (4), and instead estimate the conditional expectation directly using least-squares regression (LSR). Let x and y be random vectors with a joint distribution ρ(x, y) on Rdx Rdy. In LSR, we seek a (vector-valued) function f that achieves the lowest mean squared error (MSE) Eρ(x,y) y f(x) 2 2 . The ideal solution is given by fρ(x) := Eρ(y|x)[y], as the problem can be cast as the minimisation of Eρ(x) fρ(x) f(x) 2 2 , where ρ(x) is the marginal distribution of x (see Appendix A.1). Note that fρ(x) takes a similar form as the desired (4). In practice, the distribution ρ(x, y) is known only through a sample {(xn, yn)}N n=1 i.i.d. ρ(x, y); thus, LSR can be understood to seek a good approximation of fρ based on the sample. 2.4. Kernel Ridge Regression In LSR, as the target fρ is unknown, it is desirable to construct an estimate without imposing restrictions on its form. Kernel ridge regression (KRR) is a nonlinear regression method that draws the estimated regression function from a flexible class of functions called a reproducing-kernel Hilbert space (RKHS) (Hofmann et al., 2008). The KRR estimator is found by minimising the regularised empirical risk min f H 1 N n=1 yn f(xn) 2 2 + λ f 2 H, (5) where λ > 0 is a regularisation parameter, and H is the RKHS corresponding to a matrix-valued kernel κ : Rdx Rdx Rdy dy (Carmeli et al., 2006). The solution can be found conveniently in closed-form, which allows a further simplification detailed in Section 3.2. In this paper, we use a kernel of the form κ(x, x ) = k(x, x )Iy, where Iy is the identity matrix, and k is a scalar-valued positive definite kernel; therefore, the matrix-valued kernel κ can be identified with its scalar counterpart k. In particular, in the scalar output case dy = 1, this choice of κ coincides with KRR with the scalar kernel k. Importantly, the closed-form solution ˆfλ of KRR in (5) can be expressed as ˆfλ(x ) = Y(K + NλIN) 1k , (6) where Y is the concatenation of the training targets [y1, . . . , y N] Rdy N, K RN N is the gram matrix whose element is (K)ij = k(xi, xj), IN is the identity matrix and k = (k(xi, x ))N i=1 RN for a test point x . In the limit of N and λ 0, the solution ˆfλ will achieve the minimum MSE in the RKHS (Caponnetto & De Vito, 2007). In general, the target fρ may not be in the RKHS2; nonetheless, if the RKHS is sufficiently rich (or C0 universal (Carmeli et al., 2010)), the error made by the estimator Eρ(x) h ˆfλ(x) fρ(x) 2 2 i will converge to zero (Szab o et al., 2016, Theorem 7). 3. Amortised Learning by Wake-Sleep 3.1. Gradient of Log-Likelihood As stated above and derived in Appendix A.3, the loglikelihood gradient function evaluated on observation x at iteration t (with current parameters θt) can be written θt(x) = θ log pθt(x) θt = θF(pθt(z|x), θ) θt, (7) where the gradient in the second line is taken w.r.t. the second argument of F; the posterior distribution is for a fixed θ at the current θt. We want to directly estimate of this gradient without explicit computation of the posterior. Inserting the definition from (4) into (7) we have, θt(x) = Epθt(z|x) h θ log pθ(z, x) θt = θEpθt(z|x)[log pθ(z, x)] θt = θJθ(x) θt. (9) where Jθ(x) := Epθt(z|x)[log pθ(z, x)]. Note that the function Jθ(x) changes with iteration due to the dependence on pθt(z|x). It can be regarded as an instantaneous objective for ML learning starting from θt. Neither (8) nor (9) can be computed in closed form, and therefore need to be estimated. We refer to ML learning via the estimation of θt(x) either through Jθ by (9) or directly by (8) as amortised learning. The difference between the two equations lies purely in implementation: The former estimates the high-dimensional θt(x) directly, whereas the latter implements the same computation by differentiating Jθ(x). We term an estimator of Jθ a gradient model, as it retains information about θ and is used to estimate the gradient θt(x). In the next section, we develop a concrete instantiation of amortised learning. 3.2. Training KRR Gradient Model by Wake-Sleep As discussed in Section 2.3, LSR allows us to estimate the conditional expectation of an output variable given an input. Thus, although the gradient in (8) (or in (9)) involves an intractable conditional expectation, we can obtain an estimate of the gradient θt(x) by regressing from x to 2In this case, fρ is only assumed to be square-integrable with respect to ρ Amortised Learning by Wake-Sleep θ log pθ(z, x) (or log pθ(z, x)). Any reasonable regression model, e.g., a neural network, could serve this purpose, but here we choose to use KRR introduced in Section 2.4. Other possible forms of gradient model are discussed in Appendix B.1. The expression in (8) leads to the following LSR problem min f H 1 N n=1 θ(yθ,n) θt f(xn) 2 2 + λ f 2 H, (10) where yθ,n = log pθ(zn, xn), H is an RKHS and {(zn, xn)}N n=1 pθt. Brehmer et al. (2020) also noticed that log-likelihood gradient can be obtained by LSR. However, regressing to a vector-valued θ log pθ can be expensive, and evaluating the target yθ,n on all (zn, xn) is slow. Alternatively, we can use (9) and find an estimator for the scalar-valued Jθ that keeps the dependence on θ and then evaluate its gradient by automatic differentiation. Thus, we construct an estimator by ˆJθ,γ = arg min f H n=1 |yθ,n f(xn)|2 + λ f 2 H, (11) where H is the RKHS induced by a kernel kω( , ) with hyperparameters ω, and γ = {ω, λ}. For each data point x D, the estimate of Jθ(x ) is ˆJθ,γ(x ) = αθ,γ k ω, (12) αθ,γ = yθ (Kω + λNIN) 1 , (yθ)n = log pθ(zn, xn) Kω,i,j = kω(xi, xj), k ω,j = kω(xj, x ) where IN is the identity matrix of size N N. Note that the dependence of ˆJθ,γ on θ is only through evaluations of log pθ(z, x) on samples drawn from pθt for fixed θ = θt. The gradient θt(x) is then estimated as ˆ θt,γ(x) := θ ˆJθ,γ(x) θt. In general, a good estimator of Jθ may not yield a reliable estimate of its gradient θJθ; however, for the KRR estimate, taking the derivative of ˆJθ,γ w.r.t. θ is equivalent to replacing yθ in (12) with θ(yθ)|θt, which is the solution for the optimisation in (10), with H being a vector-valued RKHS given by a kernel κω = kωI (see Section 2.4). We show in Appendix A.2 that, under mild conditions, the target of the regression Epθt(z|x) h θyθ,n θt i is square-integrable under pθt(x) for common generative models. In summary, learning proceeds according to the following wake-sleep procedure: at the tth step when θ = θt, the gradient model is first trained using sleep samples (zn, xn) pθt and evaluations log pθ(zn, xn), keeping the dependence on θ; then the gradient model is applied to real data ( wake samples) x D to produce ˆ θt,γ(x ) by differentiating ˆJθ,γ and evaluating at θt. See Algorithm 1. Two points are worth emphasis: (a) The algorithm does not require explicit computation or approximation of the posterior, and (b) We only need samples from the model pθ(z, x) and differentiable evaluations of log pθ(z, x). 3.3. Exponential Family Conditionals In many common models, the conditional pθ(x|z) lies in the exponential family (e.g. Gaussian, Bernoulli), and we can exploit this structure to simplify the estimation of Jθ. In this case, the log joint can be written as log pθ(z, x) = log pθ(x|z) + log pθ(z) = ηθ(z) s(x) log Zθ(z) + log pθ(z) = ηθ(z) s(x) Ψθ(z) where ηθ(z), s(x) and Zθ(z) are, respectively, the natural parameter, sufficient statistics and normaliser of the likelihood, and Ψθ := log Zθ(z) log pθ(z). By taking the posterior expectation, Jθ(x) in (9) becomes Jθ(x) = Epθt[ηθ(z)] | {z } hη θ (x) s(x) Epθt[Ψθ(z)] | {z } hΨ θ (x) where pθt stands for pθt(z|x). Therefore, for exponential family likelihoods, the regression to log pθ(z, x) in (11) can be replaced by two separate regressions to ηθ(z) and Ψθ(z), which are functions of z alone. The resulting estimators ˆhη θ,γ and ˆhΨ θ,γ are combined to yield ˆ θt,γ(x) = θ h ˆhη θ,γ(x) s(x) i θt θˆhΨ θ,γ(x)|θt, where the Jacobian vector product applies to the first term. 3.4. Kernel Structure and Learning The kernel kω used in the gradient model affects how well θt(x) is estimated. It can be made more flexible by augmenting with a neural network as in (Wilson et al., 2016; Wenliang et al., 2019) kω(x, x ) = κσ(ψv(x), ψv(x )) where κσ is a standard kernel (e.g. exponentiated-quadratic) with parameter σ (e.g. bandwidth), and ψv is a neural network with parameter v, so ω = {σ, v}. Other details of the kernel structure are described in Appendix B.2. The gradient model parameter γ = {ω, λ} can be learned to further minimise the MSE in (11) using a scheme of cross-validation by gradient descent (Wenliang et al., 2019). Specifically, we generate two sets of sleep samples from pθ; we use one set to compute αθ,γ in closed form; then, on the other set {(z l, x l)}L l=1, we compute the MSE between the estimator ˆJθ,γ(x l) and the ground truth value log pθ(z l, x l), and minimise this by gradient descent on γ. The full ALWS procedure is presented in Algorithm 1. Amortised Learning by Wake-Sleep Algorithm 1: Amortised learning by wake sleep input :Dataset D, gradient model parameters γ, generative model log pθ(z, x), or ηθ and Ψθ with parameters θ initialised s.t. pθ(x) covers/dominates the data distribution, max epoch and any convergence criteria. while θ not converged within max epoch do Sleep phase: train gradient model Sample {zn, xn}N n=1 pθ if p(x|z) is not in exponential family then Find ˆJθ,γ( ) by computing αθ,γ in (12) else Find ˆhη θ,γ( ) and ˆhΨ θ,γ( ) similar to (12) ˆJθ,γ( ) = ˆhη θ,γ( ) s(x) ˆhΨ θ,γ( ) in (13) Sleep phase: update γ Sample {z l, x l}L l=1 pθ Compute dl := log pθ(z, x) Compute Eγ = 1 L PL l=1( ˆJθ,γ(x l) dl)2 Update γ γEγ Wake phase: update θ Sample {x m}M m=1 D Jθ = 1 M PM i ˆJθ,γ(x m) Update θ θ Jθ end return :θ 3.5. Dealing with Covariate Shift The gradient model is to be used to estimate θt(x) on x drawn from an underlying data distribution p , but it is trained using sleep samples from pθt. This mismatch in input data distribution for training and evaluation is known as covariate shift (Shimodaira, 2000). Here, to ensure that the gradient model performs reasonably well on p , we initialise pθ(x) to be overdispersed relative to p by setting a large noise in pθ(x|z). Since ML estimation minimises DKL[p pθ], which penalises a distribution pθ that is narrower than p , we expect the noise to continue to cover the data before the model is well trained. For image data only, we also apply batch normalisation in ψw of the kernel. We find these simple remedies to be effective, though other more principled methods, such as kernel mean matching (Gretton et al., 2009) and binary classification (Gutmann & Hyv arinen, 2010; Goodfellow et al., 2014), may further improve the results. 4. Experiments We evaluate ALWS on a wide range of generative models. Details for each experiment can be found in Appendix C. 3 3Code is at github.com/kevin-w-li/al-ws Figure 2. Gradient estimated using amortised learning and variational inference. The true gradients are approximated by importance sampling. Figure 3. Learning to generate Gabor filters given a 1-D circular uniform prior. Top images show samples generated by latents separated by fixed rotation on the circle. For VAE, a 2-D Gaussian prior was used, and the images are generated by latents on the unit circle. S-VAEs cannot reliably learn the filters. The errors below show the squared distance between generated images and data at each orientation. For each method, an angle offset and direction are chosen to minimise the total error. 4.1. Parameter Gradient Estimation First, we demonstrate that KRR can estimate θt(x) well on a simple toy generative model described by z1, z2 N(0, 1), x|z N(softplus(b z) b 2 2, σ2 x). The training data are 100 data points from the model given b = [1, 1], σx = 0.1. we estimate the gradients of the log-likelihood w.r.t. b evaluated at a grid of b by ALWS, and compare them to estimates using importance sampling ( truth ) and a factorised Gaussian posterior that minimises the forward KL for each x. For ALWS, we used a Gaussian kernel with a bandwidth equal to the median distance between samples generated for each b, and set λ = 0.01. For variational inference, we assumed a factorised Gaussian posterior for each sample of x, and optimise posterior pa- Amortised Learning by Wake-Sleep Figure 4. Learning hierarchical model with discrete and continuous latents. From left to right: data sample, component probabilities, samples of the first latent distribution and samples of generated data. Colours correspond to different components rameters until convergence. ALWS tends to estimate better, especially for small b (Figure 2). For the smallest σx, the KRR estimates are noisier, whereas variational inference introduces greater bias. 4.2. Non-Euclidean Priors The prior p(z) may capture special topological structure in the data. For instance, a prior over the hypersphere can be used to describe circular features (Davidson et al., 2018; Xu & Durrett, 2018). Training models with such a prior is straightforward using ALWS, while learning by amortised inference requires special reparameterisation for a posterior on the hypersphere, such as the von-Mises Fisher (v MF) used in the S-VAE (Davidson et al., 2018; Xu & Durrett, 2018). We fit a model with uniform circular latent and neural-network output: z = [cos(a), sin(a)], p(a) = U(a; ( π, π)), p(x|z) = N(x; NNw(z), σ2 x I), (where U is a uniform distribution) on a data set of Gabor wavelets with uniformly distributed orientations. As shown in Figure 3, ALWS learns to generate images that closely resemble the training data. A fixed rotation around the latent circle corresponds to almost a fixed rotation of the Gabor wavelet in the image. The VAE with a 2-D Gaussian latent also generates good filters given latents on the circle, but the length of the filter varies with rotation. Surprisingly, SVAE is not able to learn on this dataset, the v MF posterior is almost flat for any input image. This hints at potential optimisation issues with the complicated reparameterisation. This advantage also extends to priors over the hyperbolic space, which are used to capture tree-like hierarchical structures (Nagano et al., 2019; Mathieu et al., 2019). 4.3. Hierarchical Models Rich hierarchical structures in the data can be captured with multiple layers of latents. Provided that samples can be drawn from the hierarchical model and the joint loglikelihood evaluated, ALWS extends straightforwardly to hierarchies, even with mixed discrete and continuous latents. The pinwheel distribution (Johnson et al., 2016; Lin et al., Figure 5. Feature identification. Left, true basis used to generate images. Middle, basis recovered by ALWS. Right, basis recovered by VAE. The filters are arranged according to correlations with the true basis. 2018) has five clusters of distorted Gaussian distributions (Figure 4), and can be described by the following model: p(z1) = Cat(z1; m), p(z2|z1 = k) = N(z2; µk, Σk), p(x|z2) = N(x; NNw(z2), Σx), where Cat is the categorical distribution. The parameters are the logits m in 10 dimensions, the means and covariance matrices of the component distributions {µk, Σk}10 k=1, the weights w in NN, and the diagonal covariance Σx. The logits m are penalised according to a Dirichlet prior, and {µk, Σk}10 k=1 by a normal-Wishart prior. After training with ALWS, the categorical distribution correctly identifies the five components, and the generated samples match the training data. We compare these samples with those reconstructed from a Bayesian version of the model trained by structured inference network (SIN) (Lin et al., 2018)4. A three-way maximum mean discrepancy (MMD) test (Bounliphone et al., 2016) finds that samples from the two models are equally close to the training data (p = 0.514, N = 1, 000 samples). Details are in Appendix C.3. 4.4. Feature Identification Independent Components. Learning informative features from complex data can benefit downstream tasks. We use ALWS to identify features from data generated by p(zi) = Lap(zi; 0, 1), p(x|z) = N(x; W z, σ2I), where Lap is the Laplace distribution, σ = 0.1 and basis W contains independent components of natural images (Hateren & Schaaf, 1998) found by the Fast ICA algorithm (Hyv arinen & Oja, 2000). Since this model is identifiable, we perform model recovery from a random initialisation of W using ALWS and compare with a VAE. ALWS clearly finds better features, as shown in Figure 5. On generated samples, a three-way MMD test favours ALWS over the Laplace-VAE (p < 10 5) based on 10, 000 samples. Details are in Appendix C.4. 4github.com/emtiyaz/vmp-for-svae Amortised Learning by Wake-Sleep 0.00 0.01 0.02 0.03 0.04 0.05 MSE Num. Examples ALWS G-Rep VAE 0.02 0.04 0.06 0.08 0.10 0.12 MSE Num. Examples ALWS G-Rep VAE Reconstruction Denoising VAE Reconstruction noisy test data Figure 6. Beta-Gamma Matrix Factorisation. Top, mean squared error across 1,000 test inputs compared to G-Rep and VAE. Bottom, examples of real data, reconstructed and denoised samples. Matrix Factorisation. A more accurate data model may improve performance on a downstream task that relies on inference of associated latent variables. Following (Ruiz et al., 2016), we test post-learning inference on a probabilistic non-negative matrix factorisation model: p(zi) = U(zi; 0, 1), p(xi|z) = Bernoulli (xi; xi) xi = sigmoid (wi logit(z) + bi) . For each element of each wi, we place a penalty consistent with a Gamma(w; 0.9, 0.3) prior on each entry and learn W and b. We include b to the model trained by ALWS as it prevents samples with opposite colour polarity to be generated, which creates a more severe covariate shift that harms the gradient model. We evaluate the models on reconstructing and denoising handwritten digits from the binarised MNIST dataset. To recover the original image given a clean or noisy x , we generate x given the posterior mode found by maximising log p(z, x ) over z. We compare with a Bayesian version of the model trained by generalised reparameterisation Ruiz et al. (2016) and a VAE-like model in which the decoder has the generative structure as above and the posterior is a reparametrised Beta distribution. The results for both tasks are depicted in Figure 6. The leftmost panels show the histograms of MSE on 1 000 test images, and the other panels show examples of 25 test images and reconstructions by each method. ALWS achieved significantly lower error (p < 10 10 for both a two-tailed t-test and a Wilcoxon signed-rank test). Figure 7. Modelling blowfly population time series. Black, training data. Coloured, samples for an extended time period drawn from trained model. 4.5. Neural Processes The neural process (NP) (Garnelo et al., 2018) is a model that learns to infer over functions. Conceptually, the computational goal of NPs is similar to predictive inference in Gaussian Processes, but without defining an explicit prior over functions. We review NPs in more detail and illustrate how they can be trained by ML using ALWS in Appendix C.5. We compared ALWS with the original variational learning method on a toy problem. NP trained by ALWS produces better prediction and uncertainty estimates on test inputs. See Figure 10 in Appendix C.5. 4.6. Dynamical Models In fields such as biology and environmental science, the behaviour of complex systems is often described by simulationbased dynamical models. Estimating parameters for these models from data is crucial for prediction and policymaking. (Lintusaari et al., 2016; Sunn aker et al., 2013; Kypraios et al., 2017) A dynamical model can be expressed, in discrete time, as zt = lθ(z1:t 1, x1:t 1, ut, ϵt), xt = oθ(zt) + et where lθ describes a latent process that can depend on a control input ut, a noise source ϵt and the history of latents z1:t 1 and measurements x1:t 1. The function oθ maps the latent zt to measurement with noise et. For ALWS, we need that pθ(zt, ϵt|z1:t 1, x1:t 1, ut) and pθ(xt, et|zt) are tractable so that θ log p(z1:T , x1:T ) can be evaluated, where T is the length of the data. However, learning using approximate inference may be challenging due to complex dependencies between latent variables and across time. Here, we fit the parameters of two dynamical models: the Hodgkin-Huxley (HH) model (Pospischil et al., 2008) on the membrane potential of a simulated neuron, and an ecological model (ECO) on blowfly data (Wood, 2010). The HH equations describe the membrane potential and three ion-channel state variables of a neuron that follow complicated nonlinear transitions. Details of the experiment are in Appendix C.6. Results in Figure 12 show that the trained model can not only reproduce the training data well but also predict the response given new inputs ut. ECO describes nonlinear and non-Gaussian dynamics and has discrete and Amortised Learning by Wake-Sleep continuous latent variables. Fitting ECO on blowfly data was used to validate approximate Bayesian computation (ABC) methods (Park et al., 2016). The model trained with ALWS can simulated sequences very close to data Figure 7, and are visibly closer than sequences from the model trained with ABC (Park et al., 2016, Figure 2b). 4.7. Sample Quality Finally, we train deep models of images and test sample quality. We chose six benchmark datasets: the binarised and original MNIST (Le Cun et al., 1998) (B-MNIST and MNIST, respectively), fashion MNIST (Fashion) (Xiao et al., 2017), natural images (Natural) (Hateren & Schaaf, 1998), CIFAR10 (Krizhevsky et al., 2009) and Celeb A (Liu et al., 2015). The original un-binarised MNIST is known to be difficult for most VAE-based methods (Loaiza-Ganem & Cunningham, 2019). Natural images consist of grey-scale images from natural scenes. All images have size 32 32 with colour channels. For ALWS, we test two variants. In ALWS-F, gradient model parameters γ are fixed. In ALWS-A, γ is adapted as described in Section 3.4 except for λ which is fixed at 0.1. Fixing λ improved quality for the higherdimensional CIFAR-10 and Celeb A, but lowered quality for Natural and did not affect much on the other datasets. We compare these methods with four other approaches: the vanilla VAE (Kingma & Welling, 2014), VAE with a Sylvester (orthogonal) flow as an inference network (van den Berg et al., 2018) (Syl-VAE)5, semi-implicit variational inference (Yin & Zhou, 2018) (SIVI)6, and reweighted wake-sleep (Bornschein & Bengio, 2015). Each algorithm has the same generative network architecture as in DCGAN7 with the last convolutional layer removed. We also run WGAN-GP (Gulrajani et al., 2017)8 for reference, although it is not trained by ML methods. Each algorithm is run for 50 epochs 10 times with different initialisations, except for SIVI where we trained for 1000 epochs with a lower learning rate for stability. To test the generative quality, we compute both the Fr echet Inception Distance (FID) (Heusel et al., 2017) and Kernel Inception Distance (KID) (Binkowski et al., 2018) on 10,000 generated images. The results are shown in Figure 8. According to FID, ALWSA is the best ML method for binarised MNIST, Fashion, and CIFAR-10. Notably, both ALWS-A and ALWS-F have much smaller FID and KID on MNIST and Fashion than other ML methods. WGAN-GP did not produce a good score on CIFAR-10 within 50 epochs but becomes the best model for all datasets with further training. Samples are 5github.com/riannevdberg/sylvester-flows 6github.com/mingzhang-yin/SIVI 7pytorch.org/tutorials/beginner/dcgan_ faces_tutorial.html 8github.com/caogang/wgan-gp shown from Figure 15 to Figure 20 in Appendix C.7 with additional experiments to show the effectiveness of ALWS. 5. Related Work 5.1. Amortised Variational Inference Using F(q, θ) as the objective for learning θ, the gradient for θ is given by an intractable posterior expectation. The large majority of learning algorithms based on amortised variational inference use Monte Carlo estimators for the gradient. The Variational auto-encoder (VAE) (Kingma & Welling, 2014; Rezende et al., 2014) parametrises qφ(z|x) by simple distributions using reparameterised samples to obtain gradients for ψ. Approximate posteriors may also be incorporated into tighter bounds on log pθ(x) by reweighting (Burda et al., 2016; Bornschein & Bengio, 2015; Le et al., 2019), although with some loss of gradient signal (Rainforth et al., 2018). More expressive forms of qφ can be formed by invertible transformations (normalising flows) (Rezende & Mohamed, 2015; Kingma et al., 2016; van den Berg et al., 2018)) that allow H[qφ] to be computed easily, or by non-invertible mappings (implicit variational inference), which requires estimating H[qφ] or its gradient w.r.t. φ (Shi et al., 2018; Li & Turner, 2018; Yin & Zhou, 2018; Husz ar, 2017). Reparametrising posterior samples may require nontrivial methods (Jang et al., 2017; Vahdat et al., 2018; Rolfe, 2017; Ruiz et al., 2016; Figurnov et al., 2018). On the other hand, amortised learning focuses exclusively on estimating the gradient for ML learning, making no assumptions on the type of latent variables. Our approach is related to at least two other algorithms inspired by the original Helmholtz machine (HM) (Dayan et al., 1995; Hinton et al., 1995). The distributed distributional code HM (DDC-HM) (V ertes & Sahani, 2018) represents posteriors by expectations of pre-defined and finite nonlinear features, which are used to approximate θt(x) by linearity of expectation. ALWS differs from DDC-HM in two ways. First, our gradient model integrates the inferential model and the linear readout for θt(x) in DDC-HM using adaptive and more flexible KRR. Second, using (9) avoids explicit computation of θ log pθ and makes ALWS easily applicable to more complex generative models. Reweighted wake-sleep (RWS) (Bornschein & Bengio, 2015) addressed covariance shift by training an inferential model to increase the likelihood of not only sleep z given sleep x as in the HM, but also weighted posterior samples given data x . ALWS does not make assumptions about the posterior distributions, and we found that simple strategies mitigated covariate shift in practice, but this is a point that deserves further investigation. Amortised Learning by Wake-Sleep RWS WGAN-GP* RWS WGAN-GP* RWS WGAN-GP* RWS WGAN-GP* RWS WGAN-GP* RWS WGAN-GP* Figure 8. FID and KID scores (lower is better) for different datasets and methods. Red dot is the score for a single run. Bars are medians of the dots for each method. Short bars on KID dots shows standard error of the estimate. All models are trained for 50 epochs. 5.2. Training Implicit Generative Models Implicit generative models, including generative adversarial networks (GANs) (Goodfellow et al., 2014) and simulationbased models considered by approximate Bayesian computation (ABC) (Tavar e et al., 1997; Marin et al., 2012), do not have an explicitly defined likelihood function, but can be trained using simulated data. Amortised learning requires an explicit joint likelihood function pθ(x, z), but can also train simulation-based generative models (Section 4.6). In GANs, the generator is improved by a discriminator that is concurrently trained to tell apart real and generated samples. The approach is able to synthesise high-quality samples in high dimensions. However, the competitive setting can be problematic for convergence, and the discriminator needs to be carefully regularised to be less effective at its own task but more informative to the generator. (Arjovsky et al., 2017; Gulrajani et al., 2017; Arbel et al., 2018; Mescheder et al., 2018). In amortised learning, a better gradient model always helps when training the generative model. Importantly, amortised learning can directly train real-world simulators for which samples of x are not differentiable w.r.t. θ, such as the Galton board, where GANs are not directly applicable. Rather than performing maximum likelihood estimation, ABC estimates a posterior of θ using simulated data and a chosen prior on θ. Amortised learning can be seen as maximum likelihood learning based on simulations, since the gradient model is trained using data from the generative model. In particular, ALWS is similar to Kernel-ABC (Nakagome et al., 2013) in which the posterior is found by weighting prior samples using KRR on pre-defined summary statistics. The kernel recursive ABC (Kajihara et al., 2018) iteratively updates the prior over θ by herding from a kernel embedding (Song et al., 2009) of the posterior, converging to a maximum likelihood solution. ALWS does not maintain a distribution of θ, but iteratively updates them by gradient methods so that the model distribution approaches the data distribution. Also, ALWS performs well even when the number of parameters is large for which traditional ABC methods are likely to be expensive. 6. Discussion Direct estimation of the expected log-likelihood and its gradient in a latent variable model circumvents the challenges and issues posed by explicit approximation of posteriors. The KRR gradient model is consistent, easy to implement, and avoids the need for explicit computation of derivatives. However, we observe the following issues with the current instance of amortised learning. First, its computational complexity limits the number of sleep samples that can be used to train the gradient model and thus the quality of the approximation. Techniques such as random featureand Nystrom-approximations could make KRR more efficient. Second, the KRR prediction is a linear combination of the set { θ log pθ(zn, xn)}N n=1, but the true gradient function, which can be much higher-dimensional than N, may lie outside this span an issue that might be compounded by covariate shift. Further, hyper-parameter learning using the meta-learning method described in Section 3.4 improves the estimation of Jθ rather than θJθ, which might explain why adapting λ on some tasks worsens the results. Therefore, alternative amortised learning models may be worth future exploration. Nonetheless, we have found here that ALWS based on KRR provides accurate parameter estimates in many settings where approximate inference-based approaches appear to struggle. ALWS can be extended to training generative models of other types of data, such as graphs, as long as an appropriate kernel is used. Another useful extension is to train conditional generative models, which we explored briefly in the neural processes experiment. In this case, the gradient model needs to depend on any conditioning variables (or sets). Finally, while we used LSR to approximate the gradient of the model w.r.t θ, other useful quantities could also be estimated in a similar fashion Brehmer et al. (2020). Amortised Learning by Wake-Sleep Acknowledgements We thank Arthur Gretton, Sebastian Nowozin, Jiaxin Shi and Eszter V ertes for helpful discussions; we thank Ferenc Husz ar for discussion and comments on an earlier draft. Arbel, M., Sutherland, D. J., Binkowski, M., and Gretton, A. On gradient regularizers for MMD GANs. In Neur IPS, pp. 6701 6711, 2018. Arjovsky, M., Chintala, S., and Bottou, L. Wasserstein generative adversarial networks. In ICML, 2017. Binkowski, M., Sutherland, D. J., Arbel, M., and Gretton, A. Demystifying MMD GANs. In ICLR, 2018. Bishop, C. M. Pattern Recognition and Machine Learning. Springer, 2006. Bornschein, J. and Bengio, Y. Reweighted wake-sleep. In ICLR, 2015. Boucheron, S., Lugosi, G., and Massart, P. Concentration Inequalities: A Nonasymptotic Theory of Independence. Oxford University Press, February 2013. ISBN 978-0-19953525-5. doi: 10.1093/acprof:oso/9780199535255.001. 0001. Bounliphone, W., Belilovsky, E., Blaschko, M. B., Antonoglou, I., and Gretton, A. A test of relative similarity for model selection in generative models. In ICLR, 2016. Brehmer, J., Louppe, G., Pavez, J., and Cranmer, K. Mining gold from implicit models to improve likelihood-free inference. Proceedings of the National Academy of Sciences, 117(10):5242 5249, 2020. Burda, Y., Grosse, R. B., and Salakhutdinov, R. Importance weighted autoencoders. In ICLR, 2016. Caponnetto, A. and De Vito, E. Optimal rates for the regularized least-squares algorithm. Foundations of Computational Mathematics, 2007. Carmeli, C., De Vito, E., and Toigo, A. Vector valued reproducing kernel Hilbert spaces of integrable functions and Mercer theorem. Analysis and Applications, 2006. Carmeli, C., De Vito, E., Toigo, A., and Umanit a, V. Vector valued reproducing kernel Hilbert spaces and universality. Analysis and Applications, 2010. Chatterjee, S., Diaconis, P., et al. The sample size required in importance sampling. The Annals of Applied Probability, 28(2):1099 1135, 2018. Davidson, T. R., Falorsi, L., Cao, N. D., Kipf, T., and Tomczak, J. M. Hyperspherical variational auto-encoders. In UAI, 2018. Dayan, P., Hinton, G. E., Neal, R. M., and Zemel, R. S. The Helmholtz machine. Neural computation, 1995. Dempster, A. P., Laird, N. M., and Rubin, D. B. Maximum likelihood from incomplete data via the em algorithm. Journal of the Royal Statistical Society: Series B (Methodological), 1977. Dieng, A. B. and Paisley, J. Reweighted expectation maximization. ar Xiv preprint ar Xiv:1906.05850, 2019. Figurnov, M., Mohamed, S., and Mnih, A. Implicit reparameterization gradients. In Neur IPS, 2018. Garnelo, M., Schwarz, J., Rosenbaum, D., Viola, F., Rezende, D. J., Eslami, S., and Teh, Y. W. Neural processes. ar Xiv preprint ar Xiv:1807.01622, 2018. Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. Generative adversarial nets. In Neur IPS, pp. 2672 2680, 2014. Gretton, A., Smola, A., Huang, J., Schmittfull, M., Borgwardt, K., and Sch olkopf, B. Covariate shift by kernel mean matching. Dataset shift in machine learning, 2009. Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., and Courville, A. C. Improved training of Wasserstein GANs. In Neur IPS, 2017. Gutmann, M. and Hyv arinen, A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models. In AISTATS, 2010. Hateren, J. H. v. and Schaaf, A. v. d. Independent component filters of natural images compared with simple cells in primary visual cortex. Proceedings: Biological Sciences, 1998. Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. GANs trained by a two time-scale update rule converge to a local Nash equilibrium. In Neur IPS, 2017. Hinton, G. E., Dayan, P., Frey, B. J., and Neal, R. M. The wake-sleep algorithm for unsupervised neural networks. Science, 1995. Hofmann, T., Sch olkopf, B., and Smola, A. J. Kernel methods in machine learning. The annals of statistics, 2008. Huggins, J. H., Kasprzak, M., Campbell, T., and Broderick, T. Practical posterior error bounds from variational objectives. Co RR, abs/1910.04102, 2019. Amortised Learning by Wake-Sleep Husz ar, F. Variational inference using implicit distributions. ar Xiv preprint ar Xiv:1702.08235, 2017. Hyv arinen, A. and Oja, E. Independent component analysis: algorithms and applications. Neural Networks, 2000. Jang, E., Gu, S., and Poole, B. Categorical reparameterization with gumbel-softmax. In ICLR, 2017. Johnson, M., Duvenaud, D. K., Wiltschko, A., Adams, R. P., and Datta, S. R. Composing graphical models with neural networks for structured representations and fast inference. In Neur IPS, pp. 2946 2954, 2016. Kajihara, T., Kanagawa, M., Yamazaki, K., and Fukumizu, K. Kernel recursive abc: Point estimation with intractable likelihood. In International Conference on Machine Learning, pp. 2400 2409, 2018. Kingma, D. P. and Welling, M. Auto-encoding variational Bayes. In 2nd International Conference on Learning Representations, ICLR 2014, Banff, AB, Canada, April 14-16, 2014, Conference Track Proceedings, 2014. Kingma, D. P., Salimans, T., Jozefowicz, R., Chen, X., Sutskever, I., and Welling, M. Improved variational inference with inverse autoregressive flow. In NIPS, pp. 4743 4751, 2016. Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. Technical report, 2009. Kypraios, T., Neal, P., and Prangle, D. A tutorial introduction to Bayesian inference for stochastic epidemic models using Approximate Bayesian Computation. Mathematical biosciences, 2017. Le, T. A., Kosiorek, A. R., Siddharth, N., Teh, Y. W., and Wood, F. Revisiting reweighted wake-sleep for models with stochastic control flow. In UAI, 2019. Le Cun, Y., Bottou, L., Bengio, Y., Haffner, P., et al. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. Li, Y. and Turner, R. E. Gradient estimators for implicit models. In ICLR, 2018. Lin, W., Hubacher, N., and Khan, M. E. Variational message passing with structured inference networks. In ICLR, 2018. Lintusaari, J., Gutmann, M. U., Dutta, R., Kaski, S., and Corander, J. Fundamentals and Recent Developments in Approximate Bayesian Computation. Systematic Biology, 2016. Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In ICCV, 2015. Loaiza-Ganem, G. and Cunningham, J. P. The continuous Bernoulli: fixing a pervasive error in variational autoencoders. In Neur IPS, 2019. Marin, J.-M., Pudlo, P., Robert, C. P., and Ryder, R. J. Approximate Bayesian computational methods. Statistics and Computing, 2012. Mathieu, E., Le Lan, C., Maddison, C. J., Tomioka, R., and Teh, Y. W. Continuous hierarchical representations with Poincar e variational auto-encoders. In Neur IPS, 2019. Mescheder, L. M., Geiger, A., and Nowozin, S. Which training methods for GANs do actually converge? In ICML, 2018. Nagano, Y., Yamaguchi, S., Fujita, Y., and Koyama, M. A wrapped normal distribution on hyperbolic space for gradient-based learning. In ICML, 2019. Nakagome, S., Fukumizu, K., and Mano, S. Kernel approximate bayesian computation in population genetic inferences. Statistical applications in genetics and molecular biology, 2013. Newey, K. and Mc Fadden, D. Large sample estimation and hypothesis. Handbook of Econometrics, IV, Edited by RF Engle and DL Mc Fadden, 1994. Park, M., Jitkrittum, W., and Sejdinovic, D. K2-ABC: Approximate Bayesian Computation with kernel embeddings. In AISTATS, 2016. Pospischil, M., Toledo-Rodriguez, M., Monier, C., Piwkowska, Z., Bal, T., Fr egnac, Y., Markram, H., and Destexhe, A. Minimal Hodgkin Huxley type models for different classes of cortical and thalamic neurons. Biological cybernetics, 2008. Rainforth, T., Kosiorek, A. R., Le, T. A., Maddison, C. J., Igl, M., Wood, F., and Teh, Y. W. Tighter variational bounds are not necessarily better. In ICML, 2018. Rezende, D. and Mohamed, S. Variational inference with normalizing flows. In ICML, 2015. Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic backpropagation and approximate inference in deep generative models. In ICML, pp. 1278 1286, 2014. Rolfe, J. T. Discrete variational autoencoders. In ICLR, 2017. Ruiz, F. J. R., Titsias, M. K., and Blei, D. M. The generalized reparameterization gradient. In Neur IPS, 2016. Shi, J., Sun, S., and Zhu, J. Kernel implicit variational inference. In ICLR, 2018. Amortised Learning by Wake-Sleep Shimodaira, H. Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of statistical planning and inference, 2000. Song, L., Huang, J., Smola, A., and Fukumizu, K. Hilbert space embeddings of conditional distributions with applications to dynamical systems. In ICML, 2009. Sunn aker, M., Busetto, A. G., Numminen, E., Corander, J., Foll, M., and Dessimoz, C. Approximate Bayesian Computation. PLo S CB, 2013. Szab o, Z., Sriperumbudur, B. K., P oczos, B., and Gretton, A. Learning theory for distribution regression. Journal of Machine Learning Research, 2016. Tavar e, S., Balding, D. J., Griffiths, R. C., and Donnelly, P. Inferring coalescence times from DNA sequence data. Genetics, 1997. Turner, R. and Sahani, M. Two problems with variational expectation maximisation for time-series models. Bayesian Time Series Models, 2011. Vahdat, A., Macready, W. G., Bian, Z., Khoshaman, A., and Andriyash, E. DVAE++: Discrete variational autoencoders with overlapping transformations. In ICML, 2018. van den Berg, R., Hasenclever, L., Tomczak, J. M., and Welling, M. Sylvester normalizing flows for variational inference. In Proceedings of the Thirty-Fourth Conference on Uncertainty in Artificial Intelligence, UAI 2018, Monterey, California, USA, August 6-10, 2018, pp. 393 402, 2018. V ertes, E. and Sahani, M. Flexible and accurate inference and learning for deep generative models. In Neur IPS, pp. 4166 4175, 2018. Wenliang, L., Sutherland, D. J., Strathmann, H., and Gretton, A. Learning deep kernels for exponential family densities. In ICML, 2019. Wenliang, L. K. and Sahani, M. A neurally plausible model for online recognition and postdiction in a dynamical environment. In Neur IPS, 2019. Wilson, A. G., Hu, Z., Salakhutdinov, R., and Xing, E. P. Deep kernel learning. In AISTATS, 2016. Wood, S. N. Statistical inference for noisy nonlinear ecological dynamic systems. Nature, 2010. Xiao, H., Rasul, K., and Vollgraf, R. Fashion-MNIST: a novel image dataset for benchmarking machine learning algorithms. ar Xiv preprint ar Xiv:1708.07747, 2017. Xu, J. and Durrett, G. Spherical latent spaces for stable variational autoencoders. EMNLP, 2018. Yin, M. and Zhou, M. Semi-implicit variational inference. In ICML, 2018. Zhang, C., Butepage, J., Kjellstrom, H., and Mandt, S. Advances in variational inference. Pattern analysis and machine intelligence, 2018.