# taskrobust_modelagnostic_metalearning__5c218444.pdf Task-Robust Model-Agnostic Meta-Learning Liam Collins ECE Department University of Texas at Austin Austin, TX 78712 liamc@utexas.edu Aryan Mokhtari ECE Department University of Texas at Austin Austin, TX 78712 mokhtari@austin.utexas.edu Sanjay Shakkottai ECE Department University of Texas at Austin Austin, TX 78712 sanjay.shakkottai@utexas.edu Meta-learning methods have shown an impressive ability to train models that rapidly learn new tasks. However, these methods only aim to perform well in expectation over tasks coming from some particular distribution that is typically equivalent across meta-training and meta-testing, rather than considering worstcase task performance. In this work we introduce the notion of task-robustness by reformulating the popular Model-Agnostic Meta-Learning (MAML) objective [12] such that the goal is to minimize the maximum loss over the observed meta-training tasks. The solution to this novel formulation is task-robust in the sense that it places equal importance on even the most difficult and/or rare tasks. This also means that it performs well over all distributions of the observed tasks, making it robust to shifts in the task distribution between meta-training and meta-testing. We present an algorithm to solve the proposed min-max problem, and show that it converges to an ϵ-accurate point at the optimal rate of O(1/ϵ2) in the convex setting and to an (ϵ, δ)-stationary point at the rate of O(max{1/ϵ5, 1/δ5}) in nonconvex settings. We also provide an upper bound on the new task generalization error that captures the advantage of minimizing the worst-case task loss, and demonstrate this advantage in sinusoid regression and image classification experiments. 1 Introduction Despite continual advances in computational power and data collection, many scenarios remain in which machine learning models must rapidly adapt to previously unseen tasks. Motivated by such scenarios, meta-learning techniques aim to learn how to learn quickly from few samples by leveraging knowledge acquired while learning prior tasks [4, 35]. The recent successes of these techniques in areas such as few-shot learning [12, 31, 33, 37] and reinforcement learning [8, 34, 38] have sparked tremendous interest in meta-learning. Following the setting introduced in [3], most offline meta-learning methods try to minimize the expected loss on new tasks drawn from the same, but unknown, distribution as a finite set of metatraining tasks. For example, in gradient-based meta-learning, the learning method is typically a small number of stochastic gradient descent (SGD) steps, and the means to learn quickly is having a favorable initialization. Standard methods thus try to find an initialization that enables the model fine-tuned via task-specific SGD to perform well in expectation over new tasks. Since they assume 34th Conference on Neural Information Processing Systems (Neur IPS 2020), Vancouver, Canada. the new tasks are drawn from the same unknown distribution as the meta-training tasks, during meta-training they attempt to minimize the average empirical loss after one step of SGD [12, 26]. However, by minimizing the average loss, such methods may perform arbitrarily poorly on difficult and/or rare meta-training tasks. In many cases, a model that performs well across all tasks is desired, even the most difficult and rare tasks. Consider for example applications in which safety is critical, such as object detection in self-driving cars, in which failing to detect rarely seen objects may result in driving accidents. In this and similar settings, the failure of the system to produce accurate results for the worst-case task could possibly cause severe issues. Moreover, existing methods disregard for worst-case performance relies on the often unrealistic assumption that the meta-test tasks are drawn from the same distribution as the meta-training tasks. If the meta-training dataset overestimates the prevalence of certain types of tasks in the meta-test distribution, existing methods will overfit to the popular tasks and fail to generalize to new tasks in both expectation and in the worst case. Indeed, existing generalization bounds for gradient-based meta-learning strategies depend on the similarity of the meta-test tasks to the meta-training solution [41, 2], rather than exploiting the diversity of the meta-training tasks to show generalization to a broad range of new tasks. To address these issues, we propose a novel meta-learning formulation that calls for minimizing the maximum as opposed to average task loss during meta-training. Our contributions are threefold: We modify the standard gradient-based meta-learning framework, Model-Agnostic Meta Learning (MAML) [12], to find an initialization that minimizes the loss after one SGD step for the worst-case task, where tasks are broadly defined as distributions over few-shot learning problems. Our new formulation, Task-Robust MAML (TR-MAML), thus yields a "task-robust" solution, in the sense that it prioritizes performance equally on all observed tasks, including the hardest and rarest ones. Importantly, this means it is also robust to all shifts in distribution over the sampled tasks from meta-training to meta-testing. We present an algorithm to solve our min-max formulation and prove that it convergences efficiently in both convex and nonconvex settings. In the convex case, it achieves the optimal rate of O(ϵ 2) stochastic gradient evaluations, and in the nonconvex case, it reaches an (ϵ, δ)-stationary point at a rate of O(max{ϵ 5, δ 5}) stochastic gradient evaluations. We capture the generality of our formulation s task robustness by giving a Rademacher complexity bound on the generalization error of any new task within the convex hull of the meta-training tasks, as well as showing improved performance in few-shot sinusoid regression and image classification experiments compared to MAML. Related Work. Among a variety of meta-learning formulations, MAML [12] has become especially popular due to its efficiency and flexibility, inspiring many follow-up works [1, 22, 26, 5, 21]. From more theoretical perspectives, [11] analyzed the convergence of MAML with nonconvex losses, [30] and [41] presented MAML variants with guarantees in both convex and nonconvex settings, and other works have shown regret bounds for online analogues of MAML [13, 42, 18]. Meanwhile, robustness in meta-learning has been studied in multiple recent works. In [43] and [40], the authors proposed models whose expected performance is robust to perturbations in the task samples, and [20] extended MAML to deal with imbalances in the number of samples per task instance and out-of-distribution meta-test tasks, but their model requires a complicated dataset encoding and computing per-task balancing variables. Additionally, In [15] a heuristic was introduced that aims to prevent over-performing on certain meta-training tasks by regularizing the inequality among task losses, although only across mini-batches. [6] also considered a task-weighted objective and showed Rademacher complexity-based generalization bounds, but their weights utilize task similarity to a particular target rather than optimizing for worst-case performance. To the best of our knowledge, no other offline meta-learning formulation attempts to minimize the worst-case loss over tasks. Many works outside meta-learning have considered min-max optimization problems of the finitesum form discussed here. In the context of distributionally-robust optimization, [32] and [10] argued that minimizing the maximal loss over a set of possible distributions can provide better generalization performance than minimizing the average loss. While the stochastic mirror descentascent algorithm achieves the asymptotically optimal O(ϵ 2) convergence rate to an ϵ-accurate solution in the convex setting [25], the literature is less established for nonconvex problems. In [29], the authors proposed a stochastic inexact proximal point method that attains O(ϵ 6) convergence in terms of the outer minimization problem when that problem is nonsmooth and weakly convex, while in [28] O(ϵ 4) convergence was shown when the outer problem is smooth and strongly convex. In the deterministic case, the authors of [27] demonstrated an O(ϵ 3.5) convergence rate to an ϵ-first-order Nash equilibrium for a gradient descent-ascent algorithm. Also, [7] and [16] analyzed first-order methods that improve on these rates but rely on an oracle to solve the inner maximization. 2 Problem Formulation Before discussing our min-max objective, we first formalize the meta-learning scenario. Let x X and y Y denote inputs and labels, respectively, and let hw : X Y represent the model parameterized by w. The performance of hw on a point (x, y) X Y is determined by ℓ(hw(x), y), where ℓ: Y Y R+ is a loss function, e.g., the mean squared error in regression and the cross entropy loss in classification. We define a task Ti as a distribution Di over task instances, which are few-shot learning episodes composed of two data batches, Dtrain i,j and Dtest i,j , of K and J points, respectively, in X Y. Within each task instance, the goal of the learner is to perform well on the points in Dtest i,j after learning from the points in Dtrain i,j , which is made possible by assuming that each point in both batches is an i.i.d. sample from the same distribution Di,j over X Y. During meta-training, a finite number of task instances are observed by first sampling a task Ti from P(T ), the meta-training distribution over tasks, then sampling (Dtrain i,j , Dtest i,j ) Di. Let there be mi instances of the i-th task for each of n tasks observed during meta-training, for a total of m := Pn i=1 mi task instances. In MAML, for each task instance, the dataset Dtrain i,j is used to update a global initialization w via one SGD step with respect to the expected loss of the model on Di,j, namely fi,j(w) := E(x,y) Di,j[ℓ(hw(x), y)]. Afterwards, the resulting "test" loss is approximated using Dtest i,j , which serves as the meta-training loss. With the ultimate goal of learning how to learn new task instances coming from the same distribution P(T ), the meta-training objective is to find a w that minimizes the post-update loss on Dtest i,j on average over the observed task instances, namely: min w W 1 m j=1 ˆfi,j(w α ˆfi,j(w; Dtrain i,j ), Dtest i,j ), (1) where α is the inner update step size and ˆfi,j( , Dtest i,j ) = 1 (x,y) Dtest i,j ℓ(hw(x), y) is the sample- average approximation of fi,j( ) using the J samples in Dtest i,j , and likewise for ˆfi,j( , Dtrain i,j ). As referred to in the introduction, the solution of (1) may perform arbitrarily poorly on tasks that differ significantly from the average task instance, which is especially problematic if tasks similar to those become more prevalent at meta-test time due to a distributional shift. Thus, we propose to treat all n meta-training tasks equally by minimizing the maximum task empirical average meta-loss ˆFi(w): min w W max i [n] ˆFi(w) := 1 j=1 ˆfi,j(w α ˆfi,j(w, Dtrain i,j ), Dtest i,j ) . (2) Problem (2) is equivalent to the problem of finding the w that minimizes the worst-case metalearning performance over all distributions of the n tasks, since the worst-case distributions will occur at the extreme points of the probability simplex in n dimensions. We write this relaxed problem as min w W max p n i=1 pi ˆFi(w) , (3) where pi is the probability associated with task i, the vector p = (p1, . . . , pn) is the concatenation of probabilities, and n = {p Rn + |Pn i=1 pi =1}. Note that (3) may be hard to solve if n is very large, and in many applications, m is indeed very large. However, n need not be, as tasks may be defined to encompass many similar task instances. We provide experiments for this case in Section 6. By optimizing for worst-case performance, the formulation in (3) encourages a solution w that performs similarly across all of the observed tasks. Instead of disregarding performance on some tasks, any algorithm that solves (3) must try to perform reasonably well on all of them. Indeed, as observed in [9], the min-max formulation implicitly regularizes the variance of the losses. This naturally makes the solution robust to distributional shifts between meta-training and meta-testing, and we provably show its ability to generalize to new tasks in Section 5. 3 Algorithm Taking inspiration from [25], we propose to solve the meta-training problem (3) using a Euclidean version of the robust stochastic mirror-prox algorithm. Our method, termed TR-MAML and outlined in Algorithm 1, requires stochastic gradient estimates of the function φ(w, p) defined in (3) with respect to w and p. Note that the full gradients, denoted by gw(w, p) and gp(w, p), respectively, are j=1 (I α 2 ˆfi,j(w, Dtrain i,j )) ˆfi,j(w α ˆfi,j(w, Dtrain i,j ), Dtest i,j ), (4) gp(w, p) = 1 j=1 ˆfi,j(w α ˆfi,j(w, Dtrain i,j ), Dtest i,j ) 1 i n , (5) where 2 ˆfi,j(w, Dtrain i,j ) is the sample average approximation of 2fi,j(w) based on the K samples in Dtrain i,j , and the notation [ai]1 i n corresponds to the vector [a1, . . . , an] Rn. Since n and the mi s may be large, TR-MAML must estimate the full gradients gw and gp on each iteration. To do so, it first uniformly and independently samples a set C of C indices {ik}C k=1 from {1, . . . , n}. For each ik C, the algorithm samples one index jk uniformly from {1, . . . , mi}, then estimates gw(w, p) and gp(w, p) using the data {(Dtrain ik,jk, Dtest ik,jk)}C k=1. The two estimates can then be written as ˆgw(w, p) = n k=1 pik(I α 2 ˆfik,jk(w, Dtrain ik,jk)) ˆfik,jk(w α ˆfik,jk(w, Dtrain ik,jk), Dtest ik,jk), (6) ˆgp(w, p) = n k=1 ˆfik,jk(w α ˆfik,jk(w, Dtrain ik,jk), Dtest ik,jk)eik, (7) where eik is the ik-th standard basis vector in Rn. We show that ˆgw(w, p) and ˆgp(w, p) are unbiased and bound their second moments in Section 4. In order to solve (3), TR-MAML initializes p0 = [1/n]1 i n and w0 W, then executes alternating projected stochastic gradient descent-ascent. In particular, from iterations t = 0 to T 1, TR-MAML computes wt+1 and pt+1 as wt+1 = ΠW(wt ηt wˆgw(wt, pt)), pt+1 = Π n(pt + ηt pˆgp(wt, pt)), (8) where ηt w and ηt p are step sizes, ΠW (u) = arg minw W u w 2 and Π n(q) = arg minp n p q 2. The projections are convex programs and can be solved efficiently using standard techniques. In particular, since n is the full simplex, Π n(q) can be computed in O(n log n) time [39]. As mentioned previously, tasks can be defined to leverage similarity among the task instances such that n is small, in which case the O(d2) per-iteration cost of both MAML and TR-MAML due to the Hessian estimations trivializes the added cost of the simplex projection in TR-MAML, thus TR-MAML has effectively the same computational cost as MAML. Nevertheless, first-order MAML approximations [12, 26, 11] may be seamlessly applied to TR-MAML to reduce the computational burden. After T iterations, TR-MAML terminates in one of two ways: Case T1. If each ˆFi(w) is convex, TR-MAML outputs wc T := 1 T PT t=1 wt and pc T := 1 T PT t=1 pt. Case T2. Otherwise, TR-MAML samples τ uniformly from {1, ..., T} and outputs wτ T := wτ and pτ T := pτ. 4 Convergence Analysis We next analyze the convergence of TR-MAML to a solution of (3). Convergence results for stochastic gradient-based algorithms typically assume access to unbiased stochastic gradients with bounded second moments [25, 29]. In our case, ˆgw and ˆgp are naturally unbiased, but bounding their second moments requires modest assumptions on the functions ˆfi,j due to the nested structure of ˆFi. Assumption 1. ˆfi,j( , Dtrain i,j ) and ˆfi,j( , Dtest i,j ), i [n] and j [mi] are ˆB-bounded and ˆLLipschitz. Furthermore, λmin( 2 ˆfi,j(w, Dtrain i,j )) ˆH for all w W. With this assumption, we can bound the second moments. All proofs are given in the appendix. Lemma 1. Under Assumption 1, for all w W, p n, vectors ˆgw(w, p) and ˆgp(w, p) satisfy: (i) E[ˆgw(w, p)] = gw(w, p), E[ˆgp(w, p)] = gp(w, p); and (ii) Bounded second moment: E[ ˆgw 2 2] n(1 + α ˆH)2 ˆL2; E[ ˆgp 2 2] n(n+C+1) ˆ B2 Algorithm 1 Task-Robust MAML (TR-MAML) Input: m task instances of n unique tasks; parameters α, {ηt w}t, {ηt p}t, T,C Initialize p1 = [ 1 n]1 i n and w1 W arbitrarily. for t = 0 to T 1 do Sample a batch C of C unique task indices uniformly from {1, . . . , n}. for ik C do Sample one task instance index jk uniformly from {1, . . . , mik}. end for Compute ˆgw(wt, pt) and ˆgp(wt, pt) using (6) and (7), respectively. Update wt+1 and pt+1 as in (8). end for Output: See Cases T1 and T2. Convex Setting. Our first convergence result holds in the case when each ˆFi is convex. Note that the convexity of each fi,j does not imply the convexity of ˆFi (consider as a counterexample fi,j(w) = 1/w for w R+ \ {0}). In Lemma 2 we adapt a result from [13] showing that the strong convexity of each ˆfi,j( , Dtest i,j ) implies the strong convexity of ˆFi under an additional assumption on each ˆfi,j( , Dtrain i,j ). Assumption 2. ˆfi,j( , Dtrain i,j ), for all j [mi], is ˆ M-smooth and ˆρ-Hessian-Lipschitz. Lemma 2. (Adapted from [13], Theorem 1) Suppose α < 1/ ˆ M and Assumptions 1-2 hold. If ˆfi,j( , Dtest i,j ) is ˆµ-strongly convex j [mi], then ˆFi is µ := (ˆµ(1 α ˆ M)2 αˆLˆρ)-strongly convex. The optimal rate of convergence for solving convex-concave stochastic min-max problems is O(1/ϵ2), where convergence rate is measured in terms of the expected number of stochastic gradient computations required to achieve a duality gap of ϵ [25]. The duality gap of the pair ( w, p) is defined as maxp n φ( w, p) minw W φ(w, p). By strong duality, ( w, p) is optimal if and only if it has a duality gap of zero. We show that TR-MAML achieves the optimal O(1/ϵ2) rate by adapting Theorem 2 from [24], which in turn is a simplified version of Theorem 1 from [17]. Theorem 1. (Adapted from [24], Theorem 2) Consider problem (3) when each ˆFi is convex and Assumption 1 holds. Suppose there exists a ball of radius RW that contains W. With step sizes ηw = 2RW/((1 + α ˆH)ˆL n T) and ηp = 2/( ˆGp T), the output of TR-MAML satisfies: E max p n φ(wc T , p) min w W φ(w, pc T ) 3 n RW(1 + α ˆH)ˆL + 3 ˆGp Thus, TR-MAML requires T = O(1/ϵ2) iterations to reach an expected duality gap of at most ϵ. Since it computes a constant number of stochastic oracle evaluations per iteration, its convergence rate is the optimal O(1/ϵ2) stochastic oracle calls to reach an ϵ-accurate solution. Nonconvex Setting. We next study the case when each ˆFi may be nonconvex and as a result, φ(w, p) may be nonconvex in w. Here we must evaluate the pair (wτ T , pτ T ) returned by our algorithm differently with respect to p and w: we still intend that pτ T n globally maximizes φ(wτ T , ), but can only hope to find wτ T near a stationary point of φ( , pτ T ). Thus, we say that ( w, p) is an (ϵ, δ)-stationary point of φ if wφ( w, p) 2 ϵ and φ( w, p) max p n φ( w, p) δ, (9) where ϵ, δ > 0, assuming that W = Rd, otherwise we consider the projected gradient, which we discuss later. In either case we will leverage smoothness. The function that we aim to minimize, maxp n φ(w, p), is non-smooth because of the maximization, but we can again adapt a result from [13] to show that each ˆFi is smooth under the previous assumptions on each ˆfi,j. Lemma 3. (Adapted from [13], Theorem 1) Under Assumptions 1 and 2, each ˆFi is M-smooth, where M := ˆ M(1 + α ˆ M)2 + αˆLˆρ. We must also compute the expected squared deviation of the stochastic gradient ˆgw, denoted by σ2 w. Lemma 4. For all w W and p n, σ2 w(w, p) := E[ ˆgw(w, p) gw(w, p) 2 2] = n C σ2(w, p) + n i=1 σ2 i (w, p) (10) where σ2(w, p) := Pn i=1 pi ˆFi(w) 1 n Pn i =1 pi ˆFi (w) 2 2 and σ2 i (w, p) := p2 i mi Pmi j=1 (I α 2 ˆfi,j(w, Dtrain i,j )) ˆfi,j(w α 2 ˆfi,j(w, Dtrain i,j ), Dtest i,j ) ˆFi(w) 2 2. Here σ2 represents the inter-task variance and each σ2 i represents an intra-task variance. With σ2 w defined, the convergence of TR-MAML when W = Rd can be shown via the following theorem. Theorem 2. If Assumptions 1 and 2 hold, W = Rd and ηt w = T β, and ηt p = 2(T 2β ˆGp) 1 for all t = 1, . . . , T and any β (0, 1 2), and T β > M/2, then the output of Algorithm 1 satisfies E wφ(wτ T , pτ T ) 2 2 φ(w1, p1) + ˆB + 2n ˆB + 2 Mσ2 w T β M/2 , E [φ(wτ T , pτ T )] max p n {E [φ(wτ T , p)]} ˆGp/( 2T min{2β,1 2β}). Theorem 2 shows that Algorithm 1 converges in expectation to an (ϵ, δ)-stationary point of φ in O(max{1/ϵ2/β, 1/δ1/ min{2β,1 2β}}) stochastic gradient evaluations in the unconstrained setting. Note that β can be tuned to favor convergence with respect to w or p. To treat convergence with respect to w and p equally, the optimal setting is β = 2 5, yielding a convergence rate of O(max{1/ϵ5, 1/δ5}). We finally consider the case when W is a compact, convex set. In this setting the notion of an (ϵ, δ)-stationary point must be altered such that ϵ upper bounds the projected gradient, gw, defined as gw(wt, pt) := 1 ηtw (wt ΠW(wt ηt wˆgw(wt, pt))), since this vector reveals how much the solution can be improved by moving within the feasible set. In the following theorem, we choose C as a function of T to show convergence. Theorem 3. Suppose Assumptions 1 and 2 hold. Let σ2 w := Cσ2 w and set ηt w = 1/(2 M) and ηt p = (T β ˆB n) 1 for t [T], and the task batch size as C = T β, for any β (0, 1), then E gw(wτ T , pτ T ) 2 2 8 M(φ(w1, p1) + ˆB) 3T + 8 M ˆB n + 4 σ2 w 3T β , E [φ(wτ T , pτ T )] max p n {E [φ(wτ T , p)]} ˆB n T min{β,1 β} . The number of stochastic gradient evaluations is now O(CT) = O(T 1+β), so Theorem 3 shows Algorithm 1 converges to an (ϵ, δ)-stationary point after O(max{1/ϵ(2+2β)/β, 1/δ(1+β)/ min{β,1 β}}) evaluations with convex, compact W and nonconvex ˆFi. By setting β = 2 3 we treat convergence with respect to w and p equally, yielding a complexity of O(max{1/ϵ5, 1/δ5}) evaluations. 5 Generalization Bounds Given that the meta-learner has access to a finite number of task instances during meta-training, there are two types of generalization to consider: generalization to new instances of previously-seen tasks, and generalization to new tasks. We start by bounding the error on new instances of previouslyseen tasks. Note that each task s Di is a distribution over Z := (X Y)K+J. For some loss ℓ, define the family of functions F(Z) := F := { ˆf(w α ˆf(w, Dtrain), Dtest) : w W}, where (Dtrain, Dtest) Z and ˆf(w, D) is the average loss of w on the samples in D. The Rademacher complexity of F on mi samples {(Dtrain j , Dtest j )}mi j=1 =: D drawn i.i.d. from Di is then Ri mi(F) = ED (Di)mi Eϵj j=1 ϵj ˆfi,j(w α ˆfi,j(w, Dtrain i,j ), Dtest i,j ) , (11) where the ϵj s are Rademacher random variables. Recall that the empirical loss of the model w on the i-th task is ˆFi(w), defined in (2). By a standard Rademacher complexity bound, one can bound the analogous expected loss Fi(w) with high probability over the choice of task instances. Figure 1: Meta-training and meta-test task MSE statistics vs the number of meta-training iterations for K = 5, with 95% confidence intervals shaded over 5 trials. The rightmost plot shows the number of meta-test tasks with average MSE within particular intervals for a sample trial. TR-MAML outperforms MAML on the worst-case regression task, and performs more uniformly across all tasks. Proposition 1. Suppose Assumption 1 holds, then with probability at least 1 δ for any δ > 0, Fi(w) := E(Dtrain i,j ,Dtest i,j) Di[ ˆfi,j(w α ˆfi,j(w, Dtrain i,j ), Dtest i,j )] ˆFi(w) + 2Ri mi(F) + ˆB 2mi Next, let w be the optimal solution to the TR-MAML meta-training objective (3). Suppose a new task is drawn with distribution Dn+1, and suppose that Dn+1 = Pn i=1 ai Di for some a n. Then the loss Fn+1(w) is a convex combination of the losses on the meta-training tasks, yielding Theorem 4. For a new task with distribution Dn+1, if Dn+1 = Pn i=1 ai Di for a n, then with probability at least 1 δ for any δ > 0, Fn+1(w ) min w W max p n i=1 pi ˆFi(w) + 2ai Ri mi(F) + ai ˆB Theorem 4 shows that the min-max meta-training solution leverages the diversity of the meta-training tasks to generalize across their full convex hull, not just a local neighborhood of the solution. 6 Experimental Results Experimental Setup: Our experiments study whether minimizing the maximum task loss during meta-training leads to a more task-robust solution compared to MAML in few-shot sinusoid regression and image classification settings. Recall that our setting consists of a collection of tasks, with each task having a number of task-instances. In practice, the datasets we consider could have an exceedingly large number of tasks, thus rendering it computationally infeasible for us to conduct experiments. For instance, consider few-shot image classification on the Omniglot dataset, which is composed of images of characters from various alphabets. Suppose that we wish to do 5-way classification, meaning that there are images from 5 classes (characters) in each few-shot classification problem. In this setting, a task is a set of 5 particular classes (characters). There are 1200 meta-training characters in Omniglot, which would lead to 1200 5 (around 2 1013) total tasks for meta-training. Thus for for computational tractability, we reduce the number of tasks by clustering in a problem-dependent manner, resulting in the number of tasks ranging from few tens to hundred, as detailed below. Sinusoid Regression. In the popular sinusoid regression experiment [12], each task instance is a sinusoid regression problem in which the target is a sine function on [ 5, 5] R with amplitude a [0.1, 5] and phase b [0, 2π]. The learner has K samples {(xi, a sin(xi b))}K i=1, where each xi is uniformly sampled from [ 5, 5], and tries to find a function that closely approximates a sin(x b) in terms of mean squared error (MSE). Typically the meta-training and meta-testing distributions are identical, and are such that amplitudes are drawn uniformly from [0.1, 5] and phases uniformly from [0, 2π]. Here we experiment with a distributional shift between meta-training and meta-testing in which a large number of easy task instances and a small number of hard task instances are accessible for meta-training, and the resulting initialization is evaluated on all tasks in the space. In particular, we assume that sine functions of all phases but with amplitudes only in the intervals [0.1, 1.05] (easy tasks) and [4.95, 5] (hard tasks) are available for meta-training. The sinusoids with larger amplitudes are harder targets because they are less smooth and have larger magnitudes, meaning poor approximations are generally punished more severely in terms of MSE. Empirically we find that phase has little effect on the hardness of a target. Table 1: Sinusoid regression results showing MSE statistics across the 490 meta-test tasks, with 95% confidence intervals over 5 random trials. K Algorithm Mean Worst Std. Dev. 5 MAML 1.02 0.10 3.89 0.83 0.88 0.14 TR-MAML 1.09 0.08 2.82 0.35 0.43 0.03 10 MAML 0.66 0.16 2.57 0.70 0.54 0.13 TR-MAML 0.77 0.11 1.68 0.43 0.25 0.08 Table 2: Omniglot N-way, K-shot classification accuracies, with 95% confidence intervals over 3 random trials. Meta-training Alphabets Meta-testing Alphabets (N, K) Algorithm Weighted Mean Worst Weighted Mean Worst Std. Dev. (5,1) MAML 98.4 .2 82.4 1.1 93.5 .2 82.5 .2 3.84 .1 TR-MAML 97.4 .6 95.0 0.3 93.1 1.1 85.3 1.9 3.50 .3 (20,1) MAML 99.2 .1 33.9 3.0 67.6 2.0 49.7 3.5 9.10 .1 TR-MAML 92.2 .8 82.4 2.1 74.3 1.4 58.4 1.8 8.70 .5 We partition [0.1, 5] into 490 disjoint subintervals of length 0.01, and define a task as the uniform distribution over all task instances with target amplitude in a particular subinterval. Thus, there are 95 easy and 5 hard meta-training tasks. We assume each task has the same number of instances available, so both MAML and TR-MAML sample phases uniformly from [0, 2π] and amplitudes uniformly from [0.1, 1.05] [4.95, 5]. The meta-test distribution is the uniform distribution across the full space of amplitudes and phases. Both algorithms use one SGD step as the inner learning algorithm, and the same fully-connected network architecture as in [12] for the learning model. Figure 1 shows the convergence trajectories of MAML and TR-MAML when K = 5. Each plot entails estimating the current model s MSE on each task by sampling 5,000 task instances across all 100 meta-training tasks (for an average of 50 instances per task), and separately across all 490 meta-testing tasks. The leftmost plot shows the average and maximum MSE over each of the 100 meta-training tasks estimated MSE vs the number of iterations, and the middle-left plot shows the same statistics over the 490 meta-testing tasks. During meta-training, TR-MAML sacrifices average for worst-case task performance. However, its focus on task-robustness yields more uniform performance across all tasks, allowing TR-MAML to outperform MAML on the hardest meta-test tasks while nearly matching MAML s average performance after the distribution shift. TR-MAML s more uniform performance for K = 5 is captured in the middle-right plot of Figure 1, which shows the standard deviation across the meta-testing task MSEs vs the number of iterations, and the rightmost plot of Figure 1, a histogram of the average MSEs among the 490 meta-test tasks. Table 1 tells a similar story for the K {5, 10}-shot cases by giving the average, maximum, and standard deviation of the MSEs among the 490 meta-test tasks after full meta-training, where the statistics are again empirical averages over 5,000 task instances. Image Classification. In few-shot image classification, the task instances are N-way, K-shot classification problems, where N is the number of classes and K is the number of labeled samples from each class that are available to the learner. After updating the model based on these NK samples, the model is evaluated on J samples from each class. As discussed earlier, in standard few-shot image classification experiments, each individual N-way, K-shot classification problem is sometimes considered a task , leading to an intractably large number of tasks in our setting. Instead, we consider a more practical definition of a task as a set of N-way, K-shot classification problems (task instances) sharing similar properties (e.g. all N characters belong to the same alphabet in the Omniglot experiment discussed below). Thus, a task instance is an individual N-way, K-shot classification problem, equivalent to the definition of a task as used in other works. We experiment in this setting using the Omniglot [19] and mini-Image Net [37] datasets. For both datasets, we use the corresponding 4-layer CNN used in the original MAML paper [12]. Omniglot contains 1623 handwritten characters from 50 alphabets, with 20 examples per character. In order to establish an environment with a tractable number of tasks, we define each task as the uniform Table 3: mini-Image Net 5-way, 1-shot accuracies, with 95% confidence intervals over 3 random trials. Eight Meta-Training Tasks Four Meta-Testing Tasks (N, K) Algorithm Weighted Mean Worst Weighted Mean Worst (5,1) MAML 70.1 2.2 48.0 4.5 46.6 .4 44.7 .7 TR-MAML 63.2 1.3 60.7 1.6 48.5 .6 45.9 .8 distribution over all task instances composed of characters from one particular alphabet. Note that as a result, we sample the same fine-grained task instances, of characters all belonging to the same alphabet, as those recommended to use to evaluate meta-learning models on Omniglot in [36]. We use the same (meta-) train/validation/test splits as in [36]. There are n = 25 alphabets, i.e., tasks, for meta-training and 20 for meta-testing. Suppose there are Zi characters in the i-th alphabet, then the number of task instances that may be drawn from the i-th task is proportional to Zi N , since every character has the same number of samples. These proportions define the empirical distribution over the 25 meta-training tasks, so during meta-training MAML samples task instances by first selecting the i-th alphabet with probability proportional to Zi N , then uniformly samples an N-way, K-shot classification problem from the available data in alphabet i. Conversely, TR-MAML first samples an alphabet uniformly, then samples an N-way, K-shot problem uniformly from that alphabet. After 60,000 meta-training iterations, we evaluate the models yielded by MAML and TR-MAML on 5,000 N-way, K-shot classification problems from the 20 meta-test alphabets, as well as 5,000 problems from the meta-training alphabets. Table 3 shows statistics taken over the average accuracy on task instances from each alphabet for different values of N and K. First note that TR-MAML improves on MAML s worst-case task performance in all cases. Regarding mean performance, Weighted Mean is the uniform average over task instances (i.e. is the surrogate for the expected loss over tasks given in Equation 1), and weighs the average accuracy on each task (alphabet) by the number of instances it contains. MAML aims to minimize this metric, and always outperforms TRMAML on it. TR-MAML s improved Weighted Mean performance at meta-test time in the N = 20 case shows that TR-MAML can generalize better than MAML because it prioritizes performance on all the meta-training tasks, whereas MAML may overfit to the most frequent ones. Observe that the empirical distribution of meta-training alphabets becomes more skewed as N increases, causing MAML to focus on a smaller subset of the meta-training alphabets and further disregard worst-case alphabet performance, thus leading to worse generalization. For mini-Image Net, we split the 100 image classes into two subsets: 64 classes used for metatraining, and the remaining 36 for meta-testing, according to standard procedure [31] with the metavalidation classes used for meta-testing. We create tasks as follows: we randomly group the 64 metatraining classes into 8 meta-train tasks, with the numbers of classes/task being {6, 7, 7, 8, 8, 9, 9, 10}. Likewise, the 36 meta-test classes are randomly split into 4 tasks, each with 9 classes/task. Each task instance is constructed by sampling 1 image each from 5 distinct classes within a task: thus, this is 5-way 1-shot problem. We meta-train for 60,000 iterations with a batch size of 2 task instances, and 5 steps of gradient descent for local adaptation. Our results show the Weighted Mean accuracy (i.e. average case over task instances) and the worst-case performance (i.e. worst accuracy over the tasks). The first two columns are generated by testing on new task instances from the meta-training classes; the second two columns are generated by testing on task instances from the previously unseen meta-test classes. Again we see improved worst-case performance for TR-MAML compared to MAML, and improved mean performance is likely due to TR-MAML learning a more uniform model across the meta-training tasks. Concluding Remarks: We propose TR-MAML1, a MAML variant, that focuses on optimizing for robustness across tasks through a min-max formulation instead of an average case formulation. Our setting thus enables the model to provide reasonable performance even on hard and rarely seen tasks. However, shifting the focus to the worst-case does not come for free, as the model might suffer performance degradation in the average-case if some tasks are sufficiently outlying. Thus, the model that one would use in practice needs to be chosen appropriately depending on the deployment conditions and desired behavior. 1The code for TR-MAML is available at: https://github.com/lgcollins/tr-maml. Broader Impact Our work presents a formulation for learning how to learn optimally on the worst-case task from some environment. Although this formulation has no immediate societal consequences, it provides a novel framework for developing realizable meta-learning systems that are robust across all tasks. Such systems are necessary for many applications; one can think of few-shot fingerprint recognition in security systems, one-shot imitation learning for assembly line machines, and few-shot fraud detection as just a few examples. Moreover, systems that treat performance on all tasks equally despite disparities in the amount of data available for each task are critical for fairness in settings where tasks are correlated with people from a particular instance of a protected class such as race or gender. One weakness of our formulation is that it is not robust against adversarial tasks, but in settings where some tasks may be adversarial, our formulation may be modified to optimize the worst-case loss among the percentage of tasks known to be non-adversarial, the analysis of which we leave for future work. Acknowledgments and Disclosure of Funding This work was partially supported by ONR Grant N00014-19-1-2566, NSF Grant SATC 1704778, NSF Grant CCF-2007668, and ARO grant W911NF-17-1-0359. [1] A. Antoniou, H. Edwards, and A. J. Storkey. How to train your MAML. In International Conference on Learning Representations, ICLR, 2019. [2] M.-F. Balcan, M. Khodak, and A. Talwalkar. Provable guarantees for gradient-based metalearning. In International Conference on Machine Learning, pages 424 433, 2019. [3] J. Baxter. Theoretical models of learning to learn. In Learning to learn, pages 71 94. Springer, 1998. [4] Y. Bengio, S. Bengio, and J. Cloutier. Learning a synaptic learning rule. In IJCNN-91-Seattle International Joint Conference on Neural Networks, volume 2, pages 969 vol. IEEE, 1991. [5] L. Bertinetto, J. F. Henriques, P. Torr, and A. Vedaldi. Meta-learning with differentiable closed-form solvers. In International Conference on Learning Representations, 2018. [6] D. Cai, R. Sheth, L. Mackey, and N. Fusi. Weighted meta-learning. ar Xiv preprint ar Xiv:2003.09465, 2020. [7] R. S. Chen, B. Lucier, Y. Singer, and V. Syrgkanis. Robust optimization for non-convex objectives. In Advances in Neural Information Processing Systems, pages 4705 4714, 2017. [8] Y. Duan, J. Schulman, X. Chen, P. L. Bartlett, I. Sutskever, and P. Abbeel. Rl2: Fast reinforcement learning via slow reinforcement learning. ar Xiv preprint ar Xiv:1611.02779, 2016. [9] J. Duchi, P. Glynn, and H. Namkoong. Statistics of robust optimization: A generalized empirical likelihood approach. ar Xiv preprint ar Xiv:1610.03425, 2016. [10] J. Duchi and H. Namkoong. Learning models with uniform performance via distributionally robust optimization. ar Xiv preprint ar Xiv:1810.08750, 2018. [11] A. Fallah, A. Mokhtari, and A. Ozdaglar. On the convergence theory of gradient-based modelagnostic meta-learning algorithms. In International Conference on Artificial Intelligence and Statistics, pages 1082 1092, 2020. [12] C. Finn, P. Abbeel, and S. Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In International Conference on Machine Learning, pages 1126 1135, 2017. [13] C. Finn, A. Rajeswaran, S. Kakade, and S. Levine. Online meta-learning. In International Conference on Machine Learning, pages 1920 1930, 2019. [14] S. Ghadimi, G. Lan, and H. Zhang. Mini-batch stochastic approximation methods for nonconvex stochastic composite optimization. Mathematical Programming, 155(1-2):267 305, 2016. [15] M. A. Jamal and G.-J. Qi. Task agnostic meta-learning for few-shot learning. 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Jun 2019. [16] C. Jin, P. Netrapalli, and M. I. Jordan. Minmax optimization: Stable limit points of gradient descent ascent are locally optimal. ar Xiv preprint ar Xiv:1902.00618, 2019. [17] A. Juditsky, A. Nemirovski, and C. Tauvel. Solving variational inequalities with stochastic mirror-prox algorithm. Stochastic Systems, 1(1):17 58, 2011. [18] M. Khodak, M.-F. F. Balcan, and A. S. Talwalkar. Adaptive gradient-based meta-learning methods. In Advances in Neural Information Processing Systems, pages 5915 5926, 2019. [19] B. M. Lake, R. Salakhutdinov, and J. B. Tenenbaum. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332 1338, 2015. [20] H. Lee, H. Lee, D. Na, S. Kim, M. Park, E. Yang, and S. J. Hwang. Learning to balance: Bayesian meta-learning for imbalanced and out-of-distribution tasks. In International Conference on Learning Representations, ICLR, 2020. [21] K. Lee, S. Maji, A. Ravichandran, and S. Soatto. Meta-learning with differentiable convex optimization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 10657 10665, 2019. [22] Z. Li, F. Zhou, F. Chen, and H. Li. Meta-sgd: Learning to learn quickly for few-shot learning. ar Xiv preprint ar Xiv:1707.09835, 2017. [23] M. Mohri, A. Rostamizadeh, and A. Talwalkar. Foundations of machine learning. MIT press, 2018. [24] M. Mohri, G. Sivek, and A. T. Suresh. Agnostic federated learning. In International Conference on Machine Learning, pages 4615 4625, 2019. [25] A. Nemirovski, A. Juditsky, G. Lan, and A. Shapiro. Robust stochastic approximation approach to stochastic programming. SIAM Journal on Optimization, 19(4):1574 1609, 2009. [26] A. Nichol and J. Schulman. Reptile: a scalable metalearning algorithm. ar Xiv preprint ar Xiv:1803.02999, 2:2, 2018. [27] M. Nouiehed, M. Sanjabi, T. Huang, J. D. Lee, and M. Razaviyayn. Solving a class of nonconvex min-max games using iterative first order methods. In Advances in Neural Information Processing Systems, pages 14905 14916, 2019. [28] Q. Qian, S. Zhu, J. Tang, R. Jin, B. Sun, and H. Li. Robust optimization over multiple domains. In Proceedings of the AAAI Conference on Artificial Intelligence, pages 4739 4746, 2019. [29] H. Rafique, M. Liu, Q. Lin, and T. Yang. Non-convex min-max optimization: Provable algorithms and applications in machine learning. ar Xiv preprint ar Xiv:1810.02060, 2018. [30] A. Rajeswaran, C. Finn, S. M. Kakade, and S. Levine. Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems, pages 113 124, 2019. [31] S. Ravi and H. Larochelle. Optimization as a model for few-shot learning. In ICLR, 2016. [32] S. Shalev-Shwartz and Y. Wexler. Minimizing the maximal loss: How and why. In ICML, pages 793 801, 2016. [33] J. Snell, K. Swersky, and R. Zemel. Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pages 4077 4087, 2017. [34] X. Song, W. Gao, Y. Yang, K. Choromanski, A. Pacchiano, and Y. Tang. Es-maml: Simple hessian-free meta learning. ar Xiv preprint ar Xiv:1910.01215, 2019. [35] S. Thrun and L. Pratt. Learning to Learn. Springer Science and Business Media, 2012. [36] E. Triantafillou, T. Zhu, V. Dumoulin, P. Lamblin, U. Evci, K. Xu, R. Goroshin, C. Gelada, K. Swersky, P.-A. Manzagol, et al. Meta-dataset: A dataset of datasets for learning to learn from few examples. In International Conference on Learning Representations, 2019. [37] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. Matching networks for one shot learning. In Advances in Neural Information Processing Systems, pages 3630 3638, 2016. [38] J. X. Wang, Z. Kurth-Nelson, D. Tirumala, H. Soyer, J. Z. Leibo, R. Munos, C. Blundell, D. Kumaran, and M. Botvinick. Learning to reinforcement learn. ar Xiv preprint ar Xiv:1611.05763, 2016. [39] W. Wang and M. A. Carreira-Perpinán. Projection onto the probability simplex: An efficient algorithm with a simple proof, and an application. ar Xiv preprint ar Xiv:1309.1541, 2013. [40] C. Yin, J. Tang, Z. Xu, and Y. Wang. Adversarial meta-learning. ar Xiv preprint ar Xiv:1806.03316, 2018. [41] P. Zhou, X. Yuan, H. Xu, S. Yan, and J. Feng. Efficient meta learning via minibatch proximal update. In Advances in Neural Information Processing Systems, pages 1532 1542, 2019. [42] Z. Zhuang, Y. Wang, K. Yu, and S. Lu. Online meta-learning on non-convex setting. ar Xiv preprint ar Xiv:1910.10196, 2019. [43] D. Zügner and S. Günnemann. Adversarial attacks on graph neural networks via meta learning. ar Xiv preprint ar Xiv:1902.08412, 2019.