# understanding_benign_overfitting_in_gradientbased_meta_learning__72993ebd.pdf Understanding Benign Overfitting in Gradient-based Meta Learning Lisha Chen Rensselaer Polytechnic Institute Troy, NY, USA chenl21@rpi.edu Songtao Lu IBM Research Yorktown Heights, NY, USA songtao@ibm.com Tianyi Chen Rensselaer Polytechnic Institute Troy, NY, USA chentianyi19@gmail.com Meta learning has demonstrated tremendous success in few-shot learning with limited supervised data. In those settings, the meta model is usually overparameterized. While the conventional statistical learning theory suggests that overparameterized models tend to overfit, empirical evidence reveals that overparameterized meta learning methods still work well a phenomenon often called benign overfitting. To understand this phenomenon, we focus on the meta learning settings with a challenging bilevel structure that we term the gradient-based meta learning, and analyze its generalization performance under an overparameterized meta linear regression model. While our analysis uses the relatively tractable linear models, our theory contributes to understanding the delicate interplay among data heterogeneity, model adaptation and benign overfitting in gradient-based meta learning tasks. We corroborate our theoretical claims through numerical simulations. 1 Introduction Meta learning, also referred to as learning to learn , usually learns a prior model from multiple tasks so that the learned model is able to quickly adapt to unseen tasks [43, 26]. Meta learning has been successfully applied to few-shot learning learning [2, 13], image recognition [52], federated learning [29], reinforcement learning [21] and communication systems [10]. While there are many exciting meta learning methods today, in this paper, we will study a representative meta learning setting where the goal is to learn a shared initial model that can quickly adapt to task-specific models. This adaptation may take an explicit form such as the output of one gradient descent step, which is referred to as the model agnostic meta learning (MAML) method [21]. Alternatively, the adaptation step may take an implicit form such as the solution of another optimization problem, which is referred to as the implicit MAML (i MAML) method [39]. Since both MAML and i MAML will solve a bilevel optimization problem, we term them the gradient-based meta learning thereafter. In many cases, overparameterized models are used as the initial models in meta learning for quick adaptation. For example, Resnet-based MAML models typically have around 6 million parameters, but are trained on 1-3 million meta-training data [12]. Training such initial models is often difficult in meta learning because the number of training data is much smaller than the dimension of the model parameter. Previous works on meta learning mainly focus on addressing the optimization challenges or analyzing the generalization performance with sufficient data [18, 19, 14]. Different from these works, we are particularly interested in the generalization performance of the sought initial model in practical scenarios where the total number of data from all tasks is smaller than the dimension of the initial model, which we term overparameterized meta learning. In those overparameterized regimes, the generalization error of meta learning models is not fully understood. Motivated by this, we ask: If and when overparameterized meta learning models would lead to overfitting, provably? 36th Conference on Neural Information Processing Systems (Neur IPS 2022). Empirical studies have demonstrated that the MAML with overparameterized model generally performs better than MAML with underparameterized model [12] a phenomenon often called benign overfitting. To show this, we plot in Figure 1 the empirical results from Table A5 in [12]. Complementing this, we take an initial step by answering this theoretical question in the meta linear regression setting. 1.1 Prior art Conv-4 Conv-6 Resnet-10 Resnet-18 Resnet-34 Networks Accuracy (%) CUB 1-shot CUB 5-shot mini Image Net 1-shot mini Image Net 5-shot Figure 1: Accuracy vs networks with increasing dimensions for MAML on few-shot image classification with different datasets [12]. We review prior art that we group in the following three categories. Benign overfitting analysis. The empirical success of overparameterized deep neural networks has inspired theoretical studies of overparameterized learning. The most closest line of work is benign overfitting in linear regression [5], which provides excess risk that measures the difference between expected population risk of the empirical solution and the optimal population risk. Analysis of overparameterized linear regression model with the minimum-norm solution. It concludes that certain data covariance matrices lead to benign overfitting, explaining why overparameterized models that perfectly fit the noisy training data can work well during testing. The analysis has been extended to ridge regression [46], multi-class classification [50], and adversarial learning with linear models [8]. While previous theoretical efforts on benign overfitting largely focused on linear models, most recently, the analysis of benign overfitting has been extended to two-layer neural networks [7, 33, 22]. However, existing works mainly study benign overfitting for empirical risk minimization problems, rather than bilevel problems such as gradient-based meta learning, which is the focus of this work. Meta learning. Early works of meta learning build black-box recurrent models that can make predictions based on a few examples from new tasks [43, 26, 2, 13], or learn shared feature representation among multiple tasks [44, 48]. More recently, meta learning approaches aim to find the initialization of model parameters that can quickly adapt to new tasks with a few number of optimization steps such as MAML [21, 38, 41]. The empirical success of meta learning has also stimulated recent interests on building the theoretical foundation of meta learning methods. Generalization of meta learning. The excess risk, as a metric of generalization ability of gradientbased meta learning has been analyzed recently [15, 3, 9, 49, 4, 19]. The generalization of meta learning has been studied in [32] in the context of mixed linear regression, where the focus is on investigating when abundant tasks with small data can compensate for lack of tasks with big data. Generalization performance has also been studied in a relevant but different setting - representation based meta learning [14, 17]. Information theoretical bounds have been proposed in [30, 11], which bound the generalization error in terms of mutual information between the input training data and the output of the meta-learning algorithms. The PAC-Bayes framework has been extended to meta learning to provide a PAC-Bayes meta-population risk bound [1, 40, 16, 20]. These works mostly focus on the case where the meta learning model is underparameterized; that is, the total number of meta training data from all tasks is larger than the dimension of the model parameter. Recently, overparameterized meta learning has attracted much attention. Bernacchia [6] suggests that in overparameterized MAML, negative learning rate in the inner loop is optimal during meta training for linear models with Gaussian data. Sun et al. [45] shows that the optimal representation in representation-based meta learning is overparameterized and provides sample complexity for the method of moment estimator. Besides our work, a concurrent work [27] also studies a common setting where the meta learning models incur overparameterization in the meta level, and we both cover the nested MAML method. However, the two studies differ in terms of how the empirical solution of the meta parameter is obtained. In our case, we consider the minimum ℓ-2 norm solution, while [27] consider the solution trained with T-step stochastic gradient descent (SGD). Furthermore, our analysis covers both MAML and i MAML, while [27] only considers MAML. Table 1: A comparison with closely related prior work on meta learning with linear models. Reps. and Gradient refer to representation based methods and gradient-based methods; Per-task refers to the per-task level overparameterization and Meta refers to the meta level overparameterization. Prior work Type of meta learning Overparameterization Methods Focus of analysis Reps. Gradient Per-task Meta Bai et al. [3] i MAML Train-validation split Bernacchia [6] MAML Optimal step size Chen et al. [9] - MAML, BMAML Test risk comparison Huang et al. [28] MAML SGD solution Kong et al. [32] - - - Effect of small data tasks Saunshi et al. [42] - Train-validation split Sun et al. [45] - Optimal representation Ours MAML, i MAML Benign overfitting Compared to the most relevant works, our work is different in the following aspects. Compared to the works that also analyze generalization error or sample complexity in linear meta learning models such as [15, 3, 9], we focus on the overparameterized case when the total number of training data is smaller than the dimension of the model parameter. Compared to the work that focus on representation-based meta learning with a bilinear structure [45], we consider initialization-based meta learning methods with a bilevel structure such as MAML and i MAML. Furthermore, we provide tight analysis of the excess risk with explicit consideration of the benign overfitting condition. A summary of key differences compared to prior art is provided in Table 1. We distinguish two different overparameterization settings: i) the per-task level overparameterization where the dimension of model parameter is larger than the number of training data per task, but smaller than the total number of data across all tasks; and, ii) the meta level overparameterization where the dimension of model parameter is larger than the total number of training data from all tasks. 1.2 This work This paper provides a unifying analysis of the generalization performance for meta learning problems with overparameterized meta linear models. To our best knowledge, this is the first work that provides the condition for benign overfitting in gradient-based meta learning including MAML and i MAML. Technical challenges. Before we introduce the key result of our paper, we first highlight the challenges of analyzing the generalization of gradient-based meta learning and characterizing its benign overfitting condition, compared to the non-bilevel setting such as in [5, 46, 45]. T1) Due to the bilevel structure of gradient-based meta learning, the solution to the meta training objective involves high order terms of data covariance. As a result, the dominating term in the excess risk propagated from the label noise contains higher order terms, which is harder to quantify and can potentially lead to orders of magnitude higher excess risk than the linear regression case [5, 46, 45]. T2) The existing analysis of benign overfitting in single-level problems [5, 46] has a solution that is directly related to the data covariance matrix. However, due to the nested structure of gradient-based meta learning and thus the solution matrix, the solution matrix is a function of both the data covariance matrix and the hyperparameters such as the step size. Therefore, what kind of data matrices can satisfy the benign overfitting condition cannot be directly implied. T3) Due to the multi-task learning nature of meta learning, the excess risk of MAML depends on the heterogeneity across different tasks in terms of both the task data covariance and the ground truth task parameter. As a result, the data covariance matrices from different tasks have different eigenvectors. This is in contrast to the linear regression case where all the data follow the same distribution. Contributions. In view of challenges, our contributions can be summarized as follows. C1) Focusing on the relatively tractable linear models, we derive the excess risk for the minimumnorm solution to overparameterized gradient-based meta learning including MAML and i MAML. Specifically, the excess risk upper bound adopts the following form Cross-task variance + Per-task variance + Bias where the cross-task variance quantifies the error caused by finite task number and the variation of the ground truth task specific parameter, which is a unique term compared to single task learning. The bias quantifies the bias resulting from the minimum-norm solution. And the per-task variance quantifies the error caused by noise in the training data. C2) We compare the benign overfitting condition for the overparameterized gradient-based meta learning models and that for the empirical risk minimization (ERM) which learns a single shared parameter for all tasks. We show that overfitting is more likely to happen in MAML and its variants such as implicit MAML than in ERM. In addition, larger data heterogeneity across tasks will make overfitting more likely to happen. C3) We discuss the choice of hyperparameter, e.g., the step size in MAML and the weight of the regularizer in i MAML, such that if the data leads to benign overfitting in ERM, it also leads to benign overfitting in MAML and i MAML. We show that a negative step size can preserve benign overfitting in MAML. This is complementary to the recent discovery that the optimal step size of overparameterized MAML during training is negative [6]. 2 Problem Formulation and Methods In this section, we will introduce the problem setup and the considered meta learning methods. Problem setup. In the meta-learning setting, assume task m is drawn from a task distribution, i.e. m M. For each task m, we observe N samples with input feature xm Xm Rd and target label ym Ym R drawn i.i.d. from a task-specific data distribution Pm. These samples are collected in the dataset Dm = {(xm,n, ym,n)}N n=1, which is divided into the train and validation datasets, denoted as Dtr m and Dva m. And |Dtr m| = Ntr and |Dva m| = Nva with N = Ntr + Nva. We use the empirical loss ℓm(θm, Dm) of per-task parameter θm Θm as a measure of the performance. In this paper, we consider regression problems, where ℓm is defined as the mean squared error. The goal for gradient-based meta learning methods, such as MAML [21] and i MAML [39], is to learn an initial parameter θ0 Θ0, which, with an adaptation method A : Θ0 (Xm Ym)Ntr Θm, can generate a per-task parameter θm that performs well on the validation data for task m. Given M tasks, our meta-learning objective is computed as the average of the per-task objective, given by Meta training objective LA(θ0, D) := 1 m=1 ℓm(A(θ0, Dtr m), Dva m). (1) Obtaining the empirical solution ˆθA 0 by minimizing (1) under a meta learning method A, in the meta testing stage, we evaluate ˆθA 0 on the population risk, given by Meta testing objective RA(ˆθA 0 ) := Em h EDm ℓm(A(ˆθA 0 , Dtr m), Dva m) i . (2) Figure 2: Two types of meta learning. Methods. We focus on understanding the generalization of two representative gradient-based meta learning methods MAML [21] and i MAML [39] in the overparameterized regime. MAML obtains the task-specific parameter ˆθm(θ0) by taking one step gradient descent with step size α of the per-task loss function ℓm from the initial parameter θ0, that is A(θ0, Dtr m) = θ0 α θ0ℓm(θ0, Dtr m). (3) On the other hand, i MAML obtains the task-specific parameter ˆθm from the initial parameter θ0 by optimizing the taskspecific loss regularized by the distance between ˆθm and θ0, that is A(θ0, Dtr m) = arg min θ ℓm(θ, Dtr m) + γ 2 θ θ0 2 (4) where γ > 0 is the weight of the regularizer. As summarized in Figure 2, MAML has smaller computation complexity than i MAML since i MAML requires solving an inner problem during adaptation, while i MAML may achieve smaller test error since it explicitly minimize the loss. 3 Main Results: Benign Overfitting for Gradient-based Meta Learning In this section, we introduce the meta linear regression model and some necessary assumptions for the analysis. We present the main results, highlight the key steps of the proof and conduct simulations to verify our results. Due to space limitations, we will defer the proofs to the supplementary document. 3.1 Meta linear regression setting To make a precise analysis, we will assume the following linear data model. Denoting the ground truth parameter on task m as θ m Rd, and the noise as ϵm, we assume the data model for task m is ym = θ m xm + ϵm. (5) Given the linear model (5), the meta training problem (1) with adaptation method (3) or (4) generally have unique solutions when d NM. However, when the meta model θ0 and thus the per-task model θm are overparameterized, i.e. d > NM, the training problem (1) may have multiple solutions. In the subsequent analysis, we will analyze the performance of the minimum norm solution because recent advances in training overparameterized models reveal that gradient descent-based methods converge to the minimum norm solution [24, 35]. We provide a formal definition below. Definition 1 (Minimum norm solution). Denote Xva m := [xm,1, . . . , xm,Nva] RNva d, yva m := [ym,1, . . . , ym,Nva] RNva. With A(θ, Dtr m) being either (3) or (4), the minimum norm solution to the meta training problem (1) under the linear regression loss is expressed by min θ0 θ0 2 s.t. θ0 arg min θ Xva m A(θ, Dtr m) yva m 2 . (6) In our analysis, we make the following basic assumptions. Assumption 1 (Overparameterized model). The total number of meta training data is smaller than the dimension of the model parameter; i.e. NM < d. Assumption 2 (Sub Gaussian data). The noise ϵm is sub Gaussian with E[ϵm] = 0 and E[ϵ2 m] = σ2. For the m-th task, data xm = VmΛ 1 2mzm, where zm has centered, independent, σx-sub Gaussian entries; E[zm] = 0, E[zmz m] = Id, with Id being a d d identity matrix. Assumption 3 (Data covariance matrix). 1) Assume for all m [M], i [d], λm,i > 0, Tr(Λm), Tr(Λ) are bounded, i.e. for all m [M], Tr(Λm) cλ. 2) Cross-task data heterogeneity V({Qm}M m=1) := maxi,m |(λi λm,i)/λi| is bounded above and below. Assumption 4 (Task parameter). The ground truth parameter θ m is independent of Xm and satisfies Cov[θ m] = (R2/d)Id, where R is a constant, and the entries of θ m are i.i.d. O(R/ d)-sub Gaussian. Assumption 1 defines the setting that the meta level is overparameterized, which has also been used in [45]. Note that Assumptions 2-4 are common in the analysis of meta learning in [15, 3, 9, 23]. With the linear data model (5), the (minimum norm) solutions to the meta training objective (1) and the meta testing objective (2) can be computed analytically which we will summarize next. Proposition 1. (Empirical and population level solutions) Under the meta linear regression model (5), the meta testing objective of method A in (2) can be equivalently written as RA(θ0) = Em θ0 θ m 2 WA m (7) where the matrix WA m and its empirical version ˆ WA m are given in Table 2 with ˆQal m := 1 N Xal m Xal m. The optimal solutions to the meta-test risk and the minimum-norm solutions to the empirical meta training loss are given below respectively θA 0 := arg min θ0 RA(θ0) = Em WA m 1Em WA mθ m (8a) ˆθA 0 := arg min θ0 LA(θ0, D) = XM m=1 ˆ WA m XM m=1 ˆ WA mθ m + A M (8b) where denotes the Moore-Penrose pseudo inverse; A M is an error term that depends on Xm, ϵm, and specified in the supplementary document. Table 2: Weight matrices under different method A. Method Weight matrices ERM Wer m = Qm ˆ Wer m = ˆQm MAML Wma m = (I αQm)Qm(I αQm) ˆ Wma m = (I α ˆQtr m) ˆQva m(I α ˆQtr m) i MAML Wim m = (γ 1Qm + I) 1Qm(γ 1Qm + I) 1 ˆ Wim m = (γ 1 ˆQtr m + I) 1 ˆQva m(γ 1 ˆQtr m + I) 1 To study overfitting in the meta learning model, we quantify its generalization ability via the widely used metric - excess risk. The excess risk of method A (which can be ma for MAML and im for i MAML), with an empirical solution ˆθA 0 and population solution θA 0 , is defined as EA(ˆθA 0 ) := RA(ˆθA 0 ) RA(θA 0 ). (9) In (9), the excess risk measures the difference between the population risk of the empirical solution, ˆθ0 and the optimal population risk. Given total number of training samples MN, if d , the classic learning theory implies that the excess risk EA(ˆθA 0 ) also grows, which leads to overfitting [25]. The larger the excess risk, the further the empirical solution ˆθA 0 is from the optimal population solution θA 0 , indicating more severe overfitting. 3.2 Main results With the closed-form solutions given in Proposition 1, we are ready to bound the excess risk of MAML and i MAML in the overparameterized linear regime. For notation brevity, we first introduce some universal constants such as c0, c1, c2, . . . , and only present the dominating terms in the subsequent results. The precise presentation of remaining terms are deferred to the supplementary document. We first decompose the excess risk into three terms in Proposition 2. Proposition 2. Define WA := Em[WA m]. The excess risk of a meta learning method A can be bounded by EA(ˆθA 0 ) Eθ m + Eϵm + Eb (10) where the first term Eθ m is a function of θ m, θA 0 , WA, ˆ WA m, which quantifies the weighted variance of the ground truth task specific parameters θ m; the second term Eϵm, as a function of ϵm, is the weighted noise variance; and the third term Eb, as a function of θA 0 , WA, ˆ WA m, is the bias of the minimum-norm solution in overparameterized MAML or i MAML. Based on this decomposition, as we will show in Section 4, the bound of the excess risk can be derived from the bound of these three terms Eθ m, Eϵ m, Eb, respectively, which gives Theorem 1. Theorem 1 (Excess risk bound). Suppose Assumptions 1-4 hold. Let µ1( ) µ2( ) . . . denote the eigenvalues of a matrix in the descending order. For the meta linear regression problem with the minimum-norm solution (6), for 0 k d, define the effective ranks as µk+1 (WA) ; Rk WA := i>k µi(WA) 2 P i>k µ2 i (WA) . (11) With the cross-task data heterogeneity V defined in Assumption 3, if there exist universal constants c1, c2, c3 > 1 such that the effective dimension k = min{k 0 : rk(WA) c1NM}, c2 log(1/δ) < NM and k < NM/c3, then with probability at least 1 δ, the excess risk satisfies EA(ˆθA 0 ) E[θ m] 2 λ MN + MN Rk (WA) 1 + V({WA m}M m=1) Theorem 1 provides the excess risk bound via the effective ranks. In (11), the effective ranks rk and Rk of a matrix capture the distribution of the eigenvalues of this matrix, and the effective dimension k determines the above upper bound by considering the asymmetry of the eigenvalues of the solution matrix. The idea is to choose k that makes Rk large enough and keeps k small enough compared to MN so that the variance term of the excess risk is controlled. For example, r0 is the trace normalized by the largest eigenvalue, which is bounded above by R0. And both r0 and R0 are no larger than the rank of the matrix, and they are equal to the rank only when all non-zero eigenvalues are equal. If the eigenvalues distribute more uniformly, the effective rank will be larger, otherwise smaller. 0 10 20 30 40 50 60 70 80 N ERM , = 0:01 , = 0:1 , = 0:2 , = 0:5 (a) MAML with different α. 0 10 20 30 40 50 60 70 80 N ERM . = 0:01 . = 0:1 . = 10 . = 1000 (b) i MAML with different γ. Figure 3: Excess risk vs number of samples (N) with different hyperparameters (M = 10, d = 200). Remark 1. 1) The definition of effective rank has been also given in [5] but only on the data matrix Q. And our setting reduces to the single task ERM learning, or the linear regression case in [5], when M = 1, θ m = θ0, WA m = Q, which implies that the cross-task variance in (10) as well as the data heterogeneity V( ) reduces to zero. Accordingly, Theorem 1 reduces to Theorem 4 in [5]. 2) Given Theorem 1, in order to control the excess risk of solution ˆθA 0 , we want r0(WA) to be small compared to the total number of training samples MN, but rk (WA) and Rk (WA) to be large compared to MN. In addition, the cross-task heterogeneity V should be small. Since for a matrix W, rk(W) Rk(W) d, this suggests the model benefits from overparameterization. Building upon Theorem 1, we now discuss the conditions for benign overfitting , which refers to the situation that overparameterization does not harm the excess risk, or the excess risk still vanishes when d > MN and N, M, d increase. Definition 2 (Condition for benign overfitting in meta learning). The weight matrices WA for method A satisfy the benign overfitting condition in gradient-based meta learning, if and only if lim NM,d r0(WA) NM = lim NM,d k NM = lim NM,d NM Rk (WA) = 0. (13) This guarantees the excess risk (12) goes to zero in overparameterized meta learning models with sufficient training data from all tasks. To provide an intuitive explanation, Figure 3 plots the population risk versus the number of the training data, which demonstrates the double descent curve. Namely, as N increases, E(ˆθ0) first decreases, then increases and then decreases again, as is discovered in overparameterized neural networks [36]. The trend in Figure 3 is similar to the trend observed in [37]. When d/(NM) > 1, the model is overparameterized, which can overfit the training data, leading to larger excess risk as N decreases. However, Figure 3 shows the excess risk does not become too large as N decreases, indicating that overfitting does not severely harm the population risk in this case. 3.3 Examples and discussion In this section, we discuss how the benign overfitting condition (13) in gradient-based meta learning reduces to that in single task linear regression; e.g., in [5, 46]. We also provide examples to show Q1) how certain properties of meta training data affect the excess risk; and, Q2) how to choose the hyperparameters that preserve benign overfitting. Data covariance and cross-task heterogeneity. Theorem 1 reveals that the excess risk depends on both the eigenvalues of the data covariance matrix Qm, and the cross-task data heterogeneity, measured by V({Qm}M m=1). We give an example below to better demonstrate how these two properties of gradient-based meta training data affect the excess risk. Example 1 (Data covariance). Suppose Qm = diag(Id1, βId d1), m. Set M = 10, d = 200, d1 = 20, α = 0.1 for MAML and γ = 103 for i MAML. Then the benign overfitting condition (13) is satisfied by MAML and i MAML. We plot the excess risk under different β in Figure 4. 0 10 20 30 40 50 60 70 80 N - = 1 - = 0:3 - = 0:1 - = 0:03 - = 0:01 0 10 20 30 40 50 60 70 80 N - = 1 - = 0:3 - = 0:1 - = 0:03 - = 0:01 Figure 4: Excess risks vs number of samples (N) for Qm = diag(Id1, βId d1) with different β. 0 10 20 30 40 50 60 70 80 N