# maxmargin_token_selection_in_attention_mechanism__7cad17ea.pdf Max-Margin Token Selection in Attention Mechanism Davoud Ataee Tarzanagh University of Pennsylvania tarzanaq@upenn.edu Yingcong Li Xuechen Zhang University of California, Riverside {yli692,xzhan394}@ucr.edu Samet Oymak University of Michigan UC Riverside oymak@umich.edu Attention mechanism is a central component of the transformer architecture which led to the phenomenal success of large language models. However, the theoretical principles underlying the attention mechanism are poorly understood, especially its nonconvex optimization dynamics. In this work, we explore the seminal softmaxattention model f(X) = Xv, softmax(XW p) , where X is the token sequence and (v,W, p) are trainable parameters. We prove that running gradient descent on p, or equivalently W, converges in direction to a max-margin solution that separates locally-optimal tokens from non-optimal ones. This clearly formalizes attention as an optimal token selection mechanism. Remarkably, our results are applicable to general data and precisely characterize optimality of tokens in terms of the value embeddings Xv and problem geometry. We also provide a broader regularization path analysis that establishes the margin maximizing nature of attention even for nonlinear prediction heads. When optimizing v and p simultaneously with logistic loss, we identify conditions under which the regularization paths directionally converge to their respective hard-margin SVM solutions where v separates the input features based on their labels. Interestingly, the SVM formulation of p is influenced by the support vector geometry of v. Finally, we verify our theoretical findings via numerical experiments and provide insights. 1 Introduction Since its introduction in the seminal work [1], attention mechanism has played an influential role in advancing natural language processing, and more recently, large language models [2, 3, 4, 5]. Initially introduced for encoder-decoder RNN architectures, attention allows the decoder to focus on the most relevant parts of the input sequence, instead of relying solely on a fixed-length hidden state. Attention mechanism has taken the center stage in the transformers [6], where the self-attention layer which calculates softmax similarities between input tokens serves as the backbone of the architecture. Since their inception, transformers have revolutionized natural language processing, from models like BERT [7] to Chat GPT [8], and have also become the architecture of choice for foundation models [9] addressing diverse challenges in generative modeling [3, 10], computer vision [11, 12], and reinforcement learning [13, 14, 15]. The prominence of the attention mechanism motivates a fundamental theoretical understanding of its role in optimization and learning. While it is well-known that attention enables the model to focus on the relevant parts of the input sequence, the precise mechanism by which this is achieved is far from clear. To this end, we ask Q: What are the optimization dynamics and inductive biases of the attention mechanism? We study this question using the fundamental attention model f(X) = Xv, S(XW p) . Here, X is the sequence of input tokens, v is the prediction head, W is the trainable key-query weights, and S 37th Conference on Neural Information Processing Systems (Neur IPS 2023). 4 2 0 2 4 6 8 10 2 (-0.1,1), optimal token ( =1) (1,0), non-opt token ( =0.9) (0,0), non-opt token ( =0.9) (a) Global convergence 2 0 2 4 6 8 10 12 2 (-0.1,1), optimal token ( =1) (1,0), non-opt token ( =0.9) (0,0), non-opt token ( =0) (b) Global and local optimal directions 2 1 0 1 2 3 4 5 6 2 Optimal tokens ( =1) Non-opt tokens ( =0.5) (c) Multiple inputs Figure 1: The convergence behavior of the gradient descent on the attention weights p using the logistic loss in (ERM). The arrows ( > ) represent trajectories from different initializations. Here, (- - -) and (- - -) denote the globallyand locally-optimal max-margin directions (GMM, LMM). γ denotes the score of a token per Definition 1. Discussion is provided under Theorems 2 and 3. denotes the softmax nonlinearity. For transformers, p corresponds to the [CLS] token or tunable prompt [16, 17, 18], whereas for RNN architectures [1], p corresponds to the hidden state. Given training data (Yi, Xi)n i=1 with labels Yi { 1, 1} and inputs Xi RT d, we consider the empirical risk minimization with a decreasing loss function ℓ( ) : R R, L(v, p,W) = 1 i=1 ℓ(Yi f(Xi)), where f(Xi) = v X i S(Xi W p). (1) At a high-level, this work establishes fundamental equivalences between the optimization trajectories of (1) and hard-margin SVM problems. Our main contributions are as follows: Optimization geometry of attention (Sec 2): We first show that gradient iterations of p and W admit a one-to-one mapping, thus we focus on optimizing p without losing generality. In Theorem 3, we prove that, under proper initialization: Gradient descent on p converges in direction to a max-margin solution namely (ATT-SVM) that separates locally-optimal tokens from non-optimal ones. We call these Locally-optimal Max-Margin (LMM) directions and show that these thoroughly characterize the viable convergence directions of attention when the norm of its weights grows to infinity. We also identify conditions under which (algorithm-independent) regularization path and gradient descent path converge to Globally-optimal Max-Margin (GMM) direction in Theorems 1 and 2, respectively. A central feature of our results is precisely quantifying optimality in terms of token scores γt = Y v xt where xt is the tth token of the input sequence X. Locally-optimal tokens are those with higher scores than their nearest neighbors determined by the SVM solution. These are illustrated in Figure 1. Optimize attention p and prediction-head v jointly (Sec 3): We study the joint problem under logistic loss function. We use regularization path analysis where (ERM) is solved under ridge constraints and we study the solution trajectory as the constraints are relaxed. Since the problem is linear in v, if the attention features xatt i = X i S(Xi W p) are separable based on their labels Yi, v would implement a max-margin classifier. Building on this, we prove that p and v converges to their respective max-margin solutions under proper geometric conditions (Theorem 5). Relaxing these conditions, we obtain a more general solution where margin constraints on p are relaxed on the inputs whose attention features are not support vectors of v (Theorem 6). Figure 3 illustrates these outcomes. The next section introduces the preliminary concepts, Section 4 presents numerical experiments1, Section 5 discusses related literature, and Section 6 highlights limitations and future work. 1.1 Preliminaries Notations. For any integer N 1, let [N] := {1, . . . , N}. We use lower-case and upper-case bold letters (e.g. a and A) to represent vectors and matrices, respectively. The entries of a are denoted as ai. We use σ(A) to denote the maximum singular value of A. We denote the minimum of two numbers a, b as a b, and the maximum as a b. Big-O notation O( ) hides the universal constants. Throughout, we will use L(p) and L(v, p) to denote Objective (1) with fixed (v,W) and W, respectively. Optimization. Given an objective function L : Rd R and an ℓ2-norm bound R, define the regularized solution as 1The code for experiments can be found at https://github.com/ucr-optml/max_margin_attention. p(R) := arg min p R L(p). (2) Regularization path the evolution of p(R) as R grows is known to capture the spirit of gradient descent as the ridge constraint R provides a proxy for the number of gradient descent iterations. For instance, [19, 20, 21] study the implicit bias of logistic regression and rigorously connect the directional convergence of regularization path (i.e. lim R p(R)/R) and gradient descent. For gradient descent, we assume the objective L(p) is smooth and describe the gradient descent process as p(t + 1) = p(t) η(t) L(p(t)), (3) where η(t) is the stepsize at time t and L(p(t)) is the gradient of L at p(t). Attention in Transformers. Next, we will discuss the connection between our model and the attention mechanism used in transformers. Our exposition borrows from [17], where the authors analyze the same attention model using gradient-based techniques on specific contextual datasets. Self-attention is the core building block of transformers [6]. Given an input consisting of T tokens X = [x1, . . . , x T] RT d, self-attention with key-query matrix W Rd d, and value matrix V Rd v, the self-attention model is defined as follows: fsa(X) = S(XWX )XV. (4) Here, S( ) is the softmax nonlinearity that applies row-wise on the similarity matrix XWX . Tunable tokens: [CLS] and prompt-tuning. In practice, we append additional tokens to the raw input features X: For instance, a [CLS] token is used for classification purposes [7] and prompt vectors can be appended for adapting a pretrained model to new tasks [16, 18]. Let p Rd be the tunable token ([CLS] or prompt vector) and concatenate it to X to obtain Xp := [p X ] R(T+1) d. Consider the cross-attention features obtained from Xp and X given by " f cls(X) fsa(X) # = S(Xp WX )XV = " S(p WX ) S(XWX ) The beauty of cross-attention is that it isolates the contribution of p under the upper term fcls(X) = V X S(XW p) Rv. In this work, we use the value weights for classification, thus we set v = 1, and denote v = V Rd. This brings us to our attention model of interest: f(X) = v X S(Kp), where K = XW . (5) Here, (v,W, p) are the tunable model parameters and K is the key embeddings. Note that W and p are playing the same role within softmax, thus, it is intuitive that they exhibit similar optimization dynamics. Confirming this, the next lemma shows that gradient iterations of p (after setting W Identity) and W admit a one-to-one mapping. Lemma 1 Fix u Rd \ {0} . Let ψ : Rd R and ℓ: R R be differentiable functions. On the same training data (Yi, Xi)n i=1, define L(p) := 1/n Pn i=1 ℓ(Yi ψ(X i S(Xip))) and L(W) := 1/n Pn i=1 ℓ(Yi ψ(X i S(Xi W u))). Consider the gradient descent iterations on p and W with initial values p(0) and W(0) = up(0) / u 2 and stepsizes η and η/ u 2, respectively: p(t + 1) = p(t) η L(p(t)), W(t + 1) = W(t) η u 2 L(W(t)). We have that W(t) = up(t) / u 2 for all t 0. This lemma directly characterizes the optimization dynamics of W through the dynamics of p, allowing us to reconstruct W from p using their gradient iterations. Therefore, we will fix W and concentrate on optimizing p in Section 2 and the joint optimization of (v, p) in Section 3. Problem definition: Throughout, (Yi, Xi)n i=1 denotes training dataset where Yi { 1, 1} and Xi RT d. We denote the key embeddings of Xi via Ki = Xi W and explore the training risk L(v, p) = 1 i=1 ℓ Yi v X i S(Kip) . (ERM) Importantly, our results apply to general tuples (Yi, Xi, Ki) and do not assume that (Xi, Ki) are tied via W. Finally, the tth tokens of Xi, Ki are denoted by xit, kit Rd, respectively, for t [T]. The highly nonlinear and nonconvex nature of the softmax operation makes the training problem in (ERM) a challenging nonconvex optimization problem for p, even with a fixed v. In the next section, we will introduce a set of assumptions to demonstrate the global and local convergence of gradient descent for margin maximization in the attention mechanism. 2 Global and Local Margin Maximization with Attention In this section, we present the main results of this paper (Theorems 2 and 3) by examining the implicit bias of gradient descent on learning p Rd given a fixed choice of v Rd. Notably, our results apply to general decreasing loss functions without requiring convexity. This generality is attributed to margin maximization arising from the exponentially-tailed nature of softmax within attention, rather than ℓ. We maintain the following assumption on the loss function throughout this section. Assumption A (Well-behaved Loss) Over any bounded interval: (1) ℓ: R R is strictly decreasing. (2) ℓ is M0-Lipschitz continuous and |ℓ (u)| M1. Assumption A includes many common loss functions, including the logistic loss ℓ(u) = log (1 + e u), exponential loss ℓ(u) = e u, and correlation loss ℓ(u) = u. Assumption A implies that L (p) is Lp smooth (see Lemma 6 in Supplementary), where M0 v 2 W 2 Xi 4 + 3M1 v W 2 Xi 3 . (6) We now introduce a convex hard-margin SVM problem that separates one token of the input sequence from the rest, jointly solved over all inputs. We will show that this problem captures the optimization properties of softmax-attention. Fix indices α = (αi)n i=1 and consider pmm(α) = arg min p p subject to min t,αi p (kiαi kit) 1, for all 1 i n. (ATT-SVM) Note that existence of pmm(α) implies the separability of tokens α from the others. Specifically, choosing direction pmm(α) will exactly select tokens (xiαi)n i=1 at the attention output for each input sequence, that is, lim R X i S(R Kipmm(α)) = xiαi. We are now ready to introduce our main results that characterize the global and local convergence of the attention weights p via (ATT-SVM). 2.1 Global convergence of the attention weights p We first identify the conditions that guarantee the global convergence of gradient descent for p. The intuition is that, in order for attention to exhibit implicit bias, the softmax nonlinearity should be forced to select the optimal token within each input sequence. Fortunately, the optimal tokens that achieve the smallest training objective under decreasing loss function ℓ( ) have a clear definition. Definition 1 (Token Scores, Optimality & GMM) The score of token xit of input Xi is defined as γit := Yi v xit. The optimal tokens for input Xi are those tokens with highest scores given by opti arg max t [T] γit. Globally-optimal max-margin (GMM) direction is defined as the solution of (ATT-SVM) with optimal indices (opti)n i=1 by pmm . It is worth noting that score definition simply uses the value embeddings v xit of the tokens. Note that multiple tokens within an input might attain the same score, thus opti or pmm may not be unique. The theorem below provides our regularization path guarantee on the global convergence of attention. Theorem 1 (Regularization Path) Suppose Assumption A on the loss function holds, and for all i [n] and t , opti, the scores obey γit < γiopti. Then, the regularization path p(R) = arg min p R L(p) converges to the GMM direction i.e. lim R p(R)/R = pmm / pmm . Theorem 1 shows that as the regularization strength R increases towards the ridgeless problem minp L(p), the optimal direction p(R) aligns more closely with the max-margin solution pmm . Since this theorem allows for arbitrary token scores, it demonstrates that max-margin token separation is an essential feature of the attention mechanism. In fact, it is a corollary of Theorem 8, which applies to the generalized model f(X) = ψ(X S(XW p)) and accommodates multiple optimal tokens per input. However, while regularization path analysis captures the global behavior, gradient descent lacks general global convergence guarantees. In Section 2.2, we show that due to the nonconvex landscape and softmax nonlinearity, gradient descent often converges to local optima. We first establish that when (ERM) is trained with gradient descent, the norm of the parameters will diverge. For the restrictive setting of n = 1, gradient descent also exhibits a global convergence guarantee. Assumption B For all i [n] and t, τ , opti, the scores per Definition 1 obey γit = γiτ < γiopti. Theorem 2 (Global Convergence of Gradient Descent) Suppose Assumption A on the loss function ℓand Assumption B on the tokens score hold. Then, the gradient descent iterates p(t + 1) = p(t) η L(p(t)) on (ERM), with the stepsize η 1/Lp and any starting point p(0) satisfy limt p(t) = . If n = 1, we also have limt p(t)/ p(t) = pmm / pmm . Theorem 2 shows that gradient descent will diverge in norm, and when n = 1, the normalized predictor p(t)/ p(t) converges towards pmm , the separator of the globally optimal token. While n = 1 is a stringent condition, this requirement is in fact tight as discussed in Appendix E. To illustrate this theorem, we have conducted synthetic experiments. Let us first explain the setup used in Figure 1. We set d = 3 as the dimension, with each token having three entries x = [x1, x2, x3]. We reserve the first two coordinates as key embeddings k = [x1, x2, 0] by setting W = diag([1, 1, 0]). This is what we display in our figures as token positions. Finally, in order to assign scores to the tokens we use the last coordinate by setting v = [0, 0, 1]. This way score becomes Y v x = Y x3, allowing us to assign any score (regardless of key embedding). In Figure 1(a), the gray paths represent gradient descent trajectories from different initializations. The points (0, 0) and (1, 0) correspond to non-optimal tokens, while the point ( 0.1, 1) represents the optimal token. Notably, gradient descent iterates with various starting points converge towards the direction of the max-margin solution pmm (depicted by - - -). Moreover, as the iteration count t increases, the inner product p(t)/ p(t) , pmm / pmm consistently increases. Figure 1(c) also depicts the directional convergence of gradient descent from various initializations on multiple inputs, with the gray dotted line representing the separating hyperplane. These emphasize the gradual alignment between the evolving predictor and the max-margin solution throughout the optimization. Lemma 2 Suppose for all i [n] and t , opti, Yi = 1 and γit < γiopti. Also assume W Rd d is full-rank. Then pmm exists i.e. (ATT-SVM) is feasible for optimal indices αi opti. 2.2 Local convergence of the attention weights p Theorem 2 on the global convergence of gradient descent serves as a prelude to the general behavior of the optimization. Once we relax Assumption B by allowing for arbitrary token scores, we will show that p can converge (in direction) to a locally-optimal solution. However, this locally-optimal solution is still characterized in terms of (ATT-SVM) which separates locally-optimal tokens from the rest. Our theory builds on two new concepts: locally-optimal tokens and neighbors of these tokens. Definition 2 (SVM-Neighbor and Locally-Optimal Tokens) Fix token indices α = (αi)n i=1 for which (ATT-SVM) is feasible to obtain pmm = pmm(α). Consider tokens Ti [T] such that (kiαi kit) pmm = 1 for all t Ti. We refer to Ti as SVM-neighbors of kiαi. Additionally, tokens with indices α = (αi)n i=1 are called locally-optimal if for all i [n] and t Ti scores per Definition 1 obey γiαi > γit. Associated pmm is called a locally-optimal max-margin (LMM) direction. Optimal Token Locally-Optimal Token Non-Optimal Token Figure 2: Gradient descent initialization p(0) inside the cone containing the locally-optimal solution pmm. To provide a basis for discussing local convergence, we provide some preliminary definitions regarding cones. For a given q and a scalar µ > 0, we define coneµ(q) as the set of vectors p Rd such that the correlation coefficient between p and q is at least 1 µ : coneµ(q) := ( p Rd + 1 µ ) . (7) Given R > 0, the intersection of coneµ(q) and the set {p Rd| p R} is denoted as Cµ,R(q): Cµ,R(q) := coneµ(q) n p Rd p R o . (8) Next, we demonstrate the existence of parameters µ = µ(α) > 0 and R > 0 such that when R is sufficiently large, there are no stationary points within Cµ,R(pmm). Further, the gradient descent initialized within Cµ,R(pmm) converges in direction to pmm/ pmm ; refer to Figure 2 for a visualization. Theorem 3 (Local Convergence of Gradient Descent) Suppose Assumption A on the loss function ℓholds and assume α = (αi)n i=1 are indices of locally-optimal tokens per Definition 2. Then, there is a constant µ = µ(α) (0, 1) and R > 0 such that Cµ,R(pmm) does not contain any stationary points. Further, for any starting point p(0) Cµ,R(pmm), gradient descent iterates p(t + 1) = p(t) η L(p(t)) on (ERM) with stepsize η 1/Lp satisfies limt p(t) = and limt p(t)/ p(t) = pmm/ pmm . To further illustrate Theorem 3, we can consider Figure 1(b) where n = 1 and T = 3. In this figure, the point (0, 0) represents the non-optimal tokens, while (1, 0) represents the locally optimal token. Additionally, the gray paths represent the trajectories of gradient descent initiated from different points. By observing the figure, we can see that gradient descent, when properly initialized, converges towards the direction of pmm (depicted by - - -). This direction of convergence effectively separates the locally optimal tokens (1, 0) from the non-optimal token (0, 0). 2.3 Regularization paths can only converge to locally-optimal max-margin directions An important question arises regarding whether our definition of LMM (Definition 2) encompasses all possible convergence paths of the attention mechanism when p . To address this, we introduce the set of LMM directions as follows: Pmm := ( pmm(α) α is locally-optimal per Definition 2 The following theorem establishes the tightness of these directions: It demonstrates that for any candidate q < Pmm, its local regularization path within an arbitrarily small neighborhood will provably not converge in the direction of q. Theorem 4 Fix q < Pmm with unit ℓ2 norm. Assume that token scores are distinct (namely γit , γiτ for t , τ) and key embeddings kit are in general position (see Theorem 7). Fix arbitrary ϵ > 0, R0 > 0. Define the local regularization path of q as its (ϵ, R0)-conic neighborhood: p(R) = arg min p Cϵ,R0(q), p R L(p), where Cϵ,R0(q) = coneϵ(q) n p Rd p R0 o . (9) Then, either lim R p(R) < or lim R p(R)/ p(R) , q. In both scenarios lim R p(R)/R , q. The result above nicely complements Theorem 3, which states that when gradient descent is initialized above a threshold ( p(0) R0) in an LMM direction, p(t) diverges but the direction converges to LMM. In contrast, Theorem 4 shows that regardless of how small the cone is (in terms of angle and norm lower bound p R0), the optimal solution path will not converge along q < Pmm. 3 Joint Convergence of Head v and Attention Weights p In this section, we extend the preceding results to the general case of joint optimization of head v and attention weights p using a logistic loss function. To this aim, we focus on regularization path analysis, which involves solving (ERM) under ridge constraints and examining the solution trajectory as the constraints are relaxed. High-level intuition. Since the prediction is linear as a function of v, logistic regression in v can exhibit its own implicit bias to a max-margin solution. Concretely, define the attention features xp i = X i S(Kip) and define the dataset Dp = (Yi, xp i )n i=1. If this dataset Dp is linearly separable, then fixing p and optimizing only v will converge in the direction of the standard max-margin classifier vmm = arg min v Rd v subject to Yi v ri 1, for all 1 i n, (SVM) after setting inputs to the attention features ri xp i [22]. This motivates a clear question: Under what conditions, optimizing v, pjointly will converge to their respective max-margin solutions? We study this question in two steps. Loosely speaking: (1) We will first assume that, at the optimal 2 0 2 4 6 8 10 12 14 (0,0), (1,1), Y=1 (0,0), (1,-1), Y=-1 (0,0), (0.5,1), Y=1 (a) All inputs are support vectors 2 0 2 4 6 8 10 12 14 (0,0), (1,1), Y=1 (0,0), (1,-1), Y=-1 (0,0), (0.5,1.5), Y=1 (b) (0.5,1.5) is not a support vector 0 30 60 90 120 150 Iterations Probabilities Softmax prob. Logistic prob. (c) Probability evolutions in (a) Figure 3: (a) and (b) Joint convergence of attention weights p ( > ) and classifier head v ( > ) to max-margin directions. (c) Averaged softmax probability evolution of optimal tokens and logistic probability evolution of output in (a). tokens xiαi, i [n] selected by p, when solving (SVM) with ri xiαi, all of these tokens become support vectors of (SVM). (2) We will then relax this condition to uncover a more general implicit bias for p that distinguish support vs non-support vectors. Throughout, we assume that the joint problem is separable and there exists (v, p) asymptotically achieving zero training loss. 3.1 When all attention features are support vectors In (SVM), define label margin to be 1/ vmm . Our first insight in quantifying the joint implicit bias is that, optimal tokens admit a natural definition: Those that maximize the downstream label margin when selected. This is formalized below where we assume that: (1) Selecting the token indices α = (αi)n i=1 from each input data achieves the largest label margin. (2) The optimality of the α choice is strict in the sense that mixing other tokens will shrink the label margin in (SVM). Assumption C (Optimal Tokens) Let Γ > 0 be the label margin when solving (SVM) with ri xiαi. There exists ν > 0 such that for all p, solving (SVM) with ri xp i results in a label margin of at most Γ ν maxi [n](1 siαi) where si = S(Kip). Example: To gain intuition, let us fix a Rd and consider the dataset obeying xi1 = Yi a and xit < a for all t 2 and all i [n]. For this dataset, we can choose αi = 1, vmm = a/ a 2, Γ = 1/ vmm = a and ν = a supi [n],t 2 xit . Theorem 5 Consider the ridge-constrained solutions (vr, p R) of (ERM) defined as (vr, p R) = arg min v r, p R L(v, p). Suppose there are token indices α = (αi)n i=1 for which pmm(α) exists (ATT-SVM is feasible) and Assumption C holds for some Γ, ν > 0. Then, lim R p R/R = pmm/ pmm , where pmm is the solution of (ATT-SVM); and limr vr/r = vmm/ vmm , where vmm is the solution of (SVM) with ri = xiαi. As further discussion, consider Figure 3(a) where we set n = 3, T = d = 2 and W = Identity. All three inputs share the point (0, 0) which corresponds to their non-optimal tokens. The optimal tokens (denoted by ) are all support vectors of the (SVM) since vmm = [0, 1] is the optimal classifier direction (depicted by - - -). Because of this, pmm will separate optimal tokens from tokens at the (0, 0) coordinate via (ATT-SVM) and its direction is dictated by yellow and teal colored s which are the support vectors. 3.2 General solution when selecting one token per input Can we relax Assumption C, and if so, what is the resulting behavior? Consider the scenario where the optimal p diverges to and ends up selecting one token per input. Suppose this p selects some coordinates α = (αi)n i=1. Let S [n] be the set of indices where the associated token xiαi is a support vector when solving (SVM). Set S = [n] S. Our intuition is as follows: Even if we slightly perturb this p choice and mix other tokens t , αi over the input set S [n], since S is not support vector for (SVM), we can preserve the label margin (by only preserving the support vectors S). This means that p may not have to enforce max-margin constraint over inputs i S, instead, it suffices to just select 0 50 100 150 200 Iterations Softmax probability Normalized GD GD (Constant ) (a) Evolution of softmax probability 0 50 100 150 200 Iterations Attention norm ||p|| Normalized GD GD (Constant ) (b) Evolution of attention weights Figure 4: Evolution of softmax probability and attention weights when training with normalized gradient descent or constant step size η respectively. 2 0 2 4 6 8 10 12 6 C=1 C=2~9 C=10 C=100 (x) = log(1 + e x) (0,0), non-opt token ( =0) (1,1), optimal token ( =C) (0,0), non-opt token ( =0) (1,-1), optimal token ( =1) Figure 5: Trajectories of p with different loss functions and scores in Theorem 2. these tokens (asymptotically). This results in the following relaxed SVM problem: prelax = arg min p p such that p (kiαi kit) (1 for all t , αi, i S 0 for all t , αi, i S . (10) Here, p (kiαi kit) 0 corresponds to the selection idea. Building on this intuition, the following theorem captures the generalized behavior of the joint regularization path. Theorem 6 Consider the same (ERM) problem as discussed in Theorem 5. Suppose S(Kip R)αi 1, i.e., the tokens (αi)n i=1 are asymptotically selected. Let vmm be the solution of (SVM) with ri = xiαi and S be its set of support vector indices. Suppose Assumption C holds over S i.e. having siαi < 1 shrinks the margin when (SVM) is only solved over S [n]. Then, limr vr/r = vmm/ vmm and lim R p R/R = prelax/ prelax , where prelax is the solution of (10) with (αi)n i=1 choices. To illustrate this numerically, consider Figure 3(b) which modifies Figure 3(a) by pushing the yellow to the northern position (0.5, 1.5). We still have vmm = [0, 1] however the yellow is no longer a support vector of (SVM). Thus, p solves the relaxed problem (10) which separates green and teal s by enforcing the max-margin constraint on p (which is the red direction). Instead, yellow only needs to achieve positive correlation with p (unlike Figure 3(a) where it dictates the direction). We also display the direction of pmm using a gray dashed line. We further investigate the evolution of softmax and logistic output probabilities throughout the training process of Figure 3(a), and the results are illustrated in Figure 3(c). The averaged softmax probability of optimal tokens is represented by the red curve and is calculated as 1 n Pn i=1 maxt [T] S(Kip)t. An achievement of 1 for this probability indicates that the attention mechanism successfully selects the optimal tokens. On the other hand, the logistic probability of the output is represented by the blue curve and is determined by 1/n Pn i=1 1/(1 + e Yi f(Xi)). This probability also reaches a value of 1, suggesting that the inputs are correctly classified. 4 Experiments Sparsity of softmax and evolution of attention weights. It is well known that, in practice, attention maps often exhibit sparsity and highlight salient tokens that aid inference. Our results provide a formal explanation of this when tokens are separable: Since attention selects a locally-optimal token within the input sequence and suppresses the rest, the associated attention map S(Xp) will (eventually) be a sparse vector. Additionally, the sparsity should arise in tandem with the increasing norm of attention weights. We provide empirical evidence to support these findings. Synthetic experiments. Figures 4(a) and 4(b) show the evolution of the largest softmax probability and attention weights over time when using either normalized gradient or a fixed stepsize η for training. The dataset model follows Figure 1(c). The softmax probability shown in Figure 4(a) is defined as 1 n Pn i=1 maxt [T] S(Kip)t. When this average probability reaches the value of 1, it means attention selects only a single token per input. The attention norm in Figure 4(b), is simply equal to p . The red curves in both figures represent the normalized gradient method, which updates the model parameters p using p(t + 1) = p(t) η L(p(t))/ L(p(t)) with η = 0.1. The blue curves correspond (a) Input image (b) Epoch 0 (c) Epoch 100 (d) Epoch 200 (e) Epoch 300 (f) Epoch 400 Figure 6: Illustration of the progressive change in attention weights of the [CLS] token during training in the transformer model, using a specific input image shown in Figure 6(a). 0 100 200 300 400 Epoch Figure 7: Red curve is the sparsity level d nnz(s)/ T of the average attention map which takes values on [0,1]. A sparser vector implies that few key tokens receive significantly higher attention, while the majority of the tokens receive minimal attention. Blue curve is the Frobenius norm of attention weights W F of the final layer. We display their evolutions over epochs. to gradient descent with constant learning rate given by p(t + 1) = p(t) η L(p(t)) with η = 1. Observe that the normalized gradient method achieves a softmax probability of 1 quicker as vanilla GD suffers from vanishing gradients. This is visible in Figure 4(b) where blue norm curve levels off. Real experiments. To study softmax sparsity and the evolution of attention weights throughout training, we train a vision transformer (Vi T-base) model [23] from scratch, utilizing the CIFAR10 dataset [24] for 400 epochs with fixed learning rate 3 10 3. Vi T tokenizes an image into 16 16 patches, thus, its softmax attention maps can be easily visualized. We examine the average attention map associated with the [CLS] token computed from all 12 attention heads within the model. Figure 6 provides a visual representation of the resulting attention weights (16 16 grids) corresponding to the original patch locations within the image. During the initial epochs of training, the attention weights are randomly distributed and exhibit a dense pattern. However, as the training progresses, the attention map gradually becomes sparser and the attention mechanism begins to concentrate on fewer salient patches within the image that possess distinct features that aid classification. This illustrates the evolution of attention from a random initial state to a more focused and sparse representation. These salient patches highlighted by attention conceptually corresponds to the optimal tokens within our theory. We quantify the sparsity of the attention map via a soft-sparsity measure, denoted by d nnz(s) where s is the softmax probability vector. The soft-sparsity is computed as the ratio of the ℓ1 norm to the squared ℓ2 norm, defined as d nnz(s) = s 1/ s 2. d nnz(s) takes values between 1 to T = 256 and a smaller value indicates a sparser vector. Also note that s 1 = PT t=1 st = 1. Together with sparsity, Figure 7 also displays the Frobenius norm of the combined key-query matrix W of the last attention layer over epochs. The theory suggests that the increase in sparsity is associated with the growth of attention weights which converge directionally. The results in Figure 7 align with the theory, demonstrating the progressive sparsification of the attention map as W F grows. Transient optimization dynamics and the influence of the loss function. Theorem 2 shows that the asymptotic direction of gradient descent is determined by pmm . However, it is worth noting that transient dynamics can exhibit bias towards certain input examples and their associated optimal tokens. We illustrate this idea in Fig 5(a), which displays the trajectories of the gradients for different scores and loss functions. We consider two optimal tokens ( ) with scores γ1 = 1 and γ2 = C, where C varies. For our analysis, we examine the correlation loss ℓ(x) = x and the logistic loss ℓ(x) = log(1 + e x). In essence, as C increases, we can observe that the correlation loss ℓ(x) = x exhibits a bias towards the token with a high score, while the logistic loss is biased towards the token with a low score. The underlying reason for this behavior can be observed from the gradients of individual inputs: Li(p) = ℓ i K i S (Xp)Xv, where S ( ) represents the derivative of the softmax function and ℓ i := ℓ (Yi v X i S(Xip)). Assuming that p (approximately) selects the optimal tokens, this simplifies to ℓ i ℓ (γi) and Li(p) |ℓ (γi)| γi. With the correlation loss, |ℓ | = 1, resulting in Li(p) γi, meaning that a larger score induces a larger gradient. On the other hand, the logistic loss behaves similarly to the exponential loss under separable data, i.e., |ℓ | = e x/(1 + e x) e x. Consequently, Li(p) γie γi e γi, indicating that a smaller score leads to a larger gradient. These observations explain the empirical behavior we observe. 5 Related Work Implicit Regularization. The implicit bias of gradient descent in classification tasks involving separable data has been extensively examined by [22, 25, 26, 27, 28, 29]. These works typically use logistic loss or, more generally, exponentially-tailed losses to make connections to margin maximization. These results are also extended to non-separable data by [30, 31, 21]. Furthermore, there have been notable investigations into the implicit bias in regression problems/losses utilizing techniques such as mirror descent [32, 25, 33, 34, 35, 36]. In addition, several papers have explored the implicit bias of stochastic gradient descent [37, 38, 39, 40, 41, 42], as well as adaptive and momentum-based methods [43, 44, 45, 46]. Although there are similarities between our optimization approach for v and existing works, the optimization of p stands out as significantly different. Firstly, our optimization problem is nonconvex, introducing new challenges and complexities. Secondly, it necessitates the introduction of novel concepts such as locally-optimal tokens and requires a fresh analysis specifically tailored to the cones surrounding them. Attention Mechanism. Transformers, introduced by [6], revolutionized the field of NLP and machine translation, with earlier works on self-attention by [47, 48, 49, 50]. Self-attention differs from traditional models like MLPs and CNNs by leveraging global interactions for feature representations, showing exceptional empirical performance. However, the underlying mechanisms and learning processes of the attention layer remain unknown. Recent studies such as [51, 52, 53, 54, 23] have focused on specific aspects like representing sparse functions, convex-relaxations, and expressive power. In contrast to our nonconvex (ERM), [52] studies self-attention with linear activation instead of softmax, while [53] approximates softmax using a linear operation with unit simplex constraints. Their main objective is to derive convex reformulations for ERM-based training problem. [55, 56] have developed initial results to characterize the optimization and generalization dynamics of attention. [17] is another closely related work where the authors analyze the same attention model (ERM) as us. Specifically, they jointly optimize v, p for three gradient iterations for a contextual dataset model. However, all of these works make stringent assumptions on the data, namely, tokens are tightly clusterable or can be clearly split into clear relevant and irrelevant sets. Additionally [56] requires assumptions on initialization and [55] considers a simplified attention structure where the attention matrix is not directly parameterized with respect to the input. Our work links attention models to hard-margin SVM problems and pioneers the study of gradient descent s implicit bias in these models. 6 Discussion We have provided a thorough optimization-theoretic characterization of the fundamental attention model f(X) = v X S(XW p) by formally connecting it to max-margin problems. We first established the convergence of gradient descent on p (or equivalently W) in isolation. We also explored joint convergence of (v, p) via regularization path which revealed surprising implicit biases such as (10). These findings motivate several exciting avenues for future research. An immediate open problem is characterizing the (local) convergence of gradient descent for joint optimization of (v, p). Another major direction is to extend similar analysis to study self-attention layer (4) or to allow for multiple tunable tokens (where p becomes a matrix). Either setting will enrich the problem by allowing the attention to discover multiple hyperplanes to separate tokens. While our convergence guarantees apply when tokens are separable, it would be interesting to characterize the non-separable geometry by leveraging results developed for logistic regression analysis [31, 22]. Ideas from such earlier results can also be useful for characterizing the non-asymptotic/transient dynamics of how gradient descent aligns with the max-margin direction. Overall, we believe that max-margin token selection is a fundamental characteristic of attention mechanism and the theory developed in this work lays the groundwork of these future extensions. Acknowledgements This work was supported by the NSF grants CCF-2046816 and CCF-2212426, Google Research Scholar award, and Army Research Office grant W911NF2110312. The authors express their gratitude for the valuable feedback provided by the anonymous reviewers and Christos Thrampoulidis, which has significantly improved this paper. [1] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. The International Conference on Learning Representations, 2015. [2] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, and et al. Language models are few-shot learners. In Advances in neural information processing systems, volume 33, pages 1877 1901, 2020. [3] Mark Chen, Jerry Tworek, Heewoo Jun, Qiming Yuan, Henrique Ponde de Oliveira Pinto, Jared Kaplan, Harri Edwards, Yuri Burda, Nicholas Joseph, Greg Brockman, et al. Evaluating large language models trained on code. ar Xiv preprint ar Xiv:2107.03374, 2021. [4] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. Open AI blog, 1(8):9, 2019. [5] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. ar Xiv preprint ar Xiv:2204.02311, 2022. [6] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, volume 30, 2017. [7] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 4171 4186, Minneapolis, Minnesota, June 2019. Association for Computational Linguistics. [8] Open AI. Gpt-4 technical report. ar Xiv preprint ar Xiv:2303.08774, 2023. [9] Rishi Bommasani, Drew A Hudson, Ehsan Adeli, Russ Altman, Simran Arora, Sydney von Arx, Michael S Bernstein, Jeannette Bohg, Antoine Bosselut, Emma Brunskill, et al. On the opportunities and risks of foundation models. ar Xiv preprint ar Xiv:2108.07258, 2021. [10] Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, and Ilya Sutskever. Zero-shot text-to-image generation. In International Conference on Machine Learning, pages 8821 8831. PMLR, 2021. [11] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. [12] Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, et al. Learning transferable visual models from natural language supervision. In International conference on machine learning, pages 8748 8763. PMLR, 2021. [13] Danny Driess, Fei Xia, Mehdi SM Sajjadi, Corey Lynch, Aakanksha Chowdhery, Brian Ichter, Ayzaan Wahid, Jonathan Tompson, Quan Vuong, Tianhe Yu, et al. Palm-e: An embodied multimodal language model. ar Xiv preprint ar Xiv:2303.03378, 2023. [14] Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Misha Laskin, Pieter Abbeel, Aravind Srinivas, and Igor Mordatch. Decision transformer: Reinforcement learning via sequence modeling. In Advances in Neural Information Processing Systems, volume 34, pages 15084 15097, 2021. [15] Scott Reed, Konrad Zolna, Emilio Parisotto, Sergio Gomez Colmenarejo, Alexander Novikov, Gabriel Barth-Maron, Mai Gimenez, Yury Sulsky, Jackie Kay, Jost Tobias Springenberg, et al. A generalist agent. ar Xiv preprint ar Xiv:2205.06175, 2022. [16] Brian Lester, Rami Al-Rfou, and Noah Constant. The power of scale for parameter-efficient prompt tuning. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 3045 3059, 2021. [17] Samet Oymak, Ankit Singh Rawat, Mahdi Soltanolkotabi, and Christos Thrampoulidis. On the role of attention in prompt-tuning. In International Conference on Machine Learning, 2023. [18] Xiang Lisa Li and Percy Liang. Prefix-tuning: Optimizing continuous prompts for generation. ar Xiv preprint ar Xiv:2101.00190, 2021. [19] Saharon Rosset, Ji Zhu, and Trevor Hastie. Margin maximizing loss functions. Advances in neural information processing systems, 16, 2003. [20] Arun Suggala, Adarsh Prasad, and Pradeep K Ravikumar. Connecting optimization and regularization paths. Advances in Neural Information Processing Systems, 31, 2018. [21] Ziwei Ji, Miroslav Dudík, Robert E Schapire, and Matus Telgarsky. Gradient descent follows the regularization path for general losses. In Conference on Learning Theory, pages 2109 2136. PMLR, 2020. [22] Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The implicit bias of gradient descent on separable data. The Journal of Machine Learning Research, 19(1):2822 2878, 2018. [23] Yihe Dong, Jean-Baptiste Cordonnier, and Andreas Loukas. Attention is not all you need: Pure attention loses rank doubly exponentially with depth. In International Conference on Machine Learning, pages 2793 2803. PMLR, 2021. [24] Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The cifar-10 dataset. online: http://www. cs. toronto. edu/kriz/cifar. html, 55(5), 2014. [25] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Characterizing implicit bias in terms of optimization geometry. In International Conference on Machine Learning, pages 1832 1841. PMLR, 2018. [26] Mor Shpigel Nacson, Jason Lee, Suriya Gunasekar, Pedro Henrique Pamplona Savarese, Nathan Srebro, and Daniel Soudry. Convergence of gradient descent on separable data. In The 22nd International Conference on Artificial Intelligence and Statistics, pages 3420 3428. PMLR, 2019. [27] Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual analysis. In Algorithmic Learning Theory, pages 772 804. PMLR, 2021. [28] Edward Moroshko, Blake E Woodworth, Suriya Gunasekar, Jason D Lee, Nati Srebro, and Daniel Soudry. Implicit bias in deep linear classification: Initialization scale vs training accuracy. Advances in neural information processing systems, 33:22182 22193, 2020. [29] Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep learning. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 17176 17186. Curran Associates, Inc., 2020. [30] Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic regression. ar Xiv preprint ar Xiv:1803.07300, 2018. [31] Ziwei Ji and Matus Telgarsky. The implicit bias of gradient descent on nonseparable data. In Conference on Learning Theory, pages 1772 1798. PMLR, 2019. [32] Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro Savarese, Itay Golan, Daniel Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. In Conference on Learning Theory, pages 3635 3673. PMLR, 2020. [33] Chulhee Yun, Shankar Krishnan, and Hossein Mobahi. A unifying view on implicit bias in training linear neural networks. ar Xiv preprint ar Xiv:2010.02501, 2020. [34] Tomas Vaskevicius, Varun Kanade, and Patrick Rebeschini. Implicit regularization for optimal sparse recovery. Advances in Neural Information Processing Systems, 32:2972 2983, 2019. [35] Ehsan Amid and Manfred K Warmuth. Winnowing with gradient descent. In Conference on Learning Theory, pages 163 182. PMLR, 2020. [36] Ehsan Amid and Manfred KK Warmuth. Reparameterizing mirror descent as gradient descent. Advances in Neural Information Processing Systems, 33:8430 8439, 2020. [37] Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization effect of initial large learning rate in training neural networks. ar Xiv preprint ar Xiv:1907.04595, 2019. [38] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regularization for deep neural networks driven by an ornstein-uhlenbeck like process. In Conference on learning theory, pages 483 513. PMLR, 2020. [39] Jeff Z Hao Chen, Colin Wei, Jason D Lee, and Tengyu Ma. Shape matters: Understanding the implicit bias of the noise covariance. ar Xiv preprint ar Xiv:2006.08680, 2020. [40] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD reaches zero loss? a mathematical framework. In International Conference on Learning Representations, 2022. [41] Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat global minimizers. ar Xiv preprint ar Xiv:2106.06530, 2021. [42] Difan Zou, Jingfeng Wu, Vladimir Braverman, Quanquan Gu, Dean P Foster, and Sham Kakade. The benefits of implicit regularization from sgd in least squares problems. Advances in Neural Information Processing Systems, 34:5456 5468, 2021. [43] Qian Qian and Xiaoyuan Qian. The implicit bias of adagrad on separable data. Advances in Neural Information Processing Systems, 32, 2019. [44] Bohan Wang, Qi Meng, Huishuai Zhang, Ruoyu Sun, Wei Chen, and Zhi-Ming Ma. Momentum doesn t change the implicit bias. ar Xiv preprint ar Xiv:2110.03891, 2021. [45] Bohan Wang, Qi Meng, Wei Chen, and Tie-Yan Liu. The implicit bias for adaptive optimization algorithms on homogeneous neural networks. In International Conference on Machine Learning, pages 10849 10858. PMLR, 2021. [46] Ziwei Ji, Nathan Srebro, and Matus Telgarsky. Fast margin maximization via dual acceleration. In International Conference on Machine Learning, pages 4860 4869. PMLR, 2021. [47] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine reading. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages 551 561, Austin, Texas, November 2016. Association for Computational Linguistics. [48] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention model for natural language inference. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages 2249 2255, Austin, Texas, November 2016. Association for Computational Linguistics. [49] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive summarization. In International Conference on Learning Representations, 2018. [50] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua Bengio. A structured self-attentive sentence embedding. In International Conference on Learning Representations, 2017. [51] Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning, pages 5793 5831. PMLR, 2022. [52] Arda Sahiner, Tolga Ergen, Batu Ozturkler, John Pauly, Morteza Mardani, and Mert Pilanci. Unraveling attention via convex duality: Analysis and interpretations of vision transformers. In International Conference on Machine Learning, pages 19050 19088. PMLR, 2022. [53] Tolga Ergen, Behnam Neyshabur, and Harsh Mehta. Convexifying transformers: Improving optimization and understanding of transformer networks. ar Xiv:2211.11052, 2022. [54] Pierre Baldi and Roman Vershynin. The quarks of attention. ar Xiv:2202.08371, 2022. [55] Samy Jelassi, Michael Eli Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure. In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural Information Processing Systems, 2022. [56] Hongkang Li, Meng Wang, Sijia Liu, and Pin-Yu Chen. A theoretical understanding of shallow vision transformers: Learning, generalization, and sample complexity. ar Xiv preprint ar Xiv:2302.06015, 2023. [57] Simon Du, Jason Lee, Haochuan Li, Liwei Wang, and Xiyu Zhai. Gradient descent finds global minima of deep neural networks. In International Conference on Machine Learning, pages 1675 1685. PMLR, 2019. [58] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generalization in neural networks. ar Xiv preprint ar Xiv:1806.07572, 2018. [59] Samet Oymak and Mahdi Soltanolkotabi. Toward moderate overparameterization: Global convergence guarantees for training shallow neural networks. IEEE Journal on Selected Areas in Information Theory, 1(1):84 105, 2020. [60] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In International Conference on Machine Learning, pages 242 252. PMLR, 2019. [61] Vladimir Vapnik. Estimation of dependences based on empirical data. Springer Science & Business Media, 2006. [62] Peter Bartlett. For valid generalization the size of the weights is more important than the size of the network. Advances in neural information processing systems, 9, 1996. [63] Albert B Novikoff. On convergence proofs for perceptrons. Technical report, STANFORD RESEARCH INST MENLO PARK CA, 1963. [64] Peter Bartlett, Yoav Freund, Wee Sun Lee, and Robert E Schapire. Boosting the margin: A new explanation for the effectiveness of voting methods. The annals of statistics, 26(5):1651 1686, 1998. [65] Tong Zhang and Bin Yu. Boosting with early stopping: Convergence and consistency. Annals of Statistics, page 1538, 2005. [66] Matus Telgarsky. Margins, shrinkage, and boosting. In International Conference on Machine Learning, pages 307 315. PMLR, 2013. [67] Ganesh Ramachandra Kini, Orestis Paraskevas, Samet Oymak, and Christos Thrampoulidis. Label-imbalanced and group-sensitive classification under overparameterization. Advances in Neural Information Processing Systems, 34:18970 18983, 2021. [68] Mahdi Soltanolkotabi, Dominik Stöger, and Changzhi Xie. Implicit balancing and regularization: Generalization and convergence guarantees for overparameterized asymmetric matrix sensing. ar Xiv:2303.14244, 2023. [69] Hossein Taheri and Christos Thrampoulidis. On generalization of decentralized learning with separable data. In International Conference on Artificial Intelligence and Statistics, pages 4917 4945. PMLR, 2023. [70] Samet Oymak and Mahdi Soltanolkotabi. Overparameterized nonlinear learning: Gradient descent takes the shortest path? In International Conference on Machine Learning, pages 4951 4960. PMLR, 2019. [71] Ziwei Ji and Matus Telgarsky. Gradient descent aligns the layers of deep linear networks. ar Xiv preprint ar Xiv:1810.02032, 2018. [72] Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization in deep matrix factorization. Advances in Neural Information Processing Systems, 32, 2019. [73] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous neural networks. ar Xiv preprint ar Xiv:1906.05890, 2019. [74] Lenaic Chizat and Francis Bach. Implicit bias of gradient descent for wide two-layer neural networks trained with the logistic loss. In Conference on Learning Theory, pages 1305 1338. PMLR, 2020. [75] Spencer Frei, Gal Vardi, Peter L Bartlett, and Nathan Srebro. Benign overfitting in linear classifiers and leaky relu networks from kkt conditions for margin maximization. ar Xiv e-prints, pages ar Xiv 2303, 2023. [76] Gal Vardi, Ohad Shamir, and Nati Srebro. On margin maximization in linear and relu networks. Advances in Neural Information Processing Systems, 35:37024 37036, 2022. [77] Navid Azizan, Sahin Lale, and Babak Hassibi. Stochastic mirror descent on overparameterized nonlinear models. IEEE Transactions on Neural Networks and Learning Systems, 33(12):7717 7727, 2021. [78] Navid Azizan and Babak Hassibi. Stochastic gradient/mirror descent: Minimax optimality and implicit regularization. In International Conference on Learning Representations. [79] Guorui Zhou, Xiaoqiang Zhu, Chenru Song, Ying Fan, Han Zhu, Xiao Ma, Yanghui Yan, Junqi Jin, Han Li, and Kun Gai. Deep interest network for click-through rate prediction. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 1059 1068, 2018. [80] Qiwei Chen, Huan Zhao, Wei Li, Pipei Huang, and Wenwu Ou. Behavior sequence transformer for e-commerce recommendation in alibaba. In Proceedings of the 1st International Workshop on Deep Learning Practice for High-Dimensional Sparse Data, pages 1 4, 2019. [81] Fei Sun, Jun Liu, Jian Wu, Changhua Pei, Xiao Lin, Wenwu Ou, and Peng Jiang. Bert4rec: Sequential recommendation with bidirectional encoder representations from transformer. In Proceedings of the 28th ACM International Conference on Information and Knowledge Management, pages 1441 1450, 2019. [82] Mia Xu Chen, Orhan Firat, Ankur Bapna, Melvin Johnson, Wolfgang Macherey, George Foster, Llion Jones, Mike Schuster, Noam Shazeer, Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Lukasz Kaiser, Zhifeng Chen, Yonghui Wu, and Macduff Hughes. The best of both worlds: Combining recent advances in neural machine translation. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 76 86, Melbourne, Australia, July 2018. Association for Computational Linguistics. [83] Michael Janner, Qiyang Li, and Sergey Levine. Reinforcement learning as one big sequence modeling problem. In ICML 2021 Workshop on Unsupervised Reinforcement Learning, 2021. [84] Qinqing Zheng, Amy Zhang, and Aditya Grover. Online decision transformer. In Proceedings of the 39th International Conference on Machine Learning, 2022. [85] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2020. [86] Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, and Hervé Jégou. Training data-efficient image transformers & distillation through attention. In International Conference on Machine Learning, pages 10347 10357. PMLR, 2021. [87] Zi-Hang Jiang, Qibin Hou, Li Yuan, Daquan Zhou, Yujun Shi, Xiaojie Jin, Anran Wang, and Jiashi Feng. All tokens matter: Token labeling for training better vision transformers. In Advances in Neural Information Processing Systems, volume 34, pages 18590 18602, 2021. [88] Hyunjik Kim, George Papamakarios, and Andriy Mnih. The lipschitz constant of self-attention. In International Conference on Machine Learning, pages 5562 5571. PMLR, 2021. [89] Jiri Hron, Yasaman Bahri, Jascha Sohl-Dickstein, and Roman Novak. Infinite attention: Nngp and ntk for deep attention networks. In International Conference on Machine Learning, pages 4376 4386. PMLR, 2020. [90] Greg Yang. Tensor programs ii: Neural tangent kernel for any architecture. ar Xiv preprint ar Xiv:2006.14548, 2020. [91] Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Lukasz Kaiser. Universal transformers. In International Conference on Learning Representations, 2018. [92] Chulhee Yun, Srinadh Bhojanapalli, Ankit Singh Rawat, Sashank Reddi, and Sanjiv Kumar. Are transformers universal approximators of sequence-to-sequence functions? In International Conference on Learning Representations, 2019. [93] Angeliki Giannou, Shashank Rajput, Jy-yong Sohn, Kangwook Lee, Jason D Lee, and Dimitris Papailiopoulos. Looped transformers as programmable computers. ar Xiv:2301.13196, 2023. [94] Yoav Levine, Noam Wies, Or Sharir, Hofit Bata, and Amnon Shashua. Limits to depth efficiencies of self-attention. In Advances in Neural Information Processing Systems, volume 33, pages 22640 22651, 2020. [95] Charlie Snell, Ruiqi Zhong, Dan Klein, and Jacob Steinhardt. Approximating how single head attention learns. ar Xiv preprint ar Xiv:2103.07601, 2021. [96] Jason Wei, Maarten Bosma, Vincent Y Zhao, Kelvin Guu, Adams Wei Yu, Brian Lester, Nan Du, Andrew M Dai, and Quoc V Le. Finetuned language models are zero-shot learners. ar Xiv preprint ar Xiv:2109.01652, 2021. [97] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is in-context learning? investigations with linear models. ar Xiv:2211.15661, 2022. [98] Shivam Garg, Dimitris Tsipras, Percy S Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583 30598, 2022. [99] Yingcong Li, M Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms: Generalization and stability in in-context learning. In International Conference on Machine Learning, 2023. [100] Jason Wei, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Ed Chi, Quoc Le, and Denny Zhou. Chain of thought prompting elicits reasoning in large language models. ar Xiv preprint ar Xiv:2201.11903, 2022. [101] Guhao Feng, Yuntian Gu, Bohang Zhang, Haotian Ye, Di He, and Liwei Wang. Towards revealing the mystery behind chain of thought: a theoretical perspective. ar Xiv preprint ar Xiv:2305.15408, 2023. [102] Yingcong Li, Kartik Sreenivasan, Angeliki Giannou, Dimitris Papailiopoulos, and Samet Oymak. Dissecting chain-of-thought: A study on compositional in-context learning of mlps. ar Xiv preprint ar Xiv:2305.18869, 2023. [103] Yuandong Tian, Yiping Wang, Beidi Chen, and Simon Du. Scan and snap: Understanding training dynamics and token composition in 1-layer transformer. ar Xiv:2305.16380, 2023. [104] Tan Minh Nguyen, Tam Minh Nguyen, Nhat Ho, Andrea L Bertozzi, Richard Baraniuk, and Stanley Osher. A primal-dual framework for transformers and neural networks. In The Eleventh International Conference on Learning Representations, 2023. [105] Davoud Ataee Tarzanagh, Yingcong Li, Christos Thrampoulidis, and Samet Oymak. Transformers as support vector machines. ar Xiv preprint ar Xiv:2308.16898, 2023. Roadmap. The appendix is organized as follows: Section A provides basic facts about the training risk. Section B presents the proof of local and global gradient descent and regularized path for learning p Rd with a fixed v Rd choice. Section C provides the proof of regularized path applied to the general case of joint optimization of head v and attention weights p using a logistic loss function. Section D presents the regularized path applied to a more general model f(X) = ψ(X S(XW p)) with a nonlinear head ψ. Section E provides implementation details. Finally, Section F discusses additional related work on implicit bias and self-attention. Table of Contents A Addendum to Section 1 18 A.1 Preliminaries on the Training Risk . . . . . . . . . . . . . . . . . . . . . . . . . 18 A.2 Proof of Lemma 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 B Addendum to Section 2 21 B.1 Descent and Gradient Correlation Conditions . . . . . . . . . . . . . . . . . . . 21 B.2 Proof of Theorem 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29 B.3 Proof of Theorem 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30 B.4 Proof of Theorem 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30 B.5 Proof of Theorem 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33 B.6 Proof of Lemma 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39 C Addendum to Section 3 39 C.1 Proof of Theorem 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39 C.2 Proof of Theorem 6 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41 D Regularization Path of Attention with Nonlinear Head 43 D.1 Proof of Theorem 8 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 44 D.2 Application to Linearly-mixed Labels . . . . . . . . . . . . . . . . . . . . . . . 45 E Implementation Details and Additional Experiments 46 F Addendum to Section 5 48 F.1 Related Work on Implicit Regularization . . . . . . . . . . . . . . . . . . . . . . 48 F.2 Related Work on Attention Mechanism . . . . . . . . . . . . . . . . . . . . . . . 48 A Addendum to Section 1 A.1 Preliminaries on the Training Risk By our assumption ψ : Rd R and ℓ: R R are differentiable functions. Recall the objective i=1 ℓ Yi ψ(X i S(Kip)) (11) with the generic prediction model ψ(X S(Kp)) and K = XW . Here, we write down the gradients of W and p in (11) to highlight the connection. Set q := W p, z{X} := X S(Kp), and a{X} := Kp. Given X and using K = XW , we have that qψ(p,W) = X S (a{X})X ψ (z{X}), (12a) pψ(p,W) = W qψ(p,W), (12b) Wψ(p,W) = p q ψ(p,W), (12c) where S (a{X}) = diag(S(a{X})) S(a{X})S(a{X}) RT T. Setting ψ(z) = v z for linear head, we obtain qψ(p,W) = X S (a{X})γ, (13a) pψ(p,W) = W qψ(p,W) = K S (a{X})γ, (13b) Wψ(p,W) = p q ψ(p,W) = pv X S (a{X})X. (13c) Recalling (12b) and (12c), and defining ℓ i := ℓ (Yi ψ(z{Xi})) R, we have that p L(p,W) = 1 i=1 ℓ i Yi W qψ(p,W), (14a) WL(p,W) = 1 i=1 ℓ i Yi p q ψ(p,W). (14b) Setting ψ(z) = v z for linear head and γi = Yi Xiv, we obtain p L(p,W) = 1 i=1 ℓ i K i S (a{Xi})γi, (15a) WL(p,W) = p i=1 ℓ i γ i S (a{Xi})Xi Lemma 3 (Key Lemma) For any p, q Rd, let a = Kq, s = S(Kp), and γ = Xv. Set Γ = sup t,τ [T] |γt γτ| and A = sup t [T] kt q . We have that a diag(s)γ a ss γ t 2 (a1 at)st(γ1 γt) 2ΓA(1 s1)2. Proof. Set γ = PT t=1 γtst. We have t 2 (γ1 γt)st, and |γ1 γ| Γ(1 s1). a diag(s)γ a ss γ = = a1s1(γ1 γ) t 2 atst( γ γt). (16) t 2 atst( γ γt) t 2 atst(γ1 γt) a diag(s)γ a ss γ = a1s1(γ1 γ) t 2 atst(γ1 γt) AΓ(1 s1)2 t 2 (γ1 γt)st t 2 atst(γ1 γt) AΓ(1 s1)2 t 2 (a1s1 at)st(γ1 γt) AΓ(1 s1)2 t 2 (a1 at)st(γ1 γt) 2AΓ(1 s1)2. Here, on the right handside uses the fact that t 2 (a1s1 a1)st(γ1 γt) t 2 st = (1 s1)2AΓ. A.2 Proof of Lemma 1 Proof. Let us prove the result for a general step size sequence (ηt)t 0. On the same training data (Yi, Xi)n i=1, recall the objectives L(p) = 1 n Pn i=1 ℓ(Yi ψ(X i S(Xip))) and L(W) = 1 n Pn i=1 ℓ(Yi ψ(X i S(Xi W u))). Suppose claim is true till iteration t. For iteration t + 1, using W(t) u = p(t), define and observe that Si = S (Xi W(t) u) = S (Xi p(t)), si = S(Xi W(t) u) = S(Xip(t)), z{Xi} = X i S(Xip(t)) = X i S(Xi W(t) u), for all i [n]. Thus, using (14), we have that L(p(t)) = 1 i=1 ℓ i Yi X i Si Xi ψ (z{Xi}), L(W(t)) = u i=1 ℓ i Yi X i Si Xi ψ (z{Xi}) Consequently, we found that gradient is rank-1 with left singular space equal to u, i.e., L(W(t)) = u L(p(t)). Since W(t) s left singular space is guaranteed to be in u (including W(0) by initialization), we only need to study the right singular vector. Using the induction till t, this yields W(t + 1) u = W(t) u ηt u 2 L(W(t))u = p(t) ηt u 2u u L(p(t)) = p(t + 1). This concludes the induction. 2For simplicity, we use on the right hand side to denote the upper and lower bounds. B Addendum to Section 2 B.1 Descent and Gradient Correlation Conditions The lemma below identifies conditions under which pmm is a global descent direction for L(p). Lemma 4 Suppose ℓ( ) is a strictly decreasing differentiate loss function and Assumption B holds. Then, for all p Rd, the training loss (ERM) obeys L(p), pmm < 0. Proof. Set γi = Yi Xiv, ai = Kip, ai = Kipmm , and ℓ i = ℓ γ i S(Kip) . (17) Let us recall the gradient evaluated at p which is given by i=1 ℓ i K i S (ai)γi. (18) This implies that D L(p), pmm E = 1 i=1 ℓ i ai, S (ai)γi . (19) To proceed, we will prove that individual summands are all strictly negative. To show that, without losing generality, let us focus on the first input and drop the subscript i for cleaner notation. This yields a, S (a)γ = a diag(S(a))γ a S(a)S(a) γ. (20) Without losing generality, assume optimal token is the first one and γt is a constant for all t 2. To proceed, we will prove the following: Suppose γ = γt 2 is constant, γ1, a1 are the largest indices of γ, a. Then, for any s obeying P t [T] st = 1, st 0, we have that a diag(s)γ a ss γ > 0. To see this, we write a diag(s)γ a ss γ = γ1s1 + γ(1 s1) a1s1 + = a1(γ1 γ)s1(1 s1) + γ (γ1s1 + γ(1 s1)) T X = a1(γ1 γ)s1(1 s1) (γ1 γ)s1 = (γ1 γ)(1 s1)s1 a1 PT t 2 atst PT t 2 st To proceed, let γgap = γ1 γ and agap = a1 maxt 2 at. With these, we obtain a diag(s)γ a ss γ agapγgaps1(1 s1). (22) ai gap inf t,opti(kiopti kit) pmm 1, γi gap = inf t,opti γiopti γit > 0, si1(1 si1) > 0. On the other hand, by our assumption ℓ i < 0. Hence, infimum ing (22) over all inputs, multiplying by ℓ i and using (19) give the desired result. Lemma 5 (Gradient Correlation Conditions) Consider n = 1 and let pmm = pmm be (ATT-SVM) solution separating α = opt from remaining tokens of input X. Suppose ℓ( ) is a strictly decreasing differentiate loss function and Assumption B holds. For any choice of π > 0, there exists R := Rπ such that, for any p with p R, we have * L(p), p + (1 + π) * L(p), pmm Above, observe that as R , we eventually get to set π = 0. Proof. The proof is similar to Lemma 4 at a high-level. However, we also need to account for the impact of p besides pmm in the gradient correlation. The main goal is showing that pmm is the near-optimal descent direction, thus, p cannot significantly outperform it. To proceed, let p = pmm p/ p , M = supt kt , Θ = 1/ pmm , s = S(Kp), a = K p, a = Kpmm. Without losing generality assume opt = 1. Set γ = γt 2. Repeating the proof of Lemma 4 yields L(p), pmm = ℓ (γ1 γ)(1 s1)s1 a1 PT t 2 atst PT t 2 st L(p), p = ℓ (γ1 γ)(1 s1)s1 a1 PT t 2 atst PT t 2 st Given π, for sufficiently large R, we wish to show that a1 PT t 2 atst PT t 2 st (1 + π) a1 PT t 2 atst PT t 2 st We consider two scenarios. Scenario 1: p pmm ϵ := π/(2M). In this scenario, for any token, we find that |at at| = |k t ( p pmm)| M p pmm Mϵ. Consequently, we obtain a1 PT t 2 atst PT t 2 st a1 PT t 2 atst PT t 2 st 2Mϵ = a1 PT t 2 atst PT t 2 st π. Also noticing a1 PT t 2 atst PT t 2 st 1 (thanks to pmm satisfying 1 margin), this implies (23). Scenario 2: p pmm ϵ := π/(2M). In this scenario, for some ν = ν(ϵ) and τ 2, we have that p (k1 kτ) = a1 aτ 1 2ν. Here τ = arg maxt 2 p kt denotes the nearest point to k1. Recall that s = S( Ra) where R = p / pmm . To proceed, split the tokens into two groups: Let N be the group of tokens obeying p (k1 kt) 1 ν for t N and [T] N be the rest. Observe that P t [T] N st PT t 2 st P t [T] N st e2ν R = Te Rν. Set M = M/Θ and note that at pmm kt M. Using p (k1 kt) 1 ν over t N and plugging in the above bound, we obtain PT t 2(a1 at)st PT t 2 st = P t N(a1 at)st PT t 2 st + P t [T] N(a1 at)st PT t 2 st 1 ν + 2 MTe Rν. Using the fact that a1 PT t 2 atst PT t 2 st 1, the above implies (23) with π = 2 MTe Rν ν. To proceed, choose Rπ = ν 1Θ 1 log(2 MT/π) to ensure π π. The following lemma states the descent property of gradient descent for L(p) under Assumption A. It is important to note that although the infimum of the optimization problem is L , it is not achieved at any finite p. Additionally, there are no finite critical points p. Lemma 6 Under Assumption A, the function L(p) is Lp-smooth, where M0 v 2 W 2 Xi 4 + 3M1 v W 2 Xi 3 . (24) Furthermore, if η 1/Lp, then, for any initialization p(0), with the GD sequence p(t + 1) = p(t) η L(p(t)), we have L(p(t + 1)) L(p(t)) η 2 L(p(t)) 2 , (25) for all t 0. This implies that t=0 L (p(t)) 2 < , and lim t L (p(t)) 2 = 0. (26) Proof. Recall that we defined γi = Yi Xiv and ai = Ki p. The gradient of L(p) is given by i=1 ℓ γ i S(Kip) K i S (ai)γi. Note that for any p Rd, the Jacobian of S(Kip) is given by S(Ki p) p = S (Kip)Ki = diag(S(Kip)) S(Kip)S(Kip) Ki. (27) The Jacobian (27) together with the definition of the softmax function S( ) implies that S(Kip)/ p Ki . Hence, for any p, p Rd, we have S(Kip) S(Ki p) Ki p p , (28a) and S (Kip) S (Ki p) diag(S(Ki p)) diag(S(Ki p)) + S(Kip)S(Kip) S(Ki p)S(Ki p) 3 Ki p p . (28b) Here, the last inequality uses the fact that |ab cd| |d||a c| + |a||b d|. Next, for any p, p Rd, we have L(p) L( p) 1 ℓ γ i S(Kip) K i S (Ki p)γi ℓ γ i S(Ki p) K i S (Ki p)γi K i S (Ki p)γi ℓ γ i S(Kip) ℓ γ i S(Ki p) ℓ (γ i S(Kip)) K i S (Ki p)γi K i S (Ki p)γi i=1 M0 γi 2 Ki S(Kip) S(Ki p) i=1 M1 γi Ki S (Kip) S (Ki p) , (29) where the second inequality follows from the fact that |ab cd| |d||a c| + |a||b d| and the third inequality uses Assumption A. Substituting (28a) and (28b) into (29), we get L(p) L( p) 1 M0 γi 2 Ki 2 + 3M1 Ki 2 γi p p M0 v 2 W 2 Xi 4 + 3M1 v |W 2 Xi 3 p p where Lp is defined in (24). The remaining proof follows standard gradient descent analysis (see e.g. [22, Lemma 10]). Since L (p) is Lp-smooth, we get L (p(t + 1)) L (p(t)) + L (p(t)) (p(t + 1) p(t)) + Lp 2 p(t + 1) p(t) 2 = L (p(t)) η L (p(t)) 2 + Lpη2 2 L (p(t)) 2 = L (p(t)) η 1 Lpη ! L (p(t)) 2 2 L (p(t)) 2 , where the last inequality follows from our assumption on the stepsize. The above inequality implies that t=0 L (p(t)) 2 2 η (L (p(0)) L ) , (30) where the right hand side is upper bounded by a finite constant. This is because, by Assumption A, L (p(0)) < and L L (p(t)), where L denotes the minimum objective. Finally, (30) yields the expression (26). In the following lemma, we demonstrate the existence of parameters µ = µ(α) > 0 and Rµ > 0 such that when Rµ is sufficiently large, there are no stationary points within Cµ,Rµ(pmm). Additionally, we provide the local gradient correlation condition. Lemma 7 (Local Gradient Condition) Suppose Assumption A on the loss function ℓholds. Let α = (αi)n i=1 be indices of locally-optimal tokens per Definition 2. L1. There exists a positive scalar µ = µ(α) > 0 such that for sufficiently large Rµ, no stationary point exists within Cµ, Rµ(pmm), where Cµ, Rµ is defined in (8). L2. For all q, p coneµ(pmm) with q = pmm and p Rµ with same Rµ choice as (L1.), there exist dataset dependent constants C, c > 0 such that i [n] {1 S(Kip)αi} L(p), q c 1 i [n] {1 S(Ki p)αi} > 0, (31a) i [n] {1 S(Kip)αi} ACTe RµΘ/2. (31b) C A > 0, (31c) Here, A = maxi [n],t,τ [T] kit kiτ and Θ = 1/ pmm . L3. For any π > 0, there exists Rπ such that Rπ Rµ and all p Cµ,Rπ(pmm) obeys + (1 + π) * L(p), pmm Proof. Let pmm = pmm(α) be the solution of (ATT-SVM). Recall Cµ, Rµ(pmm) = coneµ(pmm) \ n p p Rµ o . Let (Ti)n i=1 be the sets of all SVM-neighbors per Definition 2. Let Ti = [T] Ti {αi} be the set of non-SVM-neighbor tokens, i [n]. Let Θ = 1/ pmm , δ = 0.5 min i [n] min t Ti,τ Ti (kit kiτ) pmm, A = max i [n],t [T] kit /Θ, min(0.5, δ) When Ti = for all i [n] (i.e. globally-optimal indices), we set δ = as all non-neighbor related terms will disappear. Since pmm is the max-margin model ensuring (kiαi kit) pmm 1 for all i [n], the following inequalities hold for all q coneµ(pmm), q = pmm and all i [n], t Ti, τ Ti: (kit kiτ) q δ > 0, (kiαi kiτ) q 1 + δ, 3/2 (kiαi kit) q 1/2. Here, we used q pmm 2/ pmm 2 2µ which implies q pmm p L1. and L2.. Now that the choice of local cone is determined, we need to prove the main claims. We will lower bound q L(p) and establish its strict positivity for p R, where R = Rµ. This will show that there is no stationary point as a by product. Consider any q Rd satisfying q = pmm . To proceed, we write the gradient correlation following (18) and (21) L(p), q = 1 i=1 ℓ i ai, S (a i)γi , (34) where we denoted ℓ i = ℓ (Yi v X i S(Kip)), ai = Kiq, a i = Kip, si = S(Kip). Using (33), for all t Ti, τ Ti, for all p Cµ,R(pmm), we have that a iαi a iτ RΘ(1 + δ), and a it a iτ RΘδ. Consequently, we can bound the softmax probabilities si = S(Kip) as follows: For all i [n], τ Ti siτ 1 siαi = X τ,αi siτ Te RΘ/2siαi Te RΘ/2, τ Ti siτ Te RΘδsiti Te RΘδS i, ti Ti. (35) Recall scores γit = Yi v xit. Define the score gaps over neighbors: γgap i = γiαi max t Ti γit, and γgap i = γiαi min t Ti γit. It follows from (32) that A = max i [n],t [T] kit /Θ max i [n],t [T] ait = kitq . Define the α-dependent global scalar Γ = supi [n],t,τ [T] |γit γiτ|. Let us focus on a fixed datapoint i [n], assume (without losing generality) αi = 1, and drop subscripts i, that is, α := αi = 1, X := Xi, Y := Yi, K := Ki, a = Kp, a = Kq, s = S(Kp), γ = Y Xv, γgap := γgap i , γgap := γgap i , Q := Qi, and S := S i. Directly applying Lemma 3, we obtain a diag(s)γ a ss γ t 2 (a1 at)st(γ1 γt) 2ΓA(1 s1)2. To proceed, let us decouple the non-neighbors within PT t 2(a1 at)st(γ1 γt) via X t T (a1 at)st(γ1 γt) 2QΓA. Aggregating these, we found a diag(s)γ a ss γ X t Ti (a1 at)st(γ1 γt) 2ΓA((1 s1)2 + Q). (36) To proceed, let us upper/lower bound the gradient correlation. We use two bounds depending on q coneµ(pmm) (Case 1) or general q Rd (Case 2). Case 1: q coneµ(pmm). Since 1.5 a1 at 0.5 following (33), we find 1.5 S γgap X t Ti (a1 at)st(γ1 γt) 0.5 S γgap. Next we claim that S dominates ((1 s1)2 + Q) for large R. Specifically, we wish for S γgap/4 4ΓA max((1 s1)2, Q) S 16 ΓA γgap max((1 s1)2, Q). (37) Now choose R δ 1 log(T)/Θ to ensure Q S since Q Te RΘδS . Consequently (1 s1)2 = (Q + S )2 4S 2 4S Te RΘ/2. Combining these, what we wish is ensured by guaranteeing γgap max(4S Te RΘ/2, Te RΘδS ). (38) This in turn is ensured for all inputs i [n] by choosing R = max(2, δ 1) Θ log 64TΓA where γgap min = mini [n] γgap i is the global scalar which is the worst case score gap over all inputs. With the above choice of R we guaranteed 2(1 s1) γgap 2 S γgap a diag(s)γ a ss γ S γgap 4 (1 s1)γgap Since this holds over all inputs, going back to the gradient correlation (34) and averaging above over all inputs i [n] and plugging back the indices i, we obtain the advertised bound by setting qi = 1 siαi (where we set αi = 1 above without losing generality) i [n] ℓ i qi γgap i L(p), q 1 i [n] ℓ i qi γgap i . (40) Let ℓ min / max be the min/max values negative loss derivative admits over the ball [ B, B] for B = v maxi,t xit and note that maxi [n] γgap i > 0 and mini [n] γgap i > 0 are dataset dependent constants. Then, we declare the constants C = 2ℓ max maxi [n] γgap i > 0, c = (1/8)ℓ min mini [n] γgap i > 0 to obtain the bound i [n] qi L(p), q c i [n] qi, (41) which is the desired statement in (31a). Case 2: q Rd and q = pmm . Define A = maxi [n],t,τ [T] kit kiτ . For any q = pmm , we use the fact that a1 at k1 kt q A Θ. Note that by definition A Θ 1. To proceed, we can upper bound A Θ S γgap X t T (a1 at)st(γ1 γt). (42) By choosing the same R as in (39) to ensure S dominates ((1 s1)2 + Q) and since A Θ 1, we guaranteed 2 A Θ S γgap a diag(s)γ a ss γ. Going back to the gradient correlation (34) and averaging above over all inputs i [n], with the same definition of C > 0, we obtain AC Θn i [n] qi L(p), q . (43) To proceed, since (43) holds for any q Rd and q = pmm , we observe that when choosing q = pmm L(p) L(p), this implies that L(p), q = L(p) pmm AC Θn Simplifying Θ = 1/ pmm on both sides yields (31b). Incorporating (35) in the bound above provides the exponential upper bound that decay with R. Combining this with (41), we obtain that for all q, p coneµ(pmm) and q Rµ This gives the desired result in (31c). L3.: Establishing gradient correlation. Our final goal is establishing gradient comparison between p, pmm for the same choice of µ > 0 provided in (32). Define p = pmm p/ p to be the normalized vector. Set notations ai = Ki p, ai = Ki pmm, and γi = Yi Xiv. To establish the result, using (34), we will prove that, for any π > 0, there is sufficiently large R = Rπ such that for any p Cµ,R(pmm): * L(p), p i=1 ℓ i ai, S (Ki p)γi i=1 ℓ i ai, S (Kip)γi = (1 + π) * L(p), pmm Following (36), for all i [n], for all q coneµ(pmm) with q = pmm , a = Kq and s = S(Kp), we have found a i diag(si)γ a i sis i γi X t Ti (a i1 a it)sit(γi1 γit) 2ΓA((1 si1)2 + Qi). (45) Plugging in ai, ai in the bound above and assuming π 1 (w.l.o.g.), (44) is implied by the following stronger inequality 6ΓA((1 si1)2 + Qi) + X t Ti (ai1 ait)sit(γi1 γit) t Ti ( ai1 ait)sit(γi1 γit) t Ti sit(γi1 γit). First, we claim that 0.5π P t Ti sit(γi1 γit) 6ΓA((1 si1)2 + Qi) for all i [n]. The proof of this claim directly follows the earlier argument, namely, following (37), (38) and (39) which leads to the choice R max(2, δ 1) for some constant C0 > 0. Here, we choose sufficiently large C0 64π to ensure R = Rπ Rµ. Following this control over the perturbation term 6ΓA((1 si1)2 + Qi), to conclude with the result, what remains is proving the comparison t Ti (ai1 ait)sit(γi1 γit) 1 + 0.5π t Ti sit(γi1 γit). (47) To proceed, we split the problem into two scenarios. Scenario 1: p pmm ϵ = π 4AΘ for some ϵ > 0. In this scenario, for any token, we find that |ait at| = |k it ( p pmm)| AΘϵ = π/4. Consequently, we obtain ai1 ait ai1 ait + 2AΘϵ = 1 + 0.5π. Similarly, ai1 ait 1 0.5π 0.5. Since all terms ai1 ait, sit, γi1 γit in (47) are nonnegative and (ai1 ait)sit(γi1 γit) (1 + 0.5π)sit(γi1 γit), above implies the desired result (47). Scenario 2: p pmm ϵ = π 4AΘ. Since p is not (locally) max-margin, in this scenario, for some i [n], ν = ν(ϵ) > 0, and τ Ti, we have that p (ki1 kiτ) = ai1 aiτ 1 2ν. Here τ = arg maxt Ti p kit denotes the nearest point to ki1 (along the p direction). Note that a nonneighbor t Ti cannot be nearest because p coneµ(pmm) and (33) holds. Recall that si = S( Rai) where R = p Θ RΘ. To proceed, let ai := mint Ti ai1 ait, I := n i [n] : ai 1 2ν o , [n] I := n i [n] : 1 2ν < ai o . For all i [n] I,X t Ti (ai1 ait)sit(γi1 γit) (1 + 0.5π) X t Ti sit(γi1 γit) (2A (1 + 0.5π)) Γ X t Ti, ai1 ait 1+ π (2A (1 + 0.5π)) ΓTe R(1+ π 2AΓTe R(1+ π For all i I, split the tokens into two groups: Let Ni be the group of tokens obeying ai1 ait 1 ν and Ti Ni be the rest of the neighbors. Observe that P t Ti Ni sit P t Ti sit T eν R e2ν R = Te Rν. Thus, using |ai1 ait| 2A and recalling the definition of γgap min, observe that X t Ti Ni (ai1 ait)sit(γi1 γit) 2ΓATe Rν t Ti sit(γi1 γit). t Ti (ai1 ait)sit(γi1 γit) = X t Ni (ai1 ait)sit(γi1 γit) + X t Ti Ni (ai1 ait)sit(γi1 γit) t Ni (1 ν)sit(γi1 γit) + 2ΓATe Rν t Ti sit(γi1 γit) 1 ν + 2ΓATe Rν t Ti sit(γi1 γit) 1 + 2ΓATe Rν t Ti sit(γi1 γit). Hence, choosing νΘ log 8ΓAT results in that X t Ti (ai1 ait)sit(γi1 γit) (1 + π t Ti sit(γi1 γit) t Ti sit(γi1 γit) t Ti sit(γi1 γit) 4T γgap mine R(1 2ν). Here, the last inequality follows from the fact that P t Ti sit maxt Ti sit e R(1 2ν) PT t=1 e R(ai1 ait) e R(1 2ν)/T. From Assumption A, we have cmin ℓ cmax for some positive constants cmin and cmax. It follows from (48) and (50) that t Ti (ai1 ait)sit(γi1 γit) X t Ti (1 + 0.5π)sit(γi1 γit) cmax2AΓTΓe R(1+ π n T πγgap min 4 e R(1 2ν) Combing with (49), this is guaranteed by choosing 1 νΘ log 8ΓAT , 1 (2ν + π/2)Θ log 8nΓAT 2cmax cminγgap minπ where ν = ν( π 4AΘ) depends only on π and global problem variables. Combining this with the prior R choice (46) (by taking maximum), we conclude with the statement. B.2 Proof of Theorem 1 Proof. This proof is a direct corollary of Lemma 14 which itself is a special case of the nonlinear head Theorem 8. Let us verify that f(X) = v X S(Xp) satisfies the assumptions of Lemma 14 where we replace the nonlinear head with linear v. To see this, set the optimal sets to be the singletons Oi = {opti}. Given (Xi, Yi), let si = S(Kip) and qi = qp i = P t,opti sit. Recalling score definition γi = Yi Xiv and setting νi := γiopti and Zi := P t,opti γitsit, a particular prediction can be written as Yi v X i S(Xi p) = γ i si = γiopti(1 qi) + X t,opti γitsit = νi(1 qi) + Zi. To proceed, we demonstrate the choices for C, ϵ > 0. Let C := mini [n],t [T] γit 0 and qmax = maxi [n] qi. Note that Zi P t,opti γitsit qiγmin Cqmax. Now, using strict score optimality of opti s for all i [n], we set ϵ := 1 sup i [n] P t,opti γitsit νiqi 1 sup i [n] supt,opti γit γiopti > 0. We conclude by observing Zi νiqi P t,opti γitsit νiqi νiqiϵ as desired. B.3 Proof of Theorem 2 Proof. We first show that limt p(t) = . From Lemma 4, we have D L(p), pmm E = 1 i=1 ℓ (Yi v X i S(Kip)) D Kipmm , S (ai)γi E , where γi = Yi Xiv and ai = Kip. It follows from Lemma 4 that L(p), pmm < 0 for all p Rd. Hence, for any finite p, L(p), pmm cannot be equal to zero, as a sum of negative terms. Therefore, there are no finite critical points p, for which L(p) = 0. On the other hand, Lemma 6 states L(p(t)) 0 which implies that p(t) . Next, we provide the directional convergence for the setting n = 1. Let us consider an arbitrary value of ϵ (0, 1) and set π = ϵ/(1 ϵ). As limt p(t) = , we can select a specific tϵ such that for all t tϵ, it holds that p(t) Rϵ 1/2 for any choice of Rϵ. To proceed, we choose Rϵ based on Lemma 5 so that for any t tϵ, we have that * L(p(t)), pmm + (1 ϵ) * L(p(t)), p(t) Multiplying both sides by the stepsize η and using the gradient descent update, we get * p(t + 1) p(t), pmm + (1 ϵ) * p(t + 1) p(t), p(t) p(t + 1) 2 p(t) 2 p(t + 1) p(t) 2 (1 ϵ) 1 2 p(t) p(t + 1) 2 p(t) 2 p(t + 1) p(t) 2 ! (1 ϵ) p(t + 1) p(t) p(t + 1) p(t) 2 (1 ϵ) p(t + 1) p(t) 2η (L(p(t)) L(p(t + 1))) . Here, the second inequality is obtained from p(t) 1/2; the third inequality follows since for any a, b > 0, we have (a2 b2)/(2b) (a b) 0; and the last inequality uses Lemma 6. Summing the above inequality over t tϵ gives * p(t) + 1 ϵ + C(ϵ, η) for some finite constant C(ϵ, η) defined as C(ϵ, η) := * p(tϵ), pmm + (1 ϵ) p(tϵ) 2η(1 ϵ) (L(p(tϵ)) L ) , (52) where L denotes the minimum objective. Since p(t) , we get Given that we can choose any value of ϵ (0, 1), we have p(t)/ p(t) pmm / pmm . B.4 Proof of Theorem 3 Proof. Following the proof of Lemma 7, let (Ti)n i=1 denote the sets of SVM-neighbors as defined in Definition 2. We define Ti = [T] Ti {αi} as the tokens that are non-SVM neighbors. Additionally, let µ be defined as in (32). Let us denote the initialization lower bound as R0 µ := R, where R is given in the Theorem 3 s statement. Consider an arbitrary value of ϵ (0, µ/2) and let 1/(1 + π) = 1 ϵ. We additionally denote Rϵ Rπ 1/2 where Rπ was defined in Lemma 7(L3.). At initialization p(0), we set ϵ = µ/2 to obtain R0 µ = Rµ/2 and provide the proof in four steps: Step 1: There are no stationary points within Cµ,R0µ(pmm). We begin by proving that there are no stationary points within Cµ,R0µ(pmm). Then, since R0 µ Rµ per Lemma 7, we can apply (L2.) to find that: For all q, p coneµ(pmm) with q , 0 and p R0 µ, we have that q L(p) is strictly positive. Step 2: It follows from Lemma 7(L3.) that, for all ϵ (0, µ/2), all p Cµ,Rϵ(pmm) satisfy * L(p), pmm + (1 ϵ) * L(p), p The argument above applies to a general ϵ (0, µ/2). However, at initialization p(0), we set ϵ = µ/2 to obtain our earlier R0 µ choice. To proceed, for any ϵ (0, µ/2), we will show that after gradient descent enters the conic set Cµ,Rϵ(pmm) for the first time, it will never leave the set. Let tϵ be the first time gradient descent enters Cµ,Rϵ(pmm). In Step 4, we will prove that such tϵ is guaranteed to exist. Additionally, for ϵ µ/2, note that tϵ = 0 i.e. the point of initialization. Step 3: Updates remain inside the cone Cµ,Rϵ(pmm). By leveraging the results from Step 1 and Step 2, we demonstrate that the gradient iterates, with an appropriate constant step size, starting from p(tϵ) Cµ,Rϵ(pmm), remain within this cone. We proceed by induction. Suppose that the claim holds up to iteration t tϵ. This implies that p(t) Cµ,Rϵ(pmm). Hence, recalling cone definition, for µ > 0 and Rϵ, we have D p(t) pmm E 1 µ and p(t) Rϵ. Let ρ(t) := 1 1 ϵ * L(p(t)), pmm Note that ρ(t) > 0 due to Step 1. This together with the gradient descent update rule gives * p(t + 1) p(t) η p(t) L(p(t)), pmm * L(p(t)), pmm = 1 µ + ηρ(t)(1 ϵ) Note that from Lemma 7, we have L(p(t)), p(t) < 0 which implies that p(t + 1) p(t) . This together with Rϵ definition and p(t) 1/2 implies that p(t + 1) 1 2 p(t) p(t + 1) 2 + p(t) 2 2 p(t) 2 2η L(p(t)), p(t) + η2 L(p(t)) 2 p(t) η p(t) L(p(t)), p(t) + η2 L(p(t)) 2. Hence, using (53) p(t) 1 η p(t) * L(p(t)), p(t) + + η2 L(p(t)) 2 1 η (1 ϵ) p(t) * L(p(t)), pmm + + η2 L(p(t)) 2 = 1 + ηρ(t) p(t) + η2 L(p(t)) 2 p(t) =: C1(ρ(t), η). Here, the second inequality follows from (53). Now, it follows from (54a) and (54c) that * p(t + 1) p(t + 1) , pmm + 1 C1(ρ(t), η) 1 µ + ηρ(t)(1 ϵ) = 1 µ + 1 C1(ρ(t), η) (1 µ)(1 C1(ρ(t), η)) + ηρ(t)(1 ϵ) = 1 µ + η C1(ρ(t), η) (µ 1)( ρ(t) p(t) + η L(p(t)) 2 p(t) ) + ρ(t)(1 ϵ) = 1 µ + η C1(ρ(t), η) p(t) η(1 µ) L(p(t)) 2 where the last inequality uses our choice of stepsize η 1/Lp in Theorem 3 s statement. Specifically, we need η to be small to ensure the last inequality. We will guarantee this by choosing a proper Rϵ in Lemma 7. Specifically, Lemma 7 leaves the choice of C0 in Rϵ lower bound of (46) open (it can always be chosen larger). Here, by choosing C0 1/Lp will ensure η 1/Lp works well. To proceed, we have that (µ ϵ) 1 µ ρ(t) L(p(t)) 2 µ ϵ 1 µ 1 1 ϵ c C Θ A 1 ACT e R0 µΘ/2 µ 2(1 µ)(1 µ A 1 ACT e R0 µΘ/2 η. (56) Here, the second inequality uses our choice of ϵ (0, µ/2) (see Step 2), and the first inequality is obtained from Lemma 7 since ρ(t) L(p(t)) = 1 1 ϵ L(p(t)) , pmm + 1 1 ϵ c C Θ 1 L(p(t)) 1 AC 1 n Pn i=1 1 siαi 1 ACTe R0µΘ/2 for some data dependent constants c and C, A = maxi [n],t,τ [T] kit kiτ , and Θ = 1/ pmm . Next, we will demonstrate that the choice of η in (56) does indeed meet our step size condition as stated in the theorem, i.e., η 1/Lp. Recall that 1/(1 + π) = 1 ϵ, which implies that π = ϵ/(1 ϵ). Combining this with (46), we obtain: Rπ max(2, δ 1) , where C0 64π, Rϵ max(2, δ 1) Θ log (1 ϵ)C0TΓA , where C0 64 ϵ 1 ϵ . On the other hand, at the initialization, we have ϵ = µ/2 which implies that R0 µ max(2, δ 1) Θ log (2 µ)C0TΓA , where C0 64 µ (2 µ). (57) In the following, we will determine a lower bound on C0 such that our step size condition in Theorem 3 s statement, i.e., η 1/Lp, is satisfied. Note that for the choice of η in (56) to meet the condition η 1/Lp, the following condition must hold: 1 Lp µ (2 µ) 1 C2T e R0 µΘ/2 R0 µ 2 µ C2T ! , (58) where C2 = (1 µ) A2C2 This together with (57) implies that for sufficiently large R0 µ max(2, δ 1) Θ log (2 µ)C3T ! , where C3 = C0ΓA γgap min C2 the step size bound in (56) ensures that η 1/Lp guarantees (55). Hence, p(t + 1) remains within the cone, i.e., p(t + 1) Cµ,Rϵ(pmm). Step 4: The correlation of p(t) and pmm increases over t. The remainder is similar to the proof of Theorem 2. From Step 3, we have that all iterates remain within the initial conic set i.e. p(t) Cµ,R0µ(pmm) for all t 0. Note that it follows from Lemma 7 that L(p), pmm/ pmm < 0, for any finite p Cµ,R0µ(pmm). Hence, there are no finite critical points p Cµ,R0µ(pmm), for which L(p) = 0. Now, based on Lemma 6, which guarantees that L(p(t)) 0, this implies that p(t) . Consequently, for any choice of ϵ (0, µ/2) there is a time tϵ such that, for all t tϵ, p(t) Cµ,Rϵ(pmm). Once within Cµ,Rϵ(pmm), following similar steps in (51) and (52), for any t tϵ, * p(t) + 1 ϵ + C2(ϵ, η) p(t) , p(t) Cµ,Rϵ(pmm), for some finite constant C2(ϵ, η). Consequently, + 1 ϵ, where p(t) Cµ,Rϵ(pmm). Since the choice of ϵ (0, µ/2) is arbitrary, we obtain p(t)/ p(t) pmm/ pmm . B.5 Proof of Theorem 4 B.5.1 Supporting Lemma We present a lemma that will aid in simplifying our analysis. We begin with a definition. Definition 3 (Selected-tokens, Neighbors, Margins, and Neighbor-optimality of a direction) Let q Rd {0} and (Yi, Ki, Xi)n i=1 be our dataset. We define the (possibly non-unique) selected-tokens of q as follows:3 αi arg max t [T] k it q. (59) Next, we define the margin and directional-neighbors for q as the minimum margin tokens to the selected-tokens, i.e., Γq = min i [n],t,αi(kiαi kit) q, (60) Mq = n (i, t) (kiαi kit) q = Γq o . (61) Finally, we say that q is neighbor-optimal if the scores of its directional-neighbors are strictly less than the corresponding selected-token. Concretely, for all (i, t) Mq, we require that γit = Yi x it v < γiαi = Yi x iαiv. Lemma 8 (When does one direction dominate another?) Suppose q, p Rd be two unit Euclidean norm vectors with identical selected tokens. Specifically, for each i [n], there exists unique αi [T] such that αi = arg maxt [T] k it q = arg maxt [T] k it p. Suppose directional margins obey Γq < Γp and set δΓ = Γp Γq. Suppose q and p are both neighbor-optimal. Then, for some R(δΓ) and all R > R(δΓ), we have that L(R p) < L(R q). Suppose q has a unique directional-neighbor and is not neighbor-optimal (i.e. this neighbor has higher score). Let δq be the margin difference between unique directional-neighbor and the second-most minimum-margin neighbor (i.e. the one after the unique one, see (65)) of q. Then, for some R(δΓ δq) and all R > R(δΓ δq), we have that L(R q) < L(R p). Proof. We prove these two statements in order. First define the directional risk baseline induced by letting R and purely selecting the tokens α = (αi)n i=1. This is given by i=1 ℓ Yi v xiαi . 3If αi is unique for all i [n], let us call it, unique selected tokens. We evaluate q, p with respect to L . To proceed, let si = S(RKiq). Define Γit q = k iαi q k it q. Note that, the smallest value for t , αi is achieved for Γq. For sufficiently large R O(log(T)/Γq), observe that, for t , αi e RΓit q sit = e Rk it q P t [T] e Rk it q 0.5e RΓit q. (62) Recalling the score definition and let M+, M be the upper and lower bounds on ℓ over its bounded domain that scores fall on, respectively. Note that, for some intermediate M+ Mi M values, we have L(Rq) L = 1 t [T] sitγit) ℓ(γiαi) t,αi sit(γiαi γit). Now, using (62) for a refreshed M+ Mit 0.5M values, we can write L(Rq) L = 1 t,αi Mite RΓit q(γiαi γit). (63) The same bound also applies to p with some M it, multipliers L(Rp) L = 1 t,αi M ite RΓit p(γiαi γit). (64) We can now proceed with the proof. Case 1: q and p are both neighbor-optimal. This means that γiαi γit > 0 for all i [n], t , αi. Let K+ > K > 0 be upper and lower bounds on γiαi γit values. We can now upper bound the right hand side of (63) via M+K+Te RΓq L(Rq) L 1 2n M K e RΓq. Consequently, L(Rq) > L(Rp) as soon as 1 2n M K e RΓq > M+K+Te RΓp. Since M+, K+, n, T are global constants, this happens under the stated condition on the margin gap Γp Γq. Case 2: q has a unique directional-neighbor and is not neighbor-optimal. In this scenario, L(Rq) L is actually negative for large R. To proceed, define the maximum score difference K+ = supi,t,αi |γiαi γit|. Also let (j, β) be the unique directional neighbor achieving the minimum margin Γq. Then, δq the margin difference between unique directional-neighbor and the second minimum-margin neighbor (i.e. the one after the unique one) of q is defined as δq = min i [n], t,αi, (i,t),(j,β) Γit q Γq. (65) To proceed, we can write L(Rp) L M+K+Te RΓp. On the other hand, setting κ = γjβ γjα j > 0, we can bound L(Rq) L = 1 n M jβe RΓqκ + 1 i [n], t,αi, (i,t),(j,β) Mite RΓit q(γiαi γit) n M e RΓqκ + M+K+Te R(Γq+δq). Consequently, we have found that L(Rp) > L(Rq) as soon as 1 n M e RΓqκ M+K+T(e R(Γq+δq) + e RΓp) This happens when R 1 δq (Γp Γq) (up to logarithmic terms) establishing the desired statement. B.5.2 Proof of Theorem 4 Define the locally-optimal unit directions Pmm = ( pmm(α) α is a locally-optimal set of indices ) . The theorem below shows that cone-restricted regularization paths can only directionally converge to an element of this set. Theorem 7 (Non-LOMM Regularization Paths Fail) Fix a unit Euclidean norm vector q Rd such that q < Pmm. Assume that the token scores are distinct (i.e., γit , γiτ for t , τ) and the key embeddings kit are in general position. Specifically, we require the following conditions to hold 4: When m = d, all matrices K Rm d where each row of K has the form kit kiαi for a unique (i, αi, t , αi) tuple, are full-rank. When m = d + 1, the vector of all ones is not in the range space of any such K matrix. Fix arbitrary ϵ > 0, R0 > 0. Define the local regularization path of q as its (ϵ, R0)-conic neighborhood: p(R) = arg min p Cϵ,R0(q), p R L(p), where Cϵ,R0(q) = coneϵ(q) n p Rd p R0 o . (66) Then, either lim R p(R) < or lim R p(R)/ p(R) , q. In both scenarios lim R p(R)/R , q. Proof. We will prove the result by dividing the problem into distinct cases. In each case, we will construct an alternative direction that achieves a strictly better objective than some δ = δ(ϵ) > 0 neighborhood of q, thereby demonstrating the suboptimality of the q direction. Let s define the δ neighborhood as follows: Nδ = ( p p p q δ and p R0 Now, let s recall a few more definitions based on Definition 3. First, the tokens selected by q are given by (59). To proceed, let s initially consider the scenario where αi is unique for all i [n], meaning that as we let c , c q will choose a single token per input. Later, we will revisit the setting when arg max is not a singleton, and q is allowed to select multiple tokens. Additionally, it s important to note that p(R) is non-decreasing by definition. Suppose it has a finite upper bound p(R) M for all R < . In that scenario, we have lim R p(R) (A) q selects a single token per input: Given that the indices α = (αi)n i=1 defined in (59) are uniquely determined, we can conclude that the q direction eventually selects tokens kiαi. Recall the definition of the margin Γq from (60) and the set of directional neighbors, which is defined as the indices that achieve Γq, as shown in (61). Let us refer to q as neighbor-optimal if γit < γiαi for all (i, t) Mq. We will consider two cases for this scenario: when q is neighbor-optimal and when q is not neighboroptimal. (A1) q is neighbor-optimal. In this case, we will argue that max-margin direction pmm := pmm(α)/ pmm(α) can be used to construct a strictly better objective than q. Note that pmm(α) exists because q is already a viable separating direction for tokens α. Specifically, consider the direction q = q+ϵ pmm q+ϵ pmm . Observe that, q lies within cone2ϵ(q),5 1 + ϵ 1 2ϵ. We now argue that, there exists δ = δϵ > 0 such that for all R > Rϵ min Rϵ r R L(r q ) < min p Nδ,Rϵ p R L(p). (68) 4This requirement holds for general data because it is guaranteed by adding arbitrarily small independent gaussian perturbations to keys kit. 5As a result, let us prove the result for ϵ 2ϵ without losing generality. To prove this, we study the margin Γq induced by q and the maximum margin Γδ induced within p Nδ. Concretely, we will show that Γq > Γδ and directly apply the first statement of Lemma 8 to conclude with (68). Let Γ = 1/ pmm(α) be the margin induced by pmm. Note that Γ > Γq by the optimality of pmm(α) and the fact that q , pmm(α). Consequently, we can lower and upper bound the margins via Γq = min i [n] min t,αi (kiαi kit) q Γq + ϵΓ 1 + ϵ Γq + ϵ Γδ = max p Nδ min i [n] min t,αi (kiαi kit) p/ p max r 1 min i [n] min t,αi (kiαi kit) (q + δr) where M = maxi,t,τ kit kiτ . Consequently, setting δ = ϵ 4M(Γ Γq), we find that Equipped with this inequality, we apply the first statement of Lemma 8 which concludes that6 for some Rϵ = R( ϵ 4(Γ Γq)) and all R > Rϵ, (68) holds. This in turn implies that, within Cϵ, the optimal solution is either upper bounded by Rϵ in ℓ2 norm (i.e. lim R p(R) < ) or at least δ = δ(ϵ) > 0 away from q after ℓ2-normalization i.e. p(R) In either scenario, we have proven that lim R (A2) q is not neighbor-optimal. In this scenario, we will prove that p(R) is finite to obtain lim R p(R)/R = 0 , q. To start, assume that conic neighborhood ϵ of q is small enough so that selected-tokens α remain unchanged within Cϵ. This is without generality because if directional convergence fails in a small neighborhood of q, it will fail in the larger neighborhood as well. Secondly, if lim R p(R) and p(R) Cϵ, since softmax will eventually perfectly select α (i.e. assigning probability 1 on token indices (i, αi)), we would have lim R L( p(R)) = L = 1 i=1 ℓ(γiαi). Note that, this is simply by selection of α and regardless of p(R) directionally converges to q. This means that, if there exists a finite p Cϵ such that L(p) < L (i.e. outperforming the training loss of p(R) ), then p(R) < . This would conclude the proof. Thus, we will simply find such a p obeying L(p) < L . To this aim, we first prove the following lemma. Lemma 9 Given a fixed unit Euclidean norm vector p, if all directional neighbors of p consistently have higher scores for their associated selected tokens, i.e., γiαi < γiβ for all (i, αi) and directional neighbor (i, β), then there exists R such that for all R > R, L(R p) < L = lim R L(R p). (69) Proof. Define the maximum score difference K+ = supi [n],t,αi |γiαi γit|. Also let Mp be the set of directional neighbors achieving the minimum margin Γp; see (61). Define Γit = k iαi p k it p. Define δp to be the margin difference between the directional-neighbors and the second-most minimum-margin neighbors defined as δp = min i [n],t,αi,(i,t) 0 and using (63), we can bound ( j,β) Mp M jβ e RΓpκ + 1 i [n],t,αi,(i,t) d, the equality Dq = Γq1M cannot be satisfied because 1M is not in the range space of D by our assumption of general key embedding positions. To proceed, |Mq| d and let D be as defined in the lemma above. D is also full-rank by our assumption of general key positions. We use D to construct a perturbation as follows. Let M+ q Mq be the set of directional neighbor with strictly higher scores than their associated selected-tokens. In other words, all ( j, β) M+ q obeys γjβ > γjα j. Define the score difference κ = min(j,β) M+q γjβ γjα j > 0. We know κ > 0 because α is not neighboroptimal. Finally, define the indicator vector of 1+ with same dimension as cardinality |Mq|. 1+ is 1 for the rows of D corresponding to M+ q and is 0 otherwise. Finally, set the perturbation as where we used the full-rankness of D during pseudo-inversion. To proceed, for a small ϵ0 > 0, consider the candidate direction q0 = q + ϵ0q . We pick ϵ0 = O(ϵ) sufficiently small to ensure q0 Cϵ. To finalize, let us consider the margins of the tokens within q0. Similar to Lemma 9, set δq = mini [n],t,αi,(i,t) 0. Let ϵ0 = q0 1 = q + ϵ0q 1. Using definition of q , we have that For (i, t) M+ q, we achieve a margin of (kiαi kit) (q + ϵ0q )/(1 + ϵ0) = Γq ϵ0 For (i, t) M+ q, we achieve a margin of Γq 1+ ϵ0 . For (i, t) < Mq, setting K = q supi,t,τ kiαi kit , we achieve a margin of at most (kiαi kit) (q + ϵ0q )/(1 + ϵ0) Γq + δq ϵ0K In short, since ϵ0 = O(ϵ0), setting ϵ0 sufficiently small guarantees that M+ q is the set of directional neighbors of q0. Since M+ q has strictly higher scores than their associated selected-tokens, applying Lemma 9 on q0 shows that, L(R q0) < L for sufficiently large R implying p(R) < . (B) q selects multiple tokens for some inputs i [n]: In this setting, we will again construct a perturbation to create a scenario where q0 = q + ϵq selects a single token for each input i [n]. We will then employ margin analysis (first statement of Lemma 8) to conclude that q0 outperforms a δ ϵ neighborhood of q. Let I [n] be the set of inputs for which q selects multiple tokens. Specifically, for each i I, there is Ti [T] such that |Ti| 2 and for any i I and θ Ti, k iθq = arg max t [T] k it q. From these multiply-selected token indices let us select the highest score one, namely, βi = arg maxθ Ti γiθ for i I. Now, define the unique optimal tokens for each input as α Rn where αi := βi for i I and αi = arg maxt [T] k it q for i < I. Define L = 1 n Pn i=1 ℓ(γiαi) as earlier. Secondly, we construct a perturbation q to show that q0 = q+ϵ0q can select tokens α asymptotically. To see this, define the matrix D where each (unique) row is given by kiαi kiθ where θ Ti, θ , αi, i I. Now note that, (kiαi kiθ) q = 0 for all θ Ti, θ , αi, i I. Since keys are in general positions, this implies that D has at most d 1 rows and, thus, its rows are linearly independent. Consequently, choose q = D 1 where denotes pseudo-inverse and 1 is the all ones vector. Also let Γq be the margin of directional margin of q that is Γq = min i [n],t L because ℓ 1 |Ti| P θ Ti γiθ > ℓ γiαi = ℓ(γiαi) where αi has the highest score i.e. γiαi > 1 |Ti| P θ Ti γiθ. Set p(R) = arg minp Nδ, p R L(p). Consequently, there are two scenarios are: lim R p(R) is finite. This already proves the statement of the theorem as p(R)/R 0 within δ < ϵ neighborhood of q. For sufficiently large R, the selected-tokens of p(R) are α = (αi)n i=1. Proceeding with the second (remaining scenario), we study the directional margin of p(R). More broadly, for any p Nδ and p = p/ p with selected-tokens α, using the fact that p q δ ϵ0, we can bound the directional margin as For θ Ti, θ , αi: (kiαi kiθ) p = (kiαi kiθ) ( p q) Kδ. For all other (i, t) with t , αi: (kiαi kit) p Γq Kδ. This means that, any such p Nδ achieves a directional margin of at most Applying Lemma 8 and setting δ = O(ϵ0), this implies that for R 1 Γq0 Γ p = 1 ϵ0 1+ϵ0 q Kδ = O( 1 we have that L(R q0) < min p =R,p Nδ L(p). Since this holds for all R, (68) holds (similar to Case (A1)) and we conclude that whenever p(R) , it doesn t directionally converge within Nδ (i.e. δ > 0 neighborhood of q) proving the advertised result. B.6 Proof of Lemma 2 We prove a slightly general restatement where we require v range(W ) instead of full-rank W. Lemma 11 Suppose for all i [n] and t , opti, Yi = 1 and γit < γiopti. Also suppose v range(W ). Then, pmm exists i.e. (ATT-SVM) is feasible for optimal indices αi opti. Proof. To establish the existence of pmm , we only need to find a direction that demonstrates the feasibility of (ATT-SVM), i.e. we need to find p that satisfies the margin constraints. To begin, let s define the minimum score difference: γ = min i [n],t,opti γiopti γit. We then set p = γ 1(W ) v where denotes pseudo-inverse. By assumption W p = γ 1v. To conclude, observe that p is a feasible solution since kit = Wxit and for all i [n] and t , opti, we have that (kiopti kit) p = (xiopti xit) W p = γ 1(xiopti xit) W (W ) v = γ 1(xiopti xit) v 1, which together with the constraints in (ATT-SVM) completes the proof. C Addendum to Section 3 C.1 Proof of Theorem 5 Proof. Suppose the claim is incorrect and either p R/R or vr/r fails to converge as R, r grows. Set Ξ = 1/ pmm , Γ = 1/ vmm , pmm = RΞpmm and vmm = rΓvmm. The proof strategy is obtaining a contradiction by proving that ( vmm, pmm) is a strictly better solution compared to (vr, p R) for large R, r. Without losing generality, we will set αi = 1 for all i [n] as the problem is invariant to tokens permutation. Define qp i = 1 sp i1 to be the amount of non-optimality (cumulative probability of non-first tokens) where sp i = S(Kip) is the softmax probabilities. Case 1: p R/R does not converge. Under this scenario there exists δ, γ = γ(δ) > 0 such that we can find arbitrarily large R with p R/R pmm/R δ and margin induced by p R/R is at most Ξ(1 γ) (from strong convexity of (ATT-SVM)). Following qp i definition above, set ˆqmax = supi [n] qp R i to be worst non-optimality in p R and q max = supi [n] q pmm i to be the same for pmm. Repeating the identical argument in Theorem 8 (specifically (84)), we can bound the non-optimality amount q pmm i of pmm as i = P t,αi exp(k it pmm) P t [T] exp(k it pmm) P t,αi exp(k it pmm) exp(k iαi pmm) T exp( RΞ). (71) Thus, q max = maxi [n] q pmm i T exp( RΞ). Next without losing generality, assume first margin constraint is γ-violated by p R and mint,α1(k1α1 k1t) p R ΞR(1 γ). Denoting the amount of non-optimality of the first input as qp R 1 , we find qp R 1 = P t,α1 exp(k 1t p R) P t [T] exp(k 1t p R) 1 P t,α1 exp(k 1t p R) exp(k 1α1 p R) T 1 exp( (1 γ)RΞ). (72) We similarly have q max T 1 exp( RΞ) to find that log(ˆqmax) (1 γ)ΞR log T, ΞR log T log(q max) ΞR + log T. (73) In words, pmm contains exponentially less non-optimality compared to p R as R grows. The remainder of the proof differs from Theorem 8 as we need to upper/lower bound the logistic loss of ( vmm, pmm) and (vr, p R) respectively to conclude with the contradiction. First, let us upper bound the logistic loss of ( vmm, pmm). Set ri = X i S(Ki pmm). Observe that if ri xi1 ϵi, we have that vmm satisfies the SVM constraints on ri with Yi r i vmm 1 ϵi/Γ. Consequently, setting ϵmax = supi [n] ϵi, vmm achieves a label-margin of Γ ϵmax on the dataset (Yi, ri)i [n]. With this, we upper bound the logistic loss of ( vmm, pmm) as follows. Let M = supi [n],t,τ [T] xit xiτ . Let us recall the fact (73) that worst-case perturbation is ϵmax M exp( ΞR + log T) = MT exp( ΞR). This implies that L( vmm, pmm) max i [n] log(1 + exp( Yir i vmm)). max i [n] exp( Yir i vmm) exp( rΓ + rϵmax) er MT exp( ΞR)e rΓ. (74) Conversely, we obtain a lower bound for (vr, p R). Set ri = X i S(Kip R). Using Assumption C, we find that solving (SVM) on (Yi, ri)i [n] achieves at most Γ νe (1 γ)ΞR/T margin. Consequently, we have L(vr, p R) 1 n max i [n] log(1 + exp( Yir i vr)) 2n max i [n] exp( Yir i vr) log 2 2n exp( r(Γ νe (1 γ)ΞR/T)) log 2 2ner(ν/T) exp( (1 γ)ΞR)e rΓ log 2. (75) Observe that, this lower bound dominates the previous upper bound when R is large, namely, when (ignoring the multiplier 1/2n for brevity) (ν/T)e (1 γ)ΞR MTe ΞR R R0 := 1 γΞ log MT 2 Thus, we indeed obtain the desired contradiction since such large R is guaranteed to exist when p R/R pmm. Case 2: vr/r does not converge. This is the simpler scenario: There exists δ > 0 such that we can find arbitrarily large r obeying vr/r vmm/ vmm δ. If p R/R Ξpmm 0, then Case 1 applies. Otherwise, we have p R/R Ξpmm 0, thus we can assume p R/R Ξpmm ϵ for arbitrary choice of ϵ > 0. On the other hand, due to the strong convexity of (SVM), for some γ := γ(δ) > 0, vr achieves a margin of at most (1 γ)Γr on the dataset (Yi, xi1)i [n]. Additionally, since p R/R Ξpmm ϵ, p R strictly separates all optimal tokens (for small enough ϵ > 0) and ˆqmax := f(ϵ) 0 as R . Consequently, setting ri = X i S(Kip R), for sufficiently large R > 0 setting M = supi [n],t [T] xit , we have that min i [n] Yiv r ri min i [n] Yiv r xi1 + sup i [n] |v r (ri xi1)| (1 γ)Γr + M f(ϵ)r (1 γ/2)Γr. (76) This in turn implies that logistic loss is lower bounded by (following (75)), L(vr, p R) 1 2neγΓr/2e Γr log 2. Going back to (74), this exponentially dominates the upper bound of ( pmm, vmm) whenever r MT exp( ΞR) < rγΓ/2, (that is, whenever R, r are sufficiently large), again concluding the proof. C.2 Proof of Theorem 6 We will prove this result in two steps. Our first claim restricts the optimization to the particular quadrant induced by mint,αi(kiαi kit) p R 0 under the theorem s condition S(Kip R)αi 1. Lemma 12 Suppose S(Kip R)αi 1. Then, there exists R0 such that for all R R0, we have that, min t,αi (kiαi kit) p R 0, for all i [n]. (77) Proof. Suppose the claim does not hold. Set s R i = S(Kip R). Fix R0 such that s R iαi 0.9 for all R R0. On the other hand, there exists arbitrarily large R for which (kiαi kit) p R < 0 for some t , αi [T] and i [n]. At this (R, i, t) choices, we have that s R it s R iαi. Since s R it + s R iαi 1, we find s R iαi < 0.5 which contradicts with s R iαi 0.9. Let Q be the set of p satisfying the quadrant constraint (77) i.e. indices (αi)n i=1 are selected. Let h R be the solution of regularization path of (v, p) subject to the constraint p Q. From Lemma 12, we know that, for some R0 and all R R0, h R = p R. Thus, if the limit exists, we have that lim R h R/R = lim R p R/R. To proceed, we will prove that lim R h R/R exists and is equal to prelax/ prelax and simultaneously establish vr/r vmm/ vmm . Lemma 13 lim R h R/R = prelax/ prelax and limr vr/r = vmm/ vmm . Proof. The proof will be similar to that of Theorem 5. As usual, we aim to show that SVM-solutions constitute the most competitive direction. Set Ξ = 1/ prelax . Case 1: h R/R does not converge. Under this scenario there exists δ, γ = γ(δ) > 0 such that we can find arbitrarily large R with h R/R Ξprelax δ. This implies that margin induced by h R/R is at most Ξ(1 γ) over the support vectors S (from strong convexity of (10)). The reason is that, h R satisfies h R(kiαi kit) 0 for all t , αi by construction as h R Q. Thus, a constraint over the support vectors have to be violated (when normalized to the same ℓ2 norm as prelax = 1/Ξ). As usual, we will construct a solution strictly superior to h R and contradicts with its optimality. Construction of competitor: Rather than using prelax direction, we will choose a slightly deviating direction that ensures the selection of the correct tokens over non-supports S. Specifically, consider the solution of (10) where we tighten the non-support constraints by arbitrarily small ϵ > 0. pϵ-rlx = arg min p p such that p (kiαi kit) (1 for all t , αi, i S ϵ for all t , αi, i S . (78) Let pmm be the solution of (ATT-SVM) with α = (αi)n i=1 (which was assumed to be separable). Observe that pmm ϵ = ϵ pmm + (1 ϵ)prelax satisfies the constraints of (78). Additionally, pmm ϵ would achieve a margin of 1 (1 ϵ)/Ξ+ϵ/ = Ξ +ϵ(Ξ ) where = 1/ pmm . Using optimality of pϵ-rlx, this implies that the reduced margin Ξϵ = 1/ pϵ-rlx (by enforcing ϵ over non-support) over the support vectors is a Lipschitz function of ϵ. That is Ξϵ Ξ ϵM for some M 0. To proceed, choose an ϵ > 0 such that, it is strictly superior to margin induced by h R, that is, To proceed, set pϵ-rlx = RΞϵ pϵ-rlx. Let us recall the following notation from the proof of Theorem 5: sp i = S(Kip) and qp i = 1 siαi. Set ˆqmax = maxi S qh R i to be worst non-optimality of h R over support set. Similarly, define q max = maxi S q pϵ-rlx i to be the same for pϵ-rlx. Repeating the identical arguments to (71), (72), (73), and using the fact that pϵ-rlx achieves a margin Ξ(1 γ 2) Ξϵ Ξ, we end up with the lines log(ˆqmax) (1 γ)ΞR log T, (79a) ΞR log T log(q max) Ξ(1 0.5γ)R + log T. (79b) In what follows, we will prove that pϵ-rlx achieves a strictly smaller logistic loss contradicting with the optimality of p R (whenever h R/R Ξprelax δ). Upper bounding logistic loss. Let us now upper bound the logistic loss of ( vmm, pϵ-rlx) where vmm = rΓvmm with vmm being the solution of (SVM) with ri xiαi and Γ = 1/ vmm . Set ri = X i S(Ki pϵ-rlx). Set υ = mini S Yi x iαivmm 1 to be the additional margin buffer that non-support vectors have access to. Also set M = supi [n],t,τ [T] xit xiτ . Observe that we can write xiαi ri = X t,αi sit(xiαi xit) = xiαi ri qi M. Non-supports achieve strong label-margin: Using above and (78) for all i S and t , αi, we have that sit e ϵΞϵRsiαi e ϵΞ(1 γ/2)Rsiαi. Consequently, whenever R R0 := (ϵΞ(1 γ/2)) 1 log( T M i P t,αi sit siαi Te ϵΞ(1 γ/2)R Γυ This implies that, on i S Yi r i vmm 1 + υ + Yi (ri xiαi) vmm 1 + υ qi M vmm 1. (80) In words: Above a fixed R0 that only depends on γ = γ(δ), features ri induced by all non-support indices i S achieve margin at least 1. What remains is analyzing the margin shrinkage over the support vectors as in Theorem 5. Controlling support margin and combining bounds: Over S, suppose vmm satisfies the SVM constraints on ri with Yi r i vmm 1 ϵi/Γ. Consequently, setting ϵmax = supi [n] ϵi, vmm achieves a label-margin of Γ ϵmax on the dataset (Yi, ri)i [n]. Next, we recall the fact (79b) that worst-case perturbation is ϵmax M exp( Ξ(1 0.5γ)R + log T) = MT exp( Ξ(1 0.5γ)R). With this and (80), we upper bound the logistic loss of ( vmm, pϵ-rlx) as follows. L( vmm, pmm) max i [n] log(1 + exp( Yir i vmm)). max i [n] exp( Yir i vmm) exp( rΓ + rϵmax) er MT exp( Ξ(1 0.5γ)R)e rΓ. (81) Conversely, we obtain a lower bound for (vr, h R). Set ri = X i S(Kih R). Recall the lower bound (79a) over the support vector set S. Combining this with our Assumption C over the support vectors of (SVM) implies that, solving (SVM) on (Yi, ri)i [n] achieves at most Γ νe (1 γ)ΞR/T margin. Consequently, we have L(vr, h R) 1 n max i [n] log(1 + exp( Yir i vr)) 2n max i [n] exp( Yir i vr) log 2 2n exp( r(Γ νe (1 γ)ΞR/T)) log 2 2ner(ν/T) exp( (1 γ)ΞR)e rΓ log 2. (82) Observe that, this lower bound dominates the previous upper bound when R is large, namely, when (ignoring the multiplier 1/2n for brevity) (ν/T)e (1 γ)ΞR MTe Ξ(1 0.5γ)R R R0 := 2 γΞ log MT 2 Thus, we obtain the desired contradiction since pϵ-rlx is a strictly better solution compared to p R = h R (once R is sufficiently large). Case 2: vr/r does not converge. This is the simpler scenario: There exists δ > 0 such that we can find arbitrarily large r obeying vr/r vmm/ vmm δ. First, note that, due to the strong convexity of (SVM), for some γ := γ(δ) > 0, vr achieves a margin of at most (Γ γ)r on the dataset (Yi, xi1)i [n]. By theorem s condition, we are provided that S(Kip R)αi 1. This immediately implies that, for any choice of ϵ = γ/3 > 0, above some sufficiently large (r0, R0), we have that xp R i ri ϵ. Following (81), this implies that, choosing vmm = rvmm/ vmm achieves a logistic loss of at most erγ/3e rΓ. Again using xp R i ri ϵ, for sufficiently large (r, R) we have that min i [n] Yiv r ri min i [n] Yiv r xi1 + sup i [n] |v r (ri xi1)| (Γ γ)r + ϵr (Γ 2γ/3)r. This in turn implies that logistic loss is lower bounded by (following (82)), L(vr, p R) 1 2ne2γr/3e rΓ log 2. This dominates the above upper bound erγ/3e rΓ of vmm whenever 1 2neγr/3 > 1 r > 3 γ log(2n), (that is, when r is sufficiently large), again concluding the proof. D Regularization Path of Attention with Nonlinear Head So far our discussion has focused on the attention model with linear head. However, the conceptual ideas on optimal token selection via margin maximization also extends to a general nonlinear model under mild assumptions. The aim of this section is showcasing this generalization. Specifically, we consider the prediction model f(X) = ψ(X S(Kp)) where ψ( ) : Rd R generalizes the linear head v of our attention model. For instance, following exposition in Section 1.1, ψ( ) can represent a multilayer transformer with p being a tunable prompt at the input layer. Recall that (Xi, Ki, Yi)n i=1 is the dataset of the input-key-label tuples. We consider the training risk i=1 ℓ(Yi, ψ(X i sp i )), where sp i = S(Kip) RT. (83) The challenge with nonlinear ψ( ) is that, we lack a clear score function (Def. 1) unlike the previous sections. The assumption below introduces a generic condition that splits the tokens of each Xi into an optimal set Oi and non-optimal set Oi = [T] Oi. In words, non-optimal tokens are those that strictly increase the training risk L(p) if they are not fully suppressed by attention probabilities sp i . Assumption D (Mixing non-optimal tokens hurt) There exists sets (Oi)n i=1 [T] as follows. Let qp i = P t Oi sp it be the sum of softmax similarities over the non-optimal set for p. Set qp max = maxi [n] qp i . For any > 0, there exists ρ < 0 such that: For all p, p Rd, if log(qp max) (1 + ) log(qp max) ρ, then L(p) < L(p ). This assumption is titled mixing hurts because the attention output X i sp i is mixing the tokens of Xi and our condition is that, to achieve optimal risk, this mixture should not contain any non-optimal tokens. In particular, we require that, a model p that contains exponentially less non-optimality (quantified via log(qmax)) compared to p is strictly preferable. As we discuss in the supplementary material, Theorem 1 is in fact a concrete instance (with linear head v) satisfying this condition. Before stating our generic theorem, we need to introduce the max-margin separator towards which regularization path of attention will converge. This is a slightly general version of Section 2 s (ATT-SVM) problem where we allow for a set of optimal tokens Oi for each input. pmm = arg min p p subject to max α Oi min β Oi p (kiα kiβ) 1, for all i [n]. (ATT-SVM ) Unlike (ATT-SVM), this problem is not necessarily convex when the optimal set Oi is not a singleton. To see this, imagine n = d = 1 and T = 3: Set the two optimal tokens as k1 = 1 and k2 = 1 and the non-optimal token as k3 = 0. The solution set of (ATT-SVM ) is pmm { 1, 1} whereas their convex combination p = 0 violates the constraints. To proceed, our final result establishes the convergence of regularization path to the solution set of (ATT-SVM ) under Assumption D. Theorem 8 Let Gmm be the set of global minima of (ATT-SVM ). Suppose its margin Ξ := 1/ pmm > 0 and Assumption D holds. Let dist ( , ) denote the ℓ2-distance between a vector and a set. Following (83), define p(R) = arg min p R L(p). We have that lim R dist p(R) ΞR , Gmm = 0. We note that Theorem 1 is a corollary of this result where Oi s and Gmm are singleton. Based on this result, with multiple optimal tokens, Theorem 1 would gracefully generalize to solve (ATT-SVM ). D.1 Proof of Theorem 8 Proof. The key idea is showing that, thanks to the exponential tail of softmax-attention, (harmful) contribution of the non-optimal token with the minimum margin can dominate the contribution of all other tokens as R . This high-level approach is similar to earlier works on implicit bias of gradient descent with logistic loss [31, 22]. Pick pmm Gmm and set p R = R pmm pmm . This will be the baseline model that p R has to compete against. Also let p R = p R ΞR. Now suppose dist ( p R, Gmm) 0 as R . Then, there exists δ > 0 such that, we can always find arbitrarily large R obeying dist ( p R, Gmm) δ. Since p R is δ > 0 bounded away from Gmm, and p R = pmm , p R strictly violates at least one of the inequality constraints in (ATT-SVM ). Otherwise, we would have p R Gmm. Without losing generality, suppose p R violates the first margin constraint, that is, for some γ := γ(δ) > 0, maxα O1 minβ O1 p R(k1α k1β) 1 γ. Now, we will argue that this will lead to a contradiction as R since we will show that L(p R) < L(p R) for sufficiently large R. First, let us control L(p R). We study s i = S(Kip R) and let αi Oi be the index α in (ATT-SVM ) for which margini = maxα Oi minβ Oi(kiα kiβ) pmm 1 is attained. Then, we bound the non-optimality amount q i of p R as q i = P t Oi exp(k it p R) P t [T] exp(k it p R) P t Oi exp(k it p R) exp(k iαi p R) T exp( ΞR). Thus, q max = maxi [n] q i T exp( ΞR). Secondly, we wish to control L(p R) by lower bounding the non-optimality in p R. Focusing on the first margin constraint, let α O1 be the index in (ATT-SVM ) for which margin1 1 γ is attained. Denoting the amount of non-optimality of the first input as ˆq1, we find7 ˆq1 = P t O1 exp(k 1t p R) P t [T] exp(k 1t p R) 1 P t O1 exp(k 1t p R) exp(k 1α p R) T 1 exp( ΞR(1 γ)). 7Here, we assumed margin is non-negative i.e. k 1α p R supt O1 k 1t p R. Otherwise, supt [T] k 1t p R is attained in O1 which implies ˆq1 T 1. Thus, we can still use the identical inequality (84) with the choice γ = 1. We similarly have q max T 1 exp( ΞR). In conclusion, for p R, p R, denoting maximum nonoptimality by ˆqmax ˆq1 and q max, we respectively obtained log(ˆqmax) (1 γ)(ΞR) log T, (ΞR) log T log(q max) (ΞR) + log T. (84) The above inequalities satisfy Assumption D as follows where p p R and p p R: Set R0 = 3γ 1Ξ 1 log T so that log T = γΞR0 3 . Secondly, set ρ0 = ΞR0 log T. This way, ρ0 log(q max) implies R R0 and log T γΞR 3 . Using the latter inequality, we bound the log T terms to obtain log(ˆqmax) (1 2γ/3)(ΞR), and log(q max) (1 γ/3)(ΞR). To proceed, we pick 1+ = 1 γ/3 1 2γ/3 implying := γ 3 2γ. Finally, for this , there exists ρ( ) which we need to ensure log(ˆqmax) ρ( ). This can be guaranteed by picking sufficiently large R that ensures log(q max) (1 γ/3)(ΞR) ρ( ) to satisfy all conditions of Assumption D. Since such large R exists by initial assumption dist ( p R, Gmm) 0, Assumption D in turn implies that L(p R) < L(p R) contradicting with the optimality of p R in (83). D.2 Application to Linearly-mixed Labels The following example shows that if non-optimal tokens result in reduced score (in terms of the alignment of prediction and label), Assumption D holds. The high-level idea behind this lemma is that, if the optimal risk is achieved by setting qp max = 0, then, Assumption D will hold. Lemma 14 (Linear label mixing) Recall qp i = P t Ot sp it from Assumption D. Suppose Yi { 1, 1} and Yi ψ(X i sp i ) = νi(1 qp i ) + Zi, for some (νi)n i=1 > 0. Here Zi = Zi(p) is the contribution of non-optimal tokens to prediction. For some C, ϵ > 0 and for all p Rd, assume Cqp max Zi (1 ϵ)νiqp i . (85) Then, Assumption D holds for L(p) = 1 n Pn i=1 ℓ(Yi ψ(X i sp i )) when ℓ( ) is a strictly decreasing loss function with continuous derivative. Proof. Recall the assumption Yi ψ(X i sp i ) = νi(1 qp i ) + Zi with Zi obeying (85). Let us also write the loss function i=1 ℓ(νi(1 sp i ) + Zi). Define qp max = supi [n] qp i . Let M be the maximum absolute value of score over tokens. Let B = max |x| M ℓ (x) A = min |x| M ℓ (x) > 0. Through Taylor s Theorem (integral remainder), we have that B(qp i νi Zi) ℓ(νi(1 qp i ) + Zi) ℓ(νi) A(qp i νi Zi) ϵAνiqp i . n Pn i=1 ℓ(νi). Set C+ = B(C + maxi [n] νi) and C = n 1Aϵ mini [n] νi. This also implies i [n] B(qp i νi Zi) L(p) L 1 i [n] A(qp i νi Zi) i [n] ϵAνiqp i C qp max. Thus, to prove L(p ) > L(p), we simply need to establish the stronger statement C qp max > C+qp max. Going back to the condition of Assumption D, any log(qp max) (1 + ) log(qp max) obeys qp max (qp max)1+ i.e. qp max (qp max)(1+ ) 1. Following above, we wish to ensure qp max > Θqp max for such (p, p ) pairs where Θ = C+/C > 1. This is guaranteed by (qp max)(1+ ) 1 1 > Θ 1 + log(qp max) < log(Θ). The above is satisfied by choosing a ρ( ) := 2(1 + 1) log(Θ) in Assumption D. Thus, all p, p with log(qp max) ρ = ρ( ) satisfies the condition of Assumption D finishing the proof. E Implementation Details and Additional Experiments In this section, we provide implementation details and additional experiments. We build one attention layer using Py Torch. During training, we use SGD optimizer with learning rate 0.1 and train the model for 1000 iterations. To better visualize the convergence path, we normalize the gradient of p (and v) at each iteration. Next, given the gradient solution p, we determine locally-optimal indices to be those with the highest softmax scores. Using these optimal indices, we utilize python package cvxopt to build and solve (ATT-SVM), and then get solution pmm. After obtaining pmm, we also verify that these indices satisfy our local-optimal definition. The examples we use in the paper are all trivial to verify (by construction). In Figures 3(a) and 3(b), vmm (blue dashed) is solved using python package sklearn.svm via (SVM) based on the given label information, and red dashed line represents prelax direction instead, which is the solution of (10). Note that in both figures, vmm/ vmm = [0, 1]. Therefore, in Figure 3(a) all optimal tokens are support vectors and prelax = pmm. Whereas in Figure 3(b), yellow is not a support vector and only needs to satisfy positive correlation with p. Gray dashed line displays the pmm direction. Failure of gradient descent s global convergence when n 2 (refer to Theorem 2). Figure 8 provides a counter-example demonstrating that the n = 1 restriction is indeed necessary and tight to guarantee global convergence of the gradient descent iterates p(t + 1) = p(t) η L(p(t)) on (ERM). 6 4 2 0 2 4 6 2 Optimal tokens Non-opt tokens Figure 8: The convergence behavior of the gradient descent on the attention weights p using the logistic loss in (ERM) with n = T = d = 2. For this example, we use logistic loss in (ERM). We set n = T = d = 2, implying that there is only one nonoptimal token, thus Assumption B is satisfied. The red and blue lines represent GMM and LMM solutions, respectively. We note that the green star and teal square indicate the locally-optimal tokens. Specifically, referring to the local optimality definition (Definition 2), for LMM solution (pmm) represented by the blue line, the square teal token does not have any SVM-neighbors. The arrows indicate the two trajectories originating from different initializations. This demonstrates that the gradient descent iterates p(t + 1) = p(t) η L(p(t)) on (ERM) with two different initializations converge to two different SVM solutions (GMM and LMM). Results validate the necessity of n = 1 in Theorem 2 to provide the gradient descent convergence to pmm from any initialization. The convergence behavior of gradient descent under over-parameterization. To illustrate Theorems 3 & 4, we have investigated the convergence behavior of p(t) generated by gradient descent in Figure 9(a), using n = 4, T = 6, and conducted 1,000 random trials for varying d {2, 5, 10, 100, 300, 500}. These experiments use normalized gradient descent with learning rate 1 for 1000 iterations. Inputs xit and the linear head v are uniformly sampled from the unit sphere, while Yi is uniformly 1, and W is set to I. The bar plot in Figure 9(a) distinguishes between non-saturated softmax (red bars) and saturated softmax (other bars). Saturation is defined as average softmax probability over tokens selected by gradient descent are at least 1 10 5 and implies that attention selects one token per input. Note that, whenever the norm of gradient descent solution is finite, softmax will be non-saturated. For 2 5 10 100 300 500 d S(Kp(t))α < 1 (a) Perc. of different convergence scenarios for p(t) 0.2 0.0 0.2 0.4 0.6 0.8 1.0 1.2 Min score gap to neighbors ( ) Probability d = 5 d = 10 d = 100 (b) Prob. of γ := mini [n],t Ti γiαi γit Figure 9: Convergence analysis of p(t) trained with random data using gradient descent. (a) shows three scenarios: (1) attention failing to select one token per input (i.e. softmax is not saturated); (2) p converging towards pmm; and (3) pmm equating to pmm with red, blue, and green bars, respectively. Considering saturated softmax instances where p(t) selects one token αi per-input, (b) presents histogram of the minimal score gap between αi and its corresponding neighbors Ti. small d (e.g., d = 2), problem has small degrees of freedom to separate optimal tokens from the rest (i.e. no SVM solution for LMM directions) especially due to label randomness. This results in a tall red bar capturing the finite-norm solutions. However, for larger d, we observe that softmax saturates (i.e. p(t) ) and we observe that the selected tokens α almost always converges to an LMM direction (blue bar) this is in line with Theorems 3 & 4. We also study the convergence to the globally-optimal GMM which is represented by the green bar: GMM is a strict subset of LMM however as d increases, we observe that the probability of GMM convergence increases as well. This behavior is in line with what one would expect from over-parameterized deep learning theory [57, 58, 59, 60] and motivates future research. The average correlation coefficient between p(t) and its associated LMM/GMM direction is 0.997, suggesting that, whenever softmax saturates, gradient descent indeed directionally converges to a LMM solution p Pmm, confirming Theorem 4. Furthermore, we found that there exist problem instances, with saturated softmax and p(t) , that do not converge to either LMM or GMM. We analyzed this phenomenon using the minimum score gap, γ := mini [n],t Ti γiαi γit, where Ti, i [n], represents the sets of SVM-neighbor tokens. Figure 9(b) provides the probability distribution of γ (with bins of width < 0.01) and demonstrates the rarity of such cases. Specifically, we found this happens less than 1% of the problems, that is, Prob(γ < 0) < 0.01. Figure 9(b) also reveals that, in these scenarios, even if γ < 0, it is typically close to zero i.e. even if there exists a SVM-neighbor with a higher score, it is only slightly so. This is not surprising since when token scores are close, we need a large number of gradient iterations to distinguish them. For all practical purposes, the optimization will treat both tokens equally and rather than solving (ATT-SVM), the more refined formulation (ATT-SVM ) developed in Section D will be a better proxy. Confirming this intuition, we have verified that, over the instances γ < 0, gradient descent solution is still > 0.99 correlated with the max-margin solution in average. In Figure 9, we again applied normalized gradient descent with a learning rate equal to 1. Cumulative probability d = 5 d = 10 d = 100 γ Figure 10: Cumulative prob. of the gap γ := mini [n],t Ti γiαi γit in Figure 9(b). Each trial involved randomly generated data and training for 1000 iterations as discussed in Theorem 4. The tokens selected by p were denoted as (αi)n i=1, where αi = arg maxt [T] S(Xip)t. The averaged softmax probabilities were calculated as s := 1 n Pn i=1 S(Xip)αi (same as Figure 3(c)). The red bars in Figure 9(b) represent the values of P( s 1 10 5) for each choice of d. Figure 10 displays the cumulative probability distribution of γ from Figure 9(b), with the gray dashed line indicating γ = 0. From this figure, we observe that the minimal score gap exhibit a sharp transition at zero (<1% of the instances have γ < 0), demonstrating that, in most random problem instances with p ( s 1), problem directionally converges to an LMM i.e. p(t)/ p(t) pmm/ pmm . We believe the rare occurrence of a negative score gap is due to small score differences (so that optimal token is not clearly distinguished) and finite number of gradient iterations we run. Interestingly, even in the negative score gap scenarios, gradient descent is aligned with pmm(α) (even if pmm(α) is not LMM) which can be predicted from our Section D which handles the scenario where there are multiple optimal tokens per input. F Addendum to Section 5 We provide an overview of the current literature on implicit regularization and attention mechanism. F.1 Related Work on Implicit Regularization The introduction of Support Vector Machines (SVM), which utilize explicit regularization to choose maximum margin classifiers, represents one of the earliest relevant literature in this field [61]. The concept of maximizing the margin was later connected to generalization performance [62]. From a practical perspective, exponential losses with decaying regularization exhibit asymptotic behavior similar to SVMs, as demonstrated in [22]. While the analysis of the perceptron [63] originally introduced the concept of margins, the method itself does not possess an inherent bias as it terminates with zero classification error. However, establishing a meaningful lower bound for the attained margin is not possible. Initial empirical investigations highlighting the implicit bias of descent methods focused on ℓ1-regularization, revealing that coordinate descent, when combined with the exponential loss, exhibits an inherent inclination towards ℓ1-regularized solutions [64]. This work draws extensively from the literature on implicit bias and regularization, which has provided valuable techniques and inspiration. A common observation in these studies is the convergence to a specific optimal solution over the training set. This phenomenon has been observed in various approaches, including coordinate descent [65, 66], gradient descent [30, 67, 25, 68, 69, 22, 70], deep linear networks [71, 72], Re LU networks [73, 74, 29, 75, 76], mirror descent [77, 78, 33, 36], and many others. The implicit bias of gradient descent in classification tasks involving separable data has been extensively examined by [22, 25, 26, 27, 28, 29]. The works on classification typically utilize logistic loss or exponentially-tailed losses to establish connections to margin maximization. The results have also been extended to non-separable data by [30, 31, 21]. Additionally, several papers have explored the implicit bias of stochastic gradient descent [37, 38, 41, 42], as well as adaptive and momentum-based methods [43, 44, 45, 46]. While there are some similarities between our optimization approach for v and existing works, the optimization of p presents notable differences. Firstly, our optimization problem is nonconvex and involves a composition of loss and softmax, which introduces new challenges and complexities. The presence of softmax adds a nonlinearity to the problem, requiring specialized techniques for analysis and optimization. Secondly, our analysis introduces the concept of locally-optimal tokens, which refers to tokens that achieve locally optimal solutions in their respective attention cones. This concept is crucial for understanding the behavior of the attention mechanism and its convergence properties. By focusing on the cones surrounding locally-optimal tokens, we provide a tailored analysis that captures the unique characteristics of the attention model. Overall, our work offers novel insights into the optimization of attention-based models and sheds light on the behavior of the attention mechanism during training. F.2 Related Work on Attention Mechanism As the backbone of Transformers [6], the self-attention mechanism [47, 48, 49, 50] plays a crucial role in computing feature representations by globally modeling long-range interactions within the input. Transformers have achieved remarkable empirical success in various domains, including natural language processing [4, 2], recommendation systems [79, 80, 81], and reinforcement learning [82, 83, 84]. With the introduction of Vision Transformer (Vi T) [85], Transformer-based models [86, 87] have become a strong alternative to convolutional neural networks (CNN) and become prevalent in vision tasks. However, the theoretical foundation of Transformers and self-attention mechanisms has remained largely unexplored. Some studies have established important results, including the Lipschitz constant of self-attention [88], properties of the neural tangent kernel [89, 90], and the expressive power and Turing-completeness of Transformers [91, 92, 93, 51, 23, 94] with statistical guarantees [95, 96]. There is also a growing effort towards a theoretical understanding of emergent abilities of language models such as in-context learning [97, 98, 99] and chain-of-thought [100, 101, 102] which are inherently related to the models ability to attend to the relevant information within the input sequence. Focusing on the self-attention component, Edelman et al. [51] theoretically shows that a single self-attention head can represent a sparse function of the input with a sample complexity for the generalization gap between the training loss and the test loss. However, they did not delve into the algorithmic aspects of training Transformers to achieve desirable loss. Sahiner et al. [52] and Ergen et al. [53] further explored the analysis of convex relaxations for self-attention, investigating potential optimization techniques and properties. The former work applies to self-attention with linear activation (rather than softmax) whereas the latter work attempts to approximate softmax via a linear operation with unit simplex constraints. In contrast, we directly study softmax and characterize its non-convex geometry. In terms of expressive ability, Baldi and Vershynin [54] investigated the capacity of attention layers to capture complex patterns and information, while Dong et al. [23] illustrates the propensity of attention networks to degenerate during the training process, with the result often being an output that is approximately a rank-1 matrix. Recent works have made progress in characterizing the optimization and generalization dynamics of attention [55, 56, 103, 17, 104]. Jelassi et al. [55] studied gradient-based methods from random initialization and provided a theoretical analysis of the empirical finding that Vision Transformers learn position embeddings that recapitulate the spatial structure of the training data, even though this spatial structure is no longer explicitly represented after the image is split into patches. Li et al. [56] provided theoretical results on training three-layer Vi Ts for classification tasks. They quantified the importance of self-attention in terms of sample complexity for achieving zero generalization error, as well as the sparsity of attention maps when trained by stochastic gradient descent (SGD). In another related work, Nguyen et al. [104] proposed a primal-dual optimization framework that focuses on deriving attention as the dual expansion of a primal neural network layer. By solving a support vector regression problem, they gained a deeper understanding and explanation of various attention mechanisms. This framework also enables the creation of novel attention mechanisms, offering flexibility and customization in designing attention-based models. In another closely related work, Oymak et al. [17] analyzed the same attention model as ours, denoted by (ERM). Specifically, they jointly optimize v, p for three gradient iterations for a contextual dataset model. This is in contrast to our emphasis on infinite-iteration behavior of p-only optimization. However, it is important to note that all of these works make certain assumptions about the data. Specifically, they assume that tokens are tightly clusterable or can be clearly split into relevant and irrelevant sets. Additionally, Li et al. [56] require specific assumptions on the initialization of the model, while Jelassi et al. [55] consider a simplified attention structure where the attention matrix is not directly parameterized with respect to the input. In contrast, our work offers a comprehensive optimization-theoretic analysis of the attention model, establishing a formal connection to max-margin problems. While comparable works make assumptions on the dataset model, our results apply under minimal assumptions for general data and realistic conditions. Our analysis based on max-margin-equivalence allows us to gain a deeper understanding of the optimization geometry of attention and its behavior during the training process. As articulated in our experiments, our results lead to novel insights even for n = 1, 2 samples, T = 2, 3 tokens and d = 2, 3 dimensions (in contrast to [55, 56, 103, 17]). Notably, our work also presents the first theoretical understanding of the implicit bias exhibited by gradient descent methods in the context of the attention model. We remark that recent work [105] expands the theory presented in this work to 1-layer transformers. By uncovering the underlying optimization principles and thoroughly characterizing the directional convergence of attention, we provide valuable insights into the dynamics and generalization properties of attention-based models opening the path for future research.