# efficient_lifelong_learning_with_agem__faa0f644.pdf Published as a conference paper at ICLR 2019 EFFICIENT LIFELONG LEARNING WITH A-GEM Arslan Chaudhry1, Marc Aurelio Ranzato2, Marcus Rohrbach2, Mohamed Elhoseiny2 1University of Oxford, 2Facebook AI Research arslan.chaudhry@eng.ox.ac.uk, {ranzato,mrf,elhoseiny}@fb.com In lifelong learning, the learner is presented with a sequence of tasks, incrementally building a data-driven prior which may be leveraged to speed up learning of a new task. In this work, we investigate the efficiency of current lifelong approaches, in terms of sample complexity, computational and memory cost. Towards this end, we first introduce a new and a more realistic evaluation protocol, whereby learners observe each example only once and hyper-parameter selection is done on a small and disjoint set of tasks, which is not used for the actual learning experience and evaluation. Second, we introduce a new metric measuring how quickly a learner acquires a new skill. Third, we propose an improved version of GEM (Lopez-Paz & Ranzato, 2017), dubbed Averaged GEM (A-GEM), which enjoys the same or even better performance as GEM, while being almost as computationally and memory efficient as EWC (Kirkpatrick et al., 2016) and other regularizationbased methods. Finally, we show that all algorithms including A-GEM can learn even more quickly if they are provided with task descriptors specifying the classification tasks under consideration. Our experiments on several standard lifelong learning benchmarks demonstrate that A-GEM has the best trade-off between accuracy and efficiency.1 1 INTRODUCTION Intelligent systems, whether they are natural or artificial, must be able to quickly adapt to changes in the environment and to quickly learn new skills by leveraging past experiences. While current learning algorithms can achieve excellent performance on a variety of tasks, they strongly rely on copious amounts of supervision in the form of labeled data. The lifelong learning (LLL) setting attempts at addressing this shortcoming, bringing machine learning closer to a more realistic human learning by acquiring new skills quickly with a small amount of training data, given the experience accumulated in the past. In this setting, the learner is presented with a stream of tasks whose relatedness is not known a priori. The learner has then the potential to learn more quickly a new task, if it can remember how to combine and re-use knowledge acquired while learning related tasks of the past. Of course, for this learning setting to be useful, the model needs to be constrained in terms of amount of compute and memory required. Usually this means that the learner should not be allowed to merely store all examples seen in the past (in which case this reduces the lifelong learning problem to a multitask problem) nor should the learner be engaged in computations that would not be feasible in real-time, as the goal is to quickly learn from a stream of data. Unfortunately, the established training and evaluation protocol as well as current algorithms for lifelong learning do not satisfy all the above desiderata, namely learning from a stream of data using limited number of samples, limited memory and limited compute. In the most popular training paradigm, the learner does several passes over the data (Kirkpatrick et al., 2016; Aljundi et al., 2018; Rusu et al., 2016; Schwarz et al., 2018), while ideally the model should need only a handful of samples and these should be provided one-by-one in a single pass (Lopez-Paz & Ranzato, 2017). Moreover, when the learner has several hyper-parameters to tune, the current practice is to go over the sequence of tasks several times, each time with a different hyper-parameter value, again ignoring the requirement of learning from a stream of data and, strictly speaking, violating the assumption of 1The code is available at https://github.com/facebookresearch/agem. Published as a conference paper at ICLR 2019 the LLL scenario. While some algorithms may work well in a single-pass setting, they unfortunately require a lot of computation (Lopez-Paz & Ranzato, 2017) or their memory scales with the number of tasks (Rusu et al., 2016), which greatly impedes their actual deployment in practical applications. In this work, we propose an evaluation methodology and an algorithm that better match our desiderata, namely learning efficiently in terms of training samples, time and memory from a stream of tasks. First, we propose a new learning paradigm, whereby the learner performs cross validation on a set of tasks which is disjoint from the set of tasks actually used for evaluation (Sec. 2). In this setting, the learner will have to learn and will be tested on an entirely new sequence of tasks and it will perform just a single pass over this data stream. Second, we build upon GEM (Lopez-Paz & Ranzato, 2017), an algorithm which leverages a small episodic memory to perform well in a single pass setting, and propose a small change to the loss function which makes GEM orders of magnitude faster at training time while maintaining similar performance; we dub this variant of GEM, A-GEM (Sec. 4). Third, we explore the use of compositional task descriptors in order to improve the fewshot learning performance within LLL showing that with this additional information the learner can pick up new skills more quickly (Sec. 5). Fourth, we introduce a new metric to measure the speed of learning, which is useful to quantify the ability of a learning algorithm to learn a new task (Sec. 3). And finally, using our new learning paradigm and metric, we demonstrate A-GEM on a variety of benchmarks and against several representative baselines (Sec. 6). Our experiments show that AGEM has a better trade-off between average accuracy and computational/memory cost. Moreover, all algorithms improve their ability to quickly learn a new task when provided with compositional task descriptors, and they do so better and better as they progress through the learning experience. 2 LEARNING PROTOCOL Currently, most works on lifelong learning (Kirkpatrick et al., 2016; Rusu et al., 2016; Shin et al., 2017; Nguyen et al., 2018) adopt a learning protocol which is directly borrowed from supervised learning. There are T tasks, and each task consists of a training, validation and test sets. During training the learner does as many passes over the data of each task as desired. Moreover, hyperparameters are tuned on the validation sets by sweeping over the whole sequence of tasks as many times as required by the cross-validation grid search. Finally, metrics of interest are reported on the test set of each task using the model selected by the previous cross-validation procedure. Since the current protocol violates our stricter definition of LLL for which the learner can only make a single pass over the data, as we want to emphasize the importance of learning quickly from data, we now introduce a new learning protocol. We consider two streams of tasks, described by the following ordered sequences of datasets DCV = {D1, , DT CV } and DEV = {DT CV +1, , DT }, where Dk = {(xk i , tk i , yk i )nk i=1} is the dataset of the k-th task, T CV < T (in all our experiments T CV = 3 while T = 20), and we assume that all datasets are drawn from the same distribution over tasks. To avoid cluttering of the notation, we let the context specify whether Dk refers to the training or test set of the k-th dataset. DCV is the stream of datasets which will be used during cross-validation; DCV allows the learner to replay all samples multiple times for the purposes of model hyper-parameter selection. Instead, DEV is the actual dataset used for final training and evaluation on the test set; the learner will observe training examples from DEV once and only once, and all metrics will be reported on the test sets of DEV . Since the regularization-based approaches for lifelong learning (Kirkpatrick et al., 2016; Zenke et al., 2017) are rather sensitive to the choice of the regularization hyper-parameter, we introduced the set DCV , as it seems reasonable in practical applications to have similar tasks that can be used for tuning the system. However, the actual training and testing are then performed on DEV using a single pass over the data. See Algorithm 1 for a summary of the training and evaluation protocol. Each example in any of these dataset consists of a triplet defined by an input (xk X), task descriptor (tk T , see Sec. 5 for examples) and a target vector (yk yk), where yk is the set of labels specific to task k and yk Y. While observing the data, the goal is to learn a predictor fθ : X T Y, parameterized by θ RP (a neural network in our case), that can map any test pair (x, t) to a target y. Published as a conference paper at ICLR 2019 Algorithm 1 Learning and Evaluation Protocols 1: for h in hyper-parameter list do Cross-validation loop, executing multiple passes over DCV 2: for k = 1 to T CV do Learn over data stream DCV using h 3: for i = 1 to nk do Single pass over Dk 4: Update fθ using (xk i , tk i , yk i ) and hyper-parameter h 5: Update metrics on test set of DCV 6: end for 7: end for 8: end for 9: Select best hyper-parameter setting, h , based on average accuracy of test set of DCV , see Eq. 1. 10: Reset fθ. 11: Reset all metrics. 12: for k = T CV + 1 to T do Actual learning over datastream DEV 13: for i = 1 to nk do Single pass over Dk 14: Update fθ using (xk i , tk i , yk i ) and hyper-parameter h 15: Update metrics on test set of DEV 16: end for 17: end for 18: Report metrics on test set of DEV . Below we describe the metrics used to evaluate the LLL methods studied in this work. In addition to Average Accuracy (A) and Forgetting Measure (F) (Chaudhry et al., 2018), we define a new measure, the Learning Curve Area (LCA), that captures how quickly a model learns. The training dataset of each task, Dk, consists of a total Bk mini-batches. After each presentation of a mini-batch of task k, we evaluate the performance of the learner on all the tasks using the corresponding test sets. Let ak,i,j [0, 1] be the accuracy evaluated on the test set of task j, after the model has been trained with the i-th mini-batch of task k. Assuming the first learning task in the continuum is indexed by 1 (it will be T CV + 1 for DEV ) and the last one by T (it will be T CV for DCV ), we define the following metrics: Average Accuracy (A [0, 1]) Average accuracy after the model has been trained continually with all the mini-batches up till task k is defined as: j=1 ak,Bk,j (1) In particular, AT is the average accuracy on all the tasks after the last task has been learned; this is the most commonly used metric used in LLL. Forgetting Measure (F [ 1, 1]) (Chaudhry et al., 2018) Average forgetting after the model has been trained continually with all the mini-batches up till task k is defined as: j=1 f k j (2) where f k j is the forgetting on task j after the model is trained with all the mini-batches up till task k and computed as: f k j = max l {1, ,k 1} al,Bl,j ak,Bk,j (3) Measuring forgetting after all tasks have been learned is important for a two-fold reason. It quantifies the accuracy drop on past tasks, and it gives an indirect notion of how quickly a model may learn a new task, since a forgetful model will have little knowledge left to transfer, particularly so if the new task relates more closely to one of the very first tasks encountered during the learning experience. Published as a conference paper at ICLR 2019 Learning Curve Area (LCA [0, 1]) Let us first define an average b-shot performance (where b is the mini-batch number) after the model has been trained for all the T tasks as: k=1 ak,b,k (4) LCA at β is the area of the convergence curve Zb as a function of b [0, β]: LCAβ = 1 β + 1 0 Zbdb = 1 β + 1 LCA has an intuitive interpretation. LCA0 is the average 0-shot performance, the same as forward transfer in Lopez-Paz & Ranzato (2017). LCAβ is the area under the Zb curve, which is high if the 0-shot performance is good and if the learner learns quickly. In particular, there could be two models with the same Zβ or AT , but very different LCAβ because one learns much faster than the other while they both eventually obtain the same final accuracy. This metric aims at discriminating between these two cases, and it makes sense for relatively small values of β since we are interested in models that learn from few examples. 4 AVERAGED GRADIENT EPISODIC MEMORY (A-GEM) So far we discussed a better training and evaluation protocol for LLL and a new metric to measure the speed of learning. Next, we review GEM (Lopez-Paz & Ranzato, 2017), which is an algorithm that has been shown to work well in the single epoch setting. Unfortunately, GEM is very intensive in terms of computational and memory cost, which motivates our efficient variant, dubbed A-GEM. In Sec. 5, we will describe how compositional task descriptors can be leveraged to further speed up learning in the few shot regime. GEM avoids catastrophic forgetting by storing an episodic memory Mk for each task k. While minimizing the loss on the current task t, GEM treats the losses on the episodic memories of tasks k < t, given by ℓ(fθ, Mk) = 1 |Mk| P (xi,k,yi) Mk ℓ(fθ(xi, k), yi), as inequality constraints, avoiding their increase but allowing their decrease. This effectively permits GEM to do positive backward transfer which other LLL methods do not support. Formally, at task t, GEM solves for the following objective: minimizeθ ℓ(fθ, Dt) s.t. ℓ(fθ, Mk) ℓ(f t 1 θ , Mk) k < t (6) Where f t 1 θ is the network trained till task t 1. To inspect the increase in loss, GEM computes the angle between the loss gradient vectors of previous tasks gk, and the proposed gradient update on the current task g. Whenever the angle is greater than 90 with any of the gk s, it projects the proposed gradient to the closest in L2 norm gradient g that keeps the angle within the bounds. Formally, the optimization problem GEM solves is given by: minimize g 1 2||g g||2 2 s.t. g, gk 0 k < t (7) Eq.7 is a quadratic program (QP) in P-variables (the number of parameters in the network), which for neural networks could be in millions. In order to solve this efficiently, GEM works in the dual space which results in a much smaller QP with only t 1 variables: minimizev 1 2v GG v + g G v s.t. v 0 (8) where G = (g1, , gt 1) R(t 1) P is computed at each gradient step of training. Once the solution v to Eq. 8 is found, the projected gradient update can be computed as g = G v + g. While GEM has proven very effective in a single epoch setting (Lopez-Paz & Ranzato, 2017), the performance gains come at a big computational burden at training time. At each training step, GEM computes the matrix G using all samples from the episodic memory, and it also needs to solve the QP of Eq. 8. Unfortunately, this inner loop optimization becomes prohibitive when the size of M and the number of tasks is large, see Tab. 7 in Appendix for an empirical analysis. To alleviate Published as a conference paper at ICLR 2019 the computational burden of GEM, next we propose a much more efficient version of GEM, called Averaged GEM (A-GEM). Whereas GEM ensures that at every training step the loss of each individual previous tasks, approximated by the samples in episodic memory, does not increase, A-GEM tries to ensure that at every training step the average episodic memory loss over the previous tasks does not increase. Formally, while learning task t, the objective of A-GEM is: minimizeθ ℓ(fθ, Dt) s.t. ℓ(fθ, M) ℓ(f t 1 θ , M) where M = k