# scalable_bayesian_metalearning_through_generalized_implicit_gradients__d15c9b63.pdf Scalable Bayesian Meta-Learning through Generalized Implicit Gradients Yilang Zhang, Bingcong Li, Shijian Gao, Georgios B. Giannakis Dept. of ECE, University of Minnesota, Minneapolis, MN, USA {zhan7453,lixx5599,gao00379,georgios}@umn.edu Meta-learning owns unique effectiveness and swiftness in tackling emerging tasks with limited data. Its broad applicability is revealed by viewing it as a bi-level optimization problem. The resultant algorithmic viewpoint however, faces scalability issues when the inner-level optimization relies on gradient-based iterations. Implicit differentiation has been considered to alleviate this challenge, but it is restricted to an isotropic Gaussian prior, and only favors deterministic metalearning approaches. This work markedly mitigates the scalability bottleneck by cross-fertilizing the benefits of implicit differentiation to probabilistic Bayesian meta-learning. The novel implicit Bayesian meta-learning (i Ba ML) method not only broadens the scope of learnable priors, but also quantifies the associated uncertainty. Furthermore, the ultimate complexity is well controlled regardless of the inner-level optimization trajectory. Analytical error bounds are established to demonstrate the precision and efficiency of the generalized implicit gradient over the explicit one. Extensive numerical tests are also carried out to empirically validate the performance of the proposed method. Introduction Over the past decade, deep learning (DL) has garnered huge attention from theory, algorithms, and application viewpoints. The underlying success of DL is mainly attributed to the massive datasets, with which large-scale and highly expressive models can be trained. On the other hand, the stimulus of DL, namely data, can be scarce. Nevertheless, in several real-world tasks, such as object recognition and concept comprehension, humans can perform exceptionally well even with very few data samples. This prompts the natural question: How can we endow DL with human s unique intelligence? By doing so, DL s data reliance can be alleviated and the subsequent model training can be streamlined. Several trials have been emerging in those stimulus-lacking domains, including speech recognition (Miao, Metze, and Rawat 2013), medical imaging (yang et al. 2016), and robot manipulation (Hansen and Wang 2021). A systematic framework has been explored in recent years to address the aforementioned question, under the terms learning-to-learn or meta-learning (Thrun 1998). In brief, Copyright 2023, Association for the Advancement of Artificial Intelligence (www.aaai.org). All rights reserved. meta-learning extracts task-invariant prior information from a given family of correlated (and thus informative) tasks. Domain-generic knowledge can therein be acquired as an inductive bias and transferred to new tasks outside the set of given ones (Thrun and Pratt 2012; Grant et al. 2018), making it feasible to learn unknown models/tasks even with minimal training samples. One representative example is that of an edge extractor, which can act as a common prior owing to its presence across natural images. Thus, using it can prune degrees of freedom from a number of image classification models. The prior extraction in conventional meta-learning is more of a hand-crafted art; see e.g., (Schmidhuber 1993; Bengio, Bengio, and Cloutier 1995; Schmidhuber, Zhao, and Wiering 1996). This rather cumbersome art has been gradually replaced by data-driven approaches. For parametric models of the task-learning process (Santoro et al. 2016; Mishra et al. 2018), the task-invariant sub-model can then be shared across different tasks with prior information embedded in the model weights. One typical model is that of recurrent neural networks (RNNs), where task-learning is captured by recurrent cells. However, the resultant blackbox learning setup faces interpretability challenges. As an alternative to model-committed approaches, modelagnostic meta-learning (MAML) transforms task-learning to optimizing the task-specific model parameters, while the prior amounts to initial parameters per task-level optimization, that are shared across tasks and can be learned through differentiable meta-level optimization (Finn, Abbeel, and Levine 2017). Building upon MAML, optimization-based meta-learning has been advocated to ameliorate its performance; see e.g. (Li et al. 2017; Bertinetto et al. 2019; Flennerhag et al. 2020; Abbas et al. 2022). In addition, performance analyses have been reported to better understand the behavior of these optimization-based algorithms (Franceschi et al. 2018; Fallah, Mokhtari, and Ozdaglar 2020; Wang, Sun, and Li 2020; Chen and Chen 2022). Interestingly, the learned initialization can be approximately viewed as the mean of an implicit Gaussian prior over the task-specific parameters (Grant et al. 2018). Inspired by this interpretation, Bayesian methods have been advocated for meta-learning to further allow for uncertainty quantification in the model parameters. Different from its deterministic counterpart, Bayesian meta-learning seeks a prior distribution over the model parameters that The Thirty-Seventh AAAI Conference on Artificial Intelligence (AAAI-23) best explains the data. Exact Bayesian inference however, is barely tractable as the posterior is often non-Gaussian, which prompts pursuing approximate inference methods; see e.g., (Yoon et al. 2018; Grant et al. 2018; Finn, Xu, and Levine 2018; Ravi and Beatson 2019). MAML and its variants have appealing empirical performance, but optimizing the meta-learning loss with backpropagation is challenging due to the high-order derivatives involved. This incurs complexity that grows linearly with the number of task-level optimization steps, which renders the corresponding algorithms barely scalable. For this reason, scalability of meta-learning algorithms is of paramount importance. One remedy is to simply ignore the highorder derivatives, and rely on first-order updates only (Finn, Abbeel, and Levine 2017; Nichol, Achiam, and Schulman 2018). Alternatively, the so-termed implicit (i)MAML relies on implicit differentiation to eliminate the explicit backpropagation. However, the proximal regularization term in i MAML is confined to be a simple isotropic Gaussian prior, which limits model expressiveness (Rajeswaran et al. 2019). In this paper, we develop a novel implicit Bayesian metalearning (i Ba ML) approach that offers the desirable scalability, expressiveness, and performance quantification, and thus broadens the scope and appeal of meta-learning to real application domains. The contribution is threefold. i) i Ba ML enjoys complexity that is invariant to the number K of gradient steps in task-level optimization. This fundamentally breaks the complexity-accuracy tradeoff, and makes Bayesian meta-learning affordable with more sophisticated task-level optimization algorithms. ii) Rather than an isotropic Gaussian distribution, i Ba ML allows for learning more expressive priors. As a Bayesian approach, i Ba ML can quantify uncertainty of the estimated model parameters. iii) Through both analytical and numerical performance studies, i Ba ML showcases its complexity and accuracy merits over the state-of-the-art Bayesian meta-learning methods. In a large K regime, the time and space complexity can be reduced even by an order of magnitude. Preliminaries and Problem Statement This section outlines the meta-learning formulation in the context of supervised few-shot learning, and touches upon the associated scalability issues. Meta-Learning Setups Suppose we are given datasets Dt := {(xn t , yn t )}Nt n=1, each of cardinality |Dt| = Nt corresponding to a task indexed by t {1, . . . , T}, where xn t is an input vector, and yn t R denotes its label. Set Dt is disjointly partitioned into a training set Dtr t and a validation set Dval t , with |Dtr t | = N tr t and |Dval t | = N val t for t. Typically, Nt is limited, and often much smaller than what is required by supervised DL tasks. However, it is worth stressing that the number of tasks T can be considerably large. Thus, PT t=1 Nt can be sufficiently large for learning a prior parameter vector shared by all tasks; e.g., using deep neural networks. A key attribute of meta-learning is to estimate such a task-invariant prior information parameterized by the metaparameter θ based on training data across tasks. Subsequently, θ and Dtr t are used to perform taskor inner-level optimization to obtain the task-specific parameter θt Rd. The estimate of θt is then evaluated on Dval t (and potentially also Dtr t ) to produce a validation loss. Upon minimizing this loss summed over all the training tasks w.r.t. θ, this metaor outer-level optimization yields the task-invariant estimate of θ. Note that the dimension of θt is not necessarily identical to that of θ; see e.g. (Li et al. 2017; Bertinetto et al. 2019; Lee et al. 2019). As we will see shortly, this nested structure can be formulated as a bi-level optimization problem. This formulation readily suggests application of meta-learning to settings such as hyperparameter tuning that also relies on a similar bi-level optimization (Franceschi et al. 2018). This bi-level optimization is outlined next for both deterministic and probabilistic Bayesian meta-learning variants. Optimization-based meta-learning. For each task t, let ˇLtr t (θt) and ˇLval t (θt) denote the losses over Dtr t and Dval t , respectively. Further, let ˆθ be the meta-parameter estimate, and R(ˆθ, θt) the regularizer of the learning cost per task t. Optimization-based meta-learning boils down to ˆθ = argmin θ t=1 ˇLval t (ˆθt(θ)) (1) s.to ˆθt(θ) = argmin θt ˇLtr t (θt) + R(θ, θt), t = 1, . . . , T. The regularizer R can be either implicit (as in i MAML) or explicit (as in MAML). Further, the task-invariant metaparameter is calibrated by R in order to cope with overfitting. Indeed, an over-parameterized neural network could easily overfit Dtr t to produce a tiny ˇLtr t yet a large ˇLval t . As reaching global minima can be infeasible especially with highly nonconvex neural networks, a practical alternative is an estimator ˆθt produced by a function ˆ At(θ) representing an optimization algorithm, such as gradient descent (GD), with a prefixed number K of iterations. Thus, a tractable version of (1) is ˆθ = argmin θ t=1 ˇLval t (ˆθt(θ)) (2) s.to ˆθt(θ) = ˆ At(θ), t = 1, . . . , T As an example, ˆ At can be an one-step gradient descent initialized by ˆθ with implicit priors (R(ˆθ, θt) = 0) (Finn, Abbeel, and Levine 2017; Grant et al. 2018), which yields the per task parameter estimate ˆθt = ˆ At(θ) = θ α ˇLtr t (θ), t = 1, . . . , T (3) where α is the learning rate of GD, and we use the compact gradient notation ˇLtr t (θ) := θt ˇLtr t (θt) θt=θ hereafter. For later use, we also define A t the (unknown) oracle function that generates the global optimum θ t . Bayesian meta-learning. The probabilistic approach to meta-learning takes a Bayesian view of the (now random) vector θt per task t. The task-invariant vector θ is still deterministic, and parameterizes the prior probability density function (pdf) p(θt; θ). Task-specific learning seeks the posterior pdf p(θt|ytr t ; Xtr t , θ), where Xtr t := [x1 t, . . . , x N tr t t ] and ytr t := [y1 t , . . . , y N tr t t ] ( denotes transposition), while the objective per task t is to maximize the conditional likelihood p(yval t |ytr t ; Xval t , Xtr t , θ) = R p(yval t |θt; Xval t )p(θt|ytr t ; Xtr t , θ)dθt. Along similar lines followed by its deterministic optimization-based counterpart, Bayesian meta-learning amounts to ˆθ = argmax θ Z p(yval t |θt; Xval t )p(θt|ytr t ; Xtr t , θ)dθt s.to p(θt|ytr t ; Xtr t , θ) p(ytr t |θt; Xtr t )p(θt; θ), t (4) where we used that datasets are independent across tasks, and Bayes rule in the second line. Through the posterior p(θt|ytr t ; Xtr t , θ), Bayesian meta-learning quantifies the uncertainty of task-specific parameter estimate ˆθt, thus assessing model robustness. When the posterior of θt is replaced by its maximum a posteriori point estimator ˆθ map t , meaning p(θt|ytr t ; Xtr t , θ) = δD[θt ˆθ map t ] with δD denoting Dirac s delta, it turns out that (4) reduces to (1). Unfortunately, the posterior in (4) can be intractable with nonlinear models due to the difficulty of finding analytical solutions. To overcome this, we can resort to the widely adopted approximate variational inference (VI); see e.g. (Finn, Xu, and Levine 2018; Ravi and Beatson 2019; Nguyen, Do, and Carneiro 2020). VI searches over a family of tractable distributions for a surrogate that best matches the true posterior p(θt|ytr t ; Xtr t , θ). This can be accomplished by minimizing the KL-divergence between the surrogate pdf q(θt; vt) and the true one, where vt determines the variational distribution. Considering that the dimension of θt can be fairly high, both the prior and surrogate posterior are often set to be Gaussian (N) with diagonal covariance matrices. Specifically, we select the prior as p(θt; θ) = N(m, D) with covariance D = diag(d) and θ := [m , d ] Rd Rd >0, and the surrogate posterior as q(θt; vt) = N(mt, Dt) with Dt = diag(dt) and vt := [m t , d t ] Rd Rd >0. To ensure tractable numerical integration over q(θt; vt), the meta-learning loss is often relaxed to an upper bound of PT t=1 log p(yval t |ytr t ; Xval t , Xtr t , θ). Common choices include applying Jensen s inequality (Nguyen, Do, and Carneiro 2020) or an extra VI (Finn, Xu, and Levine 2018; Ravi and Beatson 2019) on (4). For notational convenience, here we will denote this upper bound by Lval t (vt, θ). With VI and a relaxed (upper bound) objective, (4) becomes ˆθ = argmin θ t=1 Lval t (v t (θ), θ) (5) s.to v t (θ) = argmin vt KL q(θt; vt) p(θt|ytr t ; Xtr t , θ) t, where Lval t depends on θ in two ways: i) via the intermediate variable v t ; and, ii) by acting directly on Lval t . Note that (5) is general enough to cover the case where Lval t is constructed using both Dval t and Dtr t ; see e.g., (Ravi and Beatson 2019). Similar to optimization-based meta-learning, the difficulty in reaching global optima prompts one to substitute v t with a sub-optimum ˆvt obtained through an algorithm ˆ At(θ); i.e., ˆθ = argmin θ t=1 Lval t (ˆvt(θ), θ) s.to ˆvt(θ) = ˆ At(θ), t = 1, . . . , T. (6) Scalability Issues in Meta-Learning Delay and memory resources required for solving (2) and (6) are arguably the major challenges that meta-learning faces. Here we will elaborate on these challenges in the optimization-based setup, but the same argument carries over to Bayesian meta-learning too. Consider minimizing the meta-learning loss in (2) using gradient-based iteration such as Adam (Kingma and Ba 2015). In the (r+1)-st iteration, gradients must be computed for a batch Br {1, . . . , T} of tasks. Letting ˆθ r t := ˆ At(ˆθ r), where ˆθ r denotes the meta-parameter in the r-th iteration, the chain rule yields the so-termed meta-gradient θ ˇLval t (ˆθ r t(θ)) θ=ˆθ r = ˆ At(ˆθ r) ˇLval t (ˆθ r t), t Br (7) where ˆ At(ˆθ r) contains high-order derivatives. When ˆ At is chosen as the one-step GD (cf. (3)), the meta-gradient is ˆ At(ˆθ r) = Id α 2 ˇLtr t (ˆθ r), t Br. (8) Fortunately, in this case the meta-gradient can still be computed through the Hessian-vector product (HVP), which incurs spatio-temporal complexity O(d). In general, ˆ At is a K-step GD for some K > 1, which gives rise to high-order derivatives { k ˇLtr t (ˆθ r)}K+1 k=2 in the meta-gradient. The most efficient computation of the metagradient calls for recursive application of HVP K times, what incurs an overall complexity of O(Kd) in time, and O(Kd) in space requirements. Empirical wisdom however, favors a large K because it leads to improved accuracy in approximating the true meta-gradient θ ˇLval t (A t (θ)) θ=ˆθ r. Hence, the linear increase of complexity with K will impede the scaling of optimization-based meta-learning algorithms. When computing the meta-gradient, it should be underscored that the forward implementation of the K-step GD function has complexity O(Kd). However, the constant hidden in the O is much smaller compared to the HVP computation in the backward propagation. Typically, the constant is 1/5 in terms of time and 1/2 in terms of space; see (Griewank 1993; Rajeswaran et al. 2019). For this reason, we will focus on more efficient means of obtaining the meta-gradient function θLval t ( ˆ At(θ)) for Bayesian metalearning. It is also worth stressing that our results in the next section will hold for an arbitrary vector θ Rd Rd >0 instead of solely the variable ˆθ r of the r-th iteration. Thus, we will use the general vector θ when introducing our approach, while we will take its value at the point θ = ˆθ r when presenting our meta-learning algorithm. Implicit Bayesian Meta-Learning In this section, we will first introduce the proposed implicit Bayesian meta-learning (i Ba ML) method, which is built on top of implicit differentiation. Then, we will provide theoretical analysis to bound and compare the errors of explicit and implicit differentiation. Implicit Bayesian Meta-Gradients We start with decomposing the meta-gradient in Bayesian meta-learning (6) (henceforth referred to as Bayesian metagradient) using the chain rule θLval t (ˆvt(θ), θ) = ˆ At(θ) 1Lval t (ˆvt, θ) + 2Lval t (ˆvt, θ), t = 1, . . . , T (9) where 1 and 2 denote the partial derivatives of a function w.r.t. its first and second arguments, respectively. The computational burden in (9) comes from the high-order derivatives present in the Jacobian ˆ At(θ). The key idea behind implicit differentiation is to express ˆ At(θ) as a function of itself, so that it can be numerically obtained without using high-order derivatives. The following lemma formalizes how the implicit Jacobian is obtained in our setup. All proofs can be found in the Appendix. Lemma 1. Consider the Bayesian meta-learning problem in (5), and let vt := [ m t , d t ] be a local minimum of the task-level KL-divergence generated by At(θ). Also, let Ltr t (vt) := Eq(θt;vt)[ log p(ytr t |θt; Xtr t )] denote the expected negative log-likelihood (nll) on Dtr t . If Ht( vt) := 2Ltr t ( vt)+ D 1 0d 0d 1 2 D 1 + 2 diag dt Ltr t ( vt) 2 is invertible, then it holds for t {1, . . . , T} that At(θ) = D 1 0d diag mt Ltr t ( vt) D 1 1 2D 2 H 1 t ( vt). (10) Two remarks are now in order regarding the technical assumption, and connections with i MAML. For notational brevity, define the block matrix Gt( vt) := D 1 0d diag mt Ltr t ( vt) D 1 1 2D 2 Remark 1. The invertibility of Ht( vt) in Lemma 1 is assumed to ensure uniqueness of At(θ). Without this assumption, it turns out that vt can be a singular point, belonging to a subspace where any point is also a local minimum. The Bayesian meta-gradients (9) of the points in this subspace form a set Gt = n Gt( vt) H t( vt) 1Lval t ( vt, θ) + u + 2Lval t ( vt, θ) u Null Ht( vt) o (12) where represents pseudo-inverse, and Null( ) stands for the null space. Upon replacing H 1 t ( vt) with H t( vt), one can generalize Lemma 1, and forgo the invertibility assumption. Algorithm 1: Implicit Bayesian meta-learning (i Ba ML) 1: Inputs: tasks {1, . . . , T} with their Dtr t and Dval t , and meta-learning rate β. 2: Initialization: initialize ˆθ 0 randomly, and iteration counter r = 0. 3: repeat 4: Sample a batch Br {1, . . . , T} of tasks; 5: for t Br do 6: Compute task-level sub-optimum ˆvr t = ˆ At(ˆθ r) using e.g. K-step GD; 7: Approximate ˆur t H 1 t (ˆvr t ) 1Lval t (ˆvr t , ˆθ r) with L-step CG; 8: Compute meta-level gradient ˆgr t = Gt(ˆvr t )ˆur t + 2Lval t (ˆvr t , ˆθ r) using (17); 9: end for 10: Update ˆθ r+1 = ˆθ r β 1 |Br| P t Br ˆgr t ; 11: r = r + 1; 12: until convergence 13: Output: ˆθ r. Remark 2. To recognize how Lemma 1 links i Ba ML with i MAML (Rajeswaran et al. 2019), consider the special case where the covariance matrices of the prior and local minimum are fixed as D λ 1Id and Dt 0d for some constant λ. Since d = [λ 1, . . . , λ 1] Rd is a constant vector, Lemma 1 boils down to m At(θ) = D 1 2 m Ltr t ( vt) + D 1 1 = λ 1 2 m Ltr t ( vt) + Id 1 (13) which coincides with Lemma 1 of (Rajeswaran et al. 2019). Hence, i Ba ML subsumes i MAML whose expressiveness is confined because d is fixed, while i Ba ML entails a learnable covariance matrix in the prior p(θt; θ). In addition, the uncertainty of i MAML s training posterior p(θt|ytr t ; Xtr t , θ) can be more challenging to quantify than that in i Ba ML. An immediate consequence of Lemma 1 is the so-called generalized implicit gradients. Suppose that ˆ At involves a K sufficiently large for the sub-optimal point ˆvt to be close to a local optimum vt. The Bayesian meta-gradient (9) can then be approximated through θLval t (ˆvt(θ), θ) (14) Gt(ˆvt)H 1 t (ˆvt) 1Lval t (ˆvt, θ) + 2Lval t (ˆvt, θ), t. The approximate implicit gradient in (14) is computationally expensive due to the matrix inversion H 1 t (ˆvt), which incurs complexity O(d3). To relieve the computational burden, a key observation is that H 1 t (ˆvt) 1Lval t (ˆvt, θ) is the solution of the optimization problem argmin u 1 2u Ht(ˆvt)u u 1Lval t (ˆvt, θ). (15) Given that the square matrix Ht(ˆvt) is by definition symmetric, problem (15) can be efficiently solved using the conjugate gradient (CG) iteration. Specifically, the complexity of CG is dominated by the matrix-vector product Ht(ˆvt)p (for some vector p R2d), given by Ht(ˆvt)p = 2Ltr t (ˆvt)p (16) + D 1 0d 0d 1 2 D 1 + 2 diag ˆdt Ltr t (ˆvt) 2 p. The first term on the right-hand side of (16) is an HVP, and the second is the multiplication of a diagonal matrix with a vector. Note that with the diagonal matrix, the latter term boils down to a dot product, implying that the complexity of each CG iteration is as low as O(d). In practice, a small number of CG iterations suffices to produce an accurate estimate of H 1 t (ˆvt) 1Lval t (ˆvt, θ) thanks to its fast convergence rate (Van der Sluis and van der Vorst 1986; Winther 1980). In order to control the total complexity of i Ba ML, we set the maximum number of CG iterations to a constant L. Having obtained an approximation of the matrixinverse-vector product H 1 t (ˆvt) 1Lval t (ˆvt, θ), we proceed to estimate the Bayesian meta-gradient. Let ˆut := [ˆu t,m, ˆu t,d] be the output of the CG method with subvectors ˆut,m, ˆut,d Rd. Then, it follows from (14) that θLval t (ˆvt(θ), θ) Gt(ˆvt)ˆut + 2Lval t (ˆvt, θ) = D 1ˆut,m diag ˆ mt Ltr t (ˆvt) D 1ˆut,m + 1 + 2Lval t (ˆvt, θ) := ˆgt, t = 1, . . . , T where we also used the definition (11). Again, the diagonalmatrix-vector products in (17) can be efficiently computed through dot products, which incur complexity O(d). The step-by-step pseudocode of the i Ba ML is listed under Algorithm 1. In a nutshell, the implicit Bayesian meta-gradient computation consumes O(Ld) time, regardless of the optimization algorithm ˆ At. One can even employ more complicated algorithms such as second-order matrix-free optimization (Martens and Grosse 2015; Botev, Ritter, and Barber 2017). In addition, as the time complexity does not depend on K, one can increase K to reduce the approximation error in (14). The space complexity of i Ba ML is only O(d) thanks to the iterative implementation of CG steps. These considerations explain how i Ba ML addresses the scalability issue of explicit backpropagation. Theoretical Analysis This section deals with performance analysis of both explicit and implicit gradients in Bayesian meta-learning to further understand their differences. Similar to (Rajeswaran et al. 2019), our results will rely on the following assumptions. Assumption 1. Vector vt = At(θ) is a local minimum of the KL-divergence in (5). Assumption 2. The meta-loss function Lval t (vt, θ) is At Lipschitz and Bt-smooth w.r.t. vt while its partial gradient 2Lval t (vt, θ) is Ct-Lipschitz w.r.t. vt. Assumption 3. The expected nll function Ltr t (vt) is Dtsmooth, and has a Hessian that is Et-Lipschitz. Assumption 4. Matrices Ht(ˆvt) and Ht( vt) are both non-singular; that is, their smallest singular value σt := min σmin Ht(ˆvt) , σmin Ht( vt) > 0. Assumption 5. Prior variances are positive and bounded, meaning 0 < Dmin [d]i Dmax, i = 1, . . . , d. Based on these assumptions, we can establish the following result. Theorem 1 (Explicit Bayesian meta-gradient error bound). Consider the Bayesian meta-learning problem (6). Let ϵt := ˆvt vt 2 be the task-level optimization error, and δt := ˆ At(θ) Gt(ˆvt)H 1 t (ˆvt) 2 the error in the Jacobian. Upon defining ρt := max vt Ltr t ( vt) , ˆvt Ltr t (ˆvt) , and with Assumptions 1-5 in effect, it holds for t {1, . . . , T} that θLval t ˆvt(θ), θ θLval t vt(θ), θ 2 Ftϵt + Atδt (17) where Ft is a constant dependent on ρt. Theorem 1 asserts that the ℓ2 error of the explicit Bayesian meta-gradient relative to the true depends on the task-level optimization error as well as the error in the Jacobian, where the former captures the Euclidean distance of the local minimum vt and its approximation ˆvt, while the latter characterizes how the sub-optimal function ˆ At influences the Jacobian. Both errors can be reduced by increasing K in the task-level optimization, at the cost of time and space complexity for backpropagating ˆ At(θ). Ideally, one can have δt = 0 when ˆvt is a local optimum, and ϵt = 0 when choosing vt = ˆvt. Next, we derive an error bound for implicit differentiation. Theorem 2 (Implicit Bayesian meta-gradient error bound). Consider the Bayesian meta-learning problem (6). Let ϵt := ˆvt vt 2 be the task-level optimization error, and δ t := ˆut H 1 t (ˆvt) 1Lval t (ˆvt, θ) the CG error. Upon defining ρt := max vt Ltr t ( vt) , ˆvt Ltr t (ˆvt) , and with Assumptions 1-5 in effect, it holds for t {1, . . . , T} that ˆgt θLval t vt(θ), θ 2 F tϵt + G tδ t, (18) where F t and G t are constants dependent on ρt. While the bound on implicit meta-gradient also depends on the task-level optimization error, the difference with Theorem 1 is highlighted in the CG error. The fast convergence of CG leads to a tolerable δ t even with a small L. As a result, one can opt for a large K to reduce task-level optimization error ϵt, and a small L to obtain a satisfactory approximation of the meta-gradient. It is worth stressing that vt in Theorems 1 and 2 can denote any local optimum. It further follows by definition that both δt and δ t do not rely on the choice of local optima, yet ϵt does. One final remark is now in order. Remark 3. Theorems 1 and 2 can be further simplified under the additional assumption that Ltr t (vt) is Ht-Lipschitz. In such a case, we have ρt Ht, and thus the scalars Ft, F t and G t boil down to task-specific constants. Numerical Tests Here we test and showcase on synthetic and real data the analytical novelties of this contribution. Our implementation relies on the Py Torch (Paszke et al. 2019), and codes are available at https://github.com/zhangyilang/i Ba ML. Synthetic Data Here we experiment on the errors between explicit and implicit gradients over a synthetic dataset. The data are generated using the Bayesian linear regression model yn t = θt, xn t + en t , n, t = 1, . . . , T (19) where {θt}T t=1 are i.i.d. samples drawn from a distribution p(θt; ˆθ) that is unknown during meta-training, and en t is the additive white Gaussian noise (AWGN) with known variance σ2. Although the current training posterior p(θt|ytr t ; Xtr t , θ) becomes tractable, we still focus on the VI approximation for uniformity. Within this rudimentary linear case, it can be readily verified that the task-level optimum v t := [m t , d t ] of (5) is given by σ2 Xtr t (Xtr t ) + D 1 1 D 1m + 1 σ2 Xtr t ytr t 2σ2 diag Xtr t (Xtr t ) + d 1 1 , t = 1 . . . , T where diag(M) is a vector collecting the diagonal entries of matrix M. The true posterior in the linear case is p(θt|ytr t ; Xtr t , θ) = N(m t , 1 2σ2 (Xtr t (Xtr t ) ) + d 1 1), implying that the posterior covariance matrix is essentially approximated by its diagonal counterpart D t in VI. Lemma 1 and (9) imply that the oracle meta-gradient is θLval t (v t (θ), θ) (21) = Gt(v t )H 1 t (v t ) 1Lval t (v t , θ) + 2Lval t (v t , θ), t. As a benchmark meta-learning algorithm, we selected the amortized Bayesian meta-learning (ABML) in (Ravi and Beatson 2019). The metric used for performance assessment is the normalized root-mean-square error (NRMSE) between the true meta-gradient θLval t (v t (θ), θ), and the estimated meta-gradients θLval t (ˆvt(θ), θ) and ˆgt; see also the Appendix for additional details on the numerical test. Figure 1 depicts the NRMSE as a function of K for the first iteration of ABML, that is at the point θ = ˆθ 0. For explicit and implicit gradients, the NRMSE decreases as K increases, while the former outperforms the latter for K 5, and the vice-versa for K > 5. These observations confirm our analytical results. Intuitively, factors Ftϵt and F tϵt caused by imprecise task-level optimization dominate the upper bounds for small K, thus resulting in large NRMSE. Besides, implicit gradients are more sensitive to task-level optimization errors. One conjecture is that i Ba ML is developed based on Lemma 1, where the matrix inversion can be sensitive to vt s variation. Despite that the conditioning number κ of Xtr t takes on a large value purposely so that ϵt 10 20 30 40 50 Number K of GD steps explicit implicit, L=2 implicit, L=5 Figure 1: Gradient error comparison on synthetic dataset. decreases slowly with K, a small K suffices to capture accurately implicit gradients. The main reason is that the CG error δ t can become sufficiently small even with only L = 2 steps, while δt remains large because GD converges slowly. Real Data Next, we conduct tests to assess the performance of i Ba ML on real datasets. We consider one of the most widely used few-shot dataset for classification mini Image Net (Vinyals et al. 2016). This dataset consists of natural images categorized in 100 classes, with 600 samples per class. All images are cropped to have size of 84 84. We adopt the dataset splitting suggested by (Ravi and Larochelle 2017), where 64, 16 and 20 disjoint classes are used for meta-training, meta-validation and meta-testing, respectively. The setups of the numerical test follow from the standard W-class Strshot few-shot learning protocol in (Vinyals et al. 2016). In particular, each task has W randomly selected classes, and each class contains Str training images and Sval validation images. In other words, we have N tr = Str W and N val = Sval W. We further adopt the typical choices with W = 5, Str {1, 5}, and Sval = 15. It should be noted that the training and validation sets are also known as support and query sets in the context of few-shot learning. We first empirically compare the computational complexity (time and space) for explicit versus implicit gradients on the 5-class 1-shot mini Image Net dataset. Here we are only interested in backward complexity, so the delay and memory requirements for forward pass of ˆ At is excluded. Figure 2(a) plots the time complexity of explicit and implicit gradients against K. It is observed that the time complexity of explicit gradient grows linearly with K, while the implicit one increases only with L but not K. Moreover, the explicit and implicit gradients have comparable time complexity when K = L. As far as space complexity, Figure 2(b) illustrates that memory usage with explicit gradients is proportional to K. In contrast, the memory used in the implicit gradient algorithms is nearly invariant across K values. Such a memory-saving property is important when meta-learning is employed with models of growing degrees of freedom. Furthermore, one may also notice from both figures that 0 5 10 15 20 25 30 Number K of GD steps Normalized time complexity MAML ABML i MAML, L=5 i Ba ML, L=2 i Ba ML, L=5 (a) Time complexity 0 5 10 15 20 25 30 Number K of GD steps Normalized space complexity MAML ABML i MAML, L=5 i Ba ML, L=2 i Ba ML, L=5 (b) Space complexity Figure 2: Time and space complexity comparisons for meta-gradients computation on 5-class 1-shot mini Image Net dataset. Method nll accuracy MAML, K = 5 0.967 0.017 63.1 0.92% ABML, K = 5 0.957 0.016 62.8 0.74% i Ba ML, K = 5 0.965 0.018 63.2 0.74% i Ba ML, K = 10 0.947 0.017 64.0 0.75% i Ba ML, K = 15 0.943 0.017 64.0 0.74% Table 1: Test negative log-likelihood (nll) and accuracy comparison on 5-class 5-shot mini Image Net dataset. The sign indicates the 95% confidence interval. MAML and i MAML incur about 50% time/space complexities of ABML and i Ba ML. This is because non-Bayesian approaches only optimize the mean vector of the Gaussian prior, whose dimension is d, while the probabilistic methods cope with both the mean and diagonal covariance matrix of the pdf with corresponding dimension 2d. This increase in dimensionality doubles the space-time complexity in gradient computations. Next, we demonstrate the effectiveness of i Ba ML in reducing the Bayesian meta-learning loss. The test is conducted on the 5-class 5-shot mini Image Net. The model is a standard 4-layer 32-channel convolutional neural network, and the chosen baseline algorithms are MAML (Finn, Abbeel, and Levine 2017) and ABML (Ravi and Beatson 2019); see also the Appendix for alternative setups. Due to the large number of training tasks, it is impractical to compute the exact meta-training loss. As an alternative, we adopt the test nll (averaged over 1, 000 test tasks) as our metric, and also report their corresponding accuracy. For fairness, we set L = 5 when implementing the implicit gradients so that the time complexity is similar to explicit one with K = 5. The results are listed in Table 1. It is observed that both nll and accuracy improve with K, implying that the meta-learning loss can be effectively reduced by trading a small error in gradient estimation. MAML PLATIPUS ABML i Ba ML 0 Figure 3: Calibration errors on 5-class 1-shot mini Image Net. To quantify the uncertainties embedded in state-of-the-art meta-learning methods, Figure 3 plots the expected/maximum calibration errors (ECE/MCE) (Naeini, Cooper, and Hauskrecht 2015). It can be seen that i Ba ML is once again the most competitive among tested approaches. Conclusions This paper develops a novel so-termed i Ba ML approach to enhance the scalablity of Bayesian meta-learning. At the core of i Ba ML is an estimate of meta-gradients using implicit differentiation. Analysis reveals that the estimation error is upper bounded by task-level optimization and CG errors, and these two can be significantly reduced with only a slight increase in time complexity. In addition, the required computational complexity is invariant to the tasklevel optimization trajectory, what allows i Ba ML to deal with complicated task-level optimization. Besides analytical performance, extensive numerical tests on synthetic and real datasets are also conducted and demonstrate the appealing merits of i Ba ML over competing alternatives. Acknowledgments This work was supported in part by NSF grants 2220292, 2212318, 2126052, and 2128593. References Abbas, M.; Xiao, Q.; Chen, L.; Chen, P.-Y.; and Chen, T. 2022. Sharp-MAML: Sharpness-Aware Model-Agnostic Meta Learning. In Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, 10 32. PMLR. Bengio, S.; Bengio, Y.; and Cloutier, J. 1995. On the Search for New Learning Rules for ANNs. Neural Processing Letters, 2(4): 26 30. Bertinetto, L.; Henriques, J. F.; Torr, P.; and Vedaldi, A. 2019. Meta-learning with Differentiable Closed-Form Solvers. In Proceedings of International Conference on Learning Representations. Botev, A.; Ritter, H.; and Barber, D. 2017. Practical Gauss Newton Optimisation for Deep Learning. In Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, 557 565. PMLR. Chen, L.; and Chen, T. 2022. Is Bayesian Model-Agnostic Meta Learning Better than Model-Agnostic Meta Learning, Provably? In Proceedings of The 25th International Conference on Artificial Intelligence and Statistics, volume 151 of Proceedings of Machine Learning Research, 1733 1774. PMLR. Fallah, A.; Mokhtari, A.; and Ozdaglar, A. 2020. On the Convergence Theory of Gradient-Based Model-Agnostic Meta-Learning Algorithms. In Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, volume 108, 1082 1092. PMLR. Finn, C.; Abbeel, P.; and Levine, S. 2017. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. In Proceedings of the 34th International Conference on Machine Learning, volume 70, 1126 1135. PMLR. Finn, C.; Xu, K.; and Levine, S. 2018. Probabilistic Model Agnostic Meta-Learning. In Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc. Flennerhag, S.; Rusu, A. A.; Pascanu, R.; Visin, F.; Yin, H.; and Hadsell, R. 2020. Meta-Learning with Warped Gradient Descent. In Proceedings of International Conference on Learning Representations. Franceschi, L.; Frasconi, P.; Salzo, S.; Grazzi, R.; and Pontil, M. 2018. Bilevel Programming for Hyperparameter Optimization and Meta-Learning. In Proceedings of the 35th International Conference on Machine Learning, volume 80, 1568 1577. PMLR. Grant, E.; Finn, C.; Levine, S.; Darrell, T.; and Griffiths, T. 2018. Recasting Gradient-Based Meta-Learning as Hierarchical Bayes. In Proceedings of International Conference on Learning Representations. Griewank, A. 1993. Some bounds on the complexity of gradients, Jacobians, and Hessians. In Complexity in numerical optimization, 128 162. World Scientific. Hansen, N.; and Wang, X. 2021. Generalization in Reinforcement Learning by Soft Data Augmentation. In 2021 IEEE International Conference on Robotics and Automation (ICRA), 13611 13617. Kingma, D. P.; and Ba, J. 2015. Adam: A Method for Stochastic Optimization. In Proceedings of International Conference on Learning Representations. Lee, K.; Maji, S.; Ravichandran, A.; and Soatto, S. 2019. Meta-Learning With Differentiable Convex Optimization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Li, Z.; Zhou, F.; Chen, F.; and Li, H. 2017. Meta-sgd: Learning to learn quickly for few-shot learning. ar Xiv preprint ar Xiv:1707.09835. Martens, J.; and Grosse, R. 2015. Optimizing Neural Networks with Kronecker-factored Approximate Curvature. In Proceedings of the 32nd International Conference on Machine Learning, volume 37 of Proceedings of Machine Learning Research, 2408 2417. Lille, France: PMLR. Miao, Y.; Metze, F.; and Rawat, S. 2013. Deep maxout networks for low-resource speech recognition. In 2013 IEEE Workshop on Automatic Speech Recognition and Understanding, 398 403. IEEE. Mishra, N.; Rohaninejad, M.; Chen, X.; and Abbeel, P. 2018. A Simple Neural Attentive Meta-Learner. In International Conference on Learning Representations. Naeini, M. P.; Cooper, G.; and Hauskrecht, M. 2015. Obtaining well calibrated probabilities using bayesian binning. In Proceedings of the Twenty Ninth International Conference on Artificial Intelligence and Statistics, 2901 2907. PMLR. Nguyen, C.; Do, T.-T.; and Carneiro, G. 2020. Uncertainty in Model-Agnostic Meta-Learning using Variational Inference. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV). Nichol, A.; Achiam, J.; and Schulman, J. 2018. On First-Order Meta-Learning Algorithms. ar Xiv preprint ar Xiv:1803.02999. Paszke, A.; Gross, S.; Massa, F.; Lerer, A.; Bradbury, J.; Chanan, G.; Killeen, T.; Lin, Z.; Gimelshein, N.; Antiga, L.; Desmaison, A.; Kopf, A.; Yang, E.; De Vito, Z.; Raison, M.; Tejani, A.; Chilamkurthy, S.; Steiner, B.; Fang, L.; Bai, J.; and Chintala, S. 2019. Py Torch: An Imperative Style, High Performance Deep Learning Library. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc. Rajeswaran, A.; Finn, C.; Kakade, S. M.; and Levine, S. 2019. Meta-Learning with Implicit Gradients. In Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc. Ravi, S.; and Beatson, A. 2019. Amortized Bayesian Meta Learning. In Proceedings of International Conference on Learning Representations. Ravi, S.; and Larochelle, H. 2017. Optimization as a Model for Few-Shot Learning. In Proceedings of International Conference on Learning Representations. Santoro, A.; Bartunov, S.; Botvinick, M.; Wierstra, D.; and Lillicrap, T. 2016. Meta-Learning with Memory Augmented Neural Networks. In Proceedings of the 33rd International Conference on Machine Learning, volume 48, 1842 1850. New York, New York, USA: PMLR. Schmidhuber, J. 1993. A Neural Network that Embeds its Own Meta-Levels. In IEEE International Conference on Neural Networks, 407 412 vol.1. Schmidhuber, J.; Zhao, J.; and Wiering, M. 1996. Simple Principles of Metalearning. Technical report IDSIA, 69: 1 23. Thrun, S. 1998. Lifelong Learning Algorithms, 181 209. Boston, MA: Springer US. ISBN 978-1-4615-5529-2. Thrun, S.; and Pratt, L. 2012. Learning to Learn. Springer Science & Business Media. Van der Sluis, A.; and van der Vorst, H. A. 1986. The rate of convergence of conjugate gradients. Numerische Mathematik, 48(5): 543 560. Vinyals, O.; Blundell, C.; Lillicrap, T.; kavukcuoglu, k.; and Wierstra, D. 2016. Matching Networks for One Shot Learning. In Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc. Wang, H.; Sun, R.; and Li, B. 2020. Global Convergence and Generalization Bound of Gradient-Based Meta-Learning with Deep Neural Nets. ar Xiv preprint ar Xiv:2006.14606. Winther, R. 1980. Some Superlinear Convergence Results for the Conjugate Gradient Method. SIAM Journal on Numerical Analysis, 17(1): 14 17. yang, y.; Sun, J.; Li, H.; and Xu, Z. 2016. Deep ADMM-Net for Compressive Sensing MRI. In Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc. Yoon, J.; Kim, T.; Dia, O.; Kim, S.; Bengio, Y.; and Ahn, S. 2018. Bayesian Model-Agnostic Meta-Learning. In Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc.