# gradientbased_feature_learning_under_structured_data__5d51dfd6.pdf Gradient-Based Feature Learning under Structured Data Alireza Mousavi-Hosseini1, Denny Wu2, Taiji Suzuki3, Murat A. Erdogdu1 1University of Toronto and Vector Institute, 2New York University and Flatiron Institute, 3University of Tokyo and RIKEN AIP {mousavi,erdogdu}@cs.toronto.edu, dennywu@nyu.edu, taiji@mist.i.u-tokyo.ac.jp Recent works have demonstrated that the sample complexity of gradient-based learning of single index models, i.e. functions that depend on a 1-dimensional projection of the input data, is governed by their information exponent. However, these results are only concerned with isotropic data, while in practice the input often contains additional structure which can implicitly guide the algorithm. In this work, we investigate the effect of a spiked covariance structure and reveal several interesting phenomena. First, we show that in the anisotropic setting, the commonly used spherical gradient dynamics may fail to recover the true direction, even when the spike is perfectly aligned with the target direction. Next, we show that appropriate weight normalization that is reminiscent of batch normalization can alleviate this issue. Further, by exploiting the alignment between the (spiked) input covariance and the target, we obtain improved sample complexity compared to the isotropic case. In particular, under the spiked model with a suitably large spike, the sample complexity of gradient-based training can be made independent of the information exponent while also outperforming lower bounds for rotationally invariant kernel methods. 1 Introduction A fundamental feature of neural networks is their adaptivity to learn unknown statistical models. For instance, when the learning problem exhibits certain low-dimensional structure or sparsity, it is expected that neural networks optimized by gradient-based algorithms can efficiently adapt to such structure via feature/representation learning. A considerable amount of research has been dedicated to understanding this phenomenon under various assumptions and to demonstrate the superiority of neural networks over non-adaptive methods such as kernel models [GMMM19, WLLM19, BES+19, LMZ20, AAM22, BES+22, DLS22, Tel23, MHPG+23]. A particular relevant problem setting for feature learning is the estimation of single index models, where the response y R depends on the input x Rd via y = g( u, x ) + ϵ, where g : R R is the nonlinear link function and u is the unit target direction. Here, learning corresponds to recovering the unknowns u and g, which requires the model to extract and adapt to the low-dimensional target direction. Recent works have shown that the sample complexity is determined by certain properties of the link function g. In particular, the complexity of gradient-based optimization is captured by the information exponent of g introduced by [BAGJ21]. Intuitively, a larger information exponent s corresponds to a more complex g (for gradient-based learning), and it has been proven that when the input is isotropic x N(0, Id), gradient flow can learn the single index model with O(ds) sample complexity [BBSS22]. 37th Conference on Neural Information Processing Systems (Neur IPS 2023). In practice, however, real data always exhibits certain structures such as low intrinsic dimensionality, and isotropic data assumptions fail to capture this fact. In statistics methodology, it is known that the directions along which the input x has high variance are often good predictors of the target y [HTFF09]; indeed, this is the main reason principal component analysis is used in pretraining [JWHT13]. A fundamental model that captures such a structure is the spiked matrix model in which x N(0, Id +κθθ ) for some unit direction θ Rd and κ > 0 [Joh01]. Along the direction θ, data has higher variability and predictive power. In single index models, such predictive power translates to a non-trivial alignment between the vectors u and θ our focus is to investigate the effect of such alignment on the sample complexity of gradient-based training. 1.1 Contributions: learning single index models under spiked covariance Gy Oz LWylg3X1P7K56jx OQoc NWUuw K19wv V+n Yj+Hxy/6g/e9Hc+vu7t Hazvu EGekmfk BRm QIdkjhy Qh Yy KJRfkny PXkafoi8Ra6VRZz3zh Fyr CH4Cbh LUng=Hard n / ds r2 (spike magnitude) w=Intermediate r1 (spike-target alignment) sample complexity High Low n / d1+(s 1)(1 r2) n / d1+(s 1)(1 2(r2 r1)) 8Lk/PI0HJ293d9wj8hj8p SE5Iick Pck IWMi SE6+ke/kh/f MO/Um3qd O6v V2NQ/Jlf DELw C0k M=r1 = r2 sk7Ep Mp EUSTb+Q7+e E9ybe R2/WSb3eru YRu RJe+gu YDd J/2r1 = r2 Figure 1: Sample complexity to learn u and g under the spiked model. Smaller r1 denotes a better spike-target alignment, while larger r2 denotes a larger spike magnitude. The sample complexities are based on Corollary 8. In this paper, we study the sample complexity of learning a single index model using a two-layer neural network and show that it is determined by an interplay between spike-target alignment: u, θ d r1, r1 [0, 1/2], spike magnitude: κ dr2, for r2 [0, 1]. Our contributions can be summarized as follows. 1. We show that even in the case of perfect spiketarget alignment (r1 = 0), the spherical gradient flow commonly employed in recent literature (see e.g. [BAGJ21, BBSS22]) cannot recover the target direction for moderate spike magnitudes in the population limit. The failure of this covariance-agnostic procedure under anisotropic structure insinuates the necessity of an appropriate covariance-aware normalization to effectively learn the single index model. 2. We show that a covariance-aware normalization that resembles batch normalization resolves this issue. Indeed, the resulting gradient flow can successfully recover the target direction u in this case, and depending on the amount of spike-target alignment, the sample complexity can significantly improve compared to the isotropic case. 3. Under the spiked covariance model, we prove a three-stage phase transition for the sample complexity depending on the quantities r1 and r2. For a suitable direction and magnitude of the spike, the sample complexity can be made O(d3+ν) for any ν > 0 which is independent of the information exponent s. This should be compared against the known complexity of O(ds) under isotropic data. 4. We finally show that preconditioning the training dynamics with the inverse covariance improves the sample complexity. This is particularly significant for the spiked covariance model where O(d3+ν) samples can be reduced to O(d1+ν) for any ν > 0, i.e. almost linear in d. The three-stage phase transition also emerges, as illustrated in Figure 1: in the hard regime, the complexity remains O(ds) regardless of the magnitude and direction of the spike, while in the easy regime the complexity only depends on the spike magnitude and not its direction. The intermediate regime interpolates between these two; smaller r1 and larger r2 improve the sample complexity. The rest of the paper is organized as follows. We discuss the notation and the related work in the remainder of this section. We provide preliminaries on the statistical model and the training procedure in Section 2, and provide a negative result on the covariance-agnostic gradient flow in Section 2.1. Our main sample complexity result on a single neuron is presented in Section 3.2. We provide our results on multi-neuron neural networks in Section 4 and also discuss extensions such as preconditioning and its implications. We provide a technical summary in Section 5 and conclude in Section 6. Notation. We use , and to denote Euclidean inner product and norm. For matrices, denotes the usual operator norm, and λmax( ) and λmin( ) denote the largest and smallest eigenvalues respectively. We reserve γ for the standard Gaussian distribution on R, and let γ denote the L2(γ) norm. Sd 1 is the unit d-dimensional sphere. For quantities a and b, we will use a b to convey there exists a constant C (a universal constant unless stated otherwise, in which case may depend on polylogarithmic factors of d) such that a Cb, and a b signifies that a b and b a. 1.2 Further related work Non-Linear Feature Learning with Neural Networks. Recently, two popular scaling regimes of neural networks have emerged for theoretical studies. A large initialization variance leads to the lazy training regime, where the weights do not move significantly, and the training dynamics is captured by the neural tangent kernel (NTK) [JGH18, COB19]. However, there are many instances of function classes that are efficiently learnable by neural networks and not efficiently learnable by the NTK [YS19, GMMM19]. Under a smaller initialization scaling, gradient descent on infinite-width neural networks becomes equivalent to Wasserstein gradient flow on the space of measures, known as the mean-field limit [CB18, RVE18, MMN18, MMM19, NWS22, Chi22], which can learn certain low-dimensional target functions efficiently [WLLM19, AAM22, HC22, ASKL23]. As for neural networks with smaller width, recent works showed that a two-stage feature learning procedure can outperform the NTK when the data is sampled uniformly from the hypercube [BEG+22] or isotropic Gaussian [DLS22, BES+22, BBSS22, MHPG+23, ABAM23]. However, these results do not take into account the additional structure that might be present in the covariance matrix of the input data. Two notable exceptions are [GMMM20, RGKZ21], where the authors analyzed a spiked covariance and Gaussian mixture data, respectively. Our setting is closer to [GMMM20], however, they do not provide optimization guarantees through gradient-based training. Furthermore, in a companion work [BES+23], we zoom into the setting where the spike and target are perfectly aligned (r1 = 0), and prove learnability in the n d regime for both kernel regression and two-layer neural network. Finally, we go over some results concurrent to our work in Appendix A. Learning Single Index Models. The problem of estimating the relevant direction in a single index model is classical in statistics [LD89], with efficient dedicated algorithms ([KKSK11, CM20] among others). However, these algorithms are non-standard and instead, we are concerned with standard iterative algorithms like training neural networks with gradient descent. Recently, [DH18] considered an iterative optimization procedure for learning such models with a polynomial sample complexity that is controlled by the smoothness of the link function. [BES+22] considered the effect of taking a single gradient step on the ability of a two-layer neural network to learn a single index model, and [BBSS22, MHPG+23] considered training a special two-layer neural network architecture where all neurons share the same weight with gradient flow or online SGD. However, these works only consider the isotropic Gaussian input, and the effect of anisotropy in the covariance matrix when training a neural network to learn a single index model has remained unclear. Training a Single Neuron with Gradient Descent. When training the first layer, we consider a setting where there is only one effective neuron. A large body of works exists on training a single neuron using variants of gradient descent. In the realizable setting (i.e. identical link and activation), the typical assumptions on the activation correspond to information exponent 1 as the activations are required to be monotone or have similar properties, see e.g. [Sol17, YO20, DKTZ22]. In the agnostic setting, [FCG20] considered initializing from the origin which is a saddle point for information exponent larger than 1. [ATV22] also considered the agnostic learning of a Re LU activation, albeit their sample complexity is not explicit other than being polynomial in dimension. 2 Preliminaries: Statistical Model and Training Procedure For a d-dimensional input x and a link function g L2(γ), consider the single index model y = g u,x Σ1/2u + ϵ with x N(0, Σ), (2.1) where ϵ is a zero-mean noise with O(1) sub-Gaussian norm and u Sd 1. Learning the model (2.1) corresponds to approximately recovering the unknown link g and the unknown direction u. Note that a normalization is needed to make this problem well-defined; without loss of generality, we write u, x / Σ1/2u to ensure that the input variance and the scaling of g both remain independent of the conditioning of Σ. For this learning task, we will use a two-layer neural network of the form ˆy(x; W , a, b) := i=1 aiϕ( wi, x + bi), (2.2) where W = {wi}m i=1 is the m d matrix whose rows corresponds to first-layer weights wi, a = {ai}m i=1 denote the second-layer weights, b = {bi}m i=1 denote the biases, and ϕ is the non-linear activation function. We assume g and ϕ are weakly differentiable with weak derivatives g and ϕ respectively, and g, g , ϕ, ϕ L2(γ). We are interested in the high-dimensional regime; thus, d is assumed to be sufficiently large throughout the paper. Our ultimate goal is to learn both unknowns g and u by minimizing the population risk R(W , a, b) := 1 2 E (ˆy(x; W , a, b) y)2 , (2.3) using a gradient-based training method such as gradient flow. We follow the two-step training procedure employed in recent works [BES+22, MHPG+23, BBSS22, DLS22]: First, we train the first-layer weights W to learn the unknown direction u; at the end of this stage, the neurons wi align with u. Here, the goal is to recover only the direction. Next, using random biases and training the second-layer weights, we obtain a good approximation for the unknown link function g. In the majority of this work, we focus on the first part of this two-stage procedure as the alignment between wi s and u essentially determines the sample complexity of the overall procedure. This problem is somewhat equivalent to the simplified problem of minimizing (2.3) with m = 1, a1 = 1, b1 = 0, i.e., ˆy(x; W , a, b) is replaced with ˆy(x; w) := ϕ( w, x ) and we write R(w) := R(W , a, b) for simplicity. We emphasize that unless ϕ = g (i.e. the link function is known), the first stage of training only recovers the relevant direction u and is not able to approximate g. Indeed, m > 1 is often needed to learn the non-linear link function; this is the focus of Section 4.2 where we derive a complete learnability result for a two-layer neural network with m > 1. Characteristics of the link function play an important role in the complexity of learning the model. As such, a central part of our analysis will rely on a particular property based on the Hermite expansion of functions in a basis defined by the normalized Hermite polynomials {hj}j 0 given as hj(z) = ( 1)jez2/2 j! dj dzj e z2/2. (2.4) These polynomials form an orthonormal basis in the space L2(γ), and the resulting expansion yields the following measure of complexity for g, which is termed as the information exponent. Definition 1 (Information exponent). Let g = P j 0 αjhj be the Hermite expansion of g. The information exponent of g is defined to be s := inf{j > 0 : αj = 0}. This concept was introduced in [BAGJ21] in a more general framework, and our definition is more in line with the setting in [BBSS22]. We remark that the definition of [BAGJ21] can be modified to handle anisotropy in which case one arrives at Definition 1. We provide a detailed discussion on this concept together with some properties of the Hermite expansion in Appendix B. Throughout the paper, we assume that the information exponent does not grow with dimension. In the case where the d-dimensional input data is isotropic, [BBSS22] showed that learning a single index target with full-batch gradient flow requires a sample complexity of O(ds) for s 3 where s is the information exponent of g. We will show that this sample complexity can be improved under anisotropy. More specifically, if the input covariance Σ has non-trivial alignment with the unknown direction u, we prove in Section 3 that the resulting sample complexity can be even made independent of the information exponent if we use a certain normalization in the training. In what follows, we prove that such a normalization in training procedure is indeed necessary. 2.1 The spiked model and limitations of covariance-agnostic training In practice, data often exhibit a certain structure which may have a profound impact on the statistical procedure. A well-known model that captures such a structure is the spiked model [Joh01] for which one or several large eigenvalues of the input covariance matrix Σ are separated from the bulk of the spectrum (see also [BBAP05, BS06]). Although our results hold for generic covariance matrices, they reveal interesting phenomena under the following spiked model assumption. Assumption 1. The covariance Σ follows the (κ, θ)-spiked model if Σ = Id+κθθ 1+κ where θ = 1. In pursuit of the target (unit) direction u, the magnitude of the neuron w is immaterial; thus, recent works take advantage of this and simplify the optimization trajectory by projecting w onto unit sphere Sd 1 throughout the training process [BAGJ21, BBSS22]. In the sequel, we study the same dynamics which is agnostic to the input covariance in order to motivate our investigation of normalized gradient flow in Section 3. More specifically, we consider the spherical population gradient flow dt = SR(wt) where SR(w) = R(w) R(w), w w. (2.5) where S is the spherical gradient at the current iterate. It is straightforward to see that when the initialization w0 is on the unit sphere, the entire flow will remain on the unit sphere, i.e. wt Sd 1 for all t 0. The flow (2.5) has been proven useful for learning the direction u [BBSS22] in the isotropic case Σ = Id when the activation ϕ is Re LU. In contrast, when Σ follows a spiked model, we show that it can get stuck at stationary points that are almost orthogonal to u. Indeed, when the input covariance Σ has a spike in the target direction u, i.e. θ = u, one expects that the training procedure benefits from this as the input x contains information about the sought unknown u without even querying the response y. The following result proves the contrary; for moderate spike magnitudes, the alignment between the first-layer weights and target wt, u will be insignificant for all t. Theorem 2. Let s > 2 be the information exponent of g with E[g] = 0, and assume Σ follows the (κ, u)-spiked model with Ω(1) κ O(d s 2 s 1 ). For Re LU activation, let wt denote the solution to (2.5) initialized uniformly at random over Sd 1, then with probability at least 0.99, A non-trivial alignment between the first-layer weights wt and the target direction u is required to learn the single index model (2.1). However, the above result implies that in high dimensions when d 1, the alignment is negligible in the population limit (when the number of samples goes to infinity). We remark that when the spike magnitude is large, i.e. κ Ω(d), the flow (2.5) can achieve alignment as the problem essentially becomes one-dimensional, as we demonstrate in Appendix C. To see why the flow (2.5) gets stuck at saddle points and fails to recover the true direction, notice that 2 E h (ϕ( w, x ) y)2i = 1 2 E ϕ( w, x )2 E[ϕ( w, x )y] + 1 2 E y2 . (2.7) If the input was isotropic, i.e. x N(0, Id), the first term in (2.7) would be equal to ϕ 2 γ, which is independent of w. Thus, minimizing R(w) in this case is equivalent to maximizing the correlation term E[ϕ( w, x )y]. However, under the spiked model, the alignment between w and u breaks the symmetry; consequently, the first term in the decomposition grows with w, u , creating a repulsive force that traps the dynamics around the equator where w is almost orthogonal to u. 3 Main Results: Alignment via Normalized Dynamics Having established that the covariance-agnostic training dynamics (2.5) is likely to fail, we consider a covariance-aware normalized flow in this section and show that it can achieve alignment with the unknown target and enjoy better sample complexity compared to the existing results [BAGJ21, BBSS22] in the isotropic case. We start with the population dynamics. 3.1 Warm-up: Population dynamics To simplify the exposition, we define z := Σ 1/2x, w := Σ1/2w/ Σ1/2w and similarly define u, and consider the prediction function ˆy(x; w) := ϕ( w, z ). Due to symmetry, the second moment of the prediction is E ˆy(x; w)2 = ϕ 2 γ which is independent of w; thus, the population risk reads 2 E h (ˆy(x; w) y)2i = 1 2 ϕ 2 γ + 1 2 E y2 E[ϕ( w, z )y]. (3.1) In (3.1), the only term that depends on the weights w is the correlation term and the source of the repulsive force in (2.7) is eliminated; we have w R(w) = w E[ϕ( w, z )y]. Based on this, we use the following normalized gradient flow for training dt = η(wt) w R(wt) where η(w) = Σ1/2w 2. (3.2) We remark that, though not identical, this normalization is closely related to batch normalization which is commonly employed in practice [IS15]. Under the invariance provided by the current normalization, minimizing R(w) corresponds to maximizing E[ϕ( w, z )y]. Thus, instead of w, it will be more useful to track the dynamics of its normalized counterpart w, which is made possible by the following intermediary result that follows from Stein s lemma; also see e.g. [EDB16, MHPG+23]. Lemma 3. Suppose we train wt using the gradient flow (3.2). Then wt solves the following ODE wt, u )(Id wtwt )Σ(Id wtwt )u, (3.3) where ζϕ,g( w, u ) := E[ϕ ( w, z )g ( u, z )]. We will investigate if the modified flow (3.3) achieves alignment; in this context, alignment corresponds to wt, u 1. Towards that end, we make the following assumption. Assumption 2. Let g = P j 0 αjhj and ϕ = P j 0 βjhj be the Hermite decomposition of g and ϕ respectively. Let s be the information exponent of g. For some universal constant c > 0, we assume ζϕ,g(ω) = P j>0 jαjβj ωj 1 c ωs 1, ω (0, 1). There are several important examples that readily satisfy Assumption 2. The obvious example is when the link function is known as in [BAGJ21], i.e. ϕ = g. A more interesting example is when ϕ is an activation with degree s non-zero Hermite coefficient (e.g. Re LU when s is even, see [GKK19, Claim 1]) and g is a degree s Hermite polynomial, which for s = 2 corresponds to the phase retrieval problem. In this case, the assumption is satisfied if αs and βs have the same sign, which occurs with probability 0.5 if we randomly choose the sign of the second layer. Under this condition, the following result shows that the population flow (3.3) can achieve alignment. Proposition 4. Suppose Assumption 2 holds and consider the gradient flow given by (3.3) with initialization satisfying w0, u > 0. Then, we have w T , u 1 ε as soon as w0, u + ln(1/ε) λmin(Σ) where τs(z) := 1 s = 1 ln(1/z) s = 2 (1/z)s 2 s > 2 . (3.4) We remark that the information exponent enters the rate in (3.4) through the function τs, and time needed to achieve ε alignment gets worse with larger information exponent. Indeed, it is understood that this quantity serves as a measure of complexity for the target function being learned. 3.2 Empirical dynamics and sample complexity Given n i.i.d. samples {(x(i), y(i))}n i=1 from the single index model (2.1), we consider the flow dt = η(wt) ˆR(wt) with ˆR(w) := w 1 n Pn i=1 ϕ w,x(i) y(i) , (3.5) where we estimate the covariance matrix Σ using the sample mean ˆΣ := 1 n Pn i=1 x(i)x(i) over n i.i.d. samples; the above dynamics defines an empirical gradient flow. Notice that we ignored the gradient associated with the term ϕ2 since the population dynamics ensures that its gradient will concentrate around zero; thus, it is redundant to estimate this term. Below, we will use n = n for smooth activations, i.e. the same dataset can be used for covariance estimation; For Re LU, we require a more accurate covariance estimator, thus, we use n n2 by assuming access to an additional n n unlabeled data points. Similar to the previous section, we track the dynamics of normalized w by defining w := ˆΣ1/2w/ ˆΣ1/2w (and leave u unchanged from Section 3.1). The same arguments as in Lemma 3 allow us to track the evolution of w, which ultimately yields the following alignment result under general covariance structure. Theorem 5. Let s be the information exponent of g, and assume it satisfies |g( )| 1 + | |p for some p > 0. For ϕ denoting either the Re LU activation or a smooth activation satisfying |ϕ | |ϕ | 1, suppose Assumption 2 holds. For any ε > 0, suppose we run the finite sample gradient flow (3.5) with η(w) = ˆΣ1/2w 2, initialized such that w0, u > 0, and with number of samples w0, u 2(1 s) ε 2o , where κ(Σ) is the condition number of Σ. Then, for T τs( w0,u )+ln(1/ε) λmin(Σ) , we have w T, u 1 ε, (3.6) with probability at least 1 c1d c2 for some universal constants c1, c2 > 0 over the randomness of the dataset. Here, τs is defined in (3.4) and hides poly-logarithmic factors. Remark. We make the following remarks on the above theorem. The initial condition w0, u > 0 is required when we have odd information exponent. When w0 is initialized uniformly over Sd 1, the condition holds with probability 0.5 over the initialization. See [BAGJ21, Remark 1.8] for further discussion on this condition. Although w is defined using the empirical covariance unlike u which is defined by population covariance, this definition is the suitable choice to approximate the target function g (c.f. Theorem 9), since it ensures the arguments of ϕ and g are sufficiently close when w recovers u. The intuition behind the proof of Theorem 5 is presented in Section 5 with the complete proof in the appendix. We highlight that the improvement in the sample complexity compared to the isotropic setting occurs whenever the covariance structure induces a stronger initial alignment and consequently stronger signal. The following corollary demonstrates a concrete example of such improvement by specializing Theorem 5 for a spiked covariance model. Corollary 6. Consider the setting of Theorem 5 with Σ following the (κ, θ)-spiked model, where u, θ d r1 and κ dr2 with r1 [0, 1/2] and r2 [0, 1]. Suppose w0 is sampled uniformly from Sd 1. Then, when conditioned on w0, u > 0, the sample complexity in Theorem 5 reads d1+2r2 ds 1 ε 2 0 < r2 < r1 d1+2r2 d(s 1)(1 2(r2 r1)) ε 2 r1 < r2 < 2r1 d1+2r2 d(s 1)(1 r2) ε 2 2r1 < r2 < 1 , (3.7) where hides poly-logarithmic factors of d. Remark. We have the following observations on the above sample complexity. Corollary 6 demonstrates that structured data can lead to better sample complexity when the right normalization is used during training. This complements Theorem 2 where we recall that spherical training dynamics ignores the structure in data and the target direction cannot be recovered. When g is a polynomial of degree p, the lower bound for rotationally invariant kernels (including the neural tangent kernel at initialization) implies a complexity of at least dΩ((1 r2)p) [DWY21]. Thus the sample complexity of Corollary 6 can always outperform the kernel lower bound when p is sufficiently large and s remains constant. Three-step phase transition. Recall that in the isotropic setting Σ = Id, the sample complexity of learning g with information exponent s using full-batch gradient flow is O(ds) for s 3 [BBSS22]. The sample complexity in Corollary 6 is strictly smaller than O(ds) as soon as (s 1)r1/(s 2) < r2. Furthermore, for any ν > 0 it is at most O(d3+ν) as soon as r2 1 ν/(s 3) and 2r1 < r2, in which case the sample complexity becomes independent of the information exponent. Interestingly, the complexity becomes independent of r1 when r2 > 2r1 or r2 < r1, i.e. the direction of the spike becomes irrelevant when the spike magnitude is sufficiently large or small. The three-stage phase transition of Corollary 6 is due to the different behaviour of the inner product w0, u in different regimes of r1 and r2. When r2 < r1, we have w0, u w0, u , thus the initial alignment is just as uninformative as the isotropic case providing no improvement. Moreover, a potentially large condition number may hurt the sample complexity in this case. On the other hand, when r1 < r2 < 2r1 we have w0, u κ u, θ w0, θ , and r2 > 2r1 leads to w0, u κ w0, θ , thus large κ or u, θ in this regime may improve the sample complexity. 4 Implications to Neural Networks and Further Improvements 4.1 Improving Sample Complexity via Preconditioning We now demonstrate that preconditioning the training dynamics with ˆΣ 1 can remove the dependency on κ(Σ), ultimately improving the sample complexity. Consider the preconditioned gradient flow dt = η(wt) ˆΣ 1 ˆR(wt) with η(w) = ˆΣ 1/2w 2. (4.1) We have the following alignment result. Theorem 7. Consider the same setting as Theorem 5, and assume we run the preconditioned empirical gradient flow (4.1) with number of samples w0, u 2(1 s) ε 2o , where hides poly-logarithmic factors of d. Then, for T τs w0, u + ln(1/ε), we have w T , u 1 ε, with probability at least 1 c1d c2 for some universal constants c1, c2 > 0. Preconditioning removes the condition number dependence, which is particularly important in the spiked model case where this quantity can be large. Corollary 8. Consider the setting of Theorem 7, and assume we run the preconditioned empirical gradient flow (4.1) for the (κ, θ)-spiked model where u, θ d r1 and κ dr2 with r1 [0, 1/2] and r2 [0, 1]. Suppose w0 is sampled uniformly from Sd 1. Then, when conditioned on w0, u > 0, the sample complexity of Theorem 7 reads d ds 1 ε 2 0 < r2 < r1 d d(s 1)(1 2(r2 r1)) ε 2 r1 < r2 < 2r1 d d(s 1)(1 r2) ε 2 2r1 < r2 < 1 , (4.2) where hides poly-logarithmic factors of d. The above result improves upon Corollary 6; thus, making a case for preconditioning in practice. The complexity results also strictly improve upon the O(ds) complexity in the isotropic case [BBSS22] when r2 > r1. Further, for any ν > 0, we can obtain the complexity of O(d1+ν) (nearly linear in dimension) when r2 > 1 ν/(s 1) and r2 > 2r1 or r1 + 1/2(1 ν/(s 1)) < r2 < 2r1. In addition to the remarks of Corollary 6, we note that the complexity is independent of both r1 and r2 when r2 < r1 (cf. Figure 1 hard regime), i.e. the spike magnitude and the spike-target alignment have no effect on the complexity unless r2 r1. Under the spiked covariance model, one could improve the above results by instead using spectral initialization, i.e. initializing at θ, which can be estimated from unlabeled data. Assuming perfect access to θ, using the statement of Theorems 5 and 7, this initialization would imply a sample complexity of O(d1+2r2+((s 1)(2r1 r2) 0)) without and O(d1+((s 1)(2r1 r2) 0)) with preconditioning. 4.2 Two-layer neural networks and learning the link function Our main focus so far was learning the target direction u. Next, we consider learning the unknown link function with a neural network, providing a complete learnability result for single index models. We use Algorithm 1 and train the first-layer of the neural network with either the empirical gradient flow (3.5) or the preconditioned version (4.1). Then, we randomly choose the bias units and minimize the second layer weights using another gradient flow. Our goal is to track the sample complexity n needed to learn the single index target which we compare against the results of [BBSS22]. We highlight that layer-wise training in Algorithm 1 is frequently employed in the literature [BES+22, BBSS22, DLS22, MHPG+23] and in particular [BBSS22] also used gradient flow for training. Algorithm 1 Layer-wise training of a two-layer Re LU network with gradient flow (GF). Input: w0 Rd, T, T , , λ R+ and data {(x(i), y(i))}n i=1. 1: Train the first layer weights W T j using the GF (3.5) or the preconditioned GF (4.1). 2: Normalize the weights W T j := W T j / ˆΣ 1/2W T j for every 1 j m. 3: Let bj i.i.d. Unif( , ) and a0 j = 1/m for 1 j m. 4: Train the second layer weights a T via the gradient flow i=1 (ˆy(x(i); W T , at, b) y(i))2 + λ at 2 5: return (W T , a T , b). Theorem 9. Let g be twice weakly differentiable with information exponent s and assume g has at most polynomial growth. Suppose ϕ is the Re LU activation, Assumption 2 holds and we run Algorithm 1 with w0 initialized uniformly over Sd 1. For any ε > 0, let n and T be chosen according to Theorem 5 when we run the gradient flow (3.5) and Theorem 7 when we run the preconditioned gradient flow (4.1). Then, for p ln(nd), some regime of λ given by (E.3) and sufficiently large T given by (E.4), we have ˆy(x; W T , a T , b) y 2 C1 E ϵ2 + C2(ε + 1/m), (4.3) conditioned on w0, u > 0 with probability at least 0.99 over the randomness of the dataset, biases, and initialization, where C1 is a universal constant and C2 hides polylog(m, n, d) factors. The next result immediately follows from the previous theorem together with Corollaries 6 & 8. Corollary 10. In the setting of Theorem 9, if Σ follows the (κ, θ)-spiked model, the sample complexity n is given by (3.7) if we use the empirical gradient flow and (4.2) if we use the preconditioned version. We remark that for fixed ε, the sample complexity to learn g in the isotropic case is O(ds) [BBSS22]. Under the spiked model, if we assume that r2 is sufficiently large and r1 is sufficiently small as discussed in the previous section, Corollary 10 improves this rate to either (3.7) when the empirical gradient flow is used without preconditioning or to (4.2) with preconditioning. 5 Technical Overview In this section, we briefly discuss the key intuitions that lead to the proof of our main results. We first review the case Σ = Id, where we have the following decomposition for population loss 2 E (ϕ( w, x ) y)2 = 1 2 ϕ 2 γ + 1 2 E y2 E[ϕ( w, x )g( u, x )]. (5.1) Notice that the only term contributing to the population gradient is the last term which measures the correlation between ϕ and g. Following the gradient flow and applying Stein s lemma yields dt = E ϕ ( wt, x )g ( u, x ) (1 wt, u 2) = (1 wt, u 2) X j s jαjβj wt, u j 1, where the second identity follows from the Hermite expansion; see also [EDB16, EBD19]. Assume αsβs > 0 to ensure that the population dynamics will move towards u at least near initialization. When replacing the population gradient with a full-batch gradient, we need the estimation noise to be smaller than the signal existing in the gradient. When w0, u 1, this signal is roughly of the order w0, u s 1. As the uniform concentration error over Sd 1 scales with p d/n, we need n d w0, u 2(s 1) to ensure the signal remains dominant and wt moves towards u. When w0 is initialized uniformly over Sd 1 this translates to a sample complexity of n ds, which is indeed obtained by [BBSS22] via similar arguments. However, the behavior of the spherical dynamics entirely changes when we move to the anisotropic case. Suppose Σ follows a (κ, u)-spiked model and ϕ is Re LU. Using Lemma 12, it is easy to show that with the spherical gradient flow, the alignment obeys the following ODE wt, z )g ( u, z ) κ ψϕ,g(wt) 1 + κ wt, u ) (1 wt, u 2), where ψϕ,g(wt) is introduced in Lemma 12. The additional ψϕ,g(wt) term creates a repulsive force towards the equator wt, u = 0. The presence of this term is due to the fact that unlike (5.1), the term E ϕ( wt, u )2 is no longer independent of w and cannot be replaced by ϕ 2 γ. When w0 is initialized uniformly over Sd 1 and Ω(1) κ O(d), we have w0, u κ w0, u . Furthermore, at this initialization ψϕ,g(w0) 1/2. Therefore, sαsβs( κ w0, u )s 1 Hence the dynamics is trapped at | wt, u | = O(1/ d) for all t > 0 as long as κ = O(d1 1/(s 1)). To remove the repulsive force in the spherical dynamics, we can directly normalize the input of ϕ. As demonstrated by (3.1), once again the only term that varies with w would be the correlation loss. Specifically, using the result of Lemma 3, in the population limit we can track dt = E[ϕ ( w, z )g ( u, z )] ut , Σut , (5.2) where ut := u wt. Thus, the strength of the signal at initialization is of order w0, u s 1/κ(Σ), which after controlling the error in the estimate of ˆΣ and in the estimate of popula- tion gradient using finitely many samples, leads to the sample complexity n dκ(Σ)2 w0, u 2(1 s). Importantly, Σ can incude a much stronger initial alignment w0, u than the isotropic case w0, u , which is emphasized in Corollary 6. Using preconditioning will further remove the dependency on the condition number of Σ. 6 Conclusion We studied the dynamics of gradient flow to learn single index models when the input data covariance may contain additional structure. Under a spiked model for the covariance matrix, we showed that using spherical gradient flow, as an example of a covariance-agnostic training mechanism employed in the recent literature, is unable to learn the target direction of the single index model even when the spike and the target directions are identical. In contrast, we showed that an appropriate weight normalization removes this problem and successfully recovers the target direction. Moreover, depending on the alignment between the covariance structure and the target direction, the sample complexity can improve upon the isotropic setting, while also outperforming lower bounds for rotationally-invariant kernels. This phenomenon is due to the additional information about the target direction contained in the covariance matrix which improves the effective alignment at initialization. Additionally, we showed that a simple preconditioning of the gradient flow using the inverse empirical covariance can improve the sample complexity, achieving almost linear rate in certain settings. We outline a few limitations of our current work and discuss directions for future research. While studying single index models provides a pathway to a general understanding of feature learning with structured covariance, considering multi-index models can provide a more complete picture [PSE22], e.g. by establishing incremental learning dynamics [ABAM23]. We leave the problem of learning multi-index models under structured input as an interesting future direction. Gradient flow under squared loss can be seen as an example of a Correlational Statistical Query (CSQ) algorithm [BF02, Rey20], i.e. an algorithm that only accesses noisy estimates of expected correlation queries from the model. Understanding the limitations of learning single index models under a structured input through a CSQ lower bound perspective is another important direction that would complement our results in this paper. When training the first layer, we considered a somewhat unconventional initialization and relied on the symmetry it induces. It is interesting to consider cases where we train a network with multiple neurons starting from a more standard initialization which can help relax Assumption 2. Acknowledgments The authors thank Alberto Bietti and Zhichao Wang for discussions and feedback on the manuscript. TS was partially supported by JSPS KAKENHI (20H00576) and JST CREST (JPMJCR2015). MAE was partially supported by NSERC Grant [2019-06167], CIFAR AI Chairs program, CIFAR AI Catalyst grant. [AAM22] Emmanuel Abbe, Enric Boix Adsera, and Theodor Misiakiewicz, The merged-staircase property: a necessary and nearly sufficient condition for sgd learning of sparse functions on two-layer neural networks, Conference on Learning Theory, 2022. [ABAM23] Emmanuel Abbe, Enric Boix-Adsera, and Theodor Misiakiewicz, Sgd learning on neural networks: leap complexity and saddle-to-saddle dynamics, ar Xiv preprint ar Xiv:2302.11055 (2023). [ASKL23] Luca Arnaboldi, Ludovic Stephan, Florent Krzakala, and Bruno Loureiro, From highdimensional & mean-field dynamics to dimensionless odes: A unifying approach to sgd in two-layers networks, ar Xiv preprint ar Xiv:2302.05882 (2023). [ATV22] Pranjal Awasthi, Alex Tang, and Aravindan Vijayaraghavan, Agnostic learning of general relu activation using gradient descent, ar Xiv preprint ar Xiv:2208.02711 (2022). [BAGJ21] Gerard Ben Arous, Reza Gheissari, and Aukosh Jagannath, Online stochastic gradient descent on non-convex losses from high-dimensional inference., J. Mach. Learn. Res. 22 (2021), 106 1. [BBAP05] Jinho Baik, G erard Ben Arous, and Sandrine P ech e, Phase transition of the largest eigenvalue for nonnull complex sample covariance matrices. [BBSS22] Alberto Bietti, Joan Bruna, Clayton Sanford, and Min Jae Song, Learning single-index models with shallow neural networks, Advances in Neural Information Processing Systems, 2022. [BEG+22] Boaz Barak, Benjamin L Edelman, Surbhi Goel, Sham Kakade, Eran Malach, and Cyril Zhang, Hidden Progress in Deep Learning: SGD Learns Parities Near the Computational Limit, ar Xiv preprint ar Xiv:2207.08799 (2022). [BES+19] Jimmy Ba, Murat Erdogdu, Taiji Suzuki, Denny Wu, and Tianzong Zhang, Generalization of two-layer neural networks: An asymptotic viewpoint, International Conference on Learning Representations, 2019. [BES+22] Jimmy Ba, Murat A Erdogdu, Taiji Suzuki, Zhichao Wang, Denny Wu, and Greg Yang, High-dimensional Asymptotics of Feature Learning: How One Gradient Step Improves the Representation, ar Xiv preprint ar Xiv:2205.01445 (2022). [BES+23] Jimmy Ba, Murat A. Erdogdu, Taiji Suzuki, Zhichao Wang, and Denny Wu, Learning in the presence of low-dimensional structure: a spiked random matrix perspective, Thirty-seventh Conference on Neural Information Processing Systems (Neur IPS 2023), 2023. [BF02] Nader H Bshouty and Vitaly Feldman, On using extended statistical queries to avoid membership queries, Journal of Machine Learning Research 2 (2002), no. Feb, 359 395. [BMZ23] Rapha el Berthier, Andrea Montanari, and Kangjie Zhou, Learning time-scales in two-layers neural networks, ar Xiv preprint ar Xiv:2303.00055 (2023). [BPVZ23] Joan Bruna, Loucas Pillaud-Vivien, and Aaron Zweig, On single index models beyond gaussian data, ar Xiv preprint ar Xiv:2307.15804 (2023). [BS06] Jinho Baik and Jack W Silverstein, Eigenvalues of large sample covariance matrices of spiked population models, Journal of multivariate analysis 97 (2006), no. 6, 1382 1408. [CB18] Lenaic Chizat and Francis Bach, On the Global Convergence of Gradient Descent for Over-parameterized Models using Optimal Transport, Advances in Neural Information Processing Systems, 2018. [Chi22] L ena ıc Chizat, Mean-field langevin dynamics: Exponential convergence and annealing, ar Xiv preprint ar Xiv:2202.01009 (2022). [CM20] Sitan Chen and Raghu Meka, Learning polynomials in few relevant dimensions, Conference on Learning Theory, 2020. [COB19] Lenaic Chizat, Edouard Oyallon, and Francis Bach, On Lazy Training in Differentiable Programming, Advances in Neural Information Processing Systems, 2019. [CWPPS23] Elizabeth Collins-Woodfin, Courtney Paquette, Elliot Paquette, and Inbar Seroussi, Hitting the high-dimensional notes: An ode for sgd learning dynamics on glms and multi-index models, ar Xiv preprint ar Xiv:2308.08977 (2023). [DH18] Rishabh Dudeja and Daniel Hsu, Learning single-index models in gaussian space, Conference On Learning Theory, PMLR, 2018, pp. 1887 1930. [DKL+23] Yatin Dandi, Florent Krzakala, Bruno Loureiro, Luca Pesce, and Ludovic Stephan, Learning two-layer neural networks, one (giant) step at a time, ar Xiv preprint ar Xiv:2305.18270 (2023). [DKTZ22] Ilias Diakonikolas, Vasilis Kontonis, Christos Tzamos, and Nikos Zarifis, Learning a single neuron with adversarial label noise via gradient descent, Conference on Learning Theory, PMLR, 2022, pp. 4313 4361. [DLS22] Alexandru Damian, Jason Lee, and Mahdi Soltanolkotabi, Neural Networks can Learn Representations with Gradient Descent, Conference on Learning Theory, 2022. [DNGL23] Alex Damian, Eshaan Nichani, Rong Ge, and Jason D Lee, Smoothing the landscape boosts the signal for sgd: Optimal sample complexity for learning single index models, ar Xiv preprint ar Xiv:2305.10633 (2023). [DWY21] Konstantin Donhauser, Mingqi Wu, and Fanny Yang, How rotational invariance of common kernels prevents generalization in high dimensions, International Conference on Machine Learning, 2021. [EBD19] Murat A. Erdogdu, Mohsen Bayati, and Lee H. Dicker, Scalable Approximations to Generalized Linear Problems, Journal of Machine Learning Research (2019). [EDB16] Murat A Erdogdu, Lee H Dicker, and Mohsen Bayati, Scaled least squares estimator for glms in large-scale problems, Advances in Neural Information Processing Systems 29 (2016). [Erd15] Murat A Erdogdu, Newton-stein method: a second order method for glms via stein s lemma, Proceedings of Advances in Neural Information Processing Systems, 2015, pp. 1216 1224. [FCG20] Spencer Frei, Yuan Cao, and Quanquan Gu, Agnostic learning of a single neuron with gradient descent, Advances in Neural Information Processing Systems, vol. 33, Curran Associates, Inc., 2020, pp. 5417 5428. [GKK19] Surbhi Goel, Sushrut Karmalkar, and Adam Klivans, Time/accuracy tradeoffs for learning a relu with respect to gaussian marginals, Advances in neural information processing systems 32 (2019). [GMMM19] B. Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari, Limitations of Lazy Training of Two-layers Neural Networks, Advances in Neural Information Processing Systems, 2019. [GMMM20] Behrooz Ghorbani, Song Mei, Theodor Misiakiewicz, and Andrea Montanari, When Do Neural Networks Outperform Kernel Methods?, Advances in Neural Information Processing Systems, 2020. [HC22] Karl Hajjar and Lenaic Chizat, On the symmetries in the dynamics of wide two-layer neural networks, ar Xiv preprint ar Xiv:2211.08771 (2022). [HTFF09] Trevor Hastie, Robert Tibshirani, Jerome H Friedman, and Jerome H Friedman, The elements of statistical learning: data mining, inference, and prediction, vol. 2, Springer, 2009. [IS15] Sergey Ioffe and Christian Szegedy, Batch normalization: Accelerating deep network training by reducing internal covariate shift, International conference on machine learning, pmlr, 2015, pp. 448 456. [JGH18] Arthur Jacot, Franck Gabriel, and Clement Hongler, Neural Tangent Kernel: Convergence and Generalization in Neural Networks, Advances in Neural Information Processing Systems, 2018. [Joh01] Iain M Johnstone, On the distribution of the largest eigenvalue in principal components analysis, The Annals of statistics 29 (2001), no. 2, 295 327. [JWHT13] Gareth James, Daniela Witten, Trevor Hastie, and Robert Tibshirani, An introduction to statistical learning, vol. 112, Springer, 2013. [KKSK11] Sham M Kakade, Varun Kanade, Ohad Shamir, and Adam Kalai, Efficient learning of generalized linear and single index models with isotonic regression, Advances in Neural Information Processing Systems 24 (2011). [LD89] Ker-Chau Li and Naihua Duan, Regression Analysis Under Link Violation, The Annals of Statistics (1989). [LMZ20] Yuanzhi Li, Tengyu Ma, and Hongyang R Zhang, Learning over-parametrized twolayer neural networks beyond NTK, Conference on Learning Theory, 2020. [MHD+23] Arvind Mahankali, Jeff Z Haochen, Kefan Dong, Margalit Glasgow, and Tengyu Ma, Beyond ntk with vanilla gradient descent: A mean-field analysis of neural networks with polynomial width, samples, and time, ar Xiv preprint ar Xiv:2306.16361 (2023). [MHPG+23] Alireza Mousavi-Hosseini, Sejun Park, Manuela Girotti, Ioannis Mitliagkas, and Murat A Erdogdu, Neural networks efficiently learn low-dimensional representations with SGD, The Eleventh International Conference on Learning Representations, 2023. [MMM19] Song Mei, Theodor Misiakiewicz, and Andrea Montanari, Mean-field theory of twolayers neural networks: dimension-free bounds and kernel limit, Conference on Learning Theory, 2019. [MMN18] Song Mei, Andrea Montanari, and Phan-Minh Nguyen, A mean field view of the landscape of two-layer neural networks, Proceedings of the National Academy of Sciences 115 (2018), no. 33, E7665 E7671. [NWS22] Atsushi Nitanda, Denny Wu, and Taiji Suzuki, Convex analysis of the mean field langevin dynamics, International Conference on Artificial Intelligence and Statistics, PMLR, 2022, pp. 9741 9757. [O D14] Ryan O Donnell, Analysis of boolean functions, Cambridge University Press, 2014. [PSE22] Sejun Park, Umut Simsekli, and Murat A. Erdogdu, Generalization Bounds for Stochastic Gradient Descent via Localized ε-Covers, ar Xiv preprint ar Xiv:2209.08951 (2022). [Rey20] Lev Reyzin, Statistical queries and statistical algorithms: Foundations and applications, ar Xiv preprint ar Xiv:2004.00557 (2020). [RGKZ21] Maria Refinetti, Sebastian Goldt, Florent Krzakala, and Lenka Zdeborov a, Classifying high-dimensional Gaussian mixtures: Where kernel methods fail and neural networks succeed, International Conference on Machine Learning, 2021. [RVE18] Grant M Rotskoff and Eric Vanden-Eijnden, Neural networks as Interacting Particle Systems: Asymptotic convexity of the Loss Landscape and Universal Scaling of the Approximation Error, ar Xiv preprint ar Xiv:1805.00915 (2018). [Sol17] Mahdi Soltanolkotabi, Learning relus via gradient descent, Advances in neural information processing systems 30 (2017). [SWON23] Taiji Suzuki, Denny Wu, Kazusato Oko, and Atsushi Nitanda, Feature learning via mean-field langevin dynamics: classifying sparse parities and beyond, Thirty-seventh Conference on Neural Information Processing Systems (Neur IPS 2023), 2023. [Tel23] Matus Telgarsky, Feature selection and low test error in shallow low-rotation relu networks, The Eleventh International Conference on Learning Representations, 2023. [Ver18] Roman Vershynin, High-dimensional probability: An introduction with applications in data science, Cambridge University Press, 2018. [VH16] Ramon Van Handel, Probability in high dimension, 2016. [Wai19] Martin J. Wainwright, High-dimensional statistics: A non-asymptotic viewpoint, Cambridge University Press, 2019. [WLLM19] Colin Wei, Jason D Lee, Qiang Liu, and Tengyu Ma, Regularization matters: Generalization and optimization of neural nets vs their induced kernel, Advances in Neural Information Processing Systems 32 (2019). [YO20] Gilad Yehudai and Shamir Ohad, Learning a single neuron with gradient methods, Proceedings of Thirty Third Conference on Learning Theory, Proceedings of Machine Learning Research, vol. 125, PMLR, 2020, pp. 3756 3786. [YS19] Gilad Yehudai and Ohad Shamir, On the Power and Limitations of Random Features for Understanding Neural Networks, Advances in Neural Information Processing Systems, 2019. A Concurrent Works In this paragraph we briefly summarize a few relevant results concurrent to our submission. [BMZ23] provided a precise analysis of the two-timescale dynamics in learning a single index model with information exponent s = 1. [MHD+23] considered the learning of a single index target with s = 2 using a neural network in the mean-field regime. [BPVZ23] extended the information exponentbased characterization of online SGD to input data beyond Gaussian. [DNGL23] showed that a gradient-smoothed dynamics can improve the sample complexity and match the CSQ lower bound. Finally, beyond the single-index setting, [DKL+23, SWON23, CWPPS23] considered learning low-dimensional target functions supported on k > 1 dimensions via gradient-based feature learning. B Background on Hermite Expansion The normalized Hermite polynomials {hj}j 0 given by (2.4) provide an orthonormal basis for L2(γ), thus for every f L2(γ) we have j=0 f, hj γhj, where f, hj γ := Ez N(0,1)[f(z)hj(z)]. We will commonly invoke the following well-known properties of Hermite polynomials. If j 1, then h j = jhj 1, where h j stands for the derivative of hj. Furthermore, if z1 and z2 are two standard Gaussian random variables with E[z1z2] = ρ, then E[hi(z1)hj(z2)] = δijρj where δij is the Kronecker delta. We refer the interested reader to [O D14, Chapter 11.2] for additional discussions and properties of these polynomials. We will now discuss how our Definition 1 relates to the original definition of information exponent of [BAGJ21]. In their setting, they assume the true data distribution Pu is parameterized by some unit vector u Sd 1, and we know the parametric family {Pw}w Sd 1; thus the problem is to estimate the direction u. Furthermore, they assume the population loss, which is the expectation of some per-sample loss, has spherical symmetry, i.e. the population loss R(w) can be written as R(w) = R( w, u ). Then, [BAGJ21, Definition 1.2] defines the information exponent to be the degree of the first non-zero coefficient of R in its Taylor expansion around the origin. In other words, we say R has information exponent s if dk R dzk (0) = 0 1 k < s dk R dzk (0) = c < 0 k = s dk R dzk (z) C k > s, z [ 1, 1] , where C, c > 0 are universal constants. To specialize the above abstract definition to the Gaussian case, consider the setting where the input data is standard Gaussian x N(0, Id) and the problem is to estimate u Sd 1 given a response variable y = f( u, x ) with known f. Via the Hermite expansion of f, one can write R( w, u ) := 1 2 E (f( w, x ) f( u, x ))2 = X j 1 f, hj 2 γ w, u j + const. Thus, the information exponent of R is indeed the degree of the first non-zero term in the Hermite expansion of f. Now consider the general case where x N(0, Σ). The spherical symmetry assumed in [BAGJ21] no longer holds. However, after proper normalization of weights, if we consider the population loss then R(w) = R w,Σu Σ1/2w Σ1/2u . Indeed, a close examination of the arguments of [BAGJ21] reveals that for their results to hold, the proper symmetry to consider is the ellipsoidal symmetry, and the proper definition of information exponent is the degree of the first non-zero term in the Hermite expansion of R, which reads j 1 f, hj 2zj + const. Once again, we can consistently define the information exponent to be the degree of the first non-zero term in the Hermite expansion of f, as long as the input is Gaussian (potentially anisotropic). C Proofs of Section 2.1 Before beginning our main discussions, we state the following lemma which is a generalization of Stein s lemma (Gaussian integration by parts), and will help obtain a closed-form expression for the population gradient. We refer to [Erd15, MHPG+23] for similar statements. Lemma 11. Let f, g : R R with g weakly differentiable. Suppose z N(0, Id). Then, for any w, u Sd 1, we have E[f( w, z )g( u, z )z] = E[f( w, z )g ( u, z )]u+E[f( w, z ){g( u, z ) w, z g ( u, z ) u, w }]w. Proof. Consider the conditional distribution z| w, z N(µ, Σ), where µ = w, z w and Σ = Id w w . Recall that Stein s lemma (Gaussian integration by parts) states that when z N(µ, Σ), then E[g(z)z] = E[g(z)]µ + Σ E[ g(z)]. E[g( u, z )z | w, z ] = E[g( u, z ) | w, z ] w, z w + (Id ww ) E[g ( u, z ) | w, z ]u. Applying the tower property of conditional expectation and rearranging the terms yields the desired result. We are now ready to state and prove the expression for the population gradient when using the Re LU activation. Lemma 12. Suppose ϕ is the Re LU activation and g is weakly differentiable. Let z N(0, Id). Define w := Σ1/2w Σ1/2w and similarly define u. Then R(w) = Σ n ψϕ,g(w)w + ζϕ,g(w)u o where ψϕ,g(w) := E[ ϕ( w, z )g( u, z ) + ϕ ( w, z )g ( u, z ) u, w ] and ζϕ,g(w) := E[ϕ ( w, z )g ( u, z )] Proof. Notice that the population risk is given by ϕ( w, x ) g 2 E ϕ( w, x )2 E[ϕ( w, x )g]+E y2 Notice that x = Σ1/2z. By the homogeneity of Re LU, we can rewrite the first term as E ϕ( w, x )2 = w Σw ϕ 2 γ = w Σw Then we have, ϕ( w, x ) g | {z } =:υ(w) Note that ϕ ( w, x ) = ϕ ( w, z ). Then, υ(w) = Σ1/2 E[ϕ ( w, z ) g( u, z )z] = Σ1/2{E[ϕ ( w, z )g ( u, z )]u + E[ϕ( w, z g( u, z ) ϕ ( w, z )g ( u, z ) u, w ]w} ( E[ϕ ( w, z )g ( u, z )] Σ1/2u u + E[ϕ( w, z g( u, z ) ϕ ( w, z )g ( u, z ) u, w ] where we used Lemma 11 and the fact that ϕ (z)z = ϕ(z). Therefore, R(w) = Σ n ψϕ,g(w)w + ζϕ,g(w)u o which concludes the proof. Particularly, the above lemma yields the following corollary for the spherical dynamics in the population limit. Corollary 13. Suppose {wt}t 0 is a solution to the population spherical gradient flow (2.5), ϕ is the Re LU activation, and Σ follows the (κ, θ)-spiked model. Then, dt = ζϕ,g(wt)(1 wt, u 2) κ n ψϕ,g(wt) wt, θ + ζϕ,g(wt) u, θ o 1 + κ ( θ, u wt, θ wt, u ). (C.1) C.1 Proof of Theorem 2 Plugging in θ = u in Corollary 13, we obtain ( ζϕ,g(wt) + κ ψϕ,g(wt) wt, u (1 u, wt 2). (C.2) To prove the statement of the theorem, we will show that whenever we will have d wt,u 2 dt < 0, thus when initialized from O(1/ d), wt, u can never escape the saddle point near the equator w, u = 0. Recall from the properties of the Hermite expansion in Appendix B that jαjhj 1 and ϕ = X Since we additionally assume E[g] = α0 = 0, by the definition of ψϕ,g in Lemma 12 and the properties of Hermite expansion discussed in Appendix B, we have j s(j 1)αjβj and similarly ζϕ,g(wt) := X Thus we obtain d wt, u 2 dt = 2F(wt)(1 wt, u 2), wt, u j 1 wt, u κ 1 + κ j s (j 1)αjβj wt, u j+1 wt, u κ wt, u 2 We proceed by upper bounding F. To do so, first note that 1 + κ w, u 2 To bound the first term of F, we have X wt, u j 1 wt, u wt, u wt, u s 1 X j s j|αjβj| wt, u s 1 wt, u ϕ γ g γ(1 + κ)(s 1)/2 wt, u s. j s (j 1)αjβj wt, u j+1 wt, u ϕ γ g γ(1 + κ)(s+1)/2 wt, u s+2. Hence, for κ 1, F(wt) ϕ γ g γ(1 + κ)(s 1)/2 wt, u s 1 + (1 + κ) wt, u 2 wt, u 2 Suppose κ < d/C2 1, then F(wt) wt, u 2 2 ϕ γ g γ(1 + κ)(s 1)/2 wt, u s 2 1/4 2 ϕ γ g γCs 2 r Thus, for any κ such that (8Cs 2 ϕ γ g γ) 2 s 1 d and any wt such that | wt, u | C/ d, we have d wt,u 2 dt 0, hence supt 0| wt, u | C/ d, as long as the above holds true at initialization. Finally, we will show w0, u C/ d with probability at least 0.99 for a suitable choice of constant C. Indeed, this is an elementary concentration of measure result on the unit sphere. For simplicity, we avoid performing sharp probability of failure analysis and only remark that E h w0, u 2i = 1/d, thus by the Markov inequality P w0, u 2 C2/d 1/C2, hence a choice of C 10 suffices, and the proof is complete. C.2 Extremely Large Spike In this section, we will show that under extremely large spike, the spherical gradient flow (2.5) can potentially recover the true direction. Namely, we will prove the following proposition. Proposition 14. Suppose we initialize the spherical population gradient flow (2.5) from w0. Let ϕ be the Re LU activation and assume ϕ, g γ := Ez N(0,1)[ϕ(z)g(z)] = α > 1/2, and κ C w0,u 2 for a sufficiently large constant C > 0 depending only on g. Then, the gradient flow on the sphere satisfies w T , u 1 ε if w0, u > 0 w T , u 1 + ε if w0, u < 0 (C.3) T 1 α 1/2 ln(2/ε). (C.4) Before proceeding to the proof, we notice that if we uniformly initialize w0 over Sd 1, then the typical value for w0, u is of order d 1/2, meaning that the above proposition asks for κ = Ω(d). This is a regime where lower bounds for the sample complexity of kernel methods are Ω(1) [DWY21], thus no meaningful separation in terms of dimension dependency of the sample complexity between neural networks and kernel methods is possible, as the problem becomes effectively one-dimensional. Proof. The cases where w0, u > 0 and w0, u < 0 are symmetric, thus we only present the proof for the former. Using (C.2), we can write the dynamics on the sphere more explicitly as wt, z )g ( u, z ) 1 + κ wt, u 2 | {z } =:B1 + κ wt, u E ϕ( wt, z )g( u, z ) 1 + κ wt, u 2 | {z } =:B2 (1 wt, u 2). Our goal is to study the regime of large κ, therefore we will bound how much B1 and B2 can deviate from their corresponding κ = values. In particular, we have 1 + κ wt, u 2 = g γ/2 1 + κ wt, u 2 . Furthermore, assuming wt, u > 0 and ϕ, g γ > 0, let cκ(wt) := κ wt,u 1+κ 1+κ wt,u 2 . Then, by the Lipschitzness of ϕ, B2 = cκ(wt) ϕ, g γ + cκ(wt) E ϕ( wt, z ) ϕ( u, z ) g( u, z ) cκ(wt) ϕ, g γ cκ(wt) g γ wt u cκ(wt) ϕ, g γ cκ(wt) g γ q cκ(wt) ϕ, g γ 2κ g γ| wt, u | 1 + κ wt, u 2 . where we used wt = Σ1/2wt Σ1/2wt in the last step. Suppose wt, u > 0 (which holds at least on a neighborhood around initialization, and as we will see below holds for all t > 0), then, cκ(wt) κ wt, u 2 1 + κ wt, u 2 . As a result, we obtain B1 + B2 κ wt, u 2(1 + κ) ϕ, g γ κ wt, u 2 1 Consequently, the lower bound of the time derivative of wt, u becomes larger as wt, u increases. Therefore, assuming w0, u > 0, we only need to control this lower bound at initialization. Assume w0, u 2(α 1/2) From this, we conclude that when wt, u > 0 and ϕ, g γ = α > 1/2, we have 2 (1 wt, u 2), integration yields the desired result. D Proofs of Section 3 We begin by stating the closed-form expression for the population gradient, i.e. the counterpart of Lemma 12 in the normalized setting. Lemma 15. Consider the population risk R(w) defined by (3.1), recall that 2 ϕ 2 γ + 1 R(w) = Σ1/2(Id w w )ζϕ,g( w, u )u Σ1/2w , (D.1) where ζϕ,g( w, u ) := E[ϕ ( w, z )g ( u, z )] = X j s jαjβj w, u j 1. (D.2) Proof. Recall from (3.1) that Σ1/2 E[ϕ ( w, z )g( u, z )z] = Σ1/2 Id w w Σ1/2w {ζϕ,g( w, u )u + ψϕ,g( w, u )w} (by Lemma 11) = Σ1/2(Id w w )ζϕ,g( w, u )u ψϕ,g( w, u ) := E[ϕ ( w, z )g( u, z ) w, z ϕ ( w, z )g ( u, z ) w, u ] (D.3) (the above is only a function of w, u due to the Hermite expansion). Given the closed form of the population gradient, the proof of Lemma 3 is immediate by noticing that dt = (Id wtwt ) Σ1/2w Σ1/2 dwt Next, we move on to prove Proposition 4. D.1 Proof of Proposition 4 From Lemma 3, we have wt, u ) Σ1/2ut 2 wt, u s 1(1 where ut := u wt. The above inequality and the fact that w0, u > 0 imply that wt, u is non-decreasing in time. Let T1 := sup{t > 0 : wt, u < 1/2}. Then, on t [0, T1], we have dt 3cλmin(Σ) and integration yields T1 0 4 3cλmin(Σ) w0, u s = 1 ln(1/(2 w0, u )) s = 2 1 s 2 (1/ w0, u )s 2 2s 2 s > 2 . Therefore, T1 τs( w0, u )/λmin(Σ). For t > T1, we have d dt cλmin(Σ) Let T2 = sup t > 0 : wt, u < 1 ε . Once again, integration implies T2 T1 + 2s 2 cλmin(Σ) ln(2/(3ε)), which completes the proof. D.2 Preliminary Lemmas for proving Theorem 5 We first introduce a number of concentration (and anti-concentration) lemmas that will be useful for proving Theorem 5. Lemma 16. Suppose {z(i)}n i=1 i.i.d. N(0, Id), and g : R R satisfies |g( )| C(1 + | |p) for some C > 0 and p 1. Additionally, suppose {ϵ(i)}i=1 are i.i.d. σ-sub-Gaussian zero-mean noise independent of {z(i)}n i=1. Let y(i) := g( u, z(i) ) + ϵ(i) for some u Sd 1. Then, for any q > 0, with probability at least 1 4d q, we have y(i) C + C(2 ln(ndq))p/2 + σ p 2 ln(ndq) ln(ndq)p/2, Proof. Notice that u, z(i) N(0, 1), thus 2 ln(ndq) 2n 1d q. Similarly, by the sub-Gaussian and zero-mean property of ϵ(i), 2 ln(ndq) 2n 1d q. Thus, by a union bound, we have D 2 ln(ndq) and ϵ(i) σ p 2 ln(ndq), for all 1 i n, with probability at least 1 4d q. Using the upper bound on |g| finishes the proof. Lemma 17. Suppose {z(i)}n i=1 are i.i.d. samples drawn uniformly from Sd 1. Then, 3d 2 + ln(8n/ Proof. Fix some ϵ (0, 1). Let Nϵ be a minimal ϵ-covering of Sd 1. Let ˆw be the projection of w onto Nϵ. Notice that by the triangle inequality and the union bound w, z(i)E ϵ α i=1 1 D ˆw, z(i)E 2ϵ α i=1 1 D ˆw, z(i)E 2ϵ α i=1 1 D ˆw, z(i)E 2ϵ α Moreover, due to [BBSS22, Lemma A.7], E[1(| ˆw, z | 2ϵ)] = P( ˆw, z 2ϵ) 8 Choose ϵ = 3 d 8n . By Lemma 25 i=1 1 D ˆw, z(i)E 2ϵ 3d 2 + ln(8n/ which completes the proof. We summarize the above statements into a good event , as characterized by the following lemma. Lemma 18. Let {z(i)}n i=1 i.i.d. N(0, Id), and z(i) := z(i)/ z(i) for every i. We say event G occurs whenever: 1. |y|(i) ln(ndq)p/2 for all 1 i n. 2. supw Sd 1 Pn i=1 1 D d/n d ln(n/ 3. λmax 1 n Pn i=1 z(i)z(i) 1 p 4. 1 λmin 1 n Pn i=1 z(i)z(i) p For n d, event G occurs with probability at least 1 O(d q). Proof. The first and second statements of the lemma follow from Lemmas 16 and 17 respectively. The third and fourth statements are standard Gaussian covariance concentration bounds (see e.g. [Wai19, Theorem 6.1] where both statements hold with probability at least 1 2e d). D.3 Proof of Theorem 5 We begin by recalling the definition w := ˆΣ 1/2w ˆΣ 1/2w and the finite-samples dynamics (3.5), which we copy here for the reader s convenience, dt = η(wt) wt where η(wt) = ˆΣ 1/2wt 2. Moreover, via chain rule, we obtain dt = (Id wtwt ) ˆΣ 1/2 Let z(i) := ˆΣ 1/2x(i) = ˆΣ 1/2Σ1/2z(i). Then, dt = (Id wtwt ) ˆΣ(Id wtwt ) i=1 ϕ wt, x(i) y(i) z(i) ) To simplify the notation, define ν(w) := (Id w w ) ˆΣ(Id w w )u. wt, z(i)E y(i) z(i) + We can decompose the above dynamics into a population term and three different error terms in the following manner: dt = ν(wt), Ez,y ϕ wt, z(i)E y(i)z(i) Ez,y ϕ | {z } =:E1 wt, z(i)E ϕ D wt, z(i)E o y(i)D z(i), ν(wt) E | {z } =:E2 w, z(i)E y(i)D z(i) z(i), ν(wt) E | {z } =:E3 We will proceed in three steps. In the first, we bound E1, the concentration error. In the second, we bound E2 and E3, the errors due to estimating Σ with ˆΣ (i.e. replacing z(i) with z(i)). Finally, we will analyze the convergence time similar to that of Proposition 4. Throughout the proof, we will assume that the event G of Lemma 18 occurs. Step 1. Controlling the concentration error E1. Let K ln(ndq)p/2, and notice that on event G we have y(i) K for all i. Let y K := y1(|y| K). On the event G, we have y(i) K = y(i) for all i, and E1 = ν(wt), n n ν(wt) , where w, z(i)E y K (i)z(i) Ez,y[ϕ ( w, z )yz]. Thus, our objective is to bound n uniformly for all w Sd 1. To that end, we first modify the expectation in the above definition so that the empirical average and expected value match in terms of their random variables. Specifically, sup w,v Sd 1 n, v = sup w,v Sd 1 1 n w, z(i)E y(i) K D z(i), v E Ez,y[ϕ ( w, z )y K z, v ] Ez,y[ϕ ( w, z )y z, v 1(|y| > K)]. By the Cauchy-Schwartz inequality, |Ez,y[ϕ ( w, z )y z, v 1(|y| > K)]| Ez,y h ϕ ( w, z )2y2 z, v 2i1/2 E[1(|y| > K)]1/2 E y4 1/4 Ez h z, v 4i1/4 P(|y| > K)1/2 where the last inequality follows from Lemma 16. Hence, sup w,v Sd 1 n, v sup w,v Sd 1 1 n w, z(i)E y(i) K D z(i), v E E[ϕ ( w, z )y K z, v ] + O(d q/2). Next, we need to establish high-probability bounds via a covering argument. To simplify the exposition, define the stochastic process indexed by w Sd 1 and v Sd 1 via w, z(i)E y(i) K D z(i), v E . Fix some ϵw, ϵv > 0. Let Θw and Θv be ϵw and ϵv coverings of Sd 1, and let ˆw and ˆv denote the projection of w onto Θw and of v onto Θv respectively, then sup w,v Sd 1 1 n w,v E[Xw,v] = sup w,v Sd 1 1 n w,ˆv X(i) ˆw,ˆv + Ez,y[Xw,ˆv Xw,v] + Ez,y[X ˆw,ˆv Xw,ˆv] i=1 X(i) ˆw,ˆv Ez,y[X ˆw,ˆv]. We bound each of the terms using Cauchy-Schwartz. Specifically, w, z(i) )2y(i) K 2 v u u t 1 z(i), v ˆv 2 Kϵv, where we used the upper bound on the operator norm of 1 n Pn i=1 z(i)z(i) from Lemma 18 together with the fact that n d. Similarly, Ez,y[Xw,ˆv Xw,v] Ez,y ϕ ( w, z )2y2 K 1/2 Ez h z, v ˆv 2i1/2 Kϵv. To bound the differences when we replace w with ˆw, we need to make a distinction between Re LU and smooth activations as the respective arguments are to some extent different. When ϕ is Lipschitz, w,ˆv X(i) ˆw,ˆv i=1 (y(i) K )2 ϕ ( w, z(i) ) ϕ ( ˆw, z(i) ) 2 v u u t 1 z(i), ˆv2 Kϵw, E[X ˆw,ˆv Xw,ˆv] E h y2 K(ϕ ( w, z ) ϕ ( ˆw, z ))2i1/2 E h z, ˆv 2i1/2 Kϵw. Therefore, for a smooth activation ϕ we choose ϵv = ϵw = p d/n, and obtain sup w,v Sd 1 1 n w,v Ez,y[Xw,v] sup ˆw,ˆv i=1 X(i) ˆw,ˆv Ez,y[X ˆw,ˆv] + O( p When ϕ is the Re LU activation, we need to show that the sign of the preactivation changes only for a small number of samples when we change the weight w to ˆw. Notice that w, z(i)E = sign D ˆw, z(i)E = D w, z(i)E D ˆw w, z(i)E w, z(i)E ϵw. Recall that z(i) := z(i)/ z(i) . Choose ϵw d/n. On event G, we know from Lemma 18 that at most O(d ln(n/ d)) samples can satisfy the above condition. Therefore, w,ˆv X(i) ˆw,ˆv i=1 (y(i) K )2 ϕ ( w, z(i) ) ϕ ( ˆw, z(i) ) 2 v u u t 1 E[X ˆw,ˆv Xw,ˆv] E h y2 K(ϕ ( w, z ) ϕ ( ˆw, z ))2i1/2 E h z, ˆv 2i1/2 KP(sign( w, z ) = sign( ˆw, z ))1/2 KP(| w, z | ϵw)1/2 where the last inequality follows from the anti-concentration on the sphere [BBSS22, Lemma A.7]. Thus, for Re LU we choose ϵv p d/n, and once again obtain sup w,v Sd 1 1 n w,v Ez,y[Xw,v] sup ˆw,ˆv i=1 X(i) ˆw,ˆv Ez,y[X ˆw,ˆv] + O( p It remains to bound the term i=1 X(i) ˆw,ˆv E[X ˆw,ˆv]. Notice that for fixed ˆw, ˆv, X ˆw,ˆv is sub-Gaussian with sub-Gaussian norm O(K). Thus, via the sub-Gaussian maximal inequality [VH16, Lemma 5.2], i=1 X(i) ˆw,ˆv E[X ˆw,ˆv] p K2d/n ln(1/(ϵwϵv)), with probability at least 1 e d. Consequently, we have sup w Sd 1 n O( p d/n + d q/2), with probability at least 1 O(d q). Assuming that n grows at most polynomially in dimension and choosing a sufficiently large q, we have supw Sd 1 n O( p d/n) with probability at least 1 O(d q). Finally, by Lemma 23, ν(wt) λmax ˆΣ λmax(Σ), (D.6) with probability at least 1 e n /2. Combining the above with the bound on n , we have E1 λmax(Σ) O( p d/n) with probability at least 1 O(d q), which concludes the first step of the proof. Step 2. Bounding the error due to the estimation of Σ, i.e. E2 and E3. Recall that we are considering the event G, thus y(i) = y(i) K . We can control each of the error terms separately. We begin by E2, where by Cauchy-Schwartz wt, z(i)E ϕ D wt, z(i)E o y(i) K D z(i), ν(wt) E wt, z(i) ϕ D wt, z(i)E o2 v u u t 1 2 z(i), ν(wt) 2 wt, z(i) ϕ D wt, z(i)E o2 , where the last line follows from Lemma 18 and the fact that n d. When ϕ is additionally Lipschitz, we have wt, z(i) z(i)E2 . Moreover, for any w Sd 1, w, z(i) z(i)E2 1 i=1 z(i)z(i) 2 (Id ˆΣ 1/2Σ1/2)w 2 i=1 z(i)z(i) 2 Id ˆΣ 1/2Σ1/2 2 where the last inequality holds with probability at least 1 2e d on the event of Lemma 24. Hence for smooth activations we conclude E2 K ν(wt) p When ϕ is the Re LU activation, we need a more involved argument to control E2. In particular, we will show that for any w, at most only O(d) datapoints can have sign w, z(i) = sign D w, z(i)E . Notice that w, z(i)E = sign D w, z(i)E = D w, z(i) z(i)E w, z(i)E Id ˆΣ 1/2Σ1/2 (D.7) where z(i) := z(i) z(i) is distributed uniformly over Sd 1. From Lemma 24 we have Id ˆΣ 1/2Σ1/2 q d n with probability at least 1 2e d. On the other hand, from Lemma 17 we know with probability at least 1 e d, for any w Sd 1 at most O(d) of the labeled samples have d/n. Recall that n n2 when using the Re LU activation. This is precisely why we make this choice for the Re LU activation, as we need to balance the RHS of (D.7) which is of order p d/n with the LHS of (D.7) which should at most be of order d/n if we want to ensure only O(d) samples satisfy the bound. When n = n2 we can balance these two terms, thus with probability at least 1 3e d the sign change can occur for at most O(d) many samples, and wt, z(i)E ϕ D wt, z(i)E o2 O d In this case, we end up with E2 K ν(wt) O( p Bounding E3 for Re LU and Lipschitz ϕ is identical. In both cases, by Cauchy-Schwartz, w, z(i) )2y(i) K 2 v u u t 1 D z(i) z(i), ν(wt) E2 i=1 z(i)z(i) Id ˆΣ 1/2Σ1/2 ν(wt) which holds on the intersection of event G and of Lemma 23. At last, using the bound on ν(wt) from (D.6), we obtain E2 E3 λmax(Σ) O( p d/n), with probability at least 1 O(d q). Step 3. Analyzing the Convergence. As a result of the previous steps, we have established dt ν(wt), E ϕ ( wt, z )yz λmax(Σ) O( p Thanks to Lemma 11, we can write E[ϕ ( w, z )yz] = E[ϕ ( w, z )g( u, z )z] = ζϕ,g( w, u )u ψϕ,g( where ζϕ,g and ψϕ,g were introduced in (D.2) and (D.3) respectively. Recall the definition of ν(wt), ν(wt) := (Id wtwt ) ˆΣ(Id wtwt )u. u (Id wtwt ) ˆΣ(Id wtwt T )u λmax(Σ) O( p ut , ˆΣut E λmax(Σ) O( p d/n) (By Assumption 2) wt, u s 1(1 wt, u 2) λmax(Σ) O( p where ut := u wt. Moreover, from Lemma 23, we have λmin ˆΣ λmin(Σ) tr(Σ) n λmin(Σ) with probability at least 1 e n /8. Hence, for n dκ(Σ) we have λmin ˆΣ λmin(Σ), and consequently, dt c λmin(Σ) wt, u s 1(1 wt, u 2) λmax(Σ) O( p where c is a universal constant. Notice that the first term in the RHS above denotes the signal, while the second term denotes the noise. We want to ensure the noise remains smaller than the signal throughout the trajectory, which leads to the convergence of wt to u. Notice that the signal term is first increasing, then decreasing for wt, u [0, 1]. Thus, it suffices to ensure the noise is smaller than the signal on the two ends of the interval, i.e. at time t = 0 and at time t = T where w T , u = 1 ε. At initialization, this condition leads to w0, u 2(1 s), and at time t = T, leads to where C hides constant depending only on s and at most polylogarithmic factors of d. Thus, we have established the sample complexity as presented by Theorem 5. It remains to obtain the convergence time. With the above sample complexity, we have dt c λmin(Σ) wt, u s 1(1 where c is a universal constant. The rest of the proof follows by integration and is identical to the proof of Proposition 4 in Appendix D.1. D.4 Proof of Corollary 6 The proof follows immediately from Theorem 5 and the following lemma which describes how w0, u behaves under different regimes of r1 and r2. Lemma 19. Suppose Σ follows the (κ, θ)-spiked model, w0 is sampled uniformly from Sd 1, n d, and there exist universal constants C2, C 2, C3, C 3 > 0 such that C2dr2 κ C 2dr2 and C3d r1 u, θ C 3d r1 for r1 [0, 1/2] and r2 [0, 1]. Then, conditioned on w0, u > 0, with any arbitrarily large constant probability 1 δ, for sufficiently large d (that depends on δ) we have d 1/2 0 r2 < r1 dr2 r1 1/2 r1 < r2 < 2r1 d(r2 1)/2 2r1 < r2 < 1 . (D.8) Proof. By definition, D ˆΣ 1/2w0, Σ1/2u E ˆΣ 1/2w0 Σ1/2u . Recall that we are conditioning our arguments on w0, u > 0, hence the numerator of the above fraction is positive. To translate the sample complexities of Theorems 5 and 7 to the spiked model, our goal is to lower bound w0, u in terms of d, r1, and r2. We begin by observing that ˆΣ 1/2w ˆΣ 1/2Σ 1/2 Σ1/2w Σ1/2w , where the last inequality holds on the event of Lemma 24, which happens with probability at least 1 2e d. Consequently, D ˆΣ 1/2w0, Σ1/2u E Σ1/2w0 Σ1/2u = w0, Σu + D w0, ( ˆΣ 1/2 Σ1/2)Σ1/2u E Σ1/2w0 Σ1/2u . Furthermore, due to the Markov inequality, P w0, θ 2 C1 Similarly (by conditioning on ˆΣ) D w0, ( ˆΣ 1/2 Σ1/2)Σ1/2u E2 C1 ( ˆΣ 1/2 Σ1/2)Σ1/2u 2 Additionally, on the event of Lemma 24, ( ˆΣ 1/2 Σ1/2)Σ1/2u ˆΣ 1/2Σ 1/2 Id Σu p Therefore, on the above events, for some absolute constant C > 0 w0, Σu C Σu p 1+C 2C1 1+κ Σu w0, u + κ w0, θ u, θ C(1 + κ| u, θ |) p 1 + C 2C1 q 1 + κ u, θ 2 . Recall that C2dr2 κ C 2dr2 and C3d r1 u, θ C 3d r1 (notice that changing θ to θ does not change the spiked model of Assumption 1, thus we can assume u, θ 0 without loss of generality). Then, w0, u + κ w0, θ u, θ C(1 + κ u, θ ) p 1 + C 2C1 q 1 + C 2C 3 2dr2 2r1 . (D.9) The last term in the numerator can be made arbitrarily small by sufficiently large n , hence we focus on other terms for now. Intuitively, when r2 < 2r1, the denominator is of constant order. If additionally r2 < r1, the dominant term in the numerator is w0, u and d, otherwise the dominant term is κ w0, θ u, θ and w0, u dr2 r1 1/2. On the other hand, when r2 > 2r1, the denominator is of order dr2/2 r1, and once again the dominant term of the numerator is κ w0, θ u, θ , therefore w0, u d(r2 1)/2. Using this intuition, we analyze each of the following regimes separately. Case 1. 0 < r2 < r1: In this case, by [BBSS22, Lemma A.7] we have w0, u c/ d with probability at least 1 4c. On the intersection of all considered events with w0, u > 0, and for sufficiently large d and n d, we must have w0, u > 0 (otherwise w0, u < 0). Thus by plugging the values in (D.9), w0, u c C1C 2C 3dr2 r1 C(1 + C 2C 3dr2 r1) p 1 + C 2C1 q 1 + C 2C 3 2dr2 2r1 1 The intersection of all desired events and w0, u > 0 happens with probability at least 1 2 4c 2/C1 2e d, thus conditioned on w0, u the probability is at least 1 8c 4/C1 4e d. Choosing sufficiently small c, large C1, and respectively large d and n d with sufficiently large absolute constant, we can arbitrarily increase the (constant) probability of success. Thus the analysis of this regime is complete. Case 2. r1 < r2 < 2r1: This time we use the fact that w0, u p C1/d with probability at least 1 1/C1, and w0, θ c/ d with probability at least 1 4c. By an argument similar to the previous case, for sufficiently large d and n d, w0, u > 0 implies w0, θ > 0, hence by (D.9) w0, u C1 + c C2C3dr2 r1 C(1 + C 2C 3dr2 r1) p 1 + C 2C1 q 1 + C 2C 3 2dr2 2r1 dr2 r1 1/2, (D.11) with probability at least 1 8c 4/C1 4e d when conditioned on Case 3. 2r1 < r2 < 1: Once again recall (D.9). To bound the numerator, we repeat the exact same argument as in the previous case, thus w0, u C1 + c C2C3dr2 r1 C(1 + C 2C 3dr2 1) p (1 + C 2C1)C 2C 3 2d 1+r2 2r1 2 d(r2 1)/2, (D.12) which finishes the proof of the lemma. E Proofs of Section 4 E.1 Proof of Theorem 7 We recall from (D.5) that dt = (Id wtwt ) ˆΣ 1/2 Furthermore, the preconditioned dynamics of wt given by (4.1) reads dt = η(wt) ˆΣ 1/2(Id wtwt ) wt, z(i)E y(i) z(i) ) where we recall z(i) := ˆΣ 1/2x(i). Plugging in η(wt) = ˆΣ 1/2w 2 yields dt = (Id wtwt )2 ( 1 n wt, z(i)E y(i) z(i) ) = (Id wtwt ) wt, z(i)E y(i) z(i) ) The rest of the analysis is identical to that of the proof of Theorem 5 in Appendix D.3. Specifically, by defining ut , Ez,y ϕ ( wt, z(i)E y(i)z(i) Ez,y ϕ | {z } =:E1 wt, z(i)E ϕ D wt, z(i)E o y(i)D z(i), ut E | {z } =:E2 w, z(i)E y(i)D z(i) z(i), ut E | {z } =:E3 As long as n d, n = n for the smooth case, and n n2 for the Re LU case, the first two steps of the proof of Theorem 5 in Appendix D.3 implies that E1 E2 E3 ut O( p Once again, we apply Lemma 11 to obtain E[ϕ ( w, z )yz] = ζϕ,g( w, u )u ψϕ,g( with ζϕ,g and ψϕ,g given in (D.2) and (D.3) respectively. As a result, wt, u ) ut 2 O( p wt, u s 1(1 wt, u 2) O( p d/n) (By Assumption 2). We need to ensure the noise term, i.e. the second term on the RHS remains smaller than the signal, i.e. the first term. The signal term attains its minimum at either initialization t = 0 or at the end of the trajectory t = T where w T , u = 1 ε, which imposes the following sufficient conditions on n. Namely, at initialization we require w0, u 2(1 s), while at t = T we require n Cd/ε2, where C hides constant that only depend on s and polylogarithmic factors of d. Hence, we obtain wt, u s 1(1 for some universal constant c > 0. Via integration (similar to the proof of Proposition 4 in Appendix D.1), for T1 := sup{t > 0 : wt, u < 1/2} we obtain T1 τs( and for T2 := sup{t > 0 : wt, u < 1 ε} we obtain T2 T1 ln(1/ε), which completes the proof. We conclude by remarking that the proof of Corollary 8 is immediate given Theorem 7 and Lemma 19. E.2 Preliminary Lemmas for Proving Theorem 9 We will adapt the following lemma from [MHPG+23], which provides a non-parametric approximation of g via random biases. Lemma 20. [MHPG+23, Lemma 22] For any smooth g : R R and > 0, let g : R R be a smooth function such that g(z) = g(z) for |z| and g( 2 ) = g ( 2 ) = 0. Suppose {bj}m j=1 i.i.d. Unif( 2 , 2 ), and let := sup|z| 2 | g (z)|. Then, there exist second layer weights {aj(bj)}m j=1 with a / m, such that for any fixed z [ , ] and any δ > 0, with probability at least 1 δ over the random biases, j=1 a(bj)ϕ(z + bj) g(bj) where ϕ is the Re LU activation. We use the above lemma to show the existence of a second layer with O(1/ m) norm with training error of order O(1/m). Lemma 21. For any ε < 1, suppose w, u 1 ε. Then for any q > 0, sufficiently large d, n d/ε2, with probability at least 1 O(d q) over the random biases and the dataset, there exists a second layer a with a O(1/ m) described by Lemma 20 such that w, z(i)E + bj) y(i) E ϵ2 + O(1/m + ε), where ϕ is the Re LU activation. Proof. We begin by replacing w and z(i) with u and z(i). Specifically, via Jensen s inequality, w, z(i)E + bj y(i) u, z(i)E + bj g D | {z } =:E1 u, z(i)E o2 | {z } =:E2 j=1 aj n ϕ D w, z(i)E + bj ϕ D w, z(i)E + bj o | {z } =:E3 j=1 aj n ϕ D w, z(i)E + bj ϕ D u, z(i)E + bj o | {z } =:E4 We bound each term separately. For E1, we can invoke Lemma 20 which implies that each term in the sum can be bounded by O(1/m) with probability at least 1 1/(ndq), thus by a union bound, with probability at least 1 d q over the random biases, By sub-Guassianity of ϵ(i) (hence sub-exponentiality of ϵ(i)2), for n d (with a sufficiently large constant) we have E2 E ϵ2 + p with probability at least 1 e d. For E3, via the Lipschitzness of Re LU and the Cauchy-Schwartz inequality we can write w, z(i) z(i)E2 O(1) I ˆΣ 1/2Σ1/2 2 O(d/n ), where we used the event of Lemma 24 which happens with probability at least 1 2e d, and O(1) represents a constant that depends at most polylogarithmically on d. Finally, we bound the last term. Once again via the Lipschitzness of the Re LU activation and the Cauchy-Schwartz inequality w u, z(i)E2 O( w u 2) O(ε), where once again we used the event of Lemma 24. On the intersection of all desired events, we have w, z(i)E + bj) y(i) E ϵ2 + O(1/m + p d/n + d/n + ε). We conclude the proof by noticing that n n2 and n dε 2. Additionally, we will use the following standard Lemma on the Rademacher complexity of two-layer neural networks, which in particular is a restatement of [MHPG+23, Lemma 18] in a way suitable for our analysis. Lemma 22. Let F be a class of real-valued functions on (z, y). Given n samples {z(i), y}n i=1, define the empirical Rademacher complexity of F as ˆRn(F) := E(ςi)n i=1 i=1 ςif(z(i), y(i)) where (ςi) are i.i.d. Rademacher random variables (i.e. 1 with equal probability). Suppose F is given by j=1 ajϕ( u, z + bj) y C : a ra/ m, |bj| rb, 1 j m for some fixed u Sd 1. Suppose {z(i)}n i=1 i.i.d. N(0, Id), and suppose |ϕ | 1. Then, E(z(i),y(i))n i=1 h ˆRn(F) i 2 2C(1 + rb)ra n . Proof. See the proof of [MHPG+23, Lemma 18]. E.3 Proof of Theorem 9 Throughout the proof, we will assume w, u 1 ε where we recall w := ˆΣ 1/2w ˆΣ 1/2w and u := Σ1/2u Σ1/2u . From either Theorem 5 or Theorem 7, we can assume w, u 1 ε with probability at least 1 O(d q) for any fixed q > 0. For simplicity, let ˆy( z; w) = j=1 ajϕ( w, z + bj), and similarly define ˆy(z; u). We define the following quantities, R(w) := Ez,y h (ˆy( z; w) y)2i and R(u) := Ez,y h (ˆy(z; u) y)2i , (E.1) and similarly define their empirical counterparts, ˆy( z(i); w) y(i) 2 and ˆR(u) := 1 ˆy(z(i); u) y(i) 2 . (E.2) Notice that ultimately, we are interested in bounding R(w). We break down the proof into three steps. In the first step, we show that R(w) can be upper bounded by R(u). Then, via a generalization bound, we show that the R(u) can be upper bounded by ˆR(u). Finally, we show that ˆR(u) can be upper bounded by the training error, i.e. ˆR(w), and convex optimization of the last layer can attain the near-optimal value of this training error which is bounded by Lemma 21. Step 1. Bounding R(w) via R(u). By Jensen s inequality, Ez,y h (ˆy( z; w) y)2i 3 Ez h (ˆy( z; w) ˆy(z; w))2i +3 Ez h (ˆy(z; w) ˆy(z; u))2i +3 Ez,y h (ˆy(z; u) y)2i . Suppose a ra/ m. For the first term, by Lipschitzness of ϕ and the Cauchy-Schwartz inequality Ez h (ˆy( z; w) ˆy(z; w))2i = Ez j=1 aj n ϕ D w, z(i)E + bj ϕ D w, z(i)E + bj o r2 a Ez h w, z z 2i r2 a Id ˆΣ 1/2Σ1/2 2 r2 ad/n , where the last inequality holds with probability at least 1 2e d on the event of Lemma 24. For the middle term, via a similar argument, Ez h (ˆy(z; w) ˆy(z; u))2i r2 a Ez h w u, z 2i 2raε. In what follows, we will restrict the analysis to the case where ra = O(1). Therefore, we have Ez,y h (ˆy( z; w) y)2i 3 Ez,y h (ˆy(z; u) y)2i + O(d/n + ε). Step 2. Generalization: Bounding R(u) via ˆR(u). Define the event E := n | u, z | |ϵ| p 2 ln(ndq) o . and similarly define E(i) by replacing z and ϵ with z(i) and ϵ(i) respectively. Via the Cauchy-Schwartz and Jensen inequalities Ez,y h (ˆy(z; u) y)2i = Ez,y h (ˆy(z; u) y)21(E) i + Ez,y h (ˆy(z; u) y)21(EC) i Ez,y h (ˆy(z; u) y)21(E) i + 8 E ˆy(z; u)4 + E y4 1/2P EC 1/2 Moreover, E y4 1, Ez,y ˆy(z; u)4 O(1), and P EC 4/(ndq) (via a standard sub-Gaussian tail bound). Consequently, Ez,y h (ˆy(z; u) y)2i Ez,y h (ˆy(z; u) y)21(E) i + O(n 1d q), ℓ(z(i), y(i); a, b) := u, z(i)E + bj y(i) Notice that u is fixed. Then, by a standard symmetrization argument (see e.g. [VH16, Lemma 7.4]) and Lemma 22 sup a ra/ m,|bj| rb Ez,y[ℓ(z, y; a, b)] 1 i=1 ℓ(z(i), y(i); a, b) 2 E h ˆRn(F) i where ˆRn(F). As the loss is bounded, we can apply Mc Diarmid s inequality to turn the above bound in expectation into a bound in probability, in particular sup a ra/ m,|bj| rb Ez,y[ℓ(z, y; a, b)] 1 i=1 ℓ(z(i), y(i); a, b) O( p with probability at least 1 2e d. Therefore, we conclude this step by noticing that Ez,y h (ˆy(z; u) y)2i 1 ˆy(z(i); u) y(i) 2 1(E(i)) + O( p d/n + n 1d q) ˆy(z(i); u) y(i) 2 + O( p with probability at least 1 2e d. Step 3. Bounding the training error and finishing the proof. This step is similar to the proof of [MHPG+23, Theorem 4]. For conciseness, define ˆy(x(i); W , a, b) y(i) 2 . and ˆRλ(a) := ˆR(a) + λ a 2/2. Our goal is to choose suitable λ such that the minimizer a := arg min a Rm ˆR(a) + λ a 2/2, satisfies a ra/ m while the value of the above minimization problem which we denote with ˆR λ does not significantly exceed min a ra/ m ˆR(a). We argue that the suitable choice for λ is λ m E ϵ2 + mε + 1 r2a = Θ m E ϵ2 + mε + 1 . (E.3) Let ˆR denote the minimizer of the regularized problem and a := arg min a ra/ m ˆR(a). From Lemma 21, with a proper choice of ra = Θ(1), we have ˆR( a) E ϵ2 + O(1/m + ε). with probability at least 1 O(d q) over the biases and the dataset. Note that as a is the minimizer of ˆRλ, we have ˆR(a ) + λ a 2 2 ˆR( a) + λ a 2 and in particular λ a 2 2 ˆR( a) + λ a 2 2 = a O(1/ m). and ˆR(a ) ˆR( a) + λ a 2 2 E[ϵ]2 + O(1/m + ε). Let {at}t 0 be the solution to the gradient flow of a. Then, dt = 2 D at a , ˆRλ(at) E , and by the first-order condition of strong convexity D at a , ˆRλ(at) E λ at a 2, therefore a T a 2 e 2λT a0 a 2. As the training error (of the regularized problem) is λ-strongly convex in a, by applying the standard Polyak-Łojasiewicz condition, gradient flow for training a obtains ˆRλ(a T ) ˆR λ ˆRλ(a0) ˆR λ e 2λT . Furthermore, since a 2 a T 2 2 a a T a a T a 2 2 a T a a , we have ˆR(a T ) ˆR 2 a0 a a e λT + ˆRλ(a0) ˆR λ e 2λT . Consequently, choosing λ ln 4 a0 a a λ ln 2( ˆ Rλ(a0) ˆ R λ) ε implies ˆR(a T ) ˆR + ε and a T 2 a O(1/ m). Therefore ˆR(a T ) E ϵ2 + O(1/m + ε). Recall that ˆR(a T ) = 1 ˆy( z(i); w) y(i) 2 , is the final training error which we also denoted by ˆR(w) earlier in this section when were not focusing on the second layer. From the previous two steps, we know how to bound R(w) via ˆR(u). Thus the last step is to upper bound ˆR(u) via ˆR(w). To that end, via Jensen s inequality ˆy(z(i); u) y(i) 2 3 ˆy( z(i); w) y(i) 2 ˆy(z(i); w) ˆy( z(i); w) 2 ˆy(z(i); w) ˆy(z(i); u) 2 . The first term on the RHS is ˆR(w) for which we developed a bound earlier in this step. Bounding the latter two terms can be performed similarly to the arguments in the previous sections. In particular, ˆy(z(i); w) ˆy( z(i); w) 2 r2 a Id ˆΣ 1/2Σ1/2 2 1 i=1 z(i)z(i) 2 O(d/n ), where the last inequality holds with probability at least 1 2e d (over the event of Lemma 24). Similarly, 1 n ˆy(z(i); w) ˆy(z(i); u) 2 ra w u 2 O(ε). Putting the bounds back together (recall n n dε 2), we arrive at ˆR(u) ˆR(w) + O(ε + d/n ) ˆR(w) + O(ε). Combining the result of this step with the two previous steps implies R(w) E ϵ2 + O(1/m + ε), with probability at least 1 O(d q) (when conditioned on w0, u > 0) which completes the proof of Theorem 9. F Auxiliary Lemmas In this section, we recall a number of standard lemmas which we employ in various parts of our proofs. Lemma 23. [Wai19, Theorem 6.1]. Suppose {x(i)}n i=1 i.i.d. N(0, Σ). Let ˆΣ := 1 n Pn i=1 x(i)x(i) . Then, for n tr(Σ)/λmax(Σ), λmax ˆΣ λmax(Σ) tr(Σ) n λmax(Σ) with probability at least 1 e n /2. Furthermore, for n d, λmin ˆΣ λmin(Σ) tr(Σ) nλmin(Σ) with probability at least 1 e n /8. Lemma 24. Suppose z(i) n i=1 i.i.d. N(0, Id), let x(i) := Σ1/2z(i) for some invertible Σ, and define i=1 x(i)x(i) . Id ˆΣ 1/2Σ 1/2 Id ˆΣ 1/2Σ1/2 with probability at least 1 2e d. Proof. We have Id ˆΣ 1/2Σ1/2 = n λmax ˆΣ 1/2Σ1/2 1 o n 1 λmin ˆΣ 1/2Σ1/2 o = λmax Σ1/2 ˆΣ 1Σ1/2 1/2 1 n 1 λmin Σ1/2 ˆΣ 1Σ1/2 o = λmin Σ 1/2 ˆΣΣ 1/2 1/2 1 1 λmax Σ 1/2 ˆΣΣ 1/2 1/2 i=1 z(i)z(i) i=1 z(i)z(i) Id ˆΣ 1/2Σ 1/2 = i=1 z(i)z(i) i=1 z(i)z(i) Moreover, by [Wai19, Example 6.2], we have with probability at least 1 2e d, i=1 z(i)z(i) d n and λmin i=1 z(i)z(i) Thus, for n d (with a sufficiently large absolute constant), we have Id ˆΣ 1/2 ˆΣ 1/2 Id ˆΣ 1/2Σ1/2 with probability at least 1 2e d. Lemma 25 (Chernoff s Inequality). Suppose X1, . . . , Xn are i.i.d. Bernoulli random variables, and further assume that E[P i Xi] µ. Then, for any δ 1, i=1 Xi µ(1 + δ) e µδ/3. (F.1) Proof. The proof follows from a standard Chernoff bound. From [Ver18, Theorem 2.3.1] i Xi µ(1 + δ) eµ(δ (1+δ) ln(1+δ)), (notice that the statement of [Ver18, Theorem 2.3.1] holds true even when E[P i Xi] = µ is replaced with E[P i Xi] µ). We conclude by remarking that δ (1 + δ) ln(1 + δ) δ/3 for δ 1.