# deep_momentum_multimarginal_schrödinger_bridge__ef0d9bf5.pdf Deep Multi-Marginal Momentum Schrödinger Bridge Tianrong Chen, Guan-horng Liu, Molei Tao, Evangelos A. Theodorou Georgia Institute of Technology, USA {tianrong.chen,ghliu, mtao, evangelos.theodorou}@gatech.edu It is a crucial challenge to reconstruct population dynamics using unlabeled samples from distributions at coarse time intervals. Recent approaches such as flowbased models or Schrödinger Bridge (SB) models have demonstrated appealing performance, yet the inferred sample trajectories either fail to account for the underlying stochasticity or are unnecessarily rigid. In this article, We extend the approach in [1] to operate in continuous space and propose Deep Momentum Multi-Marginal Schrödinger Bridge (DMSB), a novel computational framework that learns the smooth measure-valued spline for stochastic systems that satisfy position marginal constraints across time. By tailoring the celebrated Bregman Iteration and extending the Iteration Proportional Fitting to phase space, we manage to handle high-dimensional multi-marginal trajectory inference tasks efficiently. Our algorithm outperforms baselines significantly, as evidenced by experiments for synthetic datasets and a real-world single-cell RNA sequence dataset. Additionally, the proposed approach can reasonably reconstruct the evolution of velocity distribution, from position snapshots only, when there is a ground truth velocity that is nevertheless inaccessible. 1 Introduction We consider the multi-marginal trajectory inference problem, which pertains to elucidating the dynamics and reactions of indiscernible individuals, given static snapshots of them taken at sporadic time points. Due to the inability of tracking each individual, one considers the evolution of the statistical distribution of the population instead. This problem received considerable attention, and associated applications appear in various scientific areas such as estimating cell dynamics [2, 3], predicting meteorological evolution [4], and medical healthcare statistics tracking [5]. [6, 7] constructed an energy landscape that best aligned with empirical observations using neural network. [8, 9] learn regularized Neural ODE [10] to encode such potential landscape. Notably, in the aforementioned work, the trajectory of samples is represented in a deterministic way. In contrast, [11, 12] employ Schrödinger Bridge (SB) to determine the most likely evolution of samples between marginal distributions when individual sample trajectories are also affected by environmental stochasticity. Yet, these approaches scale poorly w.r.t. the state dimension due to specialized neural network architectures and computational frameworks. SB can be viewed as a solution to the entropy-regularized optimal transport problem. SB seeks a nonlinear SDE that yields a straight path measure between two arbitary distributions. The straightness is implied by achieving optimality of minimizing transportation costs (i.e. 2-Wasserstein distance (W2)). We note SB is often related to Score-based Generated Model (SGM), both of which can be used for generative modeling by constructing certain Stochastic Differential Equation (SDE) that links data distribution and a tractable prior distribution (i.e. 2 marginals). SGM accomplishes the generative task by first diffusing data to prior through a pre-specified linear SDE, during which a neural network is also learned to approximate the score function. Then this score approximator is used to reverse this diffusion process, and consequently establish the generation. Critically-damped 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Table 1: Comparison between different models in terms of optimality and boundary distributions p0 and p1. Our DMSB extends standard SB, which generalizes SGM beyond Gaussian priors, to phase space, similar to CLD. However, unlike CLD, DMSB jointly learns the phase space distributions, i.e., pθ(x, v) = p A(x)qθ(v|x) and pϕ(x, v) = p B(x)qϕ(v|x). In other words, DMSB infers the underlying phase state dynamics given only state distributions. Models Optimality p0( ) p1( ) SGM [20] p A(x) N(0, Σ) CLD [13] p A(x) N(0, Σ) N(0, Σ) N(0, Σ) SB [14] W2 kinks p A(x) p B(x) DMSB (ours) W2 smooth p A(x)qθ(v|x) p B(x)qϕ(v|x) Langevin Diffusion (CLD) [13] extends the SGM SDE to the phase space by introducing an auxiliary velocity variable with a tractable Gaussian distribution at both the initial and terminal time. The resulting trajectory in the position space becomes smoother, as stochasticity is only injected into the velocity space, and the empirical performance and sample efficiency are enhanced due to the structure of the critical damped SDE. The connection between SGM and SB has been elaborated in [14, 15] and scalable mean matching Iterative Proportional Fitting algorithm (IPF) is proposed to estimate SB efficiently in high dimensional cases. Applications of SB, such as image-to-image transformation [16, 17], RNA trajectory inference [11], solving Mean Field Game[18], Riemannian interpolation [19], demonstrate the effectiveness of SB in various domains. In this work, we start with SB in phase space (termed momentum SB, m SB in short), and then further investigate m SB with multiple empirical marginal constraints present in the position space, which was formulated as multi-marginal m SB (mmm SB) in [1]. This circumvents the need for expensive space discretization which does not scale well to high dimensions. We also address the challenge of intricate geometric averaging in continuous space setup by strategically partitioning and reorganizing the constraint sets. Furthermore, we enhance the algorithm s computational efficiency by incorporating the method of half-bridge IPF. The optimality of transportation cost in SB leads to straight trajectories, and if one solves N 2-marginal SB problems and connect the resulting trajectories to match N+1 marginals, the connected trajectories will have kinks at all connection points. On the contrary, in mmm SB, the optimality of transportation cost leads to a smooth measure-spline over the state space that also interpolates the empirical marginals. Therefore, this approach is highly suitable for problems originated from physical systems and/or those that should have smooth trajectories, such as trajectory inference in single-cell RNA sequencing. Our research will emphasize on solving mmm SB efficiently in high-dimensions (thus the approach will differ from that in the seminal work [1]; see Sec.4). The differences between our algorithm and prior work are demonstrated in Table.1, and the main contributions of our work are fourfold: We extend the mean matching IPF to phase space allowing for scalable m SB computing. We introduce and tailor the Bregman Iteration [21] for mmm SB which makes it compatible with the phase space mean matching objective, thus the efficient computation is activated for high dimensional mmm SB. We show how to overcome the challenge of sampling the velocity variable when it is not available in training data, which enhances the applicability of our model. We show the performance of proposed algorithm DMSB on toy datasets which contains intricate bifurcations and merge. On realistic high-dimension (100-D) single-cell RNA-seq (sc RNA-seq) datasets, DMSB outperforms baselines by a significant margin in terms of the quality of the generated trajectory both visually and quantitatively. We show that DMSB is able to capture reasonable velocity distribution compared with ground truth while other baselines fail. 2 Preliminary 2.1 Dynamical Schrödinger Bridge problem Dynamical Schrödinger Bridge problem has been extensively studied in the past few decades. The objective of the SB problem is to solve the following optimization problem: min π Π(ρ0,ρT ) DKL (π||ξ) , (1) where π Π(ρ0, ρT ) belongs to a set of path measures with its marginal densities at t = 0 and T being ρ0 and ρT . ξ is the reference path measure (i.e., [14] sets ξ as Wiener process from ρ0). The optimality of the problem (1) is characterized by a set of PDEs (3). Theorem 2.1 ([22]). The optimal path measure π in the problem (1) is represented by forward and backward stochastic processes dxt = [2 x log Ψt]dt + 2 dwt, x0 ρ0, (2a) dxt = [ 2 x log bΨt]dt + 2 dbwt, x T ρT . (2b) in which Ψ, bΨ C1,2 are the solutions to the following coupled PDEs, t = Ψt, bΨt s.t. Ψ(0, )bΨ(0, ) = ρ0( ), Ψ(T, )bΨ(T, ) = ρT ( ), (3) The stochastic processes of SB in (2a) and (2b) are equivalent in the sense of t [0, T], p(2a) t p(2b) t p SB t . Here p SB t stands for the marginal distribution of SB at time t, which also represents the marginal density of stochastic process induced by either of Eq.2. The potentials Ψt and bΨt explicitly represent the solution of Fokker-Plank Equation (FPE) and Hamilton Jacobi Bellman equation (HJB) after exponential transform [14] where FPE describes the evolution of samples density and HJB represents for the optimality of Eq.1. Furthermore, the marginal density also obeys a factorization of p SB t = Ψt bΨt. Such rich structures of SB will later on be used to construct the log-likelihood objective (Thm.B.1) and Langevin sampler for velocity ( 4.4). To solve SB, prior work have primarily used the half-bridge optimization technique, also known as Iterative Proportional Fitting (IPF), in which one iteratively solves the optimization problem with one of the two boundary conditions [14, 15, 23], π(d+1) := arg min π Π( ,ρ1) DKL(π||π(d)) π(d+2) := arg min π Π(ρ0, ) DKL(π||π(d+1)) (4) with initial path measure π(0) := ξ. By repeatedly iterating over aforementioned optimizations until the algorithm converges, the SB solution will be attained as πSB limd π(d) [24]. In addition, [25] shows that the drift term in SB problem can also be interpreted as the solution Stochastic Optimal Control (SOC) problem by having optimal control policy z = 2 x log Ψ(t, xt): z (x) arg min z Z E s.t dxt = ztdt + 2 dwt x0 ρ0, x1 ρT . This formulation will be used later on for constructing phase space likelihood objective function in 3. Regarding solving the half-bridge problem, abundant results exist in the literature for the vanilla SB described above [14, 15, 23], but we will be solving a different SB problem; see Prop.4.1 for formulation and .4 for a solution. 2.2 Bregman Iterations for Multiple Constraints Bregman iteration [21] can be viewed as a multiple marginal generalization of IPF, and it is widely used to solve entropy regularized optimal transport problem [1] with multiple constraints. The algorithm can efficiently solve problems in the form of, inf π K KL (π|ξ) , Figure 1: A summary of various SB problems and corresponding algorithms. The toy example in the 3rd row illustrates that vanilla SB determines straight paths (modulo fluctuations due to noise) between pairwise empirical marginals, while our multi-marginal momentum SB approach establishes a smooth measure-spline between marginals in the position space (albeit still stochastic, the path is smooth between any pair of adjacent 2 marginals, because noise is added to velocity, and the path is also smooth across different pairs of adjacent 2 marginals per design. where K is the intersection of multiple closed convex constraint sets Kl: K = L l=1Kl. Bregman Projection (BP) is defined as optimization w.r.t one of the constraint Kl, P KL Kl (ξ) := arg min π Kl KL(π|ξ), and d-th Bregman Iteration (BI) is recursively computing BP over all the constraints in K: 0 < n L, π(d,n) := P KL Kn l (π(d,n 1)), The initial condition for (d + 1)-th BI is π(d+1,0) = π(d,L). Under certain conditions (see e.g., [24]), one has that π(d,L) converges to the unique solution: π(d,L) P KL K (ξ) as d + Remark 2.2. One BI traverses all constraints via multiple BPs, and each BP solves an optimization problem with one constraint.One can notice that the BI will become the aforementioned IPF procedure solving SB problem (1) by defining L = 2, K1 = Π(ρ0, ), K2 = Π( , ρ1). Table 2: Mathematical notation. Notation Definition x position variable v velocity variable m concatenation of [x, v]T Notation Definition ρ position distribution ρ(x) γ velocity Distribution γ(v) µ distribution of µ(x, v) 3 Momentum Schrödinger Bridge We first describe how to conduct half-bridge IPF training in the phase space, which can be used to solve momentum SB (m SB) problem with two marginals constraints. This scalable phase space half-bridge technique will then be applied to multi-marginal cases (Sec.4). Fig.1 demonstrates how we develop an algorithm based on [14]. Notations used in following sections are listed in Table.2. m SB extends SB problem to phase space, which consists of both position and velocity. We will first consider boundary distributions that depend on both x and v, although eventually we will use this as a module to find transport maps between two distributions that only depend on position x, as velocity v is an auxiliary variable artificially introduced for obtaining smooth transport. Conceptually, as an entropy regularized optimal transport problem, SB tries to obtain the straightest path between empirical marginals of positions x with additive noise, but m SB aims at finding the smooth interpolation between empirical marginals of x [26] conditioned on boundary velocity distributions (see Fig.1). Such smooth measure-valued splines in the position space are obtained by the optimization problem in the phase space [1]: min π Π(µ0,µT ) KL(π|ξ) s.t π = Law(x, v) : dxt dvt | {z } f(v,t) dt + 0 0 0 gt | {z } g(t) | {z } Z(t) dt + 0 0 0 gt | {z } g(t) Similar to Theorem 2.1, one can derive a set of PDEs using the potential functions Ψ(t, x, v) and bΨ(t, x, v), and subsequently apply IPF procedure to solve the problem. The formulation of the phase space PDE can be found in Appendix.B.2. Such PDE representation of m SB results in a straightforward yet innovative log-likelihood training that enables efficient optimization of the IPF. Proposition 3.1 (likelihood bound). The half-bridge IPF in phase space π(d+1) := arg min π Π(µ0, ) DKL(π||π(d)) π(d+2) := arg min π Π( ,µT ) DKL(π||π(d+1)) represents the bound of the likelihood and gives approximate likelihood training: Zt := arg min Zt log p(m0, 0) b Zt := arg min b Zt log p(m T , T). where log p(m0, 0) Z T 2 bzt + zt g v log ˆpt 2 dt. and bmt samples from: d bmt = h f gb Zt i dt + g(t)dwt, bm T µT (5) b Zt = 0 bzt and bpt is the density of path measure induced by eq.5 at time t. A similar result for log(m T , T) can be obtained in a similar derivation. Proof. See Appendix B.1. Remark 3.2. After optimizing b Zt, the reference path measure becomes eq.5, which implies π Π( , µT ), i.e., the constraint in half-bridge IPF is satisfied. A path measure π is induced by either Zt or b Zt. As being mentioned in Remark.2.2. One half-bridge IPF is basically one BP and one IPF is one BI. Prop.3.1 provides a convenient way to perform one BP in the form of π := arg minπ Kl DKL(π|| π) by maximizing log-likelihood given constraint K and reference path measure π. Prop.3.1 provides an alternative way to conduct the BI which will be heavily used in mmm SB 3, and it is computationally efficient after parameterizing and discretization ( 4.4). 4 Deep Momentum Multi-Marginal Schrödinger Bridge We first state the problem formulation of momentum multi-marginal Schrödinger Bridge (mmm SB). Different from previous two marginals case, we consider the scenario where N + 1 probability measures µti are lying at time ti. In addition, velocity distributions are not necessarily known. Proposition 4.1 ([1]). The dynamical mmm SB with multiple marginal constraints reads: min π J (π) := i=0 KL πti:ti+1|ξti:ti+1 , s.t π K := N i=0Kti (6) Figure 2: The procedure details the Bregman Iteration (BI) employed in DMSB. The gray and blue blocks represent the BP step performed under Kboundary constraint for forward and backward policies, respectively. The red block signifies the BP step executed under the Kbridge constraint. Algorithms for training and sampling can be found in Appendix.D. where: Kt0 = Z πt0:t1dmt1 = µt0, Z µt0dvt0 = ρt0 Kt N = Z πt N 1:t N dmt N 1 = µt N , Z µt N dvt N = ρt N Kti = Z πti:tt+1dmti+1 = µti, Z πti 1:tidmti 1 = µti, Z µtidvti = ρti and K is the intersection of close convex set of Kti. The problem described in Prop.4.1 can be solved by classical BI algorithm integrated with Sinkhorn method [1]. However, due to the curse of dimensionality and unfavorable geometric explicit solution, the BP cannot be applied in high-dimensional and continuous state space directly. To tackles these difficulties, we parameterize the forward and backward policies zt and bzt by a pair of neural networks. We further decouple and resemble the constraints by which it enables the scalable likelihood IPF and avoids the geometric averaging issue under mmm SB context. 4.1 Decoupling and Reassembling Constraints We decompose the constraint set (7) by Kti = 2 r=0Kr ti, where K0 ti = R πti:tt+1dmti+1 = ˆµti, R ˆµtidvti = ρti K1 ti = R πti 1:tidmti 1 = µti, R µtidvti = ρti K2 ti = R πti:tt+1dmti+1 = R πti 1:tidmti 1 . (8) One can notice that the K0 ti and K1 ti share similar structure as simpler boundary marginal conditions Kt0 and Kt N , hence we can get rid of the notorious geometric averaging (see 4 in [1]). Notably, this type of constraint provides an opportunity to utilize Proposition 3.1 for optimization, but the joint distribution of x and v is still absent. We classify the constraints into two categories: Kboundary = N 1 i=1 Kr ti Kt0 Kt N | r {0, 1} , Kbridge = N 1 i=1 K2 ti . By following BI ( 2.2), we execute optimization w.r.t. (6) while projecting the solution to subset of Kboundary or Kbridge iteratively. The sketch can be found in Fig.2. The next sections will provide more details on obtaining the joint distribution µ and optimizing within each constraint set. Hereafter, we only demonstrate the optimization for forward policy zt given reference path measure π driven by fixed backward policy bzt. The procedure can be applied for the bzt and vice versa. 4.2 Optimization in set Kboundary We first show how to optimize forward policy zt w.r.t. objective function (6) given the reference path measure π driven by fixed backward policy bzt under one subset of Kboundary. Proposition 4.2 (Optimality w.r.t. Kboundary). Given the reference path measure π driven by the backward policy bzt from boundary µti+1 in the reverse time direction, the optimal path measure in the forward time direction of the following problem min π J (π) := i=0 KL πti:ti+1| πti:ti+1 , s.t π Z πti:ti+1dmti+1 = µti, Z µtidvti = ρti is : π ti:ti+1 = ρti πti:ti+1 R πti:ti+1dmti+1dvti . When πti:ti+1 π ti:ti+1, the following equations need to hold t [ti, ti+1]: zt + bzt g v log ˆpt 2 2 = 0, (9a) pti(vti|xti) ˆqti(vti|xti), (9b) where ˆpt and ˆqt denote the marginal density and conditional velocity distribution of the reference path measure at time t, respectively. Proof. See appendix.B.5 Remark 4.3. When the ground truth distributions of velocity γti are available, one can simply sample from γti since the joint distribution µt is available in this case. In order to matching the reference path measure in KL divergence sense, one needs to match both the intermediate path measure eq.9a and the boundary condition eq.9b. In the traditional two-boundary SB case, matching the boundary condition is often disregarded due to either having a predefined data distribution or a tractable prior. However, in our specific case, as the velocity is not predefined, it becomes imperative to address this issue and optimize it through the application of Langevin dynamics. 4.3 Optimization in set Kbridge The formulation of optimization under Kbridge is similar to the previous section but differs by the boundary condition (eq.10b): Proposition 4.4 (Optimality w.r.t. Kbridge). Given the reference path measure π driven by the backward policy bzt from boundary µt N in the reverse time direction, the optimal path measure in the forward time direction of the following problem min π J (π) := i=0 KL πti:ti+1| πti:ti+1 , s.t π Kbridge = N 1 i=1 K2 ti is: π t0:t N = qt0 πt0:t N R πt0:t N dmt N dvt0 . when πt0:t N π t0:t N , the following equations need to hold t [t0, t N]: zt + bzt g v log ˆpt 2 2 = 0 (10a) pt0(vt0, xt0) ˆqt0(vt0, xt0) (10b) Proof. See appendix.B.6 Conceptually, the above optimization objective with Kbridge constraint aims at finding a continuous path measure close to reference path measure π while any intermediate marginals constraints will not be considered. The boundary condition of reference path measure in the next iteration pt0(vt0, xt0) is determined by eq.10b. Fortunately, the empirical samples from this distribution are available, though the analytic representation of the distribution ˆqt0(vt0, xt0) is unknown. Hence we can utilize these samples as empirical sources from boundary distribution ˆqt0(vt0, xt0) for the next BP. For further explanation and intuition, one can find it in Appendix.G 4.4 Parameterization and Training Objective Function Inspired by the success of prior work [14], we parameterize path measure π by forward policy zθ t or backward policy bzϕ t combined with one of constraints in Kboundary or Kbridge (see Fig.8 in Appendix for visualization). We adopt Euler Maruyama discretization and denote the timestep as δt. Notably, eq.9b and eq.10b can be implied by minimizing phase space NLL in Prop.3.1. This leads to the following objective function, termed as phase space mean matching objective, which will be used to train neural networks that represent zθ t and bzϕ t after time discretization: LMM = E h ||δtzθ t(mt+δt) + δtbzϕ t+δt(mt+δt) mt + δtzθ t mt+δt ||2i . The velocity boundary condition for the reference path measure in the succeeding BP is encoded in eq.9b or eq.10b, but the representation of conditional distribution eq.9b is not clear. We leverage the favorable property of SB to parameterize and sample from such distribution. t = 0.00(constraint ρ0) t = 0.66 t = 1.32(constraint ρ1) t = 1.99 t = 2.66(constraint ρ2) t = 3.33 t = 4.00(constraint ρ3) Figure 4: Validation of our DMSB model on complex GMM synthetic dataset. The velocity and position of the same sample correspond to the same shade level. Upper: Samples evolution in the position space. Bottom: Learnt samples evolution in the velocity space. Proposition 4.5 ([27, 28]). If πθ and πϕ shares same path measure, then pθ,ϕ ti (vti, xti) qϕ ti(vti, xti) qϕ ti(vti|xti), where: v log pθ,ϕ t = zθ t + bzϕ t /g. (11) Prop.4.5 suggests that one can use pti(vti|xti) := pθ,ϕ ti to imply condition (9b) and obtain samples from such distribution by simulating Langevin dynamics. Namely, we first sample position from ground truth xti ρti, and then sample vti pθ,ϕ t using eq.11. One can further adopt the same regularization [29] to enforce the condition of Prop.4.5. 4.5 Training Scheme Here we introduce the scheme to traverse BI (see Fig.2). In one BI, all constraints must be iterated once. For the sake of LMM, the reference path measure should be induced by opposite direction. A single BI cannot be recursively repeated due to the conflict of reference path measure direction. For example (see Fig.2), at the end of d-th BI, π is yielded by forward policy while the first BP of d-th BI is also optimizing forward policy which violates LMM. Instead, we reschedule the optimization order. Specifically, in (d + 1)-th BI, we optimize backward policy at the first BP and the last BP. 5 Experiments Setups: We test DMSB on 2D synthetic datasets and real-world sc RNA-seq dataset [30]. We choose state of the art algorithms MIOFlow [9] and NLSB [11] as our baselines. We tune both models to the best of our hardware capacity. We choose Sliced-Wasserstein Distance (SWD)[31] and Maximum Mean Discrepancy (MMD)[32] together with visualization as our criterion. The detailed setup of training and evaluation can be found in Appendix.C. Synthetic Datasets: The Petal [9] and Gaussian Mixture Model (GMM) dataset are simple yet challenging, as they mimic natural dynamics arising in cellular differentiation, including bifurcations and merges. We compare our algorithm with MIOFlow in Fig.3. DMSB can infer trajectories aligned with ground truth distribution more faithfully at timesteps when snapshots are taken. DMSB(ours) MIOFlow[8] Ground Truth t0 t1 t2 t3 t4 Figure 3: Comparsion with MIOFlow and ground truth on challenging petal dataset. DMSB is able to generate trajectories whose time marginal matches ground truth faithfully and outperforms prior work. Time is indicated by colors. In GMM experiments (see Fig.4), we choose standard Gaussian at initial and terminal time steps while fourmodal GMM and eight-modal GMM are placed at intermediate time steps. Besides good position trajectory, it is almost serendipity that DMSB can also learn the reasonable velocity trajectory without any access to ground truth velocity information. This paves the way for our later velocity estimation for the RNAsc dataset. sc RNAseq Dataset: The emergence of single-cell profiling technologies has facilitated the acquisition of high-resolution single-cell data, enabling the characterization of individual cells at distinct developmental states [7]. However, because the cell population is eliminated after the measurement, one may only gather statistical data for single samples at particular timesteps, which neither preserves any correlations over time nor provides access to the ground truth trajectory. The diversity of embryonic stem cells after development from embryoid bodies, which comprises mesoderm, endoderm, Table 3: Numerical result of MMD and SWD on 100 dimensions single-cell RNA-seq dataset and results for leaving out (LO) marginals at different observation. DMSB outperforms prior work by a large margin for both metrics and all leave-out case. See Appendix.4 for Results over 3 seeds. Algorithm w/o LO LO-t1 LO-t2 LO-t3 w/o LO LO-t1 LO-t2 LO-t3 NLSB[10] 0.66 0.38 0.37 0.37 0.54 0.55 0.54 0.55 MIOFlow[8] 0.23 0.23 0.90 0.23 0.35 0.49 0.72 0.50 DMSB(ours) 0.03 0.04 0.04 0.04 0.20 0.20 0.19 0.18 neuroectoderm, and neural crest in 27 days, is demonstrated by the sc RNA-seq dataset. The snapshot of cells are collected between (t0: day 0 to 3, t1: day 6 to 9, t2: day 12 to 15, t3: day 18 to 21,t4: day 24 to 27). Snapshot data are prepossessed by the quality control [30] and then projected to feature space by principal component analysis (PCA). We inherit processed data from [8]. We validate DMSB on 5-dim and 100-dim PCA space to show superior performance on high-dimension problems compared with baselines. We further show that DMSB can estimate better velocity distribution compared with baselines when the ground truth is absent during training and testing. We testify the performance of our model by computing MMD and SWD with full snapshots and when one of snapshots is left out (LO). We postpone the comparison of all the models on 5-d RNA space to the appendix (see Fig.9 and Table.6) because the problem is relatively simple and all models can infer accurate trajectory. Table.3 summarizes the average MMD and SWD between estimated marginal and ground truth over different snapshot timesteps. DMSB outperforms prior work by a large margin in high (100) dimensional scenarios. The visualization (Fig.5) in PCA space further justifies the numerical result and highlights the variety and quality of the samples produced by DMSB. Ground Truth PC3 PC3 PC3 PC5 PC5 PC5 t0 t1 t2 t3 t4 DMSB Predicted Velocity Ground Truth Velocity Figure 5: Comparison of population-level dynamics on 100-dimensional PCA space at the moment of observation for sc RNA-seq data using MIOFlow, NLSB, and DMSB. We display the plot of the first 6 principle components (PC). Baselines can only learn the trajectory s fundamental trend, whereas DMSB can match the target marginal along the trajectory across different dimensions. The right figure shows Kernel Density Estimation [33] of samples generated by DMSB and ground truth at t3 and t4. The generated samples for all timesteps and comparison with baseline are in Appendix.F. Interestingly, Fig.4 demonstrates that DMSB can reconstruct reasonable evolution of the velocity distribution which was not accessible to the algorithm. We further validate such property in 100-D RNAsc dataset. During the training and testing, all the models do not have access to the ground truth velocity. We run the experiments of 100-D and 5-D RNAsc datasets and average the discrepancy between ground truth velocity and estimated velocity over snapshot time. The numerical values are listed in the Table.7 and Table.6. The plot of velocity and position can be found in Fig.9 and Fig.10. The plot illustrates that while all models are capable of learning reasonable trajectories, only DMSB has the ability to estimate a plausible velocity distribution. This property holds even for 100-D RNA dataset (see Fig.5,11,12). This is notable, despite the velocity estimated by DMSB does not perfectly match the ground truth, because it should be noted that the proposed phase space SDE and the optimality of OT are artificial and may not necessarily represent the actual RNA evolution. Moreover, as individual evolutions cannot be tracked, possibilities such as {A A, B B} versus {A B, B A} can not be discerned, which renders exact velocity recovering almost impossible. 6 Conclusion and Limitations In this paper, we propose DMSB, a scalable algorithm that learns the trajectory which fits the different marginal distributions over time. We extend the mean matching objective to phase space which enables efficient m SB computing. We propose a novel training scheme to fit the mean matching objective without violating BI which is the root of solving mmm SB problem. We demonstrate the superior result of DMSB compared with the existing algorithms. A main limitation of this work is, the rate of convergence to the actual mmm SB has not been quantified after neural network approximations are introduced. Even though [15] theoretically analyzed the convergence of mean matching iteration, supporting its outstanding performance [14], the iteration still fails to converge to the actual SB [34] precisely due to practical neural network estimation errors accumulating over BI. However, recent work [35] shows the convergence of SB when training error exists. In addition, DMSB cannot simulate the process with death and birth of cells which can be potentially described as unbalanced optimal transport [36]. 7 Acknowledgement This research was supported by the ARO Award W911NF2010151, and the Do D Basic Research Office Award HQ00342110002. [1] Yongxin Chen, Giovanni Conforti, Tryphon T Georgiou, and Luigia Ripani. Multi-marginal schrödinger bridges. In International Conference on Geometric Science of Information, pages 725 732. Springer, 2019. [2] Geoffrey Schiebinger, Jian Shu, Marcin Tabaka, Brian Cleary, Vidya Subramanian, Aryeh Solomon, Joshua Gould, Siyan Liu, Stacie Lin, Peter Berube, et al. Optimal-transport analysis of single-cell gene expression identifies developmental trajectories in reprogramming. Cell, 176 (4):928 943, 2019. [3] Karren D Yang and Caroline Uhler. Scalable unbalanced optimal transport using generative adversarial networks. ar Xiv preprint ar Xiv:1810.11447, 2018. [4] Mike Fisher, Jorge Nocedal, Yannick Trémolet, and Stephen J Wright. Data assimilation in weather forecasting: a case study in pde-constrained optimization. Optimization and Engineering, 10(3):409 426, 2009. [5] Kenneth G Manton, Xi Liang Gu, and Gene R Lowrimore. Cohort changes in active life expectancy in the us elderly population: Experience from the 1982 2004 national long-term care survey. The Journals of Gerontology Series B: Psychological Sciences and Social Sciences, 63(5):S269 S281, 2008. [6] Tatsunori Hashimoto, David Gifford, and Tommi Jaakkola. Learning population-level diffusions with generative rnns. In International Conference on Machine Learning, pages 2417 2426. PMLR, 2016. [7] Charlotte Bunne, Laetitia Papaxanthos, Andreas Krause, and Marco Cuturi. Proximal optimal transport modeling of population dynamics. In International Conference on Artificial Intelligence and Statistics, pages 6511 6528. PMLR, 2022. [8] Alexander Tong, Jessie Huang, Guy Wolf, David Van Dijk, and Smita Krishnaswamy. Trajectorynet: A dynamic optimal transport network for modeling cellular dynamics. In International conference on machine learning, pages 9526 9536. PMLR, 2020. [9] Guillaume Huguet, Daniel Sumner Magruder, Oluwadamilola Fasina, Alexander Tong, Manik Kuchroo, Guy Wolf, and Smita Krishnaswamy. Manifold interpolating optimal-transport flows for trajectory inference. ar Xiv preprint ar Xiv:2206.14928, 2022. [10] Tian Qi Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. In Advances in Neural Information Processing Systems, pages 6572 6583, 2018. [11] Takeshi Koshizuka and Issei Sato. Neural lagrangian schr\" odinger bridge. ar Xiv preprint ar Xiv:2204.04853, 2022. [12] Lénaïc Chizat, Stephen Zhang, Matthieu Heitz, and Geoffrey Schiebinger. Trajectory inference via mean-field langevin in path space. ar Xiv preprint ar Xiv:2205.07146, 2022. [13] Tim Dockhorn, Arash Vahdat, and Karsten Kreis. Score-based generative modeling with critically-damped langevin diffusion. ar Xiv preprint ar Xiv:2112.07068, 2021. [14] Tianrong Chen*, Guan-Horng Liu*, and Evangelos A Theodorou. Likelihood training of schrödinger bridge using forward-backward sdes theory. ar Xiv preprint ar Xiv:2110.11291, 2021. [15] Valentin De Bortoli, James Thornton, Jeremy Heng, and Arnaud Doucet. Diffusion schrödinger bridge with applications to score-based generative modeling. ar Xiv preprint ar Xiv:2106.01357, 2021. [16] Yuyang Shi, Valentin De Bortoli, George Deligiannidis, and Arnaud Doucet. Conditional simulation using diffusion schrödinger bridges. In Uncertainty in Artificial Intelligence, pages 1792 1802. PMLR, 2022. [17] Guan-Horng Liu, Arash Vahdat, De-An Huang, Evangelos A Theodorou, Weili Nie, and Anima Anandkumar. I 2 sb: Image-to-image schr\" odinger bridge. ar Xiv preprint ar Xiv:2302.05872, 2023. [18] Guan-Horng Liu, Tianrong Chen, Oswin So, and Evangelos A Theodorou. Deep generalized schr\" odinger bridge. ar Xiv preprint ar Xiv:2209.09893, 2022. [19] James Thornton, Michael Hutchinson, Emile Mathieu, Valentin De Bortoli, Yee Whye Teh, and Arnaud Doucet. Riemannian diffusion schr\" odinger bridge. ar Xiv preprint ar Xiv:2207.03024, 2022. [20] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. ar Xiv preprint ar Xiv:2011.13456, 2020. [21] Lev M Bregman. The relaxation method of finding the common point of convex sets and its application to the solution of problems in convex programming. USSR computational mathematics and mathematical physics, 7(3):200 217, 1967. [22] Michele Pavon and Anton Wakolbinger. On free energy, stochastic control, and schrödinger processes. In Modeling, Estimation and Control of Systems with Uncertainty, pages 334 348. Springer, 1991. [23] Francisco Vargas. Machine-learning approaches for the empirical schrödinger bridge problem. Technical report, University of Cambridge, Computer Laboratory, 2021. [24] Jean-David Benamou, Guillaume Carlier, Marco Cuturi, Luca Nenna, and Gabriel Peyré. Iterative bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2):A1111 A1138, 2015. [25] Paolo Dai Pra. A stochastic control approach to reciprocal diffusion processes. Applied mathematics and Optimization, 23(1):313 329, 1991. [26] Jean-David Benamou, Thomas O Gallouët, and François-Xavier Vialard. Second-order models for optimal transport and cubic splines on the wasserstein space. Foundations of Computational Mathematics, 19(5):1113 1143, 2019. [27] Brian DO Anderson. Reverse-time diffusion equation models. Stochastic Processes and their Applications, 12(3):313 326, 1982. [28] Edward Nelson. Dynamical theories of Brownian motion, volume 106. Princeton university press, 2020. [29] Hung-Yu Tseng, Lu Jiang, Ce Liu, Ming-Hsuan Yang, and Weilong Yang. Regularizing generative adversarial networks under limited data. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 7921 7931, 2021. [30] Kevin R Moon, David van Dijk, Zheng Wang, Scott Gigante, Daniel B Burkhardt, William S Chen, Kristina Yim, Antonia van den Elzen, Matthew J Hirn, Ronald R Coifman, et al. Visualizing structure and transitions in high-dimensional biological data. Nature biotechnology, 37 (12):1482 1492, 2019. [31] Nicolas Bonneel, Julien Rabin, Gabriel Peyré, and Hanspeter Pfister. Sliced and radon wasserstein barycenters of measures. Journal of Mathematical Imaging and Vision, 51:22 45, 2015. [32] Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. The Journal of Machine Learning Research, 13(1):723 773, 2012. [33] Murray Rosenblatt. Remarks on some nonparametric estimates of a density function. The annals of mathematical statistics, pages 832 837, 1956. [34] David Lopes Fernandes, Francisco Vargas, Carl Henrik Ek, and Neill DF Campbell. Shooting schrödinger s cat. In Fourth Symposium on Advances in Approximate Bayesian Inference, 2021. [35] Yu Chen, Wei Deng, Shikai Fang, Fengpei Li, Nicole Tianjiao Yang, Yikai Zhang, Kashif Rasul, Shandian Zhe, Anderson Schneider, and Yuriy Nevmyvaka. Provably convergent schr\" odinger bridge with applications to probabilistic time series imputation. ar Xiv preprint ar Xiv:2305.07247, 2023. [36] Yongxin Chen, Tryphon T Georgiou, and Michele Pavon. The most likely evolution of diffusing and vanishing particles: Schrodinger bridges with unbalanced marginals. SIAM Journal on Control and Optimization, 60(4):2016 2039, 2022. [37] Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-based diffusion models. ar Xiv e-prints, pages ar Xiv 2101, 2021. [38] Jiongmin Yong and Xun Yu Zhou. Stochastic controls: Hamiltonian systems and HJB equations, volume 43. Springer Science & Business Media, 1999. [39] Kenneth Caluya and Abhishek Halder. Wasserstein proximal algorithms for the schrödinger bridge problem: Density control with nonlinear drift. IEEE Transactions on Automatic Control, 2021. [40] Ioannis Exarchos and Evangelos A Theodorou. Stochastic optimal control via forward and backward stochastic differential equations and importance sampling. Automatica, 87:159 165, 2018. [41] Tianrong Chen, Ziyi O Wang, Ioannis Exarchos, and Evangelos Theodorou. Large-scale multiagent deep fbsdes. In International Conference on Machine Learning, pages 1740 1748. PMLR, 2021. [42] Tianrong Chen, Ziyi Wang, and Evangelos A Theodorou. Deep graphic fbsdes for opinion dynamics stochastic control. In 2022 IEEE 61st Conference on Decision and Control (CDC), pages 4652 4659. IEEE, 2022. [43] Hopf Eberhard. The partial differential equation ut+ uux= µxx. Communications on Pure and Applied Mathematics, 3(3):201 230, 1950. [44] Julian D Cole. On a quasi-linear parabolic equation occurring in aerodynamics. Quarterly of applied mathematics, 9(3):225 236, 1951. [45] Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. ar Xiv preprint ar Xiv:1711.05101, 2017. [46] Francisco Vargas, Pierre Thodoroff, Neil D Lawrence, and Austen Lamacraft. Solving schrödinger bridges via maximum likelihood. ar Xiv preprint ar Xiv:2106.02081, 2021. B Proof in 3 and 4 Before stating our proofs, we provide the assumptions used throughout the paper. These assumptions are adopted from stochastic analysis for SGM [27, 37, 38], SB [39], and FBSDE [40 42]. (i) µti with finite second-order moment for all ti. (ii) f and g are continuous functions, and |g(t)|2 > 0 is uniformly lower-bounded w.r.t. t. (iii) t [0, T], we have v log pt(mt, t), v log Ψ( , , ), v log bΨ( , , ), Z( , , ; θ), and b Z( , , ; ϕ) Lipschitz and at most linear growth w.r.t. x and v. (iv) Ψ, bΨ C1,2. (v) k > 0 : p SB t (m) = O(exp m 2 k) as m . Assumptions (i) (ii) (iii) are standard conditions in stochastic analysis to ensure the existenceuniqueness of the SDEs; hence also appear in SGM analysis [37]. Assumption (iv) allows applications of Itô formula and properly defines the backward SDE in FBSDE theory. Finally, assumption (v) assures the exponential limiting behavior when performing integration by parts. w.o.l.g, we denote f = [v, 0]T. B.1 Proof of Proposition.3.1 The results of the Prop.3.1 is part of results of Prop.B.4 which gives the results for both forward and backward likelihood objective. Theorem B.1. The optimization problem 2 a 2 2µdxdvdt, (12) t = m {[(f + gu)] µ} + 1 2g2 vµ, µ0 = p(0, x, v), µ1 = p(T, x, v), (13) will induce the coupled PDEs, t = m f + g2 vϕ µ + 1 2g2 vµ, (14) 2 g vϕ 2 2 v T xϕ 1 2g2 vϕ, (15) and the optimal control of the problem is Proof. One can write the Lagrange by introducing lagrangian multiplier ϕ: L(µ, a, ϕ) = Z 1 Rn Rn 1 2 a 2 2µdxdvdt + Z 1 2g2 mµ + m [(f + gu) µ] dvdxdt Rn Rn 1 2 a 2 2µdxdvdt Z 1 Rn Rnϕ m [(f + gu) µ] ϕ 1 2g2 mµ dvdxdt Rn Rn 1 2 a 2 2µdxdvdt Z 1 Rn Rn mϕT [(f + gu)] µ µ 1 2g2 mϕ dvdxdt 2g2 vϕ g vϕTa µ dvdxdt By taking the minimization within the bracket, The optimal control is, By Plugging it back, the optimality of the aforementioned problem is presented as: t = v f + g2 vϕ µ + 1 2 g vϕ 2 2 v T xϕ 1 Theorem B.2. The optimal forward and backward processes are represented as: dmt = h f + guf t i dt + g(t)dwt (forward) (16) dms = f + gub s dt + g(t)dws (Backward) (17) in which f = [v, 0]T. Optimal control is expressed as, uf t := Zt 0 zt 0 g v log Ψt ub t := b Zt 0 bzt 0 g v log bΨt where Ψ and bΨ are the solution of following PDEs, 2g2 vΨt xΨT t v 2g2 v bΨt x bΨT t v s.t Ψ(x, v, 0)bΨ(x, v, 0) = p(x, v, 0), Ψ(x, v, T)bΨ(x, v, T) = p(x, v, T) Proof. By Lemma.B.1, we notice that the optimal control is: By leveraging Hopf-Cole [43, 44] transformation, here we define Ψ = exp (ϕ) , bΨ = µ exp ( ϕ) . Then we can have the following expressions: Ψ = exp(ϕ) ϕ Ψ = ( Ψ) mi [exp (ϕ) ϕ] = ϕT exp (ϕ) ϕi + exp (ϕ) ( ϕ)i = exp (ϕ) ϕ 2 2 + ϕ bΨ = µ exp ( ϕ) ( ϕ) + exp( ϕ) µ = exp( ϕ)( µ ϕ + µ) mi [exp ( ϕ) ( µ ϕ + µ)] ( µ ϕ + µ)T exp ( ϕ) [ ϕ]i mi [ ϕ]i ϕT mi [ µ]i = exp ( ϕ) µ ϕ 2 2 µT ϕ + µ µ ϕ ϕT µ = exp ( ϕ) µ ϕ 2 2 2 µT ϕ + µ µ ϕ Thus, we can have. t = exp (ϕ) ϕ = exp (ϕ) 1 2 g vϕ 2 2 v T xϕ 1 2g2 vΨ xΨTv t = exp ( ϕ) µ t µ exp ( ϕ) ϕ = exp ( ϕ) µ = exp ( ϕ) m {[(f + gu) Id] µ} + 1 2g2 mµ + µ 1 2 g vϕ 2 2 + v T xϕ + 1 = exp ( ϕ) v (g2µ vϕ) v T xµ + 1 2 g vϕ 2 2 + µv T xϕ + µ1 = exp ( ϕ) g2 vµT vϕ g2µ vϕ v T xµ + 1 2 g vϕ 2 2 + µv T xϕ + µ1 = exp ( ϕ) g2 vµT vϕ µ1 2g2 vϕ v T xµ + 1 2 g vϕ 2 2 + µv T xϕ = exp ( ϕ) µ 2 g vϕ 2 2 g2 vµT vϕ µ1 2g2 vµ v T xµ + µv T xϕ 2g2 v ˆΨ x bΨTv Then we can represent the optimal control as: uf t := Zt 0 zt Hopf-Cole = 0 g v log Ψt Then the solution of such m SB is characterized by the forward SDE: dmt = h f + guf t i dt + g(t)dwt, (23) Due to the structure of Hopf-Cole transform, one can have p SB t = peq.(23) t = Ψt bΨt (24) According to [27, 28], the reverse drift of such SDE (eq.23) ub t should admits, uf t + ub t = g v log p SB t (25) 0 g v log Ψt + ub t = 0 g v log Ψt + g v log bΨt ub t = 0 g v log bΨt which yields The backward optimal control ub t := b Zt 0 bzt 0 g v log bΨt Thus, the optimal forward and backward process is dmt = h f + guf t i dt + g(t)dwt (29) dms = f + gub s dt + g(t)dbws (30) And Ψ and bΨ satisfy following PDEs, 2g2 vΨt xΨT t v 2g2 v bΨt x bΨT t v Lemma B.3. By specifying f = [v, 0]T, The PDE shown in 20 can be represented by following SDEs dx dv = v g2 v log Ψ dt + 0 0 0 g 2 z 2dt + z Tdwt 2 bz 2 + z Tbz + v gbz dt + bz Tdwt s.t : exp (y0 + by0) = p(x, v, 0), exp (y T + by T ) = p(x, v, T) y y(xt, v, t) = log Ψ(xt, vt, t), z z(xt, vt, t) = g v log Ψ(xt, vt, t) by by(xt, vt, t) = log bΨ(xt, vt, t), bz bz(xt, vt, t) = g v log bΨ(xt, vt, t) Proof. One can write log Ψ = x log ΨTv 1 = x log ΨTv 1 = x log bΨTv 1 By applying Itô s lemma, d log Ψ = log Ψ t dt + x log ΨTv + g2 v log Ψ 2 2 + 1 2g2 v log Ψ dt + m log ΨT gdwt = x log ΨTv 1 + x log ΨTv + g2 v log Ψ 2 2 + 1 Ψ2 vΨ vΨT dt + g v log ΨT dwt = g2 v log Ψ 2 2 1 Ψ2 vΨ vΨT dt + g v log ΨT dwt 2g2 v log Ψ 2 2 dt + g v log ΨT dwt Similarly, one can have, d log bΨ = log bΨ t dt + x log bΨTv + g2 v log ΨT v log bΨ + 1 2g2 v log bΨ dt + h m log bΨTi gdwt = x log bΨTv + 1 bΨ 2 v bΨ dt + x log bΨTv + g2 v log ΨT v log bΨ + 1 2g2 v log bΨ dt + g h v log bΨTi dwt Noticing: 1 2 bΨ 2 v bΨ + 2 v log bΨ = Tr 1 2 v log bΨ 2 2 v log bΨ 2 + v log bΨ Following the above derivation, one can have, d log bΨ = x log bΨTv + 1 bΨ 2 v bΨ dt + x log bΨTv + g2 v log ΨT v log bΨ + 1 2g2 v log bΨ dt + g h v log bΨTi dwt = g2 v log ΨT v log bΨ + 1 2g2 v log bΨ 2 + 21 2g2 v log bΨ dt + g h v log bΨTi dwt By defining y y(xt, v, t) = log Ψ(xt, vt, t), z z(xt, vt, t) = g v log Ψ(xt, vt, t) by by(xt, vt, t) = log bΨ(xt, vt, t), bz bz(xt, vt, t) = g v log bΨ(xt, vt, t) One can conclude the results. = v g2 v log Ψ dt + 0 0 0 g 2 z 2dt + z Tdwt 2 bz 2 + z Tbz + v gbz dt + bz Tdwt s.t : exp (y0 + by0) = p(x, v, 0), exp (y T + by T ) = p(x, v, T) Proposition B.4. The log-likelihood at data point m0 can be expressed as log p(m0, 0) = Emt (17) [log p(m T , T)] Z T 2 zt 2dt + 1 2 bzt 2 + z T t bzt + v gbzt = Emt (17) [log p(m T , T)] 1 2 zt 2 + 1 2 bzt g v log p(17) + zt 2 | {z } mean matching objective 2 g v log p(17) zt 2 1 2 bzt g v log p(17) + zt 2 | {z } mean matching objective log p(m T , T) = Emt (16) [log p(m0, 0)] Z T 2 zt 2dt + 1 2 bzt 2 + z T t bzt + v gzt = Emt (16) [log p(m0, 0)] 1 2 bzt 2 + 1 2 bzt g v log p(16) + zt 2 | {z } mean matching objective 2 g v log p(16) bzt 2 1 2 bzt g v log p(16) + zt 2 | {z } mean matching objective By maximizing the log-likelihood at time t = 0 then t = T iteratively, (zt,bzt) will converge to the solution of phase space SB. Proof. from Lemma.B.3, one can have: log p(m0, 0) = E [y0 + by0] = E [y T + by T ] Z T 2 zt 2dt + 1 2 bzt 2 + z T t bzt + v gbzt = E [log p(m T , T)] Z T 2 bzt 2 + z T t bzt + v gbzt = E [log p(m T , T)] Z T 2 bzt 2 bz T t g v log p SB + z T t bzt = E [log p(m T , T)] 2 bzt g v log p SB + zt 2 1 2 g v log p SB zt 2 A similar result can be obtained for log p(m T , T). One can notice that the likelihood objective is a continuous time analog of the mean matching objective proposed in [15], and iterative optimization between logp(m0, 0) and log p(m T , T) are the continuous analog of IPF. Hence, the convergence proof will keep valid (see Proposition 4 in [15]). The equivalence of KL divergence optimization in IPF and likelihood optimization is widely analyzed in [14, 15, 18]. The objective function will eventually boil down to the mean matching objective shown in the above proposition.B.4. Proposition B.5 (Optimality w.r.t. Kboundary). . Given the reference path measure π driven by the policy bzt from boundary µti+1 in the reverse time direction, the optimal path measure in the forward time direction of the following problem min π J (π) := i=0 KL πti:ti+1| πti:ti+1 , s.t π Z πti:ti+1dmti+1 = µti, Z µtidvti = ρti is : π ti:ti+1 = ρti πti:ti+1 R πti:ti+1dmti+1dvti . When πti:ti+1 π ti:ti+1, the following equations need to hold t [ti, ti+1]: zt + bzt g v log ˆpt 2 2 = 0, (35a) pti(vti|xti) ˆqti(vti|xti), (35b) where ˆpt and ˆqt denote the marginal density and conditional velocity distribution of the reference path measure at time t and ti, respectively. Proof. Due to the similarity of optimization for Kboundary, the close form solution of the next path measure is (see 4 in [1] for detail): π ti:ti+1 = ρti πti:ti+1 R πti:ti+1dmti+1dvti . By denoting the transition kernel of parameterized SDE driven by backward policy bzt as q( | ), and the time range between ti and ti+1 is discretized into S interval by EM discretization. Then one can = ρti πti:ti+1 R πti:ti+1dmti+1dvti = pti(xti)qti(mti|mti+δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) = pti(xti)qti(xti, vti|xti+δt, vti+δt)qti+δt(xti+δt, vti+δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti(xti)qti+δt(xti+δt, vti+δt) = pti(xti)qti(xti, vti, xti+δt, vti+δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti(xti)qti+δt(xti+δt, vti+δt) = pti(xti)qti(vti, xti+δt, vti+δt|xti) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti+δt(mti+δt) = pti(xti)qti(vti, xti+δt, vti+δt|xti)qti+δt(mti+δt|mti+2δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti+δt(mti+δt) = pti(xti)q(vti|xti)q(mti+δt|mti)qti+δt(mti+δt|mti+2δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti+δt(mti+δt) = pti(xti)q(vti|xti)q(mti+δt|mti)qti+δt(mti+2δt|mti+δt) qti+1 δt(mti+1 δt|mti+1)µti+1(mti+1) qti+δt(mti+2δt) (36) Doing eq.36 revursively = pti(xti)q(vti|xti)q(mti+δt|mti) s=1 qs(mti+(s+1) δt|mti+s δt) = pti(xti)q(vti|xti) s=0 qs(mti+(s+1) δt|mti+s δt) According to [15], given the policy bzt, the transition kernel qs(mti+(s+1) δt|mti+s δt) can be estimated by bzt (see Proposition 3 in [15])and it can be treated as the label for the forward policy zt for all s. Thus, if πtiti+1 is aligned with π titi+1, then one can construct following objective function for policy zt: t mt + δt Zt(mt) | {z } 1 (mt + mt+δt + δtb Zt+δt(mt+δt) | {z } 2 (mt + δtb Zt+δt(mt) | {z } 3 )) 2 2 (37) t δt Zt(mt) + δtb Zt+δt(mt) (mt+δt mt δtb Zt+δt(mt+δt)) 2 2 (38) t Zt(mt) + b Zt+δt(mt) v log p(17) t 2 2 (39) due to the special structure of Zt and b Zt (40) t zt(mt) + bzt+δt(mt) v log p(17) t 2 2 (41) Where 1 , 2 , 3 corresponds to Fk, Bk, and Bk+1 in [15] respectively. Furthermore, we need to find a density function pti(vti|xti) which satisfies pti(xti)pti(vti|xti) pti(xti)ˆq(vti|xti) pti(vti|xti) ˆqti(vti|xti) to be the new boundary condition. Proposition B.6 (Optimality w.r.t. Kbridge). Given the reference path measure π driven by the policy bzt from boundary µt N in the reverse time direction, the optimal path measure in the forward time direction of the following problem min π J (π) := i=0 KL πti:ti+1| πti:ti+1 , s.t π Kbridge = N 1 i=1 K2 ti is: π t0t N = qt0 πt0:t N R πt0:t N dmt N dvt0 . when πt0t N π t0t N , the following equations need to hold t [t0, t N]: zt + bzt g v log ˆpt 2 2 = 0 (42a) pt0(vt0, xt0) ˆqt0(vt0, xt0) (42b) Proof. Same proof as B.5. Remark B.7. The optimizer of such a problem can be represented as π = µt N π |t N (43) which can also be represented as, π = Z πdµt N ( π |t N )R (44) Where the notation R represents for the time reversal. The Proposition.4.4 is basically using neural network Zθ t to approximate eq.44. C Experiment Details We test DMSB on 2D synthetic datasets and realworld sc RNA-seq dataset [30]. We parameterize z(t, m; θ) and bz(t, m; ϕ) with residual-based networks for all datasets (see.fig.6). The network adopts position encoding and is trained with Adam W[45] on one Nvidia 3090 Ti GPU. We use constant g(t) for simplicity though the framework can adopt time varying function g(t). We set the time horizon T = t N = 1 N and interval δt = 0.01. We use EM discretization throughout the whole paper. For sc RNA-seq dataset, we split data into train and test subsets(85% and 15%).All the experiment results are simulated by all-step push forward from initial data points at time t = t0. MIOFlow and NLSB setup: We use the official implementation of NLSB and MIOFlow.For MIOFlow, we report the best performance for all experiments w/GAE(or AE) and w/o GAE(or AE) embedding. For NLSB, we enlarge the size of the neural network to the best of our GPU capacity for a 100-dimensional sc RNA-seq dataset and report the best performance during the training. We evaluate the velocity of NLSB, as an SDE model, by its estimated drift term at time steps t = {1, 2, 3, 4, 5}. Because MIOflow w/ GAE simulates trajectories in the latent space, we estimate the velocity by using the forward finite difference technique with discretization 1E 3 sec after mapping from the latent code to the original space. We run the experiments of 100-D and 5-D RNAsc datasets and average the discrepancy between ground truth velocity and estimated velocity over snapshot time. The numerical values are listed in the Table.7 and Table.6. The plot of velocity and position can be found in Fig.9 and Fig.10. We do not want to underestimate any prior work and tried out best to tune the prior work. Feel free to communicate with the first author if one can reproduce better results in the experiment section, and we are willing to update it. Metrics and Evaluations The 1-Wasserstein Distance suffers from the curse of dimensionality seriously. In the main paper, we are using Sliced-Wasserstein Distance (SWD) and Maximum Mean Distance as our criterion for 100-dim RNA dataset. An example is listed in the following toy code. One can notice that W1 suffers from the curse of dimensionality seriously, the distance between two gaussian samples is even larger than the distance between gaussian and zeros (See following code snapshot). Hence such a metric is not suitable for high dimension ( 100) dataset evaluation even though some papers report W1. In order to better evaluate our model compared with baselines, we are using W1, Energy Distance, Max-sliced Wasserstein distance, Sliced-Wasserstein Distance and MMD. Our metric is adapted from Geoloss (W1 and Energy), POT (Sliced Wassersetein and Maximum-Sliced Wasserstein) and this repo (MMD). Trajectories Cache Similar to prior work [14, 15], we also need to cache the trajectories for training purposes. We cache 4096 trajectories for each Bregman Projection. Special Clarification for NLSB We evaluate the velocity of NLSB, as an SDE model, by its estimated drift term at time steps t = {1, 2, 3, 4, 5}. It may not be reasonable to consider the drift term as the real velocity, but the drift term can certainly depict a trend of SDE, so we still provide the result here. 1 from ot.sliced import sliced_wasserstein_distance 2 a=torch.randn (1000 ,100) #1000 gaussian samples with dimension 100 3 b=torch.zeros (1000 ,100) #1000 zeros samples with dimension 100 4 c=torch.randn (1000 ,100) #1000 gaussian samples with dimension 100 5 Loss= sliced_wasserstein_distance 6 print( SWD distance between a and b is: {} .format(Loss(a,b))) 7 print( SWD distance between a and c is: {} .format(Loss(a,c))) 8 #SWD distance between a and b is: 1.0433608293533325 9 #SWD distance between a and c is: 0.11096614599227905 Listing 1: Distance compute by SWD distance with 1000 samples and 100 dimensions. 1 from geomloss import Samples Loss 2 a=torch.randn (1000 ,100) #1000 gaussian samples with dimension 100 3 b=torch.zeros (1000 ,100) #1000 zeros samples with dimension 100 4 c=torch.randn (1000 ,100) #1000 gaussian samples with dimension 100 5 Loss=Samples Loss( sinkhorn ,p=1) 6 print( W1 distance between a and b is: {} .format(Loss(a,b))) 7 print( W1 distance between a and c is: {} .format(Loss(a,c))) 8 #W1 distance between a and b is: 9.781818389892578 9 #W1 distance between a and c is: 11.734640121459961 Listing 2: Distance compute by W1 distance with 1000 samples and 100 dimensions. Training:We use Exponential Moving Average (EMA) with a decay rate of 0.999. Table.7 details the hyperparameters used for each dataset.The learning rate for all the datasets is set to be 2e-4 and the training batching size is 256. For computation efficiency, we cache large batch size of empirical samples from reference trajectory and sample training batch size from the cache data. The hyperparameters can be found in Table.7. Figure 6: Neural network architecture for all experiments. The network size (# parameters) are varying between different tasks. Figure 7: Training Hyper-parameters Dataset time steps # BI g(t) # Parameters T SNR # vt Langevin Semicircle 15 2000 0.2 1.21M 3 0.15 1 Petal 30 2000 0.2 1.21M 2 0.15 1 GMM 15 2000 0.2 1.21M 4 0.15 1 sc RNA (100 dim) 15 4000 0.4 1.34M 4 0.15 1 Langevin sampling:The Langevin sampling procedure for the velocity is summarized in 2. Given some pre-defined signal-to-noise ratio r (we set snr =0.15 for all experiments), the Langevin noise scale σ at each time step t and each corrector step i is computed by σt = 2r2g2 ϵ 2 z(t, mt) + bz(t, mt) 2 , (45) D Algorithms Algorithm 1 Sampling Procedure of DMSB Input: Policies z( , ; θ) and bz( , ; ϕ) Total sampling step S = t N δt . Data distributions ρti. Initializing velocity distributions γti = N(0, I) if they are not avaliable. for s = 0 to S 1 do if s==0 then Sample position data xt0 from ρt0. if ground truth velocity distribution γt0 avaliable then Sample velocity data vt0 from γt0 else Sample velocity data vt0 by Langevin simulation conditioning on xt0.(Algorithm.2) end if mt0 = [xt0, vt0]T end if Simulating dynamics: dmt = [f(mt, t) + g(t)Zt] dt + g(t)dwt(eq.16) end for return mt [t0,t N ] Algorithm 2 Langevin Sampler at ti marginal constraint Input: policies z( , ; θ) and bz( , ; ϕ), Previous timestep predicted velocity vti. Sample position from ground truth xti ρti. for step = 0 to # Langevin steps do Sample ϵ N(0, I). Construct new mti = [xti, vti]T Compute v log pθ,ϕ t [z(ti, mti)+bz(ti, mti)]/g. Compute σt with (45). Langevin Sampling vti vti + σti v log pθ,ϕ ti + 2σt ϵ. end for return mti = [xti, vti]T Algorithm 3 DMSB Training Input: N + 1 Marginal position distribution ρti, i [0, N].Parametrized policies z( ; θ) and bz( ; ϕ). The number of Bregman Iteration B. Initialize postion and velocity at time step ti : mti := None for the first iteration. if Use ground truth velocity then set prior velocity: γti = γti else set initial velocity γti = N(0, I) end if for b = 0 to B 1 do for k = N to 1 do zϕ, _ = Opt Sub Set(tk, tk 1, zref = zθ, zopt = zϕ, η = ϕ, m = None) [Optimize Kboundary] end for for k = 0 to N 1 do zθ, _ = Opt Sub Set(tk, tk+1, zref = zϕ, zopt = zθ, η = θ, m = None) [Optimize Kboundary] end for zϕ, bm = Opt Sub Set(t N, t0, zref = zθ, zopt = zϕ, η = ϕ, m = m) [Optimize Kbridge] for k = 0 to N 1 do zθ, _ = Opt Sub Set(tk, tk+1, zref = zϕ, zopt = zθ, η = θ, m = m) [Optimize Kboundary] end for for k = N to 1 do zϕ, _ = Opt Sub Set(tk, tk 1, zref = zθ, zopt = zϕ, η = ϕ, m = None) [Optimize Kboundary] end for zϕ, bm = Opt Sub Set(t0, t N, zref = zϕ, zopt = zθ, η = θ, m = m) [Optimize Kbridge] end for Algorithm 4 Function Opt Sub Set (Optimization for subsets) input: Initial time ti and terminal time tj. Reference path measure boundary condition ρti. Reference path measure driver zref. Policy being optimized zopt and corresponding parameter η. Empirical sample form last iteration bm. output: zopt,samples bmtj from reference path measure. if bm is None then Sample position data xti from ρti. if velocity distribution γti avaliable then Sample conditional velocity data vti from γti else Sample velocity data vti by Langevin simulation conditioning on xti.(see Algorithm.2.) end if mti = [xti, vti]T mti = m end if Sample trajectory mt [ti,tj] from mti using zref Compute L = αLMM + (1 α)Lreg (Regularization of SB Lreg[29] is optional ) update η E Additional Diagram Figure 8: The detailed example diagram of Fig.2. We demonstrate an example of 3 marginals case. The training scheme can be extended to general N marginals easily. The figure consists of two BIs that differs by the training order. Given the reference path measure, we first run the Bregman Projection (BP) within the subset of Kboundary sequentially and end up with the constraint Kbridge. F Additional Experiment Table 4: Our algorithm results over 3 seeds. Numerical result of MMD and SWD on 100 dimensions single-cell RNA-seq dataset and results for leaving out marginals at different observation. DMSB outperforms prior work by a large margin for both metrics and all leave-out case. LO Metrics t1 t2 t3 t4 Avg w/o LO MMD 0.021 1E-3 0.029 5E-3 0.038 2E-3 0.034 2E-3 0.032 3E-3 SWD 0.114 5E-2 0.155 2E-2 0.19 3E-2 0.155 1E-2 0.16 2E-2 w/ LO-t1 MMD 0.09 1E-3 0.019 1E-2 0.032 2E-2 0.029 2E-2 0.042 2E-2 SWD 0.140 2E-2 0.155 1E-2 0.19 2E-2 0.155 1E-2 0.153 3E-2 w/ LO-t2 MMD 0.021 1E-3 0.065 5E-3 0.032 2E-3 0.02 2E-3 0.033 3E-3 SWD 0.100 5E-2 0.202 2E-2 0.13 3E-2 0.191 1E-2 0.155 2E-2 w/ LO-t3 MMD 0.025 2E-3 0.026 2E-2 0.075 1E-2 0.029 2E-2 0.040 2E-2 SWD 0.124 2E-2 0.14 1E-2 0.27 2E-2 0.18 1E-2 0.179 3E-2 Table 5: Numerical result of Wasserstein-1 (W1), MMD, energy distance and Max-sliced Wasserstein distance (MWD) on position of 5 dimensions single-cell RNA-seq dataset using 500 generative samples and 500 ground truth data. Dim=5 Energy MMD W1 SWD MWD NLSB 0.04 0.10 0.74 0.24 0.48 MIOFLOW 0.09 0.28 0.79 0.388 0.66 DMSB(ours) 0.03 0.06 0.67 0.22 0.41 Table 6: Numerical result of Wasserstein-1 (W1), MMD, energy distance and Max-sliced Wasserstein distance (MWD) on the velocity of 5 dimensions single-cell RNA-seq dataset using 500 generative samples and 500 ground truth data. Dim=5 Energy MMD W1 SWD MWD NLSB1 0.44 1.37 1.75 0.83 1.40 MIOFLOW 0.68 2.11 1.88 0.94 1.54 DMSB(ours) 0.40 0.85 1.67 0.74 1.43 1See special clarification (Appendix.C) for the velocity generated NLSB Table 7: Numerical result of Wasserstein-1 (W1), MMD, energy distance and Max-sliced Wasserstein distance (MWD) on the velocity of 100 dimensions single-cell RNA-seq dataset using 500 generative samples and 500 ground truth data. Dim=100 Energy MMD SWD MWD NLSB2 2.12 1.6 0.94 1.27 MIOFLOW 9.18 2.41 1.89 5.66 DMSB(ours) 0.36 0.18 0.39 0.78 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 Ground Truth 4 2 0 2 4 PC3 4 2 0 2 4 PC3 4 2 0 2 4 PC3 4 2 0 2 4 PC3 t0 t1 t2 t3 t4 Figure 9: Comparison of population-level dynamics on 5-dimensional PCA space at the moment of observation for sc RNA-seq data using MIOFlow, NLSB, and DMSB. We display the plot of the first 4 principle components (PC). All method performs well under this experiment setup. 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 2 0 2 4 PC1 4 Ground Truth 4 2 0 2 4 PC3 4 2 0 2 4 PC3 4 2 0 2 4 PC3 4 2 0 2 4 PC3 t0 t1 t2 t3 t4 Figure 10: Comparison of estimated velocity on 5-dimensional PCA space at the moment of observation for sc RNA-seq data using MIOFlow, NLSB, and DMSB. We display the plot of the first 4 principle components (PC). For the results of NLSB, see special clarification of NLSB in Appendix.C 2See special clarification (Appendix.C) for the velocity generated NLSB Ground Truth PC3 PC3 PC3 PC5 PC5 PC5 t0 t1 t2 t3 t4 Figure 11: Comparison of estimated velocity on 100-dimensional PCA space at the moment of observation for sc RNA-seq data using MIOFlow, NLSB, and DMSB. We display the plot of the first 6 principle components (PC). For the results of NLSB, see special clarification of NLSB in Appendix.C ground truth predict Figure 12: Comparison of estimated velocity on 100-dimensional PCA space at the moment of observation for sc RNA-seq data using DMSB with ground truth. We display the plot of the first 6 principle components (PC). G Intuitions of Propositions and Theorems 1. Remark for Proposition 3.1: Within each half-bridge IPF, the variable Zt(or b Zt) is essentially learning the reverse-time stochastic process induced by b Zt. This process can also be viewed as minimizing the approximated parameterized negative log-likelihood. 2. Remark for Proposition 4.2: In order to matching the reference path measure in KL divergence sense, one need to match both the intermediate path measure eq.(9a) and the boundary condition eq.(9b). In the traditional two boundary SB case, matching the boundary condition is often disregarded due to either having a predefined data distribution or a tractable prior. However, in our specific case, as the velocity is not predefined, it becomes imperative to address this issue and optimize it through the application of Langevin dynamics. 3. Remark for Proposition 4.3: Following the same argument as the remark of Proposition 4.2, it becomes evident that, in this particular scenario, there is no need to account for the data distribution since there are no position constraints when optimizing with K. Consequently, the optimal solution will inherently align faithfully with the reverse diffusion and adapt to the boundary conditions imposed by the reference path measure. 4. Remark for Proposition 4.5: We indeed underexplained an important nontrivial fact (thank you so much for catching it): the unique structure of SB leads to a beautiful fact that the score is proportional to the sum of the forward and backward drift terms. This facilitates the sampling of velocity. Specifically, the score function can be obtained using eq(11), as supported by the findings in eq(24). It can also be understood as the one realization of Nelson duality (see Lemma 1 in [46] and [28]). H Complexity Here we provide the complexity of our algorithm. Table 8: Time complexity w.r.t Dimensionality ( Marginals=5) # dimensions 5 10 50 100 Train 24min 25min 33min 44min Sampling 1sec 1.6sec 2.0 sec 2.02sec Table 9: Time complexity w.r.t number of marginals (Dim=100) # Marginal 2 3 4 5 Train 32min 25min 33min 44min Sampling 2.02sec 1.6sec 2.0 sec 2.02sec