# infinite_limits_of_multihead_transformer_dynamics__e06cd5d5.pdf Infinite Limits of Multi-head Transformer Dynamics Blake Bordelon, Hamza Chaudhry, Cengiz Pehlevan John A. Paulson School of Engineering and Applied Sciences Center for Brain Science Kempner Institute for the Study of Natural and Artificial Intelligence Harvard University Cambridge, MA 02138 blake_bordelon@g.harvard.edu hchaudhry@g.harvard.edu cpehlevan@seas.harvard.edu In this work, we analyze various scaling limits of the training dynamics of transformer models in the feature learning regime. We identify the set of parameterizations that admit well-defined infinite width and depth limits, allowing the attention layers to update throughout training a relevant notion of feature learning in these models. We then use tools from dynamical mean field theory (DMFT) to analyze various infinite limits (infinite key/query dimension, infinite heads, and infinite depth) which have different statistical descriptions depending on which infinite limit is taken and how attention layers are scaled. We provide numerical evidence of convergence to the limits and discuss how the parameterization qualitatively influences learned features. 1 Introduction Increasing the scale of transformer models has continued to improve performance of deep learning systems across many settings including computer vision [1, 2, 3, 4] and language modeling [5, 6, 7, 8, 9]. However, understanding the optimization stability and limiting behavior of these models under increases in model scale remains a core challenge. One approach to scaling up systems in a stable and predictable way is to identify parameterizations of neural networks that give approximately scale-independent feature updates during training [10, 11, 12]. The mean field parameterization, commonly referred to as µP, is a well-known example that satisfies this property [13, 14, 15]. When such parameterizations are adopted, the learned internal representations in hidden layers of the network are very similar across model scales [16, 17], but performance tends to improve with model scale [10, 11, 12]. Further, theoretical results about their limits can often be obtained using Tensor Programs [14] or dynamical mean field theory (DMFT) techniques [15, 17]. In this work, we develop a theoretical treatment of randomly initialized transformers. We study various scaling limits of the training dynamics of these models including the infinite key/query dimension limit, the infinite head limit, and the infinite depth limit. Concretely, our contributions are the following: 1. We derive a DMFT for feature learning in randomly initialized transformers with key/query dimension N, attention head count H and depth L. From the derived DMFT action, we identify large N, large H and large L limits of the training dynamics. 38th Conference on Neural Information Processing Systems (Neur IPS 2024). 2. We analytically show that the large key-query N limit requires the µP scaling of key/query inner product with 1/N, even if key/queries are reparameterized to decrease the size of their updates from gradient descent. 3. From the limiting equations, we show that this N limit causes multi-head self attention trained with stochastic gradient descent (SGD) to effectively collapse to single-head self attention since all heads follow identical dynamics. 4. To overcome this limitation, we analyze the infinite head H limit while fixing N. We show there is a limiting distribution of attention variables across heads at each layer throughout training. Despite N being finite, the infinite-head H limit leads to concentration of the network s output logits and learned residual stream feature kernels, giving deterministic training dynamics. 5. Finally, we examine large depth limits of transformers with residual branch scaling. We illustrate and discuss the tension between parameterizing a model so that it has a non-trivial kernel at initialization while maintaining feature learning within the multi-head self attention (MHSA) and multi-layer perceptron (MLP) blocks. 1.1 Related Works Hron et al. [18] studied the Neural Network Gaussian Process limit of multi-head self attention in the infinite-head H limit. They showed that, at initialization, there is a limiting distribution over attention matrices and that the outputs of the multi-head attention block follow a Gaussian process, establishing a connection to kernel methods. Dinan et al. [19] develop a similar theory of transformers at initialization and compute the Neural Tangent Kernel associated with this architecture as the dimensions per head N using a 1 N scaling of the key-query inner product within each attention layer. One of our key theoretical results is showing that this picture of a distribution over learned attention heads persists throughout training in the feature-learning regime as H (though the distribution of residual stream variables generally becomes non-Gaussian). Several works have analyzed the signal propagation properties of transformers at initialization at large key/query dimension N and large depth L [20, 21, 22, 23] including providing modifications to the standard transformer architecture [22, 24]. In this work, we pursue large depth limits of transformers by scaling the residual branch as L αL with αL [ 1 2, 1], which has been shown to converge to a limit not only at initialization [25, 26, 27], but also throughout training in the feature learning regime [11, 12, 27]. However, we argue that in transformers that αL = 1 is preferable as it enables the attention layers to update non-negligibly as L . Yang et al. [10] introduced the µP scaling for attention layers which multiplies the key/query inner product with 1 N rather than the more commonly used 1 N [5]. They show empirically that this change improves stability of training and transfer of optimal hyperparameters across different values of N. Vyas et al. [16] empirically found that such µP transformers learn attention matrices that become approximately consistent across different heads and model sizes, suggesting that models parameterized in µP learn similar representations across scales. In addition to work on infinite width and depth limits of deep networks, there is also a non-asymptotic approach to optimizer design and scaling based on controlling the norm of weight updates [28]. This approach coincides with µP width-scaling when the spectral norm of the weights is used as the measure of distance [29], and can achieve hyperparameter transfer for a wide array of optimizers and initialization schemes [30? ]. 2 Parameterizations with Feature Learning Limits We consider a transformer architecture with L layers, H heads per layer, and N dimensional keys/ queries per head. Transformers are often defined in terms of dmodel = Hdhead = HN which can be increased by scaling the number of heads or the dimension of each head, where N is often written dhead. Our goal is to determine the set of parameterizations that allow the attention layers to undergo non-trivial feature learning in the various N, H, L limits. The analysis of these limits is performed with batch size and number of training steps t fixed while the other architectural parameters are taken to infinity. 2.1 Model Scalings The network s output is computed by a depth L recursion through hidden layers ℓ [L] starting with the first layer h1 s(x) = 1 DW 0xs RNH where xs RD is the input at spatial/token position s. Preactivations in subsequent layers hℓare determined by a forward pass through the residual stream which contains an attention layer and a MLP layer hℓ+1 s = hℓ s + β0 LαL MLP hℓ s , hℓ s = hℓ s + β0 LαL MHSA hℓ The constants γ0 and β0 control the rate of feature learning and the effective depth respectively 1. We will consider αL [ 1 2, 1].2 The multi-head self attention layer (MHSA) with pre-layer-norm3 is h [H] W ℓ Ohvℓσ hs , vℓσ hs = X s [S] σℓ h ss vℓ h s NH W ℓ V h hℓ s , hℓ s = LN(hℓ s), (2) where σℓ h RS S are the attention matrices passed through a matrix-valued nonlinearity σ Aℓ h 4. For a sequence of length S, the pre-attention matrix Aℓ h RS S is defined as Aℓ hss = 1 N αA kℓ hs qℓ hs , kℓ hs = 1 H W ℓ Kh hℓ s , qℓ hs = 1 H W ℓ Qh hℓ s. (3) The exponent αA will alter the scale of the pre-attention variables Aℓ h at initialization. The input matrices have shape W ℓ V h, W ℓ Kh, W ℓ Qh RN NH, while the output matrices have shape W ℓ Oh RNH N. All of the weights W ℓ Oh , W ℓ Qh, W ℓ Kh are initialized with Θ(1) entries while W ℓ Kh, W ℓ Qh have entries of size Θ(N 1 αA) which ensures that all key and query k, q vectors are Θ(1) at initialization. The pre-attention variables Aℓ h RS S at each head h are determined by key kℓ hs and query qℓ hs inner products. The MLP layer consists of two linear layers with an element-wise nonlinearity ϕ applied in between, where W ℓ,2, W ℓ,1 RNH NH are initialized with Θ(1) entries: MLP( hℓ s) = 1 NH W ℓ,2ϕ hℓ,1 s , hℓ,1 s = 1 NH W ℓ,1 hℓ s , hℓ s = LN hℓ s . (4) µP scaling [13, 31, 14, 15] downscales the readout of the last layer compared to standard and NTK parameterization [32]. Thus, we define the output of the model as f = 1 γ0NHw L 5 where h L s RNH are the final layer preactivations at spatial/token position s [S]. The parameter γ0 is an additional scalar that controls the rate of change of the internal features of the network relative to the network output [33]. 2.2 Learning Rate Scalings In order to approximately preserve the size of internal feature updates, we must scale the learning rate η appropriately with (N, H, L). However, this scaling depends on the optimizer. In Table 2, we provide the appropriate scaling of learning rates for SGD and Adam for any αL [ 1 2, 1] and αA [ 1 2, 1]. In what follows, we focus on the SGD scaling and theoretically analyze the N , H , and L limits of the training dynamics. We also provide in Table 2 details about what prefactor the first layer should be multiplied by and the initial weights divided by to ensure convergence to the L limit. Example FLAX implementations of these parameterizations for vision and language modeling transformers are provided in Appendix B. 1The scale of the update to the residual stream due to each layer will be O(γ0β2 0/L). 2If αL < 1 2 or αA < 1 2 some of the forward pass variables would diverge at initialization as L or N respectively. If αL > 1 then updates to internal residual blocks will diverge as L . 3The LN denotes layer-norm which we let be defined in terms of each vector s instantaneous mean and variance LN(h) = 1 σ2+ϵ(h µ1) where µ = 1 NH1 h and σ2 = 1 NH|h µ1|2. 4The nonlinearity is generally softmax. For autoregressive tasks, it is taken to be causal. 5Instead of pooling over space, outputs f could also carry spatial index s (such as next word prediction). (a) Scaled Residual Stream (b) MHSA Block Figure 1: Schematic representations of the transformer architecture we model. (a) The forward pass through the residual stream is an alternation of MHSA and MLP blocks scaled by β0L αL. (b) The MHSA block computes keys, queries, values, and attention variables to produce a concatenated output of dimension dmodel = NH. Optimizer Global LR First/Last Layer Rescale Multiplier First/Last Layer Std. Dev. SGD η0NHL2αL 1 L 1 2 αL Θ(1) Adam η0 NHL 1+αL L1 αL Table 1: The learning rates which should be applied to obtain the correct scale of updates for SGD or Adam optimizers. In addition, the weight variance and multiplier for the first layer may need to be rescaled (relative to eq (5)) with width/depth depending on the parameterization and optimizer. Our analysis assumes that at each step t of SGD or Adam a mini-batch Bt of size Θ(1) is used to estimate the loss gradient. We assume that the minibatches are fixed. Further, the number of total training steps is assumed to not be scaled jointly with the model size. Therefore the analysis provided here can cover both online SGD for a fixed number of steps or full batch GD (repeating data) with a Θ(1) sized dataset. 3 Infinite Limits of Learning Dynamics In this section, we first analyze the infinite dimension-per-head N limit of training. We find that for this limit, the µP rule of αA = 1 is necessary and show that all heads collapse to the same dynamics. To counteract this effect, we next analyze the infinite head H limit of the training dynamics at fixed N, L, where we find a limiting distribution over attention heads. We will conclude by analyzing the statistical descriptions of various infinite depth L limits. 6. 3.1 Mean Field Theory Treatment of the Learning Dynamics To obtain the exact infinite limits of interest when scaling dimension-per-head N, the number of heads H, or the depth L to infinty, we work with a tool from statistical physics known as dynamical mean field theory (DMFT). Classically, this method has been used to analyze high dimensional disordered systems such as spin glasses, random recurrent neural networks, or learning algorithms with high dimensional random data [34, 35, 36, 37, 38, 39]. Following [15, 11], we use this method to reason about the limiting dynamics of randomly initialized neural networks by tracking a set of deterministic correlation functions (feature and gradient kernels) as well as additional linear-response functions (see Appendix D). The core conceptual idea of this method is that in the infinite limit and throughout training, all neurons remain statistically independent and only interact through collective variables (feature kernels, neural network outputs, etc). Further the collective variables can be computed as averages over distribution of neurons in each hidden layer or along the residual stream. This DMFT description can be computed using a path integral method that tracks the moment generating function of the preactivations or with a dynamical cavity method (see Appendix D). 3.2 Scaling Dimension-Per-Head N One way of obtaining a well-defined infinite parameter limit of transformers is to take the N limit, where N is the dimension of each head. A priori, it is unclear if there are multiple ways of scaling the key/query inner product. Concretely, it is unknown what values for the exponent αA are admissible for the pre-attention A = 1 N αA k q. The keys and queries are uncorrelated at initialization which motivated the original choice of αA = 1 2 [5, 18]. Yang et al. [10] assume the entries of the key and query vectors move by Θ(1), implying αA = 1 is necessary since the 6Training time, sample size, sequence length/spatial dimension are all treated as fixed Θ(1) quantities. 10 2 10 1 100 Learning Rate 2 N = 4 N = 8 N = 16 N = 32 N = 64 (a) Hyperparameter Transfer for Various αA 101 102 103 N (b) Attention variance across heads Figure 2: Increasing dimension-per-head N with heads fixed for αA = {1, 1 2}. (a) Both αA = 1 and αA = 1 2 exhibit similar hyperparameter transfer for vision transformers trained on CIFAR-5M over finite N at H = 16. (b) The variance of attention variables across the different heads of a vision transformer after training for 2500 steps on CIFAR-5M. For αA = 1 the variance of attention variables decays at rate O(N 2) and for αA = 1 2 the variance does not decay with N. update to k is correlated to q and vice versa. However, it is possible to obtain Θ(1) updates to the attention variable for alternative values of αA if we choose the change to key (also query) entries after gradient descent to be δki Θ(N 1+αA). We show that this scaling can approximately preserve optimal hyperparameters across N in Figure 2 (a) and give similar dynamics under SGD Appendix C. However, as we show in Appendix E.1.2, any well defined N limit of SGD requires αA = 1. The reason is not that keys and queries become correlated, but rather that the scale of the backward pass must be controlled to ensure the dynamics remain stable (non-divergent) under SGD training. After performing two or more gradient descent steps, we demonstrate that the backpropagation signals will diverge as N unless initial key and query weight matrices are downscaled to have variance of order ΘN(1). In Appendix E, we provide a DMFT analysis of the N limit of the transformer training dynamics. We summarize the result of that analysis informally below. Result 1 (Infinite Dimension-Per-Head N) (Informal) A stable feature learning N limit of transformer SGD training requires taking αA = 1 (µP scaling), even if key/query updates are allowed to be rescaled to account for their correlation. The limiting dynamics of training are governed by the residual stream kernel Hℓ ss (x, x , t, t ) = 1 NHhℓ s(x, t) hℓ s (x , t ) and a collection of inner product kernels in each head h that concentrate as N V ℓ hss (x, x , t, t ) = 1 N vℓ hs(x, t) vℓ hs (x , t ) , Qℓ h(x, x , t, t ) = 1 N qℓ hs(x, t) qℓ hs (x , t ) (6) Kℓ hss (x, x , t, t ) = 1 N kℓ hs(x, t) kℓ hs (x , t ) , Aℓ hss (x, t) = 1 N kℓ hs(x, t) qℓ hs (x, t), (7) alongside residual-stream gradient kernels and response functions in the sense of [15, 11]. The NN output logits f(x, t) evolve deterministically according to the above kernels as well as kernels for the gradient vectors gℓ γ0NH f hℓwhich appear in the backward pass. These variables become identical across heads such that for any h, h [H], Aℓ hss (x, t) = Aℓ h ss (x, t). All preactivations on the residual stream and key/query/value variables within a MHSA block are statistically independent across neurons and can be described by a single scalar stochastic process hℓ+1 s (x, t) = hℓ s(x, t) + β0L αL uℓ s(x, t) + β0L αLuℓ+1 s (x, t) + η0γ0β2 0L 1 X Z dx h Cℓ ss (x, x , t, t ) gℓ s (x , t ) + Cℓ ss (x, x , t, t )gℓ s (x , t ) i kℓ hs(x, t) = uℓ Khs(x, t) + X Z dx Ckℓ ss (x, x , t, t )qℓ hs (x , t ) (8) where uℓ, uℓ, uℓ Kh are Gaussian processes with covariances Φℓ,1, V ℓσ, Hℓrespectively. Analogous equations hold for the queries and values. In the limit, the kernels Hℓ ss (x, x , t, t ) = hℓ s(x, t)hℓ s (x , t ) , Aℓ hss (x, t) = kℓ hs(x, t)qℓ hs (x, t) , etc. are computed as averages over these random variables. The deterministic kernels Cℓ, Cℓcan also be expressed in terms of single site averages of residual variables and head averages of MHSA variables. The kernels Cℓ, Cℓ, Ckℓ depend on the precise mini-batches of data Bt presented at each step t which we assume are known. We derive this result using a Martin-Siggia-Rose path integral formalism [40] for DMFT in Appendix E. Full DMFT equations can be found in Appendix E.2. Following prior works on DMFT for infinite width feature learning, the large-N limit can be straightforwardly obtained from a saddle point of the DMFT action [15, 11, 41, 17]. Collapse of Attention Heads As N , multi-head self-attention will effectively compute the same outputs as a single-head self-attention block. We theoretically derive this effect in Appendix E.2.1 and demonstrate it empirically in Figure 2 (b). This experiment shows that in αA = 1 (µP) transformers trained for 2500 steps on CIFAR-5M [42], the variance of attention matrices across heads decreases with N. However, we note that if the scaling exponent is chosen instead as αA = 1 2 there is non-decreasing diversity of attention variables across heads. This result is consistent with recent empirical findings that attention variables in µP transformers converge to the same limiting quantities at large N with H fixed for different initializations and also across model sizes [16]. This aspect of transformers in the large-N limit is potentially undesirable as some tasks could require learning multiple attention mechanisms. Furthermore, this suggests scaling the model in this limit could increase computational cost without improving performance. To circumvent this, we explore if there exist well defined limits at finite N with a diversity of attention variables across heads. 3.3 Scaling Number of Heads In this section, we take H with the inner dimension N fixed. Rather than having all kernels concentrate, the kernel of each head of the MHSA block follows a statistically independent stochastic process. This picture was shown to hold at initialization by Hron et al. [18]. Here, using a DMFT analysis, we show that it continues to hold throughout training in the feature learning regime. Result 2 (Infinite Head Limit) (Informal) The H limit of SGD training dynamics in a randomly initialized transformer at any key/query dimension N, scaling exponents αA, αL, and any depth L is governed by head-averaged kernels for pairs of input data x, x at training times t, t and spatial/token positions s, s such as V ℓ,σ ss (x, x , t, t ) = 1 NH h=1 vℓσ hs(x, t) vℓσ hs (x , t ) (9) which converge to deterministic values as H . The attention variables {kℓ h(x, t), qℓ h(x, t), vℓ h(x, t), Aℓ h(x, t)} within each head become statistically independent across heads and decouple in their dynamics (but not across dimensions within a head). Each residual stream neuron becomes independent and obeys a single site stochastic process analogous to Result 1, but with different kernels. We derive this and the full DMFT in Appendix E.3, showing that the joint distribution of headaveraged dynamical quantities satisfies a large deviation principle and the limit can be derived as a saddle point of a DMFT action. To gain intuition for this result, we first examine variables Hℓand Aℓ h at initialization. In Figure 3, we plot the convergence of a N = 4, L = 8 vision transformer s residual stream kernel Hℓto its H limit at rate O(H 1) in square error, consistent with perturbative analysis near the limit [17]. Next, we plot the distribution (over heads) of Ah at a fixed pair of spatial/token positions for a fixed sample. This is a non-Gaussian random variable for finite N, but as N the distribution of A will approach a Gaussian with variance Θ(N 1 2αA). We then investigate training dynamics as we approach the H limit. In Figure 4 (a) we show the test loss on CIFAR-5M as a function of the number of training iterations. The performance tends improve as H increases and the model approaches its limit. In Figure 4 (b) we show that all of the models are converging in function space by measuring the squared error between finite H head models and a proxy for the infinite H model. Since the H limit is essentially uncomputable, we approximate it as the ensemble averaged predictor of the widest possible models, a technique used in prior works [16, 11]. We again see that at early time, the logits of H head models converge to the limit proxy at a rate O(H 1), but after continued training the convergence rate weakens. This effect 101 102 103 = 1 = 2 = 3 = 4 = 5 = 6 = 7 = 8 (a) Initial HℓKernels, N = 4 2 1 0 1 2 10 2 = 32 = 64 = 128 = 256 Gaussian (b) Ah Distribution N = 1 2 1 0 1 2 10 2 = 32 = 64 = 128 = 256 Gaussian (c) Ah Distribution N = 16 Figure 3: The initial kernels converge as H and are determined by (possibly non-Gaussian) distributions of Aℓ h over heads in each layer. (a) Convergence of Hℓ ss (x, x ) = 1 HN hℓ s(x) hℓ s (x ) in a L = 8, N = 4 vision transformer at initialization at rate O(H 1). (b) The density of Aℓ h entries over heads at fixed spatial location converges as H but is non-Gaussian for small N. (c) As N the initial density of A approaches a Gaussian with variance of order O(N 1 2αA). 0 500 1000 1500 2000 2500 Steps = 16 = 32 = 64 = 128 = 256 = 512 = 1024 = 2048 (a) Training Dynamics Varying H 101 102 103 10 4 steps=25 steps=100 steps=250 steps=500 steps=1000 steps=2000 (b) Convergence to H limit Figure 4: Approaching the large head limit H in early portion of SGD dynamics for a vision transformer trained on CIFAR-5M with (L, N) = (2, 4) and (γ0, β, αA) = (0.05, 4, 1 2) and losses averaged over 10 random inits (colored error bars are standard deviations). (a) As H increases the loss and the variability over random initial seed decreases. (b) The mean square difference between output logits for H head models and a proxy for the infinite head model on a held out batch of test examples. Following prior works, our proxy for the limit is the ensemble averaged outputs of the widest models [16, 11]. has been observed in µP networks in many settings [16] and a theoretical model of this was provided in recent work which argues it arises from low-rank effects in the finite H kernels [39]. 3.4 Infinite Depth Limits We next describe the infinite depth limits which depend on the choice of αL. Below we informally describe the main finding which again uses a DMFT formalism and is based on analyses in recent works on infinite depth networks from Bordelon et al. [11] and Yang et al. [12]. Result 3 (Infinite Depth Limit) (Informal) The training dynamics for H, L with L αL branch scaling with αL [ 1 2, 1] is described by a differential equation for residual variables hs(τ, t) in layer time τ = lim L ℓ L for the residual stream hs(τ, x, t) = β0 δαL, 1 0 dus(τ , x, t) + η0γ0β2 0 X 0 dτ Css (τ , x, x , t, t )gs (τ , x , t ) (10) where the Brownian motion term dus(τ, x, t) survives in the limit only if αL = 1 2 and has covariance dus(τ, x, t)dus (τ , x , t ) = δ(τ τ )dτdτ [Φss (τ, x, x , t, t ) + V σ ss (τ, x, x , t, t )] (11) and the deterministic kernel Css (τ, x, x , t, t ) can be expressed in terms of head-averaged kernels and response functions. The weights inside each hidden MHSA layer or each MLP layer are frozen in the L limit unless αL = 1. All response functions are suppressed at L unless αL = 1 Below we provide a couple of short comments about this result. The proof and full DMFT is provided in Appendix E.4. 1. At initialization t = 0, the only term which contributes to the residual stream layer dynamics is the integrated Brownian motion R τ 0 du(τ ) which survives at infinite depth for αL = 1 2. For αL = 1 this term disappears in the limit. The structure of C(τ) is also modified by additional response functions at αL = 1 2 [11] which we show disappear for αL = 1. 2. The weights within residual blocks (including the MHSA block) can be treated as completely frozen for αL < 1 in the L limit, which leads to the simplified statistical description of the preactivations in those layers. However, the residual stream variables h(τ) do still obtain ΘL(1) updates. At αL = 1 the weights in the MHSA blocks evolve by ΘL(1), causing additional feature evolution in the model. 3. A consequence of our large N and large L result is that the N, L limit with αL = 1 2 (the parameterization studied by [11, 12]) would lead to Aℓ hss (x, t) = 0 for all time t. Thus the MHSA blocks would only involve average pooling operations over the spatial indices, despite the residual stream kernels Hℓupdating from feature learning. |Wq, k(t) Wq, k(0)| (a) Key/Query Weight Changes 105 106 107 108 109 1010 Compute L = 1 L = 4 L = 8 L = 16 L = 32 (b) Depth Scaling for αL { 1 Figure 5: Depth scaling in a vision transformer on CIFAR-5M with αL { 1 2, 1}. (a) The key and query weights move by 1/ L. (b) The compute scaling laws with models at fixed width N, H and varying depth L. At large L, the αL = 1 (dashed) models perform better at fixed compute. First, we note in Figure 5 that the weights within each attention block freeze as L with αL = 1 2 case but move at a constant scale for αL = 1. As a consequence, the loss at large L can be lower in the α = 1 parameterization. We can see some numerical evidence for the first of these effects in Figure 6 (a)-(b) where initially training at large L is slower than the base model and the initial kernel appears quite different for L = 4 and L = 64. The initial kernel will decrease in scale as L for αL = 1 since the preactivation vectors lose variance as we discuss in Appendix E.4, resulting in slower initial training. However, we note that the final learned feature kernels are quite similar after enough training. In summary, our results indicate that the αL = 1 parameterization is the one that allows attention layers to actually be learned in the limit L , but that this parameterization leads to a less structured kernel at initialization. 4 Experiments in Realistic Settings In practice, large scale neural networks do not generally operate close to their limit. Given the costs of training large networks, one would ideally operate in a regime where there is a guarantee of consistent improvements with respect to model scale. In pursuit of this goal, we apply our theoretical findings of this paper to training language models on a larger natural language dataset, a Transformer with causal attention blocks trained on the C4 dataset [43] with Adam optimizer. As mentioned in 2.2, while our exact theoretical description of these infinite limits focus on SGD, we can implement an appropriate scaling for Adam which preserves the scale of internal feature updates. This allows us 102 103 104 105 Steps = 128 N = 128 L = 64 (a) Training Dynamics Initial HL(x, x ) Final HL(x, x ) (b) Initial and Final Residual Stream Pooled Kernels (c) Spatial Kernels for Single Sample Initial Density 15 10 5 0 0.0 Final Density N = 16 N = 32 N = 64 N = 128 (d) Attention Distributions Before and After Training Figure 6: Initial and final representations are converging as model scale increases after one pass of training on the full CIFAR-5M with SGD+momentum. The base model is a (N, H, L) = (16, 16, 4) and (αA, αL, β0, γ0) = (1, 1, 4, 0.1). (a) The test loss dynamics for one pass through CIFAR-5M. The dynamics are very similar across different head-counts H but the early dynamics are changed for large depth L, consistent with our theory. (b) The initial and final feature kernels after spatial pooling at the last layer of the residual stream. The initial kernel at large L is quite different for αA = 1 due to suppression of Brownian motion on the forward pass, which we explain in Section 3.4. (c) The residual stream kernel across pairs of spatial positions for a single randomly chosen input sample. (d) The distribution of attention entries across heads at a fixed pair of spatial locations and data point. The initial variance of A decreases for αA = 1 but the update is roughly consistent across N. For αA = 1 2 both initial and final distributions for Ah are consistent across N. to investigate realistic training dynamics of our LLM as we take the N, L, H limits. Training details are provided in Appendix F In Figure 7 (a), we sweep over each of the model dimensions independently for each parameterization of αA {1, 1 2} on the left and right respectively. For fixed N and L, scaling H provides a similar increase in performance in both parameterization and appear to start converging to a final loss around 5, with slight benefit to αA = 1 2. For fixed H and L, scaling N provides a similar increase in performance to scaling heads in when αA = 1, but a substantial increase when αA = 1 2. This is in line with our predictions in Section 3.2 about the benefits of diversity across attention heads. Next, for fixed N and H, scaling L provides little to no benefit in either parameterization as predicted in Section 3.4. Finally, we inspect the sample and spatial residual stream kernels of these models before and after training and find that the kernels are identical for both αA, except for a slight difference for large N. Furthermore, they are extremely similar for large N and large H. Taken together, these results suggest that scaling different model dimensions do indeed have different effects on training dynamics and final performance. This provides groundwork for future large-scale experiments systematically investigating their trade-offs, thereby identifying compute-optimal scaling of realistic architectures in parameterizations with well-defined limits. 5 Discussion This paper provided analysis of the infinite head, depth and key/query dimension limits of transformer training in the feature learning regime. We showed that feature learning in µP multi-head transformers in the limit of N collapses to single-head self-attention. At finite N and infinite heads H 0 2000 4000 6000 8000 10000 3 Base H = 16 H = 32 H = 64 H = 128 0 2000 4000 6000 8000 10000 3 11 αA = 1 2 Base H = 16 H = 32 H = 64 H = 128 0 2000 4000 6000 8000 10000 3 Base N = 16 N = 32 N = 64 N = 128 0 2000 4000 6000 8000 10000 3 Base N = 16 N = 32 N = 64 N = 128 0 2000 4000 6000 8000 10000 Steps Base L = 8 L = 16 L = 32 L = 64 0 2000 4000 6000 8000 10000 Steps Base L = 8 L = 16 L = 32 L = 64 (a) Training dynamics of LLM on C4 Initial HL(x, x ) Final HL(x, x ): αA = 1 Final HL(x, x ): αA = 1 2 (b) Kernels: final token across samples Initial HL ss Final HL ss : αA = 1 Final HL ss : αA = 1 2 (c) Kernels: tokens within single sample Figure 7: Training dynamics and initial/final representations of decoder only language models trained on C4 converge with increasing model scale. The base model has (N, H, L) = (8, 8, 4) and (αL, β0, γ0) = (1, 4, 0.25) and αA {1, 1 2}. (a) Train loss dynamics after 10000 steps on C4 using Adam optimizer. The dynamics improve consistently when scaling H for both values of αA, with slight benefit to αA = 1 2. Scaling N reveals a significant advantage to setting αA = 1 2. Scaling L provides little improvement for either parameterization of αA. (b) Initial and final residual stream kernels for the final token across samples for Base, H = 128, N = 128, and L = 64 models. The first row is at initialization. The second and third rows are after training with αA {1, 1 2} respectively. (c) Initial and final feature kernels across pairs of tokens for a single randomly chosen input sample. Note both types of kernels are identical across αA except for a slight difference at large N. we showed that there is an alternative limit which maintains a distribution over attention heads. We discussed two different large depth limits of transformer training that reduce to differential equations in the residual layer time τ. The depth scaling that maintains feature learning within all MHSA blocks (αL = 1) causes the initial kernel to lose structure from the initialization as L , but allows learning of the self-attention variables, whereas the depth scaling that preserves structure from initialization (αL = 1 2) leads to static layers. Limitations and Future Directions Currently exact theoretical analysis of the limit is focused on SGD (and can be easily extended to SGD+momentum [15]) while Adam is currently only reasoned with rough scaling arguments rather than an exact theoretical description of the limit. Since Adam is most commonly used to train transformers, a theory of the limiting dynamics of Adam in Transformers would be an important future extension. In addition, while we provide an exact asymptotic description of network training, the limiting equations are compute intensive for realistic settings which is why we focus our empirical investigations on training large width networks in the appropriate parameterizations. Lastly our techniques assume that the number of training steps is fixed as the scaling parameters of interest (N, H, L) are taken to infinity. However, it would be important to understand learning dynamics in the regime where model size and training times are chosen to balance a compute optimal tradeoff (or perhaps even training longer than compute optimal) [8, 39, 44]. In this regime, harmful finite model-size effects become significant and comparable to the finite training horizon [10, 16, 17, 39]. Thus stress testing the ideas in this work at larger scales and longer training runs would be an important future direction of research into scaling transformer models. Acknowledgements and Disclosure of Funding BB would like to thank Alex Atanasov, Jacob Zavatone-Veth, Lorenzo Noci, Mufan Bill Li, Boris Hanin, Alex Damian, Eshaan Nichani for inspiring conversations. We would also like to thank Alex Atanasov and Jacob Zavatone-Veth for useful comments on an earlier version of this manuscript. BB is supported by a Google Ph D fellowship. HC was supported by the GFSD Fellowship, Harvard GSAS Prize Fellowship, and Harvard James Mills Peirce Fellowship. CP was supported by NSF Award DMS2134157 and NSF CAREER Award IIS2239780. CP is further supported by a Sloan Research Fellowship. This work has been made possible in part by a gift from the Chan Zuckerberg Initiative Foundation to establish the Kempner Institute for the Study of Natural and Artificial Intelligence. The computations in this paper were run on the FASRC Cannon cluster supported by the FAS Division of Science Research Computing Group at Harvard University. [1] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pages 10012 10022, 2021. [2] Kai Han, Yunhe Wang, Hanting Chen, Xinghao Chen, Jianyuan Guo, Zhenhua Liu, Yehui Tang, An Xiao, Chunjing Xu, Yixing Xu, et al. A survey on vision transformer. IEEE transactions on pattern analysis and machine intelligence, 45(1):87 110, 2022. [3] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. ar Xiv preprint ar Xiv:2010.11929, 2020. [4] Mostafa Dehghani, Alexey Gritsenko, Anurag Arnab, Matthias Minderer, and Yi Tay. Scenic: A jax library for computer vision research and beyond. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pages 21393 21398, 2022. [5] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017. [6] Jared Kaplan, Sam Mc Candlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. ar Xiv preprint ar Xiv:2001.08361, 2020. [7] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877 1901, 2020. [8] Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models. ar Xiv preprint ar Xiv:2203.15556, 2022. [9] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. ar Xiv preprint ar Xiv:2303.08774, 2023. [10] Ge Yang, Edward Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tuning large neural networks via zero-shot hyperparameter transfer. Advances in Neural Information Processing Systems, 34:17084 17097, 2021. [11] Blake Bordelon, Lorenzo Noci, Mufan Bill Li, Boris Hanin, and Cengiz Pehlevan. Depthwise hyperparameter transfer in residual networks: Dynamics and scaling limit. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview. net/forum?id=KZJehv RKGD. [12] Greg Yang, Dingli Yu, Chen Zhu, and Soufiane Hayou. Feature learning in infinite depth neural networks. In The Twelfth International Conference on Learning Representations, 2023. [13] Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Conference on Learning Theory, pages 2388 2464. PMLR, 2019. [14] Greg Yang and Edward J Hu. Tensor programs iv: Feature learning in infinite-width neural networks. In International Conference on Machine Learning, pages 11727 11737. PMLR, 2021. [15] Blake Bordelon and Cengiz Pehlevan. Self-consistent dynamical field theory of kernel evolution in wide neural networks. Advances in Neural Information Processing Systems, 35:32240 32256, 2022. [16] Nikhil Vyas, Alexander Atanasov, Blake Bordelon, Depen Morwani, Sabarish Sainathan, and Cengiz Pehlevan. Feature-learning networks are consistent across widths at realistic scales, 2023. [17] Blake Bordelon and Cengiz Pehlevan. Dynamics of finite width kernel and prediction fluctuations in mean field neural networks. ar Xiv preprint ar Xiv:2304.03408, 2023. [18] Jiri Hron, Yasaman Bahri, Jascha Sohl-Dickstein, and Roman Novak. Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pages 4376 4386. PMLR, 2020. [19] Emily Dinan, Sho Yaida, and Susan Zhang. Effective theory of transformers at initialization, 2023. [20] Yihe Dong, Jean-Baptiste Cordonnier, and Andreas Loukas. Attention is not all you need: Pure attention loses rank doubly exponentially with depth. In International Conference on Machine Learning, pages 2793 2803. PMLR, 2021. [21] Lorenzo Noci, Sotiris Anagnostidis, Luca Biggio, Antonio Orvieto, Sidak Pal Singh, and Aurelien Lucchi. Signal propagation in transformers: Theoretical perspectives and the role of rank collapse. Advances in Neural Information Processing Systems, 35:27198 27211, 2022. [22] Bobby He and Thomas Hofmann. Simplifying transformer blocks. ar Xiv preprint ar Xiv:2311.01906, 2023. [23] Aditya Cowsik, Tamra Nebabu, Xiao-Liang Qi, and Surya Ganguli. Geometric dynamics of signal propagation predict trainability of transformers, 2024. [24] Lorenzo Noci, Chuning Li, Mufan Li, Bobby He, Thomas Hofmann, Chris J Maddison, and Dan Roy. The shaped transformer: Attention models in the infinite depth-and-width limit. Advances in Neural Information Processing Systems, 36, 2024. [25] Soufiane Hayou. On the infinite-depth limit of finite-width neural networks. Transactions on Machine Learning Research, 2023. ISSN 2835-8856. URL https://openreview.net/ forum?id=Rb Ls Yz1Az9. [26] Nicola Muca Cirone, Maud Lemercier, and Cristopher Salvi. Neural signature kernels as infinite-width-depth-limits of controlled resnets. ar Xiv preprint ar Xiv:2303.17671, 2023. [27] Lénaïc Chizat and Praneeth Netrapalli. The feature speed formula: a flexible approach to scale hyper-parameters of deep neural networks, 2024. URL https://arxiv.org/abs/2311. 18718. [28] Jeremy Bernstein, Arash Vahdat, Yisong Yue, and Ming-Yu Liu. On the distance between two neural networks and the stability of learning. Advances in Neural Information Processing Systems, 33:21370 21381, 2020. [29] Greg Yang, James B Simon, and Jeremy Bernstein. A spectral condition for feature learning. ar Xiv preprint ar Xiv:2310.17813, 2023. [30] Jeremy Bernstein, Chris Mingard, Kevin Huang, Navid Azizan, and Yisong Yue. Automatic gradient descent: Deep learning without hyperparameters. ar Xiv preprint ar Xiv:2304.05187, 2023. [31] Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for overparameterized models using optimal transport. Advances in neural information processing systems, 31, 2018. [32] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018. [33] Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. Advances in neural information processing systems, 32, 2019. [34] Haim Sompolinsky and Annette Zippelius. Dynamic theory of the spin-glass phase. Physical Review Letters, 47(5):359, 1981. [35] Moritz Helias and David Dahmen. Statistical field theory for neural networks, volume 970. Springer, 2020. [36] Stefano Sarao Mannelli, Florent Krzakala, Pierfrancesco Urbani, and Lenka Zdeborova. Passed & spurious: Descent algorithms and local minima in spiked matrix-tensor models. In international conference on machine learning, pages 4333 4342. PMLR, 2019. [37] Francesca Mignacco, Florent Krzakala, Pierfrancesco Urbani, and Lenka Zdeborová. Dynamical mean-field theory for stochastic gradient descent in gaussian mixture classification. Advances in Neural Information Processing Systems, 33:9540 9550, 2020. [38] Cedric Gerbelot, Emanuele Troiani, Francesca Mignacco, Florent Krzakala, and Lenka Zdeborova. Rigorous dynamical mean field theory for stochastic gradient descent methods. ar Xiv preprint ar Xiv:2210.06591, 2022. [39] Blake Bordelon, Alexander Atanasov, and Cengiz Pehlevan. A dynamical model of neural scaling laws, 2024. [40] Paul Cecil Martin, ED Siggia, and HA Rose. Statistical dynamics of classical systems. Physical Review A, 8(1):423, 1973. [41] Blake Bordelon and Cengiz Pehlevan. The influence of learning rule on representation dynamics in wide neural networks. ar Xiv preprint ar Xiv:2210.02157, 2022. [42] Preetum Nakkiran, Behnam Neyshabur, and Hanie Sedghi. The deep bootstrap framework: Good online learners are good offline generalizers. ar Xiv preprint ar Xiv:2010.08127, 2020. [43] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 21(140):1 67, 2020. [44] Ibrahim M Alabdulmohsin, Xiaohua Zhai, Alexander Kolesnikov, and Lucas Beyer. Getting vit in shape: Scaling laws for compute-optimal model design. Advances in Neural Information Processing Systems, 36, 2024. [45] Etai Littwin and Greg Yang. Adaptive optimization in the -width limit. In The Eleventh International Conference on Learning Representations, 2022. [46] Mufan Li, Mihai Nica, and Dan Roy. The neural covariance sde: Shaped infinite depth-andwidth networks at initialization. Advances in Neural Information Processing Systems, 35: 10795 10808, 2022. A Additional Figures 100 101 102 103 Steps = 16 = 32 = 64 = 128 100 101 102 103 Steps N = 16 N = 32 N = 64 N = 128 100 101 102 103 Steps L = 4 L = 8 L = 16 L = 32 L = 64 Figure 8: One pass training on CIFAR-5M with vision transformers with the setting of Figure 6. Initial HL(x, x ) Final HL(x, x ) Initial HL(x, x ) Final HL(x, x ) Initial HL(x, x ) Final HL(x, x ) Figure 9: Examples of initial and learned kernels in final residual stream layer with various extrapolations of a base vision transformer model with (H, N, L) = (16, 16, 4) trained on CIFAR-5M. Figure 10: Spatial kernels for a single test point before and after training across H, N, L values. 0 500 1000 1500 2000 2500 Steps N = 8 N = 16 N = 32 N = 64 N = 128 N = 256 N = 512 (a) N Scaling with α = 1 0 500 1000 1500 2000 2500 Steps N = 8 N = 16 N = 32 N = 64 N = 128 N = 256 N = 512 (b) N Scaling with α = 1 Figure 11: Early training dynamics on CIFAR-5M in vision transformer with different dimensionper-head N with heads fixed at H = 4 for αA = {1, 1 109 1010 1011 1012 1013 4.0 Base H = 16 H = 32 H = 64 H = 128 109 1010 1011 1012 1013 4.0 7.0 αA = 1 2 Base H = 16 H = 32 H = 64 H = 128 109 1010 1011 1012 1013 4.0 Base N = 16 N = 32 N = 64 N = 128 109 1010 1011 1012 1013 4.0 Base N = 16 N = 32 N = 64 N = 128 109 1010 1011 1012 1013 Compute Base L = 8 L = 16 L = 32 L = 64 109 1010 1011 1012 1013 Compute Base L = 8 L = 16 L = 32 L = 64 Figure 12: Performance of language models trained on C4 in main text Figure 7(a) as a function of compute, estimated as FLOPs = 6 Params. The base model has size (N, H, L) = (8, 8, 4) and we examine scaling up N, H, L with either αA = 1/2 or αA = 1. The αA = 1 models perform better at fixed compute for either N or H scaling. Increasing L does not significantly increase compute in this regime since the embedding and decoding layers contribute most of the parameters. B Implementations for Vision and Causal Language Modeling Transformers We provide an example FLAX implementation of the vision transformer and causal language model. We start by defining a fixed layernorm operation 1 from flax import linen as nn 2 import jax.numpy as jnp 5 class LN_Fixed(nn.Module): 7 eps: jnp.float32 = 1.0e-6 8 @nn.compact 10 def __call__(self , x): 12 features = x.shape [-1] # number of features 13 mean = jnp.mean( x , axis = -1 ) # mean of x 14 var = jnp.var( x , axis = -1 ) # var of x 15 out = (x - mean[:,:,jnp.newaxis] ) / jnp.sqrt( var[:,:,jnp. newaxis] + self.eps ) 16 return out The MHSA layer is implemented as the following where scale_exp represents αA. 1 # MHSA attention layer 2 from einops import rearrange 4 class Attention(nn.Module): 5 """ Multi -head Self -Attention Layer """ 6 scale_exp: jnp.float32 8 heads: int 10 def setup(self): 12 self.c = 1.5 - self.scale_exp # exponent for the scale factor 13 kif_qk = nn.initializers.normal(stddev = self.dim **( self.c - 0.5) ) # possible scaling with N 14 kif_v = nn.initializers .normal(stddev = 1.0 ) # O_N (1) entries 15 # computes key , query , value 16 self.qk_layer = nn.Dense(features = 2 * self.heads * self.dim , kernel_init = kif_qk , use_bias = False) 17 self.v_layer = nn.Dense(features = self.heads * self.dim , kernel_init = kif_v , use_bias = False) 18 self.out_layer = nn.Dense(features = self.heads * self.dim , kernel_init = kif_v , use_bias = False) 21 def __call__(self ,inputs): 23 qk = self.qk_layer(inputs) / self.heads **(0.5) / self.dim **( self.c) 24 qk = rearrange( qk , b l (h d) -> b h l d , h = self.heads) # (batch , heads , loc , d ) 25 q,k = jnp.split(qk , 2, axis = -1) # gives q, k each of shape ( batch , heads , loc , d ) 27 v = self.v_layer(inputs) / jnp.sqrt( inputs.shape [-1] ) 28 v = rearrange(v, b l (h d) -> b h l d , h = self.heads) 29 A = self.dim**(- self.scale_exp) * jnp.einsum( ijkl ,ijml ->ijkm , q, k) # batch x heads x loc x loc 30 sigma_A = softmax( A, axis=-1 ) 31 out = jnp.einsum( ijkl ,ijlm ->ijkm , sigma_A , v) # (batch , head , loc , d) 32 out = rearrange(out , b h l d -> b l (h d) ) 33 out = self.out_layer(out) / jnp.sqrt( out.shape [-1] ) 34 return out The two layer MLP block is implemented as the following with ϕ = gelu nonlinearity. 1 class MLP_Block(nn.Module): 2 """ Two Layer MLP Block """ 3 features: int 5 @nn.compact 6 def __call__(self ,x): 7 N = self.features 8 kif = nn.initializers .normal(stddev = 1.0) # O_N (1) entries 9 h = nn.Dense(features = N, kernel_init = kif , use_bias = False )(x) / jnp.sqrt(N) 10 h = nn.gelu(h) 11 h = nn.Dense(features = N, kernel_init = kif , use_bias = False )(h) / jnp.sqrt(N) 12 return h We also allow for a trainable positional encoding matrix. 2 class Positional Encoding (nn.Module): 3 """ Trainable Positional Encoding """ 4 d_model : int # Hidden dimensionality of the input. 5 max_len : int # Maximum length of a sequence to expect. 6 scale: jnp.float32 # scale parameter for initialization 8 def setup(self): 9 # Create matrix of [Seq Len , Hidden Dim] representing the positional encoding for max_len inputs 10 self.pos_embedding = self.param( pos_embedding , 11 nn. initializers .normal(stddev = self.scale), 12 (1, 1+ self.max_len , self. d_model)) 14 def __call__(self , x, train=True): 15 B,T,_ = x.shape 16 x = x + self.pos_embedding [:,:T] / self.scale 17 return x Each residual block is implemented as the following. Below we show the αL = 1 implementation. 1 # Residual Block 2 class Resid Block(nn.Module): 5 heads: int 6 features: int 8 scale_exp: jnp.float32 = 1.0 9 beta: jnp.float32 = 4.0 11 @nn.compact 12 def __call__(self ,x): 13 h = LN_Fixed ()(x) 14 h = Attention(dim = self.dim , scale_exp = self.scale_exp , heads = self.heads)( h ) 15 x = x + self.beta / self.L * h 16 h = LN_Fixed ()(x) 17 h = MLP_Block(features = self.features)(h) 18 x = x + self.beta / self.L * h 19 return x Our vision transformer model consists of an embedding layer which is applied to each patch, a positional encoding layer, L residual layers each containing a MHSA and MLP block, a spatial pooling operation, and a readout. 2 class VIT(nn.Module): 4 "simple VIT model with " 6 heads: int 7 depth: int 8 patch_size: int 9 scale_exp: jnp.float32 = 1.0 10 adam_scale: int = 0.0 11 beta: jnp.float32 = 4.0 13 @nn.compact 14 def __call__(self , x): 15 d_model = self.heads * self.dim 16 L = self.depth 19 # patchify images 20 x = rearrange(x, b (w p1) (h p2) c -> b (w h) (p1 p2 c) , p1 = self.patch_size , p2 = self.patch_size) # (batch , loc , patch_ch_dim ) 22 kif_first= nn.initializers .normal(stddev = d_model **( -0.5* self .adam_scale) * (L/self.beta)**(0.5 * (1.0 - self.adam_scale)) ) # O_N (1) entries 23 kif = nn.initializers .normal( stddev = 1.0 ) # O_N (1) entries 24 kif_last = nn.initializers .normal(stddev = (L/self.beta)**(0.5 * (1-self.adam_scale) ) ) 26 # read -in weights 27 x = (L/self.beta)**( -0.5 * (1.0 - self.adam_scale))*d_model **(0.5 * self.adam_scale) * nn.Dense(features = N, kernel_init = kif_first , use_bias = False)(x) / jnp.sqrt( D * self.patch_size **2 29 # positional encoding 30 x = Positional Encoding (d_model = d_model , max_len = (32// self. patch_size)**2, scale = d_model **( -0.5* self.adam_scale)*(L/self. beta)**(0.5 * (1.0self.adam_scale)))(x) 32 # residual stream with pre -LN 33 for l in range(self.depth): 34 x = Resid Block(dim = self.dim , heads = self.heads , scale_exp=self.scale_exp , features = d_model , beta=self.beta , L = L)(x) 36 # last norm layer 37 x = LN_Fixed ()(x) 38 # pool over spatial dimension 39 x = x.mean(axis = 1) # (batch , d_model) 40 x = (L/self.beta)**( -0.5*(1 - self.adam_scale)) * nn.Dense( features = 10, use_bias = False , kernel_init = kif_last)(x) / d_model **(1.0 -0.5* self.adam_scale) # for mean field scaling 41 return x For the causal decoder only model, we need to modify the Attention layer and also prevent pooling over spatial indices before the readout. 2 class Causal_Attention (nn.Module): 4 scale_exp: jnp.float32 6 heads: int 7 qk_ln: bool = True 9 def setup(self): 11 self.c = 1.5 - self.scale_exp # exponent for the scale factor 12 kif_qk = nn.initializers.normal(stddev = self.dim **( self.c - 0.5) ) # possibly needs to be scaled with N 13 kif_v = nn.initializers .normal(stddev = 1.0 ) # O_N (1) entries 14 # computes key , query , value 15 self.qk_layer = nn.Dense(features = 2 * self.heads * self.dim , kernel_init = kif_qk , use_bias = False) 16 self.v_layer = nn.Dense(features = self.heads * self.dim , kernel_init = kif_v , use_bias = False) 17 self.out_layer = nn.Dense(features = self.heads * self.dim , kernel_init = kif_v , use_bias = False) 20 def __call__(self ,inputs): 22 qk = self.qk_layer(inputs) / self.heads **(0.5) / self.dim **( self.c) # (batch , loc , 3*h*d) 23 qk = rearrange( qk , b l (h d) -> b h l d , h = self.heads) # (batch , heads , loc , d ) 24 q,k = jnp.split(qk , 2, axis = -1) # gives q, k each of shape ( batch , heads , loc , d ) 26 v = self.v_layer(inputs) / jnp.sqrt( inputs.shape [-1] ) 27 v = rearrange(v, b l (h d) -> b h l d , h = self.heads) 29 A = 1.0/ self.dim **( self.scale_exp) * jnp.einsum( ijkl ,ijml -> ijkm , q, k) # batch x heads x loc x loc 30 exp_A = jnp.einsum( ijkl ,kl ->ijkl , jnp.exp(A), jnp.tril(jnp. ones ((v.shape [2], v.shape [2])))) 31 phi_A = exp_A / exp_A.sum(axis = -1)[:,:,:,jnp.newaxis] 33 out = jnp.einsum( ijkl ,ijlm ->ijkm , phi_A , v) # (batch , head , loc , d) 34 out = rearrange(out , b h l d -> b l (h d) ) 35 out = self.out_layer(out) / jnp.sqrt( out.shape [-1] ) 36 return out 39 class LM_Transformer(nn.Module): 40 """A simple Decoder only transformer """ 42 dim: int 43 heads: int 44 depth: int 45 scale_exp: jnp.float32 46 adam_scale: int 47 beta: jnp.float32 48 VOCAB_SIZE: int 50 @nn.compact 51 def __call__(self , x, train = True): 52 d_model = self.heads * self.dim 53 L = self.depth 54 kif_first = nn.initializers.normal(stddev = d_model **( -0.5* self.adam_scale) * (L/self.beta)**(0.5 * (1-self.adam_scale) ) ) # O(1) entries 55 kif0 = nn.initializers.normal(stddev = 0.0 ) 56 kif = nn.initializers .normal(stddev = 1.0) # O(1) entries 57 kif_last = nn.initializers .normal(stddev = (L/self.beta)**(0.5 * (1-self.adam_scale)) * d_model **( -0.5* self.adam_scale) ) 59 # embed the batch x sequence integers to 60 x = (L/self.beta)**( -0.5 * (1-self.adam_scale) )* d_model **(0.5 * self.adam_scale) * nn.Embed(self.VOCAB_SIZE , d_model , embedding_init = kif_first)(x) # batch x seq len x N 62 x = Positional Encoding (d_model = d_model , scale = d_model **( -0.5* self.adam_scale) * (L/self.beta)**(0.5 *(1self.adam_scale )) )(x) 64 for l in range(self.depth): 65 h = LN_Fixed ()(x) 66 x = x + self.beta/L * Causal_Attention (dim = self.dim , scale_exp = self.scale_exp , heads = self.heads)(h) 67 h = LN_Fixed ()(x) 68 x = x + self.beta/L * MLP_Block(features = d_model)(h) 70 x = LN_Fixed ()(x) 71 x = (L/self.beta)**( -0.5 * (1 - self.adam_scale ) ) * nn.Dense (features = self.VOCAB_SIZE , use_bias = True , kernel_init = kif0)( x) / d_model **(1.0 -0.5* self.adam_scale) # for mean field scaling 72 return x C Simple Heuristic Scaling Analysis In this section, we heuristically work out the simple scaling analysis to justify the set of parameterizations and learning rates we consider. More detailed theoretical analysis for the limit of SGD training is provided in Appendix E where we exactly characterize the N , H and L limits. We consider taking heads H, inner dimension N and depth L to infinity separately and attempt to control the scale of gradients and updates. C.1 Learning Rate Scalings We show that the correct learning rate scaling for SGD is η = η0NHL2αL 1. For Adam, the learning rate should be scaled as η = η0N 1/2H 1/2L 1+αL. Optimizer Bulk Parameters LR First Layer Rescale Factor SGD η0NHL2αL 1 L 1 Adam η0N 1/2H 1/2L 1+αL L1 αL Table 2: The learning rates which should be applied to obtain the correct scale of updates for SGD or Adam optimizers. In addition, the weight variance and multiplier for the first layer may need to be rescaled with depth depending on the parameterization and optimizer. C.2 Heuristic Analysis of Feature Changes Under SGD In this section we consider performing a single update on a single example to all weight matrices. δW ℓ Oh 1 L1 αL NH gℓ+1vℓ h (12) where gℓ+1 RNH and vℓ h RN have Θ(1) entries. Thus, computing a perturbation to the forward pass we find δhℓ+1 = δhℓ+ 1 LαL NH gℓ+1vℓ h h vℓ h vℓ h The term in the brackets is Θ(1) and we see that the perturbation from each layer contributes Θ(L 1). As there are L layers, this will give a total change to the final layer h L that is Θ(1). For the Attention variables, we note that the δW ℓ Kh 1 L1 αL NH qℓ hhℓ , δW ℓ Qh 1 L1 αL NH kℓ hhℓ (14) where qh, kℓ h RN are the query and key for head h and h RNH is the residual stream preactivation. We can thus compute the changes to the keys and queries due to changes in their associated weights NH qℓ hhℓ hℓ= 1 L1 αLN 1 αA qℓ h Hℓ NH kℓ hhℓ hℓ= 1 L1 αLN 1 αA qℓ h Hℓ. (15) where Hℓ= 1 NHhℓ hℓ Θ(1). Combining these changes , we find the following update to the pre-Attention variables Aℓ h = 1 NαA kℓ h qℓ h δAℓ h = 1 L1 αLN qℓ h qℓ h Hℓ+ 1 L1 αLN kℓ h kℓ h Hℓ+ 1 L2 2αLN 2 2αA Aℓ h(Hℓ)2 = Θ(L 1+αL), (16) N q q Θ(1). This update to the attention variable due to changes in W ℓ K, W ℓ Q will clearly die out as L unless αL = 1. C.3 Heuristic Analysis of Feature Changes Under Adam For Adam, the gradient of each individual parameter entry is approximately normalized by its scale [45]. Thus the learning rate η sets the size of the updates. This is why we scale the learning rate as η = 1 L1 αL NH which gives the same scale updates to the weights as SGD δW ℓ Oh η gℓ+1vℓ h = 1 L1 αL NH gℓ+1vℓ h (17) Again computing the correction to the forward pass we find δhℓ+1 = δhℓ+ 1 LαL NH gℓ+1vℓ h h vℓ h vℓ h gℓ+1 = Θ(1) (18) Similarly our update generates the same scale of weight updates to the key and query weight matrices δW ℓ Kh 1 L1 αL NH qℓ hhℓ , δW ℓ Qh 1 L1 αL NH kℓ hhℓ (19) We can therefore follow the identical argument to identify the scale of the change to the pre-attention variables δAℓ h = Θ(L1 αL). C.4 What Counts as Feature Learning for Attention Layers? Any parameterization with αN [ 1 2, 1] will cause all updates to Aℓ h and entries of hℓ+1 to be ΘN,H,L(1) across finite N. The entries of q and k only move by ΘN(1) if αA = 1 (µP scaling). However, we argue that this criterion is not strictly necessary. Rather, feature learning could alternatively be defined in terms of evolution of macroscopic variables (H, A, f, etc) rather than preactivation or key/query vector entries themselves. Table 3 summarizes two example values of αA which are of special interest for their N limits. Variance of A(0) Update to A Update to k, q Entries αA = 1 (µP) Θ(N 1) Θ(1) Θ(1) αA = 1 2 Θ(1) Θ(1) Θ(N 1 2 ) Table 3: Two interesting choices of scaling for the attention layer exponent αA which give approximately constant updates to the attention matrices Ah. The µP scaling αA = 1 causes the entries of the key/query vector entries to move non-negligibly but causes all heads to be identical (and all A = 0) at initialization. Scaling instead with αA = 1 2 causes the A variables to be random but still non-negligibly updated under training. The choice αA = 1 2 allows the variance of Aℓ h to be constant size as a function of N while also enabling learning of these variables. We verify these scalings in Figure 13. 102 103 104 N t = 1 t = 200 (a) Change in k Entries (SGD) 102 103 104 N t = 1 t = 200 (b) Change in A Entries (SGD) 102 103 104 N t = 1 t = 200 (c) Change in k Entries (Adam) 102 103 104 N t = 1 t = 200 (d) Change in A Entries (Adam) Figure 13: The update to (a) key kh entries and (b) pre-attention variables Ah after t steps of gradient descent for scaling exponents αA {1, 1 2}. At the first step of SGD, the updates to the keys and attention variables are suppressed due to a lack of correlation between WO and the gradient f h. After training for multiple steps, this correlation increases and non-negligible updates to the attention variables occur. (c)-(d) The same but for the Adam optimizer with our proposed parameterization. D DMFT Primer and Simple Examples D.1 Main Conceptual Idea of the Approach Dynamical mean field theory is a method that was developed in the physics of spin glasses for dealing with dynamical systems that depend on a fixed source of disorder. The disorder could be random couplings between sites in a spin glass model [34], random connections between neurons in a random recurrent neural network [35], random data drawn from a distribution [37, 39] or the random initial weights in a deep neural network [15, 11]. In our case, we are interested in the last example, where the feature learning dynamics of a randomly initialized transformer is a function of the initial weights in each layer. In what follows, we will give a primer on the main objects which typically appear in a DMFT analysis (the correlation and response functions) to illustrate the main ideas of the approach. D.2 Example 1: Linear Dynamics with GOE Matrix In this section, we discuss and derive the DMFT equations for the simplest possible example, a linear dynamical system with a Gaussian symmetric coupling matrix. In this example we show that the DMFT path integral is computing something non-trivial about the kinds of dynamics induced by a linear dynamical system with a random matrix. In this linear example, the DMFT path integral encodes spectral properties of the random matrix. Let s consider the simplest possible example: d dthi(t) = 1 N PN j=1 Wijhj(t) where Wij = Wji is a Gaussian symmetric matrix (GOE). This matrix is fixed while the state h(t) RN evolves. The path integral appraoch would tell you that in the N limit, every neuron i has identical statistics given by the stochastic integro-differential equation th(t) = u(t) + Z t 0 ds R(t, s)h(s) , u(t) GP(0, C(t, s)) C(t, s) = h(t)h(s) , R(t, s) = δh(t) where denotes an average over the random variables u(t). In this picture, the averages over the noise can also be interpreted as averages over all N neurons in the system, each of which are independent. This stochastic equation can be used to close the evolution equations for the correlation C(t, s) and linear response function R(t, s). A generic result of this path integral DMFT picture is 1. All neurons (all variables hi) decouple statistically. The presence of all other neurons only enters through "macroscopic" quantities C(t, s) and R(t, s) known as the correlation and response functions. The distribution of these functions over random realizations satisfies a large deviations principle where the distribution over C, R has the form p(C, R) e NS(C,R) where S is the DMFT action obtained from the path integral method. 2. Extra memory terms like R t 0 R(t, s)h(s) appear which depend on the state at earlier times s < t. The Markovian (deterministic) system for p(h|W ) becomes stochastic and nonmarkovian after marginalizing p(h) = R d W p(h|W )p(W ). I would argue these memory terms are not obvious apriori but are systematic to compute in this framework. Since this toy example is a linear dynamical system, one can also identify a connection between the DMFT correlation and response and spectral properties of the random matrix W . We note that the response has the form R(t, s) = 1 N Tr exp (W(t s)) = Z dλρ(λ)eλ(t s) (21) where ρ(λ) is the eigenvalue density of W. In fact a Fourier transform of our DMFT equation recovers the semicircle law ρ(λ) = 1 πIm R(iλ) = 1 2π p [4 λ2]+ for the eigenvalues. In general, one can think of DMFT as a more powerful version of this method that can also handle nonlinearities. D.3 Example 2: Deep Linear Network Updates In this section I will try showing how this DMFT approach can give useful insights into reasoning about learning updates which are not obvious apriori. While our paper advocates for taking depth L in a residual network, we first thought about simply scaling depth in a standard MLP. Below we show how the proliferation of response terms gives a different predicted scaling with L than if we naively disregarded response terms. Consider a non-residual linear MLP network with µP/mean-field scaling with L hidden layers with N . Train the model for a single step of gradient descent with learning rate η on a data point (x, y) with |x|2 = 1 and y = 1 and output multiplier 1/γ0. The forward pass variables hℓ(t) and the backward pass variables gℓ(t) are defined recursively as hℓ+1(t) = 1 N W ℓ(t)hℓ(t) = 1 N W ℓ(0)hℓ(t) + ηγ0 X s