# transformers_learn_incontext_by_gradient_descent__6d0bc370.pdf Transformers Learn In-Context by Gradient Descent Johannes von Oswald 1 2 Eyvind Niklasson 2 Ettore Randazzo 2 Jo ao Sacramento 1 Alexander Mordvintsev 2 Andrey Zhmoginov 2 Max Vladymyrov 2 At present, the mechanisms of in-context learning in Transformers are not well understood and remain mostly an intuition. In this paper, we suggest that training Transformers on auto-regressive objectives is closely related to gradient-based metalearning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradientdescent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of in-context learning in optimized Transformers. Building on this insight, we furthermore identify how Transformers surpass the performance of plain gradient descent by learning an iterative curvature correction and learn linear models on deep data representations to solve non-linear regression tasks. Finally, we discuss intriguing parallels to a mechanism identified to be crucial for in-context learning termed induction-head (Olsson et al., 2022) and show how it could be understood as a specific case of in-context learning by gradient descent learning within Transformers. 1Department of Computer Science, ETH Z urich, Z urich, Switzerland 2Google Research. Correspondence to: Johannes von Oswald . Proceedings of the 40 th International Conference on Machine Learning, Honolulu, Hawaii, USA. PMLR 202, 2023. Copyright 2023 by the author(s). 1. Introduction In recent years Transformers (TFs; Vaswani et al., 2017) have demonstrated their superiority in numerous benchmarks and various fields of modern machine learning, and have emerged as the de-facto neural network architecture used for modern AI (Dosovitskiy et al., 2021; Yun et al., 2019; Carion et al., 2020; Gulati et al., 2020). It has been hypothesised that their success is due in part to a phenomenon called in-context learning (Brown et al., 2020; Liu et al., 2021): an ability to flexibly adjust their prediction based on additional data given in context (i.e. in the input sequence itself). In-context learning offers a seemingly different approach to few-shot and meta-learning (Brown et al., 2020), but as of today the exact mechanisms of how it works are not fully understood. It is thus of great interest to understand what makes Transformers pay attention to their context, what the mechanisms are, and under which circumstances, they come into play (Chan et al., 2022b; Olsson et al., 2022). In this paper, we aim to bridge the gap between in-context and meta-learning, and show that in-context learning in Transformers can be an emergent property approximating gradient-based few-shot learning within its forward pass, see Figure 1. For this to be realized, we show how Transformers (1) construct a loss function dependent on the data given in sequence and (2) learn based on gradients of that loss. We will first focus on the latter, the more elaborate learning task, in sections 2 and 3, after which we provide evidence for the former in section 4. We summarize our contributions as follows1: We construct explicit weights for a linear self-attention layer that induces an update identical to a single step of gradient descent (GD) on a mean squared error loss. Additionally, we show how several self-attention layers can iteratively perform curvature correction improving on plain gradient descent. When optimized on linear regression datasets, we demonstrate that linear self-attention-only Transform- 1Main experiments can be reproduced with notebooks provided under the following link: https://github.com/ google-research/self-organising-systems/ tree/master/transformers_learn_icl_by_gd Transformers Learn In-Context by Gradient Descent 0 20 40 GD Steps / Transformer Layers Gradient descent Trained Transformer Figure 1. Illustration of our hypothesis: gradient-based optimization and attention-based in-context learning are equivalent. Left: Learning a neural network output layer by gradient descent on a dataset Dtrain. The task-shared meta-parameters θ are obtained by meta-learning with the goal that after adjusting the neural network output layer, the model generalizes well on unseen data. Center: Illustration of a Transformer that adjusts its query prediction on the data given in-context i.e. tθ(xquery; Dcontext). The weights of the Transformer are optimized to predict the next token yquery. Right: Our results confirm the hypothesis that learning with K steps of gradient descent on a dataset Dtrain (green part of the left plot) matches trained Transformers with K linear self-attention layers (central plot) when given Dtrain as in-context data Dcontext. ers either converge to our weight construction and therefore implement gradient descent, or generate linear models that closely align with models trained by GD, both in inand out-of-distribution validation tasks. By incorporating multi-layer-perceptrons (MLPs) into the Transformer architecture, we enable solving nonlinear regression tasks within Transformers by showing its equivalence to learning a linear model on deep representations. We discuss connections to kernel regression as well as nonparametric kernel smoothing methods. Empirically, we compare meta-learned MLPs and a single step of GD on its output layer with trained Transformers and demonstrate striking similarities between the identified solutions. We resolve the dependency on the specific token construction by providing evidence that learned Transformers first encode incoming tokens into a format amenable to the in-context gradient descent learning that occurs in the later layers of the Transformer. These findings allow us to connect learning Transformer weights and the concept of meta-learning a learning algorithm (Schmidhuber, 1987; Hinton & Plaut, 1987; Bengio et al., 1990; Chalmers, 1991; Schmidhuber, 1992; Thrun & Pratt, 1998; Hochreiter et al., 2001; Andrychowicz et al., 2016; Ba et al., 2016; Kirsch & Schmidhuber, 2021). In this extensive research field, meta-learning is typically regarded as learning that takes place on various time scales namely fast and slow. The slowly changing parameters control and prepare for fast adaptation reacting to sudden changes in the incoming data by e.g. a context switch. Notably, we build heavily on the concept of fast weights (Schmidhuber, 1992) which has shown to be equivalent to linear self-attention (Schlag et al., 2021) and show how optimized Transformers implement interpretable learning algorithms within their weights. Another related meta-learning concept, termed MAML (Finn et al., 2017), aims to meta-learn a deep neural network initialization which allows for fast adaptation on novel tasks. It has been shown that in many circumstances, the solution found can be approximated well when only adapting the output layer i.e. learning a linear model on a meta-learned deep data representations (Finn et al., 2017; Finn & Levine, 2018; Gordon et al., 2019; Lee et al., 2019; Rusu et al., 2019; Raghu et al., 2020; von Oswald et al., 2021). In section 3, we show the equivalence of this framework to in-context learning implemented in a common Transformer block i.e. when combining self-attention layers with a multi-layerperceptron. In the light of meta-learning we show how optimizing Transformer weights can be regarded as learning on two time scales. More concretely, we find that solely through the pressure to predict correctly Transformers discover learning algorithms inside their forward computations, effectively meta-learning a learning algorithm. Recently, this concept of an emergent optimizer within a learned neural network, such as a Transformer, has been termed mesa-optimization (Hubinger et al., 2019). We find and describe one possible realization of this concept and hypothesize that the in-context learning capabilities of language models emerge through mechanisms similar to the ones we discuss here. Transformers come in different shapes and sizes , operate on vastly different domains, and exhibit varying forms of phase transitions of in-context learning (Kirsch et al., 2022; Chan et al., 2022a), suggesting variance and significant complexity of the underlying learning mechanisms. As a result, we expect our findings on linear self-attention-only Transformers to only explain a limited part of a complex process, and it may be one of many possible methods giving rise to in-context learning. Nevertheless, our approach provides an intriguing perspective on, and novel evidence for, an incontext learning mechanism that significantly differs from existing mechanisms based on associative memory (Ramsauer et al., 2020), or by the copying mechanism termed induction heads identified by (Olsson et al., 2022). We, therefore, state the following Transformers Learn In-Context by Gradient Descent Hypothesis 1 (Transformers learn in-context by gradient descent). When training Transformers on auto-regressive tasks, in-context learning in the Transformer forward pass is implemented by gradient-based optimization of an implicit auto-regressive inner loss constructed from its in-context data. We acknowledge work done in parallel, investigating the same hypothesis. Aky urek et al. (2023) puts forward a weight construction based on a chain of Transformer layers (including MLPs) that together implement a single step of gradient descent with weight decay. Similar to work done by Garg et al. (2022), they then show that trained Transformers match the performance of models obtained by gradient descent. Nevertheless, it is not clear that optimization finds Transformer weights that coincide with their construction. Here, we present a much simpler construction that builds on Schlag et al. (2021) and only requires a single linear selfattention layer to implement a step of gradient descent. This allows us to (1) show that optimizing self-attention-only Transformers finds weights that match our weight construction (Proposition 1), demonstrating its practical relevance, and (2) explain in-context learning in shallow two layer Transformers intensively studied by Olsson et al. (2022). Therefore, although related work provides comprehensive empirical evidence that Transformers indeed seem to implement gradient descent based learning on the data given in-context, we will in the following present mechanistic verification of this hypothesis and provide compelling evidence that our construction, which implements GD in a Transformer forward pass, is found in practice. 2. Linear self-attention can emulate gradient descent on a linear regression task We start by reviewing a standard multi-head self-attention (SA) layer with parameters θ. A SA layer updates each element ej of a set of tokens {e1, . . . , e N} according to ej ej + SAθ(j, {e1, . . . , e N}) h Ph Vhsoftmax(KT h qh,j) (1) with Ph, Vh, Kh the projection, value and key matrices, respectively, and qh,i the query, all for the h-th head. To simplify the presentation, we omit bias terms here and throughout. The columns of the value Vh = [vh,1, . . . , vh,N] and key Kh = [kh,1, . . . , kh,N] matrices consist of vectors vh,i = Wh,V ei and kh,i = Wh,Kei; likewise, the query is produced by linearly projecting the tokens, qh,j = Wh,Qej. The parameters θ = {Ph, Wh,V , Wh,K, Wh,Q}h of a SA layer consist of all the projection matrices, of all heads. The self-attention layer described above corresponds to the one used in the standard Transformer model. Follow- ing Schlag et al. (2021), we now introduce our first (and only) departure from the standard model, and omit the softmax operation in equation 1, leading to the linear selfattention (LSA) layer ej ej + LSAθ(j, {e1, . . . , e N}) = ej + P h Ph Vh KT h qh,j We next show that with some simple manipulations we can relate the update performed by an LSA layer to one step of gradient descent on a linear regression loss. Data transformations induced by gradient descent We now introduce a reference linear model y(x) = Wx parameterized by the weight matrix W RNy Nx, and a training dataset D = {(xi, yi)}N i=1 comprising of input samples xi RNx and respective labels yi RNy. The goal of learning is to minimize the squared-error loss: L(W) = 1 2N i=1 Wxi yi 2. (2) One step of gradient descent on L with learning rate η yields the weight change W = η W L(W) = η i=1 (Wxi yi)x T i . (3) Considering the loss after changing the weights, we obtain L(W + W) = 1 2N i=1 (W + W)xi yi 2 i=1 Wxi (yi yi) 2 (4) where we introduced the transformed targets yi yi with yi = Wxi. Thus, we can view the outcome of a gradient descent step as an update to our regression loss (equation 2), where data, and not weights, are updated. Note that this formulation is closely linked to predicting based on nonparametric kernel smoothing, see Appendix A.8 for a discussion. Returning to self-attention mechanisms and Transformers, we consider an in-context learning problem where we are given N context tokens together with an extra query token, indexed by N + 1. In terms of our linear regression problem, the N context tokens ej = (xj, yj) RNx+Ny correspond to the N training points in D, and the N+1-th token e N+1 = (x N+1, y N+1) = (xtest, ˆytest) = etest to the test input xtest and the corresponding prediction ˆytest. We use the terms training and in-context data interchangeably, as well as query and test token/data, as we establish their equivalence now. Transformers Learn In-Context by Gradient Descent Transformations induced by gradient descent and a linear self-attention layer can be equivalent We have re-cast the task of learning a linear model as directly modifying the data, instead of explicitly computing and returning the weights of the model (equation 4). We proceed to establish a connection between self-attention and gradient descent. We provide a construction where learning takes place simultaneously by directly updating all tokens, including the test token, through a linear self-attention layer. In other words, the token produced in response to a query (test) token is transformed from its initial value W0xtest, where W0 is the initial value of W, to the post-learning prediction ˆy = (W0 + W)xtest obtained after one gradient descent step. Proposition 1. Given a 1-head linear attention layer and the tokens ej = (xj, yj), for j = 1, . . . , N, one can construct key, query and value matrices WK, WQ, WV as well as the projection matrix P such that a Transformer step on every token ej is identical to the gradient-induced dynamics ej (xj, yj) + (0, Wxj) = (xj, yj) + P V KT qj such that ej = (xj, yj yj). For the test data token (x N+1, y N+1) the dynamics are identical. The simple construction can be found in Appendix A.1 and we denote the corresponding self-attention weights by θGD. Below, we provide some additional insights on what is needed to implement the provided LSA-layer weight construction, and further details on what it can achieve: Full self-attention. Our dynamics model training is based on in-context tokens only, i.e., only e1, . . . , e N are used for computing key and value matrices; the query token e N+1 (containing test data) is excluded. This leads to a linear function in xtest as well as to the correct W, induced by gradient descent on a loss consisting only of the training data. This is a minor deviation from full self-attention. In practice, this modification can be dropped, which corresponds to assuming that the underlying initial weight matrix is zero, W0 0, which makes W in equation 8 independent of the test token even if incorporating it in the key and value matrices. In our experiments, we see that these assumptions are met when initializing the attention weights θ to small values. Reading out predictions. When initializing the yentry of the test-data token with W0x N+1, i.e. etest = (xtest, W0xtest), the test-data prediction ˆy can be easily read out by simply multiplying again by 1 the updated token, since y N+1 + y N+1 = (y N+1 y N+1) = y N+1 + Wx N+1. This can easily be done by a final projection matrix, which incidentally is usually found in Transformer architectures. Importantly, we see that a single head of self-attention is sufficient to transform our training targets as well as the test prediction simultaneously. Uniqueness. We note that the construction is not unique; in particular, it is only required that the products PWV as well as WKWQ match the construction. Furthermore, since no nonlinearity is present, any rescaling s of the matrix products, i.e., PWV s and WKWQ/s, leads to an equivalent result. If we correct for these equivalent formulations, we can experimentally verify that weights of our learned Transformers indeed match the presented construction. Meta-learned task-shared learning rates. When training self-attention parameters θ across a family of in-context learning tasks τ, where the data (xτ,i, yτ,i) follows a certain distribution, the learning rate can be implicitly (meta-)learned such that an optimal loss reduction (averaged over tasks) is achieved given a fixed number of update steps. In our experiments, we find this to be the case. This kind of meta-learning to improve upon plain gradient descent has been leveraged in numerous previous approaches for deep neural networks (Li et al., 2017; Lee & Choi, 2018; Park & Oliva, 2019; Zhao et al., 2020; Flennerhag et al., 2020). Task-specific data transformations. A self-attention layer is in principle further capable of exploiting statistics in the current training data samples, beyond modeling task-shared curvature information in θ. More concretely, a LSA layer updates an input sample according to a data transformation xj xj + xj = (I + P(X)V (X)K(X)T WQ)xj = Hθ(X)xj, with X the Nx N input training data matrix, when neglecting influences by target data yi. Through Hθ(X), a LSA layer can encode in θ an algorithm for carrying out data transformations which depend on the actual input training samples in X. In our experiments, we see that trained self-attention learners employ a simple form of H(X) and that this leads to substantial speed ups in for GD and TF learning. 3. Trained Transformers do mimic gradient descent on linear regression tasks We now experimentally investigate whether trained attention-based models implement gradient-based incontext learning in their forward passes. We gradually build up from single linear self-attention layers to multi-layer nonlinear models, approaching full Transformers. In this section, we follow the assumption of Proposition 1 tightly and construct our tokens by concatenating input and target data, ej = (xj, yj) for 1 j N, and our query token by concatenating the test input and a zero vector, e N+1 = (xtest, 0). We show how to lift this assumption in the last section of the Transformers Learn In-Context by Gradient Descent 0 2000 4000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF Figure 2. Comparing one step of GD with a trained single linear self-attention layer. Outer left: Trained single LSA layer performance is identical to the one of gradient descent. Center left: Almost perfect alignment of GD and the model generated by the SA layer after training, measured by cosine similarity and the L2 distance between models as well as their predictions. Center right: Identical loss of GD, the LSA layer model as well as the model obtained by interpolating between the construction and the optimized LSA layer weights for different N = Nx. Outer right: The trained LSA layer, gradient descent and their interpolation show identically loss (in log-scale) when provided input data different than during training i.e. with scale of 1. We display the mean/std. or the single runs of 5 seeds. paper. The prediction ˆyθ({eτ,1, . . . , eτ,N}, eτ,N+1) of the attention-based model, which depends on all tokens and on the parameters θ, is read-out from the y-entry of the updated N + 1-th token as explained in the previous section. The objective of training, visualized in Figure 1, is to minimize the expected squared prediction error, averaged over tasks minθ Eτ[||ˆyθ({eτ,1, . . . , eτ,N}, eτ,N+1) yτ,test||2]. We achieve this by minibatch online minimization (by Adam (Kingma & Ba, 2014)): At every optimization step, we construct a batch of novel training tasks and take a step of stochastic gradient descent on the loss function: τ=1 ||ˆyθ({eτ,i}N i=1, eτ,N+1) yτ,test||2 (5) where each task (context) τ consists of in-context training data Dτ = {(xτ,i, yτ,i)}N i=1 and test point (xτ,N+1, yτ,N+1), which we use to construct our tokens {eτ,i}N+1 i=1 as described above. We denote the optimal parameters found by this optimization process by θ . In our setup, finding θ may be thought of as meta-learning, while learning a particular task τ corresponds to simply evaluating the model ˆyθ({eτ,1, . . . , eτ,N}, eτ,N+1). Note that we therefore never see the exact same training task twice during training. See Appendix A.12, especially Figure 16 for an analyses when using a fixed dataset size which we cycle over during training. We focus on solvable tasks and similarly to Garg et al. (2022) generate data for each task using a teacher model with parameters Wτ N(0, I). We then sample xτ,i U( 1, 1)n I and construct targets using the task-specific teacher model, yτ,i = Wτxτ,i. In the majority of our experiments we set the dimensions to N = n I = 10 and n O = 1. Since we use a noiseless teacher for simplicity, we can expect our regression tasks to be well-posed and analytically solvable as we only compute a loss on the Transformers last token, which stands in contrast to usual autoregressive training and the training setup of Garg et al. (2022). Full details and results for training with a fixed training set size may be found in Appendix A.12. One-step of gradient descent vs. a single trained self-attention layer Our first goal is to investigate whether a trained single, linear self-attention layer can be explained by the provided weight construction that implements GD. To that end, we compare the predictions made by a LSA layer with trained weights θ (which minimize equation 5) and with constructed weights θGD (which satisfy Proposition 1). Recall that a LSA layer yields the prediction ˆyθ(xtest) = e N+1 + LSAθ({e1, . . . , e N}, e N+1) = Wθ,Dxtest, which is linear in xtest. We denote by Wθ,D the matrix generated by the LSA layer following the construction provided in Proposition 1, with query token e N+1 set such that the initial prediction is set to zero, ˆytest = 0. We compare ˆyθ(xtest) to the prediction of the control LSA ˆyθGD(xtest), which under our token construction corresponds to a linear model trained by one step of gradient descent starting from W0 = 0. For this control model, we determine the optimal learning rate η by minimizing L(η) over a training set of 104 tasks through line search, with L(η) defined analogously to equation 5. More concretely, to compare trained and constructed LSA layers, we sample Tval = 104 validation tasks and record the following quantities, averaged over validation tasks: (1) the difference in predictions measured with the L2 norm, ˆyθ(xτ,test) ˆyθGD(xτ,test) , (2) the cosine similarity be- tween the sensitivities ˆyθGD(xτ,test) xtest and ˆyθ(xτ,test) xtest as well as (3) their difference ˆyθGD(xτ,test) xtest ˆyθ(xτ,test) xtest again according to the L2 norm, which in both cases yields the explicit models computed by the algorithm. We show the results of these comparisons in Figure 2. We find an excellent agreement between the two models over a wide range of hyperparameters. We note that as we do not have direct access to the initialization of W in the attention-based learners (it is hidden in θ), we cannot expect the models to agree exactly. Although the above metrics are important to show similarities between the resulting learned models (in-context Transformers Learn In-Context by Gradient Descent vs. gradient-based), the underlying algorithms could still be different. We therefore carry out an extended set of analyses: 1. Interpolation. We take inspiration on recent work (Benzing et al., 2022; Entezari et al., 2021) that showed approximate equivalence of models found by SGD after permuting weights within the trained neural networks. Since our models are deep linear networks with respect to xtest we only correct for scaling mismatches between the two models in this case the construction that implements GD and the trained weights. As shown in Figure 2, we observe (and can actually inspect by eye, see Appendix Figure 9) that a simple scaling correction on the trained weights is enough to recover the weight construction implementing GD. This leads to an identical loss of GD, the trained Transformer and the linearly interpolated weights θI = (θ + θGD)/2. See details in Appendix A.3 on how our weight correction and interpolation is obtained. 2. Out-of-distribution validation tasks. To test if our in-context learner has found a generalizable update rule, we investigate how GD, the trained LSA layer and its interpolation behave when providing in-context data in regimes different to the ones used during training. We therefore visualize the loss increase when (1) sampling the input data from U( α, α)Nx or (2) scaling the teacher weights by α as αW when sampling validation tasks. For both cases, we set α = 1 during training. We again observe that when training a single linear self-attention Transformer, for both interventions, the Transformer performs equally to gradient descent outside of this training setups, see Figure 2 as well Appendix Figure 6. Note that the loss obtained through gradient descent also starts degrading quickly outside the training regime. Since we tune the learning rate for the input range [ 1, 1] and one gradient step, tasks with larger input range will have higher curvature and the optimal learning rate for smaller ranges will lead to divergence and a drastic increase in loss also for GD. 3. Repeating the LSA update. Since we claim that a single trained LSA layer implements a GD-like learning rule, we further test its behavior when applying it repeatedly, not only once as in training. After we correct the learning rate of both algorithms, i.e. for GD and the trained Transformer with a dampening parameter λ = 0.75 (details in Appendix A.6), we see an identical loss decrease of both GD and the Transformer, see Figure 1. To conclude, we present evidence that optimizing a single LSA layer to solve linear regression tasks finds weights that (approximately) coincide with the LSA-layer weight construction of Proposition 1, hence implementing a step of gradient descent, leading to the same learning capabilities on inand out-of-distribution tasks. We comment on the random seed dependent phase transition of the loss during training in Appendix A.11. Multiple steps of gradient descent vs. multiple layers of self-attention We now turn to deep linear self-attention-only Transformers. The construction we put forth in Proposition 1, can be immediately stacked up over K layers; in this case, the final prediction can be read out from the last layer as before by negating the y-entry of the last test token: y N+1 + PK k=1 yk,N+1 = (y N+1 PK k=1 yk,N+1) = y N+1 + PK k=1 Wkx N+1, where yk,N+1 are the test token values at layer k, and yk,N+1 the change in the y-entry of the test token after applying the k-th step of self-attention, and Wk the k-th implicit change in the underlying linear model parameters W. When optimizing such Transformers with K layers, we observe that these models generally outperform K steps of plain gradient descent, see Figure 3. Their behavior is however well described by a variant of gradient descent, for which we tune a single parameter γ defined through the transformation function H(X) which transforms the input data according to xj H(X)xj, with H(X) = (I γXXT ). We term this gradient descent variant GD++ which we explain and analyze in Appendix A.10. To analyze the effect of adding more layers to the architecture, we first turn to the arguably simplest extension of a single SA layer and analyze a recurrent or looped 2-layer LSA model. Here, we simply repeatably apply the same layer (with the same weights) multiple times i.e. drawing the analogy to learning an iterative algorithm that applies the same logic multiple times. Somewhat surprisingly, we find that the trained model surpasses plain gradient descent, which also results in decreasing alignment between the two models (see center left column), and the recurrent Transformer realigns perfectly with GD++ while matching its performance on inand out-of distribution tasks. Again, we can interpolate between the Transformer weights found by optimization and the LSAweight construction with learned η, γ, see Figure 3 & 6. We next consider deeper, non-recurrent 5-layer LSA-only Transformers, with different parameters per layer (i.e. no weight tying). We see that a different GD learning rate as well as γ per step (layer) need to be tuned to match the Transformer performance. This slight modification leads again to almost perfect alignment between the trained TF and GD++ with in this case 10 additional parameters and Transformers Learn In-Context by Gradient Descent (a) Comparing two steps of gradient descent with trained recurrent two-layer Transformers. 0 1000 2000 3000 Training steps GD vs trained TF 0 1000 2000 3000 Training steps Preds diff Model diff GD+ + vs trained TF 0 1000 2000 3000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs Interpolated Trained TF (b) Comparing five steps of gradient descent with trained five-layer Transformers. 0 20000 40000 Training steps GD+ + 5 steps GD vs trained TF 0 20000 40000 Training steps Preds diff Model diff GD+ + vs trained TF 0 20000 40000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs Figure 3. Far left column: The trained TF performance surpasses standard GD but matches GD++, our GD variant with simple iterative data transformation. On both cases, we tuned the gradient descent learning rates as well as the scalar γ which governs the data transformation H(X). Center left & center right columns: We measure the alignment between the GD as well as the GD++ models and the trained TF. In both cases the TF aligns well with GD in the beginning of training but aligns much better with GD++ after training. Far right column: TF performance (in log-scale) mimics the one of GD++ well when testing on OOD tasks (α = 1). loss close to 0, see Figure 3. Nevertheless, we see that the naive correction necessary for model interpolation used in the aforementioned experiments is not enough to interpolate without a loss increase. We leave a search for better weight corrections to future work. We further study Transformers with different depths for recurrent as well as non-recurrent architectures with multiple heads and equipped with MLPs, and find qualitatively equivalent results, see Appendix Figure 7 and Figure 8. Additionally, in Appendix A.9, we provide results obtained when using softmax SA layers as well as Layer Norm, thus essentially retrieving the standard Transformer architecture. We again observe and are able to explain (after slight architectural modifications) good learning performance and as well as alignment with the construction of Proposition 1, though worse than when using linear self-attention. These findings suggest that the incontext learning abilities of the standard Transformer with these common architecture choices can be explained by the gradient-based learning hypothesis explored here. Our findings also question the ubiquitous use of softmax attention, and suggest further investigation is warranted into the performance of linear vs. softmax SA layers in real-world learning tasks, as initiated by Schlag et al. (2021). Transformers solve nonlinear regression tasks by gradient descent on deep data representations It is unreasonable to assume that the astonishing in-context learning flexibility observed in large Transformers is ex- plained by gradient descent on linear models. We now show that this limitation can be resolved by incorporating one additional element of fully-fledged Transformers: preceding self-attention layers by MLPs enables learning linear models by gradient descent on deep representations which motivates our illustration in Figure 1. Empirically, we demonstrate this by solving non-linear sine-wave regression tasks, see Figure 4. Experimental details can be found in Appendix A.7. We state Proposition 2. Given a Transformer block i.e. a MLP m(e) which transforms the tokens ej = (xj, yj) followed by an attention layer, we can construct weights that lead to gradient descent dynamics descending 1 2N PN i=1 ||Wm(xi) yi||2. Iteratively applying Transformer blocks therefore can solve kernelized least-squares regression problems with kernel function k(x, y) = m(x) m(y) induced by the MLP m( ). A detailed discussion on this form of kernel regression as well as kernel smoothing w/wo softmax nonlinearity through gradient descent on the data can be found in Appendix A.8. The way MLPs transform data in Transformers diverges from the standard meta-learning approach, where a task-shared input embedding network is optimized by backpropagation-through-training to improve the learning performance of a task-specific readout (e.g., Raghu et al., 2020; Lee et al., 2019; Bertinetto et al., 2019). On the other hand, given our token construction in Proposition 1, MLPs in Transformers intriguingly process both inputs and targets. The output of this transformation is then processed by a sin- Transformers Learn In-Context by Gradient Descent gle linear self-attention layer, which, according to our theory, is capable of implementing gradient descent learning. We compare the performance of this Transformer model, where all weights are learned, to a control Transformer where the final LSA weights are set to the construction θGD which is therefore identical to training an MLP by backpropagation through a GD updated output layer. Intriguingly, both obtained functions show again surprising similarity on (1) the initial (meta-learned) prediction, read out after the MLP, and (2) the final prediction, after altering the output of the MLP through GD or the self-attention layer. This is again reflected in our alignment measures that now, since the obtained models are nonlinear w.r.t. xtest, only represent the two first parts of the Taylor approximation of the obtained functions. Our results serve as a first demonstration of how MLPs and self-attention layers can interplay to support nonlinear in-context learning, allowing to fine-tune deep data representations by gradient descent. Investigating the interplay between MLPs and SA-layer in deep TFs is left for future work. 4. Do self-attention layers build regression tasks? The construction provided in Proposition 1 and the previous experimental section relied on a token structure where both input and output data are concatenated into a single token. This design is different from the way tokens are typically built in most of the related work dealing with simple few-shot learning problems as well as in e.g. language modeling. We therefore ask: Can we overcome the assumption required in Proposition 1 and allow a Transformer to build the required token construction on its own? This motivates Proposition 3. Given a 1-head linear or softmax attention layer and the token construction e2j = (xj), e2j+1 = (0, yj) with a zero vector 0 of dim Nx Ny and concatenated positional encodings, one can construct key, query and value matrix WK, WQ, WV as well as the projection matrix P such that all tokens ej are transformed into tokens equivalent to the ones required in Proposition 1. The construction and its discussion can be found in Appendix A.5. To provide evidence that copying is performed in trained Transformers, we optimize a two-layer self-attention circuit on in-context data where alternating tokens include input or output data i.e. e2j = (xj) and e2j+1 = (0, yj). We again measure the loss as well as the mean of the norm of the partial derivative of the first layer s output w.r.t. the input tokens during training, see Figure 5. First, the training speeds are highly variant given different training seeds, also reported in Garg et al. (2022). Nevertheless, the Transformer is able to match the performance of a single (not two) step gradient descent. Interestingly, before the Transformer performance jumps to the one of GD, token ej transformed by the first self-attention layer becomes notably dependant on the neighboring token ej+1 while staying independent on the others which we denote as eother in Figure 5. 0 10000 20000 30000 40000 Training steps GD 1 step TF 2 layers 0 10000 20000 30000 40000 Training steps Norm part. derivatives t(ej)/ ej + 1 t(ej)/ eother Figure 5. Training a two layer SA-only Transformer using the standard token construction. Left: The loss of trained TFs matches one step of GD, not two, and takes an order of magnitude longer to train. Right: Norm of the partial derivatives of the output of the first self-attention layer w.r.t. input tokens. Before the Transformer performance jumps to the one of GD, the first layer becomes highly sensitive to the next token. We interpret this as evidence for a copying mechanism of the Transformer s first layer to merge input and output data into single tokens as required by Proposition 1. Then, in the second layer the Transformer performs a single step of GD. Notably, we were not able to train the Transformer with linear self-attention layers, but had to incorporate the softmax operation in the first layer. These preliminary findings support the study of Olsson et al. (2022) showing that softmax self-attention layers easily learn to copy; we confirm this claim, and further show that such copying allows the Transformer to proceed by emulating gradient-based learning in the second or deeper attention layers. We conclude that copying through (softmax) attention layers is the second crucial mechanism for in-context learning in Transformers. This operation enables Transformers to merge data from different tokens and then to compute dot products of input and target data downstream, allowing for in-context learning by gradient descent to emerge. 5. Discussion Transformers show remarkable in-context learning behavior. Mechanisms based on attention, associative memory and copying by induction heads are currently the leading explanations for this remarkable feature of learning within the Transformer forward pass. In this paper, we put forward the hypothesis, similar to Garg et al. (2022) and Aky urek et al. (2023), that Transformer s in-context learning is driven by gradient descent, in short Transformers learn to learn by gradient descent based on their context. Viewed through the lens of meta-learning, learning Transformer weights corresponds to the outer-loop which then enables the forward Transformers Learn In-Context by Gradient Descent 4 2 0 2 4 x GD init GD step 1 Tr. TF init Tr. TF step 1 0 20000 40000 Training steps GD Trained TF Partial cosine 0 20000 40000 Training steps Preds diff Partial diff Figure 4. Sine wave regression: comparing trained Transformers with meta-learned MLPs for which we adjust the output layer with one step of gradient descent. Left: Plots of the learned initial functions as well as the adjusted functions through either a layer of self-attention or a step of GD. We observe similar initial functions as well as solutions for the trained TF compared fine-tuning a meta-learned MLP. Center: The performance of the trained Transformer is matched by meta-learned MLPs. Left: We observe strong alignment when comparing the prediction as well as the partial derivatives of the the meta-learned MLP and the trained Transformer. pass to transform tokens by gradient-based optimization. To provide evidence for this hypothesis, we build on Schlag et al. (2021) that already provide a linear self-attention layer variant with (fast-)inner loop learning by the error-correcting delta rule (Widrow & Hoff, 1960). We diverge from their setting and focus on (in-context) learning where we specifically construct a dataset by considering neighboring elements in the input sequence as inputand target training pairs, see assumptions of Proposition 1. This construction could be realized, for example, due to the model learning to implement a copying layer, see section 4 and proposition 3, and allows us to provide a simple and different construction to Schlag et al. (2021) that solely is built on the standard linear, and approximately softmax, self-attention layer but still implements gradient descent based learning dynamics. We, therefore, are able to explain gradient descent based learning in these standard architectures. Furthermore, we extend this construction based on a single self-attention layer and provide an explanation of how deeper K-layer Transformer models implement principled K-step gradient descent learning, which deviates again from Schlag et al. and allows us to identify that deep Transformers implement GD++, an accelerated version of gradient descent. We highlight that our construction of gradient descent and GD++ is not suggestive but when training multi-layer selfattention-only Transformers on simple regression tasks, we provide strong evidence that the construction is actually found. This allows us, at least in our restricted problems settings, to explain mechanistically in-context learning in trained Transformers and its close resemblance to GD observed by related work. Further work is needed to incorporate regression problems with noisy data and weight regularization into our hypothesis. We speculate aspects of learning in these settings are meta-learned e.g., the weight magnitudes to be encoded in the self-attention weights. Additionally, we did not analyze logistic regression for which one possible weight construction is already presented in Zhmoginov et al. (2022). Our refined understanding of in-context learning based on gradient descent motives us to investigate how to improve it. We are excited about several avenues of future research. First, to exceed upon a single step of gradient descent in every self-attention layer it could be advantageous to incorporate so called declarative nodes (Amos & Kolter, 2017; Bai et al., 2019; Gould et al., 2021; Zucchet & Sacramento, 2022) into Transformer architectures. This way, we would treat a single self-attention layer as the solution of a fully optimized regression loss leading to possibly more efficient architectures. Second, our findings are restricted to small Transformers and simple regression problems. We are excited to delve deeper into research trying to understand how further mechanistic understanding of Transformers and incontext learning in larger models is possible and to what extend. Third, we are excited about targeted modifications to Transformer architectures, or their training protocols, leading to improved gradient descent based learning algorithms or allow for alternative in-context learners to be implemented within Transformer weights, augmenting their functionality, as e.g. in Dai et al. (2023). Finally, it would be interesting to analyze in-context learning in Hyper Transformers (Zhmoginov et al., 2022) that produce weights for target networks and already offer a different perspective on merging Transformers and meta-learning. There, Transformers transform weights instead of data and could potentially allow for gradient computations of weights deep inside the target network lifting the limitation of GD on linear models analyzed here. Acknowledgments Jo ao Sacramento and Johannes von Oswald deeply thank Angelika Steger for her support and guidance. The authors also thank Seijin Kobayashi, Marc Kaufmann, Nicolas Zucchet, Yassir Akram, Guillaume Obozinski and Mark Sandler for many valuable insights throughout the project and Dale Schuurmans and Timothy Nguyen for their valuable comments on the manuscript. Jo ao Sacramento was supported by an Ambizione grant (PZ00P3 186027) from the Swiss National Science Foundation and an ETH Research Grant (ETH-23 21-1). Transformers Learn In-Context by Gradient Descent Aky urek, E., Schuurmans, D., Andreas, J., Ma, T., and Zhou, D. What learning algorithm is in-context learning? investigations with linear models. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum? id=0g0X4H8y N4I. Amos, B. and Kolter, J. Z. Optnet: Differentiable optimization as a layer in neural networks. In International Conference on Machine Learning, 2017. Andrychowicz, M., Denil, M., Gomez, S., Hoffman, M. W., Pfau, D., Schaul, T., Shillingford, B., and de Freitas, N. Learning to learn by gradient descent by gradient descent. In Advances in Neural Information Processing Systems, 2016. Ba, J., Hinton, G. E., Mnih, V., Leibo, J. Z., and Ionescu, C. Using fast weights to attend to the recent past. In Advances in Neural Information Processing Systems 29, 2016. Bai, S., Kolter, J. Z., and Koltun, V. Deep equilibrium models. Advances in Neural Information Processing Systems, 2019. Bengio, Y., Bengio, S., and Cloutier, J. Learning a synaptic learning rule. Technical report, Universit e de Montr eal, D epartement d Informatique et de Recherche op erationnelle, 1990. Benzing, F., Schug, S., Meier, R., von Oswald, J., Akram, Y., Zucchet, N., Aitchison, L., and Steger, A. Random initialisations performing above chance and how to find them. OPT2022: 14th Annual Workshop on Optimization for Machine Learning, 2022. Bertinetto, L., Henriques, J. F., Torr, P. H. S., and Vedaldi, A. Meta-learning with differentiable closed-form solvers. In International Conference on Learning Representations, 2019. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., Mc Candlish, S., Radford, A., Sutskever, I., and Amodei, D. Language models are few-shot learners. ar Xiv preprint ar Xiv:2005.14165, 2020. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., and Zagoruyko, S. End-to-end object detection with transformers. In Computer Vision ECCV 2020. Springer International Publishing, 2020. Chalmers, D. J. The evolution of learning: an experiment in genetic connectionism. In Touretzky, D. S., Elman, J. L., Sejnowski, T. J., and Hinton, G. E. (eds.), Connectionist Models, pp. 81 90. Morgan Kaufmann, 1991. Chan, S. C. Y., Dasgupta, I., Kim, J., Kumaran, D., Lampinen, A. K., and Hill, F. Transformers generalize differently from information stored in context vs in weights. ar Xiv preprint ar Xiv:2210.05675, 2022a. Chan, S. C. Y., Santoro, A., Lampinen, A. K., Wang, J. X., Singh, A., Richemond, P. H., Mc Clelland, J., and Hill, F. Data distributional properties drive emergent in-context learning in transformers. Advances in Neural Information Processing Systems, 2022b. Choromanski, K. M., Likhosherstov, V., Dohan, D., Song, X., Gane, A., Sarlos, T., Hawkins, P., Davis, J. Q., Mohiuddin, A., Kaiser, L., Belanger, D. B., Colwell, L. J., and Weller, A. Rethinking attention with performers. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum? id=Ua6zuk0WRH. Dai, D., Sun, Y., Dong, L., Hao, Y., Ma, S., Sui, Z., and Wei, F. Why can GPT learn in-context? language models implicitly perform gradient descent as meta-optimizers. In ICLR 2023 Workshop on Mathematical and Empirical Understanding of Foundation Models, 2023. URL https: //openreview.net/forum?id=fzb HRj Ad8U. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., and Houlsby, N. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021. URL https:// openreview.net/forum?id=Yicb Fd NTTy. Entezari, R., Sedghi, H., Saukh, O., and Neyshabur, B. The role of permutation invariance in linear mode connectivity of neural networks. ar Xiv preprint ar Xiv:2110.06296, 2021. Finn, C. and Levine, S. Meta-learning and universality: Deep representations and gradient descent can approximate any learning algorithm. In International Conference on Learning Representations, 2018. URL https: //openreview.net/forum?id=Hyj C5y WCW. Finn, C., Abbeel, P., and Levine, S. Model-agnostic metalearning for fast adaptation of deep networks. In International Conference on Machine Learning, 2017. Flennerhag, S., Rusu, A. A., Pascanu, R., Visin, F., Yin, H., and Hadsell, R. Meta-learning with warped gradient descent. In International Conference on Learning Representations, 2020. Transformers Learn In-Context by Gradient Descent Garg, S., Tsipras, D., Liang, P., and Valiant, G. What can transformers learn in-context? a case study of simple function classes. In Oh, A. H., Agarwal, A., Belgrave, D., and Cho, K. (eds.), Advances in Neural Information Processing Systems, 2022. URL https: //openreview.net/forum?id=fl NZJ2e Oet. Gordon, J., Bronskill, J., Bauer, M., Nowozin, S., and Turner, R. Meta-learning probabilistic inference for prediction. In International Conference on Learning Representations, 2019. URL https://openreview. net/forum?id=Hkx Sto C5F7. Gould, S., Hartley, R., and Campbell, D. J. Deep declarative networks. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021. Gulati, A., Qin, J., Chiu, C.-C., Parmar, N., Zhang, Y., Yu, J., Han, W., Wang, S., Zhang, Z., Wu, Y., and Pang, R. Conformer: Convolution-augmented transformer for speech recognition. ar Xiv preprint ar Xiv:2005.08100, 2020. Hendrycks, D. and Gimpel, K. Gaussian error linear units (gelus). ar Xiv preprint ar Xiv:1606.08415, 2016. Hinton, G. E. and Plaut, D. C. Using fast weights to deblur old memories. 1987. Hochreiter, S., Younger, A. S., and Conwell, P. R. Learning to learn using gradient descent. In Dorffner, G., Bischof, H., and Hornik, K. (eds.), Artificial Neural Networks ICANN 2001, pp. 87 94, Berlin, Heidelberg, 2001. Springer Berlin Heidelberg. ISBN 978-3-540-44668-2. Hubinger, E., van Merwijk, C., Mikulik, V., Skalse, J., and Garrabrant, S. Risks from learned optimization in advanced machine learning systems. ar Xiv [cs.AI], Jun 2019. URL http://arxiv.org/abs/1906. 01820. Irie, K., Schlag, I., Csord as, R., and Schmidhuber, J. Going beyond linear transformers with recurrent fast weight programmers. Co RR, abs/2106.06295, 2021. URL https://arxiv.org/abs/2106.06295. Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization, 2014. Kirsch, L. and Schmidhuber, J. Meta learning backpropagation and improving it. In Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J. W. (eds.), Advances in Neural Information Processing Systems, 2021. URL https: //openreview.net/forum?id=hh U9TEv B6AF. Kirsch, L., Harrison, J., Sohl-Dickstein, J., and Metz, L. General-purpose in-context learning by meta-learning transformers. In Sixth Workshop on Meta-Learning at the Conference on Neural Information Processing Systems, 2022. URL https://openreview.net/forum? id=t6t A-KB4d O. Lee, K., Maji, S., Ravichandran, A., and Soatto, S. Metalearning with differentiable convex optimization. In IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019. Lee, Y. and Choi, S. Gradient-based meta-learning with learned layerwise metric and subspace. In International Conference on Machine Learning, 2018. Li, Z., Zhou, F., Chen, F., and Li, H. Meta-SGD: Learning to learn quickly for few shot learning. ar Xiv preprint ar Xiv:1707.09835, 2017. Liu, P., Yuan, W., Fu, J., Jiang, Z., Hayashi, H., and Neubig, G. Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. ar Xiv preprint ar Xiv:2107.13586, 2021. Nadaraya, E. A. On estimating regression. Theory of Probability & its Applications, 9(1):141 142, 1964. Olsson, C., Elhage, N., Nanda, N., Joseph, N., Das Sarma, N., Henighan, T., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Johnston, S., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., Mc Candlish, S., and Olah, C. Incontext learning and induction heads. ar Xiv preprint ar Xiv:2209.11895, 2022. Park, E. and Oliva, J. B. Meta-curvature. In Advances in Neural Information Processing Systems, 2019. Power, A., Burda, Y., Edwards, H., Babuschkin, I., and Misra, V. Grokking: Generalization beyond overfitting on small algorithmic datasets. abs/2201.02177, 2022. Raghu, A., Raghu, M., Bengio, S., and Vinyals, O. Rapid learning or feature reuse? Towards understanding the effectiveness of MAML. In International Conference on Learning Representations, 2020. Ramsauer, H., Sch afl, B., Lehner, J., Seidl, P., Widrich, M., Adler, T., Gruber, L., Holzleitner, M., Pavlovi c, M., Sandve, G. K., Greiff, V., Kreil, D., Kopp, M., Klambauer, G., Brandstetter, J., and Hochreiter, S. Hopfield networks is all you need. ar Xiv preprint ar Xiv:2008.02217, 2020. Rusu, A. A., Rao, D., Sygnowski, J., Vinyals, O., Pascanu, R., Osindero, S., and Hadsell, R. Meta-learning with latent embedding optimization. In International Conference on Learning Representations, 2019. Schlag, I., Irie, K., and Schmidhuber, J. Linear transformers are secretly fast weight programmers. In ICML, 2021. Transformers Learn In-Context by Gradient Descent Schmidhuber, J. Evolutionary principles in self-referential learning, or on learning how to learn: the meta-meta-... hook. Diploma thesis, Institut f ur Informatik, Technische Universit at M unchen, 1987. Schmidhuber, J. Learning to control fast-weight memories: An alternative to dynamic recurrent networks. Neural Computation, 4(1):131 139, 1992. doi: 10.1162/neco. 1992.4.1.131. Thrun, S. and Pratt, L. Learning to learn. Springer US, 1998. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need, 2017. von Oswald, J., Zhao, D., Kobayashi, S., Schug, S., Caccia, M., Zucchet, N., and Sacramento, J. Learning where to learn: Gradient sparsity in meta and continual learning. In Advances in Neural Information Processing Systems, 2021. Watson, G. S. Smooth regression analysis. Sankhy a: The Indian Journal of Statistics, Series A, pp. 359 372, 1964. Widrow, B. and Hoff, M. E. Adaptive switching circuits. In 1960 IRE WESCON Convention Record, Part 4, pp. 96 104, New York, 1960. IRE. Yun, S., Jeong, M., Kim, R., Kang, J., and Kim, H. J. Graph transformer networks. In Wallach, H., Larochelle, H., Beygelzimer, A., d Alch e-Buc, F., Fox, E., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, 2019. Zhang, A., Lipton, Z. C., Li, M., and Smola, A. J. Dive into deep learning. ar Xiv preprint ar Xiv:2106.11342, 2021. Zhao, D., Kobayashi, S., Sacramento, J., and von Oswald, J. Meta-learning via hypernetworks. In Neur IPS Workshop on Meta-Learning, 2020. Zhmoginov, A., Sandler, M., and Vladymyrov, M. Hyper Transformer: Model generation for supervised and semi-supervised few-shot learning. In Chaudhuri, K., Jegelka, S., Song, L., Szepesvari, C., Niu, G., and Sabato, S. (eds.), Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pp. 27075 27098. PMLR, 17 23 Jul 2022. URL https://proceedings.mlr. press/v162/zhmoginov22a.html. Zucchet, N. and Sacramento, J. Beyond backpropagation: bilevel optimization through implicit differentiation and equilibrium propagation. Neural Computation, 34(12), December 2022. Transformers Learn In-Context by Gradient Descent A. Appendix A.1. Proposition 1 First, we highlight the dependency on the tokens ei of the linear self-attention operation ej ej + LSAθ({e1, . . . , e N}) = ej + X h Ph Vh KT h qh,j = ej + X i vh,i kh,iqh,j h Ph Wh,V X i eh,i eh,i W T h,KWh,Qej (6) with the outer product between two vectors. With this we can now easily draw connections to one step of gradient descent on L(W) = 1 2N PN i=1 Wxi yi 2 with learning rate η which yields weight change W = η W L(W) = η i=1 (Wxi yi)x T i . (7) We first restate Proposition 1. Given a 1-head linear attention layer and the tokens ej = (xj, yj), for j = 1, . . . , N, one can construct key, query and value matrices WK, WQ, WV as well as the projection matrix P such that a Transformer step on every token ej is identical to the gradient-induced dynamics ej (xj, yj) + (0, Wxj) = (xi, yi) + P V KT qj such that ej = (xj, yj yj). For the test data token (x N+1, y N+1) the dynamics are identical. We provide the weight matrices in block form: WK = WQ = Ix 0 0 0 with Ix and Iy the identity matrices of size Nx and Ny respectively. Furthermore, we set WV = 0 0 W0 Iy with the weight matrix W0 RNy Nx of the linear model we wish to train and P = η N I with identity matrix of size Nx + Ny. With this simple construction we obtain the following dynamics for every token ej = (xj, yj) including the query token e N+1 = etest = (xtest, W0xtest) which will give us the desired result. A.2. Comparing the out-of-distribution behavior of trained Transformers and GD We provide more experimental results when comparing GD with tuned learning rate η and data transformation scalar γ and the trained Transformer on other data distributions than provided during training, see Figure 6. We do so by changing the in-context data distribution and measure the loss of both methods averaged over 10.000 tasks when either changing α that 1) affects the input data range x U( α, α)Nx or 2) the teacher by αW with W N(0, I). This setups leads to results shown in the main text, in the first two columns of Figure 6 and in the corresponding plots of Figure 7. Although the match for deeper architectures starts to become worse, overall the trained Transformers behaves remarkably similar to GD and GD++ for layer depth greater than 1. Furthermore, we try GD and the trained Transformer on input distributions that it never has seen during training. Here, we chose by chance of 1/3 either a normal, exponential or Laplace distribution (with JAX default parameters) and depict the average loss value over 10.000 tasks where the α value now simply scales the input values that are sampled from one of the distributions αx. The teacher scaling is identical to the one described above. See for results the two right columns of Figure 6, where we see almost identical behavior for recurrent architectures with less good match for deeper non-recurrent Transformers Learn In-Context by Gradient Descent (a) Comparing one step of gradient descent with trained one layer Transformers on OOD data. 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF 1 2 3 4 5 where x Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF (b) Comparing two steps of gradient descent with trained recurrent two layer Transformers on OOD data. 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets Interpolated Trained TF 1 2 3 4 5 where x Test on larger inputs Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets Interpolated Trained TF (c) Comparing five steps of gradient descent with trained five layer Transformers on OOD data. 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs 1 2 3 4 5 W where W N(0, I) Test on larger targets 1 2 3 4 5 where x Test on larger inputs 1 2 3 4 5 W where W N(0, I) Test on larger targets Figure 6. Left & center left column: Comparing Transformers, GD and their weight interpolation on rescaled training distributions. In all setups, the trained Transformer behaves remarkably similar to GD or GD++. Right & center right: Comparing Transformers, GD and their weight interpolation on data distributions never seen during training. Again, in all setups, the trained Transformer behaves remarkably similar to GD or GD++ with less good match for deep non-recurrent Transformers far away from training regimes. architectures far away from the training range of α = 1. Note that for deeper Transformers (K > 2) the corresponding GD and GD++ version, see for more experimental details Appendix section A.12, we include a harsh clipping of the token values after every step of transformation between [ 10, 10] (for the trained TF and GD) to improve training stability. Therefore, the loss increase is restricted to a certain value and plateaus. A.3. Linear mode connectivity between the weight construction of Prop 1 and trained Transformers In order to interpolate between the construction θGD and the trained weights of the Transformer θ, we need to correct for some scaling ambiguity. For clarification, we restate here the linear self-attention operation for a single head ej ej + PWV X i ei ei W T KWQej (9) = ej + WP V X i ei ei WKQej (10) Now, to match the weight construction of Prop. 1 we have the aim for the matrix product WKQ to match an identify matrix (except for the last diagonal entry) after re-scaling. Therefore we compute the mean of the diagonal of the matrix product of the trained Transformer weights WKQ which we denote by β. After resealing both operations i.e. WKQ WKQ/β and WP V WP V β we interpolate linearly between the matrix products of GD as well as these rescaled trained matrix products i.e. WI,KQ = (WGD,KQ + WT F,KQ)/2 as well as WI,P V = (WGD,P V + WT F,P V )/2. We use these parameters to obtain results throughout the paper denote with Interpolated. We do so for GD as well as GD++ when comparing to Transformers Learn In-Context by Gradient Descent 0 5000 10000 15000 Training steps GD vs trained TF 0 5000 10000 15000 Training steps Preds diff Model diff GD+ + vs trained TF 0 5000 10000 15000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets Interpolated Trained TF 1 2 3 4 5 where x Test on larger inputs Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets Interpolated Trained TF Figure 7. Comparing ten steps of gradient descent with trained recurrent ten-layer Transformers. Results comparable to recurrent Transformer with two layers, see Figure 3, but now with 10 repeated layers. We again observe for deeper recurrent linear self-attention only Transformers that overall GD++ and the trained Transformer align very well with one another and are again interpolatable leading to very similar behavior insight as well as outside training situations. Note the inferior performance to the non-recurrent five-layer Transformer which highlights the importance on specific learning rate as well γ parameter per layer/step. Figure 8. Comparing twelve steps of GD++ with a trained twelve-layer Transformers with MLPs and 4 headed linear self-attention layer. Results comparable to the deep recurrent Transformer, see Figure 7, but now with 12 independent Transformer blocks including MLPs and 4-head linear self-attention. We omit Layer Norm. We again observe a close resemblance of the trained Transformers and GD++. We hypotheses that even when equipped with multiple heads and MLPs, Transformers approximate GD++. Transformers Learn In-Context by Gradient Descent 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 Weights of WT 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 Weight of PWV 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 Weights of WT 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 Weight of PWV Figure 9. Visualizing the weight matrices of trained Transformers. Left & outer left: Weight matrix products of a trained single linear self-attention layer. We see (after scalar correction) a perfect resemblance of our construction. Right & outer right: Weight matrix products of a trained 3-layer recurrent linear self-attention Transformer. Again, we see (after scalar correction) a perfect resemblance of our construction and an additional curvature correction i.e. diagonal values in PWV of the same magnitude except the last entry that functions as the learning rate. recurrent Transformers. Note that for non-recurrent Transformers, we face more ambiguity that we have to correct for since e.g. scalings influence each other across layer. We also see this in practice and are not able (only for some seeds) to interpolate between weights with our simple correction from above. We leave the search for more elaborate corrections for future work. A.4. Visualizing the trained Transformer weights The simplicity of our construction enables us to visually compare trained Transformers and the construction put forward in Proposition A.1 in weight space. As discussed in the previous section A.3 there is redundancy in the way the trained Transformer can construct the matrix products leading to the weights corresponding to gradient descent. We therefore visualize WKQ = W T KWQ as well as WP V = PKWV in Figure 9. A.5. Proof and discussion of Proposition 3 We state here again Proposition 3, provide the necessary construction and a short discussion. Proposition 3. Given a 1-head linearor softmax attention layer and the token construction e2j = (xj), e2j+1 = (0, yj) with a zero vector 0 of dim Nx Ny and concatenated positional encodings, one can construct key, query and value matrix WK, WQ, WV as well as the projection matrix P such that all tokens ej are transformed into tokens equivalent to the ones required in proposition 1. To get a simple and clean construction, we choose wlog xj R2N+1 and (0, yj) R2N+1 as well as model the positional encodings as unit vectors pj R2N+1 and concatenate them to the tokens i.e. ej = (xj/2, pj). We wish for a construction that realizes + 0 yj/2+1 pj This means that a token replaces its own positional encoding by coping the target data of the next token to itself leading to ej = (xj/2, 0, yj/2+1), with slight abusive of notation. This can simply be realized by (for example) setting P = I, WV = 0 0 Ix Ix,off , WK = 0 0 0 Ix and WQ = 0 0 0 IT x,off with Ix,off the lower diagonal identity matrix fo size Nx. Note that then simply KT WQej = pj+1 i.e. it chooses the j + 1 element of V which stays pj+1 if we apply the softmax operation on KT qj. Since the j + 1 entry of V is (0, yj/2+1 pj) we obtain the desired result. For the (toy-)regression problems considered in this manuscript, the provided result would give N/2 tokens for which we also copy (parts) of xj underneath yj. This is desired for modalities such as language where every two tokens could be considered an in-and output pair for the implicit autoregressive inner-loop loss. These tokens do not have be necessarily next to each other, see for this behavior experimental findings presented in (Olsson et al., 2022). For the experiments conducted here, one solution is to zero out these tokens which could be constructed by a two-head self-attention layer that given uneven j simply subtracts itself resulting in a zero token. For all even tokens, we use the construction from above which effectively coincides with the token construction required in Proposition 1. Transformers Learn In-Context by Gradient Descent Rolling out experiment with different dampening strength 0 10 20 30 40 50 GD Steps / Transformer Layers Dampening = 1 GD Trained TF 0 10 20 30 40 50 GD Steps / Transformer Layers Dampening = 0.875 GD Trained TF 0 20 40 GD Steps / Transformer Layers Dampening = 0.75 GD Trained TF Figure 10. Roll-out experiments: applying a trained single linear self-attention layer multiple times. We observe that different dampening strengths affect the generalization of both methods with slightly better robustness for GD which matching performance for 50 steps when λ = 0.75. A.6. Dampening the self-attention layer As an additional out-of-distribution experiment, we test the behavior when repeating a single LSA-layer trained to lower our objective, see equation 5, with the aim to repeat the learned learning/update rule. Note that GD as well as the selfattention layer were optimized to be optimal for one step. For GD we line search the otpimal learning rate η on 10.000 task. Interestingly, for both methods we observe quick divergence when applied multiple times, see left plot of Figure 10. Nevertheless, both of our update functions are described by a linear self-attention layer for which we can control the norm, post training, by a simple scale which we denote as λ. This results in the new update ytest + λ Wxtest for GD and ytest + λPV KT WQxtest for the trained self-attention layer which effectively re-tunes the learning rate for GD and the trained self-attention layer. Intriguingly, both methods do generalize similarly well (or poorly) on this out-of-distribution experiment when changing λ, see again Figure 10. We show in Figure 1 the behavior for λ = 0.75 for which we see both methods steadily decreasing the loss within 50 steps. A.7. Sine wave regression For the sine wave regression tasks, we follow (Finn et al., 2017) and other meta-learning literature and sample for each task an amplitude a U(0.1, 5) and a phase ρ U(0, π). Each tasks consist of N = 10 data points where inputs are sampled x U( 5, 5) and targets computed by y = a sin(ρ + x). We choose here for the first time, for GD as well as for the Transformer, an input embedding emb that maps tokens ei = (xi, yi) into a 40 dimensional space emb(ei) = Wembei through an affine projection without bias. We skip the first self-attention layer but, as usually done in Transformers, then transform the embedded tokens through an MLP m with a single hidden layer, widening factor of 4 (160 hidden neuros) and GELU nonlinearity (Hendrycks & Gimpel, 2016) i.e. ej m(emb(ej)) + emb(ej). We interpret the last entry of the transformed tokens as the (transformed) targets and the rest as a higher-dimensional input data representation on which we train a model with a single gradient descent step. We compare the obtained meta-learned GD solution with training a Transformer on the same token embeddings but instead learn a self-attention layer. Note that the embeddings of the tokens, including the transformation through the MLP, are not dependent on an interplay between the tokens. Furthermore, the initial transformation is dependent on ei = (xi, yi), i.e., input as well as on the target data except for the query token for which ytest = 0. This means that this construction is, except for the additional dependency on targets, close to a large corpus of meta-learning literature that aims to find a deep representation optimized for (fast) fine tuning and few-shot learning. In order to compare the meta-training of the MLP and the Transformer, we choose the same seed to initialize the network weights for the MLPs and the input embedding trained by meta-learning i.e. backprop through training or the Transformer. This leads to the plots and almost identical learned initial function and updated functions shown in Figure 4. A.8. Proposition 2 and connections between gradient descent, kernelized regression and kernel smoothing Let s consider the data transformation induced by an MLP m(x) and a residual connection commonly used in Transformer blocks i.e. ej ej + m(ej) = (xj, yj) + ( m(xj), 0) = (m(xj), yj) with m(xj) = xj + m(xj) and m not changing the targets y. When simply applying Proposition 1, it is easy to see that given this new token construction, a linear self-attention layer can induce the token dynamics ej (m(xj), yj)+(0, Wm(xj)) with W = η L(W) given the loss function L(W) = 1 2N PN i=1 ||Wm(xi) yi||2. Transformers Learn In-Context by Gradient Descent Interestingly, for the test token etest = (xtest, 0) this induces, after a multiplication with 1, an initial prediction after a single Transformer block given by ˆy = Wm(xtest) = η W L(0)m(xtest) = i=1 yim(xi)T m(xtest) = i=1 yik(xi, xtest) (13) with m(xi)T m(xtest) = k(xi, xtest) R interpreted as a kernel function. Concluding, we see that the combination of MLPs and a single self-attention layer can lead to dynamics induced when descending a kernelized regression (squared error) loss with a single step of gradient-descent. Interestingly, when choosing W0 = 0, we furthermore see that a single self-attention layer or Transformer block can be regarded as doing nonparametric kernel smoothing ˆy = PN i=1 yik(xi, xtest) based on the data given in-context (Nadaraya, 1964; Watson, 1964). Note that we made a particular choice of kernel function here and that this view still holds when m(xj) = 1 i.e. consider Transformers without MLPs or leverage the well-known view of softmax self-attention layer as a kernel function used to measure similarity between tokens (e.g. Choromanski et al., 2021; Zhang et al., 2021). Thus, implementing one step of gradient descent through a self-attention layer (w/wo softmax nonlinearity) is equivalent to performing kernel smoothing estimation. We however argue that this nonparametric kernel smoothing view of in-context learning is limited, and arises from looking only at a single self-attention layer. When considering deeper Transformer architectures, we see that multiple Transformer blocks can iteratively transform the targets based on multiple steps of gradient descent leading to minimization of a kernelized squared error loss L(W). One way to obtain a suitable construction is by neglecting MLPs everywhere except in the first Transformer block. We leave the study of the exact mechanics, especially how the Transformer makes use of possibility transforming the targets through the MLPs, and the possibility of iteratively changing the kernel function throughout depth for future study. A.9. Linear vs. softmax self-attention as well Layer Norm Transformers Although linear Transformers and their variants have been shown to be competitive with their softmax counterpart (Irie et al., 2021), the removal of this nonlinearity is still a major departure from classic Transformers and more importantly from the Transformers used in related studies analyzing in-context learning. In this section we investigate whether and when gradient-based learning emerges in trained softmax self-attention layers, and we provide an analytical argument to back our findings. First, we show, see Figure 12, that a single layer of softmax self-attention is not able to match GD performance. We tuned the learning rate as well as the weight initialization but found no significant difference over the hyperparameters we used througout this study. In general, we hypothesize that GD is an optimal update given the limited capacity of a single layer of (single-head) self-attention. We therefore argue that the softmax induces (at best) a linear offset of the matrix product of training data and query vector softmax(KT qj) = (ek T 1 qj, . . . , ek T Nqj)T /( X i ek T i qj) (14) = (ex T 1 WKQxj, . . . , ex T NWKQxj)T /( X i ex T i WKQxj) (15) (1 + x T 1 WKQxj, . . . , 1 + x T NWKQxj)T /( X i 1 + x T i WKQxj) (16) KT qj + ϵ (17) proportional to a factor dependent on all {xτ,i}N+1 i=1 . We speculate that the dependency on the specific task τ, for large Nx vanishes or that the x-dependent value matrix could introduce a correcting effect. In this case the softmax operation introduces an additive error w.r.t. to the optimal GD update. To overcome this disadvantageous offset, the Transformer can (approximately) introduce a correction with a second self-attention head by a simple subtraction i.e. P1V1softmax(KT 1 WQxj) + P2V2softmax(KT 2 WQxj) (18) PV ((1 + x T 1 W1,KQxj, . . . , 1 + x T NW1,KQxj) (1 + x T 1 W2,KQxj, . . . , 1 + x T NW2,KQxj)) (19) = PV (x T 1 (W1,KQ W2,KQ)xj, . . . , x T N(W1,KQ W2,KQ)xj) (20) PV KT qj. (21) Transformers Learn In-Context by Gradient Descent 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 1 2 3 4 5 6 7 8 9 1011 1 2 3 4 5 6 7 8 9 10 11 1W1, KQ + 2W2, KQ Figure 11. Visualizing the correction to the softmax operation when training Transformers on regression tasks. The left and center plot show the matrix product WKQ = W T KWQ including its scaling by η induced through PWV of the two heads of the trained softmax self-attention layer. We observe that both of the matrices are approximate diagonal almost perfect sign reversed values on the off-diagonal terms. After adding the matrices (right plot), we observe a diagonal matrix and therefore to much improved approximation of our construction and therefore gradient descent dynamics. Here we assume that PV 1) subsumes the dividing factor of the softmax and that 2) is the same (up to scaling) for each head. Note that if (W1,KQ W2,KQ) is diagonal, and P and V chosen as in the Proposition of Appendix A.1, we recover our gradient descent construction. We base this derivation on empirical findings, see Figure 12, that, first of all, show the softmax self-attention performance increases drastically when using two heads instead of one. Nevertheless, the self-attention layer has difficulties to match the loss values of a model trained with GD. Furthermore, this architecture change leads to a very much improved alignment of the trained model and GD. Second, we can observe that when training a two-headed softmax self-attention layer on regression tasks the correction proposed above is actually observed in weight space, see Figure 11. Here, we visualize the matrix product within the softmax operation Wh,KQ per head which we scale with the last diagonal entry of Ph Wh,V which we denote by ηh = Ph Wh,V ( 1, 1). Intriguingly, this results in an almost perfect cancellation (right plot) of the off-diagonal terms and therefore in sum to an improved approximation of our construction, see the derivation above. We would like to reiterate that the stronger inductive bias for copying data of the softmax layer remains, and is not invalidated by the analysis above. Therefore, even for our shallow and simple constructions they indeed fulfill an important role in support for our hypotheses: The ability to merge or copy input and target data into single tokens allowing for their dot product computation necessary for the construction in Proposition 1, see Section 4 in the main text. We end this section by analysing Transformers equipped with Layer Norm which we apply as usually done before the self-attention layer: Overall, we observe qualitatively similar results to Transformers with softmax self-attention layer i.e. a decrease in performance compared to GD accompanied with a decrease in alignment between models generated by the Transformer and models trained with GD, see Figure 14. Here, we test again a single linear self-attention layer succeeding Layer Norm as well as two layers where we skip the first Layer Norm and only include a Layer Norm between the two. Including more heads does not help substantially. We again assume the optimality of GD and argue that information of targets and inputs present in the tokens is lost by averaging when applying Layer Norm. This naturally leads to decreasing performance compared to GD, see first row of Figure 14. Although the alignment to GD and GD++, especially for two layers, is high, we overall see inferior performance to one or two steps of GD or two steps of GD++. Nevertheless, we speculate that Layer Norm might not only stabilize Transformer training but could also act as some form of data normalization procedure that implicitly enables better generalization for larger inputs as well as targets provided in-context, see OOD experiments in Figure 14. Overall we conclude that common architecture choices like softmax and Layer Norm seem supoptimal for the constructed in-context learning settings when comparing to GD or linear self-attention. Nevertheless, we speculate that the potentially small performance drops of in-context learning are negligible when turning to deep and wide Transformers for which these architecture choices have empirically proven to be superior. Transformers Learn In-Context by Gradient Descent (a) Comparing one step of GD with a trained softmax one-headed self-attention layer. 0 2000 4000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Trained TF (b) Comparing one step of GD with a trained softmax two-headed self-attention layer. 0 2500 5000 7500 10000 Training steps GD Trained TF 0 2000 4000 6000 8000 10000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Trained TF Figure 12. Comparing trained two-headed and one-headed single-layer softmax self-attention with 1 step of gradient descent on linear regression tasks. Left column: Softmax self-attention is not able to match gradient descent performance with hand-tuned learning rate, but adding a second attention head significantly reduces the gap, as expected by our analytical argument. Center left: The alignment suffers significantly for single-head softmax SA. We observe good but not as precise alignment when compared to linear Transformers for the two-headed softmax SA layer. Center right & right: The two-headed self-attention compared to the single-head layer shows similar robust out-of-distribution behavior compared to gradient descent. A.10. Details of curvature correction We give here a precise construction showing how to implement in a single head, a step of GD and the discussed data transformation, resulting in GD++. Recall again the linear self-attention operation with a single head ej ej + PWV X i ei ei W T K. (22) We provide again the weight matrices in block form of the construction of Prop. 1 but now enabling additionally our described data transformation: WK = WQ = Ix 0 0 0 with Ix the identity matrix of size Nx, Iy od size Ny resp. Furthermore, we set WV = Ix 0 W Iy with the weight matrix W RNy Nx of the linear model we wish to train and P = γIx 0 0 η N . This leads to the following update + γIx 0 0 η N + γIx 0 0 η N + γXXT xj Wxj for every token ej = (xj, yj) including the query token e N+1 = etest = (xtest, 0) which will give us the desired result. Why does GD++ perform better? We give here one possible explanation of the superior performance of GD++ compared to GD. Note that there is a close resemblance of the GD transformation and a heavily truncated Neuman series approximation of the inverse XXT . We provide here a more heuristic explanation for the observed acceleration. Given γ R, GD++ transforms every input according to xi xi γXXT xi = (I γXXT )xi. We can therefore look at the change of squared regression loss L(W) = 1 2 PN i=0(Wxi yi)2 induced by this transformation i.e. L++(W) = Transformers Learn In-Context by Gradient Descent Figure 13. GD++ analyses. Left: We visualize the change of the eigenspectrum induced by the input data transformation of GD++ for different γ observed in practice. Center: Given we know the maximum and minimum of eigenvalues λ1, λn of the loss Hessian XXT with X = (x0, . . . , x N) for different N, we compare the original condition number (depicted by * s at γ = 0) and the condition number (in log scale) of the GD++ altered loss Hessian when varying γ. We plot in dotted lines the γ values that we observe in practice which are close the optimal ones i.e. the local minimum derived through our analysis. Right: For N = 25, we plot for different γ values the distribution of condition numbers κ = λ1/λn for 10000 tasks and observe favorable κ values close to 1 when approaching the γ = 0.099 value was found in practice. The κ values quickly explode for γ > 0.1. 1 2 PN i=0(W(I γXXT )xi yi)2 = 1 2(W(I γXXT )X Y )2 which in turn leads to a change of the loss Hessian from 2L = XXT to 2L++ = (I γXXT )X((I γXXT )X)T . Given the original Hessian H = XXT = UΣU T with it s set of sorted eigenvalues {λ1, . . . , λn} and λi 0 on the diagonal matrix Σ we can express the new Hessian through U, Σ i.e. H++ = (I γXXT )X((I γXXT )X)T = (I γUΣU T )UΣU T (I γUΣU T )T . We can simplify H++ further as H++ = (I γUΣU T )UΣU T (I γUΣU T )T = U(Σ γΣ2)U T U(I γΣ)U T (24) = U(Σ 2γΣ2 + γ2Σ3)U T (25) Given the eigenspectrum {λ1, . . . , λn} of H, we obtain an (unsorted) eigenspecturm for H++ with {λ1 2γλ2 1 + γ2λ3 1, . . . , λn 2γλ2 n + γ2λ3 n} which we visualize in Figure 13 for different γ observed in practice. We hypotheses that the Transformer chooses γ in a way that on average, across the distribution of tasks, the data transformation (iteratively) decreases the condition number λ1/λn leading to accelerated learning. This could be achieved, for example, by keeping the smallest eigenvalue λn λ++ n fixed and choosing γ such that the largest eigenvalue of the transformed data λ++ 1 is reduced, while the original λ1 stays within [λ++ 1 , λ++ n ]. To support our hypotheses empirically, we computed the minimum and maximum eigenvalues of XXT across 10000 tasks while changing the number of datapoints N [10, 25, 50, 100] i.e. X = (x0, . . . , x N) leading to better conditioned loss Hessians i.e. [1e 10, 0.097, 0.666, 2.870] and [4.6, 7.712, 10.845, 17.196] as the minimum and maximum eigenvalues of XXT across all tasks where we cut the smallest eigenvalue for N = 10 at 1e 10. Furthermore, we extract the γ values from the weights of optimized recurrent 2-layer Transformers trained on different task distributions and obtain γ values of [0.179, 0.099, 0.056, 0.029], see again Figure 13. Note that the observed eigenvalues stay within [0, 1/γ] i.e. the two roots of f(λ, γ) = λ 2γλ2 + γ2λ3. Given the derived function of eigenvalue change f(λ, γ), we compute the condition number of H++ by dividing the novel maximum eigenvalues λ++ 1 = f(1/(3γ), γ) where λ = 1/(3γ) as the local maximum of f(λ, γ), for fixed γ, and the novel minimum eigenvalue λ++ n = min(f(λ1, γ), f(λn, γ)). Note that with too small γ, we move the original λn closer to the root of f(λ, γ) i.e. λ = 1/γ and therefore can change the smallest eigenvalue. Given the task distribution and its corresponding eigenvalue distribution, we see that choosing γ reduces the new condition number κ++ = λ++ 1 /λ++ n which leads to better conditioned learning, see center plot of Figure 13. Note that the optimal γ based on our derivation above is based on the maximum and minimum eigenvalue across all tasks and does not take the change of the eigenvalue distribution into account. We argue therefore that the simplicity of the arguments above does not capture the task statistics and distribution shifts entirely and therefore obtains a slightly larger γ as an optimal value. Transformers Learn In-Context by Gradient Descent (a) Comparing one step of GD with a single-layer LSA Transformer with Layer Norm. 0 5000 10000 15000 Training steps GD Trained TF 0 5000 10000 15000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Trained TF (b) Comparing two steps of GD with a two-layer LSA Transformer with Layer Norm. 0 5000 10000 15000 Training steps GD vs trained TF 0 5000 10000 15000 Training steps Preds diff Model diff GD+ + vs trained TF 0 5000 10000 15000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs Figure 14. Comparing trained 1-layer and 2-layer Transformers with Layer Norm and 1 step or 2 steps of gradient descent resp. Left column: The Transformers is not able to match the gradient descent performance with hand-tuned learning rate. Alignment plots: The alignment suffers significantly when comparing to linear self-attention layers although still reasonable alignment is obtained which decreases slightly when comparing to GD++ for the two-layer Transformer.Center right & right: The Layer Norm Transformer outperforms when GD when providing training input data that is significantly larger than the data provided during training. We furthermore visualize the condition number change for N = 25 and 10000 tasks in the right plot of Figure 13 and observe the distribution moving to desirable κ values close to 1. For γ values larger than 0.1 the distribution quickly exhibits exploding condition numbers. A.11. Phase transitions We comment shorty on the curiously looking phase transitions of the training loss observed in many of our experiments, see Figure 2. Nevertheless, simply switching from a single-headed self-attention layer to a two-headed self-attention layer mitigates the random seed dependent training instabilities in our experiments presented in the main text, see left and center plot of Figure 15. Furthermore, these transitions look reminiscent of the recently observed grokking behaviour (Power et al., 2022). Interestingly, when carefully tuning the learning rate and batchsize we can also make the Transformers trained in these linear regression tasks grokk. For this, we train a single Transformer block (self-attention layer and MLP) on a limited amount of data (8192 tasks), see right plot of Figure 15, and observe grokking like train and test loss phase transitions where test set first increases drastically before experiencing a sudden drop in loss almost matching the desired GD loss of 0.2. We leave a thorough investigation of these phenomena for future study. A.12. Experimental details We use for most experiments identical hyperparameters that were tuned by hand which we list here Optimizer: Adam (Kingma & Ba, 2014) with default parameters and learning rate of 0.001 for Transformer with depth K < 3 and 0.0005 otherwise. We use a batchsize of 2048 and applied gradient clipping to obtain gradients with global norm of 10. We used the Optax library. Haiku weight initialisation (fan-in) with truncated normal and std 0.002/K where K the number of layers. We did not use any regularisation and observed for deeper Transformers with K > 2 instabilities when reaching GD performance. We speculate that this occurs since the GD performance is, for the given training tasks, already close to divergence as seen when providing tasks with larger input ranges. Therefore, training Transformers also becomes Transformers Learn In-Context by Gradient Descent Figure 15. Phase transitions during training. Left: Loss based on 10 different random seeds when optimizing a single-headed selfattention layer. We observe for some seeds very long initial phases of virtually zero progress after which the loss drops suddenly to the desired GD loss. Center: The same experiment but optimizing a two-headed self-attention layer. We observe fast and robust convergence to the loss of GD. Right: Training a single Transformer block i.e. a self-attention layer with MLP and a reduced training set size of 8192 tasks. We observe grokking like train and test loss phase transitions where test set first increases drastically before experiencing a sudden drop in loss almost matching the desired GD loss of 0.2. instable when we approach GD with an optimal learning rate. In order to stabilize training, we simply clipped the token values to be in the range of [ 10, 10]. When applicable we use standard positional encodings of size 20 which we concatenated to all tokens. For simplicity, and to follow the provided weight construction closely, we did use square key, value and query parameter matrix in all experiments. The training length varied throughout our experimental setups and can be read off our training plots in the article. When training meta-parameters for gradient descent i.e. η and γ we used an identical training setup but usually training required much less iterations. In all experiments we choose inital W0 = 0 for gradient descent trained models. Inspired by (Garg et al., 2022), we additionally provide results when training a single linear self-attention layer on a fixed number of training tasks. Therefore, we iterate over a single fixed batch of size B instead of drawing new batch of tasks at every iteration. Results can be found in Figure 16. Intriguingly, we find that (meta-)gradient descent finds Transformer weights that align remarkable well with the provided construction and therefore gradient descent even when provided with an arguably very small number of training tasks. We argue that this again highlights the strong inductive bias of the LSA-layer to match (approximately) gradient descent learning in its forward pass. Transformers Learn In-Context by Gradient Descent (a) Comparing 1 step of gradient descent with training a LSA-layer on 128 tasks. 0 1000 2000 3000 4000 5000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF (b) Comparing 1 step of gradient descent with training a LSA-layer on 512 tasks. 0 2000 4000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF (c) Comparing 1 step of gradient descent with training a LSA-layer on 2048 tasks. 0 2000 4000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF (d) Comparing 1 step of gradient descent training a LSA-layer on 8192 tasks. 0 2000 4000 Training steps GD Trained TF 0 1000 2000 3000 4000 5000 Training steps Preds diff Model diff 0.5 1.0 1.5 2.0 where x U( , ) Test on larger inputs GD Interpolated Trained TF 1 2 3 4 5 W where W N(0, I) Test on larger targets GD Interpolated Trained TF Figure 16. Comparing trained Transformers with GD and their weight interpolation when training the Transformer on a fixed training set size B. Across our alignment measures as well as our tests on out-of-training behaviour, trained Transformers fail to align with GD when provided with a very small amount of tasks. Nevertheless, we see already almost perfect alignment in our base setting N = Nx = 10 when provided with B > 2048 tasks. In all settings, we train the Transformer on (non-stochastic) gradient descent iterating over a single batch of tasks of size B equal to the number provided in the Figure titles (128, 512, 2048, 8192).