# on_the_feature_learning_in_diffusion_models__138e6a84.pdf Published as a conference paper at ICLR 2025 ON THE FEATURE LEARNING IN DIFFUSION MODELS Andi Han Wei Huang Yuan Cao Difan Zou RIKEN AIP (andi.han@riken.jp, wei.huang.vr@riken.jp). Equal contribution. Department of Statistics and Actuarial Science, University of Hong Kong (yuancao@hku.hk) Department of Computer Science and Institute of Data Science, University of Hong Kong (dzou@cs.hku.hk) The predominant success of diffusion models in generative modeling has spurred significant interest in understanding their theoretical foundations. In this work, we propose a feature learning framework aimed at analyzing and comparing the training dynamics of diffusion models with those of traditional classification models. Our theoretical analysis demonstrates that diffusion models, due to the denoising objective, are encouraged to learn more balanced and comprehensive representations of the data. In contrast, neural networks with a similar architecture trained for classification tend to prioritize learning specific patterns in the data, often focusing on easy-to-learn components. To support these theoretical insights, we conduct several experiments on both synthetic and real-world datasets, which empirically validate our findings and highlight the distinct feature learning dynamics in diffusion models compared to classification. 1 INTRODUCTION Diffusion models (Ho et al., 2020; Song et al., 2021) have emerged as a powerful class of generative models for content synthesis and have demonstrated state-of-the-art generative performance in a variety of domains, such as computer vision (Dhariwal & Nichol, 2021; Peebles & Xie, 2023), acoustic (Kong et al., 2021; Chen et al., 2021) and biochemical (Hoogeboom et al., 2022; Watson et al., 2023). Recently, many works have employed (pre-trained) diffusion models to extract useful representations for tasks other than generative modelling, and demonstrated surprising capabilities in classical tasks such as image classification with little-to-no tuning (Mukhopadhyay et al., 2023; Xiang et al., 2023; Li et al., 2023a; Clark & Jaini, 2024; Yang & Wang, 2023; Jaini et al., 2024). Compared to discriminative models trained with supervised learning, diffusion models not only are able to achieve comparable recognition performance (Li et al., 2023a), but also demonstrate exceptional out-of-distribution transferablity (Li et al., 2023a; Jaini et al., 2024) and improved classification robustness (Chen et al., 2024c). The significant representation learning power suggests diffusion models are able to extract meaningful features from training data. Indeed, the core of diffusion models is to estimate the data distribution through progressively denoising noisy inputs over several iterative steps. This inherently views data distribution as a composition of multiple latent features and therefore learning the data distribution corresponds to learning the underlying features. Nevertheless, it remains unclear how feature learning emerges during the training of diffusion models and whether the feature learning process is different to supervised learning. Regardless of the ground-breaking success of diffusion models, the theoretical understanding is still in its infancy. Existing analysis on diffusion models has mostly focused on theoretical guarantees in terms of distribution estimation and sampling convergence. Several works have derived statistical estimation errors between distribution generated by diffusion models to ground-truth distribution (Oko et al., 2023; Zhang et al., 2024; Chen et al., 2023a), showing that diffusion models achieve a minimax optimal rate under certain assumptions on the true density (Oko et al., 2023; Zhang et al., 2024). Algorithmically, Li et al. (2023c); Han et al. (2024) studied the estimation error of diffusion models trained with gradient descent using kernel methods. Shah et al. (2023); Gatmiry et al. (2024); Published as a conference paper at ICLR 2025 Chen et al. (2024d) introduced algorithms based on diffusion models for learning Gaussian mixture models. In addition, given access to sufficiently accurate score estimation, Lee et al. (2022; 2023); Chen et al. (2023b); Li et al. (2023b) proved the convergence guarantees of sampling in (score-based) diffusion models. Despite showing provable guarantees for diffusion models, existing theories are limited to the generative aspects of diffusion models, namely distribution learning and sampling. To the best of our knowledge, no theoretical analysis is performed to elucidate the feature learning process in diffusion models. Notations. We make use of the following notations throughout the paper. We use to denote L2 norm for vectors and Frobenius norm for matrices, unless mentioning otherwise. We use O( ), Ω( ), Θ( ), o( ), ω( ) for the big-O, big-Omega, big-Theta, small-o, small-omega notations. We write e O( ) to hide (poly)logarithmic factors and similar notations hold for eΩ( ) and eΘ( ). For a binary condition C, we let 1(C) = 1 if C is true and 1(C) = 0 otherwise. 1.1 OUR MAIN RESULTS Figure 1: Illustration of the ratio of signal learning to noise learning when varying n SNR2, where SNR := µ /(σξ d). We show diffusion model tends to study more balanced signal and noise while classification has a sharp phase transition and tends to focus on learning either signal or noise. In this work, we develop a theoretical framework that studies feature learning dynamics of diffusion models and compares with classification. Inspired by the image data structure, we employ a multipatch data distribution x = [µy, ξ] for both classification and diffusion model training. We consider a binary-class data setup with y = 1 as the data label and µ1, µ 1 Rd are two fixed orthogonal vectors, i.e., µ1 µ 1, representing the signal. On the other hand, ξ is the label-independent noise, which is randomly sampled from a Gaussian distribution with standard deviation σξ. In order to elucidate the difference of feature learning dynamics for the two tasks, we adopt a two-layer convolutional neural network with quadratic activation. For diffusion model, we consider a weight-sharing setting for the first and second layer, which is commonly considered for analyzing autoencoders (Nguyen, 2021; Cui & Zdeborová, 2024). For classification, we fix the second layer weights to be 1, following Cao et al. (2022); Kou et al. (2023). In other words, the classifier can be viewed as attaching a fixed linear head to the intermediate layer of the diffusion model. Given a training dataset of n samples from the multi-patch data distribution, we use gradient descent to minimize the empirical logistic loss for classification and the DDPM loss (Ho et al., 2020) with expectation over the diffusion noise. Under the above settings, we investigate the differences of feature learning dynamics (Allen-Zhu & Li, 2023; Cao et al., 2022; Zou et al., 2023; Huang et al., 2023b; Jiang et al., 2024; Huang et al., 2024a; Lu et al., 2024; Meng et al., 2024) between diffusion model and classification. We quantify the feature learning in terms of signal learning and noise learning, measured through the alignment between the network weights w to the directions of signal/noise respectively, i.e., | w, µy |, | w, ξ |. We present the following (informal) results for the two learning paradigms. Theorem 1.1 (Informal). Let SNR := µ /(σξ d) be the signal-to-noise ratio. We can show For diffusion model, | w, µy |, | w, ξ | exhibit linear growth initially and there exists a stationary point along the path of the training dynamics that satisfies | w, µy |/| w, ξ | = Θ(n SNR2). For classification, | w, µy |, | w, ξ | exhibit exponential growth initially and when n SNR2 β for some constant β > 1, | w, µy |/| w, ξ | = ω(1), and when n SNR2 < 1/β, | w, µy |/| w, ξ | = o(1). Published as a conference paper at ICLR 2025 Theorem 1.1 highlights differences in the feature learning process between diffusion models and classification. Especially in the regime where n SNR2 = Θ(1), classification is sensitive to changes in SNR and tends to learn either the signal µy or the noise ξ. In contrast, diffusion model learns both signal and noise to the same order. Such a claim is visualized in Figure 1. We believe our framework represents the first attempt to systematically investigate feature learning within diffusion models, potentially uncovering novel insights into the intriguing properties of diffusion models, including but not limited to critical window (Sclocchi et al., 2024; Li & Chen, 2024), shape bias (Jaini et al., 2024), classification robustness (Chen et al., 2024c), feature composition and dependence (Okawa et al., 2024; Yang et al., 2025; Han et al., 2025). 1.2 RELATED WORK Theoretical analysis of diffusion model. Existing theoretical guarantees for diffusion models focus on distribution estimation and sampling. For distribution estimation, Oko et al. (2023) proved that diffusion models achieve a nearly minimax optimal estimation error where the true density is defined over a bounded Besov space. Zhang et al. (2024) extended the minimax optimality to more general sub-Gaussian densities with sufficient smoothness. When the density is supported on a low-dimensional subspace, diffusion models avoid curse of dimensionality with an estimation rate depending only on the intrinsic dimension (Oko et al., 2023; Chen et al., 2023a). Furthermore, Shah et al. (2023); Gatmiry et al. (2024); Chen et al. (2024d) introduced algorithms based on diffusion models for learning a mixture of Gaussians. Other works provided guarantees of diffusion model trained by gradient descent (Li et al., 2023c; Han et al., 2024; Wang et al., 2024). For sampling, Lee et al. (2022; 2023); Chen et al. (2023b); Li et al. (2023b) have shown (score-based) diffusion models converge polynomially under sufficiently accurate score estimation. Recent studies also aimed to accelerate the convergence via strategies such as consistency training (Song et al., 2023; Li et al., 2024b), advanced design of the reverse transition kernel (Huang et al., 2024b), higher-order approximation (Li et al., 2024a) and parallelization (Chen et al., 2024a; Gupta et al., 2024). In addition, Li & Chen (2024) theoretically verified critical windows of feature emergence during the sampling process, provided accurate score estimation. Theoretical analysis on (denoising) autoencoders. Diffusion models can be viewed as multi-level denoising autoencoders (Xiang et al., 2023). While there is extensive research on the theoretical guarantees of autoencoders without denoising, most studies focus on linear autoencoders (Kunin et al., 2019; Oftadeh et al., 2020; Steck, 2020; Bao et al., 2020). In contrast, only a limited number of works analyze non-linear autoencoders, primarily in the lazy training regime (Nguyen et al., 2021) or the mean-field regime (Nguyen, 2021). Additionally, the training dynamics of non-linear autoencoders have been investigated under population gradient descent (Shevchenko et al., 2023; Kögler et al., 2024) and online gradient descent (Refinetti & Goldt, 2022). On the other hand, the training dynamics of denoising autoencoders have been studied in the context of linear networks (Pretorius et al., 2018) and in the high-dimensional asymptotic limit (Cui & Zdeborová, 2024). Diffusion model for representation learning. Pre-trained diffusion models are shown to learn powerful representation, which is useful for downstream tasks such as classification (Mukhopadhyay et al., 2023; Xiang et al., 2023; Li et al., 2023a; Clark & Jaini, 2024; Yang & Wang, 2023), semantic segmentation (Baranchuk et al., 2022; Zhao et al., 2023; Yang & Wang, 2023). Moreover, many works have found intriguing properties of diffusion models used as classifier, including its ability to understand shape bias (Jaini et al., 2024) and improved adversarial robustness (Chen et al., 2024c). For more detailed exposition, we refer to the recent survey on this matter (Fuest et al., 2024). 2 PROBLEM SETTING This section introduces the problem settings for both diffusion model and classification, including the data model, neural network functions as well as training objectives and algorithm. Definition 2.1 (Data distribution). Each data sample consists of two patches, as x = [x(1) , x(2) ] , where each patch is generated as follows: Sample y { 1, 1} uniformly with P(y = 1) = P(y = 1) = 1/2. Published as a conference paper at ICLR 2025 Given two orthogonal signal vectors µ1, µ 1, with µ1 µ 1, we set x(1) = µy, i.e., x(1) = µ1 if y = 1 and x(1) = µ 1 if y = 1. For simplicity, we assume µ1 = µ 1 = µ . Set x(2) = ξ where ξ N(0, σ2 ξ(I µ1µ 1 µ1 2 µ 1µ 1 µ 1 2)). This multi-patch data model reflects the structure of image data, where each image consists of multiple patches, and only a subset of the patches are relevant to the class label, while the rest contribute as background noise. This data model has been employed in several existing studies (Allen-Zhu & Li, 2023; Cao et al., 2022; Kou et al., 2023; Meng et al., 2024; Zou et al., 2023). A difference in our model is the use of two orthogonal signal vectors, in contrast to previous works, that employ a single signal vector of the form yµ. Additionally, while our analysis focuses on a two-patch setting for simplicity, it can be readily extended to multi-patch data. We let SNR := µ /(σξ d) denote the signal-to-noise ratio. Neural network functions. We study two-layer convolutional-type neural networks for both diffusion model and classification. For diffusion model, we consider neural network with quadratic activation and shared first-layer and second-layer weights: f(W, x) = h f1(W, x(1)) , f2(W, x(2)) i R2d, where fp W, x(p) = 1 m r=1 wr, x(p) 2wr, p = 1, 2 where m denotes the network width and r represents the neuron index. For classification, we consider a similar neural network with quadratic activation where second-layer weights are fixed to be 1 (instead of wr): f(W, x) = F1(W1, x) F 1(W 1, x), where Fj(W, x) = 1 r=1 wj,r, x(1) 2 + 1 r=1 wj,r, x(2) 2. We remark that the use of polynomial activation, such as quadratic, cubic and Re LU with polynomial smoothing is not uncommon in existing theoretical works (Cao et al., 2022; Jelassi & Li, 2022; Zou et al., 2023; Huang et al., 2023a; Meng et al., 2023). The aim is to better elucidate the separation between signal and noise learning dynamics. Training objectives and algorithm. For diffusion model, we employ the objective of denoising diffusion probabilistic model (DDPM) (Ho et al., 2020). We let x0 = [x(1), x(2)] R2d to denote input image. For a given diffusion time step t [0, T], we sample xt = αtx0 +βtϵt for ϵt N(0, I) and a pre-determined noise schedule coefficients {αt, βt}T t=0. The aim of diffusion models is to estimate the mean of the posterior distribution of the noise ϵt conditioned on xt. This is achieved by training a neural network f to predict the noise added at each step t. The DDPM loss is given by Ex0,ϵt,t f(xt) ϵt 2 up to some re-scaling (Ho et al., 2020). We consider a finite-sample setup given by the training images {xi}n i=1 sampled according to Definition 2.1 and thus the empirical DDPM loss at time step t becomes LF (Wt) = 1 i=1 Eϵt,i f(Wt, xt,i) ϵt,i 2 = 1 i=1 Eϵt,i f(Wt, αtx0,i + βtϵt,i) ϵt,i 2 , where we let x0,i = xi and xt,i = αtx0,i + βtϵt,i. Here, we decouple the training of neural network at each diffusion time step with separate weight parameters, a strategy also adopted in (Shah et al., 2023) for simplicity of analysis. Unlike (Han et al., 2024), where each sample i is associated with a single noise ϵt,i N(0, I), we here consider taking the expectation over the noise distribution, which aligns with the practical setting where multiple noises are sampled for each input data. We use gradient descent to train diffusion model starting from random Gaussian initialization w0 r,t N(0, σ2 0I) as wk+1 r,t = wk r,t η wr,t LF (Wk t ), where the superscript k is the iteration index. Published as a conference paper at ICLR 2025 For classification, we minimize the empirical logistic loss over the training data {xi, yi}n i=1, i=1 ℓ yif(W, xi) , ℓ(z) = log 1 + exp( z) . The same as diffusion model, we use gradient descent to train the neural network starting from random Gaussian initialization w0 j,r N(0, σ2 0I). 3 MAIN RESULTS Our main results are based on the following conditions. Condition 3.1. Suppose the following holds. 1. Dimension d is sufficiently large with d = eΩ n7m5 . 2. The sample size n satisfies n = eΩ(1). 3. The standard deviation of initialization σ0 is chosen such that e O(n2mσ 1 ξ d 1) σ0 e O min{m 1/6d 1/6σ1/3 ξ n 1/3, m 1/6d 7/12σ 1/3 ξ n1/3, d 3/4σ 1 ξ n} . 4. The learning rate η satisfies η e O min{nmσ0σ 1 ξ d 1/2, nmσ 2 ξ d 1} . 5. The signal strength satisfies µ = Θ(1) and noise variation σξ satisfies e O(max{n5/2m7/4d 5/8, nm1/6d 1}) σξ e O(d 1/4). 6. The noise coefficients for diffusion model satisfy αt, βt = Θ(1). Condition 3.1 requires d to be sufficiently large to ensure learning in an over-parameterized setting. Furthermore, we require the sample size to be lower bounded by a constant subject to logarithmic factors. The upper bound on the initialization σ0 is to ensure random initialization does not significantly affect the signal and noise learning dynamics. The lower bound on σ0 is required to bound the noise inner product at initialization for properly minimizing the training loss of classification. The learning rate η is chosen sufficiently small for the convergence analysis for the classification. Lastly for diffusion model, we consider the constant order for µ and further restrict the range of σξ. Despite these conditions, our setting covers a broad range of n SNR2, i.e., e O(nd 1/2) n SNR2 e O(min{n 4m 7/2d1/4, n 1m 1/3d}). We also consider constant order of αt, βt to avoid degeneracy in learning dynamics. We present the main results for diffusion model (Theorem 3.1) and classification (Theorem 3.2). Theorem 3.1 (Diffusion model). Under Condition 3.1, suppose m = Θ(1). With probability at least 1 δ (for any δ > 0), there exists a stationary point W t along the training trajectory of diffusion model, i.e., wr,t LF (W t ) = 0 that satisfies (1) w r,t, µj = Θ( w r ,t, µj ), (2) w r,t, ξi = Θ( w r ,t, ξi ), and (3) for all j = 1, r [m], i [m], | w r,t, µj |/| w r,t, ξi | = Θ(n SNR2), with w r,t, µj = Θ(1) if n SNR2 = Ω(1), and w r,t, ξi = Θ(1) if n 1 SNR 2 = Ω(1). Theorem 3.1 states that diffusion model training encourages balanced signal and noise learning, i.e., the neurons share the same order in the directions of signals and noise. Notably, the ratio between signal and noise learning is governed by the SNR, with a stationary magnitude as n SNR2. Theorem 3.2 (Classification). Let Tµ = eΘ(η 1m µ 2) and Tξ = eΘ(η 1nmσ 2 ξ d 1) and δ > 0. Under Condition 3.1, suppose m = Ω(log(n/δ)). There exist two absolute constants C > C > 0 such that with probability at least 1 δ, it satisfies that: When n SNR2 C, there exists 0 k Tµ such that LS(Wk) 0.1 and max r | wk j,r, µj | 2, j = 1, max j,r,i | wk j,r, ξi | = o(1). Published as a conference paper at ICLR 2025 When n SNR2 C, there exists 0 k Tξ such that LS(Wk) 0.1 and max r | wk yi,r, ξi | 1, i [n], max j,r,y | wk j,r, µy | = o(1). Theorem 3.2 establishes a sharp phase transition between signal and noise learning under classification training. The transition is precisely determined by n SNR2. That is, when n SNR2 C for some constant C > 0, the neural network learns signal to achieve small training loss. On the contrary, when n SNR2 C for some constant C (0, C), the neural network overfits noise to achieve convergence. With standard techniques, such as in (Cao et al., 2022), we can show signal and noise learning corresponds to the regime of benign and harmful overfitting respectively. To the best of our knowledge, this is the first result that shows separation under the constant of n SNR2. Diffusion model learns balanced features while classification learn dominant features. Comparing the learning outcomes of diffusion model and classification, we reveal a critical difference that diffusion models learn more balanced features depending on the SNR conditions, while classification is prone to learning either signal or noise predominately. This can be best understood in the case of n SNR2 = Θ(1). By Theorem 3.2, we have either signal learning or noise dominating the learning process in classification, while Theorem 3.1 suggests signal and noise learning are in the same order in diffusion models. The theoretical findings corroborate the empirical observations that the neural network trained for classification is prone to overly rely on learning a specific pattern that is easier to learn, a process known as shortcut learning (Geirhos et al., 2020). Meanwhile, diffusion models tend to learn low-frequency, global patterns (Jaini et al., 2024), which helps to improve the classification robustness (Chen et al., 2024b;c). 4 PROOF OVERVIEW This section outlines the proof roadmap for the main results. For diffusion model, the mean-squared loss, the joint training of two layers as well as learning in the direction of initialization, pose significant challenges for the analysis. We adopt a two-stage analysis and characterize the stationary points based on the derived results at the end of the second stage. For classification, the two-stage analysis is similar as in (Cao et al., 2022; Kou et al., 2023) where the first stage learns signal or noise vector sufficiently fast and the second stage shows convergence in the training loss where the learned scale difference in the first stage is maintained. However for classification analysis, we highlight two critical differences compared to existing works (Cao et al., 2022; Kou et al., 2023; Meng et al., 2024), i.e., a constant n SNR2 condition and quadratic activation. 4.1 DIFFUSION MODEL We first simplify the DDPM loss by taking the expectation with respect to the added diffusion noise: LF (Wt) = d + 1 m Eϵt,i m X r=1 wr,t, x(p) t,i 2wr,t 2 r=1 wr,t 2 wr,t, x(p) 0,i where for p = 1, 2, x(p) t,i = αtx(p) 0,i + βtϵ(p) t,i , with x(1) 0,i = µyi and x(2) 0,i = ξi and ϵ(1) t,i , ϵ(2) t,i N(0, I). We further simplify I1 in Lemma D.2 (in Appendix). We remark that I1 corresponds to a regularization term that regulates the magnitude and alignment of neurons, while I2 corresponds to the main learning term. We highlight that apart from the signal and noise directions, the learning term I2 also includes the initialization direction w0 r,t, which further complicates the analysis. First stage. In the first stage, where all the key quantities, including signal and noise inner products, weight norms and cross-neuron inner products remain close to their respective initialization, we can show the growth of the signal and noise inner products is approximately linear: wk+1 r,t , µj = wk r,t, µj + Θ(η wk r,t 2 µ 2) wk+1 r,t , ξi = wk r,t, ξi + Θ(ηn 1 wk r,t 2 ξi 2) (1) In addition, the change of wk r,t along direction w0 r ,t can be properly controlled such that the scale of key quantities remain unaffected and the simplification in (1) is valid throughout the first stage. Published as a conference paper at ICLR 2025 The updates in (1) immediately suggest that once the growth terms of the inner products dominate their initialization, we obtain | wk r,t, µj |/| wk r,t, ξi | = Θ(n SNR2). This marks the end of the first stage, as described in the following lemma. Lemma 4.1. Under Condition 3.1, there exists an iteration T1 = max{Tµ, Tξ}, where Tµ = eΘ( mσ 1 0 d 1 µ 1η 1) and Tξ = eΘ(n mσ 1 0 σ 1 ξ d 3/2η 1) such that for all k T1, wk r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = Θ(σ2 0d) for all r [m], j = 1, i [n]. Furthermore, we can show for all j, j = 1, r, r [m], i, i [n], w T1 r,t, µj = Θ( w T1 r ,t, µj ), w T1 r,t, ξi = Θ( w T1 r ,t, ξi ), and | w T1 r,t, µj |/| w T1 r,t, ξi | = Θ(n SNR2) , Lemma 4.1 verifies that at the end of the first stage, all the neurons are concentrated and the ratio is precisely determined by n SNR2. This is critically different compared to classification where signal and noise learning exhibits exponential growth as we show later and thus shows a clear scale difference at the end of the first stage, even when n SNR2 = Θ(1). Second stage. The second stage aims to characterize when the dominant terms in the gradients along the key directions become no longer dominant. To this end, we decompose the gradient into wr,t L(Wk t ), µj = Θ wk r,t 2 µ 2 + Ek r,t,µj, wr,t L(Wk t ), ξi = Θ n 1 wk r,t 2 ξi 2 + Ek r,t,ξi, wr,t L(Wk t ), w0 r,t = Θ wk r,t, µj + ξ wk r,t 4 wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ + Ek r,t,w0, where Ek r,t,µj, Ek r,t,ξi, Ek r,t,w0 are the residual terms of the gradients and we let ξ = 1 n Pn i=1 ξi. The following lemma shows before Ek r,t,µj, Ek r,t,ξi, Ek r,t,w0 reach order as the dominant terms, the ratio of signal and noise inner products are preserved. Lemma 4.2. There exists an iteration T2 > T1 with T2 = Θ(max{η 1σ 2 0 d 1, η 1nσ 2 0 σ2 ξ}) such that for all j = 1, r [m], i [n] (1) if n SNR2 = Ω(1), w T2 r,t, µj = Θ(1) and if n 1 SNR 2 = Ω(1), w T2 r,t, ξi = Θ(1); (2) ET2 r,t,µj = Θ( w T2 r,t 2 µ 2), Er,t,ξi = Θ(n 1 w T2 r,t 2 ξi 2), ET2 r,t,w0 = Θ(( wk r,t, µj +ξ wk r,t 4) wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj +ξ ) and (3) for all T1 k T2, we have wk r,t, µj = Θ( wk r ,t, µj ), wk r,t, ξi = Θ( wk r ,t, ξi ), | wk r,t, µj |/| wk r,t, ξi | = Θ(n SNR2) , Lemma 4.2 characterizes T2 as the point where the dominant terms of the gradients in the first stage become comparable to the residual terms. Meanwhile, we show the scale of signal and noise inner products escape from initialization and reach constant order. Throughout the second stage, the concentration of neurons are preserved and the ratio of signal to noise learning is dictated by n SNR2. Finally, we identify there exists a stationary point that satisfies the conditions at the end of the second stage (Lemma 4.2). Theorem 4.1 (Informal). There exists a stationary point W t , i.e., wr,t L(W t ) = 0 such that the conditions at T2 (in Lemma 4.2) are satisfied, and in particular | w r,t, µj |/| w r,t, ξi | = Θ(n SNR2) for all j = 1, r [m], i [n]. 4.2 CLASSIFICATION For classification, we first let Sy := {i [n] : yi = y} for y = 1 and ℓ k i = ℓ yif(Wk, x) . We can rewrite the gradient descent updates in terms of the signal and noise inner products: wk+1 j,r , µy = wk j,r, µy η|Sy| nm ℓ k i wk j,r, µy jy µ 2 = (1 η|Sy| µ 2 nm ℓ k i jy) wk j,r, µy , (2) wk+1 j,r , ξi = wk j,r, ξi η nmℓ k i wk j,r, ξi ξi 2jyi η nm i =i ℓ k i wk j,r, ξi jyi ξi , ξi , (3) Published as a conference paper at ICLR 2025 Low SNR (n SNR2 = 0.75) High SNR (n SNR2 = 6.75) Figure 2: Experiments on the synthetic dataset with both low SNR (n SNR2 = 0.75) and high SNR (n SNR2 = 6.75). In the low SNR setting, we see noise learning quickly dominates signal learning for the classification task and in the high SNR setting, signal learning quickly dominates noise learning. Meanwhile diffusion model converges to a stationary point that with signal-to-noise learning ratio respects the order of n SNR2. More experimental results on additional SNR values are in Appendix A.2. for all j, y = 1, r [m], i [n]. For the signal inner product, its update suggests that for any j = 1, wj,r specializes the learning of µj due to that ℓ k i < 0, | wk+1 j,r , µy | = (1 nm ℓ k i jy)| wk j,r, µy | > | wk j,r, µy | only when j = y. For the noise inner product, the growth is dominated by the second term where | ξi, ξi | = e O(d 1/2) ξi 2 is negligible. Thus, we show | wk+1 j,r , ξi | grows only for j = yi. In contrast, for j = yi, its magnitude cannot increase relative to the scale of initialization. Next, we decompose the analysis into two stages. First stage. In the first stage before the maximum of signal and noise inner product reaches constant order, the loss derivatives can be lower bounded by an absolute constant, i.e., |ℓ k i | Cℓ, for all k T1. As a result, both signal and noise inner product can grow exponentially and the relative growth rates are determined by n SNR2. A constant order of difference in the growth rates is sufficient to ensure a scale separation in the signal and noise learning at the end of the first stage, as shown in the following Lemma. Lemma 4.3. Under Condition 3.1: (1) When n SNR2 = Ω(1), there exists T1 = eΘ(η 1m µ 2), such that 1 m Pm r=1 | w T1 j,r, µj | 2 for all j = 1 and maxj,r,i | w T1 j,r, ξi | = o(1). (2) When n 1 SNR 2 = Ω(1), there exists T1 = eΘ(η 1nmσ 2 ξ d 1) such that 1 m Pm r=1 | w T1 yi,r, ξi | 4 for all i [n] and maxj,r,y | w T1 j,r, µy | = o(1). Remark 4.1. Different to existing analysis that only shows maximum inner product reaches constant order (Cao et al., 2022; Huang et al., 2023a), we also show the average inner product reach constant order at the same time. Such a stronger result is required for the analysis under the constant order of n SNR2, which reduces the required iteration number in the second stage by an order of m. Second stage. In the second stage, we show the loss converges while the scale separation established in Lemma 4.3 is maintained. Because n SNR2 can be a constant, we require to carefully bound the loss derivatives in the second stage particularly for establishing the upper bound for | wk j,r, ξi | when n SNR2 = Ω(1). The naïve bound maxi |ℓ k i | maxi |ℓk i | n LS(Wk) used in (Cao et al., 2022) no longer works as it introduces an additional factor of n. To provide a tighter bound, we show the ratio of loss derivatives in the case of n SNR2 = Ω(1), i.e., |ℓ k i |/|ℓ k i | C1 for all i, i [n] with yi = yi , k T1, where C1 > 0 is a constant. This is possible because the network output is dominated by the signal, which is shared across samples with the same label. This allows to bound maxi |ℓ k i | = Θ(|Syi | 1 P i Syi |ℓ k i |) Θ(LS(Wk)). 5 EXPERIMENTS We conduct both synthetic and real-world experiments to verify our theoretical claims. Published as a conference paper at ICLR 2025 Figure 3: Experiments on Noisy-MNIST with ] SNR = 0.1. (First row): Test Noisy-MNIST images; (Second row): Illustration of input gradient, i.e., x F+1(W, x) when y = 1 and x F 1(W, x) when y = 0. (Third row): denoised image from diffusion model. In this low-SNR case, we see classification tends to predominately learn noise while diffusion learns both signals and noise. Figure 4: Experiments on Noisy-MNIST with ] SNR = 0.1. (a) Train loss for classification. (b) Train loss for diffusion model. (c) Feture learning dynamics. 5.1 SYNTHETIC EXPERIMENT Setup. We follow Definition 2.1 to generate a synthetic dataset for both diffusion model and classification. Specifically, we set data dimension d = 1000 and let µ1 = [µ, 0, , 0] Rd and µ 1 = [0, µ, 0, , 0] Rd. We sample the noise patch ξi N(0, Id), i [n] (i.e., σξ = 1). We set sample size and network width to be n = 30 and m = 20 and initialize the weights to be Gaussian with a standard deviation σ0 = 0.001. We vary the choice of µ to create two problem settings: (1) low SNR with µ = 5, which leads to n SNR2 = 0.75 and (2) high SNR with µ = 15, which leads to n SNR2 = 6.75. We use the same two-layer networks introduced in Section 2. For classification, we set a learning rate of η = 0.1 and train for 500 iterations. For diffusion model, we minimize the DDPM loss by averaging over the diffusion noise, following the standard training of diffusion model. In particular, for each sample, we samples nϵ = 2000 noise at each iteration and the loss is calculated by taking an average over the noise. For the noise coefficients, we consider a time t = 0.2 and set αt = exp( t) = 0.82 and βt = p 1 exp( 2t) = 0.57. For diffusion model, we set η = 0.5 and train for 40000 iterations. Results. In Figure 2, we compare signal and noise learning dynamics visualized through maximum signal and noise inner product between classification and diffusion model. In Appendix A.1, we also include training loss convergence for both the tasks as well as training and test accuracy for classification. For classification, the training loss converges while diffusion model recovers only a stationary point. In terms of feature learning, noise learning in classification quickly dominates signal learning by exhibiting a significant larger growth in the first stage (up to around 20 iterations). This ensures that noise learning stabilizes at a constant order while signal learning remains relatively small. In the second stage, training loss converges and signal and noise learning exhibits logarithmic growth. For diffusion model, in the first stage, where training loss does not materially change, both signal and noise learning increase linearly. In the second stage where loss significantly decreases, signal and noise learning start to grow exponentially and in the final stage, due to the weight regularization terms, noise and signal reach a stationary point that preserves the scale of n SNR2. Published as a conference paper at ICLR 2025 5.2 REAL-WORLD EXPERIMENT Setup. We also conduct experiments on the MNIST dataset (Lecun et al., 1998) to support our theory. In order to better control the signal-to-noise ratio, we create a Noisy-MNIST dataset, where we treat each original MNIST image as a clean signal patch and concatenate a standard Gaussian noise patch with the same size, i.e., 28 28. In addition, we scale the signal patch by a constant denoted as ] SNR. Because the noise scale is fixed, higher ] SNR corresponds to higher SNR. Some sample images with ] SNR = 0.1 are shown in the first row of Figure 3. We select 50 samples each from digit 0 and 1 respectively (i.e., n = 100). We consider the same neural networks as in the synthetic example, where we set m = 100 and initialize the weights with σ0 = 0.01. For diffusion model, we choose the same αt, βt as in the synthetic experiment. In the main paper, we present the results for ] SNR = 0.1, which corresponds to a low SNR setting. Results. Figure 4(a,b) shows that both classification and diffusion model converge in loss. Additionally, Figure 4(c) plots the signal and noise learning dynamics. Because each image is composed of unique signal µi and noise patch ξi for i [n], we measure the signal and noise learning by computing 1 n Pn i=1 maxr | wr, µi | and 1 n Pn i=1 maxr | wr, ξi | respectively. We notice that due to the low SNR, noise learning in classification dominates signal learning at convergence while diffusion model learns more balanced features. This corroborates our theoretical findings. To visualize the patterns learned by the neural networks, for classification, we adopt an approach similar to Grad-CAM (Selvaraju et al., 2020) by analyzing the gradient of output with respect to the input. Specifically, for samples of digit 0, we plot the gradient of negative function output, x F 1(W, x), while for digit 1, we plot x F+1(W, x). As shown in the second row of Figure 3, the gradients of six test images indicate that classification primarily learns the noise rather than the signal patch. For diffusion model, we first add diffusion noise to the input images and use the network to predict the added noise. Then we reconstruct the input with the formula ˆx0 = (xt βtˆϵ(xt))/αt, where ˆϵ(xt) denotes the predicted diffusion noise. The third row of Figure 3 shows that the diffusion model learns both the signal and noise. In Appendix A.3, we present results for a high-SNR setting with ] SNR = 0.5, where we observe the reverse pattern: classification predominately captures signal rather than noise while diffusion model continues to balance the learning of both signal and noise. Additionally, Appendix A.5 presents experiments on all 10 digits of the MNIST dataset, verifying the observed distinctions in feature learning between diffusion models and classification. 6 CONCLUSION This work introduces a theoretical framework for analyzing the feature learning dynamics in diffusion models, taking an initial step toward understanding the representation learning in diffusion models. Our findings demonstrate that diffusion models inherently promote a more balanced feature learning, in contrast to classification models, which tend to prioritize certain features over others. This suggests that classification models may be more sensitive to variations in the signal-to-noise ratio compared to diffusion models. Consequently, this may provide an explanation for the inherent adversarial robustness of diffusion models in downstream applications, such as classification (Li et al., 2023a; Chen et al., 2024c;b), as perturbations are less likely to significantly affect the learned representations compared to classification models. Although our study focuses on a two-patch data setup, the proposed framework can be extended to accommodate more complex data settings. For example, our analysis can be extended to multifeature data distributions, where certain features appear more frequently (Zou et al., 2023) or have larger norms than others (Lu et al., 2024). Such extensions could provide deeper insights into the mechanisms of feature learning in more realistic scenarios. We hypothesize that, despite the infrequent occurrence or smaller norm of certain features, diffusion models can effectively learn them due to the nature of the denoising objective. This insight has significant implications for downstream tasks, such as out-of-distribution classification, where rare or weak features may be the primary distinguishing factors. We believe our framework holds broader potential beyond the scope of this work and can be adapted to analyze conditional and latent diffusion models, elucidate the mechanisms of various training objectives and optimizers, and examine other generative paradigms, such as flow matching. Published as a conference paper at ICLR 2025 ACKNOWLEDGEMENTS We would like to thank the anonymous reviewers and area chairs for their helpful comments. Wei Huang is supported by JSPS KAKENHI Grant Number 24K20848. Yuan Cao is supported in part by NSFC 12301657 and Hong Kong ECS award 27308624. Difan Zou is supported in part by NSFC 62306252, Hong Kong ECS award 27309624, Guangdong NSF 2024A1515012444, and the central fund from HKU IDS. Zeyuan Allen-Zhu and Yuanzhi Li. Towards understanding ensemble, knowledge distillation and self-distillation in deep learning. In International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=Uuf2q9Tf XGA. Xuchan Bao, James Lucas, Sushant Sachdeva, and Roger B Grosse. Regularized linear autoencoders recover the principal components, eventually. Advances in Neural Information Processing Systems, 33:6971 6981, 2020. Dmitry Baranchuk, Andrey Voynov, Ivan Rubachev, Valentin Khrulkov, and Artem Babenko. Labelefficient semantic segmentation with diffusion models. In International Conference on Learning Representations, 2022. Yuan Cao, Zixiang Chen, Misha Belkin, and Quanquan Gu. Benign overfitting in two-layer convolutional neural networks. Advances in Neural Information Processing Systems, 35:25237 25250, 2022. Haoxuan Chen, Yinuo Ren, Lexing Ying, and Grant M Rotskoff. Accelerating diffusion models with parallel sampling: Inference at sub-linear time complexity. ar Xiv:2405.15986, 2024a. Huanran Chen, Yinpeng Dong, Shitong Shao, Zhongkai Hao, Xiao Yang, Hang Su, and Jun Zhu. Your diffusion model is secretly a certifiably robust classifier. ar Xiv:2402.02316, 2024b. Huanran Chen, Yinpeng Dong, Zhengyi Wang, Xiao Yang, Chengqi Duan, Hang Su, and Jun Zhu. Robust classification via a single diffusion model. In International Conference on Machine Learning, 2024c. Minshuo Chen, Kaixuan Huang, Tuo Zhao, and Mengdi Wang. Score approximation, estimation and distribution recovery of diffusion models on low-dimensional data. In International Conference on Machine Learning, pp. 4672 4712. PMLR, 2023a. Nanxin Chen, Yu Zhang, Heiga Zen, Ron J Weiss, Mohammad Norouzi, and William Chan. Wavegrad: Estimating gradients for waveform generation. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=Ns MLjc Fa O8O. Sitan Chen, Sinho Chewi, Jerry Li, Yuanzhi Li, Adil Salim, and Anru R Zhang. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. In International Conference on Learning Representations, 2023b. Sitan Chen, Vasilis Kontonis, and Kulin Shah. Learning general gaussian mixtures with efficient score matching. ar Xiv:2404.18893, 2024d. Kevin Clark and Priyank Jaini. Text-to-image diffusion models are zero shot classifiers. Advances in Neural Information Processing Systems, 36, 2024. Hugo Cui and Lenka Zdeborová. High-dimensional asymptotics of denoising autoencoders. Advances in Neural Information Processing Systems, 36, 2024. Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in Neural Information Processing Systems, 34:8780 8794, 2021. Michael Fuest, Pingchuan Ma, Ming Gui, Johannes S Fischer, Vincent Tao Hu, and Bjorn Ommer. Diffusion models and representation learning: A survey. ar Xiv:2407.00783, 2024. Published as a conference paper at ICLR 2025 Khashayar Gatmiry, Jonathan Kelner, and Holden Lee. Learning mixtures of gaussians using diffusion models. ar Xiv:2404.18869, 2024. Robert Geirhos, Jörn-Henrik Jacobsen, Claudio Michaelis, Richard Zemel, Wieland Brendel, Matthias Bethge, and Felix A Wichmann. Shortcut learning in deep neural networks. Nature Machine Intelligence, 2(11):665 673, 2020. Shivam Gupta, Linda Cai, and Sitan Chen. Faster diffusion-based sampling with randomized midpoints: Sequential and parallel. ar Xiv:2406.00924, 2024. Yinbin Han, Meisam Razaviyayn, and Renyuan Xu. Neural network-based score estimation in diffusion models: Optimization and generalization. In International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=h8Geq Oxtd4. Yujin Han, Andi Han, Wei Huang, Chaochao Lu, and Difan Zou. Can diffusion models learn hidden inter-feature rules behind images? ar Xiv:2502.04725, 2025. Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems, 33:6840 6851, 2020. Emiel Hoogeboom, Vıctor Garcia Satorras, Clément Vignac, and Max Welling. Equivariant diffusion for molecule generation in 3d. In International Conference on Machine Learning, pp. 8867 8887. PMLR, 2022. Wei Huang, Yuan Cao, Haonan Wang, Xin Cao, and Taiji Suzuki. Graph neural networks provably benefit from structural information: A feature learning perspective. ar Xiv:2306.13926, 2023a. Wei Huang, Ye Shi, Zhongyi Cai, and Taiji Suzuki. Understanding convergence and generalization in federated learning through feature learning theory. In The Twelfth International Conference on Learning Representations, 2023b. Wei Huang, Andi Han, Yongqiang Chen, Yuan Cao, Zhiqiang Xu, and Taiji Suzuki. On the comparison between multi-modal and single-modal contrastive learning. ar Xiv preprint ar Xiv:2411.02837, 2024a. Xunpeng Huang, Difan Zou, Hanze Dong, Yi Zhang, Yi-An Ma, and Tong Zhang. Reverse transition kernel: A flexible framework to accelerate diffusion inference. ar Xiv preprint ar Xiv:2405.16387, 2024b. Priyank Jaini, Kevin Clark, and Robert Geirhos. Intriguing properties of generative classifiers. In International Conference on Learning Representations, 2024. URL https://openreview. net/forum?id=rmg0q MKYRQ. Samy Jelassi and Yuanzhi Li. Towards understanding how momentum improves generalization in deep learning. In International Conference on Machine Learning, pp. 9965 10040. PMLR, 2022. Jiarui Jiang, Wei Huang, Miao Zhang, Taiji Suzuki, and Liqiang Nie. Unveil benign overfitting for transformer in vision: Training dynamics, convergence, and generalization. ar Xiv preprint ar Xiv:2409.19345, 2024. Kevin Kögler, Alexander Shevchenko, Hamed Hassani, and Marco Mondelli. Compression of structured data with autoencoders: Provable benefit of nonlinearities and depth. ar Xiv:2402.05013, 2024. Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=a-x FK8Ymz5J. Yiwen Kou, Zixiang Chen, Yuanzhou Chen, and Quanquan Gu. Benign overfitting in two-layer Re LU convolutional neural networks. In International Conference on Machine Learning, pp. 17615 17659. PMLR, 2023. Daniel Kunin, Jonathan Bloom, Aleksandrina Goeva, and Cotton Seed. Loss landscapes of regularized linear autoencoders. In International Conference on Machine Learning, pp. 3560 3569. PMLR, 2019. Published as a conference paper at ICLR 2025 Y Lecun, L Bottou, Y Bengio, and P Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278 2324, 1998. Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence for score-based generative modeling with polynomial complexity. Advances in Neural Information Processing Systems, 35:22870 22882, 2022. Holden Lee, Jianfeng Lu, and Yixin Tan. Convergence of score-based generative modeling for general data distributions. In International Conference on Algorithmic Learning Theory, pp. 946 985. PMLR, 2023. Alexander C Li, Mihir Prabhudesai, Shivam Duggal, Ellis Brown, and Deepak Pathak. Your diffusion model is secretly a zero-shot classifier. In International Conference on Computer Vision, pp. 2206 2217, 2023a. Gen Li, Yuting Wei, Yuxin Chen, and Yuejie Chi. Towards faster non-asymptotic convergence for diffusion-based generative models. ar Xiv:2306.09251, 2023b. Gen Li, Yu Huang, Timofey Efimov, Yuting Wei, Yuejie Chi, and Yuxin Chen. Accelerating convergence of score-based diffusion models, provably. In International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pp. 27942 27954. PMLR, 21 27 Jul 2024a. Gen Li, Zhihan Huang, and Yuting Wei. Towards a mathematical theory for consistency training in diffusion models. ar Xiv:2402.07802, 2024b. Marvin Li and Sitan Chen. Critical windows: non-asymptotic theory for feature emergence in diffusion models. In Proceedings of the 41st International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pp. 27474 27498. PMLR, 2024. Puheng Li, Zhong Li, Huishuai Zhang, and Jiang Bian. On the generalization properties of diffusion models. In Advances in Neural Information Processing Systems, volume 36, pp. 2097 2127, 2023c. Miao Lu, Beining Wu, Xiaodong Yang, and Difan Zou. Benign oscillation of stochastic gradient descent with large learning rate. In International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=w Ymv N3s Qp G. Xuran Meng, Yuan Cao, and Difan Zou. Per-example gradient regularization improves learning signals from noisy data. ar Xiv:2303.17940, 2023. Xuran Meng, Difan Zou, and Yuan Cao. Benign Overfitting in Two-Layer Re LU Convolutional Neural Networks for XOR Data. In International Conference on Machine Learning. PMLR, 2024. Soumik Mukhopadhyay, Matthew Gwilliam, Vatsal Agarwal, Namitha Padmanabhan, Archana Swaminathan, Srinidhi Hegde, Tianyi Zhou, and Abhinav Shrivastava. Diffusion models beat gans on image classification. ar Xiv:2307.08702, 2023. Phan-Minh Nguyen. Analysis of feature learning in weight-tied autoencoders via the mean field lens. ar Xiv:2102.08373, 2021. Thanh V Nguyen, Raymond KW Wong, and Chinmay Hegde. Benefits of jointly training autoencoders: An improved neural tangent kernel analysis. IEEE Transactions on Information Theory, 67(7):4669 4692, 2021. Reza Oftadeh, Jiayi Shen, Zhangyang Wang, and Dylan Shell. Eliminating the invariance on the loss landscape of linear autoencoders. In International Conference on Machine Learning, pp. 7405 7413. PMLR, 2020. Maya Okawa, Ekdeep S Lubana, Robert Dick, and Hidenori Tanaka. Compositional abilities emerge multiplicatively: Exploring diffusion models on a synthetic task. Advances in Neural Information Processing Systems, 36, 2024. Published as a conference paper at ICLR 2025 Kazusato Oko, Shunta Akiyama, and Taiji Suzuki. Diffusion models are minimax optimal distribution estimators. In International Conference on Machine Learning, pp. 26517 26582. PMLR, 2023. William Peebles and Saining Xie. Scalable diffusion models with transformers. In International Conference on Computer Vision, pp. 4195 4205, 2023. Arnu Pretorius, Steve Kroon, and Herman Kamper. Learning dynamics of linear denoising autoencoders. In International Conference on Machine Learning, pp. 4141 4150. PMLR, 2018. Maria Refinetti and Sebastian Goldt. The dynamics of representation learning in shallow, non-linear autoencoders. In International Conference on Machine Learning, pp. 18499 18519. PMLR, 2022. Antonio Sclocchi, Alessandro Favero, and Matthieu Wyart. A phase transition in diffusion models reveals the hierarchical nature of data. ar Xiv:2402.16991, 2024. Ramprasaath R Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. Grad-CAM: visual explanations from deep networks via gradient-based localization. International Journal of Computer Vision, 128:336 359, 2020. Kulin Shah, Sitan Chen, and Adam Klivans. Learning mixtures of gaussians using the DDPM objective. Advances in Neural Information Processing Systems, 36:19636 19649, 2023. Aleksandr Shevchenko, Kevin Kögler, Hamed Hassani, and Marco Mondelli. Fundamental limits of two-layer autoencoders, and achieving them with gradient methods. In International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp. 31151 31209. PMLR, 2023. Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum? id=Px TIG12RRHS. Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever. Consistency models. In International Conference on Machine Learning, pp. 32211 32252. PMLR, 2023. Harald Steck. Autoencoders that don t overfit towards the identity. In H. Larochelle, M. Ranzato, R. Hadsell, M.F. Balcan, and H. Lin (eds.), Advances in Neural Information Processing Systems, volume 33, pp. 19598 19608. Curran Associates, Inc., 2020. Yuqing Wang, Ye He, and Molei Tao. Evaluating the design space of diffusion-based generative models. ar Xiv:2406.12839, 2024. Joseph L Watson, David Juergens, Nathaniel R Bennett, Brian L Trippe, Jason Yim, Helen E Eisenach, Woody Ahern, Andrew J Borst, Robert J Ragotte, Lukas F Milles, et al. De novo design of protein structure and function with rfdiffusion. Nature, 620(7976):1089 1100, 2023. Weilai Xiang, Hongyu Yang, Di Huang, and Yunhong Wang. Denoising diffusion autoencoders are unified self-supervised learners. In International Conference on Computer Vision, pp. 15802 15812, 2023. Xingyi Yang and Xinchao Wang. Diffusion model as representation learner. In International Conference on Computer Vision, pp. 18938 18949, 2023. Yongyi Yang, Core Francisco Park, Ekdeep Singh Lubana, Maya Okawa, Wei Hu, and Hidenori Tanaka. Dynamics of concept learning and compositional generalization. In International Conference on Learning Representations, 2025. Kaihong Zhang, Heqi Yin, Feng Liang, and Jingbo Liu. Minimax optimality of score-based diffusion models: Beyond the density lower bound assumptions. In International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pp. 60134 60178. PMLR, 2024. Wenliang Zhao, Yongming Rao, Zuyan Liu, Benlin Liu, Jie Zhou, and Jiwen Lu. Unleashing text-toimage diffusion models for visual perception. In International Conference on Computer Vision, pp. 5729 5739, 2023. Published as a conference paper at ICLR 2025 Difan Zou, Yuan Cao, Yuanzhi Li, and Quanquan Gu. The benefits of mixup for feature learning. In International Conference on Machine Learning, pp. 43423 43479. PMLR, 2023. Published as a conference paper at ICLR 2025 APPENDIX CONTENTS A Additional experimental results 17 A.1 Supplementary results for synthetic experiment . . . . . . . . . . . . . . . . . . . 17 A.2 Feature learning comparison under varying SNRs . . . . . . . . . . . . . . . . . . 17 A.3 High SNR setting on Noisy-MNIST . . . . . . . . . . . . . . . . . . . . . . . . . 17 A.4 Experiments with additional diffusion time step . . . . . . . . . . . . . . . . . . . 18 A.5 On the feature learning with 10-class MNIST . . . . . . . . . . . . . . . . . . . . 20 B Preliminary lemmas 21 C Classification 21 C.1 Useful lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 22 C.2 Scale of inner products . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 23 C.3 Signal learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28 C.3.1 First stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 C.3.2 Second stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31 C.4 Noise memorization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 35 C.4.1 First stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36 C.4.2 Second stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38 D Diffusion model 43 D.1 Useful lemmas . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 43 D.2 Derivation of loss function and gradient . . . . . . . . . . . . . . . . . . . . . . . 44 D.3 First stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 46 D.4 Second stage . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 54 D.5 Stationary point . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 66 Published as a conference paper at ICLR 2025 A ADDITIONAL EXPERIMENTAL RESULTS This section includes additional experiment results. A.1 SUPPLEMENTARY RESULTS FOR SYNTHETIC EXPERIMENT We first include the convergence in loss plots as well as accuracy for classification under the two SNR conditions considered in the main experiment. The (in-distribution) test accuracy is computed with 3000 test samples. We see both classification and diffusion model are able to converge in loss, although diffusion model only finds a stationary point. In the low SNR setting, classification is able to perfectly fit the training samples with a 100% classification accuracy. However because it primarily focuses on learning noise, the generalization is poor with a test accuracy of around 50%. For the high-SNR case, both training and test sets can be perfectly classified due to the signal learning. Low SNR High SNR Figure 5: Experiments on the synthetic dataset with both low SNR (n SNR2 = 0.75) and high SNR (n SNR2 = 6.75). A.2 FEATURE LEARNING COMPARISON UNDER VARYING SNRS In this section, we compare the feature learning dynamics of classification and diffusion models on additional settings of SNR. Apart from the n SNR2 = 0.75 and n SNR2 = 6.75 as shown in the main text, we additionally test on (1) n SNR2 = 1.92, (2) n SNR2 = 3 (3) n SNR2 = 4.32. The feature learning dynamics under the corresponding SNR settings are shown in Figure 6. From the figures, we can see that classification indeed is more sensitive to the SNR scale, where it easily overfit to either signal or noise (except for the case where n SNR2 = 3 where classification learns signal and noise to approximately the same scale). On the other hand, we can verify that at stationarity, diffusion model learns in a more balanced scale for signal and noise. A.3 HIGH SNR SETTING ON NOISY-MNIST Here we include experiment results when ] SNR = 0.5, which corresponds to the high SNR setting. The experiment settings are exactly the same as in the main experiment. Figure 8 shows both classification and diffusion model converge in terms of objective. In addition, we see the high SNR encourages classification to learn primarily the signal while ignoring the noise. In contrast, diffusion model still learns both signal and noise to relatively the same order. Figure 7 suggests that classification learns more signal compared to noise while diffusion model still learns more balanced signal and noise. We also plot classification accuracy for both the low and high SNR cases. In the low-SNR case, because classification predominately learns noise, the generalization is poor with test accuracy around 50%. Conversely in the high-SNR case, where the model is able to learn signals, the classification demonstrates effective generalization with nearly 100% test accuracy. Published as a conference paper at ICLR 2025 n SNR2 = 1.92 n SNR2 = 3 n SNR2 = 4.32 Figure 6: Experiments on the synthetic dataset with varying SNRs. Figure 7: Experiments on Noisy-MNIST with ] SNR = 0.5. (First row): Test Noisy-MNIST images; (Second row): Illustration of input gradient, i.e., x F+1(W, x) when y = 1 and x F 1(W, x) when y = 0. (Third row): denoised image from diffusion model. In this low-SNR case, we see classification tends to predominately learn noise while diffusion learns both signal and noise. Figure 8: Experiments on Noisy-MNIST with ] SNR = 0.5. (a) Train loss for classification. (b) Train loss for diffusion model. (c) Feature learning dynamics. A.4 EXPERIMENTS WITH ADDITIONAL DIFFUSION TIME STEP Here we also test on additional diffusion time step for learning on noisy-MNIST dataset. In particular, we consider t = 0.8, which gives αt = exp( t) = 0.45 and βt = p 1 exp( 2t) = 0.89. We include the illustrations of denoised images as well as loss convergence and feature learning dynamics in Figure 10, 11, 12, 13. We see despite with a larger scale of added diffusion noise, diffusion model still learn both signals and noise unlike for the case of classification. Published as a conference paper at ICLR 2025 (a) ACC (] SNR = 0.1) (b) ACC (] SNR = 0.5) Figure 9: Classification accuracy on (a) low-SNR and (b) high-SNR noisy MNIST datasets. This demonstrates that when classification focuses on learning noise (as in the low-SNR case), the test accuracy hovers around 50%, thus suggesting failure to generalize. In contrast, when classification focuses on learning signals (as in the high-SNR case), classification generalizes effectively, achieving near-perfect accuracy. Figure 10: Additional experiments on Noisy-MNIST with ] SNR = 0.1 and diffusion t = 0.8. (First row): Test Noisy-MNIST images; (Second row): denoised image from diffusion model. We see diffusion still learns both signals and noise even with large diffusion time step. Figure 11: Additional experiments on Noisy-MNIST with ] SNR = 0.1 and = t = 0.8. (a) Train loss for diffusion model. (c) Feature learning dynamics. Figure 12: Additional experiments on Noisy-MNIST with ] SNR = 0.5 and diffusion t = 0.8. (First row): Test Noisy-MNIST images; (Second row): denoised image from diffusion model. We see diffusion still learns both signals and noise even with large diffusion time step. Figure 13: Additional experiments on Noisy-MNIST with ] SNR = 0.5 and = t = 0.8. (a) Train loss for diffusion model. (c) Feature learning dynamics. Published as a conference paper at ICLR 2025 Figure 14: Experiments on 10-class Noisy-MNIST with ] SNR = 0.1. (First row): Test Noisy-MNIST images; (Second row): Illustration of gradient of output (for the true class) with respect to the input. (Third row): denoised image from diffusion model. In this low-SNR case, we see classification tends to predominately learn noise while diffusion learns both signals and noise. Figure 15: Experiments on 10-class Noisy-MNIST with ] SNR = 0.1. (a) Train loss for classification. (b) Train loss for diffusion model. (c) Feature learning dynamics. A.5 ON THE FEATURE LEARNING WITH 10-CLASS MNIST In the main paper, we only conduct experiments on Noisy-MNIST restricted to two classes. In this section, we experiment over the 10-class MNIST dataset, which contains more features and is more challenging for both diffusion model and classification. We adopt the same data processing pipelines as in Section 5.2 except that for each class, we select 10 images. We set the scaled SNR ] SNR = 0.1, consistent with the main paper. While the diffusion model remains unchanged, the classification model requires modification. Specifically, the second layer s weight matrix has dimensions m 10, with entries fixed uniformly to values in { 1, 1}. Furthermore, we employ cross-entropy loss for training the classification model. We plot the visualization of feature learning in Figure 14. We observe that, even with additional features and labels, the similar learning patterns are observed, i.e., diffusion model learns both signals and noise in order to reconstruct the input distribution while classification model learns primarily noise for loss minimization. From Figure 15(c), we notice that diffusion model learns features to relatively the same scale while for classification, the growth of feature learning is dominated by noise learning. Published as a conference paper at ICLR 2025 B PRELIMINARY LEMMAS Recall we define S1 = {i [n] : yi = 1} and S 1 = {i [n] : yi = 1}. Lemma B.1. Given arbitrary δ > 0, with probability at least 1 δ, we have n 2 1 e O(n 1/2) |S1|, |S 1| n 2 1 + e O(n 1/2) Proof of Lemma B.1. The proof is the same as in (Cao et al., 2022; Kou et al., 2023) and we include here for completeness. Because |S1| = Pn i=1 1(yi = 1) and |S 1| = Pn i=1 1(yi = 1) and P(yi = 1) = P(yi = 1) = 1/2 for all i [n], then E|S1| = E|S 1| = n/2. By Hoeffding s inequality, for arbitrary a > 0, P(||S 1| n/2| a) 2 exp( 2a2n 1). Setting a = p n log(4/δ)/2 and taking union bound, we have with probability at least 1 δ, Hence the proof is complete. Lemma B.2. Given arbitrary δ > 0, with probability at least 1 δ, σ2 ξd(1 e O(d 1/2)) ξi 2 σ2 ξd(1 + e O(d 1/2)) | ξi, ξi | 2σ2 ξ p d log(4n2/δ) for all i, i [n]. Proof of Lemma B.2. The proof is the same as in (Cao et al., 2022; Kou et al., 2023) and we include here for completeness. By Bernstein s inequality, with probability at least 1 δ/(2n), we have | ξi 2 σ2 ξd| = O(σ2 ξ p d log(4n/δ)), which shows the first result. For the second claim, we can show by Bernstein s inequality, with probability at least 1 δ/(2n2) that for any i = i | ξi, ξi | 2σ2 ξ p d log(4n2/δ) Then we apply union bound to show the results hold for all i, i [n]. C CLASSIFICATION We track the inner product dynamics during the training of supervised classification to elucidate the signal learning and noise learning. We first write the gradient descent dynamics as follows. wk+1 j,r = wk j,r η wj,r LS(Wk) = wk j,r η nm i=1 ℓ k i wk j,r, x(1) i jyix(1) i η nm i=1 ℓ k i wk j,r, ξi jyiξi = wk j,r η nm i S1 ℓ k i wk j,r, µ1 jµ1 + η nm i S 1 ℓ k i wk j,r, µ 1 jµ 1 η nm i=1 ℓ k i wk j,r, ξi jyiξi Here we restate the Condition 3.1 specific for the case of supervised classification. Condition C.1. Suppose that 1. Dimension d satisfies d = eΩ(max{n2mσ 1 ξ µ , n4m}). Published as a conference paper at ICLR 2025 2. Training sample and network width satisfy m = Ω(log(n/δ)), n = Ω(log(m/δ)). 3. The initialization variation σ0 satisfies e O(n2mσ 1 ξ d 1) σ0 e O(min{ µ 1, σ 1 ξ d 1/2}). 4. The learning rate satisfies η e O(min{m µ 2, nmσ0σ 1 ξ d 1/2, nmσ 2 ξ d 1}) We make the particular remarks as follows. The lower bound on m = eΩ(1) is to ensure the initialization is concentrated and thus provides a lower bound on the maximum and average inner product. The lower bound on n = eΩ(1) is required such that |S1|, |S 1| = Θ(n) and e O(n 1/2) remains small. The lower bound on σ0 is required for the noise memorization setting where we need to control the lower bound for the noise inner product at initialization. Thus to ensure the lower bound σ0 is valid, we require further conditions on the dimension d apart from d = eΩ(n2). C.1 USEFUL LEMMAS We first provide a lemma that bound the inner product at initialization. Lemma C.1 (Cao et al. (2022)). Suppose δ > 0 and that d = Ω(log(mn/δ)), m = Ω(log(1/δ)), then with probability at least 1 δ, | w0 j,r, µj | p 2 log(8m/δ)σ0 µ | w0 j ,r, ξi | 2 p log(8mn/δ)σ0σξ for all j, j { 1}, r [m], i [n]. In addition, max r [m] | w0 j,r, µj | σ0 µ /2, max r [m] | w0 j,r, ξi | σ0σξ for all j, j { 1}, i [n]. We decompose the weights into its signal components and noise components. Lemma C.2. The weight can be decomposed as wk j,r = w0 j,r + ζk 1 µ1 + ζk 1µ 1 + i=1 ρk j,r,i ξi 2ξi where the noise coefficients ρk j,r,i satisfy ρ0 j,r,i = 0 and ρk+1 j,r,i = ρk j,r,i η nmℓ k i wk j,r, ξi jyi ξi 2 for all j = 1, r [m] and i [n]. Proof of Lemma C.2. The proof follows from (Cao et al., 2022; Kou et al., 2023). First, we recall the gradient descent update as wk+1 j,r = wk j,r η nm i S1 ℓ k i wk j,r, µ1 jµ1 + η nm i S 1 ℓ k i wk j,r, µ 1 jµ 1 η nm i=1 ℓ k i wk j,r, ξi jyiξi = w0 j,r η nm i S1 ℓ k i ws j,r, µ1 jµ1 + η nm i S 1 ℓ k i ws j,r, µ 1 jµ 1 i=1 ℓ k i ws j,r, ξi jyiξi. By the data model, we have with probability 1, the vectors are linearly independent and thus the decomposition is unique with ρk j,r,i = η s=0 ℓ k i ws j,r, ξi jyi ξi 2 Then writing out the iterative update for ρk j,r,i completes the proof. Published as a conference paper at ICLR 2025 Lemma C.3. Let x N(0, σ2). Then P(|x| c) erf c Proof of Lemma C.3. The probability density function for x is given by 2πσ exp( x2 Then we know that P(|x| c) = 1 By the definition of erf function erf(c) = 2 π 0 exp( x2)dx, and variable substitution yields Therefore, we first conclude P(|x| c) = 2erf( c 2σ). Next, by the inequality erf(x) p 1 exp( 4x2/π), we obtain the desired result. C.2 SCALE OF INNER PRODUCTS We first derive a global bound for the growth of inner products until convergence. To this end, we let T = η 1poly( µ 1, σ 2 ξ d 1, σ 1 0 , n, m, d) be the maximum number of iterations considered and let α = 2 log(T ). We also denote β := 3 maxj,r,i,y{| w0 j,r, µy |, | w0 j,r, ξi |}. Then from Lemma C.1 and from Condition C.1, we can bound 3 max{σ0 µ /2, σ0σξ d/4} β 1/C (4) for some sufficiently large constant C > 0. Proposition C.1. Under Condition C.1, for all 0 k T , we can bound | wk j,r, µj |, | wk yi,r, ξi |, |ρk yi,r,i| α, (5) | wk j,r, µ j | β, (6) | wk yi,r, ξi |, |ρk yi,r,i| β + 12 for all i [n], r [m] and j = 1. We will prove the bound by induction and we first derive several intermediate lemmas as follows. Lemma C.4. Suppose results in Proposition C.1 hold at iteration k, then we have Fj(Wk j , xi) 0.5 for all i [n], j = yi. Proof of Lemma C.4. Recall that Fj(Wk j , xi) = 1 wk j,r, x(1) i 2 + wk j,r, x(2) i 2 wk j,r, µyi 2 + wk j,r, ξi 2 where the second last inequality is by (6) and (7). The last inequality is by Condition C.1 such that β 1/C 0.25 and d 144n2α2 log(4n2/δ). Published as a conference paper at ICLR 2025 Lemma C.5. Suppose results in Proposition C.1 hold at iteration k, then we have | wk j,r w0 j,r, ξi ρk j,r,i| 4 for all j = 1, r [m], i [n]. Proof. By Lemma C.2, we recall the decomposition as wk j,r = w0 j,r + ζk 1 µ1 + ζk 1µ 1 + i=1 ρk j,r,i ξi 2ξi. By the orthogonality, we can show wk j,r, ξi = w0 j,r, ξi + ρk j,r,i + X i =i ρk j,r,i ξi 2 ξi, ξi By Lemma B.2 and suppose d = Ω(log(n/δ)), then | ξi, ξi | ξi 2 4 p log(4n2/δ)d 1. Thus we have | wk j,r w0 j,r, ξi ρk j,r,i| 4 where we use the upper bound on |ρk j,r,i| α. Lemma C.6. For any r [m], j, y = 1, we have sign( w0 j,r, µy ) = sign( wk j,r, µy ) for all 0 k T . Proof of Lemma C.6. We prove the results by induction. First, it is clear at k = 0, the results are satisfied. Then suppose there exists an iteration ek such that sign( wk j,r, µy ) = sign( w0 j,r, µy ) holds for all k ek 1, we show the sign invariance also holds at ek. Recall the gradient descent update as wk+1 j,r = wk j,r η nm i S1 ℓ k i wk j,r, µ1 jµ1 + η nm i S 1 ℓ k i wk j,r, µ 1 jµ 1 i=1 ℓ k i wk j,r, ξi jyiξi. Then the update of the inner product is w ek j,r, µy = w ek 1 j,r , µy η nm i Sy ℓ k 1 i w ek 1 j,r , µy jy µ 2 = 1 η nmjy X i Sy ℓ k 1 i µ 2 w ek 1 j,r , µy By the condition that η C 1m µ 2 for sufficiently large constant C, we have | η nmjy P i Sy ℓ k 1 i µ 2| < 1. Thus we can guarantee the sign( wek j,r, µy ) = sign( w ek 1 j,r , µy ) = sign( w0 j,r, µy ). Proof of Proposition C.1. We prove the results by induction. For ρk j,r,i, we prove a stronger result that |ρk yi,r,i| 0.9α α and |ρk yi,r,i| 0.6β + 8 q d nα. First it is clear at t = 0, the results are satisfied based on the definition of β and α β. Now suppose that there exists e T T such that results hold for all 0 k e T 1. We wish to show the results also hold for k = e T. First recall the gradient descent update as wk+1 j,r = wk j,r η nm i S1 ℓ k i wk j,r, µ1 jµ1 + η nm i S 1 ℓ k i wk j,r, µ 1 jµ 1 Published as a conference paper at ICLR 2025 i=1 ℓ k i wk j,r, ξi jyiξi. Then based on the orthogonal data modelling assumption, we have for y = j, i.e., y = j, wk+1 j,r , µ j = wk j,r, µ j + η nm i S j ℓ k i wk j,r, µ j µ 2 i S j |ℓ k i | wk j,r, µ j where the second equality is by ℓ k i < 0 for all i, k. From Lemma C.6, we have sign( wk+1 j,r , µ j ) = sign( wk j,r, µ j ) and thus | w e T j,r, µ j | i S j |ℓ e T 1 i | w e T 1 j,r , µ j w e T 1 j,r , µ j β On the other hand, for y = j, we have wk+1 j,r , µj = wk j,r, µj η nm i Sj ℓ k i wk j,r, µj µ 2 = wk j,r, µj + η µ 2 i Sj |ℓ k i | wk j,r, µj Next, we notice that |ℓ k i | = 1 1 + exp Fyi(Wkyi, xi) F yi(Wk yi, xi) exp Fyi(Wk yi, xi) + F yi(Wk yi, xi) exp Fyi(Wk yi, xi) + 0.5 wk yi,r, µyi 2 + wk yi,r, ξi 2 + 0.5 (8) where the last inequality is by Lemma C.4. Let kj,r be the last time k T that | wk j,r, µj | 0.5α. Then we have w e T j,r, µj = wkj,r j,r , µj + η µ 2 nm |ℓ kj,r i | wkj,r j,r , µj | {z } A1 kj,r 0, then 1 3ρ e T 1 yi,r,i w e T 1 yi,r, ξi 4 3ρ e T 1 yi,r,i Then (9) suggests ρ e T yi,r,i 1 η ξi 2 3nm |ℓ k i | ρ e T 1 yi,r,i ρ e T 1 yi,r,i 0.6β + 8 If ρ e T 1 yi,r,i < 0, then 4 3ρ e T 1 yi,r,i w e T 1 yi,r, ξi 1 3ρ e T 1 yi,r,i Then (9) suggests ρ e T yi,r,i 1 η ξi 2 3nm |ℓ k i | ρ e T 1 yi,r,i ρ e T 1 yi,r,i 0.6β 8 Thus this completes the proof that |ρ e T yi,r,i| 0.6β + 8 q Finally, by Lemma C.5 we have for all k 0 | wk yi,r, ξi | | w0 yi,r, ξi | + |ρk yi,r,i| + 4 d nα 0.9β + 12 which proves the upper bound for | w e T yi,r, ξi | and |ρ e T yi,r,i|. Next, from Lemma C.2, we have for yi = j, ρk+1 yi,r,i = ρk yi,r,i η nmℓ k i wk yi,r, ξi ξi 2. (10) Let kr,i be the last time k < T that |ρk yi,r,i| 0.6α. Then it can be verified that for k kr,i, | wk yi,r, ξi | |ρk yi,r,i| | w0 yi,r, ξi | 4 where the first inequality is by Lemma C.5 and the last inequality is by | w0 yi,r, ξi | + d nα 1 0.1α. Published as a conference paper at ICLR 2025 We now expand (10) as ρ e T yi,r,i = ρ kr,i yi,r,i + η nm|ℓ kr,i i | w kr,i yi,r, ξi ξi 2 kr,i 0 such that |ℓ k i | Cℓfor all i [n]. Proof of Lemma C.7. If maxr,i,y{ wk j,r, µy , wk j,r, ξi } = O(1), we can bound for all j = 1 Fj(Wk j , xi) = 1 wk j,r, µyi 2 + wk j,r, ξi 2 O(1) Therefore, we can bound |ℓ k i | = (1 + exp(Fyi(Wk yi, xi) F yi(Wk yi, xi))) 1 Ω(1). We also prove the following upper bound on the gradient norm. Published as a conference paper at ICLR 2025 Lemma C.8 (Proof of Lemma C.8). Under Condition C.1, for 0 k T , we can bound LS(Wk) 2 = O(max{ µ 2, σ2 ξd})LS(Wk) Proof of Lemma C.8. The proof adopts a similar argument as in (Cao et al., 2022, Lemma C.7) and we include here for completeness. We first bound f(Wk, xi) 2 wk j,r, µyi µyi + wk j,r, ξi ξi r | wk yi,r, µyi | µ + 2 r | wk yi,r, ξi | ξi r | wk yi,r, µyi | µ + 2 r | wk yi,r, ξi | ξi | wk yi,r, µyi | + | wk yi,r, ξi | max{ µ , 2σξ | wk yi,r, µyi | + | wk yi,r, ξi | max{ µ , 2σξ Fyi(Wkyi, xi) + q F yi(Wk yi, xi) max{ µ , 2σξ Fyi(Wkyi, xi) + 1 max{ µ , 2σξ where the third inequality is by Lemma B.2 and the fourth inequality is by Jensen s inequality and the last inequality is by Lemma C.4 that F yi(Wk yi, xi) for all i [n]. Then we have ℓ (yif(Wk, xi)) f(Wk, xi) 2 ℓ Fyi(Wk yi, xi) 0.5 2 q Fyi(Wkyi, xi) + 1 max{ µ , 2σξ = 4ℓ Fyi(Wk yi, xi) 0.5 q Fyi(Wkyi, xi) + 1 2 max{ µ 2, 4σ2 ξd} max z>0 { 4ℓ (z 0.5)( z + 1)2} max{ µ 2, 4σ2 ξd} = O(max{ µ 2, σ2 ξd}) where the last equality is by maxz>0{ 4ℓ (z 0.5)( z +1)2} < because ℓ has an exponentially decaying tail. Then we can bound i=1 ℓ (yif(Wk, xi)) f(Wk, xi) 2 O(max{ µ 2, σ2 ξd})ℓ (yif(Wk, xi)) 2 O(max{ µ 2, σ2 ξd}) 1 i=1 ℓ (yif(Wk, xi)) O(max{ µ 2, σ2 ξd})LS(Wk) where the third inequality is by Cauchy-Schwartz inequality and the last inequality is by ℓ ℓfor cross-entropy loss. C.3 SIGNAL LEARNING We first analyze the setting, where n SNR2 C for some constant C > 0, which allows signal learning to dominate noise memorization, thus reaching benign overfitting. For the purpose of signal learning, we derive an anti-concentration result that provides a lower bound for signal inner product at initialization. Published as a conference paper at ICLR 2025 Lemma C.9. Suppose δ > 0 and m = Ω(log(1/δ)). Then with probability at least 1 δ, we have for all j, y = 1 r=1 | w0 j,r, µy | σ0 µ Proof of Lemma C.16. First notice that for any j = 1, w0 j,r, µy N(0, σ2 0 µ 2) and thus we have E[| w0 j,r, µy |] = σ0 µ p 2/π. By sub-Gaussian tail bound, with probability at least 1 δ/8, for any j, y = 1 1 m r=1 | w0 j,r, µy | σ0 µ p 2σ2 0 µ 2 log(8/δ) Choosing m = Ω(log(1/δ)), we have r=1 | w0 j,r, µy | σ0 µ p Then we have σ0 µ /2 1 m Pn r=1 | w0 j,r, ξi | σ0 µ . Finally taking the union bound for all j, y = 1 completes the proof. We have established several preliminary lemmas that hold with high probability, including Lemma B.1, Lemma B.2, Lemma C.1, Lemma C.9. We let Eprelim be the event such that all the results in these lemmas hold for a given δ. Then by applying union bound, we have P(Eprelim) 1 4δ. The subsequent analysis are conditioned on the event Eprelim. C.3.1 FIRST STAGE In the first stage where maxr,i,y{ wk j,r, µy , wk j,r, ξi } = O(1), we show in Lemma C.7 that we can lower bound the loss derivatives by a constant Cℓ, i.e., |ℓ k i | Cℓ, for all i [n], k T1. Theorem C.1. Under Condition C.1, suppose n SNR2 C for some C 0. Then there exists a time T1 = eΘ(η 1m µ 2), such that (1) maxr | w T1 j,r, µj | 2, for all j = 1, (2) 1 m Pm r=1 | w T1 j,r, µj | 2, for all j = 1 (3) maxr,i | w T1 yi,r, ξi | = e O(n 1/2). Proof of Theorem C.1. We first upper bound the growth of noise by analyzing inner product dynamics wk yi,r, ξi = wk 1 yi,r , ξi η nm i =1 ℓ k 1 i wk 1 j,r , ξi ξi , ξi = wk 1 yi,r , ξi η nmℓ k i wk 1 yi,r , ξi ξi 2 η nm i =i ℓ k 1 i wk 1 yi,r , ξi ξi , ξi This suggests | wk yi,r, ξi | | wk 1 yi,r , ξi | + η nm|ℓ k i || wk 1 yi,r , ξi | ξi 2 + η nm i =i |ℓ k i || wk 1 yi,r , ξi || ξi , ξi | Next, from Lemma C.7 and Lemma B.2, we have for any i = i [n] and k T1, |ℓ k i | | ξi, ξi | |ℓ k i | ξi 2 2σ2 ξ p d log(4n2/δ) Cℓ0.99σ2 ξd = 2.1C 1 ℓ where we use the lower and upper bound on loss derivatives during the first stage, as well as Lemma B.2. Then taking the maximum of (11) over the neurons and samples, we let Bk := maxr,i | wk yi,r, ξi | and obtain Bk Bk 1 + η nm 1 + 2.1C 1 ℓ n |ℓ k i | ξi 2Bk 1 Published as a conference paper at ICLR 2025 1 + 1.01η ξi 2 1 + 1.02ησ2 ξd nm where the second inequality is by d = eΩ(n2) sufficiently large and |ℓ k i | 1. The third inequality is by Lemma B.2. We then consider the propagation of wk j,r, µy . From the gradient update we can show for j = y, wk+1 j,r , µj = wk j,r, µj η nm i Sj ℓ k i wk j,r, µj µ 2 wk j,r, µj + ηCℓ|S1| µ 2 nm wk j,r, µj 1 + 0.49ηCℓ µ 2 m wk j,r, µj where the first inequality is by loss derivative lower bound and the the second inequality is by Lemma B.1 and n = eΩ(1) sufficiently large. This implies that | wk j,r, µj | 1 + 0.49ηCℓ µ 2 m | wk 1 j,r , µj | 1 + 0.49ηCℓ µ 2 m k| w0 j,r, µj | Applying Lemma C.1 and Lemma C.9, we have for all j = 1, max r | wk j,r, µj | 1 + 0.49ηCℓ µ 2 r=1 | wk j,r, µj | 1 + 0.49ηCℓ µ 2 T1 = log(4mσ 1 0 µ 1)/ log 1 + 0.49ηCℓ µ 2 m = Θ(η 1m µ 2 log(4mσ 1 0 µ 1)) for η sufficiently small. Then we can verify that for j = 1, we have max r | w T1 j,r, µj | 2, 1 m r=1 | w T1 j,r, µj | 2, Now under the SNR condition, we can bound the growth of noise as BT1 1 + 1.01 ησ2 ξd nm log(1 + 1.02 ησ2 ξd nm ) log(1 + 0.49 ηCℓ µ 2 m ) log(4σ 1 0 µ 1) exp 2.1/Cℓn 1SNR 2 + e O(n SNR2η) log(4σ 1 0 µ 1) 2σ0σξ exp 2.1/Cℓn 1SNR 2 + 0.01 log(4σ 1 0 µ 1) 2σ0σξ = e O(n 1/2) where the first inequality is by Lemma C.1 and the second inequality is by Taylor expansion around η = 0. The third inequality is by choosing η sufficiently small and the fourth inequality is by the SNR condition that n SNR2 C 2.5C 1 ℓ . Published as a conference paper at ICLR 2025 C.3.2 SECOND STAGE First, at the end of first stage, we have maxr | w T1 j,r, µj | 2 for all j = 1. 1 m Pm r=1 | w T1 j,r, µj | 2 for all j = 1. maxr,i | w T1 yi,r, ξi | = e O(n 1/2) maxr,i | w T1 yi,r, ξi | β + 12 q Next we define w j,r = w0 j,r + 2 log(4/ϵ)sign( w0 j,r, µj )µj + µ j We first show the monotonicity of signal inner product in the second stage. Lemma C.10. Under the same conditions as in Theorem C.1, we have for all j = 1, r [m], T1 k T, | wk j,r, µj | | w T1 j,r, µj | 2. Proof of Lemma C.10. From the update of signal inner product, we have for all j = 1, r [m], T1 k T wk+1 j,r , µj = wk j,r, µj η nm i Sj ℓ k i wk j,r, µj µ 2 i Sj ℓ k i wk j,r, µj . Thus | wk j,r, µj | | wk 1 j,r , µj | | w T1 j,r, µj | 2 for all j = 1, r [m], T1 k T. We then bound the distance between WT1 to W . Lemma C.11. Under Condition C.1, we can bound WT1 W = O( m log(1/ϵ) µ 1). Proof of Lemma C.11. Let Pξ be the projection matrix to the direction of ξ, i.e., Pξ = ξξ ξ 2 . Then we can represent wk j,r w0 j,r = Pµ1(wk j,r w0 j,r) + Pµ 1(wk j,r w0 j,r) + i=1 Pξi(wk j,r w0 j,r) + I Pµ1 Pµ 1 i=1 Pξi (wk j,r w0 j,r). By the scale difference at T1 and the fact that gradient descent only updates in the direction of µj, j = 1 and ξi, we can bound w T1 j,r w0 j,r, µ1 2 µ 2 + w T1 j,r w0 j,r, µ 1 2 w T1 j,r w0 j,r, ξi 2 i=1 Pξi (w T1 j,r w0 j,r) 2m 2 maxr w T1 j,r, µj 2 µ 2 + 2 w T1 j,r, µ j 2 + 2 w0 j,r, µ j 2 + 2 w0 j,r, µj 2 Published as a conference paper at ICLR 2025 2 w T1 j,r, ξi 2 + 2 w0 j,r, ξi 2 i=1 Pξi (w T1 j,r w0 j,r) where we have use the scale difference at T1. Therefore, WT1 W WT1 W0 + W0 W O( m µ 1) + O( m log(1/ϵ) µ 1) O( m log(1/ϵ) µ 1) where we use the definition of W . Lemma C.12. Under Condition C.1, we have for all T1 k T Wk W 2 Wk+1 W 2 2ηLS(Wt) ηϵ Proof of Lemma C.12. The proof is similar as in Cao et al. (2022). We first show a lower bound on yi f(Wt, xi), W for any i [n] for all T1 k T . yi f(Wk, xi), W = 1 j,r jyi wk j,r, µyi µyi, w j,r + 1 j,r jyi wk j,r, ξi ξi, w j,r r=1 wk yi,r, µyi w yi,r, µyi 1 r=1 wk yi,r, µyi w yi,r, µyi j,r jyi wk j,r, ξi ξi, w0 j,r r=1 | wk yi,r, µyi |2 log(4/ϵ) r=1 wk yi,r, µyi w0 yi,r, µyi r=1 wk yi,r, µyi w yi,r, µyi j,r jyi wk j,r, ξi ξi, w0 j,r | {z } A8 where the second equality is by definition of W . The third equality is by Lemma C.6. We next bound |A6| σ0 µ p 2 log(8m/δ)α = e O(σ0 µ ) r=1 |wk yi,r, µyi| | w0 yi,r, µyi | + 2 log(2/ϵ) = e O(σ0 µ ) |A8| e O(σ0σξ where we use the global bound on the inner product by e O(1). Next, by Theorem C.1 and Lemma C.10, we can show 1 m Pm r=1 | wk yi,r, µyi | 2 for all i [n] and we can lower bound A5 4 log(4/ϵ) and thus yi f(Wk, xi), W 4 log(4/ϵ) 2 log(4/ϵ) = 2 log(4/ϵ) (12) where we bound |A6| + |A7| + |A8| 2 log(4/ϵ) under Condition C.1. Further, we derive Wk W 2 Wk+1 W 2 = 2η LS(Wk), Wk W η2 LS(Wk) 2 i=1 ℓ k i yi 2f(Wk, xi) f(Wk, xi), W η2 LS(Wk) 2 Published as a conference paper at ICLR 2025 i=1 ℓ k i 2yif(Wk, xi) 2 log(2/ϵ) η2 LS(Wk) 2 ℓ(yif(Wk, xi)) ϵ/4 η2 LS(Wk) 2 2ηLS(Wk) ηϵ where the first inequality is by (12) and the second inequality is by convexity of cross-entropy function and the last inequality is by Lemma C.8. Before proving the second stage convergence, we require the following lemma in order to bound the ratio of loss derivatives among different samples. Lemma C.13 (Kou et al. (2023)). Let g(z) = ℓ (z) = (1 + exp(z)) 1. Then for any z2 c z1 1 where c 0, we have g(z1)/g(z2) exp(c). Theorem C.2. Under the same settings as in Theorem C.1, let T = T1 + WT1 W 2 ηϵ = T1 + O(η 1ϵ 1m µ 2). Then we have there exists T1 k T such that LS(Wk) 0.1. maxj,r,i | wk j,r, ξi | = o(1) for all T1 k T. maxr | wk j,r, µj | 2 for all j = 1, T1 k T. Proof of Theorem C.2. By Lemma C.12, for any T1 k T, we have Wk W 2 Wk+1 W 2 2ηLS(Wk) ηϵ for all s k. Then summing over the inequality gives k=T1 LS(Wk) WT1 W 2 2η(T T1 + 1) + ϵ where the last inequality is by the choice T = T1 + WT1 W 2 ηϵ = T1 + Ω(η 1ϵ 1m log(1/ϵ) µ 2). Then we can claim that there exists a k [T1, T] such that LS(Wk) ϵ. Setting ϵ = 0.1 shows the desired convergence. Next, we show the upper bound on maxj,r,i | wk j,r, ξi | for all k [T1, T]. Notice that by Proposition C.1, we already have maxj,r | wk yi,r, ξi | ϑ, where we let ϑ := 3 max{maxr,i | w T1 yi,r, ξi |, β, 4 q d nα}. Then we only focus on bounding maxyi,i | wk j,r, ξi |. From the scale difference at T1, we know that ϑ = e O(max{n 1/2, σ0σξ d, σ0 µ , nd 1/2}) = o(1). Next, we can bound k=T1 LS(Wk) WT1 W 2 η = O(η 1m log(1/ϵ) µ 2) (13) where we use Lemma C.11 for the last equality. Then, we first prove maxr,i |ρk yi,r,i| 2ϑ for all T1 k T. First it is easy to see that at T1, we have max r,i |ρT1 yi,r,i| max r,i | w T1 yi,r, ξi | + max r,i | w0 yi,r, ξi | + 4 Then suppose there e T [T1, T] such that maxr,i |ρT1 yi,r,i| 2ϑ for all k [T1, e T 1]. Now we let ϕk := maxr,i |ρk yi,r,i| and thus by the update of noise coefficient ϕk+1 ϕk + η nm|ℓ k i | ϕk + β/3 + 4 Published as a conference paper at ICLR 2025 ϕk + η nm max i |ℓ k i | ϕk + β/3 + 4 d nα O(σ2 ξd). where we use Lemma C.5 in the first inequality. Then taking the summation from T1 to e T gives ϕ e T ϕT1 + η nm k=T1 max i |ℓ k i |O(σ2 ξd)ϑ (14) where the first inequality is by the induction condition. Next, the aim is bound P e T 1 k=T1 maxi |ℓ k i |. First, for any i, i [n] such that yi = yi , we can bound for all T1 k e T 1 yif(Wk, xi) yi f(Wk, xi ) = Fyi(Wk yi, xi) F yi(Wk yi, xi)) Fyi (Wk yi , xi ) + F yi (Wk yi , xi )) wk yi,r, µyi 2 + wk yi,r, ξi 2 1 wk yi,r, µyi 2 + wk yi,r, ξi 2 + 1/C1 wk yi,r, ξi 2 wk yi,r, ξi 2 + 1/C1 max r,i wk yi,r, ξi 2 + 1/C1 max r,i |ρk yi,r,i| + max r,i | w0 yi,r, ξi | + 4 where in the first inequality we notice that F yi(Wk yi, xi)) 0, yi = yi and we recall that F yi(Wk j , xi) β2 + β + 12 q d nα 2 = 1/C1 for some sufficiently large constant C1 > 0. The second last inequality is by induction condition and the last inequality is by choosing ϑ 1/6. Then we can bound the ratio of loss derivatives (based on Lemma C.13) that |ℓ k i |/|ℓ k i | exp yif(Wk, xi) yi f(Wk, xi )) exp(ϑ) This suggests 1 O(ϑ) |ℓ k i |/|ℓ k i | 1 + O(ϑ) for all i, i [n], T1 k e T 1. Then let i = arg maxi |ℓ k i |, we have T1 max i |ℓ k i | = T1 Θ( 1 |Syi | i Syi |ℓ k i |) T1 Θ( 1 |Syi | i Syi ℓk i ) T1 Θ( n |Syi |LS(Wk)) = e O(η 1m log(1/ϵ) µ 2) (15) where the first inequality is by |ℓ | ℓand the last equality is from (13) and |Syi | 0.49n (based on Lemma B.1). This allows to bound (14) as ϕ e T ϕT1 + η nm s=T1 max i |ℓ k i |O(σ2 ξd)ϑ ϕT1 + O(n 1σ2 ξd log(1/ϵ) µ 2) ϑ ϑ + O(n 1SNR 2 log(1/ϵ)) ϑ 2ϑ and the second inequality is by (15) and the last inequality is by setting ϵ = 0.1 and n SNR2 C for sufficiently large constant C . Thus, we have maxr,i | w e T yi,r, ξi | maxr,i |ρ e T yi,r,i| + β + d nα 3ϑ = o(1). The lower bound on signal inner product is directly from Lemma C.10. Published as a conference paper at ICLR 2025 C.4 NOISE MEMORIZATION We also analyze the setting where n 1SNR 2 C for some constant C > 0, which allows the noise memorization to dominate signal learning, thus reaching harmful overfitting. We first require the following anti-concentration result for the noise inner product, which is required to ensure the sign invariance of the inner product along training. Lemma C.14. Suppose δ > 0 and σ0 Ω(log(n2/δ)n2mαd 1σ 1 ξ ), we have for all j = 1, r [m], i [n], | w0 j,r, ξi | 8 q Proof of Lemma C.14. For any j = 1, r [m], i [n], we have w0 j,r, ξi N(0, σ2 0 ξi 2). Then applying Lemma C.3 by setting RHS to δ/(2mn) and c = 8 q d nα, we require d2 42 log(4n2/δ)n2α2σ 2 0 σ 2 ξ / log( 4m2n2 where we use Lemma B.2 that ξi 2 0.99σ2 ξd. Finally noticing that 1/ log(4m2n2/(4m2n2 δ2)) Θ(m2n2) and taking the union bound completes the proof. An immediate consequence of Lemma C.14 is the following result that allows to derive the sign invariance for wk yi,r,i, ξi for all iterations. Lemma C.15. Under Condition C.1, for any i [n], r [m], we have sign( wk yi,r, ξi ) = sign(ρk yi,r,i) = sign( w0 yi,r, ξi ) for all 0 k T . Proof of Lemma C.15. First by Lemma C.14 and Lemma C.5, we can bound if w0 yi,r, ξi 0, ρk yi,r,i + 1 2 w0 yi,r, ξi wk yi,r, ξi ρk yi,r,i + 3 2 w0 yi,r, ξi and if w0 yi,r, ξi 0, ρk yi,r,i + 3 2 w0 yi,r, ξi wk yi,r, ξi ρk yi,r,i + 1 2 w0 yi,r, ξi Next we use induction to show the sign invariance. First it is clear when k = 0, the sign invariance is trivially satisfied. At k = 1, we have by the iterative update of the coefficients, ρ1 yi,r,i = ρ0 yi,r,i + η nm|ℓ 0 i | w0 yi,r, ξi ξi 2 = η nm|ℓ 0 i | w0 yi,r, ξi ξi 2 and thus sign(ρ1 yi,r,i) = sign( w0 yi,r, ξi ). Further, by Lemma C.5, and without loss of generality that w0 yi,r, ξi 0, we have w1 yi,r, ξi ρ1 yi,r,i + w0 yi,r, ξi 4 d nα ρ1 yi,r,i + 1 2 w0 yi,r, ξi 0. Similar argument also holds for w0 yi,r, ξi < 0. Then we show at k = 1, sign(ρ1 yi,r,i) = sign( w1 yi,r, ξi ) = sign( w0 yi,r, ξi ). Suppose there exists a time e T such that for all k e T 1, the sign invariance holds. Then for k = e T, suppose sign( w e T 1 yi,r , ξi ) = sign(ρ e T 1 yi,r,i) = sign( w0 yi,r, ξi ) = +1, ρ e T yi,r,i = ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | w e T 1 yi,r , ξi ξi 2 ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | ρ e T 1 yi,r,i + w0 yi,r, ξi 4 ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | ρ e T 1 yi,r,i + 1 2 w0 yi,r, ξi ξi 2 Published as a conference paper at ICLR 2025 w e T yi,r, ξi ρ e T yi,r,i + w0 yi,r, ξi 4 d nα ρ e T yi,r,i + 1 2 w0 yi,r, ξi 0. and thus completes the induction that sign( w e T yi,r, ξi ) = sign(ρ e T yi,r,i) = sign( w0 yi,r, ξi ). Similar argument holds when sign( w0 yi,r, ξi ) = 1. We also derive the following concentration result for the average noise inner product at initialization. Lemma C.16. Suppose δ > 0 and m = Ω(log(n/δ)). Then with probability at least 1 δ, we have for all j = 1, i [n] r=1 | w0 j,r, ξi | σ0σξ Proof of Lemma C.16. First notice that for any i [n], w0 j,r, ξi N(0, σ2 0 ξi 2) and thus we have E[| w0 j,r, ξi |] = σ0 ξi p 2/π. By sub-Gaussian tail bound, with probability at least 1 δ/(2n), for any i [n] 1 m r=1 | w0 j,r, ξi | σ0 ξi p 2σ2 0 ξi 2 log(4n/δ) Choosing m = Ω(log(n/δ)), we have r=1 | w0 j,r, ξi | σ0 ξi p Because from Lemma B.2, we have 0.99σξ d ξi 1.01σξ d by choosing d = eΩ(1) sufficiently large. Then we have σ0σξ m Pn r=1 | w0 j,r, ξi | σ0σξ d. Finally taking the union bound for all j = 1, i [n] completes the proof. We have established several preliminary lemmas that hold with high probability, including Lemma B.1, Lemma B.2, Lemma C.1, Lemma C.14, Lemma C.16. We let Eprelim be the event such that all the results in these lemmas hold for a given δ. Then by applying union bound, we have P(Eprelim) 1 5δ. The subsequent analysis are conditioned on the event Eprelim. C.4.1 FIRST STAGE Theorem C.3. Under Condition C.1, suppose n 1 SNR 2 C for some constant C > 0. Then there exists a time T1 = eΘ(η 1nmσ 2 ξ d 1), such that (1) maxr | w T1 yi,r, ξi | 2 for all i [n], (2) 1 m Pm r=1 | w T1 yi,r, ξi | 4 for all i [n] and (3) maxj,r,y | w T1 j,r, µy | = e O(n 1/2). Proof of Theorem C.3. We first bound the growth of signal as follows. From the gradient descent update, we have | wk j,r, µj | = | wk 1 j,r , µj | + η|Sj| nm | wk 1 j,r , µj | µ 2 1 + 0.51η µ 2 | wk 1 j,r , µj | 1 + 0.51η µ 2 k | w0 j,r, µj | (16) where the first inequality is by |ℓ k i | 1 and the second inequality is by Lemma B.1 with n = eΩ(1) sufficiently large. Published as a conference paper at ICLR 2025 On the other hand, for the growth of noise, we have from the inner product update, for any i [n] wk yi,r, ξi = wk 1 yi,r , ξi η nm i =1 ℓ k 1 i wk 1 j,r , ξi ξi , ξi = 1 η nmℓ k i ξi 2 wk 1 yi,r , ξi η nm i =i ℓ k 1 i wk 1 yi,r , ξi ξi , ξi Then this suggests | wk yi,r, ξi | 1 η nmℓ k i ξi 2 | wk 1 yi,r , ξi | η nm i =i |ℓ k 1 i | | wk 1 yi,r , ξi | | ξi , ξi | (17) We first prove for any i [n], maxr | wk+1 yi,r , ξi | maxr | wk yi,r, ξi | maxr | w0 yi,r, ξi | for all k T1. We prove such a result by induction. It is clear that at k = 0, the result is satisfied. Now suppose there exists an iteration k such that max r | wk yi,r, ξi | max r | w0 yi,r, ξi | σ0σξ for all k k 1, where the last inequality is by Lemma C.1. Then we can bound based on Lemma C.7 and Lemma B.2, we have for any i = i [n] and n|ℓ k 1 i | | ξi, ξi | | w k 1 yi,r , ξi | |ℓ k 1 i | ξi 2 maxr | w k 1 yi,r , ξi | 2σ2 ξ p d log(4n2/δ) Cℓ0.99σ2 ξd nασ 1 0 σ 1 ξ d 1/2 = 8.4C 1 ℓ nα dσ0σξ 0.01 (18) where we use the lower and upper bound on loss derivatives during the first stage, as well as Lemma B.2 and Lemma C.1. The last inequality is by σ0 840n C 1 ℓ d 1σ 1 ξ α p log(4n2/δ). Then we have max r | w k yi,r, ξi | 1 η nmℓ k 1 i ξi 2 max r | w k 1 yi,r , ξi | η nm i =i |ℓ k 1 i | | w k 1 yi,r , ξi | | ξi , ξi | 1 + η nm0.99|ℓ k 1 i | ξi 2 max r w k 1 yi,r , ξi w k 1 yi,r , ξi max r | w0 yi,r, ξi | Let Bk i := maxr | wk yi,r, ξi | and we obtain for any k T1, Bk i 1 + η nm0.99|ℓ k 1 i | ξi 2 Bk 1 i 1 + ησ2 ξd nm 0.98Cℓ Bk 1 i 1 + ησ2 ξd nm 0.98Cℓ k B0 i 1 + ησ2 ξd nm 0.98Cℓ k σ0σξ where we use (18), which holds for iteration k and Lemma C.1. Consider T1 = log(8σ 1 0 σ 1 ξ d 1/2)/ log 1 + ησ2 ξd nm 0.98Cℓ = Θ(η 1nmσ 2 ξ d 1 log(8σ 1 0 σ 1 ξ d 1/2)) for η sufficiently small. Then it can be shown that BT1 i = max r | w T1 yi,r, ξi | 2 Published as a conference paper at ICLR 2025 In addition, we show the average also grows to a constant order with a similar argument. In particular, from (17), we have r=1 | wk yi,r, ξi | 1 η nmℓ k i ξi 2 1 r=1 | wk 1 yi,r , ξi | i =i |ℓ k 1 i | 1 r=1 | wk 1 yi,r , ξi | | ξi , ξi | Using a similar induction argument, we can show r=1 | wk yi,r, ξi | 1 r=1 | wk 1 yi,r , ξi | 1 r=1 | w0 yi,r, ξi | σ0σξ for all k T1, where the last inequality follows from Lemma C.16. Then we can show at T1, r=1 | w T1 yi,r, ξi | 1 + ησ2 ξd nm 0.98Cℓ In the meantime, (16) allows to bound the growth of signal learning as for any j = 1, max r | w T1 j,r, µj | 1 + 0.51η µ 2 2 log(8m/δ)σ0 µ = exp log(1 + 0.51 η µ 2 log(1 + 0.98 ησ2 ξd Cℓ nm ) log 8σ 1 0 σ 1 ξ d 1/2 p 2 log(8m/δ)σ0 µ exp 0.53C 1 ℓ n SNR2 + e O(n 1SNR 2η) log 8σ 1 0 σ 1 ξ d 1/2 p 2 log(8m/δ)σ0 µ 2 log(8m/δ)SNR = e O(n 1/2) where the first inequality is by Lemma C.1 and the second inequality is by Taylor expansion around η = 0. The third inequality is by choosing η sufficiently small and based on the condition that n 1SNR 2 0.55C 1 ℓ . The last equality is by the SNR condition. C.4.2 SECOND STAGE We choose W to be w j,r = w0 j,r + 2 log(4/ϵ) i=1 1(yi = j)sign( w0 j,r, ξi ) ξi ξi 2 First we show the invariance of sign of noise inner product after the first stage. Lemma C.17. Under the same settings as in Theorem C.3, we have maxr | wk yi,r, ξi | 1 and 1 m Pm r=1 | wk yi,r, ξi | 2 for all T1 k T and any i [n]. Proof of Lemma C.17. In addition to the two results, we also prove maxr |ρk yi,r,i| 1.5 and 1 m Pm r=1 |ρk yi,r,i| 3. We prove these results by induction. First, it is clear that at k = T1, the bound regarding inner products are trivially satisfied by Theorem C.3. Then by Lemma C.5, we have max r |ρT1 yi,r,i| max r | w T1 yi,r, ξi | β 4 d nα 2 0.5 = 1.5 r=1 |ρT1 yi,r,i| 1 r=1 | w T1 yi,r, ξi | β 4 d nα 4 1 = 3 Published as a conference paper at ICLR 2025 where the last inequalities are by Condition C.1 for sufficiently large constant C. Now suppose there exists a time T1 e T T such that the results hold for all k e T 1. Then at k = e T, recall the coefficient update as ρ e T yi,r,i = ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | w e T 1 yi,r , ξi ξi 2 (19) If w0 yi,r, ξi > 0, by Lemma C.15 we have w e T 1 yi,r , ξi , ρ e T 1 yi,r,i > 0. Then ρ e T yi,r,i = ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | w e T 1 yi,r , ξi ξi 2 ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | ρ e T 1 yi,r,i + w0 yi,r, ξi 4 ρ e T 1 yi,r,i + η nm|ℓ e T 1 i | ρ e T 1 yi,r,i + 1 2 w0 yi,r, ξi ξi 2. Then taking maximum over r, max r |ρ e T yi,r,i| max r |ρ e T 1 yi,r,i| + η ξi 2 2nm |ℓ e T 1 i | max r |ρ e T 1 yi,r,i| max r |ρ e T 1 yi,r,i| 1.5 where the first inequality follows from w0 yi,r, ξi /2 0.5 maxr |ρ e T 1 yi,r,i|/2 based on Condition C.1. Similarly, when w0 yi,r, ξi < 0, we can obtain the same result. Then, we have max r | w e T yi,r, ξi | max r |ρ e T yi,r,i| β 4 d nα 1.5 0.5 = 1. Furthermore, we prove the results for the average quantities in a similar manner. First, from the coefficient update, and by Lemma C.15, sign(ρ e T 1 yi,r,i) = sign( w e T 1 yi,r , ξi ) and thus taking the average of absolute value on both sides of (19), we get r=1 |ρ e T yi,r,i| = 1 r=1 |ρ e T 1 yi,r,i| + η nm|ℓ e T 1 i | 1 r=1 | w e T 1 yi,r , ξi | ξi 2 r=1 |ρ e T 1 yi,r,i| + η nm|ℓ e T 1 i | 1 r=1 |ρ e T 1 yi,r,i| β 4 r=1 |ρ e T 1 yi,r,i| + η 2nm|ℓ e T 1 i | 1 r=1 |ρ e T 1 yi,r,i| ξi 2 r=1 |ρ e T 1 yi,r,i| 3 where we use |a + b| = |a| + |b| when sign(a) = sign(b). Then, we have r=1 | w e T yi,r, ξi | 1 r=1 |ρ e T yi,r,i| β 4 d nα 3 1 = 2. where the inequality is by Condition C.1. Lemma C.18. Under Condition C.1, we have WT1 W = O( nm log(1/ϵ)σ 1 ξ d 1/2). Proof of Lemma C.18. The proof follows similarly as in Lemma C.11. Let Pξ be the projection matrix to the direction of ξ, i.e., Pξ = ξξ ξ 2 . Then we can represent wk j,r w0 j,r = Pµ1(wk j,r w0 j,r) + Pµ 1(wk j,r w0 j,r) + i=1 Pξi(wk j,r w0 j,r) Published as a conference paper at ICLR 2025 + I Pµ1 Pµ 1 i=1 Pξi (wk j,r w0 j,r). By the scale difference at T1 and the fact that gradient descent only updates in the direction of µj, j = 1 and ξi, we can bound w T1 j,r w0 j,r, µ1 2 µ 2 + w T1 j,r w0 j,r, µ 1 2 w T1 j,r w0 j,r, ξi 2 i=1 Pξi (w T1 j,r w0 j,r) 2m 2 w T1 j,r, µj 2 + 2 w T1 j,r, µ j 2 + 2 w0 j,r, µ j 2 + 2 w0 j,r, µj 2 + n max j,r 2 w T1 j,r, ξi 2 + 2 w0 j,r, ξi 2 i=1 Pξi (w T1 j,r w0 j,r) O(mnσ 2 ξ d 1) where we use the scale difference at T1. Therefore, WT1 W WT1 W0 + W0 W O( mnσ 1 ξ d 1/2) + O( nm log(1/ϵ)σ 1 ξ d 1/2) O( nm log(1/ϵ)σ 1 ξ d 1/2) where we use the definition of W . Lemma C.19. Under Condition C.1, we have for all T1 k T Wk W 2 Wk+1 W 2 2ηLS(Wt) ηϵ Proof of Lemma C.19. The proof follows from similar arguments as for Lemma C.12. We first obtain a lower bound on yi f(Wt, xi), W for any i [n] for all T1 k T . yi f(Wk, xi), W = 1 j,r jyi wk j,r, µyi µyi, w j,r + 1 j,r jyi wk j,r, ξi ξi, w j,r j,r jyi wk j,r, µyi µyi, w0 j,r + 1 j,r jyi wk j,r, ξi ξi, w0 j,r i =1 jyi wk j,r, ξi 1(j = yi ) ξi, ξi ξi 2 2 log(4/ϵ) r=1 | wk yi,r, ξi |2 log(4/ϵ) i =i wk yi,r, ξi 2 log(4/ϵ) ξi, ξi ξi 2 | {z } A10 j,r jyi wk j,r, µyi µyi, w0 j,r j,r jyi wk j,r, ξi ξi, w0 j,r where the second equality is by definition of W . The third equality is by Lemma C.17 and Lemma C.15 on the sign invariance. We next bound based on the scale difference and Lemma B.2, |A10| = e O(nd 1/2), |A11| = e O(σ0 µ ), |A12| e O(σ0σξ Published as a conference paper at ICLR 2025 where we use the global bound on the inner product by e O(1). Next, by Theorem C.3 and Lemma C.17, we can show 1 m Pm r=1 | wk yi,r, µyi | 2 for all i [n], k T1 and we can bound A9 4 log(4/ϵ) Combining the bound for A9, A10, A11, A12, we have yi f(Wk, xi), W 2 log(4/ϵ) (20) where we bound |A10| + |A11| + |A12| 2 log(4/ϵ) under Condition C.1. Further, we derive Wk W 2 Wk+1 W 2 = 2η LS(Wk), Wk W η2 LS(Wk) 2 i=1 ℓ k i yi 2f(Wk, xi) f(Wk, xi), W η2 LS(Wk) 2 i=1 ℓ k i 2yif(Wk, xi) 2 log(2/ϵ) η2 LS(Wk) 2 ℓ(yif(Wk, xi)) ϵ/4 η2 LS(Wk) 2 2ηLS(Wk) ηϵ where the first inequality is by (20) and the second inequality is by convexity of cross-entropy function and the last inequality is by Lemma C.8. Theorem C.4. Under the same settings as in Theorem C.3, let T = T1 + WT1 W 2 ηϵ = T1 + O(η 1ϵ 1mnσ 2 ξ d 1). Then we have there exists T1 k T such that LS(Wk) 0.1. maxj,r,y | wk j,r, µy | = o(1) for all T1 k T. maxr | wk yi,r, ξi | 1 for all i [n], T1 k T. Proof of Theorem C.4. The proof is similar as in Theorem C.2. By Lemma C.19, for any T1 k T, we have Wk W 2 Wk+1 W 2 2ηLS(Wk) ηϵ for all s k. Then summing over the inequality gives k=T1 LS(Wk) WT1 W 2 2η(T T1 + 1) + ϵ where the last inequality is by the choice T = T1 + WT1 W 2 ηϵ = T1 + Ω(η 1ϵ 1nm3 log(1/ϵ)σ 2 ξ d 1). Then we can claim that there exists a k [T1, T] such that LS(Wk) ϵ. Setting ϵ = 0.1 shows the desired convergence. Next, we show the upper bound on maxj,y,r | wk j,r, µy | for all k [T1, T]. Notice that by Proposition C.1, we already have maxj,r | wk j,r, µj | ϑ, where we let ϑ := 3 max{max j,r | w T1 j,r, µj |, β, 4 d nα} = e O(max{n 1/2, σ0σξ d, σ0 µ , nd 1/2}) Subsequently, we use induction to prove maxj,r | wk j,r, µj | 2ϑ. First we notice that k=T1 LS(Wk) WT1 W 2 η = O(η 1nmσ 2 ξ d 1) (21) Published as a conference paper at ICLR 2025 where the equality is by Lemma C.11 where we choose ϵ = 0.1. At k = T1, we have maxj,r | w T1 j,r, µj | ϑ 2ϑ. Suppose there e T [T1, T] such that maxr,i |ρT1 yi,r,i| 2ϑ for all k [T1, e T 1]. Now we let Ψk := maxj,r | wk j,r, µj | and thus by the update of inner product Ψk+1 Ψk + η nm i Sj |ℓ k i |Ψk µ 2 i [n] ℓk i Ψk µ 2 = Ψk + 2η µ 2 m LS(Wk)Ψk. where we use |ℓ | ℓin the second inequality. Taking the summation from T1 to e T gives Ψ e T ΨT1 + 2η µ 2 k=T1 LS(Wk) m2ϑ ΨT1 + O(n SNR2) 2ϑ 2ϑ where the second inequality is by (21) and the last inequality is by n 1 SNR 2 C for sufficiently large constant C > 0. The lower bound for noise inner product is directly from Lemma C.17. Published as a conference paper at ICLR 2025 D DIFFUSION MODEL For the analysis of diffusion model, we restate Condition 3.1 specifically for the case of diffusion model. Condition D.1. Suppose δ > 0 and the following conditions hold. 1. The dimension d satisfies d = eΩ(max{n4, nσ 1 ξ }). 2. The training sample satisfies n = Ω(log(m/δ)) and the network width satisfies m = Θ(1). 3. The initialization σ0 satisfies e O(nσ 1 ξ d 5/4) σ0 e O(min{m 1/6d 1/6σ1/3 ξ n 1/3, m 1/6d 7/12σ 1/3 ξ n1/3, d 3/4σ 1 ξ n}) 4. The signal strength satisfies µ = Θ(1). 5. SNR 1 = e O(d1/4). 6. The noise coefficients αt, βt satisfy αt, βt = Θ(1). We make the following remarks on the conditions. Compared to the conditions required by classification, diffusion model requires m = Θ(1) for the analysis of stationary points. The lower bound on sample size n is required for the concentration of |S1|, |S 1|. The lower bound on σ0 is required to ensure the inner products of ξi across samples remain small relative to the initialization. The constant order of signal strength µ and the bound for n SNR2 are utilized for simplifying the analysis. It is also worth mentioning that diffusion does not require a small learning rate for convergence. D.1 USEFUL LEMMAS Lemma D.1. Suppose δ > 0. Then with probability at least 1 δ, for any t, σ2 0d(1 e O(d 1/2)) w0 r,t 2 σ2 0d(1 + e O(d 1/2)) | w0 r,t, µj | p 2 log(16m/δ)σ0 µj , | w0 r,t, ξi | 2 p log(16mn/δ)σ0σξ | w0 r,t, w0 r ,t | 2 p log(16m2/δ)σ2 0 for all r, r [m] and i [n], and j = 1, 2. Proof of Lemma D.1. The proof is the same as in (Kou et al., 2023) and we include here for completeness. Because at initialization w0 r,t N(0, σ2 0I), by Bernstein s inequality, with probability at least 1 δ/(8m), we have | w0 r,t 2 2 σ2 0d| = O(σ2 0 p d log(16m/δ)) Then taking the union bound yields for all r [m], we have with probability at least 1 δ/4 that σ2 0d(1 e O(d 1/2)) w0 r,t 2 2 σ2 0d(1 + e O(d 1/2)). Further, because w0 r,t, µj N(0, σ2 0 µj 2 2) for j = 1, 2, then by Gaussian tail bound and union bound, we have with probability at least 1 δ/4, for all j = 1, 2, r [m], | w0 r,t, µj | p 2 log(16m/δ)σ0 µ 2 Finally, following similar argument and noticing that ξi 2 2 = Θ(σ2 ξd) and w0 r,t 2 2 = Θ(σ2 0d), we have with probability at least 1 δ/4 that for all i [n], | w0 r,t, ξi | 2 p log(16mn/δ)σ0σξ d and | w0 r,t, w0 r ,t | 2 p log(16m2/δ)σ2 0 Published as a conference paper at ICLR 2025 D.2 DERIVATION OF LOSS FUNCTION AND GRADIENT We first simplify the objective through taking the expectation over the added diffusion noise. Lemma D.2. The DDPM loss can be simplified under expectation as d + L(j) 1,i(Wt) + L(j) 2,i(Wt) , L(j) 1,i(Wt) = 1 r=1 wr,t 2 α4 t wr,t, x(j) 0,i 4 + 6α2 tβ2 t wr,t, x(j) 0,i 2 wr,t 2 + 3β4 t wr,t 4 4 mαtβt wr,t, x(j) 0,i L(j) 2,i(Wt) = 2 r =r wr,t, wr ,t α2 t wr,t, x(j) 0,i 2 + β2 t wr,t 2 α2 t wr ,t, x(j) 0,i 2 + β2 t wr ,t 2 + 2β4 t wr,t, wr ,t 2 + 4α2 tβ2 t wr,t, x0,i wr ,t, x0,i wr,t, wr ,t corresponding to the learning of r-th neuron and alignment of r-th neuron with other neurons respectively. Proof of Lemma D.2. Without loss of generality, we consider for a single sample xt,i. We first write the objective as E fp(Wt, x(p) t,i ) ϵ(p) t,i 2 = E ϵ(p) t,i 2 | {z } I1 r=1 σ( wr,t, x(p) t,i )wr,t r=1 σ( wr,t, x(p) t,i ) wr,t, ϵt,i where we omit the subscript for the expectation for clarity. First, we can see I1 = d. Then r=1 E h ( wr,t, x(p) t,i )2 wr,t, ϵt,i i i =1 E h ( wr,t, x(p) t,i )2wr,t[i ]ϵt,i[i ] i i =1 E h ( wr,t, x(p) t,i )wr,t[i ]2i r=1 wr,t 2E h wr,t, x(p) t,i i r=1 wr,t 2 wr,t, x(p) 0,i where the third equality uses Stein s Lemma. Next, we consider I2 by writing r=1 E ( wr,t, x(p) t,i )4 wr,t 2 + 2 r =r E ( wr,t, x(p) t,i )2( wr ,t, x(p) t,i )2 wr,t, wr ,t . Published as a conference paper at ICLR 2025 Next, we compute the two terms E ( wr,t, x(p) t,i )4 and E ( wr,t, x(p) t,i )2( wr ,t, x(p) t,i )2 respec- tively. For notation simplicity, we let ar := αt wr,t, x(p) 0,i , br := βt wr,t and zr := βt wr,t, ϵt,i . We first compute E[zr] = 0 and E[z2 r] = β2 t wr,t 2, E[z4 r] = 3β4 t wr,t 4. For the first term, E ( wr,t, x(p) t,i )4 = E (ar + zr)4 = E[a4 r + 4a3 rzr + 6a2 rz2 r + 4arz3 r + z4 r] = a4 r + 6a2 r E[z2 r] + E[z4 r] = a4 r + 6a2 rb2 r + 3b4 r = α4 t wr,t, x(p) 0,i 4 + 6α2 tβ2 t wr,t, x(p) 0,i 2 wr,t 2 + 3β4 t wr,t 4 Next, for Eϵt,i N(0,I)[ wr,t, αtx0,i + βtϵt,i 2 wr ,t, αtx0,i + βtϵt,i 2], we note that E[zrzr ] = β2 t E[ϵ t,iwr,tw r ,tϵt,i] = β2 t wr,t, wr ,t , E[zrz2 r ] = 0 E[z2 rz2 r ] = E[z2 r]E[z2 r ] + 2E[zrzr ]2 = β4 t wr,t 2 wr ,t 2 + 2β4 t wr,t, wr ,t 2 where the second and third results follow from Isserlis Theorem. Then we can simplify E[ wr,t, αtx0,i + βtϵt,i 2 wr ,t, αtx0,i + βtϵt,i 2] = E[(ar + zr)2(ar + zr )2] = a2 ra2 r + a2 r E[z2 r ] + 4arar E[zrzr ] + a2 r E[z2 r] + E[z2 rz2 r ] = α4 t wr,t, x0,i 2 wr ,t, x0,i 2 + α2 tβ2 t wr,t, x0,i 2 wr ,t 2 + 4α2 tβ2 t wr,t, x0,i wr ,t, x0,i wr,t, wr ,t + α2 tβ2 t wr ,t, x0,i 2 wr,t 2 + β4 t wr,t 2 wr ,t 2 + 2β4 t wr,t, wr ,t 2 Combining I1, I2, I3 gives E st(x(p) t,i ) ϵt,i 2 r=1 wr,t 2 α4 t wr,t, x(p) 0,i 4 + 6α2 tβ2 t wr,t, x(p) 0,i 2 wr,t 2 + 3β4 t wr,t 4 4 mαtβt wr,t, x(p) 0,i L(p) 1,i (wr,t) r =r wr,t, wr ,t α2 t wr,t, x(p) 0,i 2 + β2 t wr,t 2 α2 t wr ,t, x(p) 0,i 2 + β2 t wr ,t 2 + 2β4 t wr,t, wr ,t 2 + 4α2 tβ2 t wr,t, x0,i wr ,t, x0,i wr,t, wr ,t L(p) 2,i (wr,t) where we respectively denote the two composing loss terms as L(p) 1,i (corresponding to the learning of r-th neuron) and L(p) 2,i (alignment with other neurons). We next compute the gradient of the DDPM loss in expectation. Lemma D.3. The gradient of expected DDPM loss in Lemma D.2 can be computed as L(p) 1,i (Wt) + L(p) 2,i (Wt) L(p) 1,i (wr,t) α4 t wr,t, x(p) 0,i 4 + 12α2 tβ2 t wr,t, x(p) 0,i 2 wr,t 2 + 9β4 t wr,t 4 4 mαtβt wr,t, x(p) 0,i wr,t Published as a conference paper at ICLR 2025 2α4 t wr,t, x(p) 0,i 3 wr,t 2 + 6α2 tβ2 t wr,t 4 wr,t, x(p) 0,i 2 mαtβt wr,t 2 x(p) 0,i L(p) 2,i (wr,t) α2 t wr,t, x(p) 0,i 2 + β2 t wr,t 2 α2 t wr ,t, x(p) 0,i 2 + β2 t wr ,t 2 + 2β4 t wr,t, wr ,t 2 + 4α2 tβ2 t wr,t, x0,i wr ,t, x0,i wr,t, wr ,t wr ,t α2 t wr ,t, x(p) 0,i 2 + β2 t wr ,t 2 wr,t, wr ,t 2α2 t wr,t, x(p) 0,i x(p) 0,i + 2β2 t wr,t r =r wr,t, wr ,t 2 4β2 t wr ,t + 8α2 tβ2 t wr,t, x0,i x0,i Proof of Lemma D.3. The proof is straightforward and thus omitted for clarity. D.3 FIRST STAGE Before deriving the results for the first stage, we derive the following lemma that decomposes the weight norm given concentration of neurons. Lemma D.4. For any k and r [m], such that wk r,t, µj = Θ( wk r,t, µj ) eΘ(σ0 µ ), wk r,t, ξi = Θ( wk r,t, ξi ) eΘ(σ0σξ d) and wk r,t, µj , wk r,t, ξi = e O(1), wk r,t, w0 r,t = Θ(σ2 0d) for any j, j = 1, i, i [n], r [m]. Then we can show wk r,t 2 = Θ wk r,t, µj 2 µ 2 + n SNR2 wk r,t, ξi 2 µ 2 + w0 r,t 2 . and for r = r , we have wk r,t, wk r ,t = Θ wk r,t, µj wk r ,t, µj µ 2 + n SNR2 wk r,t, ξi wk r ,t, ξi µ 2 + w0 r,t, w0 r ,t Proof of Lemma D.4. We decompose the weight wk r,t as wk r,t = ϕk rw0 r,t + γk 1µ1 µ1 2 + γk 1µ 1 µ 1 2 + i=1 ρk r,iξi ξi 2, (22) based on the gradient descent updates of wk r,t starting from small initialization w0 r,t and the direction of update only involves wk r,t and µ 1, ξi, where γ0 1 = γ0 1 = ρ0 r,i = 0 and ϕk r = 1. Then given the assumption that wk r,t, w0 r,t = Θ(σ2 0d), we have ϕk r = Θ(1) because wk r,t, w0 r,t = ϕk r w0 r,t 2 + γk 1µ1 µ1 2 + γk 1µ 1 µ 1 2 + i=1 ρk r,iξi ξi 2, w0 r,t = ϕk rΘ(σ2 0d), where the second equality uses Lemma D.1. This suggests that ϕk r = Θ(1). Then we can see wk r,t, µj = ϕk r w0 r,t, µj + γk j wk r,t, ξi = ϕk r w0 r,t, ξi + ρk r,i + X i =i ρk r,i ξi, ξi ξi 2 = ϕk r w0 r,t, ξi + ρk r,i + e O(nd 1/2), where the second equality for wk r,t, ξi is by Lemma B.2 and wk r,t, ξi = e O(1), thus ρk r,i = e O(1). Then based on the assumptions that | wk r,t, µj | eΘ(σ0 µ ) and | wk r,t, ξi | eΘ(σ0σξ d) and ϕk r = Θ(1), we can simplify (22) as wk r,t = Θ(w0 r,t) + Θ( wk r,t, µj (µ1 + µ 1) µ 2) + Θ wk r,t, ξi + e O(nd 1/2) n X i=1 ξi ξi 2, Published as a conference paper at ICLR 2025 where we use wk r,t, µj = Θ( wk r,t, µj ) ϕk r| w0 r,t, µj | and wk r,t, ξi = Θ( wk r,t, ξi ) ϕk r| w0 r,t, ξi | in the first equality. For the second equality, we use the assumption that wk r,t, ξi eΘ(σ0σξ Then we can show = Θ( w0 r,t 2) + Θ( wk r,t, µj 2) µ 2 + Θ( wk r,t, ξi 2 + e O(nd 1/2)) i=1 ξi ξi 2 2 + Θ( wk r,t, µj w0 r,t, µj µ 2) + Θ(( wk r,t, ξi + e O(nd 1/2)) w0 r,t, ξi = Θ(σ2 0d) + Θ( wk r,t, µj 2) µ 2 + Θ( wk r,t, ξi 2 + e O(nd 1/2)) Θ(nσ 2 ξ d 1) + e O(n2σ 2 ξ d 3/2) + Θ(( wk r,t, ξi + e O(nd 1/2)) w0 r,t, ξi nσ 2 ξ d 1) = Θ(σ2 0d) + Θ( wk r,t, µj 2) µ 2 + Θ(nσ 2 ξ d 1 wk r,t, ξi 2) + e O(n2σ 2 ξ d 3/2) + e O(n2σ0σ 1 ξ d 1) = Θ wk r,t, µj 2 µ 2 + n SNR2 wk r,t, ξi 2 µ 2 + σ2 0d , where the second equality uses Lemma D.1, Lemma B.2 and wk r,t, µj = Θ( wk r,t, µj ) ϕk r| w0 r,t, µj |. The third equality is by the condition on d = eΩ(n2) and wk r,t, ξi = Θ( wk r,t, ξi ) ϕk r| w0 r,t, ξi | and Lemma D.1. The last equality is by the condition σ0 eΩ(max{nσ 1 ξ d 5/4, n2σ 1 ξ d 2}) = eΩ(nσ 1 ξ d 5/4) given d = eΩ(n2). In addition, we can deduce from (23) that wk r,t, wk r ,t = Θ( w0 r,t, w0 r ,t ) + Θ( wk r ,t, µj w0 r,t, µj µ 2) + Θ( wk r ,t, ξi i=1 w0 r,t, ξi ξi 2) + Θ( wk r,t, µj w0 r ,t, µj µ 2) + Θ wk r,t, ξi i=1 w0 r ,t, ξi ξi 2 + Θ( wk r,t, µj wk r ,t, µj µ 2) + Θ(n wk r,t, ξi wk r ,t, ξi ξi 2) = Θ( w0 r,t, w0 r ,t ) + Θ wk r,t, µj wk r ,t, µj µ 2 + Θ(n wk r,t, ξi wk r ,t, ξi ξi 2) = Θ( w0 r,t, w0 r ,t ) + Θ wk r,t, µj wk r ,t, µj µ 2 + Θ(n SNR2 wk r,t, ξi wk r ,t, ξi µ 2) where we use ϕk r = Θ(1) and the wk r,t, µj = Θ( wk r,t, µj ) ϕk r| w0 r,t, µj | and wk r,t, ξi = Θ( wk r,t, ξi ) ϕk r| w0 r,t, ξi | for the equalities. Lemma D.5 (Restatement of Lemma 4.1). Under Condition D.1, there exists an iteration T1 = max{Tµ, Tξ}, where Tµ = eΘ( mσ 1 0 d 1 µ 1η 1) and Tξ = eΘ(n mσ 1 0 σ 1 ξ d 3/2η 1) such that for all 0 k T1, (1) wk r,t 2 = Θ(σ2 0d) for all r [m], j = 1, i [n], (2) wk r,t, w0 r,t = Θ(σ2 0d), and (3) the signal and noise learning dynamics can be simplified to wk+1 r,t , µj = wk r,t, µj + Θ ηαtβt|Sj| n m wk r,t 2 µj 2 wk+1 r,t , ξi = wk r,t, ξi + Θ ηαtβt n m wk r,t 2 ξi 2 for all j = 1, r [m], i [n]. Furthermore, we can show w T1 r,t, µj = Θ( w T1 r ,t, µj )> 0, w T1 r,t, ξi = Θ( w T1 r ,t, ξi )> 0, w T1 r,t 2 = Θ( w T1 r ,t 2), Published as a conference paper at ICLR 2025 w T1 r,t, w T1 r ,t = Θ( w T1 r,t 2), for r = r , wr,t L(WT1 t ), w0 r,t = 1 mΘ w T1 r,t, µj + ξ m w T1 r,t 4 w T1 r,t, w0 r,t + w T1 r,t 2 w0 r,t, µj + ξ , where we denote ξ = 1 n Pn i=1 ξi. w T1 r,t, µj / w T1 r ,t, ξi = Θ(n SNR2), for all j, j = 1, r, r [m], i, i [n]. Proof of Lemma D.5. We prove the results by induction. To this end, we first compute the scale of the gradients projected to the space of µ1, µ 1 and ξi, for i [n] under the initialization scale. For notation clarity, we omit the index k. As long as wr,t 2 = Θ(σ2 0d) and suppose wr ,t, µj = O( wr,t, µj ), wr ,t, ξi = O( wr,t, ξi ), we can identify the dominant terms as follows. Signal. First for µj, and for any i [n], we compute i=1 L(1) 1,i (wr,t), µj mΘ( wr,t, µj 5 + wr,t, µj 3σ2 0d + σ4 0d2 wr,t, µj ) 1 mΘ( wr,t, µj 2) mΘ(σ2 0d wr,t, µj 3 + σ4 0d2 wr,t, µj ) 1 mΘ(σ2 0d) m O(σ4 0d2 wr,t, µj ) 1 mΘ(σ2 0d) where the second equality is by wr,t, µj 2 wr,t 2 µ 2 = Θ(σ2 0d). It is clear the dominant term is 1 m4αtβt wr,t 2 µ 2. The second dominant term comes from Θ( 1 m wr,t 4 wr,t, µj ). Further, we have due to the orthogonality between signal and noise vectors, i=1 L(2) 1,i (wr,t), µj = 1 m O wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj + σ4 0d2 wr,t, µj m wr,t, ξi wr,t, µj In addition, we have i=1 L(1) 2,i (wr,t), µj = m 1 m O wr,t, µj 5 + wr,t, µj 3σ2 0d + σ4 0d2 wr,t, µj m O σ4 0d2 wr,t, µj i=1 L(2) 2,i (wr,t), µj = m 1 m O( wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj + σ4 0d2 wr,t, µj ) Then according to the definition of |S 1| and µ 1, we can simplify the gradient into the dominant terms as L(Wt), µj = 1 mΘ( wr,t 2 + wr,t, ξi wr,t, µj ) + O σ4 0d2 wr,t, µj + wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj Published as a conference paper at ICLR 2025 Noise. Similarly, we can also show for the noise learning i =1 L(1) 1,i (wr,t), ξi = 1 m O wr,t, µj 4 wr,t, ξi + wr,t, µj 2σ2 0d wr,t, ξi + σ4 0d2 wr,t, ξi 1 mΘ wr,t, ξi wr,t, µj m O(σ4 0d2 wr,t, ξi ) 1 mΘ( wr,t, ξi wr,t, µj ) where the dominating term is 4 mαtβt wr,t, µj wr,t, ξi . In addition, i =1 L(2) 1,i (wr,t), ξi m O wr,t, ξi 5 + wr,t, ξi 3σ2 0d + σ4 0d2 wr,t, ξi 1 m O( wr,t, ξi 2) m O( wr,t, ξi 3σ2 0d + σ4 0d2 wr,t, ξi ) + Θ(σ2 0d) Θ(σ2 ξdn 1) m e O(σ2 ξ m O wr,t, ξi 5 + wr,t, ξi 3σ2 0d + σ4 0d2 wr,t, ξi 1 m O( wr,t, ξi 2) m O( wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1) 1 mΘ(σ2 0σ2 ξd2n 1) where we use Lemma B.2 in the first equality and the second equality is by d = eΩ(n2). Further we can show i =1 L(1) 2,i (wr,t), ξi m O wr,t, µj 4 wr,t, ξi + wr,t, µj 2σ2 0d wr,t, ξi + σ4 0d wr,t, ξi m O σ4 0d wr,t, ξi i =1 L(2) 2,i (wr,t), ξi = m 1 m O wr,t, ξi 5 + wr,t, ξi 3σ2 0d + σ4 0d2 wr,t, ξi m O( wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1) This suggests we can simplify the gradient along noise direction as L(Wt), ξi = 1 mΘ(σ2 0σ2 ξd2n 1 + wr,t, ξi 2 + wr,t, ξi wr,t, µj ) + O σ4 0d2 wr,t, ξi + wr,t, ξi 5 + wr,t, ξi 3σ2 0d + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 In summary, as long as wr,t 2 = Θ(σ2 0d) and suppose wr ,t, µj = O( wr,t, µj ), wr ,t, ξi = O( wr,t, ξi ), we can simplify the gradient as wr,t L(Wt), µj = 1 mΘ(σ2 0d) 1 m O( wr,t, ξi wr,t, µj ) + O σ4 0d2 wr,t, µj + wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj Published as a conference paper at ICLR 2025 wr,t L(Wt), ξi = 1 mΘ(σ2 0σ2 ξd2n 1) 1 m O( wr,t, ξi 2 + wr,t, ξi wr,t, µj ) + O σ4 0d2 wr,t, ξi + wr,t, ξi 5 + wr,t, ξi 3σ2 0d + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 (25) In the initial phase where wk r,t 2 = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ ) and | wk r,t, ξi | = e O(σ0σξ d) (by Lemma D.1), we can show (24) reduces to = 1 mΘ(σ2 0d) + O σ4 0d2 wr,t, µj + wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj = 1 mΘ(σ2 0d) where we use the condition that SNR 1 = e O(d1/4), i.e., σξ = e O(d 1/4) = o(1) in the first equality. The second equality is by condition σ0 e O(m 1/6d 1/3) and σ 1 ξ = eΩ(d1/4). Further, we can show (25) reduces to L(Wt), ξi = 1 mΘ(σ2 0σ2 ξd2n 1) + O σ4 0d2 wr,t, ξi + wr,t, ξi 5 + wr,t, ξi 3σ2 0d + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 = 1 mΘ(σ2 0σ2 ξd2n 1) + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 = 1 mΘ(σ2 0σ2 ξd2n 1) where the first equality is by d = eΩ(n) and d eΩ(n2/3σ 2/3 ξ ). The second equality is by σ3 0 e O(m 1/2d 1/2σξn 1). The third equality is by σ3 0 e O(m 1/2d 3/2σ 1 ξ ). In summary, we can show as long as wk r,t 2 = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ ) and | wk r,t, ξi | = e O(σ0σξ wk+1 r,t , µj = wk r,t, µj + ηαtβt m Θ(σ2 0d) (26) wk+1 r,t , ξi = wk r,t, ξi + ηαtβt n m Θ(σ2 0d) ξi 2 (27) In addition, we similarly show that as long as wk r,t 2 = Θ(σ2 0d), | wk r,t, µj |, | wk r,t, ξi | = o(1), wk+1 r,t , w0 r,t = wk r,t, w0 r,t + ηO σ4 0d2 + wk r,t, µj + ξi wk r,t, w0 r,t + σ2 0d w0 r,t, µj + ξi , (28) Next, let Tµ = Θ( m log(16m/δ) σ0d µ ηαtβt ) and Tξ = Θ( n m log(16mn/δ) σ0σξd3/2ηαtβt ) and T1 = max{Tµ, Tξ}. We prove the results hold for all 0 k T1 via induction. We partition the proof into two stages, namely when 0 k min{Tµ, Tξ} and when min{Tµ, Tξ} k T1. (1) We first show for all 0 k min{Tµ, Tξ} that wk r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ ) and | wk r,t, ξi | = e O(σ0σξ d) hold and thus (26), (27), (28) are directly satisfied. We prove the claims by induction as follows. It is clear that at k = 0, we have from Lemma D.1 that w0 r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = w0 r,t 2 = Θ(σ2 0d) and | w0 r,t, µj | p 2 log(16m/δ)σ0 µ = e O(σ0 µ ) Published as a conference paper at ICLR 2025 | w0 r,t, ξi | 2 p log(16mn/δ)σ0σξ d = e O(σ0σξ Suppose there exists an iteration e T min{Tµ, Tξ} such that wk r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = Θ(σ2 0d), | w0 r,t, µj | = e O(σ0 µ ) and | wk r,t, ξi | = e O(σ0σξ d) for all 0 k e T 1. Then we have from (26) that w e T r,t, µj = w e T 1 r,t , µj + ηαtβt m Θ(σ2 0d) µ 2 = w0 r,t, µj + ηαtβt m Θ(σ2 0d) µ 2 e T w0 r,t, µj + ηαtβt m Θ(σ2 0d) µ 2Tµ = w0 r,t, µj + e O(σ0 µ ) = e O(σ0 µ ) (29) where we use the Lemma B.1 that |Sj| = Θ(n). In addition, we have from (27) that w e T r,t, ξi = w e T 1 r,t , ξi + ηαtβt n m Θ(σ2 0σ2 ξd2) w0 r,t, ξi + ηαtβt n m Θ(σ2 0σ2 ξd2)Tξ where we use Lemma D.1 that ξi 2 = Θ(σ2 ξd) for all i [n]. Finally, we deduce from (28) that w e T r,t, w0 r,t = w e T 1 r,t , w0 r,t + ηO σ4 0d2 + wk r,t, µj + ξi wk r,t, w0 r,t + σ2 0d w0 r,t, µj + ξi = w e T 1 r,t , w0 r,t + ηO σ3 0d3/2( µ + σξ d) + σ6 0d3 = Θ(σ2 0d) + ηO σ3 0d3/2( µ + σξ d) + σ6 0d3 e T (31) where we use Cauchy-Schwarz inequality in the first inequality. When e T = Tµ, ηO σ3 0d3/2( µ + σξ d) + σ6 0d3 e T = e O ( µ + σξ d + σ5 0d2 Θ(σ2 0d) (32) where the last inequality is by the condition on σ 1 ξ = eΩ(d1/4) 1 and σ0 e O(d 1/3). When e T = Tξ, ηO σ3 0d3/2( µ + σξ d) + σ6 0d3 e T = e O nσ 1 ξ σ2 0( µ + σξ d) + nσ5 0σ 1 ξ d3/2 Θ(σ2 0d) where the last inequality is by the condition on d that d = eΩ(nσ 1 ξ µ ) and d = eΩ(n2) and σ0 e O(σ1/3 ξ n 1/3d 1/6). Hence we have proved the induction on wk r,t, w0 r,t and in fact proved a stronger result that wk r,t, w0 r,t = Θ(σ2 0d) for all k max{Tµ, Tξ} as long as wk r,t 2 = Θ(σ2 0d), | wk r,t, µj |, | wk r,t, ξi | = o(1). Next, we let Pξ = ξξ ξ 2 be the projection matrix onto the direction of ξ and we express w e T r,t = Pµ1w e T r,t + Pµ 1w e T r,t + Pn i=1 Pξiw e T r,t + I Pµ1 Pµ 1 Pn i=1 Pξi w e T r,t and due to the orthogonality of the decomposition, we have w e T r,t 2 = w e T r,t, µ1 2 µ 2 + w e T r,t, µ 1 2 w e T r,t, ξi ξ 2 i=1 Pξi w e T r,t = e O(σ2 0) + e O(nσ2 0) + w e T r,t, w0 r,t w0 r,t 2 w0 r,t 2 Published as a conference paper at ICLR 2025 where we use the induction results that | w e T r,t, µj | = e O(σ0 µ ) and | w e T r,t, ξi | = e O(σ0σξ and the I Pµ1 Pµ 1 Pn i=1 Pξi w e T r,t 2 is dominated by its projection to w0 r,t. This completes the induction that for all k min{Tµ, Tξ}, we have wk r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ ) and | wk r,t, ξi | = e O(σ0σξ (2) Next, we examine the iteration min{Tµ, Tξ} k max{Tµ, Tξ} = T1. The magnitude comparison between Tµ and Tξ depends on the condition on n SNR2. In particular, we can verify that Tµ/Tξ = eΘ(n 1/2 n 1SNR 2) = eΘ(n 1SNR 1). When Tµ Tξ, i.e., n SNR2 = eΩ(1), we use induction to show for all min{Tµ, Tξ} k T1, wk r,t 2 = Θ(σ2 0d), wk r,t, w0 r,t = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ n SNR), | wk r,t, ξi | = e O(σ0σξ d). It can be shown that under the condition σ0 e O(n 1σξd1/2), we have | wk r,t, ξi | = o(1), which suggests wk r,t, w0 r,t = Θ(σ2 0d). Suppose there exists an iteration Tµ < e Tξ Tξ such that the results hold for all Tµ k e Tξ 1. Then we can derive the dominant terms in (24) L(Wt), µj = 1 mΘ(σ2 0d) + O σ4 0d2 wr,t, µj + wr,t, ξi 4 wr,t, µj + wr,t, ξi 2σ2 0d wr,t, µj = 1 mΘ(σ2 0d) where the first equality is by d = eΩ(n) and the second equality is by σ3 0 e O(m 1/2n 1d 1/2σξ). This suggests that we can still leverage (26) to bound w e Tξ r,t, µj = w e Tξ 1 r,t , µj + ηαtβt m Θ(σ2 0d) µ 2 w0 r,t, µj + ηαtβt m Θ(σ2 0d) µ 2Tξ = w0 r,t, µj + e O(σ0 µ n SNR) = e O(σ0 µ n SNR) The bound on | w e Tξ r,t, ξi | is the same as (30). Then by the same arguments in (31), and (32), (33), we can show w e Tξ r,t, w0 r,t = Θ(σ2 0d). Thus, we can compute w e Tξ r,t 2 = e O(σ2 0n2 SNR2) + e O(nσ2 0) + Θ(σ2 0d) = Θ(σ2 0d) where the last equality is by the condition on d that d = eΩ(n µ σ 1 ξ ). This verifies the induction on wk r,t 2 = Θ(σ2 0d). When Tξ < Tµ, i.e., n 1 SNR 2 = eΩ(1), we use induction to show for all min{Tµ, Tξ} k T1, wk r,t 2 = Θ(σ2 0d), | wk r,t, µj | = e O(σ0 µ ), | wk r,t, ξi | = e O(σ0σξ dn 1SNR 1). Under the condition that σ0 e O(σ 1 ξ d 3/4n) e O(σ 2 ξ d 1n), we have | wk r,t, ξi | = o(1), which suggests wk r,t, w0 r,t = Θ(σ2 0d). Suppose there exists an iteration Tξ < e Tµ Tµ such that the results hold for all Tξ k e Tµ 1. Thus we can derive the dominant terms in (25) as L(Wt), ξi = 1 mΘ(σ2 0σ2 ξd2n 1) + O σ4 0d2 wr,t, ξi + wr,t, ξi 5 + wr,t, ξi 3σ2 0d + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 = 1 mΘ(σ2 0σ2 ξd2n 1) + O wr,t, ξi 3σ2 0σ2 ξd2n 1 + wr,t, ξi σ4 0σ2 ξd3n 1 = 1 mΘ(σ2 0σ2 ξd2n 1) Published as a conference paper at ICLR 2025 where the first equality is due to SNR 1 = e O(d1/4) and d eΩ(n2/3σ 2/3 ξ ). The second equality is by σ3 0 e O(d 1m 1/2). The third equality is by σ3 0 e O(σ 1 ξ d 7/4nm 1/2). This suggests that we can still leverage (27) to bound w e Tµ r,t, ξi = w e Tµ 1 r,t , ξi + ηαtβt n m Θ(σ2 0σ2 ξd2) w0 r,t, ξi + ηαtβt n m Θ(σ2 0σ2 ξd2)Tµ The bound on w e Tξ r,t, µj is the same as (29). Then following the same argument, we can decompose w e Tξ r,t 2 = e O(σ2 0) + e O(σ2 0n 1SNR 2) + Θ(σ2 0d) = Θ(σ2 0d) where the last equality is by the condition that SNR 1 = e O(d1/4). This verifies the induction on wk r,t 2 = Θ(σ2 0d). Furthermore, at k = T1, we have for all r [m], j = 1 and i [n], the growth term dominates the initialization term and thus w T1 r,t, µj = Θ(ηαtβtm 1/2σ2 0d µ 2T1) eΘ(σ0 µ ) Θ(| w0 r,t, µj |) w T1 r,t, ξi = Θ(ηαtβtn 1m 1/2σ2 0dσ2 ξd T1) eΘ(σ0σξ d) Θ(| w0 r,t, ξi |) where the inequality is by the definition of T1. Thus, we verify the concentration of inner products, i.e., w T1 r,t, µj = Θ( w T1 r ,t, µj ) and w T1 r,t, ξi = Θ( w T1 r ,t, ξi ), at the end of first stage as well as the ratio w T1 r,t, µj / w T1 r ,t, ξi = Θ(n SNR2) for any r, r [m]. Then, we can see directly w T1 r,t 2 = Θ( w T1 r ,t 2) = Θ(σ2 0d) for all r, r [m]. Next, we verify at T1, we have w T1 r,t, w T1 r ,t = Θ( w T1 r,t 2) for all r, r [m] such that r = r . To this end, we first notice that the conditions required by Lemma D.4 are readily satisfied at k = T1 and thus applying Lemma D.4 yields w T1 r,t 2 = Θ w T1 r,t, µj 2 µ 2 + n SNR2 w T1 r,t, ξi 2 µ 2 + w0 r,t 2 w T1 r,t, w T1 r ,t = Θ w T1 r,t, µj w T1 r ,t, µj µ 2 + n SNR2 w T1 r,t, ξi w T1 r ,t, ξi µ 2 + w0 r,t, w0 r ,t = Θ w T1 r,t, µj 2 µ 2 + n SNR2 w T1 r,t, ξi 2 µ 2 + w0 r,t, w0 r ,t = Θ w T1 r,t 2 w0 r,t 2 + w0 r,t, w0 r ,t = Θ w T1 r,t 2 σ2 0d + e O(σ2 0 = Θ( w T1 r,t 2) where the second equality for w T1 r,t, w T1 r ,t is due to w T1 r,t, µj = Θ( w T1 r ,t, µj ) and w T1 r,t, ξi = Θ( w T1 r ,t, ξi ) and the second last equality is by Lemma D.1. Finally we verify that at T1, wr,t L(WT1 t ), w0 r,t = 1 mΘ( w T1 r,t, µj + ξ w T1 r,t, w0 r,t + w T1 r,t 2 w0 r,t, µj + ξ ) + O w T1 r,t, µj 4 + w T1 r,t, ξi 4 + ( w T1 r,t, µj 2 + w T1 r,t, ξi 2) w T1 r,t 2 + w T1 r,t 4 w T1 r,t, w0 r,t + O w T1 r,t, µj 3 w T1 r,t 2 w0 r,t, µj + w T1 r,t, ξi 3 w T1 r,t 2 w0 r,t, ξi + O w T1 r,t 4 w T1 r,t, µj w0 r,t, µj + w T1 r,t 4 w T1 r,t, ξi w0 r,t, ξi = 1 mΘ(( w T1 r,t, µj + ξ m w T1 r,t 4) w T1 r,t, w0 r,t + w T1 r,t 2 w0 r,t, µj + ξ ) where we use the concentration of neurons along directions µj, ξi at T1 and the scale of w T1 r,t, µj , w T1 r,t, ξi . Published as a conference paper at ICLR 2025 D.4 SECOND STAGE For the second stage, we derive an extension of Lemma D.4 given the scale of wk r,t, w0 r,t can escape initialization. We highlight that unlike wk r,t, µj and wk r,t, ξi that increase monotonically, the dominant term of wr,t L(WT1 t ), w0 r,t suggests that wk r,t, w0 r,t can also decrease. Lemma D.6. For any k and r [m], such that wk r,t, µj = Θ( wk r,t, µj ) Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, µj |), wk r,t, ξi = Θ( wk r,t, ξi ) Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, ξi |) and wk r,t, µj , wk r,t, ξi = e O(1), wk r,t, w0 r,t = Θ( wk r ,t, w0 r ,t ) = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) for any j, j = 1, i, i [n], r, r [m]. Then we can show wk r,t 2 = Θ wk r,t, µj 2 µ 2 + n SNR2 wk r,t, ξi 2 µ 2 + wk r,t, w0 r,t 2 w0 r,t 2 . And for r = r , we have wk r,t, wk r ,t = Θ wk r,t, µj wk r ,t, µj µ 2 + n SNR2 wk r,t, ξi wk r ,t, ξi µ 2 + wk r,t, w0 r,t 2 w0 r,t, w0 r ,t w0 r,t 4 Proof of Lemma D.6. Similar to the proof of Lemma D.4, we can decompose the weight wk r,t as wk r,t = ϕk rw0 r,t + γk 1µ1 µ1 2 + γk 1µ 1 µ 1 2 + i=1 ρk r,iξi ξi 2. First, we show that ϕk r = Θ( wk r,t, w0 r,t w0 r,t 2) as follows. We compute wk r,t, w0 r,t = ϕk r w0 r,t 2 + Θ( wk r,t, µj w0 r,t, µj µ 2 + wk r,t, ξi i=1 w0 r,t, ξi ξi 2) = ϕk r w0 r,t 2 + e O(σ0 + nσ0σ 1 ξ d 1/2) = Θ(ϕk r w0 r,t 2) where the second equality is by the assumption that wk r,t, µj , wk r,t, ξi = e O(1) and the last equality is by the assumption wk r,t, w0 r,t = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) and the condition that σ 1 ξ = Ω(d1/4), d e O(nm1/3) and d e O(nm1/6σ 1 ξ ). Then based on the assumption, we can still bound wk r,t, µj ϕk r| w0 r,t, µj | = Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, µj |), and similarly we can bound wk r,t, ξi ϕk r| w0 r,t, ξi | = Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, ξi |). This allows to simplify wk r,t = Θ( wk r,t, w0 r,t w0 r,t 2)w0 r,t + Θ( wk r,t, µj (µ1 + µ 1) µ 2) + Θ wk r,t, ξi n X i=1 ξi ξi 2 (34) Consequently, the assumption that wk r,t, w0 r,t = Θ( wk r ,t, w0 r ,t ), combined with (34), we can derive that ϕk r = Θ(ϕk r ) given wk r,t, µj = Θ( wk r ,t, µj ) and wk r,t, ξi = Θ( wk r ,t, ξi ). Thus, we can compute wk r,t 2 = Θ (ϕk r)2 w0 r,t 2 + wk r,t, µj 2 µ 2 + n SNR2 wk r,t, ξi 2 µ 2 = Θ( wk r ,t 2) In addition, we can derive for r = r wk r,t, wk r ,t = Θ (ϕk r)2 w0 r,t, w0 r ,t + wk r,t, µj wk r ,t, µj µ 2 + n SNR2 wk r,t, ξi wk r ,t, ξi µ 2 which completes the proof. Published as a conference paper at ICLR 2025 Lemma D.7. Let T + 1 T1 and suppose for all T1 k < T + 1 , it satisfies that for all j = 1, i [n], r [m], wk+1 r,t , µj , wk+1 r,t , ξi = e O(1), wk r,t, w0 r,t = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) and wk+1 r,t , µj = wk r,t, µj + Θ η m wk r,t 2 µ 2 (35) wk+1 r,t , ξi = wk r,t, ξi + Θ η n m wk r,t 2 ξi 2 . (36) wk+1 r,t , w0 r,t = wk r,t, w0 r,t + η mΘ wk r,t, µj + ξ m wk r,t 4 wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ (37) Then we have for all T1 k T + 1 , (1) wk r,t, µj = Θ( wk r ,t, µj ) (2) wk r,t, ξi = Θ( wk r ,t, ξi ) (3) wk r,t, w0 r,t = Θ( wk r ,t, w0 r ,t ) (4) wk r,t, µj Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, µj |), wk r,t, ξi Θ( wk r,t, w0 r,t w0 r,t 2| w0 r,t, ξi |), (5) wk r,t 2 = Θ( wk r ,t 2) (6) wk r,t, wk r ,t = Θ( wk r,t 2) for r = r (7) | wk r,t, µj |/| wk r ,t, ξi | = Θ(n SNR2) for all j = 1, r, r [m], i [n]. Proof of Lemma D.7. The proof is by induction. First, when k = T1, claims (1-8) are satisfied by Lemma D.5 with w T1 r,t, w0 r,t = Θ(σ2 0d), w T1 r,t, w0 r,t w0 r,t 2 = Θ(1). Now suppose there exists e T + 1 < T + 1 such that for all T1 k e T + 1 , (1-6) are satisfied. We aim to show for it is also satisfied for k + 1. By the assumption that for any r [m] wk+1 r,t , µj = wk r,t, µj + Θ η m wk r,t 2 µ 2 wk+1 r,t , ξi = wk r,t, ξi + Θ η n m wk r,t 2 ξi 2 , wk+1 r,t , w0 r,t = wk r,t, w0 r,t + η mΘ wk r,t, µj + ξi m wk r,t 4 wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξi we can show wk+1 r,t , µj = wk r,t, µj + Θ η m wk r,t 2 µ 2 = Θ wk r ,t, µj + η m wk r ,t 2 µ 2 = Θ( wk+1 r ,t , µj ) where the second equality is by induction condition, thus verifying the induction for claim (1). Similarly, we can use the same argument for verifying claim (2). For the claim (3) wk+1 r,t , w0 r,t = wk r,t, w0 r,t + η mΘ wk r,t, µj + ξ m wk r,t 2 wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ = Θ( wk r ,t, w0 r ,t ) Published as a conference paper at ICLR 2025 + η mΘ wk r ,t, µj + ξ m wk r ,t 4 wk r ,t, w0 r ,t + wk r ,t 2 w0 r ,t, µj + ξ = Θ( wk+1 r ,t , w0 r ,t ) where the second equality is due to induction claim that wk r,t, w0 r,t = Θ( wk r ,t, w0 r ,t ), wk r,t, µj + ξ = Θ( wk r ,t, µj + ξ ), and we can show that Θ( w0 r,t, v ) = w0 r ,t, v holds for any v, and any r, r [m] with constant probability due to m = Θ(1). Next, we verify claim (7) wk+1 r,t , µj wk+1 r ,t , ξi = wk r,t, µj + Θ η m wk r,t 2 µ 2 wk r ,t, ξi + Θ η n m wk r ,t 2 ξi 2 = Θ(n SNR2) where the last equality follows from the induction condition and µ 2/ ξi 2 = Θ(SNR2) by Lemma B.2 and wk r,t 2 = Θ( wk r ,t 2) by induction condition. Thus the induction for (7) is verified. Next in order to verify (4), we only need to show the growth of wk r,t, µj , wk r,t, ξi is larger than the growth of wk r,t, w0 r,t w0 r,t 2| w0 r,t, µj | and wk r,t, w0 r,t w0 r,t 2| w0 r,t, ξi | respectively. To this end, we consider upper bounding the update of wk r,t, w0 r,t as | wk r,t, µj + ξ wk r,t, w0 r,t | Θ wk r,t 2( µ + σξ | wk r,t 2 w0 r,t, µj + ξ | Θ( wk r,t 2( µ + σξ d) w0 r,t ). Then we consider two cases depending on the magnitude of µ and σξ d, i.e., σξ d = O(1). Then wk r,t 2( µ + σξ d) w0 r,t 1| w0 r,t, µj | = e O( wk r,t 2d 1/2) Θ( wk r,t 2 µ 2) (38) wk r,t 2( µ + σξ d) w0 r,t 1| w0 r,t, ξi | = e O( wk r,t 2σξ) Θ( 1 n wk r,t 2 ξi 2) (39) where we use the condition on d = eΩ(nσ 1 ξ ) and ξi 2 = Θ(σ2 ξd) for (39). d, i.e., we have σξ d = Ω(1). Then wk r,t 2( µ + σξ d) w0 r,t 1| w0 r,t, µj | = Θ( wk r,t 2σ 1 0 σξ w0 r,t, µj ) = e O( wk r,t 2σξ) = e O( wk r,t 2nd 1/2) Θ( wk r,t 2 µ 2) (40) wk r,t 2( µ + σξ d) w0 r,t 1| w0 r,t, ξi | = Θ( wk r,t 2σ 1 0 σξ w0 r,t, ξi ) = e O( wk r,t 2σ2 ξ n wk r,t 2σ2 ξd) = Θ( 1 n wk r,t 2 ξi 2) where the second last equality of (40) is by the condition that SNR 1 = e O(n) which implies that σξ = e O(nd 1/2). The second last inequality of (41) is by d = eΩ(n2). This suggests that | wk r,t, µj + ξ wk r,t, w0 r,t | Θ wk r,t 2 µ 2 | wk r,t 2 w0 r,t, µj + ξ | Θ( 1 n wk r,t 2 ξi 2). which verifies the claim (4) by combining with the update (35), (36), (37). Next, in order to verify (5,6), we leverage Lemma D.6. First, it is easy to verify that at k + 1, the conditions for Lemma D.6 are satisfied by the induction claims (1-4) at k + 1. Then we have wk+1 r,t 2 = Θ wk+1 r,t , µj 2 µ 2 + n SNR2 wk+1 r,t , ξi 2 µ 2 + wk r,t, w0 r,t 2 w0 r,t 2 = Θ( wk+1 r ,t , µj 2 µ 2 + n SNR2 wk+1 r ,t , ξi 2 µ 2 + wk r ,t, w0 r ,t 2 w0 r ,t 2) Published as a conference paper at ICLR 2025 = Θ( wk+1 r ,t 2). Finally, to verify (4) for k + 1, we have from Lemma D.6 that wk+1 r,t , wk+1 r ,t = Θ wk+1 r,t , µj wk+1 r ,t , µj µ 2 + n SNR2 wk+1 r,t , ξi wk+1 r ,t , ξi µ 2 + wk+1 r,t , w0 r,t 2 w0 r,t, w0 r ,t w0 r,t 4 = Θ wk+1 r,t , µj 2 µ 2 + n SNR2 wk+1 r,t , ξi 2 µ 2 + wk+1 r,t , w0 r,t 2 w0 r,t, w0 r ,t w0 r,t 4 = Θ( wk+1 r,t 2 wk+1 r,t , w0 r,t 2 w0 r,t 2 + wk+1 r,t , w0 r,t 2 w0 r,t, w0 r ,t w0 r,t 4 ) = Θ( wk+1 r,t 2) where we use the induction claims (1-2) for k + 1 and Lemma D.1. The last equality is by wk+1 r,t , w0 r,t 2 w0 r,t 2 Θ( wk+1 r,t 2). which completes all the induction. From Lemma D.5 and Lemma D.7, we know that for T1 k T + 1 we can decompose the gradient into two parts, the dominant term and the residual term: wr,t L(Wk t ), µj = 1 mΘ wk r,t 2 µ 2 + Ek r,t,µj (42) wr,t L(Wk t ), ξi = 1 n mΘ wk r,t 2 ξi 2 + Ek r,t,ξi (43) wr,t L(Wk t ), w0 r,t = 1 mΘ wk r,t, µj + ξ m wk r,t 4 wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ + Ek r,t,w0 where we let Ek r,t,µj, Ek r,t,ξi, Ek r,t,w0 denote the residual terms. Therefore, before Ek r,t,µj, Ek r,t,ξi grow to reach Ek r,t,µj = Θ( 1 m wk r,t 2 µ 2), Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2), Ek r,t,w0 = 1 mΘ(( wk r,t, µj + ξ m wk r,t 4) wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ ), it can be veri- fied that (35), (36), (37) are satisfied respectively. If further, wk r,t, µj , wk r,t, ξi = e O(1), wk r,t, w0 r,t = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) are satisfied, then we readily have | wk r,t, µj |/| wk r ,t, ξi | = Θ(n SNR2) by Lemma D.7. The next lemma characterizes the end of second stage where the residual term reaches the same order as the dominant term. Lemma D.8 (Restatement of Lemma 4.2). Consider the gradient decomposition defined in (42), (43) and (44). There exists T2 > T1 with T2 = Θ(max{η 1m1/3σ 2 0 d 1, η 1m1/3nσ 2 0 σ2 ξ}) such that for all j = 1, r [m], i [n], (1) If n SNR2 = Ω(1), w T2 r,t, µj = Θ(m 1/6), w T2 r,t, ξi = Θ( w T2 r,t, µj ), w T2 r,t, w0 r,t w0 r,t 2 Θ n SNR2 w T2 r,t, ξi If n 1 SNR 2 = Ω(1), w T2 r,t, µj = Θ(n SNR2 m 1/6), w T2 r,t, ξi = Θ(m 1/6) w T2 r,t, w0 r,t w0 r,t 2 Θ p n SNR2 w T2 r,t, ξi (2) ET2 r,t,µj = Θ( 1 m w T2 r,t 2 µ 2), ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2), ET2 r,t,w0 = 1 mΘ( wk r,t, µj + ξi wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξi ). Published as a conference paper at ICLR 2025 In addition, for any T1 k T2 and for all j = 1, r [m], i [n], (3) wk r,t, µj = Θ( wk r ,t, µj ) and wk r,t, ξi = Θ( wk r ,t, ξi ), (4) wk r,t 2 = Θ( wk r ,t 2) and wk r,t, wk r ,t = Θ( wk r,t 2), (5) wk r,t, µj / wk r,t, ξi = Θ(n SNR2). Proof of Lemma D.8. Here we let T2 be the first time such that ET2 r,t,µj = Θ( 1 m w T2 r,t 2 µ 2) or ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2) or ET2 r,t,w0 = 1 mΘ(( wk r,t, µj + ξ m wk r,t 4) wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ ). In order to prove the results, we use induction k to simultaneously prove the following conditions A (k), B(k), C (k), D(T2), E (T2), for T1 k T2: A (k): wk r,t, µj , wk r,t, ξi = e O(1), wk r,t, w0 r,t = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) for all j = 1, r [m], i [n]. B(k): wr,t L(Wk t ), µj = Θ 1 m wk r,t 2 µ 2 , wr,t L(Wk t ), ξi = Θ 1 n m wk r,t 2 ξi 2 and wr,t L(Wk t ), w0 r,t = 1 mΘ wk r,t, µj + ξ wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ for all j = 1, r [m], i [n]. C (k): Claims (3-5), i.e., wk r,t, µj = Θ( wk r ,t, µj ), wk r,t, ξi = Θ( wk r ,t, ξi ), wk r,t 2 = Θ( wk r ,t 2), wk r,t, wk r ,t = Θ( wk r,t 2), and wk r,t, µj / wk r,t, ξi = Θ(n SNR2). D(T2): Claim (1), i.e., If n SNR2 = Ω(1), w T2 r,t, µj = Θ(m 1/6), w T2 r,t, ξi = Θ(n 1 SNR 2 m 1/6), and if n 1 SNR 2 = Ω(1), w T2 r,t, µj = Θ(n SNR2 m 1/6), w T2 r,t, ξi = Θ(m 1/6). E (T2): Claim (2), i.e., ET2 r,t,µj = Θ( 1 m w T2 r,t 2 µ 2), ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2), and ET2 r,t,w0 = 1 mΘ ( w T2 r,t, µj + ξ m w T2 r,t 4) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ . The initial conditions A (T1), B(T1), C (T1) are satisfied by Lemma D.5 at the end of the first stage. In order to show C (k), D(T2), E (T2), we show the following claims respectively. Claim D.1. A (k), B(k) C (k), for any T1 k T2. Claim D.2. C (T1), ..., C (T2) D(T2), E (T2). Claim D.3. D(T2), E (T2) A (T1), ..., A (T2), B(T1), ..., B(T2). Proof of Claim D.1. Claim D.1 directly follows from Lemma D.7. Proof of Claim D.2. First, when C (k) is satisfied, we can simplify wk r,t 2 from Lemma D.6 wk r,t 2 = Θ wk r,t, µj 2 µ 2 + n SNR2 wk r,t, ξi 2 µ 2 + wk r,t, w0 r,t 2 w0 r,t 2 . = Θ (n2SNR4 + n SNR2) µ 2 wk r,t, ξi 2 + wk r,t, w0 r,t 2 w0 r,t 2 = Θ (χ2 + χ) µ 2 wk r,t, ξi 2 + ψk r,t where we temporarily denote χ := n SNR2 and ψk r,t := wk r,t, w0 r,t 2 w0 r,t 2 for notation clarity. Then, for the update of wk r,t, µj , we can compute i=1 L(1) 1,i (wk r,t), µj mΘ wk r,t, µj 5 + wk r,t, µj 3 wk r,t 2 + wk r,t, µj wk r,t 4 m wk r,t, µj 2 Published as a conference paper at ICLR 2025 mΘ wk r,t, µj 3 wk r,t 2 µ 2 + wk r,t, µj wk r,t 4 µ 2 m wk r,t 2 µ 2 mΘ(χ5 wk r,t, ξi 5 + (χ5 + χ4) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 mχ2 wk r,t, ξi 2) mΘ (χ5 + χ4) wk r,t, ξi 5 + (χ5 + χ3) wk r,t, ξi 5 µ 2 m(χ2 + χ) wk r,t, ξi 2 mΘ(ψk r,tχ3 wk r,t, ξi 3 µ 2 + (ψk r,t)2χ wk r,t, ξi µ 2 mψk r,t µ 2) mΘ m(χ2 + χ) wk r,t, ξi 2 + (χ5 + χ4) wk r,t, ξi 5 + (χ5 + χ3) wk r,t, ξi 5 µ 2 mΘ(ψk r,tχ3 wk r,t, ξi 3 µ 2 + (ψk r,t)2χ wk r,t, ξi µ 2 mψk r,t µ 2) Similarly, we obtain i=1 L(2) 1,i (wr,t), µj mΘ wk r,t, ξi 4 wk r,t, µj + wk r,t, ξi 2 wk r,t, µj wk r,t 2 + wk r,t, µj wk r,t 4 m wk r,t, ξi wk r,t, µj mΘ χ wk r,t, ξi 5 + (χ3 + χ2) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 mχ wk r,t, ξi 2 + 1 mΘ(χψk r,t wk r,t, ξi 3 + χ(ψk r,t)2 wk r,t, ξi ) mΘ mχ wk r,t, ξi 2 + χ wk r,t, ξi 5 + (χ3 + χ2) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 mΘ(χψk r,t wk r,t, ξi 3 + χ(ψk r,t)2 wk r,t, ξi ) i=1 L(1) 2,i (wr,t), µj m Θ wk r,t, µj 5 + wk r,t, µj 3 wk r,t 2 + wk r,t, µj wk r,t 4 + wk r,t, µj wk r,t 4 µ 2 + wk r,t, µj 3 wk r,t 2 µ 2 m Θ (χ5 + χ4) wk r,t, ξi 5 + (χ5 + χ3) wk r,t, ξi 5 µ 2 m Θ(ψk r,tχ3 wk r,t, ξi 3 µ 2 + (ψk r,t)2χ wk r,t, ξi µ 2 mψk r,t µ 2) i=1 L(2) 2,i (wr,t), µj m Θ wk r,t, µj wk r,t, ξi 4 + wk r,t, µj wk r,t 4 + wk r,t, µj wk r,t, ξi 2 wk r,t 2 m Θ χ wk r,t, ξi 5 + (χ3 + χ2) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 m Θ(χψk r,t wk r,t, ξi 3 + χ(ψk r,t)2 wk r,t, ξi ) Combining the above results, we have wr,t L(Wk t ), µj = 1 mΘ wk r,t 2 µ 2 + Ek r,t,µj = 1 mΘ (χ2 + χ) wk r,t, ξi 2 + ψk r,t µ 2 Published as a conference paper at ICLR 2025 + Θ (χ5 + χ4) wk r,t, ξi 5 + (χ5 + χ3) wk r,t, ξi 5 µ 2 + Θ χ wk r,t, ξi 5 + (χ3 + χ2) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 + Θ(ψk r,tχ3 wk r,t, ξi 3 µ 2 + (ψk r,t)2χ wk r,t, ξi µ 2 + χψk r,t wk r,t, ξi 3 Similarly, we can derive for the update of wk r,t, ξi as follows: i=1 L(1) 1,i (wk r,t), ξi mΘ wk r,t, µj 4 wk r,t, ξi + wk r,t, µj 2 wk r,t 2 wk r,t, ξi + wk r,t 4 wk r,t, ξi m wk r,t, µj wk r,t, ξi mΘ χ4 wk r,t, ξi 5 + (χ4 + χ3) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 mχ wk r,t, ξi 2 mΘ(ψk r,tχ2 wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi ) mΘ mχ wk r,t, ξi 2 + χ4 wk r,t, ξi 5 + (χ4 + χ3) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 mΘ(ψk r,tχ2 wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi ) i=1 L(2) 1,i (wk r,t), ξi mΘ wk r,t, ξi 5 + wk r,t, ξi 3 wk r,t 2 + wk r,t 4 wk r,t, ξi m wk r,t, ξi 2 + 1 nmΘ wk r,t, ξi 3 wk r,t 2 ξi 2 + wk r,t 4 wk r,t, ξi ξi 2 m wk r,t 2 ξi 2 mΘ wk r,t, ξi 5 + (χ2 + χ) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 m wk r,t, ξi 2 + 1 χmΘ (χ2 + χ) wk r,t, ξi 5 + (χ4 + χ2) wk r,t, ξi 5 µ 2 m(χ2 + χ) wk r,t, ξi 2 mΘ ψk r,t wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 χ 1 mψk r,t µ 2 mΘ m(χ + 1) wk r,t, ξi 2 + (χ2 + χ) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + (χ + 1) wk r,t, ξi 5 + (χ3 + χ) wk r,t, ξi 5 µ 2 mΘ ψk r,t wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 χ 1 mψk r,t µ 2 where the second equality follows from Pn i =1 ξi , ξi = (1 + e O(nd 1/2)) ξi 2 = Θ( ξi 2) by Lemma B.2 and condition on d. Further, i =1 L(1) 2,i (wk r,t), ξi m Θ wk r,t, µj 4 wk r,t, ξi + wk r,t 4 wk r,t, ξi + wk r,t, µj 2 wk r,t 2 wk r,t, ξi m Θ χ4 wk r,t, ξi 5 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + (χ4 + χ3) wk r,t, ξi 5 µ 2 m Θ ψk r,tχ2 wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi Published as a conference paper at ICLR 2025 i =1 L(2) 2,i (wk r,t), ξi m Θ wk r,t, ξi 5 + wk r,t 4 wk r,t, ξi + wk r,t, ξi 3 wk r,t 2 nm Θ wk r,t, ξi 3 wk r,t 2 ξi 2 + wk r,t 4 wk r,t, ξi ξi 2 m Θ wk r,t, ξi 5 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + (χ2 + χ) wk r,t, ξi 5 µ 2 χm Θ (χ2 + χ) wk r,t, ξi 5 + (χ4 + χ2) wk r,t, ξi 5 µ 2 m Θ ψk r,t wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 χ 1 mψk r,t µ 2 m Θ (χ2 + χ) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + (χ + 1) wk r,t, ξi 5 + (χ3 + χ) wk r,t, ξi 5 µ 2 m Θ ψk r,t wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 χ 1 mψk r,t µ 2 Combining the above results, we have wr,t L(Wk t ), ξi = 1 n mΘ wk r,t 2 ξi 2 + Ek r,t,ξi = 1 mΘ (χ + 1) wk r,t, ξi 2 + χ 1ψk r,t µ 2 + Θ χ4 wk r,t, ξi 5 + (χ4 + χ3) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + Θ (χ2 + χ) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + Θ (χ + 1) wk r,t, ξi 5 + (χ3 + χ) wk r,t, ξi 5 µ 2 + Θ ψk r,tχ2 wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + ψk r,t wk r,t, ξi 3 + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 In summary, we finally arrive at wr,t L(Wk t ), µj = 1 mΘ wk r,t 2 µ 2 + Ek r,t,µj = 1 mΘ (χ2 + χ) wk r,t, ξi 2 + ψk r,t µ 2 + Θ (χ5 + χ4) wk r,t, ξi 5 + (χ5 + χ3) wk r,t, ξi 5 µ 2 + Θ χ wk r,t, ξi 5 + (χ3 + χ2) wk r,t, ξi 5 µ 2 + (χ5 + χ3) wk r,t, ξi 5 µ 4 + Θ(ψk r,tχ3 wk r,t, ξi 3 µ 2 + (ψk r,t)2χ wk r,t, ξi µ 2 + χψk r,t wk r,t, ξi 3 (45) wr,t L(Wk t ), ξi = 1 n mΘ wk r,t 2 ξi 2 + Ek r,t,ξi = 1 mΘ (χ + 1) wk r,t, ξi 2 + χ 1ψk r,t µ 2 + Θ χ4 wk r,t, ξi 5 + (χ4 + χ3) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 + Θ (χ2 + χ) wk r,t, ξi 5 µ 2 + (χ4 + χ2) wk r,t, ξi 5 µ 4 Published as a conference paper at ICLR 2025 + Θ (χ + 1) wk r,t, ξi 5 + (χ3 + χ) wk r,t, ξi 5 µ 2 + Θ ψk r,tχ2 wk r,t, ξi 3 + (ψk r,t)2 wk r,t, ξi + ψk r,t wk r,t, ξi 3 + χ 1ψk r,t wk r,t, ξi 3 µ 2 + χ 1(ψk r,t)2 wk r,t, ξi µ 2 (46) We also examine wr,t L(Wk t ), w0 r,t as follows. We first upper bound for r = r wk r,t, w0 r ,t = Θ( wk r,t, w0 r,t w0 r,t 2) w0 r,t, w0 r ,t + Θ( wk r,t, µj w0 r ,t, µj µ 2) + Θ wk r,t, ξi n X i=1 w0 r ,t, ξi ξi 2 = e O(σ0) + e O(σ0) + e O(nσ0σ 1 ξ d 1/2) = e O(σ0) + e O(nσ0σ 1 ξ d 1/2) where we use (34) in the first equality and Lemma D.1, Lemma D.7 in the second equality. Next we simplify the gradient as wr,t L(Wk t ), w0 r,t = 1 mΘ wk r,t, µj + ξi wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξi mΘ wk r,t, µj 4 + wk r,t, ξi 4 wk r,t, w0 r,t + wk r,t, µj 2 + wk r,t, ξi 2 wk r,t 2 wk r,t, w0 r,t mΘ wk r,t 4 wk r,t, w0 r,t + wk r,t, µj 3 w0 r,t, µj + wk r,t, ξi 3 w0 r,t, ξi wk r,t 2 mΘ wk r,t, µj w0 r,t, µj + wk r,t, ξi w0 r,t, ξi wk r,t 4 m Θ wk r,t, µj 4 + wk r,t, ξi 4 + wk r,t 2 wk r,t, µj 2 + wk r,t, ξi 2 + wk r,t 4 e O(σ0 + nσ0σ 1 ξ d 1/2) m Θ wk r,t, µj 3 w0 r,t, µj + wk r,t, ξi 3 w0 r,t, ξi wk r,t 2 m Θ wk r,t, µj w0 r,t, µj + wk r,t, ξi w0 r,t, ξi wk r,t 4 m Θ wk r,t, µj 2 + wk r,t, ξi 2 wk r,t 2 wk r,t, w0 r,t m Θ wk r,t 4 wk r,t, w0 r,t = 1 mΘ wk r,t, µj + ξi wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξi mΘ wk r,t, µj 4 + wk r,t, ξi 4 wk r,t, w0 r,t + Θ wk r,t, µj 2 + wk r,t, ξi 2 wk r,t 2 + wk r,t 4 wk r,t, w0 r,t + Θ wk r,t, µj 3 w0 r,t, µj + wk r,t, ξi 3 w0 r,t, ξi wk r,t 2 + Θ wk r,t, µj w0 r,t, µj + wk r,t, ξi w0 r,t, ξi wk r,t 4 (47) where we use that wk r,t, w0 r,t Θ(σ2 0d) e O(σ0 + nσ0σ 1 ξ d 1/2) due to the condition that σ0 e O(max{nσ 1 ξ d 3/2, d 1}). In order to identify the dominant terms of (45) and (46), we separate the analysis for three cases depending on the scale of n SNR2. For each case, we also consider two sub-cases depending on the scale of wk r,t, ξi . Recall that we define ψk r,t = wk r,t, w0 r,t 2 w0 r,t 2 Published as a conference paper at ICLR 2025 When χ = n SNR2 = Θ(1), If wk r,t, ξi 2 Θ(ψk r,t), we can identify the dominant terms as 1 m wk r,t µ 2 = 1 mΘ wk r,t, ξi 2 , Ek r,t,µj = Θ( wk r,t, ξi 5) 1 n m wk r,t 2 ξi 2 = 1 mΘ wk r,t, ξi 2 , Ek r,t,ξi = Θ( wk r,t, ξi 5) Hence, we can see at k = T2, when w T2 r,t, ξi = Θ(m 1/6) and w T2 r,t, µj = Θ(χm 1/6) = Θ(m 1/6) (by the condition of χ), we have ET2 r,t,µj = Θ( 1 m w T2 r,t µ 2) and ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2). It remains to show that ET2 r,t,w0 = 1 mΘ ( w T2 r,t, µj + ξ m w T2 r,t 4) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ . We compute that w T2 r,t 2 = Θ w T2 r,t, ξi 2 + ψT2 r,t = Θ(m 1/3), which implies ( w T2 r,t, µj + ξ m w T2 r,t 2) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ = 1 mΘ m 1/6 w T2 r,t, w0 r,t + m 1/3 w0 r,t, µj + ξ = Θ m 2/3 w T2 r,t, w0 r,t + m 5/6 w0 r,t, µj + ξ = ET2 r,t,w0 based on the derivation in (47). In this case, we can show ψk r,t wk r,t, ξi 2 = Θ(m 1/3) is feasible given the lower bound on wk r,t, w0 r,t = Ω(σ0 dm 1/6) in A (k). This verifies D(T2), E (T2) in the case where wk r,t, ξi 2 Θ(ψk r,t). If wk r,t, ξi 2 = o(ψk r,t), we show wr,t L(Wk t ), w0 r,t = 0 and thus cannot converge. More specifically, we can identify the dominant terms as 1 m wk r,t µ 2 = 1 mΘ ψk r,t , Ek r,t,µj = Θ((ψk r,t)2 wk r,t, ξi ) 1 n m wk r,t 2 ξi 2 = 1 mΘ ψk r,t , Ek r,t,ξi = Θ((ψk r,t)2 wk r,t, ξi ). Thus, we see Ek r,t,µj = Θ( 1 m wk r,t µ 2) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) only when wk r,t, ξi = Θ(m 1/2(ψk r,t) 1) = wk r,t, µj . This implies that ψk r,t Θ(m 1/3) and thus wk r,t 2 = Θ(ψk r,t) Θ(m 1/3). We show in this case, the gradient along direction w0 r,t is negative. Examining wr,t L(Wk t ), w0 r,t , we see ( wk r,t, µj + ξ m wk r,t 4) wk r,t, w0 r,t + wk r,t 2 w0 r,t, µj + ξ Θ(m 1/2(ψk r,t) 1 m(ψk r,t)2) wk r,t, w0 r,t + Θ(ψk r,t) w0 r,t, µj + ξ = Θ m 1(ψk r,t) 1 wk r,t, w0 r,t (ψk r,t)2 wk r,t, w0 r,t + m 1/2ψk r,t w0 r,t, µj + ξ Ek r,t,w0 = Θ (m 3(ψk r,t) 4 + m 1(ψk r,t) 1 + (ψk r,t)2) wk r,t, w0 r,t + Θ (m 3/2(ψk r,t) 2 + m 1/2ψk r,t) w0 r,t, µj + ξ = Θ((ψk r,t)2 wk r,t, w0 r,t + m 1/2ψk r,t w0 r,t, µj + ξ ) where we use that ψk r,t Θ(m 1/3). It is clear that that due to the condition ψk r,t Θ(m 1/3), it satisfies wr,t L(Wk t ), w0 r,t < 0, which then concludes that ψk r,t would decrease and cannot reach stationary point. When χ = n SNR2 = eΩ(1), Published as a conference paper at ICLR 2025 If χ2 wk r,t, ξi 2 Θ(ψk r,t), we can simplify (45) and (46) to 1 m wk r,t µ 2 = 1 mΘ χ2 wk r,t, ξi 2 , Ek r,t,µj = Θ(χ5 wk r,t, ξi 5) 1 n m wk r,t 2 ξi 2 = 1 mΘ χ wk r,t, ξi 2 , Ek r,t,ξi = Θ(χ4 wk r,t, ξi 5) Hence, we can see at k = T2, when w T2 r,t, ξi = Θ(χ 1m 1/6) and thus w T2 r,t, µj = Θ(m 1/6), we have ET2 r,t,µj = Θ( 1 m w T2 r,t µ 2) and ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2). Next we show ET2 r,t,w0 = 1 mΘ ( w T2 r,t, µj + ξ m w T2 r,t 4) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ . We first compute that w T2 r,t 2 = Θ χ2 w T2 r,t, ξi 2 + ψT2 r,t = Θ(m 1/3), which implies ( w T2 r,t, µj + ξ m w T2 r,t 2) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ = 1 mΘ m 1/6 w T2 r,t, w0 r,t + m 1/3 w0 r,t, µj + ξ = Θ m 2/3 w T2 r,t, w0 r,t + m 5/6 w0 r,t, µj + ξ = ET2 r,t,w0 based on the derivation in (47). In this case, we can show ψk r,t χ2 wk r,t, ξi 2 = Θ(m 1/3) is feasible given the lower bound on wk r,t, w0 r,t = Ω(σ0 dm 1/6) in A (k). This verifies D(T2), E (T2) in the case where χ2 wk r,t, ξi 2 Θ(ψk r,t). If χ2 wk r,t, ξi 2 = o(ψk r,t), we can follow a similar argument to show that wr,t L(Wk t ), w0 r,t = 0. Specifically, we can simplify (45) and (46) to 1 m wk r,t µ 2 = 1 mΘ ψk r,t , Ek r,t,µj = Θ((ψk r,t)2χ wk r,t, ξi ) 1 n m wk r,t 2 ξi 2 = η mΘ χ 1ψk r,t , Ek r,t,ξi = Θ((ψk r,t)2 wk r,t, ξi ). Thus, we see Ek r,t,µj = Θ( 1 m wk r,t µ 2) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) only when χ wk r,t, ξi = Θ((ψk r,t) 1m 1/2), which implies ψk r,t Θ(m 1/3). However, by a similar argument as the case n SNR2 = Θ(1), we can show wr,t L(Wk t ), w0 r,t < 0, which then concludes that ψk r,t would decrease and cannot reach stationary point. When χ 1 = n 1SNR 2 = eΩ(1), If wk r,t, ξi 2 Θ(χ 1ψk r,t), we can simplify (45) and (46) into 1 m wk r,t µ 2 = 1 mΘ χ wk r,t, ξi 2 , Ek r,t,µj = Θ(χ wk r,t, ξi 5) 1 n m wk r,t 2 ξi 2 = 1 mΘ wk r,t, ξi 2 , Ek r,t,ξi = Θ( wk r,t, ξi 5) Hence, we can see at k = T2, when w T2 r,t, ξi = Θ(m 1/6) and w T2 r,t, µj = Θ(χm 1/6), we have ET2 r,t,µj = Θ( 1 m w T2 r,t µ 2) and ET2 r,t,ξi = Θ( 1 n m w T2 r,t 2 ξi 2). In this case, we can show ψk r,t χ wk r,t, ξi 2 is feasible given the lower bound on wk r,t, w0 r,t = Ω(σ0σ 1 ξ n1/2m 1/6) in A (k) when χ 1 = eΩ(1). Next we can check that at T2, max{ w T2 r,t, µj , w T1 r,t, ξi } = w T2 r,t, ξi = Θ(m 1/6). Further, w T2 r,t 2 = Θ w T2 r,t, µj 2 + w T2 r,t, µj w T2 r,t, ξi + w T2 r,t, w0 r,t 2 w0 r,t 2 = Θ(χm 1/3) Published as a conference paper at ICLR 2025 where the second equality is by w T2 r,t, w0 r,t 2 w0 r,t 2 = ψT2 r,t Θ(χm 1/3). This leads to ( w T2 r,t, µj + ξ m w T2 r,t 4) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj + ξ = Θ m 2/3 w T2 r,t, w0 r,t + χm 5/6 w0 r,t, µj + ξ ET2 r,t,w0 = Θ (m 1 + χ)m 2/3 w T2 r,t, w0 r,t + χm 5/6 w0 r,t, µj + ξ where ET2 r,t,w0 = 1 m ( w T2 r,t, µj +ξ m w T2 r,t 4) w T2 r,t, w0 r,t + w T2 r,t 2 w0 r,t, µj +ξ holds due to m = Θ(1). This verifies D(T2), E (T2). If wk r,t, ξi 2 < Θ(χ 1ψk r,t), we can simplify (45) and (46) into 1 m wk r,t µ 2 = 1 mΘ ψk r,t , Ek r,t,µj = Θ((ψk r,t)2 wk r,t, ξi χ + ψk r,t wk r,t, ξi 3χ) 1 n m wk r,t 2 ξi 2 = 1 mΘ χ 1ψk r,t , Ek r,t,ξi = Θ(χ 1ψk r,t wk r,t, ξi 3 + χ 1(ψk r,t)2 wk r,t, ξi ) * If wk r,t, ξi 2 Θ(ψk r,t), the equalities become 1 m wk r,t µ 2 = 1 mΘ ψk r,t , Ek r,t,µj = Θ(ψk r,t wk r,t, ξi 3χ) 1 n m wk r,t 2 ξi 2 = 1 mΘ χ 1ψk r,t , Ek r,t,ξi = Θ(χ 1ψk r,t wk r,t, ξi 3), and Ek r,t,µj = Θ( 1 m wk r,t µ 2) when wk r,t, ξi = Θ(m 1/6χ 1/3) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) when wk r,t, ξi = Θ(m 1/6). Due to the scale that χ 1 = eΩ(1), two equalities cannot hold at the same time. * If wk r,t, ξi 2 < Θ(ψk r,t), the equalities become 1 m wk r,t µ 2 = 1 mΘ ψk r,t , Ek r,t,µj = Θ((ψk r,t)2 wk r,t, ξi χ) 1 n m wk r,t 2 ξi 2 = 1 mΘ χ 1ψk r,t , Ek r,t,ξi = Θ(χ 1(ψk r,t)2 wk r,t, ξi ). Thus Ek r,t,µj = Θ( 1 m wk r,t µ 2) when wk r,t, ξi = Θ(χ 1(ψk r,t) 1) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) when wk r,t, ξi = Θ((ψk r,t) 1), which clearly cannot be satis- fied at the same time given χ 1 = eΩ(1). To conclude, we verify that when wk r,t, ξi 2 Θ(χ 1ψk r,t), Ek r,t,µj = Θ( 1 m wk r,t µ 2) or Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) cannot be satisfied. In summary, we obtain the following results: When n SNR2 = Θ(1), we can show Ek r,t,µj = Θ( 1 m wk r,t µ 2) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) when wk r,t, ξi = Θ(m 1/6). When n SNR2 = eΩ(1), we can show Ek r,t,µj = Θ( 1 m wk r,t µ 2) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) when wk r,t, ξi = Θ(χ 1m 1/6). When n 1 SNR 2 = eΩ(1), we can show Ek r,t,µj = Θ( 1 m wk r,t µ 2) and Ek r,t,ξi = Θ( 1 n m wk r,t 2 ξi 2) when wk r,t, ξi = Θ(m 1/6). Combining the definition of T2, we complete the proof for Claim D.2. Published as a conference paper at ICLR 2025 Proof of Claim D.3. By the definition of T2 and Lemma D.5, we know that for all T1 k T2, the gradients can be written as wr,t L(Wk t ), µj = Θ 1 m wk r,t 2 µ 2 wr,t L(Wk t ), ξi = Θ 1 n m wk r,t 2 ξi 2 and thus (35), (36) are satisfied, which verifies B(k). In addition, this suggests, for all T1 k T2, the increase in wk r,t, ξi , wk r,t, µj is monotonic. Combining with D(T2), we have wk r,t, µj , wk r,t, ξi = O(m 1/6) = e O(1) for all T1 k T2, thus verifying A (k). Hence, the proof completes the induction on k and verify the claims A (k), B(k), C (k), D(T2), E (T2), T1 k T2. Finally, we derive an upper bound on T2. Because for all T1 k T2, we can decompose wk r,t 2 from Lemma D.4 as wk r,t 2 = Θ wk r,t, µj 2 µ 2 + χ wk r,t, ξi 2 µ 2 + w0 r,t 2 Θ(σ2 0d) Therefore, we can upper bound the update in (35), (36) for T1 k T2 by a liner growth as wk+1 r,t , µj = wk r,t, µj + Θ η m wk r,t 2 µ 2 wk r,t, µj + Θ ηm 1/2σ2 0d wk+1 r,t , ξi = wk r,t, ξi + Θ η n m wk r,t 2 ξi 2 wk r,t, ξi + Θ ηχ 1m 1/2σ2 0d . Therefore, we can upper bound as T2 Θ(max{η 1m1/3σ 2 0 d 1, η 1m1/3nσ 2 0 σ2 ξ}). D.5 STATIONARY POINT This section analyzes the stationary point with the conditions at the end of the second stage. Theorem D.1. Under Condition D.1, suppose (1) w r,t, µj = Θ( w r ,t, µj ) = e O(1), (2) w r,t, ξi = Θ( w r ,t, ξi ) = e O(1), (3) w r,t 2 = Θ( w r ,t 2) and (4) w r,t, w r ,t = Θ( w r,t 2) (5) w r,t, w0 r,t = Θ( w r ,t, w0 r ,t ) = Ω(min{σ0σ 1 ξ n1/2m 1/6, σ0 dm 1/6}) hold for all j = 1, r [m], i [m]. Then there exists a stationary point W t , i.e., wr,t L(W t ) = 0 that satisfies | w r,t, µj |/| w r,t, ξi | = Θ(n SNR2), with w r,t, ξi = Θ(n 1 SNR 2 m 1/6), w r,t, w0 r,t w0 r,t 1 Θ( w r,t, µj ) if n SNR2 = Ω(1), and w r,t, ξi = Θ(m 1/6), w r,t, w0 r,t w0 r,t 1 Θ( n SNR2 w r,t, ξi ) if n 1 SNR 2 = Ω(1). Proof of Theorem D.1. The analysis mostly follows from Lemma D.8. Due to the concentration of neurons, we can derive wr,t L(W t ), µj = 1 mΘ w r,t, µj 2 + w r,t 2 µj 2 + w r,t, ξi w r,t, µj + Θ w r,t, µj 5 + w r,t, µj 3 w r,t 2 + w r,t 4 w r,t, µj + w r,t, µj 3 w r,t 2 µj 2 + w r,t 4 w r,t, µj µj 2 + w r,t, ξi 4 w r,t, µj + w r,t, ξi 2 w r,t 2 w r,t, µj wr,t L(W t ), ξi = 1 mΘ w r,t, µj w r,t, ξi + w r,t, ξi 2 + 1 n w r,t 2 ξi 2 + Θ w r,t, ξi 5 + w r,t, ξi 3 w r,t 2 + w r,t, µj 4 w r,t, ξi + w r,t, µj 2 w r,t 2 w r,t, ξi Published as a conference paper at ICLR 2025 + w r,t 4 w r,t, ξi + 1 n w r,t, ξi 3 w r,t 2 ξi 2 + 1 n w r,t 4 w r,t, ξi ξi 2 wr,t L(W t ), w0 r,t = 1 mΘ w r,t, µj + ξi w r,t, w0 r,t + w r,t 2 w0 r,t, µj + ξi mΘ w r,t, µj 4 + w r,t, ξi 4 w r,t, w0 r,t + Θ w r,t, µj 2 + w r,t, ξi 2 w r,t 2 w r,t, w0 r,t + Θ w r,t 4 w r,t, w0 r,t + w r,t, µj 3 w0 r,t, µj + w r,t, ξi 3 w0 r,t, ξi w r,t 2 + Θ w r,t, µj w0 r,t, µj + w r,t, ξi w0 r,t, ξi w r,t 4 And we can verify that W t is a stationary point if and only if for all j = 1, r [m], i [n], wr,t L(W t ), µj = wr,t L(W t ), ξi = wr,t L(W t ), w0 r,t = 0. This leads to the following equation system: mΘ w r,t, µj 5 + w r,t, µj 3 w r,t 2 + w r,t, ξi 4 w r,t, µj + w r,t 4 w r,t, µj + w r,t, ξi 2 w r,t 2 w r,t, µj + w r,t, µj 3 w r,t 2 µj 2 + w r,t 4 w r,t, µj µj 2 = Θ w r,t, ξi w r,t, µj + w r,t, µj 2 + w r,t 2 µj 2 (48) mΘ w r,t, ξi 5 + w r,t, ξi 3 w r,t 2 + w r,t, µj 4 w r,t, ξi + w r,t 4 w r,t, ξi + w r,t, µj 2 w r,t 2 w r,t, ξi + 1 n w r,t, ξi 3 w r,t 2 ξi 2 + 1 n w r,t 4 w r,t, ξi ξi 2 = Θ w r,t, µj w r,t, ξi + w r,t, ξi 2 + 1 n w r,t 2 ξi 2 (49) m w r,t, µj 4 + w r,t, ξi 4 w r,t, w0 r,t + w r,t, µj 2 + w r,t, ξi 2 w r,t 2 w r,t, w0 r,t + w r,t 4 w r,t, w0 r,t + w r,t, µj 3 w0 r,t, µj + w r,t, ξi 3 w0 r,t, ξi w r,t 2 + w r,t, µj w0 r,t, µj + w r,t, ξi w0 r,t, ξi w r,t 4 = Θ w r,t, µj + ξi w r,t, w0 r,t + w r,t 2 w0 r,t, µj + ξi (50) In order to solve the system, we let τi,j := w r,t,µj w r,t,ξi for any i [n], j = 1. We let τ = Θ(τi,j). We first consider solving (48) and (49) and then analyze (50). Furthermore, because the claims (1-4) are assumed, we can leverage Lemma D.4 to decompose w r,t 2 = Θ w r,t, µj 2 µ 2 + n SNR2 w r,t, ξi 2 µ 2 + w r,t, w0 r,t 2 w0 r,t 2 = Θ (τ 2 + n SNR2) w r,t, ξi 2 µ 2 + w r,t, w0 r,t 2 w0 r,t 2 where the third equality is by the scale of w r,t, µj and w r,t, ξi . Next, we separately consider three SNR conditions, namely (1) n SNR2 = Θ(1); (2) n SNR2 eΩ(1); and (3) n 1 SNR 2 eΩ(1). 1. When n SNR2 = Θ(1): we first can bound w r,t, w0 r,t w0 r,t 1 Θ( w r,t, µj ) and derive w r,t 2 = max{Θ( w r,t, µ 2), Θ( w r,t, ξi 2)} µ 2 Next, we can simplify (48) and (49) depending on the scale of τ. When τ = eΩ(1), we have w r,t 2 = Θ( w r,t, µj 2) µj 2 and the equations reduce to Θ( mτ 5 w r,t, ξi 5) = Θ(τ 2 w r,t, ξi 2) Θ( mτ 4 w r,t, ξi 5) = Θ(τ 2 w r,t, ξi 2) It is clear to see for τ = eΩ(1), the equations cannot be jointly satisfied. Published as a conference paper at ICLR 2025 When τ 1 = eΩ(1), we have w r,t 2 = Θ( w r,t, ξi 2) µj 2 and the equations reduce to Θ( mτ w r,t, ξi 5) = Θ( w r,t, ξi 2) Θ( m w r,t, ξi 5) = Θ( w r,t, ξi 2) which cannot be satisfied simultaneously for τ 1 = eΩ(1). When τ = Θ(1), w r,t 2 = Θ( w r,t, µj 2) µj 2 = Θ( w r,t, ξi 2) µj 2 and thus we can simplify the equations to Θ( m w r,t, ξi 5) = Θ( w r,t, ξi 2) Θ( m w r,t, ξi 5) = Θ( w r,t, ξi 2) which has a solution with w r,t, ξi = Θ(m 1/6) = w r,t, µj , thus verifying the scale and τ = Θ(n SNR2). Then we can verify (50) holds under the scale of w r,t, ξi , w r,t, µj = Θ(m 1/6) and w r,t 2 = Θ(m 1/3). With the same argument as in Lemma D.8, we can show wr,t L(W t ), w0 r,t = 0. 2. When n SNR2 = eΩ(1): we first can bound w r,t, w0 r,t w0 r,t 1 Θ( w r,t, µj ) and derive w r,t 2 = max{Θ(τ 2), Θ(n SNR2)} w r,t, ξi 2 µ 2. We only consider the scale when τ = Θ(n SNR2), where we can simplify w r,t 2 = Θ(τ 2) w r,t, ξi 2 µj 2 and thus the system becomes Θ( mτ 5 w r,t, ξi 5) = Θ(τ 2 w r,t, ξi 2) Θ( mτ 4 w r,t, ξi 5) = Θ(τ w r,t, ξi 2) In order to satisfy both equations, we require w r,t, ξi = Θ(τ 1m 1/6) and w r,t, µj = Θ(m 1/6), which verifies the scale. We can then verify (50) holds under such condition. With the same argument as in Lemma D.8, we can show wr,t L(W t ), w0 r,t = 0. 3. When n 1 SNR 2 = eΩ(1): we first can bound w r,t, w0 r,t w0 r,t 1 n SNR2 w r,t, ξi ) and derive w r,t 2 = max{Θ(τ 2), Θ(n SNR2)} w r,t, ξi 2 µj 2. We only consider the scale when τ = Θ(n SNR2), where we can simplify w r,t 2 = Θ(n SNR2) w r,t, ξi 2 µj 2 and thus the system becomes Θ( mτ w r,t, ξi 5) = Θ(n SNR2 w r,t, ξi 2) Θ( m w r,t, ξi 5) = Θ( w r,t, ξi 2) which can be satisfied when w r,t, ξi = Θ(m 1/6) and w r,t, µj = Θ(τm 1/6) and thus verify the scale. In this case, we can also verify that (50) holds under the condition that m = Θ(1). With the same argument as in Lemma D.8, we can show wr,t L(W t ), w0 r,t = 0. This concludes the proof that suppose the scales and concentration are the same as the end of second stage, then there exists a stationary point where w r,t, µj / w r,t, ξi = Θ(n SNR2).