# largescale_wasserstein_gradient_flows__8eabf59a.pdf Large-Scale Wasserstein Gradient Flows Petr Mokrov Skolkovo Institute of Science and Technology Moscow Institute of Physics and Technology Moscow, Russia petr.mokrov@skoltech.ru Alexander Korotin* Skolkovo Institute of Science and Technology Artificial Intelligence Research Institute Moscow, Russia a.korotin@skoltech.ru Lingxiao Li Massachusetts Institute of Technology Cambridge, Massachusetts, USA lingxiao@mit.edu Aude Genevay Massachusetts Institute of Technology Cambridge, Massachusetts, USA aude.genevay@gmail.com Justin Solomon Massachusetts Institute of Technology Cambridge, Massachusetts, USA jsolomon@mit.edu Evgeny Burnaev Skolkovo Institute of Science and Technology Artificial Intelligence Research Institute Moscow, Russia e.burnaev@skoltech.ru Wasserstein gradient flows provide a powerful means of understanding and solving many diffusion equations. Specifically, Fokker-Planck equations, which model the diffusion of probability measures, can be understood as gradient descent over entropy functionals in Wasserstein space. This equivalence, introduced by Jordan, Kinderlehrer and Otto, inspired the so-called JKO scheme to approximate these diffusion processes via an implicit discretization of the gradient flow in Wasserstein space. Solving the optimization problem associated to each JKO step, however, presents serious computational challenges. We introduce a scalable method to approximate Wasserstein gradient flows, targeted to machine learning applications. Our approach relies on input-convex neural networks (ICNNs) to discretize the JKO steps, which can be optimized by stochastic gradient descent. Unlike previous work, our method does not require domain discretization or particle simulation. As a result, we can sample from the measure at each time step of the diffusion and compute its probability density. We demonstrate our algorithm s performance by computing diffusions following the Fokker-Planck equation and apply it to unnormalized density sampling as well as nonlinear filtering. 1 Introduction Stochastic differential equations (SDEs) are used to model the evolution of random diffusion processes across time, with applications in physics [63], finance [22, 52], and population dynamics [35]. In machine learning, diffusion processes also arise in applications filtering [34, 21] and unnormalized posterior sampling via a discretization of the Langevin diffusion [70]. The time-evolving probability density ρt of these diffusion processes is governed by the Fokker Planck equation. Jordan, Kinderlehrer, and Otto [32] showed that the Fokker-Planck equation is Equal contribution. 35th Conference on Neural Information Processing Systems (Neur IPS 2021). equivalent to following the gradient flow of an entropy functional in Wasserstein space, i.e., the space of probability measures with finite second order moment endowed with the Wasserstein distance. This inspired a simple minimization scheme called JKO scheme, which consists an implicit Euler discretization of the Wasserstein gradient flow. However, each step of the JKO scheme is costly as it requires solving a minimization problem involving the Wasserstein distance. One way to compute the diffusion is to use a fixed discretization of the domain and apply standard numerical integration methods [18, 49, 15, 17, 40] to get ρt. For example, [50] proposes a method to approximate the diffusion based on JKO stepping and entropy-regularized optimal transport. However, these methods are limited to small dimensions since the discretization of space grows exponentially. An alternative to domain discretization is particle simulation. It involves drawing random samples (particles) from the initial distribution and simulating their evolution via standard methods such as Euler-Maruyama scheme [36, M9.2]. After convergence, the particles are approximately distributed according to the stationary distribution, but no density estimate is readily available. Another way to avoid discretization is to parameterize the density of ρt. Most methods approximate only the first and second moments ρt, e.g., via Gaussian approximation. Kalman filtering approaches can then compute the dynamics [34, 39, 33, 61]. More advanced Gaussian mixture approximations [65, 1] or more general parametric families have also been studied [64, 69]. In [48], variational methods are used to minimize the divergence between the predictive and the true density. Recently, [24] introduced a parametric method to compute JKO steps via entropy-regularized optimal transport. The authors regularize the Wasserstein distance in the JKO step to ensure strict convexity and solve the unconstrained dual problem via stochastic program on a finite linear subset of basis functions. The method yields unnormalized probability density without direct sample access. Recent works propose scalable continuous optimal transport solvers, parametrizing the solutions by reproducing kernels [10], fully-connected neural networks [62], or Input Convex Neural Networks (ICNNs) [37, 44, 38]. In particular, ICNNs gained attention for Wasserstein-2 transport since their gradients ψθ : RD RD can represent OT maps for the quadratic cost. These continuous solvers scale better to high dimension without discretizing the input measures, but they are too computationally expensive to be applied directly to JKO steps. Contributions. We propose a scalable parametric method to approximate Wasserstein gradient flows via JKO stepping using input-convex neural networks (ICNNs) [6]. Specifically, we leverage Brenier s theorem to bypass the costly computation of the Wasserstein distance, and parametrize the optimal transport map as the gradient of an ICNN. Given sample access to the initial measure ρ0, we use stochastic gradient descent (SGD) to sequentially learn time-discretized JKO dynamics of ρt. The trained model can sample from a continuous approximation of ρt and compute its density dρt dx (x). We compute gradient flows for the Fokker-Planck free energy functional FFP given by (5), but our method generalizes to other cases. We demonstrate performance by computing diffusion following the Fokker-Planck equation and applying it to unnormalized density sampling as well as nonlinear filtering. Notation. P2(RD) denotes the set of Borel probability measures on RD with finite second moment. P2,ac(RD) denotes its subset of probability measures absolutely continuous with respect to Lebesgue measure. For ρ P2,ac(RD), we denote by dρ dx(x) its density with respect to the Lebesgue measure. Π(µ, ν) denotes the set of probability measures on RD RD with marginals µ and ν. For measurable T : RD RD, we denote by T the associated push-forward operator between measures. 2 Background on Wasserstein Gradient Flows We consider gradient flows in Wasserstein space (P2(RD), W2), the space of probability measures with finite second moment on RD endowed with the Wasserstein-2 metric W2. Wasserstein-2 distance. The (squared) Wasserstein-2 metric W2 between µ, ν P2(RD) is W2 2(µ, ν) def = min π Π(µ,ν) RD RD x y 2 2 dπ(x, y), (1) where the minimum is over measures π on RD RD with marginals µ and ν respectively [68]. For µ P2,ac(RD), there exists a µ-unique map ψ : RD RD that is the gradient of a convex function ψ : RD R { } satisfying ψ µ = ν [46]. From Brenier s theorem [13], it follows that π = [id RD, ψ ] µ is the unique minimizer of (1), i.e., W2 2(µ, ν) = Z RD x ψ (x) 2 2 dµ(x). Wasserstein Gradient Flows. In the Euclidean case, gradient flows along a function f : R R follow the steepest descent direction and are defined through the ODE dxt dt = f(xt). Discretization of this flow leads to the gradient descent minimization algorithm. When functionals are defined over the space of measures equipped with the Wasserstein-2 metric, the equivalent flow is called the Wasserstein gradient flow. The idea is similar: the flow follows the steepest descent direction, but this time the notion of gradient is more complex. We refer the reader to [4] for exposition of gradient flows in metric spaces, or [59, Chapter 8] for an accessible introduction. A curve of measures {ρt}t R+ following the Wasserstein gradient flow of a functional F solves the continuity equation ρt t = div(ρt x F (ρt)), s.t. ρ0 = ρ0, (2) where F ( ) is the first variation of F [4, Theorem 8.3.1]. The term on the right can be understood as the gradient of F in Wasserstein space, a vector field perturbatively rearranging the mass in ρt to yield the steepest possible local change of F. Wasserstein gradient flows are used in various applied tasks. For example, gradient flows are applied in training [8, 43, 25] or refinement [7] of implicit generative models. In reinforcement learning, gradient flows facilitate policy optimization [55, 72]. Other tasks include crowd motion modelling [45, 58, 50], dataset optimization [2], and in-between animation [26]. Many applications come from the connection between Wasserstein gradient flows and SDEs. Consider an RD-valued stochastic process {Xt}t R+ governed by the following Itô SDE: d Xt = Φ(Xt)dt + p 2β 1d Wt, s.t. X0 ρ0 (3) where Φ : RD R is the potential function, Wt is the standard Wiener process, and β > 0 is the magnitude. The solution of (3) is called an advection-diffusion process. The marginal measure ρt of Xt at each time satisfies the Fokker-Planck equation with fixed diffusion coefficient: ρt t = div( Φ(x)ρt) + β 1 ρt, s.t. ρ0 = ρ0. (4) Equation (4) is the Wasserstein gradient flow (2) for F given by the Fokker-Planck free energy functional [32] FFP(ρ) = U(ρ) β 1E(ρ), (5) where U(ρ) = R RD Φ(x)dρ(x) is the potential energy and E(ρ) = R dx(x)dρ(x) is the entropy. As the result, to solve the SDE (3), one may compute the Wasserstein gradient flow of the Fokker-Planck equation with the free-energy functional FFP given by (5). JKO Scheme. Computing Wasserstein gradient flows is challenging. The closed form solution is typically unknown, necessitating numerical approximation techniques. Jordan, Kinderlehrer, and Otto proposed a method later abbreviated as JKO integration to approximate the dynamics of ρt in (2) [32]. It consists of a time-discretization update of the continuous flow given by: ρ(k) arg min ρ P2(Rn) 2h W2 2(ρ(k 1), ρ) (6) where ρ(0) = ρ0 is the initial condition and h > 0 is the time-discretization step size. The discrete time gradient flow converges to the continuous one as h 0, i.e., ρ(k) ρkh. The method was further developed in [4, 60], but performing JKO iterations remains challenging thanks to the minimization with respect to W2. A common approach to perform JKO steps is to discretize the spatial domain. For support size 106, (6) can be solved by standard optimal transport algorithms [51]. In dimensions D 3, discrete supports can hardly approximate continuous distributions and hence the dynamics of gradient flows. To tackle this issue, [24] propose a stochastic parametric method to approximate the density of ρt. Their method uses entropy-regularized optimal transport (OT), which is biased. 3 Computing Wasserstein Gradient Flows with ICNNs We now describe our approach to compute Wasserstein gradient flows via JKO stepping with ICNNs. 3.1 JKO Reformulation via Optimal Push-forwards Maps Our key idea is to replace the optimization (6) over probability measures by an optimization over convex functions, an idea inspired by [11]. Thanks to Brenier s theorem, for any ρ P2,ac there exists a unique ρ(k 1)-measurable gradient ψ : RD RD of a convex function ψ satisfying ρ = ψ ρ(k 1). We set ρ = ψ ρ(k 1) and rewrite (6) as an optimization over convex ψ: ψ(k) arg min Convex ψ F( ψ ρ(k 1)) + 1 2h W2 2(ρ(k 1), ψ ρ(k 1)) . (7) To proceed to the next step of JKO scheme, we define ρ(k) def = ψ(k) ρ(k 1). Since ρ is the pushforward of ρ(k 1) by the gradient of a convex function ψ, the W2 2 term in (7) can be evaluated explicitly, simplifying the Wasserstein-2 distance term in (7): ψ(k) arg min Convex ψ F( ψ ρ(k 1)) + 1 RD x ψ(x) 2 2dρ(k 1)(x) . (8) This formulation avoids the difficulty of computing Wasserstein-2 distances. An additional advantage is that we can sample from ρ(k). Since ρ(k) = [ ψ(k) ψ(1)] ρ0, one may sample x0 ρ(0), and then ψ(k) ψ(1)(x0) gives a sample from ρ(k). Moreover, if functions ψ( ) are strictly convex, then gradients ψ( ) are invertible. In this case, the density dρ(k) dx of ρ(k) = ψ(k) ψ(1) ρ0 is computable by the change of variables formula (assuming ψ( ) are twice differentiable) dx (xk) = [det 2ψ(k)(xk 1)] 1 [det 2ψ(1)(x0)] 1 dρ(0) dx (x0), (9) where xi = ψ(i)(xi 1) for i = 1, . . . , k and dρ(0) dx is the density of ρ(0). 3.2 Stochastic Optimization for JKO via ICNNs In general, the solution ψ(k) of (8) is intractable since it requires optimization over all convex functions. To tackle this issue, [11] discretizes the space of convex function. The approach also requires discretization of measures ρ(k) limiting this method to small dimensions. We propose to parametrize the search space using input convex neural networks (ICNNs) [6] satisfying a universal approximation property among convex functions [20]. ICNNs are parametric models of the form ψθ : RD R with ψθ convex w.r.t. the input. ICNNs are constructed from neural network layers, with restrictions on the weights and activation functions to preserve the input-convexity, see [6, M3.1] or [37, MB.2]. The parameters are optimized via deep learning optimization techniques such as SGD. The JKO step then becomes finding the optimal parameters θ for ψθ: θ arg min θ F( ψθ ρ(k 1)) + 1 RD x ψθ(x) 2 2dρ(k 1)(x) . (10) If the functional F can be estimated stochastically using random batches from ρ(k 1), then SGD can be used to optimize θ. FFP given by (5) is an example of such a functional: Theorem 1 (Estimator of FFP). Let ρ P2,ac(RD) and T : RD RD be a diffeomorphism. For a random batch x1, . . . , x N ρ, the expression [c UT (x1, . . . , x N) β 1 d ET (x1, . . . , x N)], where c UT (x1, . . . , x N) def = 1 n=1 Φ T(xn) and d ET (x1, . . . , x N) def = 1 n=1 log | det T(xn)|, is an estimator of FFP(T ρ) up to constant (w.r.t. T) shift given by β 1E(ρ). Proof. c UT is a straightforward unbiased estimator for U(T ρ). Let p and p T be the densities of ρ and T ρ. Since T is a diffeomorphism, we have p T (y) = p(x) | det T(x)| 1 where x = T 1(y). Using the change of variables formula, we write RD p T (y) log p T (y)dy RD p(x) | det T(x)| 1 log p(x) | det T(x)| 1 | det T(x)|dx RD p(x) log p(x)dx + Z RD p(x) log | det T(x)|dx RD p(x) log | det T(x)|dx, = ET (ρ) def = E(T ρ) E(ρ) = Z RD log | det T(x)|dρ(x) which explains that d ET is an unbiased estimator of ET (ρ). As the result, c UT β 1 d ET is an estimator for FFP(T ρ) = U(T ρ) β 1E(T ρ) up to a shift of β 1E(ρ). To apply Theorem 1 to our case, we take T ψθ and ρ ρ(k 1) to obtain a stochastic estimator for FFP( ψθ ρ(k 1)) in (10). Here, β 1E(ρ(k 1)) is θ-independent and constant since ρ(k 1) is fixed, so the offset of the estimator plays no role in the optimization w.r.t. θ. Algorithm 1 details our stochastic JKO method for FFP. The training is done solely based on random samples from the initial measure ρ0: its density is not needed. Algorithm 1: Fokker-Planck JKO via ICNNs Input :Initial measure ρ0 accessible by samples; JKO discretization step h > 0, number of JKO steps K > 0; target potential Φ(x), diffusion process temperature β 1; batch size N; Output :trained ICNN models {ψ(k)}K k=1 representing JKO steps for k = 1, 2, . . . , K do ψθ basic ICNN model; for i = 1, 2, . . . do Sample batch Z ρ0 of size N; X ψ(k 1) ψ(1)(Z); d W2 2 1 x X ψθ(x) x 2 2; x X Φ ψθ(x) ; x X log det 2ψθ(x); b L 1 2h d W2 2 + b U β 1d E; Perform a gradient step over θ by using b L θ ; This algorithm assumes F is the Fokker-Planck diffusion energy functional. However, our method admits straightforward generalization to any F that can be stochastically estimated; studying such functionals is a promising avenue for future work. 3.3 Computing the Density of the Diffusion Process Our algorithm provides a computable density for ρ(k). As discussed in M3.1, it is possible to sample from ρ(k) while simultaneously computing the density of the samples. However, this approach does not provide a direct way to evaluate dρ(k) dx (xk) for arbitrary xk RD. We resolve this issue below. If a convex function is strongly convex, then its gradient is bijective on RD. By the change of variables formula for xk RD, it holds dρ(k) dx (xk) = dρ(k 1) dx (xk 1) [det 2ψ(k)(xk 1)] 1 where xk = ψ(k)(xk 1). To compute xk 1, one needs to solve the convex optimization problem: xk = ψ(k)(xk 1) xk 1 = arg max x RD x, xk ψ(k)(x) . (11) If we know the density of ρ0, to compute the density of ρ(k) at xk we solve k convex problems xk 1 = arg max x RD x, xk ψ(k)(x) . . . x0 = arg max x RD x, x1 ψ(1)(x) to obtain xk 1, . . . , x0 and then evaluate the density as dx (xk) = dρ0 dx (x0) k Y i=1 det 2ψ(i)(xi 1) 1. Note the steps above provide a general method for tracing back the position of a particle along the flow, and density computation is simply a byproduct. 4 Experiments In this section, we evaluate our method on toy and real-world applications. Our code is written in Py Torch and is publicly available at https://github.com/Petr Mokrov/Large-Scale-Wasserstein-Gradient-Flows The experiments are conducted on a GTX 1080Ti. In most cases, we performed several random restarts to obtain mean and variation of the considered metric. As the result, experiments require about 100-150 hours of computation. The details are given in Appendix A. Neural network architectures. In all experiments, we use the Dense ICNN [37, Appendix B.2] architecture for ψθ in Algorithm 1 with Soft Plus activations. The network ψθ is twice differentiable w.r.t. the input x and has bijective gradient ψθ : RD RD with positive semi-definite Hessian 2ψθ(x) 0 at each x. We use automatic differentiation to compute ψθ and 2ψθ. Metric. To qualitatively compare measures, we use the symmetric Kullback-Leibler divergence Sym KL(ρ1, ρ2) def = KL(ρ1 ρ2) + KL(ρ2 ρ1), (12) where KL(ρ1 ρ2) def = R dρ2 (x)dρ1(x) is the Kullback-Leibler divergence. For particle-based methods, we obtain an approximation of the distribution by kernel density estimation. 4.1 Convergence to Stationary Solution Starting from an arbitrary initial measure ρ0, an advection-diffusion process (4) converges to the unique stationary solution ρ [56] with density dρ dx (x) = Z 1 exp( βΦ(x)), (13) where Z = R RD exp( βΦ(x))dx is the normalization constant. This property makes it possible to compute the symmetric KL between the distribution to which our method converges and the ground truth, provided Z is known. 2 4 6 8 10 12 D, dimension log10Sym KL [EM] 1K [EM] 10K [EM] 50K Ours Figure 1: Sym KL between the computed and the stationary measure in D = 2, 4, . . . 12 We use N(0, 16ID) as the initial measure ρ0 and a random Gaussian mixture as the stationary measure ρ . In our method, we perform K = 40 JKO steps with step size h = 0.1. We compare with a particle simulation method (with 103, 104, 105 particles) based on the Euler-Maruyama EM approximation [36, M9.2]. We repeat the experiment 5 times and report the averaged results in Figure 1. In Figure 2, we present qualitative results of our method converging to the ground truth in D = 13, 32. Stationary measure Fitted measure (ours) (a) Dimension D = 13 10 5 0 5 10 Stationary measure 10 5 0 5 10 Fitted measure (ours) (b) Dimension D = 32 Figure 2: Projections to 2 first PCA components of the true stationary measure and the measure approximated by our method in dimensions D = 13 (on the left) and D = 32 (on the right). 4.2 Modeling Ornstein-Uhlenbeck Processes Ornstein-Uhlenbeck processes are advection-diffusion processes (4) with Φ(x) = 1 2(x b)T A(x b) for symmetric positive definite A RD D and b RD. They are among the few examples where we know ρt for any t R+ in closed form, when the initial measure ρ0 is Gaussian [67]. This allows to quantitatively evaluate the computed dynamics of the process, not just the stationary measure. We choose A, b at random and set ρ0 to be the standard Gaussian measure N(0, ID). We approximate the dynamics of the process by our method with JKO step h = 0.05 and compute Sym KL between the true ρt and the approximate one at time t = 0.5 and t = 0.9. We repeat the experiment 15 times in dimensions D = 1, 2 . . . , 12 and report the performance at in Figure 3. The baselines are EM with 103, 104, 5 104 particles, EM particle simulation endowed with the Proximal Recursion operator EM PR with 104 particles [16], and the parametric dual inference method [24] for JKO steps Dual JKO . The detailed comparison for times t = 0.1, 0.2, . . . 1 is given in Appendix C. 2 4 6 8 10 12 D, dimension log10Sym KL [Dual JKO] [EM] 1K [EM] 10K [EM] 50K [EM PR] 10K Ours (a) Time t = 0.5 2 4 6 8 10 12 D, dimension log10Sym KL [Dual JKO] [EM] 1K [EM] 10K [EM] 50K [EM PR] 10K Ours (b) Time t = 0.9 Figure 3: Sym KL values between the computed measure and the true measure ρt at t = 0.5 (on the left) and t = 0.9 (on the right) in dimensions D = 1, 2, . . . , 12. Best viewed in color. 4.3 Unnormalized Posterior Sampling in Bayesian Logistic Regression An important task in Bayesian machine learning to which our algorithm can be applied is sampling from an unnormalized posterior distribution. Given the model parameters x RD with the prior distribution p0(x) as well as the conditional density p(S|x) = QM m=1 p(sm|x) of the data S = {s1, . . . , s M}, the posterior distribution is given by p(x|S) = p(S|x)p0(x) p(S) p(S|x)p0(x) = p0(x) m=1 p(sm|x). Computing the normalization constant p(S) is in general intractable, underscoring the need for estimation methods that sample from p(S|x) given the density only up to a normalizing constant. Dataset Accuracy Log-Likelihood Ours SVGD Ours SVGD covtype 0.75 0.75 -0.515 -0.515 german 0.67 0.65 -0.6 -0.6 diabetis 0.775 0.78 -0.45 -0.46 twonorm 0.98 0.98 -0.059 -0.062 ringnorm 0.74 0.74 -0.5 -0.5 banana 0.55 0.54 -0.69 -0.69 splice 0.845 0.85 -0.36 -0.355 waveform 0.78 0.765 -0.485 -0.465 image 0.82 0.815 -0.43 -0.44 Table 1: Comparison of our method with SVGD [42] for Bayesian logistic regression. In our context, sampling from p(x|S) can be solved similarly to the task in M4.1. From (13), it follows that the advection-diffusion process with temperature β > 0 and Φ(x) = 1 β log p0(x) p(S|x) has dρ dx (x) = p(x|S) as the stationary distribution. Thus, we can use our method to approximate the diffusion process and obtain a sampler for p(x|S) as a result. The potential energy U(ρ) = R RD Φ(x)dρ(x) can be estimated efficiently by using a trick similar to the ones in stochastic gradient Langevin dynamics [70], which consists in resampling samples in S uniformly. For evaluation, we consider the Bayesian linear regression setup of [42]. We use the 8 datasets from [47]. The number of features ranges from 2 to 60 and the dataset size from 700 to 7400 data points. We also use the Covertype dataset2 with 500K data points and 54 features. The prior on regression weights w is given by p0(w|α) = N(w|0, α 1) with p0(α) = Gamma(α|1, 0.01), so the prior on parameters x = [w, α] of the model is given by p0(x) = p0(w, α) = p0(w|α) p0(α). We randomly split each dataset into train Strain and test Stest ones with ratio 4:1 and apply the inference on the posterior p(x|Strain). In Table 1, we report accuracy and log-likelihood of the predictive distribution on Stest. As the baseline, we use particle-based Stein Variational Gradient Descent [42]. We use the author s implementation with the default hyper-parameters. 4.4 Nonlinear Filtering We demonstrate the application of our method to filtering a nonlinear diffusion. In this task, we consider a diffusion process Xt governed by the Fokker-Planck equation (4). At times t1 < t2 < < t K we obtain noisy observations of the process Yk = Xtk + vk with vk N(0, σ). The goal is to compute the predictive distribution pt,X(x|Y1:K) for t t K given observations Y1:K. For each k and t tk predictive distribution pt,X(x|Y1:k) follows the diffusion process on time interval [tk, t] with initial distribution ptk,X(x|Y1:k). If tk = t then ptk,X(x|Y1:k) p(Yk|Xtk = x) ptk,X(x|Y1:k 1). (14) For k = 1, . . . , K, we sequentially obtain the predictive distribution ptk,X(x|Y1:k) by using the previous predictive distribution ptk 1,X(x|Y1:k 1). First, given access to ptk 1,X(x|Y1:k 1), we approximate the diffusion on interval [tk 1, tk] with initial distribution ptk 1,X(x|Y1:k 1) by our Algorithm 1 to get access to ptk,X(x|Y1:k 1). Next, we use (14) to get unnormalized density and Metropolis-Hastings algorithm [57] to sample from ptk,X(x|Y1:k). We give details in Appendix B. For evaluation, we consider the experimental setup of [24, M6.3]. We assume that the 1-dimensional diffusion process Xt has potential function Φ(x) = 1 π sin(2πx) + 1 4x2 which makes the process highly nonlinear. We simulate nonlinear filtering on the time interval tstart = 0 sec., tfin = 5 sec. and take the noise observations each 0.5 sec. The noise variance is σ2 = 1 and p(X0) = N(X0|0, 1). We predict the conditional density ptfinal,X(x|Y1:9) and compare the prediction with ground truth obtained with numerical integration method by Chang and Cooper [19], who use a fine discrete grid. As the baselines, we use Dual JKO [24] as well as the Bayesian Bootstrap filter BBF [27], which combines particle simulation with bootstrap resampling at observation times. We repeat the experiment 15 times. In Figure 4a, we report the Sym KL between predicted density and true p(Xtfin|Y1:9). We visually compare the fitted and true conditional distributions in Figure 4b. 5 Discussion Complexity of training and sampling. Let T be the number of operations required to evaluate ICNN ψθ(x), and assume that the evaluation of Φ(x) in the potential energy U takes O(1) time. 2https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary.html Ours [Dual JKO] [BBF] log10Sym KL Discrepancy comparision at t = 5 sec. (a) Sym KL values. 4 2 0 2 4 0.0 pt, X(x|Y1 : K) Diffusion pdfs comparison at t = 5 sec. Ours [Dual JKO] [BBF] 50000 ground truth (b) Visualized probability density functions. Figure 4: Comparison of the predicted conditional density and true p(Xtfin|Y1:9). Operation Time Complexity Eval. ψθ, ψθ, 2ψθ T, O(T), O(DT) Eval. log det 2ψθ O(DT +D3) Sample x ρ(k) O (k 1)T Eval. b L on x ρ(k) O(DT + D3) Eval. b L θ on x ρ(k) O(DT +D3) Sample x ρ(k) and Eval. dρ(k) dx (x) O (k 1)(TD+D3) Table 2: Complexity of operations in our method for computing JKO steps via ICNNs. Recall that computing the gradient is a small constant factor harder than computing the function itself [41]. Thus, evaluation of ψθ(x) : RD RD requires O(T) operations and evaluating the Hessian 2ψθ(x) : RD RD D takes O(DT) time. To compute log det 2ψθ(x), we need O(D3) extra operations. Sampling from ρ(k 1) = ψ(k 1) ψ(1) ρ0 involves pushing x0 ρ0 forward by a sequence of ICNNs ψ( ) of length k 1, requiring O (k 1)T operations. The forward pass to evaluate the JKO step objective b L in Algorithm 1 requires O(DT + D3) operations, as does the backward pass to compute the gradient b L θ w.r.t. θ. The memory complexity is more difficult to characterize, since it depends on the autodiff implementation. It does not exceed the time complexity and is linear in the number of JKO steps k. Wall-clock times. All particle-based methods considered in M4 and Dual JKO require from several seconds to several minutes CPU computation time. Our method requires from several minutes to few hours on GPU, the time is explained by the necessity to train a new network at each step. Advantages. Due to using continuous approximation, our method scales well to high dimensions, as we show in M4.1 and M4.2. After training, we can produce infinitely many samples xk ρ(k), together with their trajectories xk 1, xk 2, . . . , x0 along the gradient flow. Moreover, the densities of samples in the flow dρ(k) dx (xk), dρ(k 1) dx (xk 1), . . . , dρ(0) dx (x0) can be evaluated immediately. In contrast, particle-based and domain discretization methods do not scale well with the dimension (Figure 3) and provide no density. Interestingly, despite its parametric approximation, Dual JKO performs comparably to particle simulation and worse than ours (see additionally [24, Figure 3]). Limitations. To train k JKO steps, our method requires time proportional to k2 due to the increased complexity of sampling x ρ(k). This may be disadvantageous for training long diffusions. In addition, for very high dimensions D, exact evaluation of log det 2ψθ(x) is time-consuming. Future work. To reduce the computational complexity of sampling from ρ(k), at step k one may regress an invertible network H : RD RD [9, 31] to satisfy H(x0) ψ(k) ψ(1)(x0) and use H ρ0 ρ(k) to simplify sampling. An alternative is to use variational inference [12, 54, 71] to approximate ρ(k). To mitigate the computational complexity of computing log det ψθ(x), fast approximation can be used [66, 28]. More broadly, developing ICNNs with easily-computable exact Hessians is a critical avenue for further research as ICNNs continue to gain attention in machine learning [44, 37, 38, 30, 23, 5]. Potential impact. Diffusion processes appear in numerous scientific and industrial applications, including machine learning, finances, physics, and population dynamics. Our method will improve models in these areas, providing better scalability. Performance, however, might depend on the expressiveness of the ICNNs, pointing to theoretical convergence analysis as a key topic for future study to reinforce confidence in our model. In summary, we develop an efficient method to model diffusion processes arising in many practical tasks. We apply our method to common Bayesian tasks such as unnormalized posterior sampling (M4.3) and nonlinear filtering (M4.4). Below we mention several other potential applications: Population dynamics. In this task, one needs to recover the potential energy Φ(x) included in the Fokker-Planck free energy functional FFP based on samples from the diffusion obtained at timesteps t1, . . . , tn, see [29]. This setting can be found in computational biology, see M6.3 of [29]. A recent paper [14] utilizes ICNN-powered JKO to model population dynamics. Reinforcement learning. Wasserstein gradient flows provide a theoretically-grounded way to optimize an agent policy in reinforcement learning, see [55, 72]. The idea of the method is to maximize the expected total reward (see (10) in [72]) using the gradient flow associated with the Fokker-Planck functional (see (12) in [72]). The authors of the original paper proposed discrete particle approximation method to solve the underlying JKO scheme. Substituting their approach with our ICNN-based JKO can potentially improve the results. Refining Generative Adversarial Networks. In the GAN setting, given trained generator G and discriminator D, one can improve the samples from G by D via considering a gradient flow w.r.t. entropy-regularized f-divergence between real and generated data distribution (see [7], in particular, formula (4) for reference). Using KL-divergence makes the gradient flow consistent with our method: the functional F defining the flow has only entropic and potential energy terms. The usage of our method instead of particle simulation may improve the generator model. Molecular Discovery. In [3], in parallel to our work the JKO-ICNN scheme is proposed. The authors consider the molecular discovery as an application. The task is to increase the drug-likeness of a given distribution ρ of molecules while staying close to the original distribution ρ0. The task reduces to optimizing the functional F(ρ) = Ex ρΦ(x)+D(ρ, ρ0) for a certain potential Φ (V - in the notation of [3]) and a discrepancy D. The authors applied the JKO-ICNN method to minimize F on MOSES [53] molecular dataset and obtained promising results. ACKNOWLEDGEMENTS. Skoltech acknowledges the support of the Ministry of Science and Higher Education grant No. 075-10-2021-068. The MIT Geometric Data Processing group acknowledges the generous support of Army Research Office grants W911NF2010168 and W911NF2110293, of Air Force Office of Scientific Research award FA9550-19-1-031, of National Science Foundation grants IIS-1838071 and CHS-1955697, from the CSAIL Systems that Learn program, from the MIT IBM Watson AI Laboratory, from the Toyota CSAIL Joint Research Center, from a gift from Adobe Systems, from an MIT.nano Immersion Lab/NCSOFT Gaming Program seed grant, and from the Skoltech MIT Next Generation Program. [1] Juha Ala-Luhtala, Simo Särkkä, and Robert Piché. Gaussian filtering and variational approximations for Bayesian smoothing in continuous-discrete stochastic dynamic systems. Signal Processing, 111:124 136, 2015. [2] David Alvarez-Melis and Nicolò Fusi. Gradient flows in dataset space. ar Xiv preprint ar Xiv:2010.12760, 2020. [3] David Alvarez-Melis, Yair Schiff, and Youssef Mroueh. Optimizing functionals on the space of probabilities with input convex neural networks. ar Xiv preprint ar Xiv:2106.00774, 2021. [4] Luigi Ambrosio, Nicola Gigli, and Giuseppe Savaré. Gradient flows: in metric spaces and in the space of probability measures. Springer Science & Business Media, 2008. [5] Brandon Amos and J Zico Kolter. Optnet: Differentiable optimization as a layer in neural networks. In International Conference on Machine Learning, pages 136 145. PMLR, 2017. [6] Brandon Amos, Lei Xu, and J Zico Kolter. Input convex neural networks. In International Conference on Machine Learning, pages 146 155. PMLR, 2017. [7] Abdul Fatir Ansari, Ming Liang Ang, and Harold Soh. Refining deep generative models via Wasserstein gradient flows. ar Xiv preprint ar Xiv:2012.00780, 2020. [8] Michael Arbel, Anna Korba, Adil Salim, and Arthur Gretton. Maximum mean discrepancy gradient flow. ar Xiv preprint ar Xiv:1906.04370, 2019. [9] Lynton Ardizzone, Jakob Kruse, Sebastian Wirkert, Daniel Rahner, Eric W Pellegrini, Ralf S Klessen, Lena Maier-Hein, Carsten Rother, and Ullrich Köthe. Analyzing inverse problems with invertible neural networks. ar Xiv preprint ar Xiv:1808.04730, 2018. [10] Genevay Aude, Marco Cuturi, Gabriel Peyré, and Francis Bach. Stochastic optimization for large-scale optimal transport. ar Xiv preprint ar Xiv:1605.08527, 2016. [11] Jean-David Benamou, Guillaume Carlier, Quentin Mérigot, and Edouard Oudet. Discretization of functionals involving the Monge Ampère operator. Numerische mathematik, 134(3):611 636, 2016. [12] David M Blei, Alp Kucukelbir, and Jon D Mc Auliffe. Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518):859 877, 2017. [13] Yann Brenier. Polar factorization and monotone rearrangement of vector-valued functions. Communications on pure and applied mathematics, 44(4):375 417, 1991. [14] Charlotte Bunne, Laetitia Meng-Papaxanthos, Andreas Krause, and Marco Cuturi. Jkonet: Proximal optimal transport modeling of population dynamics. ar Xiv preprint ar Xiv:2106.06345, 2021. [15] Martin Burger, José A Carrillo, and Marie-Therese Wolfram. A mixed finite element method for nonlinear diffusion equations. Kinetic & Related Models, 3(1):59, 2010. [16] Kenneth F. Caluya and Abhishek Halder. Proximal recursion for solving the fokker-planck equation, 2019. [17] José A Carrillo, Alina Chertock, and Yanghong Huang. A finite-volume method for nonlinear nonlocal equations with a gradient flow structure. Communications in Computational Physics, 17(1):233 258, 2015. [18] JS Chang and G Cooper. A practical difference scheme for fokker-planck equations. Journal of Computational Physics, 6(1):1 16, 1970. [19] J.S Chang and G Cooper. A practical difference scheme for fokker-planck equations. Journal of Computational Physics, 6(1):1 16, 1970. [20] Yize Chen, Yuanyuan Shi, and Baosen Zhang. Optimal control via neural networks: A convex approach. ar Xiv preprint ar Xiv:1805.11835, 2018. [21] Arnaud Doucet and Adam M Johansen. A tutorial on particle filtering and smoothing: Fifteen years later. Handbook of nonlinear filtering, 12(656-704):3, 2009. [22] Nicole El Karoui, Shige Peng, and Marie Claire Quenez. Backward stochastic differential equations in finance. Mathematical finance, 7(1):1 71, 1997. [23] Jiaojiao Fan, Amirhossein Taghvaei, and Yongxin Chen. Scalable computations of wasserstein barycenter via input convex neural networks. ar Xiv preprint ar Xiv:2007.04462, 2020. [24] Charlie Frogner and Tomaso Poggio. Approximate inference with Wasserstein gradient flows. In International Conference on Artificial Intelligence and Statistics, pages 2581 2590. PMLR, 2020. [25] Yuan Gao, Yuling Jiao, Yang Wang, Yao Wang, Can Yang, and Shunkang Zhang. Deep generative learning via variational gradient flow. In International Conference on Machine Learning, pages 2093 2101. PMLR, 2019. [26] Yuan Gao, Guangzhen Jin, and Jian-Guo Liu. Inbetweening auto-animation via fokker-planck dynamics and thresholding. ar Xiv preprint ar Xiv:2005.08858, 2020. [27] N. Gordon, D. Salmond, and A. Smith. Novel approach to nonlinear/non-Gaussian Bayesian state estimation. 1993. [28] Insu Han, Dmitry Malioutov, and Jinwoo Shin. Large-scale log-determinant computation through stochastic chebyshev expansions. In International Conference on Machine Learning, pages 908 917. PMLR, 2015. [29] Tatsunori Hashimoto, David Gifford, and Tommi Jaakkola. Learning population-level diffusions with generative rnns. In International Conference on Machine Learning, pages 2417 2426. PMLR, 2016. [30] Chin-Wei Huang, Ricky TQ Chen, Christos Tsirigotis, and Aaron Courville. Convex potential flows: Universal probability distributions with optimal transport and convex optimization. ar Xiv preprint ar Xiv:2012.05942, 2020. [31] Jörn-Henrik Jacobsen, Arnold Smeulders, and Edouard Oyallon. i-revnet: Deep invertible networks. ar Xiv preprint ar Xiv:1802.07088, 2018. [32] Richard Jordan, David Kinderlehrer, and Felix Otto. The variational formulation of the fokker planck equation. SIAM journal on mathematical analysis, 29(1):1 17, 1998. [33] Simon J Julier, Jeffrey K Uhlmann, and Hugh F Durrant-Whyte. A new approach for filtering nonlinear systems. In Proceedings of 1995 American Control Conference-ACC 95, volume 3, pages 1628 1632. IEEE, 1995. [34] Rudolph E Kalman and Richard S Bucy. New results in linear filtering and prediction theory. 1961. [35] Søren Klim, Stig Bousgaard Mortensen, Niels Rode Kristensen, Rune Viig Overgaard, and Henrik Madsen. Population stochastic modelling (psm) an r package for mixed-effects models based on stochastic differential equations. Computer methods and programs in biomedicine, 94(3):279 289, 2009. [36] Peter E. Kloeden. Numerical solution of stochastic differential equations / Peter E. Kloeden, Eckhard Platen. Applications of mathematics; v. 23. Springer, Berlin, 1992. [37] Alexander Korotin, Vage Egiazarian, Arip Asadulaev, Alexander Safin, and Evgeny Burnaev. Wasserstein-2 generative networks. In International Conference on Learning Representations, 2021. [38] Alexander Korotin, Lingxiao Li, Justin Solomon, and Evgeny Burnaev. Continuous wasserstein2 barycenter estimation without minimax optimization. In International Conference on Learning Representations, 2021. [39] Harold Kushner. Approximations to optimal nonlinear filters. IEEE Transactions on Automatic Control, 12(5):546 556, 1967. [40] Hugo Lavenant, Sebastian Claici, Edward Chien, and Justin Solomon. Dynamical optimal transport on discrete surfaces. ACM Transactions on Graphics (TOG), 37(6):1 16, 2018. [41] Seppo Linnainmaa. The representation of the cumulative rounding error of an algorithm as a taylor expansion of the local rounding errors. Master s Thesis (in Finnish), Univ. Helsinki, pages 6 7, 1970. [42] Qiang Liu and Dilin Wang. Stein variational gradient descent: A general purpose Bayesian inference algorithm. ar Xiv preprint ar Xiv:1608.04471, 2016. [43] Antoine Liutkus, Umut Simsekli, Szymon Majewski, Alain Durmus, and Fabian-Robert Stöter. Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions. In International Conference on Machine Learning, pages 4104 4113. PMLR, 2019. [44] Ashok Makkuva, Amirhossein Taghvaei, Sewoong Oh, and Jason Lee. Optimal transport mapping via input convex neural networks. In International Conference on Machine Learning, pages 6672 6681. PMLR, 2020. [45] Bertrand Maury, Aude Roudneff-Chupin, and Filippo Santambrogio. A macroscopic crowd motion model of gradient flow type. Mathematical Models and Methods in Applied Sciences, 20(10):1787 1821, 2010. [46] Robert J Mc Cann et al. Existence and uniqueness of monotone measure-preserving maps. Duke Mathematical Journal, 80(2):309 324, 1995. [47] Sebastian Mika, Gunnar Ratsch, Jason Weston, Bernhard Scholkopf, and Klaus-Robert Mullers. Fisher discriminant analysis with kernels. In Neural networks for signal processing IX: Proceedings of the 1999 IEEE signal processing society workshop (cat. no. 98th8468), pages 41 48. Ieee, 1999. [48] Manfred Opper. Variational inference for stochastic differential equations. Annalen der Physik, 531(3):1800233, 2019. [49] Lorenzo Pareschi and Mattia Zanella. Structure preserving schemes for nonlinear fokker planck equations and applications. Journal of Scientific Computing, 74(3):1575 1600, 2018. [50] Gabriel Peyré. Entropic approximation of Wasserstein gradient flows. SIAM Journal on Imaging Sciences, 8(4):2323 2351, 2015. [51] Gabriel Peyré, Marco Cuturi, et al. Computational optimal transport: With applications to data science. Foundations and Trends in Machine Learning, 11(5-6):355 607, 2019. [52] Eckhard Platen and Nicola Bruti-Liberati. Numerical solution of stochastic differential equations with jumps in finance, volume 64. Springer Science & Business Media, 2010. [53] Daniil Polykovskiy, Alexander Zhebrak, Benjamin Sanchez-Lengeling, Sergey Golovanov, Oktai Tatanov, Stanislav Belyaev, Rauf Kurbanov, Aleksey Artamonov, Vladimir Aladinskiy, Mark Veselov, Artur Kadurin, Simon Johansson, Hongming Chen, Sergey Nikolenko, Alán Aspuru-Guzik, and Alex Zhavoronkov. Molecular sets (MOSES): A benchmarking platform for molecular generation models. Frontiers in Pharmacology, 11:1931, 2020. [54] Danilo Rezende and Shakir Mohamed. Variational inference with normalizing flows. In International Conference on Machine Learning, pages 1530 1538. PMLR, 2015. [55] Pierre H Richemond and Brendan Maginnis. On Wasserstein reinforcement learning and the Fokker-Planck equation. ar Xiv preprint ar Xiv:1712.07185, 2017. [56] Hannes. Risken. The Fokker-Planck Equation: Methods of Solution and Applications (Springer Series in Synergetics). Springer 1996. [57] Christian P Robert and George Casella. The Metropolis Hastings algorithm. In Monte Carlo Statistical Methods, pages 231 283. Springer, 1999. [58] Filippo Santambrogio. Gradient flows in Wasserstein spaces and applications to crowd movement. Séminaire Équations aux dérivées partielles (Polytechnique), pages 1 16, 2010. [59] Filippo Santambrogio. Optimal transport for applied mathematicians. Birkäuser, NY, 55(5863):94, 2015. [60] Filippo Santambrogio. Euclidean, Metric, and Wasserstein gradient flows: an overview, 2016. [61] Simo Sarkka. On unscented kalman filtering for state estimation of continuous-time nonlinear systems. IEEE Transactions on automatic control, 52(9):1631 1641, 2007. [62] Vivien Seguy, Bharath Bhushan Damodaran, Rémi Flamary, Nicolas Courty, Antoine Rolet, and Mathieu Blondel. Large-scale optimal transport and mapping estimation. ar Xiv preprint ar Xiv:1711.02283, 2017. [63] Kazimierz Sobczyk. Stochastic differential equations: with applications to physics and engineering, volume 40. Springer Science & Business Media, 2013. [64] Tobias Sutter, Arnab Ganguly, and Heinz Koeppl. A variational approach to path estimation and parameter inference of hidden diffusion processes. The Journal of Machine Learning Research, 17(1):6544 6580, 2016. [65] Gabriel Terejanu, Puneet Singla, Tarunraj Singh, and Peter D Scott. A novel gaussian sum filter method for accurate solution to the nonlinear filtering problem. In 2008 11th International Conference on Information Fusion, pages 1 8. IEEE, 2008. [66] Shashanka Ubaru, Jie Chen, and Yousef Saad. Fast estimation of tr(f(a)) via stochastic lanczos quadrature. SIAM Journal on Matrix Analysis and Applications, 38(4):1075 1099, 2017. [67] P Vatiwutipong and N Phewchean. Alternative way to derive the distribution of the multivariate ornstein uhlenbeck process. Advances in Difference Equations, 2019(1):1 7, 2019. [68] Cédric Villani. Optimal transport: old and new, volume 338. Springer Science & Business Media, 2008. [69] Michail D Vrettas, Manfred Opper, and Dan Cornford. Variational mean-field algorithm for efficient inference in large systems of stochastic differential equations. Physical Review E, 91(1):012148, 2015. [70] Max Welling and Yee W Teh. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), pages 681 688. Citeseer, 2011. [71] Cheng Zhang, Judith Bütepage, Hedvig Kjellström, and Stephan Mandt. Advances in variational inference. IEEE transactions on pattern analysis and machine intelligence, 41(8):2008 2026, 2018. [72] Ruiyi Zhang, Changyou Chen, Chunyuan Li, and Lawrence Carin. Policy optimization as Wasserstein gradient flows. In International Conference on Machine Learning, pages 5737 5746. PMLR, 2018.