# transformers_learn_through_gradual_rank_increase__567fc011.pdf Transformers learn through gradual rank increase Enric Boix-Adser a 1,2 Etai Littwin 1 Emmanuel Abbe1,3 Samy Bengio1 Joshua Susskind1 1Apple 2MIT 3EPFL eboix@mit.edu,emmanuel.abbe@epfl.ch {elittwin,bengio,jsusskind}@apple.com We identify incremental learning dynamics in transformers, where the difference between trained and initial weights progressively increases in rank. We rigorously prove this occurs under the simplifying assumptions of diagonal weight matrices and small initialization. Our experiments support the theory and also show that phenomenon can occur in practice without the simplifying assumptions. 1 Introduction The transformer architecture achieves state of the art performance in various domains, yet we still lack a solid theoretical understanding of its training dynamics [VSP+17, DCLT19, LOG+19, DBK+20]. Nevertheless, the theoretical toolbox has matured over the last years and there are promising new approaches. One important line of work examines the role that initialization scale plays on the trajectory taken by gradient descent [JGH18, COB18, GSJW19, MGW+20, JGS+21, SS21, KC22]. When the weights are initialized small, it has been shown for simple networks that an incremental learning behaviour occurs, where functions of increasing complexity are learned in stages. This regime is known to be richer than the large-initialization regime1, but the incremental learning dynamics are difficult to analyze, and are so far understood only for extremely simple architectures. Can we apply this analysis to transformers? Namely: Are there incremental learning dynamics when training a transformer architecture? An obstacle is that past work on incremental learning has mainly studied linear networks [Ber22, ACHL19, MKAA21, LLL20, WGL+19, JGS+21, GSSD19, SKZ+23, PF23], with one paper studying nonlinear 2-layer fully-connected networks [BPVF22]. In contrast, transformers have nonlinear attention heads that do not fall under previous analyses: given X Rn d, an attention head computes attention(X; WK, WQ, WV , WO) = smax(XWKW Q X )XWV W O (1) where WK, WQ, WV , WO Rd d are trainable matrices, and the softmax is applied row-wise. A transformer is even more complex, since it is formed by stacking alternating layers of attention heads and feedforward networks, along with residual connections. Main finding Our main finding is that transformers exhibit incremental learning dynamics, where the difference between the trained and initial weights incrementally increases in rank. Our results have a theoretical component and an experimental component. 1In the large-initialization regime, deep learning behaves as a kernel method [JGH18, COB18]. Various separations with kernels are known for smaller initialization: e.g., [GMMM19, ABM22, MKAS21]. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Figure 1: For an attention head in Vi T trained on (a) CIFAR-10, and (b) Image Net, we plot the normalized spectra of WKW Q at initialization (in red), and of the learned perturbations to WKW Q at different iterations (in green). Theoretical contributions For our theory, we study a simplification of the transformer architecture, where the attention head weights are diagonal matrices: i.e., in each attention head we have WK = diag(w K), where w K Rd are trainable weights, and similarly for WQ, WV and WO. We rigorously establish the training dynamics of this architecture under gradient flow when the initialization is small. We prove that dynamics occur in discrete stages: (1) during most of each stage, the loss plateaus because the weights remain close to a saddle point, and (2) at the end, the saddle point is quickly escaped and the rank of the weights increases by at most one. This theoretical result on transformers follows from a general theorem characterizing the learning dynamics of networks f NN that depend on the product of parameters u, v Rp as f NN(x; u, v) = h(x; u v) , (2) where x is the input, denotes the elementwise product, and h is a smooth function. Theorem 1.1 (Informal statement of incremental learning dynamics). Let f NN be a network of the form (2), and suppose that the weights are initialized very small: i.e., the entries of u, v are initialized on the order Θ(α) for some small α > 0. Then the dynamics of gradient flow training effectively proceeds in discrete stages, each one lasting time Θ(log(1/α)). In each stage, the number of nonnegligible entries of u v increases by at most one. A transformer with diagonal weight matrices falls under this result when we only train the attention head weights. For example, if the transformer has one attention head, then we can take u = [w K, w V ] R2d and v = [w Q, w O] R2d to be concatenations of the diagonal entries of the weights of the head; see Example 3.2 for more details and the extension to transformers with many heads. Then, using Theorem 1.1, we see that in each stage either WKW Q = diag(w K)diag(w Q) or WV W O = diag(w V )diag(w O) increases in effective rank by at most one.2 Experimental contributions In our experiments, we first validate our theoretical results, which require the simplifying assumptions of small initialization and diagonal weight matrices. Then, we conduct experiments on vision and language transformers in settings closer to practice, without any of the assumptions required by our theoretical analysis. Perhaps surprisingly, we again observe incremental learning dynamics, even though the assumptions of the theory are not met. The difference between trained and initial weights has low rank, and the rank of this difference grows gradually during training; see Figure 1. The incremental nature of the dynamics is easier to see for Image Net, since for CIFAR-10 the rank of the weight difference does not grow as much. 1.1 Related work Relation to Lo RA We note an intriguing connection to the Lo RA algorithm, where a pretrained base model is cheaply fine-tuned by training a low-rank perturbation of the weights [LFLY18, AZG20, HSW+21]. The method is surprisingly powerful, and recently Lo RA has been fundamental to allowing the open-source community to inexpensively fine-tune language models [PA23, TGZ+23]. On the other hand, in our work we observe that the trained weights are a low-rank perturbation of the initial weights due to the training dynamics, without having to apply an explicit rank constraint as in Lo RA. This raises an exciting open question for future work: can we explain and improve algorithms like Lo RA by better understanding and quantifying the incremental dynamics of large transformers? 2We also remark that Theorem 1.1 is interesting in its own right and may have other applications beyond transformers. It qualitatively recovers the incremental dynamics result of [Ber22, PF23] when specialized to linear diagonal networks, i.e., when f NN(x; u, v) = Pp i=1 uivixi. Low-rank bias in nonlinear networks For 2-layer networks, it is known that low-rank bias in the weights emerges if the target function depends on a low-dimensional subspace of the input [ABM22, ABM23, DLS22, BBSS22, MHPG+22]. The results of [ABM22, ABM23] are especially relevant, since they show that the rank of the weights increases in a sequential manner, determined by the leap complexity of the target function, which is reminiscent of our empirical observations on transformers. See also [FVB+22, TVS23] for more investigations of low-rank bias in 2-layer networks under different assumptions. For transformers, [YW23] report that empirically the trained weights (using default initialization) are not low-rank. This is consistent with our claim that the difference between initial and trained weights is low-rank, since the initial weights might not be low-rank. Incremental learning dynamics Several works prove incremental learning behaviour in deep linear networks when the initialization is small. [GBLJ19] has shown that gradient descent dynamics on a 2-layer linear network with L2 loss effectively solve a reduced-rank regression problem with gradually increasing rank. [GSSD19] prove a dynamical depth separation result, allowing for milder assumptions on initialization scale. [ACHL19, MKAA21] show implicit bias towards low rank in deep matrix and tensor factorization. [LLL20] show deep matrix factorization dynamics with small initialization are equivalent to a greedy low-rank learning (GLRL) algorithm. And [JGS+21] independently provides a similar description of the dynamics, but without requiring balanced initialization. Finally, [Ber22, JLL+23, PF23] overcome a technical hurdle from previous analyses by proving incremental learning for the entire training trajectory, rather than just the first stage. In contrast to our result, these prior works apply only to linear networks with certain convex losses, whereas our result applies to nonlinear networks. In order to make our extension to nonlinear networks possible, we must make stronger assumptions on the training trajectory, which we verify hold empirically. As far as we are aware, one other work on incremental learning handles nonlinear networks: [BPVF22] proves that a 2-layer network learns with a two-stage incremental dynamic; but that result needs the stylized assumption that all data points are orthogonal. 1.2 Paper organization Sections 2, 3, and 4 contain theoretical preliminaries, definitions of the models to which our theory applies, and our main theoretical result on incremental dynamics. Section 5 provides experiments which verify and extend the theory. Section 6 discusses limitations and future directions. 2 Preliminaries We consider training a network f NN( ; θ) parametrized by a vector of weights θ, to minimize a loss L(θ) = Ex,y[ℓ(y, f NN(x; θ))] , where the expectation is over samples (x, y) Rdx Rdy from a training data distribution, and ℓ: Rdy Rdout R. Consider a solution θ(t) to the gradient flow3 θ(0) = αθ0, dθ dt = θL(θ) (3) where α > 0 is a parameter governing the initialization scale, that we will take small. For our theory, we henceforth require the following mild regularity assumption on the loss and data. Assumption 2.1 (Regularity of data distribution and loss). The function ℓ(y, ζ) is continuously twice-differentiable in the arguments [y, ζ] Rdy+dout. There exists C > 0 such that almost surely the data is bounded by x , y C. The assumption on ℓis satisfied in typical cases such as the square and the cross-entropy losses. The data boundedness is often satisfied in practice (e.g., if the data is normalized). We also use the notation supp(a) := {i : ai = 0} to denote the support of a vector a. 3Gradient flow training can be obtained as a limit of SGD or GD training as the learning rate tends to 0 (see, e.g., [Bac20]). It is a popular testbed for studying learning dynamics (see e.g., [SMG13, ACH18, RC20]), since is generally simpler to analyze than SGD. 3 Neural networks with diagonal weights Our theory analyzes the training dynamics of networks that depend on products of diagonal weight matrices. We use to denote elementwise vector product. Definition 3.1. A network f NN is smooth with diagonal weights θ = (u, v) R2p if it is of the form f NN(x; θ) = h(x; u v) where h : Rdx Rp Rdout is continuously twice-differentiable in its arguments in Rdx+p. The assumption on h precludes the use of the Re LU function since it is not continuously-differentiable. Otherwise the assumption is fairly mild since any h can be used to express an architecture of any depth as long as the nonlinearities are twice-differentiable, which includes for example Ge LUs (as used in Vi T). We describe how to express a transformer with diagonal weights. Example 3.2 (Transformer with diagonal weights). A transformer with L layers and H attention heads on each layer is defined inductively by Z0 = X Rn d and (Attention layer) Zℓ= Zℓ 1 + PH i=1 attention(Zℓ 1; W ℓ,i K , W ℓ,i Q , W ℓ,i V , W ℓ,i O ) (Feedforward layer) Zℓ= Zℓ+ σ( ZℓW ℓ A)(W ℓ B) , where W ℓ,i K , W ℓ,i Q , W ℓ,i V , W ℓ,i O Rd d are attention parameters, and W ℓ A, W ℓ B Rd d are the feedforward parameters, and σ is a continuously twice-differentiable activation. Suppose that the attention parameters are diagonal matrices: i.e., W ℓ,i K = diag(wℓ,i K ) Rd d, and similarly for the W ℓ,i Q , W ℓ,i V , W ℓ,i O matrices. Then by the definition of the attention layer (1), the final output of the transformer ZL only depends on the attention parameters through the elementwise products wℓ,i K wℓ,i Q and wℓ,i V wℓ,i O . In other words, we can write ZL = h(X; u v) , for vectors u = [wℓ,i K , wℓ,i V ](ℓ,i) [L] [H] R2d HL and v = [wℓ,i Q , wℓ,i O ](ℓ,i) [L] [H] R2d HL, and some smooth model h. Thus, if only the attention layers are trained, the diagonal transformer fits under Definition 3.1. 4 Incremental learning in networks with diagonal weights We prove that if the initialization scale α is small, then learning proceeds in incremental stages, where in each stage the effective sparsity of the weights increases by at most one. These stages are implicitly defined by Algorithm 1 below, which constructs a sequence of times 0 = T0 < T1 < < Tk < and weight vectors θ0, θ1, . . . , θk, . . . R2p that define the stages. We prove the following: Theorem 4.1 (Incremental dynamics at small initialization). Let f NN be a model with diagonal weights as in Definition 3.1. For any stage k and time t (Tk, Tk+1) the following holds under Assumptions 2.1, 4.3, 4.4 and 4.5. There is α0(t) > 0 such that for all α < α0, there exists a unique solution θ : [0, t log(1/α)] Rp to the gradient flow (3) and lim α 0 θ(t log(1/α)) θk , and at each stage the sparsity increases by at most one: supp(θk+1) \ supp(θk) {ik}.4 Application: transformer with diagonal weights Before giving the intuition for this theorem and stating the assumptions formally, let us discuss its application to the diagonal transformer model from Example 3.2. As a corollary of Theorem 4.1, the gradient flow on a diagonal transformer with small initialization will learn in stages, where in each stage there will be at most one head i [H] in one layer ℓ [L] such that either the rank of W ℓ,i K (W ℓ,i Q ) = diag(wℓ,i K )diag(wℓ,i Q ) or the rank of W ℓ,i V (W ℓ,i O ) = diag(wℓ,i V )diag(wℓ,i O ) increases by at most one. In Figure 2, we illustrate these dynamics in the toy case of a single attention head trained in a student-teacher setup. 4Abusing notation, for θ = (u, v) Rp Rp, we write supp(θ) = supp(u) supp(v). Algorithm 1 Incremental learning in networks with diagonal weights 1: b0, θ0 0 Rp, T0 0 2: for stage number k = 0, 1, 2, . . . do 3: # (A) Pick new coordinate ik [p] to activate. 4: For each i, define time k(i) until active using (10). 5: Pick winning coordinate ik using (11) 6: Calculate time Tk+1 using (11) and break if 7: Update logarithmic weight approximation bk+1 using (12) 8: # (B) Train activated coordinates to stationarity. 9: θk+1 limiting dynamics point from (13) 10: end for 0 200 400 600 800 1000 time / log(1/alpha) training loss T raining loss vs. rescaled time, for various alpha alpha = 0.1 alpha = 0.01 alpha = 0.0001 alpha = 1e-08 alpha = 1e-16 alpha = 1e-32 0 200 400 600 800 1000 ime / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 0.0001 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 0.0001 Figure 2: (a) Loss versus rescaled time in the toy task of learning an attention head with diagonal weights, for various initialization scales α. The loss curves converge as α 0 to a curve with stagewise loss plateaus and sharp decreases, as predicted by the theory; some stagewise learning behavior is already clear with α = 0.01. (b) Each line shows the evolution of one of the entries of diag(w Q)diag(w K) and diag(w V )diag(w O) over rescaled time, demonstrating that the rank of these matrices increases incrementally; see Appendix A for experimental details and further results. 4.1 Intuition for incremental learning dynamics We develop an informal intuition for Theorem 4.1 and fill out the definition of Algorithm 1. A model f NN with diagonal weights θ = (u, v) as in Definition 3.1 evolves under the gradient flow (3) as du dt = v g(θ), dv dt = u g(θ) where (4) g(θ) = Ex,y[Dℓ(y, h(x; u v)) Dh(x; u v) ] . Here Dℓ(y, ) R1 dout is the derivative of ℓin the second argument and Dh(x, ) Rdout p is the derivative of h in the second argument. The first key observation is a conservation law that simplifies the dynamics. It can be viewed as the balancedness property for networks with linear activations [ACH18, DHL18], specialized to the case of diagonal layers. Lemma 4.2 (Conservation law). For any i [p] and any time t, we have u2 i (t) v2 i (t) = u2 i (0) v2 i (0) . (5) Proof. This follows from d dt(u2 i v2 i ) = uivigi(θ) uivigi(θ) = 0. The conservation law reduces the degrees of freedom and means that we need only keep track of p parameters in total. Specifically, if we define wi(t) := ui(t) + vi(t), then the vector w(t) = u(t) + v(t) evolves by dt = w g(θ) . (6) Using the conservation law (5), we can keep track of the weights in terms of the initialization and w(t): 2(w(t) + u 2(0) v 2(0) 2(w(t) u 2(0) v 2(0) Therefore it suffices to analyze the dynamics of w(t). 4.1.1 Stage 1 of dynamics Stage 1A of dynamics: loss plateau for time Θ(log(1/α)) At initialization, θ(0) 0 because the weights are initialized on the order of α which is small. This motivates the approximation g(θ(t)) g(0), under which the dynamics solve to: w(t) w(0) eg(0)t. (8) Of course, this approximation is valid only while the weights are still close to the small initialization. The approximation breaks once one of the entries of θ(t) reaches constant size. By combining (7) and (8), this happens at time t T1 log(1/α) for T1 = min i [p] 1/|gi(0)| . Until this time, the network remains close to its initialization, and so we observe a loss plateau. Stage 1B of dynamics: nonlinear dynamics for time O(1) Subsequently, the loss decreases nonlinearly during a O(1) time-scale, which is vanishingly short relative to the time-scale of the loss plateau. To prove this, we make the non-degeneracy assumption that there is a unique coordinate i0 such that 1/|gi0(0)| = T1. Under this assumption, in stage 1A all weights except for those at coordinate i0 remain vanishingly small, on the order of oα(1). Concretely, for any small ϵ > 0, there is a time t1(ϵ) T1 log(1/α) and sign s {+1, 1} such that5 ui0(t1) ϵ, vi0(t1) sϵ and |ui(t1)|, |vi(t1)| = oα(1) for all i = i0. Because all coordinates except for i0 have vanishingly small oα(1) weights after stage 1A, we may perform the following approximation of the dynamics. Zero out the weights at coordinates except for i0, and consider the training dynamics starting at θ = (ϵei0, sϵei0). After O(1) time, we should expect these dynamics to approach a stationary point. Although the evolution is nonlinear, all entries remain zero except for the i0 entries, so the stationary point is also sparse. Mathematically, there is a time t1 = t1 + O(1) T1 log(1/α) such that θ( t1) (aei0, saei0) := θ1 , for some a R>0, where θ1 is a stationary point of the loss.6 Despite the nonlinearity of the dynamics, the approximation can be proved using Gr onwall s inequality since t1 t1 = O(1) is a constant time-scale. To summarize, we have argued that the network approximately reaches stationary point that is 1-sparse, where only the weights at coordinate i0 are nonzero. 5Without loss of generality, we can ensure that at initialization u(0) and u(0) + v(0) are nonnegative. This implies u(t) is nonnegative. The fact that ui0 and vi0 are roughly equal in magnitude but might differ in sign is due to the conservation law (5). See Appendix C.3 for details. 6The entries of u and v are close in magnitude (but may differ in sign) because of the conservation law (5). 4.1.2 Later stages We inductively extend the argument to any number of stages k, where each stage has a Θ(log(1/α))- time plateau, and then a O(1)-time nonlinear evolution, with the sparsity of the weights increasing by at most one. The argument to analyze multiple stages is analogous, but we must also keep track of the magnitude of the weights on the logarithmic scale, since these determine how much longer . Inductively on k, suppose that there is some Tk R, bk Rp and θk R2p and a time tk Tk log(1/α) such that logα(w( tk)) bk and θ( tk) θk, where θk is a stationary point of the loss. Our inductive step shows that there is Tk+1 R such that during times t ( tk, Tk+1 log(1/α) Ω(1)) the weights remain close to the stationary point from the previous stage, i.e., θ(t) θk. And at a time tk+1 Tk+1 log(1/α) we have logα(w( tk+1)) bk+1 and θ( tk+1) θk+1, where θk+1 and bk+1 are defined below (summarized in Algorithm 1). Most notably, θk+1 is a stationary point of the loss whose support grows by at most one compared to θk. Stage (k + 1)A, loss plateau for time Θ(log(1/α)) At the beginning of stage k + 1, the weights are close to the stationary point θk, and so, similarly to stage 1A, linear dynamics are valid. w(t) w( tk) eg(θk)(t tk) . (9) Using the conservation law (7), we derive a time until active for each coordinate i [p], which corresponds to the time for the weight at that coordinate to grow from oα(1) to Θ(1) magnitude: k(i) = (bk i 1 + sgn(gi(θk)))/gi(θk), if gi(θk) = 0 , if gi(θk) = 0 (10) The linear dynamics approximation (9) breaks down at a time t Tk+1 log(1/α), where Tk+1 = Tk + k(ik), ik = arg min i [p] k(i) , (11) which corresponds to the first time at the weights at a coordinate grow from oα(1) to Θ(1) magnitude. And at times t Tk+1 log(1/α), on the logarithmic scale w is given by logα(w(t)) bk+1 := bk g(θk) k(ik) , (12) Stage (k + 1)B of dynamics: nonlinear dynamics for time O(1) Subsequently, the weights evolve nonlinearly during O(1) time. In a similar way to the analysis of Stage 1B, we show that at a time tk+1 = tk+1 + O(1) Tk+1 log(1/α), we have θ( tk+1) θk+1 := lim ϵ 0 lim t ψk(t, ϵ) , (13) where the dynamics ψk(t, ϵ) R2p are initialized at ψk(0, ϵ) = θk + (ϵeik, sgn(gi(θk))ϵeik) and evolve according to the gradient flow dψk(t,ϵ) dt = θL(ψk). This concludes the inductive step. 4.2 Assumptions for incremental dynamics To make this intuition rigorous, we formalize below the assumptions required for Theorem 4.1. In Figure 3 and Appendix A, we provide experiments validating these assumptions on the toy model. The first assumption is that the dynamics are non-degenerate, in the sense that two coordinates do not have weights that grow from oα(1) to Θ(1) size at the same rescaled time. We also place a technical condition to handle the corner case when a coordinate leaves the support of the current stage s stationary point. Assumption 4.3 (Nondegeneracy of dynamics in part (A)). The initialization satisfies |ui(0)| = |vi(0)| for all i. For stage k, either Tk+1 = or there is a unique minimizer ik to mini k(ik) in (11). Finally, for all i supp(θk 1) \ supp(θk) we have gi(θk) = 0. 778 779 780 781 782 783 784 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 778.0 322.5 325.0 327.5 330.0 332.5 335.0 337.5 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 321.4 Figure 3: Validation of assumptions on the toy model of learning a single attention head. (a) Assumption 4.4: weights perturbed at a random time during training (solid lines) tend back to the near-stationary point (dashed lines). (b) Assumption 4.5: weights perturbed at the beginning of a stage (solid lines) have same nonlinear evolution as without perturbation (dashed lines). Details of these experiments and further validations are provided in Appendix A. Next, we require that very small perturbations of the coordinates outside of supp(θk) do not change the dynamics. For this, it suffices that θk be a strict local minimum. Assumption 4.4 (Stationary points are strict local minima). For stage k, there exist δk > 0 and ck > 0 such that for u B(uk, δ) supported on supp(uk), we have L( u, sk u) ck uk u 2 Finally, we require a robust version of the assumption that the limit (13) exists, asking for convergence to a neighborhood of θk+1 even when the initialization is slightly noisy. Assumption 4.5 (Noise-robustness of dynamics in part (B)). For any stage k with Tk+1 < and any ϵ > 0, there are δ > 0 and τ : R>0 R such that the following holds. For any u B(uk, δ) Rp 0 supported on supp( u) supp(uk) {ik}, there exists a unique solution ψ : [0, ) Rp of the gradient flow dψ dt = θL(ψ) initialized at ψ(0) = ( u, sk+1 u), and at times t τ( ψik), ψ(t) θk+1 < ϵ . 5 Experimental results We run experiments that go beyond the toy diagonal attention head model (see Figures 2 and 3) to test the extent to which low-rank incremental learning occurs in popular models used in practice. We conduct experiments with vision transformers (Vi T) [DBK+20] trained on the CIFAR-10/100 and Image Net datasets, and with the GPT-2 language transformer [BMR+20] trained on the Wikitext-103 dataset. Full experiments are deferred to Appendix B. Gradual rank increase in vision and language models We train practical transformer architectures on vision and language tasks using Adam and the cross-entropy loss. We train all layers (including the feedforward layers). To capture the low-rank bias with a non-vanishing initialization scale, we study the spectrum of the difference WKW Q and WV W O between the weights post-training and their initial values. Specifically, in Figure 4, we plot the stable rank of the differences WKW Q and WV W O . The weight perturbation learned during the training process gradually increases in stable rank during training, and is ultimately low-rank when compared to the initial spectrum. Finally, for CIFAR-10, we plot the spectrum of WKW Q against that of its initialized state in Figure 5 for different self-attention heads, illustrating that the weight perturbation learned during the training process is extremely low-rank when compared to the initial spectrum. In Appendix B, we also study optimization with SGD, which shows similar gradual rank increase behavior. Effect of initialization scale We probe the effect of initialization scale on gradual rank increase dynamics for a Vi T trained on CIFAR-10. We use a Vi T of depth 6, with 8 self-attention heads per layer (with layer normalization). We use an embedding and MLP dimension of demb = 512, and a head dimension of dh = 128 (i.e WK, WQ, WV , WO Rdemb dh). We train the transformer 0 20000 40000 60000 80000 Iteration Stable rank (a) Vi T, CIFAR-10 (b) Vi T, CIFAR-100 (c) Vi T, Image Net (d) GPT-2, Wikitext-103 Figure 4: Stable rank of WKW Q (blue) and WV W O (orange) on an arbitrary chosen layer throughout training for four different pairs of networks and tasks. The stable rank of a matrix W is defined as W 2 F / W 2 2, and gives a smooth approximation of the rank. Mean and standard deviation (shaded area) are computed across all heads in each attention layer. Full details and results are in Appendix B. Figure 5: Spectrum of the weight perturbation WKW Q vs. initialization in a vision transformer trained on CIFAR-10, using Adam and default initialization scale, in random self-attention heads in different layers. The learned perturbation exhibits extreme low-rank bias post-training even in default initialization scales. Analogous plots for CIFAR-100 and Image Net are in Appendix B. using Adam with the cross-entropy loss. We train all layers (including the feedforward layers) while varying the initialization scale of all layers by multiplying their initial values by a scale factor (we fix the scale of the initial token mapper). Figure 6 shows the evolution of the principal components of WKW Q and WV W O for a randomly-chosen self-attention head and layer throughout training, exhibiting incremental learning dynamics and a low-rank bias. Note that incremental learning and low-rank bias are increasingly evident with smaller initialization scales, as further demonstrated in Figure 7. 6 Discussion We have identified incremental learning dynamics in transformers, proved them rigorously in a simplified setting, and shown them experimentally in networks trained with practical hyperparameters. Limitations There are clear limitations to our theory: the diagonal weights and small initialization assumptions. More subtly, the theory does not apply to losses with exponential-like tails because the weights may not converge to a finite value and so Assumption 4.4 is not met (this could possibly be addressed by adding regularization). Also, the architecture must be smooth, which precludes Re LUs but allows for smoothed Re LUs such as the Ge LUs used in Vi T [DBK+20]. Finally, the theory is for training with gradient flow, while other optimizers such as Adam are used in practice instead [KB14]. Nevertheless, our experiments on Vi Ts indicate that the incremental learning dynamics occur even when training with Adam. Figure 6: Training a vision transformer on CIFAR-10 using Adam, while varying the initialization scale (unit scale indicates default initialization). Plotted are the evolution of the eigenvalues of WKW Q (a) - (c) and WV W O (d) - (f) in a random self-attention head in the second layer throughout training. Incremental learning dynamics and a low-rank bias are evident for all scales, albeit more pronounced at smaller initialization scales. Figure 7: Stable rank of WKW Q per initialization scale (Unit scale refers to the default initialization) in different self-attention heads post-training, at layers 1, 3, 5. At each layer, the stable rank mean and standard deviation are computed across 8 heads per layer, for each initialization scale. All models were trained on CIFAR-10 using the Adam optimizer. Smaller initialization scales lead to lower-rank attention heads. Future directions An interesting avenue of future research is to develop a theoretical understanding of the implicit bias in function space of transformers whose weights are a low-rank perturbation of randomly initialized weights. Another promising direction is to examine the connection between our results on incremental dynamics and the Lo RA method [HSW+21], with the goal of explaining and improving on this algorithm; see also the discussion in Section 1.1. Along this vein, a concurrent work [ZZC+23] independently observes gradual rank increase dynamics during transformer training and this inspires a low-rank training algorithm that obtains runtime and memory improvements over regular training. The results of [ZZC+23] are complementary to ours, since they study the feedforward layers of the transformer, and their theory applies to linear networks in the standard initialization scale; in contrast, we study the attention layers, and our theory applies to nonlinear networks with small initialization scale. Acknowledgments We would like to thank Vimal Thilak for his help in setting up the infrastructure for conducting experiments, and the anonymous reviewers for their helpful feedback. [ABM22] Emmanuel Abbe, Enric Boix-Adsera, and Theodor Misiakiewicz. The merged-staircase property: a necessary and nearly sufficient condition for SGD learning of sparse functions on two-layer neural networks, COLT, 2022. [ABM23] Emmanuel Abbe, Enric Boix-Adsera, and Theodor Misiakiewicz. Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics. ar Xiv preprint ar Xiv:2302.11055, 2023. [ACH18] Sanjeev Arora, Nadav Cohen, and Elad Hazan. On the optimization of deep networks: Implicit acceleration by overparameterization. In International Conference on Machine Learning, pages 244 253. PMLR, 2018. [ACHL19] Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization in deep matrix factorization. Advances in Neural Information Processing Systems, 32, 2019. [AZG20] Armen Aghajanyan, Luke Zettlemoyer, and Sonal Gupta. Intrinsic dimensionality explains the effectiveness of language model fine-tuning. ar Xiv preprint ar Xiv:2012.13255, 2020. [Bac20] Francis Bach. Effortless optimization through gradient flows. Machine Learning Research Blog. https://francisbach. com/gradient-flows, 2020. [BBSS22] Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song. Learning single-index models with shallow neural networks. ar Xiv preprint ar Xiv:2210.15651, 2022. [Ber22] Rapha el Berthier. Incremental learning in diagonal linear networks. ar Xiv preprint ar Xiv:2208.14673, 2022. [BMR+20] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877 1901, 2020. [BPVF22] Etienne Boursier, Loucas Pillaud-Vivien, and Nicolas Flammarion. Gradient flow dynamics of shallow relu networks for square loss and orthogonal inputs. ar Xiv preprint ar Xiv:2206.00939, 2022. [COB18] L ena ıc Chizat, Edouard Oyallon, and Francis R. Bach. On lazy training in differentiable programming. In Neural Information Processing Systems, 2018. [DBK+20] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. Ar Xiv, abs/2010.11929, 2020. [DCLT19] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pretraining of deep bidirectional transformers for language understanding. Ar Xiv, abs/1810.04805, 2019. [DHL18] Simon S Du, Wei Hu, and Jason D Lee. Algorithmic regularization in learning deep homogeneous models: Layers are automatically balanced. Advances in Neural Information Processing Systems, 31, 2018. [DLS22] Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi. Neural networks can learn representations with gradient descent. In Conference on Learning Theory, pages 5413 5452. PMLR, 2022. [FVB+22] Spencer Frei, Gal Vardi, Peter L Bartlett, Nathan Srebro, and Wei Hu. Implicit bias in leaky relu networks trained on high-dimensional data. ar Xiv preprint ar Xiv:2210.07082, 2022. [GBLJ19] Gauthier Gidel, Francis Bach, and Simon Lacoste-Julien. Implicit regularization of discrete gradient dynamics in linear neural networks. Advances in Neural Information Processing Systems, 32, 2019. [GMMM19] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Limitations of lazy training of two-layers neural network. Advances in Neural Information Processing Systems, 32, 2019. [GSJW19] Mario Geiger, Stefano Spigler, Arthur Jacot, and Matthieu Wyart. Disentangling feature and lazy learning in deep neural networks: an empirical study. Ar Xiv, abs/1906.08034, 2019. [GSSD19] Daniel Gissin, Shai Shalev-Shwartz, and Amit Daniely. The implicit bias of depth: How incremental learning drives generalization. ar Xiv preprint ar Xiv:1909.12051, 2019. [HSW+21] Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. ar Xiv preprint ar Xiv:2106.09685, 2021. [JGH18] Arthur Jacot, Franck Gabriel, and Cl ement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31, 2018. [JGS+21] Arthur Jacot, Francois Gaston Ged, Berfin Simsek, Cl ement Hongler, and Franck Gabriel. Saddle-to-saddle dynamics in deep linear networks: Small initialization training, symmetry, and sparsity. 2021. [JLL+23] Jikai Jin, Zhiyuan Li, Kaifeng Lyu, Simon S Du, and Jason D Lee. Understanding incremental learning of gradient descent: A fine-grained analysis of matrix sensing. ar Xiv preprint ar Xiv:2301.11500, 2023. [KB14] Diederik P Kingma and Jimmy Ba. Adam: A method for stochastic optimization. ar Xiv preprint ar Xiv:1412.6980, 2014. [KC22] Daesung Kim and Hye Won Chung. Rank-1 matrix completion with gradient descent and small random initialization. Ar Xiv, abs/2212.09396, 2022. [LFLY18] Chunyuan Li, Heerad Farkhoor, Rosanne Liu, and Jason Yosinski. Measuring the intrinsic dimension of objective landscapes. ar Xiv preprint ar Xiv:1804.08838, 2018. [LLL20] Zhiyuan Li, Yuping Luo, and Kaifeng Lyu. Towards resolving the implicit bias of gradient descent for matrix factorization: Greedy low-rank learning. Ar Xiv, abs/2012.09839, 2020. [LOG+19] Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. Roberta: A robustly optimized bert pretraining approach. Ar Xiv, abs/1907.11692, 2019. [MGW+20] Edward Moroshko, Suriya Gunasekar, Blake E. Woodworth, J. Lee, Nathan Srebro, and Daniel Soudry. Implicit bias in deep linear classification: Initialization scale vs training accuracy. Ar Xiv, abs/2007.06738, 2020. [MHPG+22] Alireza Mousavi-Hosseini, Sejun Park, Manuela Girotti, Ioannis Mitliagkas, and Murat A Erdogdu. Neural networks efficiently learn low-dimensional representations with sgd. ar Xiv preprint ar Xiv:2209.14863, 2022. [MKAA21] Paolo Milanesi, Hachem Kadri, S. Ayache, and Thierry Arti eres. Implicit regularization in deep tensor factorization. 2021 International Joint Conference on Neural Networks (IJCNN), pages 1 8, 2021. [MKAS21] Eran Malach, Pritish Kamath, Emmanuel Abbe, and Nathan Srebro. Quantifying the benefit of using differentiable learning over tangent kernels. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 7379 7389. PMLR, 18 24 Jul 2021. [PA23] Dylan Patel and Afzal Ahmad. Google we have no moat, and neither does openai , May 2023. [PF23] Scott Pesme and Nicolas Flammarion. Saddle-to-saddle dynamics in diagonal linear networks. ar Xiv preprint ar Xiv:2304.00488, 2023. [RC20] Noam Razin and Nadav Cohen. Implicit regularization in deep learning may not be explainable by norms. Advances in neural information processing systems, 33:21174 21187, 2020. [SKZ+23] James B Simon, Maksis Knutins, Liu Ziyin, Daniel Geisz, Abraham J Fetterman, and Joshua Albrecht. On the stepwise nature of self-supervised learning. ar Xiv preprint ar Xiv:2303.15438, 2023. [SMG13] Andrew M Saxe, James L Mc Clelland, and Surya Ganguli. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. ar Xiv preprint ar Xiv:1312.6120, 2013. [SS21] Dominik St oger and Mahdi Soltanolkotabi. Small random initialization is akin to spectral learning: Optimization and generalization guarantees for overparameterized low-rank matrix reconstruction. In Neural Information Processing Systems, 2021. [TGZ+23] Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, and Tatsunori B Hashimoto. Alpaca: A strong, replicable instruction-following model. Stanford Center for Research on Foundation Models. https://crfm. stanford. edu/2023/03/13/alpaca. html, 2023. [TVS23] Nadav Timor, Gal Vardi, and Ohad Shamir. Implicit regularization towards rank minimization in relu networks. In International Conference on Algorithmic Learning Theory, pages 1429 1459. PMLR, 2023. [VSP+17] Ashish Vaswani, Noam M. Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. Ar Xiv, abs/1706.03762, 2017. [WGL+19] Blake E. Woodworth, Suriya Gunasekar, J. Lee, Edward Moroshko, Pedro H. P. Savarese, Itay Golan, Daniel Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. Ar Xiv, abs/2002.09277, 2019. [YW23] Hao Yu and Jianxin Wu. Compressing transformers: Features are low-rank, but weights are not! 2023. [ZZC+23] Jiawei Zhao, Yifei Zhang, Beidi Chen, Florian Sch afer, and Anima Anandkumar. Inrank: Incremental low-rank learning. ar Xiv preprint ar Xiv:2306.11250, 2023. 1 Introduction 1 1.1 Related work . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2 1.2 Paper organization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 3 2 Preliminaries 3 3 Neural networks with diagonal weights 4 4 Incremental learning in networks with diagonal weights 4 4.1 Intuition for incremental learning dynamics . . . . . . . . . . . . . . . . . . . . . 5 4.1.1 Stage 1 of dynamics . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6 4.1.2 Later stages . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7 4.2 Assumptions for incremental dynamics . . . . . . . . . . . . . . . . . . . . . . . . 7 5 Experimental results 8 6 Discussion 9 A Experimental validation of the assumptions in Theorem 4.1 15 B Further experiments on vision and language transformers 20 B.1 SGD-trained transformers . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 B.2 Adam-trained transformers . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 C Proof for dynamics of networks with diagonal parametrization (Theorem 4.1) 25 C.1 Assumptions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25 C.2 Rescaling time for notational convenience . . . . . . . . . . . . . . . . . . . . . . 25 C.3 Simplifying problem without loss of generality . . . . . . . . . . . . . . . . . . . 26 C.4 Tracking the sum of the weights . . . . . . . . . . . . . . . . . . . . . . . . . . . 26 C.5 Claimed invariants in proof of Theorem C.4 . . . . . . . . . . . . . . . . . . . . . 26 C.6 Dynamics from time tk to time tk+1 (Linear dynamics for O(log(1/α)) unrescaled time) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27 C.6.1 Analysis in case where Tk+1 < . . . . . . . . . . . . . . . . . . . . . . 27 C.6.2 Analysis in case where Tk+1 = . . . . . . . . . . . . . . . . . . . . . . 30 C.7 Dynamics from time tk to time tk (Nonlinear evolution for O(1) unrescaled time) . 30 C.8 Concluding the proof of Theorem C.4 . . . . . . . . . . . . . . . . . . . . . . . . 32 D Technical lemmas 32 D.1 Relating the sum of the weights to the original weights using the conservation law . 32 D.2 Sign of gradients on coordinates that leave support . . . . . . . . . . . . . . . . . 33 D.3 Local lipschitzness and smoothness . . . . . . . . . . . . . . . . . . . . . . . . . 33 A Experimental validation of the assumptions in Theorem 4.1 In Figures 2, 8, and 9, we plot the evolution of the losses, of the entries of WKW Q = diag(w K)diag(w Q), and of the entries of WV W O = diag(w V )diag(w O) in the toy task of training an attention head (1) with diagonal weights. The model is trained with SGD on the meansquared error loss on 1000 random samples (X, y). Each random sample has X R10 50, which a sequence of 10 tokens, each of dimension 50, which is distributed as isotropic Gaussian. The label y is given by a randomly-generated teacher model that is also an attention head (1) with diagonal weights. In Figures 2, 8, and 9, for α {0.1, 0.01, 0.0001, 10 8, 10 16, 10 32} we plot the evolution of the loss and of the weights when initialized at θ(0) = αθ0, for some random Gaussian θ0. Qualitatively, as α 0 we observe that the loss curve and the trajectories of the weights appear to converge to a limiting stagewise dynamics, where there are plateaus followed by movement on short time-scales, as predicted by the theory. Validation of Assumption 4.3 (non-degeneracy of dynamics) As α 0, notice that the stages appear to separate and happen at distinct times. Furthermore, the extra technical condition on coordinates i supp(θk) \ supp(θk 1) in Assumption 4.3 is satisfied since no coordinates ever leave the support of θk. Validation of Assumption 4.4 (stationary points are strict local minima) In Figure 10 we consider the α = 10 32 trajectory, since this is closest to the dynamics in the α 0 limit. We randomly select several epochs. Since the transitions between stages are a vanishing fraction of the total training time, each of these randomly-selected epochs is likely during a plateau, as we see in the figure. For each epoch perform the following experiment. For each nonnegligible coordinate of the weights (those where the weight is of magnitude greater than the threshold τ = 10 5), we perturb the weights by adding noise of standard deviation 0.05. We then run the training dynamics starting at this perturbed initialization for 1000 epochs. We observe that the training dynamics quickly converge to the original unperturbed initialization, indicating that the weights were close to a strict local minimum of the loss. Validation of Assumption 4.5 (noise-robustness of dynamics) In Figure 11 we perform the same experiment as in Figure 10, except that the epochs we select to perturb the weights are those where there is a newly-nonnegligible coordinate (less than 10 5 in magnitude in the previous epoch, and more than 10 5 in magnitude in this epoch). We find that the nonlinear dynamics are robust and tend to the limiting endpoint even under a random Gaussian perturbation of standard deviation 10 2 on each of the nonnegligible coordinates, supporting Assumption 4.5. 0 200 400 600 800 1000 time / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 0.1 0 200 400 600 800 1000 ime / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 0.01 0 200 400 600 800 1000 ime / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 0.0001 0 200 400 600 800 1000 ime / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 1e-08 0 200 400 600 800 1000 time / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 1e-16 0 200 400 600 800 1000 time / log(1/alpha) Diagonal en ries of W En ries of diag(w ), for alpha = 1e-32 Figure 8: Evolution of diag(w Q)diag(w K) entries over rescaled time initializing at various scalings α. Notice that as α 0, the training trajectories tend to a limiting trajectory. Each line corresponds to a diagonal entry of diag(w Q)diag(w K). 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 0.1 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 0.01 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 0.0001 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 1e-08 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 1e-16 0 200 400 600 800 1000 time / log(1/alpha) Diagonal entries of W Entries of diag(w ), for alpha = 1e-32 Figure 9: Evolution of diag(w V )diag(w O) entries in the toy task of learning an attention head with diagonal weights. Each line corresponds to the evolution of an entry of diag(w V )diag(w O) over rescaled time. Each plot corresponds to a different initialization magnitude α. Notice that as α 0, the training trajectories tend to a limiting trajectory. 374 375 376 377 378 379 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 373.4 374 375 376 377 378 379 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 373.4 541 542 543 544 545 546 547 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 541.2 541 542 543 544 545 546 547 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 541.2 778 779 780 781 782 783 784 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 778.0 778 779 780 781 782 783 784 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 778.0 863 864 865 866 867 868 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 862.4 863 864 865 866 867 868 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 862.4 942 943 944 945 946 947 948 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 942.2 942 943 944 945 946 947 948 time / log(1/al ha) Coordinate value Entries of W , erturb time / log(1/al ha) = 942.2 Figure 10: Evolution of weights of toy attention model under perturbation, validating Assumption 4.4. At 5 different random times during training, we perturb the nonnegligible weight coordinates and continue to train with SGD. The evolution of each of the weights under the initial perturbation (solid line) is compared to the original evolution without perturbation (dashed line). Observe that the training dynamics quickly brings each weight back to the unperturbed weight trajectory, indicating that the weights are originally close to a strict local minimum. 200 220 240 260 280 300 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 202.5 200 220 240 260 280 300 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 202.5 322.5 325.0 327.5 330.0 332.5 335.0 337.5 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 321.4 322.5 325.0 327.5 330.0 332.5 335.0 337.5 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 321.4 357.5 360.0 362.5 365.0 367.5 370.0 372.5 375.0 377.5 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 356.8 357.5 360.0 362.5 365.0 367.5 370.0 372.5 375.0 377.5 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 356.8 406 408 410 412 414 416 418 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 405.4 406 408 410 412 414 416 418 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 405.4 500 525 550 575 600 625 650 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 486.5 500 525 550 575 600 625 650 time / log(1/alpha) Coo dinate value Ent ies of W , pe tu b time / log(1/alpha) = 486.5 Figure 11: Validating Assumption 4.5 with the same experiment as in Figure 10, except that the epochs for the perturbation chosen are those where there is a newly nonnegligible coordinate. Perturbed dynamics (solid lines) are again robust to perturbation and track the original dynamics (dashed lines). B Further experiments on vision and language transformers The practice of training transformer models often deviates substantially from the assumptions made in our theoretical analysis, and it is a priori unclear to what extent gradual rank increase behaviour and a low rank bias are manifested in setups more common in practical applications. To gauge the relevancy of our findings we conduct experiments on popular vision and language benchmarks, using algorithms and hyperparameters common in the literature. We use the stable rank of a matrix W given by W 2 F W 2 2 as a smooth approximation of rank. We track the value of the stable rank for the different attention matrices throughout training. Although we do not expect our theoretical results to to hold precisely in practice, we find evidence of gradual increase in stable rank, leading to a low rank bias in Figures 12, 13, 15, 17 and 19. In these experiments we use off-the-shelf vision transformers (Vi T) [DBK+20] trained on popular vision benchmarks, as well as off-the-shelf GPT-2 trained on a popular language benchmark. We use no weight decay or dropout in our experiments. All models were initialized using the default initialization scale. B.1 SGD-trained transformers CIFAR-10/100 We trained a 6-layer Vi T with 8 heads per layer, embedding dimension 512, head dimension 128, and MLP dimension 512 and patch-size 4 for 500 epochs on CIFAR10/CIFAR100 with SGD and learning rate 3e-1 and warmup. See Figures 12 and 13. Each run took 2 hours on one A100 GPU. 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank Figure 12: CIFAR-10, Vi T trained with SGD: Stable rank of WKW Q (blue) and WV W O (orange) throughout training. Mean and standard deviation (shaded area) are computed across 8 heads per attention layer. 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank 0 10000 20000 30000 40000 50000 Iteration Stable rank Figure 13: CIFAR-100, Vi T trained with SGD: Stable rank of WKW Q (blue) and WV W O (orange) throughout training. Mean and standard deviation (shaded area) are computed across 8 heads per attention layer. B.2 Adam-trained transformers CIFAR-10/100 For the CIFAR-10/100 datasets we use a VIT with 6 layers, patchsize of 4, 8 heads per self attention layer, an embedding and MLP dimension of 512, and a head dimension of 128. We train the model using the Adam optimizer for 500 epochs with a base learning rate of 1e-4, a cyclic learning rate decay with a linear warmup schedule for 15 epochs and a batchsize of 512. Our results are summarized in Figures 14 and 15 for CIFAR-10, and Figures 16 and 17 for CIFAR-100. Figure 14: CIFAR-10, Vi T trained with Adam: normalized spectrum at different stages of training. (a) - (c) Normalized spectrum of WKW Q at initialization and WKW Q during training for different attention heads at different layers. (d) - (e) equivalent figures for WV W O . Image Net For Image Net, we use the VIT-Base/16 from [DBK+20] trained with Adam for 360 epochs with a base learning rate of 3e-3, a cyclic learning rate decay with a linear warmup schedule for 15 epochs and a batchsize of 4096. Our results are summarized in Figures 18 and 19 for Image Net. Figure 15: CIFAR-10, Vi T trained with Adam: Stable rank of WKW Q (blue) and WV W O (red) throughout training. Mean and standard deviation (shaded area) are computed across 8 heads per attention layer. Figure 16: CIFAR-100, Vi T trained with Adam: normalized spectrum at different stages of training. (a) - (c) Normalized spectrum of WKW Q at initialization and WKW Q during training for different attention heads at different layers. (d) - (e) equivalent figures for WV W O . Figure 17: CIFAR-100, Vi T trained with Adam: Stable rank of WKW Q (blue) and WV W O (red) throughout training. Mean and standard deviation (shaded area) are computed across 8 heads per attention layer. Figure 18: Image Net, Vi T trained with Adam: normalized spectrum at different stages of training. (a) - (c) Normalized spectrum of WKW Q at initialization and WKW Q during training for different attention heads at different layers. (d) - (e) equivalent figures for WV W O . Figure 19: Image Net, Vi T trained with Adam: Stable rank of WKW Q (blue) and WV W O (red) throughout training. Mean and standard deviation (shaded area) are computed across 12 heads per attention layer. Wikitext-103 The gradual rank increase phenomenon also occurs in the NLP setting with language transformers. We trained GPT-2 on Wikitext-103 using the Hugging Face training script with Adam learning rate 3e-4, per-GPU batch-size 8, and block-length 256. We trained for 3 epochs on 2 A100 GPUs, which took 12 hours. See Figure 20. 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank 0 20000 40000 60000 80000 Iteration Stable rank Figure 20: Wikitext-103, GPT-2 trained with Adam: Stable rank of WV W O and WQW K , versus training iteration. Stable rank of the perturbation increases gradually, but remains small throughout training. C Proof for dynamics of networks with diagonal parametrization (Theorem 4.1) C.1 Assumptions Recall we have defined θ0, . . . , θk, . . . R2p as the sequence of weights such that θ0 = 0 and θk+1 is defined inductively as follows. Consider the dynamics of ψk(t, ϵ) R2p initialized at ψk(0, ϵ) = θk + (ϵeik, sgn(gi(θk))ϵeik) and evolving according to the gradient flow dψk(t,ϵ) dt = θL(ψk). We assume that there is a limiting point θk+1 of these dynamics as ϵ is taken small and the time is taken large: lim ϵ 0 lim t ψk(t, ϵ) = θk+1 . Under the above assumption that this sequence θ0, . . . , θk, . . . is well-defined, we can derive a useful property of it for free. Namely, the conservation law (5) implies that u u v v is preserved. It follows that for each k we have that θk = (uk, vk) satisfies |uk| = |vk| entrywise. In other words, there is sk {+1, 1}p satisfying θk = (uk, sk uk) R2p . We also abuse notation and write supp(θk) := supp(uk) [p], since the support of θk on the first p coordinates matches its support on the last p coordinates. Having fixed this notation, we now recall the main assumptions of the theorem. Assumption C.1 (Nondegeneracy of dynamics in part (A); Assumption 4.3). The initialization satisfies |ui(0)| = |vi(0)| for all i. For stage k, either Tk+1 = or there is a unique minimizer ik to mini k(ik) in (11). Finally, for all i supp(θk 1) \ supp(θk) we have gi(θk) = 0. Assumption C.2 (Stationary points are strict local minima; Assumption 4.4). For stage k, there exist δk > 0 and ck > 0 such that for u B(uk, δ) supported on supp(uk), we have L( u, sk u) ck uk u 2 . Assumption C.3 (Noise-robustness of dynamics in part (B); Assumption 4.5). For stage k, either Tk+1 = or the following holds. For any ϵ > 0, there are δ > 0 and τ : R>0 R such that the following holds. For any u B(uk, δ) Rp 0 supported on supp( u) supp(uk) {ik}, there exists a unique solution ψ : [0, ) Rp of the gradient flow dψ dt = θL(ψ) initialized at ψ(0) = ( u, sk+1 u), and at times t τ( uik), ψ(t) θk+1 < ϵ . C.2 Rescaling time for notational convenience For ease of notation, we rescale time uα(0) = αu(0), vα(0) = αv(0) duα dt = log(1/α)vα g(uα, vα), dvα dt = log(1/α)uα g(uα, vα). (14) We also define θα(t) = (uα(t), vα(t)) R2p . Because of this time-rescaling, we equivalently state Theorem 4.1 as: Theorem C.4 (Restatement of Theorem 4.1). Let K Z 0 be such that Assumptions 4.3 4.4 hold for all k K and Assumption 4.5 holds for all k < K. Then for any k K and time t (Tk, Tk+1) the following holds. There is α0(t) > 0 such that for all α < α0, there exists a unique solution θα : [0, t] Rp to the gradient flow (14) and lim α 0 θα(t) θk , where at each stage |supp(uk) \ supp(uk 1)| 1. For shorthand, we also write Sk = supp(uk) and Sc k = [p] \ supp(uk) . C.3 Simplifying problem without loss of generality For each coordinate i [p] we have |uα,i(0)| = |vα,i(0)| by the non-degeneracy Assumption 4.3. So we can assume |uα,i(0)| > |vα,i(0)| without loss of generality. Furthermore, we can assume the entrywise inequality uα(0) > 0 by otherwise training weights uα(t), vα(t) initialized at uα(0) = sgn(uα(0))uα(0) and vα(0) = sgn(vα(0))vα(0), as uα(t) vα(t) = uα(t) vα(t) at all times. Since u2 α,i(t) v2 α,i(t) = u2 α,i(0) v2 α,i(0) by the conservation law (5), it holds that |uα,i(t)| > |vα,i(t)| throughout. So by continuity uα(t) > 0 throughout training. C.4 Tracking the sum of the weights We define wα(t) = uα(t) + vα(t) . The reason for this definition is that during training we have dwα dt = log(1/α)wα g(θα) , (15) Notice that since that we have assumed uα,i(0) > |vα,i(0)| for each i [p] we have wα(0) > 0 entrywise. So, by (15) for all t > 0 , wα(t) > 0 . It suffices to track wα(t) because we can relate the log-scale magnitude of wα(t) to the magnitudes of the corresponding coordinates in uα(t) and vα(t) see technical Lemmas D.1 D.2 and D.3. C.5 Claimed invariants in proof of Theorem C.4 In order to prove Theorem C.4, we consider any gradient flow θα : [0, T ] Rp solving (14) where T (TK, TK+1). For now, we focus only on proving properties of this gradient flow, and defer its existence and uniqueness to Section C.8. We show the following invariants inductively on the stage k. For any ϵ > 0, any stage k K, there is αk := αk(ϵ) > 0 such that for all α < αk the following holds. There are times tk := tk(α, ϵ) and tk+1 := tk+1(α, ϵ), such that tk [Tk ϵ, Tk + ϵ] , (16) tk+1 [Tk+1 ϵ, Tk+1 + ϵ] , if Tk+1 < {T }, if Tk+1 = . (17) and the weights approximate the greedy limit for all times t [ tk, tk+1] θα(t) θk < ϵ , (18) and the weights at times tk and tk+1 are correctly estimated by the incremental learning dynamics on the logarithmic-scale logα(wα( tk)) bk < ϵ (19) and if Tk+1 < then logα(wα(tk+1)) bk+1 < ϵ . (20) Base case k = 0: Take t0(α, ϵ) = 0. Then statement (16) holds since T0 = 0. Notice that as α 0 we have that uα(0), vα(0) 0 = u0, and also logα wα(0) 1 = b0. So statement (19) follows if we take α0 small enough. In Section C.6 we show how to construct time t1 such that (18) and (20) hold. Inductive step: Suppose that (16), (18), (19) and (20) hold for some iteration k < K. We prove them for iteration k + 1. In Section C.7 we construct time tk. In Section C.6 we construct time tk+1. C.6 Dynamics from time tk to time tk+1 (Linear dynamics for O(log(1/α)) unrescaled time) Let k K, and suppose that we know that for any ϵk > 0, there is αk( ϵk) > 0 such that for all 0 < α < αk, there is a time tk = tk(α, ϵk) satisfying |Tk tk| < ϵk θα( tk) θk < ϵk logα(wα( tk)) bk < ϵk . C.6.1 Analysis in case where Tk+1 < Consider first the case where Tk+1 < . We show that, for any ϵk+1 > 0, there is ρk+1(ϵk+1) > 0 such that for all 0 < ρ < ρk+1( ϵk+1) there is αk+1(ρ, ϵk+1) > 0 such that for all α < αk+1, there is a time tk+1 = tk+1(α, ρ, ϵk+1) satisfying |Tk+1 tk+1| < ϵk+1 (21) θα(t) θk < ϵk+1 for all t [ tk, tk+1] (22) logα(wα(tk+1)) bk+1 < ϵk+1 (23) uα,ik(tk+1) [ρ, 3ρ] , (24) sgn(vα,ik(tk+1)) = sk+1 ik . (25) For any ρ, α, let ϵk = ρϵk+1/(4p) and choose tk = tk(α, ϵk). Then define tk+1 = tk+1(α, ρ, ϵk+1) (26) = inf{t [ tk, ) : uα,Sc k(t) uα,Sc k( tk) + vα,Sc k(t) vα,Sc k( tk) > 4ρ} . Now we show that the weights θα(t) cannot move much from time tk to tk+1. The argument uses the local Lipschitzness of the loss L (from technical Lemma D.7), and the strictness of θk as a stationary point (from Assumption 4.4). Lemma C.5 (Stability of active variables during part (A) of dynamics). There is ρk+1 small enough and αk+1(ρ) small enough depending on ρ,such that for all ρ < ρk+1 and α < αk+1 and all t [ tk, tk+1), θα(t) θk < ρ := max(24ρ, 18 p ρKRk/ck) . (27) where ck is the strict-minimum constant from Assumption 4.4 and KRk is the Lipschitzness constant from Lemma D.7 for the ball of radius Rk = θk + 1. Proof. Assume by contradiction that (27) is violated at some time t < tk+1. Let us choose the first such time t = inf{t [ tk, tk+1) : uα(t ) uk + vα(t ) sk uk ρ } . Define θ = ( u, v) by ui = uα,i(t ), i Sk 0, i Sk and vi = vα,i(t ), i Sk 0, i Sk . By the definition of tk+1, this satisfies u uα(t ) = uα,Sc k(t ) 4ρ + uα,Sc k( tk) 4ρ + ϵk < 5ρ , v vα(t ) = vα,Sc k(t ) 4ρ + vα,Sc k( tk) 4ρ + ϵk < 5ρ . u uk + v sk uk = uα,Sk(t ) zk Sk + vα,Sk(t ) sk Sk zk Sk ρ 10ρ ρ /2 . Using (a) the strict minimum Assumption 4.4 with constant ck, since θ θk ρ and we take ρ small enough, L(θα(t )) L( θ) 4ρKRk (a) L(θk) 4ρKRk + ck(ρ )2 L(θα( tk)) (4ρ + ϵk)KRk + ck(ρ )2 16 > L(θα( tk)) . This is a contradiction because L is nondecreasing along the gradient flow. Lemma C.6 (Log-scale approximation is correct during part (A)). There are functions ρk+1(ϵk+1) > 0 and αk+1(ρ, ϵk+1) > 0 such that for all ρ < ρk+1 and α < αk+1, and for all t ( tk, tk+1) we have for a constant C depending on k, logα(wα(t)) bk + (t tk)g(θk) < ρϵk+1 + Cρ (t tk) . (28) Furthermore, for all i Sc k and t ( tk, tk+1) we have sgn(gi(θα(t))) = sgn(gi(θk)). (29) Proof. By Lemma C.5 and Lemma D.7, there is a constant C depending on θk such that for all t ( tk, tk+1), g(θα(t)) g(θk) Cρ . For shorthand, write g(θk) = g(θk) + Cρ 1 and g(θk) = g(θk) Cρ 1. Since wα(t) > 0 entrywise as we have assumed without loss of generality (see Section C.3), we have the following entrywise inequalities g(θk) wα(t) < g(θα(t)) wα(t) < g(θk) wα(t) . (30) Since the dynamics are given by dwα dt = log(1/α)g(wα) wα, wα( tk)e(t tk) log(1/α)g(θk) wα(t) wα( tk)e(t tk) log(1/α) g(θk) . Taking the logarithms with base α (0, 1), (t tk)g(uk) logα(wα( tk)) logα(wα(t)) (t tk) g(uk) . The bound (28) follows since logα(wα( tk)) bk < ϵk < ρϵk+1. Finally, the claim (29) follows from (30) since sgn( g(θk)) = sgn(g(θk)) = sgn(g(θk)) if we take ρ small enough. First, we show that the weights must move significantly by time roughly Tk+1. This is because of the contribution of coordinate ik. Lemma C.7 (tk+1 is not much larger than Tk+1). Suppose that Tk+1 < . Then there are ρk+1(ϵk+1) > 0 and αk+1(ρ, ϵk+1) > 0 such that for all ρ < ρk+1 and α < αk+1, the following holds. tk+1 < Tk+1 + ϵk+1 . Proof. Assume by contradiction that tk+1 < Tk+1 + ϵk+1. For all times t [ tk, min(tk+1, Tk+1 + ϵk+1)], by Lemma C.6, | logα(wα,ik(t)) bt ik + (t tk)gik(θk)| < O( ρ) . Since we know | k(ik) (Tk+1 tk)| < ϵk and bk i k(ik)gik(θk) {0, 2}, it follows that logα(wα,ik(Tk+1 + ϵk+1)) ( |gik(θk)|(ϵk+1 ϵk+1), 2 + |gik(θk)|(ϵk+1 ϵk+1)) + O( ρ). By taking ρ small enough, we see that |gik(θk)|(ϵk+1 ϵk+1) + O( ρ) > δ > 0 for some δ > 0 that is independent of α, so logα(wα,ik(Tk+1 + ϵk+1)) ( δ, 2 + δ) . So |uα,ik(Tk+1 + ϵk+1)| > 1 by Lemma D.2. But by the construction of tk+1 this means that tk+1 < Tk+1 + ϵk+1. Next, we show that until time tk+1, none of the coordinates in Sc k move significantly, with the possible exception of coordinate ik. Lemma C.8 (No coordinates in Sc k \ {ik} move significantly during part (A)). Suppose Tk+1 < . Then there are ρk+1(ϵk+1) > 0 and αk+1(ρ, ϵk+1) > 0 such that for all ρ < ρk+1 and α < αk+1, the following holds. There is a constant c > 0 depending on k such that for all i Sc k \ {ik} and t [ tk, tk+1], |uα,i(t) uα,i( tk)|, |vα,i(t) vα,i( tk)| < αc + ϵk . Proof. The previous lemma combined with the inductive hypothesis gives tk+1 tk < k(ik) + 2ϵk+1 \ {ik}. We analyze the movement of each coordinate i Sc k \ {ik} by breaking into two cases: Coordinate i = ik such that bk i (0, 2). By Assumption 4.3, there is a unique winning coordinate so bk i τgi(θk) (c, 2 c) for some constant c > 0 for all τ [0, tk+1 tk] [0, k(ik) + 2ϵk+1]. By Lemma C.6, logα(wα,i(t)) ( c/2, 2 c/2) for all times t [ tk, tk+1]. So by Lemma D.1, |uα,i(t)|, |vα,i(t)| αc/4. Coordinate i = ik such that bk i = 0. By Lemma D.4, we must be in the corner case where i Sk 1 Sc k (i.e., the coordinate was active in the previous stage but was dropped from the support in this stage). By Lemma D.4, since bk i = 0 we have gi(θk) < 0. By Lemma C.6, this means sgn(gi(θα(t))) = sgn(gi(θk)) < 0 for all t ( tk, tk+1). We break the analysis into two parts. Since bk i = 0, the sign is sk i = +1. The inductive hypothesis θα( tk) θk < ϵk implies that |uα,i( tk) zk i | < ϵk and |vα,i( tk) zk i | < ϵk. For small enough ϵk this means that sgn(uα,i( tk)) = sgn(vα,i( tk)) = +1. Now let t = min(tk+1, inf{t > tk : vα,i(t) = 0}). Since uα,i(t) > vα,i(t) without loss of generality (see Section C.3), we have sgn(uα,i(t)) = sgn(vα,i(t)) = +1 for all t [ tk, t ]. So duα,i(t) dt , dvα,i(t) dt < 0 for all t [ tk, t ]. So, for any t [ tk, t ], |uα,i(t) uα,i( tk)|, |vα,i(t) vα,i( tk)| < ϵk Also, since logα(wα,i(t )) 1, by Lemma C.6 we have t > c > 0 for some constant c independent of α. So for all t [t , tk+1] we have bk i τgi(θk) (c, 2 c) for some constant c > 0. So |uα,i(t)|, |vα,i(t)| αc/4 for all t [t , tk+1]. The conclusion follows by triangle inequality. Coordinate i = ik such that bk i = 2. The analysis is analogous to the case bk i = 0, except that we have sk i = 1 instead and gi(θk) > 0 by Lemma D.4. Finally, we use this conclude that tk+1 Tk+1 and that the weights at coordinate ik are the only weights that change significantly, and by an amount approximately ρ. Lemma C.9 (Coordinate ik wins the part (A) race at time tk+1 Tk+1). Suppose that Tk+1 < . Then there are ρk+1(ϵk+1) > 0 and αk+1(ρ, ϵk+1) > 0 such that for all ρ < ρk+1 and α < αk+1, the following holds. |tk+1 Tk+1| < ϵk+1 , uα,ik(tk+1) [ρ, 3ρ] , sgn(vα,ik(tk+1)) = sk+1 ik . Proof. Let us analyze the case that bk ik (0, 2). Notice that bk+1 ik = bk ik k(ik)gik(θk) {0, 2} and that if bk+1 i = 0 then gik(θk) > 0 and if it is 2 then bk+1 ik = gik(θk) < 0. So by Lemma C.6, for all times t [ tk, min(tk+1, Tk+1 ϵk+1)], we have wα,ik(t) (c, 2 c) for some c > 0. So for small enough α by Lemma D.1, |uα,ik(t)|, |vα,ik(t)| αc/2. Combining this with Lemma C.8, we see that for t [ tk, min(tk+1, Tk+1 ϵk+1)] we have uα(t) uα( tk) + vα(t) vα( tk) < 2(αc + ϵk)p < ρ , for small enough α. So by definition of tk+1 we must have tk+1 > Tk+1 ϵk+1. Combined with Lemma C.7, we conclude that |Tk+1 tk+1| < ϵk+1, which is the first claim of the lemma. Furthermore, by Lemma C.8, X i Sc k\{ik} |uα,i(tk+1) uα,i( tk)| + |vα,i(tk+1) vα,i( tk)| 2p(αc + ϵk)) < ρ/2, so by definition of tk+1 and triangle inequality we have |uα,ik(tk+1)| + |vα,ik(tk+1)| 4ρ ρ/2 = 7ρ/2. Also, since u2 α,ik(tk+1) v2 α,ik(tk+1) = Θ(α2) we have uα,ik(tk+1) [ρ, 3ρ]. Finally, if bk+1 ik = 2, then sk+1 ik = 1 and logα(wα,ik(tk+1)) > 1.5 so sgn(vα,ik(t)) < 0 by Lemma D.3; analogously, if bk+1 ik = 0, we have sk+1 ik = 1 and logα(wα,ik(tk+1) < 0.5 so sgn(vα,ik(tk+1) > 0. The case bk ik {0, 2} can be proved similarly to the analysis in Lemma C.8, where one shows that during the first period of time the magnitudes of |uik(t)| and |vik(t)| decrease, until the sign of vik flips and they once again increase. We have shown the claims (21), (22), (23) (24), and (25) for the time tk+1. In fact, if we let t k+1 [ tk, ) be the first time t such that uα,ik(t) = ρ we still have (21), (22), (23) and (25) by the same analysis as above, and (24) can be replaced with the slightly more convenient uα,ik(t k+1) = ρ . C.6.2 Analysis in case where Tk+1 = In this case that Tk+1, we just have to show that the weights remain close to θk. We show that for any ϵk+1 > 0, there is αk+1(ϵk+1) > 0 such that for all α < αk+1 and times t [Tk + ϵk+1, T ], θα(t) θk < ϵk+1. We can use Lemmas C.5 and C.6, which were developed for the case of Tk+1 < , but still hold for Tk+1 = . Lemma C.5 guarantees that the weights do not move much until time tk+1, and so we only need to show that tk+1 T when we take ρ small enough. For this, observe that gi(θk) = 0 for all i Sk, because otherwise Tk+1 < . Therefore Lemma C.6 guarantees that until time min(T , tk+1) all weights are close to the original on the logarithmic scale. Namely, logα(wα(t)) bk < ρϵk+1 + Cρ (T tk) Furthermore, by the non-degeneracy Assumption 4.3 we know that bk i (0, 2) for all i Sk by Lemma D.4. So if we take ρ small enough and αk+1 small enough, we must have that tk+1 T . C.7 Dynamics from time tk to time tk (Nonlinear evolution for O(1) unrescaled time) Suppose that we know for some k K that for any ϵk > 0, there is ρk(ϵk) > 0 such that for all ρ < ρk there is αk(ρ, ϵk) > 0 such that for all α < αk, there is a time tk = tk(α, ρ, ϵk) satisfying |Tk tk| < ϵk (31) θα(tk) θk 1 < ϵk (32) logα(wα(tk)) bk < ϵk (33) uα,ik 1(tk) = ρ , (34) sgn(vα,ik 1(tk)) = sk ik 1 . (35) Now we will show that for any ϵk > 0, there is αk = αk( ϵk) > 0 such that for all 0 < α < αk, there is a time tk = tk(α, ϵk) satisfying |Tk tk| < ϵk (36) θα( tk) θk < ϵk (37) logα(wα( tk)) bk < ϵk (38) We give the construction for tk. For any desired accuracy ϵk > 0 in this stage, we will construct an accuracy ϵk = ϵk( ϵk) = ϵk/3 > 0. We will also construct a ρ = ρ(ϵk) > 0 which is sufficiently small, and we will construct an cutoff for α equal to αk = αk+1( ϵk) > 0 which satisfies αk < αk(ρ, ϵk). The values for these parameters ϵk and ρ and αk will be chosen in the following lemma, and will depend only on ϵk. Lemma C.10 (New local minimum reached in time O(1/ log(1/α))). For any ϵk > 0, we can choose αk = αk( ϵk) > 0 small enough so that, for any 0 < α < αk, there is tk = tk(α, ϵk) for which conditions (36) to (38) hold. Furthermore, there is a constant C independent of α such that |θα(t)|/|θα(tk)| [1/C , C ]2p at all times t [tk, tk]. Proof. Let tk = tk(α, ρ, ϵk) be given by the induction. Let us compare the dynamics starting at θα(tk) with the dynamics starting at θ(tk) = ( u(tk), v(tk)) which is given by ui(tk) = uα,i(tk), i Sk 1 {ik 1} 0, otherwise and vi(tk) = vα,i(tk), i Sk 1 {ik 1} 0, otherwise and run with dt = log(1/α) w L( θ) . By Assumption 4.5 we know there exists a unique solution θ : [tk, ) Rp as long as we take ϵk small enough because supp( θ(tk)) = Sk 1 {ik 1} and θi(tk) θk 1 < ϵk. Furthermore, by Assumption 4.5 if we take ϵk small enough there must be a time τ := τ( ϵk, ρ) < such that θ(t) θk < ϵk/2 for t tk + τ/ log(1/α) (39) tk = tk + τ/ log(1/α). So for α small enough, |Tk tk| < 2ϵk < ϵk, proving (36). We now compare θα( tk) with θ( tk), and show that if we take α small enough, then the dynamics of θ closely match the dynamics of θα(t) for times tk + O(1/ log(1/α)). The argument uses Gronwall s inequality. Let t = inf{t > tk : θ(t ) θα(t) > 1/3}. For times t [tk, t ) by Lemma D.7 we have dtθα(t) = log(1/α) θL( θ(t)) θL(θα(t)) K θ(t) log(1/α) θ(t) θα(t) , where K θ(t) is the smoothness constant from Lemma D.7. Note that since θ(t) < for large enough t by (39), the trajectory of θ must lie in a compact set. Therefore, there must be a finite set of times s1, . . . , sm [tk, t ) such that t [tk,t )B( θ(t), 1/2) m i=1B( θ(si), 3/4). So letting C = maxm i=1 K θ(si) < for all times t [tk, t ) we have d dt θ(t) θα(t) C log(1/α) θ(t) θα(t) . By Gronwall s inequality, for all times t [tk, t ), θ(t) θα(t) θ(tk) θα(tk) exp(C log(1/α)(t tk)) . We know from Lemma C.8 that there is a constant c > 0 such that for any small enough 0 < α < αk, such that θ(tk) θα(tk) < αc If we take α small enough that αc exp(Cτ) < ϵk/2 < 1/3, we must have t > tk + τ/ log(1/α) and so we prove (37) θk θα( tk) ϵk/2 + θ( tk) θα( tk) < ϵk . It remains to show that (38) is satisfied. Since θ(t) θα(t) < 1/3 for all t [tk, tk], it holds that the trajectory of θα(t) lies in a compact set. So by Lemma D.7 we have g(θα(t)) < C for some constant C at all times t [tk, tk]. Since 1 log(1/α)| dwα,i dt | = |wα,i(t)||gi(wα(t))| < C |wα,i(t)|, we must have |wα,i(t)|/|wα,i(tk)| [1/C , C ] for some constant C independent of α and all t [tk, tk]. Therefore, (38) follows from (33). A similar argument shows that |θα(t)/θα(tk)| [1/C , C ]2p. C.8 Concluding the proof of Theorem C.4 We have shown that Theorem 4.1 is true for solutions θα : [0, T ] R2p to the gradient flow, where T (TK, TK+1). To establish Theorem C.4 it remains only to show that for any T (TK, TK+1) and small enough α such a solution to the gradient flow exists and is unique. To see this, note that in the inductive proof of the invariants we construct a sequence of times 0 = t0 t1 t1 t K t K+1 > T , where we guarantee that any gradient flow solution θα : [0, tk+1] Rp satisfies θα k {0,...,K}B(θk, 1) for all t k {0,...,K}[ tk, tk+1]. And also for t k {0,...,K 1}[tk, tk+1], we have θα(t) B(0, C k θk) for some constant C k independent of α by Lemma C.10. So θα(t) B(0, CK) for some constant CK at all times t [0, T ]. By Lemma D.7, the loss gradient θL(θ) = (v g(θ), u g(θ)) is Lipschitz-continuous on the compact set B(0, CK). So θα : [0, T ] Rp exists and is unique by the Cauchy-Lipschitz theorem. D Technical lemmas D.1 Relating the sum of the weights to the original weights using the conservation law Lemma D.1. If for some constant 0 < c < 1 we have logα(wα,i(t)) (c, 2 c), then for small enough α max(|uα,i(t)|, |vα,i(t)|) αc/2 . Proof. Let wα(t) = uα(t) vα(t). By the conservation law (5), wα,i(t) wα,i(t) = wα,i(0) wα,i(0) = uα,i(0)2 vα,i(0)2. By the non-degeneracy of initialization (Assumption 4.3), the right-hand-side is Θ(α2). So if logα(wα,i(t)) (c, 2 c) then for small enough α, we have logα(| wα,i(t)|) (3c/4, 2 3c/4). So |uα,i(t)| |wα,i(t) + wα,i(t)| αc/2 and |vα,i(t)| |wα,i(t) wα,i(t)| αc/2. Lemma D.2. If for some constant 0 < c we have logα(wα,i(t)) ( c, 2 + c), then for small enough α, |uα,i(t)| > 1 . Proof. Define wα = uα vα as in the proof of Lemma D.1. If logα(wα,i(t)) < c then logα(| wα,i(t)|) > 2 c/2 for small enough α, so ui(t) > α c α2 c/2 > 1. Similarly, if logα(wα,i(t)) > 2 + c then logα(| wα,i(t)|) < c/2 so |ui(α)| > α c/2 α2+c > 1. Lemma D.3. If for some constant c > 0, there is small enough α such that if we have logα(wα,i(t)) > 1 + c then sgn(vα,i(t)) < 0. Otherwise, if logα(wα,i(t)) < 1 c then sgn(vα,i(t)) > 0. Proof. Follows from vα = 1 2(wα wα). Recall that wα(t) > 0 and notice that wα(t) > 0. In the first case, wα,i(t) < α1+c and wα,i(t) > α1 c/2. In the latter case wα,i(t) > α1 c and wα,i(t) < α1+c/2. D.2 Sign of gradients on coordinates that leave support Lemma D.4. For any k 1 and i Sc k, if bk i {0, 2} then we must have i supp(uk 1) \ supp(uk), and we must have gi(uk) < 0 if bk i = 0 and gi(θk) > 0 if bk i = 2. In particular, k(ik) > 0 for all k. Proof. This is by induction on k and using the non-degeneracy Assumption 4.3. D.3 Local lipschitzness and smoothness We provide several technical lemmas on the local Lipschitzness and smoothness of ℓ, h, and g. Lemma D.5. The function ℓ(y, ) is locally Lipschitz and smooth in its second argument: for any R > 0, there exists KR such that for any ζ, ζ B(0, R) |ℓ(y, ζ) ℓ(y, ζ )| KR ζ ζ Dℓ(y, ζ) Dℓ(y, ζ ) KR ζ ζ , almost surely over y. Here Dℓ(y, ) Rdout is the derivative in the second argument. Proof. Since ℓis continuously twice-differentiable, for each y Rdy, ζ Rdout there is Ky,ζ < such that for all y B(y, 1/Ky,ζ) and ζ B(ζ, 1/Ky,ζ) we have Dℓ(y , ζ ) Ky,ζ and D2ℓ(y , ζ ) Ky,ζ , where Dℓand D2ℓdenote the first and second derivative in the second argument. So for all such y B(y, 1/Ky,ζ) and ζ , ζ B(ζ, 1/Ky,ζ) we have |ℓ(y , ζ ) ℓ(y , ζ )| Ky,ζ ζ ζ and |Dℓ(y , ζ ) Dℓ(y , ζ )| Ky,ζ ζ ζ . Cover the set {(y, ζ) : y C, ζ R} with the balls y B(y, 1/Ky,ζ). By compactness, there is a finite subcover (y1, ζ1), . . . , (yr, ζr), so we can take KR = maxi [r] Kyi,ζi < and the lemma holds since y C almost surely by Assumption 2.1. Lemma D.6. The function h(x; ) is locally bounded, Lipschitz and smooth in its second argument: for any R > 0 there exists KR such that for any ψ, ψ B(0, R), h(x; ψ) KR h(x; ψ) h(x; ψ ) KR ψ ψ Dh(x; ψ) Dh(x; ψ ) KR ψ ψ , almost surely over x. Here Dh(x, ) Rdout Rp is the derivative in the second argument. Proof. Analogous to proof of Lemma D.5, using continuous twice-differentiability of h and boundedness of x . Lemma D.7 (Local Lipschitzness of loss and loss derivative). When θ = (u, v) R2p and f NN(x; θ) = h(x; u u) the following holds for g(θ) defined in (4). For any R > 0, there exists KR < such that for any θ, θ B(0, KR), g(θ) g(θ ) KR θ θ θL(θ) RL(θ ) Kθ θ θ |L(θ) L(θ )| KR θ θ . Proof. Let θ = (u, v), θ = (u , v ). This follows immediately from the local Lipschitzness and smoothness of h and ℓin Lemmas D.5 and D.6, as well as g(θ) g(θ ) = Ex,y[Dh(x; u v) Dℓ(y, h(x; u v)) Dh(x; u v ) Dℓ(y, h(x; u v )) ] .