# secondorder_neural_ode_optimizer__3f26adcd.pdf Second-Order Neural ODE Optimizer Guan-Horng Liu, Tianrong Chen, Evangelos A. Theodorou Georgia Institute of Technology, USA {ghliu, tianrong.chen, evangelos.theodorou}@gatech.edu We propose a novel second-order optimization framework for training the emerging deep continuous-time models, specifically the Neural Ordinary Differential Equations (Neural ODEs). Since their training already involves expensive gradient computation by solving a backward ODE, deriving efficient second-order methods becomes highly nontrivial. Nevertheless, inspired by the recent Optimal Control (OC) interpretation of training deep networks, we show that a specific continuoustime OC methodology, called Differential Programming, can be adopted to derive backward ODEs for higher-order derivatives at the same O(1) memory cost. We further explore a low-rank representation of the second-order derivatives and show that it leads to efficient preconditioned updates with the aid of Kronecker-based factorization. The resulting method named SNOpt converges much faster than first-order baselines in wall-clock time, and the improvement remains consistent across various applications, e.g. image classification, generative flow, and timeseries prediction. Our framework also enables direct architecture optimization, such as the integration time of Neural ODEs, with second-order feedback policies, strengthening the OC perspective as a principled tool of analyzing optimization in deep learning. Our code is available at https://github.com/ghliu/snopt. 1 Introduction Accuracy (%) 0 1.8k 2.6k 5.4k 0 1.8k 2.6k 5.4k Accuracy (%) Adam SGD SNOpt (ours) Wall-Clock Time (sec) Wall-Clock Time (sec) Image Classification (CIFAR10) Time-Series Prediction (Art WR) Figure 1: Our second-order method (SNOpt; green solid curves) achieves superior convergence compared to first-order methods (SGD, Adam) on various Neural-ODE applications. Neural ODEs (Chen et al., 2018) have received tremendous attention over recent years. Inspired by taking the continuous limit of the discrete residual transformation, xk+1 = xk + ϵF(xk, θ), they propose to directly parameterize the vector field of an ODE as a deep neural network (DNN), i.e. dx(t) dt = F(t, x(t), θ), x(t0) = xt0, (1) where x(t) Rm and F( , , θ) is a DNN parameterized by θ Rn. This provides a powerful paradigm connecting modern machine learning to classical differential equations (Weinan, 2017) and has since then achieved promising results on time series analysis (Rubanova et al., 2019; Kidger et al., 2020b), reversible generative flow (Grathwohl et al., 2018; Nguyen et al., 2019), image classification (Zhuang et al., 2020, 2021), and manifold learning (Lou et al., 2020; Mathieu & Nickel, 2020). Due to the continuous-time representation, Neural ODEs feature a distinct optimization process (see Fig. 2) compared to their discrete-time counterparts, which also poses new challenges. First, the 35th Conference on Neural Information Processing Systems (Neur IPS 2021) defined in (1) defined in (5) Computation flow ODESolve ODESolve Forward ODE Backward Adjoint ODE ODE solver function call Query time derivative Forward vector field Backward vector field Parameter of N-ODE L x(t1) 0 G F ( , , θ) (t,x(t)) dx(t)/dt x(t0) x(t1) θ Figure 2: Neural ODE features distinct training process: Both forward and backward passes parameterize vector fields so that any generic ODE solver (which can be non-differentiable) can query time derivatives, e.g. dx(t) dt , to solve the ODEs (1, 5). In this work, we extend it to second-order training. forward pass of Neural ODEs involves solving (1) with a black-box ODE solver. Depending on how its numerical integration is set up, the propagation may be refined to arbitrarily small step sizes and become prohibitively expensive to solve without any regularization (Ghosh et al., 2020; Finlay et al., 2020). On the other hand, to prevent Back-propagating through the entire ODE solver, the gradients are typically obtained by solving a backward adjoint ODE using the Adjoint Sensitivity Method (ASM; Pontryagin et al. (1962)). While this can be achieved at a favorable O(1) memory, it further increases the runtime and can suffer from inaccurate integration (Gholami et al., 2019). For these reasons, Neural ODEs often take notoriously longer time to train, limiting their applications to relatively small or synthetic datasets (Massaroli et al., 2020) until very recently (Zhuang et al., 2021). To improve the convergence rate of training, it is natural to consider higher-order optimization. While efficient second-order methods have been proposed for discrete models (Ba et al., 2016; George et al., 2018), it remains unclear how to extend these successes to Neural ODEs, given their distinct computation processes. Indeed, limited discussions in this regard only note that one may repeat the backward adjoint process recursively to obtain higher-order derivatives (Chen et al., 2018). This is, unfortunately, impractical as the recursion will accumulate the aforementioned integration errors and scale the per-iteration runtime linearly. As such, second-order methods for Neural ODEs are seldom considered in practice, nor have they been rigorously explored from an optimization standpoint. In this work, we show that efficient second-order optimization is in fact viable for Neural ODEs. Our method is inspired by the emerging Optimal Control perspective (Weinan et al., 2018; Liu & Theodorou, 2019), which treats the parameter θ as a control variable, so that the training process, i.e. optimizing θ w.r.t. some objective, can be interpreted as an Optimal Control Programming (OCP). Specifically, we show that a continuous-time OCP methodology, called Differential Programming, provides analytic second-order derivatives by solving a set of coupled matrix ODEs. Interestingly, these matrix ODEs can be augmented to the backward adjoint ODE and solved simultaneously. In other words, a single backward pass is sufficient to compute all derivatives, including the original ASM-based gradient, the newly-derived second-order matrices, or even higher-order tensors. Further, these higher-order computations enjoy the same O(1) memory and a comparable runtime to first-order methods by adopting Kronecker factorization (Martens & Grosse, 2015). The resulting method called SNOpt admits superior convergence in wall-clock time (Fig. 1), and the improvement remains consistent across image classification, continuous normalizing flow, and time-series prediction. Our OCP framework also facilitates progressive training of the network architecture. Specifically, we study an example of jointly optimizing the integration time of Neural ODEs, in analogy to the depth of discrete DNNs. While analytic gradients w.r.t. this architectural parameter have been derived under the ASM framework, they were often evaluated on limited synthetic datasets (Massaroli et al., 2020). In the context of OCP, however, free-horizon optimization is a well-studied problem for practical applications with a priori unknown terminal time (Sun et al., 2015; De Marchi & Gerdts, 2019). In this work, we show that these principles can be applied to Neural ODEs, yielding a novel second-order feedback policy that adapts the integration time throughout training. On training CIFAR10, this further leads to a 20% runtime reduction, yet without hindering test-time accuracy. In summary, we present the following contributions. We propose a novel computational framework for computing higher-order derivatives of deep continuous-time models, with a rigorous analysis using continuous-time Optimal Control theory. We propose an efficient second-order method, SNOpt, that achieves superior convergence (in wall-clock time) over first-order methods in training Neural ODEs, while retaining constant memory complexity. These improvements remain consistent across various applications. To show that our framework also enables direct architecture optimization, we derive a secondorder feedback policy for adapting the integration horizon and show it further reduces the runtime. 2 Preliminaries Notation. We use roman and italic type to represent a variable x(t) and its realization x(t) given an ODE. ODESolve denotes a function call that solves an initial value problem given an initial condition, start and end integration time, and vector field, i.e. ODESolve(x(t0), t0, t1, F) where dx(t) Forward and backward computations of Neural ODEs. Given an initial condition x(t0) and integration interval [t0, t1], Neural ODEs concern the following optimization over an objective L, min θ L(x(t1)), where x(t1) = x(t0) + Z t1 t0 F(t, x(t), θ) dt (2) is the solution of the ODE (1) and can be solved by calling a black-box ODE solver, i.e. x(t1) = ODESolve(x(t0), t0, t1, F). The use of ODESolve allows us to adopt higher-order numerical methods, e.g. adaptive Runge-Kutta (Press et al., 2007), which give more accurate integration compared with e.g. vanilla Euler discretization in residual-based discrete models. To obtain the gradient L θ of Neural ODE, one may naively Back-propagate through ODESolve. This, even if it could be made possible, leads to unsatisfactory memory complexity since the computation graph can grow arbitrarily large for adaptive ODE solvers. Instead, Chen et al. (2018) proposed to apply the Adjoint Sensitivity Method (ASM), which states that the gradient can be obtained through the following integration. L t1 a(t)T F(t, x(t), θ) where a(t) Rm is referred to the adjoint state whose dynamics obey a backward adjoint ODE, dt = a(t)T F(t, x(t), θ) x(t) , a(t1) = L x(t1) . (4) Equations (3, 4) present two coupled ODEs that can be viewed as the continuous-time expression of the Back-propagation (Le Cun et al., 1988). Algorithmically, they can be solved through another call of ODESolve (see Fig. 2) with an augmented dynamics G, i.e. x(t0) a(t0) = ODESolve( x(t1) a(t1) 0 , t1, t0, G), where G F(t, x(t), θ) a(t)T F augments the original dynamics F in (1) with the adjoint ODEs (3, 4). Notice that this computation (5) depends only on (x(t1), a(t1)). This differs from naive Back-propagation, which requires storing intermediate states along the entire computation graph of forward ODESolve. While the latter requires O( e T) memory cost,1 the computation in (5) only consumes constant O(1) memory cost. Chen et al. (2018) noted that if we further encapsulate (5) by θL = grad(L, θ), one may compute higherorder derivatives by recursively calling n L θn = grad( n 1L θn 1 , θ), starting from n=1. This can scale unfavorably due to its recursive dependence and accumulated integration errors. Indeed, Table 1 Table 1: Numerical errors between ground-truth and adjoint derivatives using different ODESolve on CIFAR10. rk4 implicit adams dopri5 θ 7.63 10 5 2.11 10 3 3.44 10 4 θ2 6.84 10 3 2.50 10 1 41.10 suggests that the errors of second-order derivatives, 2L θ2 , obtained from the recursive adjoint procedure can be 2-6 orders of magnitude larger than the ones from the first-order adjoint, L θ . In the next section, we will present a novel optimization framework that computes these higher-order derivatives without any recursion (Section 3.1) and discuss how it can be implemented efficiently (Section 3.2). 3.1 Dynamics of Higher-order Derivatives using Continuous-time Optimal Control Theory OCP perspective is a recently emerging methodology for analyzing optimization of discrete DNNs. Central to its interpretation is to treat the layer propagation of a DNN as discrete-time dynamics, so 1 e T is the number of the adaptive steps used to solve (1), as an analogy of the depth of Neural ODEs. that the training process, i.e. finding an optimal parameter of a DNN, can be understood like an OCP, which searches for an optimal control subjected to a dynamical constraint. This perspective has provided useful insights on characterizing the optimization process (Hu et al., 2019) and enhancing principled algorithmic design (Liu et al., 2021a). We leave a complete discussion in Appendix A.1. Lifting this OCP perspective from discrete DNNs to Neural ODEs requires special treatments from continuous-time OCP theory (Todorov, 2016). Nevertheless, we highlight that training Neural ODEs and solving continuous-time OCP are fundamentally intertwined since these models, by construction, represent continuous-time dynamical systems. Indeed, the ASM used for deriving (3, 4) originates from the celebrated Pontryagin s principle (Pontryagin et al., 1962), which is an optimality condition to OCP. Hence, OCP analysis is not only motivated but principled from an optimization standpoint. We begin by first transforming (2) to a form that is easier to adopt the continuous-time OCP analysis. Φ(xt1) + Z t1 t0 ℓ(t, xt, ut)dt subjected to dxt dt = F(t, xt, ut), xt0 = xt0 dut dt = 0, ut0 = θ , (6) where x(t) xt, and etc. It should be clear that (6) describes (2) without loss of generality by having (Φ, ℓ) := (L, 0). These functions are known as the terminal and intermediate costs in standard OCP. In training Neural ODEs, ℓcan be used to describe either the weight decay, i.e. ℓ ut , or more complex regularization (Finlay et al., 2020). The time-invariant ODE imposed for ut makes the ODE of xt equivalent to (1). Problem (6) shall be understood as a particular type of OCP that searches for an optimal initial condition θ of a time-invariant control ut. Despite seemly superfluous, this is a necessary transformation that enables rigorous OCP analysis for the original training process (2), and it has also appeared in other control-related analyses (Zhong et al., 2020; Chalvidal et al., 2021). Next, define the accumulated loss from any time t [t0, t1] to the integration end time t1 as Q(t, xt, ut) := Φ(xt1) + Z t1 t ℓ(τ, xτ, uτ) dτ, (7) which is also known in OCP as the cost-to-go function. Recall that our goal is to compute higherorder derivatives w.r.t. the parameter θ of Neural ODEs. Under the new OCP representation (6), the first-order derivative L θ is identical to Q(t0,xt0,ut0) ut0 . This is because Q(t0, xt0, ut0) accumulates all sources of losses between [t0, t1] (hence it sufficiently describes L) and ut0 = θ by construction. Likewise, the second-order derivatives can be captured by the Hessian 2Q(t0,xt0,ut0) ut0 ut0 = 2L θ θ Lθθ. In other words, we are only interested in obtaining the derivatives of Q at the integration start time t0. To obtain these derivatives, notice that we can rewrite (7) as 0 = ℓ(t, xt, ut) + d Q(t, xt, ut) dt , Q(t1, xt1) = Φ(xt1), (8) since the definition of Q implies that Q(t, xt, ut) = ℓ(t, xt, ut)dt + Q(t + dt, xt+dt, ut+dt). We now state our main result, which provides a local characterization of (8) with a set of coupled ODEs expanded along a solution path. These ODEs can be used to obtain all second-order derivatives at t0. Theorem 1 (Second-order Differential Programming). Consider a solution path ( xt, ut) that solves the ODEs in (6). Then the first and second-order derivatives of Q(t, xt, ut), expanded locally around this solution path, obey the following backward ODEs: dt = ℓ x + F T x Q x, d Q u dt = ℓ u + F T u Q x, (9a) dt = ℓ x x + F T x Q x x + Q x x F x, d Q x u dt = ℓ x u + Q x x F u + F T x Q x u, (9b) dt = ℓ u u + F T u Q x u + Q u x F u, d Q u x dt = ℓ u x + F T u Q x x + Q u x F x, (9c) where F x(t) F xt |( xt, ut), Q x x(t) 2Q xt xt |( xt, ut), and etc. All terms in (9) are time-varying vectorvalued or matrix-valued functions expanded at ( xt, ut). The terminal condition is given by Q x(t1) = Φ x, Q x x(t1) = Φ x x, and Q u(t1) = Q u u(t1) = Q u x(t1) = Q x u(t1) = 0. The proof (see Appendix A.2) relies on rewriting (8) with differential states, δxt := xt xt, which view the deviation from xt as an optimizing variable (hence the name Differential Programming ). It can be shown that δxt follows a linear ODE expanded along the solution path. Theorem 1 has several important implications. First, the ODEs in (9a) recover the original ASM computation (3,4), as one can readily verify that Q x(t) a(t) follows the same backward ODE in (4) and the solution of the second ODE in (9a), Q u(t0) = R t0 t1 F u TQ xdt, gives the exact gradient in (3). Meanwhile, solving the coupled matrix ODEs presented in (9b, 9c) yields the desired second-order matrix, Q u u(t0) Lθθ, for preconditioning the update. Finally, one can derive the dynamics of other higher-order tensors using the same Differential Programming methodology by simply expanding (8) beyond the second order. We leave some discussions in this regard in Appendix A.2. 3.2 Efficient Second-order Preconditioned Update Theorem 1 provides an attractive computational framework that does not require recursive computation (as mentioned in Section 2) to obtain higher-order derivatives. It suggests that we can obtain first and second-order derivatives all at once with a single function call of ODESolve: [xt0, Q x(t0), Q u(t0), Q x x(t0), Q u x(t0), Q x u(t0), Q u u(t0)] = ODESolve([xt1, Φ x, 0, Φ x x, 0, 0, 0], t1, t0, G), (10) where G augments the original dynamics F in (1) with all 6 ODEs presented in (9). Despite that this OCP-theoretic backward pass (10) retains the same O(1) memory complexity as in (5), the dimension of the new augmented state, which now carries second-order matrices, can grow to an unfavorable size that dramatically slows down the numerical integration. Hence, we must consider other representations of (9), if any, in order to proceed. In the following proposition, we present one of which that transforms (9) into a set of vector ODEs, so that we can compute them much efficiently. Proposition 2 (Low-rank representation of (9)). Suppose ℓ:=0 in (6) and let Q x x(t1)= PR i=1 yi yi be a symmetric matrix of rank R n, where yi Rm and is the Kronecker product. Then, for all t [t0, t1], the second-order matrices appeared in (9b, 9c) can be decomposed into i=1 qi(t) qi(t), Q x u(t) = i=1 qi(t) pi(t), Q u u(t) = i=1 pi(t) pi(t), where the vectors qi(t) Rm and pi(t) Rn obey the following backward ODEs: dt = F x(t)Tqi(t), dpi(t) dt = F u(t)Tqi(t), (11) with the terminal condition given by (qi(t1), pi(t1)) := (yi, 0). The proof is left in Appendix A.2. Proposition 2 gives a nontrivial conversion. It indicates that the coupled matrix ODEs presented in (9b, 9c) can be disentangled into a set of independent vector ODEs where each of them follows its own dynamics (11). As the rank R determines the number of these vector ODEs, this conversion will be particularly useful if the second-order matrices exhibit low-rank structures. Fortunately, this is indeed the case for many Neural-ODE applications which often propagate xt in a latent space of higher dimension (Chen et al., 2018; Grathwohl et al., 2018; Kidger et al., 2020b). Based on Proposition 2, the second-order precondition matrix Lθθ is given by2 Lθθ Q u u(t0) = t1 F u Tqi dt Z t0 t1 F u Tqi dt , (12) where qi qi(t) follows (11). Our final step is to facilitate efficient computation of (12) with Kronecker-based factorization, which underlines many popular second-order methods for discrete DNNs (Grosse & Martens, 2016; Martens et al., 2018). Recall that the vector field F is represented 2 We drop the dependence on t for brevity, yet all terms inside the integrations of (12, 13) are time-varying. Algorithm 1 SNOpt: Second-order Neural ODE Optimizer 1: Input: dataset D, parametrized vector field F( , , θ), integration time [t0, t1], black-box ODE solver ODESolve, learning rate η, rank R, interval of the time grid t 2: repeat 3: Solve x(t1) = ODESolve(x(t0), t0, t1, F), where x(t0) D. Forward pass 4: Initialize ( An, Bn) := (0, 0) for each layer n and set qi(t1) := yi. 5: for t in {t1, t1 t, , t0 + t, t0} do 6: Set t := t t as the small integration step, then call [x(t), Q x(t), Q u(t), {qi(t)}R i=1] = ODESolve([x(t ), Q x(t ), Q u(t ), {qi(t )}R i=1], t , t, b G), Backward pass where b G augments the ODEs of state (1), first and second-order derivatives (9a, 11). 7: Evaluate zn(t), hn(t), F(t, xt, θ), then compute An(t), Bn(t) in (13). 8: Update An An + An(t) t and Bn Bn + Bn(t) t. 9: end for 10: n, apply θn θn η vec( B 1 n Q un(t0) A T n ). Second-order parameter update 11: until converges defined with (1,9a,11) Query time derivatives Sampled time grid Collect sampled matrices ODE solver function call Backward vector field w/ 2nd-order derivatives ODE solution path x(t0) Q x(t0) Q u(t0) x(t1) Q x(t1) Q u(t1) {yi}R i=1 {tj} ( An=P j An(tj) t Bn=P j Bn(tj) t b G b G Figure 4: Our second-order method, SNOpt, solves a new backward ODE, i.e. the b G appeared in line 6 of Alg. 1, which augments second-order derivatives, while simultaneously collecting the matrices An(tj) and Bn(tj) on a sampled time grid {tj} for computing the preconditioned update in (14). F( , , θ) F( , , ut) zn(t) zn+1(t) (t,x(t)) dx(t) ( hn(t) = f(zn(t),un(t)) zn+1(t) = σ(hn(t)) Figure 3: The layer propagation inside the vector field F, where f and σ denote affine and nonlinear activation functions. by a DNN. Let zn(t), hn(t), and un(t) denote the activation vector, pre-activation vector, and the parameter of layer n when evaluating dx dt at time t (see Fig. 3), then the integration in (12) can be broken down into each layer n, Z t0 F u Tqi dt =[ , R t0 t1 F T unqi dt, ] =[ , R t0 t1 hn Tqi) dt, ], where the second equality holds by F T unqi = ( F un )Tqi = zn ( F hn Tqi). This is an essential step towards the Kronecker approximation of the layer-wise precondition matrix: Lθnθn Q un un(t0) = hn Tqi) dt Z t0 | {z } An(t) hn Tqi) ( F | {z } Bn(t) We discuss the approximation behind (13), and also the one for (14), in Appendix A.2. Note that An(t) and Bn(t) are much smaller matrices in Rm m compared to the ones in (9), and they can be efficiently computed with automatic differentiation packages (Paszke et al., 2017). Now, let {tj} be a time grid uniformly distributed over [t0, t1] so that An= P j An(tj) t and Bn= P j Bn(tj) t approximate the integrations in (13), then our final preconditioned update law is given by n, L 1 θnθn Lθn vec B 1 n Q un(t0) A T n , (14) where vec denotes vectorization. Our second-order method named SNOpt is summarized in Alg. 1, with the backward computation (i.e. line 4-9 in Alg. 1) illustrated in Fig. 4. In practice, we also adopt eigen-based amortization with Tikhonov regularization (George et al. (2018); see Alg. 2 in Appendix A.4), which stabilizes the updates over stochastic training. Remark. The fact that Proposition 2 holds only for degenerate ℓcan be easily circumvented in practice. As ℓtypically represents weight decay, ℓ:= 1 t1 t0 θ 2, which is time-independent, it can be separated from the backward ODEs (9) and added after solving the backward integration, i.e. Q u(t0) γθ + Q u(t0), Q u u(t0) γI + Q u u(t0), where γ is the regularization factor. Finally, we find that using the scaled Gaussian-Newton matrix, i.e. Q x x(t1) 1 t1 t0 Φ x Φ x, generally provides a good trade-off between the performance and runtime complexity. As such, we adopt this approximation to Proposition 2 for all experiments. 3.3 Memory Complexity Analysis Table 2: Memory complexity at different stages of our derivation in terms of xt Rm, θ Rn, and the rank R. Note that all methods have O(1) in terms of depth. Theorem 1 Proposition 2 SNOpt (Alg. 1) first-order adjoint Eqs. (9,10) Eqs. (11,12) Eqs. (13,14) Eqs. (3,4) backward storage O((m + n)2) O(Rm + Rn) O(Rm + 2n) O(m + n) parameter update O(n2) O(n2) O(2n) O(n) Table 2 summarizes the memory complexity of different computational methods that appeared along our derivation in Section 3.1 and 3.2. Despite that all methods retain O(1) memory as with the first-order adjoint method, their complexity differs in terms of the state and parameter dimension. Starting from our encouraging result in Theorem 1, which allows one to compute all derivatives with a single backward pass, we first exploit their low-rank representation in Proposition 2. This reduces the storage to O(Rm + Rn) and paves a way toward adopting Kronecker factorization, which further facilitates efficient preconditioning. With all these, our SNOpt is capable of performing efficient second-order updates while enjoying similar memory complexity (up to some constant) compared to first-order adjoint methods. Lastly, for image applications where Neural ODEs often consist of convolution layers, we adopt convolution-based Kronecker factorization (Grosse & Martens, 2016; Gao et al., 2020), which effectively makes the complexity to scale w.r.t. the number of feature maps (i.e. number of channels) rather than the full size of feature maps. 3.4 Extension to Architecture Optimization Relative train time (%) Accuracy (%) Figure 5: Training performance of CIFAR10 with Adam when using different t1, which motivates joint optimization of t1. Experiment setup is left in Appendix A.4. Let us discuss an intriguing extension of our OCP framework to optimizing the architecture of Neural ODEs, specifically the integration bound t1. In practice, when problems contain no prior information on the integration, [t0, t1] is typically set to some trivial values (usually [0, 1]) without further justification. However, these values can greatly affect both the performance and runtime. Take CIFAR10 for instance (see Fig. 5), the required training time decreases linearly as we drop t1 from 1, yet the accuracy retains mostly the same unless t1 becomes too small. Similar results also appear on MNIST (see Fig. 12 in Appendix A.5). In other words, we may interpret the integration bound t1 as an architectural parameter that needs to be jointly optimized during training. The aforementioned interpretation fits naturally into our OCP framework. Specifically, we can consider the following extension of Q, which introduces the terminal time T as a new variable: e Q(t, xt, ut, T) := eΦ(T, x(T)) + Z T t ℓ(τ, xτ, uτ) dτ, (15) where eΦ(T, x(T)) explicitly imposes the penalty for longer integration time, e.g. eΦ := Φ(x(T)) + c 2T2. Following a similar procedure presented in Section 3.1, we can transform (15) into its ODE form (as in (8)) then characterize its local behavior (as in (9)) along a solution path ( xt, ut, T). After some tedious derivations, which are left in Appendix A.3, we will arrive at the update rule below, T T η δT(δθ), where δT(δθ) = [ e Q T T (t0)] 1 e Q T (t0) + e Q T u(t0)δθ . (16) Similar to what we have discussed in Section 3.1, one shall view e Q T (t0) L T as the first-order derivative w.r.t. the terminal time T. Likewise, e Q T T (t0) 2L T T, and etc. Equation (16) is a secondorder feedback policy that adjusts its updates based on the change of the parameter θ. Intuitively, it moves in the descending direction of the preconditioned gradient (i.e. e Q 1 T T e Q T ), while accounting for the fact that θ is also progressing during training (via the feedback e Q T uδθ). The latter is a distinct feature arising from the OCP principle. As we will show later, this update (16) leads to distinct behavior with superior convergence compared to first-order baselines (Massaroli et al., 2020). 4 Experiments time-series observation Neural ODE GRU cell Linear mapping Figure 6: Hybrid model for time-series prediction. Table 3: Sample size of time-series datasets (input dimension, class label, series length) Spo AD Art WR Char T (27, 10, 93) (19, 25, 144) (7, 20, 187) Dataset. We select 9 datasets from 3 distinct applications where N-ODEs have been applied, including image classification ( ), time-series prediction ( ), and continuous normalizing flow ( ; CNF): MNIST, SVHN, CIFAR10: MNIST consists of 28 28 gray-scale images, while SVHN and CIFAR10 consist of 3 32 32 colour images. All 3 image datasets have 10 label classes. Spo AD, Art WR, Char T: We consider UEA time series archive (Bagnall et al., 2018). Spoken Arabic Digits (Spo AD) is a speech dataset, whereas Articulary Word Recognition (Art WR) and Character Trajectories (Char T) are motion-related datasets. Table 3 details their sample sizes. Circle, Gas, Miniboone: Circle is a 2-dim synthetic dataset adopted from Chen et al. (2018). Gas and Miniboone are 8 and 43-dim tabular datasets commonly used in CNF (Grathwohl et al., 2018; Onken et al., 2020). All 3 datasets transform a multivariate Gaussian to the target distributions. Models. The models for image datasets and CNF resemble standard feedforward networks, except now consisting of Neural ODEs as continuous transformation layers. Specifically, the models for image classification consist of convolution-based feature extraction, followed by a Neural ODE and linear mapping. Meanwhile, the CNF models are identical to the ones in Grathwohl et al. (2018), which consist of 1-5 Neural ODEs, depending on the size of the dataset. As for the time-series models, we adopt the hybrid models from Rubanova et al. (2019), which consist of a Neural ODE for hidden state propagation, standard recurrent cell (e.g. GRU (Cho et al., 2014)) to incorporate incoming time-series observation, and a linear prediction layer. Figure 6 illustrates this process. We detail other configurations in Appendix A.4. ODE solver. We use standard Runge-Kutta 4(5) adaptive solver (dopri5; Dormand & Prince (1980)) implemented by the torchdiffeq package. The numerical tolerance is set to 1e-6 for CNF and 1e-3 for the rest. We fix the integration time to [0, 1] whenever it appears as a hyper-parameter (e.g. for image and CNF datasets3); otherwise we adopt the problem-specific setup (e.g. for time series). Training setup. We consider Adam and SGD (with momentum) as the first-order baselines since they are default training methods for most Neural-ODE applications. As for our second-order SNOpt, we set up the time grid {tj} such that it collects roughly 100 samples along the backward integration to estimate the precondition matrices (see Fig. 4). The hyper-parameters (e.g. learning rate) are tuned for each method on each dataset, and we detail the tuning process in Appendix A.4. We also employ practical acceleration techniques, including the semi-norm (Kidger et al., 2020a) for speeding up ODESolve, and the Jacobian-free estimator (FFJORD; Grathwohl et al. (2018)) for accelerating CNF models. The batch size is set to 256, 512, and 1000 respectively for Art Word, Char Traj, and Gas. The rest of the datasets use 128 as the batch size. All experiments are conducted on a TITAN RTX. 4.1 Results Convergence and computation efficiency. Figures 1 and 7 report the training curves of each method measured by wall-clock time. It is obvious that our SNOpt admits a superior convergence 3 except for Circle where we set [t0, t1]:=[0, 10] in order to match the original setup in Chen et al. (2018). Accuracy (%) 0 10k 20k Wall-Clock Time (sec) 0 10k 20k Wall-Clock Time (sec) Accuracy (%) SNOpt (ours) 0 0.7k 1.4k 2.1k 0 0.7k 1.4k 2.1k Accuracy (%) 0 2k 4k 6k Wall-Clock Time (sec) 0 1.2k 2.4k 3.6k Wall-Clock Time (sec) Spo AD Training Loss and Accuracy Char T Training Loss and Accuracy SVHN Training Loss and Accuracy Figure 7: Training performance in wall-clock runtime, averaged over 3 trials. Our SNOpt achieves faster convergence against first-order baselines. See Fig. 14 in Appendix A.5 for MNIST and Circle. Table 4: Test-time performance: accuracies for image and time-series datasets; NLL for CNF datasets MNIST SVHN CIFAR10 Spo AD Art WR Char T Circle Gas Miniboone Adam 98.83 91.92 77.41 94.64 84.14 93.29 0.90 -6.42 13.10 SGD 98.68 93.34 76.42 97.70 85.82 95.93 0.94 -4.58 13.75 SNOpt 98.99 95.77 79.11 97.41 90.23 96.63 0.86 -7.55 12.50 Mn Sv Cf SA AW CT Cl Ga Mi Rel. Runtime Computation Efficiency of SNOpt w.r.t. Adam Mn Sv Cf SA AW CT Cl Ga Mi Rel. Memory Figure 8: Relative runtime and memory of our SNOpt compared to Adam (denoted by the dashed black lines) on all 9 datasets, where Mn is the shorthand for MNIST, and etc. 10 4 10 2 100 102 SNOpt (ours) Accuracy (%) Train Loss Train Loss Learning Rate (Adam) Learning Rate (SGD) Learning Rate (SNOpt) Figure 9: Sensitivity analysis where each sample represents a training result using different optimizer and learning rate (annotated by different symbol and color). Our SNOpt achieves higher accuracies and is insensitive to hyperparameter changes. Note that x-axes are in log scale. rate compared to the first-order baselines, and in many cases exceeds their performances by a large margin. In Fig. 8, we report the computation efficiency of our SNOpt compared to Adam on each dataset, and leave their numerical values in Appendix A.4 (Table 9 and 10). For image and time-series datasets (i.e. Mn~CT), our SNOpt runs nearly as fast as first-order methods. This is made possible through a rigorous OCP analysis in Section 3, where we showed that second-order matrices can be constructed along with the same backward integration when we compute the gradient. Hence, only a minimal overhead is introduced. As for CNF, which propagates the probability density additional to the vanilla state dynamics, our SNOpt is roughly 1.5 to 2.5 times slower, yet it still converges faster in the overall wall-clock time (see Fig. 7). On the other hand, the use of second-order matrices increases the memory consumption of SNOpt by 10-40%, depending on the model and dataset. However, the actual increase in memory (less than 1GB for all datasets; see Table 10) remains affordable on standard GPU machines. More importantly, our SNOpt retains the O(1) memory throughout training. Test-time performance and hyper-parameter sensitivity. Table 4 reports the test-time performance, including the accuracies (%) for image and time-series classification, and the negative log-likelihood (NLL) for CNF. On most datasets, our method achieves competitive results against standard baselines. In practice, we also find that using the preconditioned updates greatly reduce the sensitivity to hyper-parameters (e.g. learning rate). This is demonstrated in Fig. 9, where we sample distinct learning rates from a proper interval for each method (shown with different color bars) and record their training results after convergence. It is clear that our method not only converges to higher Table 5: Performance of jointly optimizing the integration bound t1 on CIFAR10 Method Train time (%) w.r.t. t1=1.0 Accuracy (%) ASM baseline 96 76.61 SNOpt (ours) 81 77.82 t1 Optimization ASM baseline SNOpt (ours) Train Iteration Figure 10: Dynamics of t1 over CIFAR10 training using different methods. Table 6: Measure of implicit regularization on SVHN # of function Regularization evaluation (NFE) ( R x F 2 + R F 2) Adam 42.1 323.88 SNOpt 32.6 199.1 Mn Sv Cf SA AW CT 1 Relative Runtime ( ) Mn Sv Cf SA AW CT 0 5 10 15 Accuracy Improvement (%) ( ) SNOpt Recursive Adjoint Recursive Adjoint / SNOpt Figure 11: Comparison between SNOpt and secondorder recursive adjoint. SNOpt is at least 2 times faster and improves the accuracies of baselines by 5-15%. accuracies with lower losses, these values are also more concentrated on the plots. In other words, our method achieves better convergence in a more consistent manner across different hyper-parameters. Joint optimization of the integration bound t1. Table 5 and Fig. 10 report the performance of optimizing t1 along with its convergence dynamics. Specifically, we compare our second-order feedback policy (16) derived in Section 3.4 to the first-order ASM baseline proposed in Massaroli et al. (2020). It is clear that our OCP-theoretic method leads to substantially faster convergence, and the optimized t1 stably hovers around 0.5 without deviation (as appeared for the baseline). This drops the training time by nearly 20% compared to the vanilla training, where we fix t1 to 1.0, yet without sacrificing the test-time accuracy. A similar experiment for MNIST (see Fig. 13 in Appendix A.5) shows a consistent result. We highlight these improvements as the benefit gained from introducing the well-established OCP principle to these emerging deep continuous-time models. Comparison with recursive adjoint. Finally, Fig. 11 reports the comparison between our SNOpt and the recursive adjoint baseline (see Section 2 and Table 1). It is clear that our method outperforms this second-order baseline by a large margin in both runtime efficiency and test-time performance. Note that we omit the comparison on CNF datasets since the recursive adjoint simply fails to converge. Remark (Implicit regularization). In some cases (e.g. SVHN in Fig. 8), our method may run slightly faster than first-order methods. This is a distinct phenomenon arising exclusively from training these continuous-time models. Since their forward and backward passes involve solving parameterized ODEs (see Fig. 2), the computation graphs are parameter-dependent; hence adaptive throughout training. In this vein, we conjecture that the preconditioned updates in these cases may have guided the parameter to regions that are numerically stabler (hence faster) for integration.4 With this in mind, we report in Table 6 the value of Jacobian, R x F 2, and Kinetic, R F 2, regularization (Finlay et al., 2020) in SVHN training. Interestingly, the parameter found by our SNOpt indeed has a substantially lower value (hence stronger regularization and better-conditioned ODE dynamics) compared to the one found by Adam. This provides a plausible explanation of the reduction in the NFE when using our method, yet without hindering the test-time performance (see Table 4). 5 Conclusion We present an efficient higher-order optimization framework for training Neural ODEs. Our method named SNOpt differs from existing second-order methods in various aspects. While it leverages similar factorization inherited in Kronecker-based methods (Martens & Grosse, 2015), the two methodologies differ fundamentally in that we construct analytic ODE expressions for higher-order derivatives (Theorem 1) and compute them through ODESolve. This retains the favorable O(1) memory as opposed to their O(T). It also enables a flexible rank-based factorization in Proposition 2. Meanwhile, our method extends the recent trend of OCP-inspired methods (Li et al., 2017; Liu et al., 2021b) to deep continuous-time models, yet using a rather straightforward framework without imposing additional assumptions, such as Markovian or game transformation. To summarize, our work advances several methodologies to the emerging deep continuous-time models, achieving strong empirical results and opening up new opportunities for analyzing models such as Neural SDEs/PDEs. 4 In Appendix A.4, we provide some theoretical discussions (see Corollary 9) in this regard. Acknowledgments and Disclosure of Funding The authors would like to thank Chia-Wen Kuo and Chen-Hsuan Lin for the meticulous proofreading, and Keuntaek Lee for providing additional computational resources. Guan-Horng Liu was supported by CPS NSF Award #1932068, and Tianrong Chen was supported by ARO Award #W911NF2010151. Almubarak, H., Sadegh, N., and Taylor, D. G. Infinite horizon nonlinear quadratic cost regulator. In 2019 American Control Conference (ACC), pp. 5570 5575. IEEE, 2019. Amari, S.-i. and Nagaoka, H. Methods of information geometry, volume 191. American Mathematical Soc., 2000. Ba, J., Grosse, R., and Martens, J. Distributed second-order optimization using kronecker-factored approximations. 2016. Bagnall, A., Dau, H. A., Lines, J., Flynn, M., Large, J., Bostrom, A., Southam, P., and Keogh, E. The uea multivariate time series classification archive, 2018. ar Xiv preprint ar Xiv:1811.00075, 2018. Botev, A., Ritter, H., and Barber, D. Practical gauss-newton optimisation for deep learning. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 557 565. JMLR. org, 2017. Chalvidal, M., Ricci, M., Van Rullen, R., and Serre, T. Go with the flow: Adaptive control for neural odes. 2021. Chen, T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D. K. Neural ordinary differential equations. In Advances in Neural Information Processing Systems, pp. 6572 6583, 2018. Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., and Bengio, Y. Learning phrase representations using rnn encoder-decoder for statistical machine translation. ar Xiv preprint ar Xiv:1406.1078, 2014. De Marchi, A. and Gerdts, M. Free finite horizon lqr: a bilevel perspective and its application to model predictive control. Automatica, 100:299 311, 2019. Desjardins, G., Simonyan, K., Pascanu, R., and Kavukcuoglu, K. Natural neural networks. ar Xiv preprint ar Xiv:1507.00210, 2015. Dormand, J. R. and Prince, P. J. A family of embedded runge-kutta formulae. Journal of computational and applied mathematics, 6(1):19 26, 1980. Finlay, C., Jacobsen, J.-H., Nurbekyan, L., and Oberman, A. How to train your neural ode: the world of jacobian and kinetic regularization. In International Conference on Machine Learning, pp. 3154 3164. PMLR, 2020. Gao, K.-X., Liu, X.-L., Huang, Z.-H., Wang, M., Wang, Z., Xu, D., and Yu, F. A trace-restricted kronecker-factored approximation to natural gradient. ar Xiv preprint ar Xiv:2011.10741, 2020. George, T., Laurent, C., Bouthillier, X., Ballas, N., and Vincent, P. Fast approximate natural gradient descent in a kronecker factored eigenbasis. In Advances in Neural Information Processing Systems, pp. 9550 9560, 2018. Gholami, A., Keutzer, K., and Biros, G. Anode: Unconditionally accurate memory-efficient gradients for neural odes. ar Xiv preprint ar Xiv:1902.10298, 2019. Ghosh, A., Behl, H. S., Dupont, E., Torr, P. H., and Namboodiri, V. Steer: Simple temporal regularization for neural odes. ar Xiv preprint ar Xiv:2006.10711, 2020. Grathwohl, W., Chen, R. T., Betterncourt, J., Sutskever, I., and Duvenaud, D. Ffjord: Free-form continuous dynamics for scalable reversible generative models. ar Xiv preprint ar Xiv:1810.01367, 2018. Grosse, R. and Martens, J. A kronecker-factored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, pp. 573 582, 2016. Gupta, V., Koren, T., and Singer, Y. Shampoo: Preconditioned stochastic tensor optimization. In International Conference on Machine Learning, pp. 1842 1850. PMLR, 2018. Hu, K., Kazeykina, A., and Ren, Z. Mean-field langevin system, optimal control and deep neural networks. ar Xiv preprint ar Xiv:1909.07278, 2019. Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International conference on machine learning, pp. 448 456. PMLR, 2015. Kelly, J., Bettencourt, J., Johnson, M. J., and Duvenaud, D. Learning differential equations that are easy to solve. ar Xiv preprint ar Xiv:2007.04504, 2020. Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., and Tang, P. T. P. On large-batch training for deep learning: Generalization gap and sharp minima. ar Xiv preprint ar Xiv:1609.04836, 2016. Kidger, P., Chen, R. T., and Lyons, T. " hey, that s not an ode": Faster ode adjoints with 12 lines of code. ar Xiv preprint ar Xiv:2009.09457, 2020a. Kidger, P., Morrill, J., Foster, J., and Lyons, T. Neural controlled differential equations for irregular time series. ar Xiv preprint ar Xiv:2005.08926, 2020b. Laurent, C., George, T., Bouthillier, X., Ballas, N., and Vincent, P. An evaluation of fisher approximations beyond kronecker factorization. 2018. Le Cun, Y., Touresky, D., Hinton, G., and Sejnowski, T. A theoretical framework for back-propagation. In Proceedings of the 1988 connectionist models summer school, volume 1, pp. 21 28. CMU, Pittsburgh, Pa: Morgan Kaufmann, 1988. Li, Q., Chen, L., Tai, C., and Weinan, E. Maximum principle based algorithms for deep learning. The Journal of Machine Learning Research, 18(1):5998 6026, 2017. Liu, G.-H. and Theodorou, E. A. Deep learning theory review: An optimal control and dynamical systems perspective. ar Xiv preprint ar Xiv:1908.10920, 2019. Liu, G.-H., Chen, T., and Theodorou, E. A. Ddpnopt: Differential dynamic programming neural optimizer. In International Conference on Learning Representations, 2021a. Liu, G.-H., Chen, T., and Theodorou, E. A. Dynamic game theoretic neural optimizer. In International Conference on Machine Learning, 2021b. Lou, A., Lim, D., Katsman, I., Huang, L., Jiang, Q., Lim, S.-N., and De Sa, C. Neural manifold ordinary differential equations. ar Xiv preprint ar Xiv:2006.10254, 2020. Ma, L., Montague, G., Ye, J., Yao, Z., Gholami, A., Keutzer, K., and Mahoney, M. W. Inefficiency of k-fac for large batch size training. ar Xiv preprint ar Xiv:1903.06237, 2019. Martens, J. New insights and perspectives on the natural gradient method. ar Xiv preprint ar Xiv:1412.1193, 2014. Martens, J. and Grosse, R. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning, pp. 2408 2417, 2015. Martens, J., Ba, J., and Johnson, M. Kronecker-factored curvature approximations for recurrent neural networks. In International Conference on Learning Representations, 2018. Massaroli, S., Poli, M., Park, J., Yamashita, A., and Asama, H. Dissecting neural odes. ar Xiv preprint ar Xiv:2002.08071, 2020. Mathieu, E. and Nickel, M. Riemannian continuous normalizing flows. ar Xiv preprint ar Xiv:2006.10605, 2020. Nguyen, T. M., Garg, A., Baraniuk, R. G., and Anandkumar, A. Infocnf: An efficient conditional continuous normalizing flow with adaptive solvers. ar Xiv preprint ar Xiv:1912.03978, 2019. Onken, D., Fung, S. W., Li, X., and Ruthotto, L. Ot-flow: Fast and accurate continuous normalizing flows via optimal transport. ar Xiv preprint ar Xiv:2006.00104, 2020. Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., De Vito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch. 2017. Pontryagin, L. S., Mishchenko, E., Boltyanskii, V., and Gamkrelidze, R. The mathematical theory of optimal processes. 1962. Press, W. H., William, H., Teukolsky, S. A., Vetterling, W. T., Saul, A., and Flannery, B. P. Numerical recipes 3rd edition: The art of scientific computing. Cambridge university press, 2007. Rubanova, Y., Chen, R. T., and Duvenaud, D. Latent odes for irregularly-sampled time series. ar Xiv preprint ar Xiv:1907.03907, 2019. Santurkar, S., Tsipras, D., Ilyas, A., and Madry, A. How does batch normalization help optimization? ar Xiv preprint ar Xiv:1805.11604, 2018. Schacke, K. On the kronecker product. Master s thesis, University of Waterloo, 2004. Sun, W., Theodorou, E., and Tsiotras, P. Model based reinforcement learning with final time horizon optimization. ar Xiv preprint ar Xiv:1509.01186, 2015. Tassa, Y., Mansard, N., and Todorov, E. Control-limited differential dynamic programming. In 2014 IEEE International Conference on Robotics and Automation (ICRA), pp. 1168 1175. IEEE, 2014. Theodorou, E., Tassa, Y., and Todorov, E. Stochastic differential dynamic programming. In Proceedings of the 2010 American Control Conference, pp. 1125 1132. IEEE, 2010. Todorov, E. Optimal control theory. Bayesian brain: probabilistic approaches to neural coding, pp. 269 298, 2016. Weinan, E. A proposal on machine learning via dynamical systems. Communications in Mathematics and Statistics, 5(1):1 11, 2017. Weinan, E., Han, J., and Li, Q. A mean-field optimal control formulation of deep learning. ar Xiv preprint ar Xiv:1807.01083, 2018. Wu, Y., Zhu, X., Wu, C., Wang, A., and Ge, R. Dissecting hessian: Understanding common structure of hessian in neural networks. ar Xiv preprint ar Xiv:2010.04261, 2020. Zhang, G., Martens, J., and Grosse, R. Fast convergence of natural gradient descent for overparameterized neural networks. ar Xiv preprint ar Xiv:1905.10961, 2019. Zhong, Y. D., Dey, B., and Chakraborty, A. Symplectic ode-net: Learning hamiltonian dynamics with control. 2020. Zhuang, J., Dvornek, N., Li, X., Tatikonda, S., Papademetris, X., and Duncan, J. Adaptive checkpoint adjoint method for gradient estimation in neural ode. In International Conference on Machine Learning, pp. 11639 11649. PMLR, 2020. Zhuang, J., Dvornek, N. C., Tatikonda, S., and Duncan, J. S. Mali: A memory efficient and reverse accurate integrator for neural odes. ar Xiv preprint ar Xiv:2102.04668, 2021.