# posterior_metareplay_for_continual_learning__4090f01a.pdf Posterior Meta-Replay for Continual Learning Christian Henning*, Maria R. Cervera*, Francesco D Angelo, Johannes von Oswald, Regina Traber, Benjamin Ehret, Seijin Kobayashi, Benjamin F. Grewe, João Sacramento *Equal contribution Institute of Neuroinformatics University of Zürich and ETH Zürich Zürich, Switzerland {henningc,mariacer}@ethz.ch Learning a sequence of tasks without access to i.i.d. observations is a widely studied form of continual learning (CL) that remains challenging. In principle, Bayesian learning directly applies to this setting, since recursive and one-off Bayesian updates yield the same result. In practice, however, recursive updating often leads to poor trade-off solutions across tasks because approximate inference is necessary for most models of interest. Here, we describe an alternative Bayesian approach where task-conditioned parameter distributions are continually inferred from data. We offer a practical deep learning implementation of our framework based on probabilistic task-conditioned hypernetworks, an approach we term posterior meta-replay. Experiments on standard benchmarks show that our probabilistic hypernetworks compress sequences of posterior parameter distributions with virtually no forgetting. We obtain considerable performance gains compared to existing Bayesian CL methods, and identify task inference as our major limiting factor. This limitation has several causes that are independent of the considered sequential setting, opening up new avenues for progress in CL. 1 Introduction In recent years, a variety of continual learning (CL) algorithms have been developed to overcome the need to train neural networks with an independent and identically distributed (i.i.d.) sample. Most CL literature focuses on the particular scenario of continually learning a sequence of T tasks with datasets D(1), . . . , D(T ). Because only access to the current task is granted, successful training of a discriminative model that captures p(Y | X) has to occur without an i.i.d. training sample from the overall joint D(1:T ) i.i.d. p(X)p(Y | X). The advantages of a Bayesian approach for solving this problem are numerous and include the ability to drop all i.i.d. assumptions across and within tasks in a mathematically sound way, the ability to revisit tasks whenever new data becomes available, and access to principled uncertainty estimates capturing both data and parameter uncertainty. Up until now, Bayesian approaches to CL essentially focused on finding a combined posterior distribution via a recursive Bayesian update p(W | D(1:T )) p(W | D(1:T 1))p(D(T ) | W). Because the posterior of the previous task is used as prior for the next task, these approaches are also known as prior-focused [17]. In theory, the 35th Conference on Neural Information Processing Systems (Neur IPS 2021). above recursive update can always recover the posterior p(W | D(1:T )), independently of how the data is presented. However, because proper Bayesian inference is intractable, approximations are needed in practice, which lead to errors that are recursively amplified. As a result, whether solutions that are easily found in the i.i.d. setting can be obtained via the approximate recursive update strongly depends on factors such as task ordering, task similarity and the considered family of distributions. These factors limit the effectiveness of the recursive update and have a detrimental effect on the performance of prior-focused methods, especially in task-agnostic CL settings. MAP solutions Figure 1: The proposed posterior meta-replay framework learns task-specific posteriors p(W | D(t)) via a single shared meta-model, with taskspecific point estimates (e.g., MAP) being a limit case. In this view, the modelled solution space is not limited to admissible solutions that lie in the overlap of all task-specific posteriors. By contrast, prior-focused methods learn a single posterior p(W | D(1:T )) recursively and thus require the existence of trade-off solutions between learned and future tasks in the currently modelled solution space. Shaded areas indicate high density regions. To overcome these limitations, we propose an alternative Bayesian approach to CL that does not rely on the recursive update to learn distinct tasks and instead aims to learn task-specific posteriors (Fig. 1, refer to SM F.1 for a detailed discussion of the graphical model). In this view, finding trade-off solutions across tasks is not required, and knowledge transfer can be explicitly controlled for each task via the prior, which is no longer prescribed by the recursive update and can thus be set freely. By introducing probabilistic extensions of task-conditioned hypernetworks [91], we show how task-specific posteriors can be learned with a single shared meta-model, an approach we term posterior meta-replay. This approach introduces two challenges: forgetting at the level of the hypernetwork, and the need to know task identity to correctly condition the hypernetwork. We empirically show that forgetting at the meta-level can be prevented by using a simple regularizer that replays parameters of previous posteriors. In task-agnostic inference settings, often referred to as class-incremental learning in the context of classification benchmarks [88], the main hurdle therefore becomes task inference at test time. Here we focus on this task-agnostic setting, arguably the most challenging but also the most natural CL scenario, since the obtained models can be deployed just like those obtained via i.i.d. training (e.g., irrespective of the sequential training, the final model will be a classifier across all classes). In order to explicitly infer task identity from unseen inputs without resorting to generative models, we thoroughly study the use of principled uncertainty that naturally arises in Bayesian models. We show that results obtained in this task-agnostic setting with our approach constitute a leap in performance compared to prior-focused methods. Furthermore we show that limitations in task inference via predictive uncertainty are not related to our CL solution, but depend instead on the combination of approximate inference method, architecture, uncertainty measure and prior. Finally, we investigate how task inference can be further improved through several extensions. We summarize our main contributions below: We describe a Bayesian CL framework where task-conditioned posterior parameter distributions are continually learned and compressed in a hypernetwork. In a series of synthetic and real-world CL benchmarks we show that our task-conditioned hypernetworks exhibit essentially no forgetting, both for explicitly parameterized and implicit posterior distributions, despite using the parameter budget of a single model. Compared to prior-focused methods, our approach leads to a leap in performance in task-agnostic inference while maintaining the theoretical benefits of a Bayesian approach. Our approach scales to modern architectures such as Res Nets, and remaining performance limitations are linked to uncertainty-based out-of-distribution detection but not to our CL solution. Finally, we show how prominent existing Bayesian CL methods such as elastic weight consolidation can be dramatically improved in task-agnostic settings by introducing a small set of task-specific parameters and explicitly inferring the task. 2 Related Work Continual learning. CL algorithms attempt to mitigate catastrophic interference while facilitating transfer of skills whenever possible. They can be coarsely categorized as (1) regularization-methods that put constraints on weight updates, (2) replay-methods that mimic pseudo-i.i.d. training by rehearsing stored or generated data and (3) dynamic architectures which can grow to allocate capacity for new knowledge [71]. Most related to our work is the study from von Oswald et al. [91] that introduces task-conditioned hypernetworks for CL, and already considers task inference via predictive uncertainty in the deterministic case. Our framework can be seen as a probabilistic extension of their work, which provides task-specific point estimates via a shared meta-model (cf. Sec. 3). Follow-up work also achieves task inference via predictive uncertainty, e.g., Wortsman et al. [94] use it to select a learned binary mask per task that modulates a random base network. Here we complement these studies by thoroughly exploring task inference via several uncertainty measures, disclosing the factors that limit task inference and highlighting the importance of parameter uncertainty. A variety of methods tackling CL have been derived from a Bayesian perspective. A prominent example are prior-focused methods [17], which incorporate knowledge from past data via the prior and, in contrast to our work, aim to find a shared posterior for all data. Examples include (Online) EWC [38, 78] and VCL [65, 54]. Other methods like CN-DPM [46] use Bayes rule for task inference on the joint p(X, C), where C is a discrete condition such as task identity. An evident downside of CN-DPM is the need for a separate generative and discriminative model per condition. More generally, such an approach requires meaningful density estimation in the input space, a requirement that is challenging for modern ML problems [64]. Other Bayesian CL approaches consider instead task-specific posterior parameter distributions. Lee et al. [47] learn separate task-specific Gaussian posterior approximations which are merged into a single posterior after all tasks have been seen. CBLN [49] also learns a separate Gaussian posterior approximation per task but later tries to merge similar posteriors in the induced Gaussian mixture model. Task inference is thus required and achieved via predictive uncertainty, although for a more reliable estimation all experiments consider batches of 200 samples that are assumed to belong to the same task. Tuor et al. [85] also learn a separate approximate posterior per task and use predictive uncertainty for task-boundary detection and task inference. In contrast to these approaches, we learn all task-specific posteriors via a single shared meta-model and remain agnostic to the approximate inference method being used. A conceptually related approach is MERLIN [33], which learns task-specific weight distributions by training an ensemble of models per task that is used as training set for a task-conditioned variational autoencoder. Importantly, MERLIN requires a fine-tuning stage at inference, such that every drawn model is fine-tuned on stored coresets, i.e., a small set of samples withheld throughout training. By contrast, our approach learns the parameters of an approximate Bayesian posterior p(W | D(t)) per task t, and no fine-tuning of drawn models is required. Bayesian neural networks. Because neural networks are expressive enough to fit almost any data [98] and are often deployed in an overparametrized regime, it is implausible to expect that any single solution obtained from limited data generalizes to the ground truth p(Y | w, X) p(Y | X) almost everywhere on p(x). By contrast, Bayesian statistics considers a distribution over models, explicitly handling uncertainty to acknowledge data insufficiencies. This distribution is called the posterior parameter distribution p(W | D) p(D | W) p(W), which weights models based on their ability to fit the data (via the likelihood p(D | W)), while considering only plausible models according to the prior p(W). Predictions are made by marginalizing over models (for an introduction see Mac Kay [57]). Bayesian neural networks (BNN) apply this formalism to network parameters w, whereas for practical reasons hyperparameters like architecture are chosen deterministically [56]. While a deterministic discriminative model can only capture aleatoric uncertainty (i.e., uncertainty intrinsic to the data p(Y | X)), a Bayesian treatment allows to also capture epistemic uncertainty by being uncertain about the model s parameters (parameter uncertainty). This proper treatment of uncertainty is of utmost importance for safety-critical applications, where intelligent systems are expected to know what they don t know. However, due to the complexity of modelling highdimensional distributions at the scale of modern deep learning, BNNs still face severe scalability issues [82]. Here, we employ several approximations to the posterior based on variational inference [5] from prior work, ranging from simple and scalable methods with a mean-field variational family like Bayes-by-Backprop (Bb B, [6]) to methods with complex but rich variational families like the spectral Stein gradient estimator [79]. For more details see Sec. 3 and SM C. Weight Generator Main Network Task Conditioned Hypernetwork Weight Generator Implicit Explicit (Bb B) approx. posterior parameters learned parameters random variable Posterior Distribution (a) (b) Figure 2: Posterior meta-replay for CL. (a) The architecture consists of a main network M that processes inputs x and generates predictions ˆy according to a set of weights w generated by a weight generator (WG). The WG is a deterministic function f WG(z, θ(t)) that transforms a base distribution p(Z) into a distribution over main network weights, where θ(t) are the parameters of the approximate posterior qθ(t)(W). Crucially, θ(t) are task-specific, and generated by a task-conditioned (TC) hypernetwork, which receives task embeddings e(t) as input. The embeddings and the parameters ψ of the TC are learned continually via a simple meta-replay regularizer (Eq. 1). (b) We refer to the approximate posteriors as explicit if f WG is predefined. In Bayes-by-Backprop (Bb B), for example, the reparametrization trick transforms Gaussian noise into weight samples. (c) More complex, implicit posterior approximations are parametrized by an auxiliary hypernetwork, which receives its task-conditioned parameters from the TC, which now plays the role of a hyper-hypernetwork. The obtained posterior approximations are more flexible and can, for example, capture multi-modality. In this section we describe our posterior meta-replay framework (Fig. 2). We start by introducing task-conditioned hypernetworks as a tool to continually learn parameters of task-specific posteriors, each of which is learned using variational inference (SM C.1). We then explain how the framework can be instantiated for both simple, explicit posterior approximations, and complex ones parametrized by an auxiliary network, and describe how forgetting can be mitigated through the use of a metaregularizer. We next explain how predictive uncertainty, naturally arising from a probabilistic view of learning, can be used to infer task identity for both Posterior Replay methods, and Prior Focused methods that use a multihead output. Finally, we outline ways to boost task inference. Task-conditioned hypernetworks. Traditionally, hypernetworks are seen as neural networks that generate the weights w of a main network M processing inputs as ˆy = f M(x, w) [22, 77]. Here, we consider instead hypernetworks that learn to generate θ, the parameters of a distribution qθ(W) over main network weights. By taking low-dimensional task embeddings e(t) as inputs and computing θ(t) = f TC(e(t), ψ), task-conditional (TC) computation is possible. Sampling is realized by transforming a base distribution p(Z) via a weight generator (WG) f WG(z, θ(t)), whose choice determines the family of distributions considered for the approximation (i.e., the variational family). In our framework, weights w qθ(t)(W) are directly used for inference without requiring any fine-tuning. Importantly, all learnable parameters are comprised in the TC system, which can be designed to have less parameters than the main network, i.e., dim(ψ) + P t dim(e(t)) < dim(w). Such constraint is vital to ensure fairness when comparing different CL methods, and is enforced in all our computer vision experiments. Additional details can be found in SM C.2. Posterior-replay with explicit distributions. Different families of distributions can be realized within our framework. In the special case of a point estimate qθ(t)(W) = δ(W θ(t)), the WG system can be omitted altogether as it corresponds to the identity θ(t) = f WG(z, θ(t)). This reduces our solution to the deterministic CL method introduced by von Oswald et al. [91], which we refer to as Posterior Replay-Dirac. However, capturing parameter uncertainty is a key ingredient of Bayesian statistics that is necessary for more robust task inference (cf. Sec. 4.2). We thus turn as a first step to explicit distributions qθ(t)(W), for which the WG system samples according to a predefined function. We refer as Posterior Replay-Exp to finding a mean-field Gaussian approximation via the Bb B algorithm (SM C.3.1, [6]). In this case, θ(t) corresponds to the set of means and variances that define a Gaussian for each weight, which is directly generated by the TC. In the SM, we also report results for another instance of explicit distribution (cf. SM C.3.2). Posterior-replay with implicit distributions. Since the expressivity of explicit distributions is limited, we also explore the more diverse variational family of implicit distributions [19, 31]. These are parametrized by a WG that now takes the form of an auxiliary neural network, making the parameters θ(t) of the approximate posterior dependent on the chosen WG architecture. This setting, referred to as Posterior Replay-Imp, results in a hierarchy of three networks: a TC network generates task-specific parameters θ(t) for the approximate posterior, which is defined through an arbitrary base distribution p(Z) and the WG hypernetwork, which in turn generates weights w for a main network M that processes the actual inputs of the dataset D(t). Interestingly, the TC now plays the role of a hyper-hypernetwork as it generates the weights of another hypernetwork (Fig. 2a and Fig. 2c). Variational inference commonly resorts to optimizing an objective consisting of a data-dependent term and a prior-matching term KL(qθ(W) || p(W)). Estimating the prior-matching term when using implicit distributions is not straightforward since we do not have analytic access to the density nor the entropy of qθ(t)(W). To overcome this challenge, we resort to the spectral Stein gradient estimator (SSGE, SM C.4.2, [79]). This method is based on the insight that direct access to the log-density is not required, but only to its gradient with respect to W. Noticing that this quantity appears in Stein s identity, the authors consider a spectral decomposition of the term and use the Nyström method to approximate the eigenfunctions. We test an alternative method for dealing with implicit distributions in the SM that is based on estimating the log-density ratio (SM C.4.1). As an additional challenge introduced by the use of implicit distributions, the support of qθ(t)(W) is limited to a low-dimensional manifold when using an inflating architecture for WG, causing the priormatching term to be ill-defined. To overcome this, we investigate the use of small noise perturbations in WG outputs (SM C.4.3). Normalizing flows [70] can also be utilized as WG architectures to gain analytic access to qθ(t)(W), albeit at the cost of architectural constraints such as invertibility. Overcoming forgetting via meta-replay. Since all learnable parameters are part of the TC system, forgetting only needs to be addressed at this meta-level. With Ltask the task-specific loss (SM Eq. 3) and D( || ) a divergence measure between distributions, the loss for task t becomes: L(t)(ψ, E, D(t)) = Ltask(ψ, e(t), D(t)) + β X t