# continuoustime_metalearning_with_forward_mode_differentiation__5f0592aa.pdf Published as a conference paper at ICLR 2022 CONTINUOUS-TIME META-LEARNING WITH FORWARD MODE DIFFERENTIATION Tristan Deleu David Kanaa Leo Feng Giancarlo Kerg Yoshua Bengio 1,2 Guillaume Lajoie 2 Pierre-Luc Bacon 2 Mila Université de Montréal Drawing inspiration from gradient-based meta-learning methods with infinitely small gradient steps, we introduce Continuous-Time Meta-Learning (COMLN), a meta-learning algorithm where adaptation follows the dynamics of a gradient vector field. Specifically, representations of the inputs are meta-learned such that a taskspecific linear classifier is obtained as a solution of an ordinary differential equation (ODE). Treating the learning process as an ODE offers the notable advantage that the length of the trajectory is now continuous, as opposed to a fixed and discrete number of gradient steps. As a consequence, we can optimize the amount of adaptation necessary to solve a new task using stochastic gradient descent, in addition to learning the initial conditions as is standard practice in gradient-based meta-learning. Importantly, in order to compute the exact meta-gradients required for the outer-loop updates, we devise an efficient algorithm based on forward mode differentiation, whose memory requirements do not scale with the length of the learning trajectory, thus allowing longer adaptation in constant memory. We provide analytical guarantees for the stability of COMLN, we show empirically its efficiency in terms of runtime and memory usage, and we illustrate its effectiveness on a range of few-shot image classification problems. 1 INTRODUCTION Among the existing meta-learning algorithms, gradient-based methods as popularized by Model Agnostic Meta-Learning (MAML, Finn et al., 2017) have received a lot of attention over the past few years. They formulate the problem of learning a new task as an inner optimization problem, typically based on a few steps of gradient descent. An outer meta-optimization problem is then responsible for updating the meta-parameters of this learning process, such as the initialization of the gradient descent procedure. However since the updates at the outer level typically require backpropagating through the learning process, this class of methods has often been limited to only a few gradient steps of adaptation, due to memory constraints. Although solutions have been proposed to alleviate the memory requirements of these algorithms, including checkpointing (Baranchuk, 2019), using implicit differentiation (Rajeswaran et al., 2019), or reformulating the meta-learning objective (Flennerhag et al., 2018), they are generally either more computationally demanding, or only approximate the gradients of the meta-learning objective (Nichol et al., 2018; Flennerhag et al., 2020). In this work, we propose a continuous-time formulation of gradient-based meta-learning, called Continuous-Time Meta-Learning (COMLN), where the adaptation is the solution of a differential equation (see Figure 1). Moving to continuous time allows us to devise a novel algorithm, based on forward mode differentiation, to efficiently compute the exact gradients for meta-optimization, no matter how long the adaptation to a new task might be. We show that using forward mode differentiation leads to a stable algorithm, unlike the counterpart of backpropagation in continuous time called the adjoint method (frequently used in the Neural ODE literature; Chen et al., 2018) which tends to be unstable in conjunction with gradient vector fields. Moreover as the length of Correspondence to: Tristan Deleu 1CIFAR Senior Fellow, 2CIFAR AI Chair Code is available at: https://github.com/tristandeleu/jax-comln Published as a conference paper at ICLR 2022 dt = L W (t) Wt+1 = Wt α L(Wt) (a) Gradient-based Meta-Learning (b) COMLN Figure 1: Illustration of the adaptation process in (a) a gradient-based meta-learning algorithm, such as ANIL (Raghu et al., 2019), where the adapted parameters WT are given after T steps of gradient descent, and in (b) Continuous-Time Meta-Learning (COMLN), where the adapted parameters W (T) are the result of following the dynamics of a gradient vector field up to time T. the adaptation trajectory is a continuous quantity, as opposed to a discrete number of gradient steps fixed ahead of time, we can treat the amount of adaptation in COMLN as a meta-parameter on par with the initialization which we can meta-optimize using stochastic gradient descent. We verify empirically that our method is both computationally and memory efficient, and we show that COMLN outperforms other standard meta-learning algorithms on few-shot image classification datasets. 2 BACKGROUND In this work, we consider the problem of few-shot classification, that is the problem of learning a classification model with only a small number of training examples. More precisely for a classification task τ, we assume that we have access to a (small) training dataset Dtrain τ = {(xm, ym)}M m=1 to fit a model on task τ, and a distinct test dataset Dtest τ to evaluate how well this adapted model generalizes on that task. In the few-shot learning literature, it is standard to consider the problem of k-shot N-way classification, meaning that the model has to classify among N possible classes, and there are only k examples of each class in Dtrain τ , so that overall the number of training examples is M = k N. We use the convention that the target labels ym {0, 1}N are one-hot vectors. 2.1 GRADIENT-BASED META-LEARNING Gradient-based meta-learning methods aim to learn an initialization such that the model is able to adapt to a new task via gradient descent. Such methods are often cast as a bi-level optimization process: adapting the task-specific parameters θ in the inner loop, and training the (task-agnostic) meta-parameters Φ and initialization θ0 in the outer loop. The meta-learning objective is: min θ0,Φ Eτ L(θτ T , Φ; Dtest τ ) (1) s.t. θτ t+1 = θτ t α θL(θτ t , Φ; Dtrain τ ) θτ 0 = θ0 τ p(τ), (2) where T is the number of inner loop updates. For example, in the case of MAML (Finn et al., 2017), there is no additional meta-parameter other than the initialization (Φ ); in ANIL (Raghu et al., 2019), θ are the parameters of the last layer, and Φ are the parameters of the shared embedding network; in CAVIA (Zintgraf et al., 2019), θ are referred to as context parameters. During meta-training, the model is trained over many tasks τ. The task-specific parameters θ are learned via gradient descent on Dtrain τ . The meta-parameters are then updated by evaluating the error of the trained model on the test dataset Dtest τ . At meta-testing time, the meta-trained model is adapted on Dtrain τ i.e. applying (2) with the learned meta-parameters θ0 and Φ. Published as a conference paper at ICLR 2022 2.2 LOCAL SENSITIVITY ANALYSIS OF ORDINARY DIFFERENTIAL EQUATIONS Consider the following (autonomous) Ordinary Differential Equation (ODE): dz dt = g z(t); θ z(0) = z0(θ), (3) where the dynamics g and initial value z0 may depend on some external parameters θ, and integration is carried out from 0 to some time T. Local sensitivity analysis is the study of how the solution of this dynamical system responds to local changes in θ; this effectively corresponds to calculating the derivative dz(t)/dθ. We present here two methods to compute this derivative, with a special focus on their memory efficiency. Adjoint sensitivity method Based on the adjoint state (Pontryagin, 2018), and taking its root in control theory (Lions & Magenes, 2012), the adjoint sensitivity method (Bryson & Ho, 1969; Chavent et al., 1974) provides an efficient approach for evaluating derivatives of L z(T); θ , a function of z(T) the solution of the ODE in (3). This method, popularized lately by the literature on Neural ODEs (Chen et al., 2018), requires the integration of the adjoint equation da dt = a(t) g z(t); θ z(t) a(T) = d L z(T); θ dz(T) , (4) backward in time. The adjoint sensitivity method can be viewed as a continuous-time counterpart to backpropagation, where the forward pass would correspond to integrating (3) forward in time from 0 to T, and the backward pass to integrating (4) backward in time from T to 0. One possible implementation, reminiscent of backpropagation through time (BPTT), is to store the intermediate values of z(t) during the forward pass, and reuse them to compute the adjoint state during the backward pass. While several sophisticated checkpointing schemes have been proposed (Serban & Hindmarsh, 2003; Gholami et al., 2019), with different compute/memory trade-offs, the memory requirements of this approach typically grow with T; this is similar to the memory limitations standard gradient-based meta-learning methods suffer from as the number of gradient steps increases. An alternative is to augment the adjoint state a(t) with the original state z(t), and to solve this augmented dynamical system backward in time (Chen et al., 2018). This has the notable advantage that the memory requirements are now independent of T, since z(t) are no longer stored during the forward pass, but they are recomputed on the fly during the backward pass. Forward sensitivity method While the adjoint method is related to reverse-mode automatic differentiation (backpropagation), the forward sensitivity method (Feehery et al., 1997; Leis & Kramer, 1988; Maly & Petzold, 1996; Caracotsios & Stewart, 1985), on the other hand, can be viewed as the continuous-time counterpart to forward (tangent-linear) mode differentiation (Griewank & Walther, 2008). This method is based on the fact that the derivative S(t) dz(t)/dθ is the solution of the so-called forward sensitivity equation dt = g z(t); θ z(t) S(t) + g z(t); θ θ S(0) = z0 This equation can be found throughout the literature in optimal control and system identification (Betts, 2010; Biegler, 2010). Unlike the adjoint method, which requires an explicit forward and backward pass, the forward sensitivity method only requires the integration forward in time of the original ODE in (3), augmented by the sensitivity state S(t) with the dynamics above. The memory requirements of the forward sensitivity method do not scale with T either, but it now requires storing S(t), which may be very large; we will come back to this problem in Section 3.2. We will simply note here that in discrete-time, this is the same issue afflicting forward-mode training of RNNs with real-time recurrent learning (RTRL; Williams & Zipser, 1989), or other meta-learning algorithms (Sutton, 1992; Franceschi et al., 2017; Xu et al., 2018). 3 CONTINUOUS-TIME ADAPTATION In the limit of infinitely small steps, some optimization algorithms can be viewed as the solution trajectory of a differential equation. This point of view has often been taken to analyze their behavior Published as a conference paper at ICLR 2022 Forward Backward Figure 2: Numerical instability of the adjoint method applied to the gradient vector field of a quadratic loss function. The trajectory in green starting at W (0) corresponds to the integration of the dynamical system in (8) forward in time up to T, and the trajectory in red starting at W (T) corresponds to its integration backward in time. Note that T was chosen so that W (T) does not reach the equilibrium/minimum of the loss W . (Platt & Barr, 1988; Wilson et al., 2016; Su et al., 2014; Orvieto & Lucchi, 2019). In fact, some optimization algorithms such as gradient descent with momentum have even been introduced initially from the perspective of dynamical systems (Polyak, 1964). As the simplest example, gradient descent with a constant step size α 0+ (i.e. α tends to 0 by positive values) corresponds to following the dynamics of an autonomous ODE called a gradient vector field zt+1 = zt α f(zt) α 0+ dz dt = f z(t) , (6) where the iterate z(t) is now a continuous function of time t. The solution of this dynamical system is uniquely defined by the choice of the initial condition z(0) = z0. 3.1 CONTINUOUS-TIME META-LEARNING In gradient-based meta-learning, the task-specific adaptation with gradient descent may also be replaced by a gradient vector field in the limit of infinitely small steps. Inspired by prior work in meta-learning (Raghu et al., 2019; Javed & White, 2019), we assume that an embedding network fΦ with meta-parameters Φ is shared across tasks, and only the parameters W of a task-specific linear classifier are adapted, starting at some initialization W0. Instead of being the result of a few steps of gradient descent though, the final parameters W (T) now correspond to integrating an ODE similar to (6) up to a certain horizon T, with the initial conditions W (0) = W0. We call this new meta-learning algorithm Continuous-Time Meta-Learning1 (COMLN). Treating the learning algorithm as a continuous-time process has the notable advantage that the adapted parameter W (T) is now differentiable wrt. the time horizon T (Wiggins, 2003, Chap. 7), in addition to being differentiable wrt. the initial conditions W0 which plays a central role in gradientbased meta-learning, as described in Section 2.1. This allows us to view T as a meta-parameter on par with Φ and W0, and to effectively optimize the amount of adaptation using stochastic gradient descent (SGD). The meta-learning objective of COMLN can be written as min Φ,W0,T Eτ L Wτ(T); fΦ(Dtest τ ) (7) dt = L Wτ(t); fΦ(Dtrain τ ) Wτ(0) = W0 τ p(τ), (8) where fΦ(Dtrain τ ) = {(fΦ(xm), ym) | (xm, ym) Dtrain τ } is the embedded training dataset, and fΦ(Dtest τ ) is defined similarly for Dtest τ . In practice, adaptation is implemented using a numerical integration scheme based on an iterative discretization of the problem, such as Runge-Kutta methods. Although a complete discussion of numerical solvers is outside of the scope of this paper, we recommend (Butcher, 2008) for a comprehensive overview of numerical methods for solving ODEs. 3.2 THE CHALLENGES OF OPTIMIZING THE META-LEARNING OBJECTIVE In order to minimize the meta-learning objective of COMLN, it is common practice to use (stochastic) gradient methods; that requires computing its derivatives wrt. the meta-parameters, which we call meta-gradients. Our primary goal is to devise an algorithm whose memory requirements do not scale with the amount of adaptation T; this would contrast with standard gradient-based meta-learning methods that backpropagate through a sequence of gradient steps (similar to BPTT), where the 1COMLN is pronounced chameleon. Published as a conference paper at ICLR 2022 intermediate parameters are stored during adaptation (i.e. θτ t for all t in (2)). Since this objective involves the solution W (T) of an ODE, we can use either the adjoint method, or the forward sensitivity method, in order to compute the derivatives wrt. Φ and W0 (see Section 2.2). Although the adjoint method has proven to be an effective strategy for learning Neural ODEs, in practice computing the state W (t) backward in time is numerically unstable when applied to a gradient vector field like the one in (8), even for convex loss functions. Figure 2 shows an example where the trajectory of W (t) recomputed backward in time (in red) diverges significantly from the original trajectory (in green) on a quadratic loss function, even though the two should match exactly in theory since they follow the same dynamics. Intuitively, recomputing W (t) backward in time for a gradient vector field requires doing gradient ascent on the loss function, which is prone to compounding numerical errors; this is closely related to the loss of entropy observed by Maclaurin et al. (2015). This divergence makes the backward trajectory of W (t) unreliable to find the adjoint state, ruling out the adjoint sensitivity method for computing the meta-gradients in COMLN. The forward sensitivity method addresses this shortcoming by avoiding the backward pass altogether. However, it can also be particularly expensive here in terms of memory requirements, since the sensitivity state S(t) in Section 2.2 now corresponds to Jacobian matrices, such as d W (t)/d W0. As the size d of the feature vectors returned by fΦ may be very large, this Nd Nd Jacobian matrix would be almost impossible to store in practice; for example in our experiments, it can be as large as d = 16,000 for a Res Net-12 backbone. In Section 4.1, we will show how to apply forward sensitivity in a memory-efficient way, by leveraging the structure of the loss function. This is achieved by carefully decomposing the Jacobian matrices into smaller pieces that follow specific dynamics. We show in Appendix D that unlike the adjoint method, this process is stable. 3.3 CONNECTION WITH ALMOST NO INNER-LOOP (ANIL) Similarly to ANIL (Raghu et al., 2019), COMLN only adapts the parameters W of the last linear layer of the neural network. There is a deeper connection between both algorithms though: while our description of the adaptation in COMLN (Eq. 8) was independent of the choice of the ODE solver used to find the solution W (T) in practice, if we choose an explicit Euler scheme (Euler, 1913; roughly speaking, discretizing (6) from right to left), then the adaptation of COMLN becomes functionally equivalent to ANIL. However, this equivalence can greatly benefit from the memoryefficient algorithm to compute the meta-gradients described in Section 4, based on the forward sensitivity method. This means that using the methods devised here for COMLN, we can effectively compute the meta-gradients of ANIL with a constant memory cost wrt. the number of gradient steps of adaptation, instead of relying on backpropagation (see also Section 4.2). 4 MEMORY-EFFICIENT META-GRADIENTS For some fixed task τ and (xm, ym) Dtrain τ , let φm = fΦ(xm) Rd be the embedding of input xm through the feature extractor fΦ. Since we are confronted with a classification problem, the loss function of choice L(W ) is typically the cross-entropy loss. Böhning (1992) showed that the gradient of the cross-entropy loss wrt. W can be written as L W ; fΦ(Dtrain τ ) = 1 m=1 (pm ym)φ m, (9) where pm = softmax(W φm) is the vector of probabilities returned by the neural network. The key observation here is that the gradient can be decomposed as a sum of M rank-one matrices, where the feature vectors φm are independent of W . Therefore we can fully characterize the gradient of the cross-entropy loss with M vectors pm ym RN, as opposed to the full N d matrix. This is particularly useful in the context of few-shot classification, where the number of training examples M is small, and typically significantly smaller than the embedding size d. 4.1 DECOMPOSITION OF THE META-GRADIENTS We saw in Section 3.2 that the forward sensitivity method was the only stable option to compute the meta-gradients of COMLN. However, naively applying the forward sensitivity equation would involve Published as a conference paper at ICLR 2022 quantities that typically scale with d2, which can be too expensive in practice. Using the structure of (9), the Jacobian matrices appearing in the computation of the meta-gradients for COMLN can be decomposed in such a way that only small quantities will depend on time. Meta-gradients wrt. W0 By the chain rule of derivatives, it is sufficient to compute the Jacobian matrix d W (T)/d W0 in order to obtain the meta-gradient wrt. W0. We show in Appendix B.2 that the sensitivity state d W (t)/d W0 can be decomposed as: j=1 Bt[i, j] φiφ j , (10) where is the Kronecker product, and each Bt[i, j] is an N N matrix, solution of the following system of ODEs2 dt = 1(i = j)Ai(t) Ai(t) m=1 (φ i φm)Bt[m, j] B0[i, j] = 0, (11) and Ai(t), defined in Appendix B.2, is also an N N matrix that only depends on W (t) and φi. The main consequence of this decomposition is that we can simply integrate the augmented ODE in W (t), Bt[i, j] up to T to obtain the desired Jacobian matrix, along with the adapted parameters W (T). Furthermore, in contrast to naively applying the forward sensitivity method (see Section 3.2), the M 2 matrices Bt[i, j] are significantly smaller than the full Jacobian matrix. In fact, we show in Appendix C that we can compute vector-Jacobian products required for the chain rule using only these smaller matrices, and without ever having to explicitly construct the full Nd Nd Jacobian matrix d W (t)/d W0 with (10). Meta-gradients wrt. Φ To backpropagate the error through the embedding network fΦ, we need to first compute the gradients of the outer-loss wrt. the feature vectors φm. Again, by the chain rule, we can get these gradients with the Jacobian matrices d W (T)/dφm. Similar to (10), we can show that these Jacobian matrices can be decomposed as: dφm = sm(t) I + i=1 Bt[i, m]W0 φi + j=1 zt[i, j, m]φ j φi where sm(t) and zt[i, j, m] are vectors of size N, that follow some dynamics; the exact form of this system of ODEs, as well as the proof of this decomposition, are given in Appendix B.3. Crucially, the only quantities that depend on time are small objects independent of the embedding size d. Following the same strategy as above, we can incorporate these vectors in the augmented ODE, and integrate it to get the necessary Jacobians. Once all the d W (t)/dφm are known, for all the training datapoints, we can apply standard backpropagation through fΦ to obtain the meta-gradients wrt. Φ. Meta-gradient wrt. T One of the major novelties of COMLN is the capacity to meta-learn the amount of adaptation using stochastic gradient descent. To compute the meta-gradient wrt. the time horizon T, we can directly borrow the results derived by Chen et al. (2018) in the context of Neural ODEs, and apply it to our gradient vector field in (8) responsible for adaptation: d L W (T); fΦ(Dtest τ ) d T = L W (T); fΦ(Dtest τ ) L W (T); fΦ(Dtrain τ ) W (T) . (13) The proof is available in Appendix B.4. Interestingly, we find that this involves the alignment between the vectors of partial derivatives of the inner-loss and the outer-loss at W (T), which appeared in various contexts in the meta-learning and the multi-task learning literature (Li et al., 2018; Rothfuss et al., 2019; Yu et al., 2020; Von Oswald et al., 2021). 2Here we used the notation Bt[i, j] to make the dependence on t explicit, without overloading the notation. A more precise notation would be B[i, j](t). Published as a conference paper at ICLR 2022 Table 1: Memory required to compute meta-gradients for different algorithms. Exact: the method returns the exact meta-gradients. Full net.: the whole network is adapted, with a number of metaparameters |θ|. The requirements for checkpointing are taken from (Shaban et al., 2019). Note that typically M d in few-shot learning. Model Exact Full net. Memory MAML (Finn et al., 2017) O(|θ| T) ANIL (Raghu et al., 2019) O(Nd T) Checkpointing (every T steps) O(|θ| T) i MAML (Rajeswaran et al., 2019) O(|θ|) Forward sensitivity (naive) O(N 2d2 + MNd2) COMLN O(M 2N 2 + M 3N) 4.2 MEMORY EFFICIENCY Although naively applying the forward sensitivity method would be memory intensive, we have shown in Section 4.1 that the Jacobians can be carefully decomposed into smaller pieces. It turns out that even the parameters W (t) can be expressed using the vectors sm(t) from the decomposition in (12); see Appendix B.1 for details. As a consequence, to compute the adapted parameters W (T) as well as all the necessary meta-gradients, it is sufficient to integrate a dynamical system in Bt[i, j], sm(t), zt[i, j, m] (see Algorithms 1 & 2 in App. A.1), involving exclusively quantities that are independent of the embedding size d. Instead, the size of that system scales with M the total number of training examples, which is typically much smaller than d for few-shot classification. Table 1 shows a comparison of the memory cost for different algorithms. It is important to note that contrary to other standard gradient-based meta-learning methods, the memory requirements of COMLN do not scale with the amount of adaptation T (i.e. the number of gradient steps in MAML & ANIL), while still returning the exact meta-gradients unlike i MAML (Rajeswaran et al., 2019), which only returns an approximation of the meta-gradients. We verified empirically this efficiency, both in terms of memory and computation costs, in Section 5.2. 5 EXPERIMENTS For our embedding network fΦ, we consider two commonly used architectures in meta-learning: Conv-4, a convolutional neural network with 4 convolutional blocks, and Res Net-12, a 12-layer residual network (He et al., 2016). Note that following Lee et al. (2019), Res Net-12 does not include a global pooling layer at the end of the network, leading to feature vectors with embedding dimension d = 16,000. Additional details about these architectures are given in Appendix E. To compute the adapted parameters and the meta-gradients in COMLN, we integrate the dynamical system described in Section 4.2 with a 4th order Runge-Kutta method with a Dormand Prince adaptive step size (Runge, 1895; Dormand & Prince, 1980); we will come back to the choice of this numerical solver in Section 5.2. Furthermore to ensure that T > 0, we parametrized it with an exponential activation. 5.1 FEW-SHOT IMAGE CLASSIFICATION We evaluate COMLN on two standard few-shot image classification benchmarks: the mini Image Net (Vinyals et al., 2016) and the tiered Image Net datasets (Ren et al., 2018), both datasets being derived from ILSVRC-2012 (Russakovsky et al., 2015). The process for creating tasks follows the standard procedure from the few-shot classification literature (Santoro et al., 2016), with distinct classes between the different splits. mini Imagenet consists of 100 classes, split into 64 training classes, 16 validation classes, and 20 test classes. tiered Image Net consists of 608 classes grouped into 34 high-level categories from ILSVRC-2012, split into 20 training, 6 validation, and 8 testing categories corresponding to 351/97/160 classes respectively; Ren et al. (2018) argue that separating data according to high-level categories results in a more difficult and more realistic regime. Table 2 shows the average accuracies of COMLN compared to various meta-learning methods, be it gradient-based or not. For both backbones, COMLN decisively outperforms all other gradient-based Published as a conference paper at ICLR 2022 Table 2: Few-shot classification on mini Image Net & tiered Image Net. The average accuracy (%) on 1,000 held-out meta-test tasks is reported with 95% confidence interval. denotes gradient-based meta-learning algorithms. denotes baseline results we executed using the official implementations. Model Backbone mini Image Net 5-way tiered Image Net 5-way 1-shot 5-shot 1-shot 5-shot MAML (Finn et al., 2017) Conv-4 48.70 1.84 63.11 0.92 51.67 1.81 70.30 1.75 ANIL (Raghu et al., 2019) Conv-4 46.30 0.40 61.00 0.60 49.35 0.26 65.82 0.12 Meta-SGD (Li et al., 2017) Conv-4 50.47 1.87 64.03 0.94 52.80 0.44 62.35 0.26 CAVIA (Zintgraf et al., 2019) Conv-4 51.82 0.65 65.85 0.55 52.41 2.64 67.55 2.05 i MAML (Rajeswaran et al., 2019) Conv-4 49.30 1.88 59.77 0.73 38.54 1.37 60.24 0.76 Meta Opt Net-RR (Lee et al., 2019) Conv-4 53.23 0.59 69.51 0.48 54.63 0.67 72.11 0.59 Meta Opt Net-SVM (Lee et al., 2019) Conv-4 52.87 0.57 68.76 0.48 54.71 0.67 71.79 0.59 COMLN (Ours) Conv-4 53.01 0.62 70.54 0.54 54.30 0.69 71.35 0.57 MAML (Finn et al., 2017) Res Net-12 49.92 0.65 63.93 0.59 55.37 0.74 72.93 0.60 ANIL (Raghu et al., 2019) Res Net-12 49.65 0.65 59.51 0.56 54.77 0.76 69.28 0.67 Meta Opt Net-RR (Lee et al., 2019) Res Net-12 61.41 0.61 77.88 0.46 65.36 0.71 81.34 0.52 Meta Opt Net-SVM (Lee et al., 2019) Res Net-12 62.64 0.61 78.63 0.46 65.99 0.72 81.56 0.53 COMLN (Ours) Res Net-12 59.26 0.65 77.26 0.49 62.93 0.71 81.13 0.53 meta-learning methods. Compared to methods that explicitly backpropagate through the learning process, such as MAML or ANIL, the performance gain shown by COMLN could be credited to the longer adaptation T it learns, as opposed to a small number of gradient steps usually about 10 steps; this was fully enabled by our memory-efficient method to compute meta-gradients, which does not scale with the length of adaptation anymore (see Section 4.2). We analyse the evolution of T during meta-training for these different settings in Appendix E.3. In almost all settings, COMLN is even closing the gap with a strong non-gradient-based method like Meta Opt Net; the remainder may be explained in part by the training choices made by Lee et al. (2019) (see Appendix E for details). 5.2 EMPIRICAL EFFICIENCY OF COMLN In Section 4.2, we showed that our algorithm to compute the meta-gradients, based on forward differentiation, had a memory cost independent of the length of adaptation T. We verify this empirically in Figure 3, where we compare the memory required by COMLN and other methods to compute the meta-gradients on a single task, with a Conv-4 backbone (Figure 4 in Appendix E.2 shows similar results for Res Net-12). To ensure an aligned comparison between discrete and continuous time, we use a conversion corresponding to a learning rate α = 0.01 in (2); see Appendix E.2 for a justification. As expected, the memory cost increases for both MAML and ANIL as the number of gradient steps increases, while it remains constant for i MAML and COMLN. Interestingly, we observe that the cost of COMLN is equivalent to the cost of running ANIL for a small number of steps, showing that the additional cost of integrating the augmented ODE in Section 4.2 to compute the meta-gradients is minimal. Increasing the length of adaptation also has an impact on the time it takes to compute the adapted parameters, and the meta-gradients. Figure 3 (right) shows how the runtime increases with the amount of adaptation for different algorithms. We see that the efficiency of COMLN depends on the numerical solver used. When we use a simple explicit-Euler scheme, the time taken to compute the meta-gradients matches the one of ANIL; this behavior empirically confirms our observation in Section 3.3. When we use an adaptive numerical solver, such as Runge-Kutta (RK) with a Dormand Prince step size, this computation can be significantly accelerated, thanks to the smaller number of function evaluations. In practice, we show in Appendix E.1 that the choice of the ODE solver has a very minimal impact on the accuracy. Published as a conference paper at ICLR 2022 100 101 102 103 104 105 106 107 108 Number of gradient steps Memory usage (in Gb) MAML i MAML 100 101 102 103 104 105 106 107 108 101 Number of gradient steps Runtime (in ms) ANIL COMLN (Euler) 10 2 10 1 100 101 102 103 104 105 106 T 10 2 10 1 100 101 102 103 104 105 106 T Figure 3: Empirical efficiency of COMLN on a single 5-shot 5-way task, with a Conv-4 backbone. (Left) Memory usage for computing the meta-gradients as a function of the number of inner-gradient steps. The extrapolated dashed lines correspond to the method reaching the memory capacity of a Tesla V100 GPU with 32Gb of memory. (Right) Average time taken (in ms) to compute the exact meta-gradients. The extrapolated dashed lines correspond to the method taking over 2 seconds. 6 RELATED WORK We are interested in meta-learning (Bengio et al., 1991; Schmidhuber, 1987; Thrun & Pratt, 2012), and in particular we focus on gradient-based meta-learning methods (Finn, 2018), where the learning rule is based on gradient descent. While in MAML (Finn et al., 2017) the whole network was updated during this process, follow-up works have shown that it is generally sufficient to share most parts of the neural network, and to only adapt a few layers (Raghu et al., 2019; Chen et al., 2020b; Tian et al., 2020). Even though this hypothesis has been challenged recently (Arnold & Sha, 2021), COMLN also updates only the last layer of a neural network, and therefore can be viewed as a continuous-time extension of ANIL (Raghu et al., 2019); see also Section 3.3. With its shared embedding network across tasks, COMLN is also connected to metric-based meta-learning methods (Vinyals et al., 2016; Snell et al., 2017; Sung et al., 2018; Bertinetto et al., 2018; Lee et al., 2019). Zhang et al. (2021) introduced a formulation where the adaptation of prototypes follows a gradient vector field, but finally opted for modeling it as a Neural ODE (Chen et al., 2018). Concurrent to our work, Li et al. (2021) also propose a formulation of adaptation based on a gradient vector field, and use the adjoint method to compute the meta-gradients, despite the challenges we identified in Sec. 3.2; Li et al. (2021) also acknowledge these challenges, and they limit their analysis to relatively small values of T (in comparison to the ones learned by COMLN), hence further from the equilibrium, to circumvent this issue altogether. Zhou et al. (2021) also uses a gradient vector field to motivate a novel method with a closed-form adaptation; COMLN still explicitly updates the parameters following the gradient vector field, since there is no closed-form solution of (8). As mentioned in Section 3, treating optimization as a continuous-time process has been used to analyze the convergence of different optimization algorithms, including the meta-optimization of MAML (Xu et al., 2021), or to introduce new meta-optimizers based on different integration schemes (Im et al., 2019). Guo et al. (2021) also uses meta-learning to learn new integration schemes for ODEs. Although this is a growing literature at the intersection of meta-learning and dynamical systems, our work is the first algorithm that uses a gradient vector field for adaptation in meta-learning (see also Li et al. (2021)). Beyond the memory efficiency of our method, one of the main benefits of the continuous-time perspective is that COMLN is capable of learning when to stop the adaptation, as opposed to taking a number of gradient steps fixed ahead of time. However unlike Chen et al. (2020a), where the number of gradient steps are optimized (up to a maximal number) with variational methods, we incorporate the amount of adaptation as a (continuous) meta-parameter that can be learned using SGD. To compute the meta-gradients, which is known to be challenging for long sequences in gradient-based meta-learning, we use forward-mode differentiation as an alternative to backpropagation through the learning process, similar to prior work in meta-learning (Franceschi et al., 2017; Jiwoong Im et al., 2021) and hyperparameter optimization over long horizons (Micaelli & Storkey, 2021). This yields the exact meta-gradients in constant memory, without any assumption on the optimality of the inner optimization problem, which is necessary when using the normal equations (Bertinetto et al., 2018), or to apply implicit differentiation (Rajeswaran et al., 2019). Published as a conference paper at ICLR 2022 7 CONCLUSION AND FUTURE WORK In this paper, we have introduced Continuous-Time Meta-Learning (COMLN), a novel algorithm that treats the adaptation in meta-learning as a continuous-time process, by following the dynamics of a gradient vector field up to a certain time horizon T. One of the major novelties of treating adaptation in continuous time is that the amount of adaptation T is now a continuous quantity, that can be viewed as a meta-parameter and can be learned using SGD, alongside the initial conditions and the parameters of the embedding network. In order to learn these meta-parameters, we have also introduced a novel practical algorithm based on forward mode automatic differentiation, capable of efficiently computing the exact meta-gradients using an augmented dynamical system. We have verified empirically that this algorithm was able to compute the meta-gradients in constant memory, making it the first gradient-based meta-learning approach capable of computing the exact metagradients with long sequences of adaptation using gradient methods. In practice, we have shown that COMLN significantly outperforms other standard gradient-based meta-learning algorithms. In addition to having a single meta-parameter T that drives the adaptation of all possible tasks, the fact that the time horizon can be learned with SGD opens up new possibilities for gradient-based methods. For example, we could imagine treating T not as a shared meta-parameters, but as a task-specific parameter. This would allow the learning process to be more adaptive, possibly with different behaviors depending on the difficulty of the task. This is left as a future direction of research. ACKNOWLEDGEMENTS The authors are grateful to Samsung Electronics Co., Ldt., CIFAR, and IVADO for their funding and Calcul Québec and Compute Canada for providing us with the computing resources. Published as a conference paper at ICLR 2022 REPRODUCIBILITY STATEMENT We provide in Appendix A.1 a full description in pseudo-code of the meta-training procedure (Algorithm 1), along with the exact dynamics of the ODE (Algorithm 2) and the projection operations (Algorithms 3 & 4) to avoid explicitly building the Jacobian matrices to compute Jacobian-vector products (see Section 4.1). We also provide in Appendix A.2 a snippet of code in JAX (Bradbury et al., 2018) to compute the adapted parameters W (T), as well as all the necessary objects Bt[i, j], sm(t), zt[i, j, m] to compute all the meta-gradients (see Section 4.2). We also give in Code Snippet 2 the code to compute the meta-gradients wrt. the initialization W0 and the integration time T. Computing the meta-gradients wrt. Φ involves non-minimal dependencies on Haiku (Hennigan et al., 2020), and therefore is omitted here. The full code is available at https://github.com/tristandeleu/jax-comln. Data generation & hyperparameters We used the mini Image Net and tiered Image Net datasets provided by Lee et al. (2019) in order to create the 1-shot 5-way and 5-shot 5-way tasks for both datasets. During evaluation, a fixed set of 1,000 tasks was sampled for each setting; this means that both architectures for COMLN have been evaluated using exactly the same data, to ensure direct comparison across backbones. A full description of all the hyperparameters used in COMLN is given in Appendix E. Reproducibility of baseline results To the best of our ability, we have tried to report baseline results from existing work, to limit as much as possible the bias induced by running our own baseline experiments. The references of those works are given in Table 3. We still had to run CAVIA and i MAML on the remaining settings, since these results have not been reported in the literature. For both methods, we used the data generation described above. CAVIA: We used the official implementation3. We used the hyperparameters reported in (Zintgraf et al., 2019) for mini Image Net, and an architecture with 64 filters. i MAML: We used the official implementation4. We used the hyperparameters reported in (Rajeswaran et al., 2019) for mini Image Net 1-shot 5-way. Table 3: References for the results provided in Table 2: (Liu et al., 2019), (Oh et al., 2021), (Aimen et al., 2021), (Arnold et al., 2021), and are reported in their respective references (under Model). Recall that denotes baseline results we executed using the official implementations. Model Backbone mini Image Net 5-way tiered Image Net 5-way 1-shot 5-shot 1-shot 5-shot MAML (Finn et al., 2017) Conv-4 48.70 1.84 63.11 0.92 51.67 1.81 70.30 1.75 ANIL (Raghu et al., 2019) Conv-4 46.30 0.40 61.00 0.60 49.35 0.26 65.82 0.12 Meta-SGD (Li et al., 2017) Conv-4 50.47 1.87 64.03 0.94 52.80 0.44 62.35 0.26 CAVIA (Zintgraf et al., 2019) Conv-4 51.82 0.65 65.85 0.55 52.41 2.64 67.55 2.05 i MAML (Rajeswaran et al., 2019) Conv-4 49.30 1.88 59.77 0.73 38.54 1.37 60.24 0.76 Meta Opt Net-RR (Lee et al., 2019) Conv-4 53.23 0.59 69.51 0.48 54.63 0.67 72.11 0.59 Meta Opt Net-SVM (Lee et al., 2019) Conv-4 52.87 0.57 68.76 0.48 54.71 0.67 71.79 0.59 COMLN (Ours) Conv-4 53.01 0.62 70.54 0.54 54.30 0.69 71.35 0.57 MAML (Finn et al., 2017) Res Net-12 49.92 0.65 63.93 0.59 55.37 0.74 72.93 0.60 ANIL (Raghu et al., 2019) Res Net-12 49.65 0.65 59.51 0.56 54.77 0.76 69.28 0.67 Meta Opt Net-RR (Lee et al., 2019) Res Net-12 61.41 0.61 77.88 0.46 65.36 0.71 81.34 0.52 Meta Opt Net-SVM (Lee et al., 2019) Res Net-12 62.64 0.61 78.63 0.46 65.99 0.72 81.56 0.53 COMLN (Ours) Res Net-12 59.26 0.65 77.26 0.49 62.93 0.71 81.13 0.53 3https://github.com/lmzintgraf/cavia/ 4https://github.com/aravindr93/imaml_dev Published as a conference paper at ICLR 2022 Aroof Aimen, Sahil Sidheekh, and Narayanan C Krishnan. Task Attended Meta-Learning for Few-Shot Learning. ar Xiv preprint, 2021. Sébastien MR Arnold and Fei Sha. Embedding Adaptation is Still Needed for Few-Shot Learning. ar Xiv preprint, 2021. Sébastien MR Arnold, Guneet S Dhillon, Avinash Ravichandran, and Stefano Soatto. Uniform Sampling over Episode Difficulty. ar Xiv preprint, 2021. Dmitry Baranchuk. Memory Efficient MAML, 2019. URL https://github.com/ dbaranchuk/memory-efficient-maml. Yoshua Bengio, Samy Bengio, Jocelyn Cloutier, and Jan Gecsei. Learning a Synaptic Learning Rule. International Joint Conference on Neural Networks, 1991. Luca Bertinetto, Joao F Henriques, Philip HS Torr, and Andrea Vedaldi. Meta-learning with Differentiable Closed-Form Solvers. ar Xiv preprint, 2018. John T Betts. Practical Methods for Optimal Control and Estimation Using Nonlinear Programming. SIAM, 2010. Lorenz T. Biegler. Nonlinear Programming. Society for Industrial and Applied Mathematics, January 2010. Dankmar Böhning. Multinomial Logistic Regression Algorithm. Annals of the Institute of Statistical Mathematics, 1992. James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, and Skye Wanderman-Milne. JAX: composable transformations of Python+Num Py programs, 2018. URL http://github.com/google/jax. A. E. Bryson and Y. C. Ho. Applied Optimal Control. Blaisdell, New York, 1969. John Charles Butcher. Numerical Methods for Ordinary Differential Equations. Wiley, 2008. Makis Caracotsios and Warren E Stewart. Sensitivity analysis of initial value problems with mixed odes and algebraic equations. Computers & Chemical Engineering, 9(4):359 365, 1985. G Chavent, RE Goodson, and M Polis. Identification of parameter distributed systems. Identification of function parameters in partial differential equations, pp. 31 48, 1974. Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. Neural Ordinary Differential Equations. Advances in Neural Information Processing Systems, 2018. Xinshi Chen, Hanjun Dai, Yu Li, Xin Gao, and Le Song. Learning To Stop While Learning To Predict. In International Conference on Machine Learning, 2020a. Yutian Chen, Abram L Friesen, Feryal Behbahani, Arnaud Doucet, David Budden, Matthew W Hoffman, and Nando de Freitas. Modular Meta-Learning with Shrinkage. Neural Information Processing Systems, 2020b. John R Dormand and Peter J Prince. A family of embedded Runge-Kutta formulae. Journal of computational and applied mathematics, 1980. Leonhard Euler. De integratione aequationum differentialium per approximationem. Opera Omnia, 1913. William F Feehery, John E Tolsma, and Paul I Barton. Efficient sensitivity analysis of large-scale differential-algebraic systems. Applied Numerical Mathematics, 25(1):41 54, 1997. Chelsea Finn. Learning to Learn with Gradients. Ph D thesis, UC Berkeley, 2018. Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. International Conference on Machine Learning (ICML), 2017. Published as a conference paper at ICLR 2022 Sebastian Flennerhag, Pablo G Moreno, Neil D Lawrence, and Andreas Damianou. Transferring knowledge across learning processes. ar Xiv preprint, 2018. Sebastian Flennerhag, Andrei A Rusu, Razvan Pascanu, Francesco Visin, Hujun Yin, and Raia Hadsell. Meta-Learning with Warped Gradient Descent. International Conference on Learning Representations, 2020. Luca Franceschi, Michele Donini, Paolo Frasconi, and Massimiliano Pontil. Forward and reverse gradient-based hyperparameter optimization. In International Conference on Machine Learning, 2017. Golnaz Ghiasi, Tsung-Yi Lin, and Quoc V Le. Dropblock: A regularization method for convolutional networks. In Neural Information Processing Systems, 2018. Amir Gholami, Kurt Keutzer, and George Biros. ANODE: unconditionally accurate memory-efficient gradients for neural odes. Co RR, abs/1902.10298, 2019. URL http://arxiv.org/abs/ 1902.10298. Andreas Griewank and Andrea Walther. Evaluating Derivatives. Society for Industrial and Applied Mathematics, January 2008. Yue Guo, Felix Dietrich, Tom Bertalan, Danimir T Doncevic, Manuel Dahmen, Ioannis G Kevrekidis, and Qianxiao Li. Personalized Algorithm Generation: A Case Study in Meta-Learning ODE Integrators. ar Xiv preprint, 2021. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep Residual Learning for Image Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2016. Tom Hennigan, Trevor Cai, Tamara Norman, and Igor Babuschkin. Haiku: Sonnet for JAX, 2020. URL http://github.com/deepmind/dm-haiku. Daniel Jiwoong Im, Yibo Jiang, and Nakul Verma. Model-Agnostic Meta-Learning using Runge Kutta Methods. ar Xiv preprint, 2019. Khurram Javed and Martha White. Meta-Learning Representations for Continual Learning. In Advances in Neural Information Processing Systems, 2019. Daniel Jiwoong Im, Cristina Savin, and Kyunghyun Cho. Online hyperparameter optimization by Real-Time Recurrent Learning. ar Xiv preprint, 2021. Kwonjoon Lee, Subhransu Maji, Avinash Ravichandran, and Stefano Soatto. Meta-learning with Differentiable Convex Optimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019. Jorge R Leis and Mark A Kramer. The simultaneous solution and sensitivity analysis of systems described by ordinary differential equations. ACM Transactions on Mathematical Software (TOMS), 14(1):45 60, 1988. Da Li, Yongxin Yang, Yi-Zhe Song, and Timothy Hospedales. Learning to Generalize: Meta-Learning for Domain Generalization. In Proceedings of the AAAI Conference on Artificial Intelligence, 2018. Shibo Li, Zheng Wang, Akil Narayan, Robert Kirby, and Shandian Zhe. Meta-Learning with Adjoint Methods. ar Xiv preprint, 2021. Zhenguo Li, Fengwei Zhou, Fei Chen, and Hang Li. Meta-SGD: Learning to learn quickly for few-shot learning. ar Xiv preprint, 2017. Jacques Louis Lions and Enrico Magenes. Non-homogeneous boundary value problems and applications: Vol. 1, volume 181. Springer Science & Business Media, 2012. Yanbin Liu, Juho Lee, Minseop Park, Saehoon Kim, Eunho Yang, Sung Ju Hwang, and Yi Yang. Learning to Propagate Labels: Transductive Propagation Network for Few-Shot Learning. International Conference on Learning Representations, 2019. Published as a conference paper at ICLR 2022 Dougal Maclaurin, David Duvenaud, and Ryan Adams. Gradient-based hyperparameter optimization through reversible learning. In International conference on machine learning, pp. 2113 2122. PMLR, 2015. Timothy Maly and Linda R Petzold. Numerical methods and software for sensitivity analysis of differential-algebraic systems. Applied Numerical Mathematics, 20(1-2):57 79, 1996. Paul Micaelli and Amos Storkey. Gradient-based Hyperparameter Optimization Over Long Horizons. In Neural Information Processing Systems, 2021. Alex Nichol, Joshua Achiam, and John Schulman. On First-Order Meta-Learning Algorithms. ar Xiv preprint, 2018. Jaehoon Oh, Hyungjun Yoo, Chang Hwan Kim, and Se-Young Yun. BOIL: Towards Representation Change for Few-Shot Learning. International Conference on Learning Representations, 2021. Antonio Orvieto and Aurelien Lucchi. Shadowing Properties of Optimization Algorithms. In Advances in Neural Information Processing Systems, 2019. John Platt and Alan Barr. Constrained Differential Optimization. In Neural Information Processing Systems, 1988. Boris T Polyak. Some methods of speeding up the convergence of iteration methods. USSR computational mathematics and mathematical physics, 1964. Lev Semenovich Pontryagin. Mathematical theory of optimal processes. Routledge, 2018. Aniruddh Raghu, Maithra Raghu, Samy Bengio, and Oriol Vinyals. Rapid learning or feature reuse? towards understanding the effectiveness of maml. ar Xiv preprint, 2019. Aravind Rajeswaran, Chelsea Finn, Sham M Kakade, and Sergey Levine. Meta-Learning with Implicit Gradients. In Advances in Neural Information Processing Systems, 2019. Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, Kevin Swersky, Joshua B Tenenbaum, Hugo Larochelle, and Richard S Zemel. Meta-learning for semi-supervised few-shot classification. In International Conference on Learning Representations, 2018. Jonas Rothfuss, Dennis Lee, Ignasi Clavera, Tamim Asfour, and Pieter Abbeel. Pro MP: Proximal Meta-Policy Search. International Conference on Learning Representations, 2019. Carl Runge. Über die numerische auflösung von differentialgleichungen. Mathematische Annalen, 1895. Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet large scale visual recognition challenge. International journal of computer vision, 115(3):211 252, 2015. Andrei A Rusu, Dushyant Rao, Jakub Sygnowski, Oriol Vinyals, Razvan Pascanu, Simon Osindero, and Raia Hadsell. Meta-learning with Latent Embedding Optimization. ar Xiv preprint, 2018. Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, and Timothy Lillicrap. One-shot learning with memory-augmented neural networks. ar Xiv preprint, 2016. Jürgen Schmidhuber. Evolutionary principles in self-referential learning, or on learning how to learn: the meta-meta-... hook. Ph D thesis, Technische Universität München, 1987. Radu Serban and Alan C Hindmarsh. Cvodes: An ode solver with sensitivity analysis capabilities. Technical report, Technical Report UCRL-JP-200039, Lawrence Livermore National Laboratory, 2003. Amirreza Shaban, Ching-An Cheng, Nathan Hatch, and Byron Boots. Truncated Back-propagation for Bilevel Optimization. In International Conference on Artificial Intelligence and Statistics, 2019. Published as a conference paper at ICLR 2022 Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical Networks for Few-shot Learning. In Advances in Neural Information Processing Systems, 2017. Weijie Su, Stephen Boyd, and Emmanuel Candes. A Differential Equation for Modeling Nesterov s Accelerated Gradient Method: Theory and Insights. Advances in Neural Information Processing Systems, 2014. Flood Sung, Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales. Learning to compare: Relation network for few-shot learning. In Proceedings of the IEEE conference on computer vision and pattern recognition, 2018. Richard S. Sutton. Adapting Bias by Gradient Descent: An Incremental Version of Delta-Bar-Delta. In Proceedings of the 10th National Conference on Artificial Intelligence, 1992. Sebastian Thrun and Lorien Pratt. Learning to learn. Springer Science & Business Media, 2012. Yonglong Tian, Yue Wang, Dilip Krishnan, Joshua B Tenenbaum, and Phillip Isola. Rethinking Few-Shot Image Classification: a Good Embedding Is All You Need? 2020. Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Daan Wierstra, et al. Matching networks for one shot learning. Neural Information Processing Systems, 29:3630 3638, 2016. Johannes Von Oswald, Dominic Zhao, Seijin Kobayashi, Simon Schug, Massimo Caccia, Nicolas Zucchet, and João Sacramento. Learning where to learn: Gradient sparsity in meta and continual learning. Advances in Neural Information Processing Systems, 2021. Stephen Wiggins. Introduction to Applied Nonlinear Dynamical Systems and Chaos, volume 2. Springer, 2003. Ronald J. Williams and David Zipser. A Learning Algorithm for Continually Running Fully Recurrent Neural Networks. Neural Computation, 1989. Ashia C Wilson, Benjamin Recht, and Michael I Jordan. A Lyapunov Analysis of Momentum Methods in Optimization. ar Xiv preprint, 2016. Ruitu Xu, Lin Chen, and Amin Karbasi. Meta Learning in the Continuous Time Limit . In Proceedings of The 24th International Conference on Artificial Intelligence and Statistics. PMLR, 2021. Zhongwen Xu, Hado van Hasselt, and David Silver. Meta-Gradient Reinforcement Learning. In Advances in Neural Information Processing Systems, 2018. Tianhe Yu, Saurabh Kumar, Abhishek Gupta, Sergey Levine, Karol Hausman, and Chelsea Finn. Gradient Surgery for Multi-Task Learning. Neural Information Processing Systems, 2020. Baoquan Zhang, Xutao Li, Yunming Ye, Shanshan Feng, and Rui Ye. Meta NODE: Prototype Optimization as a Neural ODE for Few-Shot Learning. ar Xiv preprint, 2021. Yufan Zhou, Zhenyi Wang, Jiayi Xian, Changyou Chen, and Jinhui Xu. Meta-Learning with Neural Tangent Kernels. International Conference on Learning Representations, 2021. Luisa Zintgraf, Kyriacos Shiarli, Vitaly Kurin, Katja Hofmann, and Shimon Whiteson. Fast context adaptation via meta-learning. In International Conference on Machine Learning, pp. 7693 7702. PMLR, 2019. Published as a conference paper at ICLR 2022 The Appendix is organized as follows: in Appendix A we give the pseudo-code for meta-training and meta-testing COMLN, along with minimal code in JAX (Bradbury et al., 2018). In Appendix B we prove the decomposition of the Jacobians introduced in Section 4, and we give the exact dynamics of sm(t) and zt[i, j, m] in the decomposition of d W (t)/dφm, omitted from the main text for concision. In Appendix C we show how the total derivatives may be computed from the Jacobian matrices without ever having to form them explicitly, hence maintaining the memory-efficiency described in Section 4.2. In Appendix D, we show that unlike the adjoint method described in Section 2.2, the algorithm used in COMLN to compute the meta-gradients based on the forward sensitivity method is stable and guaranteed not to diverge. Finally in Appendix E we give additional details regarding the experiments reported in Section 5, together with additional analyses and results on the dataset introduced in Rusu et al. (2018). A ALGORITHMIC DETAILS A.1 META-TRAINING PSEUDO-CODE We give in Algorithm 1 the pseudo-code for meta-training COMLN, based on a distribution of tasks p(τ), with references to the relevant propositions developed in Appendices B and C. Note that to simplify the presentation in Algorithm 1, we introduce the notations τ and Lτ to denote the total derivative of the outer-loss, and the partial derivative of the loss L respectively, both computed at the adapted parameters Wτ(T). For example: τW0 = d L Wτ(T); fΦ(Dtest τ ) d W0 W (T )Ltrain τ = L Wτ(T); fΦ(Dtrain τ ) Algorithm 1 COMLN Meta-training Require: A task distribution p(τ) Initialize randomly Φ and W0. Initialize T to a small value ε > 0. loop Sample a batch of tasks B p(τ) for all τ B do Embed the training set: fΦ(Dtrain τ ) = {(φm, ym)}M m=1 s(T), BT , z T ODESolve([0, 0, 0], DYNAMICS, 0, T) Algorithm 2 Wτ(T) W0 P m sm(T)φ m Proposition 1 Embed the test set: fΦ(Dtest τ ) Compute the partial derivatives: W (T )Ltrain τ & W (T )Ltest τ τW0 PROJECTW0 W (T )Ltest τ , BT , φ Proposition 5 τφm PROJECTφm W (T )Ltest τ , s(T), BT , z T , φ + φm Ltest τ Proposition 6 τΦ Backpropagation through fΦ, starting with τφm for all m τT W (T )Ltest τ W (T )Ltrain τ Proposition 4 end for Update the meta-parameters: W0 W0 α |B| P τ B τW0 Φ Φ α |B| P τ B τΦ T T α |B| P τ B τT end loop The dynamical system followed during adaptation is given in Propositions 1 to 3 and is summarized in Algorithm 2. The solution of this dynamical system is used not only to find the adapted parameters Wτ(T), but also to get the quantities necessary to compute the meta-gradients for the updates of Published as a conference paper at ICLR 2022 the meta-parameters W0 & Φ. The integration of this dynamical system, named ODESOLVE in Algorithm 1, is done in practice using a numerical solver such as Runge-Kutta methods. Algorithm 2 Dynamical system function DYNAMICS(W0, fΦ(Dtrain τ ), s(t)) Wτ(t) W0 P m sm(t)φ m Proposition 1 pm(t) softmax(Wτ(t)φm) Am(t) diag(pm(t)) pm(t)pm(t) /M M pm(t) ym Proposition 1 dt 1(i = j)Ai(t) Ai(t) φ i φm Bt[m, j] Propositions 2 & 3 ( -) dzt[i, j, m] dt Ai(t) 1(i = j)sm(t) + 1(i = m)sj(t) + φ i φk zt[k, j, m] return ds/dt, d Bt/dt, dzt/dt end function In Algorithms 3 & 4, we give the procedures responsible for the projection of the partial derivatives W (T )Ltest τ onto the Jacobian matrices d W (T)/d W0 and d W (T)/dφm respectively, based on Propositions 5 & 6. Note that Algorithm 4 also depends on the initial conditions W0, which is implicitly assumed here for clarity of presentation. It is interesting to see that both functions use the same matrix C, and therefore this can be computed only once and reused for both projections. Algorithm 3 Projection onto d W (T)/d W0 function PROJECTW0(V (T), BT , φ) i=1 φ i V (T) BT [i, j] return V (T) C φ end function Algorithm 4 Projection onto d W (T)/dφm function PROJECTφm(V (T), s(T), BT , z T , φ) i=1 φ i V (T) BT [i, j] i=1 z T [i, j, m] V (T)φi return sm(T) V (T) + Cm W0 + Dmφ end function Finally in Algorithms 5 & 6, we show how to use COMLN at meta-test time on a novel task, based on the learned meta-parameters W0, Φ and T. This simply corresponds to integrating a dynamical system in W (t) (equivalently, in sm(t), see Proposition 1). Note that for adaptation during metatesting, there is no need to compute either Bt[i, j] or zt[i, j, m], since these are only necessary to compute the meta-gradients during meta-training. Algorithm 5 COMLN Meta-test Require: A task τ with a dataset Dtrain τ Require: Meta-parameters Φ, W0 & T Embed the training dataset: fΦ(Dtrain τ ) s(T) ODESolve(0, DYNAMICS, 0, T) Wτ(T) W0 P m sm(T)φ m return Wτ(T) fΦ Algorithm 6 Dynamical system for adaptation function DYNAMICS(W0, fΦ(Dtrain τ ), s(t)) Wτ(t) W0 P m sm(t)φ m pm(t) softmax(Wτ(t)φm) dsm/dt pm(t) ym /M return ds/dt end function Published as a conference paper at ICLR 2022 A.2 SOURCE CODE We provide a snippet of code written in JAX (Bradbury et al., 2018) in order to compute the adapted parameters, based on Algorithms 1 & 2. The dynamics function not only computes the vectors sm(t) necessary to compute W (t) (see Appendix B.1), but also Bt[i, j] and zt[i, j, m] jointly in order to compute the meta-gradients. This snippet shows that various necessary quantities can be precomputed ahead of adaptation, such as the Gram matrix gram. 1 import jax.numpy as jnp 3 from jax import vmap, grad, nn, ops, tree_util 4 from jax.experimental import ode 5 from collections import namedtuple 7 State = namedtuple('State', ['s', 'B', 'z']) 9 M, N = train_inputs.shape[0], train_labels.shape[1] 10 gram = jnp.matmul(train_inputs, train_inputs.T) 11 logits_0 = jnp.matmul(train_inputs, W_0.T) 12 diag = jnp.diag_indices(M) 14 def dynamics(state, _): 15 preds = nn.softmax(logits_0 - jnp.matmul(gram, state.s), axis=1) 16 A = (vmap(jnp.diag)(preds) - vmap(jnp.outer)(preds, preds)) / M 18 # Update of s 19 ds = (preds - train_labels) / M 21 # Update of B 22 cross_prod = jnp.einsum('ikn,im,mjnl->ijkl', A, gram, state.B) 23 d B = ops.index_add(-cross_prod, diag, A) 25 # Update of z 26 cross_prod = jnp.einsum('iln,ik,jmn->ijml', A, gram, state.z) 27 A_s = jnp.einsum('ikl,jl->ijk', A, state.s) 28 dz = ops.index_add(cross_prod, diag, A_s) 29 dz = ops.index_add(dz, (diag[0], None, diag[1]), A_s) 31 return State(s=ds, B=d B, z=-dz) 33 state_0 = State( 34 s=jnp.zeros((M, N)), 35 B=jnp.zeros((M, M, N, N)), 36 z=jnp.zeros((M, M, M, N)) 38 solution = ode.odeint(dynamics, state_0, jnp.array([0., T])) 39 state_T = tree_util.tree_map(lambda x: x[-1], solution) 40 W_T = W_0 - jnp.matmul(state_T.s.T, train_inputs) Code Snippet 1: Snippet of code in JAX to compute the adapted parameter W_T for a given task specified by the embedded training set train_inputs & train_labels (note that here train_inputs are the embedding vectors returned by fΦ) and an initialization W_0. We want to emphasize that all the information required to compute the meta-gradients are available in state_T found after integration of the dynamics function forward in time, and does not require any additional backward pass (apart from backpropagation through the embedding network fΦ). In particular, the meta-gradients wrt. the initial conditions W0 and the wrt. the time horizon T can be computed using the snippet of code available in Code Snippet 2. Published as a conference paper at ICLR 2022 41 grads_test = grad(cross_entropy_loss)(W_T, test_inputs, test_labels) 42 grads_train = grad(cross_entropy_loss)(W_T, train_inputs, train_labels) 44 C = jnp.einsum('in,kn,ijkl->jl', train_inputs, grads_test, state_T.B) 46 grads_W_0 = grads_test - jnp.matmul(C.T, train_inputs) 47 grads_T = -jnp.vdot(grads_train, grads_test) Code Snippet 2: Snippet of code in JAX to compute the meta-gradients wrt. the meta-parameters W_0 and T, based on state_T found above. See Appendix B.4 & Appendix C for details. The code to compute the meta-gradients wrt. the meta-parameters of the embedding network Φ is not included here for clarity, as it involves non-minimal dependencies on Haiku (Hennigan et al., 2020) in order to backpropagate the error through the backbone network. The full code is available at https://github.com/tristandeleu/jax-comln. B PROOFS OF MEMORY EFFICIENT META-GRADIENTS In this section, we will prove the results presented in Section 4 in a slightly more general case, where the loss function L is the cross-entropy loss regularized with a proximal term around the initialization (Rajeswaran et al., 2019): L W ; fΦ(Dtrain τ ) = 1 m=1 y m log pm + λ 2 W W0 2 F , (14) where pm = softmax(W φm), F is the Frobenius norm, and λ is a regularization constant. Note that we can recover the setting presented in the main paper by setting λ = 0. The core idea of computing the meta-gradient efficiently is based on the decomposition of the gradient (as well as the Hessian matrix) of the cross-entropy loss (Böhning, 1992). We recall this decomposition as the following lemma, extended to include the regularization term: Lemma 1 (Böhning, 1992). Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 be the embedded training set through the embedding network fΦ, and L the regularized cross-entropy loss defined in (14). The gradient L and the Hessian matrix 2L of the regularized cross-entropy loss can be written as L W ; fΦ(Dtrain τ ) = 1 pm ym φ m + λ(W W0) 2L W ; fΦ(Dtrain τ ) = m=1 Am φmφ m + λI, where Am = diag(pm) pmp m /M are N N matrices, and is the Kronecker product. This lemma is particularly useful in the context of few-shot learning since it reduces the characterization of the gradient and the Hessian from quantities of size N d and Nd Nd respectively (where d is the dimension of the embedding vectors φ) to M objects whose size is independent of d the embedding vectors φ being independent of W . Typically in few-shot learning, M d. In order to avoid higher-order tensors when defining the different Jacobians and Hessians, we always consider them as matrices, with possibly an implicit flattening operation. For example here even though W is a N d matrix, we treat the Hessian 2L as a Nd Nd matrix, as opposed to a 4D tensor. When the context is required, we will make this transformation from a higher-order tensor to a matrix explicit with an encoding of indices. Moreover throughout this section, we will only consider the computation of the meta-gradients for a single task τ, and therefore we will often drop the explicit dependence of the different objects on τ (e.g. we will write W (t) instead of Wτ(t)); the meta-gradients are eventually averaged over a batch of tasks for the update of the outer-loop, see Appendix A for details in the pseudo-code. Finally, since we only consider adaptation of the last layer of the neural network, the presentation here is always made in the context of an embedded training set fΦ(Dtrain τ ) = {(φm, ym)}M m=1 that went through the backbone fΦ. Published as a conference paper at ICLR 2022 B.1 DECOMPOSITION OF W (t) FOR PARAMETER ADAPTATION As a direct application of Lemma 1, we first decompose the parameters W (t) into smaller quantities sm(t) that follow the dynamics defined in Prop. 1. Although this decomposition is equivalent to W (t), in practice solving a smaller dynamical system in sm(t) improves the efficiency of our method. Proposition 1. Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 be the embedded training set through the embedding network fΦ, and W (t) be the solution of the following dynamical system dt = L W (t); fΦ(Dtrain τ ) W (0) = W0, where the loss function is the regularized cross-entropy loss defined in (14). The solution W (t) of this dynamical system can be written as m=1 sm(t)φ m, where for all m, sm(t) is the solution of the following dynamical system: M pm(t) ym λsm(t) sm(0) = 0, and pm(t) = softmax(W (t)φm) is the vector of predictions returned by the network at time t. Proof. Using Lemma 1, the function W (t) is the solution of the following differential equation: dt = L W (t); fΦ(Dtrain τ ) pm(t) ym φ m λ W (t) W0 W (0) = W0. The proof relies on the unicity of the solution of a given autonomous differential equation given a particular choice of initial conditions (Wiggins, 2003, Prop. 7.4.2). In other words, if we can find another function f W (t) that also satisfies the differential equation above with the initial conditions f W (0) = W0, then it means that for all t we have f W (t) = W (t). Suppose that this function f W (t) can be written as f W (t) = W0 m=1 sm(t)φ m, where sm satisfies the dynamical system defined in the statement of Proposition 1. Then we have L f W (t); fΦ(Dtrain τ ) = 1 pm(t) ym φ m λ f W (t) W0 m=1 sm(t)φ m λ f W (t) W0 dt φ m = df W dt We have shown that f W (t) follows the same dynamics as W (t). Moreover, using the initial conditions sm(0) = 0, it is clear that f W (0) = W0, which are the same initial conditions as the equation satisfied by W (t). Therefore we have f W (t) = W (t) for all t, showing the expected decomposition of W (t). Published as a conference paper at ICLR 2022 B.2 DECOMPOSITION OF THE JACOBIAN MATRIX d W (t)/d W0 The core objective of gradient-based meta-learning methods is the capacity to compute the metagradients wrt. the initial conditions W0. In order to compute this meta-gradient using forward-mode automatic differentiation, we want to first compute the Jacobian matrix d W (t)/d W0. We show that this Jacobian can be decomposed into smaller quantities that follow the dynamics in Proposition 2. Proposition 2. Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 be the embedded training set through the embedding network fΦ. The Jacobian matrix d W (t)/d W0 can be written as j=1 Bt[i, j] φiφ j , where is the Kronecker product, and for all i, j, Bt[i, j] is a N N matrix which the solution of the following dynamical system dt = 1(i = j)Ai(t) λBt[i, j] Ai(t) φ i φm Bt[m, j] B0[i, j] = 0, and Ai(t) = diag(pi(t)) pi(t)pi(t) /M are defined using the vectors of predictions at time t pi(t) = softmax W (t)φi . Proof. We will use the forward sensitivity equation from Section 2.2 in order to derive this new dynamical system over Bt[i, j]. Recall that to simplify the notations we can write the gradient vector field followed during adaptation as dt = g W (t); fΦ(Dtrain τ ) L W (t); fΦ(Dtrain τ ) . Introducing the matrix-valued function S(t) = d W (t)/d W0 as a sensitivity state, we can use the forward sensitivity equations and see that S(t) satisfies the following equation dt = g W (t) W (t) S(t) + g W (t) W0 S(0) = I = 2L W (t); fΦ(Dtrain τ ) S(t) + λI. The rest of the proof is based on the unicity of the solution of a given autonomous5 differential equation given a particular choice of initial conditions (Wiggins, 2003, Prop. 7.4.2). In other words, if we can find another function e S(t) that also satisfies the above differential equation with the initial conditions e S(0) = I, then it means that for all t we have e S(t) = S(t). Suppose that this function can be written as j=1 Bt[i, j] φiφ j , where Bt[i, j] satisfies the dynamical system defined in the statement of Proposition 2. Then we have, using Lemma 1: 2L W (t); fΦ(Dtrain τ ) e S(t) + λI i=1 Ai(t) φiφ i + λI I X m,j Bt[m, j] φmφ j i=1 Ai(t) φiφ i + X Ai(t)Bt[m, j] φi φ i φm | {z } R i,j Bt[i, j] φiφ j 5While the dynamical system defined here for the sensitivity state S(t) alone is not exactly autonomous due to the dependence of the Hessian matrix on W (t), we could augment S(t) with W (t) to obtain an autonomous system on the augmented state. We come back to this distinction in Appendix D. The unicity argument still holds here (the augmented solution would be unique). Published as a conference paper at ICLR 2022 1(i = j)Ai(t) λBt[i, j] Ai(t) φ i φm Bt[m, j] φiφ j dt φiφ j = d e S We have shown that e S(t) follows the same dynamics as S(t). Moreover, using the initial conditions B0[i, j] = 0, it is clear that e S(0) = I, which are the same initial conditions as the equation satisfied by S(t). Therefore, we have e S(t) = S(t) for all t, showing the expected decomposition of S(t) = d W (t)/d W0. B.3 DECOMPOSITION OF THE JACOBIAN MATRIX d W (t)/dφm Similar to Proposition 2, there exists a decomposition of the Jacobian matrix d W (T)/dφm. Recall that this Jacobian matrix appears in the computation of the gradient of the outer-loss wrt. the embedding vectors φm, which is necessary in order to compute the meta-gradients wrt. the metaparameters of the embedding network fΦ using backpropagation from the last layer of fΦ. Proposition 3. Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 be the embedded training set through the embedding network fΦ. For all m, the Jacobian matrix d W (t)/dφm can be written as dφm = sm(t) Id + i=1 Bt[i, m]W0 φi + j=1 zt[i, j, m]φ j φi where sm(t) and Bt[i, j] are the solutions of the ODEs defined in Propositions 1 & 2, and for all i, j, m, zt[i, j, m] is a vector of length N solution of the following dynamical system: dzt[i, j, m] dt = Ai(t) 1(i = j)sm(t) + 1(i = m)sj(t) + λzt[i, j, m] + φ i φk zt[k, j, m] , with the initial conditions z0[i, j, m] = 0. Proof. The outline of this proof follows the proof of Proposition 2: we first use the forward sensitivity equations to get the dynamical system followed by the Jacobian matrix d W (t)/dφm, and then use a unicity argument given that the new decomposition satisfies the same ODE. Recall that we use the following notation to write the gradient vector field followed during adaptation: dt = g W (t); fΦ(Dtrain τ ) L W (t); fΦ(Dtrain τ ) . For a fixed m, we introduce the matrix valued function S(t) = d W (t)/dφm as a sensitivity state. We can use the forward sensitivity equations, and see that S(t) satisfies the following equation dt = g W (t), φm W (t) S(t) + g W (t), φm φm S(0) = d W (0) = 2L W (t); fΦ(Dtrain τ ) S(t) Am(t)W0 φm i=1 si(t)φ i φm 1 M pm(t) ym Id where we make the direct dependence of g on φm explicit, and we used the decomposition of W (t) from Proposition 1. Suppose that we define the function e S(t) as e S(t) = sm(t) Id + i=1 Bt[i, m]W0 φi + X i,j zt[i, j, m]φ j φi Published as a conference paper at ICLR 2022 where zt[i, j, m] satisfies the dynamical system defined in the statement of Proposition 3. Then we have, using Lemma 1: 2L W (t); fΦ(Dtrain τ ) e S(t) Am(t) W0 i=1 si(t)φ i M pm(t) ym Id i=1 Ai(t) φiφ i + λI sm(t) Id + i=1 Bt[i, m]W0 φi + X i,j zt[i, j, m]φ j φi Am(t)W0 φm + Am(t) i=1 si(t)φ i φm 1 M pm(t) ym Id M pm(t) ym λsm(t) | {z } = dsm/dt 1(i = m)Ai(t) λBt[i, m] Ai(t) φ i φk Bt[k, m] | {z } = d Bt[i,m]/dt i,j Ai(t) 1(i = j)sm(t) + 1(i = m)sj(t) + λzt[i, j, m] + φ i φk zt[k, j, m] | {z } = dzt[i,j,m]/dt We have shown that e S(t) follows the same dynamics as S(t). Moreover using the initial conditions sm(0) = 0, B0[i, j] = 0, and z0[i, j, m] = 0, we have e S(0) = 0, which are the same initial conditions as the equation satisfied by S(t). Therefore we have e S(t) = S(t) for all t, showing the expected decomposition of the Jacobian matrix S(t) = d W (t)/dφm. B.4 PROOF OF THE META-GRADIENT WRT. T The novelty of COMLN over prior work on gradient-based meta-learning is the ability to meta-learn the amount of adaptation T using SGD. To compute the meta-gradient wrt. the time horizon T, we can apply the result from Chen et al. (2018). Proposition 4. Let fΦ(Dtrain τ ) and fΦ(Dtest τ ) be the embedding through the network fΦ of the training and test set respectively. The gradient of the outer-loss wrt. the time horizon T is given by: d L W (T); fΦ(Dtest τ ) d T = L W (T); fΦ(Dtest τ ) L W (T); fΦ(Dtrain τ ) Proof. We can directly apply the result from (Chen et al., 2018, App. B.2), which we recall here for completeness. Given the ODE dz/dt = g(z(t)), the gradient of the loss e L z(T) wrt. T is given by d e L d T = a(T) g(z(T)) where a(T) = e L z(T) is the initial adjoint state (see Section 2.2). In our case we have g(z(T)) L W (T); fΦ(Dtrain τ ) a(T) L W (T); fΦ(Dtest τ ) Published as a conference paper at ICLR 2022 C PROJECTION ONTO THE JACOBIAN MATRICES Once we have computed the Jacobian d W (T)/dθ wrt. some meta-parameter θ (in our case, either the initial conditions W0 or the embedding vectors φm), we only have to project the vector of partial derivatives onto it to obtain the meta-gradients (vector-Jacobian product), using the chain rule: d L W (T); fΦ(Dtest τ ) dθ = L W (T); fΦ(Dtest τ ) W (T) d W (T) dθ + L W (T); fΦ(Dtest τ ) While we have shown how to decompose the Jacobian matrices in such a way that it only involves quantities that are independent of d the dimension of the embedding vectors φ, this final projection a priori requires us to form the Jacobian explicitly. Even though this operation is done only once at the end of the integration, this may be overly expensive since the Jacobian matrices scale quadratically with d, which can be as high as d = 16,000 in our experiments. Fortunately, we can perform this projection as a function of sm(T), BT [i, j], and z T [i, j, m], without having to explicitly form the full Jacobian matrix, by exchanging the order of operations. Proposition 5. Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 and fΦ(Dtest τ ) be the embedded training set and test set through the embedding network fΦ respectively. Let BT [i, j] be the solution at time T of the differential equation defined in Proposition 2, and V (T) RN d the partial derivative of the outer-loss wrt. the adapted parameters W (T): V (T) = L W (T); fΦ(Dtest τ ) W (T) so that d L W (T); fΦ(Dtest τ ) d W0 = Vec V (T) T d W (T) Let φ = [φ1, . . . , φM] RM d be the design matrix. The gradient of the outer-loss wrt. the initial conditions W0 can be computed as d L W (T); fΦ(Dtest τ ) d W0 = V (T) C φ, where the rows of C RM N are defined by i=1 φ i V (T) BT [i, j]. Proof. In this proof, we will make the encoding of the indices for the Jacobian d W (T)/d W0 more explicit, as mentioned in Appendix B. Recall from Proposition 2 that this Jacobian matrix can be decomposed as d W (T) i,j BT [i, j] φiφ j . Introducing the following notations i,j BT [i, j] φiφ j RNd Nd G Vec V (T) F RNd, for all l {0, . . . , N 1} and y {0, . . . , d 1}: G[dl + y] = x=0 VT [k, x]F [dk + x, dl + y] x=0 VT [k, x]φi[x] | {z } = [V (T )φi]k BT [i, j, k, l]φj[y] k=0 [V (T)φi]k BT [i, j, k, l] = [φ i V (T ) BT [i,j]]l Published as a conference paper at ICLR 2022 φ i V (T) BT [i, j] l | {z } = Cj,l φj[y] = C φ This shows that G is equal to C φ up to reshaping. Using the full form of the Jacobian, including I, concludes the proof. An interesting observation is that if the term C φ is ignored in the computation of the meta-gradient, we recover an equivalent of the first-order approximation introduced in Finn et al. (2017). Similarly, we can show that we can perform the projection onto the Jacobian matrix d W (T)/dφm without having to form the explicit Nd d matrix. Proposition 6. Let fΦ(Dtrain τ ) = {(φm, ym)}M m=1 and fΦ(Dtest τ ) be the embedded training set and test set through the embedding network fΦ respectively. Let sm(T) be the solution at time T of the differential equation defined in Proposition 1, BT [i, j] the solution of the one defined in Proposition 2, and z T [i, j, m] the solution of the one defined in Proposition 3. Let V (T) RN d be the partial derivative of the outer-loss wrt. the adapted parameters W (T): V (T) = L W (T); fΦ(Dtest τ ) Let φ = [φ1, . . . , φM] RM d be the design matrix. The projection of V (T) onto the Jacobian matrix d W (T)/dφm can be computed as Vec V (T) d W (T) dφm = sm(T) V (T) + Cm W0 + Dmφ , where Cm RN and Dm RM are defined by i=1 φ i V (T) BT [i, m] Dm,j = i=1 z T [i, j, m] V (T)φi Note that the definition of C in Proposition 6 matches exactly the definition of C in Proposition 5, and therefore we only need to compute this matrix once to perform the projections for both meta-gradients. Proof. Recall from Proposition 3 that the Jacobian d W (T)/dφm can be decomposed as dφm = sm(T) Id + i=1 BT [i, m]W0 φi + X i,j z T [i, j, m]φ j φi We will consider each of these 3 terms separately. For the first term, let F1 sm(T) Id RNd d G1 Vec V (T) F1 Rd, and for y {0, . . . , d 1}: x=0 VT [k, x]F1[dk + x, y] = x=0 VT [k, x]1(x = y) sm(T) k VT [k, y] = sm(T) V (T) Therefore, G1 the projection of V (T) onto the first term of the Jacobian matrix is equal to sm(T) V (T). Published as a conference paper at ICLR 2022 For the second term, let i=1 BT [i, m]W0 φi RNd d G2 Vec V (T) F2 Rd, and for y {0, . . . , d 1}: x=0 VT [k, x]F2[dk + x, y] x=0 VT [k, x]φi[x] | {z } = [V (T )φi]k BT [i, m, k, l]W0[l, y] k=0 [V (T)φi]k BT [i, m, k, l] = φ i V (T ) BT [i,m] φ i V (T) BT [i, m] l | {z } = Cm,l W0[l, y] = Cm W0 G2 the projection of V (T) onto the second term of the Jacobian matrix is equal to Cm W0. Finally for the third term, let i,j z T [i, j, m]φ j φi RNd d G3 Vec V (T) F3 Rd and for y {0, . . . , d 1}: x=0 VT [k, x]F3[dk + x, y] x=0 VT [k, x]φi[x] | {z } = [V (T )φi]k z T [i, j, m, k]φj[y] k=0 [V (T)φi]kz T [i, j, m, k] | {z } = z T [i,j,m] V (T )φi R i=1 z T [i, j, m] V (T)φi | {z } = Dm,j φj[y] = Dmφ G3 the projection of V (T) onto the third term of the Jacobian is equal to Dmφ. Note that the projection of V (T) as defined in Proposition 6 is not equal to the meta-gradient itself, as it is missing the term coming from the partial derivative of the outer-loss wrt. the embedding vector L W (T); fΦ(Dtest τ ) / φm, which is non-zero unlike the counterpart for W0. However, this projection is the only expensive operation required to compute the total gradient. Published as a conference paper at ICLR 2022 D PROOF OF STABILITY In this section, we would like to show that unlike the adjoint method (see Sections 2.2 and 3.2), our method to compute the meta-gradients of COMLN is guaranteed to be stable. In other words, even under some small perturbations due to the ODE solver, the solution found by numerical integration is going to stay close to the true solution of the dynamical system. To do so, we use the concept of Lyapunov stability of the solution of an ODE, which we recall here for completeness: Definition 1 (Lyapunov stability; Wiggins, 2003). The solution x(t) of a dynamical system is said to be Lyapunov stable if, given ε > 0, there exists δ such that for any other solution x(t) such that x(0) x(0) < δ, we have x(t) x(t) < ε for all t > 0. We would like to understand the conditions needed for a function f : Rn Rn, such that a trajectory x(t) satisfying the following autonomous differential equation is (Lyapunov) stable: dt = f x(t) If we assume for a moment that f(x) = L(x) where L is a convex function, then for any two trajectories x1(t) & x2(t) of the dynamical system above, we can see that the function F defined by 2 x1(t) x2(t) 2 satisfies F(t) 0, where we use the notation F(t) d F/dt. Indeed, since 2L is positive semi-definite and by defining h(t) = x2(t) x1(t), we get = x2(t) x1(t) x2(t) x1(t) = x2(t) x1(t) f(x2(t)) f(x1(t)) 0 Df x1(t) + sh(t) ds h(t) (15) 0 h(t) 2L x1(t) + sh(t) h(t) | {z } 0 where (15) follows from the fundamental theorem of calculus. Now since F(t) 0, the distance between two trajectories decreases as time t increases, hence we have Lyapunov stability for all trajectories or solutions of the above autonomous differential equation. Now note that we didn t actually need f to be of the specific form f(x) = L(x), but having the Jacobian matrix Df negative semi-definite everywhere would have been sufficient. Hence we get the following statement Proposition 7. If a continuously differentiable function f : Rn Rn is such that the Jacobian matrix Df is negative semi-definite everywhere, then any solution x(t) of the autonomous differential equation dx dt = f x(t) with initial condition x(0) = x0 is (Lyapunov) stable. We can extend the above result to the non-autonomous case. For this purpose, let us define dt = f t, x(t) (16) which induces the parametrized function ft : x 7 f(t, x). With this notation we can state the following result. Proposition 8. If the Jacobian matrix Dft is negative semi-definite everywhere for all t 0, then any trajectory x(t), solution of (16) with initial condition x(0) = x0, is (Lyapunov) stable. Published as a conference paper at ICLR 2022 Proof. Let us start with two trajectories x1(t) and x2(t), which are solutions to the above (16). By defining h(t) = x2(t) x1(t) and by applying the fundamental theorem of calculus, we get f t, x1(t) f t, x2(t) = ft x1(t) ft x2(t) 0 Dft x1(t) + sh(t) ds h(t) Following the same idea as already previously outlined, let us define F(t) = 1 2 x1(t) x2(t) 2 and show that F(t) 0: F(t) = (x2(t) x1(t)) f(t, x2(t)) f(t, x1(t)) 0 Dft x1(t) + sh(t) ds h(t) 0 h(t)T Dft x1(t) + sh(t) h(t) | {z } 0 Hence the distance between two trajectories decreases as time t increases, and we have Lyapunov stability for all solutions of the non-autonomous differential equation above. D.1 STABILITY OF THE FORWARD SENSITIVITY EQUATIONS In this subsection let us start by considering the general case of solving the initial value problem to the following autonomous system of differential equations: dt = g(W (t), θ) W (0) = W0 dt = g W (t), θ W (t) S(t) + g W (t), θ θ S(0) = W (0) Proposition 9. There exists a solution of the autonomous system of differential equations in (17), which is also unique. Proof. By Theorem 7.1.1 in Wiggins (2003), we know that there exists a solution of the system of differential equations in Eq. 17. Then, by Theorem 7.4.1 in Wiggins (2003), we can conclude that this solution is unique upon the choice of initial conditions W (0) = W0 and S(0) = W (0)/ θ. Alternatively, one could also apply Theorems 7.1.1 & 7.4.1 in Wiggins (2003) to conclude that the initial value problem d W /dt = g(W (t), θ) with initial condition W (0) = W0 has a unique solution. This existence and uniqueness then gives rise to existence and uniqueness of a solution of the entire system of differential equations via S(t) = d W /dt by applying Clairnaut s theorem. In our case, we have g(W (t), θ) = L W (T), θ , where L is a convex function in W . Hence the autonomous system of differential equations can be rewritten as dt = L W (T), θ W (0) = W0 dt = 2L W (T), θ S(t) + g W (t), θ θ S(0) = W (0) Proposition 10. Let W (t) be the solution of dt = L W (T), θ W (0) = W0 Then the trajectory W (t) is (Lyapunov) stable. Proof. Note that Dg W (t) = 2L here. Since L is convex, 2L is positive semi-definite, and hence Dg(W (t)) is negative semi-definite. Hence the result directly follows from Proposition 7. Published as a conference paper at ICLR 2022 Table 4: The effect of the numerical solver on the performance of COMLN. The average accuracy (%) on 1, 000 held-out meta-test tasks is reported with 95% confidence interval. Note that for a given setting, the same 1, 000 tasks are used for evaluation, making both methods directly comparable. RK: 4th-order Runge-Kutta with Dormand Prince adaptive step size. Euler: explicit Euler scheme. Method mini Image Net 5-way 1-shot 5-shot COMLN (RK) 53.01 0.62 70.54 0.54 COMLN (Euler) 53.00 0.83 70.50 0.72 In addition we can now also conclude that the solution of (18) is also (Lyapunov) stable. Corollary 1. The solution W (t), S(t) of (18) is (Lyapunov) stable. This result is general to the application of the forward sensitivity equations on a gradient vector field derived from a convex loss function. We can also show that when the gradient d W /dθ exists and is finite (recall that W is the minimizer of the loss function, and the equilibrium of the gradient vector field), then the solution of the system is also bounded, which guarantees us that our solution will not diverge (unlike the adjoint method applied to a gradient vector field). Proposition 11. Assuming that d W /dθ is finite, the solution W (t), S(t) of Eq. 18 is bounded. Proof. Let us define 2 W (t) 2 + 1 Since W (t) and S(t) are continuous functions in t, verifying W (t) t W and S(t) t d W V (t) t 1 2 W 2 + 1 where the last point follows from our assumption that d W (t)/dθ < . We also have that V (t) being continuous as a composition of continuous functions, we conclude that V is bounded, and hence W (t) and S(t) are bounded as well. E EXPERIMENTAL DETAILS E.1 CHOICE OF THE NUMERICAL SOLVER In Section 5.2, we showed that the numerical solver had a significant impact on the time to compute the meta-gradients (but not the memory requirements). In Table 4, we show that using explicit Euler instead of an adaptive scheme like Runge-Kutta leads to very similar results. In all our experiments we therefore chose Runge-Kutta for its faster execution, and by convenience it is the default numerical solver available in JAX (Bradbury et al., 2018). E.2 DETAILS FOR MEASURING THE EFFICIENCY In Figures 3 & 4, we show the efficiency, both in terms of memory and runtime, of COMLN compared to other meta-learning algorithms. The computational efficiency (runtime) was measured as the average time taken to compute the meta-gradients on a single 5-shot 5-way task from the mini Image Net dataset over 100 runs. Published as a conference paper at ICLR 2022 100 101 102 103 104 105 106 107 108 Number of gradient steps Memory usage (in Gb) MAML i MAML 100 101 102 103 104 105 106 107 108 102 Number of gradient steps Runtime (in ms) ANIL COMLN (Euler) 10 2 10 1 100 101 102 103 104 105 106 T 10 2 10 1 100 101 102 103 104 105 106 T Figure 4: Empirical efficiency of COMLN on a single 5-shot 5-way task, with a Res Net-12 backbone; this figure is similar to Figure 3. (Left) Memory usage for computing the meta-gradients as a function of the number of inner-gradient steps. The extrapolated dashed lines correspond to the method reaching the memory capacity of a Tesla V100 GPU with 32Gb of memory. (Right) Average time taken (in ms) to compute the exact meta-gradients. The extrapolated dashed lines correspond to the method taking over 3 seconds. 0 0.2 0.4 0.6 0.8 1 Meta-training Evolution of T mini Image Net tiered Image Net 5-way 1-shot 5-way 5-shot 10 1 100 101 102 103 104 Effect of learning T Figure 5: (Left) Evolution of the meta-parameter T controlling the amount of adaptation necessary for all tasks during meta-training. Here, the backbone is a Res Net-12. We normalized the duration of meta-training in [0, 1] to account for early stopping; typically the model for mini Image Net requires an order of magnitude fewer iterations. (Right) Comparison of COMLN (where T is learned, in red) to meta-learning with a fixed length of adaptation T (in blue), on a 5-shot 5-way classification problem on the mini Image Net dataset (with a Conv-4 backbone). In order to ensure fair comparison between methods that rely on a discrete number of gradient steps (MAML, ANIL, and i MAML), and COMLN which relies on a continuous integration time T, we added a conversion between the number of gradient steps and T. This corresponds to taking a learning rate of α = 0.01 in (2) (which is standard practice for MAML and ANIL on mini Image Net). This means that the number of gradient steps is 100 larger than T. This can be formally justified by considering an explicit Euler scheme for COMLN (see Section 3.3), with a constant step size α = 0.01. The memory requirements is independent of the choice of the numerical solver. For the computation time, this correspondence between T and the number of gradient steps is no longer exact for COMLN (RK) the comparison with COMLN (Euler) is still valid though. E.3 ANALYSIS OF THE LEARNED HORIZON Since one of the advantage of COMLN compared to other gradient-based methods is its capacity to learn the amount of adaptation through the time horizon T, Figure 5 (left) shows the evolution of this meta-parameter during meta-training. We can observe that for the more complex dataset tiered Image Net, COMLN appropriately learns to use a longer sequence of adaptation. Similarly Published as a conference paper at ICLR 2022 Table 5: mini Image Net results using LEO embeddings and a single linear classifier layer. The average accuracy (%) on 1,000 held-out meta-test tasks is reported with 95% confidence interval. * Results reported in (Rusu et al., 2018). ** Note that LEO uses more than a single linear classifier layer, but we add the numbers for completeness. Model mini Image Net 5-way 1-shot 5-shot MAML (Finn et al., 2017) 50.35 0.63 65.28 0.54 Meta-SGD (Li et al., 2017) 54.24 0.03 70.86 0.04 Meta-SGD (Li et al., 2017) 50.57 0.64 69.09 0.53 i MAML (Rajeswaran et al., 2019) 50.26 0.61 69.52 0.51 R2D2 (Bertinetto et al., 2018) 50.33 0.62 70.38 0.52 LRD2 (Bertinetto et al., 2018) 50.41 0.62 70.29 0.52 LEAP (Flennerhag et al., 2018) 50.95 0.62 66.72 0.55 Meta Opt Net (Lee et al., 2019) 40.60 0.60 50.94 0.62 LEO** (Rusu et al., 2018) 61.76 0.08 77.59 0.12 COMLN (Ours) 50.39 0.63 70.06 0.52 within each dataset, it also learns to use shorter sequences of adaptation of 1-shot problems, possibly to allow for better generalization and to reduce overfitting. Besides adapting the amount of adaptation to the problem at hand, learning T also has the advantage of saving computation while reaching high levels of performance. If we were to fix T ahead of meta-training (as is typically the case in gradient-based meta-learning, where the number of gradient steps for adaptation is a hyperparameter) to a large value in order to reach high accuracy, as is shown in Figure 5 (right), then it would induce larger computational costs early on during meta-training compared to COMLN, which achieves equal performance while tuning the value of T. In COMLN, the value of T is relatively small at the beginning of meta-training. E.4 ADDITIONAL EXPERIMENT: PREPROCESSED mini IMAGENET DATASET In addition to our experiments on the mini Image Net and tiered Image Net datasets in Section 5, we want to evaluate the performance of continuous-time adaptation in isolation from learning the embedding network. We evaluate this using a preprocessed version of mini Image Net introduced in Rusu et al. (2018), where the embeddings were trained using a Wide Residual Network (WRN) via supervised classification on the meta-train set. For our purposes, we consider these embeddings φ R640 as fixed, and only meta-learn the initial conditions W0 as well as T. The classification accuracies for COMLN and other baselines using pretrained embeddings are shown in Table 5. COMLN achieves comparable or better performance with a single linear classifier layer as other meta-learning methods in both the 5-way 1-shot and 5-way 5-shot classification tasks. The single exception is the 5-way 1-shot result of Meta-SGD from Rusu et al. (2018), which exceeded the performance of COMLN. However, our implementation of Meta-SGD achieved comparable performance to COMLN. This gap in performance is likely due to the additional data used in Rusu et al. (2018) (meta-train and meta-validation splits) during meta-training, as opposed to using only the meta-training set for all other baselines. These results show that isolated from representation learning, all these meta-learning algorithms (either gradient-based or not) perform similarly, and COMLN is no exception. One notable exception though is Meta Opt Net (Lee et al., 2019), where the performance is not as high as the other baselines when the backbone network is not learned anymore despite often being the best performing model in Table 2. Our hypothesis is that this discrepancy is due to the accuracy of the QP solver used in Meta Opt Net, since learning individual SVMs on 1,000 held-out meta-test tasks leads to performance matching all other methods (about 50% for 5-way 1-shot and about 70% for 5-way 5-shot). Published as a conference paper at ICLR 2022 E.5 EXPERIMENTAL DETAILS For all methods and all datasets, we used SGD with momentum 0.9 and Nesterov acceleration, with a decreasing learning rate starting at 0.1 and decreasing according to the schedule provided by Lee et al. (2019). For meta-training, we followed the standard procedure in gradient-based meta-learning, and meta-trained with a fixed number of shots: for example in mini Image Net 5-shot 5-way, we only used tasks with k = 5 training examples for each of the N = 5 classes. This contrasts with Lee et al. (2019), which uses a larger number of shots during meta-training than the one used for evaluation (e.g. meta-training with k = 15, and evaluating on k = 1). This may explain the gap in performance between COMLN and Meta Opt Net, especially on 1-shot settings. We opted to not follow this decision made by Lee et al. (2019) to ensure a fair comparison with other gradient-based methods, which all used the process described above. Conv-4 backbone We used a standard convolutional neural network with 4 convolutional blocks (Finn et al., 2017). Each block consists of a convolutional layer with a 3 3 kernel and 64 channels, followed by a batch normalization layer, and a max-pooling layer with window size and stride 2 2. The activation function is a Re LU. Res Net-12 backbone We largely followed the architecture from Lee et al. (2019), which consists of a 12-layer residual network. The neural network is composed of 4 blocks with residual connections of 3 convolutional layers with a 3 3 kernel. The convolutional layers in the residual block have k = 64, 160, 320 and 640 channels respectively. The non-linearity functions are LEAKYRELU(0.1), and a max-pooling layer with window size and stride 2 2 is applied at the end of each block. No global pooling is performed at the end of the embedding network, meaning that the embedding dimension is d = 16,000. The only notable difference with the architecture used by Lee et al. (2019) is the absence of Drop Block (Ghiasi et al., 2018) regularization.