# learning_in_temporally_structured_environments__4b06cc54.pdf Published as a conference paper at ICLR 2023 LEARNING IN TEMPORALLY STRUCTURED ENVIRONMENTS Matt Jones,1,2 Tyler R. Scott,1 Mengye Ren,1,3 Gamaleldin El Sayed,1 Katherine Hermann,1 David Mayo,1,4 Michael C. Mozer1 1Brain Team, Google Research 2University of Colorado 3NYU 4MIT mcj@colorado.edu dmayo2@mit.edu mengye@cs.nyu.edu {tylersco,gamaleldin,hermannk,mcmozer}@google.com Natural environments have temporal structure at multiple timescales. This property is reflected in biological learning and memory but typically not in machine learning systems. We advance a multiscale learning method in which each weight in a neural network is decomposed as a sum of subweights with different learning and decay rates. Thus knowledge becomes distributed across different timescales, enabling rapid adaptation to task changes while avoiding catastrophic interference. First, we prove previous models that learn at multiple timescales, but with complex coupling between timescales, are equivalent to multiscale learning via a reparameterization that eliminates this coupling. The same analysis yields a new characterization of momentum learning, as a fast weight with a negative learning rate. Second, we derive a model of Bayesian inference over 1/f noise, a common temporal pattern in many online learning domains that involves long-range (power law) autocorrelations. The generative side of the model expresses 1/f noise as a sum of diffusion processes at different timescales, and the inferential side tracks these latent processes using a Kalman filter. We then derive a variational approximation to the Bayesian model and show how it is an extension of the multiscale learner. The result is an optimizer that can be used as a drop-in replacement in an arbitrary neural network architecture. Third, we evaluate the ability of these methods to handle nonstationarity by testing them in online prediction tasks characterized by 1/f noise in the latent parameters. We find that the Bayesian model significantly outperforms online stochastic gradient descent and two batch heuristics that rely preferentially or exclusively on more recent data. Moreover, the variational approximation performs nearly as well as the full Bayesian model, and with memory requirements that are linear in the size of the network. 1 INTRODUCTION Many online tasks facing both biological and artificial intelligence systems involve changes in data distribution over time. Natural environments exhibit correlations at a wide range of timescales, a pattern variously referred to as self-similarity, power-law correlations, and 1/f noise (Keshner, 1982). This is in stark contrast with the iid environments assumed by many machine learning (ML) methods, and with diffusion or random-walk environments that exhibit only short-range correlations. Moreover, biological learning systems are well-tuned to the temporal statistics of natural environments, as seen in phenomena of human cognition including power laws in learning (Anderson, 1982), power-law forgetting (Wixted & Ebbesen, 1997), long-range sequential effects (Wilder et al., 2013), and spacing effects (Anderson & Schooler, 1991; Cepeda et al., 2008). An important goal is to incorporate similar inductive biases into ML systems for online or continual learning. This paper analyzes a framework for learning in temporally structured environments, multiscale learning, which for neural networks (NNs) can be implemented as a new kind of optimizer. A common explanation for self-similar temporal structure in nature is that it arises from a mixture of events at various timescales. Indeed, many generative models of 1/f noise involve summing independent stochastic processes with varying time constants (Eliazar & Klafter, 2009). Accordingly, the multiscale optimizer comprises multiple learning processes operating in parallel at different timescales. Published as a conference paper at ICLR 2023 In a NN, every weight wj is replaced by a family of subweights ωij, each with its own learning rate and decay rate, that sum to determine the weight as a whole. Learning at multiple timescales is a key idea in several theories in neuroscience, including conditioning (Staddon et al., 2002), learning (Benna & Fusi, 2016), memory (Howard & Kahana, 2002; Mozer et al., 2009), and motor control (Kording et al., 2007), and has also been exploited in ML (Hinton & Plaut, 1987; Rusch et al., 2022). The multiscale learner isolates and simplifies this idea, by assuming knowledge at different timescales evolves independently and that credit assignment follows gradient descent. The first part of this paper (Sections 2 and 3) proves three other models are formally equivalent to instances of the multiscale optimizer: a new variant of fast weights (cf. Ba et al., 2016; Hinton & Plaut, 1987), the model synapse of Benna & Fusi (2016), and momentum learning (Rumelhart et al., 1986; Qian, 1999). The insight behind these proofs is that each of these models can be written in terms of a linear update rule with diagonalizable transition matrix. Thus the eigenvectors of this matrix correspond to states that evolve independently. By writing the state of the model as a mixture of eigenvectors, we effect a coordinate transformation that exactly yields the multiscale optimizer. These results imply that the complicated coupling among timescales assumed by some models can be superfluous. They also provide a new perspective on momentum learning, with implications for how and when it is beneficial and how it interacts with nonstationarity in the task environment. In Section 4, we provide a normative grounding for multiscale learning in terms of Bayesian inference over 1/f noise. Our starting point is a generative model of 1/f noise as a sum of diffusion processes at different timescales. Exact Bayesian inference with respect to this generative process is possible using a Kalman filter (KF) that tracks the component processes jointly (Kording et al., 2007). When learning a single environmental parameter θ, such as mean reward for some action in a bandit task, this amounts to modeling θ(t) = Pn i=1 zi(t), where each zi is a diffusion process with a different characteristic timescale τi, and doing joint inference over Z = (z1, . . . , zn). We then generalize this approach to an arbitrary statistical model, h (x, θ), where x is the input and θ Rm is a parameter vector to be estimated. For instance, h might be a NN with parameters θ. Our Bayesian model places a 1/f prior on θ (as a stochastic process), by assuming θ(t) = Pn i=1 zi(t) for diffusion processes zi Rm with characteristic timescales τi. We then do approximate inference over the joint state Z = (z1, . . . , zn), using an extended Kalman filter (EKF) that linearizes h by calculating its Jacobian at each step (Singhal & Wu, 1989; Puskorius & Feldkamp, 2003). Next, we derive a variational approximation to the EKF that constrains the covariance matrix to be diagonal, and show how it extends the multiscale optimizer. Specifically, writing wj and ωij as the current mean estimates of θj and zij (for weight j and time scale i), the variational update to each ωij follows that of the multiscale optimizer, with additional machinery for determining decay rates based on τi and adapting learning rates based on the current prior variance s2 ij(t). In Section 5, we test our methods in online prediction and classification tasks with nonstationary distributions. In online learning, nonstationarity often manifests as poorer generalization performance on future data versus held-out data from within the training interval. Common solutions are to train on a window of fixed length (to exclude stale data) or to use stochastic gradient descent (SGD) with fixed learning rate and weight decay, which leads older observations to have less influence (Ditzler et al., 2015). Here, we demonstrate that performance can be significantly improved by retaining all data and using a learning model that accounts for the temporal structure of the environment. We introduce nonstationarity in our simulations by varying the latent data-generating parameters according to 1/f noise. Thus an important caveat is the task domains are matched to the Bayesian model. Notwithstanding, we test robustness by using a different set of timescales for task generation versus learning (Section 5.1), a generative process that mismatches the NN architecture (Section 5.2), and a construction of 1/f noise that differs from the sum-of-diffusion process the model assumes (Section 5.3). Results show the Bayesian methods (KF and EKF) outperform windowing and online SGD, as well as a novel heuristic of training the network on all past data with gradients weighted by recency. We also find the variational approximation performs nearly as well as the full model (Section 5.1) and scales well to a multilayer NN trained on real data (Section 5.3). 2 MULTISCALE OPTIMIZER Assume a statistical model ˆy(t) = h(x(t), w(t)) and loss function L(y, ˆy), where x(t) is the input on step t, w(t) is the parameter estimate, ˆy(t) is the model output, and y(t) is the target output. In a Published as a conference paper at ICLR 2023 0 50 100 150 200 -1 1 Figure 1: Toy illustration of fast weights. A single weight w (blue) with constant input (x 1) predicts a target signal T (black) with square loss L = 1 2(T w)2. The weight is a sum of subweights ωslow (yellow) and ωfast (red). Initial learning is rapid, due to ωfast. Because of decay and the shared error signal, knowledge is gradually transferred to ωslow while ωfast returns to zero. When the task switches (trial 151), ωfast enables rapid adaptation while long-term knowledge is preserved in ωslow. Thus the model recovers quickly on the second reversal (compare blue curve beginning on trials 1 vs 156). The general multiscale optimizer extends this idea to an array of faster and slower weights. NN, w(t) is the vector of current weights. (Under the Bayesian framing in Section 4, w is the mean estimate of the optimal parameters θ.) For exposition, we assume the weights are updated by SGD, w(t + 1) = w(t) α w(t)L(y(t), ˆy(t)), (1) and we henceforth abbreviate the gradient as w(t)L. However, the following approach can be naturally composed with other optimizers, such as extensions of SGD or Hebbian learning, by replacing α w(t)L with the appropriate update term. The multiscale optimizer is motivated by the assumption that, in online learning tasks, the true or optimal parameters change over time, on multiple timescales. Accordingly, it expands each weight into a sum of subweights, wj = P ωij, each with a different learning rate αi and decay rate γi. Here j indexes weights in w, and i indexes timescales. The subweights evolve according to: ωij(t + 1) = γiωij(t) αi wj(t)L. (2) Each ωij has characteristic timescale τi := ( log γi) 1. Note that wj(t)L = ωij(t)L, so one can think of the gradient for wj being apportioned among the subweights (with total learning rate α = P αi), or equivalently of each subweight following its own gradient. 2.1 FAST WEIGHTS A potentially important special case of multiscale learning arises with two timescales, w = ωslow + ωfast. We assume γslow = 1 (no decay) and αfast > αslow. Thus each ωslow,j can be thought of as the original weight, which is augmented by ωfast,j, a second channel between the same neurons that both learns and decays rapidly. The fast weight enables the system to adapt quickly to distribution shifts while resisting catastrophic forgetting (Figure 1). This model is conceptually similar to the fast weights approach of Ba et al. (2016) and Hinton & Plaut (1987). In that work, the weights are updated by a different mechanism (Hebbian learning) than the primary weights, and they act as a memory of recent hidden states in a recurrent network. In the present conception, fast weights optimize the same loss as the primary weights, only with different temporal properties, and they act as a memory for recent learning signals (e.g., loss gradients). Thus they are perhaps better suited for handling distribution shifts of the sort considered here. 3 EQUIVALENCE RESULTS 3.1 BENNA-FUSI SYNAPSE Benna & Fusi s (2016) model synapse is designed to capture how biochemical mechanisms in real synapses implement a cascading hierarchy of timescales, and has been adopted in ML for continual reinforcement learning (Kaplanis et al., 2018; 2019). We consider a single weight w in a network, suppressing the index j. The Benna-Fusi model assumes that the information in w is maintained in a 1d hierarchy of variables u1, . . . , un, each dynamically coupled to its immediate neighbors: C1(u1(t + 1) u1(t)) = g1(u2(t) u1(t)) w(t)L (3) Ck(uk(t + 1) uk(t)) = gk 1(uk 1(t) uk(t)) + gk(uk+1(t) uk(t)) (4) Published as a conference paper at ICLR 2023 for 2 k n, with gn = 0. The external behavior of the synapse comes from u1 alone (i.e., w = u1), while u2:n act as stores with progressively longer timescales. This update rule can be rewritten as u(t + 1) = T u(t) d(t), (5) with transition matrix T determined by the coefficients in Equations 3 and 4, and external signal d(t) defined by d1(t) = 1 C1 w(t)L and d2:n 0. It can be shown that the transition matrix is diagonalizable, T = V ΛV 1, with eigenvalues Λii = λi < 1 (see Appendix A). We can further enforce V1 = 1, for a purpose explained below. We refer to the eigenvectors (columns V i) as modes of the system, because they are preserved over time up to a scalar. That is, if the initial state is proportional to mode i, then in the absence of external signal (d 0), the system will remain in that mode, decaying exponentially with rate factor λi: u(0) V i = t : u(t) = λt iu(0) (6) In general, any state can be written uniquely as a linear combination of modes, u = P ωi V i = V ω. Therefore, reparameterizing the model as ω := V 1u yields the simplified update equation: ω(t + 1) = Λω(t) + V 1d(t) (7) where V 1d(t) = 1 C1 [V 1] 1 w(t)L. Because Λ is diagonal, there is no cross-talk between the modes, unlike in the original dynamics. Thus we have derived an instance of the multiscale optimizer, with subweights ωi(t), decay rates λi, and learning rates 1 C1 [V 1]i1. The assumption above, V1 = 1, implies w = u1 = P ωi, so the models agree on the external behavior of the weight as a whole. Figure 2 illustrates the translation between the two models. 3.2 MOMENTUM LEARNING The standard rationale for momentum learning is to smooth updates over time, so that oscillations along directions of high curvature cancel out while progress can be made in directions with consistent gradients (Rumelhart et al., 1986). To simplify notation, we again focus on a single weight w in the network, suppressing the index j. The momentum g is defined as an exponentially filtered running average of gradients, with weight update determined by current momentum: g(t + 1) = βg(t) + (1 β) w(t)L (8) w(t + 1) = w(t) ηg(t + 1). (9) This formulation is equivalent to one in which the update w(t) = w(t + 1) w(t) includes a portion of the previous update: w(t) = α w(t)L + β w(t 1), with α = η(1 β). Paralleling the analysis in Section 3.1, we write the state of the momentum optimizer as [w, g] and use Equations 8 and 9 to obtain the update rule: w(t + 1) g(t + 1) + η(1 β) (1 β) w(t)L. (10) The transition matrix has eigenvectors [1, 0] with eigenvalue 1, and [1, 1 β ηβ ] with eigenvalue β. Now use this eigenbasis to define a reparameterization: w g = 1 1 0 1 β ωslow ωfast Substitution into Equation 10 yields the reparameterized update rule: ωslow(t + 1) ωfast(t + 1) ωslow(t) ωfast(t) Thus we recover the fast-weight optimizer, with decay γfast = β and learning rates αslow = η and αfast = ηβ. The negative fast learning rate is perhaps surprising but can be understood as follows: When εfast < 0, the subweights learn in opposite directions, with the latent knowledge in ωslow overshooting the observable knowledge in w = ωslow + ωfast. As ωfast decays toward 0, w catches up to ωslow, so that the model appears to continue learning from past input, just as it would with momentum. This analysis highlights the contrasting rationales of these two methods: Learning at multiple timescales is motivated by an expectation of positive autocorrelation in the environment, whereas momentum is effective at smoothing out negative autocorrelation in the gradient signal. Published as a conference paper at ICLR 2023 1 2 3 4 5 6 7 8 -0.25 1 2 3 4 5 6 7 8 -0.4 1 2 3 4 5 6 7 8 -0.25 1 2 3 4 5 6 7 8 0 1 2 3 4 5 6 7 8 0 1 2 3 4 5 6 7 8 -1 Figure 2: Translation between the model of Benna & Fusi (2016) and the multiscale optimizer works by decomposing the state of the former model into modes, or eigen-patterns of activation that decay independently, which correspond to subweights in the multiscale optimizer. A: All modes for a default Benna-Fusi model with eight variables (n = 8). B: An arbitrary initial state of the model. C: Unique eigen-decomposition of the state in Figure 2B. Implied values of the corresponding multiscale optimizer s subweights can be read off as the values of the curves at k = 1. D: Decay of the individual modes or subweights for 1000 steps (with no external input) at rates given by their eigenvalues. E: Reconstruction of the final state exactly matches the result of iterating the Benna-Fusi update (dotted arrow from Figure 2B). F: Decomposition of a unit impulse to u1 (e.g., loss gradient, shown as grey bar) as a weighted sum of modes. Learning rates for the corresponding subweights, ωi, can be read off as the values of the curves at k = 1 (because V1i = 1). 4 BAYESIAN MULTISCALE OPTIMIZER We turn now to a normative analysis of learning at multiple timescales, based on Bayesian inference over 1/f noise. The Bayesian model introduced here assumes that the latent parameters θ governing the observed data in some learning task vary over time according to 1/f noise. When the statistical model h(x, θ) is linear in θ, exact Bayesian inference is possible with a KF that maintains a posterior over an expanded representation of θ. When the model is nonlinear, approximate Bayesian inference is achieved by an EKF that uses a linear approximation of h. We then show that a variational approximation of the KF or EKF, in which the posterior covariance matrix is constrained to be diagonal, yields an extension of the multiscale optimizer that adapts its learning rates online by tracking uncertainty. 4.1 GENERATIVE MODEL FOR 1/f NOISE Let zi(t) be an Ornstein-Uhlenbeck process (i.e., diffusion with decay), with timescale or inverse decay rate τi and diffusion rate σ2 i , defined by the following stochastic differential equation: dzi = τ 1 i z dt + σi d W. (13) Here W(t) is a standard Wiener process (Brownian motion). As a Gaussian process, zi has kernel E[zi(t)zi(t + s)] e |s|/τi, implying exponentially decaying autocorrelations. However, a superposition of such processes at different timescales can have qualitatively different properties (Eliazar Published as a conference paper at ICLR 2023 & Klafter, 2009). In particular, consider i=1 zi(t), (14) where τi = νi and σi = ν i/2 for a chosen ν > 1, and n is an integer such that τn is very large. We show in Appendix B that ξ has power-law (i.e., long range) autocorrelations, E[ξ(t)ξ(t+s)] |s| 1 for s τn, and accordingly a power spectrum that is well-approximated by 1/f for frequencies f τ 1 n . Moreover, m independent copies of this process constitute m-dimensional 1/f noise, due to the rotational invariance of multidimensional Ornstein-Uhlenbeck processes. This construction affords a flexible generative model of nonstationarity in a variety of online learning domains, by applying it to the latent parameters governing the relationships among observable variables. Assume we receive observations x(t), y(t) that we wish to model with a statistical model h that is parameterized by θ Rm: y(t) = h(x(t), θ(t)). (15) For example, h may be a NN with weights θ, input x, and target output y. The generative side of our Bayesian model posits latent variables zi (i = 1, . . . , n) such that each zi is an Ornstein-Uhlenbeck process in Rm with timescale τi, and these processes sum to determine the original parameters: i=1 zi(t). (16) These assumptions imply that θ follows a 1/f process, and they entail an expanded state representation, Z = (z1, . . . , zn) Rnm, that enables efficient inference as described in Section 4.2. 4.2 INFERENCE OVER 1/f NOISE VIA EXTENDED KALMAN FILTER We consider Bayesian methods that adopt the construction in Section 4.1 as a generative model to account for nonstationarity. Equations 13 and 14 describe a linear dynamic system with state Z = (z1, . . . , zn) Rn. If ξ is directly observed at discrete intervals, then optimal Bayesian online prediction of each ξ(t) based on all preceding observations can be implemented by a KF over Z (Kording et al., 2007) (see Appendix D). We extend this approach to arbitrary statistical models with nonstationarity in their latent parameters, as in Equations 15 and 16. When h is linear in θ (and hence in Z), such as in the regression task and 1-layer perceptron model in Section 5.1, exact inference is possible with a standard KF (Appendix D). For a general h, such as a multilayer NN, we use an EKF. The EKF makes a local linear approximation of h based on its Jacobian, the matrix of gradients of predictions ˆy with respect to θ (Appendix E). We use Ollivier s (2018) generalization of the EKF that replaces Gaussian observation noise with any exponential family p(y|ˆy), which is better suited for modeling discrete outcomes such as the classification tasks of Sections 5.2 and 5.3. 4.3 VARIATIONAL APPROXIMATION Finally, we derive a variational approximation of the EKF that extends the multiscale optimizer and affords efficient implementation in large NNs (Appendix F). As is standard, the EKF maintains an iterative prior over the latent state based on all previous observations: p (Z(t)|x