# bespoke_solvers_for_generative_flow_models__768f6cae.pdf Published as a conference paper at ICLR 2024 BESPOKE SOLVERS FOR GENERATIVE FLOW MODELS N. Shaul1 J. Perez2 R. T. Q. Chen3 A. Thabet2 A. Pumarola2 Y. Lipman3,1 1Weizmann Institute of Science 2Gen AI, Meta 3FAIR, Meta Diffusion or flow-based models are powerful generative paradigms that are notoriously hard to sample as samples are defined as solutions to high-dimensional Ordinary or Stochastic Differential Equations (ODEs/SDEs) which require a large Number of Function Evaluations (NFE) to approximate well. Existing methods to alleviate the costly sampling process include model distillation and designing dedicated ODE solvers. However, distillation is costly to train and sometimes can deteriorate quality, while dedicated solvers still require relatively large NFE to produce high quality samples. In this paper we introduce Bespoke solvers , a novel framework for constructing custom ODE solvers tailored to the ODE of a given pre-trained flow model. Our approach optimizes an order consistent and parameter-efficient solver (e.g., with 80 learnable parameters), is trained for roughly 1% of the GPU time required for training the pre-trained model, and significantly improves approximation and generation quality compared to dedicated solvers. For example, a Bespoke solver for a CIFAR10 model produces samples with Fr echet Inception Distance (FID) of 2.73 with 10 NFE, and gets to 1% of the Ground Truth (GT) FID (2.59) for this model with only 20 NFE. On the more challenging Image Net-64 64, Bespoke samples at 2.2 FID with 10 NFE, and gets within 2% of GT FID (1.71) with 20 NFE. 1 INTRODUCTION Diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020), and more generally flow-based models (Song et al., 2020b; Lipman et al., 2022; Albergo & Vanden-Eijnden, 2022), have become prominent in generation of images (Dhariwal & Nichol, 2021; Rombach et al., 2021), audio (Kong et al., 2020; Le et al., 2023), and molecules (Kong et al., 2020). While training flow models is relatively scalable and efficient, sampling from a flow-based model entails solving a Stochastic or Ordinary Differential Equation (SDE/ODE) in high dimensions, tracing a velocity field defined with the trained neural network. Using off-the-shelf solvers to approximate the solution of this ODE to a high precision requires a large Number (i.e., 100s) of Function Evaluations (NFE), making sampling one of the main standing challenges in flow models. Improving the sampling complexity of flow models, without degrading sample quality, will open up new applications that require fast sampling, and will help reducing the carbon footprint and deployment cost of these models. Current approaches for efficient sampling of flow models divide into two main groups: (i) Distillation: where the pre-trained model is fine-tuned to predict either the final sampling (Luhman & Luhman, 2021) or some intermediate solution steps (Salimans & Ho, 2022) of the ODE. Distillation does not guarantee sampling from the pre-trained model s distribution, but, when given access to the training data during distillation training, it is shown to empirically generate samples of comparable quality to the original model (Salimans & Ho, 2022; Meng et al., 2023). Unfortunately, the GPU time required to distill a model is comparable to the training time of the original model Salimans & Ho (2022), which is often considerable. (ii) Dedicated solvers: where the specific structure of the ODE is used to design a more efficient solver (Song et al., 2020a; Lu et al., 2022a;b) and/or employ a suitable solver family from the literature of numerical analysis (Zhang & Chen, 2022; Zhang et al., 2023). The main benefit of this approach is two-fold: First, it is consistent, i.e., as the number of steps (NFE) increases, the samples converge to those of the pre-trained model. Second, it does not require further training/fine-tuning of the pre-trained model, consequently avoiding long additional training times and access to training data. Related to our approach, some works have tried to learn an ODE solver within a certain class (Watson et al., 2021; Duan et al., 2023); however, they do not guarantee consistency and usually introduce moderate improvements over generic dedicated solvers. Published as a conference paper at ICLR 2024 Baseline Bespoke Figure 1: Using 10 NFE to sample using our Bespoke solver improves fidelity w.r.t. the baseline (RK2) solver. Visualization of paths was done with the 2D PCA plane approximating the noise and end sample points. In this paper, we introduce Bespoke solvers, a framework for learning consistent ODE solvers custom-tailored to pre-trained flow models. The main motivation for Bespoke solvers is that different models exhibit sampling paths with different characteristics, leading to local truncation errors that are specific to each instance of a trained model. A key observation of this paper is that optimizing a solver for a particular model can significantly improve quality of samples for low NFE compared to existing dedicated solvers. Furthermore, Bespoke solvers use a very small number of learnable parameters and consequently are efficient to train. For example, we have trained n {5, 8, 10} steps Bespoke solvers for a pretrained Image Net-64 64 flow model with {40, 64, 80} learnable parameters (resp.) producing images with Fr echet Inception Distances (FID) of 2.2, 1.79, 1.71 (resp.), where the latter is within 2% from the Ground Truth (GT) FID (1.68) computed with 180 NFE. The Bespoke solvers were trained (using a rather naive implementation) for roughly 1% of the GPU time required for training the original model. Figure 1 compares sampling at 10 NFE from a pre-trained AFHQ-256 256 flow model with order 2 Runge-Kutta (RK2) and its Bespoke version (RK2-Bes), along with the GT sample that requires 180 NFE. Our work brings the following contributions: 1. A differentiable parametric family of consistent ODE solvers. 2. A tractable loss that bounds the global truncation error while allowing parallel computation. 3. An algorithm for training a Bespoke n-step solver for a specific pre-trained model. 4. Significant improvement over dedicated solvers in generation quality for low NFE. 2 BESPOKE SOLVERS We consider a pre-trained flow model taking some prior distribution (noise) p to a target (data) distribution q in data space Rd. The flow model (Chen et al., 2018) is represented by a time-dependent Vector Field (VF) u : [0, 1] Rd Rd that transforms a noise sample x0 p(x0) to a data sample x1 q(x1) by solving the ODE x(t) = ut(x(t)), (1) with the initial condition x(0) = x0 p(x0), from time t = 0 until time t = 1, and x(t) := d dtx(t). The solution at time t = 1, i.e., x(1) q(x(1)), is the generated target sample. Algorithm 1 Numerical ODE solver. Require: t0, x0 for i = 0, 1, . . . , n 1 do (ti+1, xi+1) = step(ti, xi; ut) end for return xn Numerical ODE solvers. Solving equation 1 is done in practice with numerical ODE solvers. A numerical solver is defined by an update rule: (tnext, xnext) = step(t, x; ut). (2) The update rule takes as input current time t and approximate solution x, and outputs the next time step tnext and the corresponding approximation xnext to the true solution x(tnext) at time tnext. To approximate the solution at some desired end time, i.e., t = 1, one first initializes the solution at t = 0 and repeatedly applies the update rule in equation 2 n times, as presented in Algorithm 1. The step is designed so that tn = 1. An ODE solver (step) is said to be of order k if its local truncation error is x(tnext) xnext = O (tnext t)k+1 , (3) asymptotically as tnext t, where t [0, 1) is arbitrary but fixed and tnext, xnext are defined by the solver, equation 2. A popular family of solvers that offers a wide range of orders is the Runge-Kutta (RK) family (Iserles, 2009). Two of the most popular members of the RK family are (set h = n 1): RK1 (Euler - order 1): step(t, x; ut) = (t + h , x + hut(x)) , (4) RK2 (Midpoint - order 2): step(t, x; ut) = t + h , x + hut+ h 2 ut(x) . (5) Published as a conference paper at ICLR 2024 Algorithm 2 Bespoke solver. Require: t0, x0, pre-trained ut, θ for i = 0, 1, . . . , n 1 do (ti+1, xi+1) stepθ(ti, xi; ut) end for return xn Approach outline. Given a pre-trained ut and a target number of time steps n our goal is to find a custom (Bespoke) solver that is optimal for approximating the samples x(1) defined via equation 1 from initial conditions sampled according to x(0) = x0 p(x0). To that end we develop two components: (i) a differentiable parametric family of update rules stepθ, with parameters θ Rp (where p is very small), where sampling is done by replacing step with stepθ in Algorithm 1, see Algorithm 2; and (ii) a tractable loss bounding the global truncation error, i.e., the Root Mean Square Error (RMSE) between the approximate sample xθ n and the GT sample x(1), Global truncation error: LRMSE(θ) = Ex0 p(x0) x(1) xθ n , (6) where xθ n is the output of Algorithm 2, and x = ( 1 d Pd j=1[x(j)]2)1/2. 2.1 PARAMETRIC FAMILY OF ODE SOLVERS THROUGH TRANSFORMED SAMPLING PATHS Our strategy for defining the parametric family of solvers stepθ is using a generic base ODE solver, such as RK1 or RK2, applied to a parametric family of transformed paths. Figure 2: Transformed paths. Transformed sampling paths. We transform the sample trajectories x(t) by applying two components: a time reparametrization and an arbitrary invertible transformation. That is, x(r) = φr(x(tr)), r [0, 1], (7) where tr, φr(x) are arbitrary functions in a family F defined by the following conditions: (i) Smoothness: tr : [0, 1] [0, 1] is a diffeomorphism1, and φ : [0, 1] Rd Rd is C1 and a diffeomorphism in x. We also assume rt and φ 1 r are Lipschitz continuous with a constant L > 0. (ii) Boundary conditions: tr satisfies t0 = 0 and t1 = 1, and φ0( ) is the identity function, i.e., φ0(x) = x for all x Rd. Figure 2 depicts a transformation of a path, x(t). Note that x(0) = x(0), however the end point x(1) does not have to coincide with x(1). Furthermore, as tr : [0, 1] [0, 1] is a diffeomorphism, tr is strictly monotonically increasing. The motivation behind the definition of the transformed trajectories is that it allows reconstructing x(t) from x(r). Indeed, denoting r = rt the inverse function of t = tr we have x(t) = φ 1 rt ( x(rt)). (8) Our hope is to find a transformation that simplifies sampling paths and allows the base solver to provide better approximations of the GT samples. The transformed trajectory xr is defined by a VF ur(x) that can be given an explicit form as follows (proof in Appendix A): Proposition 2.1. Let x(t) be a solution to equation 1. Denote φr := d drφr and tr := d drtr. Then x(r) defined in equation 7 is a solution to the ODE (equation 1) with the VF ur(x) = φr(φ 1 r (x)) + tr xφr(φ 1 r (x))utr(φ 1 r (x)). (9) Solvers via transformed paths. We are now ready to define our parametric family of solvers stepθ(t, x; ut): First we transform the input sample (t, x) according to equation 7 to (r, x) = (rt, φrt(x)). (10) Next, we perform a step with the base solver of choice, (rnext, xnext) = step(r, x; ur), (11) and lastly, transform back using equation 8 to define the parametric solver stepθ via (tnext, xnext) = stepθ(x, t; ut) = trnext, φ 1 rnext( xnext) . (12) The parameters θ denote the parameterized transformations tr and φr satisfying the properties of F and the choice of a base solver step. In Section 2.2 we derive the explicit rules we use in this paper. 1A diffeomorphism is a C1 continuously differentiable function with a C1 continuous differentiable inverse. Published as a conference paper at ICLR 2024 Consistency of solvers. An important property of the parametric solver stepθ is consistency. Namely, due to the properties of F, regardless of the particular choice of tr, φr F, the solver stepθ has the same local truncation error as the base solver. Theorem 2.2. (Consistency of parametric solvers) Given arbitrary tr, φr in the family of functions F and a base ODE solver of order k, the corresponding ODE solver stepθ is also of order k, i.e., x(tnext) xnext = O((tnext t)k+1). (13) The proof is provided in Appendix B. Therefore, as long as tr, φr(x) are in F, decreasing the base solver s step size h 0 will result in our approximated sample xθ n converging to the exact sample x(1) of the trained model in the limit, i.e., xθ n x(1) as n . 2.2 TWO USE CASES We instantiate the Bespoke solver framework for two cases of interest (a full derivation is in Appendix E), and later prove that our choice of transformations in fact covers all noise scheduler configurations used in the standard diffusion model literature. In our use cases, we consider a timedependent scaling as our invertible transformation φr, φr(x) = srx, and its inverse φ 1 r (x) = x/sr, (14) where s : [0, 1] R>0 is a strictly positive C1 scaling function such that s0 = 1 (i.e., satisfying the boundary condition of φ). The transformation of trajectories, i.e., equations 7 and 8, take the form x(r) = srx(tr), and x(t) = x(rt)/str, (15) and we name this transformation: scale-time. The transformed VF ur (equation 9) is thus sr x + trsrutr Use case I: RK1-Bespoke. We consider RK1 (Euler) method (equation 4) as the base solver step and denote ri = ih, i [n], where [n] = {0, 1, . . . , n} and h = n 1. Substituting equation 4 in equation 11, we get from equation 12 that stepθ(ti, xi; ut) := ti+1, si + h si si+1 xi + h ti si si+1 uti(xi) , (17) where we denote ti = tri, ti = d dr|r=ritr, si = sri, si = d dr|r=risr, and i [n 1]. The learnable parameters θ Rp and their constraints are derived from the fact that the functions tr, φr are members of F. There are p = 4n 1 parameters in total: θ = (θt, θs), where θt : 0 = t0 < t1 < < tn 1 < tn = 1 t0, . . . , tn 1 > 0 , θs : s1, . . . , sn > 0 , s0 = 1 s0, . . . , sn 1 . (18) Note that we ignore the Lipschitz constant constraints in F when deriving the constraints for θ. Use case II: RK2-Bespoke. Here we choose the RK2 (Midpoint) method (equation 5) as the base solver step. Similarly to the above, substituting equation 5 in equation 11, we get stepθ(ti, xi; ut) := si+1 xi + h si+1 2 zi + ti+ 1 where we set ri+ 1 2 , and accordingly ti+ 1 2 , and si+ 1 2 are defined, and zi = si + h 2 si tiuti(xi). (20) In this case there are p = 8n 1 learnable parameters, θ = (θt, θs) Rp, where ( 0 = t0 < t 1 2 < < tn = 1 t0, t 1 2 , . . . , tn 1, tn 1 2 > 0 , θs : 2 , s1, . . . , sn > 0 , s0 = 1 s0, s 1 2 , . . . , sn 1 Published as a conference paper at ICLR 2024 Equivalence of scale-time transformations and Gaussian Paths. We note that our scale-time transformation covers all possible trajectories used by diffusion and flow models trained with Gaussian distributions. Denote by pt(x) the probability density function of the random variable x(t), where x(t) is defined by a random initial sampling x(0) = x0 p(x0) and solving the ODE in equation 1. When training a Diffusion or Flow Matching models, pt has the form pt(x) = R pt(x|x1)q(x1)dx1, where pt(x|x1) = N(x|αtx1, σ2 t I). A pair of functions α, σ : [0, 1] [0, 1] satisfying α0 = 0 = σ1, α1 = 1 = σ0, and strictly monotonic snr(t) = αt/σt (22) is called a scheduler2. We use the term Gaussian Paths for the collection of probability paths pt(x) achieved by different schedulers. The velocity vector field that generates pt(x) and results from zero Diffusion/Flow Matching training loss is ut(x) = Z ut(x|x1)pt(x|x1)q(x1) pt(x) dx1, (23) where ut(x|x1) = σt σt x + h αt σt αt i x1, as derived in Lipman et al. (2022). Next, we generalize a result by Kingma et al. (2021) and Karras et al. (2022) to consider marginal sampling paths x(t) defined by ut(x), and show that any two such paths are related by a scale-time transformation: Theorem 2.3. (Equivalence of Gaussian Paths and scale-time transformation) Consider a Gaussian Path defined by a scheduler (αt, σt), and let x(t) denote the solution of equation 1 with ut defined in equation 23 and initial condition x(0) = x0. Then, (i) For every other Gaussian Path defined by a scheduler ( αr, σr) with trajectories x(r) there exists a scale-time transformation with s1 = 1 such that x(r) = srx(tr). (ii) For every scale-time transformation with s1 = 1 there exists a Gaussian Path defined by a scheduler ( αr, σr) with trajectories x(r) such that srx(tr) = x(r). (Proof in Appendix C.) Assuming an ideal velocity field (equation 23), i.e., the pre-trained model is optimal, this theorem implies that searching over the scale-time transformations is equivalent to searching over all possible Gaussian Paths. Note, that in practice we allow s1 = 1, expanding beyond the standard space of Gaussian Paths. Another interesting consequence of Theorem 2.3 (simply plug in t = 1) is that all ideal velocity fields in equation 23 define the same coupling, i.e., joint distribution, of noise x0 and data x1. 2.3 RMSE UPPER BOUND LOSS Optimizing directly the RMSE loss (equation 6) is theoretically possible but would require keeping a full computational graph of Algorithm 2, i.e., n order compositions of ut leading to a large memory footprint. Therefore, we instead derive an upper-bound to the RMSE loss that enables parallel computation over the steps of the solver, considerably reducing memory consumption. To construct the bound, let us fix an initial condition x0 p(x0) and denote as before x(1) to be the exact solution of the sample path (equation 1). Furthermore, consider a candidate solver stepθ, and denote its t and x coordinate updates by stepθ = (stepθ t, stepθ x). Applying Algorithm 2 with t0 = 0, x0 produces a series of approximations xθ i , each corresponds to a time step ti, i [n]. Lastly, we denote by eθ i = x(ti) xθ i , dθ i = x(ti) stepθ x(ti 1, x(ti 1); ut) (24) the global and local truncation errors at time ti, respectively. Our goal is to bound the global error at the final time step tn = 1, i.e., eθ n. Using the update rule definition (equation 2) and triangle inequality we can bound eθ i+1 x(ti+1) stepθ x(ti, x(ti); ut) + stepθ x(ti, x(ti); ut) stepθ x(ti, xθ i ; ut) dθ i+1 + Lθ i eθ i , where Lθ i is defined to be the Lipschitz constant of the function stepθ x(ti, ; ut). To simplify notation we set by definition Lθ n = 1 (this is possible since Lθ n does not actually participate in the bound). 2We use the convention of noise at time t = 0 and data at time t = 1. Published as a conference paper at ICLR 2024 Algorithm 3 Bespoke training. Require: pre-trained ut, number of steps n initialize θ Rp while not converged do x0 p(x0) sample noise x(t) solve ODE 1 GT path L 0 init loss parallel for i = 0, ..., n 1 do xθ i+1 stepθ x (xaux i (ti), ti; ut) L+=M θ i+1 xaux i+1(ti+1) xθ i+1 end for θ θ γ θL optimization step end while return θ Using the above bound n times and noting that eθ 0 = 0 we get i=1 M θ i dθ i , where M θ i = j=i Lθ j. (25) Motivated by this bound we define our RMSE-Bound loss: LRMSE-B(θ) = Ex0 p(x0) i=1 M θ i dθ i , (26) where dθ i is defined in equation 24 and M θ i defined in equation 25. The constants Lθ i depend both on the parameters θ and the Lipschitz constant Lu of the network ut. As Lu is difficult to estimate, we treat Lu as a hyper-parameter, denoted Lτ (in all experiments we use Lτ = 1), and compute Lθ i in terms of θ and Lτ for our two Bespoke solvers, RK1 and RK2, in Appendix D. Assuming that Lτ Lu, an immediate consequence of the bound in equation 25 is that the RMSE-Bound loss bounds the RMSE loss, i.e., the global truncation error defined in equation 6, LRMSE(θ) LRMSE-B(θ). Implementation of the RMSE-Bound loss. We provide pseudocode for Bespoke training in Algorithm 3. During training, we need to have access to the GT path x(t) at times ti, i [n], which we compute with a generic solver. The Bespoke loss is constructed by plugging stepθ (equations 17 or 19) into di (equation 24). The gradient θLRMSE-B(θ) requires the derivatives x(ti)/ ti. Computing the derivatives of x(ti) can be done using the ODE it obeys, i.e., x(ti) = uti(xi). Therefore, a simple way to write the loss ensuring correct gradients w.r.t. ti is replace x(ti) with xaux i (ti) where xaux i (t) = x(Jti K) + u Jti K(x(Jti K)) (t Jti K), (27) where J K denotes the stop gradient operator; i.e., xaux i (t) is linear in t and its value and derivative w.r.t. t coincide with that of x(ti) at time t = ti. Full details are provided in Appendix F. In Appendix K.1 we provide an ablation experiment comparing different Bespoke losses and corresponding algorithms including the direct RMSE loss (eq. 6) and our RMSE-Bound loss (eq. 26). 3 PREVIOUS WORK Diffusion models (Sohl-Dickstein et al., 2015; Ho et al., 2020) are a powerful paradigm for generative models that for sampling require solving a Stochastic Differential Equation (SDE), or its associated ODE, describing a (deterministic) flow process (Song et al., 2020a). Diffusion models have been generalized to paradigms directly aiming to learn a deterministic flow (Lipman et al., 2022; Albergo & Vanden-Eijnden, 2022; Liu et al., 2022). Flow-based models are efficient to train but costly to sample. Previous works had tackled the sample complexity of flow models by building dedicated solver schemes and distillation. Dedicated Solvers. This line of works introduced specialized ODE solvers exploiting the structure of the sampling ODE. Lu et al. (2022a); Zhang & Chen (2022) utilize the semi-linear structure of the score/ϵ-based sampling ODE to adopt a method of exponential integrators. (Zhang et al., 2023) further introduced refined error conditions to fulfill desired order conditions and achieve better sampling, while Lu et al. (2022b) adapted the method to guided sampling. Karras et al. (2022) suggested transforming the ODE to sample a different Gaussian Path for more efficient sampling, while also suggesting non-uniform time steps. In principle, all of these methods effectively proposed based on intuition and heuristics to apply a particular scale-time transformation to the sampling trajectories of the pre-trained model for more efficient sampling, while Bespoke solvers search for an optimal transformation within the entire space of scale-time transformations. Other works also aimed at learning the solver: Dockhorn et al. (2022) (GENIE) introduced a higherorder solver, and distilled the necessary JVP for their method; Watson et al. (2021) (DDSS) optimized a perceptual loss considering a family of generalized Gaussian diffusion models; Lam et al. (2021) improved the denoising process using bilateral filters, thereby indirectly affecting the efficiency of the ODE solver; Duan et al. (2023) suggested to learn a solver for diffusion models by replacing every other function evaluation by a linear subspace projection. Our Bespoke Solvers belong to this family of learnt solvers, however, they are consistent by construction (Theorem 2.2) and minimize a bound on the solution error (for the appropriate Lipschitz constant parameter). Published as a conference paper at ICLR 2024 Distillation. Distillation techniques aim to simplify sampling from a trained model by fine-tuning or training a new model to produce samples with fewer function evaluations. Luhman & Luhman (2021) directly regressed the trained model s samples, while Salimans & Ho (2022); Meng et al. (2023) built a sequence of models each reducing the sampling complexity by a factor of 2. Song et al. (2023) distilled a consistency map that enables large time steps in the probability flow; Liu et al. (2022) retrained a flow-based method based on samples from a previously trained flow. Yang et al. (2023) used distillation to reduce model size while maintaining the quality of the generated images. The main drawbacks of distillation methods is their long training time (Salimans & Ho, 2022), and lack of consistency, i.e., they do not sample from the distribution of the pre-trained model. 4 EXPERIMENTS 8 10 12 14 16 18 20 NFE Figure 3: Bespoke RK1/2, Image Net-64 FM-OT. 8 10 16 20 32 64 128 256 NFE RK1 RK2 EDM RK2-BES Figure 4: Bespoke solver applied to EDM s (Karras et al., 2022) CIFAR10 published model. Models and datasets. Our method works with pre-trained models: we use the pre-trained CIFAR10 (Krizhevsky & Hinton, 2009) model of (Song et al., 2020b) with published weights from EDM (Karras et al., 2022). Additionally, we trained diffusion/flow models on the datasets: CIFAR10, AFHQ-256 (Choi et al., 2020a) and Image Net-64/128 (Deng et al., 2009). Specifically, for Image Net, as recommended by the authors (ima) we used the official face-blurred data (64 64 downsampled using the open source preprocessing scripts from Chrabaszcz et al. (2017)). For diffusion models, we used an ϵ-Variance Preserving (ϵ-VP) parameterization and schedule (Ho et al., 2020; Song et al., 2020b). For flow models, we used Flow Matching (Lipman et al., 2022) with Conditional Optimal Transport (FM-OT), and Flow Matching/v-prediction with Cosine Scheduling (FM/v-CS) (Salimans & Ho, 2022; Albergo & Vanden-Eijnden, 2022). Note that Flow Matching methods directly provide the velocity vector field ut(x), and we converted ϵ-VP to a velocity field using the identity in Song et al. (2020b). For conditional sampling we apply classifier free guidance (Ho & Salimans, 2022), so each evaluation uses two forward passes. Method NFE FID Distillation Zheng et al. (2023) 1 3.78 Luhman & Luhman (2021) 1 9.36 Salimans & Ho (2022) 1 2 8 9.12 4.51 2.57 Dedicated solvers DDIM(Song et al., 2020a) 10 20 13.36 6.84 DPM (Lu et al., 2022a) 10 20 4.7 3.99 DEIS (Zhang & Chen, 2022) 10 20 4.17 2.86 GENIE (Dockhorn et al., 2022) 10 20 5.28 3.94 DDSS (Watson et al., 2021) 10 20 7.86 4.72 RK2-BES ϵ-VP ϵ-VP 10 20 3.31 2.75 RK2-BES FM/v-CS FM/v-CS 10 20 2.89 2.64 RK2-BES FM-OT FM-OT 10 20 2.73 2.59 Table 1: CIFAR10 sampling. Bespoke hyper-parameters and optimization. As our base ODE solvers, we tested RK1 (Euler) and RK2 (Midpoint). Furthermore, we have two hyper-parameters n the number of steps, and Lτ the Lipschitz constant from lemmas D.2, D.3. We train our models with n {4, 5, 8, 10, 12} steps and fix Lτ = 1. Ground Truth (GT) sample trajectories, x(ti), are computed with an adaptive RK45 solver (Shampine, 1986). We compute FID (Heusel et al., 2017) and validation RMSE (equation 6) is computed on a set of 10K fresh noise samples x0 p(x0); Figure 12 depicts an example of RMSE vs. training iterations for different n values. Unless otherwise stated, below we report results on best FID iteration and show samples on best RMSE validation iteration. Figures 21, 22, 23 depict the learned Bespoke solvers parameters θ for the experiments presented below; note the differences across the learned schemes for different models and datasets. Bespoke RK1 vs. RK2. We compared RK1 and RK2 and their Bespoke versions on CIFAR10 and Image Net-64 models (FM-OT and FM/v-CS). Figure 3 and Figures 10, 9 show best validation RMSE (and corresponding PSNR). Using the same budget of function evaluations RK2/RK2-Bespoke produce considerably lower RMSE validation compared to RK1/RK1-Bespoke, respectively. We therefore opted for RK2/RK2-Bespoke for the rest of the experiments below. CIFAR10. We tested our method on the pre-trained CIFAR10 ϵ-VP model (Song et al., 2020b) released by EDM (Karras et al., 2022). Figure 4 compares our RK2-Bespoke solver to the EDM method, which corresponds to a particular choice of scaling, si, and time step discretization, ti. Euler and EDM curves computed as originally implemented in EDM, where the latter achieves Published as a conference paper at ICLR 2024 Image Net-64: ϵ-pred Image Net-64: FM/v-CS Image Net-64: FM-OT Image Net-128: FM-OT 4 6 8 10 1620 32 64 128 256 NFE 1.50 1.75 2.00 RK4 RK1 RK2 DPM-2 RK2-BES 4 6 8 10 1620 32 64 128 256 NFE 1.50 1.75 2.00 RK4 RK1 RK2 DPM-2 RK2-BES 4 6 8 10 1620 32 64 128 256 NFE 1.50 1.75 2.00 RK4 RK1 RK2 RK2-BES 8 10 1620 32 64 128 256 NFE RK4 RK1 RK2 RK2-BES 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE Figure 5: Bespoke RK2 solvers vs. RK1/2/4 solvers on CIFAR-10 Image Net-64, and Image-Net128: FID vs. NFE (top row), and RMSE vs. NFE (bottom row). PSNR vs. NFE is shown in Figure 13. FID=3.05 at 35 NFE, comparable to the result reported by EDM. Using our RK2-Bespoke Solver, we achieved an FID of 2.99 with 20 NFE, providing a 42% reduction in NFE. Additionally, we tested our method on three models we trained ourselves on CIFAR10, namely ϵ-VP, FM/v-CS, and FM-OT. Table 1 compares our best FID for each model with different baselines demonstrating superior generation quality for low NFE among all dedicated solvers; e.g., for NFE=10 we improve the FID of the runner-up by over 34% (from 4.17 to 2.73) using RK2-Bespoke FM-OT model. Table 4 lists best FID values for different NFE, along with the GT FID for the model and the fraction of time Bespoke training took compared to the original model s training time; with 20 NFE, our RK2-Bespoke solvers achieved FID within 8%, 1%, 1% (resp.) of the GT solvers FID. Although close, our Bespoke solver does not match distillation s performance, however our approach is much faster to train, requiring 1% of the original GPU training time with our naive implementation that re-samples the model at each iteration. Figure 11 shows FID/RMSE/PSNR vs. NFE, where PSNR is computed w.r.t. the GT solver s samples. Image Net-64 NFE FID GT-FID/% %Time RK2-BES ϵ-VP ϵ-VP ϵ-VP ϵ-VP ϵ-VP 8 10 16 20 24 3.63 2.96 2.14 1.93 1.84 1.83 / 229 163 120 109 101 3.5 3.6 3.6 3.5 3.6 RK2-BES FM/v-CS FM/v-CS FM/v-CS FM/v-CS FM/v-CS 8 10 16 20 24 2.95 2.20 1.79 1.71 1.69 1.68 / 176 131 107 102 101 1.4 1.6 1.8 1.5 2.0 RK2-BES FM-OT FM-OT FM-OT FM-OT FM-OT 8 10 16 20 24 3.10 2.26 1.84 1.77 1.71 1.68 / 185 135 110 105 102 1.6 1.6 1.7 1.7 1.8 Image Net-128 NFE FID GT-FID/% %Time RK2-BES FM-OT FM-OT FM-OT FM-OT FM-OT 8 10 16 20 24 5.28 3.58 2.64 2.45 2.31 2.30 / 230 156 115 107 101 1.1 1.1 1.2 1.2 1.2 Table 2: Image Net Bespoke solvers. Image Net 64/128. We further experimented with the more challenging Image Net-64 64 / 128 128 datasets. For Image Net-64 we also trained 3 models as described above. For Image Net-128, due to computational budget constraints, we only trained FM-OT (training requires nearly 2000 GPU days). Figure 5 compares RK2-Bespoke to various baselines including DPM 2nd order (Lu et al., 2022a). As can be seen in the graphs, the Bespoke solvers improve both FID and RMSE. Interestingly, the Bespoke sampling takes all methods to similar RMSE levels, a fact that can be partially explained by Theorem 2.3. In Table 2, similar to Table 4, we report best FID per NFE for the Bespoke solvers we trained, the GT FID of the model, the % from GT achieved by the Bespoke solver, and the fraction of GPU time (in %) it took to train this Bespoke solver compared to training the original pre-trained model. Lastly, Figures 6, 7, 27, 28, 29, 25, 26 depict qualitative sampling examples for RK2-Bespoke and RK2 solvers. Note the significant improvement of fidelity in the Bespoke samples to the ground truth. AFHQ-256. We tested our method on the AFHQ dataset (Choi et al., 2020b) resized to 256 256 where as pre-trained model we used a FM-OT model we trained as described above. Figure 14 depicts PSNR/RMSE curves for the RK2-Bespoke solvers and baselines, and Figures 7 and 24 show qualitative sampling examples for RK2-Bespoke and RK2 solvers. Notice the high fidelity of the Bespoke generation samples. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES Figure 6: Comparison of FM-OT and FM/v-CS Image Net-64 samples with RK2 and bespoke-RK2 solvers. Comparison to DPM-2 samples are in Figure 30. More examples are in Figures 27, 28, and 29. The similarity of generated images across models can be explained by their identical noise-to-data coupling (Theorem 2.3). GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 Image Net-128 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 7: FM-OT Image Net-128 (top) and AFHQ-256 (bottom) samples with RK2 and bespoke-RK2 solvers. More examples are in Figures 25, 26 and 24. Ablations. We conducted two ablation experiments. First, Figure 16 shows the effect of training only time transform (keeping sr 1) and scale transformation (keeping tr = r). While the time transform is more significant than scale transform, incorporating scale improves RMSE for low NFE (which aligns with Theorem 2.2), and improve FID. Second, Figure 18 shows application of RK2-Bespoke solver trained on Image Net-64 applied to Image Net-128. The transferred solver, while sub-optimal compared to the Bespoke solver, still considerably improves the RK2 baseline in RMSE, and improves FID for higher NFE (16,20). Reusing Bespoke solvers can potentially be a cheap option to improve solvers. 5 CONCLUSIONS, LIMITATIONS AND FUTURE WORK This paper develops an algorithm for finding low-NFE ODE solvers custom-tailored to general pretrained flow models. Through extensive experiments we found that different models can benefit greatly from their own optimized solvers in terms of global truncation error (RMSE) and generation quality (FID). Currently, training a Bespoke solver requires roughly 1% of the original model s training time, which can probably be still be made more efficient, e.g., by using training data or pre-processing sampling paths. A limitation of our framework is that it requires separate training for each target NFE and/or choice of guidance weight. For general NFE solvers one may consider a combined loss and/or continuous representation of φr, tr, while guidance weight or even more general conditions can be used to directly condition φr, tr. More general/expressive models for φr, tr have the potential to further improve fast sampling of pre-trained models. Published as a conference paper at ICLR 2024 ACKNOWLEDGEMENTS NS is supported by a grant from Israel CHE Program for Data Science Research Centers. Imagenet website. https://www.image-net.org/. Michael S. Albergo and Eric Vanden-Eijnden. Building normalizing flows with stochastic interpolants, 2022. Ricky T. Q. Chen. torchdiffeq, 2018. URL https://github.com/rtqichen/ torchdiffeq. Ricky TQ Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differential equations. Advances in neural information processing systems, 31, 2018. Yunjey Choi, Youngjung Uh, Jaejun Yoo, and Jung-Woo Ha. Stargan v2: Diverse image synthesis for multiple domains, 2020a. Yunjey Choi, Youngjung Uh, Jaejun Yoo, and Jung-Woo Ha. Stargan v2: Diverse image synthesis for multiple domains. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2020b. Patryk Chrabaszcz, Ilya Loshchilov, and Frank Hutter. A downsampled variant of imagenet as an alternative to the cifar datasets. ar Xiv preprint ar Xiv:1707.08819, 2017. Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Fei-Fei Li. Imagenet: A large-scale hierarchical image database. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2009. Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780 8794, 2021. Tim Dockhorn, Arash Vahdat, and Karsten Kreis. Genie: Higher-order denoising diffusion solvers. Advances in Neural Information Processing Systems, 35:30150 30166, 2022. Zhongjie Duan, Chengyu Wang, Cen Chen, Jun Huang, and Weining Qian. Optimal linear subspace search: Learning to construct fast and high-quality schedulers for diffusion models. ar Xiv preprint ar Xiv:2305.14677, 2023. Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in Neural Information Processing Systems (Neur IPS), 2017. Jonathan Ho and Tim Salimans. Classifier-free diffusion guidance. ar Xiv preprint ar Xiv:2207.12598, 2022. Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840 6851, 2020. Arieh Iserles. A first course in the numerical analysis of differential equations. Number 44. Cambridge university press, 2009. Tero Karras, Miika Aittala, Timo Aila, and Samuli Laine. Elucidating the design space of diffusionbased generative models. Advances in Neural Information Processing Systems, 35:26565 26577, 2022. Diederik Kingma, Tim Salimans, Ben Poole, and Jonathan Ho. Variational diffusion models. Advances in neural information processing systems, 34:21696 21707, 2021. Diederik P. Kingma and Jimmy Ba. Adam: A method for stochastic optimization, 2017. Published as a conference paper at ICLR 2024 Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. ar Xiv preprint ar Xiv:2009.09761, 2020. Alex Krizhevsky and Geoffrey Hinton. Learning multiple layers of features from tiny images. In University of Toronto, Canada, 2009. Max WY Lam, Jun Wang, Rongjie Huang, Dan Su, and Dong Yu. Bilateral denoising diffusion models. ar Xiv preprint ar Xiv:2108.11514, 2021. Matthew Le, Apoorv Vyas, Bowen Shi, Brian Karrer, Leda Sari, Rashel Moritz, Mary Williamson, Vimal Manohar, Yossi Adi, Jay Mahadeokar, et al. Voicebox: Text-guided multilingual universal speech generation at scale. ar Xiv preprint ar Xiv:2306.15687, 2023. Yaron Lipman, Ricky T. Q. Chen, Heli Ben-Hamu, Maximilian Nickel, and Matt Le. Flow matching for generative modeling. ar Xiv preprint ar Xiv:2210.02747, 2022. Xingchao Liu, Chengyue Gong, and Qiang Liu. Flow straight and fast: Learning to generate and transfer data with rectified flow. ar Xiv preprint ar Xiv:2209.03003, 2022. Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775 5787, 2022a. Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver++: Fast solver for guided sampling of diffusion probabilistic models. ar Xiv preprint ar Xiv:2211.01095, 2022b. Eric Luhman and Troy Luhman. Knowledge distillation in iterative generative models for improved sampling speed. ar Xiv preprint ar Xiv:2101.02388, 2021. Chenlin Meng, Robin Rombach, Ruiqi Gao, Diederik Kingma, Stefano Ermon, Jonathan Ho, and Tim Salimans. On distillation of guided diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 14297 14306, 2023. Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Bj orn Ommer. Highresolution image synthesis with latent diffusion models, 2021. Tim Salimans and Jonathan Ho. Progressive distillation for fast sampling of diffusion models. ar Xiv preprint ar Xiv:2202.00512, 2022. Lawrence F Shampine. Some practical runge-kutta formulas. Mathematics of computation, 46(173): 135 150, 1986. Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International conference on machine learning, pp. 2256 2265. PMLR, 2015. Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. ar Xiv preprint ar Xiv:2010.02502, 2020a. Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. ar Xiv preprint ar Xiv:2011.13456, 2020b. Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. 2023. Daniel Watson, William Chan, Jonathan Ho, and Mohammad Norouzi. Learning fast samplers for diffusion models by differentiating through sample quality. In International Conference on Learning Representations, 2021. Xingyi Yang, Daquan Zhou, Jiashi Feng, and Xinchao Wang. Diffusion probabilistic model made slim. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 22552 22562, 2023. Published as a conference paper at ICLR 2024 Qinsheng Zhang and Yongxin Chen. Fast sampling of diffusion models with exponential integrator. ar Xiv preprint ar Xiv:2204.13902, 2022. Qinsheng Zhang, Jiaming Song, and Yongxin Chen. Improved order analysis and design of exponential integrator for diffusion models sampling. ar Xiv preprint ar Xiv:2308.02157, 2023. Hongkai Zheng, Weili Nie, Arash Vahdat, Kamyar Azizzadenesheli, and Anima Anandkumar. Fast sampling of diffusion models via operator learning. In International Conference on Machine Learning, pp. 42390 42402. PMLR, 2023. Published as a conference paper at ICLR 2024 A TRANSFORMED PATHS (Appendix to Section 2.1.) Proposition A.1. Let x(t) be a solution to equation 1. Denote φr := d drφr and tr := d drtr. Then x(r) defined in equation 7 is a solution to the ODE (equation 1) with the VF ur(x) = φr(φ 1 r (x)) + tr xφr(φ 1 r (x))utr(φ 1 r (x)). (9) Proof. Differentiating x(r) in equation 7, i.e., x(r) = φr(x(tr)) w.r.t. r and using the chain rule gives dr(φr(x(tr))) = φr(x(tr)) + xφr(x(tr)) x(tr) tr = φr(x(tr)) + xφr(x(tr))utr(x(tr)) tr = φr(φ 1 r ( x(r))) + xφr(φ 1 r ( x(r)))utr(φ 1 r ( x(r))) tr where in the third equality we used the fact that x(t) solves the ODE in equation 1 and therefore x(t) = ut(x(t)); and in the last equality we applied φ 1 r to both sides of equation 7, i.e., x(tr) = φ 1 r ( x(r)). The above equation shows that x(r) = ur( x(r)), (28) where ur(x) is defined in equation 9, as required. B CONSISTENCY OF SOLVERS (Appendix to Section 2.1.) Figure 8: Proof notations and setup. Theorem 2.2. (Consistency of parametric solvers) Given arbitrary tr, φr in the family of functions F and a base ODE solver of order k, the corresponding ODE solver stepθ is also of order k, i.e., x(tnext) xnext = O((tnext t)k+1). (13) Proof. Here (t, x) is our input sample x Rd at time t [0, 1]. By definition r = rt, rnext = r + h, and tnext = trnext. Furthermore, by definition x = φr(x) is a sample at time r; x(rnext) is the solution to the ODE defined by ur starting from (r, x); xnext is an approximation to x(rnext) as generated from the base ODE solver step. Lastly, xnext = φ 1 rnext( xrnext) and x(tnext) = φ 1 rnext( x(rnext)). See Figure 8 for an illustration visualizing this setup. Now, since step is of order k we have that x(rnext) xnext = x(rnext) step( x, r; ur) = O((rnext r)k+1). (29) Published as a conference paper at ICLR 2024 x(tnext) xnext = x(tnext) φ 1 rnext( xnext) = x(tnext) φ 1 rnext( x(rnext) + O((rnext r)k+1)) = x(tnext) φ 1 rnext( x(rnext)) + O((rnext r)k+1) = O((rnext r)k+1) = O((tnext t)k+1), where in the first equality we used the definition of xnext; in the second equality we used equation 29; in the third equality we used the fact that φ 1 r is Lipschitz with constant L (for all r); in the fourth equality we used the definition of the path transform, x(tnext) = φ 1 rnext( x(rnext)) as mentioned above; and in the last equality we used the fact that rt is also Lipschitz with a constant L and therefore rnext r = rtnext rt = O(tnext t). C EQUIVALENCE OF GAUSSIAN PATHS AND SCALE-TIME TRANSFORMATIONS (Appendix to Section 2.2.) Theorem 2.3. (Equivalence of Gaussian Paths and scale-time transformation) Consider a Gaussian Path defined by a scheduler (αt, σt), and let x(t) denote the solution of equation 1 with ut defined in equation 23 and initial condition x(0) = x0. Then, (i) For every other Gaussian Path defined by a scheduler ( αr, σr) with trajectories x(r) there exists a scale-time transformation with s1 = 1 such that x(r) = srx(tr). (ii) For every scale-time transformation with s1 = 1 there exists a Gaussian Path defined by a scheduler ( αr, σr) with trajectories x(r) such that srx(tr) = x(r). Proof of theorem 2.3. Consider two arbitrary schedulers (αt, σt) and ( αr, σr). We can find sr, tr such that αr = srαtr, σr = srσtr. (30) Indeed, one can check the following are such sr, tr: tr = snr 1(snr(r)), sr = σr where we remember snr is strictly monotonic as defined in equation 22, hence invertible. On the other hand, given an arbitrary scheduler (αt, σt) and an arbitrary scale-time transformation (tr, sr) with s1 = 1, we can define a scheduler ( αr, σr) via equation 30. For case (i), we are given another scheduler αr, σr and define a scale-time transformation sr, tr with equation 31. For case (ii), we are given a scale-time transformation sr, tr and define a scheduler αr, σr by equation 30. Now, the scheduler αr, σr defines sampling paths x(r) given by the solution of the ODE in equation 1 with the marginal VF u(1) r (x) defined in equation 23, i.e., u(1) r (x) = Z ur(x|x1) pr(x|x1)q(x1) pr(x) dx1, (32) where ur(x|x1) = σr σr x + h αr σr αr The scale-time transformation sr, tr gives rise to a second VF u(2) r (x) as in equation 16, u(2) r (x) = sr sr x + trsrutr where ut is the VF defined by the scheduler (αt, σt) and equation 23. Published as a conference paper at ICLR 2024 By uniqueness of ODE solutions, the theorem will be proved if we show that u(1) r (x) = u(2) r (x), x Rd, r [0, 1]. (34) For that end, we use the notation of determinants to express ur(x|x1) = 1 0 x x1 σr αr 1 σr αr 0 where x, x1 Rd and αr, σr, αr, σr R as in vector cross product. Differentiating αr, σr w.r.t. r gives αr = srαtr + sr αtr tr, σr = srσtr + sr σtr tr. (36) Using the bi-linearity of determinants shows that: ur(x|x1) = 1 0 x x1 σr αr 1 σr αr 0 0 x x1 srσtr srαtr 1 srσtr + sr σtr tr srαtr + sr αtr tr 0 0 x x1 srσtr srαtr 1 srσtr srαtr 0 0 x x1 srσtr srαtr 1 sr σtr tr sr αtr tr 0 sr x + sr tr 0 x sr x1 σtr αtr 1 σtr αtr 0 sr x + sr trutr where in the second equality we substitute σr, αr as in equation 36, in the third and fourth equality we used the bi-linearity of determinants, and in the last equality we used the definition of ut(x|x1) = σt σt x + h αt σt αt i x1 expressed in determinants notation. Furthermore, since pr(x|x1) = N(x|srαtrx1, s2 rσ2 tr I) N x αtrx1, σ2 tr I = ptr we have that pr(x1|x) = ptr Therefore, Z ur(x|x1) pr(x|x1)q(x1) pr(x) dx1 = E pr(x1|x) ur(x|x1) = Eptr(x1| x sr x + sr trutr sr x + sr tr Eptr(x1| x sr x + sr trutr where in the first equality we used Bayes rule, in the second equality we substitute ur(x|x1) and pr(x1|x) as above, and in the last equality we used the definition of ut as in equation 23. We have proved equation 34 and that concludes the proof. Published as a conference paper at ICLR 2024 D LIPSCHITZ CONSTANTS OF STEPθ. (Appendix to Section 2.3.) We are interested in computing Lθ i , a Lipschitz constant of the bespoke solver step function stepθ x(ti, ; ut). Namely, Lθ i should satisfy stepθ x(ti, x ; ut) stepθ x(ti, y ; ut) Lθ i x y , x, y Rd. (39) We remember that stepθ x(ti, ; ut) is defined using a base solver and the VF uri( ); hence, we begin by computing a Lipschitz constant for uri denoted L u(ri) in an auxiliary lemma: Lemma D.1. Assume that the original velocity field ut has a Lipschitz constant Lu > 0. Then for every ri [0, 1], Lτ Lu, and x, y Rd uri(x) uri(y) L u(ri) x y , (40) L u(ri) = | si| si + ti Lτ (41) Proof of lemma D.1. Since the original velocity field u has a Lipshitz constant Lu > 0, for every t [0, 1] and x, y Rd ut(x) ut(y) Lu x y . (42) uri(x) uri(y) = si si x + tisiuti si y + tisiuti = si si (x y) + tisi si x y + tisi We first apply the auxiliary lemma D.1 to compute a Lipschitz constant of stepθ x(ti, ; ut) with RK1 (Euler method) as the base solver in lemma D.2 and for RK2 (Midpoint method) as the base solver in lemma D.3. Lemma D.2. (RK1 Lipschitz constant) Assume that the original velocity field ut has a Lipschitz constant Lu > 0. Then, for every Lτ Lu, Lθ i = si si+1 (1 + h L u(ri)) , (48) is a Lipschitz constant of RK1-Bespoke update rule, where L u(ri) = | si| si + ti Lτ. (49) Proof of lemma D.2. We begin with writing an explicit expression of stepθ x(ti, x, ; ut) for Euler solver in terms of the transformed velocity field ur. That is, stepθ x(ti, x, ; ut) = 1 si+1 [six + h uri(six)] . (50) Published as a conference paper at ICLR 2024 So that applying the triangle inequality and lemma D.1 gives stepθ x(ti, x; ut) stepθ x(ti, y; ut) = 1 si+1 six + h uri(six) [siy + h uri(siy)] si si+1 x y + h si+1 uri(six) uri(siy) si si+1 x y + h si+1 1 + h | si| Lemma D.3. (RK2 Lipschitz constant) Assume that the original velocity field ut has a Lipschitz constant Lu > 0. Then for every Lτ Lu Lθ i = si si+1 1 + h L u(ri+ 1 2 L u(ri) (51) is a Lipschitz constant of RK2-Bespoke update rule, where L u(ri) = | si| si + ti Lτ. (52) Proof of lemma D.3. We begin by writing explicit expression of stepθ x(ti, x; ut) for RK2 (Midpoint) method in terms of the transformed velocity field ur. We set z = six + h 2 uri(six), w = siy + h 2 uri(siy). (53) then stepθ x(ti, x, ; ut) = 1 si+1 h six + h uri+ 1 2 (z) i , (54) and stepθ x(ti, y; ut) = 1 si+1 h siy + h uri+ 1 2 (w) i . (55) So that applying the triangle inequality and lemma D.1 gives stepθ x(ti, x; ut) stepθ x(ti, y; ut) si si+1 x y + h si+1 2 (z) uri+ 1 si si+1 x y + h si+1 L u(ri+ 1 2 ) z w . (56) We apply the triangle inequality and the lemma D.1 again to z w . That is, z w = six + h 2 uri(six) siy + h 2 uri (siy) 2 uri(six) uri (siy) 2 L u(ri)si x y 2 L u(ri) x y . Substitute back in equation 56 gives si si+1 x y + h si+1 L u(ri+ 1 2 ) z w si si+1 1 + h L u(ri+ 1 2 L u(ri) x y . Published as a conference paper at ICLR 2024 E DERIVATION OF PARAMETRIC SOLVER STEPθ (Appendix to Section 2.2.) This section presents a derivation of n-step parametric solver stepθ(t, x ; ut) = stepθ t(t, x ; ut), stepθ x(t, x ; ut) (57) for scale-time transformation (equation 15) with two options for a base solver: (i) RK1 method (Euler) as the base solver; and (ii) RK2 method (Midpoint). We do so by following equation 1012. We begin with RK1 and derive equation 17. Given (ti, xi), equation 10 for the scale time transformation is, xi = sixi. (58) Then according to equation 11, xi+1 = stepx(ri, xi, uri) (59) = xi + h uri( xi) (60) = xi + h si si xi + tisiuti = sixi + h sixi + tisiuti(xi) , (62) where in the second equality we apply an RK1 step (equation 4), in the third equality we substitute uri using equation 16, and in the fourth equality we substitute xi as in equation 58. According to RK1 step (equation 4) we also have ri+1 = ri + h. Finally, equation 12 gives, stepθ t(ti, xi ; ut) = ti+1 (63) stepθ x(ti, xi ; ut) = si + h si si+1 xi + h si+1 tisiuti(xi), (64) as in equation 17. Regarding the second case, equation 11 for the RK2 method (equation 5) is, xi+1 = stepx(ri, xi, uri) (65) = xi + h uri+ 1 2 uri( xi) (67) is the RK1 step from (ri, xi) with step size h/2. Now substituting xi as defined equation 58 and ur as defined in equation 16 we get xi+1 = sixi + h 2 si tiuti(xi). (69) Lastly, according to equation 12 we have stepθ x(ti, x ; ut) = si si+1 xi + h si+1 as in equation 19 where zi = xi+ 1 Published as a conference paper at ICLR 2024 F IMPLEMENTATION DETAILS (Appendix to Section 2.3.) This section presents further implementation details, complementing the main text. Our parametric family of solvers stepθ is defined via a base solver step and a transformation (tr, φr) as defined in equation 12. We consider the RK2 (Midpoint, equation 5) method as the base solver with n steps and (tr, φr) the scale-time transformation (equation 15). That is, φr(x) = srx, where s : [0, 1] R>0, as in equation 14, which is our primary use case. Parameterization of ti. Remember that tr is a strictly monotonic, differentiable, increasing function t : [0, 1] [0, 1]. Hence, ti must satisfy the constraints as in equation 21, i.e., 0 = t0 < t 1 2 < < tn = 1 (71) 2 , . . . , tn 1, tn 1 2 > 0. (72) To satisfy these constrains, we model ti and ti via Pi j=0 |θt j| Pn k=0 |θt k|, ti = |θ t i|, (73) where θt i and θ t i, i = 0, 1 2, ..., n are free learnable parameters. Parameterization of si. Since sr is a strictly positive, differentiable function satisfying a boundary condition at r = 0, the sequence si should satisfy the constraints as in equation 21, i.e., 2 , s1, . . . , sn > 0 , s0 = 1, (74) and si are unconstrained. Similar to the above, we model si and si by si = 0 i = 0 exp θs i otherwise , si = θ s i , (75) where θs i and θ s i , i = 0, 1 2, ..., n are free learnable parameters. Bespoke training. The pseudo-code for training a Bespoke solver is provided in Algorithm 3. Here we add some more details on different steps of the training algorithm. We initialize the parameters θ such that the scale-transformation is the Identity transformation. That is, for every i = 0, 1 n, ti = 1, (76) si = 1, si = 0. (77) Explicitly, in terms of the learnable parameters, for every i = 0, 1 θt i = 1, θ t i = 1 (78) θs i = 0, θ s i = 0. (79) To compute the GT path x(t), we solve the ODE in equation 1 with the pre-trained model ut and DOPRI5 method, then use linear interpolation to extract x(ti), i = 0, 1, ..., n (Chen, 2018). Then, apply xaux i (t) (equation 27) to correctly handle the gradients w.r.t. to θt i. See Table 3 for number of trajectories used during training. To compute the loss LRMSE-B (equation 26) we compute xi+1 = stepθ x (xaux i (ti), ti; ut) with equations 19,20, and compute Mi via lemma D.3 with Lτ = 1. Finally, we use Adam optimizer Kingma & Ba (2017) with a learning rate of 2e 3. Efficient sampling. When sampling using a Bespoke solver (Algorithm 2) each step involves applying φ 1 ri and φri consecutively. In case we use scale transformation, equation 14 (as is done in all examples in this paper), this does not introduce any difficulty, however if a more compute intensive φ is used the following sampling pseudo-code (Algorithm 4) provides an equivalent sampling while avoiding this unnecessary step. Published as a conference paper at ICLR 2024 CIFAR10 Image Net-64 Image Net-128 AFHQ 256 Total number of trajectories 72k 48k 48k 4k Batch size 12 8 8 1 Number of iterations 6k 6k 6k 4k Table 3: Hyper-parameters of Bespoke solvers training on CIFAR10/Image Net-64/Image Net-128/AFHQ 256. Algorithm 4 Bespoke sampling (efficient). Require: pre-trained ut, trained θ x0 p(x0) sample noise r0 0, x0 x0 initial conditions for i = 0, 1, . . . , n 1 do (ri+1, xi+1) step(ri, xi; uθ r) end for return φ 1 1 ( xn) G BESPOKE RK1 VERSUS RK2 Image Net 64: FM/v-CS Image Net 64: FM-OT 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE Figure 9: Bespoke RK1, Bespoke RK2, RK1, and RK2 solvers on Image Net-64 models: RMSE vs. NFE (top row), and PSNR vs. NFE (bottom row). Published as a conference paper at ICLR 2024 CIFAR10: ϵ-VP CIFAR10: FM/v-CS CIFAR10: FM-OT 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE Figure 10: Bespoke RK1, Bespoke RK2, RK1, and RK2 solvers on CIFAR10: RMSE vs. NFE (top row), and PSNR vs. NFE (bottom row). CIFAR10 NFE FID GT-FID/% %Time RK2-BES ϵ-VP ϵ-VP ϵ-VP ϵ-VP 4.26 3.31 2.84 2.75 2.54 / 168 130 112 108 1.4 1.5 1.5 1.4 RK2-BES FM/v-CS FM/v-CS FM/v-CS FM/v-CS 3.50 2.89 2.68 2.64 2.61 / 134 111 103 101 0.5 0.6 0.6 0.6 RK2-BES FM-OT FM-OT FM-OT FM-OT 3.13 2.73 2.60 2.59 2.57 / 122 106 101 101 0.5 0.6 0.6 0.6 Table 4: CIFAR10 Bespoke solvers. We report best FID vs. NFE, the ground truth FID (GT-FID) for the model and FID/GT-FID in %, and the fraction of GPU time (in %) required to train the bespoke solver w.r.t. training the original model. Published as a conference paper at ICLR 2024 CIFAR10: ϵ-VP CIFAR10: FM/v-CS CIFAR10: FM-OT 8 10 16 20 32 64 128 256 NFE 2.5 3.0 4.0 5.0 RK4 RK1 RK2 RK2-BES 8 10 16 20 32 64 128 256 NFE RK4 RK1 RK2 RK2-BES 8 10 16 20 32 64 128 256 NFE RK4 RK1 RK2 RK2-BES 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE Figure 11: CIFAR10 sampling with Bespoke RK2 solvers vs. RK1,RK2,RK4: FID vs. NFE (top row), RMSE vs. NFE (middle row), and PSNR vs. NFE (bottom row). I IMAGENET-64/128 Image Net-64: ϵ-pred Image Net-64: FM/v-CS Image Net-64: FM-OT Image Net-128: FM-OT 0 1000 2000 3000 4000 5000 Iteration NFE=8 NFE=10 NFE=16 NFE=20 NFE=24 0 1000 2000 3000 4000 5000 Iteration NFE=8 NFE=10 NFE=16 NFE=20 NFE=24 0 1000 2000 3000 4000 5000 Iteration NFE=8 NFE=10 NFE=16 NFE=20 NFE=24 0 1000 2000 3000 4000 5000 Iteration NFE=8 NFE=10 NFE=16 NFE=20 NFE=24 Figure 12: Validation RMSE vs. training iterations of Bespoke RK2 solvers on Image Net-64, and Image Net128. Image Net-64: ϵ-pred Image Net-64: FM/v-CS Image Net-64: FM-OT Image Net-128: FM-OT 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE 4 6 8 10 12 14 16 18 20 22 24 NFE RK4 RK1 RK2 RK2-BES Figure 13: Bespoke RK2, RK1, RK2, and RK4 solvers on Image Net-64, and Image Net-128; PSNR vs. NFE. Published as a conference paper at ICLR 2024 8 10 12 14 16 18 20 NFE RK4 RK1 RK2 RK2-BES 8 10 12 14 16 18 20 NFE Figure 14: Bespoke RK2, RK1, RK2, and RK4 solvers on AFHQ-256; PSNR vs. NFE (left), and RMSE vs. NFE (right). Published as a conference paper at ICLR 2024 K ABLATIONS K.1 LOSS ABLATION We consider here three losses for optimizing the Bespoke solvers: RMSE-Bound (the parallel loss we advocate in the paper), RMSE (optimizing directly equation 6), and a simplified version of the RMSE-Bound: sum of Local Truncation Errors (LTE-parallel). That is, LTE is defined as in equation 26 but taking M θ i = 1 for all i. Algorithm 5 provides the pseudo-codes for all three losses. We have run all three algorithms on the Image Net 64 dataset and compared their FID, and RMSE, where RMSE is computed w.r.t. GT samples (see Section 4 for details). For the non-parallel RMSE loss, we needed to use Activate Checkpointing to reduce memory consumption in order to be able to run this loss. Figure 15 shows the results. As can be seen in the graphs, RMSE loss, as expected, reaches lowest RMSE values per NFE, the second best is the RMSE-Bound loss and worst in terms of RMSE is the LTE. As for FID, RMSE performs worst, while for FM-OT model RMSE-Bound and LTE perform equivalently for NFE>10, and LTE has advantage as far as FID goes otherwise. Since our goal is to reduce RMSE and provide memory-scalable training algorithm we opted to use the memory efficient RMSE-Bound in the paper. Algorithm 5 Bespoke training (parallel). Require: pre-trained ut, number of steps n initialize θ Rp while not converged do x0 p(x0) sample noise x(t) solve ODE 1 GT path if RMSE loss then xθ n Bespoke sampling Alg. 2 L x(1) xθ n else if RMSE-parallel loss then L 0 init loss parallel for i = 0, ..., n 1 do xθ i+1 stepθ x (xaux i (ti), ti; ut) L += M θ i+1 xaux i+1(ti+1) xθ i+1 end for else if LTE-parallel loss then L 0 init loss parallel for i = 0, ..., n 1 do xθ i+1 stepθ x (xaux i (ti), ti; ut) L += xaux i+1(ti+1) xθ i+1 end for end if θ θ γ θL optimization step end while return θ K.2 SCALE-TIME ABLATIONS This section presents an ablation experiment on the effect of each component in the scale-time transformation. We train Bespoke-RK2 solvers with three choices of transformation: (i) time-only: train tr, tr and freeze sr 1, sr 0, (ii) scale-only: freeze tr r, tr 1 and train sr, sr, and (iii) scale-time: train both tr, tr and sr, sr. All experiments are performed on Image Net 64 FMOT model. Figure 16 shows the FID, and RMSE of Bespoke-RK2 time-only/scale-only/scale-time solvers and the base RK2 solver. First, all three Bespoke-RK2 solvers improve upon the base RK2 solver. Second, the time component seems more significant, but the scale component improves FID for all NFEs and RMSE for NFE < 20. Third, in RMSE, the significance of the time component increases as NFE increases. In addition, Figure 17 shows the trained time-only (top) and scale-only (bottom) transformations. Interestingly, we see that even seemingly small changes (e.g., scale-only with NFE {16, 20}) can affect dramatically the FID. Published as a conference paper at ICLR 2024 Image Net 64: ϵ-VP Image Net 64: FM/v-CS Image Net 64: FM-OT 8 10 16 20 NFE LTE RMSE RMSE-Bound 8 10 16 20 NFE LTE RMSE RMSE-Bound 8 10 16 20 NFE LTE RMSE RMSE-Bound 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE 8 10 12 14 16 18 20 NFE Figure 15: Different Bespoke losses for Image Net 64. We compare the RMSE-Bound (the parallel loss advocated in the paper), direct RMSE loss, and Local Truncation Error (LTE) which is a slightly simplified version of RMSE-Bound loss. All three variations are provided in Algorithm 5. In terms of RMSE the direct RMSE loss is as expected best however has large memory footprint, while RMSE-Bound is the runner-up and parallelizable. FID is not perfectly correlated with RMSE and shows a somewhat opposite trend (partially excluding the FM-OT model where RMSE-Bound and LTE are almost FID equivalent). We opted for RMSE-Bound loss in the paper since it is memory-scalable and provides best RMSE among the parallel loss options considered. 8 10 16 20 NFE RK2-BES: Scale-Time RK2-BES: Time Only RK2-BES: Scale Only RK2 8 10 12 14 16 18 20 NFE Figure 16: Bespoke ablation I: RK2, Bespoke RK2 with full scale-time optimization, time-only optimization (keeping sr 1 fixed), and scale-only optimization (keeping tr = r fixed) on FM-OT Image Net-64: FID vs. NFE (left), and RMSE vs. NFE (right). Note that most improvement provided by time optimization where scale improves FID for all NFEs, and RMSE for < 20 NFEs. K.3 TRANSFERRING BESPOKE SOLVERS This section presents an ablation experiment demonstrating trained Bespoke solvers generalization to different models. We train a Bespoke-RK2 solver on an Image Net-64 FM-OT model and evaluate it on an Image Net-128 FM-OT model. We compare its FID, and RMSE vs. NFE against the base RK2 solver evaluated on Image Net-128 FM-OT and a Bespoke-RK2 solver trained and evaluated on Image Net-128 FM-OT. The results is shown in Figure 18. K.4 DISTILLATION-TYPE PARAMETRIZATION This section presents a distillation-type experiment in our framework. Our parametric family of solvers is defined by a composition of a scale-time transformation and the RK2 solver, resulting in stepθ as in equations 19 and 20, where the weights of the pre-trained model ut are frozen. A natural comparison to a distillation-like approach is to let the weights of ut = uθ t change during Published as a conference paper at ICLR 2024 NFE=8 NFE=10 NFE=16 NFE=20 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.5 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.00 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.00 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.5 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative Figure 17: Trained θ of scale-time ablation: Bespoke-RK2 time-only optimization (top), Bespoke-RK2 scaleonly optimization (bottom), on Image Net-64 FM-OT for NFE 8/10/16/20. 8 10 16 20 NFE RK2 Transferred RK2-BES RK2-BES 8 10 12 14 16 18 20 NFE Figure 18: Bespoke ablation II: RK2 evaluated on FM-OT Image Net-128 model, Bespoke RK2 trained and evaluated on FM-OT Image Net-128 model, and Bespoke RK2 trained on FM-OT Image Net-64 and evaluated on FM-OT Image Net-128 model (transferred): FID vs. NFE (left), and RMSE vs. NFE (right). Note that the transferred Bespoke solver is still inferior to the Bespoke solver but improves considerably RMSE compared to the RK2 baseline. In FID the transferred solver improves over the baseline only for NFE=16,20. optimization instead of using the scale-time transformation, that is Ldis(θ) = Ex0 p(x0) xi+1(ti+1) stepx(xi, ti; uθ t )) , (80) where ti = ih is fixed. We perform this experiment on Image Net 64 FMv-CS. For a fair comparison, we use the same compute budget as used to train our Bespoke-RK2 solvers on these models: a total of 48K generated trajectories, 6k iterations, and we report at best FID. Figure 19 shows the RK2 base solver, Bespoke-RK2 solver, and Distillation-RK2 on Image Net 64 FMv-CS: FID vs. NFE (left), and RMSE vs. NFE (right). Note that while distillation is able to improve from the baseline solver (RK2) it does not match the performance of the Bespoke solver. Two potential explanations why distillation is not as performant as Bespoke in this experiment are: First, the amount of trajectories/compute budget we use to optimize Bespoke is not sufficient for effective distillation; and second, successful distillation methods require access to training data. Training only on generated data distillation is less effective. K.5 TRAINING STOPPING CRITERIA In this ablation experiment, we qualitatively compare samples when changing the stopping criteria of the bespoke solver training. In Figure 20, we compare samples with bespoke-RK2 solvers where the training stopping criteria is best FID versus best RMSE on four models, Image Net 64 ϵ-VP/FMv CS/FM-OT and Image Net 128 FM-OT, and NFE {8, 10, 20}. In all of the cases the differences are practically indistinguishable. Published as a conference paper at ICLR 2024 6 8 10 12 14 16 18 20 NFE RK2 RK2-Distillation RK2-BES 6 8 10 12 14 16 18 20 NFE Figure 19: Distillation-type experiment: RK2 solver, Bespoke-R2 solver, and Distillation using same number of trajectories and compute as Bespoke, on Image Net 64 FM/v-CS: FID vs. NFE (left), and RMSE vs. NFE (right). GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 Figure 20: Different stopping criteria experiment. We compare samples with bespoke-RK2 solvers with training stopping criteria at best FID and best RMSE; Image Net64 (3 top rows) and Image Net-128 (2 bottom rows). L TRAINED BESPOKE SOLVERS In this section, we present the trained Bespoke solvers by visualizing their respective parameters θ. Figures 21, 22, and 23 show the learned scale-time transformation of Bespoke-RK2 solvers trained on Image Net-128, Image Net-64, and CIFAR10 (resp.) for NFE {8, 10, 16, 20}. First, we note the significant differences between the learned scale-time transformation of ϵ-VP versus FMOT on Image Net 64 in Figure 22 top and bottom rows (resp.). Second, we note that the scale-time transformations trained on the same model type but on different datasets seem to have similarities Published as a conference paper at ICLR 2024 to some degree but are still different from one another, see Figure 21 and Figures 22, 23 bottom rows, for FM-OT trained on Image Net-128, Image Net-64, and CIFAR10 (resp.). These two observations showcase the advantage of a custom-made solver for each model. We also tested the latter point empirically in the ablation experiment in Section K.3, where we tested a Bespoke-RK2 solver trained on an Image Net 64 FM-OT model to an Image Net 128 FM-OT model and noticed a drop in performance compared to a Bespoke solver trained directly on the Image Net 128 FM-OT model. In addition, we note the resemblance in the form of scale-time transformations trained on the same model type and same dataset across different NFE (i.e., a row in Figures 21, 22, and 23). This phenomenon suggests there may be a well-defined scale-time transformation in the limit of NFE . Furthermore, for ti and si, even and odd parity points on the grid seem to converge to different curves, possibly due to their different role in the RK2 solvers. NFE=8 NFE=10 NFE=16 NFE=20 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.0 Transform Derivative Figure 21: Trained θ of Bespoke-RK2 solvers on Image Net-128 FM-OT for NFE 8/10/16/20. NFE=8 NFE=10 NFE=16 NFE=20 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 4.0 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.00 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.0 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.0 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 2.0 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative Figure 22: Trained θ of Bespoke-RK2 solvers on Image Net-64 for NFE 8/10/16/20; ϵ-VP (top), FM/v-CS (middle), and FM-OT (bottom). M PRE-TRAINED MODELS All our FM-OT and FM/v-CS models were trained with Conditional Flow Matching (CFM) loss derived in Lipman et al. (2022), LCFM(θ) = Et,p0(x0),q(x1) vt(xt; θ) ( σtx0 + αtx1) 2 , (81) where t U([0, 1]), p0(x0) = N (x0|0, I), q(x1) is the data distribution, vt(xt; θ) is the network, (αt, σt) is the scheduler as defined in equation 22, and xt = σtx0 +αtx1. For FM-OT the scheduler Published as a conference paper at ICLR 2024 NFE=8 NFE=10 NFE=16 NFE=20 0.0 0.2 0.4 0.6 0.8 1.0 r 1.75 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 3.5 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 3.0 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 1.50 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 1.75 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 1.50 Transform Derivative 0.0 0.2 0.4 0.6 0.8 1.0 r 1.50 Transform Derivative Figure 23: Trained θ of Bespoke-RK2 solvers on CIFAR10 for NFE 8/10/16/20; ϵ-pred (top), FM/v-CS (middle), and FM-OT (bottom). is αt = t, σt = 1 t, (82) and for FM/v-CS the scheduler is 2 t, σt = cos π All our ϵ-VP models were trained with noise prediction loss as derived in Ho et al. (2020) and Song et al. (2020b), Lnoise(θ) = Et,p0(x0),q(x1) ϵt(xt; θ) x0 2 , (84) where the VP scheduler is αt = ξ1 t, σt = q 1 ξ2 1 t, ξs = e 1 4 s2(B b) 1 and B = 20, b = 0.1. All models use U-Net architecture as in Dhariwal & Nichol (2021), and the hyper-parameters are listed in Table 5. Published as a conference paper at ICLR 2024 CIFAR10 CIFAR10 Image Net-64 Image Net-128 AFHQ 256 ϵ-VP FM-OT;FM/v-CS ϵ-VP;FM-OT;FM/v-CS FM-OT FM-OT Channels 128 128 196 256 256 Depth 4 4 3 2 2 Channels multiple 2,2,2 2,2,2 1,2,3,4 1,1,2,3,4 1,1,2,2,4,4 Heads 1 1 - - - Heads Channels - - 64 64 64 Attention resolution 16 16 32,16,8 32,16,8 64,32,16 Dropout 0.1 0.3 1.0 0.0 0.0 Effective Batch size 512 512 2048 2048 256 GPUs 8 8 64 64 64 Epochs 2000 3000 1600 1437 862 Iterations 200k 300k 1M 900k 50k Learning Rate 5e-4 1e-4 1e-4 1e-4 1e-4 Learning Rate Scheduler constant constant constant Poly Decay Polyn Decay Warmup Steps - - - 5k 5k P-Unconditional - - 0.2 0.2 0.2 Guidance weight - - 0.20 (vp,cs), 0.15 (ot) 0.5 0.1 Total parameters count 55M 55M 296M 421M 537M Table 5: Pre-trained models hyper-parameters. N MORE RESULTS In this section we present more sampling results using RK2-Bespoke solvers, the RK2 baseline and Ground Truth samples (with DOPRI5). Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 24: Comparison of FM-OT AFHQ-256 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 25: Comparison of FM-OT Image Net-128 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 26: Comparison of FM-OT Image Net-128 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 27: Comparison of FM-OT Image Net-64 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 28: Comparison of FM/v-CS Image Net-64 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES RK2 RK2-BES Figure 29: Comparison of ϵ-pred Image Net-64 GT samples with RK2 and Bespoke-RK2 solvers. Published as a conference paper at ICLR 2024 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 GT NFE=20 NFE=10 NFE=8 DPM-2 RK2-BES DPM-2 RK2-BES Figure 30: Comparison of ϵ-VP and FM/v-CS Image Net-64 samples with DPM-2 and bespoke-RK2 solvers.