# learning_threshold_neurons_via_edge_of_stability__b8328d46.pdf Learning threshold neurons via the edge of stability Kwangjun Ahn MIT EECS Cambridge, MA kjahn@mit.edu Sébastien Bubeck Microsoft Research Redmond, WA sebubeck@microsoft.com Sinho Chewi Institute for Advanced Study Princeton, NJ schewi@ias.edu Yin Tat Lee Microsoft Research Redmond, WA yintat@uw.edu Felipe Suarez Carnegie Mellon University Pittsburgh, PA felipesc@mit.edu Yi Zhang Microsoft Research Redmond, WA zhayi@microsoft.com Existing analyses of neural network training often operate under the unrealistic assumption of an extremely small learning rate. This lies in stark contrast to practical wisdom and empirical studies, such as the work of J. Cohen et al. (ICLR 2021), which exhibit startling new phenomena (the edge of stability or unstable convergence ) and potential benefits for generalization in the large learning rate regime. Despite a flurry of recent works on this topic, however, the latter effect is still poorly understood. In this paper, we take a step towards understanding genuinely non-convex training dynamics with large learning rates by performing a detailed analysis of gradient descent for simplified models of two-layer neural networks. For these models, we provably establish the edge of stability phenomenon and discover a sharp phase transition for the step size below which the neural network fails to learn threshold-like neurons (i.e., neurons with a non-zero first-layer bias). This elucidates one possible mechanism by which the edge of stability can in fact lead to better generalization, as threshold neurons are basic building blocks with useful inductive bias for many tasks. 0.000 0.002 0.004 0.006 0.008 0.010 0.012 learning rate η 0.0000 0.0005 0.0010 0.0015 0.000 0.002 0.004 0.006 0.008 0.010 0.012 learning rate η final test accuracy (%) 0.0000 0.0005 0.0010 0.0015 58 Figure 1: Large step sizes are necessary to learn the threshold neuron of a Re LU network (2) for a simple binary classification task (1). We choose d = 200, n = 300, λ = 3, and run gradient descent with the logistic loss. The weights are initialized as a , a+ N(0, 1/(2d)) and b = 0. For each learning rate η, we set the iteration number such that the total time elapsed (iteration η) is 10. The vertical dashed lines indicate our theoretical prediction of the phase transition phenomenon (precise threshold at η = 8π/d2). 1 Introduction How much do we understand about the training dynamics of neural networks? We begin with a simple and canonical learning task which indicates that the answer is still far too little . 37th Conference on Neural Information Processing Systems (Neur IPS 2023). Motivating example: Consider a binary classification task of labeled pairs (x(i), y(i)) Rd { 1} where each covariate x(i) consists of a 1-sparse vector (in an unknown basis) corrupted by additive Gaussian noise, and the label y(i) is the sign of the non-zero coordinate of the 1-sparse vector. Due to rotational symmetry, we can take the unknown basis to be the standard one and write x(i) = λy(i)ej(i) + ξ(i) Rd , (1) where y(i) { 1} is a random label, j(i) [d] is a random index, ξ(i) is Gaussian noise, and λ > 1 is the unknown signal strength. In fact, (1) is a special case of the well-studied sparse coding model (Olshausen and Field, 1997; Vinje and Gallant, 2000; Olshausen and Field, 2004; Yang et al., 2009; Koehler and Risteski, 2018; Allen-Zhu and Li, 2022). We ask the following fundamental question: How do neural networks learn to solve the sparse coding problem (1)? In spite of the simplicity of the setting, a full resolution to this question requires a thorough understanding of surprisingly rich dynamics which lies out of reach of existing theory. To illustrate this point, consider an extreme simplification in which the basis e1, . . . , ed is known in advance, for which it is natural to parametrize a two-layer Re LU network as f(x; a , a+, b) = a d X i=1 Re LU x[i] + b + a+ d X i=1 Re LU +x[i] + b . (2) The parametrization (2) respects the latent data structure (1) well: a good network has a negative bias b to threshold out the noise, and has a < 0 and a+ > 0 to output correct labels. We are particularly interested in understanding the mechanism by which the bias b becomes negative, thereby allowing the non-linear Re LU activation to act as a threshold function; we refer to this as the problem of learning threshold neurons . More broadly, such threshold neurons are of interest as they constitute basic building blocks for producing neural networks with useful inductive bias. 1.5 η = 2.5 10 5 η = 2.5 10 4 0 25 50 75 100 iteration training loss η = 2.5 10 3 0 25 50 75 100 iteration 0 25 50 75 100 iteration Figure 2: Large learning rates lead to unexpected phenomena: non-monotonic loss and wild oscillations of weights. We choose the same setting as Figure 1. With a small learning rate (η = 2.5 10 5), the bias does not decrease noticeably, and the same is true even when we increase the learning rate by ten times (η = 2.5 10 4). When we increase the learning rate by another ten times (η = 2.5 10 3), we finally see a noticeable decrease in the bias, but with this we observe unexpected behavior: the loss decreases non-monotonically and the sum of second-layer weights d (a + a+) oscillates wildly. We train the parameters a , a+, b using gradient descent with step size η > 0 on the logistic loss Pn i=1 ℓlogi(y(i) f(x(i); a , a+, b)), where ℓlogi(z) := log(1 + exp( z)), and we report the results in Figures 1 and 2. The experiments reveal a compelling picture of the optimization dynamics. Large learning rates are necessary, both for generalization and for learning threshold neurons. Figure 1 shows that the bias decreases and the test accuracy increases as we increase η; note that we plot the results after a fixed time (iteration η), so the observed results are not simply because larger learning rates track the continuous-time gradient flow for a longer time. Large learning rates lead to unexpected phenomena: non-monotonic loss and wild oscillations of a + a+. Figure 2 shows that large learning rates also induce stark phenomena, such as non-monotonic loss and large weight fluctuations, which lie firmly outside the explanatory power of existing analytic techniques based on principles from convex optimization. There is a phase transition between small and large learning rates. In Figure 1, we zoom in on learning rates around η 0.0006 and observe sharp phase transition phenomena. We have presented these observations in the context of the simple Re LU network (2), but we emphasize that these findings are indicative of behaviors observed in practical neural network training settings. In Figure 3, we display results for a two-layer Re LU network trained on the full sparse coding model (1) with unknown basis, as well as a deep neural network trained on CIFAR-10. In each case, we again observe non-monotonic loss coupled with steadily decreasing bias parameters. For these richer models, the transition from small to large learning rates is oddly reminiscent of well-known separations between the lazy training or NTK regime Jacot et al. (2018) and the more expressive feature learning regime. For further experimental results, see Appendix B. 0 200 400 600 800 1000 1200 1400 step lr=0.03 lr=0.04 lr=0.05 lr=0.07 lr=0.1 lr=0.2 0 200 400 600 800 1000 1200 1400 step lr=0.03 lr=0.04 lr=0.05 lr=0.07 lr=0.1 lr=0.2 0 10 20 30 40 50 step lr=0.2 lr=0.1 lr=0.05 0 10 20 30 40 50 step median bias lr=0.2 lr=0.1 lr=0.05 Figure 3: (Top) Results for training an over-parametrized two-layer neural network f(x; a, W , b) = Pm i=1 ai Re LU w i x + b with m d for the full sparse coding model (1); in this setting, the basis vectors are unknown, and the neural network learn them through additional parameters W = (wi)m i=1. Also, we use m different weights a = (ai)m i=1 for the second layer. (Bottom) Full-batch gradient descent dynamics of Res Net-18 on (binary) CIFAR-10 with various learning rates. Details are deferred to Appendix B. We currently do not have right tools to understand these phenomena. First of all, a drastic change in behavior between the small and the large learning rates cannot be captured through well-studied regimes, such as the neural tangent kernel (NTK) regime (Jacot et al., 2018; Allen-Zhu et al., 2019; Arora et al., 2019; Chizat et al., 2019; Du et al., 2019; Oymak and Soltanolkotabi, 2020) or the mean-field regime Chizat and Bach (2018); Mei et al. (2019); Chizat (2022); Nitanda et al. (2022); Rotskoff and Vanden-Eijnden (2022). In addition, understanding why a large learning rate is required to learn the bias is beyond the scope of prior theoretical works on the sparse coding model (Arora et al., 2015; Karp et al., 2021). Our inability to explain these findings points to a serious gap in our grasp of neural network training dynamics and calls for a detailed theoretical study. 1.1 Main scope of this work In this work, we do not aim to understand the sparse coding problem (1) in its full generality. Instead, we pursue the more modest goal of shedding light on the following question. Q. What is the role of a large step size in learning the bias for the Re LU network (2)? As discussed above, the dynamics of the simple Re LU network (2) is a microcosm of emergent phenomena beyond the convex optimization regime. In fact, there is a recent growing body of work (Cohen et al., 2021; Arora et al., 2022; Ahn et al., 2022; Lyu et al., 2022; Ma et al., 2022; Wang et al., 2022b; Chen and Bruna, 2023; Damian et al., 2023; Zhu et al., 2023) on training with large learning rates, which largely aims at explaining a striking empirical observation called the edge of stability (Eo S) phenomenon. The edge of stability (Eo S) phenomenon is a set of distinctive behaviors observed recently by Cohen et al. (2021) when training neural networks with gradient descent (GD). Here we briefly summarize the salient features of the Eo S and defer a discussion of prior work to Subsection 1.3. Recall that if we use GD to optimize an L-smooth loss function with step size η, then the well-known descent lemma from convex optimization ensures monotonic decrease in the loss so long as L < 2/η. In contrast, when L > 2/η, it is easy to see on simple convex quadratic examples that GD can be unstable (or divergent). The main observation of Cohen et al. (2021) is that when training neural networks1 with constant step size η > 0, the largest eigenvalue of the Hessian at the current iterate (dubbed the sharpness ) initially increases during training ( progressive sharpening ) and saturates near or above 2/η ( Eo S ). A surprising message of the present work is that the answer to our main question is intimately related to the Eo S. Indeed, Figure 4 shows that the GD iterates of our motivating example exhibit the Eo S during the initial phase of training when the bias decreases rapidly. 0 100 200 300 400 500 iteration η = 2/1000 5 0 100 200 300 400 500 iteration 0 100 200 300 400 500 iteration Figure 4: Understanding our main question is surprisingly related to the Eo S. Under the same setting as Figure 1, we report the largest eigenvalue of the Hessian ( sharpness ), and observe that GD iterates lie in the Eo S during the initial phase of training when there is a fast drop in the bias. Consequently, we first set out to thoroughly understand the workings of the Eo S phenomena through a simple example. Specifically, we consider a single-neuron linear neural network in dimension 1, corresponding to the loss R2 (x, y) 7 ℓ(xy) , where ℓis convex, even, and Lipschitz . (3) Although toy models have appeared in works on the Eo S (see Subsection 1.3), our example is simpler than all prior models, and we provably establish the Eo S for (3) with transparent proofs. We then use the newfound insights gleaned from the analysis of (3) to answer our main question. To the best of our knowledge, we provide the first explanation of the mechanism by which a large learning rate can be necessary for learning threshold neurons. 1.2 Our contributions Figure 5: Illustration of two different regimes (the gradient flow regime and the Eo S regime) of the GD dynamics. Explaining the Eo S with a single-neuron example. Although the Eo S has been studied in various settings (see Subsection 1.3 for a discussion), these works either do not rigorously establish the Eo S phenomenon, or they operate under complex settings with opaque assumptions. Here, we study a simple two-dimensional loss function, (x, y) 7 ℓ(xy), where ℓis convex, even, and Lipschitz. Some examples include2 ℓ(s) = 1 2 log(1 + exp( s)) + 1 2 log(1 + exp(+s)) and ℓ(s) = 1 + s2. Surprisingly, GD on this loss already exhibits rich behavior (Figure 5). En route to this result, we rigorously establish the quasistatic dynamics formulated in Ma et al. (2022). 1The phenomenon in Cohen et al. (2021) is most clearly observed for tanh activations, although the appendix of Cohen et al. (2021) contains thorough experimental results for various neural network architectures. 2Suppose that we have a single-layer linear neural network f(x; a, b) = abx, and that the data is drawn according to x = 1, y unif({ 1}). Then, the population loss under the logistic loss is (a, b) 7 ℓsym(ab) with ℓsym(s) = 1 2 log(1 + exp( s)) + 1 2 log(1 + exp(+s)). The elementary nature of our example leads to transparent arguments, and consequently our analysis isolates generalizable principles for bouncing dynamics. To demonstrate this, we use our insights to study our main question of learning threshold neurons. Learning threshold neurons with the mean model. The connection between the single-neuron example and the Re LU network (2) can already be anticipated via a comparison of the dynamics: (i) for the single neuron example, x oscillates wildly while y decreases (Figure 5); (ii) for the Re LU network (2), the sum of weights (a + a+) oscillates while b decreases (Figure 2). We study this example in Section 2 and delineate a transition from the gradient flow regime to the Eo S regime , depending on the step size η and the initialization. Moreover, in the Eo S regime, we rigorously establish asymptotics for the limiting sharpness which depend on the higher-order behavior of ℓ. In particular, for the two losses mentioned above, the limiting sharpness is 2/η + O(η), whereas for losses ℓwhich are exactly quadratic near the origin the limiting sharpness is 2/η + O(1). Figure 6: Illustration of GD dynamics on the Re LU network (2). The sum of weights (a + a+) oscillates while b decreases. In fact, this connection can be made formal by considering an approximation for the GD dynamics for the Re LU network (2). It turns out (see Subsection 3.1 for details) that during the initial phase of training, the dynamics of At := d (a t + a+ t ) and bt due to the Re LU network are well-approximated by the rescaled GD dynamics on the loss (A, b) 7 ℓsym(A g(b)), where the step size for the A-dynamics is multiplied by 2d2, g(b) := Ez N(0,1) Re LU(z + b) is the smoothed Re LU, and ℓsym is the symmetrized logistic loss; see Subsection 3.1 and Figure 8. We refer to these dynamics as the mean model. The mean model bears a great resemblance to the single-neuron example (x, y) 7 ℓ(xy), and hence we can leverage the techniques developed for the latter in order to study the former. Our main result for the mean model precisely explains the phase transition in Figure 1. For any δ > 0, if η (8 δ)π/d2, then the mean model fails to learn threshold neurons: the limiting bias satisfies |b | = Oδ(1/d2). if η (8 + δ)π/d2, then the mean model enters the Eo S and learns threshold neurons: the limiting bias satisfies b Ωδ(1). 1.3 Related work Edge of stability. Our work is motivated by the extensive empirical study of Cohen et al. (2021), which identified the Eo S phenomenon. Subsequently, there has been a flurry of works aiming at developing a theoretical understanding of the Eo S, which we briefly summarize here. Properties of the loss landscape. The works (Ahn et al., 2022; Ma et al., 2022) study the properties of the loss landscape that lead to the Eo S. Namely, Ahn et al. (2022) argue that the existence of forward-invariant subsets near the minimizers allows GD to convergence even in the unstable regime. They also explore various characteristics of Eo S in terms of loss and iterates. Also, Ma et al. (2022) empirically show that the loss landscape of neural networks exhibits subquadratic growth locally around the minimizers. They prove that for a one-dimensional loss, subquadratic growth implies that GD finds a 2-periodic trajectory. Limiting dynamics. Other works characterize the limiting dynamics of the Eo S in various regimes. (Arora et al., 2022; Lyu et al., 2022) show that (normalized) GD tracks a sharpness reduction flow near the manifold of minimizers. The recent work of Damian et al. (2023) obtains a different predicted dynamics based on self-stabilization of the GD trajectory. Also, Ma et al. (2022) describes a quasi-static heuristic for the overall trajectory of GD when one component of the iterate is oscillating. Simple models and beyond. Closely related to our own approach, there are prior works which carefully study simple models. Chen and Bruna (2023) prove global convergence of GD for the two-dimensional function (x, y) 7 (xy 1)2 and a single-neuron student-teacher setting; note that unlike our results, they do not study the limiting sharpness. Wang et al. (2022b) study progressive sharpening for a neural network model. Also, the recent and concurrent work of Zhu et al. (2023) studies the two-dimensional loss (x, y) 7 (x2y2 1)2; to our knowledge, their work is the first to asymptotically and rigorously show that the limiting sharpness of GD is 2/η in a simple setting, at least when initialized locally. In comparison, in Section 2, we perform a global analysis of the limiting sharpness of GD for (x, y) 7 ℓ(xy) for a class of convex, even, and Lipschitz losses ℓ, and in doing so we clearly delineate the gradient flow regime from the Eo S regime . Effect of learning rate on learning. Recently, several works have sought to understand how the choice of learning rate affects the learning process, in terms of the properties of the resulting minima (Jastrzebski et al., 2018; Wu et al., 2018; Mulayoff et al., 2021; Nacson et al., 2022) and the behavior of optimization dynamics (Xing et al., 2018; Jastrzebski et al., 2019, 2020; Lewkowycz et al., 2020; Jastrzebski et al., 2021). Li et al. (2019) demonstrate for a synthethic data distribution and a two-layer Re LU network model that choosing a larger step size for SGD helps with generalization. Subsequent works have shown similar phenomena for regression (Nakkiran, 2020; Wu et al., 2021; Ba et al., 2022), kernel ridge regression Beugnot et al. (2022), and linear diagonal networks Nacson et al. (2022). However, the large step sizes considered in these work still fall under the scope of descent lemma, and most prior works do not theoretically investigate the effect of large step size in the Eo S regime. A notable exception is the work of Wang et al. (2022a), which studies the impact of learning rates greater than 2/smoothness for a matrix factorization problem. Also, the recent work of Andriushchenko et al. (2023) seeks to explain the generalization benefit of SGD in the large step size regime by relying on a heuristic SDE model for the case of linear diagonal networks. Despite this similarity, their main scope is quite different from ours, as we (i) focus on GD instead of SGD and (ii) establish a direct and detailed analysis of the GD dynamics for a model of the motivating sparse coding example. 2 Single-neuron linear network In this section, we analyze the single-neuron linear network model (x, y) 7 f(x, y) := ℓ(x y). 2.1 Basic properties and assumptions Basic properties. If ℓis minimized at 0, then the global minimizers of f are the xand y-axes. The GD iterates xt, yt, for step size η > 0 and iteration t 0 can be written as xt+1 = xt η ℓ (xtyt) yt , yt+1 = yt η ℓ (xtyt) xt . Assumptions. From here onward, we assume η < 1 and the following conditions on ℓ: R R. (A1) ℓis convex, even, 1-Lipschitz, and of class C2 near the origin with ℓ (0) = 1. (A2) There exist constants β > 1 and c > 0 with the following property: for all s = 0, ℓ (s)/s 1 c |s|β 1{|s| c} . We allow β = + , in which case we simply require that ℓ (s) s 1 for all s = 0. Assumption (A2) imposes decay of s 7 ℓ (s)/s locally away from the origin in order to obtain more fine-grained results on the limiting sharpness in Theorem 2. As we show in Lemma 5 below, when ℓ is smooth and has a strictly negative fourth derivative at the origin, then Assumption (A2) holds with β = 2. See Example 1 for some simple examples of losses satisfying our assumptions. 2.2 Two different regimes for GD depending on the step size Before stating rigorous results, in this section we begin by giving an intuitive understanding of the GD dynamics. It turns out that for a given initialization (x0, y0), there are two different regimes for the GD dynamics depending on the step size η. Namely, there exists a threshold on the step size such that (i) below the threshold, GD remains close to the gradient flow for all time, and (ii) above the threshold, GD enters the edge of stability and diverges away from the gradient flow. See Figure 9. First, recall that the GD dynamics are symmetric in x, y and that the lines y = x are invariant. Hence, we may assume without loss of generality that y0 > x0 > 0 , yt > |xt| for all t 1 , and GD converges to (0, y ) for y > 0 . From the expression (8) for the Hessian of f and our normalization ℓ (0) = 1, it follows that the sharpness (the largest eigenvalue of loss Hessian) reached by GD in this example is precisely y2 . Initially, in both regimes, the GD dynamics tracks the continuous-time gradient flow. Our first observation is that the gradient flow admits a conserved quantity, thereby allowing us to predict the dynamics in this initial phase. Lemma 1 (conserved quantity). Along the gradient flow for f, the quantity y2 x2 is conserved. Proof. Differentiating y2 t x2 t with respect to t gives 2yt ( ℓ (xtyt) xt) 2xt ( ℓ(xtyt) yt) = 0. Lemma 1 implies that the gradient flow converges to (0, y GF ) = (0, p y2 0 x2 0). For GD with step size η > 0, the quantity y2 x2 is no longer conserved, but we show in Lemma 6 that it is approximately conserved until the GD iterate lies close to the y-axis. Hence, GD initialized at (x0, y0) also reaches the y-axis approximately at the point (xt0, yt0) (0, p y2 0 x2 0). At this point, GD either approximately converges to the gradient flow solution (0, p y2 0 x2 0) or diverges away from it, depending on whether or not y2 t0 > 2/η. To see this, for |xt0yt0| 1, we can Taylor expand ℓ near zero to obtain the approximate dynamics for x (recalling ℓ (0) = 1), xt0+1 xt0 ηxt0y2 t0 = (1 ηy2 t0) xt0 . (4) From (4), we deduce the following conclusions. (i) If y2 t0 < 2/η, then |1 ηy2 t0| < 1. Since yt is decreasing, it implies that |1 ηy2 t | < 1 for all t t0, and so |xt| converges to zero exponentially fast. (ii) On the other hand, if y2 t0 > 2/η, then |1 ηy2 t0| > 1, i.e., the magnitude of xt0 increases in the next iteration, and hence GD cannot stabilize. In fact, in the approximate dynamics, xt0+1 has the opposite sign as xt0, i.e., xt0 jumps across the y-axis. One can show that the bouncing of the x variable continues until y2 t has decreased past 2/η, at which point we are in the previous case and GD approximately converges to (0, 2/η). This reasoning, combined with the expression for the Hessian of f, shows that sharpness(0, y ) := λmax 2f(0, y ) min y2 0 x2 0, 2/η = min{gradient flow sharpness, Eo S prediction} . Accordingly, we refer to the case y2 0 x2 0 < 2/η as the gradient flow regime, and the case y2 0 x2 0 > 2/η as the Eo S regime. See Figure 5 and Figure 9 for illustrations of these two regimes; see also Figure 10 for detailed illustrations of the Eo S regime. In the subsequent sections, we aim to make the above reasoning rigorous. For example, instead of the approximate dynamics (4), we consider the original GD dynamics and justify the Taylor approximation. Also, in the Eo S regime, rather than loosely asserting that |xt| 0 exponentially fast and hence the dynamics stabilizes quickly once y2 t < 2/η, we track precisely how long this convergence takes so that we can bound the gap between the limiting sharpness and the prediction 2/η. 2.3 Results Gradient flow regime. Our first rigorous result is that when y2 0 x2 0 = (2 δ)/η for some constant δ (0, 2), then the limiting sharpness of GD with step size η is y2 0 x2 0 + O(1) = (2 δ)/η + O(1), which is precisely the sharpness attained by the gradient flow up to a controlled error term. In fact, our theorem is slightly more general, as it covers initializations in which δ can scale mildly with η. The precise statement is as follows. Theorem 1 (gradient flow regime; see Subsection C.2). Suppose we run GD with step size η > 0 on the objective f, where f(x, y) := ℓ(xy), and ℓsatisfies Assumptions (A1) and (A2). Let ( x, y) R2 satisfy y > x > 0 with y2 x2 = 1. Suppose we initialize GD at (x0, y0) := ( 2 δ η )1/2 ( x, y), where δ (0, 2) and η δ1/2 (2 δ). Then, GD converges to (0, y ) satisfying η O(2 δ) O η min{δ, 2 δ} λmax 2f(0, y ) 2 δ η + O η 2 δ where the implied constants depend on x, y, and ℓ, but not on δ, η. The proof of Theorem 1 is based on a two-stage analysis. In the first stage, we use Lemma 6 on the approximate conservation of y2 x2 along GD in order to show that GD lands near the y-axis with y2 t0 (2 δ)/η. In the second stage, we use the assumptions on ℓin order to control the rate of convergence of |xt| to 0, which is subsequently used to control the final deviation of y2 from (2 δ)/η. Eo S regime. Our next result states that when y2 0 x2 0 > 2/η, then the limiting sharpness of GD is close to the Eo S prediction of 2/η, up to an error term which depends on the exponent β in (A2). Theorem 2 (Eo S; see Subsection C.4). Suppose we run GD on f with step size η > 0, where f(x, y) := ℓ(xy), and ℓsatisfies (A1) and (A2). Let ( x, y) R2 satisfy y > x > 0 with y2 x2 = 1. Suppose we initialize GD at (x0, y0) := p (2+δ)/η ( x, y), where δ > 0 is a constant. Also, assume that for all t 1 such that y2 t > 2/η, we have xt = 0. Then, GD converges to (0, y ) satisfying 2/η O(η1/(β 1)) λmax 2f(0, y ) 2/η , where the implied constants depend on x, y, δ 1, and ℓ, but not on η. Remarks on the assumptions. The initialization in our results is such that both y0 and y0 x0 are on the same scale, i.e., y0, y0 x0 = Θ(1/ η). This rules out extreme initializations such as y0 x0, which are problematic because they lie too close to the invariant line y = x. Since our aim in this work is not to explore every edge case, we focus on this setting for simplicity. Moreover, we imposed the assumption that the iterates of GD do not exactly hit the y-axis before crossing y2 = 2/η. This is necessary because if xt = 0 for some iteration t, then (xt , yt ) = (xt, yt) for all t > t, and hence the limiting sharpness may not be close to 2/η. This assumption holds generically, e.g., if we perturb each iterate of GD with a vanishing amount of noise from a continuous distribution, and we conjecture that for any η > 0, the assumption holds for all but a measure zero set of initializations. When β = + , which is the case for the Huber loss in Example 1, the limiting sharpness is 2/η + O(1). When β = 2, which is the case for the logistic and square root losses in Example 1, the limiting sharpness is 2/η + O(η). Numerical experiments show that our error bound of O(η1/(β 1)) is sharp; see Figure 11 below. We make a few remark about the proof. As we outline the proof in Subsection C.3, in turns out in order to bound the gap 2/η y2 , the proof requires a control of the size |xtyt|, where t is the first iteration such that y2t crosses 2/η. However, controlling the size of |xtyt| is surprisingly delicate as it requires a fine-grained understanding of the bouncing phase. The insight that guides the proof is the observation that during the bouncing phase, the GD iterates lie close to a certain envelope (Figure 9). As a by-product of our analysis, we obtain a rigorous version of the quasi-static principle from which can more accurately track the sharpness gap and convergence rate (see Subsection C.5). The results of Theorem 1, Theorem 2, and Theorem 5 are displayed pictorially as Figure 9. 3 Understanding the bias evolution of the Re LU network In this section, we use the insights from Section 2 to answer our main question, namely understanding the role of a large step size in learning threshold neurons for the Re LU network (2). Based on the observed dynamics (Figure 2), we can make our question more concrete as follows. Q. What is the role of a large step size during the initial phase of training in which (i) the bias b rapidly decreases and (ii) the sum of weights a + a+ oscillates? 3.1 Approximating the initial phase of GD with the mean model Deferring details to Appendix D, the GD dynamics for the Re LU network (2) in the initial phase are well-approximated by GD dynamics on (a , a+, b) 7 ℓsym(d (a + a+) g(b)) , where ℓsym(s) := 1 2(log(1 + exp( s)) + log(1 + exp(+s))) and g(b) := Ez N(0,1) Re LU(z + b) is the smoothed Re LU. The GD dynamics can be compactly written in terms of the parameter At := d (a t + a+ t ). At+1 = At 2d2η ℓ sym(Atg(bt)) g(bt) , bt+1 = bt η ℓ sym(Atg(bt)) Atg (bt) . (5) Figure 7: The smoothed Re LU g(b) We call these dynamics the mean model. Figure 8 shows that the mean model closely captures the GD dynamics for the Re LU network (2), and we henceforth focus on analyzing the mean model. The main advantage of the representation (5) is that it makes apparent the connection to the single-neuron example that we studied in Section 2. More specifically, (5) can be interpreted as the rescaled GD dynamics on the objective (A, b) 7 ℓsym(Ag(b)), where the step size for the A-dynamics is multiplied by 2d2. Due to this resemblance, we can apply the techniques from Section 2. 0 10 20 30 40 50 iteration 0 100 200 300 400 500 iteration Figure 8: Under the same setting as Figure 1, we compare the mean model with the GD dynamics of the Re LU network. The mean model is plotted with black dashed line. Note that the mean model tracks the GD dynamics quite well during the initial phase of training. 3.2 Two different regimes for the mean model Throughout the section, we use the shorthand ℓ:= ℓsym, and focus on initializing wiht a 0 = Θ(1/d), a + a+ = 0, and b0 = 0. This implies A0 = Θ(1). We also note the following fact for later use. Lemma 2 (formula for the smoothed Re LU; see Subsection E.1). The smoothed Re LU function g can be expressed in terms of the PDF φ and the CDF Φ of the standard Gaussian distribution as g(b) = φ(b) + b Φ(b). In particular, g = Φ. Note also that bt is monotonically decreasing. This is because ℓ (Atg(bt)) Atg (bt) 0 since ℓ is an odd function and g(b), g (b) > 0 for any b R. Following Subsection 2.2, we begin with the continuous-time dynamics of the mean model: A = 2d2 ℓ (Ag(b)) g(b) , b = ℓ (Ag(b)) Ag (b) . (6) Lemma 3 (conserved quantity; see Subsection E.1). Let κ : R R be defined as κ(b) := R b 0 g/g . Along the gradient flow (6), the quantity 1 2A2 2d2κ(b) is conserved. Based on Lemma 3, if we initialize the continuous-time dynamics (6) at (A0, 0) and if At 0, then the limiting value of the bias b GF satisfies κ(b GF ) = 1 4d2 A2 0, which implies that b GF = Θ( 1 d2 ); indeed, this holds since κ (0) = g(0)/g (0) > 0, so there exist constants c0, c1 > 0 such that c0b κ(b) c1b for all 1 b 0. Since the mean model (5) tracks the continuous-time dynamics (6) until it reaches the b-axis, the mean model initialized at (A0, 0) also approximately reaches (At0, bt0) (0, Θ( 1 d2 )) (0, 0) in high dimension d 1. In other words, the continuoustime dynamics (6) fails to learn threshold neurons. Once the mean model reaches the b-axis, we again identify two different regimes depending on the step size. A Taylor expansion of ℓ around the origin yields the following approximate dynamics (here ℓ (0) = 1/4): At0+1 At0 ηd2 2 At0 g(bt0)2 = At0 1 ηd2 2 g(bt0)2 . We conclude that the condition which now dictates whether we have bouncing or convergence is 1 2 d2g(bt0)2 > 2/η. (i) Gradient flow regime: If 2/η > d2g(0)2/2 = d2/(4π) (since g(0)2 = 1/(2π)), i.e., the step size η is below the threshold 8π/d2, then the final bias of the mean model b MM satisfies b MM b GF 0. In other words, the mean model fails to learn threshold neurons. (ii) Eo S regime: If 2/η < d2/(4π), i.e., the step size η is above the threshold 8π/d2, then 1 2 d2g2(b MM ) < 2/η, i.e., b MM < g 1(2/ p ηd2). For instance, if η = 10π d2 , then b MM < 0.087. In other words, the mean model successfully learns threshold neurons. 3.3 Results for the mean model Theorem 3 (mean model, gradient flow regime; see Appendix E). Consider the mean model (5) initialized at (A0, 0), with step size η = (8 δ) π d2 for some δ > 0. Let γ := 1 200 min{δ, 8 δ, 8 δ |A0|}. Then, as long as η γ/|A0|, the limiting bias b MM satisfies 0 b MM (η/γ) |A0| = OA0,δ(1/d2) . In other words, the mean model fails to learn threshold neurons. Theorem 4 (mean model, Eo S regime; see Appendix E). Consider the mean model initialized at (A0, 0), with step size η = (8+δ) π d2 for some δ > 0. Furthermore, assume that for all t 1 such that 1 2 d2g(bt)2 > 2/η, we have At = 0. Then, the limiting bias b MM satisfies b MM g 1 2/ p (8 + δ) π Ωδ(1) . For instance, if η = 10π d2 , then b MM < 0.087. In other words, the mean model successfully learns threshold neurons. 4 Conclusion In this paper, we present the first explanation for the emergence of threshold neuron (i.e., Re LU neurons with negative bias) in models such as the sparse coding model (1) through a novel connection with the edge of stability (Eo S) phenomenon. Along the way, we obtain a detailed and rigorous understanding of the dynamics of GD in the Eo S regime for a simple class of loss functions, thereby shedding light on the impact of large learning rates in non-convex optimization. Our approach is largely inspired by the recent paradigm of physics-style approaches to understanding deep learning based on simplified models and controlled experiments (c.f. (Zhang et al., 2022; von Oswald et al., 2023; Abernethy et al., 2023; Allen-Zhu and Li, 2023; Li et al., 2023; Ahn et al., 2023a,b)). We found such physics-style approach quite effective to understand deep learning, especially given the complexity of modern deep neural networks. We hope that our work inspires further research on understanding the working mechanisms of deep learning. Many interesting questions remain, and we conclude with some directions for future research. Extending the analysis of Eo S to richer models. Although the analysis we present in this work is restricted to simple models, the underlying principles can potentially be applied to more general settings. In this direction, it would be interesting to study models which capture the impact of the depth of the neural network on the Eo S phenomenon. Notably, a follow-up work by Song and Yun (2023) uses bifurcation theory to extend our results to more complex models. The interplay between the Eo S and the choice of optimization algorithm. As discussed in Subsection 2.3, the bouncing phase of the Eo S substantially slows down the convergence of GD (see Figure 11). Investigating how different optimization algorithm (e.g., SGD, or GD with momentum) interact with the Eo S phenomenon could potentially lead to practical speed-ups or improved generalization. Notably, a follow up work by Dai et al. (2023) studies the working mechanisms of a popular modern optimization technique called sharpness-aware minimization (Foret et al., 2021) based on our sparse coding problem. An end-to-end analysis of the sparse coding model. Finally, we have left open the motivating question of analyzing how two-layer Re LU networks learn to solve the sparse coding model (1). Despite the apparent simplicity of the problem, its analysis has thus far remained out of reach, and we believe that a resolution to this question would constitute compelling and substantial progress towards understanding neural network learning. We are hopeful that the insights in this paper provide the first step towards this goal. Acknowledgments We thank Ronan Eldan, Suriya Gunasekar, Yuanzhi Li, Jonathan Niles-Weed, and Adil Salim for initial discussions on this project. KA was supported by the ONR grant (N00014-20-1-2394) and MIT-IBM Watson as well as a Vannevar Bush fellowship from Office of the Secretary of Defense. SC was supported by the NSF TRIPODS program (award DMS-2022448). Jacob Abernethy, Alekh Agarwal, Teodor V. Marinov, and Manfred K. Warmuth. A mechanism for sample-efficient in-context learning for sparse retrieval tasks. ar Xiv preprint ar Xiv:2305.17040, 2023. Kwangjun Ahn, Jingzhao Zhang, and Suvrit Sra. Understanding the unstable convergence of gradient descent. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 247 257. PMLR, 7 2022. Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. Neur IPS 2023 (ar Xiv:2306.00297), 2023a. Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Ali Jadbabaie, and Suvrit Sra. Linear attention is (maybe) all you need (to understand transformer optimization). ar Xiv 2310.01082, 2023b. Zeyuan Allen-Zhu and Yuanzhi Li. Feature purification: how adversarial training performs robust deep learning. In 2021 IEEE 62nd Annual Symposium on Foundations of Computer Science (FOCS), pages 977 988. IEEE, 2022. Zeyuan Allen-Zhu and Yuanzhi Li. Physics of language models: part 1, context-free grammar. ar Xiv preprint ar Xiv:2305.13673, 2023. Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via overparameterization. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 242 252. PMLR, 6 2019. Maksym Andriushchenko, Aditya V. Varre, Loucas Pillaud-Vivien, and Nicolas Flammarion. SGD with large step sizes learns sparse features. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 903 925. PMLR, 7 2023. Sanjeev Arora, Rong Ge, Tengyu Ma, and Ankur Moitra. Simple, efficient, and neural algorithms for sparse coding. In Peter Grünwald, Elad Hazan, and Satyen Kale, editors, Proceedings of the 28th Conference on Learning Theory, volume 40 of Proceedings of Machine Learning Research, pages 113 149, Paris, France, 7 2015. PMLR. Sanjeev Arora, Simon Du, Wei Hu, Zhiyuan Li, and Ruosong Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pages 322 332. PMLR, 6 2019. Sanjeev Arora, Zhiyuan Li, and Abhishek Panigrahi. Understanding gradient descent on the edge of stability in deep learning. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 948 1024. PMLR, 7 2022. Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang. Highdimensional asymptotics of feature learning: how one gradient step improves the representation. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. Gaspard Beugnot, Julien Mairal, and Alessandro Rudi. On the benefits of large learning rates for kernel methods. In Po-Ling Loh and Maxim Raginsky, editors, Proceedings of Thirty Fifth Conference on Learning Theory, volume 178 of Proceedings of Machine Learning Research, pages 254 282. PMLR, 7 2022. Lei Chen and Joan Bruna. Beyond the edge of stability via two-step gradient updates. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 4330 4391. PMLR, 7 2023. Lénaïc Chizat. Mean-field Langevin dynamics: exponential convergence and annealing. Transactions on Machine Learning Research, 2022. Lénaïc Chizat and Francis Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. In S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa Bianchi, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 31. Curran Associates, Inc., 2018. Lénaïc Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. Jeremy Cohen, Simran Kaur, Yuanzhi Li, J Zico Kolter, and Ameet Talwalkar. Gradient descent on neural networks typically occurs at the edge of stability. In International Conference on Learning Representations, 2021. Yan Dai, Kwangjun Ahn, and Suvrit Sra. The crucial role of normalization in sharpness-aware minimization. Neur IPS 2023 (ar Xiv:2305.15287), 2023. Alex Damian, Eshaan Nichani, and Jason D. Lee. Self-stabilization: the implicit bias of gradient descent at the edge of stability. In The Eleventh International Conference on Learning Representations, 2023. Simon S. Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. In International Conference on Learning Representations, 2019. Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning Representations, 2021. Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: convergence and generalization in neural networks. In Proceedings of the 32nd Advances in Neural Information Processing Systems, pages 8580 8589, 2018. Stanislaw Jastrzebski, Zachary Kenton, Devansh Arpit, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amos Storkey. Width of minima reached by stochastic gradient descent is influenced by learning rate to batch size ratio. In Vˇera K urková, Yannis Manolopoulos, Barbara Hammer, Lazaros Iliadis, and Ilias Maglogiannis, editors, Artificial Neural Networks and Machine Learning ICANN 2018, pages 392 402, Cham, 2018. Springer International Publishing. Stanislaw Jastrzebski, Maciej Szymczak, Stanislav Fort, Devansh Arpit, Jacek Tabor, Kyunghyun Cho, and Krzysztof Geras. The break-even point on optimization trajectories of deep neural networks. In International Conference on Learning Representations, 2020. Stanislaw Jastrzebski, Devansh Arpit, Oliver Astrand, Giancarlo B. Kerg, Huan Wang, Caiming Xiong, Richard Socher, Kyunghyun Cho, and Krzysztof J Geras. Catastrophic Fisher explosion: early phase Fisher matrix impacts generalization. In International Conference on Machine Learning, pages 4772 4784. PMLR, 2021. Stanisław Jastrzebski, Zachary Kenton, Nicolas Ballas, Asja Fischer, Yoshua Bengio, and Amost Storkey. On the relation between the sharpest directions of DNN loss and the SGD step length. In International Conference on Learning Representations, 2019. Stefani Karp, Ezra Winston, Yuanzhi Li, and Aarti Singh. Local signal adaptivity: provable feature learning in neural networks beyond kernels. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, volume 34, pages 24883 24897. Curran Associates, Inc., 2021. Frederic Koehler and Andrej Risteski. The comparative power of Re LU networks and polynomial kernels in the presence of sparse latent structure. In International Conference on Learning Representations, 2018. Aitor Lewkowycz, Yasaman Bahri, Ethan Dyer, Jascha Sohl-Dickstein, and Guy Gur-Ari. The large learning rate phase of deep learning: the catapult mechanism. ar Xiv preprint ar Xiv:2003.02218, 2020. Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization effect of initial large learning rate in training neural networks. Advances in Neural Information Processing Systems, 32, 2019. Yuchen Li, Yuanzhi Li, and Andrej Risteski. How do transformers learn topic structure: towards a mechanistic understanding. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 19689 19729. PMLR, 7 2023. Kaifeng Lyu, Zhiyuan Li, and Sanjeev Arora. Understanding the generalization benefit of normalization layers: sharpness reduction. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 34689 34708. Curran Associates, Inc., 2022. Chao Ma, Daniel Kunin, Lei Wu, and Lexing Ying. Beyond the quadratic approximation: the multiscale structure of neural network loss landscapes. Journal of Machine Learning, 1(3): 247 267, 2022. Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Alina Beygelzimer and Daniel Hsu, editors, Proceedings of the Thirty-Second Conference on Learning Theory, volume 99 of Proceedings of Machine Learning Research, pages 2388 2464. PMLR, 6 2019. Rotem Mulayoff, Tomer Michaeli, and Daniel Soudry. The implicit bias of minima stability: a view from function space. Advances in Neural Information Processing Systems, 34:17749 17761, 2021. Mor Shpigel Nacson, Kavya Ravichandran, Nathan Srebro, and Daniel Soudry. Implicit bias of the step size in linear diagonal neural networks. In International Conference on Machine Learning, pages 16270 16295. PMLR, 2022. Preetum Nakkiran. Learning rate annealing can provably help generalization, even for convex problems. ar Xiv preprint ar Xiv:2005.07360, 2020. Atsushi Nitanda, Denny Wu, and Taiji Suzuki. Convex analysis of the mean field Langevin dynamics. In Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, editors, Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, volume 151 of Proceedings of Machine Learning Research, pages 9741 9757. PMLR, 3 2022. Bruno A Olshausen and David J Field. Sparse coding with an overcomplete basis set: a strategy employed by v1? Vision Research, 37(23):3311 3325, 1997. Bruno A Olshausen and David J Field. Sparse coding of sensory inputs. Current Opinion in Neurobiology, 14(4):481 487, 2004. Samet Oymak and Mahdi Soltanolkotabi. Toward moderate overparameterization: global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84 105, 2020. Grant M. Rotskoff and Eric Vanden-Eijnden. Trainability and accuracy of artificial neural networks: an interacting particle system approach. Comm. Pure Appl. Math., 75(9):1889 1935, 2022. Minhak Song and Chulhee Yun. Trajectory alignment: understanding the edge of stability phenomenon via bifurcation theory. Neur IPS 2023 (ar Xiv:2307.04204), 2023. William E. Vinje and Jack L. Gallant. Sparse coding and decorrelation in primary visual cortex during natural vision. Science, 287(5456):1273 1276, 2000. Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. In International Conference on Machine Learning, pages 35151 35174. PMLR, 2023. Yuqing Wang, Minshuo Chen, Tuo Zhao, and Molei Tao. Large learning rate tames homogeneity: convergence and balancing effect. In International Conference on Learning Representations, 2022a. Zixuan Wang, Zhouzi Li, and Jian Li. Analyzing sharpness along GD trajectory: progressive sharpening and edge of stability. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh, editors, Advances in Neural Information Processing Systems, volume 35, pages 9983 9994. Curran Associates, Inc., 2022b. Jingfeng Wu, Difan Zou, Vladimir Braverman, and Quanquan Gu. Direction matters: on the implicit bias of stochastic gradient descent with moderate learning rate. In International Conference on Learning Representations, 2021. Lei Wu, Chao Ma, and Weinan E. How SGD selects the global minima in over-parameterized learning: a dynamical stability perspective. Advances in Neural Information Processing Systems, 31:8279 8288, 2018. Chen Xing, Devansh Arpit, Christos Tsirigotis, and Yoshua Bengio. A walk with SGD. ar Xiv preprint ar Xiv:1802.08770, 2018. Jianchao Yang, Kai Yu, Yihong Gong, and Thomas Huang. Linear spatial pyramid matching using sparse coding for image classification. In 2009 IEEE Conference on Computer Vision and Pattern Recognition, pages 1794 1801. IEEE, 2009. Yi Zhang, Arturs Backurs, Sébastien Bubeck, Ronen Eldan, Suriya Gunasekar, and Tal Wagner. Unveiling transformers with lego: a synthetic reasoning task. ar Xiv preprint ar Xiv:2206.04301, 2022. Xingyu Zhu, Zixuan Wang, Xiang Wang, Mo Zhou, and Rong Ge. Understanding edge-of-stability training dynamics with a minimalist example. In The Eleventh International Conference on Learning Representations, 2023. A Additional illustrations for Section 2 15 B Further experimental results 16 B.1 Experiments for the full sparse coding model . . . . . . . . . . . . . . . . . . . . 17 B.2 Experiments on the CIFAR-10 dataset . . . . . . . . . . . . . . . . . . . . . . . . 17 C Proofs for the single-neuron linear network 18 C.1 Approximate conservation along GD . . . . . . . . . . . . . . . . . . . . . . . . . 19 C.2 Gradient flow regime . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 C.3 Eo S regime: proof outline . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 C.4 Eo S regime: crossing the threshold and the convergence phase . . . . . . . . . . . 21 C.5 Eo S regime: quasi-static analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . 24 D Deferred derivations of mean model 28 E Proofs for the mean model 29 E.1 Deferred proofs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 E.2 Gradient flow regime . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 E.3 Eo S regime . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30 A Additional illustrations for Section 2 In this section, we provide some illustrations of the results presented in Section 2. We first illustrate the two different regimes of GD presented in Subsection 2.2. p 2/η p y2 0 x2 0 p y2 0 x2 0 Figure 9: Two regimes for GD. We run GD on the square root loss with step size 1 4. The gradient flow regime is illustrated on the left for (x0, y0) = (3, 4). GD (blue) tracks the gradient flow (green) when η < 2/(y2 0 x2 0). Otherwise, as illustrated on the right for (x0, y0) = (3, 6), GD is in the Eo S regime and goes through a gradient flow phase (blue), an intermediate bouncing phase (orange) that tracks the quasi-static envelope (purple), and a converging phase (red). Next, we next present detailed illustrations of the edge-of-stability regime depending on the choice of step size. Compare this plot with our theoretical results characterized in Theorem 2. p 2/η1 p 2/η2 p 2/η3 p 2/η4 p 2/η1 p 2/η2 p 2/η3 p 2/η4 p 2/η1 p 2/η2 p 2/η3 p 2/η4 p 2/η1 p 2/η2 p 2/η3 p 2/η4 0.0 2.5 5.0 7.5 10.0 12.5 15.0 time (η x iteration) Figure 10: We plot the GD trajectory for ℓ(s) = 1 + s2 and sharpness for step sizes 2/η1 = 9, 2/η2 = 7, 2/η3 = 5, and 2/η4 = 3. In the Eo S regime, the final sharpness is close to 2/(step size). Lastly, we present another detailed illustrations of Theorem 2 in terms of its dependence on β. 5 4 3 2 1 log(η) log(2/η y2 ) β = 3 2 β = 2 β = 3 β = 10 5.0 4.5 4.0 3.5 3.0 2.5 2.0 log(η) T η 2 log(η 1) β = 4 3 β = 3 2 β = 2 β = 4 Figure 11: (Left) Log-log plot of the sharpness gap as a function of η, for ℓβ in Example 1 and β = 3 2, 2, 3, 10. (Right) Log-log plot of the iteration count for the bouncing region with y2 t [ 2 η ] as a function of η, for ℓβ in Example 1 and β = 4 2, 2, 4. The dashed lines show the predicted sharpness gap and iteration count with an offset computed via linear regression of the data for η < e 2. B Further experimental results In this section, we report further experimental results which demonstrate that our theory, while limited to the specific models we study (namely, the single-neuron example and the mean model), is in fact indicative of behaviors commonly observed in more realistic instances of neural network training. In particular, we show that threshold neurons often emerge in the presence of oscillations in the other weight parameters of the network. B.1 Experiments for the full sparse coding model We provide the details for the top plot of Figure 3. consider the sparse coding model in the form (1). Compared to (2), we assume that the basis vectors are unknown, and the neural network learn them through additional parameters W = (wi)m i=1 together with m different weights a = (ai)m i=1 for the second layer as follows: f(x; a, W , b) = i=1 ai Re LU wi, x + b . (7) We show results for d = 100, m = 2000. We generate n = 20000 data points according to the aforementioned sparse coding model with λ = 5. We use the He initialization, i.e., a N(0, Im/m), w N(0, Id/d), and b = 0. As shown in the top plot of Figure 3, the bias decreases more with the large learning rate. Further, we report the behavior of the average of second layer weights in Figure 12 (left), and confirm that the sum oscillates. 0 5 10 15 20 25 30 step avg. 2nd layer weight lr=0.03 lr=0.04 lr=0.05 lr=0.07 lr=0.1 lr=0.2 0 10 20 30 40 50 60 step lr=0.2 lr=0.1 lr=0.05 Figure 12: (Left) The average of the second layer weights of the Re LU network (7). Note that the average value oscillates similarly to our findings for the mean model. (Right) Oscillation of logit of Res Net18 model averaged over the (binary) CIFAR-10 training set. Since the dataset is binary, the logit is simply a scalar. B.2 Experiments on the CIFAR-10 dataset Next, we provide the details for the bottom plot of Figure 3. We train Res Net-18 on a binarized version of the CIFAR-10 dataset formed by taking only the first two classes; this is done for the purpose of monitoring the average logit of the network. The average logit is measured over the entire training set. The median bias is measured at the last convolutional layer right before the pooling. For the optimizer, we use full-batch GD with no momentum or weight decay, plus a cosine learning rate scheduler where learning rates shown in the plots are the initial values. Oscillation of expected output (logit) of the network. Bearing a striking resemblance to our two-layer models, as one can see from Figure 12 (right) that the expected mean of the output (logit) of the deep net also oscillates due to GD dynamics. As we have argued in the previous sections, this occurs as the bias parameters are driven towards negative values. Results for SGD. In Figure 13, we report qualitatively similar phenomena when we instead train Res Net-18 with stochastic gradient descent (SGD), where we use all ten classes of CIFAR-10. Again, the median bias is measured at the last convolutional layer. We further report the average activation which is the output of the Re LU activation at the last convolutional layer, averaged over the neurons and the entire training set. The average activation statistics represent the hidden representations before the linear classifier part, and lower values represent sparser representations. Interestingly, the threshold neuron also emerges with larger step sizes similarly to the case of gradient descent. 0 50 100 150 200 250 300 350 400 avg. activation 0 50 100 150 200 250 300 350 400 median bias lr=0.2 lr=0.1 lr=0.05 lr=0.01 0 50 100 150 200 250 300 350 400 0.2 0.3 0.4 0.5 0.6 0.7 0.8 avg. activation 0 50 100 150 200 250 300 350 400 median bias lr=0.2 lr=0.1 lr=0.05 lr=0.01 Figure 13: SGD dynamics of Res Net-18 on (multiclass) CIFAR-10 with various learning rates and batch sizes. (Top) batch size 100; (Bottom) batch size 1000. The results are consistent across different batch sizes. C Proofs for the single-neuron linear network We start by describing basic and relevant properties of the model and the assumptions on ℓ. Basic properties. If ℓis minimized at 0, then the global minimizers of f are the xand y-axes. The gradient and Hessian of f are given by: f(x, y) = ℓ (xy) y x 2f(x, y) = ℓ (xy) y x 2 + ℓ (xy) 0 1 1 0 This results in GD iterates xt, yt, for step size η > 0 and iteration t 0: xt+1 = xt η ℓ (xtyt) yt , yt+1 = yt η ℓ (xtyt) xt . Lemma 4 (invariant lines). Assume that ℓis even, so that ℓ is odd. Then, the lines y = x are invariant for gradient descent on f. Proof. If yt = xt, then yt+1 = yt η ℓ (xtyt) xt = xt η ℓ (x2 t) xt , xt+1 = xt η ℓ (xtyt) yt = xt η ℓ (x2 t) xt , and hence yt+1 = xt+1. Note that the iterates (xt)t 0 are the iterates of GD with step size η on the one-dimensional loss function x 7 1 We focus instead on initializing away from these two lines. We now state our assumptions on ℓ. We gather together some elementary properties of ℓ. Lemma 5 (properties of ℓ). Suppose that Assumption (A1) holds. 1. ℓis minimized at the origin and ℓ (0) = 0. 2. Suppose that ℓis four times continuously differentiable near the origin. If Assumption (A2) holds, then ℓ(4)(0) 0. Conversely, if ℓ(4)(0) < 0, then Assumption (A2) holds for β = 2. Proof. The first statement is straightforward. The second statement follows from Taylor expansion: for s = 0 near the origin, s = ℓ (0) + ℓ (0) s + R s 0 (s r) ℓ (r) dr s = 1 + Z s s ℓ (r) dr . (9) Since ℓ is odd, then Assumption (A2) and (9) imply that ℓ is non-positive on (0, ε) for some ε > 0, which in turn implies ℓ(4)(0) 0. Conversely, if ℓ(4)(0) < 0, then there exists ε > 0 such that ℓ (s) εs for s (0, ε). From (9), we see that ℓ (s)/s 1 ε R s 0 (s r) dr 1 εs2/2. By symmetry, we conclude that Assumption (A2) holds with β = 2 and some c > 0. We give some simple examples of losses satisfying our assumptions. Example 1. The examples below showcase several functions ℓthat satisfy Assumptions (A1) and (A2) with different values of β. Rescaled and symmetrized logistic loss. ℓrsym(s) := 1 2 ℓlogi( 2s) + 1 2 ℓlogi(+2s). Note ℓ rsym(s) = tanh(s), thus ℓ rsym(s)/s 1 and ℓ rsym(s)/s 1 1 4 |s|2, for |s| < 1 4. Square root loss. ℓsqrt(s) := 1 + s2. Note ℓ sqrt(s) = s 1+s2 , thus ℓ sqrt(s)/s 1 and ℓ sqrt(s)/s 1 2 5 |s|2, for |s| < 2 Huber loss. ℓHub(s) := s2 2 1{s [ 1, 1]} + |s| 1 2 1{s / [ 1, 1]}. Note ℓ Hub(s) = s 1{s [ 1, 1]} + sgn(s) 1{s / [ 1, 1]}, thus ℓ Hub(s)/s 1, i.e., we have Assumption (A2) with β = + . Higher-order. For β > 1 let cβ := 1 β+1 β β+1 β and rβ := β+1 β . We define ℓβ implicitly via its derivative ℓ β(s) := s 1 cβ |s|β 1{s2 < r2 β} + sgn(s) 1{s2 r2 β} . By definition, ℓ β(s)/s 1 and ℓ β(s)/s 1 cℓ|s|β, where cℓ= cβ rβ. We now prove our main results from Subsection 2.3 in order. C.1 Approximate conservation along GD We begin by stating and proving the approximate conservation of y2 x2 for the GD dynamics. Lemma 6 (approximately conserved quantity). Let ( x, y) R2 be such that y > x > 0 with y2 x2 = 1. Suppose that we run GD on f with step size η with initial point (x0, y0) := q γ η ( x, y), for some γ > 0. Then, there exists t0 = O( 1 η) such that supt t0 |xt| O( p (γ 1 γ) η) and y2 t0 x2 t0 = 1 O(η) (y2 0 x2 0) , where the implied constant depends on x, y, and ℓ. Proof. Let Dt := y2 t x2 t and note that Dt+1 = yt η ℓ (xtyt) xt 2 xt η ℓ (xtyt) yt 2 = 1 η2 ℓ (xtyt)2 Dt . Since ℓis 1-Lipschitz, then Dt+1 = (1 O(η2)) Dt. This shows that for t 1/η2, we have y2 t x2 t = Dt D0 = y2 0 x2 0 γ/η. Since ℓ (0) = 1, there exist constants c0, c1 > 0 such that ℓ (|xy|) ℓ (c0) c1 whenever |xy| c0. Hence, for all t 1 such that t 1/η2, xt > 0, and |xtyt| c0, we have y2 t γ/η and xt+1 = xt η ℓ (xtyt) yt = xt Θ(ηyt) = xt Θ γη . (10) γ/η, this shows that after at most O(1/η) iterations, we must have either xt < 0 or |xtyt| c0 for the first time. In the first case, (10) shows that |xt| γη. In the second case, since y2 t γ/η, we have |xt| p η/γ. Let t0 denote the iteration at which this occurs. Next, for iterations t t0, we use the dynamics (10) for x and the fact that ℓ (xtyt) has the same sign as xt to conclude that there are two possibilities: either xt+1 has the same sign as xt, in which case |xt+1| |xt|, or xt+1 has the opposite sign as xt, in which case |xt+1| η |ℓ (xtyt)| yt ηyt O( γη). This implies supt t0 |xt| O( p (γ 1 γ) η) as asserted. C.2 Gradient flow regime In this section, we prove Theorem 1. From Lemma 6, there exists an iteration t0 such that |xt0| p η/(2 δ) and η O(2 δ) y2 t0 2 δ η + x2 t0 2 δ η + O η 2 δ . In particular, C := |xt0yt0| 1. We prove by induction the following facts: for t t0, 1. |xtyt| C. 2. |xt| |xt0| exp( Ω(α (t t0))), where α := min{δ, 2 δ}. Suppose that these conditions hold up to iteration t t0. By Assumption (A2), we have |ℓ (s)| |s| for all s = 0. Therefore, yt+1 = yt η ℓ (xtyt) xt (1 ηx2 t) yt 2 δ exp Ω(α (t t0)) yt s=t0 exp Ω(α (s t0)) yt0 exp O η2 α (2 δ) yt0 , η O(2 δ) O η In particular, 1 η y2 t 2 δ/2 η throughout. In order for these assertions to hold, we require η2 α (2 δ), i.e., η min{ Next, we would like to show that t 7 |xt| is decaying exponentially fast. Since |xt+1| = |xt η ℓ (xtyt) yt| = |xt| η ℓ (|xt| yt) yt , it suffices to consider the case when xt > 0. Assumption (A2) implies that xt+1 (1 ηy2 t ) xt 1 δ For the upper bound, we split into two cases. We begin by observing that since ℓis twice continuously differentiable near the origin with ℓ (0) = 1, there is a constant ε0 such that |s| < ε0 implies |ℓ (s)| 1 2 |s|. If st := xtyt ε0, then 2 y2 t xt 1 2 δ Otherwise, if st ε0, then xt+1 xt η ℓ (ε0) yt xt η ℓ (ε0) y2 t st xt η ℓ (ε0) 2 δ 2Cη xt 1 Ω(2 δ) xt . Combining these inequalities, we obtain |xt+1| |xt| exp Ω(α) . This verifies the second statement in the induction. The first statement follows because both t 7 |xt| and t 7 yt are decreasing. This shows in particular that |xt| 0, i.e., we have global convergence. To conclude the proof, observe that (11) gives a bound on the final sharpness. Remark 1. The proof also gives us estimates on the convergence rate. Namely, from Lemma 6, the initial phase in which we approach the y-axis takes O( 1 η) iterations. For the convergence phase, in order to achieve ε error, we need |xt| εη 2 δ; hence, the convergence phase needs only O( 1 ε) iterations. Note that the rate of convergence in the latter phase does not depend on the step size η. C.3 Eo S regime: proof outline We give a brief outline of the proof of Theorem 2: As before, Lemma 6 shows that GD reaches the y-axis approximately at (0, p y2 0 x2 0). At this point, x starts bouncing while y steadily decreases, and we argue that unless xt = 0 or y2 t 2/η, the GD dynamics cannot stabilize (see Lemma 7). To bound the gap 2/η y2 , we look at the first iteration t such that y2t crosses 2/η. By making use of Assumption (A2), we simultaneously control both the convergence rate of |xt| to zero and the decrease in y2 t in order to prove that η O(|xtyt|) , (12) see Proposition 1. Therefore, to establish Theorem 2, we must bound |xtyt| at iteration t. Controlling the size of |xtyt|, however, is surprisingly delicate as it requires a fine-grained understanding of the bouncing phase. The insight that guides the proof is the observation that during the bouncing phase, the GD iterates lie close to a certain envelope (Figure 9). This envelope is predicted by the quasi-static heuristic as described in Ma et al. (2022). Namely, suppose that after one iteration of GD, we have perfect bouncing: xt+1 = xt. Substituting this into the GD dynamics, we obtain the equation η ℓ (xtyt) yt = 2xt . (13) According to Assumption (A2), we have ℓ (xtyt) = xtyt (1 Ω(|xtyt|β)), Together with (13), if y2 t = (2 + δt)/η 2/η, where δt is sufficiently small, it suggests that |xtyt| δ1/β t . (14) The quasi-static prediction (14) fails when δt is too small. Nevertheless, we show that it remains accurate as long as δt ηβ/(β 1), and consequently we obtain |xtyt| η1/(β 1). Combined with (12), it yields Theorem 2. C.4 Eo S regime: crossing the threshold and the convergence phase In this section, we prove Theorem 2. We first show that y2 t must cross 2/η in order for GD to converge, and we bound the size of the jump across 2/η once this happens. Throughout this section and the next, we use the following notation: st := xtyt; rt := ℓ (st)/st. In this notation, we can write the GD equations as xt+1 = (1 ηrty2 t ) xt , yt+1 = (1 ηrtx2 t) yt . We also make a remark regarding Assumption (A2). If β < + , then Assumption (A2) is equivalent to the following seemingly strongly assumption: for all r > 0, there exists a constant c(r) > 0 such that s 1 c(r) |s|β , for all 0 < |s| r . (A2+) Indeed, Assumption (A2) states that (A2+) holds for some r > 0. To verify that (A2+) holds for some larger r > r, we can split into two cases. If |s| r, then ℓ (s)/s 1 c |s|β. Otherwise, if |s| > r, then ℓ (r)/r < 1 and the 1-Lipschitzness of ℓ imply that ℓ (s)/s < 1 for r |s| r , and hence ℓ (s)/s 1 c |s|β, for a sufficiently small constant c > 0; thus we can take c(r ) = c c . Later, we will invoke (A2+) with r chosen to be a universal constant, so that c(r) can also be thought of as universal. We begin with the following result about the limiting value of yt. Lemma 7 (threshold crossing). Let ( x, y) R2 satisfy y > x > 0 with y2 x2 = 1. Suppose we initialize GD with step size η with initial point (x0, y0) := q η ( x, y), where δ > 0 is a constant. Then either xt = 0 for some t or lim t y2 t 2 Proof. Assume throughout that xt = 0 for all t. Recall the dynamics for y: yt+1 = yt η ℓ (xtyt) xt . By assumption ℓ (s)/s 1 as s 0, and ℓ is increasing, so this equation implies that if lim inft |xt| > 0 then y2 t must eventually cross 2/η. Suppose for the sake of contradiction that there exists ε > 0 with y2 t > (2 + ε)/η, for all t. Let ε > 0 be such that 1 (2 + ε) (1 ε ) < 1, i.e., ε < ε 2+ε. Then, there exists δ > 0 such |xt| δ implies rt > 1 ε , hence |xt| = |1 ηrty2 t | > |(2 + ε) (1 ε ) 1| > 1 . The above means that |xt| increases until it exceeds δ, i.e., lim inft |xt| δ. This is our desired contradiction and it implies that limt y2 t 2/η. Lemma 8 (initial gap). Suppose that at some iteration t, we have Then, it holds that Proof. We can bound y2t+1 = y2t 2η ℓ (xtyt) xtyt + η2 ℓ (xtyt)2 x2t y2t 2η |xtyt|2 , where we used the fact that |ℓ (s)| |s| for all s R, The above lemma shows that the size of the jump across 2/η is controlled by the size of |st| at the time of the crossing. From Lemma 6, we know that |st| 1, where the implied constant depends on δ. Hence, the size of the jump is always O(η). We now provide an analysis of the convergence phase, i.e., after y2 t crosses 2/η. Proposition 1 (convergence phase). Suppose that y2t < 2/η y2t 1. Then, GD converges to (0, y ) satisfying 2 η O(|st|) y2 2 Proof. Write y2 t = (2 ρt)/η, so that ρt = 2 ηy2 t . We write down the update equations for x and for ρ. First, by the same argument as in the proof of Theorem 1, we have |xt+1| |xt| exp Ω(ρt) . (15) Next, using rt 1, yt+1 = (1 ηrtx2 t) yt (1 ηx2 t) yt , y2 t+1 (1 2ηx2 t) y2 t , which translates into ρt+1 ρt + 2η2x2 ty2 t ρt + 4ηx2 t . (16) Using these two inequalities, we can conclude as follows. Let q > 0 be a parameter chosen later, and let t be the first iteration for which ρt q (if no such iteration exists, then ρt q for all t). Note that ρt q + O(η |xt|) due to (15) and (16). By (15), we conclude that for all t t, |xt | |xt| exp Ω(q (t t)) |xt| exp Ω(q (t t)) . Substituting this into (16), s=t x2 s q + O(η |xt|) + O(η |xt|2) s=1 exp Ω(q (s t)) q + O(η |xt|) + O η |xt|2 By optimizing this bound over q, we find that for all t, ρt η |xt| η |st| . Translating this result back into y2 t yields the result. Let us take stock of what we have established thus far. According to Lemma 6, |st| is bounded for all t by a constant. Then, from Lemma 7 and Lemma 8, we must have either y2 t 2/η, or 2/η O(η) y2t 2/η for some iteration t. In the latter case, Proposition 1 shows that the limiting sharpness is 2/η O(1). Note also that the analyses thus far have not made use of Assumption (A2), i.e., we have established the β = + case of Theorem 2. Moreover, for all β > 1, the asymptotic 2/η O(1) still shows that the limiting sharpness is close to 2/η, albeit with suboptimal rate. The reader who is satisfied with this result can then skip ahead to subsequent sections. The remainder of this section and the next section are devoted to substantial refinements of the analysis. To see where improvements are possible, note that both Lemma 8 and Proposition 1 rely on the size of |st| at the crossing. Our crude bound of |st| 1 does not capture the behavior observed in experiments, in which |st| η1/(β 1). By substituting this improved bound into Lemma 7, we would deduce that the gap at the crossing is O(η1+2/(β 1)), and then Proposition 1 would imply that the limiting sharpness is 2/η O(η1/(β 1)). Another weakness of our proof is that it provides nearly no information about the dynamics during the bouncing phase, which constitutes an incomplete understanding of the Eo S phenomenon. In particular, we experimentally observe that during the bouncing phase, the iterates lie very close to the quasi-static envelope (Figure 9). In the next section, we will rigorously prove all of these observations. Before doing so, however, we show that Proposition 1 can be refined by using Assumption (A2), which could be of interest in its own right. It shows that even if the convergence phase begins with a large value of |st|, the limiting sharpness can be much closer to 2/η than what Proposition 1 suggests. The following proposition combined with Lemma 6 implies Theorem 2 for all β > 2, but it is insufficient for the case 1 < β 2. From now on, we assume β < + . Proposition 2 (convergence phase; refined). Suppose that y2t < 2/η y2t 1. Then, GD converges to (0, y ) satisfying η O(η |st|2) O(η1/(β 1)) , β > 2 , O(η log(|st|/η)) , β = 2 , O(η |st|2 β) , β < 2 . Proof. Let y2 t = (2 ρt)/η as before. We quantify the decrease of |xt| in terms of ρt and conversely the increase of ρt in terms of |xt| by tracking the half-life of |xt|, i.e., the number of iterations it takes |xt| to halve. We call these epochs: at the i-th epoch, we have 2 (i+1) η < |xt| 2 i η . Let i0 be the index of the first epoch, i.e., i0 = log2( η/|xt|) . Due to Lemma 6, we know that i0 O(1). From (15), |xt| is monotonically decreasing and consequently |st| is decreasing as well. Also, our bound on the limiting sharpness implies that y2 t > 1/η for all t, provided that η is sufficiently small. Let us now compute the dynamics of ρt and |xt|. At epoch i, |xt| > 2 (i+1) η hence |st| > 2 (i+1). Assumption (A2+) with r = |st| 1 implies that st 1 c 2 β (i+1) , (17) where c = c(|st|). This allows to refine (15) on the decrease of |xt| to |xt| = ηrty2 t 1 (2 ρt) (1 c 2 β (i+1)) 1 1 ρt c 2 β (i+1) , where the first inequality follows from (17) and the second from ρt = 2 ηy2 t < 1. In turn, this inequality shows that the i-th phase only requires O(2βi) iterations. Hence, if t(i) denotes the start of the i-th epoch, then (16) shows that ρt(i+1) ρt(i) + 4η2 2 2i O(2βi) ρt(i) + O(η2 2(β 2) i) . Summing this up, we have ρt(i) ρt + η2 O(2(β 2) i) , β > 2 , O(i i0) , β = 2 , O(2(β 2) i0) = O(|st|2 β) , β < 2 . In the case of β < 2, the final sharpness satisfies 2/η O(ρt/η) O(η |st|2 β) y2 2/η. In the other two cases, suppose that we use this argument until epoch i such that 2 i ηγ. Then, we have |xt(i )| ηγ+1/2, |st(i )| ηγ, and by using the argument from Proposition 1 from iteration t(i ) onward we obtain ρ = ρt(i ) + ρ ρt(i ) ρt + O(ηγ+1) + η2 O(2(β 2) i ) = O(η γ (β 2)) , β > 2 , O(i i0) , β = 2 . We optimize over the choice of γ, obtaining γ = 1/(β 1) and thus ρ ρt + O(η1+1/(β 1)) , β > 2 , O(η2 log(|st|/η)) , β = 2 . By collecting together the three cases and using Lemma 8 to bound ρt, we finish the proof. Using the crude bound |st0| 1 from Lemma 6, it yields O(η1/(β 1)) , β > 2 , O(η log(1/η)) , β = 2 , O(η) , β < 2 , which is optimal for β > 2. C.5 Eo S regime: quasi-static analysis We supplement Assumption (A2) with a corresponding lower bound on ℓ (s)/s: There exists C > 0 such that ℓ (s) s 1 C |s|β for all s = 0 . (A3) Under these assumptions, we prove the following result which is also of interest as it provides detailed information for the bouncing phase of the Eo S. Theorem 5 (quasi-static principle). Suppose we run GD on f with step size η > 0, where f(x, y) := ℓ(xy) and ℓsatisfies Assumptions (A1), (A2), and (A3). Write y2 t := (2 + δt)/η and suppose that at some iteration t0, we have |xt0yt0| δ1/β t0 and δt0 1. Then, for all t t0 with δt ηβ/(β 1), we have |xtyt| δ1/β t , where all implied constants depend on ℓbut not on η. In this section, we show that the GD iterates lie close to the quasi-static trajectory and give the full proof of Theorem 2. Recall from (13) that the quasi-static analysis predicts ηrty2 t 2 , (18) and that during the bouncing phase, this closely agrees with experimental observations (Figure 9). We consider the phase where y2 t has not yet crossed the threshold 2/η and we write y2 t := (2 + δt)/η, thinking of δt as small. Then, (18) can be written (2 + δt) rt 2. If we have the behavior ℓ (s)/s = 1 Θ(|st|β) near the origin, then rt 1 Θ(δt) implies that |st|β δt . (19) Our goal is to rigorously establish (19). However, we first make two observations. First, in order to establish Theorem 2, we only need to prove an upper bound on |st|, which only requires Assumption (A2) (to prove a lower bound on |st|, we need a corresponding lower bound on ℓ (s)/s). Second, even if we relax (19) to read |st|β δt, this fails to hold when δt is too small, because the error terms (the deviation of the dynamics from the quasi-static trajectory) begin to dominate. With this in mind, we shall instead prove |st|β δt + C ηγ, where the added ηγ handles the error terms and the exponent γ > 0 emerges from the proof. Proposition 3 (quasi-static analysis; upper bound). For all t such that 0 δt 1 1/(β 1) (for a sufficiently small implied constant), it holds that |st|β C (δt + C ηβ/(β 1)) , where C, C > 0 are constants which may depend on the problem parameters but not on η. We first show that Theorem 2 now follows. Proof of Theorem 2. As previously noted, the β = + case is handled by the arguments of the previous section, so we focus on β < + . From Lemma 7, we either have y2 t 2/η and |xt| 0, in which case we are done, or there is an iteration t such that y2t < 2/η y2t 1. From Proposition 3, since δt 1 0 and δt 0, it follows that |st|β η1/(β 1). The theorem now follows, either from Proposition 1 or from the refined Proposition 2. We now prove Proposition 3. In the proof, we use asymptotic notation O( ), , etc. in order to hide constants that depend on ℓ(including β), but not on δt and η. However, the proof also involves choosing parameters C, C > 0, and we keep the dependence on these parameters explicit for clarity. Proof of Proposition 3. The proof goes by induction; namely, if |st|β C (δt + C ηγ) and δt 0 at some iteration t, we prove that the same holds one iteration later, where the constants C, C > 0 as well as the exponent γ > 0 are chosen later in the proof. For the base case, observe that the approximate conservation lemma (Lemma 6) gives |st| 1, and δt 1/(β 1) at the beginning of the induction, so the bound is satisfied initially if we choose C sufficiently large enough. Throughout, we also write ˆδt := δt + C ηγ as a convenient shorthand. The strategy is to prove the following two statements: 1. If |st|β = Ctˆδt for some Ct > C 2 , then |st+1|β Ct+1ˆδt+1 for some Ct+1 Ct. 2. If |st|β = Ctˆδt for some Ct C 2 , then |st+1|β Cˆδt+1. Proof of 1. The dynamics for x give |xt+1| = |1 ηy2 t rt| |xt| . By Assumption (A2+) and |st| 1, rt 1 Ω(|st|β) = 1 Ω(Cˆδt) ηy2 t rt = (2 + δt) 1 Ω(Cˆδt) = 2 Ω(Cˆδt) for large C. Also, ℓ (0) = 1 and a similar argument as in the proof of Theorem 1 yields the reverse inequality ηy2 t rt 1. We conclude that |xt+1| = 1 Ω(Cˆδt) |xt| |st+1|β 1 Ω(Cˆδt) |st|β = Ct 1 Ω(Cˆδt) ˆδt . Since we need a bound in terms of ˆδt+1, we use the dynamics of y, yt+1 = (1 ηx2 trt) yt (1 ηx2 t) yt , y2 t+1 (1 2ηx2 t) y2 t , δt+1 = ηy2 t+1 2 δt 2η2s2 t δt 2η2 (Cˆδt) 2/β . (20) Substituting this in, |st+1|β Ct 1 Ω(Cˆδt) ˆδt+1 + 2η2 (Cˆδt) 2/β = Ctˆδt+1 Ω(C2ˆδtˆδt+1) + 2Cη2 (Cˆδt) 2/β . (21) Let us show that 4 ˆδt . (22) From (20), we have ˆδt+1 ˆδt 2η2 (Cˆδt) 2/β, so we want to prove that η2 (Cˆδt) 2/β ˆδt/8. If β 2 this is obvious by taking η small, and if β > 2 then this is equivalent to C2/βη2 ˆδ1 2/β t . It suffices to have C2/βη2 (C )1 2/β ηγ (1 2/β), which is achieved by taking C large relative to C and by taking γ 2/(1 2/β); this constraint on γ will be satisfied by our eventual choice of γ = β/(β 1). Returning to (21), in order to finish the proof and in light of (22), we want to show that C2ˆδ2 t C1+2/βη2ˆδ2/β t . Rearranging, it suffices to have ˆδ2 2/β t C2/β 1η2, or ˆδ1 1/β t C1/β 1/2η. Since by definition ˆδt C ηγ, by choosing C large it suffices to have γ 1/(1 1/β) = β/(β 1), which leads to our choice of γ. Proof of 2. Using the simple bound ηy2 t rt 2 + δt, we have |st+1| (1 + δt) |st| , |st+1|β exp(βδt) |st|β = Ct exp(βδt) ˆδt 4 3 Ct exp(βδt) ˆδt+1 where we used (22). If exp(βδt) 4/3, which holds if δt 1/β, then from Ct C/2 we obtain |st+1|β Cˆδt+1 as desired. By following the same proof outline but reversing the inequalities, we can also show a corresponding lower bound on |st|β, as long as δt ηβ/(β 1). Although this is not needed to establish Theorem 2, it is of interest in its own right, as it shows (together with Proposition 3) that the iterates of GD do in fact track the quasi-static trajectory. Proposition 4 (quasi-static analysis; lower bound). Suppose additionally that (A3) holds and that β < + . Also, suppose that at some iteration t0, we have δt0 1 and that |st| c δ1/β t (23) holds at iteration t = t0, where c is a sufficiently small constant (depending on the problem parameters but not on η). Then, (23) also holds for all iterations t t0 such that δt ηβ/(β 1). Proof. The proof mirrors that of Proposition 3. Let δt ηβ/(β 1) for a sufficiently large implied constant. We prove the following two statements: 1. If |st| = ct δ1/β t for some ct < 2c, then |st+1| ct+1 δ1/β t+1 for some ct+1 ct. 2. If |st| = ct δ1/β t for some ct 2c, then |st+1| c δ1/β t+1. Throughout the proof, due to Proposition 3, we also have |st| δ1/β t . Proof of 1. The dynamics for x give |xt+1| = |1 ηy2 t rt| |xt| . By Assumption (A3), rt 1 O(|st|β) 1 O(c δt) . If c is sufficiently small, then ηy2 t rt (2 + δt) 1 O(c δt) 2 + Ω(δt) . Therefore, we obtain |xt+1| 1 + Ω(δt) |xt| . On the other hand, yt+1 (1 ηx2 t) yt 1 O(η2s2 t) yt 1 O(η2δ2/β t ) yt (24) |st+1| 1 + Ω(δt) 1 O(η2δ2/β t ) |st| ct 1 + Ω(δt) O(η2δ2/β t ) δ1/β t ct 1 + Ω(δt) O(η2δ2/β t ) δ1/β t+1 . To conclude, we must prove that η2δ2/β t δt, but since δt ηβ/(β 1) (with sufficiently large implied constant), then this holds, as was checked in the proof of Proposition 3. Proof of 2. Using Assumption (A3), 1 O(δt) 1 O(|st|β) rt 1 . 2 O(δt) (2 + δt) 1 O(δt) ηy2 t rt 2 + δt 1 + O(δt) 1 ηy2 t rt 1 δt . Together with the dynamics for x and (24), |st+1| 1 O(δt) 1 O(η2δ2/β t ) |st| ct 1 O(δt) 1 O(η2δ2/β t ) δ1/β t+1 . Since ct 2c, if δt and η are sufficiently small it implies |st+1| c δ1/β t+1. Convergence rate estimates. Our analysis also provides estimates for the convergence rate of GD in both regimes. Namely, in the gradient flow regime, we show that GD converges in O(1/η) iterations, whereas in the Eo S regime, GD typically spends Ω(1/η(β/(β 1)) 2) iterations (Ω(log(1/η)/η2) iterations when β = 2) in the bouncing phase (Figure 11). Hence, the existence of the bouncing phase dramatically slows down the convergence of GD. Remark 2. Suppose that at iteration t0, we have δt0 1. Then, the assumption of Proposition 4 is that |st0| 1. If this is not satisfied, i.e., |st0| 1, then the first claim in the proof of Proposition 4 shows that |st0+1| (1 + Ω(δt)) |st0| = (1 + Ω(1)) |st0|. Therefore, after t = O(log(1/|st0|)) iterations, we obtain |st0+t | 1 and then Proposition 4 applies thereafter. Remark 3. From the quasi-static analysis, we can also derive bounds on the length of the bouncing phase. Namely, suppose that t0 is such that δt0 1 and for all t t0, we have |st| = δ1/β t . If δt0 is sufficiently small so that rt 1 for all t t0, then the equation for y yields δt+1 δt Θ(η2s2 t) = δt Θ(η2δ2/β t ) . We declare the k-th phase to consist of iterations t such that 2 k δt 2 (k 1). During this phase, δt+1 δt Θ(η2 2 2k/β), so the number of iterations in phase k is 2k (2/β 1)/η2. We sum over the phases until δt ηβ/(β 1), since after this point the quasi-static analysis fails and y2 t crosses over 2/η shortly afterwards. This yields k Z ηβ/(β 1) 2 k 1 1/η2 , β > 2 , log(1/η)/η2 , β = 2 , 1/ηβ/(β 1) , β < 2 . The time spent in the bouncing phase increases dramatically as β 1. D Deferred derivations of mean model In this section, we provide the details for the derivations of the mean model in Subsection 3.1. Recall f(x; a , a+, b) = a d X i=1 Re LU x[i] + b + a+ d X i=1 Re LU +x[i] + b , where x = λyej + ξ. We first approximate i=1 Re LU x[i] + b i=1 Re LU ξ[i] + b . In other words, we can ignore the contribution of the signal λyej. This approximation holds because (i) initially, the bias b is not yet negative enough to threshold out the noise, and hence the summation Pd i=1 Re LU( ξ[i] + b) is of size O(d), and (ii) the difference between the leftand right-hand sides above is simply Re LU( λy ξ[j] + b) Re LU( ξ[j] + b), which is of size O(1) and hence negligible compared to the full summation. Next, letting g(b) := Ez N(0,1) Re LU(z + b) be the smoothed Re LU (see Figure 7), concentration of measure implies the following two facts: Pd i=1 Re LU ξ[i] + b d Eξ N(0,1) Re LU(ξ + b) =: d g(b) and Pd i=1 1{ x[i] + b 0} d Eξ N(0,1) 1{ξ + b 0} = d g (b). Indeed, the summations above are sums of d i.i.d. non-negative random variables, and hence its mean is Ω(d) (as long as b O(1)) and its standard deviation is O( d). Now, using these approximations, one can rewrite the GD dynamics on the population loss E[ℓlogi(yf(x; a , a+, b))]. Using these approximations, the output of the Re LU network (2) can be written as f(x; a , a+, b) d (a + a+) g(b) , which in turn leads to an approximation of the GD dynamics on the population loss (a , a+, b) 7 E[ℓlogi(yf(x; a , a+, b))]: a t+1 = a t η E h ℓ logi y f(x; a t , a+ t , bt) | {z } d (a t +a+ t ) g(bt) i=1 Re LU x[i] + bt | {z } d g(bt) a t η ℓ sym d (a t + a+ t ) g(bt) d g(bt) , bt+1 = bt η E h ℓ logi(y f(x; a t , a+ t , bt) | {z } d (a t +a+ t ) g(bt) i=1 1{ x[i] + bt 0} | {z } d g (bt) i=1 1{+x[i] + bt 0} | {z } d g (bt) bt η ℓ sym d (a t + a+ t ) g(bt) d (a t + a+ t ) g (bt) , where ℓsym(s) := 1 2(log(1 + exp( s)) + log(1 + exp(+s))) is the symmetrized logistic loss. Hence we arrive at the following dynamics on a and b that we call the mean model: a t+1 = a t η ℓ sym d (a t + a+ t ) g(bt) d g(bt) , bt+1 = bt η ℓ sym d (a t + a+ t ) g(bt) d (a+ t + a t ) g (bt) . Now, we can write the above dynamics more compactly in terms of the parameter At := d (a t + a+ t ). At+1 = At 2d2η ℓ sym(Atg(bt)) g(bt) , bt+1 = bt η ℓ sym(Atg(bt)) Atg (bt) . E Proofs for the mean model In this section, we prove the main theorems for the mean model. We first recall the mean model for the reader s convenience. At+1 = At 2d2η ℓ (Atg(bt)) g(bt) , bt+1 = bt η ℓ (Atg(bt)) Atg (bt) . E.1 Deferred proofs In this section, we collect together deferred proofs from Subsection 3.2. Proof of Lemma 2. By definition, g(b) = R b(ξ + b) φ(ξ) dξ = R b ξ φ(ξ) dξ + b Φ(b). Recalling φ (ξ) = ξ φ(ξ), the first term equals φ(b). Moreover, g (b) = b φ(b)+Φ(b)+b φ(b) = Φ(b). Proof of Lemma 3. Note that t( 1 2 A2) = A A = 2d2 ℓ (Ag(b)) Ag(b) and also that tκ(b) = ℓ (Ag(b)) κ (b) Ag (b) = ℓ (Ag(b)) Ag(b) since κ = g/g . Hence, t 1 2A2 2d2κ(b) = 0 and the proof is completed. E.2 Gradient flow regime Proof of Theorem 3. The following proof is analogous to the proof of Theorem 1. We first list several facts we use in the proof: (i) |g (b)| = |Φ(b)| 1 for all b R. (ii) ℓ (s) = 1 2 exp(s) 1 exp(s)+1. Hence, |ℓ (s)| 1 2 for all s R, and we have 8 1 , if |s| 2 , 2/|s| , if |s| > 2 . (iii) ℓ (0) = 1/4. (iv) ℓ (s) = exp(s) (exp(s) 1) (exp(s)+1)3 . Hence, ℓ (s) < 0 for s > 0 and ℓ (s) > 0 for s < 0. In particular, |ℓ (s)| 1 4 |s| for all s R. Throughout the proof, we assume that A0 > 0 without loss of generality. We prove by induction the following claim: for t 0 and γ := 1 200 min n δ, 8 δ, 8 δ it holds that |At| A0 exp( γt). This clearly holds at initialization. Suppose that the claim holds up to iteration t. Using the bounds on |g | and |ℓ |, it follows that bt+1 bt |ℓ (Atg(bt))| |At| g (bt) bt 1 2 ηA0 exp( γt) b0 1 s=0 exp( γs) ηA0 In particular, bt 1 and g(bt) > 0.08, since η γ A0 . Also, the bound shows that if the claim holds for all t, then we obtain the desired conclusion. It remains to establish the inductive claim; assume that it holds up to iteration t. For the dynamics of A, by symmetry we may suppose that At > 0. From ℓ (Atg(bt)) Atg(bt)/4 and g(bt) g(0) = 1 2π, it follows that At+1 = At 2ηd2 ℓ (Atg(bt)) g(bt) 1 ηd2 2 g(bt)2 At 2 g(0)2 At = 1 δ This shows that At+1 (1 γ) At. Next, we show that At+1 (1 γ) At. First, if Atg(bt) 2, At+1 = At 2ηd2 ℓ (Atg(bt)) g(bt) At 1 4 ηd2 At g(bt)2 = 1 (8 δ) π 4 g(bt)2 At 1 (8 δ) 4 π 0.082 At (1 γ) At , since we have g(bt) > 0.08. Next, if Atg(bt) 2, then At+1 = At 2ηd2 ℓ (Atg(bt)) g(bt) At 1 2 ηd2 g(bt) = 1 (8 δ) π At (1 γ) At . This shows that |At+1| (1 γ) |At| for the case At > 0. A similar conclusion is obtained for the case At < 0. The induction is complete. E.3 Eo S regime Proof of Theorem 4. The following proof is analogous to the proof of Lemma 7. Assume throughout that At = 0 for all t. Recall the dynamics for b: bt+1 = bt η ℓ (Atg(bt)) Atg (bt) . Since ℓ (s)/s 1/4 as s 0, and ℓ is increasing, this equation implies that if lim inft |At| > 0 then bt must keep decreasing until 1 2 d2g(bt)2 < 2/η. Suppose for the sake of contradiction that there exists ε > 0 with 1 2 d2g(bt)2 > (2 + ε)/η, for all t. Let ε > 0 be such that 1 (2 + ε) (1 ε ) < 1, i.e., ε < ε 2+ε. Then, there exists δ > 0 such |At| δ implies ℓ (Atg(bt))/(Atg(bt)) > 1 4 (1 ε ), hence |At| = 1 4 1 2 ηd2 g(bt)2 > |(2 + ε) (1 ε ) 1| > 1 . The above means that |At| increases until it exceeds δ, i.e., lim inft |At| δ. This is our desired contradiction and it implies that limt 1 2 d2g(bt)2 2/η. Remark 4. A straightforward calculation yields that when (a , a+ , b ) is a global minimizer (i.e., a + a+ = 0), then λmax 2f(a , a+ , b ) = 1 2 d2 g(b )2. The mean model initialized at (A0, 0) approximately reaches (0, 0) whose sharpness is d2 g(0)2/2 = d2/4π. Hence, the bias learning regime 2/η < d2/(4π) precisely corresponds to the Eo S regime, 2/η < λmax 2f(a , a+ , b ) .