# depth_separation_with_multilayer_meanfield_networks__88cff3f8.pdf Published as a conference paper at ICLR 2023 DEPTH SEPARATION WITH MULTILAYER MEAN-FIELD NETWORKS Yunwei Ren Carnegie Mellon University yunweir@andrew.cmu.edu Mo Zhou Duke University mozhou@cs.duke.edu Rong Ge Duke University rongge@cs.duke.edu Depth separation why a deeper network is more powerful than a shallower one has been a major problem in deep learning theory. Previous results often focus on representation power. For example, Safran et al. (2019) constructed a function that is easy to approximate using a 3-layer network but not approximable by any 2-layer network. In this paper, we show that this separation is in fact algorithmic: one can learn the function constructed by Safran et al. (2019) using an overparameterized network with polynomially many neurons efficiently. Our result relies on a new way of extending the mean-field limit to multilayer networks, and a decomposition of loss that factors out the error introduced by the discretization of infinite-width mean-field networks. 1 INTRODUCTION One of the mysteries in deep learning theory is why we need deeper networks. In the early attempts, researchers showed that deeper networks can represent functions that are hard for shallow networks to approximate(Eldan & Shamir, 2016; Telgarsky, 2016; Poole et al., 2016; Daniely, 2017; Yarotsky, 2017; Liang & Srikant, 2017; Safran & Shamir, 2017; Poggio et al., 2017; Safran et al., 2019; Malach & Shalev-Shwartz, 2019; Vardi & Shamir, 2020; Venturi et al., 2022; Malach et al., 2021). In particular, seminal works of Eldan & Shamir (2016); Safran et al. (2019) constructed a simple function (f (x) = Re LU(1 x )) which can be computed by a 3-layer neural network but cannot be approximated by a 2-layer network. However, these results are only about the representation power of neural networks and do not guarantee that training a deep neural network from reasonable initialization can indeed learn such functions. In this paper, we prove that one can train a neural network that approximates f (x) = Re LU(1 x ) to any desired accuracy this gives an algorithmic separation between the power of 2-layer and 3-layer networks. To analyze the training dynamics, we develop a new framework to generalize mean-field analysis of neural networks (Chizat & Bach, 2018; Mei et al., 2018) to multiple layers. As a result, all the layer weights can change significantly during the training process (unlike many previous works on neural tangent kernel or fixing lower-layer representations). Our analysis also gives a decomposition of loss that allows us to decouple the training of multiple layers. In the remainder of the paper, we first introduce our new framework for multilayer mean-field analysis, then give our main result and techniques. We discuss several related works in the algorithmic aspect for depth separation in Section 1.3. Similar to standard mean-field analysis, we first consider the infinite-width dynamics in Section 3, then we discuss our new ideas in discretizing the result to a polynomial-size network (see Section 4). 1.1 MULTI-LAYER MEAN-FIELD FRAMEWORK We propose a new way to extend the mean-field analysis to multiple layers. For simplicity, we state it for 3-layer networks here. See Appendix A for the general framework. In short, we break the middle layer into two linear layers and restrict the size of the layer in between. More precisely, we Published as a conference paper at ICLR 2023 Figure 1: Difference between previous Nguyen & Pham (2020) (Left) and our framework (Right). define f(x) = 1 m2 a 2 σ(W2F (x)), F (x) = 1 m1 A1σ(W1x), where W1 Rm1 d, A1 RD m1, W2 Rm2 D a2 Rm2 are the parameters, and F (x) RD represents the hidden feature. See Figure 1 for an illustration. Later we will refer to the step of x 7 F (x) as the first layer and F (x) 7 f(x) as the second layer, even though both of them actually are two-layer networks. In the infinite-width limit, we will fix hidden feature dimension D and let the number of neurons m1, m2 go to infinity. Then, we get the infinite-width network f(x) = E (a2,w2) µ2 a2σ(w2 F (x)), Fi(x) = E (a1,w1) µ1,i a1σ(w1 x), i [D], where (µ1,i)i [D] are distributions over R1+d with a shared marginal distribution over w1, and µ2 is a distribution over R1+D. Note that, unlike the formulation in Nguyen & Pham (2020), here the hidden layers are described using distributions of neurons, whence are automatically invariant under permutation of neurons, which is one of the most important properties of mean-field networks. One can choose µ1, µ2 to be empirical distributions over finitely many neurons to recover a finite-width network. In fact, we will do so in most parts of the paper so that our results apply to finite-width networks of polynomially many neurons. The network can be viewed as a 3-layer network with intermediate layer W2A, which is low rank. This is reminiscent of the bottleneck structure used in Res Net (He et al. (2016)) and has also been used in previous theoretical analyses such as Allen-Zhu & Li (2020) for other purposes. Learner network Now we are ready to introduce the specific network that we use to learn the target function. We set D = 1 and couple a1 with w1. F(x) = F(x; µ1) := E w µ1 { w σ(w x)} , f(x) = f(x; µ2, µ1) := E (w2,b2) µ2 σ(w2F(x; µ1) + b2). (1) Here, σ is the Re LU activation, and µ1 P(Rd) and µ2 P(R2) are distributions encoding the weights of the first and second hidden layers, respectively. We multiply each first layer neuron by w to make F more regular. This 2-homogeneous parameterization is also used in Li et al. (2020) and Wang et al. (2020). In most parts of the paper, µ1 and µ2 are empirical distributions over polynomially many neurons. We use µ1, µ2 to unify the notations in discussions on infiniteand finite-width networks. Restricting the intermediate layer to have only one dimension (D = 1) is sufficient as one can learn x 7 α x for some α R with the first layer F(x) and α x 7 σ(1 x ) with the second layer. For the network that computes F(x), we do not need a bias term as the intended function is homogeneous in x. Though we restrict the first layer to be positive, it does not restrict the representation power of the network as the second layer can be either positive or negative. For the second layer, even though a single neuron is sufficient, we follow the framework and overparameterize the network. Published as a conference paper at ICLR 2023 1.2 MAIN RESULT AND OUR TECHNIQUES Our main result applies the framework in the previous section to the function constructed in Safran et al. (2019) (see details in Section 2). Informally, we prove:1 Theorem 1.1 (Main result, Informal). Given the learner network defined in (1) with input dimension d, for any ϵ > 0, we can choose layer widths as m1 = poly(d, 1/ϵ), m2 = Θ(1) so that, with probability at least 1 1/ poly(d, 1/ε) over random initialization, running a simple variant of gradient flow2 reduces the loss L := Ex (f (x) f(x))2 /2 to ε within T = poly(d, 1/ϵ) time. This result shows that one can train a multilayer neural network to learn the function Re LU(1 x ) that cannot be approximated by any 2-layer network. There are some technical details caused by the choice of a heavy-tail input distribution in Safran et al. (2019) which we discuss in Section 2. To prove such a result, we first characterize the infinite-width dynamics (see Section 3). In particular, we show that in the infinite-width dynamics, the first layer will always compute a multiple of x , while the second layer will behave like a single neuron. However, it is often difficult to discretize such an infinite-width analysis to a polynomial-width network. The main difficulty is in the potential amplification of error in the network: if at the beginning, the first layer is δ-close to computing a multiple of x , this δ value can potentially increase exponentially during the training process (Mei et al. (2018)). Given the large polynomial training time for our dynamics, this exponential increase would not be acceptable. To fix this issue, we partition the analysis into two phases, and for the time-consuming second phase, we rely on a decomposition of the loss function: 2 E x D (f (x) f(x))2 1 n (f (x) f(x))2o + w2 2 2 E x n ( F(x) F(x))2o . (2) Here F(x) is a multiple of x that is close to the actual first-layer output F(x), f(x) is the output of the network if the first layer is replaced by F(x) that is, if the first layer actually computes a multiple of x (see (5) for precise definition). The first term therefore characterizes the loss conditioned on a perfect first-layer; while the second term characterizes the difference between the first-layer output and a multiple of x . We show that the gradients of these two terms do not affect each other, at least approximately. Therefore, we can view the training process as simultaneously doing two things: minimizing the loss given a good first-layer representation (reducing first term), and making first-layer output closer to a multiple of x (reducing second term). We believe such a decomposition highlights how the lower-layer in the neural network receives useful gradient information to learn good representation for this particular objective. 1.3 RELATED WORKS Algorithmic aspect of depth separation There have been other works that add algorithmic insights into depth separation. Allen-Zhu & Li (2020) showed that multi-layer quadratic networks can learn certain target functions in a hierarchical way, which cannot be learned by any kernel methods or shallow neural networks. Our work deals with more standard neural network architectures and target functions. A concurrent work Safran & Lee (2021) considers a similar problem as ours, where they show that GD with a certain three-layer network can learn the ball indicator which is not approximable by any two-layer network. Conceptually the main difference between our results lies in the training dynamics the first layer of Safran & Lee (2021) is fixed while we train both layers. This leads to very different training dynamics and proof techniques. Overparametrized Neural Networks One line of works studied the optimization of overparameterized neural network which couples the training dynamics to kernel regression with neural tangent kernel (NTK) (e.g., Jacot et al., 2018; Allen-Zhu et al., 2018b; Du et al., 2018). However, it is shown 1We say some quantity a is poly(d, 1/ε) if it is bounded by C(d/ε)C for some universal constant C > 0 that may change across lines. 2Though gradient flow, strictly speaking, is not a proper algorithm, it is common to use it as a surrogate for gradient descent in theoretical analysis. See Appendix E for discussions on how to convert the argument to a gradient descent one. Published as a conference paper at ICLR 2023 that neural network behaves like kernel methods in NTK regime, and several lower bounds have been developed (Yehudai & Shamir, 2019; Wei et al., 2019; Ghorbani et al., 2019; 2020). Our training dynamics is not in the NTK regime as all the weights change significantly. Another line of works studied the optimization of overparameterized neural network in the mean-field limit (Mei et al., 2018; Chizat & Bach, 2018; Nitanda & Suzuki, 2017; Wei et al., 2019; Rotskoff & Vanden-Eijnden, 2018; Sirignano & Spiliopoulos, 2020). Chizat et al. (2019) showed that the parameters can move away from its initialization in mean-field regime and learn useful features, which is different from NTK regime. However, most of the existing works require exponential/infinite number of neurons and do not provide a polynomial convergence rate. See more discussions in Appendix A. Multi-layer mean-field Although mean-field analysis has been successful for the optimization of two-layer overparameterized network, it is not easy to extend it to multiple-layer network since the width of intermediate layer goes to infinity. Many works have tried to address this issue to generalize mean-field analysis to deep networks. See e.g., Nguyen & Pham (2020); Pham & Nguyen (2021); Ara ujo et al. (2019); Sirignano & Spiliopoulos (2021); Fang et al. (2021); Lu et al. (2020); Ding et al. (2021) and references therein. Unlike most of the existing works, our multi-layer mean-field framework still has finite hidden feature dimension while the number of neurons can go to infinity to become a distribution of neurons. See Section 1.1 and Appendix A for more discussions. Mildly overparameterized neural networks Recently there are many works that consider the problem of learning certain target function with mildly overparameterized (polynomial size) network (Allen-Zhu et al., 2018a; Allen-Zhu & Li, 2019; Bai & Lee, 2019; Dyer & Gur-Ari, 2019; Woodworth et al., 2020; Bai et al., 2020; Huang & Yau, 2020; Chen et al., 2020; Li et al., 2020; Wang et al., 2020; Zhou et al., 2021). In particular, these works are different from the typical meanfield analysis where usually the infinite-width network are considered, or the typical NTK analysis where neural network behaves like kernel method. Our work is in a similar direction, but we need new insights to extend the discretization to our new multilayer framework. 2 PRELIMINARIES In this section, we discuss the additional technical conditions for the input distributions in Safran et al. (2019), and how we deal with this in the training process. Notations For a vector x, we let x denote its Euclidean norm. We use a = b c as a shorthand for the condition a [b |c|, b + |c|]. For a distribution µ, we write v µ for the condition v is in the support of µ. Other notations we use are mostly standard. We usually use v1 and w1 to denote a first layer neuron, and (v2, r2) and (w2, b2) to denote a second layer neuron. Keeping two sets of notations for neurons is intentional. When we are taking expectations over neurons, we use w1 and (w2, b2). When considering a single neuron, we use v1 and (v2, r2). For vectors, we write v := v/ v . We will use Ex as a shorthand for Ex D when it is clear from the context. We also use v µ as a shorthand for v supp(µ). Target Function and Input Distribution The target function we consider is f (x) = σ(1 x ), where σ : R R is the Re LU activation. To describe the input distribution, first, we define φ(x) := Rd x d/2 Jd/2(2πRd x ), where Rd = 1 π(Γ(d/2+1))1/d and Jν is the Bessel function of the first kind of order ν. Let α, β > 0 be the universal constants from Safran et al. (2019) (cf. the proof of Theorem 5). We assume the inputs x Rd are sampled from the distribution D whose density is given by x 7 ( dβαx). It has been verified in Eldan & Shamir (2016) and Safran et al. (2019) that this is indeed a valid probability distribution. Also, note that D is a spherically symmetric distribution. For more properties of D, see Appendix B.2. By Theorem 5 of Safran et al. (2019), no two-layer networks of width poly(d, 1/ε) can approximate f to accuracy ε in L2(D).3 This distribution is heavy-tailed in the sense that Ex D[ x 2] is undefined. The choice of such heavy-tailed distribution is mostly required for proving the lower bound. Our training result holds for most reasonable spherically symmetric distributions. 3Strictly speaking, the result in Safran et al. (2019) requires ε = O(1/d6). Even in that regime, our algorithm learns the function using poly(d) neurons, which is not achievable by any two-layer network, therefore it is still a valid separation. Published as a conference paper at ICLR 2023 Training Algorithm and Main Result We use gradient flow with clipping over MSE loss to train a polynomial-size network. We write the loss as L = L(µ1, µ2) = 1 2 E x D (f (x) f(x))2 =: E x L(x), (3) Define S(x) = (f (x) f(x)) Ew2,b2 {σ (w2F(x) + b2)w2}. One can verify that the dynamics of the neurons are given by v1 = E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] , v2 = E x D ΠRv2 [(f (x) f(x))σ (v2F(x) + r2)F(x)] , r2 = E x D ΠRr2 [(f (x) f(x))σ (v2F(x) + r2)] , where ΠR stands for the projection to the ball of radius R, and Rv1 = Θ(d), Rv2 = Θ(d3), Rr2 = Θ(1) are the projection threshold. We add these additional gradient clipping because without them the gradients are not well-defined due to the heavy-tailed property of the distribution D. Note that gradient clipping is indeed widely used in practice to avoid exploding gradients (Pascanu et al., 2013; Zhang et al., 2020). In fact, we believe our optimization result without using gradient clipping would still be true for a general spherically symmetric distribution D as long as it is more regular. To initialize the learner network, we use Unif(σ1Sd 1) to initialize the first layer weights w1, N(0, σ2 2) for the second layer weights w2, and choose all second layer bias b2 to be σr, where σ1, σ2, σr are some small positive real numbers. We initialize w1 on the sphere instead using a Gaussian only for technical convenience. We initialize the bias term to be a small positive value so that all second layer neurons are activated at initialization to avoid zero gradient. Now we are ready to give our main result. It shows that gradient flow with a polynomial-sized learner network (1) defined in our mean-field framework can learn f (x) = σ(1 x ) efficiently, which is not approximable by any two-layer network (Safran et al., 2019). Theorem 2.1 (Main result). Given the learner network defined in (1) with initialization described above and suppose we run gradient flow, assuming it exists, on this finite-width network with clipping (4) on loss (3). Then, for any ϵ > 0, we can choose m1 = polym1(d, 1/ϵ), m2 = Θ(1), σ1 = 1/ polyσ1(d, 1/ϵ), σ2 = 1/ polyσ2(d, 1/ϵ), σr = Θ(1), Rv1 = Θ(d), Rv2 = Θ(d3) and Rr2 = Θ(1) so that with probability at least 1 1/ poly(d, 1/ε) over the random initialization, we have loss L ε within T = poly(d, 1/ϵ) time. 3 THE INFINITE-WIDTH DYNAMICS Our proof consists of analyzing the dynamics of the infinite-width mean-field network and controlling the discretization error. In this section, we characterized the infinite-width dynamics. For ease of presentation, we pretend there is no projection and the gradients are well-defined in this subsection and defer the discussion on handling the projections to Section 4. First, note that both the input distribution D and the infinite-width network are spherically symmetric. That is, for any x, x Rd with x = x , the density/function value are the same. Any spherically symmetric g : Rd R can be characterized by a function h : [0, ) R which satisfies h( x ) = g(x). For convenience, we will abuse notation to also use g : R R to denote this function h. Assuming that the distribution µ1 of the first layer neurons is spherically symmetric, which is true at least at initialization, we can approximate the first layer with a simple function using the following lemma. The proof of it can be found in Appendix B.3. Lemma 3.1. Let µ be a spherically symmetric distribution. We have E w µ w σ(w x) = CΓ Ew µ w 2 d x where CΓ := Γ(d/2) d 2 πΓ((d + 1)/2). Note that, as d , we have CΓ 1/ 2π, so CΓ is universally bounded for all d. Published as a conference paper at ICLR 2023 This lemma implies that, in the infinite-width limit, we have F(x) = α x for some real α > 0, at least at initialization. This suggests defining the infinite-width approximation as: d E w1 µ1 w1 2 , F(x) := α x , f(x) := E (w2,b2) µ2 σ(w2 F(x) + r2). (5) Note that (5) is well-defined no matter µ1 is infinite-width or not, though only in the infinite-width case will one have F = F. Later in Section 4 we will show that F F throughout the entire process in the discretization part of the proof. For the infinite-width network, one can imagine that, thanks to the symmetry, as long as µ1 is spherically symmetric at time t, then no first layer neuron will change its direction and the change in norm is also uniform, i.e., it does not depend on the direction v1. (See Appendix B.4 for the proof.) As a result, µ1 will remain spherically symmetric. Formally, one can show that, for any spherically symmetric g : Rd R, we have E x {g(x)σ(v x)} = CΓ d E x {g(x) x } v and E x {g(x)σ (v x)x} = CΓ d E x {g(x) x } v, where v = v/ v . Again, the proof of these two identities can be found in Appendix B.3. Apply these identities to v1 with g S and one can obtain d E x {S(x) x } v1. As a result, µ1 is always a uniform distribution over some sphere. Moreover, we have4 α = E w1 α w1 dt = 4C2 Γ d E x {S(x) x } E w1 w1 2 = 4CΓ d E x {S(x) x } α. This implies that the dynamics of the first layer can also be characterized by α alone. This reduces the dynamics of the first layer to a single real number α. That is, the outputs of the first layer depend only on α and x, and the dynamics of α also depend only on α instead of every single neuron w1. In other words, we do not need to look at the actual dynamics of w1 in this infinite-width case. We will later show that the spread of the second layer is always small, hence the second layer can be approximated by α x 7 σ( w2α x + b2) where ( w2, b2) = E(w2, b2). Combining these observations, one can characterize the dynamics of the entire network using three quantities: α, w2 and b2. We close this section with another interpretation of F, which is going to be handy in Section 4.2. Since we know that, in the idealized case, F should be spherically symmetric. Hence, it makes sense to define the idealized F to be the average over the sphere, that is, F(x) = Ex x Sd 1 F(x ). Note that in Lemma 3.1, the expectation is taken over the neurons while here it is over the inputs. However, similar to the proof of Lemma 3.1, one can still show that E x x Sd 1 F(x ) = E w µ1 E x x Sd 1 w 2 σ( w x) = CΓ Ew µ1 w 2 d x = α x . In other words, these two derivations are equivalent. In some sense, this means that the infinite-width network can be interpreted as a symmetrization of the actual finite-width network. 4 DISCRETIZING THE DYNAMICS WITH POLYNOMIAL-SIZE NETWORK In this section, we show how to discretize the infinite-width dynamics to get our main results. See Fig. 2 for simulation results. As we can see, even though the network has a finite width, at any time step, the function f(x) is close to a function of the form x 7 σ( b2 w2α x ), and throughout the training the second layer weights are well-concentrated. Let δ2 := max(v2,r2),(v 2,r 2) (v2, r2) (v 2, r 2) be the spread of the second layer, we will split the training procedure into two stages. Recall that ( w2, b2) := E(w2,b2) µ2(w2, b2). In Stage 1, w2 will decrease to poly(d)δ2. We show that after this condition is true, the projection operators in (4) can be ignored (that is, the corresponding terms never exceed the thresholds, see Lemma 4.1). In Stage 2, we show that the network can fit the target function in polynomial time. 4As in the standard mean-field arguments, we rescale the gradients by m so that it does not go to 0 as m . In most cases regarding gradient calculation, this is equivalent to using the formal rule v Ew g(w) = vg(v). Published as a conference paper at ICLR 2023 Figure 2: Simulation results. The left figure shows the loss during training. Each vertical dashed line corresponds to a time point plotted in the other two figures. The center figure depicts the shape of f at certain steps. The right figure shows the values of the second-layer neurons at certain steps. One can observe that f f indeed holds, and the second layer neurons are concentrated around ( w2, b2), which matches our theoretical analysis. Simulation is performed on a finite-width network with widths m1 = 512, m2 = 128 and input dimension d = 100. 4.1 STAGE 1: REMOVING THE PROJECTIONS Our first step shows that after a short amount of time in training, it is OK to ignore the projection operators in (4). To see why the projections can be ignored in certain circumstances, first note that if f f, second layer neurons concentrate around their mean, b2 = Θ(1) and w2 < 0, then f σ( w2α x + b2) vanishes outside { x Θ(1/| w2α|)}, whence the gradients also vanish for those large x. Meanwhile, by upper bounding the norm of the gradients, one can show that in order for the projections to be triggered, it is necessary for x to be large. As a result, when f decreases sufficiently fast, f(x) will reach 0 before x becomes too large. Formally, we have the following lemma, whose proof can be found in Appendix C. Lemma 4.1. Choose the projection threshold Rv1 = Θ(d), Rv2 = Θ(d3) and Rr2 = Θ(1) in (4). Suppose that α = Θ(1/ d). Then, the projection operators in r2, v1 and v2 will no longer be activated if all second layer weights are nonpositive, w2 > Θ(1)δ2 for some large constant, and w2 Θ(1)/Rv2 for some large constant, respectively. Based on this lemma, we further split Stage 1 into three substages. We define T1.1 to be the first time all second layer weights become negative, and T1.2 and T1.3 the first time | w2| becomes Θ(d)δ2 and Θ(1/Rv2), respectively. They represent the end time of Stage 1.1, 1.2, and 1.3, respectively. We require | w2| to be Θ(d)δ2 instead of Θ(1)δ2 at the end of Stage 1.2 so that the starting state of Stage 1.3 is more regular. By definition and Lemma 4.1, after each substage, one more projection can be ignored, and all of them can be ignored after Stage 1. The main lemma of Stage 1 is as follows. Recall that Rv1, Rv2, Rr2 are the clipping thresholds. Lemma 4.2 (Stage 1, informal). Define the end time of Stage 1 as T1 := inf{t 0 : w2(t) = C1/Rv2} for some large constant C1. Under the assumptions of Theorem 2.1, we have T1 poly(d, 1/ε) and the following conditions hold throughout Stage 1. (a) Approximation error of the first layer. For each v1 µ1, both the tangent movement and the radial spread can be controlled as v1(t) v1(0) δ(1) 1,T (t) and v1 2 = (1 δ(1) 1,R(t)) E w1 2, where δ(1) 1,T and δ(1) 1,R are two processes which are always small. (b) Spread of the second layer. For any (v2, r2), (v 2, r 2) µ2, (v2, r2) (v 2, r 2) is small. (c) Regularity conditions. r2 = Θ(1) for all (v2, r2) µ2, | w2| = O(1/Rv2) = O(1/d3) and α = Θ( d/Rv1) = Θ(1/d1.5). The first two conditions mean the approximation f(x) σ( w2α x + b2) is valid throughout Stage 1 and the third condition describes the shape of f in Stage 1. To maintain these conditions, we use the so-called continuity argument, which can be viewed as a continuous version of mathematical induction. See Appendix B.1 for explanations of this technique. Published as a conference paper at ICLR 2023 With the approximation F(x) α x and the fact f(x)σ (v2F(x) + r2) = f(x) for most x, we can rewrite the dynamics of v2 as v2 E x ΠRv1 [(f (x) f(x))α x ] . Since f is much flatter than f , f is still Ω(1) when f vanishes because of x 1. As a result, the RHS is always negative. In fact, we show that it is Θ(α log d). Recall that T1.2 is the time | w2| reaches Θ(dδ2). If δ2 roughly remains constant, the time needed for Stage 1.1 and Stage 1.2 is proportional to the initial δ2. Then, we can make the initial δ2 small by selecting a small enough σ2. This also helps control the movement of v1 and r2 in Stage 1.1 and Stage 1.2 as their dynamics depend on |w2|. One also needs to show that δ2 cannot increase too much during Stages 1.1 and 1.2 to maintain the approximation f(x) σ( w2F(x) + b2). Intuitively, this is because for inputs with small x , the gradient v2L(x) does not depend on (v2, r2) itself; for the inputs with a large norm, they cannot contribute too much to the gradient due to gradient clipping. As a result, the dynamics of v2 are approximately uniform in Stage 1.1 and Stage 1.2, whence the distance between different (v2, r2), (v 2, r 2) stays small. The same method does not work in Stage 1.3 as now the target value of w2 no longer depends on δ2, and we need a finer analysis for the first layer. Recall that, after Stage 1.2, the projection in v1 can be ignored. Therefore, we can decompose v1 along the radial and tangent direction as v1 = Rad( v1) + Tan( v1) = v1, v1 v1 + (I v1 v 1 ) v1 = 2 E x {S(x)σ(v1 x)} + v1 E x S(x)σ (v1 x)(I v1 v 1 )x . Then, we write S(x) (f (x) f(x)) w2 = (f (x) f(x)) w2 + ( f(x) f(x)) w2. The terms related to f f is essentially what one should expect to have in the infinite-width dynamics. For those terms, the radial movement is uniform and tangent movement is 0. Then, we bound terms related to f f using the radial spread and tangent movement of the first layer to obtain d dt δ(1) 1,R + δ(1) 1,T O(1) d2.5 δ(1) 1,R + δ(1) 1,T (cf. Lemma C.16). Though, with this bound, the error can grow exponentially fast (exp(t/d2.5)), this is sufficient since Stage 1.3 only takes O(d1.5) time. 4.2 STAGE 2: FITTING THE TARGET FUNCTION The goal of Stage 2 is for the gradient flow to converge to a point with loss at most ε in polynomial time. The main difficulty in this stage is that we need to bound the approximation error of the first layer more carefully, as Stage 2 is potentially long and the brute-force estimations used in Stage 1 is too loose towards the end of training. We write F := F/α and measure the approximation error using F|Sd 1 1 and F 2 L2. Strictly speaking, for the L2 error, we only consider those x with x Θ(1/| wα|) = poly(d) since otherwise it can be ill-defined. This is valid because, as we have discussed earlier, f vanishes for large x. In Stage 2, Ex always means E x Θ(1/| w2α|) and, for the simplicity of presentation, we usually do not explicitly state this. The main result of Stage 2 is as follows. Lemma 4.3 (Stage 2, informal). Define the end time of Stage 2 as T2 := inf{t T1 : L = ε}. Under the assumptions of Theorem 2.1, we have T2 T1 poly(d, 1/ε) and the following conditions hold throughout Stage 2: (a) Approximation error of the first layer. Both F L2 and F|Sd 1 1 L are small. (b) Spread of the second layer. max(v2,r2),(v 2,r 2) (v2, r2) (v 2, r 2) does not grow. (c) Regularity conditions. The shape of f is similar to the one shown in Figure 2. As we mentioned, the main technical challenge is to bound the approximation error of the first layer. The overall strategy is to first show that, in Stage 2, the L2 error barely grows and then show that, as long as the L2 error is small, the L error can also be controlled. Unlike Stage 1, | w2α| is fairly large in Stage 2 and, as a result, the first layer can receive some signal from the loss function. Published as a conference paper at ICLR 2023 Intuitively, this signal should push the first layer to become closer to a multiple of x as that is what the global optimal solution would do. Formally, we first show the following approximation: n (f (x) f(x))2o + w2 2 2 E x n ( F(x) F(x))2o , (6) in the sense that the gradients v1 of both sides are approximately the same, where f(x) is defined as E(w2,b2) µ2 σ(w2 F(x)+b2). The first term of (6) measures the distance between the target function and the infinite-width network and the second term measures the approximation error of the first layer. In some sense, one can view this formula as a bias-variance decomposition for discretizing mean-field networks. With this approximation in hand, we then show that, thanks to the 2-homogeneity of F, the first term, after certain normalization, does not affect the approximation error of the first layer. Meanwhile, since we are following the gradient flow, the second term can only decrease the approximation error. To establish (6), we first decompose the loss function as n (f (x) f(x))2o + 1 n ( f(x) f(x))2o + E x n (f (x) f(x))( f(x) f(x)) o =: L1 + L2 + L3. We claim that L2 is approximately the second term of (6) and the third term is approximately 05. Let X1 be the largest spherically symmetric set on which v2F(x) + r2 > 0 for all (v2, r2) µ2. We show that those x outside X1 contribute a little. Therefore, we can rewrite L2 as ( E w2,b2(w2 F(x) + b2) E w2,b2(w2F(x) + b2) 2) = w2 2 2 E X1 n ( F(x) F(x))2o w2 2 2 E x n ( F(x) F(x))2o . Similarly, we can rewrite L3 as L3 w2 Ex n (f (x) f(x))( F(x) F(x)) o . Recall from Section 3 that F(x) = Ex x Sd 1 F(x). With this in mind, one can easily verify that, for any spherically symmetric function g : Rd R, Ex {g(x)F(x)} = Ex n g(x) F(x) o . Setting g = f (x) f(x) gives L3 0. Combine these two estimations together and we obtain (6). Provided that the L2 error is always small, we show that, up to some higher order terms, d dt F( x) O(d3) F 2 L2 , x Sd 1. In words, the change of d dt F(x) can be bounded by the L2 error. Hence, F|Sd 1 1 L is always small as long as we choose a sufficiently large m1 so that F(x)|x Sd 1 is close to 1 at initialization. This should not be a surprise since, after all, in the infinite-width dynamics F(x)|x Sd 1 = 1. The formal proof of the above argument can be found in Section D.2. Given that the approximation error can be controlled, one can then derive a convergence rate using the infinite-width dynamics. See Section D.3 for details. 5 CONCLUSION In this paper we give a new framework for extending mean-field limit to multilayer networks, and use this framework to show that three-layer networks can learn a function that is not approximable by two-layer networks. There are still many open problems: for the current objective the loss is spherically symmetric so the first-layer neurons don t move much tangentially, what if the function is instead σ(1 PSx ) where PS is projection to some unknown subspace? How about functions that require an intermediate layer of size more than 1? Can one generalize the saddle point analysis to deeper networks? We hope this work will be a starting point for understanding how deep neural networks can learn useful features. 5For the ease of presentation, here we are talking about the function values instead of the gradients. Strictly speaking, this is incorrect as the function value being small does not necessarily imply the gradient is small. The ideas, however, are essentially the same. See Section D.2 for the actual proof. Published as a conference paper at ICLR 2023 ACKNOWLEDGEMENT This work is supported by NSF Award DMS-2031849, CCF-1845171 (CAREER), CCF-1934964 (Tripods) and a Sloan Research Fellowship. Zeyuan Allen-Zhu and Yuanzhi Li. What can resnet learn efficiently, going beyond kernels? ar Xiv preprint ar Xiv:1905.10337, 2019. Zeyuan Allen-Zhu and Yuanzhi Li. Backward feature correction: How deep learning performs deep learning. ar Xiv preprint ar Xiv:2001.04413, 2020. Zeyuan Allen-Zhu, Yuanzhi Li, and Yingyu Liang. Learning and generalization in overparameterized neural networks, going beyond two layers. ar Xiv preprint ar Xiv:1811.04918, 2018a. Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via overparameterization. ar Xiv preprint ar Xiv:1811.03962, 2018b. Dyego Ara ujo, Roberto I Oliveira, and Daniel Yukimura. A mean-field limit for certain deep neural networks. ar Xiv preprint ar Xiv:1906.00193, 2019. Yu Bai and Jason D Lee. Beyond linearization: On quadratic and higher-order approximation of wide neural networks. ar Xiv preprint ar Xiv:1910.01619, 2019. Yu Bai, Ben Krause, Huan Wang, Caiming Xiong, and Richard Socher. Taylorized training: Towards better approximation of neural network training at finite width. ar Xiv preprint ar Xiv:2002.04010, 2020. Minshuo Chen, Yu Bai, Jason D Lee, Tuo Zhao, Huan Wang, Caiming Xiong, and Richard Socher. Towards understanding hierarchical learning: Benefits of neural representations. ar Xiv preprint ar Xiv:2006.13436, 2020. Lenaic Chizat and Francis Bach. On the global convergence of gradient descent for overparameterized models using optimal transport. In Advances in neural information processing systems, pp. 3036 3046, 2018. Lenaic Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. In Advances in Neural Information Processing Systems, pp. 2933 2943, 2019. Amit Daniely. Depth separation for neural networks. In Conference on Learning Theory, pp. 690 696. PMLR, 2017. Zhiyan Ding, Shi Chen, Qin Li, and Stephen Wright. Overparameterization of deep resnet: zero loss and mean-field analysis. ar Xiv preprint ar Xiv:2105.14417, 2021. Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes over-parameterized neural networks. ar Xiv preprint ar Xiv:1810.02054, 2018. Ethan Dyer and Guy Gur-Ari. Asymptotics of wide networks from feynman diagrams. ar Xiv preprint ar Xiv:1909.11304, 2019. Ronen Eldan and Ohad Shamir. The Power of Depth for Feedforward Neural Networks. In Vitaly Feldman, Alexander Rakhlin, and Ohad Shamir (eds.), 29th Annual Conference on Learning Theory, volume 49 of Proceedings of Machine Learning Research, pp. 907 940, Columbia University, New York, New York, USA, June 2016. PMLR. URL http://proceedings.mlr. press/v49/eldan16.html. Cong Fang, Jason Lee, Pengkun Yang, and Tong Zhang. Modeling from features: a mean-field framework for over-parameterized deep neural networks. In Conference on learning theory, pp. 1887 1936. PMLR, 2021. Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. Limitations of lazy training of two-layers neural network. In Neur IPS, 2019. Published as a conference paper at ICLR 2023 Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari. When do neural networks outperform kernel methods? ar Xiv preprint ar Xiv:2006.13409, 2020. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770 778, 2016. doi: 10.1109/CVPR.2016.90. Jiaoyang Huang and Horng-Tzer Yau. Dynamics of deep neural networks and neural tangent hierarchy. In International Conference on Machine Learning, pp. 4542 4551. PMLR, 2020. Arthur Jacot, Franck Gabriel, and Cl ement Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in neural information processing systems, pp. 8571 8580, 2018. I. Krasikov. Uniform bounds for bessel functions. Journal of Applied Analysis, 12(1):83 91, 2006. doi: doi:10.1515/JAA.2006.83. URL https://doi.org/10.1515/JAA.2006.83. Yuanzhi Li, Tengyu Ma, and Hongyang R Zhang. Learning over-parametrized two-layer neural networks beyond ntk. In Conference on Learning Theory, pp. 2613 2682. PMLR, 2020. Shiyu Liang and R Srikant. Why deep neural networks for function approximation? In 5th International Conference on Learning Representations, ICLR 2017, 2017. Yiping Lu, Chao Ma, Yulong Lu, Jianfeng Lu, and Lexing Ying. A mean field analysis of deep resnet and beyond: Towards provably optimization via overparameterization from depth. In International Conference on Machine Learning, pp. 6426 6436. PMLR, 2020. Eran Malach and Shai Shalev-Shwartz. Is deeper better only when shallow is good? Advances in Neural Information Processing Systems, 32, 2019. Eran Malach, Gilad Yehudai, Shai Shalev-Schwartz, and Ohad Shamir. The connection between approximation, depth separation and learnability in neural networks. In Conference on Learning Theory, pp. 3265 3295. PMLR, 2021. Song Mei, Andrea Montanari, and Phan-Minh Nguyen. A mean field view of the landscape of twolayer neural networks. Proceedings of the National Academy of Sciences, 115(33):E7665 E7671, 2018. Phan-Minh Nguyen and Huy Tuan Pham. A rigorous framework for the mean field limit of multilayer neural networks. ar Xiv preprint ar Xiv:2001.11443, 2020. Atsushi Nitanda and Taiji Suzuki. Stochastic particle gradient descent for infinite ensembles. ar Xiv preprint ar Xiv:1712.05438, 2017. Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28, ICML 13, pp. III 1310 III 1318. JMLR.org, 2013. Huy Tuan Pham and Phan-Minh Nguyen. Global Convergence of Three-layer Neural Networks in the Mean Field Regime. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=Kvyx Fq ZS_D. Tomaso Poggio, Hrushikesh Mhaskar, Lorenzo Rosasco, Brando Miranda, and Qianli Liao. Why and when can deep-but not shallow-networks avoid the curse of dimensionality: a review. International Journal of Automation and Computing, 14(5):503 519, 2017. Ben Poole, Subhaneil Lahiri, Maithra Raghu, Jascha Sohl-Dickstein, and Surya Ganguli. Exponential expressivity in deep neural networks through transient chaos. Advances in neural information processing systems, 29, 2016. Grant M Rotskoff and Eric Vanden-Eijnden. Trainability and accuracy of neural networks: An interacting particle system approach. ar Xiv preprint ar Xiv:1805.00915, 2018. Itay Safran and Jason D Lee. Optimization-based separations for neural networks. ar Xiv preprint ar Xiv:2112.02393, 2021. Published as a conference paper at ICLR 2023 Itay Safran and Ohad Shamir. Depth-Width Tradeoffs in Approximating Natural Functions with Neural Networks. In Doina Precup and Yee Whye Teh (eds.), Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pp. 2979 2987. PMLR, August 2017. URL https://proceedings.mlr.press/ v70/safran17a.html. Itay Safran, Ronen Eldan, and Ohad Shamir. Depth Separations in Neural Networks: What is Actually Being Separated? In Alina Beygelzimer and Daniel Hsu (eds.), Proceedings of the Thirty-Second Conference on Learning Theory, volume 99 of Proceedings of Machine Learning Research, pp. 2664 2666, Phoenix, USA, June 2019. PMLR. URL http://proceedings. mlr.press/v99/safran19a.html. Justin Sirignano and Konstantinos Spiliopoulos. Mean field analysis of neural networks: A central limit theorem. Stochastic Processes and their Applications, 130(3):1820 1852, 2020. Justin Sirignano and Konstantinos Spiliopoulos. Mean field analysis of deep neural networks. Mathematics of Operations Research, 2021. Terence Tao. Nonlinear dispersive equations: local and global analysis. Number no. 106 in Conference Board of the Mathematical Sciences regional conference series in mathematics. American Mathematical Society, 2006. ISBN 978-0-8218-4143-3. OCLC: ocm65165502. Matus Telgarsky. Benefits of depth in neural networks. In Conference on learning theory, pp. 1517 1539. PMLR, 2016. Gal Vardi and Ohad Shamir. Neural networks with small weights and depth-separation barriers. Advances in neural information processing systems, 33:19433 19442, 2020. Luca Venturi, Samy Jelassi, Tristan Ozuch, and Joan Bruna. Depth separation beyond radial functions. Journal of Machine Learning Research, 23(122):1 56, 2022. Xiang Wang, Chenwei Wu, Jason D Lee, Tengyu Ma, and Rong Ge. Beyond lazy training for over-parameterized tensor decomposition. ar Xiv preprint ar Xiv:2010.11356, 2020. Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma. Regularization matters: Generalization and optimization of neural nets vs their induced kernel. In Advances in Neural Information Processing Systems, pp. 9712 9724, 2019. Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro Savarese, Itay Golan, Daniel Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. In Conference on Learning Theory, pp. 3635 3673. PMLR, 2020. Dmitry Yarotsky. Error bounds for approximations with deep relu networks. Neural Networks, 94: 103 114, 2017. Gilad Yehudai and Ohad Shamir. On the power and limitations of random features for understanding neural networks. ar Xiv preprint ar Xiv:1904.00687, 2019. Jingzhao Zhang, Tianxing He, Suvrit Sra, and Ali Jadbabaie. Why gradient clipping accelerates training: A theoretical justification for adaptivity. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=BJgn Xp VYw S. Mo Zhou, Rong Ge, and Chi Jin. A local convergence theory for mildly over-parameterized twolayer neural network. In Conference on Learning Theory, pp. 4577 4632. PMLR, 2021. Published as a conference paper at ICLR 2023 A MULTI-LAYER MEAN-FIELD NETWORKS In this section, we first briefly review existing theories of two-layer mean-field networks, and then introduce our framework for multi-layer mean-field networks. A.1 TWO-LAYER NETWORKS AND PERMUTATION INVARIANCE A two-layer network f of width m can usually be represented by6 f(x; W , a) = 1 ma σ(W x) = 1 i=1 aiσ(wi x). (7) where W Rm d is the weight matrix of the hidden layer and a Rm the output weights. Let µ be the empirical distribution of {(ai, wi)}m i=1 Rd+1. Then, we can write f(x; µ) = E (a,w) µ {aσ(w x)} . (8) By allowing µ to be an arbitrary sufficiently regular distribution over Rd, we obtain a neural network, represented by (8), that can contain infinitely many neurons. To describe the gradient flow of this infinite-width network, it suffices to assign a vector field to Rd+1 that describes how each neuron (a, w) Rd+1 should move at time t. One simple heuristic way to do so is to first compute the gradient in the finite-width case and then replace all summations with expectations as in (8) and treat the gradient as a vector field. We now illustrate the idea under realizable setting and with the MSE loss 2 E x (f (x) f(x))2 . The theory can be generalized to much more general settings and can be formally justified using the theory of Wasserstein gradient flow. Readers can refer to, for example, Chizat & Bach (2018) and Mei et al. (2018) for details. For a finite-width network (7), the gradient of L w.r.t. a neuron (ak, wk) is m ak L = E x {(f (x) f(x; W , a))σ(wk x)} , m wk L = E x {(f (x) f(x; W , a))akσ (wk x)x} . Replace f(x; W , a) with f(x; µ), treat (ak, wk) as a generic neuron, and we obtain a vector field : Rd+1 Rd+1 (a, w) := E x (f (x) f(x; µ)) σ(w x) aσ (w x)x At each time t, we update the neurons in µ according to . One of the most important properties of this mean-field formulation is that it factors out the permutation invariance of neurons. That is, we can permute (a1, w1), . . . , (am, wm) without changing the output of the network. However, when we treat training as an optimization problem over the space of (a, W ), i.e., Rm Rm d, permuting (ai, wi) entirely changes (a, W ). On the other hand, if we describe the network using a distribution µ over Rd+1, then it is automatically permutation invariant. Note that this is not restricted to infinite-width networks. When we choose µ to be an empirical distribution over finitely many neurons, we recover a finite-width network without breaking the permutation invariance. 6Here, wi Rd means the i-th row of W . Later we will notations vi, ai to denote i-th row or column of the corresponding matrix. Whether it is a row or column can be easily inferred from the dimension. The general rule is that if V RD m where m represents the number of neurons, then vi RD is i-th column, and if W Rm D, then wi RD is the i-th row. Published as a conference paper at ICLR 2023 A.2 MULTI-LAYER MEAN-FIELD NETWORKS Unfortunately, the above strategy cannot be directly generalized to multi-layer networks. Consider the three-layer network f(x; a, W2, W1) = 1 m2 a σ (W2h(x; W1)) , h(x; W1) = 1 m1 σ(W1x), where a Rm2, W2 Rm2 m1, W1 Rm1 d. One can still write f(x; a, W2, W1) = 1 m2 i=1 aiσ(w2,i h(x; W1)) = E (ai,w2) µ2 {aσ(w2 h(x; W1))} . However, now µ2 is a distribution over Rm1, and if m1 , it will become a distribution over R , which is not readily defined. One way to resolve this issue is to view W2 as a function from [m2] [m1] to R and then generalize it to handle the infinite-width case by replacing the index sets [m2], [m1] by two general index sets I2, I1 that can potentially be uncountable. For example, we can choose I1 = I2 = R. This is the strategy employed by Nguyen & Pham (2020). (See Pham & Nguyen (2021) for a more accessible version of this paper.) The drawback of this formulation is that, with the introduction of index sets, the permutation invariance is no longer factored out. Though with this formulation, it is still possible to obtain global convergence results for infinitewidth networks, it become less useful when we want to analyze a finite-width network as it becomes essentially the same as the usual matrix formulation. We now present a formulation that does factor out the permutation invariance of neurons, and it is built upon composing a sequence of vector-valued two-layer networks. As a first step, we consider a two-layer network with D-dimensional outputs: f(x; A, W ) = 1 m Aσ(W x), (9) where A RD m and W Rm d. For each index i [D], we still have fi(x; A, W ) = 1 j=1 ai,jσ(wj x) = E (a,w) µi {aσ(w x)} , where µi is the empirical distribution of {(ai,j, wj)}j [m] Rd+1. Range over i and we obtain the output vector of this network. For two-layer networks with scalar outputs, in order to obtain its mean-field counterpart, it suffices to allow µ to take a general distribution over R Rd. This, however, is not the case for networks with vector outputs as the W parts of µi are coupled. Hence, we need to additionally impose the constraint that all (µi)i [D] share the same second margin, that is, π2#µi = µW for some distribution µW over Rd and all [D], where π2 : R Rd Rd is the projection that takes (a, w) to w. Intuitively, this condition says that they share the same first layer weights W . We formalize this idea in the following definition. Definition A.1. Let (µi)i [D] be D sufficiently regular7 distributions over R Rd. We call (µi)D i=1 an admissible configuration of dimension (D, d) if there exists a measure µW over Rd such that π2#µi = µW holds for all i [D]. Remark. Note that here, by a neuron, we mean a (D + d)-dimensional vector (a1, . . . , a D, w). In the finite-width network (9), this corresponds to a row in W and the corresponding column in A. This point of view is important when deriving the infinite-width gradient flow since, as in the twolayer case, the vector field at the position of a certain neuron can only depend on the other neurons as a whole. To complement the discussion, here we consider the problem that, given an admissible infinite-width configuration (µi)i [D], how to obtain a finite-width network with m neurons. For a scalar-valued 7Our focus is on factoring out the permutation invariance and, in this paper, essentially all distributions are empirical distributions over finitely many neurons, with respect to which the integral is just summation and is always well-defined. We leave the work of figuring out specific regularity conditions to future works. Published as a conference paper at ICLR 2023 mean-field network characterized by µ, it suffices to generate m samples from µ. For a vectorvalued network, the procedure is slightly different. We first sample a weight vector w from the shared margin µW . Then, for each i [D], we generate a real number ai conditioning on w. This gives us a neuron (a1, . . . , a D, w) RD Rd. Repeat this procedure m times and we obtain a finite-width network with m neurons. We formally define two-layer vector-valued mean-field networks as follows. Definition A.2. Given an admissible (µi)i [D], the two-layer vector-valued network it defines is F (x; µ1, . . . , µD) = (F1(x; µ1), . . . , FD(x; µD)), (10) where Fi(x; µi) = E (a,w) µi {aσ(w x)} , i [D]. Now, we are ready to define a multi-layer mean-field network. Basically, a multi-layer mean-field network is a composition of a sequence of two-layer vector-valued networks (10). Definition A.3. Let L 1 be an integer. Let D(1), . . . , D(L) be a sequence of positive integers and put D(0) = d. For each l [L], let (µ(l) i )i [Dl] be an admissible configuration of dimension (D(l), D(l 1)). The L-layer mean-field network f defined by the configuration Θ := ((µ(l) i )i [Dl])l [L] is defined recursively as f(x; Θ) = F (L)(x; Θ), F (l)(x; Θ) := F F (l 1)(x; Θ); µ(l) 1 , . . . , µ(l) Dl F (0)(x; Θ) := x, where F is the two-layer mean-field network given by (10). Example As an example, we consider the case L = 3 here. In this case, the finite-width network corresponding to (11) is f(x; A2, W2, A1, W1) = 1 m2 A2σ W2 1 m1 A1σ(W1x) , which is exactly the usual multi-layer network used in practice except the normalizing terms 1/m2, 1/m1 and an additional matrix A1 RD1 m1. This matrix compresses an m1 dimensional feature vector to a D1 dimensional one, where D1 is an integer that does not go to . It is a reminiscent of the bottleneck structure used in Res Net (He et al. (2016)). Remark. Note that this formulation is indeed invariant under permutation of each layer s neurons. However, it does not factor out all permutation invariance of a deep network. For example, one can permute the columns of W1 and adjusting A1, W2, A2 accordingly without changing the output of the network. In some sense, this corresponds to permuting the entires of the hidden feature F (1). We believe it is not necessary or useful to factor out this symmetry since, after all, even in the two-layer case, we do not permute the entries of the inputs x. Finally, we consider the problem of formulating mean-field gradient flow so that it matches the usual gradient flow. The idea is simple: We compute the gradient in the finite-width setting and then replace summations with integrals. For the ease of presentation, we consider a three-layer network and the MSE loss. Again, this framework can be easily generalized to deeper networks and other loss functions. We write f(x) = f(x; a, W2, V1, W1) = 1 m2 a σ (W2F (x; V1, W1)) , F (x) = F (x; V1, W1) = 1 m1 V1σ(W1x), L = L(a, W2, V , W1) = 1 2 E x (f (x) f(x; a, W2, V , W1))2 , Published as a conference paper at ICLR 2023 where a Rm2, W2 Rm2 D, V1 RD m1, W1 Rm1 d. We have m2 ai L = E x {(f (x) f(x))σ(w2,i F (x))} , i [m2], m2 w2,i L = E x {(f (x) f(x))aiσ (w2,i F (x))F (x)} , i [m2], m1 v1,i L = E x (f (x) f(x)) 1 j=1 ajσ (w2,j F (x))w2,jσ(w1,i x) m1 w1,i L = E x (f (x) f(x)) 1 j=1 ajσ (w2,j F (x)) w2,j, v1,i σ (w1,i x)x Replace summations with integrals and we obtain (a,w2) = E x (f (x) f(x)) σ(w2 F (x)) aσ (w2 F (x))F (x) (v1,w1) = E x (f (x) f(x)) E (a,w2) µ2 aσ (w2 F (x)) σ(w1 x)w2 w2, v1 σ (w1 x)x Namely, at each step t, we update the second layer neurons (a, w2) with (a,w2), and first layer neurons (v1, w1) with (v1,w1). Note that, unlike many other multi-layer mean-field frameworks, we do not introduce any notion of paths. The dynamics of each first layer neuron depends on the second layer as a whole as we take expectation over µ2 in (12). The same is also true for second layer neurons. In some sense, the additional matrix V1 decouples the dynamics of the first and second layer neurons. B PRELIMINARIES B.1 INDUCTION HYPOTHESIS AND CONTINUITY ARGUMENT We extensively use the continuos-time version of mathematical induction in our proof, which is also called the continuity argument. We briefly discuss this technique in this subsection and explain some conventions we employ in the writing of the proof. One may refer to, for example, Chapter 1.3 of Tao (2006) for details. Similar to the discrete-time induction argument, the goal is to maintain a collection of conditions, which we call the Induction Hypothesis, throughout a period of time (cf. Induction Hypothesis C.2 and Induction Hypothesis D.1). There are mainly two types of conditions. The first type has the form certain process At is bounded by another process Bt . In the proof, At is usually the error we want to control and Bt an non-decreasing process representing the corresponding upper bound. To maintain this type of condition, it suffices to show that At Bt at initialization and At Bt as long as the Induction Hypothesis is true. For this type of condition, usually we also have an upper bound for Bt, say, Bt B . The most rigorous way to maintain these bounds is to argue by contradiction. Let T be the minimum between the time T1 the process ends and the time T2 this bound first get violated. By definition, the Induction Hypothesis holds for any t T. Using the Induction Hypothesis, one can then derive an upper bound T on T1, which then leads to an upper bound on T. Then, all we need to show is that BT is smaller than B so that T is attained by T1 instead of T2. For the ease of presentation, for this type of conditions, instead of arguing by contradiction explicitly, we will simply show that, provided that the Induction Hypothesis is true over [0, T1], then BT1 B holds. The second type has the form certain process Ct is bounded some value D . Here, Ct is usually some quantity related to the shape of the learner function such as w2 and α. In order to maintain, say, Ct D, we show that when Ct [D ε, D], we have Ct < 0. This implies that, as long as Ct is continuous, this implies Ct can never reach D. Published as a conference paper at ICLR 2023 B.2 PROPERTIES OF THE INPUT DISTRIBUTION In this subsection, we derive some basic properties of the input distribution that will be useful in later analysis. The following lemma gives the distribution of x and its tail bound. Lemma B.1. Let x D and let D denote the distribution of x . We have r J2 d/2(2πRdβα As a result, we have the tail bound: for all R > 0, P[ x R] O (1/R). We now give some regularity conditions on the input distribution that will be used in our proof. Roughly speaking, it shows that the distribution is heavy-tailed and still has large enough mass for x [0, 1] Lemma B.2 (Regularity conditions on input distribution). For the input distribution D, we have (a) E x 0.99 x = Θ(1). (b) Ex D f (x) = Ω(1). (c) E x Ω(d) x Θ(log d) and E x poly(d) x Θ(log(d)). Proof of Lemma B.1. Recall that the input distribution of x is βα where α, β > 0 are the universal constants from Safran et al. (2019) (cf. the proof of Theorem 5), d/2 Jd/2(2πRd x ), x Rd, Rd = 1 π(Γ(d/2 + 1))1/d = Θ( d) (Lemma 5 in Eldan & Shamir (2016)) and Jν is the Bessel function of the first kind of order. Note that since φ only depends on x , we can abuse the notation to use φ(r) to denote φ(x) with x = r. For any test function g : R 7 R, we have E x D[g( x )] = Z Rd g( x ) βα 0 g(r)φ2(βα where Sd 1 = 2πd/2/Γ(d/2) is the surface of unit ball Sd 1. Therefore, we have the density of x with x = r is d d Sd 1φ2(βα dr)rd 1 = 2πd/2 βα Γ(d/2) Rd d βα dr d J2 d/2(2πRdβα r J2 d/2(2πRdβα where we use the fact that Jν(z) = O(1/ z) (Krasikov (2006)). Then, it is easy to see that P( x R) = O(1/R). Published as a conference paper at ICLR 2023 Proof of Lemma B.2. (a) It is easy to see the upper bound E x 0.99 x 0.99. For lower bound, note that E x 0.99 x 0.1 P(0.1 x 0.99). Hence, it suffices to lower bound P(0.1 x 0.99). We have P(0.1 x 0.99) = Z 0.99 d r J2 d/2(2πRdβα Ω(1) Z 1.98πRdβα d J2 d/2(r)dr where in the last line we use Lemma 23 in Eldan & Shamir (2016). This implies that E x 0.99 x = Ω(1). Together with the upper bound, we have E x 0.99 x = Θ(1). (b) We have E x D f (x) = E x 1[1 x ] E x 0.99[1 x ] 0.01 P( x 0.99) 0.01 P(0.1 x 0.99) = Ω(1), where the last inequality we use the calculation in (a). (c) The upper bound follows directly from the tail bound D (r) O(1/r2). For the lower bound, recall the density of x when x = r is d r J2 d/2(2πRdβα dr). For notational simplicity, put RD = Θ(d). We have E x RD x = Z RD 0 d J2 d/2(2πRdβα Z 2πRd RDβα 0 J2 d/2(r)dr cd J2 d/2(r)dr, where c is a large enough constant. To lower bound E x , it suffices to lower bound R cd2 cd J2 d/2(r)dr. In the following, we will lower bound it by following a similar calculation in Lemma 23 in Eldan & Shamir (2016). From the proof of Lemma 23 in Eldan & Shamir (2016), we have for x d 2 J2 d/2(x) 2 πx cos2 (d + 1)π 4 + fd,xx 3x 2, where fd,x is a quantity that depends on d and x, and satisfies 1.3 fd,x 0.85. Then, we have Z cd2 cd J2 d/2(x)dx Z cd2 2 πx cos2 (d + 1)π 4 + fd,xx dx Z cd2 1 x cos2 (d + 1)π 4 + fd,xx dx 3(d 1) Note that in the proof of Lemma 23 in Eldan & Shamir (2016), it is shown that Published as a conference paper at ICLR 2023 Then, since 1.3 fd,x 0.85 we have 1 x cos2 (d + 1)π 4 + fd,xx dx 0.85 fd,xx cos2 (d + 1)π Z fd,cd2cd2 z cos2 (d + 1)π 1 z cos2 (d + 1)π Then, using integration by parts and the fact that cos2(z (d+1)π/4) = z(z/2+sin(2z (d + 1)π/2)/4), we have Z 0.85cd2 1 z cos2 (d + 1)π 4 sin(2z (d+1)π 1.3cd + Z 0.85cd2 4 sin(2z (d+1)π 1 0.85cd2 + 1 1.3cd + Z 0.85cd2 1 0.85cd2 + 1 1.3cd 4 ln 0.85cd2 1.3cd = Ω(log d). Therefore, we have Z cd2 cd J2 d/2(x)dx = Ω(log d), which implies E x Θ(d) x = Ω(log d). B.3 PROPERTIES OF SPHERICALLY SYMMETRIC FUNCTIONS AND DISTRIBUTIONS In this subsection, we give some useful proprieties of spherically symmetric functions and distributions. These will be useful tools in our later analysis. Basically, these lemmas allow us to disentangle input x and neuron v when considering integration against spherically symmetric function. Lemma 3.1. Let µ be a spherically symmetric distribution. We have E w µ w σ(w x) = CΓ Ew µ w 2 d x where CΓ := Γ(d/2) d 2 πΓ((d + 1)/2). Note that, as d , we have CΓ 1/ 2π, so CΓ is universally bounded for all d. Lemma B.3. For any spherically symmetric g : Rd R and v Rd, we have E x {g(x)σ(v x)} = CΓ d E x {g(x) x } v . Corollary B.4. Let g : Rd R be a spherically symmetric function. We have E x {g(x)F(x)} = α E x {g(x) x } . Lemma B.5. Let g : Rd R be a spherically symmetric function. Then, for any v Rd, we have E x D {g(x)σ (v x)x} = E x D {g(x) x } CΓ Published as a conference paper at ICLR 2023 Proof of Lemma 3.1. For simplicity, put g(x) = Ew µ w σ(w x). Since σ is 1-homogenous and µ is spherically symmetric, we have Rd w 2 σ( w x)µ(w) dw Sd 1 r2σ( w x)µ(r w)rd 1 dσd 1( w)dr 0 rd+1µ(r) dr Z Sd 1 σ( w x) dσd 1( w). For the first term, note that8 Rd w 2 µ(w) dw = Z Sd 1 r2µ(r w) dσd 1( w)dr = 2πd/2 0 rd+1µ(r) dr. 0 rd+1µ(r) dr = Γ(d/2) Rd w 2 µ(w) dw = Γ(d/2) 2πd/2 E w µ w 2 . Then we compute the second term as follows. Since it is also spherically symmetric, we have Z Sd 1 σ( w x) dσd 1( w) = x Z Sd 1 σ( w1) dσd 1( w) = x Sd 1 | w1| dσd 1( w). Define I = R Rd |w1|e w 2 dw. We have i=1 e w2 i dw = Z |w1|e w2 1 dw1 e w2 i dwi = π(d 1)/2. We also have 0 r| w1|e r2rd 1 drdσd 1( w) = Z 0 e r2rd dr Z Sd 1 | w1| dσd 1( w) = Γ((d + 1)/2) Sd 1 | w1| dσd 1( w). Therefore, Z Sd 1 σ( w x) dσd 1( w) = x Sd 1 | w1| dσd 1( w) = π(d 1)/2 Γ((d + 1)/2) x . (13) g(x) = Γ(d/2) 2πd/2 E w µ w 2 π(d 1)/2 Γ((d + 1)/2) x = CΓ Ew µ w 2 Proof of Lemma B.3. We compute E x D {g(x)σ(v x)} = Z Rd g(x)σ(v x)D(x) dx Sd 1 g(r x)σ(v (r x))D(r x)rd 1 dσd 1( x)dr Sd 1 g(r)σ(v x)D(r)rd dσd 1( x)dr 0 g(r)D(r)rd dr Z Sd 1 σ(v x) dσd 1( x) 0 g(r)D(r)rd dr π(d 1)/2 Γ((d + 1)/2) v , 8Recall the surface area of the d-dimensional unit sphere is R dσd 1 = 2πd/2 Published as a conference paper at ICLR 2023 where the last line comes from (13). (Note the integral is taken w.r.t. x instead of w here.) For the first term, note that E x D {g(x) x } = Z Rd g(x) x D(x) dx Sd 1 g(r)D(x)rd dσd 1( x)dr Sd 1 g(r)D(x)rd dσd 1( x)dr 0 g(r)D(x)rd dr. E x D {g(x)σ(v x)} = E x D {g(x) x } 2πd/2 Γ((d + 1)/2) w = E x D {g(x) x } CΓ Proof of Corollary B.4. By the previous Lemma, we have E x {g(x)F(x)} = E x g(x) E w µ1 { w σ(w x)} n w E x {g(x)σ(w x)} o d E x {g(x) x } = α E x {g(x) x } . Proof of Lemma B.5. Define R = v v (Id v v ) = 2 v v Id. That is, R is the reflection matrix associated with v. Since D is spherically symmetric, we have R#D = D. For the same reason, g R = g. Moreover, by construction, Rv = v. Hence, E x D {g(x)σ (v x)x} = 1 E x D {g(x)σ (v x)x} + E x R#D {g(x)σ (v x)x} E x D {g(x)σ (v x)x + g(Rx)σ (v Rx)Rx} E x D {g(x)σ (v x)x + g(Rx)σ (Rv x)Rx} E x D {g(x)σ (v x) (x + Rx)} . Note that x + Rx = 2 v v x = 2 v, x v. Hence, E x D {g(x)σ (v x)x} = E x D {g(x)σ( v x)} v = E x D {g(x) x } CΓ where the second identity comes from Lemma B.3. B.4 THE INFINITE-WIDTH NETWORK REMAINS SPHERICALLY SYMMETRIC In this subsection, we show that the infinite-width network remains spherically symmetric throughout the whole process. Clear that µ1 is spherically symmetric at initialization. Now, assume that it is spherically symmetric at time t. We claim that v1 does not move tangentially, and its radial speed does not depend on its direction v1. That is, v1 = h( v1 ) v1 for some function h. Published as a conference paper at ICLR 2023 By our induction hypothesis, S is also spherically symmetric at time t. Let T := 2 v1 v 1 Id be the reflection w.r.t. v1. Clear that T v1 = v1. Moreover, it does not change the norm and, as a result, S(Tx) = S(x), T #D = D and Π T = T Π. Hence, we have v1 = E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] 2 E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] 2 E x T #D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] . For the second term, we have E x T #D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] = E x D ΠRv1 [S(x) ( v1σ(v1 T x) + v1 σ (v1 T x)T x)] = E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)T x)] = E x D ΠRv1 [S(x)T (σ(v1 x) v1 + v1 σ (v1 x)x)] = E x D T ΠRv1 [S(x) (σ(v1 x) v1 + v1 σ (v1 x)x)] . 2 (I + T ) E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] = 2 D v1, E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] E v1. Namely, v1 = h(v1) v1 where h(v1) = 2 D v1, E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] E . Now, we show that h is spherically symmetric to complete the proof. Let R be an arbitrary rotation matrix. We have h(Rv1) = 2 D R v1, E x D ΠRv1 [S(x) (R v1σ(Rv1 x) + v1 σ (Rv1 x)x)] E = 2 D R v1, E x D ΠRv1 S(x) R v1σ(v1 R x) + v1 σ (v1 R x)RR x E = 2 R v1, E x R #D ΠRv1 [S(x) (R v1σ(v1 x) + v1 σ (v1 x)Rx)] = 2 D R v1, R E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] E = 2 D v1, E x D ΠRv1 [S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)] E Thus, h is spherically symmetric. The goal of Stage 1 is for all v2 to decrease to Θ(1/Rv2) so that we can ignore all projection operators in r2, v1 and v2. We split Stage 1 into three substages, in which v2 decreases to 0, poly(d)δ2 and Θ(1/Rv2), respectively. By Lemma C.3, at the end of each substage, one more projection operator can be ignored. We also show that, in Stage 1, the approximation error of the first layer and the spread of second layer cannot grow too much. First, for the initialization, by some standard concentration argument, we have the following lemma. Published as a conference paper at ICLR 2023 Lemma C.1 (Initialization). We choose m1 = poly(d, 1/ε), m2 = Θ(1), σ1 = 1/ d, σ2 = 1/ poly(d, 1/ε), and σr to be a small constant. We initialize w1 Unif(σ1Sd 1) for µ1, and w2 N(0, σ2 2) and b2 = σr for µ2. Given δ1,I = 1/ poly1(d, 1/ε), we choose a sufficiently large m1 so that, at initialization, with probability at least 1 1/ poly(d), F|Sd 1 1 L δ1,I. We also choose σ2 = δ1,I/d7. With probability at least 1 1/ poly(d), we have maxw2 |w2| O(log d)σ2. Then, we formally state the Induction Hypothesis we are going to maintain for Stage 1. Induction Hypothesis C.2 (Stage 1). We define T1 := inf {t 0 : w2(t) = Θ(1)/Rv2} for some large constant. Define δ(1) 1,T , δ(1) 1,R, δ(1) 2 as9 δ(1) 1,T = max δ(1) 1,T (0). max v1 µ1 v1(t) v1(0) , δ(1) 1,R = max δ(1) 1,R(0), max v1 µ1 v1 2 Ew1 w1 2 δ(1) 2 = max δ(1) 2 (0), max (v2,r2),(v 2,r 2) (v2, r2) (v 2, r 2) , in Stage 1.1 and Stage 1.2, d dtδ(1) 1,T = Re LU d dt max v1 µ1 v1(t) v1(0) , d dtδ(1) 1,R = Re LU d dt max v1 µ1 v1 2 Ew1 w1 2 d dtδ(1) 2 = Re LU d dt max (v2,r2),(v 2,r 2) (v2, r2) (v 2, r 2) , in Stage 1.3, with initial value δ(1) 1,T (0) = δ(1) 1,R(0) = 0 and δ(1) 2 (0) = Θ(σ2 log d). We say that this Induction Hypothesis is true at time t [0, T1] if the following hold.10 (a) Approximation error of the first layer. For each v1 µ1, v1(t) v1(0) δ(1) 1,T and v1 2 = 1 δ(2) 1,R Ew1 µ1 w1 2. (b) Spread of the second layer. For any (v2, r2), (v 2, r 2) µ2, (v2, r2) (v 2, r 2) δ(1) 2 . (c) The bias term. For any (v2, r2) µ2, r2 = Θ(1). (d) Size of f. | w2| = O(1/Rv2) = O(1/d3) and α = Θ( d/Rv1) = Θ(1/d1.5). (e) Bounds for the errors. δ(1) 2 O(d1.5(log d)σ2) and δ(1) 1,R +δ(1) 1,T O(d7(log d)σ2 +δ1,I) The next lemma describes when the projection operators can be ignored. Roughly speaking, we first bound the gradients to show that in order for a projection operator to be triggered, x must be larger than a certain quantity. Meanwhile, note that f, whence the gradients, vanishes for those x with x Θ(1/| w2α|). Hence, as long as Θ(1/| w2α|) is smaller than that quantity, we can ignore the projection. 9Note that we define these δ s to be upper bounds of the corresponding values instead the values themselves. The only reason we define these δ s in such a twisted way is to make the proof easier to write rigorously. See the footnote in Induction Hypothesis D.1, where this type of definitions plays more technically important role, for further discussions. 10The first two conditions actually follow directly from the definition of the δ s. We put repeat them here only for easier reference. The actual result we need to prove for these δ s is condition (e), which says that these δ s are always small. Published as a conference paper at ICLR 2023 Lemma C.3. Suppose that Induction Hypothesis C.2 is true. The projection operators in r2, v1 and v2 will no longer be activated if all second layer weights are nonpositive, w2 > Θ(1)δ(1) 2 for some large constant, and w2 Θ(1)/Rv2 for some large constant, respectively. Remark. Though we only need w2 to be Θ(1)δ(1) 2 to ignore the projection operator in v1, we will actually define the end of Stage 1.2 to be the time w2 becomes poly(d)δ(1) 2 to get a more regular start for Stage 1.3. Now, we present the main lemma of Stage 1. One can see that, by properly choosing the parameters, the errors can be made arbitrarily small without affecting the final value of α and w2. To prove the main lemma, it suffices to combine Lemma C.6, Lemma C.9 and Lemma C.10 together. Lemma C.4 (Main lemma of Stage 1). Induction Hypothesis C.2 is true throughout Stage 1. Stage 1 takes at most O(d4σ2 + 1/d1.5) amount of time. At the end of Stage 1, we have α = Θ(1/d1.5) and w2 = Θ(1/d3). For the errors, we have δ(2) 2 O(d1.5 log dσ2) and δ(1) 1,R + δ(1) 1,T O(δ1,I). Proof of Lemma C.3. First, note that when all v2 are nonpositive, we have f = O(1). Since we choose Rr2 to be a large constant, this implies the projection operator in r2 will not be activated. When w2 > Θ(1)δ(1) 2 , we have f(x) σ(c w2α x + O(1)) for some small constant c > 0. As a result, f vanishes on { x ( c w2α) 1}. Then, for those x with x ( c w2α) 1, the gradient w.r.t. v1 can be bounded as v1L(x) O(1)| w2| x v1 O(1)| w2| v1 1 | w2|α O(d). Since we choose Rv1 = Θ(d) with a large constant, this implies the projection operator in v1 will not be triggered. Finally, for v2, for those x with x ( c w2α) 1, we have | v2L(x)| O(1)α x O(1) By assumption, | w2| = Θ(1)/Rv2 for some large constant. Hence, this inequality implies the projection operator in v2 will not be triggered. C.1 STAGE 1.1 The goal of Stage 1.1 is to make sure that all second layer weights v2 become non-positive, that is, T1.1 := inf{t 0 : (v2, r2) µ2, v2 0}. As a result, at the end of Stage 1.1, f is O(1) and, by Lemma C.3, the projection operator in r2 can be ignored. Since this stage only takes a very small amount of time, we shall control the first layer error by directly bounding the movement of v1. For the second layer, we bound the movement of the bias term in the same brute-force way. For second layer weights, we show that those positive v2 s decrease faster than the negative v2 s, so the spread will not increase. Lemma C.5. Suppose that Induction Hypothesis C.2 is true at time t and t T1.1. Then the following hold. (a) v1 Rv1 and | r2| Rr2. (b) maxw2 w2 minw2 w2 is non-increasing. (c) For any positive second layer weight v2, we have v2 Θ(log d/d1.5). Remark. In fact, (c) holds whenever α = Ω(1/d1.5) and v2F(x)+r2 Θ(1) for any (v2, r2) µ2 and x { x d1.5}, which is always true throughout Stage 1. This estimation will also be used in Stage 1.2 and Stage 1.3. Published as a conference paper at ICLR 2023 Lemma C.6 (Main lemma of Stage 1.1). Stage 1.1 takes at most O(d1.5δ(1) 2 (0)) amount of time. At the end of Stage 1.1, all second layer weights v2 are non-positive. Hence, f = O(1) and, by Lemma C.3, the projection operator in r2 can no longer be activated. For the errors, we have δ(1) 2 (T1.1) O(d1.5δ(1) 2 (0)), and both δ(1) 1,R(T1.1) and δ(1) 1,T (T1.1) can be bounded by O(d3δ(1) 2 (0)). Proof of Lemma C.5. (a) This is obvious. (b) First, we decompose v2 as v2 = E x 1 {(f (x) f(x))F(x)} E x 1 ΠRv2 [f(x)σ (v2F(x) + r2)F(x)] . Note that the first term does not depend on v2, and, for the second term, σ (v2F(x)+r2) = 1 whenever v2 0. As a result, the speed of positive v2 is uniform and more negative than those v2 < 0. Thus, maxw2 w2 minw2 w2 is non-increasing. (c) Clear that E x 1 {(f (x) f(x))F(x)} = O(α). For the second term, first note that for any x with x d1.5, we have f(x)F(x) O 1 + max w2 w2α x α x Rv2 and f(x) Θ(1) max w2 |w2|α x = Θ(1). As a result, ΠRv2 [f(x)σ (v2F(x) + r2)F(x)] Θ(α) E 1 x d1.5 x = Θ ((log d)α) . Thus, v2 Θ(log d/d1.5). Proof of Lemma C.6. By Lemma C.5, it takes at most O(d1.5δ(1) 2 (0)) amount of time for all v2 to become nonpositive. Within this amount of time, r2 at most changes O(d1.5δ(1) 2 (0)). Since the spread of w2 does not increase, this implies δ(1) 2 (T1.1) O(d1.5δ(1) 2 (0)). Finally, the change of v1 can be bounded by O(d2.5δ(1) 2 (0)). As a result, both δ(1) 1,R(T1.1) and δ(1) 1,T (T1.1) can be bounded by O(d3δ(1) 2 (0)). C.2 STAGE 1.2 The goal of Stage 1.2 is to make sure w2 dδ(1) 2 (T1.1). Namely, T1.2 := inf n t T1.1 : w2 = dδ(1) 2 (T1.1) o . We will also show that δ(1) 2 (T1.2) = O(δ(1) 2 (T1.1)) so δ(1) 2 (T1.2)/| w2| = O(1/d) at the end of Stage 1.2. Moreover, by Lemma C.3, at the end of Stage 1.2, the projection operator in v1 will no longer be activated. We also show that r2 remains Θ(1) throughout Stage 1 in this subsection. The first layer error is again controlled in a brute-force way. For the second layer spread, we show that since |v2| is small, σ (v2F(x) + r2) = 1 for most of x and, as a result, the change of (v2, r2) is approximately uniform. Lemma C.7. Suppose that Induction Hypothesis C.2 is true at time t. Then, for any (v2, r2) µ2, r2 > 0 when r E f /2 and r2 < 0 when r 2 E f . As a result, r2 = Θ(1) throughout Stage 1. Published as a conference paper at ICLR 2023 Lemma C.8 (Spread of the second layer). Suppose that Induction Hypothesis C.2 is true at time t and t T1.2. Then, for any (v2, r2), (v 2, r 2) µ2, we have d dt (v2, r2) (v 2, r 2) 2 O(d2.5) δ(1) 2 2 . Though, by this Lemma, the error δ(1) 2 can grow exponentially fast and the growth rate is quite large, it will not blow up as v2 Θ(log d/d1.5), so the time needed for Stage 1.2 is much shorter than 1/d2.5. Lemma C.9 (Main lemma of Stage 1.2). Stage 1.2 takes at most O(d2.5δ(1) 2 (T1.1)) amount of time. At the end of Stage 1.2, we have, for any (v2, r2) µ2, v2 Θ(d)δ(1) 2 (T1.1). For the errors, the spread of the second layer is (1 + o(1))δ(1) 2 (T1.1), and both δ(1) 1,R(T1.2) and δ(1) 1,T (T1.2) can be bounded by O(d4δ(1) 2 (T1.1)). Proof of Lemma C.7. We write r2 = E x {(f (x) f(x))σ (v2F(x) + r2)} = E x f (x) E x {f(x)σ (v2F(x) + r2)} Since the spread of b2 is o(1), when r2 Ex f (x)/2 = Θ(1), the RHS is a positive constant. In other word, r2 will keep grow. Meanwhile, since the second term can be bounded as Ex {f(x)σ (v2F(x) + r2)} E x d2 {f(x)} (1 o(1)) b2, when r2 2 E f (x), r2 will become a negative constant and r2 will decrease. Combine this two cases together, and we complete the proof. Proof of Lemma C.8. Since |v2| dδ(1) 2 (T1.1), F(x) = Θ(α) x and r2 = Θ(1), v2F(x) + r2 > 0 for all x with x Θ( d/δ(1) 2 (T1.1)). Hence, we can rewrite v2 as v2 = E x Θ( d/δ(1) 2 (T1.1)) ΠRv2 [(f (x) f(x))F(x)] d/δ(1) 2 (T1.1)) ΠRv2 [f(x)σ (v2F(x) + r2)F(x)] . The first term does not depend on v2 and, by the tail bound, the second term can be bounded by O(Rv2δ(1) 2 (T1.1)/ d). Similarly, for r2, we have r2 = E x Θ( d/δ(1) 2 (T1.1)) {f (x) f(x)} O δ(1) 2 (T1.1)/ Hence, for any (v2, r2), (v 2, r 2) µ2, we have d dt (v2, r2) (v 2, r 2) 2 (v2 v 2)O Rv2δ(1) 2 (T1.1) δ(1) 2 (T1.1) O(d2.5) δ(1) 2 2 . Proof of Lemma C.9. Recall from Lemma C.5 that v2 = Θ(log d/d1.5), whence Stage 1.2 takes at most O(d2.5δ(1) 2 (T1.1)) amount of time. By Lemma C.8, we have δ(1) 2 (T1.2) 2 δ(1) 2 (T1.1) 2 exp O(d5)δ(1) 2 (T1.1) (1 + o(1)) δ(1) 2 (T1.1) 2 . For v1, similar to the proof of Lemma C.6, both δ(1) 1,R(T1.2) and δ(1) 1,T (T1.2) can be bounded by O(d4δ(1) 2 (T1.1)). Published as a conference paper at ICLR 2023 C.3 STAGE 1.3 The goal of Stage 1.3 is to make sure w2 = Θ(1/Rv2) for some large constant, so that, by Lemma C.3, the projection operator in v2 can be ignored. That is, we define T1.3 := inf {t T1.2 : w2(t) = Θ(1/Rv2)} . The time needed for this stage is longer than the time needed for previous stages, so we need less brute-force ways to control the errors. For the first layer, we show that the tangent movement is almost zero and the radial movement is approximately uniform. For the second layer, we show that the spread δ(1) 2 cannot grow too fast. Lemma C.10 (Main lemma of Stage 1.3). Stage 1.3 takes at most O(1/d1.5) amount of time. At the end of Stage 1.3, we have w2 = Θ(1/Rv2) and α = Θ( For the errors, the spread of the second layer is O δ(1) 2 (T1.2) and the first layer errors are O δ(1) 1,R(T1.2) + δ(1) 1,T (T1.2) + δ1,I + log(d)δ(1) 2 (T1.2) . Proof. Since v2 = Ω(log d/d1.5) and Rv2 = Θ(d3), Stage 1.3 takes at most O(1/d1.5) amount of time. Within this amount of time, by Lemma C.18, we have (δ(1) 2 (T1.3))2 (δ(1) 2 (T1.2))2 exp O(1) d2.5 1 d1.5 = (1 + o(1))(δ(1) 2 (T1.2))2. For the first layer, by Lemma C.16, we have δ(1) 1,R(T1.3) + δ(1) 1,T (T1.3) δ(1) 1,R(T1.2) + δ(1) 1,T (T1.2) + O(1) d3 δ1,I + O log(d)δ(1) 2 exp O(1) d2.5 1 d1.5 = O δ(1) 1,R(T1.2) + δ(1) 1,T (T1.2) + δ1,I d3 + log(d)δ(1) 2 (T1.2) . Finally, by Lemma C.17, we have α(T1.3) = (1 + o(1))α(T1.2). C.3.1 ESTIMATIONS RELATED TO σ (v2F(x) + r2) First, we need some helper results to handle σ (v2F(x) + r2). The conditions for them to hold are mild and are always true throughout the entire training procedure, and we will use these results in later stages, too. First, we show that when the value of σ (v2F(x) + r2) can change across different (v2, r2), the function value must be small. Note that the error here depends on the ratio δ2/| w2| and this is why we need | w2| to be Θ(d)δ2 instead of merely Θ(1)δ2 at the end of Stage 1.2. Lemma C.11. Suppose that r2 = Θ(1), v2 Ω(δ2) for any (v2, r2) µ2, where δ2 is the spread of the second layer. If v2F(x) + r2 = 0 for some (v2, r2) µ2, then v 2F(x) + r 2 O((| w2| 1 + 1)δ2) for all (v 2, r 2) µ2. Remark. It is not necessary that there really exists a (v2, r2) µ2 with v2F(x) + r2 = 0. As long as v 2F(x) + r 2 0 and v 2F(x) + r 2 0 for some (v 2, r 2), (v 2, r 2) µ2, by the continuity, there always exists some point (v2, r2) between (v 2, r 2) and (v 2, r 2) such that v2F(x) + r2 = 0. Moreover, this point is within the spread of the second layer, so this lemma still applies. Then, we show that we can absorb σ into f and f. Lemma C.12. Suppose that the hypothesis of Lemma C.11 is true, and all second layer neurons are activated on { x 1}. Then, for any (v2, r2) µ2 and x Rd, we have f (x)σ (v2F(x) + r2) = f (x) and f(x)σ (v2F(x) + r2) = f(x) O (| w2| 1 + 1)δ2 . As a corollary, we have f(x) = σ(v2F(x) + r2) O (| w2| 1 + 1)δ2 , f(x) = σ( w2F(x) + b2) O (| w2| 1 + 1)δ2 . Published as a conference paper at ICLR 2023 As a corollary of Lemma C.11, the measure on which σ (v2F(x) + r2) can differ for different (v2, r2) is also small. Here we also use the fact that those x are around Θ(1/| w2α|) the tail bound D (r) O(1/r2). Lemma C.13. Suppose that Induction Hypothesis C.2 is true at time t. For any (v2, r2), (v 2, r 2) µ2, we have E x {|σ (v2F(x) + r2) σ (v 2F(x) + r 2)|} O αδ(1) 2 . Proof of Lemma C.11. For any (v 2, r 2) µ2, we can write v 2F(x) + r 2 = v2F(x) + r2 | {z } = 0 +(v 2 v2)F(x) + (r 2 r2) v2 (v2F(x) + r2 | {z } = 0 r2) + (r 2 r2) = r2 v2 v 2 v2 + (r 2 r2). The last term can be bounded as O((| w2| 1 + 1)δ2). Proof of Lemma C.12. Since all second layer neurons are activated on { x 1}, we always have f (x)σ (v2F(x) + r2) = f (x). Now we consider f(x)σ (v2F(x) + r2). If v2F(x) + r2 > 0, then we are done. If v 2F(x) + r 2 < 0 for all (v 2, r 2) µ2, then both f(x)σ (v2F(x) + r2) and f(x) are 0. Therefore, it suffices to consider the case where v2F(x) + r2 0 while f(x) > 0. By Lemma D.6, in this case, we have f(x) O (| w2| 1 + 1)δ2 . Proof of Lemma C.13. Since the norm and direction of x are independent, it suffices to fix a direction x and consider E r D {|σ (v2r F( x) + r2) σ (v 2r F( x) + r 2)|} . For notational simplicity, define h(v2, r2, r) = v2r F( x)+r2. The integrand is nonzero iff the signs of h(v2, r2, r) and h(v 2, r 2, r) are different. To bound the length of the interval on which the signs can differ, we write h(v2, r2, r) = w2r F( x) + b2 + (v2 w2)r F( x) + (r2 b2) = w2 O δ(1) 2 r F( x) + b2 O δ(1) 2 . Therefore, the length of this interval can be bounded by O(δ(1) 2 /( w2 2α)). Moreover, note that this interval is at Θ(1/| w2α|), whence the density on it is O( w2 2α2). Thus, the measure of this interval is O(αδ(1) 2 ). C.3.2 ESTIMATIONS FOR THE FIRST LAYER Before we control the error growth, we need a lemma that relates the approximation error with the tangent movement and radial spread of the first layer. Lemma C.14. Suppose that the tangent movement and radial spread of the first layer neurons can be bounded as v1(t) v1(0) δ1,T and v1 2 = (1 δ1,R) Ew1 w1 2. Then F(x; µ1) = 1 + δ1,I + dδ1,T α x . As a simple corollary, we have the following. Corollary C.15. Suppose that Induction Hypothesis C.2 is true at time t. Then, we have |f(x) f(x)| = δ1,I + dδ(1) 1,R + dδ(1) 1,R w2α x . As a result, we have n (f(x) f(x)) x o δ1,I + dδ(1) 1,R + dδ(1) 1,R w2α E x 2 dδ(1) 1,R + Published as a conference paper at ICLR 2023 Now, we are ready the control the error of the first layer. Lemma C.16. Suppose that Induction Hypothesis C.2 is true at time t and t [T1.2, T1.3]. Then we have d dt δ(1) 1,R + δ(1) 1,T O δ1,I + dδ(1) 1,R + dδ(1) 1,R w2 + O log(d)δ(1) 2 δ(1) 1,R + δ(1) 1,T + O(1) d3 δ1,I + O log(d)δ(1) 2 . Finally, we estimate the radial speed of v1 to provide an estimation for the magnitude of α at the end of Stage 1. Lemma C.17. Suppose that Induction Hypothesis C.2 is true at time t and t [T1.2, T1.3]. Then we have d dt v1 2 = Θ log d Proof of Lemma C.14. Define N 2 = Ew1 w1 2. Let µ 1 be the distribution obtained by setting the norm of neurons in µ1 to N. We have F(x; µ1) = E w1 µ1 (1 δ1,R)N 2σ( w1 x) = F(x; µ 1) O(δ1,RN 2 x ). Let µ 1 be the distribution obtained by moving v1(t) to v1(0) in µ 1. Then, we have F(x; µ 1) = N 2 E w1 µ1(0) {σ( w1 x)} O δ1,T N 2 x = F(x; µ 1) O δ1,T N 2 x . Finally, note that F(x; µ 1) = N 2 t N 2 0 F(x; µ1(0)) = N 2 t N 2 0 (1 δ1,I)α0 x = (1 δ1,I)αt x . Combine these together and we complete the proof. Proof of Lemma C.16. First, we decompose v1 along the tangent and radial directions as follows: Rad( v1) := v1, v1 v1 = 2 E x {S(x)σ(v1 x)} v1, Tan( v1) := (I v1 v 1 ) v1 = v1 E x S(x)σ (v1 x)(I v1 v 1 )x . Note that v1 = Rad( v1) + Tan( v1). By Lemma C.12, we have Rad( v1) = 2 w2 E x {(f (x) f(x))σ(v1 x)} v1 O log(d)δ(1) 2 v1 , Tan( v1) = v1 w2 E x (f (x) f(x))σ (v1 x)(I v1 v 1 )x O log(d)δ(1) 2 v1 . For the radial term, by Lemma B.3 and Lemma C.15, we have Rad( v1) = 2 w2 E x n (f (x) f(x))σ(v1 x) o v1 + 2 w2 E x n ( f(x) f(x))σ(v1 x) o v1 O log(d)δ(1) 2 v1 n (f (x) f(x)) x o v1 dδ(1) 1,R + dδ(1) 1,R w2 v1 O log(d)δ(1) 2 v1 . Therefore, d dt v1 2 = 2 v1, Rad( v1) n (f (x) f(x)) x o v1 2 dδ(1) 1,R + dδ(1) 1,R w2 v1 2 O log(d)δ(1) 2 v1 2 . Published as a conference paper at ICLR 2023 For any v1, v 1 µ1 with v1 v 1 , we have d dt v1 2 v 1 2 d dt v1 2 v 1 2 v 1 2 v1 2 v 1 2 n (f (x) f(x)) x o v1 2 v 1 2 dδ(1) 1,R + dδ(1) 1,R w2 O log(d)δ(1) 2 v 1 2 4CΓ w2 n (f (x) f(x)) x o v 1 2 O δ1,I + dδ(1) 1,R + dδ(1) 1,R w2 v1 2 v 1 2 v 1 2 O log(d)δ(1) 2 dδ(1) 1,R + dδ(1) 1,R w2 O log(d)δ(1) 2 . Now we consider the tangent movement. By Lemma B.5 and Lemma C.15, we have Tan( v1) = v1 w2 E x n (f (x) f(x))σ (v1 x)(I v1 v 1 )x o + v1 w2 E x n ( f(x) f(x))σ (v1 x)(I v1 v 1 )x o O log(d)δ(1) 2 v1 dδ(1) 1,R + dδ(1) 1,R w2 v1 O log(d)δ(1) 2 v1 . As a result, d dt v1 = Tan( v1) v1 = O δ1,I + dδ(1) 1,R + dδ(1) 1,R w2 O log(d)δ(1) 2 . Combine these two bounds together and we complete the proof. Proof of Lemma C.17. By the proof of Lemma C.16, we have Rad( v1) = 2CΓ w2 n (f (x) f(x)) x o v1 dδ(1) 1,R + dδ(1) 1,R w2 v1 O log(d)δ(1) 2 v1 w2v1 O δ1,I + dδ(1) 1,R + dδ(1) 1,R w2 v1 O log(d)δ(1) 2 v1 . Recall that δ(1) 2 | w2|/d. Hence, d dt v1 2 = Θ log d w2 v1 2 O δ1,I + dδ(1) 1,R + dδ(1) 1,R w2 v1 2 O log(d)δ(1) 2 v1 2 C.3.3 ESTIMATIONS FOR THE SECOND LAYER Now, we bound the growth of the spread of the second layer. Readers may first check the proof of Lemma D.14, which is essentially a simpler case of this result where we do not need to deal with the projections. In Lemma D.14, we show that the spread will never grow. Here, the error comes from the projection. Lemma C.18. Suppose that Induction Hypothesis C.2 is true at time t. Then we have d dt(δ(1) 2 )2 O(1) d2.5 (δ(1) 2 )2. Published as a conference paper at ICLR 2023 Proof. Let (v2, r2), (v 2, r 2) µ2 and define h2(x) = v2F(x) + r2 and h 2(x) = v 2F(x) + r 2. We write v2 = E x 1 {(f (x) f(x))F(x)} E x 1 ΠRv2 [f(x)σ (h2(x))F(x)] =: T1( v2) + T2( v2). T1 does not depend on v2. For T2, note that ΠRv2 [f(x)σ (h2(x))F(x)] = ΠRv2/F (x) [f(x)] σ (h2(x))F(x). Similarly, for r2, we have d dt(r2 r 2)2 = 2 E x 1 {f(x)(σ (h2(x)) σ (h 2(x)))(r2 r 2)} n ΠRv2/F (x)[f(x)](σ (h2(x)) σ (h 2(x)))(r2 r 2) o n f(x) ΠRv2/F (x)[f(x)] (σ (h2(x)) σ (h 2(x)))(r2 r 2) o . Combine these two equations together and we obtain d dt (v2 v 2)2 + (r2 r 2)2 n ΠRv2/F (x) [f(x)] (σ (h2(x)) σ (h 2(x))) (h2(x) h 2(x)) o n f(x) ΠRv2/F (x)[f(x)] (σ (h2(x)) σ (h 2(x)))(r2 r 2) o . Since σ is non-decreasing, the first term is nonpositive. For the second term, by Lemma C.11 and Lemma C.13, it can be bounded as max x:sgn(h2(x)) =sgn(h 2(x)) f(x) E x {|σ (h2(x)) σ (h 2(x))|} |r2 r 2| O α(δ(1) 2 )3 d2.5 (δ(1) 2 )2. The goal of Stage 2 is for gradient flow to converge to a point with loss ε. Similar to Stage 1, we maintain a set of induction hypotheses. Induction Hypothesis D.1. Define T2 := inf{t T1 : L = ε}. Define δ(2) 1,L2, δ1,L , δ(2) 2 as d dtδ(2) 1,L2 = Re LU d , d dtδ(2) 1,L = Re LU d , d dtδ(2) 2 = 0, with initial value satisfying11 ε (δ(2) 1,L )2 δ(2) 1,L2 Θ ε d6 δ(2) 1,L , δ(2) 1,L2 O ε2 , δ(2) 1,L (T1) O ε , δ(2) 2 O ε2 For any t [T1, T2], we say that this Induction Hypothesis is true if the following hold. 11As we have mentioned in the footnote in Induction Hypothesis C.2, these δ s are defined as upper bounds for the corresponding errors. This gives certain degree of freedom in choosing their initial value. By Lemma C.4, we can choose the parameters so that the errors at the beginning of Stage 2 is arbitrarily small and these conditions can indeed be satisfied. The first condition, which requires the L2 error to be left and right controlled by the L error, may seem strange at the first sight. The only reason we need it is to merge some second order error terms into first order ones. Published as a conference paper at ICLR 2023 (a) Error of the first layer. F L2 δ(2) 1,L2 and F|Sd 1 1 L δ(2) 1,L . (b) Spread of the second layer. (v2, r2) (v 2, r 2) δ(2) 2 for all (v2, r2), (v 2, r 2) µ2. (c) Regularity conditions. b2 1 Θ( ε). w2α 1+Θ( ε). | w2| d. | w2| Θ(1/d3). α Θ(1/d1.5). (d) Bounds for the errors. δ(2) 1,L = O(δ(2) 1,L (T1)) and δ(2) 1,L2 = O(δ(2) 1,L2(T1)). The main lemma for Stage 2 is as follows. Lemma D.2 (Stage 2). Induction Hypothesis D.1 is true throughout Stage 2 and Stage 2 takes at most O(d3/ε) amount of time. The rest of this section is organized as follows. In Section D.1, we collect some auxiliary results that will be used later. In Section D.2, we show that Induction Hypothesis D.1 is always true throughout Stage 2. (Also see Section B.1 for discussion on the techniques used and some conventions.) Then, we derive a lower bound on the convergence rate in Section D.3. Finally, we prove Lemma D.2 in Section D.4. D.1 AUXILIARY LEMMAS D.1.1 THE DYNAMICS OF F , f AND L Recall that, in Stage 2, we can ignore the projection operators, whence the dynamics of the neurons is given by v1 = E x {S(x) ( v1σ(v1 x) + v1 σ (v1 x)x)} , v2 = E x {(f (x) f(x))σ (v2F(x) + r2)F(x)} , r2 = E x {(f (x) f(x))σ (v2F(x) + r2)} . Now, we derive the equations which describes the dynamics of α, F, and the loss L. Lemma D.3 (Dynamics of α). In Stage 2, we have d E x {S(x )F(x )} . Lemma D.4 (Dynamics of F). In Stage 2, for each fixed x, we have d dt F(x) = 4 E x S(x ) E w1 {σ(w1 x )σ(w1 x)} n w1 2 σ (v1 x )σ (v1 x) (I v1 v 1 )x , x o . Note that in the above lemma, we decompose d dt F(x) into two terms where the first term corresponds to the radial movement of v1 and the second term the tangent movement. Lemma D.5 (Dynamics of L). Define W2(x) = Ew2,b2 {σ (w2F(x) + b2)w2}. In Stage 2, we have d dt L = E w2,b2,w1 w2,b2,w1 2 , w2,b2,w1 := E x (f (x) f(x)) σ (w2F(x) + b2)F(x) σ (w2F(x) + b2) 2 W2(x)σ(w1 x) w1 W2(x)σ (w1 x)(I w1 w 1 )x The entries of w2,b2,w1 correspond to the movements of v2, r2, radial movement of v1 and tangent movement of v1, respectively. The proofs of these three lemmas are as follows. Published as a conference paper at ICLR 2023 Proof of Lemma D.3. Recall that α := CΓ d Ew1 w1 2. Hence, α = 2CΓ d Ew1 w1, w1 . We compute v1, v1 = E x {S(x) (σ(v1 x) v1, v1 + v1 σ (v1 x) x, v1 )} = 2 E x {S(x) v1 σ(v1 x)} . n E x {S(x) w1 σ(w1 x)} o = 4CΓ S(x) E w1 { w1 σ(w1 x)} = 4CΓ d E x {S(x)F(x)} . Proof of Lemma D.4. First, we write d dt F(x) = d n w1 2 σ( w1 x) o = E w1 dt w1 2 σ( w1 x) + E w1 dtσ( w1 x) . By the proof of Lemma D.3, the first term is 4 Ex {S(x ) Ew1 {σ(w1 x )σ(w1 x)}} . For the second term, we compute d dtσ( v1 x) = σ (v1 x) (I v1 v 1 ) v1 v1 , x = σ (v1 x) D E x S(x )σ (v1 x )(I v1 v 1 )x , x E = E x S(x )σ (v1 x )σ (v1 x) (I v1 v 1 )x , x . Hence, the second term is dtσ( w1 x) = E x n w1 2 σ (v1 x )σ (v1 x) (I v1 v 1 )x , x o . Combine these together and we complete the proof. Proof of Lemma D.5. First, we write d dtf(x) = E w2,b2 {σ (w2F(x) + b2) w2F(x)} + E w2,b2 n σ (w2F(x) + b2) b2 o + W2(x) d dtf(x) + T2 dtf(x) + T3 Note that d dt L = P3 i=1 Ex (f (x) f(x))Ti d dtf(x) . Now we compute each of these three terms separately. We have (f (x) f(x))T1 dtf(x) = E w2,b2 n E x {(f (x) f(x))σ (w2F(x) + b2)F(x)} w2 o E x {(f (x) f(x))σ (w2F(x) + b2)F(x)} 2 , (f (x) f(x))T2 dtf(x) = E w2,b2 n E x {(f (x) f(x))σ (w2F(x) + b2)} b2 o E x {(f (x) f(x))σ (w2F(x) + b2)} 2 . Meanwhile, for T3, by Lemma D.4, we have (f (x) f(x))T3 (f (x) f(x)) W2(x) d (f (x) f(x)) W2(x) E x S(x ) E w1 {σ(w1 x )σ(w1 x)} (f (x) f(x)) W2(x) E x n w1 2 σ (w1 x )σ (w1 x) (I w1 w 1 )x , x o E x {S(x)σ(w1 x)} 2 + E w1 E x S(x) w1 σ (w1 x)(I w1 w 1 )x 2 . Published as a conference paper at ICLR 2023 Combine these together and we complete the proof. D.1.2 ERROR-RELATED ESTIMATIONS We collect some error-related estimations here. Most of them have been proved in Stage 1 except that here we have used | w2| Θ(1/d3) to replace (| w2| 1+1) with O(d3). We repeat the statement here for easier reference. Lemma D.6. Suppose that Induction Hypothesis D.1 is true at time t. If v2F(x) + r2 = 0 for some (v2, r2) µ2, then v 2F(x) + r 2 O d3δ(2) 2 for all (v 2, r 2) µ2. Proof. See Lemma C.11. Lemma D.7. Suppose that Induction Hypothesis D.1 is true at time t. Then, for any (v2, r2) µ2 and x Rd, we have f (x)σ (v2F(x) + r2) = f (x) and f(x)σ (v2F(x) + r2) = f(x) O d3δ(2) 2 . As a corollary, we have f(x) = σ(v2F(x) + r2) O d3δ(2) 2 , f(x) = σ( w2F(x) + b2) O d3δ(2) 2 . Proof. See Lemma C.12. Lemma D.8. Suppose that Induction Hypothesis D.1 is true at time t. Then we have f f L2 O | w2α|δ(2) 1,L2 . Proof. Since σ is 1-Lipschitz, we have |f(x) f(x)| = E w2,b2 n σ(w2F(x) + b2) σ(w2 F(x) + b2) o O | w2||F(x) F(x)| . Thus, f f 2 L2 O w2 2α2 F 2 L2 O w2 2α2(δ(2) 1,L2)2 . D.2 MAINTAINING THE INDUCTION HYPOTHESIS In this section, we show that Induction Hypothesis D.1 is true throughout Stage 2. See Section B.1 for discussion and conventions on the techniques used here. D.2.1 ERROR OF THE FIRST LAYER Recall that we can decompose the loss as n (f (x) f(x))2o + 1 n ( f(x) f(x))2o + E x n (f (x) f(x))( f(x) f(x)) o =: L1 + L2 + L3. As we have discussed in the main text, the goal is to show that L2 w2 2 2 E n ( F(x) F(x))2o and L3 0, so that L can be decomposed into two terms where the first term captures the difference between the target function f and the infinite-width network f, and the second term measures the approximation error between F and F. We will show in Lemma D.11 that, as one may expect, L1 does not affect F. Estimating the gradients of L2 and L3 is slightly more complicated. First we need to introduce the following partition on the input space. Published as a conference paper at ICLR 2023 Lemma D.9. Define R1 := R > 0 : (v2, r2) µ2, x RSd 1, v2F(x) + r2 > 0 , R2 := R > 0 : (v2, r2) µ2, x RSd 1, v2F(x) + r2 > 0 . Then, we partition the input space into X1 := { x R1}, X2 := {R1 x R2}, X3 := {R2 x }. In words, X1 is the largest spherically symmetric set on which all second layer neurons are activated, and X1 X2 is the largest spherically symmetric set on which at least one second layer neuron is activated. Suppose that Induction Hypothesis D.1 is true at time t. Then the following hold. (a) f vanishes on X2 X3, i.e., R1 1, f vanishes on X3, and R3 O(1/| w2|/α). (b) R2 R1 O δ(2) 1,L + dδ(2) 2 1 | w2|α =: δ(2) X2. As a corollary, we have P[X2] O δ(2) X2 (c) f O δ(2) X2 The above lemma implies that L2 1 2 EX1 n ( f(x) f(x))2o = w2 2 2 EX1 n ( F(x) F(x))2o and L3 EX1 n (f (x) f(x))( f(x) f(x)) o = w2 EX1 n (f (x) f(x))( F(x) F(x)) o = 0. We formally establish this approximation in the following lemma. Lemma D.10 (Gradient of L2 and L3). Suppose that Induction Hypothesis D.1 is true at time t. Then, for each v1 µ1, we have w2 2 2 E X1 n ( F(x) F(x))2o O δ(2) X2 v1L3 = O δ(2) X2 Now, we are ready to derive the equation that governs the dynamics of F. Note that this Lemma implies that, at least approximately, the dynamics of F depends only on L2. Lemma D.11 (Dynamics of F). Suppose that Induction Hypothesis D.1 is true at time t. Then, for each fixed x, we have d dt F(x) = w2 2 2 E w1 w1 F(x), w1 E x X1 n ( F(x ) F(x ))2o O Then, we show that the signal term in d dt F(x) can only decrease the L2 error, which is intuitively true as, after all, L2 is the (rescaled) L2 error. As a result, the L2 error barely grows. Lemma D.12 (L2 approximation error). Suppose that Induction Hypothesis D.1 is true at time t. Then we have d dt F 2 L2 O d5δ(2) 1,L2 δ(2) X2 Finally, we show that the change F|Sd 1 depends on the L2 error. As a result, as long as the L2 error is small, the L error cannot grow too fast. Lemma D.13 (L approximation error). Suppose that Induction Hypothesis D.1 is true at time t. Then, for any x Sd 1, we have d dt F( x) O d3δ(2) 1,L2 + d2 δ(2) X2 The proofs of these lemmas are as follows. Published as a conference paper at ICLR 2023 Proof of Lemma D.9. (a) This one follows directly from the construction of the partition and Induction Hypothesis D.1. (b) First, we write F(x) = α x + α x ( F( x) 1) = α x α x δ(2) 1,L = α x O δ(2) 1,L | w2| where the last equality comes from the fact f vanishes on { x Ω( b2/(α| w2|))}. Similarly, for any (v2, r2) µ2, we have v2F(x) + r2 = v2 r2 = w2F(x) + b2 O δ(2) 2 = w2F(x) + b2 O d3δ(2) 2 . Hence, for any R > 0 and x RSd 1, we have v2F(x) + r2 = w2α x + b2 O δ(2) 1,L O d3δ(2) 2 | {z } =: δTmp v2F(x) + r2 > 0, if x < b2 δTmp w2α = R δTmp v2F(x) + r2 < 0, if x > b2 + δTmp w2α = R + δTmp In other words, R1 R δTmp w2α and R2 R + δTmp w2α. Thus, R2 R1 δTmp w2α O δ(2) 1,L + O d3δ(2) 2 d4.5 = δX2. To complete the proof, it suffices to invoke Lemma B.1. (c) Note that by the definition of R2, for any x0 R2Sd 1, we have f(x0) = 0. Hence, for any x X2, there exists some x0 with f(x0) = 0 and x x0 R2 R1 = δ(2) X2. Since f is O(1)-Lipschitz, we have, for any x X2, f(x) = f(x) f(x0) O(δ(2) X2). Proof of Lemma D.10. Since both f and f vanishes on X3, it suffices to consider X1 and X2. Recall that that all second layer neurons are activated on X1. Hence, n ( f(x) f(x))2o = w2 2 2 E X1 n ( F(x) F(x))2o , n (f (x) f(x))( f(x) f(x)) o = w2 E X1 n (f (x) f(x))( F(x) F(x)) o = 0, where the last equality comes from Corollary B.4. Now, we bound the influence of X2. Note that both v1f(x) and v1 f(x) are bounded by O(| w2| v1 x ). Recall from Lemma D.9 that f O(δ(2) X2) on X2 and P[X2] O(δ(2) X2). Therefore, v1L2 O(δ(2) X2) O δ(2) X2 O | w2| v1 1 | w2|α The proof for v1L3|X2 is the same. Published as a conference paper at ICLR 2023 Proof of Lemma D.11. For fixed x Rd, we write d dt F(x) = α E w1 w1F(x), w1L + F(x) 1 α E w1 w1α, w1L . First, we consider L1. For each v1 µ1, we have n (f (x) f(x)) v1 f(x) o (f (x) f(x)) E w2,b2 {σ(w2α x + b2)w2} v1 =: CTmp,1v1. Meanwhile, note that v1F(x), v1 = D v1( v1 2 σ( v1 x)), v1 E = D v1( v1 2)σ( v1 x), v1 E = 2 v1 2 σ( v1 x), v1α, v1 = CΓ D v1 v1 2 , v1 E = 2CΓ Hence, d dt F(x) L1 := 1 α E w1 w1F(x), w1L1 + F(x) 1 α E w1 w1α, w1L1 = CTmp,1 2 α E w1 n w1 2 σ( w1 x) o + CTmp,1 F(x) 1 d E w1 w1 2 = CTmp,1 2 αF(x) + 2CTmp,1 F(x) Namely, L1 does not affect F. Now we consider L2. By Lemma D.10, we have d dt F(x) L2 := 1 α E w1 w1F(x), w1L2 + F(x) 1 α E w1 w1α, w1L2 α w2 2 2 E w1 w1F(x), w1 E x X1 n ( F(x ) F(x ))2o α w2 2 2 F(x) E w1 w1α, E x X1 n ( F(x ) F(x ))2o Note that we can rewrite the w1F(x) in the first term as ( w1α) F(x) + α w1 F(x) so that part of it cancel with the second term. Then, we get d dt F(x) L2 = w2 2 2 E w1 w1 F(x), w1 E x X1 n ( F(x ) F(x ))2o O For L3, we can simply merge it into the error term of d dt F(x)|L2. Proof of Lemma D.12. By Lemma D.11, we have d dt F 2 L2 = E x ( F(x) x ) d = w2 2 2 E x ( F(x) x ) E w1 w1 F(x), w1 E x X1 n ( F(x ) F(x ))2o ( F(x) x )O The second term can be bounded by O δ(2) 1,L2 δ(2) X2 2 d5 . The first term is equal to Tmp := w2 2 4 E w1 w1 E x ( F(x) x )2 , w1 E x X1 (α x F(x ))2 . Published as a conference paper at ICLR 2023 To complete the proof, it suffices to show that this is negative. For each w1, we have (α x F(x ))2 = E x X1 ( F(x ) x )2 w1α2 + α2 E x X1 w1( F(x ) x )2 . Since the distribution of x is spherically symmetric, Ex X1 w1( F(x ) x )2 and Ex w1( F(x) x )2 have the same direction. Hence, Tmp w2 2 4 E w1 n D w1 E x ( F(x) x )2 , w1α2Eo E x X1 ( F(x ) x )2 d w2 2α E x X1 ( F(x ) x )2 E x w1( F(x) x )2, w1 . Then, we compute w1( F(x) x )2, w1 = 2( F(x) x ) w1F(x) = 2( F(x) x ) 2 w1 2 σ( w1 x) Take expectation over w1 and one can see that this is 0. Thus, Tmp 0. Proof of Lemma D.13. Recall from Lemma D.11 that d dt F(x) = w2 2 2 E w1 w1 F(x), w1 E x X1 n ( F(x ) F(x ))2o O For the first term, we have w1 F(x) w1F(x) , w1 E x X1 F(x ) F(x ) 2 E x X1 n F(x ) F(x ) w1 F(x ) + w1F(x ) o O(1) E x X1 n F(x ) F(x ) x o w1 δ(2) 1,L2 1 p Thus, d dt F(x) O α δ(2) 1,L2 1 p α δ(2) 1,L2 O d3δ(2) 1,L2 + d2 δ(2) X2 D.2.2 SPREAD OF THE SECOND LAYER Lemma D.14. Suppose that Induction Hypothesis D.1 is true at time t. Then for any (v2, r2), (v 2, r 2) µ2, d dt (v2, r2) (v 2, r 2) 2 0. In words, the spread of the second layer never grows. Published as a conference paper at ICLR 2023 Proof. Let (v2, r2), (v 2, r 2) µ2 be two second layer neurons. For notational convenience, we define h2(x) = v2F(x) + r2 and h 2(x) = v 2F(x) + r 2. We have 1 2 d dt (v2 v 2)2 + (r2 r 2)2 = (v2 v 2) E x {(f (x) f(x))F(x) (σ (h2(x)) σ (h 2(x)))} + (r2 r 2) E x {(f (x) f(x)) (σ (h2(x)) σ (h 2(x)))} = E x {(f (x) f(x)) (h2(x) h 2(x)) (σ (h2(x)) σ (h 2(x)))} . By Lemma D.9, σ (h2(x)) σ (h 2(x)) = 0 for all x with x 1. Hence, 1 2 d dt (v2 v 2)2 + (r2 r 2)2 = E x: x >1 {(f (x) f(x)) (h2(x) h 2(x)) (σ (h2(x)) σ (h 2(x)))} = E x: x >1 {f(x) (h2(x) h 2(x)) (σ (h2(x)) σ (h 2(x)))} . Note that f 0 and, since σ is non-decreasing, (h2(x) h 2(x)) (σ (h2(x)) σ (h 2(x))) 0. Thus, 1 2 d dt (v2 v 2)2 + (r2 r 2)2 0. D.2.3 REGULARITY CONDITIONS As we have mentioned earlier, we will mainly use the continuity argument to maintain the regularity conditions, so the problem can be reduced into estimating the derivative on the boundary. As an example, suppose that b2 = 1 δ for some small δ > 0. Then by Lemma D.15, which upper bounds the loss using 1 b2 and 1 w2α, we know | 1 w2α| must be large, otherwise we would have L < ε. Then, we can use the fact that | 1 w2α| is large to estimate the derivative. The proof for the other regularity conditions is similar except the proof for | w2|, which is in the same spirit with the ones for first layer errors. Lemma D.15. Suppose that Induction Hypothesis D.1 is true at time t. Then we have L O (1 b2)2 + ( 1 w2α)2 w2 2α2 + δ(2) 1,L + d3δ(2) 2 2 . Lemma D.16. Suppose that Induction Hypothesis D.1 is true at time t and b2 = 1 Θ( ε). Then, d dt b2 < 0. Lemma D.17. Suppose that Induction Hypothesis D.1 is true at time t and w2α = 1 + Θ( ε). Then we have d dt( w2α) > 0. Lemma D.18. Suppose that Induction Hypothesis D.1 is true throughout Stage 2. Then | w2| d. Lemma D.19. Suppose that Induction Hypothesis D.1 is true throughout Stage 2. Then | w2| Θ(1/d3) and α Θ(1/d1.5). The proofs of this subsubsection are gathered bellow. Proof of Lemma D.15. For any x Rd, by Lemma D.7 and the Lipschitzness of σ, we have, for any x X1 X2, f(x) = σ( w2α F(x) + b2) O d3δ(2) 2 = σ(1 x ) |1 b2| x w2α F(x) O d3δ(2) 2 . By Induction Hypothesis D.1, for any x X1 X2, we have x w2α F(x) = 1 w2α F( x) x | 1 w2α| x + 1 F( x) | w2|α x O | 1 w2α| + O δ(2) 1,L . Published as a conference paper at ICLR 2023 f(x) = f (x) |1 b2| O | 1 w2α| O δ(2) 1,L + d3δ(2) 2 . 2 E x (f (x) f(x))2 1 |1 b2| + O | 1 w2α| + O δ(2) 1,L + d3δ(2) 2 2 O (1 b2)2 + ( 1 w2α)2 w2 2α2 + δ(2) 1,L + d3δ(2) 2 2 . Proof of Lemma D.16. By Lemma D.7, for any (v2, r2) µ2, we have r2 = E x f (x) σ( w2F(x) + b2) O d3δ(2) 2 . Then, by Induction Hypothesis D.1 and the Lipschitzness of σ, we have σ( w2F(x) + b2) = σ( w2α x F( x) + b2) = σ( w2α x + b2) O δ(2) 1,L . Therefore, b2 = E x f (x) σ( w2α x + b2) O δ(2) 1,L + d3δ(2) 2 . Since L ε, by Lemma D.15, we have w2 2α2 Ω(ε) O(δ2) O δ(2) 1,L + d3δ(2) 2 2 Ω(ε). Since w2α 1, this implies w2α 1+Ω(| w2|α ε) . In fact, this implies w2α 1+Ω( ε) even when | w2|α is o(1), as, in that case, w2α 1 + Ω( ε) directly holds. Hence, σ( w2α x + b2) σ 1 + Ω ε x + 1 δ = σ 1 x + Ω ε x δ . b2 = E x f (x) σ 1 x + Ω ε x δ O δ(2) 1,L + d3δ(2) 2 1 x 1 x + Ω ε x δ + O δ(2) 1,L + d3δ(2) 2 = Ω ε + δ + O δ(2) 1,L + d3δ(2) 2 . As long as the constant in δ = Θ( ε) is sufficiently small, this implies b2 < 0 when b2 = 1 δ. Proof of Lemma D.17. By Lemma D.3 and Lemma D.7, we have d dt( w2α) = E x f (x) σ( w2α x F( x) + b2) F(x) α + 4CΓ w2 2 O d3 log(d)δ(2) 2 α α + 4CΓ w2 2 Now we estimate the coefficient of the first term. Suppose that w2α = 1 + δ for some δ Θ( ε) with a sufficiently small constant. Then, by Lemma D.15, we have (1 b2)2 Ω(ε) O(δ2) = Ω(ε). Hence, b2 1 Θ( ε). Also note that w2α = Θ(1) implies that it suffices to consider x with x = Θ(1). As a result, we have σ( w2α x F( x) + b2) = σ( w2α x + b2) O δ(2) 1,L σ 1 x Θ( ε) + O δ(2) 1,L . Published as a conference paper at ICLR 2023 Then, we decompose the coefficient as E x f (x) σ( w2α x F( x) + b2) F(x) = E f (x) σ( w2α x F( x) + b2) F(x) n Θ( ε) O δ(2) 1,L F(x) o Thus, d dt( w2α) Ω( ε) O d3δ(2) 2 log(d) α α + 4CΓ w2 2 Proof of Lemma D.18. By Lemma D.3 and Lemma D.7, we have w2 = E x {(f (x) f(x))F(x)} d3 log dδ(2) 2 d E x {(f (x) f(x))F(x)} w2 O d2.5(log d)δ(2) 2 As a result, d dt O d4δ(2) 2 . Also recall that w2 2 α at T1. Thus, throughout Stage 2, we always have α 2CΓ d w2 2 1/d. Since | w2α| 1, this implies | w2| O(d1/6) d. Proof of Lemma D.19. Recall from the proof of Lemma D.18 that |α 2CΓ d w2 2| 1/d. Hence, when α = Θ(1/d1.5), we have | w2| O(1/d). The estimations in Stage 1, mutatis mutandis, show that both α and | w2| will grow in this case. D.3 CONVERGENCE RATE Recall from Lemma D.5 that d dt L = Ew2,b2,w1 w2,b2,w1 2 , where w2,b2,w1 := E x (f (x) f(x)) σ (w2F(x) + b2)F(x) σ (w2F(x) + b2) 2 W2(x)σ(w1 x) w1 W2(x)σ (w1 x)(I w1 w 1 )x Lemma D.20. Suppose that Induction Hypothesis D.1 is true at time t. Then we have d dt L 2 + O δ(2) 1,L2 + d3δ(2) 2 d4 , (f (x) f(x)) Lemma D.21. Suppose that Induction Hypothesis D.1 is true at time t. Then we have Ω(αL) O δ(2) 1,L2 + d3δ(2) 2 . Lemma D.22 (Stage 2). Suppose that Induction Hypothesis D.1 is true throughout Stage 2. Then T2 T1 O(d3/ε). Proof of Lemma D.20. Since it is the norm of w2,b2,w1, we can safely ignore the last entry and only consider the first three entries. By Lemma D.7, we have [ w2,b2,w1]1:3 = E x (f (x) f(x)) " F(x) 1 2 w2σ(w1 x) " α log(d) 1 w2 w1 log(d) Published as a conference paper at ICLR 2023 Furthermore, we have E x {(f (x) f(x))F(x)} = E x {(f (x) f(x))α x } + E x (f (x) f(x))α( F(x) x ) = E x {(f (x) f(x))α x } + O αδ(2) 1,L2 . Meanwhile, for [ w2,b2,w1]3, by Lemma B.3 and Lemma D.8, we have 2 w2 E x {(f (x) f(x))σ(w1 x)} n (f (x) f(x))σ(w1 x) o + 2 w2 E x n ( f(x) f(x))σ(w1 x) o n (f (x) f(x)) x o w1 2 w2 w1 f f L2 n (f (x) f(x)) x o w1 O | w2|1.5α0.5 w1 δ(2) 1,L2 . Repeat the above procedure and we can replace the f in the first term with f. Therefore, [ w2,b2,w1]1:3 = E x (f (x) f(x)) α x 1 2CΓ w2 α 0 | w2|1.5α0.5 w1 " α log(d) 1 w2 w1 log(d) (f (x) f(x)) α x 1 2CΓ w2 δ(2) 1,L2 + d3δ(2) 2 " α log(d) 1 | w2| w1 log(d) Now, we estimate the the expected norm of [ w2,b2,w1]1:3. First, we have [ w2,b2,w1]2 1 = E x {(f (x) f(x)) x } 2 α2 O δ(2) 1,L2 + d3δ(2) 2 α2 log(d) , [ w2,b2,w1]2 2 = E x {f (x) f(x)} 2 O δ(2) 1,L2 + d3δ(2) 2 . For [ w2,b2,w1]3, we have E w1[ w2,b2,w1]3 2 = 4C2 Γ w2 2 d E x {(f (x) f(x)) x } 2 E w1 w1 2 δ(2) 1,L2 + d3δ(2) 2 w2 2 E w 2 1 = E x {(f (x) f(x)) x } 2 4CΓ w2 2 O δ(2) 1,L2 + d3δ(2) 2 w2 2α log(d) . [ w2,b2,w1]1:3 2 = E x {(f (x) f(x)) x } 2 α2 + 4CΓ w2 2 d α + E x {f (x) f(x)} 2 O δ(2) 1,L2 + d3δ(2) 2 d4 = 2 O δ(2) 1,L2 + d3δ(2) 2 d4 . Published as a conference paper at ICLR 2023 Proof of Lemma D.21. For notational simplicity, put A := q d w2 2α. Then we can write (f (x) f(x)) A x 1 Define ˆ = 1 α w2 A(1 b2) By Induction Hypothesis D.1, ˆ O(1). Hence, in order to lower bound , it suffices to lower bound D , ˆ E . We have D , ˆ E = A E x (f (x) f(x)) x + 1 (α w2 x + b2) . First, for those x { x 1}, we have f (x) = x + 1 and f(x) = w2F(x) + b2 = w2α x + b2 + w2α( F(x) x ). Hence, we have (f (x) f(x)) x + 1 (α w2 x + b2) (f (x) f(x))2 + E x 1 (f (x) f(x)) w2α( F(x) x ) (f (x) f(x))2 O | w2α|δ(2) 1,L2 . Then, for x { x 1}, note that x + 1 0 and f (x) = 0. Therefore, we have (f (x) f(x)) x + 1 (α w2 x + b2) f(x) x + 1 (α w2 x + b2) E x 1 f(x)(α w2 x + b2) . Then, we compute f(x)(α w2 x + b2) f(x)( w2F(x) + b2) + E x 1 f(x)(α w2( x F(x))) f 2(x) O d3δ(2) 2 O α w2δ(2) 1,L2 . where the second equality comes from Lemma D.7 and Induction Hypothesis D.1. Combine these two cases together and we obtain D , ˆ E A E x (f (x) f(x))2 O | w2α|δ(2) 1,L2 O d3δ(2) 2 . Finally, note that A α. Thus, Ω(αL) O δ(2) 1,L2 + d3δ(2) 2 . Proof of Lemma D.22. By Lemma D.20 and Lemma D.21, d dt L Ω(α2L2) + O δ(2) 1,L2 + d3δ(2) 2 d4 Ω L2 Thus, for any T [T1, T2], L(T) Ω d 3 (T T1) + 1 L(T1) Thus, it takes at most O(d3/ε) amount of time for L to reach ε. Published as a conference paper at ICLR 2023 D.4 PROOF OF THE MAIN LEMMA Proof of Lemma D.2. The Induction Hypothesis is maintained in Section D.2 and by Lemma D.22, we have T2 T1 O(d3/ε). Now we consider the first layer errors. Recall that d dt(δ(2) 1,L2)2 = O d5δ(2) 1,L2 δ(2) X2 d dtδ(2) 1,L = O d3δ(2) 1,L2 + d2 δ(2) X2 Recall that δX2 := O(1)d4.5(δ(2) 1,L + d3δ(2) 2 ). For simplicity, we choose δ(2) 1,L d3δ(2) 2 so that δX2 = O(d4.5δ(2) 1,L ). Then, we have d dt(δ(2) 1,L2)2 = O d14δ(2) 1,L2(δ(2) 1,L )2 , d dtδ(2) 1,L = O d3δ(2) 1,L2 + d11(δ(2) 1,L )2 . We choose δ(2) 1,L2(T1) and δ(2) 1,L (T1) such that ε (δ(2) 1,L )2 δ(2) 1,L2 Θ ε d6 δ(2) 1,L and δ(2) 1,L (T1) Θ ε Note that this is possible because δ(2) 1,L2(T1) and δ(2) 1,L (T1) can be chosen to be arbitrarily polynomially small. When this is true, we have d dt(δ(2) 1,L2)2 O ε d3 (δ(2) 1,L2)2 and d dtδ(2) 1,L = O ε d3 δ(2) 1,L . Thus, by induction, within O(d3/ε) amount of time, these two errors can at most O(δ(2) 1,L2(T1)) and O(δ(2) 1,L (T1)), respectively. E FROM GRADIENT FLOW TO GRADIENT DESCENT Converting the above gradient flow argument to a gradient descent one can be done in a standard one, provided that we can generate fresh samples at each iteration. First, by choosing a sufficiently small step size, one can make sure within each step, the difference between gradient descent and gradient flow is inverse polynomially small. Note that our argument is built upon the induction hypotheses. Hence, we do not need to worry about the accumulation of errors. Moreover, our estimations can tolerate an inverse polynomially large error. Then, at each step of gradient descent, we generate sufficiently (but still polynomially) many samples to ensure that with high probability, the difference between the population gradient and the finite-sample gradient is sufficiently small. Since it only takes polynomial iterations to finish the process, the total amount of samples needed is polynomial.