# learning_associative_memories_with_gradient_descent__45e0d64a.pdf Learning Associative Memories with Gradient Descent Vivien Cabannes 1 Berfin Sim sek 2 Alberto Bietti 2 This work focuses on the training dynamics of one associative memory module storing outer products of token embeddings. We reduce this problem to the study of a system of particles, which interact according to properties of the data distribution and correlations between embeddings. Through theory and experiments, we provide several insights. In overparameterized regimes, we obtain logarithmic growth of the classification margins. Yet, we show that imbalance in token frequencies and memory interferences due to correlated embeddings lead to oscillatory transitory regimes. The oscillations are more pronounced with large step sizes, which can create benign loss spikes, although these learning rates speed up the dynamics and accelerate the asymptotic convergence. In underparameterized regimes, we illustrate how the cross-entropy loss can lead to suboptimal memorization schemes. Finally, we assess the validity of our findings on small Transformer models. 1. Introduction Modern machine learning often involves discrete data, whether it is labels in a classification problem, sequences of text tokens in language modeling, or sequences of discrete codes when dealing with other modalities. In such settings it is common to consider cross-entropy objectives, and to embed each input and output token into high-dimensional embedding vectors. Deep learning architectures consist in transforming the embedding vectors by a cascade of linear matrix multiplications together with non-linear operations. This work aims at obtaining a fine-grained understanding of training a single such linear layer with the cross-entropy loss and fixed embeddings. Indeed, one could then see the training of deep models as the joint training of multiple such associative memory models. Although our setup admits 1Meta AI 2Flatiron. Correspondence to: . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). a standard convex analysis treatment, we felt the need to provide a finer picture, more inline with behaviors observed when training large neural networks. We consider N input tokens x [N], each associated with some output y = f (x) [M] for a deterministic function f : [N] [M] and some number M of classes.1 The input variable x is assumed to be drawn from a data distribution p(x). The goal is to learn the input-output relationship f with a model of the form f W (x) = arg max y [M] uy, Wex , with W Rd d (1) where ex, uy Rd are fixed input/ouput token embeddings with d 2, and W is a parameter to be learned. The quality of a model f W is typically measured through the 0-1 loss L01(f W ) = EX,Y [1f W (X) =Y ] = X p(x)1f W (x) =f (x), while W is learned by optimizing a surrogate loss with gradient methods. We will focus on the cross-entropy loss L(W) = EX,Y [ℓ(W; x, y)] (2) ℓ(W; x, y) = log( X z [M] e uz,W ex ) uy, Wex . (3) This loss is also known as softmax, multinomial logistic or negative log-likelihood loss. The model (1) can be seen as an associative memory, which stores or memorizes pairwise associations (x, y). Associative memories originate in the neural computation literature, where they were used to model how the brain stores information (see, e.g., Willshaw et al., 1969; Longuet-Higgins et al., 1970), and solve algorithmic tasks with a machine learning perspective (see, e.g., Hopfield, 1982; Hopfield & Tank, 1985). These models have gained in popularity recently, notably as candidates to explain the inner workings of some deep neural networks (Geva et al., 2021; Schlag et al., 2021; Ramsauer et al., 2021; Bietti et al., 2023; Cabannes et al., 2024). This line of work motivates the thorough study of the training dynamics behind associative memory models. This is the goal of our paper, formalized as Problem 1. Problem 1. Understand the training dynamics of the associative memory model with the cross-entropy loss. 1We use the notation [p] = {1, 2, . . . , p}. Learning Associative Memories Training dynamics on the logistic loss have been the subject of a vast line of work (see, e.g., Soudry et al., 2018; Ji & Telgarsky, 2019; 2021; Lyu & Li, 2019; Wu et al., 2023). In contrast to these works, we focus on the specific structure arising from our associative memory setup. In particular, we characterize phenomena observed empirically in deep learning, such as loss spikes and oscillations, which are often necessary for faster optimization. While oscillatory behaviors have been studied in the literature on large learning rates and edge-of-stability (Cohen et al., 2020; Nakkiran, 2020; Beugnot et al., 2022; Agarwala et al., 2023; Bartlett et al., 2023; Chen & Bruna, 2023; Rosenfeld & Risteski, 2023; Wu et al., 2023), our setup provides a new perspective on the matter, involving correlated embeddings and imbalanced token frequencies. Contributions. With the goal of resolving Problem 1, we make the following contributions: We show that all gradient dynamics (i.e. stochastic or deterministic, continuous-time or discrete-time) reduce to a non-linear system of interacting particles. We solve the deterministic dynamics for orthogonal embeddings, where the memories do not interfere much. We illustrate typical training behaviors in the overparameterized case d N by solving the system with two-particles. In particular, we show that competition between memories can lead to benign oscillations and loss spikes, especially when considering large step-size, although large learning rates accelerate the asymptotic convergence toward robust solutions of the underlying classification problem. In limited capacity regimes (d < N), we illustrate precisely the deterministic dynamics for d = M = 2, when N 2 particles interact. We showcase how the competition between memories can ultimately erase most of them. We complement our analysis with experiments, investigating small multi-layer Transformer models with our associative memory viewpoint and identifying similar behaviors to those pinpointed in the simpler models. 2. The Many Faces of Problem 1 This section focuses on the disambiguation of Problem 1. What type of understanding. Although the training dynamics are the result of deterministic computations on a computer, which could be described exhaustively, the causal factors behind their behavior are often too numerous for us to fully comprehend, even for the simple model (1). To overcome this issue, one can abstract coarse quantities that play important roles in many training scenarios, such as (i) the dimension and geometry of embeddings; (ii) data distri- butional properties, such as imbalanced token frequencies and heavy tails; (iii) the optimization algorithms and their hyperparameters, particularly the learning rate. We aim to highlight the effect of these factors, whose understanding would help predict the outcome of alternative training choices such as bigger learning rates, or different data curricula. Levels of understanding vary from rigorous math on small models, to controlled experiments on more complex models, or insights from the training of large-scale models without many ablation studies. This paper aims for a theoretical study in between these two extremes. Which dynamics. We classify dynamics into five types, providing coarse and fine approximations of the dynamics used in practice to train neural networks. Gradient flow. The gradient flow dynamics consists in letting the weight matrix W evolves according to the equation d W = L(Wt) dt. (4) From initialization W = W0, this deterministic evolution pushes Wt towards the lower value of the potential L. In our case, L is convex, but does not always have a minimum as it might be minimized by a s W for s going to infinity. Gradient descent. Gradient descent is the discrete-time approximation of the gradient flow dynamics, namely t W = ηt L(Wt), (5) where t W = Wt+1 Wt and ηt is a learning rate. In the continuous case, the learning rate corresponds to a reparameterization of the time with ds = ηt dt. However, in the discrete-time regime, the learning rate is an important parameter that does influence the dynamics. Stochastic gradient flow. In practice, following full batch dynamics, i.e., dynamics that involve processing all the data at all time to compute the gradient of L(W), is quite costly and inconvenient. In those cases, one can process a random subset of data instead to get a good estimate of L(W). This randomness can be modeled as a perturbation of the dynamics due to some random noise with zero mean. For gradient flow, the stochasticity is naturally modeled with a Brownian motion Et, d W = L(Wt) dt + σt d Et (6) where σt is the variance of the updates, i.e. of the gradient L(Wt; B) := P (x,y) B ℓ(Wt; x, y), when considering random batches B of data. Stochastic gradient descent. For discrete-time dynamics, stochastic gradient descent can be written as t W = ηt( L(Wt) + εt), (7) Learning Associative Memories where εt is some random variable. When the descent is unbiased , this variable has zero mean. Typically, L(Wt) + εt = X (x,y) B ℓ(Wt; x, y), for some random mini-batch B of data with y = f (x). Practical descent. Practitioners often use variants of stochastic gradient descent that are known to perform well empirically. These typically involve momentum in the descent, re-conditioning the gradient (Kingma & Ba, 2015), and the addition of normalization layers in the architecture (see, e.g., Ioffe & Szegedy, 2015; Ba et al., 2016). This paper focuses on gradient flow and gradient descent, postponing the study of other dynamics for future work. We discuss the consistency of these methods, i.e., if they are reaching the best possible performance, their asymptotic convergence behaviors, as well as finite-time behaviors. 3. Memories as Interacting Particles The section reduces the training dynamics to a system of interacting particles, where the particles correspond to inputoutput associations. To simplify the analysis of our model, a few simple observations can be made. First, one can identify matrices with two-dimensional tensors, highlighting the linearity of our model uj, Wei = W, uj ei in the tensor space Rd Rd = Rd d. Secondly, recall that the loss (3) corresponds to the negative log-likelihood ℓ(W; x, y) = log p W (y|x), of the probability p W whose conditionals are parameterized as a soft-max over the scores ex, Wuy , p W (y|x) exp( W, ex uy ). (8) The chain rule leads to the following formula, ℓ(W; x, y) = X z [M] p W (z | x)(uz uy) ex. (9) The gradient formula shows that the dynamics take place on the span of the (uj uk) ei with i [N] and j, k [M] up to an affine shift due to initialization. The resulting training dynamics can be studied by tracking projections onto the input and output embeddings, or onto another family generating the tensor space Rd Rd, which leads to a system of particles with non-linear interactions. Theorem 1 (Particle system). Define the particle wij, wij = W, uj ei = u j Wei, (10) as well as the constant correlation parameters αij = ei, ej , βijk = ui, uj uk . (11) The projected gradient can be rewritten as ℓ(W; x, y), uj ei = αix X βjzy exp(wxz) P k [M] exp(wxk). Hence, all variations of gradient dynamics, (4), (5), (6) and (7), can be expressed as a (stochastic) system of interacting particles. For example, the gradient descent dynamics (5) is twij = ηt X x p(x)αix X βjf (x)z exp(wxz) P k [M] exp(wxk). (12) Similarly, the dynamics for the stochastic gradient descent consists in replacing P x p(x) by the summation over x in a random mini-batch in (12). Proof. The proof follows directly from (8) and (9). There are two reasons for interactions in this particle system; either the input embeddings are not orthogonal and the α s mix the particles, or there are more than two classes and β s mix the particles. Moreover, when the embeddings are not orthogonal, particles are not independent, since an increase of wij changes wik, as soon as u i uk = 0. Note that multiple factors could lead to correlated embeddings, such as under-parameterization (viz., embeddings are necessarily correlated in low dimension), or semantic similarity in the case of trained embeddings (e.g., Mikolov et al., 2013). Interacting particle systems commonly arise in other machine learning settings, e.g., to describe parameters in the mean-field regime of two-layer networks (Chizat & Bach, 2018; Mei et al., 2018; Rotskoff & Vanden-Eijnden, 2018), samples in certain approaches to generative modeling (Liu & Wang, 2016; Arbel et al., 2019), or both (Domingo-Enrich et al., 2021). However, these systems typically involve particles as discretizations of an underlying measure evolution, while we make no such connection here. Our dynamics may also be seen as training the middle layer of a three-layer linear network, and infinite-width dynamics for related models have been studied in (Jacot et al., 2021; Chizat et al., 2022). Yet, our focus is on the finite width (d < ) case, and we note that this suffices for optimal storage when d is sufficiently large compared to N and M. The particle wxi corresponds to the score assigned by W to the class i for the token x. Another set of sufficient statistics for the problem are the margins, which are defined by mi(x) = wxf (x) wxi = (uf (x) ui) Wex. (13) It corresponds to the difference between the scores assigned by the model (1) to the classes f (x) and i, for the input x. When all the margins (mi(x))i are positive, the token x is classified correctly. Learning Associative Memories 4. Overparameterized Regimes This section focuses on the case where N d and the (ex) form a linearly independent family. In this setting, the optimization of the convex loss (2) will ensure perfect accuracy for our model, i.e. f W = f . 4.1. Orthogonal Embeddings We first solve the case where the embedding families (ex) and (uy) are both orthogonal. The orthogonality of the inputs implies αix = 1i=x ei 2, in which case (12) shows that the gradient dynamics for W decouple on the Rd ei. In other terms, our model is implicitly fitting in parallel N parameters, the (Wei)i, of N independent exponential families, the p W (y|x). As a consequence, we can forget the context variable and fix a x [N] for the remainder of the section. For simplicity, we assume f (x) = 1. Binary classification. Let us consider the binary case first. When M = 2, the dynamics on Wex evolves on the line R (u1 u2), and is fully characterized by the margin mt = (u1 u2) Wtex. An algebraic manipulation of (12) shows that this scalar quantity evolves according to the dynamics (1 + exp(mt)) tm = cxηt, cx = p(x) ex 2 u1 u2 2, and tm = mt+1 mt. This discrete-time evolution can be solved recursively. A nice formula can be derived for the continuous-time version, i.e. for the flow (4) instead of the descent (5), where (1 + exp(m)) dm = cx dt. In particular when w is initialized at zero, mt + exp(mt) = cxt. This equation is inverted with the product logarithm, giving an exact expression for the margin evolution, proven in Appendix B.1. Theorem 2 (Binary orthogonal). Let M = 2, and the input embeddings be orthogonal. The dynamics (4), (5), (6) and (7) lead to x mt(x) u1 u2 u1 u2 2 ex ex 2 + Π W0, (14) where Π is the projection on the orthogonal of the span of the gradient updates. For gradient flow (4), there exists a t0 R that depends on initial condition, e.g. t0 = 1/cx when W0 = 0, such that when t t0, the exact evolution is given by, mt(x) = log(cx(t t0)) h(cx(t t0)), (15) where cx = p(x) ex 2 u1 u2 2 and h is a function such that 0 h(x) 2 log(x)/x for all x 1. Similarly, for gradient descent (5) with a learning rate η, mt(x) log(ηcx(t t0)). In particular, when W0 = 0, it leads to the following bound on the loss, The setting of Theorem 2 is a special case of logistic regression with linearly separable data where all margins grow logarithmically without inhibiting each other. The loss monotonically decays in O(1/ηt) which corresponds to the rate of Wu et al. (2023), with the addition of the explicit dependence on the learning rate. Multi-class. We now attack the multi-class case. Let wi = wxi = u i Wex, and denote the partition function A(w) = P exp(wi) with w = (wi). Let us focus for now on gradient flow for simplicity. When the (ui) are orthogonal, the dynamics (12) becomes A(w) A(w) exp(w1) dw1 = p(x) ex 2 u1 2 dt, A(w) exp( wi) dwi = p(x) ex 2 ui 2 dt, for i = 1. In the multi-class case, the evolution of the margins mi = w1 wi do not directly decouple from each other as in the binary case. One can combine these differential equations to find many invariants of this system of interacting particles. In particular, exp( wi) exp( wj) stays constant over time.2 Some algebraic manipulations implies the following evolution of the tightest margin (corresponding to arg maxi =1 p W (i|x)) (ci + exp(mi)) dmi = p(x) ex 2 (ci + 1) dt. for some bounded function ci that depends on initial conditions. This leads to the same logarithm convergence as in the binary case. For simplicity, we only report the asymptotic behavior in Theorem 3, which we prove in Appendix B.2. Theorem 3 (Multi-class orthogonal). Assume that the input and output embeddings are orthonormal. For any initialization, the gradient flow dynamics (4) converges as lim t Wt log(t) X x Π(uf (x)) ex, where Π is the projection on the span of the (ui uj). 2This generates all the invariants of the dynamics but one, the remaining one is a consequence of d W Span {ui uj}ij Rd. Learning Associative Memories α = 0.5, p1 = 0.75 α = 0.95, p1 = 0.75 Figure 1: Level lines of L(W) for N = d = 2 as a function of γi(W) := (u2 u1) Wfi where (fi) is a basis of R2. Token embeddings have correlation α (16). We equally plot the value of L01(W), dark blue meaning perfect accuracy, and white meaning null accuracy. For gradient descent, one can express the updates for the margins similarly A(w) tmi = ηt(A(w) ew1 + ewi)p(x) ex 2 > 0. Hence, the margins only increase during training, and the larger the learning rate, the faster the evolution. Indeed, one gradient step is enough to learn all the associations x f (x), i.e. f W1(x) = f (x) for any initialization. Continuing the training will continue to increase the margins, ultimately ensuring the convergence of Wt to the maxmargin solution of the classification problem, as characterized by Theorem 3, making the final classifier robust to embedding displacements (Cortes & Vapnik, 1995). To conclude, when the embeddings are orthogonal, the memories do not interfere much, and one can learn all the associations with one giant gradient step. This case still presents several behaviors of interest. First of all, Equation (15) shows that the association x y is learned faster when x is frequent, i.e. p(x) is large. Indeed, early in training, one can envision W P x p(x)uy ex. However, later in training, the training dynamics will start saturating in the direction uy ex for the frequent tokens, allowing the less frequent ones to catch up. The catch-up is facilitated by large learning rates. Ultimately, as shown by Equation (3), the final W does not depend on the token frequencies (see Byrd & Lipton, 2019 for related observations). In other terms, if the model has enough capacity to learn all the data (in our case, orthogonality implies d N), then at the end of the training, it allocates equal capacity to every token even though some tokens are much rarer. Nonetheless, curating data to make them less redundant can make learning more efficient. 4.2. Particles Interfering Let us now consider the case where N d, but where memories interfere between them. We first notice that in the case when the input embeddings are orthogonal, correlated output embeddings introduce limited competition, and this case can largely be understood as a simplified version of the interaction between input embeddings. Let us analyze the simple but instructive case N = 2 of two α = 0.95, p1 = 0.75 η = 10 η = 1 Figure 2: Loss spikes. Trajectories of Wt in the setting of Figure 1 for two learning rates η, η = 10 in green, η = 1 in red, and their traces in term of losses as a function of the number of epochs, here t [35]. input tokens with f (x) = x, and αij = ei, ej = 1i=j + α1i =j, α [ 1, 1]. (16) In other terms, the input embeddings are normalized and are α-correlated. Two margins are at play: mi = wii wij = (ui uj) Wei, {i, j} = {1, 2}. The interacting system (12) becomes, pi 1 + exp(mi) αpj 1 + exp(mj) where c = u1 u2 2 and we denote pi = p(i) for readability. In the gradient dynamics, x = 1 pushes W in the direction (u1 u2) e1, which, when α 0, is positively correlated with the direction (u2 u1) e2 promoted by x = 2. As can be seen in Equation (17), when α 0, both margins increase during training, there is no competition between the memories, and a single gradient step is enough to reach perfect accuracy. To solve Equation (17), let us introduce the orthogonal family f1 = e1 + e2, f2 = e1 e2, and project the dynamics on those directions with 2(u1 u2) Wfi. (18) The evolution of the γi is governed by dct = (1 + α)p1 1 + exp (γ2 + γ1) (1 + α)p2 1 + exp (γ2 γ1) dγ2 dct = (1 α)p1 1 + exp (γ2 + γ1) + (1 α)p2 1 + exp (γ2 γ1), (19) From the second differential equation, we see that γ2 always increases during the dynamics. The growth of γ2 will slow down the growth of γ1. These together imply that W grows logarithmically in one direction (f2, which turns out to be the max-margin direction) and stays bounded in the orthogonal direction, which we prove in Appendix C.1 and is the object of the following theorem. Learning Associative Memories Theorem 4 (Two particles interacting). Let N = 2 with f (x) = x. Assume without restriction that p1 p2. When Equation (16) holds, if W is initialized at zero, i.e. W0 = 0, for gradient flow, γ2(t) = log(ctt + 1) + O log(c2t + 1) 2 log p1/p2 + O(1/t), where 2p2 ct/c0 8p3 1/p2 2 with c0 = (1 α)c. Similarly for gradient descent, with any step-size η 0, γ2(t) log(ηp2c0t + 1) + O log(t) 2 log(p1/p2) η(1 α)p1 + p1 2p2 + O(1/t). These results are consistent with Wu et al. (2023; 2024), although our focus is to obtain a fine-grained dependence on the quantities relevant to our setting (α, p1, p2, η). For any learning rate η, when t grows large, both margins eventually become positive (since they are proportional to γ2 γ1 with γ1 bounded), leading to perfect accuracy of our model. In the dynamics analyzed so far, we observe a stationary regime where Wt log(t)W . However, transitory regimes can hide under the big-O in Theorem 4 we characterize the big-O precisely in the appendix. When considering discrete-time dynamics such as gradient descent (5), or stochastic dynamics, i.e., (6) or (7), those transitory regimes can showcase weight oscillations and loss spikes. For example, when N = 2 and there is strong association imbalanced and correlation, viz. αp(1) p(2), the dynamics at the beginning of training can be approximated by m1 = ηtp(1) 1 + exp(m1), m2 = ηtαp(1) 1 + exp(m1). Hence, in terms of the association stored in W, when the learning rate is large, the token x = 1 will erase the token x = 2. Since p W (f (x)|x = 2) approaching to zero implies that ℓ(W; x, f (x)) goes to infinity, this can lead to arbitrarily big loss spikes, as captured by Proposition 5, proved in Appendix C.2. However, later in training, p W (2) catches up and W ultimately aligns in the max-margin direction, while m1 m2 remains bounded. Proposition 5 (Loss spikes). Let N = 2 with f (x) = x. Assume that Equation (16) holds, and αp1 p2 > 0. From a null initialization W0 = 0, one gradient update (5) with learning rate η leads to L(W1) η(αp1 p2)p2, (20) which can be arbitrarily large. 1 2 3 log(p1/p2) 0.6 0.8 1.0 α Figure 3: Level lines of the (logarithm of the) number of steps needed to reach perfect accuracy in the setting of Theorem 4, as a function of the learning rates η, the interaction parameter α, and the class imbalance log(p1/p2). Red means more steps to reach perfect accuracy. To conclude, for overparameterized models, the dynamics is initially governed by memory interactions, before settling in a stationary regime similar to the orthogonal case described in Theorem 3. The oscillatory regime is due to the competition between two groups of tokens where increasing the margins of the high-frequency tokens causes a decrease in the margins of the others, similar to the opposing signals in Rosenfeld & Risteski (2023). The settling down of the dynamics can be understood intuitively. Since the max-margin will grow, all the partition function A(Wex) of the p W (y|x) will grow, which will slow down the dynamics. Hence the oscillation will fade, and dynamics will enter the stationary logarithmic regime. In the stationary regime, bigger learning rates act as a speed-up of time, ensuring faster convergence. From a learning efficiency point of view, there is a trade-off between large learning rates implying longer oscillatory transitory regimes, and small learning rates implying slow speed of the dynamics. We illustrate this trade-off in Figure 3. We observe that class imbalance and interference make the problem harder, and that large learning rates are beneficial, although very large learning rates can be detrimental (top left of the left plot). 4.3. Graphical Understanding Now that we have a good understanding of the mechanisms at play, we can verify these phenomena more generally through simulations. Let us first leverage the previous derivations to explain how to read measures of performance that can be obtained from experiments. When d = M = 2 and any value of N 2, the problem reduces to a twodimensional ones with γi = (u2 u1) Wfi, (fi)k = 1i=k. (21) In the resulting two-dimensional space, we can plot the level lines of the loss function, the level lines of its Hessian eigenvalues, as well as the trajectories followed by Wt for different optimizers. Figure 1 shows the level lines of the loss for p(1) = 3/4 and α { 1/2, 0.95}. The deterministic gradient trajectories can be deduced from this picture: they are always orthogonal to level lines, and their speed is proportional Learning Associative Memories N = 3, E = 0.27 1.0 0.5 0.0 0.5 N = 30, E = 0.22 SGD trajectory Figure 4: Forgetting. Similar plots as in Figures 1 and 2, yet in the limited capacity case d < N. In those situations, competition between the memories can lead to sub-optimal minimizer of L, which we illustrate with SGD on the bottom plots. The sub-optimality is reflected in the excess of risk E = L01(arg min W L(W)) min W L01(W). Figure 5: Sharpness profile. Gradient descent trajectories in the setting of Figures 2 and 4 with learning rates η = 10 (green) and η = 1 (red). We plot the level lines of the sharpness, i.e. the operator norm of 2L(W), as well as the trace of the trajectories in terms of sharpness. The left plots are in the overparameterized regime, the right ones in the underparameterized one. to the number of lines crossed locally. The fact that there is not much level line in the region {W|L01(W) = 0} is due to the logarithmic convergence illustrated by Theorem 3. The right of Figure 1 shows that, although gradients are always positively correlated with the max-margin direction (formally dγ2 0 in (19)), they can point in directions that do not lead to perfect accuracy. Indeed, Figure 2 illustrates how large learning rates are likely to result in spikes of both the loss and the accuracy. This latter figure shows the trajectories of Wt for two different learning rates, and the trace of these trajectories in the training loss and accuracy plots which are usually monitored by practitioners training neural networks. 5. Numerical Analysis This section complements previous derivations with numerical analysis. It discusses underparameterized regimes, large versions of model (1), as well as more complicated ones. 5.1. Limited capacity Let us start the numerical analysis with the case where N > d. In those cases, one can not necessarily store all associations in memory, and the model has to favor some of them. It was shown in Cabannes et al. (2024) that the ideal W can usually store about d memories similarly to Hopfield network scalings. However, this ideal W is not always the one minimizing the cross-entropy loss. We plot our problem in the case M = d = 2 thanks to the statistics γi of (21). Figure 4 reveals a striking fact: the cross-entropy loss is not calibrated for our model, i.e., minimizing L(W) does not always minimize L01(W). Indeed, even in the case N = 3, one can find examples where competition between the memories leads the minimizer of L to forget the most frequent association. When N becomes large in front of d, these cases become the norm. On these landscapes, one can come up with examples of catastrophic forgetting, where the dynamics is first dominated by frequent tokens that are well memorized until rare classes come into play, perturbing the minimizer of L, ultimately leading to convergence to a sub-optimal place. We illustrate it on the right of Figure 4. To further illustrate the differences between the dynamics in overand under-parameterized regimes, Figure 5 illustrates the sharpness, as defined by the operator norm of the Hessian of L(W) along two descent trajectories. We compute the Hessian in closed-form to show its level lines, illustrating that the sharpness of the logistic loss is mainly high for small values of the norm of W. We observe three types of behaviors. In the separable case, e.g. when d N, the transitory regime goes through relatively sharp regions, before the stationary regime where the sharpness decreases until reaching zero at infinity. In the non-separable case, with is typical when d < N, either the learning rate is small enough and we converge to the minimum of L presenting a sharpness H greater than 2/η, or the learning rate is greater than 2/H and we oscillate around the minimizer of L. 5.2. Larger dimension When the dimension d is larger, although we can not plot the weight-space, we can plot the evolution of certain statistics, such as the margin, along descent trajectories. In Fig- Learning Associative Memories d = 3, η = 3 d = 3, η = 10 d = 5, η = 3 d = 5, η = 20 0 20 iteration t d = 10, η = 3 0 20 iteration t d = 10, η = 20 Figure 6: Margins mt(x) for N = 5 tokens, with varying dimensions d and learning rates η. The embeddings were sampled uniformly at random on the sphere. Large learning rates learn faster, although they lead to more oscillation, especially in low dimension. When d < N, the model does not have enough capacity to learn all the associations, and it favors the most frequent ones. ure 6, we consider a setup with N = M = 5, f (x) = x, and p(x) 1/x, in different dimensions (with random embeddings). We show the evolution of the margins mt(x) = uf (x), Wex max j =x uj, Wex . (22) Perfect accuracy is achieved when mt(x) > 0 for all x. We see the faster increase of margins for more frequent tokens, faster convergence with large step-size η, at the cost of oscillations, and benefits of larger d. The latter are likely due to less interference thanks to more orthogonality between random embeddings in higher dimension. 5.3. Simplified Transformer model Finally, we empirically study training dynamics on a more complex model involving multiple associative memory mechanisms like the ones above. In particular, we consider a simplified two-layer Transformer architecture trained on an in-context learning task (described in Appendix D) that requires copying a bigram from the context depending on the current token. A two-layer attention-only transformer can solve this by implementing an induction head mechanism (Elhage et al., 2021; Olsson et al., 2022), and Bietti et al. (2023) show that this can be achieved by training only three matrices W 2 O, W 2 K, and W 1 K. These were found to behave as associative memories with appropriate embeddings, 0 20 40 iteration t Training W 2 O only, d = 64 η = 20 η = 50 η = 100 η = 200 0 20 40 iteration t W 2 O, d = 64, η = 200 Train all, d = 64 W 2 O, d = 128, η = 20 0 50 100 iteration t Train all, d = 128 η = 2 η = 5 η = 10 η = 20 0 50 100 iteration t W 2 K, d = 128, η = 20 Figure 7: Full-batch training of selected transformer layers on the bigram task. (top) Loss and margins when training W 2 O alone. (bottom) Loss and margins when training the three layers (W 2 O, W 2 K, and W 1 K) sufficient for the task, for two widths d. Training losses are shown for different step-sizes η, and margins are shown for 5 different tokens. and we may thus empirically assess their margins. We consider full-batch gradient descent on a dataset of 16 384 sequences of length 256 generated from the model described above with N = 64 tokens. The top of Figure 7 shows the training loss and margins when only training W 2 O, which can learn an appropriate associative memory by itself. The objective for this problem is convex and similar to the ones considered in this paper, up to some noise in the input embeddings due to attention. We see that margins tend to increase during training, and that large learning rates lead to faster optimization of the loss, at the cost of some spikes in the loss, and oscillations in the margins, which are due to correlations between embeddings. The bottom of Figure 7 shows loss curves and margin evolution when training all three matrices. Here we see more frequent spikes in the loss for large learning rates, yet their gains are much more significant later in training, with small final losses that suggest the induction head mechanism is learned. The increasing margins confirm that the desired associative behaviors have indeed been recovered. Compared to the top of Figure 7, the margins for W 2 O display more significant oscillations initially, likely due to addi- Learning Associative Memories tional interactions across different parameter matrices. In later iterations, when the attention heads are in place and inputs to W 2 O are less noisy, the margins increase together to large values, leading to a similar learning speed on all memories. This uniform convergence behavior was facilitated by the relatively low output tokens imbalance considered in our tasks where the copied tokens were sampled uniformly. Finally, we see the effect of larger embedding dimensions d, accelerating the convergence thanks to more orthogonality. 6. Discussion In this paper, we studied the gradient dynamics of associative memory models trained with cross-entropy loss, by viewing memory associations as interacting particles. This leads to new insights on the role of the data distribution and correlated embeddings on convergence speed as well as training instabilities in large learning rate regimes, such as oscillations and loss spikes. We also showed that some of these insights may transfer to some more realistic scenarios such as training small Transformers. Nonetheless, our simple model is only a first step, and there are many additional factors at play in larger models, which may lead to different behaviors and instabilities. This includes factorized parameterizations, normalization layers, adaptive optimizers, noisy data, and interactions between different layers which may change at different timescales. Studying the impact of these on training dynamics could unlock new improvements to the practice and reliability of training large models. Impact Statement This theoretical work aims to advance our understanding of training dynamics. Its short-time impact is limited. In the long run, our stream of research could help improve the training of large language models, from an energy or alignment standpoint. It indirectly relates to the quest for general artificial intelligence, which is not without consequences, although discussing them is beyond the scope of this paragraph. Acknowledgements The authors would like to thank Léon Bottou and Hervé Jegou for fruitful discussions that led to this line of work. Agarwala, A., Pedregosa, F., and Pennington, J. Secondorder regression models exhibit progressive sharpening to the edge of stability. In International Conference on Machine Learning (ICML), 2023. Arbel, M., Korba, A., Salim, A., and Gretton, A. Maximum mean discrepancy gradient flow. Advances in Neural Information Processing Systems (Neur IPS), 2019. Ba, J. L., Kiros, J. R., and Hinton, G. E. Layer normalization. ar Xiv preprint ar Xiv:1607.06450, 2016. Bartlett, P. L., Long, P. M., and Bousquet, O. The dynamics of sharpness-aware minimization: Bouncing across ravines and drifting towards wide minima. Journal of Machine Learning Research (JMLR), 24(316):1 36, 2023. Beugnot, G., Rudi, A., and Mairal, J. On the benefits of large learning rates for kernel methods. In Conference on Learning Theory (COLT), 2022. Bietti, A., Cabannes, V., Bouchacourt, D., Jegou, H., and Bottou, L. Birth of a transformer: A memory viewpoint. In Advances in Neural Information Processing Systems (Neur IPS), 2023. Byrd, J. and Lipton, Z. What is the effect of importance weighting in deep learning? In International conference on machine learning (ICML), 2019. Cabannes, V., Dohmatob, E., and Bietti, A. Scaling laws for associative memories. In International Conference on Learning Representations, 2024. Chen, L. and Bruna, J. Beyond the edge of stability via two-step gradient updates. In International Conference on Machine Learning (ICML), 2023. Chizat, L. and Bach, F. On the global convergence of gradient descent for over-parameterized models using optimal transport. In Neural Information Processing Systems (Neur IPS), 2018. Chizat, L., Colombo, M., Fernández-Real, X., and Figalli, A. Infinite-width limit of deep linear neural networks. ar Xiv preprint ar Xiv:2211.16980, 2022. Cohen, J., Kaur, S., Li, Y., Kolter, J. Z., and Talwalkar, A. Gradient descent on neural networks typically occurs at the edge of stability. In International Conference on Learning Representations (ICLR), 2020. Cortes, C. and Vapnik, V. Support-vector networks. Machine Learning, 1995. Domingo-Enrich, C., Bietti, A., Gabrié, M., Bruna, J., and Vanden-Eijnden, E. Dual training of energy-based models with overparametrized shallow neural networks. ar Xiv preprint ar Xiv:2107.05134, 2021. Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Das Sarma, N., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., Mc Candlish, S., and Olah, C. A mathematical framework Learning Associative Memories for transformer circuits. Transformer Circuits Thread, 2021. Geva, M., Schuster, R., Berant, J., and Levy, O. Transformer feed-forward layers are key-value memories. In EMNLP, 2021. Hoorfar, A. and Hassani, M. Inequalities on the lambert w function and hyperpower function. J. Inequal. Pure and Appl. Math, 9(2):5 9, 2008. Hopfield, J. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the United States of America, 1982. Hopfield, J. and Tank, D. Neural computation of decisions in optimization problems. Biological Cybernetics, 1985. Ioffe, S. and Szegedy, C. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning (ICML), 2015. Jacot, A., Ged, F., Sim sek, B., Hongler, C., and Gabriel, F. Saddle-to-saddle dynamics in deep linear networks: Small initialization training, symmetry, and sparsity. ar Xiv preprint ar Xiv:2106.15933, 2021. Ji, Z. and Telgarsky, M. The implicit bias of gradient descent on nonseparable data. In Conference on Learning Theory (COLT), 2019. Ji, Z. and Telgarsky, M. Characterizing the implicit bias via a primal-dual analysis. In Algorithmic Learning Theory, 2021. Kingma, D. and Ba, J. Adam: A method for stochastic optimization, 2015. Liu, Q. and Wang, D. Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in neural information processing systems (NIPS), 2016. Longuet-Higgins, C., Willshaw, D., and Buneman, P. Theories of associative recall. Quarterly Reviews of Biophysics, 1970. Lyu, K. and Li, J. Gradient descent maximizes the margin of homogeneous neural networks. In International Conference on Learning Representations (ICLR), 2019. Mei, S., Montanari, A., and Nguyen, P.-M. A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33): E7665 E7671, 2018. Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., and Dean, J. Distributed representations of words and phrases and their compositionality. Advances in neural information processing systems (NIPS), 2013. Nakkiran, P. Learning rate annealing can provably help generalization, even for convex problems. ar Xiv preprint ar Xiv:2005.07360, 2020. Olsson, C., Elhage, N., Nanda, N., Joseph, N., Das Sarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., Mc Candlish, S., and Olah, C. In-context learning and induction heads. Transformer Circuits Thread, 2022. Ramsauer, H., Schäfl, B., Lehner, J., Seidl, P., Widrich, M., Adler, T., Gruber, L., Holzleitner, M., Pavlovi c, M., Sandve, G. K., Greiff, V., Kreil, D., Kopp, M., Klambauer, G., Brandstetter, J., and Hochreiter, S. Hopfield networks is all you need. In International Conference on Learning Representations (ICLR), 2021. Rosenfeld, E. and Risteski, A. Outliers with opposing signals have an outsized effect on neural network optimization. ar Xiv preprint ar Xiv:2311.04163, 2023. Rotskoff, G. M. and Vanden-Eijnden, E. Trainability and accuracy of neural networks: An interacting particle system approach. ar Xiv preprint ar Xiv:1805.00915, 2018. Schlag, I., Irie, K., and Schmidhuber, J. Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning (ICML), 2021. Soudry, D., Hoffer, E., Nacson, M. S., Gunasekar, S., and Srebro, N. The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research (JMLR), 19(1):2822 2878, 2018. Willshaw, D., Buneman, P., and Longuet-Higgins, C. Nonholographic associative memory. Nature, 1969. Wu, J., Braverman, V., and Lee, J. D. Implicit bias of gradient descent for logistic regression at the edge of stability. In Neural Information Processing Systems (Neur IPS), 2023. Wu, J., Bartlett, P., Telgarsky, M., and Yu, B. Large stepsize gradient descent for logistic loss: Non-monotonicity of the loss improves optimization efficiency, 2024. Learning Associative Memories A. Gradient derivations In the following, to be consistent with pytorch convention, we redefine the model as f W (x, y) = ex, Wuy Recall that the loss can be understood intuitively as the negative log-likelihood L(W) = EX,Y [log p W (Y | X)], where p W (y | x) = exp( ex, Wuy ) P z exp( ex, Wuz ), of the probability p W whose conditional distributions are parameterized as a soft-max over the score ex, Wuy . In information theory, this negative log-likelihood is known as the cross-entropy of the model probability p W relative to the real data distribution p. To analyze the training dynamics, we will monitor quantities related to the gradient and the Hessian of the loss function. The gradient of the loss is easy to compute with simple derivation rules. ℓ(W; x, y) = log z exp ex, Wuz exp ex, Wuz P z exp ex, Wuz exu z exu y . This gradient can be understood with the probabilistic perspective on the loss as ℓ(W; x, y) = X z p W (z | x)exu z exu y . (23) HESSIAN COMPUTATION Notice that when the gradient can be written as ℓ(θ) = g(θ) a = ( iℓ(θ))i, with g(θ) R, a Rd, the Hessian follows as 2ℓ(θ) = ( ijℓ(θ)) = ( i jℓ(θ)) = ( ig(θ) aj) = ( g(θ)) a . In our case, we want to use the Euclidean structure on the matrix space, which leads to 2ℓ(W; x, y) = X z p W (z | x)(ex uz) . (24) To compute p W (z | x), notice that we could equally have expressed the loss gradient as ℓ(W; x, y) = log p W (y | x) = p W (y | x) p W (y | x) , from which we deduce that p W (z | x) = p W (z | x) ℓ(W; x, z) = p W (z | x) z p W (z | x)ex uz Plugging this into Equation (24), we will have to deal with quantities such as3 (ex uy)(ex uz) = ex uy ex uz. 3Recall that a tensor M = a1 a2 . . . ap (Rd) p can be understood as a p-dimensional matrix M. If (ei) denotes the basis of Rd, then ej1 ej2 . . . ejp is a basis of the tensor space, and M assimilates to M such that M[j1, j2, . . . , jp] = Q i [p] ai, eji . Learning Associative Memories The last operation can be understood from the fact that, when fi is the canonical basis of Rd, and a and b are in Rd, we have the matrix identification (ab )i,j = a, fi b, fj In our case, using ij as the matrix indexation for Rd d and fif j as the canonical basis of Rd d, (ex uy)(ex uz) ij,kl = ex uy, fi fj ex uz, fk fl = (f i ex)(f j uy)(f k ex)(f i uz) = fi fj fk fl , ex uy ex uz . Using Equations (24) and (25), we deduce 2ℓ(W; x, y) = X z,z p W (z | x)(δz,z p W (z | x))ex uz ex uz . (26) We implemented this formula vectorially in our code, and checked its correctness based on automatic differentiation libraries. B. Dynamics without interference To study the dynamic, using Equation (9), with the notation of the main text, i.e. f W (x, y) = uy, Wex , we have ℓ(W; x, y)ei = X z p W (z | x)(uz uy)e x ei. In particular, when the (ei) are orthogonal, summing over x leads to L(W)ex = p(x) ex 2 X z p W (z | x)(uz uf (x)). (27) B.1. Binary classification - Proof of Theorem 2 Let us consider the binary case where y {1, 2}. Assume that f (x) = 1, Equation (27) simplifies as L(W)ex = p(x) ex 2 p W (2 | x)(u2 u1). We can project it on the line where it evolves, reducing this equation to a scalar evolution (u1 u2) L(W)ex = p(x) ex 2 p W (2 | x) u2 u1 2 = p(x) ex 2 u2 u1 2 exp(u 2 Wex) exp(u 1 Wex) + exp(u 2 Wex) = p(x) ex 2 u2 u1 2 1 exp((u1 u2) Wex) + 1. Let us consider the evolution equation, for some learning rate scheduling (ηt)t 1 Wt+1 = Wt ηt L(Wt). This leads to (u1 u2) (Wt+1 Wt)ex = ηtp(x) ex 2 u2 u1 2 1 exp((u1 u2) Wtex) + 1. Let us set mt = (u1 u2) Wtex, c = p(x) ex 2 u2 u1 2 . The evolution equation becomes (exp(mt) + 1)(mt+1 mt) = ηtc. (28) Learning Associative Memories B.1.1. GRADIENT FLOW. Since the updates are in the span of (u1 u2) ex, we have x αt(x)(u1 u2) ex + Π W0, where Π is the projection on the orthogonal of the gradient updates, and αt(x) can be inferred from the margin mt(x) = (u1 u2)T Wtex = αt(x) u1 u2 2 ex 2, In the following, we will solve the ODE for each margin using the product logarithm. The gradient flow limit of the previous derivation in Eq. 28 leads to (exp(m) + 1) dm = c dt. Integrating this differential equation gives exp(mt) + mt = ct + exp(m0) + m0. For x, y R, we can solve exp(x) + x = y (y x) exp(y x) = exp(y) x = y W0(ey), where W0 is the product logarithm. This allows us to solve the previous equation in closed-form mt = ct + em0 + m0 W0(ecteexp(m0)+m0). (29) We can simplify this profile using the following asymptotic development of the product logarithm (Hoorfar & Hassani, 2008) when x e, W0(x) = log(x) log log(x) + h(log(x)) log log(x) log(x) , with h(x) [1/2, 2]. We deduce that as soon as ct + exp(m0) + m0 0, mt = log(ct + exp(m0) + m0) h(ct + exp(m0) + m0) log(ct + exp(m0) + m0) ct + exp(m0) + m0 , W0(ecteexp(m0)+m0) = ct + exp(m0) + m0 log(ct + exp(m0) + m0) + h( ) log(ct + exp(m0) + m0) ct + exp(m0) + m0 . The first part of theorem is found with t0 = exp(m0) + m0 and with the substitution of h by h(x) = h(x) log(x)/x. B.1.2. GRADIENT DESCENT In the case of gradient descent, we will work with the discrete update equation mt+1 mt = cηt exp(mt) + 1. Since we expect a logarithmic growth of the margins, we exponentiate this equation and rearrange terms exp(mt+1) exp(mt) = exp(mt) exp cηt exp(mt) + 1 1 . Learning Associative Memories Notably, when we initialize the weights at zero, all margin updates are positive in this case with no interferences. This implies that mt 0 at all times. exp(mt+1) exp(mt) cηt exp(mt) exp(mt) + 1 cηt 1 2 where we used ex x + 1 and ex/(ex + 1) 1/2 for x 0. Telescoping this summation, we get which yields the following logarithmic growth for the fixed learning rate schedule, i.e. ηt = η mt log(ηct + 1). When te weights are not initialized at zero, there will be a moment t0 where the margin will become positive and the same picture will hold. We can then control the growth of the loss given by x log(1 + exp( (ui uj)T Wtex))p(x) which can be expressed in terms of the margins x log(1 + exp( mt(x)))p(x). Since log(1 + e x) is decreasing, we can directly install lower bound on mt(x) and get an upper bound on the loss x log(1 + exp( log(ηcx(t t0(x)))))p(x) = X x log(1 + 1 ηcx(t t0(x)))p(x) X p(x) ηcx(t t0(x)). When W0 = 0, t0(x) 0, which implies the second part of the theorem. B.2. Multi-class - Proof of Theorem 3 For the multi-class, consider x [N] and assume that f (x) = 1. Because the dynamics decouple, we can simplify notation with w = Wex Rd, and forget about the context variable x. Let us denote pw(j) exp(w uj), ℓ(w) = log(pw(1)). Consider the gradient flow dynamics dt = ℓ(w) = X j [M] pw(j)(uj u1). Developing the probabilities, we get X j [M] exp(w uj)dw j [M] exp(w uj)(u1 uj). Let us denote wj = w, uj . X j [M] exp(wj)dw j [M] exp(wj)(u1 uj). When the (uj) are orthonormal, we can project the last equation on the (ui), which leads to the following coupled differential equations j=1 exp(wj) dw1 = j=2 exp(wj) dt, and j=1 exp(wj) dwi = exp(wi) dt i = 1. Learning Associative Memories We can rewrite it with the partition function A(w) = P exp(wj), A(w) A(w) exp(w1) dw1 = dt, and A(w) exp( wi) dwi = dt i = 1. Subtracting any two instances of the coupling equations for i, j = 1, we get the following invariant A(w) (exp( wi) dwi exp( wj) dwj) = 0 exp( wi) dwi exp( wj) dwj = 0 exp( wi) exp( wj) = exp( u i w0) exp( u j w0) =: cij The last invariant is found with A(w) dw1 = (A(w) exp(w1)) dt = X i>1 exp(wi) dt = X i>1 A(w) dwi X i [M] dwi = 0. This is the transcription of the fact that the update of w are in the span of the (ui uj) which is the orthogonal (P ui) when the (ui) are orthonormal. The first invariant allows us to characterize the partition function using only w1 and the logit of the most probable incorrect class. Let k = arg minj {2,..,M} exp( wj), hence cjk 0 for j = 1, and A(w) = exp(w1) + X j =1 exp(wj) = exp(w1) + X 1 cjk + exp( wk) = exp(w1) + X exp(wk) cjk exp(wk) + 1 = exp(w1) + (Mk + θk) exp(wk) where Mk = |{k = 1|cjk = 0}| 1, θk [0, |{k = 1|cjk > 0}|] = [0, M Mk], where we have used that when cjk > 0, exp(wk)/(ck exp(wk) + 1) exp(wk). Note that for any j = 1, we can write the differential equations for the margin as A(w) d(w1 wj) = (A(w) exp(w1) + exp(wj)) dt. In particular, for the tightest margin, we get A(w) A(w) exp(w1) + exp(wk) = A(w) exp( wk) A(w) exp( wk) exp(w1 wk) + 1 = exp(w1 wk) + Mk + θk Mk + θk + 1 . Using the bounds on θk, we get exp(w1 wk) M + 1 + Mk Mk + 1 d(w1 wk) dt exp(w1 wk) Mk + 1 + M M + 1 Let us introduce constants to ease notations, (c1 exp(w1 wk) + c2) d(w1 wk) dt (c3 exp(w1 wk) + c4) d(w1 wk). We can integrate these inequalities c1 exp(w1 wk) + c2(w1 wk) + b1 t c3 exp(w1 wk) + c4(w1 wk) + b2. This implies w1 wk = log(t)(1 + o(1)). Learning Associative Memories Let us denote h(t) = exp(w1), using the first invariant, we get exp(w1 wk) = exp(w1)(exp( wj) + ckj) = exp(w1 wj) + ckjh(t) = ctt(1 + o(t)) + ckjh(t), for ct [1/c3, 1/c1] a bounded function. We can characterize h(t) with the last invariant X i [M] wi =: C = o(1). The previous equations are solved with M log(t)(1 + o(1)), w1 = M 1 M log(t)(1 + o(1)), which leads to ckjh(t) = ckjt1 1/M = o(t), and, since the (ui) are orthonormal, i [M] w, ui ui = X i [M] wiui = X We can simplify the last equation by realizing that it is proportional to the projection of u1 on the span of the ui uj, which is also the span of the u1 ui. If we denote Π the projection on this span, we have the existence of (bi) such that i>2 bi(u1 ui), and since (Π(u1) u1) (ui uj) = 0 for all i, j > 1, we deduce bi = bj = b. The value of b can be computed explicitly, the triangle formed by 0, u1 and Π(u1) is both isosceles and rectangular, which leads to Π(u1) u1) = Π(u1) , 1 = ui 2 = Π(u1) u1 2 + Π(u1) 2, hence Π(u1) = 1/ 2. We also have Π(u1) = Π(u1) u1 = (M 1)b, from which we deduce that the proportionality constant in Theorem 3 is (M 1)/ INDICATIONS FOR A PROOF IN THE CASE OF GRADIENT DESCENT For gradient descent, we expect the same theorem to hold for two simple reasons, which, for simplicity, we do not formalize. By convexity, there could only be one directional convergence for gradient descent (regardless of the initialization), and it has to be the same as the one for gradient flow. Because the level lines of loss are exponentially spaced, for any fixed learning rates, gradient descent will become a finer and finer approximation of gradient flow as W grows large. Another way to proceed is to retake the previous arguments in the discrete setting. For example, when initializing gradient descent with W = 0, one can check by recurrence that exp(wi) = exp(wj) for all i, j = 1, this allows reducing the dynamics to a scalar evolution, which can be treated as in Theorem 2. C. Dynamics with two particles interfering In the setting of Theorem 4, Theorem 1 plus a few lines of omitted derivations lead to the couplings L(W), (uj ui) ej = u2 u1 2 p(i)α 1 + exp((ui uj) Wei) p(j) 1 + exp((uj ui) Wej) We remark that if α 0, there is no competition between the memory associations, the dynamics always advances in the cone R+ e1 + R+ e2, reinforcing both associations simultaneously. Let us introduce the margin mj = (uj ui) Wej = W, (uj ui) ej , for {i, j} = {1, 2} . Learning Associative Memories The previous equation can be rewritten as L(W), (uj ui) ej = u2 u1 2 p(i)α 1 + exp(mi) p(j) 1 + exp(mj) For the gradient flow, it leads to the evolution dmj = L(W), (uj ui) ej = u2 u1 2 p(j) 1 + exp(mj) p(i)α 1 + exp(mi) Similarly, if we define the orthogonal vectors f1 = e1 + e2, f2 = e1 e2, as well as the statistics, 2(u1 u2) Wfi, we get m1 = γ1 + γ2 and m2 = γ2 γ1, and γ1 = (m1 m2)/2, γ2 = (m1 + m2)/2. Hence, u2 u1 2 dγ2 u2 u1 2 dm1 + dm2 = p(1) 1 + exp(m1) p(2)α 1 + exp(m2) + p(2) 1 + exp(m2) p(1)α 1 + exp(m1) = p(1)(1 α) 1 + exp(m1) + p(2)(1 α) 1 + exp(m2) = (1 α) p(1) 1 + exp(γ1 + γ2) + p(2) 1 + exp(γ2 γ1) u2 u1 2 dγ1 u2 u1 2 dm1 dm2 = p(1) 1 + exp(m1) p(2)α 1 + exp(m2) p(2) 1 + exp(m2) p(1)α 1 + exp(m1) = p(1)(1 + α) 1 + exp(m1) p(2)(1 + α) 1 + exp(m2) = (1 + α) p(1) 1 + exp(γ1 + γ2) p(2) 1 + exp(γ2 γ1) This explains the evolution in the main text. We see that γ2 will grow at least logarithmically, while γ1 will be contained eventually because of the growth of γ2. C.1. Proof of Theorem 4 We start by focusing on the gradient flow dynamics. Recall that the evolution of the max-margin and orthogonal directions γ2 and γ1 is given by the following ODEs: dct = (1 + α)p1 1 + exp (γ2 + γ1) (1 + α)p2 1 + exp (γ2 γ1) (32) dct = (1 α)p1 1 + exp (γ2 + γ1) + (1 α)p2 1 + exp (γ2 γ1). (33) LOWER BOUND IN THE MAX-MARGIN DIRECTION γ2 From the evolution equation (33) of the margin direction γ2 we deduce, using the fact that either eγ1 1 or e γ1 1 for all γ1 R, dγ2 dct (1 α) min(p1, p2) 1 + exp (γ2) = (1 α)p2 1 + exp (γ2), Learning Associative Memories since we have assumed without restriction that p1 p2. We have solved the differential equation in Appendix B.1 (with a different constant). Using Grönwall s inequality, integrating this out leads to, when initialized at W0 = 0, γ2 log(c1t + 1) h(c1t + 1). where c1 = (1 α)cp2 and h as defined in Appendix B.1, i.e. h(x) = h(x) log(x)/x with h [1/2, 2]. UPPER BOUND IN THE ORTHOGONAL DIRECTION γ1 Let us now consider γ1. First, note that whenever γ1 0, then we have dγ1 0, thanks to our assumption p1 p2. In particular, with zero initialization W(0) = 0, we then have γ1(t) γ1(0) = 0 throughout. Let us know look for an upper bound. For γ1 to grow, we need dγ1 0. Denoting γ := log( p p1/p2), this only possible when p1 1 + exp (γ2 + γ1) p2 1 + exp (γ2 γ1) 0 p1 p2 + p1 exp (γ2 γ1) p2 exp (γ2 + γ1) (p1 p2) exp( γ2) p2 exp(γ1) p1 exp( γ1) (p1 p2) exp( γ2) p1p2(exp(γ1 γ) exp( γ1 + γ)) sinh(γ1 γ) (p1 p2) exp( γ2) C(γ2) := (p1 p2) exp( γ2) 2 p1p2 . (34) We thus have sinh(γ1 γ) C(γ2) dγ1 0. This implies that gradient flow will be bounded. In particular, the lower bound on γ2 gives us exp( γ2) exp(h(c1t + 1)) c1t + 1 exp(2/e) We conclude that, when γ1 is initialized at zero, we have that dγ1(0) 0, and γ1 will grow until reaching the point where dγ1 0, which leads to a bound on γ1 characterized by sinh(γ1 γ) C(γ2) p1 p2 This yields γ1(t) γ + sinh 1 p1 p2 p1p2 1.05 c1t + 1 + p1 p2 p1p2 1.05 c1t + 1, using that sinh 1(x) x for x 0. UPPER BOUND IN THE MAX-MARGIN DIRECTION γ2 We can upper bound γ2 based on Equation (33), dct 2(1 α) max(p1, p2) 1 + min(exp(γ1), exp( γ1)) exp (γ2) = 2(1 α)p1 1 + exp( γ1) exp (γ2). We have seen that exp(γ1) rp1 p2 exp(sinh 1(p1 p2 1.05 c1t + 1)) = rp1 1.05 c1t + 1 + 1.05 (c1t + 1)2 + 1 1.05(p1 p2)2 Learning Associative Memories We deduce that dγ2 dct 2(1 α)p1 2 exp (γ2) 8(1 α)p3 1/p2 2 1 + exp (γ2) . This allows us to conclude that γ2 does not grow faster than logarithmically. γ2 log(c2t + 1) h(c2t + 1). with c2 = 8c(1 α)p3 1/p2 2. Using the intermediate value theorem, we deduce the form of γ2 given in the theorem. LOWER BOUND IN THE ORTHOGONAL DIRECTION γ1 If γ1 was initialized such that dγ1 0, we would have that γ1 would decrease until reaching the point found in the upper bound for γ1. The difficulty consists in showing that γ1 increases fast enough toward γ. Retaking the derivations made to characterize the sign of dγ1, we can rewrite the evolution equation as dct = (1 + α)p1 p2 + 2 p1p2 exp(γ2) sinh( γ γ1) (1 + exp (γ2 + γ1))(1 + exp (γ2 γ1)) = (1 + α) p1 p2 + 2 p1p2 ctt sinh( γ γ1) (1 + ctt exp (γ1))(1 + ctt exp ( γ1)), where ct/ct [exp(e/2), 1] is found with the intermediate value theorem, and we have used that exp(γ2) = ctt exp( h(ctt + 1)) ctt [exp(e/2), 1]. We can lower bound the growth of γ1 when γ1 γ, which implies dγ1 0. Using that γ1 is bounded, we get the existence of a constant c3 such that dct (1 + α)p1 p2 + 2 p1p2 ctt sinh( γ γ1) (1 + c3t)2 (1 + α) p1 p2 (1 + c3t)2 . This leads to a growth of γ1 in O(1/t), from which we deduce that γ1 γ + O(1/t), which ends the characterization of the dynamics for gradient flow. GRADIENT DESCENT The dynamics of γ1 and γ2 for gradient descent with a step-size η are given by γ1(t + 1) = γ1(t) + ηc γ1, γ2(t + 1) = γ2(t) + ηc γ2, γ1 = (1 + α)p1 1 + exp (γ2 + γ1) (1 + α)p2 1 + exp (γ2 γ1) (35) γ2 = (1 α)p1 1 + exp (γ2 + γ1) + (1 α)p2 1 + exp (γ2 γ1). (36) Similar to gradient flow, we can lower bound the update equation of γ2 for descent, with c1 = (1 α)cp2 γ2(t + 1) γ2(t) ηc1 1 + exp (γ2). Since we expect logarithmic growth from the study of flow, we want to study exp(γ2(t)). In particular exp(γ2(t + 1)) exp ηc1 1 + exp (γ2) Learning Associative Memories Using ex 1 + x, we furthermore get exp(γ2(t + 1)) exp(γ2(t)) ηc1 exp(γ2(t)) 1 + exp (γ2(t)). Since γ2 is always non-negative we have that exp(γ2)/(1 + exp(γ2)) 1/2 hence we get a recursion exp(γ2(t + 1)) exp(γ2(t)) ηc1/2, (37) Using a telescopic sum, we get the desired lower bound γ2(t) log(ηc1t/2 + 1). Let us know focus on γ1. In comparison to gradient flow, γ1 can grow large because of potentially large steps taken from values where γ1 is positive. Similar to the gradient flow case, we have that γ1 0 if and only if sinh(γ1 γ) C(γ2), where C(γ2) is given in (34). Now consider a time t. If sinh(γ1(t) γ) C(γ2(t)), we have γ1(t) γ1(t + 1) γ1(t) + ηc γ1 γ + sinh 1(C(γ2(t))) + ηc(1 α)p1 γ + C(γ2(0)) + ηc(1 α)p1 γ + p1 2p2 + ηc(1 α)p1 =: γmax. If sinh(γ1(t) γ) C(γ2(t)), then γ1(t + 1) γ1(t), and γ1(t + 1) γ1(t) + ηc γ1 γ ηc(1 α)p2 =: γmin, where γ1(t) γ follows from sinh(γ1(t) γ) 0. By induction, we then have that γ1(t) [min(0, γmin), γmax] for all t (assuming γ1(0) = 0). For simplicity, we skip the upper bound on γ2, as well as the convergence of γ1 towards γ. INDICATIONS TO PROVE THAT γ1 CONVERGES TO γ Note that in the case of gradient descent with large learning rates, γ1 might be oscillating around γ. This case does not happen in gradient flow and requires extra derivations to handle it. Using the fact that γ2(t) log(ηct + 1), we get C(γ2(t)) p1 p2 1 ηc3t + 1. If sinh(γ1(t) γ) C(γ2(t)), we then have γ1(t) γ1(t + 1) = γ1(t) + ηc γ1 1 ηc3t + 1 + ηc(1 α)p1 1 + exp(γmin)ηc3t = γ + O(1/t) If sinh(γ1(t) γ) C(γ2(t)), then γ1(t) γ1(t + 1) = γ1(t) + ηc γ1 γ ηc(1 α)p2 1 + exp( γmax)ηc3t = γ + O(1/t). This ensures that when the dynamics are in an oscillating regime, the bound on |γ1(t) γ| will decrease as O(1/t), thus inducing faster progress towards perfect accuracy than as guaranteed by the looser bound [γmin, γmax]. C.2. Loss Spike Proposition 5 follows from L(W1) p2ℓ(W1; 2, 2) = p2 log(1 + exp( m2)) p2m2. In particular, when initialized at zero, after one gradient update m2 = η(p2 αp1). Learning Associative Memories D. Transformer experiments In this section, we provide more details on the setup for the transformer experiments in Section 5.3. We follow Bietti et al. (2023) and consider a simplified two-layer Transformer architecture trained on a simple in-context learning task. The task consists of sequences of tokens z1:T [N]T , where any occurrence of a so-called trigger token q Q is followed by the same output token oq, but where oq is resampled uniformly across different sequences. The tokens following all non-trigger tokens are randomly sampled from a sequence-independent Markov model (namely, a character-level bigram model estimated from Shakespeare text data). We focus on the prediction of the output tokens oq given a sequence [z1, . . . , q, oq, . . . , q], where we assume q has appeared at least once before the last token. Correctly predicting the token oq then requires finding previous occurrences of q in the input sequence and copying the token just after it. Bietti et al. (2023) show that this task can be solved with a two-layer transformer with no feed-forward blocks, and all layers fixed at random except three trained matrices, by implementing an induction head mechanism (Elhage et al., 2021; Olsson et al., 2022). The three trained matrices were found to behave as associative memories, each with different sets of embeddings, as we now detail: W 1 K (first layer key-query matrix), which implements a previous token lookup, satisfying arg max j pj, W 1 Kpt = t 1, where pt are positional embeddings; W 2 K (second layer key-query matrix), which implements lookup of the previous trigger that matches the current token, with arg max j ej, W 2 Kφk(ei) = i, where ei are input token embeddings, and φk(ei) = W 1 OW 1 V ei is a remapping of the input embeddings by the first attention head; W 2 O (second layer output matrix), which implements a copy of the output token into the unembedding space, with arg max j uj, W 2 Oφo(ei) = i, where uj are output embeddings, and φo(ei) = W 2 V ei is a remapping of input embeddings by the (random) second value matrix. We may then define the W 2 O margins (the ones for W 1/2 K are defined analogously): mi = ui, W 2 Oφ(ei) max j =i uj, W 2 Oφ(ei) . As explained in (Bietti et al., 2023), we note that input embeddings to each matrix are often sums/superpositions of embeddings, some of which are typically noise that gets filtered out during training. For instance, training W 2 O alone may recover the desired associations in high-dimension, even though its input at initialization is an average over all tokens in the sequence, due to the initially flat attention pattern. Our training setup is the following: we consider full-batch gradient descent on a dataset of 16 384 sequences of length 256 generated from the model described above with N = 64 tokens. The loss considers only predictions on tokens oq, ignoring the very first occurrence since it is not predictable from context.