# scaling_laws_for_associative_memories__2ed4a7d8.pdf Published as a conference paper at ICLR 2024 SCALING LAWS FOR ASSOCIATIVE MEMORIES Vivien Cabannes FAIR, Meta Elvis Dohmatob FAIR, Meta Alberto Bietti Flatiron Institute Learning arguably involves the discovery and memorization of abstract rules. The aim of this paper is to study associative memory mechanisms. Our model is based on high-dimensional matrices consisting of outer products of embeddings, which relates to the inner layers of transformer language models. We derive precise scaling laws with respect to sample size and parameter size, and discuss the statistical efficiency of different estimators, including optimization-based algorithms. We provide extensive numerical experiments to validate and interpret theoretical results, including fine-grained visualizations of the stored memory associations. 1 INTRODUCTION As the scale of large language models (LLMs) keeps increasing, scaling laws have become a crucial tool to empirically assess and predict the behavior of these models when varying the number of parameters and training data (Kaplan et al., 2020; Hoffmann et al., 2022). Despite their practical impact, the underlying phenomena leading to such scaling laws remain poorly understood. A better understanding of such phenomena could guide researchers towards improved models, algorithms, and datasets which may lead to improved scaling laws. Our study focuses on a simple model that aims to be representative of LLMs in two ways. First, we focus on heavy-tailed data distributions over discrete tokens, a natural assumption for text data (Piantadosi, 2014). Second, we consider associative memory models that store input-output pairs through outer-products of finite-dimensional embeddings, and can be seen as a proxy of the intermediate layers of transformers. Indeed, some transformer layers have been found to behave as key-value memories (Geva et al., 2021; Meng et al., 2022), and more generally outer-product associative memory matrices arise naturally from training dynamics on intermediate weights (Bietti et al., 2023). Beyond simple associative recall, the combination of multiple such associative rules at different layers may lead to certain circuits with rich reasoning behaviors based on context (Elhage et al., 2021; Bietti et al., 2023; Michaud et al., 2023). For example, an intermediate layer input token may encode for the topic linux , leading to an output token that will trigger a specific behavior in the transformer s following layers when processing the token terminal . Our contributions are as follows: We provide precise statistical rates for outer-product memories with random embeddings, and compare different memory storage schemes in the context of Zipf-distributed data. We compare theoretical schemes to the weights learned by various optimization algorithms used in practice, and illustrate the role of different design choices with numerical experiments. Related work. Associative memory models have a long history in the literature on neural computation (Steinbuch, 1961; Willshaw et al., 1969; Longuet-Higgins et al., 1970; Kohonen, 1972; Amari, 1972; Little, 1974; Hopfield, 1982; Smolensky, 1990; Schlag et al., 2021; Valle-Lisboa et al., 2023), though the statistical insights we provide for continuous-values random embeddings and heavy-tailed tokens distributions are new, to the best of our knowledge. Memorization behaviors have drawn a lot of attention recently, and are believed to be an important notion to understand the learning happening in deep neural network (e.g., Sukhbaatar et al., 2019; Feldman, 2020; Feldman & Zhang, 2020; Geva et al., 2021; Wu et al., 2022). Building on memorization and heavy-tailed discrete data, our model bears similarities to the ones of Hutter (2021), Michaud et al. (2023) or Debowski (2023), although we focus on practical models with finite capacity. The discrete nature of Published as a conference paper at ICLR 2024 101 102 103 101 102 103 104 d = 50 d = 100 d = 200 d = 1000 Figure 1: Scaling laws with respect to model capacity d (left), respectively the number of data seen T (right), for various numbers of dataset size T, respectively various model capacity d. This plots validates empirically the theory developed in the paper that proves scaling laws in E(fq) d α+1 + T 1+1/α (dashed lines) under our setting with α = 2 (1), (2), (5), and the association scheme (12) with ρ = 0 and P = d/8. The experiments averaged over 100 runs, standard deviations are shown with solid color. tokens contrasts with other recent works on scaling laws that have focused on continuous Gaussian inputs (e.g., Bahri et al., 2021; Maloney et al., 2022; Sorscher et al., 2022). 2 MODEL FOR ASSOCIATIVE MEMORY The data. In the following, we consider a joint distribution p [N] [M] on inputs x [N] and outputs y [M]. The inputs and outputs are respectively assumed to solely take N and M discrete values respectively. For example, N could be the number of potential sequences of fixed word length in the English language, while M would be all the potential words to complete the sequence. Abstractly, x and y will be referred to as tokens. To simplify the study, we assume for now that y is a deterministic function of x, i.e., there is no noise in the labels. In consistency with language modeling, we equally assume that p(x) follows a Zipf law. Formally, there exists an parameter α > 0, a normalizing constant Cα, a permutation σ Sn and a function f : [N] [M] such that x, y [N] [M], p(σ(x)) = Cαx α, p(y|x) = 1y=f (x). (1) The distribution p is not known, but has generated T known independent samples (xt, yt)t [T ] p. For readability sake, we will assume without restriction that σ is the identity (so that p is decreasing). The model, and the loss. The input tokens are embedded into a space Rd of dimension d through an embedding map e : [N] Rd. This space is used for computation purposes. In particular, we focus on the linear transformation parameterized by a matrix W Rd d mapping x to We(x). This latter vector is mapped back to the output space through an unembedding map u : [M] Rd and the decoding rule f W (x) = arg max y [M] u y Wex, W Rd d, (2) where ex and uy are abbreviations for e(x) and u(y). The model (2) can be seen as analogous to an attention layer where keys ex are tested against queries uy through a matrix W before going through a softmax layer, which, when the attention is peaky, identifies to an argmax. It also resembles nexttoken prediction from an intermediate representation Wex, which may itself be the output of an attention block that attends to a token x. The matrices W will be expressed as associative memories. Memory of an observed pair (x, y) is represented as an outer product uye x . Remembering those with respect to a probability q [N] [M] leads to the matrix (x,y) [N] [M] q(x, y)uye x , q [N] [M], (3) This representation (3) is justified as the predictions (2) are insensitive to modifications of M outside the span of (uye x )x,y. In our deterministic setting (1) where one only observes pairs (x, f (x)), we shall consider the simpler model where1 x [N] q(x)uf (x)e x , q [N]. (4) 1It should be noted that the proof techniques behind Theorem 1 do not break when considering q = q(x, y): both models would lead to similar results, with the case q = q(x, y) being simpler to comprehend. Published as a conference paper at ICLR 2024 Table 1: Summary of key elements in the study. We are given discrete tokens x, y with deterministic relation y = f (x). We embed tokens in Rd, d acts as a model capacity parameter. We store association x y in the matrix W through a scheme q and recall them through the decoding fq. We will first study the scaling law of the generalization error E as a function of the number of data T, and the model capacity d for different schemes q. We will later study the scheme q found by optimization-based algorithms. Tokens Embeddings Model Scaling yit = f (xit) ex, uy Rd W = P x q(x)uf (x)e x E(q) = E[1fq(x) =f (x)] t {1, 2, . . . , T} ex N(0, I) fq(x) = arg maxy uy Wex E(q) = F(d, T; q) Table 2: Some insightful provable scaling laws with respect to the model capacity d, and the number of data T, for two schemes that store associations as (4) and random embeddings. Model Error scaling Comment q(x) = p(x) d (α 1)/2α + T 1+1/α Found with large batches in one step q(x) = 1x d d α+1 + T 1+1/α Optimal scaling with random embeddings To simplify notations, we will write fq for f Wq (2). The model fq is seen as superposing memories since all associations are mixed together in a single matrix. The quality of a mapping f is quantified through the generalization error E(f) = E(X,Y ) p[1f(X) =Y ], f : [N] [M]. (5) Which questions are we interested in? Several questions naturally arise from our model. The first ones are related to scaling laws: how does the error depend on T, the number of data? How does it scale with d that encodes for model capacity? The second ones relate to the model itself: how does the error behave for different q? What about optimization-based algorithms? Arguably, the model (2) lays out a simple model to study memorization, which could easily be extended to model more intricate memorization and training behaviors inside a transformer language model. Indeed, memories of the form (4) were found to accurately model the behavior of weight matrices in multi-layer transformers trained by gradient methods on certain tasks (Bietti et al., 2023). Hence, we expect our study to be generalizable to more complex mechanisms in transformers, resulting in rich token interactions to predict the next token in a sequence. 3 SCALING LAWS WITH RANDOM EMBEDDINGS Why do we make errors? With a simple deterministic model, one may wonder how can we not learn perfectly the mapping f . There are two sources of error. One is due to not having enough data to see all the potential association (x, f (x)), and has already been studied by Hutter (2021). The other one is due to the limited memory capacity of our model, which we illustrate in Figure 2. Proposition 1 (Finite data, infinite memory). Consider a infinite memory model ˆf, which at time T predicts correctly all x that were seen in the past training, i.e., x {Xt}t [T ], where the (Xt, Yt) were drawn independently at random from a distribution p [N] [M]. Under the data model the generalization error reads, with respect to the random dataset DT = (Xt, Yt)t [T ], EDT [E( ˆf)] T 1+1/α. (6) Here, the notation a b means that there exist two constants c1 and c2 such that c1b a c2b. 3.1 TIGHT ERROR CHARACTERIZATION The case where one has infinite data but finite memory is intrinsically a deterministic problem. However, characterizing interferences between embeddings and the corresponding generalization error is combinatorial in nature, and is hard to study without specific assumptions on the embeddings e and u. A natural choice is to consider them to be random, as is the case at initialization. Theorem 1 (Infinite data, finite memory). Let M 4 and d > 8 log(M). For any memory weight scheme q : [N] R, when the embeddings ex are independent random variables ex N(0, I), and the unembeddings are taken uniformly at random on the sphere, Ee,u[E(fq)] inf γ 2d γ + p n x [N] dq(x)2 16cγ Q + 8cγ q 2 2 d Published as a conference paper at ICLR 2024 Figure 2: Error due to finite memory capacity: the stacking of associative memories in a matrix W may exhibit a pattern W = x uf (x)e x where three inputs mapped to three different outputs interact in such a way that u 2 We1 = e 2 e1 + u 2 u3e 3 e1 1 + u 1 u3e 3 e1 = u 1 We1, so that f W (x = 1) = 2 = 1 = f (x = 1). In other terms, memory interference may lead to wrong prediction, illustrating the finite capacity of the model f W (2) to store all data associations. where Q := maxy P x;f (x)=y q(x)2, cγ = log(M) + γ log(d), and p(X) = P x X p(x) denotes the probability of x to belong to X [N]. In terms of lower bound, Ee,u[E(fq)] 1 20p({x [N] | 3(d + 1)q(x)2 Q }). (8) Theorem 1 illustrates how the error made by a scheme q at the input x relates to the ratio between the signal dq(x), provided by the associative memory uf (x)e x , and the noise Q , which corresponds to the signal provided by the most competitive class for y [M]. This is true up to a higher term in q 2/d, which corresponds to a class y = f (x) competing against itself when the random embeddings ex for x such that f (x ) = y point in the opposite direction of ex. When d is large and p is regular, cγ q 2 2/d will be dominated by Q and the cut-off of q(x)2/Q at 32cγ/d will behave similarly to a cut-off at 1/d up to logarithmic terms. Moreover, when q is chosen independently of p(y|x),2 one can expect Q p q 2 where p = maxy [M] p(y). As a consequence, up to constants and logarithmic term, we get E[E(fq)] p({x [N] | dq(x)2 p q 2}). (9) 3.2 MEMORY SCHEMES Let us now discuss several natural choices for q and compare their corresponding performance. The first naive choice consists in storing all the data seen at time T in memory. It reads ˆq0(x) = 1x {Xt}t [T ], q0(x) = 1. (10) Here, ˆq0 corresponds to the learned weighted scheme based on the T data, while q denotes an idealized limit when one has infinite data. In the idealized setting Q (q0) = Np where p := maxy [M] p(y). From Theorem 1, we deduce that E(f Wq0 ) will follow two regimes: an overflow regime where 3(d + 1) Np and in essence the memory Wq0 is too full to recover any signal in it, and Ee,u E(f Wq0 ) > 1/20 (8); a infinite memory regime where d N and all associations exu f (x) can be stored orthogonally to one another, and the error Ee,u E(f Wq0 ) quantifies the tiny probability that some random inputs embeddings appear to be too correlated. Equipped with the knowledge that our associative memory model (2) has finite capacity, one may weight memories according to their frequencies, leading to the scheme, for ρ 0 t [T ] 1x=Xt ρ , qρ(x) = p(x)ρ. (11) A better option consists in explicitly limiting the storage of our model with a simple thresholding algorithm ˆqρ,[P ](x) = ˆp(x)ρ1x top P ((xt)t [T ]), qρ,[P ](x) = p(x)ρ1x [P ], (12) where top P ((xt)) denotes the set made of the P most frequent inputs in the data (xt). 2To be more precise, one should actually choose q(x) to be class dependent so to cram in memory as many x as possible for each different class y = f (x), ensuring that y 7 x;f (x)=y q(x)2 is constant with respect to y. For simplicity, we will not discuss this behavior that does not change the big picture beyond our exposition. Published as a conference paper at ICLR 2024 q = 1 q = p q = 1x d/8 Figure 3: Generalization error (5) as a function of d and T for the model (4) averaged over 100 runs. The data follows a Zipf law with α = 0.5, N = 100, M = 5 and f (x) = x mod. M. Left: error for q0 (10), either d is too small and there will be memory overflow leading to large error (red area), either it is big enough and with enough data, the error will be null (blue area). Middle: error for q1 (11), for small d and big T, it avoid memory overflow allowing a smaller error then q0; however for big d it does not allocated enough memory to rare association, leading to a bigger error. Those results can be interpreted mechanistically by looking at the corresponding memory matrices (see Figure 10). Right: Generalization error when T = + , N = 100 and α = 2: the scheme q0 leads to a zero-one type of plot where if d < N the error is high, and if d > N the error decreases fast to zero (in blue); the scheme q1 leads to an error decreasing in d (α 1)/2α = d 1/4 as predicted by theory (in orange); the scheme q0,P (12) with P = d/8, decreases in d (α 1) = d 1 until reaching the tipping point when d/8 > N (in green). Proposition 2 (Without thresholding). Let p be an α-Zipf distribution (1). For ρ > 0, the performance of fρ := fqρ (11) is, up to poly-logarithm factors and constants that depends on both ρ and α, Ee,u E(fρ) (log) d φ(N) (α 1)/2ρα , where φ(N) = 1 if 2ρα > 1 log(N) if 2ρα = 1 N 1 2ρα if 2ρα < 1 . (13) In particular, when ρ = 1, Ee,u E(f0) scales in d (α 1)/2α. In the limit where ρ = 0, Ee,u E(f0) can be understood as (d/N) which will go to zero if and only if d is bigger than N. Proposition 3 (With thresholding). Assume that p(x) follows a α-Zipf law (1) with N = + . For ρ 0, setting P d1/(2αρ+1), the error made by the memory scheme (12) scales as Ee,u E(fρ) (log) d (α 1)/(2ρα+1). (14) In particular, when ρ = 0 and P d, one gets a scaling in d α+1, which is actually optimal. The fact that this maximum is reached for P d is reminiscent of Hopfield networks (Hopfield, 1982) which can only store d/ log(d) patterns with a d by d matrix. Similarly, our model stores at most d associations, which, when in presence of a Zipf law, leads to an error scaling in d (α 1). Theorem 2 (Minimax performance). Assume that p(x) follows a α-Zipf law (1) with N = + . For any weighting scheme q, and p (0, 1), there exists a conditional distribution p(y|x) with p = maxy p(y) such that the error made for the distribution p is lower bounded by Ee,u E(fq) cα(d + 1) α+1 where cα = Cαpα 1 20(α + 1) 3α 1 . Moreover, this performance is reached (up to logarithms factor) by the thresholding algorithm (12) with P d/ log(d) and ρ = 0. Finally, we prove that the scaling laws proved for d when T = + and for T when d = + appears jointly when both d and T are finite. Proposition 4 (Finite data and finite memory). For the previous bound with respect to d, Proposition 2 and Proposition 3, considering finite data simply adds a term T 1+1/α (up to constants and logarithmic terms), matching the optimal bound of Proposition 1. In particular, (12) with ρ = 0 and P d/ log(d) reaches the optimal scaling in Ee,u,(xt,yt)t [T ]E(fˆq) T 1+1/α + d α+1. (15) The optimal scaling (15) recovers the law of Hutter (2021) with respect to T, and the one of Michaud et al. (2023) with respect to d. This is intuitive, since Hutter (2021) assumes memorizing exactly Published as a conference paper at ICLR 2024 0 500 1000 Epochs real approx 0 500 1000 Epochs real approx 0 500 1000 Epochs real approx Figure 4: Comparison between the error found by optimizing W (2) with SGD on the cross-entropy loss, and its approximation with q(x) (4) and the approximate update rule (20). We consider N = 100, M = 5, f (x) = x mod. M, α = 2, and batch size equals one. Left: One run with d = N = 100 with γ = 10. Middle: Average over 100 runs with d = N = 100 with γ = 1. Right: Average when d = N/10 = 10 with γ = 1, which implies that our approximation is not valid anymore. The same results can be obtained for bigger batch sizes as shown in Figure 13. f n(0)/ f 2 Figure 5: Theoretical approximation of the association scheme found with stochastic gradient descent with batch size equals one and fixed learning rates. Left: Plot of f n(0) as a function of n where f is the effect of one gradient update on q(x) (20). Right: Plot of the resulting qγ(x) when nx p(x) (x + 3) α with α = 2 and n N = 1. In dashed, we represent qρ (11) for ρ = 0.05, ρ = 0.35 and ρ = 1. Those curves map well qγ for γ = 10, γ = 10 1 and γ = 10 3 respectively. all previously seen data, while each memory could be seen as specifying a quantum of knowledge as modeled in Michaud et al. (2023), with d α+1 corresponding to the risk (5) of only storing the most frequent d tokens. However, associative memories can be understood at different level of granularity, and while one may argue that a transformer acts as a big associative memory machine and derives LLMs scaling laws approximations as corollaries, we prefer to understand a transformer as a combination of hidden associative memories as suggested by Sukhbaatar et al. (2019); Geva et al. (2021); Wu et al. (2022); Bietti et al. (2023) among others. 4 OPTIMIZATION-BASED MEMORIZATION This section studies memory schemes privileged by optimization-based algorithms, digging into the training dynamics behind memorization. In terms of relevance, we argue that our model (2) is a proxy for the inner layers of a transformer that memorize patterns before matching them against new data at inference time. As such, we want to understand how different key elements in the training of a transformer influence storage in our memory model. Gradient updates. We consider the cross entropy loss as a surrogate objective to minimize, and study the form of gradient updates on batches of data. Formally, the matrix W Rd d in (2) is optimized to minimize the loss L(W) = E(X,Y ) p[ℓ(x, y; W)], ℓ(x, y; W) = u y Wex + log( X z [M] exp(u z Wex)). (16) The gradient of this loss with respect to W takes the following form, as detailed in Appendix A.10: W ℓ(x, y; W) = (1 p W (y|x))(uy ε)e x , with ε = X z [M] p W (z|x, z = y)uz. (17) where p W (y|x) exp(u y Wex) are model predictions for the current W. For a batch of n data B = [x1, , xn], a gradient update with step size γt updates Wt as Wt+1 = Wt γt X x B W ℓ(x, f (x); Wt). (18) Published as a conference paper at ICLR 2024 Figure 6: Gradient descent dynamics from perspective of the matrix (u y Wtex)y,x RM N with N = 10, M = 5, α = 1.5, f (x) = x mod. 5, and d = 5 < N. A lighter color in the square (y, x) means a higher value of u y Wex. The optimal W corresponds to two diagonal strips of yellow boxes (see Figure 15). The matrix Wt is updated with stochastic gradient descent with batch size equal to one. From time to time, stochastic gradient descent will hit an association that is not properly stored in memory yet (the red boxes). It will consequently update the weight matrix Wt Wt+1 (side by side pairs) to store it (18). Left pair: update with a big learning rate γ = 10, whose risk is to erase previous memories (the light colored boxes), similarly to q0 (10). Right pair: update with a small learning rate γ = 10 1, which will not store rare memory, similarly to qρ (11) with large ρ. Approximation of the updates. When p W (z|x) does not change much for all z = f (x), since uz were sampled at random in Sd, we expect ε (17) to concentrate around zero with ε 2 1/M, hence to be negligible in front of uf (x). As a consequence, W ℓ(x, f (x); W) (1 p W (f (x)|x))uye x . (19) This is notably the case for W = 0, random W, or if W only stores pairs (x, f (x)) with d N. With the update model above (19), T steps of SGD with batch size one lead to an association scheme of the form (4) with (see Appendix A.11) qγ(x) f T p(x)(0) = f f f | {z } T p(x) times (0), where f : x 7 x + γ 1 + M 1 exp(x). (20) This equation tells us what form to expect for q for optimization schemes with different hyperparameters. This approximation is shown in Figure 5, and is validated empirically in Figure 4. Step size effect. When d > N, the updates approximation (20) and the resulting qγ show how a large learning rate γ is beneficial for our problem, in particular when using SGD with batch size one. Interestingly, the same behavior holds in the presence of limited capacity, i.e., d < N, although interferences between embeddings (Figure 2) break our approximation (19). In those settings, we resort to numerical simulation to study how optimization manages to rearrange memories. Figure 6 showcases two types of behaviors depending on the size of γ. (i) When the learning rate γ is large, associations will be stored easily in memory, but will tend to overwrite previous storage. (ii) When the learning rate γ is small, associations need to be seen often to build up in the matrix W (4) which will take more time, but will not erase memory. This provides another intuition explanation for why a bigger step size leads to better results on the left of Figure 7. The previous considerations also explain the usefulness of scheduling in our simple model, which we illustrate on Figure 11: using a large learning rate enables us to store associations while there is still memory space, while reducing it later in training avoids overwriting previous storage unless an association is highly frequent. Batch size effect. Table 2 recalls how storing associations with q = 1 under the model (4) is better than storing them with q = p. As such, it suggests that, when processing a finite number of data T, smaller batch size is preferable. Intuitively, processing an input x in a batch will reweight it by its frequency p(x), while processing it by itself will update W similarly to setting qγ(x) = 1 if x has not been already seen. Indeed, in the large batch limit where |B| + , one batch update corresponds to a population gradient update, which when p W 1 assimilates to W L(W) P x p(x)uf (x)e x . This contrasts with many small batch updates that rather lead to an association scheme akin to (4) with q = 1. In support of this line of reasoning, Figure 7 (middle) illustrates the benefits of splitting the descent with many steps, with a small batch size and large step size. 4.1 PRACTICAL CONSIDERATIONS In order to optimize our simple model the fastest, we have seen the usefulness of large step size and small batch size. However, for large transformers such design choices are impractical. First, large step sizes may lead to instability in realistic models (Gilmer et al., 2021). Second, in order to reduce training time and improve hardware efficiency, one should process large batches (Smith et al., 2018). Published as a conference paper at ICLR 2024 101 102 103 SGD, |B|=64, T=10240 γ=0.1 γ=1.0 γ=10.0 γ=100.0 101 102 103 SGD, T=1024 |B| = 16, γ = 1 |B| = 1024, γ = 10 101 102 103 γ=1.0, |B|=1024, T=10240 SGD Adam SGD+LN Adam+LN Figure 7: Effect of step size, batch size, layer-norm and Adam (with β1 = β2 = 0, which corresponds to Sign GD). All the experiments are conducted with N = 100, M = 5, α = 2, f (x) = x mod M, averaged over ten runs. We initialized parameters and rescale learning rates to ensure maximal feature updates, as explained in Appendix B.1. To avoid confounders, we scale γ on the middle plot for the variance of the gradient updates to be independent of the batch size. Adam. We have seen before how the update of SGD with large batch can be approximated with γ 1 t (Wt+1 Wt 1) = X x B (1 p W (f (x)|x))uf (x)e x X x N |B|(1 p W (f (x)|x))p(x)uf (x)e x . Those naive updates would lead to a model that resembles (4) with q = pρ for ρ 1 (11). In concordance with previous research on the matter (Zhang et al., 2020; Kunstner et al., 2023), we found Adam to be helpful in our setup as well, see Figure 7 (right). In first order approximation, Adam is approximated as sign SGD (Balles & Hennig, 2018). Arguably, this introduces a normalization effect to the gradient, helping to reach the saturation phase of n 7 f n (20) shown on Figure 5, homogenizing the resulting matrix W to behave similarly to q1 = 1, therefore optimizing memory capacity. Experiments to underpin this intuition are reported in Figures 15 and 16 in Appendix B. Layer normalization. Minimizing the cross-entropy loss implies setting p W (y|x) = 1, which will lead to W diverging to infinity and unstable loss gradients. In order to ensure numerical stability, it is natural to rescale the vector Wex Rd, especially since what matters for the final prediction f W is only its direction. This is precisely what layer-norm does, introducing the logit score g LN y (x) = uy, Wex Wex , instead of gy(x) = u y Wex. This leads to an added projection on the gradients in (17), as detailed in Appendix A.12, denoting W = W/ Wex , W ℓLN(x, y; W) = W ℓ(x, y; W) = 1 Wex I ( Wex)( Wex) W ℓ(x, y; W). (21) We recognize a projection that kills the signal that already aligns with Wex. We conjecture that this introduces a clipping effect on the corresponding q(x), optimizing for memory storage, and explaining the good performance observed in the right of Figure 7. 4.2 THE BENEFITS OF LEARNING THE EMBEDDINGS Taking a step back, Theorem 1 implies that our model with d2 parameters, the matrix W Rd d (4), only memorize about d/ log(d) associations (ex, uy) (Rd)2 of size 2d. Intriguingly, Lemma 1 below states that an exponential number of quasi-orthogonal elements can be put in Rd, an event that actually holds with high probability when embeddings are random, showcasing intrinsic limitations of our linear model (2). Definition 1 (Quasi-orthogonality). The family (uz)z [P ] with uz Rd is η-quasi orthogonal if {z, z } [P], | uz, uz | η, and uz = 1. (22) Lemma 1. For any d N and P 3, there exists an embedding u : [P] Rd such that the family (uz)z [P ] is η = 2 p d 1 log(P)-quasi orthogonal. As a consequence of Lemma 1, the following model f1(x) = arg max y u y X x [P ] uf (x )σ(e x ex η), (23) Published as a conference paper at ICLR 2024 Figure 8: Experiments with learned embeddings when α = 2, N = 100 and M = 5 with y = f (x) = x mod. M and d = 2. Left: level lines of the function R2 [5]; u 7 arg maxy [5] u y u with uy the learned unembedding. Middle: scatter plot of the learned input embeddings ex R2 for x [N] colored accordingly to f (x) for ex. It illustrates how the input embeddings match with the output ones, similarly to (24) and Proposition 5. Right: learned input embeddings obtained with M = 10, and allowing again a zero generalization error. Reaching a zero error with d = 2 greatly contrasts with the condition d N needed to get to a zero generalization error when the embeddings are random. where σ(x) = x+ is the Re LU function, can fit P = exp(η2d/4) elements in memory, leading to a scaling in E(f1) exp( (α 1)η2d/4) when p(x) follows a α-Zipf law.3 Similarly, one could consider higher moments of e x ex which has been the basis for modern Hopfield networks (Krotov & Hopfield, 2016; Ramsauer et al., 2021). However, implementing the model (23) requires to keep track of each of the P vectors ex Rd, leading to Pd parameters, in order to only store P associations of size d, needing compute that scales with Pd at inference time, rather than just d2, We also note that when embeddings are learned, it is actually possible to store as many memories as desired, which can be seen from the fact that W = I, y [M] uy Sd, ex = uf (x) f (x) = arg max y u y Wex, (24) In particular, Figure 8 illustrates the solution found when d = 2 by optimization-based algorithms in order to get a zero generalization error on the task of Figure 3 where M = 5. Optimizing token embeddings is probably an important element to increase memorization capacity in transformers, although enforcing ex = uf (x) is unrealistic when embeddings are shared over different heads, and the input/output relationships to be learned differ across heads. 5 CONCLUSION This work considers a simple model to study memorization in transformers. Here, memorization is seen as a valuable behavior, the network memorizing useful patterns and association rules. We derive precise scaling laws with respect to both the number of data, and the model size, which plays the role of a model capacity. We quantify the effect of different memorization schemes, illustrating the benefits of uniformly weighted outer products. We leverage these theoretical results to study how different optimization algorithms commonly used for transformers may lead to more efficient memorization. In particular, we showcase the efficacy of small batches and large learning rates, and, under the design constraints resulting from efficient hardware utilization and training stability, the usefulness of Adam and layer normalization. While our study focuses on simple memorization schemes, it opens up many possible new directions. This includes extending our study to richer models that are closer to transformers, where embeddings, attention and feed-forward layers are trained. This could allow models of scaling laws that capture interactions between tokens, as well as hierarchical behaviors that require multiple layers. We would equally like to leverage our framework for assessing memorization and generalization through clear metrics, and eventually automatically adapt the learning rates as a function of the free memory capacity left in a layer. Acknowledgements. The authors would like to thank Léon Bottou as well as Hervé Jégou for many fruitful discussions on memorization mechanisms in transformer language models. 3This result follows directly from two facts. When input embeddings are chosen at random, the probability that they are not η-quasi orthogonal is bounded by P 2 exp( dη2/2). When input embeddings are η-quasi orthogonal, f1(x) = f (x) for any x [P]. Published as a conference paper at ICLR 2024 Shun-Ichi Amari. Learning patterns and pattern sequences by self-organizing nets of threshold elements. IEEE Transactions on Computers, 1972. Yasaman Bahri, Ethan Dyer, Jared Kaplan, Jaehoon Lee, and Utkarsh Sharma. Explaining neural scaling laws. ar Xiv preprint ar Xiv:2102.06701, 2021. Lukas Balles and Philipp Hennig. Dissecting adam: The sign, magnitude and variance of stochastic gradients. In ICML, 2018. Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Herve Jegou, and Leon Bottou. Birth of a transformer: A memory viewpoint. In Neur IPS, 2023. Thomas Cover and Joy Thomas. Elements of Information Theory. Wiley, 1991. Lukasz Debowski. A simplistic model of neural scaling laws: Multiperiodic santa fe processes. ar Xiv preprint ar Xiv:2302.09049, 2023. Ian Dinwoodie. Mesures dominantes et théorème de Sanov. Annales de l Institut Henri Poincare, 1992. Nelson Elhage, Neel Nanda, Catherine Olsson, Tom Henighan, Nicholas Joseph, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, Tom Conerly, Nova Das Sarma, Dawn Drain, Deep Ganguli, Zac Hatfield-Dodds, Danny Hernandez, Andy Jones, Jackson Kernion, Liane Lovitt, Kamal Ndousse, Dario Amodei, Tom Brown, Jack Clark, Jared Kaplan, Sam Mc Candlish, and Chris Olah. A mathematical framework for transformer circuits. Technical report, Anthropic, 2021. Nelson Elhage, Tristan Hume, Catherine Olsson, Nicholas Schiefer, Tom Henighan, Shauna Kravec, Zac Hatfield-Dodds, Robert Lasenby, Dawn Drain, Carol Chen, Roger Grosse, Sam Mc Candlish, Jared Kaplan, Dario Amodei, Martin Wattenberg, and Christopher Olah. Toy models of superposition. Technical report, Anthropic, 2022. Vitaly Feldman. Does learning require memorization? a short tale about a long tail. In STOC, 2020. Vitaly Feldman and Chiyuan Zhang. What neural networks memorize and why: Discovering the long tail via influence estimation. In Neur IPS, 2020. Mor Geva, Roei Schuster, Jonathan Berant, and Omer Levy. Transformer feed-forward layers are key-value memories. In EMNLP, 2021. Justin Gilmer, Behrooz Ghorbani, Ankush Garg, Sneha Kudugunta, Behnam Neyshabur, David Cardoze, George Edward Dahl, Zachary Nado, and Orhan Firat. A loss curvature perspective on training instabilities of deep learning models. In International Conference on Learning Representations, 2021. Wassily Hoeffding. Probability inequalities for sums of bounded random variables. Journal of the American Statistical Association, 1963. Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models. In Neur IPS, 2022. John Hopfield. Neural networks and physical systems with emergent collective computational abilities. Proceedings of the National Academy of Sciences of the United States of America, 1982. Marcus Hutter. Learning curve theory. ar Xiv preprint ar Xiv:2102.04074, 2021. Jared Kaplan, Sam Mc Candlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. ar Xiv preprint ar Xiv:2001.08361, 2020. Teuvo Kohonen. Correlation matrix memories. IEEE Transactions on Computers, 1972. Published as a conference paper at ICLR 2024 Dmitry Krotov and John Hopfield. Dense associative memory for pattern recognition. In Neur IPS, 2016. Frederik Kunstner, Jacques Chen, Jonathan Wilder Lavington, and Mark Schmidt. Noise is not the main factor behind the gap between sgd and adam on transformers, but sign descent might be. In ICLR, 2023. William Little. The existence of persistent states in the brain. Mathematical Biosciences, 1974. Christopher Longuet-Higgins, David. Willshaw, and Peter Buneman. Theories of associative recall. Quarterly Reviews of Biophysics, 1970. Alexander Maloney, Daniel Roberts, and James Sully. A solvable model of neural scaling laws. ar Xiv preprint ar Xiv:2210.16859, 2022. Kevin Meng, David Bau, Alex Andonian, and Yonatan Belinkov. Locating and editing factual associations in GPT. In Neur IPS, 2022. Eric Michaud, Ziming Liu, Uzay Girit, and Max Tegmark. The quantization model of neural scaling. ar Xiv preprint ar Xiv:2303.13506, 2023. Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Yang, Zach De Vito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library. In Neur IPS, 2019. Steven Piantadosi. Zipfs word frequency law in natural language: A critical review and future directions. Psychonomic Bulletin and Review, 2014. Hubert Ramsauer, Bernhard Schäfl, Johannes Lehner, Philipp Seidl, Michael Widrich, Thomas Adler, Lukas Gruber, Markus Holzleitner, Milena Pavlovi, Geir Kjetil Sandve, Victor Greiff, David Kreil, Michael Kopp, Günter Klambauer, Johannes Brandstetter, and Sepp Hochreiter. Hopfield networks is all you need. In ICLR, 2021. Imanol Schlag, Kazuki Irie, and Jürgen Schmidhuber. Linear transformers are secretly fast weight programmers. In ICML, 2021. Samuel Smith, Pieter-Jan Kindermans, Chris Ying, and Quoc V. Le. Don t decay the learning rate, increase the batch size. In ICLR, 2018. Paul Smolensky. Tensor product variable binding and the representation of symbolic structures in connectionist systems. Artifical Intelligence, 1990. Ben Sorscher, Robert Geirhos, Shashank Shekhar, Surya Ganguli, and Ari Morcos. Beyond neural scaling laws: beating power law scaling via data pruning. In Neur IPS, 2022. Karl Steinbuch. Die Lernmatrix. Kybernetik, 1961. Sainbayar Sukhbaatar, Edouard Grave, Guillaume Lample, Herve Jegou, and Armand Joulin. Augmenting self-attention with persistent memory. ar Xiv preprint ar Xiv:1907.01470, 2019. Juan Valle-Lisboa, Andrés Pomi, and Eduardo Mizraji. Multiplicative processing in the modeling of cognitive activities in large neural networks. Biophysical Reviews, 2023. David Willshaw, Peter Buneman, and Christopher Longuet-Higgins. Non-holographic associative memory. Nature, 1969. Yuhuai Wu, Markus Rabe, De Lesley Hutchins, and Christian Szegedy. Memorizing transformers. In ICLR, 2022. Greg Yang and Etai Littwin. Tensor programs ivb: Adaptive optimization in the infinite-width limit. ar Xiv preprint ar Xiv:2308.01814, 2023. Published as a conference paper at ICLR 2024 Greg Yang, Edward Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer. In Neur IPS, 2021. Jingzhao Zhang, Sairaneeth Karimireddy, Andreas Veit, Seungyeon Kim, Sashank Reddi, Sanjiv Kumar, and Suvrit Sra. Why are adaptive methods good for attention models? In Neur IPS, 2020. Published as a conference paper at ICLR 2024 A.1 FINITE DATA - PROOF OF PROPOSITION 1 Let us consider the infinite memory model, where an LLM can store in memory all previously seen associations (x, y). At each time t, a random positive integer x is drawn from some fixed probability distribution. At time T, the LLM would have seen x1, . . . , x T and the associated f (xt), where each xt is a random positive integer drawn independently from p. As such, the LLM would have learned a map ˆf, that only miscorrects the inputs x which are different from all the xt for t [T]. The generalization error reads, with respect to the random dataset DT = (Xt, Yt)t [T ], EDT [ ˆf] = PX,DT (X / {Xt}t [T ]) = X x [N] p(x)PDt(x / {Xt}t [T ]) = X x [N] p(x)(1 p(x))T . Using that (1 a)T = exp(T log(1 a)) and 2 log(2)a log(1 + a) a for any a 1/2, we get x [N] 1p(x) 1/2 p(x) exp( 2 log(2)p(x)T) x=2 p(x) exp( 2 log(2)p(x)T) EDT [ ˆf] X x [N] p(x) exp( p(x)T). Relating this series to the corresponding integral, we have Z x [1,N] p(x) exp( 2 log(2)p(x)T) dx 1/T x [2,p 1(1/T )] p(x 1) exp( 2 log(2)p(x 1)T) dx x [p 1(1/T ),N] p(x) exp( 2 log(2)p(x)T) dx x=2 p(x) exp( 2 log(2)p(x)T) EDT [ ˆf] X x [N] p(x) exp( p(x)T) x [1,N] p(x) exp( 2 log(2)p(x)T) dx + 1/T Letting N goes to infinity, we get the scaling EDT [ ˆf] Z 1 p(x)e T p(x) dx 1/T. (25) Assuming that p(x) = Cf(x) for some constant C, and a smooth strongly decreasing function f : R+ R+ such that limx 0 f(x) = + , one may consider the change of variable u = f(x), i.e., x = f 1(u). If so, dx = d(f 1) (u) = du f f 1(u). Hence it holds that EDT [ ˆf] Z u f f 1(u)e u T du. (26) This relates to the Laplace transform of the function inside the integrand. In particular, one can work out that when p(x) Cαx α, f 1(u) = u 1/β from which one can deduce that Z 1 x α exp( Tx α) dx = α Γ( α 1 which recovers a result of Hutter (2021). Published as a conference paper at ICLR 2024 A.2 MEMORY CAPACITY - PROOF OF LEMMA 1 The proof of Lemma 1 concerning quasi orthogonal embeddings can be done through a reasoning on random embeddings. Let (Xi) be P independent identically distributed random variables. We are interested in the event where the normalized (Xi) are η-quasi orthogonal. P( {i,j} [P ]{| Xi, Xj | η Xi Xj }) = 1 P( {i,j} [P ]{| Xi, Xj | η Xi Xj }) 2 P(| X1, X2 | η X1 X2 ). If this event can happen, it means that there exists such η-quasi orthogonal samples. As a consequence, we are looking to maximize η such that P(| X1, X2 | η X1 X2 ) < 2 P(P 1). (27) Let us consider (Xi) to be distributed accordingly to a rotation-invariant probability. By symmetry, we have, with f1 denoting the first vector of the canonical basis in Rd, P(| X1, X2 | η X1 X2 ) = P(| X, f1 | η X ) = P(| X X , f1 | η) (28) By symmetry, the vector X/ X is uniform on the sphere. Using that P(| X, f1 | > η) = 2P( X, f1 > η) and P(| X, f1 | η) = 2 Vol(Sd 1) x Sd 1 1x1 η dx = 2 Vol(Sd 1) x1=η Vol( q 1 x2 1 Sd 2) dx1 = 2 Vol(Sd 2) t=η (1 t2) d 1 2 dt = 2Γ( d 2 + 1) πΓ( d t=η (1 t2) d 1 To upper bound this probability, we proceed with P(| X, f1 | η) = 2Γ( d t=η (1 t2) d 1 t η (1 t2) d 1 π 1 η(d + 1)(1 η2) d+1 η2d exp( η2d The last inequality follows from the fact that (d + 2) (d + 1)2 = d + 1 + 1 d + 1 1 d + 1 = 1 + 1 d+1 1 + 1 and that for any x ( 1, 1), the concavity of the logarithm mean that log(1 + x) x hence that (1 + x)n = exp(n log(1 + x)) exp(nx). This leads to the following series of implications (Xi) η-quasi orthogonal 1 π 2 1/2 exp( η2d 2 1/2 exp(η2d 2 1 and exp(η2d 2 2 log(P) log(2 π) 1 4 log(P) 1 + log(2 p Finally, we have proven the existence of a η-quasi orthogonal family for 4 log(P)d 1, as long as P 3. (29) Published as a conference paper at ICLR 2024 A.3 GENERIC ERROR DECOMPOSITION The error made by f W relates to the ordering between the signals uf (x)We x and the noises maxy =f (x) uy W ex. Let fq be defined as in the main text. We have the following sequence of equivalence, assuming uniqueness of the argument of the maximum for simplicity, fq(x0) = f (x0) arg max y [M] x [N] q(x)e x ex0u f (x)uy = f (x0) x [N] q(x)e x ex0u f (x)uy > X x [N] q(x)e x ex0u f (x)uf (x0) x [N] q(x)e x ex0u f (x)(uy uf (x0)) > 0. As a consequence, x0 [N] p(x0)1fq(x0) =f (x0) xo [N] p(x0)1maxy x [N] q(x)e x ex0u f (x)(uy uf (x0))>0. (30) In other terms, we have proven the following characterization, which holds for any q, even if derived from a finite number of data, E(fq) = p({x [N] | max y x [N] q(x )e x ex uf (x ), uy uf (x) > 0}). (30) A.4 RANDOM EMBEDDINGS - PROOF OF THEOREM 1 Let us introduce randomness in the model. If each ex N(0, I) is actually an independent random Gaussian vector in Rd, we continue our derivation with Ee[E(fq)] = X xo [N] p(x0)Eex0 [P(ex)x =x0 (fq(x0) = f (x0) | ex0)] xo [N] p(x0)Eex0 [P(ex)x =x0 (max y x [N] q(x)e x ex0u f (x)(uy uf (x0)) > 0 | ex0)] xo [N] p(x0)Eex0 [P(ex)x =x0 (max y Zy > 0 | ex0)]. Here, we have introduced the random variables Zy for y = f (x0), inheriting their randomness from (e|ex0), and defined by x [N] q(x)e x ex0u f (x)(uy uf (x0)). (31) Those are projections of Gaussian variables, hence are Gaussian. Using the fact that E[ex] = 0, their mean is µy := E[Zy] = q(x0) ex0 2u f (x0)(uy uf (x0)). (32) Those variables are correlated. Using the characterization of the mean, we deduce that their covariance reads Σy1,y2 := E[(Zy1 E[Zy1])(Zy2 E[Zy2])] x,x =x0 q(x)q(x )E[e x ex0e x ex0]u f (x)(uy1 uf (x0))u f (x )(uy2 uf (x0)) = (uy1 uf (x0))( X x =x0 q(x)2e x0E[exe x ]ex0uf (x)u f (x))(uy2 uf (x0)). = (uy1 uf (x0))( X x =x0 q(x)2 ex0 2uf (x)u f (x))(uy2 uf (x0)). Published as a conference paper at ICLR 2024 Finally, we obtain the following covariance Σy,y = ex0 2(uy uf (x0)) ( X x =x0 q(x)2uf (x)u f (x))(uy uf (x0)). (33) We are left with the computation of the probability that the maximum of the n correlated, noncentered, exchangeable, Gaussian variables (Zy) is bigger than zero. Generic upper bound. Since we do not care about the scaling with respect to M, we proceed with max y [M] P(Zy 0) P(max Zy 0) X y [M] P(Zy 0) M max y [M] P(Zy 0), (34) which leads to P(ex)x =x0 (max y x [N] q(x)e x ex0u f (x)(uy uf (x0)) > 0|e(x0)) y =f (x0) exp( 1µy<0 µ2 y 2Σy,y ) y =f (x0) exp( 1 uf (x0),uy uf (x0) <0 ex0 2 2 q(x0)2 uf (x0), uy uf (x0) 2 P x =x0 q(x)2 uf (x), uy uf (x0) 2 ). Finally, recognizing a χ2-variable with d degrees of freedom, for any a > 0, E[exp( a ex0 2)] = (1 + 2a) d/2 = exp( d 2 log(1 + 2a)). This leads to the final bound, with χu,x = miny [M] 1 uf (x),uy uf (x) 0. Ee[E(fq)] X x [N] p(x) min{1, X 1+ q(x)2 uf (x), uy uf (x) 2 P x =x q(x )2 uf (x ), uy uf (x) 2 d 2 χu,x}. (35) This holds for any unembedding u and associative weight scheme q. In the following, we will assume that the unembedding u are such that χu,x = 1, which is notably the case when the uy are normalized (i.e., uy Sd 1). Matching lower bound. Going back to (34), one can get a matching lower bound. Ee[E(fq)] X x [N] p(x)Eex[ max y =f (x) P(Zy 0|ex)] x [N] p(x) max y =f (x) Eex[P(Zy 0|ex)] x [N] p(x)(1 max y =f (x) Eex[erf( µy p To conclude, we need an inequality of anti-concentration for Gaussian variables. In essence, we should distinguish two type of inputs x [N]: the ones where µy/Σy,y will be large enough to store the association uf (x)e x , which will lead to an error decreasing exponentially fast; the ones where the same ratio is too small and that we should count in the lower bound. Published as a conference paper at ICLR 2024 Following this split, one can go for the simple survival lower bound Ee[E(fq)] sup t>0 1 erf(t) x0 [N] p(x0) max y =f (x0) Eex0 [1µ2y 2Σy,yt2] = sup t>0 1 erf(t) x0 [N] p(x0) max y =f (x0) Pex0 ( ex0 2q(x0)2 uf (x0), uy uf (x0) 2 2t2 X x =x0 q(x)2 uf (x), uy uf (x0) 2). sup t,s>0 1 erf(t) x0 [N] p(x0)Pex0 ( ex0 2 s) max y =f (x0) 1sq(x0)2 uf (x0),uy uf (x0) 2 2t2 x =x0 q(x)2 uf (x),uy uf (x0) 2. Without optimizing for constants, taking t = 1/ 2 and s = d, we get the simple survival bound that there exists a constant c such that Ee[E(fq)] c X x [N] p(x)1dq(x)2 uf (x),uy uf (x) 2 x =x q(x )2 uf (x ),uy uf (x) 2. (36) The constant can be computed explicitly as c = 1 erf(1/ 2) 2 P( ex0 2 d) > 0.158 1/2 = 0.079, where we have used that ex0 2 is a χ2-variable with mean d hence smaller median, which implies that P( ex0 2 < d) > 1/2. Quasi-orthogonal output embeddings. Let us consider u : [M] Rd such that (uy)y [M] is η-quasi orthogonal. Upper bound. Going back to (35), we can work out a lower bound with q(x0)2 uf (x0), uy uf (x0) 2 P x =x0 q(x)2 uf (x), uy uf (x0) 2 q(x0)2(1 η)2 P x =x0 q(x)2(1f (x)=y(1 + η)2 + 1f (x)=f (x0)(1 η)2 + 1f (x)/ {y,f (x0)}4η2) q(x0)2(1 η)2 x =x0 q(x)2(1f (x)=y + 1f (x)=f (x0) + 1f (x)/ {y,f (x0)}η2) 4 q(x0)2(1 η)2 P x q(x)2((1 η2)1f (x) {y,f (x0)} + η2) q(x0)2 4 q(x0)2(1 η)2 η2 q 2 + (1 η2) P x;f (x) {y,f (x0)} q(x)2 q(x0)2 4 q(x0)2(1 η)2 η2 q 2 + (1 η2)(Qy + Qf (x)) q(x0)2 . Here, we have used that for the numerator uf (x0), uy uf (x0) 2 = ( uf (x0), uy 1)2 (1 η)2, and the same for the term in the denominator (since their ratio cancels out), as well as uy, uy uf (x0) 2 (1 + η)2, uf (x), uy uf (x0) 2 (2η)2. Moreover, we have introduced x ;f(x )=y q(x )2. (37) Published as a conference paper at ICLR 2024 Using the fact that (1 + x)d = exp(d log(1 + x)) exp(dx), an upper bound directly follows from those derivations, Ee[E(fq)] X x0 [N] p(x0) min{1, M exp d(1 η)2 4η2 q 2 2 + 2Q Q = max y [M] Qy = max y [M] x;f (x)=y q(x)2. (39) Matching lower bound. Similarly, one can work out a lower bound with q(x0)2 uf (x0), uy uf (x0) 2 P x =x0 q(x)2 uf (x), uy uf (x0) 2 q(x0)2(1 + η)2 P x =x0 q(x)2(1f (x)=y(1 η)2 + 1f (x)=f (x0)(1 + η)2 1 η 1+η Qy + Qf (x) q(x0)2 . Combining this with (36), we get the lower bound, with c = .079, Ee[E(fq)] c X x [N] p(x)1(d+1)q(x)2 1 η 1+η Q . (40) Remark that in the previous lower bound, we have dropped the previous factor η2 q 2 that appears in the upper bound. We expect this term to actually be present in a tighter error characterization. In essence, we expect the embeddings to fill the full space Sd 1 so that most of the difference uf (x), uy uf (x0) 2 typically behave as η2. However, quantifying this precisely is beyond the scope of this paper. Random output embeddings. In the case where the output embeddings are random, we can distinguish two cases. The cases where the embeddings are η-quasi orthogonal, where one can retake the previous derivations, and the case where they are not, which will have a small probability if η is large enough. Consider u to be random embeddings taking uniformly on the unit sphere. Let us introduce the event Eη = {u is η-quasi orthogonal}. We have seen in the proof of Lemma 1 that 1 P(Eη) M 2 η2d exp( η2d For any random variable Z that is bounded by one, we have the bounds P(E)E[Z|E] E[Z] = (1 P(E))E[Z| E] + P(E)E[Z|E] (1 P(E)) + E[Z|E]. (42) The upper bound of Theorem 1 directly follows from plugging (38) and (41) into this last equation Ee,u[E(fq)] M 2 η2d exp( η2d x [N] p(x0) X Since this is true for any η, one can consider the infimum in the upper bound. In term of lower bound, retaking (40), Ee,u[E(fq)] sup η 0 c(1 M 2 η2d exp( η2d x [N] p(x)1(d+1)q(x)2 2 1 η 1+η Q . (44) In particular, when d > 8 log(M) one can consider η < 1/2 such that η2d > 4 log(M), which leads to (η 1)/(η + 1) > 1/3, and, if M 4 η2d exp( η2d 2 ) 1 1 2 π 1 p 2 log(M) > 2/3. All together we have proven that, as long as M 4 and d 8 log(M) with c1 > .079 2/3 > .052 and c2 > 1/3, Ee,u[E(fq)] c1 X x [N] p(x)1(d+1)q(x)2 c2Q . (45) Published as a conference paper at ICLR 2024 Writing upper bounds as survival bounds. Until now, we have written the upper bounds as the sum of exponential (38) and the lower bounds as a sum of missed associations (45), which we called survival bound. In order to best read how tight our characterization is, one can rewrite the upper bounds as survival bounds. In particular, as we did in the lower bound, we will dissociate x corresponding to a small exponential and the other ones. Using the fact that the p(x) sum to one, we get, when the output embeddings are η-quasi orthogonal, Ee[E(fq)] X x0 [N] p(x0) min{1, M exp d(1 η)2 4η2 q 2 2 + 2Q x0 [N] p(x0) inf t>0 M exp t(1 η)2 4 + 1dq(x0)2 t(2η2 q 2 2+Q ) inf t>0 exp t(1 η)2 4 + log(M) + X x [N] p(x)1dq(x)2 t(2η2 q 2 2+Q ). To simplify the bound, consider the constraints η2 Q / q 2 2, and η < 1/2, (46) we get, using t = 16(log(M) + γ log(d)) for γ > 0, we get Ee[E(fq)] inf t>0 exp t(1 η)2 4 + log(M) + X x [N] p(x)1dq(x)2 t(2η2 q 2 2+Q ) inf t>0 exp t + 16 log(M) x [N] p(x)1dq(x)2 3t Q exp( γ log(d)) + X x [N] p(x)1dq(x)2 48(log(M)+γ log(d))Q . Finally, when the output embedding are η-quasi orthogonal with η satisfying (46), we get Ee[E(fq)] inf γ>0 d γ + X x [N] p(x)1dq(x)2 48(log(M)+γ log(d))Q . (47) When the unembeddings are chosen at random, when d > 8 log(M), one can choose η < 1/2, and (43) is cast as, chosen dη2 = 4 log(M) + 2γ log(d), Ee,u[E(fq)] inf η,γ M 2 η2d exp( η2d x [N] p(x)1dq(x)2 16(log(M)+γ log(d))(2η2 q 2 2+Q ) 2 log(M) + γ log(d) x [N] p(x)1dq(x)2 16(log(M)+γ log(d))( 8 log(M)+4γ log(d) d q 2 2+Q ) inf γ 2d γ + X x [N] p(x)1dq(x)2 16(log(M)+γ log(d))( 8 log(M)+4γ log(d) d q 2 2+Q ). Finally, we have shown that when the embeddings are taken at random Ee,u[E(fq)] inf γ 2d γ + X x [N] p(x)1dq(x)2 16(log(M)+γ log(d))( 8 log(M)+4γ log(d) d q 2 2+Q ). (48) A.5 PROOF OF PROPOSITION 2 When p(x) x α, q(x) = p(x)ρ x ρα, hence, p({x [N] | dq(x)2 p q 2}) p({x [N] | x (d q 2)1/2ρα}) (d q 2) (α 1)/2ρα). We are left with the computation of φ(N) := q 2 R N 1 q(x)2 dx R N 1 x 2ρα dx. When 2ρα > 1, this integral reads 1 N 2αρ+1 which is bounded by one. Published as a conference paper at ICLR 2024 A.6 PROOF OF PROPOSITION 3 When p(x) x α, q(x) = 1x [P ]p(x)ρ 1x [P ]x ρα, we get p({x [N] | dq(x)2 p q 2}) = p({x [P] | dq(x)2 p q 2}) + p({x > P}) d φ(P) (α 1)/2ρα + P α+1. The optimal threshold P is set by equalizing the two terms, which we compute as d φ(P) (α 1)/2ρα = P α+1 2ρα log(d) α + 1 2ρα log(P) = ( α + 1) log(P) log(d) log(P) = 2ρα log(P) P = d1/(2ρα+1). This choice of P leads to a scaling in, with fρ,[P ] = fqρ,[P ], Ee,u[E(fρ,[P ]) (log) p({x [N] | dq(x)2 p q 2}) P (α 1) = d (α 1)/(2ρα+1). A.7 PROOF OF THEOREM 2 The lower bound directly follows from (8) together with Q = p q 2 and the fact that q is invariant to rescaling, so the best we can do is fit as much memories P as we can until reaching 3(d+1) = p P leading to a scaling in R P p(x) dx = CαP α+1/(α + 1). A.8 PROOF OF PROPOSITION 4 In order to get scaling with both finite data and finite memory simultaneously, we used a simple strategy: With high probability 1 c T 1+1/α for some constant c, ˆq is similar to q. When ˆq is similar to q, the scaling with d derived from Theorem 1 is left unchanged by substituting q by ˆq. Rather than using a uniform concentration inequality on the full ˆq, we will proceed individually on each ˆq(x). Denoting by DT the random dataset of T data, for any sequence of set (Ex)x [N] typically we will choose Ex = {ˆq(x) > q(x)/2}, Eu,e,DT [E(fˆq)] = X p(x)Pu,e,DT (f(x) = f (x)) X p(x)Pu,e,T (ˆq / Ex) + X p(x)Pu,e,T (f(x) = f (x) | ˆq Ex). The second term has been worked out before, using that Q q 2 2 Pu,e,T (f(x) = f (x) | ˆq Ex) inf γ 2d γ + PT (dˆq(x)2 16cγ( ˆQ + 8cγ ˆq 2 2 d ) | ˆq Ex). inf γ 2d γ + PT (dˆq(x)2 c γ ˆq 2 2 | ˆq Ex), where c γ = 16cγ(1 + 8cγ Without thresholding. Let us first start with the scheme (11), with ρ > 0 ˆq(x) = ( 1 t [T ] 1x=Xt)ρ, q(x) = p(x)ρ. Using a simplification of Chernoff bound for Bernoulli variables (see e.g., Hoeffding, 1963), we get the probability bound (the randomness being due to the data), PT (ˆq(x) < q(x) 21/ρ ) = PT (ˆp(x) < p(x) 2 ) exp( Tp(x)/8). Published as a conference paper at ICLR 2024 As a consequence, reusing the proof of Proposition 1, when p follows a Zipf law (1), E[E(fˆq)] = X p(x)P(f(x) = f (x)) X p(x) exp( Tp(x)/8) + X p(x)P(f(x) = f (x) | ˆq(x) > q(x)/21/ρ) T 1+1/α + X p(x)P(f(x) = f (x) | ˆq(x) > q(x)/21/ρ). We are left with the computation of the second term, denote cρ = 2 1/ρ, we have Eu,e PT (f(x) = f (x) | ˆq(x) > cρq(x)) inf γ 2d γ + PT (dˆq(x)2 c γ ˆq 2 2 | ˆq q(x)/2). By definition of ˆq, together with Jensen s inequality when ρ 1/2 x [N] (q(x)2)1/2ρ ( 1 N q 2 2)1/2ρ, hence q 2 N 1 2ρ. When ρ > 1/2, the worst value of q is when all the mass is concentrated on one q(x ), in which case q 2 1. With the corresponding ψ(N), we get Eu,e PT (f(x) = f (x) | ˆq(x) > cρq(x)) inf γ 2d γ + 1dc2ρq(x)2 c γφ(N). Finally, reusing the proof of Proposition 2, and hiding logarithmic factors, E[E(fˆq)] = X p(x)P(f(x) = f (x)) T 1+1/α + inf γ 2d γ + p({x | dc2 ρq(x)2 c γψ(N)}). T 1+1/α + ( d ψ(N)) (α 1)/2ρα. The case ρ = 0, can be easily treated by considering an error if and only if the number of seen elements |{xt | t [T]}| is smaller than d. With thresholding. Let us now consider the thresholding scheme (12), with P N and ρ 0 ˆq(x) = ˆp(x)ρ1x top P ((xt)t [T ]), q(x) = p(x)ρ1x [P ]. We basically proceed with the same technique but with the event Ex the probability that x belongs to the top P of the empirical frequencies. When dealing with a binomial distribution, one can enumerate all possible outcomes for the empirical frequencies. For a template a [N], we said that a sequence (xt) is of type a if its empirical frequency is equal to a, T (a) = {(xt) [N]T | x [N], X t [T ] 1xt=x = Ta(x)}. Some enumeration arguments that can be found in Cover & Thomas (1991, Chapter 11) leads to PDT ((xt) T (a)) = |T (a)| exp( T(H(a) + DKL(a p))) exp( T DKL(a p))). Hence, the probability that x does not belong to the top P of the empirical frequencies of (xt) is bounded by PDT (x / top P (xt) T (a)) X a A exp( T DKL(a p))), where A is the set of all templates a where x is not in the top P of (a(x ))x [N]. With T samples over N elements there are at most (N + 1)T different type templates, hence X a A exp(ca T) (T + 1)N sup a A exp(ca T) = sup a A exp(ca T + N log(T + 1)). As a consequence, PDT (x / top P (xt) T (a)) sup a A exp( T DKL(a p)) + N log(T + 1)) Published as a conference paper at ICLR 2024 Now, it is actually to possible to remove the N log(T + 1) in the exponential and extends this type of result to generic Polish spaces (see, e.g. Dinwoodie, 1992). PDT (x / top P (xt) T (a)) sup a A exp( T DKL(a p))) We are left with the computation of the information projection distance between p and the set of distribution where x does not belong to the top P. In order to get x out of the top P of p one should switch p(x) with p(P), which leads to (without caring for exact constants) DKL(p p) p(x) log(p(x)/p(P)) + p(P) log(p(P)/p(x)) = (p(x) p(P)) log(p(x)/p(P)) When considering x < P/2 and p following a Zipf law we get DKL(p p) (p(x) p(2x)) log(p(P/2)/p(P)) cαx α(1 2 α)α log(2) = c αp(x) where c α = cα(1 2 α)α log(2). As a consequence, for any P N, EDT [E(fˆq)] c0P α+1 + X x [P/2] p(x)P(f(x) = f (x)). c0P α+1 + X x [P/2] p(x)(exp( Tc αp(x)) + P(f(x) = f (x) | x top P ((xt))) c0P α+1 + exp( 2αTc αP α) + X x [P/2] p(x)P(f(x) = f (x) | x top P ((xt)). When ρ = 0, setting P = min(c1d, T 1/α/ log(T)) with c1 chosen so that all x stored in memory lead to f (x) = f(x) gives to the right scaling with both T and d: up to logarithmic factors, E[E(fˆq)] d α+1 + T 1+1/α + exp( c3 log(T)α). Because α > 1, the last term decreases faster than any polynomial power of T, hence ends up being negligible in front of T 1+1/α. For the case ρ (0, 1] one can dissociate two events: the event where x belongs to the top P/2 empirical frequencies; the event where ˆp(x) > p(x)/2; and conclude with similar derivations as precedently E[E(fˆq)] c0P α+1 + exp( 2αTc αP α) + c4T 1+1/α x [P/2] p(x)P(f(x) = f (x) | x top P ((xt)), ˆp(x) > p(x)/2). Retaking previous arguments leads to the same scalings as the ones of Proposition 3 with respect to d and a scaling in T 1+1/α with respect to T. This ends the proof of the mixed scaling with both finite data and finite memory capacity. A.9 LEARNING THE INPUTS EMBEDDINGS In instances where the embeddings are learned within the linear model (2), one may optimize them by merging all input token embeddings that are associated with the same output, which is what we actually observed in practice in Figure 8. Proposition 5 captures the resulting theoretical performance. Proposition 5 (Improvement for learned inputs embeddings). Let the input embeddings be set to ex = uf (x). Assume without restrictions that p(y) is decreasing with y. Consider the unembeddings where (uy)y [P ] are η-quasi orthogonal, and uy = 0 if y is not among the P-th most frequent classes. Let q0 RN, and set q RM as q(y) = P x;f (x)=y q0(x), then E(f Wq0 ) p({x | 1f (x)/ [P ]q(f (x)) < 2η q + 2η2 q 1}). (49) In particular, it is possible to consider a thresholding associative scheme q such that, if y follows a Zipf law p(y) = Cβy β, E(f Wq ) = O((d/ log(d)) β+1). Published as a conference paper at ICLR 2024 Proposition 5 shows that when learning the input embeddings one can expect to replace the scaling in d α+1 that depends on the law of x, by a scaling that depends on the law of y. It illustrates the usefulness to learn embeddings when the law of x is well factored by the law of y. This is typically the case when x are news articles associated with a few topics y. Proof. When e can be optimized, it is natural to set ex to be a constant for all x that are associated with the same output. Let q0 [N] be an associative scheme, 1f Wq0 (x0) =f (x0) = 1maxy =f (x0) x [N] q0(x)e x0exu x (uy uf (x0))>0. In order to lower the probability, one wants to minimize the left expression, which leads to the will to maximize e x0exu f (x)uf (x0). This can be done by setting x, x [N], e x ex = u f (x)uf (x ). (50) Such an isometry can be built by setting ex = uf (x), leading to the new characterization 1f Wq0 (x0) =f (x0) = 1maxy =y0 z [M] q0(z)u y0uzu z (uy uy0)>0, where y0 = f (x0) and x;f (x)=y q0(x). (51) When u are η-quasi orthogonal for its first P values and set to zero otherwise, we have X z [M] q(z)u y0uzu z (uy uy0) = q(y0)(u y0uy 1) + q(y)(u y0uy (u y0uy)2) z [M] ={y,y0} q(z)u y0uz(u z uy u z uy0) q(y0) + |q(y0)|η + |q(y)|η + X z [M] ={y,y0} |q(z)|η(η + η) q(y0) + 2η sup z =y0 |q(z)| + 2η2 X z [M] |q(z)| = q(y0) + 2η q + 2η2 q 1. As a consequence, we get E(f Wq0 ) X x [N] p(x)1q(f (x)) 2η q +2η2 q 1. Using that, for any A : [M] Rd, P x p(x)A(f (x)) = P y p(x, y)A(y) = P y p(y)A(y), E(f Wq0 ) X y [M] p(y)1q(y) 2η q +2η2 q 1. (52) Note that when the embeddings u are chosen uniformly at random on the sphere, and d > 4 log(M), a similar bound will hold up to an extra higher-order term as seen in the proof of Theorem 1. When u is defined to be zero on [M] \ [P], and only η-quasi orthogonal for (uy)y [P ], the same characterization holds with E(f Wq0 ) X y [M]\[P ] p(y) + X y [P ] p(y)1q(y) 2η q +2η2 q 1. (53) Finally,if η2 is set to 1/4P, and q = 1y [P ], we get the upper bound E(f Wq0 ) p({x |f (x) > P}). The best P that one can consider is that such d/4P = η2d = 4 log(P). Setting P = d/16 log(d), and bounding P y>P y β R P t β dt ends the proof. Published as a conference paper at ICLR 2024 A.9.1 DISCUSSION ON COMPENSATION MECHANISMS When optimizing the embeddings, one may turn the negative interference mechanisms illustrated in Figure 2 into positive ones. Assume that ex = uf (x), our model (4) become, denoting uf (x) = u0 for simplicity, f(x) = arg max y [M] u y Wu0; W = X y [M] q(y )uy u y . (54) Similarly as before an error is made when y [M] q(y )(uy u0) uy u y u0 > 0. (55) When the output embeddings are learned, one can optimize them to induce compensation mechanisms. For example, when M = 3, and y1 is competing when y0 = f (x) as the argmax of (54) due to a large storage of q(y1) compared to q(y0), one could benefit of q(y2) to ensure that q(y0)(u 1 u0 1) + q(y1)(1 u 1 u0)u 1 u0 + q(y2)(u1 u0) u2u 2 u0 < 0 < q(y0)(u 1 u0 1) + q(y1)(1 u 1 u0)u 1 u0. In this situation, the score u 0 Wex of y0 would be higher then u y1Wex ensuring that we do not make an error when predicting f(x) (54). We refer the interested reader to Elhage et al. (2022) for related investigation. A.10 LOSS GRADIENT The cross-entropy loss is written as ℓ((x, y, W) = log( exp(u y Wex) P z [M] exp(u z Wex)) = u y Wex + log( X z [M] exp(u z Wex)). Hence stochastic gradient descent will update the matrix W by adding terms of the form W ℓ((x, y), W) = uye x + z [M] exp(u z Wex)uze x P y [M] exp(u y Wex) = uye x + X z [M] p W (z|x)uze x = (1 p W (y|x))uye x + X z =y p W (z|x)uze x = (1 p W (y|x))(uye x X p W (z|x) 1 p W (y|x)uze x ). Note that p W (z|x)/(1 p W (y|x)) corresponds the the probability of the z conditioned with respect to x under the event that z is not y, formally p W (z|x) 1 p W (y|x) = p(z|x, z = y). W ℓ((x, y), W) = (1 p W (y|x))(uye x X z =y p W (z|x, z = y)uze x ) = (1 p W (y|x))(uye x Ez p W [uz|x, z = y]e x ). While, it is clear that the model (4) does not describe the solution found by cross entropy, one might hope that the term E[uz]e x will somewhat cancel themselves out and be an order of magnitude smaller than the leading term uye x . Published as a conference paper at ICLR 2024 A.11 APPROXIMATE UPDATES The formula (20) is justified by the fact that a matrix Wt = Wqt will lead to an update (18) at time t according to the rule (19), assuming exp(uz Wex) 1 for any z = f (x), qt+1(x) qt(x) = 1xt=xγ (1 p Wqt (f (x)|x)) 1xt=xγ 1 + (M 1) 1 exp(qt(x)), together with the fact that x will be seen Tp(x) times on average in T samples. Similarly, very large batch size b = |B| and T/b update steps, each x will appear in each batch about bp(x) times, which leads to the rough approximation qγ,b(x) = f T/b(0) = f f f | {z } T/b times (0), where f : x 7 x + γbp(x) 1 + M 1 exp(x). (56) In practice, we can approximate the effect of a batch by counting how many times x was in this batch and setting bp(x) to be the exact count, which will lead to tighter approximation. This is this approximation that we plot on Figure 13. A.12 GRADIENT FOR LAYER NORM Let x [N], y [M] and W Rd d. When processing the input x, layer norm adds a normalization layer f : W 7 W = W Wex . Using the chain rule, with D denoting the Jacobian operator, W ℓ(x, y; f(W)) = (DW f(W)) f(W )ℓ(x, y; f(W)) = (DW f(W)) W ℓ(x, y; W). We are left with the computation of the Jacobian. We proceed with chain rule f(W) = f1(f2(f3(W))) W, f1 : t R 7 t 1, f2 : e Rd 7 e , f3 : W Rd d 7 Wex. DW f(W) = W (f1 f2 f3)(W)W + f1(f2(f3(W))) I = W (f2 f3)(W) f2(f3(W))2 W + 1 Wex I = f3(W)(DW f3(W)) f3(W) Wex 2 W + 1 Wex I = Wexe x Wex 3 W + 1 Wex I = 1 Wex (I Wexe x W ). This proves the formula written in the main text. B EXPERIMENTAL DETAILS B.1 MAXIMAL PARAMETERS UPDATES In order to carefully choose step-sizes that scale well with width d in optimization algorithms, we follow Yang et al. (2021) and consider learning rates consistent with maximal feature learning updates. Here we consider the following initializations: W is initialized as a Gaussian random matrix with N(0, 1 d) entries. Input embeddings ex and output embeddings uy are initialized as either random on the unitsphere in d dimensions, or with Gaussian N(0, 1 d) entries. In both cases, every embedding has norm 1. Updates to W. The updates to the matrix W look as follows: SGD with step-size ηW : W = W + ηW δW, δW = X j αjuyje xj, Published as a conference paper at ICLR 2024 with αj = Θd(1), and a dimension-independent number of elements in the sum. Choosing ηW = Θ(1) then ensures that for any input embedding ex, we have W ex = Θ(1) as desired. Adam (idealized here as sign SGD) with step-size η: W = W + ηW sign(δW), sign(δW)ij = δWij The coordinates of sign(δW) are now Θ(1) instead of Θ(1/d), thus the step-size needs to be taken as ηW = Θ(1/d) in order to satisfy W ex = Θ(1) (see (Yang et al., 2021; Yang & Littwin, 2023) for more details) Updates to embeddings. The updates to embeddings look as follows: SGD updates: u y = uy + ηuδuy, δuy = X e x = ex + ηeδex, δex = X j α j W uyj, with αj = Θ(1) and a dimension-independent number of js. Since the algorithm ensures Wexj = Θ(1) and W uyj = Θ(1) throughout training, choosing ηu, ηe = Θ(1) ensures that these conditions continue to hold after each update. Adam/sign SGD updates: u y = uy + ηu sign(δuy), (sign(δuy))i = (δuy)i e x = ex + ηe sign(δex), (sign(δex))i = (δex)i Since the updates have coordinates of order Θ(1), in order to ensure that embeddings remain of norm Θ(1) after each update, we thus need ηu, ηe = Θ(1/ B.2 ADDITIONAL FIGURES Our theory predicted optimal scaling laws in d 1+α. However, there are some catches behind the proof: The lower bound is true when d N = 100, otherwise the error can actually reach zero when d becomes larger than a tipping number dt which compares to N. This fact was illustrated on Figure 3. Increasing N augments the tipping point dt, rectifying the learning curve as illustrated on Figure 9. This was proven for models where q(x, y) = q(x), and where q(x) is not optimized with respect to f (x). As such, it is not clear if those lower bounds hold for optimization-based algorithms, although we argue that we do not expect different mechanisms to take place in the proofs. We illustrate this empirically in the left of Figure 12. Similarly, the unreasonable effect of learning the embeddings would be highly disappointing if those were hard to optimize in practice. The right of Figure 12 illustrates how with a few steps, one can achieve a zero generalization error when learning the embeddings. In order to better understand gradient updates, Figure 14 shows the dynamic of the association memory W updated with SGD and a large step size. To validate the approximation (20), Figure 4 plots the generalization error associated with SGD and its theoretical approximation, while Figure 5 illustrates the idealized association scheme qγ associated with a step size γ, batch size one and a Zipf law on x [N]. In order to understand the effect of Adam, we compare it with plain SGD and SGD with rescaled variance on population data. That is, we consider gradient descent with W L(W) (16). The rescale variance SGD, consists in dividing the gradient by the variance of W ℓ(X, f (X); W) (17) when Published as a conference paper at ICLR 2024 q = 1 q = p q = 1x d/8 Figure 9: Same figure as the right one of Figure 3 yet with a bigger N, here N = 1000. The dashed curves represent E = .35 d 1/4 (orange) and E = 3.5 d 1 (green). They validate the scaling predicted by theory where we used N = + to get tight polynomial scalings of E (5) with respect to d. Figure 10: Representation of the weight matrix (u y Wex)y,x RM N for N = 10, M = 5, f (x) = x mod. M. The data x follows a Zipf-law with α = 1 and T = 103. The matrix W is obtained according to (4) together with the scheme (11). Left: ρ = 0 (10), d = 10, there is not enough memory capacity, and the model does not succeed to store memories, leading to a large generalization error. Middle left: ρ = 0 (10), d = 50, there is enough memory capacity, we learn the right association y = x mod. M. Middle right: ρ = 1 (11), d = 10, the weighting q allows to store the most important memories beside having a small memory capacity. Right: ρ = 1 (11), d = 50, the weighting q is too strong which does not allow to store memory associated with rare association (bottom of the matrix). X p (1). For simplicity, we consider Adam with β1 = β2 = 0, in which case, it equates sign SGD, i.e., SGD when considering the sign of each entries of W L(W) in the updates Wt Wt+1. Figures 15 and 16 underpins our intuition that the usefulness of Adam lies in its ability to rescale gradient updates, an effect that could equally be obtained by tuning the learning rate. Published as a conference paper at ICLR 2024 γ = 10 γ = .1 Step LR γ = 10 γ = .1 Step LR Figure 11: Learning curve of the generalization error E (5) with respect to the number of data processed by stochastic gradient descent in the setting of Figure 6. Left: comparison on a single run. A big step size allows to store more memory at the risk of overwriting past association, which explains the higher variance of the blue curve but its overall better performance. A small step size will avoid loss spikes due to memory overwriting, but will take more time to store rare associations, leading to worse performance. By decreasing the learning rates along training, e.g., with the Step LR scheduler (Paszke et al., 2019), one can get the best of both world, i.e., store memories fast at the beginning of training when storage capacity is underused, while being more cautious at the end of training when there is no more free memory space. Right: Similar plot with N = 30 averaged over one hundred runs. 101 102 103 SGD γ : 103 SGD γ : 10 Adam γ : .1 q = 1x,d/8 W e, u Re LU 101 102 103 γ=1.0, |B|=16, T=10240 Learning W Learning e and u Figure 12: Scalings with respect to d for optimization-based algorithms, in the setting of Figure 3. Left: optimization-based algorithms beat the best algorithm designed by hands with q(x, y) = q(x). Note how the curve seems to have the same optimal exponent E d α+1 (the left part of the figure show similar slopes for all curves) yet with smaller constant in front, leading to earlier typing point before reaching zero generalization error due to full storage of all the associations. Middle: Comparison of learning the sole matrix W (blue), or learning the embeddings e and u (orange), together with the possibility to use non-linear model uy Re LU(ex) with e and u learned (green). All curves are obtained after 103 updates with batch size 103. Right: Comparison with the same setting as Figure 7. Learning the embeddings or going non-linear allows to impressively optimize memory storage, leading to better exponent with respect to d and earlier tipping point for a fixed number of updates. 0 50 100 Epochs 2 10 2 3 10 2 4 10 2 6 10 2 real approx 0 50 100 Epochs real approx 0 50 100 Epochs 4 10 2 6 10 2 real approx Figure 13: Same as Figure 4 yet with batch size equals one thousands |B| = 103. Published as a conference paper at ICLR 2024 Figure 14: Gradient descent dynamics similar to Figure 6 with d = 10 and a fixed step size γ = 10. From time to time, we represent here t {0, 4, 5, 6, 8, 9, 11, 30, 32, 37, 49, 62, 75, 90}, stochastic gradient descent will hit an association that is not properly stored in memory yet (the red boxes). It will consequently update the weight matrix Wt Wt+1 (side by side pairs) to store it. When d is big enough, here d = 10, W will end by storing correctly all associations, leading to perfect generalization for future examples. Published as a conference paper at ICLR 2024 Figure 15: Comparison between SGD, sign SGD and SGD with normalized variance on population gradient seen from the association matrix Wt at different times in the setting of Figure 14. The different rows correspond to the matrices Wt at time t {1, 2, 3, 7, 100}. Left: Plain SGD. Middle: Adam with β1 = β2 = 0, i.e., Sign SGD. Right: SGD with normalized variance. 0 50 100 Epoch SGD sign SGD norm SGD 0 50 100 Epoch Gradient variance Figure 16: Left: Generalization error in the setting of Figure 15. Observe how SGD with rescaled variance (in green), an effect that can be done with SGD after adapting the learning rate, actually performs better than sign SGD (i.e., Adam with β1 = β2 = 0). Right: Variance of SGD along the training. As the training goes, SGD is losing momentum due to smaller gradient variances, hence smaller updates.