# attention_layers_provably_solve_singlelocation_regression__cc499114.pdf Published as a conference paper at ICLR 2025 ATTENTION LAYERS PROVABLY SOLVE SINGLE-LOCATION REGRESSION Pierre Marion Institute of Mathematics EPFL Lausanne, Switzerland pierre.marion@epfl.ch Raphaël Berthier Sorbonne Université, Inria Centre Inria de Sorbonne Université Paris, France raphael.berthier@inria.fr Gérard Biau Sorbonne Université Institut Universitaire de France Paris, France gerard.biau@upmc.fr Claire Boyer Université Paris-Saclay Institut Universitaire de France Orsay, France Attention-based models, such as Transformer, excel across various tasks but lack a comprehensive theoretical understanding, especially regarding token-wise sparsity and internal linear representations. To address this gap, we introduce the singlelocation regression task, where only one token in a sequence determines the output, and its position is a latent random variable, retrievable via a linear projection of the input. To solve this task, we propose a dedicated predictor, which turns out to be a simplified version of a non-linear self-attention layer. We study its theoretical properties, by showing its asymptotic Bayes optimality and analyzing its training dynamics. In particular, despite the non-convex nature of the problem, the predictor effectively learns the underlying structure. This work highlights the capacity of attention mechanisms to handle sparse token information and internal linear structures. 1 INTRODUCTION Attention-based models (Bahdanau et al., 2015), such as Transformer (Vaswani et al., 2017), have achieved unprecedented performance in various learning tasks, including natural language processing (NLP), e.g., text generation (Bubeck et al., 2023), translation (Luong et al., 2015), sentiment analysis (Song et al., 2019; Sun et al., 2019; Xu et al., 2019), and audio/speech analysis (Bahdanau et al., 2016). These developments have led to many architectural and algorithmic variants of attention-based models (see the review by Lin et al., 2022). At a high level, the success of attention has been linked to its ability to manage long-range dependencies in input sequences (Bahdanau et al., 2015; Vaswani et al., 2017), since it consists in computing pairwise dependence between input tokens according to their projection in learned directions, independently of their location in the sequence. On the theoretical front, however, a deeper understanding of attention-based neural networks is still in its infancy. This limited progress is due both to the complexity of the architectures and to the disturbing diversity of relevant tasks. A common approach to tackle these challenges is to introduce a simplified task that models certain features of real-world tasks, followed by demonstrating a simplified version of the attention mechanism capable of solving the task. Prominent examples of this pattern include studying in-context learning with linearized attention (Ahn et al., 2023; von Oswald et al., 2023; Zhang et al., 2024), topic understanding with single-layer attention and alternate minimization scheme (Li et al., 2023b), learning spatial structure with positional attention (Jelassi et al., 2022), next-token prediction with latent bigram (Bietti et al., 2023; Tian et al., 2023) or causal graph (Nichani et al., 2024) structures, and sparse token selection (Wang et al., 2024). We refer to Appendix F for additional discussion on some of these related works. While these works shed light on some abilities of Transformer, they do not encompass all the characteristics of tasks where Published as a conference paper at ICLR 2025 Transformer performs well, in particular in NLP. Two features of particular interest, which to our knowledge have not been addressed in previous theoretical studies on Transformer, are token-wise sparsity, where relevant information is contained in a limited number of tokens, and internal linear representations, which are interpretable representations of the input constructed by the model. Contributions. To understand why attention is a suitable architecture for addressing these features, we introduce single-location regression, a novel statistical task where attention-based predictors excel (Section 2). In a nutshell, this task is a regression problem with a sequence of tokens as input. The key novelty is that only one token determines the prediction, and the location of this token is a latent random variable that changes based on the input sequence. Consequently, solving the task requires first identifying the location of the relevant token, which can be done by learning a latent linear projection, followed by performing regression on that token. To tackle this problem, we propose a dedicated predictor, which turns out to be a simplified version of a non-linear self-attention layer. We show that this attention-based predictor is asymptotically Bayes optimal, whereas more standard linear regressors fail to perform better than the null predictor. We then analyze the training dynamics of the proposed predictor, when trained to minimize the theoretical risk by projected gradient descent. Despite the non-convexity of the problem and the non-linearity of this transformer-based method, we show that the learned predictor successfully retrieves the underlying structure of the task and thus solves single-location regression. Organization. Section 2 presents the mathematical framework of single-location regression, followed by motivations from language processing. Section 3 is dedicated to defining our predictor and explaining its connection with attention. We then move on to the mathematical study, from both statistical (Section 4) and optimization (Section 5) points of view. Section 6 concludes the paper. 2 SINGLE-LOCATION REGRESSION TASK In this section, we describe our statistical task, and connect it to language processing motivations. 2.1 STATISTICAL SETTING We consider a regression scenario where the inputs are sequences of L random tokens1 (X1, . . . , XL) taking values in Rd. The output Y R is assumed to be given by Y = X J0v + ξ, (Plearn) where J0 is a latent discrete random variable on {1, . . . , L} and, conditionally on J0, ( XJ0 N q d 2k , γ2Id Xℓ N(0, Id) for ℓ = J0 . In the above formulation, N(µ, Σ) denotes the normal distribution with expectation µ and covariance matrix Σ, and Id is the identity matrix of size d d. All vectors are considered as column matrices, and the noise term ξ is assumed to be a centered random variable independent of X and J0, with finite second-order moment ε2. Conditionally on J0, the tokens (Xj)1 j L are assumed to be independent. The parameters of the regression problem (Plearn) are the unknown vectors k and v , both assumed to be on the unit sphere Sd 1 in dimension d, i.e., k 2 = v 2 = 1. The output is determined by a specific token in the sentence, indexed by the discrete random variable J0 on {1, . . . , L}. This token can be detected via its mean, which is proportional to k , contrarily to the others which have zero mean. Once XJ0 is identified, the prediction is formed as a linear projection in the direction v . Therefore, the originality and difficulty of this task lies in the fact that the response Y is linearly related to a single informative token XJ0, whose location varies from sequence to sequence in this sense, the problem is sparse, but with a random support. A knee-jerk reaction would be to fit a linear model to the pair (X 1 , . . . , X L , Y ). One might also consider tackling the problem with classical statistical approaches dedicated to sparsity, such 1For the sake of simplicity, we interchangeably use the terms token and embedding , although they have slightly different meanings in the NLP community. Published as a conference paper at ICLR 2025 as a Lasso estimator or a group-Lasso technique (Hastie et al., 2009). However, as we will see (in Section 4), all linear predictors fail due to the unknown and changing location of J0. We note in addition that E[ Xℓ 2 2] = d when ℓ = J0, while E[ XJ0 2 2] = d/2 + γ2d. Therefore, choosing γ2 = 1/2 implies that tokens have the same squared norm in expectation, whether they are discriminatory of not. This shows that any approach based on comparing the magnitude of the tokens does not yield meaningful results. Ultimately, it is necessary to implement a more sophisticated approach, capable of taking into account the characteristics of the problem. 2.2 LANGUAGE PROCESSING MOTIVATION The structure of the task (Plearn) is motivated by natural language processing (NLP), and more specifically by two features, token-wise sparsity and internal linear representations, as we detail next. Birds flying high, you know how I feel? And, I m feeling awful. What have I become my sweetest friend? Everyone I know goes ill in the end. Birds flying high, you know how I feel? And, I m feeling good. What have I become my sweetest friend? Everyone I know goes well in the end. (a) Examples of input-output pairs. The input is a text containing two sentences (e.g., a question and an answer), and the task is to perform sentiment analysis only for the second sentence. The Y output is symbolized here by a color code, where green (resp. red) corresponds to positive (resp. negative) feelings. The relevant information is sparse, typically concentrated in a single token: changing the grey token flips the output. Initial embeddings Train set Test set Test OOD structure + tokens (b) Accuracy of logistic regression on embeddings of [CLS] tokens in the hidden layers of a pretrained Transformer model. Initial embeddings of [CLS] (at layer 0) are not context-aware, so they have a pure-chance accuracy of 50%. In hidden layers, the [CLS] token contains a representation of the sentence that achieves high scores and is robust to out-of-distribution changes in token distribution and sentence structure. Figure 1: A simple sentiment analysis task with synthetic data, which exemplifies (a) token-wise sparsity and (b) internal linear representations. We refer to Appendix E for details on the experiment. Token-wise sparsity. In language tasks, the relevant information is often contained in few tokens, where we recall that tokens correspond to small text units (typically, words or subwords), which are embedded in Rd using a learned dictionary. This sparsity is revealed by the success of sparse attention (Martins & Astudillo, 2016; Niculae & Blondel, 2017; Correia et al., 2019; Child et al., 2019; Jaszczur et al., 2021; Kim et al., 2022; Farina et al., 2024), which is competitive with full attention while attending to fewer tokens. As an illustration, we consider a simple sentiment analysis task in Figure 1a, and observe that changing one token flips the output. This is modeled in (Plearn) by having the output Y depend on a single token J0, whose location furthermore varies with the input. Internal linear representations. Linear projections of internal representations of Transformer (a.k.a. linear probing) contain interpretable information (Bolukbasi et al., 2021; Burns et al., 2023; Li et al., 2023a). Such a linear structure is also present in the learned token embeddings that are fed as input to language models (Mikolov et al., 2013a;b; Bolukbasi et al., 2016; Nanda et al., 2023; Wen-Yi & Mimno, 2023). In our task (Plearn), the two directions k and v have to be learned by the model in order to solve the task. Figure 2 gives an example of possible such directions for the toy task described above. While this illustration relies on initial embeddings, similar structures also appear in the intermediate representations of Transformer. This is shown in Figure 1b, where we observe that pretrained Transformer architectures indeed build internal representations that are sufficient to solve the task with a linear classifier. Published as a conference paper at ICLR 2025 Alignment with How are you doing sweetheart? To say the least, devastated. Word embeddings Positional encodings Token embeddings Alignment with Figure 2: Modeling of an NLP task within our statistical setting (Plearn). The token embeddings X1, . . . , XL are constructed by adding the embeddings of each word and a positional encoding. For illustration purposes, we assume that each token corresponds to a word, and that the positional encoding solely depends on the part of the sentence (before or after the question mark), which differs from usual practice. Then, let the direction k encode both the notion of sentiment and the position in the second part of the sentence. Thus only the last token of the sentence is aligned (positively) with k , and we have J0 = L. As for v , it encodes whether the word is associated with a positive or negative sentiment. Note that several tokens are positively or negatively aligned with v , but the output Y only depends on the token J0. This illustrates the interest of having two latent directions k and v , one that filters the informative token and one that aligns with the output Y . We acknowledge that our statistical task presents limitations such as fixed sequence length, independent tokens, and output depending only on a single token. More complex models could be considered, but at significant technical cost. Moreover, as argued above, our problem (Plearn) preserves interesting aspects of NLP tasks, which makes it relevant for theoretical study of Transformer. Furthermore, it is an original statistical task requiring the implementation of a customized estimation strategy. It is precisely in this context that attention models prove their effectiveness, as we show next. 3 AN ATTENTION-BASED PREDICTOR TO SOLVE THE REGRESSION TASK In this section, we propose a predictor adapted to the problem (Plearn) and discuss its connection with attention. In order to make our point as clear as possible, the construction is divided into three steps. We represent the input sequence in a matrix format X RL d, where X = (X1|X2| |XL) . Step 1: An oracle non-differentiable predictor. If the vectors (k , v ) (Sd 1)2 were known, then a natural procedure to solve the task (Plearn) would be to predict Y from X via T(X) = (Xv )j0(X) = X j0(X)v , where j0(X) = arg max 1 ℓ L (Xk )ℓ. (1) The arg max part detects the location J0 by exploiting the fact that all Xℓhave zero mean except XJ0, while the Xv part exploits the linear relationship Y = X J0v + ξ. In a more compact format, this ideal predictor can be rewritten as T(X) = PL ℓ=1 1arg max(Xk )=ℓ(Xv )ℓ, which is a linear regression in the direction v with non-differentiable weights depending on k . Step 2: A trainable predictor. In practice, the vectors k and v are unknown and must be estimated from the data. In addition, the non-differentiability of the arg max function poses significant Published as a conference paper at ICLR 2025 optimization challenges. To solve this problem, the most common approach in machine learning is to replace arg max with a softmax function with inverse temperature λ > 0, i.e., for z = (z1, . . . , z L) RL, [softmax(λz)]j = eλzj/PL ℓ=1 eλzℓ. This leads us to the model T (soft,k,v) λ (X) = ℓ=1 [softmax(λXk)]ℓ(Xv)ℓ= softmax λ Xk |{z} L 1 Xv |{z} L 1 , (2) where k, v Sd 1, and the superscript soft is used to indicate the presence of the softmax function. Step 3: The final predictor. The softmax nonlinearity, by inducing a coupling between all tokens, significantly complicates the mathematical analysis. To alleviate this difficulty, we replace it by the component-wise nonlinear function erf(z) = 2 π R z 0 e t2dt, which is differentiable, increasing on R, and such that erf( ) = 1 and erf( ) = 1. We are therefore led to our operational model T (k,v) λ (X) = erf λXk Xv = ℓ=1 erf λX ℓk X ℓv , (3) where the erf function is applied component-wise. The choice of this activation function enables closed-form expectations for functions of Gaussian random variables (see, e.g., Lemma 18). Note that the role of softmax in attention is an open question in the community. Several empirical papers investigate simplifying softmax into a component-wise nonlinearity (Qin et al., 2022; Shen et al., 2023; Wortsman et al., 2023; Ramapuram et al., 2024), and have observed a similar performance. These works emphasize the importance of the normalization λ when replacing softmax, which we also find out to play an important role (see Corollary 2 and Section 5). Connection to attention. It turns out that our estimation method finds a natural interpretation in terms of attention models. To see this, consider a model consisting of a single attention layer with a single head (Vaswani et al., 2017) T (Q,K,V,O) λ (X) = softmax λ XQ |{z} L p K X | {z } p L XV |{z} L p O |{z} p o , (4) where the dimensions p, o N are hyperparameters of the model, the softmax function is applied row by row, Q, K, V Rd p and O Ro p are the regular query, key, value, and output matrices, and λ is usually taken to be 1/ p. In practice, the attention head is added to X via a skip connection, which enforces o = d. In a nutshell, K detects which tokens are relevant in the sentence, V encodes the regression coefficient, and Q encodes where to store the information. In a supervised context, it is classical in practice to concatenate in first position an additional token [CLS] to the tokenized sentence X (see, e.g., Devlin et al., 2019). In this context, only the first coordinate of the output is used for the prediction task. Thus, we focus on the first row of (4), corresponding to the embedding of [CLS], namely T (Q,K,V,O) λ (X)1 = softmax λ a K X XV O , (5) with a = X [CLS]Q R1 p, where X[CLS] Rd denotes the embedding of the [CLS] token. It is important to note that only considering the first output coordinate is a mathematically valid simplification for a single attention layer, but not when multiple layers are stacked, as all coordinates of the attention output contribute. Nevertheless, even in this latter more realistic case, the [CLS] token or the similar concepts of attention sinks and registers has been empirically shown to play a crucial role (Clark et al., 2019; Darcet et al., 2024; Xiao et al., 2024). This is also confirmed by our experiment in Figure 1b, where we show that the [CLS] token in pretrained Transformer architectures stores an internal representation of the sentence that is sufficient to solve simple NLP tasks with a linear classifier. This further motivates the need to understand how information is stored in this token. It turns out that there is a direct connection between the model T (soft,k,v) λ (X) defined in (2) and the attention model T (Q,K,V,O) λ (X)1 described in (5). To see this, take o = 1, to adapt the model (5) for univariate regression, and set p = 1, a reasonable assumption given both empirical and theoretical Published as a conference paper at ICLR 2025 evidence suggesting that Transformer parameter matrices are low-rank (Aghajanyan et al., 2021; Kajitsuka & Sato, 2024). Then, let Q Rd 1 be any vector with positive correlation with X[CLS] (for instance it suffices to take Q = X[CLS]), and O = 1. We then deduce that T (Q,K,V,O) λ (X)1 = T (soft,K,V ) λX [CLS]Q (X) . In other words, the attention layer (5) matches the considered predictor in (2) with a softmax inverse temperature proportional to the scalar product between X[CLS] and Q. Thus, our results, in particular the study of the training dynamics in Section 5, can be seen as a model of how Transformer builds internal representations of the input during training. This is also supported by numerical experiments showing that Transformer layers behave similarly to our predictor (see Appendix E). 4 RISK OF THE ORACLE AND OF THE LINEAR PREDICTORS Now that we have constructed our predictor T (k,v) λ (see Eq. (3)), a first key question is to assess its statistical performance. Recall that k, v Sd 1 are the two parameters of the model, and their purpose is to approximate their theoretical counterparts k and v defined in (1). This begs in particular the question of the performance of the oracle predictor T (k ,v ) λ . To answer these questions, we introduce the risk of the predictor, which is measured by the mean squared error Rλ(k, v) = E h Y T (k,v) λ (X) 2i . (6) To proceed with the analysis, we make the following assumption. Assumption 1. The vectors k , v Sd 1 are orthogonal, i.e., k v = 0. This assumption is made everywhere in the remainder of the paper, even though it is not reminded explicitly at each result. It is a relatively mild assumption in a high-dimensional setting where any two independent vectors uniformly distributed on the sphere are close to being orthogonal. Oracle predictor. Our first result characterizes the risk of the proposed transformer model (3) with oracle parameters (k , v ). All the proofs of the paper are deferred to the Appendix. Theorem 1. There exists a function R< λ : R5 R such that, for any (k, v) (Sd 1)2, Rλ(k, v) = R< λ (κ, ν, θ, η, ρ) , where κ := k k , ν := v v , θ := v k , η := k v , and ρ := k v. A closed-form expression of R< λ is given in Appendix C. In particular, Rλ(k , v ) = R< λ (1, 1, 0, 0, 0) = γ2 2γ2 erf λ d 2(1 + 2λ2γ2) d 2, λ2γ2 + (L 1)ζ(0, λ2) + ε2 , where, for t, γ R, ζ(t, γ2) := E erf2(t + G) , G N(0, γ2) . (7) This result is fundamental for the analysis of gradient descent studied in the next section since it reduces the dimension of the dynamical system defined by the optimization dynamics. Before delving into the optimization analysis, we study below the statistical optimality of the estimator Rλ(k , v ) and its comparison with linear regression. Asymptotic Bayes optimality. Let us start by observing that the Bayes risk associated with problem (Plearn) is larger than ε2, which follows from elementary properties of the conditional expectation (Le Gall, 2022, Chapter 11). Indeed, using the Pythagorean theorem, one easily shows that E[(Y E[Y |X])2] E[(Y E[Y |X, J0])2] = E[ξ2] = ε2 . (8) Then, the following corollary to Theorem 1 shows that the oracle predictor achieves the Bayes-optimal risk in the asymptotic scaling L 1/λ2 d. Published as a conference paper at ICLR 2025 Corollary 2. Assume a joint asymptotic scaling where d and L = o(d). Taking λ such that λ L 0, we have Rλ(k , v ) ε2 . Thus, in this asymptotic regime, the oracle predictor T (k ,v ) λ is asymptotically Bayes optimal. Note that Corollary 2 holds for any finite L N>0, but L may also tend to infinity, as long as L = o(d). Let us give an intuition on why this result holds and where the scalings of L and λ intervene. The oracle predictor can be decomposed as T (k ,v ) λ (X) = X J0v | {z } =E[Y |X,J0] erf(λX J0k | {z } j =J0 X j v | {z } =Θ(1) erf(λX j k | {z } =Θ(λ) With the scaling λ d , the argument of the first erf nonlinearity diverges to infinity with d. Thus it reaches the saturating part of erf, so the first term in (9) converges to E[Y |X, J0]. On the other hand, the argument of the erf nonlinearities inside the sum are of order λ = o(1). Thus they are in the linear part of erf. Therefore, the sum consists of L 1 independent terms, each of magnitude λ. As a consequence, by the central limit theorem, the whole sum is of order Θ(λ L), and we get T (k ,v ) λ (X) E[Y |X, J0] + Θ(λ Due to the scaling λ L 0, the second term decays to zero, and the oracle predictor implements the conditional expectation of Y given X and J0. This is the best that we can hope for: the predictor succeeds in inferring the latent variable J0, then gives the best possible prediction of Y given X and J0. We also see the crucial role played by the nonlinearity of erf, whose linear part acts for j = J0 and saturating part for j = J0. In particular, the reasoning would not hold for a linear activation. Linear model. The asymptotic optimality of our oracle predictor is particularly striking in comparison to the risk of the optimal linear predictor. More precisely, let β arg min β Rd L E h (Y (X 1 , . . . , X L )β)2i be the optimal linear predictor for the regression task (Plearn). Its associated risk is R(β ) = E (Y (X 1 , . . . , X L )β )2 . Both the optimal predictor and its risk can be made explicit as follows. Proposition 3. Let pj = P(J0 = j) for j {1, . . . , L}. Then the optimal linear predictor is parameterized by β = (b1v , . . . , b Lv ), with bj = γ2pj 1+pj(γ2 1), and its risk is R(β ) = ε2 + γ2 γ4 L X p2 j 1 + pj(γ2 1) . In particular, R(β ) ε2 + γ2 γ2(γ2 + 1) max j=1,...,L pj . This result calls for a few comments. If the number of tokens is L = 1 or if J0 is a constant location (meaning that one pj is equal to 1 while the others are equal to 0), then the learning problem (Plearn) corresponds to a standard linear regression. In this case, R(β ) = ε2, and the linear predictor (X1, . . . , XL) 7 (X 1 , . . . , X L )β achieves the Bayes risk. At the other end of the spectrum, in the case where J0 is uniform over {1, . . . , L}, the formula for the risk of the linear predictor simplifies to R(β ) = ε2 + γ2 γ4 γ2+L 1. When L , this risk tends to ε2 + γ2, that is, the performance of the null predictor. In other words, the optimal linear predictor performs no better than always predicting zero. More generally, this conclusion is true in any limit where L and max pj 0. This can be explained by the fact that the location of the relevant token for prediction is random, varying from sentence to sentence. Unable to leverage this latent information, the linear regressor balances all its coefficients, resulting in poor prediction performance. This stands in sharp contrast to Corollary 2, which shows that the oracle predictor T (k ,v ) λ is able to account for the complexity of Published as a conference paper at ICLR 2025 0 200 400 600 800 1000 d L = 5 L = 20 L = 50 Figure 3: Risk of the oracle predictor (Theorem 1, solid lines) and of the best linear predictor (Proposition 3, dashed lines), depending on the dimensions d and L. The oracle predictor outperforms the linear predictor when scaling d. We take ε2 = 0, γ = 1/ 2, λ = 1/d0.4, and all pj equal to 1/L. the task, at least asymptotically. This is also illustrated by Figure 3, which compares the value of the risks given by Theorem 1 and Proposition 3. Naturally, implementing the attention-based oracle predictor T (k ,v ) λ requires knowledge of the parameters k and v . Our goal in the next section is therefore to show that gradient descent is able to recover these parameters. 5 GRADIENT DESCENT PROVABLY RECOVERS THE ORACLE PREDICTOR This section is devoted to the analysis of the optimization dynamics in (k, v) (Sd 1)2 of the risk Rλ(k, v) = E h Y T (k,v) λ (X) 2i = E h Y erf λXk Xv 2i . We emphasize that Rλ(k, v) is a theoretical risk, which depends on the distribution of the pair (X, Y ) (defined in Section 2). In practice, an empirical version of this risk is minimized. As we show experimentally (see Figure 5), the stochastic dynamics induced by the empirical version of the risk are qualitatively similar to the deterministic dynamics of the theoretical risk. In the remainder of the article, we focus on the theoretical risk for simplicity, and leave the empirical risk for future research. Our optimization method is the Projected (Riemannian) Gradient Descent (PGD), described below. Definition 1 (PGD). Given an initialization (k0, v0) (Sd 1)2, a step size α > 0, and an inverse temperature sequence (λt)t 0, the sequence (kt, vt)t 0 (Sd 1)2 is recursively defined by kt+1 = Proj Sd 1(kt α(Id ktk t ) k Rλt(kt, vt)) , vt+1 = Proj Sd 1(vt α(Id vtv t ) v Rλt(kt, vt)) , (10) where Proj Sd 1 : x 7 x/ x 2 denotes the Euclidean projection on the unit sphere of Rd. The operators (Id ktk t ) and (Id vtv t ) correspond to Riemannian gradient descent (Boumal, 2023, Section 4.3), meaning that we compute the gradient of the risk on the Riemannian manifold (Sd 1)2. In other words, the gradient step is performed on the tangent space to the sphere at the current iterate. This is a precaution we are taking because, in the analysis of the dynamics, we rely on an expression of the risk (6) that is valid only on this manifold. In addition, this ensures that the subsequent projection on Sd 1 is always well-defined, despite the fact that the sphere is a non-convex set, because iterates always avoid the pathological cases k = 0 or v = 0. Experimentally, we observe in Figure 4a that PGD is able to recover the oracle parameters (k , v ). Note that running the PGD iterates (10) involves computing the gradients k Rλt(kt, vt) and v Rλt(kt, vt), which is non-trivial a priori. A direct approach using Monte Carlo simulations would require a large number of sample points to reduce variance, which is computationally intractable in particular in high-dimension, and in any case gives an approximate result. Instead, we leverage our closed form formula for R< λ from Theorem 1 to get exact values for the gradients (up to numerical errors). Interestingly, we also observe in Figure 4a that v aligns with v much faster than k aligns with k . This is typical of two-timescale dynamics, which is a common framework in analysis of non-convex learning dynamics (Heusel et al., 2017; Dagréou et al., 2022; Hong et al., 2023; Marion & Berthier, 2023; Berthier et al., 2024; Marion et al., 2024). Published as a conference paper at ICLR 2025 0 50000 100000 Step Excess risk 0 50000 100000 Step Alignment with oracle parameters 0 50000 100000 Step Distance to the manifold M (a) From a random initialization on (Sd 1)2. 0 10000 20000 Step Excess risk 0 10000 20000 Step Alignment with oracle parameters 0 10000 20000 Step Distance to the manifold M (b) From a random initialization on M (see Eq. (11)). Figure 4: Convergence of PGD to the oracle parameters. Left: Excess risk as a function of the number of steps. Middle left: Alignment |κ| = |k k | and |ν| = |v v | with the oracle parameters. Middle right: Trajectories of κ and ν in two repetitions of the experiments. Each repetition corresponds to a color, the trajectory starts in the middle and ends at a corner of the plot. Right: Distance to the invariant manifold M. In all plots except the middle right ones, the experiment is repeated 30 times with independent random initializations, and 95% percentile intervals are plotted (but are not visible when the variance is too small). Parameters are d = 400, L = 10, γ = p 1/2, and (a) λt = 1/(1 + 10 4t), (b) λt = 0.1. More details are given in Appendix E. Moving on to the mathematical study, even with the formula for R< λ , a full analysis of the dynamics (10) is difficult. For instance, the dynamics (10) can be formulated in terms of the five variables of R< λ , but then one needs to study a 5-dimensional highly nonlinear dynamical system. In the following, we consider the case where the parameters are initialized on the submanifold of (Sd 1)2 M = {(k, v) Sd 1 Sd 1, k v = 0, v k = 0, k v = 0} . (11) We introduce this manifold on the one hand owing to the observation in Figure 4a (right) that the dynamics converge to this manifold even when initialized on the sphere, and on the other hand because this allows to reduce the problem to a lower-dimensional subspace and to simplify the expression of the risk. Clearly, due to Assumption 1, the oracle parameters (k , v ) belong to M. A first key property of this manifold is invariance under the PGD dynamics. Lemma 4. The manifold M is invariant under the PGD dynamics (10), in the sense that if (kt, vt) M, then (kt+1, vt+1) M. This lemma shows that, if the initialization is taken on the manifold, then it is enough to understand the dynamics on the manifold to conclude. Such analysis on the manifold is tractable. This yields Theorem 5, our main result, which shows that the sequence (kt, vt)t 0 converges to the oracle values (k , v ) (up to a sign) as t , for any small enough step size, and a constant inverse temperature. Theorem 5. Take a constant inverse temperature λt λ > 0. Then there exists α > 0 such that, for any step size α α, and for a generic initialization (k0, v0) M, (kt, vt) t (k , v ). This result shows that, despite the non-convexity of the risk, the attention layer trained by PGD can recover the underlying structure of the problem. Convergence to (k , v ) or ( k , v ) is not at all problematic, since T (k ,v ) λ = T ( k , v ) λ by symmetry of the erf function. Furthermore, recovery is guaranteed for a generic initialization on M, in the sense that the pathological pairs (k0, v0) M such that PGD fails to recover the oracle parameters are of Lebesgue measure zero. The results of Theorem 5 are illustrated by Figure 4b. We observe that, due to roundoff errors, the dynamics are not exactly on the manifold but stay very close to the manifold. Published as a conference paper at ICLR 2025 We emphasize that the manifold M depends on the unknown parameters k and v , making it impractical to initialize directly on the manifold. If the initialization is not on M, more diverse phenomena are possible. As already pointed out in Figure 4a, it is possible to obtain recovery of (k , v ) and convergence to the manifold M from a general initialization on the sphere. This suggests that our analysis on the manifold is relevant, and completing the analysis for a general initialization is left for future work. However, we note that using a decreasing inverse temperature sequence λt is crucial for the recovery of (k , v ) when initialized out of M. Indeed, to the best of our experiments, an iteration-independent choice of λ does not consistently lead to the recovery of k and v in this case (see Appendix E). This contrasts with the dynamics on the manifold proven in Theorem 5. To investigate these behaviors, a fruitful direction would be to investigate the (local) stability of the manifold M for the PGD dynamics. If the manifold is indeed stable, one can hope to transfer the analysis on the manifold to dynamics initialized close to the manifold. Furthermore, recall that, in high dimension, random vectors on the sphere are close to being orthogonal. Thus, with high probability, a uniform initialization in (Sd 1)2 falls in the neighborhood of the manifold M, so that the local analysis should allow to conclude. The proof of the theorem relies on a detailed analysis of the dynamics of the PGD algorithm on the invariant manifold M, in particular the properties of its stationary points. These arguments, which lie at the intersection of dynamical systems and topology, are of independent interest. A key idea is to reduce the problem to a two-dimensional system depending only on κ = k k and ν = v v . Finally, numerical experiments show that a full Transformer layer is able to solve the single-location regression task. Similarly to our simplified predictor, the weights align with the oracle parameters k and v . This supports the connection drawn in Section 3 between our predictor and attention layers. We refer to Appendix E for details and plots, as well as experiments on multiple-location regression, a variant of single-location regression where the output depends on several tokens. 0 100000 200000 Step Excess risk 0 100000 200000 Step Alignment with oracle parameters 0 100000 200000 Step Distance to the manifold M Figure 5: Convergence of online stochastic PGD to the oracle parameters from a random initialization on (Sd 1)2. Left: Excess risk as a function of the number of steps. Middle left: Alignment |κ| = |k k | and |ν| = |v v | with the oracle parameters. Middle right: Trajectories of κ and ν in two repetitions of the experiment. Each repetition corresponds to a color, the trajectory starts in the middle and ends at a corner of the plot. Right: Distance to the invariant manifold M. In all plots except the middle right one, the experiment is repeated 30 times with independent random initializations, and 95% percentile intervals are plotted. Parameters are d = 80, L = 10, γ = p 1/2, λt = 2/(1 + 10 4t), and a batch size of 5. More details are given in Appendix E. 6 CONCLUSION This paper introduced single-location regression, a novel statistical task where the relevant information in the input sequence is supported by a single token. We analyzed the statistical properties and optimization dynamics of a natural estimator for this task, which can be seen as a basic attention layer. We hope this work encourages further research into how Transformer architectures address sparsity and long-range dependencies, while simultaneously constructing internal linear representations of their input an aspect with significant implications for interpretability. Beyond NLP, potential applications include problems connected to sparse sequential modeling such as anomaly detection in time series. A natural extension of our framework is when relevant information is spread across a few input tokens rather than just one, which relates to multi-head attention. Future mathematical analyses should also consider extensions to general initialization schemes and stochastic dynamics. Our experiments (Figures 4a, 5, and Appendix E) yield encouraging results in all these directions. Published as a conference paper at ICLR 2025 ACKNOWLEDGMENTS Authors thank Peter Bartlett, Linus Bleistein, Alex Damian, Spencer Frei, and Clément Mantoux for fruitful discussions and feedback. P.M. is supported by a Google Ph D Fellowship. Armen Aghajanyan, Sonal Gupta, and Luke Zettlemoyer. Intrinsic dimensionality explains the effectiveness of language model fine-tuning. In C. Zong, F. Xia, W. Li, and R. Navigli (eds.), Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pp. 7319 7328. Association for Computational Linguistics, 2021. Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp. 45614 45650. Curran Associates, Inc., 2023. Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. In Y. Bengio and Y. Le Cun (eds.), 3rd International Conference on Learning Representations, 2015. Dzmitry Bahdanau, Jan Chorowski, Dmitriy Serdyuk, Philémon Brakel, and Yoshua Bengio. End-toend attention-based large vocabulary speech recognition. In 2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 4945 4949, 2016. Raphaël Berthier, Andrea Montanari, and Kangjie Zhou. Learning time-scales in two-layers neural networks. Foundations of Computational Mathematics, pp. 1 84, 2024. Alberto Bietti, Vivien Cabannes, Diane Bouchacourt, Hervé Jégou, and Léon Bottou. Birth of a transformer: A memory viewpoint. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp. 1560 1588. Curran Associates, Inc., 2023. Tolga Bolukbasi, Kai-Wei Chang, James Zou, Venkatesh Saligrama, and Adam Kalai. Man is to computer programmer as woman is to homemaker? Debiasing word embeddings. In D. Lee, M. Sugiyama, U. von Luxburg, I. Guyon, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 29, pp. 4356 4364. Curran Associates, Inc., 2016. Tolga Bolukbasi, Adam Pearce, Ann Yuan, Andy Coenen, Emily Reif, Fernanda Viégas, and Martin Wattenberg. An interpretability illusion for BERT. ar Xiv:2104.07143, 2021. Nicolas Boumal. An Introduction to Optimization on Smooth Manifolds. Cambridge University Press, Cambridge, 2023. James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, George Necula, Adam Paszke, Jake Vander Plas, Skye Wanderman-Milne, and Qiao Zhang. JAX: composable transformations of Python+Num Py programs, 2018. URL http://github.com/jax-ml/jax. Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horvitz, Ece Kamar, Peter Lee, Yin Tat Lee, Yuanzhi Li, Scott Lundberg, et al. Sparks of artificial general intelligence: Early experiments with GPT-4. ar Xiv:2303.12712, 2023. Collin Burns, Haotian Ye, Dan Klein, and Jacob Steinhardt. Discovering latent knowledge in language models without supervision. In The Eleventh International Conference on Learning Representations, 2023. Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever. Generating long sequences with sparse transformers. ar Xiv:1904.10509, 2019. Lénaïc Chizat, Edouard Oyallon, and Francis Bach. On lazy training in differentiable programming. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 32. Curran Associates, Inc., 2019. Published as a conference paper at ICLR 2025 Kevin Clark, Urvashi Khandelwal, Omer Levy, and Christopher D. Manning. What does BERT look at? An analysis of BERT s attention. In T. Linzen, G. Chrupała, Y. Belinkov, and D. Hupkes (eds.), Proceedings of the 2019 ACL Workshop Blackbox NLP: Analyzing and Interpreting Neural Networks for NLP, pp. 276 286. Association for Computational Linguistics, 2019. Gonçalo M. Correia, Vlad Niculae, and André F.T. Martins. Adaptively sparse transformers. In K. Inui, J. Jiang, V. Ng, and X. Wan (eds.), Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pp. 2174 2184. Association for Computational Linguistics, 2019. Mathieu Dagréou, Pierre Ablin, Samuel Vaiter, and Thomas Moreau. A framework for bilevel optimization that enables stochastic and global variance reduction algorithms. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh (eds.), Advances in Neural Information Processing Systems, volume 35, pp. 26698 26710. Curran Associates, Inc., 2022. Timothée Darcet, Maxime Oquab, Julien Mairal, and Piotr Bojanowski. Vision transformers need registers. In The Twelfth International Conference on Learning Representations, 2024. Richard D. De Veaux. Mixtures of linear regressions. Computational Statistics & Data Analysis, 8: 227 245, 1989. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4171 4186. Association for Computational Linguistics, 2019. Mirko Farina, Usman Ahmad, Ahmad Taha, Hussein Younes, Yusuf Mesbah, Xiao Yu, and Witold Pedrycz. Sparsity in transformers: A systematic literature review. Neurocomputing, 582:127468, 2024. Trevor Hastie, Robert Tibshirani, and Jerome Friedman. The Elements of Statistical Learning. Data Mining, Inference, and Prediction. Springer, New York, 2 edition, 2009. Bobby He and Thomas Hofmann. Simplifying transformer blocks. In The Twelfth International Conference on Learning Representations, 2024. Bobby He, James Martens, Guodong Zhang, Aleksandar Botev, Andrew Brock, Samuel L Smith, and Yee Whye Teh. Deep transformers without shortcuts: Modifying self-attention for faithful signal propagation. In The Eleventh International Conference on Learning Representations, 2023. Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a local nash equilibrium. In I. Guyon, U. von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. Mingyi Hong, Hoi-To Wai, Zhaoran Wang, and Zhuoran Yang. A two-timescale stochastic algorithm framework for bilevel optimization: Complexity analysis and application to actor-critic. SIAM Journal on Optimization, 33:147 180, 2023. Sebastian Jaszczur, Aakanksha Chowdhery, Afroz Mohiuddin, Lukasz Kaiser, Wojciech Gajewski, Henryk Michalewski, and Jonni Kanerva. Sparse is enough in scaling transformers. In M. Ranzato, A. Beygelzimer, Y. Dauphin, P.S. Liang, and J. Wortman Vaughan (eds.), Advances in Neural Information Processing Systems, volume 34, pp. 9895 9907. Curran Associates, Inc., 2021. Samy Jelassi, Michael Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure. In S. Koyejo, S. Mohamed, A. Agarwal, D. Belgrave, K. Cho, and A. Oh (eds.), Advances in Neural Information Processing Systems, volume 35, pp. 37822 37836. Curran Associates, Inc., 2022. Tokio Kajitsuka and Issei Sato. Are transformers with one layer self-attention using low-rank weight matrices universal approximators? In The Twelfth International Conference on Learning Representations, 2024. Published as a conference paper at ICLR 2025 Sehoon Kim, Sheng Shen, David Thorsley, Amir Gholami, Woosuk Kwon, Joseph Hassoun, and Kurt Keutzer. Learned token pruning for transformers. In Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 784 794. Association for Computing Machinery, 2022. Kenneth Lange. Optimization. Springer, New York, 2 edition, 2013. Jean-François Le Gall. Measure Theory, Probability, and Stochastic Processes. Springer Cham, 2022. Kenneth Li, Oam Patel, Fernanda Viégas, Hanspeter Pfister, and Martin Wattenberg. Inference-time intervention: Eliciting truthful answers from a language model. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp. 41451 41530. Curran Associates, Inc., 2023a. Yuchen Li, Yuanzhi Li, and Andrej Risteski. How do transformers learn topic structure: Towards a mechanistic understanding. In A. Krause, E. Brunskill, K. Cho, B. Engelhardt, S. Sabato, and J. Scarlett (eds.), Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp. 19689 19729. PMLR, 2023b. Tianyang Lin, Yuxin Wang, Xiangyang Liu, and Xipeng Qiu. A survey of transformers. AI Open, 3: 111 132, 2022. Thang Luong, Hieu Pham, and Christopher D. Manning. Effective approaches to attention-based neural machine translation. In L. Màrquez, C. Callison-Burch, and J. Su (eds.), Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing, pp. 1412 1421. Association for Computational Linguistics, 2015. Pierre Marion and Raphaël Berthier. Leveraging the two timescale regime to demonstrate convergence of neural networks. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp. 64996 65029. Curran Associates, Inc., 2023. Pierre Marion, Anna Korba, Peter Bartlett, Mathieu Blondel, Valentin De Bortoli, Arnaud Doucet, Felipe Llinares-López, Courtney Paquette, and Quentin Berthet. Implicit diffusion: Efficient optimization through stochastic sampling. ar Xiv:2402.05468, 2024. Andre Martins and Ramon Astudillo. From softmax to sparsemax: A sparse model of attention and multi-label classification. In M.F. Balcan and K.Q. Weinberger (eds.), Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pp. 1614 1623. PMLR, 2016. Peter Mc Cullagh and John A. Nelder. Generalized Linear Models. Chapman & Hall, London, 2 edition, 1983. Tomas Mikolov, Quoc V. Le, and Ilya Sutskever. Exploiting similarities among languages for machine translation. ar Xiv:1309.4168, 2013a. Tomas Mikolov, Wen-tau Yih, and Geoffrey Zweig. Linguistic regularities in continuous space word representations. In L. Vanderwende, H. Daumé III, and K. Kirchhoff (eds.), Proceedings of the 2013 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 746 751. Association for Computational Linguistics, 2013b. Neel Nanda, Andrew Lee, and Martin Wattenberg. Emergent linear representations in world models of self-supervised sequence models. In Y. Belinkov, S. Hao, J. Jumelet, N. Kim, A. Mc Carthy, and H. Mohebbi (eds.), Proceedings of the 6th Blackbox NLP Workshop: Analyzing and Interpreting Neural Networks for NLP, pp. 16 30. Association for Computational Linguistics, 2023. Eshaan Nichani, Alex Damian, and Jason D. Lee. How transformers learn causal structure with gradient descent. In International Conference on Machine Learning. PLMR, 2024. Published as a conference paper at ICLR 2025 Vlad Niculae and Mathieu Blondel. A regularized framework for sparse and structured neural attention. In I. Guyon, U. von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017. F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay. Scikit-learn: Machine learning in Python. Journal of Machine Learning Research, 12(85):2825 2830, 2011. Mary Phuong and Marcus Hutter. Formal algorithms for transformers. ar Xiv:2207.09238, 2022. William Press, Saul Teukolsky, William Vetterling, and Brian Flannery. Numerical Recipes: The Art of Scientific Computing. Cambridge University Press, Cambridge, 3 edition, 2007. Zhen Qin, Weixuan Sun, Hui Deng, Dongxu Li, Yunshen Wei, Baohong Lv, Junjie Yan, Lingpeng Kong, and Yiran Zhong. cosformer: Rethinking softmax in attention. ar Xiv:2202.08791, 2022. Jason Ramapuram, Federico Danieli, Eeshan Dhekane, Floris Weers, Dan Busbridge, Pierre Ablin, Tatiana Likhomanenko, Jagrit Digani, Zijin Gu, Amitis Shidani, et al. Theory, analysis, and best practices for sigmoid self-attention. ar Xiv:2409.04431, 2024. Kai Shen, Junliang Guo, Xu Tan, Siliang Tang, Rui Wang, and Jiang Bian. A study on Re LU and Softmax in Transformer. ar Xiv:2302.06461, 2023. Michael Shub. Global Stability of Dynamical Systems. Springer, New York, 1987. Youwei Song, Jiahai Wang, Tao Jiang, Zhiyue Liu, and Yanghui Rao. Attentional encoder network for targeted sentiment classification. ar Xiv:1902.09314, 2019. Charles M. Stein. Estimation of the mean of a multivariate normal distribution. The Annals of Statistics, 9:1135 1151, 1981. Chi Sun, Luyao Huang, and Xipeng Qiu. Utilizing BERT for aspect-based sentiment analysis via constructing auxiliary sentence. In J. Burstein, C. Doran, and T. Solorio (eds.), Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 380 385. Association for Computational Linguistics, 2019. Yuandong Tian, Yiping Wang, Beidi Chen, and Simon S. Du. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer. In A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (eds.), Advances in Neural Information Processing Systems, volume 36, pp. 71911 71947. Curran Associates, Inc., 2023. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In I. Guyon, U. von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett (eds.), Advances in Neural Information Processing Systems, volume 30, pp. 6000 6010. Curran Associates, Inc., 2017. Johannes von Oswald, Eyvind Niklasson, Ettore Randazzo, Joao Sacramento, Alexander Mordvintsev, Andrey Zhmoginov, and Max Vladymyrov. Transformers learn in-context by gradient descent. In A. Krause, E. Brunskill, K. Cho, B. Engelhardt, S. Sabato, and J. Scarlett (eds.), Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pp. 35151 35174. PMLR, 2023. Zixuan Wang, Stanley Wei, Daniel Hsu, and Jason D. Lee. Transformers provably learn sparse token selection while fully-connected nets cannot. In International Conference on Machine Learning. PLMR, 2024. Andrea W Wen-Yi and David Mimno. Hyperpolyglot LLMs: Cross-lingual interpretability in token embeddings. In H. Bouamor, J. Pino, and K. Bali (eds.), Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, pp. 1124 1131. Association for Computational Linguistics, 2023. Published as a conference paper at ICLR 2025 Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, Joe Davison, Sam Shleifer, Patrick von Platen, Clara Ma, Yacine Jernite, Julien Plu, Canwen Xu, Teven Le Scao, Sylvain Gugger, Mariama Drame, Quentin Lhoest, and Alexander M. Rush. Transformers: State-of-the-art natural language processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pp. 38 45. Association for Computational Linguistics, 2020. Mitchell Wortsman, Jaehoon Lee, Justin Gilmer, and Simon Kornblith. Replacing softmax with Re LU in vision transformers. ar Xiv:2309.08586, 2023. Guangxuan Xiao, Yuandong Tian, Beidi Chen, Song Han, and Mike Lewis. Efficient streaming language models with attention sinks. In The Twelfth International Conference on Learning Representations, 2024. Hu Xu, Bing Liu, Lei Shu, and Philip Yu. BERT post-training for review reading comprehension and aspect-based sentiment analysis. In J. Burstein, C. Doran, and T. Solorio (eds.), Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 2324 2335. Association for Computational Linguistics, 2019. Ge Yang, Edward Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. Tuning large neural networks via zero-shot hyperparameter transfer. Advances in Neural Information Processing Systems, 34:17084 17097, 2021. Ruiqi Zhang, Spencer Frei, and Peter L. Bartlett. Trained transformers learn linear models in-context. Journal of Machine Learning Research, 25(49):1 55, 2024. Published as a conference paper at ICLR 2025 Organization of the Appendix. Section A presents the main steps of Theorem 5. The intermediate results of this proof, as well as the other statements of the main text, are proven in Section B. Section C provides an expression of the risk R beyond the manifold M that extends the one provided in Lemma 6 on M. Section D gives some useful technical lemmas. Experimental details and additional results are in Section E. Finally, Section F discusses additional related models. Notation. In the whole Appendix, we consider a constant inverse temperature schedule λt λ > 0, as in Theorem 5. For this reason, it is not necessary to make explicit the dependence of Rλ and R< λ on λ, and we use the lighter notations R and R< instead. A OUTLINE OF THE PROOF OF THEOREM 5 This section outlines the essential steps for the proof of Theorem 5. For clarity, the proofs are to be found in Appendix B, except the proof of Proposition 10. Step 1: Invariant manifold & reparameterization. We first show that the risk R(k, v) has a simpler expression when considered on the manifold M. Lemma 6. The risk R(k, v) restricted to M has the form R(k, v) = γ2 2γ2v v erf d 2(1 + 2λ2γ2)k k ! d 2 k k , λ2γ2 + (L 1)ζ(0, λ2) + ε2, where, for t, γ R, ζ(t, γ2) := E erf2(t + G) , G N(0, γ2) . This expression has two main consequences. First, we use it to prove that the manifold M is invariant by PGD, according to Lemma 4. Second, we observe that the risk on the manifold depends on the variables (k, v) Sd 1 Sd 1 only through the two scalar quantities κ = k k and ν = v v . This suggests studying the dynamics in terms of the reduced variables (κ, ν) [ 1, 1]2. More precisely, in the following, we denote by R< the risk function R reparameterized as a function of (κ, ν), i.e., we let R<(κ, ν) = γ2 2γ2ν erf d 2(1 + 2λ2γ2)κ d 2 κ, λ2γ2 + (L 1)ζ(0, λ2) + ε2 . Note that, with a slight abuse of notation, we use R< to denote both the function of five variables (κ, ν, θ, ρ, η) (as in Theorem 1) and the function of only the first two variables (κ, ν). There should be no confusion, as both functions coincide on the manifold M where θ = ρ = η = 0. We also denote the corresponding PGD iterates using this reparameterization by (κt, νt) := (k t k , v t v ). With this notation, the following lemma reformulates the PGD iterations as an autonomous discrete dynamical system in terms of (κt, νt). Lemma 7. When initialized on the manifold M, the PGD iterations (10) can be reformulated in terms of the autonomous discrete dynamical system (κt+1, νt+1) = g(κt, νt) , (12) where the mapping g : [ 1, 1]2 [ 1, 1]2 is given by κ α( κR<(κ, ν))(1 κ2) p 1 + α2( κR<(κ, ν))2(1 κ2) , ν α( νR<(κ, ν))(1 ν2) p 1 + α2( νR<(κ, ν))2(1 ν2) Published as a conference paper at ICLR 2025 1.0 0.5 0.0 0.5 1.0 1.00 Figure 6: Dynamics in (κ, ν) on the manifold M. In (a), the fixed points of the dynamics are represented; the minimizers, saddle point, and maximizers are respectively depicted in yellow, blue and red. In (b), the vector field (κ, ν) 7 ( κR<(κ, ν)(1 κ2), νR<(κ, ν)(1 ν2)) is displayed (the colormap corresponds to the magnitude of the vector field). Step 2: Analysis of the stationary points. Regarding the dynamics restricted to the invariant manifold M, we can characterize the limit points of the PGD iterates as follows. Proposition 8. For a sufficiently small step size α and for any (k0, v0) M, the risk R< is decreasing along the PGD iterates. Furthermore, the distance between successive PGD iterates tends to zero, and, if (κ, ν) is an accumulation point of the sequence of iterates (κt, νt)t 0, then (1 κ2) κR<(κ, ν) = 0 and (1 ν2) νR<(κ, ν) = 0 . (14) We stress that the system (14) of equations corresponds to fixed points of the dynamics (12) (13). We next solve this system of equations. Proposition 9. The points (κ, ν) [ 1, 1]2 satisfying (14) are (κ, ν) = ( 1, 1)2 and (κ, ν) = (0, 0). The identity (κ, ν) = ( 1, 1) corresponds to the situation where the variables (k, v) are aligned (up to sign) with the targets (k , v ). As the next proposition shows, these are the only global minima of R<. Proposition 10. The fixed points of the dynamics can be classified as follows: (i) The points (κ, ν) = ( 1, 1) and (1, 1) are global maxima of R< on [ 1, 1]2. (ii) The points (κ, ν) = (1, 1) and ( 1, 1) are global minima of R< on [ 1, 1]2. (iii) The point (κ, ν) = (0, 0) is a saddle point of R< on [ 1, 1]2. The fixed points of the dynamics as well as the vector field (κ, ν) 7 ( κR<(κ, ν)(1 κ2), νR<(κ, ν)(1 ν2)) are displayed in Figure 6. Step 3: Convergence to global minima. The convergence of the sequence of iterates (κt, νt)t 0 to a global minimum is shown in two stages. First, we show that the iterates converge to one of the five fixed points described in Proposition 10. 2This notation is used to designate any extreme point of the square [ 1, 1]2, i.e., (κ, ν) = (1, 1), (1, 1), ( 1, 1), and ( 1, 1). Published as a conference paper at ICLR 2025 Proposition 11. For a sufficiently small step size α, the sequence of iterates (κt, νt)t 0 converges to one of the five fixed points {( 1, 1), (0, 0)}. Proof. According to Proposition 8, the distance between successive iterates (κt, νt) tends to zero. Therefore, the set of accumulation points of the sequence (κt, νt)t 0 is connected (Lange, 2013, Proposition 12.4.1). Since there is a finite number of possible accumulation points (by Proposition 9), we deduce that the sequence has a unique accumulation point. Furthermore, the sequence belongs to a compact. Thus, it converges, and its limit is one of the five fixed points. It remains to precisely characterize the limit of the sequence (κt, νt)t 0. To this aim, we begin by showing key properties of the gradient mapping g. Proposition 12. For a sufficiently small step size α, the mapping g is a local diffeomorphism around (0, 0), whose Jacobian matrix has one eigenvalue in (0, 1) and one eigenvalue in (1, ). Furthermore, it is injective on [ 1, 1]2, differentiable, and its Jacobian is non-degenerate. These properties enable us to apply the Center-Stable Manifold theorem (Shub, 1987, Theorem III.7), a tool from dynamical systems theory, to deduce the next proposition. Proposition 13. For a sufficiently small step size α, the set of initializations such that the sequence (κt, νt)t 0 converges to ( 1, 1), (1, 1), or (0, 0) has Lebesgue measure zero (with respect to the Lebesgue measure on the manifold M). Combining Proposition 11 and Proposition 13, we conclude that, provided the step size α is chosen small enough, the sequence (κt, νt)t 0 almost surely converges to one of the minimizers, (1, 1) or ( 1, 1). This convergence is almost sure with respect to the Lebesgue measure on the manifold M. Indeed, Proposition 13 ensures that the pathological initializations converging towards a maximizer or a saddle point are of Lebesgue measure zero. This concludes the proof of Theorem 5. The use of the Center-Stable Manifold theorem is crucial to our proof. Unfortunately, this tool does not provide quantitative rates of convergence. Obtaining a rate is a challenging task as it would require quantifying the distance of the iterates to the saddle points of the risk (the dynamics is indeed slower near saddle points), which in turn requires other tools of analysis and potentially additional assumptions. B PROOFS OF THE MAIN RESULTS B.1 PROOF OF LEMMA 6 AND THEOREM 1 We recall the formula for the risk R(k, v) = E h Y ℓ=1 erf(λX ℓk)X ℓv 2i and the data model Y = X J0v + ξ, J0 P({1, . . . , L}) and d 2k , γ2Id Xℓ N(0, Id) for ℓ = J0. In the above expression for the risk, we can condition on the value of J0. Actually, the conditioned risk is independent of J0. Thus in this section, we assume without loss of generality that J0 = 1 a.s.: R(k, v) = E h X 1 v + ξ erf(λX 1 k)X 1 v ℓ=2 erf(λX ℓk)X ℓv 2i , (15) where ( X1 N q d 2k , γ2Id Xℓ N(0, Id) for ℓ 2. Published as a conference paper at ICLR 2025 We rewrite this quantity in terms of multivariate standard Gaussian random variables. Using Assumption 1, we get R(k, v) = E h γ e X 1 v + ξ erf λ r d 2k k + γ e X 1 k r d 2k v + γ e X 1 v ℓ=2 erf(λX ℓk)X ℓv 2i , where e X1, X2, . . . , XL N(0, Id). This can be formulated in terms of the five scalar quantities κ = k k , ν = v v , θ = v k , η = k v , and ρ = k v. Indeed, we have R(k, v) = R<(κ, ν, θ, η, ρ) := E h γGv 1 + ξ r d 2θ + γGv 1 erf λ r d 2κ + γGk 1 ℓ=2 Gv ℓerf(λGk ℓ) 2i , (16) Gv 1 Gv 1 Gk 1 Gv L Gv L Gk L 1 ν η ν 1 ρ η ρ 1 This last expression only involves the five parameters κ, ν, θ, η, ρ, which play a role either explicitly in the function or as parameters of the covariance of the random variables. This proves the first statement of Theorem 1. A computation of a closed-form formula for this expectation is given in Appendix C. On the manifold M defined by θ = η = ρ = 0, we can simplify the expressions (16) (17) R<(κ, ν, 0, 0, 0) = E h γGv 1 + ξ γGv 1erf λ r d 2κ + γGk 1 ℓ=2 Gv ℓerf(λGk ℓ) 2i where Gv 1 Gv 1 , . . . , Gv L Gv L i.i.d. N 0, 1 ν ν 1 , Gk 1, . . . , Gk L i.i.d. N(0, 1), and ξ N(0, ε2) are independent. We first expand in ξ and obtain R<(κ, ν, 0, 0, 0) = ε2 + E h γGv 1 γGv 1erf λ r d 2κ + γGk 1 ℓ=2 Gv ℓerf(λGk ℓ) 2i . We now expand the square, as follows: R<(κ, ν, 0, 0, 0) = ε2 + γ2E h (Gv 1 )2i 2γ2E h Gv 1 Gv 1erf λ r d 2κ + γGk 1 i + γ2E h (Gv 1)2erf2 λ r d 2κ + γGk 1 i ℓ=2 γE h Gv 1 Gv 1erf λ r d 2κ + γGk 1 Gv ℓerf(λGk ℓ) i ℓ,m=2 E[Gv ℓerf(λGk ℓ)Gv merf(λGk m)] . We address each term in this sum separately. Since Gv 1 N(0, 1), γ2E (Gv 1 )2 = γ2. Since Gv 1 Gv 1 N 0, 1 ν ν 1 is independent from Gk 1 N(0, 1), we have 2γ2E h Gv 1 Gv 1erf λ r d 2κ + γGk 1 i = 2γ2E h Gv 1 Gv 1 i E h erf λ r d 2κ + γGk 1 i Published as a conference paper at ICLR 2025 = 2γ2νE h erf λ r d 2κ + γGk 1 i . Finally, using Lemma 18(ii), we obtain 2γ2E h Gv 1 Gv 1erf λ r d 2κ + γGk 1 i = 2γ2νerf λ Since Gv 1, Gk 1 i.i.d. N(0, 1), we have γ2E h (Gv 1)2erf2 λ r d 2κ + γGk 1 i = γ2E h (Gv 1)2i E h erf2 λ r d 2κ + γGk 1 i = γ2E h erf2 λ r d 2κ + γGk 1 i . Using the definition of ζ in Eq. (7), we have γ2E h (Gv 1)2erf2 λ r d 2κ + γGk 1 i = γ2E h (Gv 1)2i E h erf2 λ r d 2κ + γGk 1 i d 2κ, λ2γ2 . For ℓ= 2, . . . , L, (Gv 1 , Gv 1, Gk 1), Gv ℓ, and Gk ℓare independent. Thus E h Gv 1 Gv 1erf λ r d 2κ + γGk 1 Gv ℓerf(λGk ℓ) i = E h Gv 1 Gv 1erf λ r d 2κ + γGk 1 i E h Gv ℓ i E h erf(λGk ℓ) i = 0 , where in the last step we use E[Gv ℓ] = 0. Finally, to tackle the last term, we address the cases ℓ = m and ℓ= m separately. If ℓ = m, as Gv ℓ, Gk ℓ, Gv m, and Gk m are independent, we have E[Gv ℓerf(λGk ℓ)Gv merf(λGk m)] = E[Gv ℓ]E[erf(λGk ℓ)]E[Gv m]E[erf(λGk m)] = 0 . If ℓ= m, as Gv ℓ, Gk ℓ i.i.d. N(0, 1), we have E[(Gv ℓ)2erf2(λGk ℓ)] = E[(Gv ℓ)2]E[erf2(λGk ℓ)] = ζ(0, λ2) . Putting together these computations, we obtain R<(κ, ν, 0, 0, 0) = ε2 + γ2 2γ2νerf λ + (L 1)ζ(0, λ2) . This proves Lemma 6. Taking κ = ν = 1 proves Theorem 1. B.2 PROOF OF COROLLARY 2 Recall that, according to Theorem 1, Rλ(k , v ) = γ2 2γ2 erf λ d 2(1 + 2λ2γ2) d 2, λ2γ2 + (L 1)ζ(0, λ2) + ε2 , where, for t, γ R, ζ(t, γ2) := E erf2(t + γG) , G N(0, 1) . We compute the limit of each term separately. First, we have d 2(1 + 2λ2γ2) λ Published as a conference paper at ICLR 2025 Therefore, the second term of Rλ(k , v ) tends to 2γ2. To handle the third term, note by Jensen s inequality that d 2, λ2γ2 = E d 2 + λγG # d 2 + λγG #2 Thus, by Lemma 18(ii), d 2, λ2γ2 erf2 λ d 2(1 + 2λ2γ2) where we used (18). Thus the third term of Rλ(k , v ) converges to γ2. As for the fourth term, observe by Lemma 17 that hence 0 ζ(0, λ2) 4 π λ2E[G2] = 4 L 0, we get (L 1)ζ(0, λ2) = O(λ2L) = o(1) . Putting everything together, we obtain Rλ(k , v ) d γ2 2γ2 + γ2 + 0 + ε2 = ε2 . Since we already know by (8) that the Bayes risk is lower-bounded by ε2, this proves that the Bayes risk is asymptotically equal to ε2, and that the oracle predictor is asymptotically Bayes optimal. B.3 PROOF OF PROPOSITION 3 Let us first introduce a useful notation for the proof. If M is a block matrix, we denote by M[ij] its (i, j)-th block, and likewise, if u is a block vector, we denote by u[j] its j-th block. Next, note that E[Y 2] = ε2 + E ((v ) XJ0)2 = ε2 + γ2 v 2 2 = ε2 + γ2, since v 2 2 = 1. Recall that β arg min β Rd L E h (Y (X 1 , . . . , X L )β)2i is the optimal linear predictor. The classical formula for linear regression shows that (X 1 , . . . , X L ) ((v ) XJ0 + ξ) On the one hand, let (X 1 , . . . , X L ) . Then E[M] = E[E[M|J0]], and E[M|J0] is a block-diagonal matrix, where, for j, j {1, . . . , L}, E[M|J0 = j][j ,j ] = δj =j Id + δj=j (γ2Id + d 2k (k ) ) . E[M][j ,j ] = (1 pj )Id + pj (γ2Id + d 2k (k ) ) = Id + pj (γ2 1)Id + pj d Published as a conference paper at ICLR 2025 On the other hand, let u = ((v ) XJ0 + ξ) p1(γ2Id + d 2k (k ) ) ... p L(γ2Id + d 2k (k ) ) p1v ... p Lv since, by Assumption 1, k v = 0. Since E[M] is a block-diagonal matrix and E[u] is a block vector, we get by standard computation rules for block matrices β [j] = (E[M] 1E[u])[j] = E[M] 1 [j,j]E[u][j] = Id + pj(γ2 1)Id + pj d 2k (k ) 1 γ2pjv . Recall the Sherman-Morrison formula (Press et al., 2007, Section 2.7.1), which states that for any vectors u, v Rd, (Id + uu ) 1v = Id uu /(1 + u u) v. Applying this formula with orthogonal vectors, we obtain β [j] = 1 + pj(γ2 1) 1 γ2pjv = γ2pj 1 + pj(γ2 1)v , which shows the first formula of the proposition. Finally, the risk associated with the optimal linear predictor (X 1 , . . . , X L ) 7 (X 1 , . . . , X L )β is given by R(β ) = E[Y 2] E[Y (X 1 . . . X L )β ] = ε2 + γ2 γ2 p1(v ) , . . . , p L(v ) β [1] ... β [L] = ε2 + γ2 γ4 L X p2 j 1 + pj(γ2 1) . (19) This shows the formula for the risk given in the Proposition. To obtain the last bound, observe that, if γ2 1, we have 1 + pj(γ2 1) 1. If γ2 1, since pj 1, we have 1 + pj(γ2 1) 1 + (γ2 1) = γ2. Thus we obtain 1 + pj(γ2 1) min(1, γ2). Therefore, R(β ) ε2 + γ2 max(γ4, γ2) ε2 + γ2 max(γ4, γ2) j=1 pj max j=1,...,L pj ε2 + γ2 (γ4 + γ2) max j=1,...,L pj ε2 + γ2 γ2(γ2 + 1) max j=1,...,L pj . When all pj are equal to 1/L, all terms in the sum are equal, and Eq. (19) simplifies to R(β ) = ε2 + γ2 Lγ4 1 L2 1 + 1 L(γ2 1) = ε2 + γ2 γ4 Published as a conference paper at ICLR 2025 B.4 PROOF OF LEMMA 4 As a first step in the proof, we prove the next lemma, which is the key towards the invariance property we are aiming at, in that it shows that, for a point on the manifold M (defined by θ = η = ρ = 0), the gradient of the risk does not push the point outside of the manifold. Its proof leverages the expression of the risk as a function of five parameters derived in the previous section Lemma 14. At any point (κ, ν, θ, η, ρ) such that θ = η = ρ = 0, we have θR< = ηR< = ρR< = 0. Proof. We use Eq. (16) (17) and change signs in the square function: R<(κ, ν, θ, η, ρ) = E h γGv 1 + ξ r d 2θ + γGv 1 erf λ r d 2κ + γGk 1 ℓ=2 Gv ℓerf(λGk ℓ) 2i = E h γ( Gv 1 ) ξ r d 2( θ) + γ( Gv 1) erf λ r d 2κ + γGk 1 ℓ=2 ( Gv ℓ)erf(λGk ℓ) 2i , Gv 1 Gv 1 Gk 1 Gv L Gv L Gk L 1 ν η ν 1 ρ η ρ 1 , ξ N(0, ε2) . Gv 1 Gv 1 Gk 1 Gv L Gv L Gk L 1 ν η ν 1 ρ η ρ 1 , ξ N(0, ε2) . As a consequence, R<(κ, ν, θ, η, ρ) = R<(κ, ν, θ, η, ρ) . Taking the partial derivative in θ, we are led to θR<(κ, ν, θ, η, ρ) = θR<(κ, ν, θ, η, ρ) . At a point such that θ = η = ρ = 0, this gives θR<(κ, ν, 0, 0, 0) = θR<(κ, ν, 0, 0, 0) and thus θR<(κ, ν, 0, 0, 0) = 0. The proof for the other two derivatives ηR, ρR is identical. We now complete the proof of Lemma 4. By the chain rule for total derivatives applied to R(k, v) = R<(κ, ν, θ, η, ρ), and then by Lemma 14, on the manifold M, we have k R = ( κR<)k + ( ηR<)v + ( ρR<)v = ( κR<)k , (20) and, similarly, v R = ( νR<)v + ( θR<)k + ( ρR)k = ( νR<)v . (21) Recall the formulas for the PGD updates kt+1 = Proj Sd 1(kt α(I ktk t ) k R(kt, vt)) = kt α(I ktk t ) k R(kt, vt) kt α(I ktk t ) k R(kt, vt) 2 , vt+1 = Proj Sd 1(vt α(I vtv t ) v R(kt, vt)) = vt α(I vtv t ) v R(kt, vt) vt α(I vtv t ) v R(kt, vt) 2 . Let ck = kt α(I ktk t ) k R(kt, vt) 2 and cv = vt α(I vtv t ) v R(kt, vt) 2. Then, if (kt, vt) M, (v ) kt+1 = (v ) kt α(v ) (I ktk t )( κR<(κt, νt))k Published as a conference paper at ICLR 2025 (k ) vt+1 = (k ) vt α(k ) (I vtv t )( νR<(κt, νt))v v t+1kt+1 = v t kt α( νR<)((I vtv t )v ) kt α( κR<)((I ktk t )k ) vt cvck + α2( κR<)( νR<)((I ktk t )k ) (I vtv t )v where we have omitted the dependence of ( κR<) and ( νR<) in (κt, νt) in the last expression for the ease of readability. Note that the last term is equal to zero since ((I ktk t )k ) (I vtv t )v = (k κtkt) (v νtvt) = 0 . This shows that (kt+1, vt+1) M. B.5 PROOF OF LEMMA 7 By definition of the PGD iterates and by (20) (21), one has κt+1 = k t+1k = κt α κR<(κt, νt)(k ) (I ktk t )k p 1 + α2( κR<)2 (I ktk t )k 2 2 = κt α( κR<)(1 κ2 t) p 1 + α2( κR<)2(1 κ2 t) , νt+1 = v t+1v = νt α νR<(κt, νt)(v ) (I vtv t )v p 1 + α2( νR<)2 (I vtv t )v 2 2 = νt α( νR<)(1 ν2 t ) p 1 + α2( νR<)2(1 ν2 t ) , where we have used the Pythagorean theorem and the idempotent property of projection matrices for the denominator. B.6 PROOF OF PROPOSITION 8 In this proof, C denotes a constant that does not depend on the step t nor on the step size α, and which may vary from line to line. First note that the risk R< is C on the compact set [ 1, 1]2. In particular, it is a Λ-smooth function for some Λ > 0, in the sense that its gradient is Λ-Lipschitz continuous. Thus R<(κt+1, νt+1) R<(κt, νt) + ( R<(κt, νt)) κt+1 κt νt+1 νt κt+1 κt νt+1 νt R<(κt+1, νt+1) R<(κt, νt) ( κR<)(κt+1 κt) + ( νR<)(νt+1 νt) + Λ 2 (κt+1 κt)2 + (νt+1 νt)2 . (22) Our goal in the following computations is to derive an inequality of the form R<(κt+1, νt+1) R<(κt, νt) α( κR<)2(1 κ2 t) α( νR<)2(1 ν2 t ) + Cα2( κR<)2(1 κ2 t) + Cα2( νR<)2(1 ν2 t ) , which shall give us a descent lemma for α small enough. To this aim, observe that, by definition of the iterates (κt, νt) given by (12) (13), one has 1 + α2( κR<)2(1 κ2 t) 1 κt α( κR<)(1 κ2 t) p 1 + α2( κR<)2(1 κ2 t) (23) = α( κR<)(1 κ2 t) 1 + α2( κR<)2(1 κ2 t) 1 (κt α( κR<)(1 κ2 t)) . Published as a conference paper at ICLR 2025 As a consequence, |κt+1 κt + α( κR<)(1 κ2 t)| 1 p 1 + α2( κR<)2(1 κ2 t) 1 |κt α( κR<)(1 κ2 t)| α2( κR<)2(1 κ2 t)|κt α( κR<)(1 κ2 t)| Cα2( κR<)2(1 κ2 t) , (24) where the second inequality holds by Lemma 16 and the last bound holds since the function (κ, ν) 7 |κ α( κR<(κ, ν))(1 κ2)| is uniformly bounded for all α 1. This bound has two implications. First, ( κR<)(κt+1 κt) + α( κR<)2(1 κ2 t) = ( κR<)((κt+1 κt) + α( κR<)(1 κ2 t)) | κR<||κt+1 κt + α( κR<)(1 κ2 t)| Cα2( κR<)2(1 κ2 t) , (25) where we use the fact that | κR<| is bounded, and the bound (24). Second, since the square function is Lipschitz on compact sets, we have |(κt+1 κt)2 (α( κR<)(1 κ2 t))2| Cα2( κR<)2(1 κ2 t) . (κt+1 κt)2 α2( κR<)2(1 κ2 t)2 + Cα2( κR<)2(1 κ2 t) Cα2( κR<)2(1 κ2 t) . (26) We also obtain analogous bounds to (25) (26) for ν, namely ( νR<)(νt+1 νt) + α( νR<)(1 ν2 t ) Cα2( νR<)2(1 ν2 t ) , (27) and (νt+1 νt)2 Cα2( νR<)2(1 ν2 t ) . (28) Plugging the bounds (25) (28) into Eq. (22), we obtain the desired inequality R<(κt+1, νt+1) R<(κt, νt) α( κR<)2(1 κ2 t) α( νR<)2(1 ν2 t ) + Cα2( κR<)2(1 κ2 t) + Cα2( νR<)2(1 ν2 t ) . By choosing the step size α 1 2C , this ensures that R<(κt+1, νt+1) R<(κt, νt) α 2 ( κR<)2(1 κ2 t) α 2 ( νR<)2(1 ν2 t ). This shows that the risk is decreasing along the PGD iterates. Next, introducing R< min = min(κ,ν) [0,1]2 R<(κ, ν) and using a telescopic sum, we have, for all T 0, R<(κ0, ν0) R< min R<(κ0, ν0) R<(κT , νT ) ( κR<)2(1 κ2 t) + ( νR<)2(1 ν2 t ) . Since the left-hand side is finite, and the terms of the sum are nonnegative, we conclude that the series converges as T . In particular, the generic term ( κR<)2(1 κ2 t) + ( νR<)2(1 ν2 t ) of the series converges to 0 as t . Therefore, the accumulation points (κ , ν ) satisfy κR<(κ , ν ) = 0 or κ2 = 1 νR<(κ , ν ) = 0 or ν2 = 1. Inspecting identity (23), we observe that the convergence of the general term also implies κt+1 κt 0. We obtain similarly that νt+1 νt 0. Published as a conference paper at ICLR 2025 B.7 PROOF OF PROPOSITION 9 Recall that the risk in terms of (κ, ν) is given by R<(κ, ν) = γ2 2γ2ν erf d 2 κ, λ2γ2 + (L 1)ζ(0, λ2) + ε2 . Then the gradients of R< are given by d 2(1 + 2λ2γ2)erf λ p (1 + 2λ2γ2)(1 + 4λ2γ2) νR<(κ, ν) = 2γ2erf λ p Therefore, the solutions of the system (14) satisfy ν + erf(c1κ) = 0 or κ = 1 κ = 0 or ν = 1 , with c1 = λ d 2(1+2λ2γ2). The solutions of this system are (κ, ν) = (0, 0) or (κ, ν) = ( 1, 1) . B.8 PROOF OF PROPOSITION 10 Since R< is a smooth function, the extrema of this function on [ 1, 1]2 are either critical points (admitting null derivatives) or points on the boundary of the square [ 1, 1]2. Starting with critical points, the only critical point is (0, 0), and it is a saddle point. Indeed, the Hessian of R< at (0,0) is HR<(0, 0) = 4 π γ2λ d 2(1 + 2λ2γ2) where c = 2λ d 2(1+2λ2γ2) < 0. Then, as det(M) = 1, the two eigenvalues of HR<(0, 0) have opposite signs, (0, 0) is thus a saddle point. The extrema of R< must therefore be on the boundary of the square, which we examine next. For any (κ, ν) ( 1, 1)2, one has, by inspecting the signs of the gradients given in the proof of Proposition 9, R<(1, 1) < R<(κ, 1) < R<( 1, 1) and R<(1, 1) < R<(1, ν) < R<(1, 1) . This shows that the minimum of R< on {(κ, 1), κ [ 1, 1]} {(1, ν), ν [ 1, 1]} is reached at (1, 1), and the maximum is reached both at (1, 1) and ( 1, 1), since R< is even. Using again evenness of R<, we conclude that the extrema of R< on the whole boundary of the square, and thus on the whole square, are the minimizers (1, 1) and ( 1, 1), and the maximizers (1, 1) and ( 1, 1). B.9 PROOF OF PROPOSITION 12 We prove the statements of the proposition one by one. Published as a conference paper at ICLR 2025 The mapping g is a local diffeomorphism around (0, 0), whose Jacobian matrix has one eigenvalue in (0, 1) and one eigenvalue in (1, ). Consider the Taylor expansion of the first component g(κ, ν)1 of g(κ, ν). Since κR<(0, 0) = 0, and R< is smooth, letting x = (κ, ν), we have ( κR<(κ, ν))2 = O( x 2). Thus, g(κ, ν)1 = κ α( κR<(κ, ν))(1 κ2) p 1 + α2( κR<(κ, ν))2(1 κ2) = κ α( κR<(κ, ν))(1 κ2) p 1 + O( x 2) = (κ α( κR<(κ, ν))(1 κ2)) 1 + O( x 2) = κ α κR<(κ, ν) + O( x 2) . Proceeding similarly with the second component of g, we obtain that the Jacobian of g at (0, 0) is given by Jg(0, 0) = I2 αHR<(0, 0) = I2 + α 4 π γ2λ d 2(1 + 2λ2γ2) where c = 2λ d 2(1+2λ2γ2) < 0. Since det(M) = 1, one can choose α small enough so that one eigenvalue of Jg(0, 0) is strictly between 0 and 1 and the other one is strictly larger than 1. Therefore, Jg(0, 0) is invertible, showing that g is a local diffeomorphism around (0, 0). The mapping g is differentiable on [ 1, 1]2, and its Jacobian is not degenerate. The mapping g is clearly differentiable as a composition of differentiable function. The more delicate part is to show that its Jacobian cannot be degenerate. To show this statement, observe first that, for x [ 1, 1]2, we may write g(x) = x + αh(x), where the first component of h is given by h(κ, ν)1 = 1 α(g(κ, ν)1 κ) κ α( κR<(κ, ν))(1 κ2) p 1 + α2( κR<(κ, ν))2(1 κ2) κ 1 + α2( κR<(κ, ν))2(1 κ2) 1 =:f (1) α (κ,ν) ( κR<(κ, ν))(1 κ2) p 1 + α2( κR<(κ, ν))2(1 κ2) | {z } =:f (2) α (κ,ν) Let us prove that the gradient of h(κ, ν)1 is bounded uniformly over α 1. The uniform boundedness is clear for the gradient of f (2) α , which writes as a composition of functions with uniformly bounded gradients for α 1. Moving on to f (1) α and letting g : [ 1, 1] [0, B] R (a, b) 7 a 1+α2b 1 , B = sup (κ,ν) [ 1,1]2( κR<(κ, ν))2(1 κ2) , we observe that f (1) α is the composition of g with a smooth function independent of α. In particular, it suffices to show the uniform boundedness of g to deduce the one of f (1) α . We further have, by Lemma 16, and for α 1, ag(a, b) = 1 1 + α2b 1 αb B and bg(a, b) = αa 2(1 + α2b)3/2 Therefore, the gradient of h(κ, ν)1 is bounded uniformly over α 1. Proceeding similarly with the gradient of h(κ, ν)2, we obtain that the Jacobian of h(κ, ν) is uniformly bounded over α 1. Recall now that Jg(κ, ν) = I2 + αJh(κ, ν). Therefore, taking α small enough, we obtain that the eigenvalues of Jg have to be bounded away from zero. Published as a conference paper at ICLR 2025 The mapping g is injective. The computation above shows that h is β-Lipschitz continuous with β independent of α (for α small enough). In particular we can choose α such that α < 1/β. Now, let x = y [ 1, 1]2 be such that g(x) = g(y). Then x y α h(x) h(y) αβ x y < x y . This is a contradiction, showing that g is injective. B.10 PROOF OF PROPOSITION 13 Recall that (1, 1) and ( 1, 1) are maxima of the risk R< on [ 1, 1]2 by Proposition 10, and that the value of the risk decreases along the iterates of PGD by Proposition 8. Thus the only possible way to converge to these points is to start the dynamics from them. The case of the point (0, 0) is more delicate. We apply the Center-Stable Manifold theorem (Shub, 1987, Theorem III.7) to g, which is a local diffeomorphism around (0, 0) by Proposition 12. This guarantees the existence of a local center-stable manifold W cs loc, which verifies the following properties. First, its codimension is equal to the number of eigenvalues of Jg(0, 0) of magnitude larger than 1, that is, 1, by Proposition 12. Hence it has Lebesgue measure zero. Second, there exists a neighborhood B of 0 such that T t=0 g t(B) W cs loc. Then, let W s be the set of all x which converge to (0, 0) under the gradient map g, and take x W s. Then there exists a T such that gt(x) B for all t T. This means that g T (x) T s=0 g s(B), and thus g T (x) W cs loc. So, x g T (W cs loc). We have just shown that W s [ T 0 g T (W cs loc) . Finally, we prove that the pre-image of sets of measure zero by g T has measure zero for any T 0. This shall conclude the proof of the result since countable unions of sets of measure zero have measure zero. To show this, note that g is injective by Proposition 12, and therefore g T is injective too. This allows to define an inverse g T of g T defined on the image of g T , and the pre-image by g T of W cs loc is exactly the image by g T of W cs loc (intersected with the domain of definition of g T ). Furthermore, by Proposition 12, the Jacobian of g T is invertible. This guarantees that g T is differentiable by the inverse function theorem. The conclusion follows by recalling that differentiable functions map sets of measure zero to sets of measure zero. C EXPRESSION OF THE RISK BEYOND THE INVARIANT MANIFOLD In this appendix, we provide an expression of the risk R beyond the manifold M that extends the one provided in Lemma 6. This result is not needed to prove Theorem 5, and its proof is more involved that the one of Lemma 6. However, we provide it since it might be relevant to follow-up works that would study the dynamics if not initialized on the invariant manifold M. It is also useful for the numerical simulations (see Appendix E). Proposition 15. We have the closed-form expression R< λ (κ, ν, θ, η, ρ) = ε2 + γ2 2γ2νerf λ p 1 + 2λ2γ2 erf λ p 1 + 2λ2γ2 erf λ p 2θ2 + γ2)ζ λ d 2 κ, λ2γ2 1 + 2λ2γ2 erf λ p (1 + 4λ2γ2)(1 + 2λ2γ2) + 4λ2γ4ρ2 π p 1 + 4λ2γ2(1 + 2λ2γ2) erf λ + (L 1) ζ(0, λ2) + 8λ2 1 + 4λ2(1 + 2λ2) ρ2 + 4λ2 (1 + 2λ2)π (L 1)(L 2)ρ2 Published as a conference paper at ICLR 2025 + 4λ(L 1)ρ p d 2θ erf λ p 1 + 2λ2γ2 erf λ p Proof. We first recall the notations for the five scalar products that are used throughout this proof. ν = v v , κ = k k , θ = v k , η = k v , ρ = k v . A first decomposition. We start back from the expression (15) obtained for the risk. By expanding in ξ, then expanding the square, we obtain R(k, v) = E X 1 v ℓ=1 X ℓv erf(λX ℓk) 2 + ε2 = E X 1 v X 1 v erf(λX 1 k) ℓ=2 X ℓv erf(λX ℓk) 2 + ε2 = E X 1 v X 1 v erf(λX 1 k) 2 | {z } =:R1 ℓ=2 E X ℓv erf(λX ℓk) 2 | {z } =:R2 ℓ =j 2 E X ℓv erf(λX ℓk)X j v erf(λX j k) | {z } =:R3 ℓ=2 E X 1 v X 1 v erf(λX 1 k) X ℓv erf(λX ℓk) | {z } =:R4 Computation of R1. By expanding the square, E X 1 v X 1 v erf(λX 1 k) 2 = E X 1 v 2 2E X 1 v X 1 v erf(λX 1 k) + E X 1 v erf(λX 1 k) 2 . These three terms are computed hereafter. First we have E X 1 v 2 = E X 1 v 2 + Var X 1 v = ( d 2(k ) v )2 + γ2 = γ2 . E X 1 v X 1 v erf(λX 1 k) d 2(k ) v + Z1 d 2(k ) v + Z2 d 2(k ) k + λZ3 with Z1 Z2 Z3 v v 1 v k k v v k 1 0, γ2 1 ν η ν 1 ρ η ρ 1 Recall the multivariate version of Stein s lemma (Stein, 1981), which states that, when Z, G1, . . . , Gp are centered and jointly Gaussian, and σ : Rp R, E [Zσ(G1, . . . , Gp)] = i=1 Cov(Z, Gi)E [ iσ(G1, . . . , Gp)] . Published as a conference paper at ICLR 2025 E X 1 v X 1 verf(λX 1 k) d 2 λγ2ηθ p 1 + 2λ2γ2 erf λ p 1 + 2γ2λ2 erf λ p by using Lemma 18(i) (iii). Finally, using again Stein s lemma and Lemma 18(iv) (vi), the computation of the last term is as follows: E h X 1 v erf(λX 1 k) 2i d 2(k ) v + Z2 d 2k k + λZ3 " d 2θ2 erf2 λ d 2θZ2 erf2 λ Z2 2 erf2 λ 2θ2 + γ2)ζ λ d 2 κ, λ2γ2 Published as a conference paper at ICLR 2025 2θ2 + γ2)ζ λ d 2 κ, λ2γ2 1 + 2λ2γ2 E 2θ2 + γ2)ζ λ d 2 κ, λ2γ2 1 + 2λ2γ2 erf λ p (1 + 4λ2γ2)(1 + 2λ2γ2) 1 + 4λ2γ2 erf λ by Lemma 18(iv) (vi). Computation of R2. We have ℓ=2 E X ℓv erf(λX ℓk) 2 = (L 1)E X 2 v erf(λX 2 k) 2 . Thus, using previous calculations with γ2 = 1, θ = 0, and κ = 0, we obtain R2 = (L 1) ζ(0, λ2) + 4λ2 π 4λ2 + 1(1 + 2λ2) ρ2erf (0) = (L 1) ζ(0, λ2) + 8λ2 4λ2 + 1(1 + 2λ2) ρ2 Computation of R3. Regarding the cross-product terms, by independence of the (Xℓ) s and Stein s lemma, one gets E X ℓv erf(λX ℓk)X j v erf(λX j k) = E X ℓv erf(λX ℓk) E X j v erf(λX j k) = C2ρ2 , with C := λE(erf (λX ℓk)) = 2λ/ p (1 + 2λ2)π by Lemma 18(i). This leads to (1 + 2λ2)π (L 1)(L 2)ρ2. Computation of R4. We have, again by independence and Stein s lemma, E X 1 v X 1 v erf(λX 1 k) X ℓv erf(λX ℓk) = E X 1 v X 1 v erf(λX 1 k) E X ℓv erf(λX ℓk) d 2(k ) v E(X 1 v erf(λX 1 k)) E X ℓv erf(λX ℓk) = E(X 1 v erf(λX 1 k)) Cρ (1 + 2λ2)π E(X 1 v erf(λX 1 k)) . Note that, still using Stein s lemma, E(X 1 v erf(λX 1 k)) Published as a conference paper at ICLR 2025 d 2(k ) v erf(λX 1 k)) E((X 1 v d 2(k ) v) erf(λX 1 k)) d 2θ erf(λX 1 k)) Cov X 1 v, erf(λX 1 k) d 2θ E(erf(λX 1 k)) λCov X 1 v, X 1 k E erf (λX 1 k) λγ2(k v) 1 p 1 + 2γ2λ2 erf λ p where we used that λX 1 k L= λ p d/2κ + G with G N(0, λ2γ2), in combination with Lemma 18(i) (ii). Thus R4 = 4λ(L 1)ρ p d 2θ erf λ p 1 + 2γ2λ2 erf λ p All in all. Putting everything together, we obtain R(k, v) = ε2 + γ2 2γ2νerf λ p 1 + 2λ2γ2 erf λ p 1 + 2λ2γ2 erf λ p 2θ2 + γ2)ζ λ d 2 κ, λ2γ2 1 + 2λ2γ2 erf λ p (1 + 4λ2γ2)(1 + 2λ2γ2) + 4λ2γ4ρ2 π p 1 + 4λ2γ2(1 + 2λ2γ2) erf λ + (L 1) ζ(0, λ2) + 8λ2 1 + 4λ2(1 + 2λ2) ρ2 (1 + 2λ2)π (L 1)(L 2)ρ2 + 4λ(L 1)ρ p d 2θ erf λ p 1 + 2λ2γ2 erf λ p This concludes the proof. D TECHNICAL RESULTS This section gathers formulas that are useful in the proofs, in particular regarding expectation of functions of Gaussian random variables involving erf. Lemma 16. For u 0, 1 1 + u 1 u . Proof. The argument of the absolute value is non-positive for u 0, hence we need to show that f(u) := 1 1 1 + u u is non-positive for u 0. Just note that f(0) = 0 and f (u) = 1 (1 + u)3+2 1 0 . Published as a conference paper at ICLR 2025 Recall that the erf function is defined on R as erf(u) = 2 π Lemma 17 (Properties of the erf function). We have erf (u) = 2 π e u2 , erf (u) = 4 π ue u2 = 2uerf (u) , |erf(u)| 2 π |u| . Proof. The first two statements are clear by usual differentiation rules. Regarding the last statement, since erf is an odd function, it is sufficient to prove the statement for u 0. Moreover, erf is concave on [0, ), so we get, for u 0, |erf(u)| = |erf(u) erf(0)| erf (0)u = 2 π u , which concludes the proof. Lemma 18. Let G N(0, γ2). For t R, (i) E erf (t + G) = 1 1+2γ2 erf t (ii) E [erf(t + G)] = erf t (iii) E erf (t + G) = 1 1+2γ2 erf t (iv) E (erf )2(t + G) = 2 π (v) (1+2γ2)E[erf(t+G)erf (t+G)] = 2t E[erf(t+G)erf (t+G)] 2γ2E[(erf (t+G))2]. (vi) E erf(t + G)erf (t + G) = 1 1+2γ2 erf t (1+4γ2)(1+2γ2) This lemma reveals the importance of choosing the erf function as the component-wise nonlinearity: there are closed-form formulas for the expectation of erf and its derivatives applied to Gaussian random variables. Extending the results to any nonlinear, bounded, increasing, equal to 0 at 0, and differentiable activation function is an interesting next step. Proof. (i) By Lemma 17, E erf (t + G) = Z e (t+g)2e g2 c e 2gte t2dg for c := 2γ2 Z e (g+ct)2 c +ct2 t2dg 2 πγ e t2(1 c) Z e (g+ct)2 c dg | {z } = πc π(1 + 2γ2) exp t2 1 2γ2 Published as a conference paper at ICLR 2025 1 + 2γ2 exp t2 (ii) By (i), E [erf(t + G)] = Z t E erf (s + G) ds 1 + 2γ2 exp s2 2 π exp u2 ds (iii) By Lemma 17, and following the same steps as in (i), E erf (t + G) = 2 Z (t + g)e (t+g)2e g2 2 πγ e t2(1 c) Z (t + g)e (g+ct)2 2 πγ e t2(1 c) t πc + πc E(N( ct, c 2c πγ e t2(1 c)(t ct) π(1 + 2γ2) e t2(1 c) 1 1 + 2γ2 t = 4t π(1 + 2γ2)3/2 exp t2 (iv) By Lemma 17, E (erf )2(t + G) = 1 Z (erf )2(t + g)e g2 Z e 2(t+g)2e g2 2Γ2 e 4gte 2t2 dg with Γ2 := γ2/(1 + 4γ2) Z e (g+4Γ2t)2 2Γ2 e8Γ2t2e 2t2 dg 2 γπ3/2 e 2t2(1 4Γ2) Z e (g+4Γ2t)2 2 γπ3/2 e 2t2(1 4Γ2) 1 + 4γ2 exp 2t2 (v) We use Lemma 17 and then Stein s lemma: E[erf(t + G)erf (t + G)] = 2E (t + G)erf(t + G)erf (t + G) Published as a conference paper at ICLR 2025 = 2t E erf(t + G)erf (t + G) 2E Gerf(t + G)erf (t + G) = 2t E erf(t + G)erf (t + G) 2γ2 E erf (t + G)2 + E erf(t + G)erf (t + G) . Reordering terms, this gives the desired equation. (vi) We define the function f(t) = E erf(t + G)erf (t + G) . Then, using Lemma 18(v), we have f (t) = E erf (t + G)2 + E erf(t + G)erf (t + G) = E erf (t + G)2 2t 1 + 2γ2 E erf(t + G)erf (t + G) 1 + 2γ2 E (erf (t + G))2 = 1 1 + 2γ2 E (erf (t + G))2 2t 1 + 2γ2 f(t) . We solve this differential equation by the method of variation of parameters: we have d dt f(t)et2/(1+2γ2) = 1 1 + 2γ2 E (erf (t + G))2 et2/(1+2γ2) . We use Lemmas 17 and 18(iv): f(t)et2/(1+2γ2) = 2 π 1 (1 + 2γ2) p 1 + 4γ2 erf et2/(1+2γ2) (1 + 2γ2) p 1 + 4γ2 e 2t2/(1+4γ2)et2/(1+2γ2) (1 + 2γ2) p 1 + 4γ2 exp t2 (1 + 2γ2)(1 + 4γ2) (1 + 2γ2) p 1 + 4γ2 erf t p (1 + 2γ2)(1 + 4γ2) As the distribution of G is symmetric and erf is an odd function, we have that f(0) = E erf(G)erf (G) = 0. Thus integrating the above derivative, we obtain f(t)et2/(1+2γ2) = 2 π 1 (1 + 2γ2) p 0 ds erf s p (1 + 2γ2)(1 + 4γ2) 1 + 2γ2 erf (1 + 2γ2)(1 + 4γ2) Using again Lemma 17, we obtain the claimed result: 1 + 2γ2 erf t p (1 + 2γ2)(1 + 4γ2) E EXPERIMENTAL DETAILS AND ADDITIONAL RESULTS Our code is available at https://github.com/Pierre Marion23/ single-location-regression We use the Transformers (Wolf et al., 2020) and scikit-learn (Pedregosa et al., 2011) libraries for the experiment of Section 2, and JAX (Bradbury et al., 2018) for the experiment of Section 5. All experiments run in a short time (less than one hour) on a standard laptop. Published as a conference paper at ICLR 2025 E.1 EXPERIMENT OF SECTION 2 (NLP MOTIVATIONS) Data generation. We use synthetically-generated data for this experiment. To create our train set, we generate sentences according to the patterns The city is [SENTIMENT ADJ]. [PRONOUN] [COLOR ADJ] [ANIMAL] is [ADV] [SENTIMENT ADJ]. The city is [SENTIMENT ADJ]. [PRONOUN] [SENTIMENT ADJ] [ANIMAL] is [ADV] [COLOR ADJ]. where ADJ stands for adjective and ADV for adverb. Note that the difference between the two patterns is that the locations of the sentiment and of the color adjectives are swapped. Each element between brackets corresponds to a word, which can take a few different values that are chosen manually. For instance, some possible sentiment adjectives are nice, clean, cute, delightful, mean, dirty, or nasty. A possible value for some words is , meaning that we remove the word from the sentence, which creates more variety in sentence length. By doing the Cartesian product over the possible values of each word in brackets, we generate in this way a large number of examples. Then, the label associated to each example depends solely on the sentiment adjective appearing in the second sentence. For instance, the words nice, clean, cute, or delightful are associated to a label +1, while the words delightful, mean, and dirty are associated to a label 1. We now explain how the test sets are generated. We generate four test sets in order to assess the robustness of the model to various out-of-distribution changes. The baseline test set uses the same sentence patterns and the same sentiment adjectives as in the training set, but other words in the example (e.g., animals, adverb) are different. In particular, a given sentence cannot appear both in the train set and in the test set. Then, we generate another test set by using sentiment adjectives that are not present in the training set. We emphasize that the sentiment adjective fully determines the label, so using unseen adjectives at test time makes the task significantly harder. The third test set uses the same adjectives as in the train set, but another sentence pattern, namely Hello, how are you? Good evening, [PRONOUN] [COLOR ADJ] [ANIMAL] is [ADV] [SENTIMENT ADJ]. Finally, the fourth test set combines a different sentence pattern and unseen adjectives. The size of the datasets is given in the table below. All datasets have the same number of +1 and 1 labels. Name Number of examples Train set 15552 Test set 4608 Test w. OOD tokens 3072 Test w. OOD structure 144 Test w. OOD structure+tokens 96 Table 1: Size of the generated datasets. Model. We recall that there exists several families of Transformer architectures, which in particular are not all best suited for sequence classification. An appropriate family is called encoder-only Transformer, and a foremost example is BERT (Devlin et al., 2019). We refer to Phuong & Hutter (2022) for an introductory discussion of Transformer architectures and associated algorithms. Here, we use a pretrained BERT model from the Hugging Face Transformers library (Wolf et al., 2020), with the default configuration, namely bert-base-uncased. The model has 110M parameters, 12 layers, the tokens have dimension d = 768, and each attention layer has 12 heads. It was pretrained by masked language modeling, namely some tokens in the input are hidden, and the model learns to predict the missing tokens. We refer to Devlin et al. (2019) for details on the architecture and pretraining procedure. We do not perform any fine-tuning on the model. Published as a conference paper at ICLR 2025 Experiment design. Our experiment consists in performing logistic regression on embeddings of [CLS] tokens in the hidden layers of the pretrained BERT model, where we recall that the [CLS] token is a special token added to the beginning of each input sequence. This is a particular case of the so-called linear probing, which is a common technique in the field of LLMs interpretability. More precisely, let ℓdenote a layer index between 0 and 12, where the index 0 corresponds to the input to the model (after tokenization and embedding in Rd). Then, for each value of ℓ {0, . . . , 12}, we train a logistic regression classifier, where, for each example, the input to the classifier is the embedding of the [CLS] token at layer ℓ(that is, a d-dimensional vector), and the label is simply the label of the sentence as described above. Results. For ℓ= 0 (blue bar in Figure 1b), the embedding of [CLS] is a fixed vector that does not depend on the rest of the sequence, so the classifier has a pure-chance accuracy of 50%. However, as soon as ℓ> 0, thanks to the attention mechanism, the [CLS] token contains information about the sequence. We report in Figure 1b the average accuracy over ℓ {1, . . . , 12} for the train set (in orange) and the test sets (in green). We observe that the information contained in the [CLS] token is actually very rich, since logistic regression achieves a perfect accuracy of 100% in the train set. In other words, the data fed to the classifier is linearly separable. We emphasize that the size of the train set is significantly larger than the ambient dimension d, so it is far from trivial that this procedure would yield a linearly-separable dataset. Therefore, obtaining linearly-separable data demonstrates that the model constructs a linear representations of the input inside the [CLS] token. Moving on to the test sets, the accuracy on the baseline test set is very good (95%), which suggests some generalization abilities of the model. The accuracy on the out-of-distribution test sets degrades (between 64% and 75%), but remains largely superior to pure-chance performance. This suggests that the internal representation built by the Transformer model is to some extent universal, in the sense that it is robust to the specifics of the sentence structure and of the word choice. E.2 EXPERIMENT OF SECTION 5 (GRADIENT DESCENT RECOVERS THE ORACLE PREDICTOR) We begin by providing additional results before giving experimental details. PGD with an initialization on the sphere and constant inverse temperature schedule. As emphasized in Section 5, the dynamics of PGD with a general initialization on (Sd 1)2 depend on the choice of the inverse temperature schedule λt. The experiment presented in the main text in Figure 4a is for a decreasing schedule λt = 1/(1+10 4t). We report in Figure 7 results when taking a constant inverse temperature. We observe distinct patterns depending on the value of this parameter. With a large inverse temperature (Figure 7a), we observe that the dynamics in (κ, ν) always escape the neighborhood of 0. Furthermore, the direction v is almost perfectly recovered, i.e., ν 1. However, the value of k is only partially recovered: the dynamics stabilize around κ 0.3. Moreover, the excess risk plateaus at a high value, while the dynamics stay far away from the manifold M. In the case of a smaller inverse temperature (Figure 7b), the situation is different. We observe that some initializations lead to a convergence to the point (κ, ν) = (0, 0), in which case the dynamics stay far from the manifold M. In other words, there is no recovery of k and v . Other initializations lead to perfect recovery of k and v . In all cases, the final excess risk is low. Theoretical study of these observations is left for future work. Implementation details. The implementation of the PGD algorithm (10) requires to compute the gradient of the risk. To this aim, we use the formula for the risk given by Proposition 15. Note that all quantities appearing in this expression have explicit derivatives. The only quantity for which this is not directly clear is the function ζ, which needs to be differentiated with respect to its first variable to compute the derivative of the risk with respect to κ. However, recall that ζ(t, γ2) := E erf2(t + G) . Then, by Lemma (18), tζ(t, γ2) = 2E (erf erf )(t + G) 1 + 2γ2 erf (1 + 4γ2)(1 + 2γ2) Evaluating ζ itself (and not its derivative) is not required to simulate the dynamics, but is useful for reporting the value of the risk. For this, we also use the formula above, and use numerical quadrature Published as a conference paper at ICLR 2025 0 50000 100000 Step Excess risk 0 50000 100000 Step Alignment with oracle parameters 0 50000 100000 Step Distance to the manifold M (a) For λt = 0.9. 0 10000 20000 Step Excess risk 0 10000 20000 Step Alignment with oracle parameters 0 10000 20000 10 8 Distance to the manifold M (b) For λt = 0.1. Figure 7: Dynamics of PGD from a random initialization on (Sd 1)2, for two iteration-independent values of λt. Left: Excess risk as a function of the number of steps. Middle left: Alignment |κ| = |k k | and |ν| = |v v | with the oracle parameters. Middle right: Trajectories of κ and ν in a few repetitions of the experiments. Each repetition corresponds to a color, the end point of each trajectory is in blue. Right: Distance to the invariant manifold M. In all plots except the middle right ones, the experiment is repeated 30 times with independent random initializations, and either 95% percentile intervals are plotted or all the curves are plotted. Parameters are d = 400, L = 10, and γ = p to compute the value of ζ(t, γ2) = Z t sζ(s, γ2)ds . We report in the figures the value of the excess risk, i.e., the risk Rλ(k, v) ε2. To compute the distance to the manifold M, recall that it is defined by M = {(k, v) Sd 1 Sd 1, k v = 0, v k = 0, k v = 0} . For a point (k, v) Sd 1 Sd 1, its distance to M is therefore computed as d M((k, v)) = q (k v )2 + (v k )2 + (k v)2 . Parameter values. The following table summarizes the value of the parameters in our experiments. Name Figure 4a Figure 4b Figure 5 Figure 7a Figure 7b d 400 400 80 400 400 L 10 10 10 10 10 γ 1/ 2 λt 1/(1 + 10 4t) 0.1 2/(1 + 10 4t) 0.9 0.1 α 4 10 3 4 10 3 10 3 10 3 4 10 3 Number of steps 120k 20k 200k 120k 20k N. of repetitions 30 30 30 30 30 Batch size - - 5 - - ε 0 0 0.1 0 0 Table 2: Parameter values for the experiments on recovery of the oracle predictor by gradient descent. Published as a conference paper at ICLR 2025 E.3 ADDITIONAL EXPERIMENTS Transformer layer. The most general formulation of the Transformer layer we consider writes, for X RL d, X = concat(r, X) h=1 softmax 1 p LN( X)Qh | {z } (L+1) p K h LN( X) | {z } p (L+1) LN( X)Vh | {z } (L+1) p O h |{z} p d T(X) = ˆX + Re LU(ˆXW 1 + 1b 1 )W 2 + 1b 2 , concat(r, X) R(L+1) d adds a new token at the beginning of the sequence by concatenating r Rd to X RL d. This token corresponds to the [CLS] or register token (see Section 3 for discussion and references). In all our experiments, r Rd is a vector with i.i.d. Gaussian entries of variance 1/d, which is not trained; LN denotes layer normalization, softmax denotes row-wise softmax, and 1 RL+1 is the vector filled with 1; the parameters are Qh, Kh, Vh, Oh Rd p, W1 Rd m, b1 Rm, W2 Rm d, and b2 Rd. Experiment with single-head Transformer layer on single-location regression. We first consider the case of single-head attention, where H = 1 and p = d. For ease of notation, we drop the subscripts h in the parameters of the attention layer. We also set O to be the identity matrix. We aim at training the Transformer layer on the single-location regression task, to check that our simplified model is a good description of the Transformer layer. First note that the output of the Transformer layer (29) is a matrix in R(L+1) d while the target of single-location regression is a scalar. Thus, we consider only the first row of T(X), corresponding to the register token, and learn a linear projection of this row to R. In other words, the Transformer layer should learn to store in the register token global information about the sequence, as described in Sections 2 and 3. Overall, letting θ Rd, our risk writes R(Q, K, V, W1, b1, W2, b2, θ) = E h Y T(X)1θ 2i , where (X, Y ) are distributed according to the single-location task as described in Section 2. We train using single-pass stochastic gradient descent (meaning that fresh samples are used at each step), for 8, 000 steps with a batch size of 128 and a learning rate of 0.01. The experiment is repeated 20 times with independent random initializations, and 95% percentile intervals are plotted (but are not visible when the variance is too small). Parameters K, V , W1, W2 are initialized with Gaussian entries of variance 2/(din + dout). The bias terms are initialized to 0, as well as the query matrix Q, following a standard recommendation in the literature on signal propagation in Transformer (Yang et al., 2021; He et al., 2023; He & Hofmann, 2024). The output weights θ are initialized with Gaussian entries of variance 1/d2, following the mean-field regime (Chizat et al., 2019). Parameters are L = 10, d = p = 80, m = 200, ε2 = 0.01, γ2 = 0.5. Results are given in Figure 8. We observe in Figure 8a that the Transformer layer is able to solve single-layer regression. Furthermore, as shown by Figure 8b, it does so by encoding in its weights the underlying structure of the problem, namely the oracle parameters k and v , as in our simplified model (see Section 5). More precisely, in the case of our model, we showed that the two parameters k, v (Rd)2 converge to (k , v ). To make appear the equivalent of k and v in the more complex parametrization (29), we let k be the first left singular vector of K, and v = V (I + W1W2)θ/ V (I + W1W2)θ . We check numerically that the weight matrix QK is nearly rank-one after training3, which validates taking k as the first singular vector of K in the present experiment. It also validates considering vector-valued parameters in our simplified model. The role of the vector k is to select the relevant token among all input tokens, while the vector v describes how successive transformations (the value matrix of the attention layer, the MLP with skip connection, 3The ratio between its first and second singular value is of the order of 106 at the end of training. Published as a conference paper at ICLR 2025 0 2000 4000 6000 8000 Step Excess risk (a) Excess test risk as a function of the number of steps. 0 2000 4000 6000 8000 Step Alignment with oracle parameters Alignment with k Alignment with v (b) Alignment between Transformer parameters and oracle parameters k and v . We plot |k k | and v v as a function of the number of steps, where k is the first left singular vector of K, and v := V (I + W1W2)θ/ V (I + W1W2)θ . Figure 8: Training a full Transformer layer on single-location regression. The Transformer layer solves the task, and encodes the structure of the problem in its weights. and the final linear projection) map this token to the output of the model. We observe that these two vectors align perfectly with k and v . This confirms that our simplified model is a good description of how the Transformer layer solves single-location regression. Multiple-location regression. A natural extension of single-location regression is when the output depends on s > 1 tokens instead of just one. This task, which we name multiple-location regression, can be written as h=1 X J(h)v h + ξ, (30) where J(1), . . . , J(s) are latent discrete random variables on {1, . . . , L}, all different, and such that, conditionally on J(1), . . . , J(s), ( XJ(h) N q d 2k h, γ2Id Xℓ N(0, Id) for ℓ/ {J(1), . . . , J(s)} . Experiment with simplified predictor on multiple-location regression. In accordance with the above, a natural extension of the model presented in the main text is the multi-head predictor T (k1,v1,...,kh,vh) λ (X) = h=1 erf λXkh Xvh . (31) The hope is that each head (kh, vh) should align with one of the oracle directions (k h, v h). As a first attempt in investigating this question, we run stochastic PGD in a setup similar to the one presented in Figure 5. We take s = 2, the pair (J(1), J(2)) takes uniform values among disjoint pairs of indices in {1, . . . , L}. The directions (k 1, v 1) and (k 2, v 2) are sampled independently uniformly on the sphere, such that (k i ) v i = 0. Parameter values are the same as in Figure 5, except that the number of steps is set to 105, the number of repetitions is set to 20, and the inverse temperature λt is constant after 2.5 104 steps. Results are given in Figure 9. We observe (Figure 9a) that our predictor is able to solve the task. However, the recovery of oracle parameters is only partial, as shown in Figures 9b and 9c: each head partially aligns with the oracle parameters, but the alignment is not perfect. In other words, the model is not well able to separate the signal coming from the different XP (h). This calls for additional research in understanding how attention heads differentiate from each other in order to attend to various signals, and why in our setup the heads are not well-separated. Published as a conference paper at ICLR 2025 0 25000 50000 75000 100000 Step Excess risk (a) Excess test risk as a function of the number of steps. 0 25000 50000 75000 100000 Step Alignment with k h (b) Alignment |k k h| with the oracle parameters. 0 25000 50000 75000 100000 Step Alignment with v h (c) Alignment |v v h| with the oracle parameters. Figure 9: Training the multi-head predictor (31) on the multiple-location regression task (30). The predictor is able to reach a low-risk region. The recovery of oracle parameters by the predictor is partial. In the middle plot, for each repetition and each oracle parameter k h, we look at the end of training which head among k1 and k2 is closer to k h, and report the alignment between k h and that head along training. If the alignment were perfect, this quantity would be close to 1. The same holds for the right plot. Experiment with multi-head Transformer layer on multiple-location regression. We train a multi-head Transformer layer on the multiple-location regression task (30), taking H = s = 2. The data is generated as in the previous experiment. Parameters are as in the experiment for single-head Transformer, except the dimension p = d/H = 40, the number of repetitions set to 10, and the learning rate set to 0.02. Mimicking the single-head experiment, we let kh be the first left singular vector of Kh, and vh = Vh O h (I + W1W2)θ/ Vh O h (I + W1W2)θ . We also check numerically that all weight matrices Qh K h are nearly rank-one after training. Results are reported in Figure 10. The conclusions are similar to the previous experiment: the excess risk is low at the end of training, but we observe partial recovery of the oracle parameters (although the recovery is somewhat better than with the simplified predictor, especially for k h). This suggests that our simplified predictor might be a first good testbed to understand the training dynamics of multi-head Transformer for this task. 0 2500 5000 7500 10000 Step Excess risk (a) Excess test risk as a function of the number of steps. 0 2500 5000 7500 10000 Step Alignment with k h (b) Alignment |k k h| with the oracle parameters. 0 2500 5000 7500 10000 Step Alignment with v h (c) Alignment v v h with the oracle parameters. Figure 10: Training the multi-head Transformer layer (29) on the multiple-location regression task (30). The predictor is able to reach a low-risk region. The recovery of oracle parameters by the predictor is partial. For each h {1, 2}, we let kh be the first left singular vector of Kh, and vh = Vh O h (I + W1W2)θ/ Vh O h (I + W1W2)θ . In the middle plot, for each repetition and each oracle parameter k h, we look at the end of training which head among k1 and k2 is closer to k h, and report the alignment between k h and that head along training. If the alignment were perfect, this quantity would be close to 1. The same holds for the right plot. F FURTHER DISCUSSION OF RELATED MODELS We begin by discussing some related works on training dynamics of Transformers (Jelassi et al., 2022; Nichani et al., 2024; Wang et al., 2024), to illustrate the originality of our task and predictor. Jelassi et al. (2022) study how (vision) Transformers learn spatial patterns in the data by relying on positional Published as a conference paper at ICLR 2025 encodings. This differs significantly from our task that is invariant by token permutation. Further, in their model, the argument of softmax (i.e., a matrix A RL L) is directly a parameter of the model. This is a radically different structure from the usual attention, and from our setup, where the data appear in the nonlinearity σ(X ℓk). Next, Nichani et al. (2024) explore a task involving a fixed latent causal graph over the positions of the tokens. Here again, positional encodings play a critical role in their analysis, whereas our task is invariant under permutations of the tokens. Moreover, in Nichani et al. (2024), the output is expressed as a function of the last token, with the previous tokens providing the necessary context for this computation. In our setup, however, the output depends on a token whose position varies and must be identified within the context. Closer to our approach is the recent paper by Wang et al. (2024), which also incorporates a notion of token-wise sparsity: the output is computed as the average of a small subset of tokens, where the subset is identified by comparing the positional encodings of each token with that of a reference token. We outline two key differences with our setting. First, we do not make use of a reference token, but instead learn the latent direction k to identify the informative token. Second, in our setting, the tokens also encode an output projection direction v on top of k . In other words, our task involves learning a linear regression in addition to identifying the relevant token, which is not the case in Wang et al. (2024). Besides, we also note that our task shares similarities with multi-index models (Mc Cullagh & Nelder, 1983) and mixtures of linear regressions (De Veaux, 1989). However, our task (Plearn) has a more structured nature, involving sequence-valued inputs and incorporating a single-location pattern. Finally, one could imagine a multi-layer perceptron (MLP) designed specifically for single-location regression, where the weights have a diagonal structure with respect to the sequence index, namely MLP(X1, . . . , XL) = ℓ=1 W2σ(W1Xℓ+ b1) + b2. In such a setup, the first layer could learn the projections along k and v , while the subsequent layer could learn to map these projections to the ouput Y (in a somewhat similar spirit to multi-index models). However, this architecture is far from resembling those used in practice. If we do not assume a diagonal structure and instead use traditional MLPs, the number of parameters must scale at least linearly with the sequence length, which is highly suboptimal and may lead to very slow training. This highlights the efficiency of attention layers, which perform single-location regression with a fixed number of learnable parameters, independent of the input length. We leave a rigorous study of the learning abilities of MLPs in single-location regression for future work.