# variational_wasserstein_gradient_flow__97372ae9.pdf Variational Wasserstein gradient flow Jiaojiao Fan 1 Qinsheng Zhang 1 Amirhossein Taghvaei 2 Yongxin Chen 1 Abstract Wasserstein gradient flow has emerged as a promising approach to solve optimization problems over the space of probability distributions. A recent trend is to use the well-known JKO scheme in combination with input convex neural networks to numerically implement the proximal step. The most challenging step, in this setup, is to evaluate functions involving density explicitly, such as entropy, in terms of samples. This paper builds on the recent works with a slight but crucial difference: we propose to utilize a variational formulation of the objective function formulated as maximization over a parametric class of functions. Theoretically, the proposed variational formulation allows the construction of gradient flows directly for empirical distributions with a well-defined and meaningful objective function. Computationally, this approach replaces the computationally expensive step in existing methods, to handle objective functions involving density, with inner loop updates that only require a small batch of samples and scale well with the dimension. The performance and scalability of the proposed method are illustrated with the aid of several numerical experiments involving high-dimensional synthetic and real datasets. 1 Introduction The Wasserstein gradient flow models the gradient dynamics on the space of probability densities with respect to the Wasserstein metric. It was first discovered by Jordan, Kinderlehrer, and Otto (JKO) in their seminal work (Jordan et al., 1998). They pointed out that the Fokker-Planck equation is in fact the Wasserstein gradient flow of the free energy, bringing tremendous physical insights to this type of partial differential equations (PDEs). Since then, the 1Georgia Institute of Technology 2University of Washington, Seattle. Correspondence to: Jiaojiao Fan . Proceedings of the 39 th International Conference on Machine Learning, Baltimore, Maryland, USA, PMLR 162, 2022. Copyright 2022 by the author(s). Wasserstein gradient flow has played an important role in optimal transport (Santambrogio, 2017; Carlier et al., 2017), PDEs (Otto, 2001), physics (Carrillo et al., 2021; Adams et al., 2011), machine learning (Bunne et al., 2021; Lin et al., 2021; Alvarez-Melis et al., 2021; Frogner & Poggio, 2020), sampling (Bernton, 2018; Cheng & Bartlett, 2018; Wibisono, 2018) and many other areas (Ambrosio et al., 2008). Despite the abundant theoretical results on the Wasserstein gradient flow established over the past decades (Ambrosio et al., 2008; Santambrogio, 2017), the computation of it remains a challenge. Most existing methods are either based on a finite difference method applied to the underlying PDEs or based on a finite dimensional optimization; both require discretization of the underlying space (Peyr e, 2015; Benamou et al., 2016; Carlier et al., 2017; Li et al., 2020; Carrillo et al., 2021). The computational complexity of these methods scales exponentially with the problem dimension, making them unsuitable for the cases with probability densities over high dimensional space. This shortcoming motivated recent line of interesting works to develop scalable algorithms utilizing neural networks (Mokrov et al., 2021; Alvarez-Melis et al., 2021; Yang et al., 2020; Bunne et al., 2021; Bonet et al., 2021). A central theme, in most of these works, is the application of the JKO scheme in combination with input convex neural networks (ICNN) (Amos et al., 2017). The JKO scheme, which is essentially a backward Euler method, is used to discretize the continuous flow in time. At each time-step, one needs to find a probability distribution that minimizes a weighted sum of squared Wasserstein distance, with respect to the distribution at the previous time-step, and the objective function. The probability distribution is then parametrized as push-forward of the optimal transport map from the previous probability distribution. The optimal transport map is represented with gradient of an ICNN utilizing the knowledge that optimal transport maps are gradient of convex functions when the transportation cost is quadratic. The problem is finally cast as stochastic optimization problem which only requires samples from the distribution. Our paper builds on these recent works but with a crucial difference. We propose to use a variational form of the objective function, leveraging f-divergences, which has been employed in multiple machine learning applications, such as generative models (Nowozin et al., 2016), and Bayesian Variational Wasserstein gradient flow inference (Wan et al., 2020). The variational problem is formulated as maximization over a parametrized class of functions. The variational form allows the evaluation of the objective in terms of samples, without the need for density estimation or approximating the logarithm of the determinant of the Hessian of ICNNs which appears in (Mokrov et al., 2021; Alvarez-Melis et al., 2021). Moreover, the variational form, even when restricted to a finite-dimensional class of functions, admits nice geometrical properties of its own leading to a meaningful objective function to minimize. At the end of the algorithm, a sequence of transport maps connecting the initial distribution with the terminal distribution along the gradient flow dynamics are obtained. One can then sample from the distributions along the flow by sampling from the initial distribution (often Gaussian) and then propagating these samples through the sequence of transport maps. When the transport map is modeled by the gradient of an input convex neural network, one can also evaluate the densities at every point. Our contributions are summarized as follows. i) We develop a numerical algorithm to implement the Wasserstein gradient flow that is based on a variational representation of the objective functions. The algorithm does not require spatial discretization, density estimation, or approximating logarithm of determinant of Hessians. ii) We numerically demonstrate the performance of our algorithm on several representative problems including sampling from high-dimensional Gaussian mixtures, porous medium equation, and learning generative models on MNIST and CIFAR10 datasets. We illustrate the computational advantage of our proposed method in comparison with (Mokrov et al., 2021; Alvarez-Melis et al., 2021), in terms of computational time and scalibity with the problem dimension. iii) We establish some preliminary theoretical results regarding the proposed variational objective function. In particular, we provide conditions under which the variational objective satisfies a moment matching property and an embedding inequality with respect to a certain integral probability metric (see Proposition 4.1). Related works: Most existing methods to compute Wasserstein gradient flow are finite difference based (Peyr e, 2015; Benamou et al., 2016; Carlier et al., 2017; Li et al., 2020; Carrillo et al., 2021). These methods require spatial discretization and are thus not scalable to high dimensional settings. There is a line of research that uses particle-based method to estimate the Wasserstein gradient flow (Carrillo et al., 2019a; Frogner & Poggio, 2020). In these algorithms, the current density value is often estimated using kernel method whose complexity scales at least quadratically with the number of particles. More recently, several interesting neural network based methods (Mokrov et al., 2021; Alvarez-Melis et al., 2021; Yang et al., 2020; Bunne et al., 2021; Bonet et al., 2021; Hwang et al., 2021) were proposed for Wasserstein gradient flow. Mokrov et al. (2021) focuses on the special case with Kullback-Leibler divergence as objective function. Alvarez-Melis et al. (2021) uses a density estimation method to evaluate the objective function by back-propagating to the initial distribution, which could become a computational burden when the number of time discretization is large. Yang et al. (2020) is based on a forward Euler time discretization of the Wasserstein gradient flow and is more sensitive to time stepsize. Bunne et al. (2021) utilizes JKO scheme to approximate a population dynamics given an observed trajectory, which finds application in computational biology. Bonet et al. (2021) replaces Wasserstein distance in JKO by sliced alternative but its connection to the original Wasserstein gradient flow remains unclear. 2 Background 2.1 Optimization problem We are interested in developing algorithms for min P Pac(Rn) F(P), (1) where Pac(Rn) is the space of probability distributions that admit density d P/dx with respect to Lebesgue measure. The objective function F(P) takes different form depending on the application. Three important examples are: Example I: Kullback-Leibler divergence with respect to a given target distribution Q, D(P||Q) := Z log d P plays an important role in the sampling problem. Example II: Generalized entropy G(P) := 1 m 1 Z P m(x)dx, m > 1 is important for modeling the porous medium. Example III: The (twice) Jensen-Shannon divergence JSD(P Q) := D P P + Q + D Q P + Q is important in learning generative models. 2.2 Wasserstein gradient flow Given a function F(P) over the space of probability densities, the Wasserstein gradient flow describes the dynamics of the probability density when it follows the steepest descent direction of the function F(P) with respect to the Variational Wasserstein gradient flow Riemannian metric induced by the 2-Wasserstein distance W2 (Ambrosio et al., 2008). The Wasserstein gradient flow can be explicitly represented by the PDE where δF/δP represents the first-variation of of F with respect to the standard L2 metric (Villani, 2003, Ch. 8). Wasserstein gradient flow corresponds to various important PDEs depending on the choice of objective functions F(P). For instance, when F(P) is the free energy, i.e. Rn P(x) log P(x)dx + Z Rn V (x)P(x)dx, the gradient flow is the Fokker-Planck equation (Jordan et al., 1998). t = (P V ) + P. When F(P) is the generalized entropy F(P) = 1 m 1 R Rn P m(x)dx for some positive number m > 1, the gradient flow is the porous medium equation (Otto, 2001; V azquez, 2007) P 2.3 JKO scheme and reparametrization To numerically realize the Wasserstein gradient flow, a discretization over time is needed. One such discretization is the famous JKO scheme (Jordan et al., 1998) Pk+1 = arg min P Pac(Rn) 1 2a W 2 2 (P, Pk) + F(P). (3) This is essentially a backward Euler discretization or a proximal point method with respect to the Wasserstein metric. The solution to (3) converges to the continuous-time Wasserstein gradient flow when the step size a 0. Recall the definition of the Wasserstein-2 distance W 2 2 (P, Q) = min T :T P =Q Rn x T(x) 2 2d P(x), where the minimization is over all the feasible transport maps that transport mass from distribution P to distribution Q. Hence, (3) can be recast as an optimization in terms of the transport maps T : Rn Rn from Pk to P. By defining P = T Pk, the optimal T is the optimal transport map from Pk to T Pk and thus the gradient of a convex function ϕ by Brenier s Theorem (Brenier, 1991). Bunne et al. (2021); Mokrov et al. (2021); Alvarez-Melis et al. (2021) propose to parameterize T as the gradient of Input convex neural network (ICNN) (Amos et al., 2017) and express (3) as Pk+1 = ϕk Pk, (4) ϕk = arg min ϕ CVX Rn x ϕ(x) 2 2d Pk(x)+F( ϕ Pk), where CVX stands for the space of convex functions. In our method, we extend this idea and propose to reparametrize T alternatively by a residual neural network. With this reparametrization, the JKO step (3) becomes Pk+1 = Tk Pk, (5) Tk = arg min T Rn x T(x) 2 2d Pk(x)+F(T Pk). We use the preceding two schemes (4) and (5) in our numerical method depending on the application. 3 Methods and algorithms We discuss how to implement JKO scheme with our approach and its computational complexity in this section. 3.1 F(P) reformulation with variational formula The main challenge in implementing the JKO scheme is to evaluate the functional F(P) in terms of samples from P. We achieve this goal by using a variational formulation of F. In order to do so, we use the notion of f-divergence between the two distributions P and Q: Df(P Q) = EQ where P admits density with respect to Q (denoted as P Q) and f : [0, + ) R is a convex and lower semicontinuous function. Without loss of generality, we assume f(1) = 0 so that Df attains its minimum at P = Q. Proposition 3.1. (Nguyen et al., 2010) P, Q P(Rn) such that P Q and differentiable f: Df(P Q) = sup h C EP [h(X)] EQ[f (h(Y ))]. (7) where f (y) = supx R[xy f(x)] is the convex conjugate of f and C is all measurable functions h : Rn R. The supremum is attained at h = f (d P/d Q). The variational form has the distinguishing feature that it does not involve the density of P and Q explicitly and can be approximated in terms of samples from P and Q. In general, our scheme can be applied to any f-divergence, but we focus on the functionals in Section 2.1. With the help of the f-divergence variational formula, when F(P) = D(P Q), G(P) or JSD(P Q), the JKO scheme (5) can be equivalently expressed as Pk+1 = Tk Pk, (8) Tk = arg min T 2a EPk[ X T(X) 2] + sup h V(T, h) . Variational Wasserstein gradient flow where V(T, h) = EX Pk[Ah(T(X))] EZ Γ[Bh(Z)], Γ is a user designed distribution which is easy to sample from, and A and B are functionals whose form depends on F. The specializations of A and B appear in Table 1. The following lemma implies that if F(P) can be written as Df(P Q), then F(P) monotonically decreases along its Wasserstein gradient flow, which makes it reasonable to solve (1) by using JKO scheme. It also justifies that the gradient flow finally converges to Q. Lemma 3.2. Gao et al. (2019, Lemma 2.2) d dt F(Pt) = EPt( f (Pt/Q) 2). 3.1.1 KL DIVERGENCE The KL divergence is a special instance of the f-divergence with f(x) = x log x. Using f(x) = x log x in (7) yields the following expression for KL divergence as a corollary of Proposition 3.1. The proof appears in Section A.1 Corollary 3.3. The variational form for D(P Q) reads D(P Q)=1 + sup h EP log h(X)µ(X) Eµ [h(Z)] , where µ is a user designed distribution which is easy to sample from. The optimal function h is equal to d P/dµ. This variational formula becomes practical when we have only access to un-normalized density of Q, which is the case for the sampling problem. In practice, we choose µ = µk adaptively, where µk is the Gaussian with the same mean and covariance as Pk. We noticed that this choice improves the numerical stability of the the algorithm. 3.1.2 GENERALIZED ENTROPY The generalized entropy can be also represented as a fdivergence. In particular, with f(x) = 1 m 1(xm x) and Q the uniform distribution on the superset of the support of density P(x) with volume Ω: Df(P Q) = Ωm 1 Z P m(x)dx 1 m 1 (9) = Ωm 1G(P) 1 m 1. Plugging f(x) = 1 m 1(xm x) into (7), we get the following expression of the generalized entropy as a corollary of Proposition 3.1. The proof appears in Section A.1 Corollary 3.4. The variational formulation for G(P) reads G(P)= suph EP h mhm 1(X) m 1 i EQ [hm(Z)] Ωm 1 . (10) The optimal function h is equal to d P/d Q. Table 1: Variational formula for F(P) F(P) Ah Bh Γ D(P Q) log h µk Q h Gaussian dist. µk G(P) m m 1 hm 1 Ωm 1 k Uniform dist. Qk JSD(P Q) log(1 h) log h Empirical dist. Q In practice, we choose Ω= Ωk which is the volume of a set that guarantees to contain the support of T Pk. In view of the connection between generalized entropy and fdivergence, it is justified that the solution of Porous Media equation develops towards a uniform distribution. Especially, when m = 2, (9) recovers the Pearson divergence between P and the uniform distribution Q. 3.1.3 JENSEN-SHANNON DIVERGENCE JSD(P Q) corresponds to f-divergence with f(x) = (x+ 1) log((1 + x)/2) + x log x. Direct application of (7) concludes the following Corollary. Corollary 3.5. The variational form for JSD(P Q) is log 4 + sup h EP [log(1 h(X))] + EQ [log h(Z)] . (11) In particular, we apply JSD to the learn the image generative model, therefore we assume samples from Q are accessible. Algorithm 1 Primal-dual gradient flow Input: Objective functional F(P), initial distribution P0, JKO step size a, number of JKO steps K. Initialization: Parameterized Tθ and hλ for k = 1, 2, . . . , K do Tθ Tk 1 if k > 1 for j1 = 1, 2, . . . , J1 do Sample X1, . . . , XM Pk, Z1, . . . , ZM Γ. Maximize 1 M PM i=1 [A(Tθ(Xi), hλ) B(hλ(Zi))] over λ for J2 steps. Minimize 1 M PM i=1 h Xi Tθ(Xi) 2 2a +A(Tθ(Xi), hλ) i over θ for J3 steps. end for Tk Tθ end for Output: {Tk}K k=1 3.2 Parametrization of T and h The two optimization variables T and h in our minimax formulation (8) can be both parameterized by neural networks, denoted by Tθ and hλ. With this neural network Variational Wasserstein gradient flow parametrization, we can then solve the problem by iteratively updating Tθ and hλ. This primal-dual method to solve (1) is depicted in Algorithm 1. In this work, we implemented two different architectures for the map T. One way is to use a residual neural network to represent T directly, and another way is to parametrize T as the gradient of a ICNN ϕ. The latter has been widely used in optimal transport (Makkuva et al., 2020; Fan et al., 2020; Korotin et al., 2021b). However, recently several works (Rout et al., 2021; Korotin et al., 2021a; Fan et al., 2021; Bonet et al., 2021) find poor expressiveness of ICNN architecture and also propose to replace the gradient of ICNN by a neural network. In our experiments, we find that the first parameterization gives more regular results, which aligns with the result in Bonet et al. (2021, Figure 4). However, it would be very difficult to calculate the density of pushforward distribution. Therefore, with the first parametrization, our method becomes a particle-based method, i.e. we cannot query density directly. As we discuss in Section D, when density evaluation is needed, we adopt the ICNN since we need to compute T 1. 3.3 Computational complexity Each update k in Algorithm 1 requires O(J1k MH) operations, where J1 is the number of iterations per each JKO step, M is the batch size, and H is the size of the network. k shows up in the bound because sampling Pk requires us to pushforward x0 P0 through k 1 maps. In contrast, Mokrov et al. (2021) requires O J1 (k + n)MH + n3 operations, which has a cubic dependence (Mokrov et al., 2021, Section 5) on dimension n because they need to query the log det 2ϕ in each iteration. There exists fast approximation (Huang et al., 2020) of log det 2ϕ using Hutchinson trace estimator (Hutchinson, 1989). Alvarez-Melis et al. (2021) applies this technique, thus the cubic dependence on n can be improved to quadratic dependence. Noneless, this is accompanied by an additional cost, which is the number of iterations to run conjugate gradient (CG) method. CG is guaranteed to converge exactly in n steps in this setting. If one wants to obtain log det 2ϕ precisely, the cost is still O(n3), which is the same as calculating log det 2ϕ directly. If one uses an error ϵ stopping condition in CG, the complexity could be improved to κ log(2/ϵ)n2 (Shewchuk et al., 1994), where κ is the upper bound of condition number of 2ϕ, but this would sacrifice on the accuracy. Given the similar neural network size, our method has the advantage of independence on the dimension for the training time. Other than training time, the complexity for evaluating the density has unavoidable dependence on n due to the standard density evaluation process (see Section D). 4 Theoretical results We introduce approximate f-divergence notation and analyze its properties in this section. 4.1 Approximate f-divergence Given the results in Proposition 3.1, now we consider a restriction of the optimization domain C to a class of functions H, e.g parametrized by neural networks, and define the new functional DH f (P Q) = sup h H Z hd P Z f (h)d Q . This functional forms a surrogate for the exact f-divergence. It is straightforward to see that the new function is always smaller than the exact f-divergence, i.e. DH f (P Q) Df(P Q) where the inequality is achieved when f ( d P d Q) belongs to H. In the following lemma, we establish some important theoretical properties of the approximate f-divergence DH f (P Q). In order to do so, we introduce the integral probability metric (Sriperumbudur et al., 2012; Arora et al., 2017) d H(P, Q) = sup h H Z hd P Z hd Q , where h 2 2,Q = R h2d Q. Proposition 4.1. The approximate f-divergence DH f (P Q) satisfies the following properties: 1. (positivity) If H contains all constant functions, then DH f (P Q) 0, P, Q. 2. (moment-matching) If for all h H, c + λh H for c, λ R, then DH f (P Q) = 0 Z hd P = Z hd Q, h H. 3. (embedding inequalities) Additionally, if f is strongly convex with constant α, and smooth with constant L, then, 2 d H(P, Q)2 DH f (P Q) L 2 d H(P, Q)2. The proposition has important implications. Part (1) establishes the condition under which the approximate fdivergence is always positive. Part (2) identifies necessary and sufficient conditions under which the approximate divergence is zero for two given probability distributions P and Q. In particular, the divergence is zero iff the moments of P and Q are equal for all functions in the function class H. Finally, part (3) provides lower-bound and upper-bound for the approximate f-divergence in terms of an integral probability metric defined on the function class H, implying that the Variational Wasserstein gradient flow two measures are equivalent when f is both strongly convex and smooth. For example, a sequence DH f (Pd Qd) 0 as d iff d H(Pd, Qd) 0 as d . Or if we are able to minimize the approximate f-divergence DH f (P Q) with optimization gap ϵ, then the error in the moments of P and Q for functions in H is of order O( ϵ). These results inform us that the proposed objective function of minimizing DH f (P Q) is meaningful and has geometrical significance. Remark 4.2. The assumption that c+λh H for all h H and c, λ R holds for any neural network with linear activation function at the last layer. The assumption that f is strongly convex and smooth may not hold for a typical f such as f(x) = x log(x) over (0, ). However, It holds when the domain is restricted, which is true when either the samples are bounded or h is bounded for all h H. 4.2 Computational boundness It is also possible to obtain lower-bound for DH f (P Q) in terms of the exact f-divergence Df(P Q) when the class H is rich enough. Proposition 4.3. If f is α-strongly convex and the class of functions is able to approximate any function h C with h H such that h h 2,Q ϵ, then DH f (P Q) Df(P Q) ϵ2 Proposition 4.3 gives upper-bound on the error between variational f-divergence and the ground truth by the function class expressiveness, which can be verified for neural net function class. Assume H is the class of neural nets with an arbitrary depth under mild assumption on the activation function. Following the proof of Theorem 1 in Korotin et al. (2022), we can verify that for any ϵ > 0, compactly supported Q, and function h 2,Q < , there exists a neural net h H such that h h 2,Q ϵ (c.f. discussion in Section A.3). However, Proposition 4.1-(3) and Proposition 4.3 require f to be strongly convex, which might be too strong for some f-divergences, such as KL divergence. Unlike the exact form of the f-divergence, the variational formulation is well-defined for empirical distributions when the function class H is restricted and admits a finite Rademacher complexity. Proposition 4.4. Let P (N) = 1 N PN i=1 δXi, Q(M) = 1 M PM i=1 δYi, where {Xi}N i=1, {Yi}M i=1 are i.i.d samples from P and Q respectively. Then, it follows that E[|DH f (P Q) DH f (P (N) Q(M))|] 2RN(H, P) + 2RM(f H, Q), where the expectation is over the samples and RN(H, P) denotes the Rademacher complexity of the function class H with respect to P for sample size N. Proposition 4.4 quantifies the generalization error in terms of Rademacher complexity. We leave the task of evaluating the Rademacher complexity for different function classes employed in this paper for future work. 4.3 Convergence to spherical Gaussian distribution We assert the efficacy of JKO with variational estimation through a spherical Gaussian example. We consider sampling from the target distribution Q = N(η, In) by minimizing the functional F(P) = D(P Q). We choose P0 = µ = N(0, In), and parameterize T to be linear functions. Assume we get T0, . . . , TK 1 by solving the particle approximated JKO in (14), and we can estimate Eµ[h( )] precisely for simplication. Denote PK as the K-th JKO iteration TK 1 (. . . (T0 P0)) and P K as the ground truth solution of JKO. Proposition 4.5. Based on the assumptions in the paragraph above, let P (N) K = 1 N PN i=1 δXi, where {Xi}N i=1 are i.i.d samples from PK. Then, it follows that E[|DH(P K Q) DH(P (N) K Q)|] ξK,N + 1/N)2 where N = η (1+a)K , ξK,N = a 1 + a 1 (1 + a)2(K j) , and H {h : h(z) = exp(α z + γ), α Rn, γ R}. This proposition quantifies the sample complexity and convergence rate of JKO with our variational estimation for a spherical Gaussian example. In the future, it would be useful to analyze the stability and convergence of the proposed minmax formulation for more general functional F(P), both at the level of densities and at the level of samples/particles. 5 Numerical examples In this section, we present several numerical examples to illustrate our algorithm. We mainly compare with the JKOICNN-d (Mokrov et al., 2021), JKO-ICNN-a (Alvarez-Melis et al., 2021). The difference between JKO-ICNN-d and JKO-ICNN-a is that the former computes the log det( 2ϕ) directly and the latter adopts fast approximation. We use the default hyper-parameters in the authors implementation. Our code is written in Py Torch-lightning and is publicly available at https://github.com/sbyebss/ variational_wgf. 5.1 Sampling from Gaussian Mixture Model We first consider the sampling problem to sample from a target distribution Q. Note that Q doesn t have to be nor- Variational Wasserstein gradient flow malized. To this end, we consider the Wasserstein gradient flow with objective function F(P) = D(P Q), that is, the KL divergence between distributions P and Q. When this objective is minimized, P Q. In our experiments, we consider the Gaussian mixture model (GMM) with 10 equal-weighted spherical Gaussian components. The mean of Gaussian components are randomly uniformly sampled inside a cube. The step size is set to be a = 0.1 and the initial measure is a spherical Gaussian N(0, 16In). In Figure 1, we show our generated samples are in concordance with the target measure. (a) Dimension n = 64 (b) Dimension n = 128 Figure 1: Comparison between the target GMM and fitted measure of generated samples by our method. Samples are projected onto 2D plane by performing PCA. Figure 2: Averaged training time (in minutes) of 40 JKO steps for sampling from GMM. In Figure 2, we plot the averaged training time of 5 runs for all compared methods. Note that we fix the number of conjugate descent steps to be at most 10 when approximating log det 2ϕ in JKO-ICNN-a. That s why JKO-ICNN-d and JKO-ICNN-a have quite similar training time when n < 10. To investigate the performance under the constraint of similar training time, we perform 40 JKO steps with our method (a) log10KSD (b) Objective functional Figure 3: (a) We perform experiments in n = 2, 4, 8, 15, 24, 32 for all methods and additionally n = 64, 128 for our method. With the constraint of similar training time, our method gives smaller error in high dimension. (b) With the variational formula, we use only samples to estimate the objective functional D(Pk Q) in dimension n = 64. It converges to the ideal objective minimum D(P Q) = 0. and the same for JKO-ICNN methods except for n 15, where we only let them run for 20, 15, 12 JKO steps for n = 15, 24, 32 respectively. In doing so, one can verify the training time of our method and JKO-ICNN is roughly consistent. We only report the accuracy results of JKO-ICNN-d for n < 10 in Figure 3 since it s prone to give higher accuracy than JKO-ICNN-a considering nearly the same training time in low dimension. We select Kernalized Stein Divergence (KSD) (Liu et al., 2016) as the error criteria because it only requires samples to estimate the divergence, which is useful in the sampling task. 5.2 Ornstein-Uhlenbeck Process (a) log10Sym KL (b) Objective functional Figure 4: (a): We repeat the experiments for 15 times and compare the Sym KL (Mokrov et al., 2021) between estimated distribution and the ground truth at k = 18 in OU process. (b): We show the comparison between our estimated D(Pk Q) and the ground truth in dimension n = 64. They align with each other pretty well. We study the performance of our method in modeling the Ornstein-Uhlenbeck Process as dimension grows. The gradient flow is affiliated with the free energy (2), where Q = e(x b)TA(x b)/2 with a positive definite matrix A Rn Rn and b Rn. Given an initial Gaussian distribution Variational Wasserstein gradient flow Table 2: Bayesian logistic regression accuracy and loglikelihood results. Accuracy Log-Likelihood Dataset Ours JKO-ICNN Ours JKO-ICNN covtype 0.753 0.75 -0.528 -0.515 splice 0.84 0.845 -0.38 -0.36 waveform 0.785 0.78 -0.455 -0.485 twonorm 0.982 0.98 -0.056 -0.059 ringnorm 0.73 0.74 -0.5 -0.5 german 0.67 0.67 -0.59 -0.6 image 0.866 0.82 -0.394 -0.43 diabetis 0.786 0.775 -0.45 -0.45 banana 0.55 0.55 -0.69 -0.69 N(0, In), the gradient flow at each time t is a Gaussian distribution Pt with mean vector µt = (In e At)b and covariance (Vatiwutipong & Phewchean, 2019) Σt = A 1(In e 2At) + e 2At. We choose JKO step size a = 0.05. We only present JKOICNN-d accuracy results because JKO-ICNN-a has the similar or slightly worse performance. There could be several reasons why we have better performance. 1) The proposed distribution µ is Gaussian, which is consistent with Pt for any t. This is beneficial for the inner maximization to find a precise h. 2) Parameterizing T as a neural network instead of gradient of ICNN is handier for optimization in this toy example. We also compare the training time per every two JKO steps with JKO-ICNN method. The computation time for JKOICNN-d is around 25s when n = 2 and increases to 105s when n = 32. JKO-ICNN-a has slightly better scalability, which increases from 25s to 95s. Our method s training time remains at 22s 5s for all the dimensions n = 2 32. This is due to the fact that we fix the neural network size for both methods and our method s computation complexity does not depend on the dimension. 5.3 Bayesian Logistic Regression To evaluate our method on a real-world datast, we consider the bayesian logistic regression task with the same setting in Gershman et al. (2012). Given a dataset L = {l1, . . . , l S}, a model with parameters x Rn and the prior distribution p0(x), our target is to sample from the posterior distribution p(x|L) p0(x)p(L|x) = p0(x) s=1 p(ls|x). To this end, we let the target distribution Q(x) = p0(x)p(L|x) and choose F(P) equal to D(P Q). The parameter x takes the form of [ω, log α], where ω Rn 1 is the regression weights with the prior p0(ω|α) = N(ω, α 1). α is a scalar with the prior p0(α) = Gamma(α|1, 0.01). We test on 8 relatively small datasets (S 7400) from Mika et al. (1999) and one large Covertype dataset1 (S = 0.58M). The dataset is randomly split into training dataset and test dataset according to the ratio 4:1. The number of features scales from 2 to 60. From Table 2, we can tell that our method achieves a comparable performance as the other. The results of JKO-ICNN-d are adapted from Mokrov et al. (2021, Table 1). We present the datasets properties and comparison with another popular sampling method SVGD (Liu & Wang, 2016) in Table 5 in the Appendix. 5.4 Porous media equation Figure 5: Sym KL with respect to the Barenblatt profile ground truth in 50 JKO steps. (a) Dimension n = 3 (b) Dimension n = 6 Figure 6: We use variational formula to calculate the objective functional G(P) with samples and compare it with ground truth. We next consider the porous media equation with only diffusion: t P = P m. This is the Wasserstein gradient flow associated with the energy function F(P) = G(P). A representative closed-form solution of the porous media equation is the Barenblatt profile (GI, 1952; V azquez, 2007) P(t, x) = (t + t0) α C β x x0 2 (t + t0) where α = n n(m 1)+2, β = (m 1)α 2mn , t0 > 0 is the starting time, and C > 0 is a free parameter. In the experi- 1https://www.csie.ntu.edu.tw/ cjlin/ libsvmtools/datasets/binary.html Variational Wasserstein gradient flow ments, we set m = 2, the stepsize for the JKO scheme to be a = 0.0005 and the initial time to be t0 = 0.001. We parametrize the transport map T as the gradient of an ICNN and thus we can evaluate the density following Section D. From Figure 5, we observe that our method can give stable simulation results, where the error is controlled in a small region as diffusion time increases. 5.5 Gradient flow on images (a) Trajectory (b) Uncurated samples Figure 7: With Wasserstein gradient flow scheme, we visualize (a): trajectories of the generated samples from JKO-Flow and (b): 100 uncurated samples from PK. In this section, we illustrate the scalability of our algorithm to high-dimensional setting by applying our scheme on real image datasets, where only samples from Q are accessible. With the variational formula (11), Algorithm 1 can be adapted to model gradient flow in image space. Specifically, we choose F(P) to be JSD(P Q) and P0 = N(0, In). We name the resulted model JKO-Flow. Note JKO-Flow model specializes to GAN (Goodfellow et al., 2014) when a and K = 1. Thanks to the additional Wasserstein loss regularization, JKO-Flow enjoys stable training and suffer less from mode collapsing empirically. We evaluate JKO-Flow on popular MNIST (Le Cun et al., 1998) and CIFAR10 (Krizhevsky et al., 2009) datasets. Figure 7 shows samples and their trajectories starting from P0 to PK and demonstrates JKO-Flow can approximate Wasserstein gradient flow in image space empirically. To further quantify the performance of JKO-Flow, we measure discrepancy between PK and real distribution with the popular sample metric, Fenchel Inception Distance (Heusel et al., 2017) in Table 3. We also compare our method with normalizing flow (NF), which also consists of a sequence of forward mapping. Table 3: Results of Gradient flow (GF) based methods, GAN methods and normalizing flow (NF) on unconditional CIFAR10 dataset. Method FID score NF GLOW (Kingma & Dhariwal, 2018) 45.99 VGrow (Gao et al., 2019) 28.8 GF JKO-Flow 23.1 WGAN-GP (Arbel et al., 2018) 31.1 GANs SN-GAN (Miyato et al., 2018) 21.7 However, the invertible property of NF either requires heavy calculations (e.g. evaluating matrix determinant or solving Neural ODE) or special network structures that limit the the expressiveness of NNs. We include more comparison and experiments details in Section G. 6 Conclusion In this paper, we presented a numerical procedure to implement the Wasserstein gradient flow for objective functions in the form of f-divergence. Our procedure is based on applying the JKO scheme on a variational formulation of the f-divergence. Each step involves solving a minmax stochastic optimization problem for a transport map and a dual function that are parameterized by neural networks. We demonstrated the scalability of our approach to high-dimensional problems through numerical experiments on Gaussian mixture models and real datasets including MNIST and CIFAR10. We also provided preliminary theoretical results regarding the variational objective function. The results show that minimizing the variational objective is meaningful and serve as starting point for future research. Our method can also be adapted to Crank-Nicolson type scheme, which enjoys a faster convergence (Carrillo et al., 2021) in step size a than the classical JKO scheme (see Section B). One restriction of our method is that it is only applicable to f-divergence, thus a possible direction for future research is to extend the variational formulation beyond f-divergence. Another limitation is that the min-max training is both theoretically and numerically more challenging than a single minimization. Acknowledgement The authors would like to thank the anonymous reviewers for useful comments. JF, QZ, and YC are supported in part by grants NSF CAREER ECCS-1942523, NSF ECCS1901599, and NSF CCF-2008513. Variational Wasserstein gradient flow Adams, S., Dirr, N., Peletier, M. A., and Zimmer, J. From a large-deviations principle to the Wasserstein gradient flow: a new micro-macro passage. Communications in Mathematical Physics, 307(3):791 815, 2011. (Cited on page 1.) Alvarez-Melis, D., Schiff, Y., and Mroueh, Y. Optimizing functionals on the space of probabilities with input convex neural networks. ar Xiv preprint ar Xiv:2106.00774, 2021. (Cited on pages 1, 2, 3, 5, 6, 18, 19, and 21.) Ambrosio, L., Gigli, N., and Savar e, G. Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media, 2008. (Cited on pages 1 and 3.) Amos, B., Xu, L., and Kolter, J. Z. Input convex neural networks. In International Conference on Machine Learning, pp. 146 155. PMLR, 2017. (Cited on pages 1 and 3.) An, D., Guo, Y., Lei, N., Luo, Z., Yau, S.-T., and Gu, X. Ae-ot: a new generative model based on extended semidiscrete optimal transport. ICLR 2020, 2019. (Cited on pages 26 and 27.) An, D., Guo, Y., Zhang, M., Qi, X., Lei, N., and Gu, X. Ae-ot-gan: Training gans from data specific latent distribution. In European Conference on Computer Vision, pp. 548 564. Springer, 2020. (Cited on pages 26 and 27.) Arbel, M., Sutherland, D. J., Bi nkowski, M. a., and Gretton, A. On gradient regularizers for mmd gans. In Bengio, S., Wallach, H., Larochelle, H., Grauman, K., Cesa Bianchi, N., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018. URL https://proceedings. neurips.cc/paper/2018/file/ 07f75d9144912970de5a09f5a305e10c-Paper. pdf. (Cited on page 9.) Arora, S., Ge, R., Liang, Y., Ma, T., and Zhang, Y. Generalization and equilibrium in generative adversarial nets (gans). In International Conference on Machine Learning, pp. 224 232. PMLR, 2017. (Cited on page 5.) Benamou, J.-D., Carlier, G., M erigot, Q., and Oudet, E. Discretization of functionals involving the monge amp ere operator. Numerische mathematik, 134(3):611 636, 2016. (Cited on pages 1 and 2.) Bernton, E. Langevin monte carlo and jko splitting. In Conference On Learning Theory, pp. 1777 1798. PMLR, 2018. (Cited on page 1.) Biewald, L. Experiment tracking with weights and biases, 2020. URL https://www.wandb.com/. Software available from wandb.com. (Cited on page 22.) Bonet, C., Courty, N., Septier, F., and Drumetz, L. Sliced-wasserstein gradient flows. ar Xiv preprint ar Xiv:2110.10972, 2021. (Cited on pages 1, 2, 5, and 22.) Brenier, Y. Polar factorization and monotone rearrangement of vector-valued functions. Communications on pure and applied mathematics, 44(4):375 417, 1991. (Cited on page 3.) Bunne, C., Meng-Papaxanthos, L., Krause, A., and Cuturi, M. Jkonet: Proximal optimal transport modeling of population dynamics. ar Xiv preprint ar Xiv:2106.06345, 2021. (Cited on pages 1, 2, and 3.) Carlier, G., Duval, V., Peyr e, G., and Schmitzer, B. Convergence of entropic schemes for optimal transport and gradient flows. SIAM Journal on Mathematical Analysis, 49(2):1385 1418, 2017. (Cited on pages 1 and 2.) Carrillo, J. A., Craig, K., and Patacchini, F. S. A blob method for diffusion. Calculus of Variations and Partial Differential Equations, 58(2):1 53, 2019a. (Cited on page 2.) Carrillo, J. A., Hittmeir, S., Volzone, B., and Yao, Y. Nonlinear aggregation-diffusion equations: radial symmetry and long time asymptotics. Inventiones mathematicae, 218(3):889 977, 2019b. (Cited on page 20.) Carrillo, J. A., Craig, K., Wang, L., and Wei, C. Primal dual methods for Wasserstein gradient flows. Foundations of Computational Mathematics, pp. 1 55, 2021. (Cited on pages 1, 2, 9, 18, 19, 20, and 21.) Cheng, X. and Bartlett, P. Convergence of langevin mcmc in kl-divergence. In Algorithmic Learning Theory, pp. 186 211. PMLR, 2018. (Cited on page 1.) Eckhardt, R., Ulam, S., and Von Neumann, J. the monte carlo method. Los Alamos Science, 15:131, 1987. (Cited on page 24.) Falcon, W. and Cho, K. A framework for contrastive selfsupervised learning and designing a new approach. ar Xiv preprint ar Xiv:2009.00104, 2020. (Cited on page 22.) Fan, J., Taghvaei, A., and Chen, Y. Scalable computations of Wasserstein barycenter via input convex neural networks. ar Xiv preprint ar Xiv:2007.04462, 2020. (Cited on page 5.) Fan, J., Liu, S., Ma, S., Chen, Y., and Zhou, H. Scalable computation of monge maps with general costs. ar Xiv preprint ar Xiv:2106.03812, 2021. (Cited on page 5.) Folland, G. B. Real analysis: modern techniques and their applications, volume 40. John Wiley & Sons, 1999. (Cited on page 16.) Variational Wasserstein gradient flow Frogner, C. and Poggio, T. Approximate inference with Wasserstein gradient flows. In International Conference on Artificial Intelligence and Statistics, pp. 2581 2590. PMLR, 2020. (Cited on pages 1 and 2.) Gao, Y., Jiao, Y., Wang, Y., Wang, Y., Yang, C., and Zhang, S. Deep generative learning via variational gradient flow. In International Conference on Machine Learning, pp. 2093 2101. PMLR, 2019. (Cited on pages 4 and 9.) Gershman, S., Hoffman, M., and Blei, D. Nonparametric variational inference. ar Xiv preprint ar Xiv:1206.4665, 2012. (Cited on page 8.) GI, B. On some unsteady motions of a liquid and gas in a porous medium. Prikl. Mat. Mekh., 16:67 78, 1952. (Cited on page 8.) Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. Generative adversarial nets. Advances in neural information processing systems, 27, 2014. (Cited on page 9.) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770 778, 2016. (Cited on page 25.) Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in neural information processing systems, 30, 2017. (Cited on page 9.) Huang, C.-W., Chen, R. T., Tsirigotis, C., and Courville, A. Convex potential flows: Universal probability distributions with optimal transport and convex optimization. ar Xiv preprint ar Xiv:2012.05942, 2020. (Cited on pages 5 and 22.) Hutchinson, M. F. A stochastic estimator of the trace of the influence matrix for laplacian smoothing splines. Communications in Statistics-Simulation and Computation, 18(3):1059 1076, 1989. (Cited on page 5.) Hwang, H. J., Kim, C., Park, M. S., and Son, H. The deep minimizing movement scheme. ar Xiv preprint ar Xiv:2109.14851, 2021. (Cited on page 2.) Jordan, R., Kinderlehrer, D., and Otto, F. The variational formulation of the Fokker Planck equation. SIAM journal on mathematical analysis, 29(1):1 17, 1998. (Cited on pages 1 and 3.) Kidger, P. and Lyons, T. Universal approximation with deep narrow networks. In Conference on learning theory, pp. 2306 2327. PMLR, 2020. (Cited on page 16.) Kingma, D. P. and Dhariwal, P. Glow: Generative flow with invertible 1x1 convolutions. Advances in neural information processing systems, 31, 2018. (Cited on page 9.) Korotin, A., Egiazarian, V., Asadulaev, A., Safin, A., and Burnaev, E. Wasserstein-2 generative networks. In International Conference on Learning Representations, 2021a. URL https://openreview.net/forum? id=b Eoxz W_EXsa. (Cited on pages 5, 24, and 27.) Korotin, A., Li, L., Solomon, J., and Burnaev, E. Continuous wasserstein-2 barycenter estimation without minimax optimization. In International Conference on Learning Representations, 2021b. URL https://openreview. net/forum?id=3t FAs5E-Pe. (Cited on page 5.) Korotin, A., Selikhanovych, D., and Burnaev, E. Neural optimal transport. Ar Xiv, abs/2201.12220, 2022. (Cited on pages 6 and 16.) Krizhevsky, A., Hinton, G., et al. Learning multiple layers of features from tiny images. 2009. (Cited on page 9.) Le Cun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradientbased learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. (Cited on page 9.) Li, W., Lu, J., and Wang, L. Fisher information regularization schemes for Wasserstein gradient flows. Journal of Computational Physics, 416:109449, 2020. (Cited on pages 1 and 2.) Lin, A. T., Li, W., Osher, S., and Mont ufar, G. Wasserstein proximal of gans. In International Conference on Geometric Science of Information, pp. 524 533. Springer, 2021. (Cited on page 1.) Liu, Q. and Wang, D. Stein variational gradient descent: A general purpose bayesian inference algorithm. In Lee, D., Sugiyama, M., Luxburg, U., Guyon, I., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc., 2016. URL https://proceedings. neurips.cc/paper/2016/file/ b3ba8f1bee1238a2f37603d90b58898d-Paper. pdf. (Cited on page 8.) Liu, Q., Lee, J., and Jordan, M. A kernelized stein discrepancy for goodness-of-fit tests. In International conference on machine learning, pp. 276 284. PMLR, 2016. (Cited on pages 7 and 23.) Mac Kay, D. J. and Mac Kay, D. J. Information theory, inference and learning algorithms. Cambridge university press, 2003. (Cited on page 24.) Variational Wasserstein gradient flow Makkuva, A., Taghvaei, A., Oh, S., and Lee, J. Optimal transport mapping via input convex neural networks. In International Conference on Machine Learning, pp. 6672 6681. PMLR, 2020. (Cited on pages 5 and 26.) Mika, S., Ratsch, G., Weston, J., Scholkopf, B., and Mullers, K.-R. Fisher discriminant analysis with kernels. In Neural networks for signal processing IX: Proceedings of the 1999 IEEE signal processing society workshop (cat. no. 98th8468), pp. 41 48. Ieee, 1999. (Cited on page 8.) Miyato, T., Kataoka, T., Koyama, M., and Yoshida, Y. Spectral normalization for generative adversarial networks. In International Conference on Learning Representations, 2018. URL https://openreview.net/forum? id=B1QRgzi T-. (Cited on pages 9 and 25.) Mokrov, P., Korotin, A., Li, L., Genevay, A., Solomon, J., and Burnaev, E. Large-scale wasserstein gradient flows. In Thirty-Fifth Conference on Neural Information Processing Systems, 2021. URL https://openreview. net/forum?id=nl Lj Iu Hs MHp. (Cited on pages 1, 2, 3, 5, 6, 7, 8, 18, 21, 22, 23, and 24.) Nguyen, X., Wainwright, M. J., and Jordan, M. I. Estimating divergence functionals and the likelihood ratio by convex risk minimization. IEEE Transactions on Information Theory, 56(11):5847 5861, 2010. (Cited on page 3.) Nowozin, S., Cseke, B., and Tomioka, R. f-gan: Training generative neural samplers using variational divergence minimization. In Proceedings of the 30th International Conference on Neural Information Processing Systems, pp. 271 279, 2016. (Cited on pages 1 and 15.) Otto, F. The geometry of dissipative evolution equations: the porous medium equation. 2001. (Cited on pages 1 and 3.) Peyr e, G. Entropic approximation of Wasserstein gradient flows. SIAM Journal on Imaging Sciences, 8(4):2323 2351, 2015. (Cited on pages 1 and 2.) Ronneberger, O., Fischer, P., and Brox, T. U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention, pp. 234 241. Springer, 2015. (Cited on page 25.) Rout, L., Korotin, A., and Burnaev, E. Generative modeling with optimal transport maps. ar Xiv preprint ar Xiv:2110.02999, 2021. (Cited on pages 5 and 27.) Salim, A., Korba, A., and Luise, G. The Wasserstein proximal gradient algorithm. ar Xiv preprint ar Xiv:2002.03035, 2020. (Cited on page 19.) Salimans, T., Karpathy, A., Chen, X., and Kingma, D. P. Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications. ar Xiv preprint ar Xiv:1701.05517, 2017. (Cited on page 25.) Santambrogio, F. Euclidean, metric, and wasserstein gradient flows: an overview. Bulletin of Mathematical Sciences, 7(1):87 154, 2017. (Cited on page 1.) Seguy, V., Damodaran, B. B., Flamary, R., Courty, N., Rolet, A., and Blondel, M. Large-scale optimal transport and mapping estimation. ar Xiv preprint ar Xiv:1711.02283, 2017. (Cited on page 26.) Shewchuk, J. R. et al. An introduction to the conjugate gradient method without the agonizing pain, 1994. (Cited on page 5.) Sriperumbudur, B. K., Fukumizu, K., Gretton, A., Sch olkopf, B., and Lanckriet, G. R. On the empirical estimation of integral probability metrics. Electronic Journal of Statistics, 6:1550 1599, 2012. (Cited on page 5.) Vatiwutipong, P. and Phewchean, N. Alternative way to derive the distribution of the multivariate ornstein uhlenbeck process. Advances in Difference Equations, 2019(1):1 7, 2019. (Cited on page 8.) V azquez, J. L. The porous medium equation: mathematical theory. Oxford University Press on Demand, 2007. (Cited on pages 3 and 8.) Villani, C. Topics in optimal transportation. Number 58. American Mathematical Soc., 2003. (Cited on page 3.) Wan, N., Li, D., and Hovakimyan, N. f-divergence variational inference. Advances in Neural Information Processing Systems, 33, 2020. (Cited on page 2.) Waskom, M. L. seaborn: statistical data visualization. Journal of Open Source Software, 6(60):3021, 2021. doi: 10.21105/joss.03021. URL https://doi.org/10. 21105/joss.03021. (Cited on page 22.) Wellner, J. A. Empirical processes: Theory and applications. Notes for a course given at Delft University of Technology, 2005. (Cited on page 17.) Wibisono, A. Sampling as optimization in the space of measures: The langevin dynamics as a composite optimization problem. In Conference on Learning Theory, pp. 2093 3027. PMLR, 2018. (Cited on page 1.) Yadan, O. Hydra - a framework for elegantly configuring complex applications. Github, 2019. URL https: //github.com/facebookresearch/hydra. (Cited on page 22.) Variational Wasserstein gradient flow Yang, Z., Zhang, Y., Chen, Y., and Wang, Z. Variational transport: A convergent particle-based algorithm for distributional optimization. ar Xiv preprint ar Xiv:2012.11554, 2020. (Cited on pages 1 and 2.) Zagoruyko, S. and Komodakis, N. Wide residual networks. ar Xiv preprint ar Xiv:1605.07146, 2016. (Cited on page 25.) Variational Wasserstein gradient flow The appendix is structured as follows. In Section A, we provide the proofs of Corollaries in Section 3.1 and the theoretical results in Section 4. In Section B, we give a Crank-Nicolson-typed extension of our method for a faster convergence with respect to the step size a. In Section C, we consider the case where the target functional F(P) involves the interaction energy, and propose to use forward-backward scheme to solve the Wasserstein GF. In Section D, for the sake of completeness, we discuss how to evaluate the probability density of each JKO step Pk. In Section E, we provide additional experimental results and discussions, such as the computational time. In Section F, we provide the training details of experiments other than image generation. In Section 5.5, we provide the training details and discussions of image generation experiment. A.1 Proof of variational formulas in Section 3.1 A.1.1 KL DIVERGENCE The KL divergence is the special instance of the f-divergence obtained by replacing f with f1(x) = x log x in (6) Df1(P Q) = EQ which, according to (7), admits the variational formulation Df1(P Q) = 1 + sup h EP [h(X)] EQ h eh(Z)i (12) where the convex conjugate f 1 (y) = ey 1 and a change of variable h h 1 are used. The variational formulation can be approximated in terms of samples from P and Q. For the case where we have only access to un-normalized density of Q, which is the case for the sampling problem, we use the following change of variable: h log(h) + log(µ) log(Q) where µ is a user designed distribution which is easy to sample from. Under such a change of variable, the variational formulation reads Df1(P Q) = 1 + sup h EP log h(X) + log µ(X) Eµ [h(Z)] . Note that the optimal function h is equal to the ratio between the densities of T Pk and µ. Using this variational form in the JKO scheme (5) yields Pk+1 = Tk Pk and Tk = arg min T max h EPk 2a + log h(T(X)) + log µ(T(X)) Eµ [h(Z)] . (13) Based on particle approximation, the implementable JKO is Tk = arg min T max h 1 N " X(k) i T(X(k) i ) 2 2a + log h(T(X(k) i )) + log µ(T(X(k) i )) Q(T(X(k) i )) Eµ [h(Z)] . (14) Remark A.1. The Donsker-Varadhan formula D(P Q) = sup h EP [h(X)] log EQ h eh(Z)i is another variational representation of KL divergence and it s a stronger than (12) because it s a upper bound of (12) for any fixed h. However, we cannot get an unbiased estimation of the objective using samples. A.1.2 GENERALIZED ENTROPY The generalized entropy can be also represented as f-divergence. In particular, let f2(x) = 1 m 1(xm x) and let Q be the uniform distribution on a set which is the superset of the support of density P(x) and has volume Ω. Then Df2(P Q) = Ωm 1 Z P m(x)dx 1 m 1. Variational Wasserstein gradient flow As a result, the generalized entropy can be expressed in terms of f-divergence according to G(P) = 1 m 1 Z P m(x)dx = 1 Ωm 1 Df2(P Q) + 1 Ωm 1(m 1). Upon using the variational representation of the f-divergence with f 2 (y) = (m 1)y + 1 the generalized entropy admits the following variational formulation G(P) = sup h EP [h(X)] EQ " (m 1)h(Z) + 1 + 1 Ωm 1(m 1). In practice, we find it numerically useful to let h = 1 m 1 m ˆh m 1 1 so that G(P) = 1 Ωm 1 sup ˆh m m 1 ˆhm 1(X) EQ h ˆhm(Z) i . With such a change of variable, the optimal function ˆh = T Pk/Q. Using this in the JKO scheme yields Pk+1 = Tk Pk, and Tk = arg min T max h 1 2a EPk X T(X) 2 + 1 Ωm 1 m m 1hm 1(X) EQ [hm(Z)] . A.1.3 JENSEN-SHANNON DIVERGENCE Jensen-Shannon divergence has been widely studied in GAN literature (Nowozin et al., 2016). The variational formula follows that f(x) = (x+1) log((1+x)/2)+x log x and f (y) = log(2 exp(y)). Plugging in the variational formula in the JKO scheme gives Tk = arg min T max h 1 2a EPk X T(X) 2 + EPk [log(1 h(X))] + EQ [log h(Z)] . A.2 Proof of Propostion 4.1 We present the proof of Propostion 4.1. Let us define J(h) := R hd P R f (h)d Q. 1. The proof follows from DH f (P, Q) = sup h H J(h) sup c R J(c) = sup c R {c f (c)} = f(1) = 0 where the last identity follows from the assumption that f(1) = 0. 2. The direction ( ) follows because J(h) Z hd P Z hd Q = 0, h H where f (y) = supx{xy f(x)} y1 f(1) = y is used. As a result, DH f (p Q) = suph H J(h) 0. Using part (1), this is only possible when DH f (P Q) = 0. To show the other direction ( ), for all h H , define g(λ) := J(f (1) + λh) where λ R. The function g(λ) attains its maximum at λ = 0 because g(λ) = J(f (1) + λh) suph H J(h) = DH f (P Q) = 0 and g(0) = J(f (1)) = f (1) f (f (1)) = f(1) = 0 by Fenchel identity. Therefore, the first-order optimality condition g (0) = 0 must hold. The result follows because g (0) = Z hd P Z hf (f (1))d Q = Z hd P Z hd Q Variational Wasserstein gradient flow 3. Let us define gh(λ) := J(f (1) + λh h 2,Q ). The first and the second derivatives of gh(λ) with respect to λ are: g h(λ) = Z h h d P Z h h f (f (1) + λh g h(λ) = Z h2 h 2 f (f (1) + λh By assumption on f, the convex conjugate f is strongly convex with constant 1 L and smooth with constant 1 α. Therefore, 1 α. As a result, 1 α where we used h 2 = R h2d Q. Therefore, gh(λ) is strongly concave and smooth and satisfies the inequalities: 2 g h(0)2 sup λ gh(λ) gh(0) L Upon using gh(0) = J(f (1)) = 0 and taking the sup over h H of all sides, 2 sup h H g h(0)2 sup h H sup λ gh(λ) L 2 sup h H g h(0)2. By the assumption that for all h H, c + λh H for c, λ R, sup h H sup λ gh(λ) = sup h H J(h) = DH f (P Q). The result follows by noting that suph H g h(0) = d H(P, Q). A.3 Proof of Proposition 4.3 Proof. For a given P and Q, let h0 = f ( d P d Q) and h H be such that h h0 2,Q ϵ. Similar to the proof of Proposition 4.1, define J(h) = R hd P R f (h)d Q. Then, DH f (P Q) = sup h H J(h) J( h) = J( h) J(h0) + J(h0) = J( h) J(h0) + Df(P Q) where J(h0) = Df(P Q) is used in the last step. The proof follows by showing that J( h) J(h0) 1 2α h h0 2 2,Q. In order to show this, note that f is 1 α smooth because f is α strongly convex. Then, f ( h(x)) f (h0(x)) f (h0(x))( h(x) h0(x)) + 1 2α| h(x) h0(x)|2, x. Taking the expectation over Q and adding R h0d P R hd P yields, J(h0) J( h) Z f (h0)( h h0)d Q + Z (h0 h)d P + 1 2α h h 2 2,Q. Then, the proof follows from f (h0) = f (f ( d P d Q)) = d P d Q to cancel the first two terms. Discussion on neural network function class Consider H is the class of neural nets with an arbitrary depth and mild assumption on the activation function. Following the proof of Theorem 1 in Korotin et al. (2022), we can verify that for any ϵ > 0, compactly supported Q, and function h 2,Q < , there exists a neural net h H such that h h 2,Q ϵ. Indeed, let Q be supported on X Rn, and X be compact, by Folland (1999, Proposition 7.9), the continuous functions C0(X) are dense in L2(Q). Further by Kidger & Lyons (2020, Theorem 3.2), the neural nets in H are dense in C0(X) with respect to L norm, and as such with respect to L2 norm. Putting these two pieces together gives neural nets are dense in L2(Q). Variational Wasserstein gradient flow A.4 Proof of Proposition 4.4 Proof. We first introduce the following notations J(h) = Z hd P Z f (h)d Q JM,N(h) = Z hd P (N) Z f (h)d Q(M), GP (h) = Z hd P Z hd P (N), GQ(h) = Z f (h)d Q Z f (h)d Q(M). Assume the suph H J(h) is attained at h = h and suph H JM,N(h) is attained at h = h M,N. sup h H JM,N(h) sup h H J(h) = JM,N(h M,N) sup h H J(h) JN(h M,N) J(h M,N) sup h H {|GP (h)|} + sup h H {|GQ(h)|}. sup h H J(h) sup h H JM,N(h) = J( h) sup h H JM,N(h) JM,N( h) J( h) sup h H {|GP (h)|} + sup h H {|GQ(h)|}. |DH f (P (N) Q(M)) DH f (P Q)| = | sup h H JM,N(h) sup h H J(h)| sup h H {|GP (h)|} + sup h H {|GQ(h)|}. The result follows by taking the expectation and the symmetrization inequality (Wellner, 2005, Lemma 5.1) to the last two terms E sup h H {|GP (h)|} + E sup h H {|GQ(h)|} 2RN(H, P) + 2RM(f H, Q). It s not difficult to prove the following corollary following the same logic. Corollary A.2. Let P (N) = 1 N PN i=1 δXi, where {Xi}N i=1 are i.i.d samples from P. Then, it follows that E[|DH f (P Q) DH f (P (N) Q)|] 2RN(H, P), where the expectation is over the samples and RN(H, P) denotes the Rademacher complexity of the function class H with respect to P for sample size N. A.5 Proof of Proposition 4.5 Proof. Suppose P0 = µ = N(0, I), Q = N(η, I) and F(P) is the KL divergence D(P Q), we parameterize Tk(x) = x + βk, hk(z) = exp(α k z + γk). Then the closed-form solution of JKO is P k = N(ηk, I) where ηk = η 1 1 (1 + a)k Our method adopts the JKO iteration (14) with the variational formula (2). Since µ is a user-defined Gaussian distribution, it is reasonable to assume Eµ[h(Z)] can be estimated precisely. To sample from Pk at the k-th JKO step, we sample N particles from the very beginning {Xk i }N i=1 P0 with empirical mean ηk 0 = 1 N PN i=1 Xk i , and pushforward them through maps T1, . . . , Tk 1. We also define ηK k = 1 N PN i=1 Tk 1 T1(XK i ). Clearly, ηK k = ηK 0 + Pk 1 j=0 βj for 1 k K, 1 K . Then the solutions of our method are βk = a(η ηk+1 k ) 1 + a , αk = βk + ηk+1 k ηk k, γk = α k ηk k αk 2 Variational Wasserstein gradient flow Thus the mean of PK is bηK = PK 1 j=0 βK. By standard matrix calculation, we have bηK = ηK εN, where εN = a 1 + a ηj 0 (1 + a)K j . Denote N = η (1+a)K . By the closed-form of KL divergence between two Gaussians, DH(P K Q) = 2 N/2 (15) Denote ξK,N = E[ εN 2] = a 1+a 2 n N PK j=1 1 (1+a)2(K j) . Additionally, by the Corollary 3.3, we can derive DH(P (N) K Q) = ηK K η 2/2, where ηK K is the mean of P (N) K . Thus, E[DH(P (N) K Q)] = E[ ηK K η 2/2]| = E[| ηK K ηK 2/2 (ηK K ηK) (ηK η) + ηK η 2/2|] E[ ηK K ηK 2/2] + E[|(ηK K ηK) (ηK η)|] + ηK η 2/2 = E[ ηK 0 εN 2/2] + E[|(ηK 0 εN) (ηK η)|] + 2 N/2 E[ ηK 0 εN 2/2] + ηK η E[ ηK 0 εN ] + 2 N/2 E[ ηK 0 εN 2/2] + N q E[ ηK 0 εN 2] 2N + η (1 + a)K By triangular inequality and (15), (16) it holds that E[|DH(P K Q) DH(P (N) K Q)|] DH(P K Q) + E[|DH(P (N) K Q)|] 2 N/2 + N q ξK,N + 1/N)2. B Extension to Crank-Nicolson scheme Consider the Crank-Nicolson inspired JKO scheme (Carrillo et al., 2021) below Pk+1 = arg min P Pac(Rn) 1 2a W 2 2 (P, Pk) + 1 The difficulty of implementing this scheme with neural-network based method is the easy access to the density of Pk. The predecessors Mokrov et al. (2021) and Alvarez-Melis et al. (2021) don t have this property, while in our algorithm, Pk hk 1Γk 1(k > 1). This is because our optimal hk is equal to or can be transformed to the ratio between densities of Pk+1 and Γk. Assume h can learn to approximate Pk+1/Γk, our method can be natually extended to Crank-Nicolson inspired JKO scheme. C Extension to the interaction energy functional In this section, we consider F(P) involves the interaction energy F(P) = W(P) := Z Z W(x y)P(x)P(y)dxdy, W : Rn R is symmetric, i.e. W(x) = W( x). Variational Wasserstein gradient flow C.1 Forward Backward (FB) scheme When F(P) involves the interaction energy W(P), we add an additional forward step to solve the gradient flow: 2 := (I a x(W Pk)) Pk (17) Pk+1 := Tk+ 1 where I is the identity map, and Tk+ 1 2 is defined by replacing k by k + 1 2 in (8). In other words, the first gradient descent step (17) is a forward discretization of the gradient flow and the second JKO step (18) is a backward discretization. x(W P) can be written as expectation Ey P x(W(x y)), thus can also be approximated by samples. The computation complexity of step (17) is at most O(N 2) where N is the total number of particles to push-forward. This scheme has been studied as a discretization of gradient flows and proved to have sublinear convergence to the minimizer of F(P) under some regular assumptions (Salim et al., 2020). We make this scheme practical by giving a scalable implementation of JKO. Since W(P) can be equivalently written as expectation Ex,y P [W(x y)], there exists another non-forward-backward (non-FB) method , i.e., removing the first step and integrating W(P) into a single JKO step: Pk+1 = Tk Pk and Tk = arg min T (EPk X T(X) 2/2a +EX,Y Pk[W(T(X) T(Y ))] + sup h V(T, h)). In practice, we observe the FB scheme is more stable and gives more regular results however converge slower than non-FB scheme. The detailed discussion appears in the Appendix C.2, C.4. Remark C.1. In principle, one can single out log(Q) term from (13) and perform a similar forward step Pk+ 1 2 = (I a( x Q)/Q) Pk (Salim et al., 2020), but we don t observe improved performance of doing this in sampling task. C.2 Simulation solutions to Aggregation equation Alvarez-Melis et al. (2021) proposes using the neural network based JKO, i.e. the backward method, to solve (19). They parameterize T as the gradient of the ICNN. In this section, we use two cases to compare the forward method and backward when F(P) = W(P). This could help explain the FB and non-FB scheme performance difference later in Section C.4. We study the gradient flow associated with the aggregation equation t P = (P W P), W : Rn R. (19) The forward method is Pk+1 := (I a x(W Pk)) Pk. The backward method or JKO is Pk+1 := Tk Pk, Tk = arg min T 2a EPk[ X T(X) 2] + EX,Y Pk[W(T(X) T(Y ))] . Example 1 We follow the setting in Carrillo et al. (2021, Section 4.3.1 ). The interaction kernel is W(x) = x 4 2 , and the initial measure P0 is a Gaussian N(0, 0.25I). In this case, x(W Pk) becomes Ey Pk ( x y 2 1)(x y) . We use step size a = 0.05 for both methods and show the results in Figure 8. Example 2 We follow the setting in Carrillo et al. (2021, Section 4.2.3 ). The interaction kernel is W(x) = x 2 2 ln x , and the initial measure P0 is N(0, 1). The unique steady state for this case is The reader can refer to Alvarez-Melis et al. (2021, Section 5.3) for the backward method performance. As for the forward method, x(W Pk) becomes Ey Pk h x y 1 x y i . Because the kernel W enforces repulsion near the origin and P0 Variational Wasserstein gradient flow (a) Forward method k = 23, t = 1.15 (b) Forward method k = 200, t = 10 (c) Backward method k = 23, t = 1.15 (d) Backward method k = 40, t = 2 Figure 8: The steady state is supported on a ring of radius 0.5. Backward converges faster to the steady rate but is unstable. As k goes large, it cannot keep the regular ring shape and will collapse after k > 50. is concentrated around origin, x(W P) will easily blow up. So the forward method is not suitable for this kind of interaction kernel. Through the above two examples, if x(W P) is smooth, we can notice the backward method converges faster, but is not stable when solving (19). This shed light on the FB and non-FB scheme performance in Section C.3, C.4. However, if x(W P) has bad modality such as Example 2, the forward method loses the competitivity. C.3 Simulations to Aggregation Diffusion Equation with FB scheme Figure 9: Histogram for simulated measures Pk by FB scheme at different k. We simulate the evolution of solutions to the following aggregation-diffusion equation: t P = (P W P) + 0.1 P m, W(x) = e x 2/π. This corresponds to the energy function W(P) + 0.1G(P). There is no explicit closed-form solution for this equation except for the known singular steady state (Carrillo et al., 2019b), thus we only provide qualitative results in Figure 9. We use the same parameters in Carrillo et al. (2021, Section 4.3.3). The initial distribution is a uniform distribution supported on [ 3, 3] [ 3, 3] and the JKO step size a = 0.5. We utilize FB scheme to simulate the gradient flow for this equation with m = 3 on R2 space. With this choice W(x), x(W Pk) is equal to Ey Pk h 2e x y 2/π i in the gradient descent step (17). And we estimate x(W Pk) with 104 samples from Pk. Variational Wasserstein gradient flow Throughout the process, the aggregation term (P W P) and the diffusion 0.1 P m adversarially exert their effects and cause the probability measure split to four pulses and converge to a single pulse in the end. Our result aligns with the simulation of discretization method (Carrillo et al., 2021) well. C.4 Simulation solutions to Aggregation-diffusion equation with non-FB scheme In Figure 10, we show the non-FB solutions to Aggregation-diffusion equation in Section C.3. FB scheme should be independent with the implementation of JKO, but in the following context, we assume FB and non-FB are both neural network based methods discussed in Section 3. Non-FB scheme reads Pk+1 = Tk Pk Tk = arg min T 2a EPk[ X T(X) 2] + EX,Y Pk[W(T(X) T(Y ))] + G(T, h) , where G(T, h) is represented by the variational formula (10). We use the same step size a = 0.5 and other PDE parameters as in Section C.3. Figure 10: Histograms for simulated measures Pk by non-FB scheme at different k. Comparing the FB scheme results in Figure 9 and the non-FB scheme results in Figure 10, we observe non-FB converges 1.5 slower than the finite difference method (Carrillo et al., 2021), and FB converges 3 slower than the finite difference method. This may because splitting one JKO step to the forward-backward two steps removes the aggregation term effect in the JKO, and the diffusion term is too weak to make a difference in the loss. Note at the first several k, both Pk and Q are nearly the same uniform distributions, so h is nearly a constant and T(x) exerts little effect in the variational formula of G(P). Another possible reason is a single forward step for aggregation term converges slower than integrating aggregation in the backward step, as we discuss in Section C.2 and Figure 8. However, FB generates more regular measures. We can tell the four pulses given by FB are more symmetric. We speculate this is because gradient descent step in FB utilizes the geometric structure of W(x) directly, but integrating W(P) in neural network based JKO losses the geometric meaning of W(x). D Evaluation of the density In this section, we assume the solving process doesn t use forward-backward scheme, i.e. all the probability measures Pk are obtained by performing JKO one by one. Otherwise, the map I a x(W Pk) = I Ey Pk x(W(x y)) includes an expectation term and becomes intractable to push-backward particles to compute density. If T is invertible, these exists a standard approach, which we present here for completeness, to evaluate the density of Pk (Alvarez-Melis et al., 2021; Mokrov et al., 2021) through the change of variables formula. More specifically, we assume T is parameterized by the gradient of an ICNN ϕ that is assumed to be strictly convex. Thus we can guarantee that the gradient ϕ invertible. To evaluate the density Pk(xk) at point xk, we back propagate through the sequence of maps Tk = ϕk, . . . , T1 = ϕ1 to get xi = T 1 i+1 T 1 i+2 T 1 k (xk). Variational Wasserstein gradient flow The inverse map T 1 j = ( ϕj) 1 = ϕ j can be obtained by solving the convex optimization xj 1 = arg max x Rn x, xj ϕj(x). (20) Then, by the change of variables formula, we obtain log[d Pk(xk)] = log[d P0(x0)] i=1 log 2ϕi(xi 1) , (21) where 2ϕi(xi 1) is the Hessian of ϕi and | 2ϕi(xi 1)| is its determinant. By iteratively solving (20) and plugging the resulting xj into (21), we can recover the density d Pk(xk) at any point. E Additional experiment results and discussions E.1 Computational time The forward step (17) takes about 14 seconds to pushforward one million points. Other than learning generative model, assume each JKO step involves 500 iterations, the number of iterations J2 = 3, J3 = 2, then the training of each JKO step (18) takes around 15 seconds. For learning image generative model, assume J2 = 1, J3 = 5, then the training of each JKO step (18) takes around 20 minutes. E.2 Learning of function h The learning of the function h is crucial because it determines the effectiveness of variational formula. In our KL divergence and generalized entropy variational formulas, the optimal h is equal to T Pk/Γ, which can have large Lipschitz constant in some high dimensional applications and become difficult to approximate. To tackle this issue, we replace h by exp( h) 1, thus the optimal h is log(h + 1), whose Lipschitz constant is much weakened. We apply this trick in Section 5.1 and observe the improved performance. In image tasks, h works like a discriminator in GAN. A typical problem in GAN is that the discriminator can be too strong to let generator keep learning. To avoid this, we add the spectral normalization in h such that the Lipschitz of h is bounded by 1. E.3 Convergence comparison with the same number of JKO steps In this section, we show the convergence comparison under the constraint of performing same number of JKO steps for all methods. The result is in Figure 11. We repeat the experiment for 5 times with the same global random seed 1, 2, 3, 4, 5 for all methods. JKO-ICNN shows large variance and instability after longer run in high dimension. Specifically, we observe that at random seed 2 in dimension 24, JKO-ICNN-d converges for the first 19 JKO steps and then suddenly diverges, causing the occurrence of an extreme point. The similar instability issue is also reported in Bonet et al. (2021, Figure 3). With the same random seeds, through 40 JKO steps, we don t observe this instability issue using our method. F Experiments implementation details other than image Our experiments are conducted on Ge Force RTX 3090 or RTX A6000. We always make sure the comparison is conducted on the same GPU card when comparing training time with other methods. Our code is written in Pytorch-Lightning (Falcon & Cho, 2020). We use other wonderful python libraries including W&B (Biewald, 2020), hydra (Yadan, 2019), seaborn (Waskom, 2021), etc. We also adopt the code given by Mokrov et al. (2021) for some experiments. For fast approximation of log det 2ϕ, we adapt the code given by Huang et al. (2020) with default parameters therein. Without further specification, we use the following parameters: The number of iterations: J1 = 600. J2 = 3. J3 = 1. Variational Wasserstein gradient flow Figure 11: Quantitative comparison in converging to GMM with the constraint of performing 40 JKO steps for all methods. We calculate the kernelized Stein divergence between the generated distribution and the target distribution. The batch size is fixed to be M = 100. The learning rate is fixed to be 0.001. All the activation functions are set to be PRe Lu. h has 3 layers and 16 neurons in each layer. T has 4 layers and 16 neurons in each layer. The transport map T can be parametrized in different ways. We use a residual MLP network for it in Section 5.1, 5.2, 5.3, C.3, C.2, and the gradient of a strongly convex ICNN in Section 5.4, C.4. Except image task, the dual test function h is always a MLP network with quadratic or sigmoid actication function in the final layer to promise h is positive. The networks T and h in Section 5.5 are chosen to be UNet and a normal CNN. F.1 Calculation of error criteria Sampling from GMM We estimate the kernelized Stein discrepancy (KSD) following the author s instructions (Liu et al., 2016). We draw N samples X1, . . . , XN from each method, and estimate KSD as KSD(P, Q) = 1 N(N 1) 1 i =j N u Q(Xi, Xj), u Q(x, x ) = sq(x) k (x, x ) sq (x ) + sq(x) x k (x, x ) + xk (x, x ) sq (x ) + trace ( x,x k (x, x )) , s Q = x log Q(x) = x Q(x) We choose the kernel φ to be the RBF kernel and use the same bandwidth for all methods. We fix N = 1 105, OU process For each method, we draw 5 105 samples from Pt and calculate the empirical mean eµt and covariance eΣt. Then we calculate the Sym KL between N( µt, Σt) and the exact solution. Porous media equation We calculate the density of Pk according to Section D and estimate the Sym KL using Monte Carlo according to the instructions in Mokrov et al. (2021). F.2 Sampling from Gaussian Mixture Models (Section 5.1 ) Two moons We run K = 10 JKO steps with J2 = 6, J3 = 1 inner iterations. h has 5 layers. T has 4 layers. Variational Wasserstein gradient flow Table 4: Hyper-parameters in the GMM convergence experiments. Our methods JKO-ICNN Dimension ℓ T width T depth h width h depth width depth 2 5 8 3 8 3 256 2 4 5 32 4 32 3 384 2 8 5 32 4 32 4 512 2 15 3 64 4 64 4 1024 2 17 3 64 4 64 4 1024 2 24 3 64 5 64 4 1024 2 32 3 64 5 64 4 1024 2 64 2 128 5 128 4 - - 128 1.5 128 5 128 4 - - Table 5: Bayesian logistic regression accuracy and log-likelihood full results. Accuracy Log-Likelihood Dataset # features dataset size Ours JKO-ICNN-d SVGD Ours JKO-ICNN-d SVGD covtype 54 581012 0.753 0.75 0.75 -0.528 -0.515 -0.515 splice 60 2991 0.84 0.845 0.85 -0.38 -0.36 -0.355 waveform 21 5000 0.785 0.78 0.765 -0.455 -0.485 -0.465 twonorm 20 7400 0.982 0.98 0.98 -0.056 -0.059 -0.062 ringnorm 20 7400 0.73 0.74 0.74 -0.5 -0.5 -0.5 german 20 1000 0.67 0.67 0.65 -0.59 -0.6 -0.6 image 18 2086 0.866 0.82 0.815 -0.394 -0.43 -0.44 diabetis 8 768 0.786 0.775 0.78 -0.45 -0.45 -0.46 banana 2 5300 0.55 0.55 0.54 -0.69 -0.69 -0.69 GMM The mean of Gaussian components are randomly sampled from Uniform([ ℓ/2, ℓ/2]n). J3 = 2. The map T has dropout in each layer with probability 0.04. The learning rate of our method is 1 10 3 for the first 20 JKO steps and 4 10 4 for the last 20 JKO steps. The learning rate of JKO-ICNN is 5 10 3 for the first 20 JKO steps, and then 2 10 3 for the rest steps. The batch size is 512 and each JKO step runs 1000 iterations for all methods. The rest parameters are in Table 4. F.3 Ornstein-Uhlenbeck Process (Section 5.2) We use nearly all the same hyper-parameters as Mokrov et al. (2021), including learning rate, hidden layer width, and the number of iterations per JKO step. Specifically, we use a residual feed-forward NN to work as T, i.e. without activation function. h and T both have 2 layers and 64 hidden neurons per layer for all dimensions. We also train them for J1 = 500 iterations per each JKO with learning rate 0.005. The batch size is M = 1000. F.4 Bayesian Logistic Regression (Section 5.3) Same as Mokrov et al. (2021), we use JKO step size a = 0.1 and calculate the log-likelihood and accuracy with 4096 random parameter samples. The rest parameters are in Table 6. F.5 Porous media equation (Section 5.4) We use rejection sampling (Eckhardt et al., 1987) to sample from P0 because its computational time is more promising than MCMC methods. However, the rejection sampling acceptance rate is expected to be exponentially small (Mac Kay & Mac Kay, 2003, Ch 29.3) in dimension, and empirically it s intractable when n > 6. So we only give the results for n 6. In the experiment, h have 4 layers and 16 neurons in each layer with CELU activation functions except the last layer, which is activated by PRe LU. To parameterize the map, we adopt Dense ICNN (Korotin et al., 2021a) structure with width 64, depth 2 and rank 1. The batch size is M = 1024. Each JKO step runs J1 = 1000 iterations. The learning rate for both ϕ and h is 1 10 3. J3 = 1 for dimension 3 and J3 = 2 for dimension 6. Variational Wasserstein gradient flow Table 6: Hyper-parameters in the Bayesian logistic regression. Dataset K M J1 T width T depth h width h depth T learning rate h learning rate covtype 7 1024 7000 128 4 128 3 2 10 5 2 10 5 splice 50 1024 400 128 5 128 4 1 10 4 1 10 4 waveform 5 1024 1000 32 4 32 4 1 10 5 5 10 5 twonorm 15 512 800 32 4 32 3 1 10 3 1 10 3 ringnorm 9 1024 500 32 4 32 4 1 10 5 1 10 5 german 14 800 640 32 4 32 4 2 10 4 2 10 4 image 12 512 1000 32 4 32 4 1 10 4 1 10 4 diabetis 16 614 835 32 4 32 3 1 10 4 1 10 4 banana 16 512 1000 16 2 16 2 5 10 4 5 10 4 F.6 Aggregation-diffusion equation (Section C.3 and C.4) Each JKO step contains J1 = 200 iterations. The batch size is M = 1000. G Image experiment details G.1 Hyperparameters and network architecture We use Adam optimizer with learning rate 2 10 4 and other default settings in Py Torch library. We choose J2 = 1, J3 = 5. Our h network follows the architecture of Res Net classifier network (He et al., 2016). More specially, our module uses two downsampling modules, which results in three feature map resolution (32 32, 16 16, 8 8). We use two convolutional residual blocks for each resolution and pass the features extracted from at 8 8 resolution into a 2-layer MLP. We use 128 channels for CNN and 128 hidden neurons for the MLP. Similar to training generative adversarial networks, we found adding regularizers on h network can help stabilize training. Thus, we apply the spectral normalization (Miyato et al., 2018) on h network. Our framework requires the Tk networks to approximate mappings between same dimensional data spaces. Our network architecture follows the backbone of Pixel CNN++ (Salimans et al., 2017), which can be viewed as a modified U-Net (Ronneberger et al., 2015) based on Wide Res Net (Zagoruyko & Komodakis, 2016). More specifically, we use 3 downsampling and 3 upsampling modules, which results in four feature map resolutions (32 32, 16 16, 8 8, 4 4). At each resolution, we have two convolutional residual blocks. We use 64, 128, 256, 512 channels for as image resolution decreases. Here are more training details: We resize MNIST image to 32 32 resolution so that we h, Tk networks can work on both MNIST and CIFAR10 with small modification of input channel. We use random horizontal flips during training for CIFAR10. We use batch size M = 128. On CIFAR10, we use implementation from torch-fidelity2 to calculate FID scores with 50k samples. The JKO step size a controls the divergence between Pk and Pk+1. We observe training with large a has unstable issues and mode collapse, a small a suffers from slower convergence. We found a = 5 works well on both MNIST and CIFAR10 datasets. We use 10 epochs to train each Pk, we notice P30 generates realistic images when a = 5.0. However, we find FID score decreases as k increases. We present the change of FID score of samples from different Pk in Figure 12. 2https://github.com/toshas/torch-fidelity Variational Wasserstein gradient flow Figure 12: The FID score converges as k increases on CIFAR10 datset. Figure 13: Mode collapsing in GANs. G.2 More Comparison Comparison with GANs. As we use Jensen-Shannon divergence in our scheme, JKO-Flow specializes to Jensen-Shannon GAN when a , K = 1. However, we found training with a , K = 1 is very unstable and suffer mode collapsing occasionally. Though training GANs can not recover the gradient flow from noise to image, it is interesting to compare JKO-Flow and GANs in term of sampling quality. To make a fair comparison, we instantiate generator network as the same as Tk network and discriminator as h for GANs. We note such choice is not optimal for GAN since generators in existing works usually map a lower dimensional Gaussian noise into images instead of mapping from same dimensional space. We believe the comparison and JKO-Flow scheme may help future research when modeling mapping between same dimensional data spaces. As shown in Table 7, JKO-Flow enjoys better sample qualities. Empirically we found training GANs is more challenging when latent space is relative large and with more complex generator networks as mode collapsing becomes more common. We find the additional Wasserstein distance loss in JKO-Flow can be viewed a regularizer to avoid mode collapsing because Tk will receive large penalty if it maps all inputs into a local minimal. However, one shortcoming of our method is the scheme of JKO-Flow needs to model a sequence of generators instead of one generate that push P0 particles into Q, and small step size controlled by a resulted in slower convergence and more training time. Comparison with more generative models based on gradient flows and optimal transport maps. Most existing works in this line focus on the latent spaces of pre-trained autoencoders (Seguy et al., 2017; An et al., 2019; 2020; Makkuva et al., Method FID score GAN (JKO-Flow with a , K = 1) 80 WGAN-GP 62.3 SN-GAN 43.2 JKO-Flow 23.1 Table 7: Comparison between JKO-Flow and various GANs. The generator and discriminator networks in GANs follow same architecture of Pk and h network in JKO-Flow. Variational Wasserstein gradient flow Method FID score Inception Score AE-OT (An et al., 2019) 28.5 - AE-OT-GAN (An et al., 2020) 17.1 - OTM (Rout et al., 2021) 20.69 7.41 0.11 JKO-Flow 23.1 7.48 0.12 Table 8: More comparison among generative models on CIFAR10. 2020; Korotin et al., 2021a). The approach reduces burden of training gradients and optimal transport maps since tasks of modeling complex image modality and interactions between pixels are left to pre-trained decoders partially. We note the recent work Rout et al. (2021) investigates mappings between distributions located on the spaces with same dimensionality or unequal dimensionality. However, they only demonstrate the unconditional image generative model based on an embedding from a lower dimensional Gaussian distribution to image distributions. In contrast, we show JKO-Flow can learn complex mappings between both high dimensional distribution and achieve encouraging performance when applying such learned mappings in the challenging image generation task without additional conditional signal. We include more comparison in Table 8. G.3 More generated samples and trajectories We include more results of JKO-Flow. Figure 14, Figure 16, Figure 15, and Figure 17 show more generated samples from PK and trajectories from JKO-Flow. Variational Wasserstein gradient flow Figure 14: More MNIST sample from JKO-Flow Variational Wasserstein gradient flow Figure 15: More MNIST trajectories from JKO-Flow with K = 1 to K = 30. Variational Wasserstein gradient flow Figure 16: More CIFAR10 sample from JKO-Flow Variational Wasserstein gradient flow Figure 17: More CIFAR10 trajectories from JKO-Flow with K = 1 to K = 30.