# wasserstein_autoencoders__3aec407e.pdf Published as a conference paper at ICLR 2018 WASSERSTEIN AUTO-ENCODERS Ilya Tolstikhin MPI for Intelligent Systems T ubingen, Germany ilya@tue.mpg.de Olivier Bousquet Google Brain Z urich, Switzerland obousquet@google.com Sylvain Gelly Google Brain Z urich, Switzerland sylvaingelly@google.com Bernhard Sch olkopf MPI for Intelligent Systems T ubingen, Germany bs@tue.mpg.de We propose the Wasserstein Auto-Encoder (WAE) a new algorithm for building a generative model of the data distribution. WAE minimizes a penalized form of the Wasserstein distance between the model distribution and the target distribution, which leads to a different regularizer than the one used by the Variational Auto-Encoder (VAE) (Kingma & Welling, 2014). This regularizer encourages the encoded training distribution to match the prior. We compare our algorithm with several other techniques and show that it is a generalization of adversarial auto-encoders (AAE) (Makhzani et al., 2016). Our experiments show that WAE shares many of the properties of VAEs (stable training, encoder-decoder architecture, nice latent manifold structure) while generating samples of better quality, as measured by the FID score. 1 INTRODUCTION The field of representation learning was initially driven by supervised approaches, with impressive results using large labelled datasets. Unsupervised generative modeling, in contrast, used to be a domain governed by probabilistic approaches focusing on low-dimensional data. Recent years have seen a convergence of those two approaches. In the new field that formed at the intersection, variational auto-encoders (VAEs) (Kingma & Welling, 2014) constitute one well-established approach, theoretically elegant yet with the drawback that they tend to generate blurry samples when applied to natural images. In contrast, generative adversarial networks (GANs) (Goodfellow et al., 2014) turned out to be more impressive in terms of the visual quality of images sampled from the model, but come without an encoder, have been reported harder to train, and suffer from the mode collapse problem where the resulting model is unable to capture all the variability in the true data distribution. There has been a flurry of activity in assaying numerous configurations of GANs as well as combinations of VAEs and GANs. A unifying framework combining the best of GANs and VAEs in a principled way is yet to be discovered. This work builds up on the theoretical analysis presented in Bousquet et al. (2017). Following Arjovsky et al. (2017); Bousquet et al. (2017), we approach generative modeling from the optimal transport (OT) point of view. The OT cost (Villani, 2003) is a way to measure a distance between probability distributions and provides a much weaker topology than many others, including f-divergences associated with the original GAN algorithms (Nowozin et al., 2016). This is particularly important in applications, where data is usually supported on low dimensional manifolds in the input space X. As a result, stronger notions of distances (such as f-divergences, which capture the density ratio between distributions) often max out, providing no useful gradients for training. In contrast, OT was claimed to have a nicer behaviour (Arjovsky et al., 2017; Gulrajani et al., 2017) although it requires, in its GAN-like implementation, the addition of a constraint or a regularization term into the objective. Published as a conference paper at ICLR 2018 QVAE(Z|X) PG(X|Z) VAE reconstruction QWAE(Z|X) PG(X|Z) WAE reconstruction Figure 1: Both VAE and WAE minimize two terms: the reconstruction cost and the regularizer penalizing discrepancy between PZ and distribution induced by the encoder Q. VAE forces Q(Z|X = x) to match PZ for all the different input examples x drawn from PX. This is illustrated on picture (a), where every single red ball is forced to match PZ depicted as the white shape. Red balls start intersecting, which leads to problems with reconstruction. In contrast, WAE forces the continuous mixture QZ := R Q(Z|X)d PX to match PZ, as depicted with the green ball in picture (b). As a result latent codes of different examples get a chance to stay far away from each other, promoting a better reconstruction. In this work we aim at minimizing OT Wc(PX, PG) between the true (but unknown) data distribution PX and a latent variable model PG specified by the prior distribution PZ of latent codes Z Z and the generative model PG(X|Z) of the data points X X given Z. Our main contributions are listed below (cf. also Figure 1): A new family of regularized auto-encoders (Algorithms 1, 2 and Eq. 4), which we call Wasserstein Auto-Encoders (WAE), that minimize the optimal transport Wc(PX, PG) for any cost function c. Similarly to VAE, the objective of WAE is composed of two terms: the c-reconstruction cost and a regularizer DZ(PZ, QZ) penalizing a discrepancy between two distributions in Z: PZ and a distribution of encoded data points, i.e. QZ := EPX[Q(Z|X)]. When c is the squared cost and DZ is the GAN objective, WAE coincides with adversarial auto-encoders of Makhzani et al. (2016). Empirical evaluation of WAE on MNIST and Celeb A datasets with squared cost c(x, y) = x y 2 2. Our experiments show that WAE keeps the good properties of VAEs (stable training, encoder-decoder architecture, and a nice latent manifold structure) while generating samples of better quality, approaching those of GANs. We propose and examine two different regularizers DZ(PZ, QZ). One is based on GANs and adversarial training in the latent space Z. The other uses the maximum mean discrepancy, which is known to perform well when matching high-dimensional standard normal distributions PZ (Gretton et al., 2012). Importantly, the second option leads to a fully adversary-free min-min optimization problem. Finally, the theoretical considerations presented in Bousquet et al. (2017) and used here to derive the WAE objective might be interesting in their own right. In particular, Theorem 1 shows that in the case of generative models, the primal form of Wc(PX, PG) is equivalent to a problem involving the optimization of a probabilistic encoder Q(Z|X) . The paper is structured as follows. In Section 2 we review a novel auto-encoder formulation for OT between PX and the latent variable model PG derived in Bousquet et al. (2017). Relaxing the resulting constrained optimization problem we arrive at an objective of Wasserstein auto-encoders. We propose two different regularizers, leading to WAE-GAN and WAE-MMD algorithms. Section 3 discusses the related work. We present the experimental results in Section 4 and conclude by pointing out some promising directions for future work. Published as a conference paper at ICLR 2018 2 PROPOSED METHOD Our new method minimizes the optimal transport cost Wc(PX, PG) based on the novel auto-encoder formulation (see Theorem 1 below). In the resulting optimization problem the decoder tries to accurately reconstruct the encoded training examples as measured by the cost function c. The encoder tries to simultaneously achieve two conflicting goals: it tries to match the encoded distribution of training examples QZ := EPX[Q(Z|X)] to the prior PZ as measured by any specified divergence DZ(QZ, PZ), while making sure that the latent codes provided to the decoder are informative enough to reconstruct the encoded training examples. This is schematically depicted on Fig. 1. 2.1 PRELIMINARIES AND NOTATIONS We use calligraphic letters (i.e. X) for sets, capital letters (i.e. X) for random variables, and lower case letters (i.e. x) for their values. We denote probability distributions with capital letters (i.e. P(X)) and corresponding densities with lower case letters (i.e. p(x)). In this work we will consider several measures of discrepancy between probability distributions PX and PG. The class of f-divergences (Liese & Miescke, 2008) is defined by Df(PX PG) := R f p X(x) p G(x) p G(x)dx, where f : (0, ) R is any convex function satisfying f(1) = 0. Classical examples include the Kullback-Leibler DKL and Jensen-Shannon DJS divergences. 2.2 OPTIMAL TRANSPORT AND ITS DUAL FORMULATIONS A rich class of divergences between probability distributions is induced by the optimal transport (OT) problem (Villani, 2003). Kantorovich s formulation of the problem is given by Wc(PX, PG) := inf Γ P(X PX,Y PG) E(X,Y ) Γ[c(X, Y )] , (1) where c(x, y): X X R+ is any measurable cost function and P(X PX, Y PG) is a set of all joint distributions of (X, Y ) with marginals PX and PG respectively. A particularly interesting case is when (X, d) is a metric space and c(x, y) = dp(x, y) for p 1. In this case Wp, the p-th root of Wc, is called the p-Wasserstein distance. When c(x, y) = d(x, y) the following Kantorovich-Rubinstein duality holds1: W1(PX, PG) = sup f FL EX PX[f(X)] EY PG[f(Y )], (2) where FL is the class of all bounded 1-Lipschitz functions on (X, d). 2.3 APPLICATION TO GENERATIVE MODELS: WASSERSTEIN AUTO-ENCODERS One way to look at modern generative models like VAEs and GANs is to postulate that they are trying to minimize certain discrepancy measures between the data distribution PX and the model PG. Unfortunately, most of the standard divergences known in the literature, including those listed above, are hard or even impossible to compute, especially when PX is unknown and PG is parametrized by deep neural networks. Previous research provides several tricks to address this issue. In case of minimizing the KL-divergence DKL(PX, PG), or equivalently maximizing the marginal log-likelihood EPX[log p G(X)], the famous variational lower bound provides a theoretically grounded framework successfully employed by VAEs (Kingma & Welling, 2014; Mescheder et al., 2017). More generally, if the goal is to minimize the f-divergence Df(PX, PG) (with one example being DKL), one can resort to its dual formulation and make use of f-GANs and the adversarial training (Nowozin et al., 2016). Finally, OT cost Wc(PX, PG) is yet another option, which can be, thanks to the celebrated Kantorovich-Rubinstein duality (2), expressed as an adversarial objective as implemented by the Wasserstein-GAN (Arjovsky et al., 2017). We include an extended review of all these methods in Supplementary A. 1Note that the same symbol is used for Wp and Wc, but only p is a number and thus the above W1 refers to the 1-Wasserstein distance. Published as a conference paper at ICLR 2018 In this work we will focus on latent variable models PG defined by a two-step procedure, where first a code Z is sampled from a fixed distribution PZ on a latent space Z and then Z is mapped to the image X X = Rd with a (possibly random) transformation. This results in a density of the form p G(x) := Z Z p G(x|z)pz(z)dz, x X, (3) assuming all involved densities are properly defined. For simplicity we will focus on non-random decoders, i.e. generative models PG(X|Z) deterministically mapping Z to X = G(Z) for a given map G: Z X. Similar results for random decoders can be found in Supplementary B.1. It turns out that under this model, the OT cost takes a simpler form as the transportation plan factors through the map G: instead of finding a coupling Γ in (1) between two random variables living in the X space, one distributed according to PX and the other one according to PG, it is sufficient to find a conditional distribution Q(Z|X) such that its Z marginal QZ(Z) := EX PX [Q(Z|X)] is identical to the prior distribution PZ. This is the content of the theorem below proved in Bousquet et al. (2017). To make this paper self contained we repeat the proof in Supplementary B. Theorem 1 For PG as defined above with deterministic PG(X|Z) and any function G: Z X inf Γ P(X PX,Y PG) E(X,Y ) Γ c X, Y = inf Q: QZ=PZ EPXEQ(Z|X) c X, G(Z) , where QZ is the marginal distribution of Z when X PX and Z Q(Z|X). This result allows us to optimize over random encoders Q(Z|X) instead of optimizing over all couplings between X and Y . Of course, both problems are still constrained. In order to implement a numerical solution we relax the constraints on QZ by adding a penalty to the objective. This finally leads us to the WAE objective: DWAE(PX, PG) := inf Q(Z|X) Q EPXEQ(Z|X) c X, G(Z) + λ DZ(QZ, PZ), (4) where Q is any nonparametric set of probabilistic encoders, DZ is an arbitrary divergence between QZ and PZ, and λ > 0 is a hyperparameter. Similarly to VAE, we propose to use deep neural networks to parametrize both encoders Q and decoders G. Note that as opposed to VAEs, the WAE formulation allows for non-random encoders deterministically mapping inputs to their latent codes. We propose two different penalties DZ(QZ, PZ): GAN-based DZ. The first option is to choose DZ(QZ, PZ) = DJS(QZ, PZ) and use the adversarial training to estimate it. Specifically, we introduce an adversary (discriminator) in the latent space Z trying to separate2 true points sampled from PZ and fake ones sampled from QZ (Goodfellow et al., 2014). This results in the WAE-GAN described in Algorithm 1. Even though WAE-GAN falls back to the min-max problem, we move the adversary from the input (pixel) space X to the latent space Z. On top of that, PZ may have a nice shape with a single mode (for a Gaussian prior), in which case the task should be easier than matching an unknown, complex, and possibly multi-modal distributions as usually done in GANs. This is also a reason for our second penalty: MMD-based DZ. For a positive-definite reproducing kernel k: Z Z R the following expression is called the maximum mean discrepancy (MMD): MMDk(PZ, QZ) = Z Z k(z, )d PZ(z) Z Z k(z, )d QZ(z) Hk, where Hk is the RKHS of real-valued functions mapping Z to R. If k is characteristic then MMDk defines a metric and can be used as a divergence measure. We propose to use DZ(PZ, QZ) = MMDk(PZ, QZ). Fortunately, MMD has an unbiased U-statistic estimator, which can be used in conjunction with stochastic gradient descent (SGD) methods. This results in the WAE-MMD described in Algorithm 2. It is well known that the maximum mean discrepancy performs well when matching high-dimensional standard normal distributions (Gretton et al., 2012) so we expect this penalty to work especially well working with the Gaussian prior PZ. 2We noticed that the famous log trick (also called non saturating loss ) proposed by Goodfellow et al. (2014) leads to better results. Published as a conference paper at ICLR 2018 ALGORITHM 1 Wasserstein Auto-Encoder with GAN-based penalty (WAE-GAN). Require: Regularization coefficient λ > 0. Initialize the parameters of the encoder Qφ, decoder Gθ, and latent discriminator Dγ. while (φ, θ) not converged do Sample {x1, . . . , xn} from the training set Sample {z1, . . . , zn} from the prior PZ Sample zi from Qφ(Z|xi) for i = 1, . . . , n Update Dγ by ascending: i=1 log Dγ(zi) + log 1 Dγ( zi) Update Qφ and Gθ by descending: i=1 c xi, Gθ( zi) λ log Dγ( zi) ALGORITHM 2 Wasserstein Auto-Encoder with MMD-based penalty (WAE-MMD). Require: Regularization coefficient λ > 0, characteristic positive-definite kernel k. Initialize the parameters of the encoder Qφ, decoder Gθ, and latent discriminator Dγ. while (φ, θ) not converged do Sample {x1, . . . , xn} from the training set Sample {z1, . . . , zn} from the prior PZ Sample zi from Qφ(Z|xi) for i = 1, . . . , n Update Qφ and Gθ by descending: i=1 c xi, Gθ( zi) + λ n(n 1) ℓ =j k(zℓ, zj) ℓ =j k( zℓ, zj) 2λ ℓ,j k(zℓ, zj) We point out once again that the encoders Qφ(Z|x) in Algorithms 1 and 2 can be non-random, i.e. deterministically mapping input points to the latent codes. In this case Qφ(Z|x) = δµφ(x) for a function µφ : X Z and in order to sample zi from Qφ(Z|xi) we just need to return µφ(xi). 3 RELATED WORK Literature on auto-encoders Classical unregularized auto-encoders minimize only the reconstruction cost. This results in different training points being encoded into non-overlapping zones chaotically scattered all across the Z space with holes in between where the decoder mapping PG(X|Z) has never been trained. Overall, the encoder Q(Z|X) trained in this way does not provide a useful representation and sampling from the latent space Z becomes hard (Bengio et al., 2013). Variational auto-encoders (Kingma & Welling, 2014) minimize a variational bound on the KLdivergence DKL(PX, PG) which is composed of the reconstruction cost plus the regularizer EPX [DKL(Q(Z|X), PZ)]. The regularizer captures how distinct the image by the encoder of each training example is from the prior PZ, which is not guaranteeing that the overall encoded distribution EPX [Q(Z|X)] matches PZ like WAE does. Also, VAEs require non-degenerate (i.e. nondeterministic) Gaussian encoders and random decoders for which the term log p G(x|z) can be computed and differentiated with respect to the parameters. Later Mescheder et al. (2017) proposed a way to use VAE with non-Gaussian encoders. WAE minimizes the optimal transport Wc(PX, PG) and allows both probabilistic and deterministic encoder-decoder pairs of any kind. The VAE regularizer can be also equivalently written (Hoffman & Johnson, 2016) as a sum of DKL(QZ, PZ) and a mutual information IQ(X, Z) between the images X and latent codes Z jointly distributed according to PX Q(Z|X). This observation provides another intuitive way to explain a difference between our algorithm and VAEs: WAEs simply drop the mutual information term IQ(X, Z) in the VAE regularizer. When used with c(x, y) = x y 2 2 WAE-GAN is equivalent to adversarial auto-encoders (AAE) proposed by Makhzani et al. (2016). Theory of Bousquet et al. (2017) (and in particular Theorem 1) thus suggests that AAEs minimize the 2-Wasserstein distance between PX and PG. This provides the first theoretical justification for AAEs known to the authors. WAE generalizes AAE in two ways: first, it can use any cost function c in the input space X; second, it can use any discrepancy measure DZ in the latent space Z (for instance MMD), not necessarily the adversarial one of WAE-GAN. Finally, Zhao et al. (2017b) independently proposed a regularized auto-encoder objective similar to Bousquet et al. (2017) and our (4) based on very different motivations and arguments. Following Published as a conference paper at ICLR 2018 VAEs their objective (called Info VAE) defines the reconstruction cost in the image space implicitly through the negative log likelihood term log p G(x|z), which should be properly normalized for all z Z. In theory VAE and Info VAE can both induce arbitrary cost functions, however in practice this may require an estimation of the normalizing constant (partition function) which can3 be different for different values of z. WAEs specify the cost c(x, y) explicitly and don t constrain it in any way. Literature on OT Genevay et al. (2016) address computing the OT cost in large scale using SGD and sampling. They approach this task either through the dual formulation, or via a regularized version of the primal. They do not discuss any implications for generative modeling. Our approach is based on the primal form of OT, we arrive at regularizers which are very different, and our main focus is on generative modeling. The WGAN (Arjovsky et al., 2017) minimizes the 1-Wasserstein distance W1(PX, PG) for generative modeling. The authors approach this task from the dual form. Their algorithm comes without an encoder and can not be readily applied to any other cost Wc, because the neat form of the Kantorovich-Rubinstein duality (2) holds only for W1. WAE approaches the same problem from the primal form, can be applied for any cost function c, and comes naturally with an encoder. In order to compute the values (1) or (2) of OT we need to handle non-trivial constraints, either on the coupling distribution Γ or on the function f being considered. Various approaches have been proposed in the literature to circumvent this difficulty. For W1 Arjovsky et al. (2017) tried to implement the constraint in the dual formulation (2) by clipping the weights of the neural network f. Later Gulrajani et al. (2017) proposed to relax the same constraint by penalizing the objective of (2) with a term λ E ( f(X) 1)2 which should not be greater than 1 if f FL. In a more general OT setting of Wc Cuturi (2013) proposed to penalize the objective of (1) with the KLdivergence λ DKL(Γ, P Q) between the coupling distribution and the product of marginals. Genevay et al. (2016) showed that this entropic regularization drops the constraints on functions in the dual formulation as opposed to (2). Finally, in the context of unbalanced optimal transport it has been proposed to relax the constraint in (1) by regularizing the objective with λ Df(ΓX, P) + Df(ΓY , Q) (Chizat et al., 2015; Liero et al., 2015), where ΓX and ΓY are marginals of Γ. In this paper we propose to relax OT in a way similar to the unbalanced optimal transport, i.e. by adding additional divergences to the objective. However, we show that in the particular context of generative modeling, only one extra divergence is necessary. Literature on GANs Many of the GAN variations (including f-GAN and WGAN) come without an encoder. Often it may be desirable to reconstruct the latent codes and use the learned manifold, in which cases these models are not applicable. There have been many other approaches trying to blend the adversarial training of GANs with autoencoder architectures (Zhao et al., 2017a; Dumoulin et al., 2017; Ulyanov et al., 2017; Berthelot et al., 2017). The approach proposed by Ulyanov et al. (2017) is perhaps the most relevant to our work. The authors use the discrepancy between QZ and the distribution EZ PZ[Q Z|G(Z ) ] of auto-encoded noise vectors as the objective for the max-min game between the encoder and decoder respectively. While the authors showed that the saddle points correspond to PX = PG, they admit that encoders and decoders trained in this way have no incentive to be reciprocal. As a workaround they propose to include an additional reconstruction term to the objective. WAE does not necessarily lead to a min-max game, uses a different penalty, and has a clear theoretical foundation. Several works used reproducing kernels in context of GANs. Li et al. (2015); Dziugaite et al. (2015) use MMD with a fixed kernel k to match PX and PG directly in the input space X. These methods have been criticised to require larger mini-batches during training: estimating MMDk(PX, PG) requires number of samples roughly proportional to the dimensionality of the input space X (Reddi et al., 2015) which is typically larger than 103. Li et al. (2017) take a similar approach but further train k adversarially so as to arrive at a meaningful loss function. WAE-MMD uses MMD to match QZ to the prior PZ in the latent space Z. Typically Z has no more than 100 dimensions and PZ is Gaussian, which allows us to use regular mini-batch sizes to accurately estimate MMD. 3Two popular choices are Gaussian and Bernoulli decoders PG(X|Z) leading to pixel-wise squared and cross-entropy losses respectively. In both cases the normalizing constants can be computed in closed form and don t depend on Z. Published as a conference paper at ICLR 2018 Test interpolations Test reconstructions Random samples Figure 2: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on MNIST dataset. In test reconstructions odd rows correspond to the real test points. 4 EXPERIMENTS In this section we empirically evaluate4 the proposed WAE model. We would like to test if WAE can simultaneously achieve (i) accurate reconstructions of data points, (ii) reasonable geometry of the latent manifold, and (iii) random samples of good (visual) quality. Importantly, the model should generalize well: requirements (i) and (ii) should be met on both training and test data. We trained WAE-GAN and WAE-MMD (Algorithms 1 and 2) on two real-world datasets: MNIST (Le Cun et al., 1998) consisting of 70k images and Celeb A (Liu et al., 2015) containing roughly 203k images. Experimental setup In all reported experiments we used Euclidian latent spaces Z = Rdz for various dz depending on the complexity of the dataset, isotropic Gaussian prior distributions PZ(Z) = N(Z; 0, σ2 z Id) over Z, and a squared cost function c(x, y) = x y 2 2 for data points x, y X = Rdx. We used deterministic encoder-decoder pairs, Adam (Kingma & Lei, 2014) with β1 = 0.5, β2 = 0.999, and convolutional deep neural network architectures for encoder mapping µφ : X Z and decoder mapping Gθ : Z X similar to the DCGAN ones reported by Radford et al. (2016) with batch normalization (Ioffe & Szegedy, 2015). We tried various values of λ and noticed that λ = 10 seems to work good across all datasets we considered. Since we are using deterministic encoders, choosing dz larger than intrinsic dimensionality of the dataset would force the encoded distribution QZ to live on a manifold in Z. This would make matching QZ to PZ impossible if PZ is Gaussian and may lead to numerical instabilities. We use dz = 8 for MNIST and dz = 64 for Celeb A which seems to work reasonably well. 4The code is available at github.com/tolstikhin/wae. Published as a conference paper at ICLR 2018 Test interpolations Test reconstructions Random samples Figure 3: VAE (left column), WAE-MMD (middle column), and WAE-GAN (right column) trained on Celeb A dataset. In test reconstructions odd rows correspond to the real test points. We also report results of VAEs. VAEs used the same latent spaces as discussed above and standard Gaussian priors PZ = N(0, Id). We used Gaussian encoders Q(Z|X) = N Z; µφ(X), Σ(X) with mean µφ and diagonal covariance Σ. For both MNIST and Celeb A we used Bernoulli decoders parametrized by Gθ. Functions µφ, Σ, and Gθ were parametrized by deep nets of the same architectures as used in WAE. WAE-GAN and WAE-MMD specifics In WAE-GAN we used discriminator D composed of several fully connected layers with Re Lu. We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay. If the codes z = µφ(x) for some of the training points x X end up far away from the support of PZ (which may happen in the early stages of training) the corresponding terms in the U-statistic k(z, z) = e z z 2 2/σ2 k will quickly approach zero and provide no gradient for those outliers. This could be avoided by choosing the kernel bandwidth σ2 k in a data-dependent manner, however in this case per-minibatch U-statistic would not provide an unbiased estimate for the gradient. Instead, we used the inverse multiquadratics kernel k(x, y) = C/(C + x y 2 2) which is also characteristic and has much heavier tails. In all experiments we used C = 2dzσ2 z, which is the expected squared distance between two multivariate Gaussian vectors drawn from PZ. This significantly improved the performance compared to the RBF kernel (even the one with σ2 k = 2dzσ2 z). Trained models are presented in Figures 2 and 3. Further details are presented in Supplementary C. Random samples are generated by sampling PZ and decoding the resulting noise vectors z into Gθ(z). As expected, in our experiments we observed that for both WAE-GAN and WAE-MMD the quality of samples strongly depends on how accurately QZ matches PZ. To Published as a conference paper at ICLR 2018 see this, notice that during training the decoder function Gθ is presented only with encoded versions µφ(X) of the data points X PX. Indeed, the decoder is trained on samples from QZ and thus there is no reason to expect good results when feeding it with samples from PZ. In our experiments we noticed that even slight differences between QZ and PZ may affect the quality of samples. In some cases WAE-GAN seems to lead to a better matching and generates better samples than WAE-MMD. However, due to adversarial training WAE-GAN is highly unstable, while WAE-MMD has a very stable training much like VAE. Algorithm FID VAE 82 WAE-MMD 55 WAE-GAN 42 Table 1: FID scores for samples on Celeb A (smaller is better). In order to quantitatively assess the quality of the generated images, we use the Fr echet Inception Distance introduced by Heusel et al. (2017) and report the results on Celeb A in Table 1. These results confirm that the sampled images from WAE are of better quality than from VAE, and WAE-GAN gets a slightly better score than WAEMMD, which correlates with visual inspection of the images. Test reconstructions and interpolations. We take random points x from the held out test set and report their auto-encoded versions Gθ(µφ(x)). Next, pairs (x, y) of different data points are sampled randomly from the held out test set and encoded: zx = µφ(x), zy = µφ(y). We linearly interpolate between zx and zy with equally-sized steps in the latent space and show decoded images. 5 CONCLUSION Using the optimal transport cost, we have derived Wasserstein auto-encoders a new family of algorithms for building generative models. We discussed their relations to other probabilistic modeling techniques. We conducted experiments using two particular implementations of the proposed method, showing that in comparison to VAEs, the images sampled from the trained WAE models are of better quality, without compromising the stability of training and the quality of reconstruction. Future work will include further exploration of the criteria for matching the encoded distribution QZ to the prior distribution PZ, assaying the possibility of adversarially training the cost function c in the input space X, and a theoretical analysis of the dual formulations for WAE-GAN and WAE-MMD. ACKNOWLEDGMENTS The authors are thankful to Carl Johann Simon-Gabriel, Mateo Rojas-Carulla, Arthur Gretton, Paul Rubenstein, and Fei Sha for stimulating discussions. M. Arjovsky, S. Chintala, and L. Bottou. Wasserstein GAN, 2017. Y. Bengio, A. Courville, and P. Vincent. Representation learning: A review and new perspectives. Pattern Analysis and Machine Intelligence, 35, 2013. D. Berthelot, T. Schumm, and L. Metz. Began: Boundary equilibrium generative adversarial networks, 2017. O. Bousquet, S. Gelly, I. Tolstikhin, C. J. Simon-Gabriel, and B. Sch olkopf. From optimal transport to generative modeling: the VEGAN cookbook, 2017. Lenaic Chizat, Gabriel Peyr e, Bernhard Schmitzer, and Franc ois-Xavier Vialard. Unbalanced optimal transport: geometry and kantorovich formulation. ar Xiv preprint ar Xiv:1508.05216, 2015. M. Cuturi. Sinkhorn distances: Lightspeed computation of optimal transport. In Advances in Neural Information Processing Systems, pp. 2292 2300, 2013. V. Dumoulin, I. Belghazi, B. Poole, A. Lamb, M. Arjovsky, O. Mastropietro, and A. Courville. Adversarially learned inference. In ICLR, 2017. Published as a conference paper at ICLR 2018 G. K. Dziugaite, D. M. Roy, and Z. Ghahramani. Training generative neural networks via maximum mean discrepancy optimization. In UAI, 2015. A. Genevay, M. Cuturi, G. Peyr e, and F. R. Bach. Stochastic optimization for large-scale optimal transport. In Advances in Neural Information Processing Systems, pp. 3432 3440, 2016. Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio. Generative adversarial nets. In NIPS, pp. 2672 2680, 2014. A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Sch olkopf, and A. J. Smola. A kernel two-sample test. Journal of Machine Learning Research, 13:723 773, 2012. I. Gulrajani, F. Ahmed, M. Arjovsky, V. Domoulin, and A. Courville. Improved training of wasserstein GANs, 2017. Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, G unter Klambauer, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a nash equilibrium. ar Xiv preprint ar Xiv:1706.08500, 2017. M. D. Hoffman and M. Johnson. Elbo surgery: yet another way to carve up the variational evidence lower bound. In NIPS Workshop on Advances in Approximate Bayesian Inference, 2016. S. Ioffe and C. Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift, 2015. D. P. Kingma and J. Lei. Adam: A method for stochastic optimization, 2014. D. P. Kingma and M. Welling. Auto-encoding variational Bayes. In ICLR, 2014. Y. Le Cun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, volume 86(11), pp. 2278 2324, 1998. C. L. Li, W. C. Chang, Y. Cheng, Y. Yang, and B. Poczos. Mmd gan: Towards deeper understanding of moment matching network, 2017. Y. Li, K. Swersky, and R. Zemel. Generative moment matching networks. In ICML, 2015. Matthias Liero, Alexander Mielke, and Giuseppe Savar e. Optimal entropy-transport problems and a new hellinger-kantorovich distance between positive measures. ar Xiv preprint ar Xiv:1508.07941, 2015. F. Liese and K.-J. Miescke. Statistical Decision Theory. Springer, 2008. Ziwei Liu, Ping Luo, Xiaogang Wang, and Xiaoou Tang. Deep learning face attributes in the wild. In Proceedings of International Conference on Computer Vision (ICCV), 2015. A. Makhzani, J. Shlens, N. Jaitly, and I. Goodfellow. Adversarial autoencoders. In ICLR, 2016. L. Mescheder, S. Nowozin, and A. Geiger. Adversarial variational bayes: Unifying variational autoencoders and generative adversarial networks, 2017. Sebastian Nowozin, Botond Cseke, and Ryota Tomioka. f-GAN: Training generative neural samplers using variational divergence minimization. In NIPS, 2016. B. Poole, A. Alemi, J. Sohl-Dickstein, and A. Angelova. Improved generator objectives for GANs, 2016. A. Radford, L. Metz, and S. Chintala. Unsupervised representation learning with deep convolutional generative adversarial networks. In ICLR, 2016. R. Reddi, A. Ramdas, A. Singh, B. Poczos, and L. Wasserman. On the high-dimensional power of a linear-time two sample test under mean-shift alternatives. In AISTATS, 2015. D. Ulyanov, A. Vedaldi, and V. Lempitsky. It takes (only) two: Adversarial generator-encoder networks, 2017. Published as a conference paper at ICLR 2018 C. Villani. Topics in Optimal Transportation. AMS Graduate Studies in Mathematics, 2003. J. Zhao, M. Mathieu, and Y. Le Cun. Energy-based generative adversarial network. In ICLR, 2017a. S. Zhao, J. Song, and S. Ermon. Info VAE: Information maximizing variational autoencoders, 2017b. Published as a conference paper at ICLR 2018 A IMPLICIT GENERATIVE MODELS: A SHORT TOUR OF GANS AND VAES Even though GANs and VAEs are quite different both in terms of the conceptual frameworks and empirical performance they share important features: (a) both can be trained by sampling from the model PG without knowing an analytical form of its density and (b) both can be scaled up with SGD. As a result, it becomes possible to use highly flexible implicit models PG defined by a twostep procedure, where first a code Z is sampled from a fixed distribution PZ on a latent space Z and then Z is mapped to the image G(Z) X = Rd with a (possibly random) transformation G: Z X. This results in latent variable models PG of the form (3). These models are indeed easy to sample and, provided G can be differentiated analytically with respect to its parameters, PG can be trained with SGD. The field is growing rapidly and numerous variations of VAEs and GANs are available in the literature. Next we introduce and compare several of them. The original generative adversarial network (GAN) Goodfellow et al. (2014) approach minimizes DGAN(PX, PG) = sup T T EX PX[log T(X)] + EZ PZ log 1 T(G(Z)) (5) with respect to a deterministic decoder G: Z X, where T is any non-parametric class of choice. It is known that DGAN(PX, PG) 2 DJS(PX, PG) log(4) and the inequality turns into identity in the nonparametric limit, that is when the class T becomes rich enough to represent all functions mapping X to (0, 1). Hence, GANs are minimizing a lower bound on the JS-divergence. However, GANs are not only linked to the JS-divergence: the f-GAN approach Nowozin et al. (2016) showed that a slight modification Df,GAN of the objective (5) allows to lower bound any desired f-divergence in a similar way. In practice, both decoder G and discriminator T are trained in alternating SGD steps. Stopping criteria as well as adequate evaluation of the trained GAN models remain open questions. Recently, the authors of Arjovsky et al. (2017) argued that the 1-Wasserstein distance W1, which is known to induce a much weaker topology than DJS, may be better suited for generative modeling. When PX and PG are supported on largely disjoint low-dimensional manifolds (which may be the case in applications), DKL, DJS, and other strong distances between PX and PG max out and no longer provide useful gradients for PG. This vanishing gradient problem necessitates complicated scheduling between the G/T updates. In contrast, W1 is still sensible in these cases and provides stable gradients. The Wasserstein GAN (WGAN) minimizes DWGAN(PX, PG) = sup T W EX PX[T(X)] EZ PZ T(G(Z)) , where W is any subset of 1-Lipschitz functions on X. It follows from (2) that DWGAN(PX, PG) W1(PX, PG) and thus WGAN is minimizing a lower bound on the 1-Wasserstein distance. Variational auto-encoders (VAE) Kingma & Welling (2014) utilize models PG of the form (3) and minimize DVAE(PX, PG) = inf Q(Z|X) Q EPX DKL Q(Z|X), PZ EQ(Z|X)[log p G(X|Z)] (6) with respect to a random decoder mapping PG(X|Z). The conditional distribution PG(X|Z) is often parametrized by a deep net G and can have any form as long as its density p G(x|z) can be computed and differentiated with respect to the parameters of G. A typical choice is to use Gaussians PG(X|Z) = N(X; G(Z), σ2 I). If Q is the set of all conditional probability distributions Q(Z|X), the objective of VAE coincides with the negative marginal log-likelihood DVAE(PX, PG) = EPX[log PG(X)]. However, in order to make the DKL term of (6) tractable in closed form, the original implementation of VAE uses a standard normal PZ and restricts Q to a class of Gaussian distributions Q(Z|X) = N Z; µ(X), Σ(X) with mean µ and diagonal covariance Σ parametrized by deep nets. As a consequence, VAE is minimizing an upper bound on the negative log-likelihood or, equivalently, on the KL-divergence DKL(PX, PG). One possible way to reduce the gap between the true negative log-likelihood and the upper bound provided by DVAE is to enlarge the class Q. Adversarial variational Bayes (AVB) Mescheder et al. (2017) follows this argument by employing the idea of GANs. Given any point x X, a noise ϵ N(0, 1), and any fixed transformation e: X R Z, a random variable e(x, ϵ) Published as a conference paper at ICLR 2018 implicitly defines one particular conditional distribution Qe(Z|X = x). AVB allows Q to contain all such distributions for different choices of e, replaces the intractable term DKL Qe(Z|X), PZ in (6) by the adversarial approximation Df,GAN corresponding to the KL-divergence, and proposes to minimize5 DAVB(PX, PG) = inf Qe(Z|X) Q EPX Df,GAN Qe(Z|X), PZ EQe(Z|X)[log p G(X|Z)] . (7) The DKL term in (6) may be viewed as a regularizer. Indeed, VAE reduces to the classical unregularized auto-encoder if this term is dropped, minimizing the reconstruction cost of the encoder-decoder pair Q(Z|X), PG(X|Z). This often results in different training points being encoded into nonoverlapping zones chaotically scattered all across the Z space with holes in between where the decoder mapping PG(X|Z) has never been trained. Overall, the encoder Q(Z|X) trained in this way does not provide a useful representation and sampling from the latent space Z becomes hard Bengio et al. (2013). Adversarial auto-encoders (AAE) Makhzani et al. (2016) replace the DKL term in (6) with another regularizer: DAAE(PX, PG) = inf Q(Z|X) Q DGAN(QZ, PZ) EPXEQ(Z|X)[log p G(X|Z)], (8) where QZ is the marginal distribution of Z when first X is sampled from PX and then Z is sampled from Q(Z|X), also known as the aggregated posterior Makhzani et al. (2016). Similarly to AVB, there is no clear link to log-likelihood, as DAAE DAVB. The authors of Makhzani et al. (2016) argue that matching QZ to PZ in this way ensures that there are no holes left in the latent space Z and PG(X|Z) generates reasonable samples whenever Z PZ. They also report an equally good performance of different types of conditional distributions Q(Z|X), including Gaussians as used in VAEs, implicit models Qe as used in AVB, and deterministic encoder mappings, i.e. Q(Z|X) = δµ(X) with µ: X Z. B PROOF OF THEOREM 1 AND FURTHER DETAILS We will consider certain sets of joint probability distributions of three random variables (X, Y, Z) X X Z. The reader may wish to think of X as true images, Y as images sampled from the model, and Z as latent codes. We denote by PG,Z(Y, Z) a joint distribution of a variable pair (Y, Z), where Z is first sampled from PZ and next Y from PG(Y |Z). Note that PG defined in (3) and used throughout this work is the marginal distribution of Y when (Y, Z) PG,Z. In the optimal transport problem (1), we consider joint distributions Γ(X, Y ) which are called couplings between values of X and Y . Because of the marginal constraint, we can write Γ(X, Y ) = Γ(Y |X)PX(X) and we can consider Γ(Y |X) as a non-deterministic mapping from X to Y . Theorem 1. shows how to factor this mapping through Z, i.e., decompose it into an encoding distribution Q(Z|X) and the generating distribution PG(Y |Z). As in Section 2.2, P(X PX, Y PG) denotes the set of all joint distributions of (X, Y ) with marginals PX, PG, and likewise for P(X PX, Z PZ). The set of all joint distributions of (X, Y, Z) such that X PX, (Y, Z) PG,Z, and (Y X)|Z will be denoted by PX,Y,Z. Finally, we denote by PX,Y and PX,Z the sets of marginals on (X, Y ) and (X, Z) (respectively) induced by distributions in PX,Y,Z. Note that P(PX, PG), PX,Y,Z, and PX,Y depend on the choice of conditional distributions PG(Y |Z), while PX,Z does not. In fact, it is easy to check that PX,Z = P(X PX, Z PZ). From the definitions it is clear that PX,Y P(PX, PG) and we immediately get the following upper bound: Wc(PX, PG) W c (PX, PG) := inf P PX,Y E(X,Y ) P [c(X, Y )] . (9) If PG(Y |Z) are Dirac measures (i.e., Y = G(Z)), it turns out that PX,Y = P(PX, PG): 5The authors of AVB Mescheder et al. (2017) note that using f-GAN as described above actually results in unstable training . Instead, following the approach of Poole et al. (2016), they use a trained discriminator T resulting from the DGAN objective (5) to approximate the ratio of densities and then directly estimate the KL divergence R f p(x)/q(x) q(x)dx. Published as a conference paper at ICLR 2018 Lemma 1 PX,Y P(PX, PG) with identity if 6 PG(Y |Z = z) are Dirac for all z Z. Proof The first assertion is obvious. To prove the identity, note that when Y is a deterministic function of Z, for any A in the sigma-algebra induced by Y we have E 1[Y A]|X, Z = E 1[Y A]|Z . This implies (Y X)|Z and concludes the proof. We are now in place to prove Theorem 1. Lemma 1 obviously leads to Wc(PX, PG) = W c (PX, PG). The tower rule of expectation, and the conditional independence property of PX,Y,Z implies W c (PX, PG) = inf P PX,Y,Z E(X,Y,Z) P [c(X, Y )] = inf P PX,Y,Z EPZEX P (X|Z)EY P (Y |Z)[c(X, Y )] = inf P PX,Y,Z EPZEX P (X|Z) c X, G(Z) = inf P PX,Z E(X,Z) P c X, G(Z) . It remains to notice that PX,Z = P(X PX, Z PZ) as stated earlier. B.1 RANDOM DECODERS PG(Y |Z) If the decoders are non-deterministic, Lemma 1 provides only the inclusion of sets PX,Y P(PX, PG) and we get the following upper bound on the OT: Corollary 1 Let X = Rd and assume the conditional distributions PG(Y |Z = z) have mean values G(z) Rd and marginal variances σ2 1, . . . , σ2 d 0 for all z Z, where G: Z X. Take c(x, y) = x y 2 2. Then Wc(PX, PG) W c (PX, PG) = i=1 σ2 i + inf P P(X PX,Z PZ) E(X,Z) P X G(Z) 2 . (10) Proof First inequality follows from (9). For the identity we proceed similarly to the proof of Theorem 1 and write W c (PX, PG) = inf P PX,Y,Z EPZEX P (X|Z)EY P (Y |Z) X Y 2 . (11) EY P (Y |Z) X Y 2 = EY P (Y |Z) X G(Z) + G(Z) Y 2 = X G(Z) 2 + EY P (Y |Z) X G(Z), G(Z) Y + EY P (Y |Z) G(Z) Y 2 = X G(Z) 2 + Together with (11) and the fact that PX,Z = P(X PX, Z PZ) this concludes the proof. C FURTHER DETAILS ON EXPERIMENTS We use mini-batches of size 100 and trained the models for 100 epochs. We used λ = 10 and σ2 z = 1. For the encoder-decoder pair we set α = 10 3 for Adam in the beginning and for the 6We conjecture that this is also a necessary condition. The necessity is not used in the paper. Published as a conference paper at ICLR 2018 adversary in WAE-GAN to α = 5 10 4. After 30 epochs we decreased both by factor of 2, and after first 50 epochs further by factor of 5. Both encoder and decoder used fully convolutional architectures with 4x4 convolutional filters. Encoder architecture: x R28 28 Conv128 BN Re LU Conv256 BN Re LU Conv512 BN Re LU Conv1024 BN Re LU FC8 Decoder architecture: z R8 FC7 7 1024 FSConv512 BN Re LU FSConv256 BN Re LU FSConv1 Adversary architecture for WAE-GAN: z R8 FC512 Re LU FC512 Re LU FC512 Re LU FC512 Re LU FC1 Here Convk stands for a convolution with k filters, FSConvk for the fractional strided convolution with k filters (first two of them were doubling the resolution, the third one kept it constant), BN for the batch normalization, Re LU for the rectified linear units, and FCk for the fully connected layer mapping to Rk. All the convolutions in the encoder used vertical and horizontal strides 2 and SAME padding. Finally, we used two heuristics. First, we always pretrained separately the encoder for several minibatch steps before the main training stage so that the sample mean and covariance of QZ would try to match those of PZ. Second, while training we were adding a pixel-wise Gaussian noise truncated at 0.01 to all the images before feeding them to the encoder, which was meant to make the encoders random. We played with all possible ways of combining these two heuristics and noticed that together they result in slightly (almost negligibly) better results compared to using only one or none of them. Our VAE model used cross-entropy loss (Bernoulli decoder) and otherwise same architectures and hyperparameters as listed above. We pre-processed Celeb A images by first taking a 140x140 center crops and then resizing to the 64x64 resolution. We used mini-batches of size 100 and trained the models for various number of epochs (up to 250). All reported WAE models were trained for 55 epochs and VAE for 68 epochs. For WAE-MMD we used λ = 100 and for WAE-GAN λ = 1. Both used σ2 z = 2. For WAE-MMD the learning rate of Adam was initially set to α = 10 3. For WAE-GAN the learning rate of Adam for the encoder-decoder pair was initially set to α = 3 10 4 and for the adversary to 10 3. All learning rates were decreased by factor of 2 after 30 epochs, further by factor of 5 after 50 first epochs, and finally additional factor of 10 after 100 first epochs. Both encoder and decoder used fully convolutional architectures with 5x5 convolutional filters. Encoder architecture: x R64 64 3 Conv128 BN Re LU Conv256 BN Re LU Conv512 BN Re LU Conv1024 BN Re LU FC64 Published as a conference paper at ICLR 2018 Decoder architecture: z R64 FC8 8 1024 FSConv512 BN Re LU FSConv256 BN Re LU FSConv128 BN Re LU FSConv1 Adversary architecture for WAE-GAN: z R64 FC512 Re LU FC512 Re LU FC512 Re LU FC512 Re LU FC1 For WAE-GAN we used a heuristic proposed in Supplementary IV of Mescheder et al. (2017). Notice that the theoretically optimal discriminator would result in D (z) = log p Z(z) log q Z(z), where p Z and q Z are densities of PZ and QZ respectively. In our experiments we added the log prior log p Z(z) explicitly to the adversary output as we know it analytically. This should hopefully make it easier for the adversary to learn the remaining QZ density term. Our VAE model used a cross-entropy reconstruction loss (Bernoulli decoder) and α = 10 4 as the initial Adam learning rate and the same decay schedule as explained above. Otherwise all the architectures and hyperparameters were as explained above.