# trainable_transformer_in_transformer__af2985a9.pdf Trainable Transformer in Transformer Abhishek Panigrahi * 1 Sadhika Malladi * 1 Mengzhou Xia 1 Sanjeev Arora 1 Recent works attribute the capability of in-context learning (ICL) in large pre-trained language models to implicitly simulating and fine-tuning an internal model (e.g., linear or 2-layer MLP) during inference. However, such constructions require large memory overhead, which makes simulation of more sophisticated internal models intractable. In this work, we propose a new efficient construction, Transformer in Transformer (in short, TINT), that allows a transformer to simulate and finetune more complex models during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TINT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TINT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TINT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TINT for a OPT-125M model improves performance by 4 16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines. To facilitate further work, a modular and extensible codebase for TINT is included. 1. Introduction Large transformers (Vaswani et al., 2017) have brought about a revolution in language modeling, with scaling yield- *Equal contribution 1Department of Computer Science, Princeton University. Correspondence to: Abhishek Panigrahi , Sadhika Malladi . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). ing significant advancements in capabilities (Brown et al., 2020; Chowdhery et al., 2022). These capabilities include performing in-context learning or following natural language instructions at inference time. Researchers have tried to understand how these models can learn new tasks without parameter updates (Garg et al., 2022; von Oswald et al., 2023; Xie et al., 2022; Nanda et al., 2023). A popular hypothesis is that in-context learning corresponds to the transformer (referred to as the simulator from now on) simulating gradient-based learning of a smaller model (called auxiliary model) that is embedded within it. From perspective of AI safety and alignment (Amodei et al., 2016; Leike et al., 2018; Askell et al., 2021), the ability of a larger model to use input data (which could be arbitrary in a deployed setting) to implicitly train an auxiliary model feels worrisome. This concern felt minor due to efficiency considerations: previous analyses and experiments required the auxiliary model to be quite tiny compared to the simulator. For instance, simulating and training an auxiliary model that is a linear layer requires tens of millions of parameters in the simulator (Akyurek et al., 2022). This scaling is even more dramatic if the auxiliary model is a multi-layer fully-connected net (Giannou et al., 2023). Our primary contribution is an explicit and nontrivial construction of a simulator called TINT that explicitly adapts to the context without parameter updates. In particular, we show that a forward pass through a modestly sized TINT can involve gradient-based training of an auxiliary model that is itself a large transformer. For example, we show that TINT with 2B parameters can faithfully simulate fine-tuning a 125M parameter auxiliary transformer in a single forward pass. (Prior constructions would have required trillions of parameters in the simulator for a far simpler auxiliary model.) Our main result is described in Theorem 1.1, which details how the size of TINT depends on the auxiliary model. Our construction is generally applicable to diverse variants of pre-trained language models. The rest of the paper is structured to highlight the key design choices and considerations in TINT. 1. Section 2 discusses the overall design decisions required to make TINT, including how the simulator can read Trainable Transformer in Transformer TINT can efficiently perform simulated gradient descent of an auxiliary model. Theorem 1.1. Consider an auxiliary transformer with L layers, Daux embedding dimension, Haux attention heads, and a maximum sequence length of Taux. Given a hyperparameter S (see Section 3.1), TINT can perform an efficient forward pass (Section 3), compute the simulated gradient (Section 4), and evaluate the updated auxiliary model with a total of (c1S2 + c3)D2 aux min(Haux, S2) D2 aux + c2SDaux min(S2, Haux) + c3 Taux Daux S min(Haux, S2) parameters, with constants c1, c2, c3 < 150. The TINT model has Dsim = SDaux embedding dimension and Hsim = min(S2, Haux) attention heads. See Table 3 for a detailed breakdown of the parameters. from and write to the auxiliary model and how the data must be formatted. 2. Section 3 uses the linear layer as an example to describe how highly parallelized computation and careful rearrangement of activations enable TINT to efficiently simulate the forward pass of the auxiliary model. 3. Section 4 describes how TINT uses first-order approximations and stop gradients to compute the simulated gradient of the auxiliary model. 4. Section 5 performs experiments comparing TINT to suitable baselines in language modeling and in-context learning settings. Our findings validate that the simulated gradient can effectively update large pre-trained auxiliary models. Notably, we instantiate TINT in a highly extensible codebase, making TINT the first such construction to undergo end-to-end evaluation. Due to the complexity of the construction, we defer the formal details of TINT to the appendix. Notations For a general model, we use D to denote its embedding dimension, H to denote the number of attention heads, θ to denote its set of all parameters, and T to denote the length of an input sequence. We use subscripts sim and aux to differentiate the quantities between the TINT and the auxiliary model. For example, Daux and Dsim refers to the embedding dimension of the auxiliary model and TINT respectively. We use e(ℓ) t RDsim and x(ℓ) t RDaux to denote token embeddings in the TINT and the auxiliary model at layer ℓand sequence position t respectively. For a matrix A, aj refers to its jth row, and for any vector b, bj refers to its jth element. 2. Design Considerations Our goal is to construct a simulator that can train an auxiliary model over the course of an inference pass. This procedure requires four steps: 1. Forward Pass: A forward pass to compute the auxiliary model output f(ξ; θaux) on training input ξ and a loss L. 2. Backward Pass: Backpropagation to compute the gradient of the auxiliary model θaux L(f(ξ; θaux)). 3. Parameter Update: Update the auxiliary model using gradient descent, setting θ aux = θaux η θaux L(f(ξ; θaux)). 4. Output: Output next-token predictions f(ξ ; θ aux) on a test input ξ using the updated auxiliary model. Note that steps 1-3 can be looped to train the auxiliary model for a few steps1, either on the same training data or on different training data for each step, before evaluating it on the test input (Giannou et al., 2023). The above method highlight two crucial features of the simulator: (1) it has access to some amount of training data, and (2) it can use (i.e., read) and update (i.e., write) the auxiliary model. Below, we discuss how to design a modest-sized simulator around these two considerations. 2.1. Input structure For simplicity, we describe only one update step on a single batch of training data ξ but note that our formal construction and our experiments handle multiple training steps (see Definition 5.1). Steps 1 and 4 show that the simulator must access some training data ξ to train the auxiliary model and some testing data ξ on which it evaluates the updated auxiliary model. For the sake of illustration we consider the following simple setting: given a sequence of input tokens e1, ..., e T , we split it into training data ξ = e1, ..., er and testing data ξ = er+1, ..., e T . Suppose ξ contains an in-context input-output exemplar and ξ contains a test input. Then, the simulator performs a very natural operation of training the auxiliary model on a taskspecific example and outputs results for the test example. On the other hand, if the input is not specially formatted, ξ 1Looping steps 1-3 scales the depth of the simulator model. Trainable Transformer in Transformer Forward Modules Last (ℓth) Backward Module Forward Modules For Evaluation e1 e2 e3 e4 e5 e T v 1 v 2 v K y1 y2 y3 y4 y5 y T ①Simulated forward pass ②Backward simulation of i-1 th layer ③Descent simulation of ith layer i-1 th Backward Module (i = ℓ, , 1 ) ith Descent Module e1 e2 e3 e4 e5 e T v1 v2 v K Train Input embeddings e Auxiliary model params v Updated auxiliary model params v Gradient of loss wrt. y y (Masked) validation input embeddings e 0 Figure 1: The overall structure of TINT (see Section 2 for an overview). Each forward, backward, and descent module is represented using combinations of linear, self-attention, layernorm, and activation layers. The input consists of prefix embeddings (Definition 2.1) that represent relevant auxiliary model parameters in each layer followed by natural language input. A prefix mask separates the train and test segments of the input ( 2.1). and ξ may simply contain some natural language tokens. In this case, the simulator is using the first part of the context tokens to do a quick fine-tune of the auxiliary for some task before outputting the subsequent tokens with the auxiliary model. In a worst-case scenario, users might provide harmful contents, leading the model to implicitly fine-tune on them and potentially output even more harmful content. Our experiments consider many options for splitting a sequence into ξ and ξ , and we defer a more detailed discussion of possible setups to Section 5. Accessing Training Labels. The simulator must be able to see the labels of the training tokens in order to compute the loss L (usually, the autoregressive cross-entropy loss) in step 1. For example, in Figure 1, when we compute the loss for the token e2 in the second position, we need to use its label e3 in the third position. However, this is not possible if the simulator uses strictly autoregressive attention (Appendix G contains a more general discussion). We thus use a bidirectional attention mask on the training tokens and autoregressive attention on the evaluation portion. We note that encoding relevant (e.g., retrieved) context with bidirectional attention is a popular way to improve autoregressive capabilities in language modeling and natural language tasks (Raffel et al., 2020; Borgeaud et al., 2022; Izacard and Grave, 2020; Izacard et al., 2023; Wang et al., 2023a; Tay et al., 2022). This empirical approach is similar in motivation to how TINT uses a few context tokens to adapt the auxiliary model to a given input. Having established the training and testing data, we can now move to discussing how the simulator can access (i.e., read) and update (i.e., write to) the auxiliary model at inference time. 2.2. Read and write access to auxiliary model As discussed in the start of this section, the simulator must have read and write access to the parameters of the auxiliary model. Crucially, the simulator must do at least two forward passes through the auxiliary model, one with the current parameters θaux and one with the updated parameters θ aux. The straightforward way to simulate the forward pass of the auxiliary model would be to store its weights in the simulator s weights and run a forward pass as usual. One can analogously simulate the backward pass according to the loss L to compute the gradients. However, the simulator cannot update its own weights at inference time, so this strategy would not permit the model to write the updated parameters θ aux and later read them when simulating the second forward pass. Therefore, the auxiliary model θaux must be available in the activations of the simulator. To this end, Wei et al. (2021); Perez et al. (2021) model the simulator after a Turing machine, where the activation e(ℓ) t RDsim in each layer acts as a workspace for operations, and computation results are copied to and from memory using attention operations. In this paradigm, if Daux = 768, computing a dot product w, x(ℓ) t with weight w R768 requires at least 6.4 million parameters in the simulator2. Given the pervasiveness of dot products in neural network modules, this strategy would yield a simulator with trillions of parameters. 2Using a feedforward module to mimic the dot product (as in Akyurek et al. (2022), see thm. B.5), where the simulator embedding comprises [w, xt] R1536, necessitates a minimum of 4.7 million parameters. Using an attention module to copy the weight from memory adds another 1.7 million parameters. Trainable Transformer in Transformer Figure 2: TINT simulates the forward pass of a linear layer with a Hsim-head attention layer (Hsim = 6 here). We stack S weights per prefix embedding to reduce the number of prefix embeddings required (S = 2 here). We furthermore shard each weight and token embedding xt into S shards and compute inner products of each shared in parallel using S S attention heads (S = 3 here). Please see Section 3.1. Alternatively, one can store parameters in the first few context tokens and allow the attention modules to attend to those tokens (Giannou et al., 2023). This removes the need for copying and token-wise operations. Then, the same dot product requires only a self-attention module with 1.7 million parameters. We thus adopt this strategy to provide relevant auxiliary model weights as prefix embeddings. Definition 2.1 (Prefix Embeddings). {v(ℓ) j }K j=1 denotes the K prefix embeddings at the ℓth layer in TINT. These contain relevant auxiliary model weights or simulated activations. We now consider how to efficiently simulate the building block of neural networks: matrix-vector multiplication. In the next section, we demonstrate that a careful construction of the prefix embeddings enables efficient parallelizaton of matrix-vector products across attention heads. 3. Efficient Forward Propagation We now discuss how TINT performs a highly efficient forward pass through the auxiliary model. Here, we focus on the linear layer because it is repeated many times in various transformer modules (e.g., in self-attention), so improving the efficiency dramatically reduces TINT s size. Definition 3.1 (Linear layer). For a weight W RDaux Daux, a linear layer takes x RDaux as input and outputs y = W x.3 We compute y coordinate-wise, i.e., wi, xt for all i [Daux], where wi is the ith row of W . The simulator represents wi, xt as an attention score between the row wi and the input xt. So, the input embeddings et contain xt in 3Linear layers are applied token-wise, so we can consider a single position t without loss of generality. the first Daux coordinates, and the rows {wi} of the weight matrix W are in prefix embeddings {vj} (def. 2.1). We strategically distribute the weights ( 3.1) and aggregate the parallelized computation results ( 3.2). As we briefly mentioned in the previous section, a straightforward construction of the linear layer would use the context and attention heads inefficiently. Our construction instead parallelizes the computation across attention heads in such a way that aggregating the output of the linear operation can also be conducted efficiently. 3.1. Stacking and Sharding We partition the inner product computation across attention heads by carefully rearranging the weights and activations via stacking and sharding (Figure 2). Instead of representing each weight wi as its own prefix token vi, we stack S weights on top of each other to form each prefix embedding vi. S drives a trade-off between the embedding dimension of the TINT, Dsim = Daux S, and the context length to the TINT, Tsim = K + Taux. We set S = 4. A simple strategy now would be to use different attention heads to operate on different rows; however, this would still use only S attention heads whereas we could parallelize across many more heads. We instead parallelize across more attention heads, where each head is responsible for computing the inner product on a subset of the coordinates. We shard each individual weight and the activation into S parts and compute the inner product on each of the S parts in parallel We set S and S such that Hsim = S S , thereby using all of TINT heads to efficiently compute the dot products. Trainable Transformer in Transformer 3.2. Efficient Aggregation The attention module outputs a sparse matrix with shape (Dsim/Hsim) Hsim containing the inner products on various subsets of the coordinates in its entries. To complete the linear forward pass, we need to sum the appropriate terms to form a Dsim-length vector with W x in the first Daux coordinates. Straightforwardly summing along an axis aggregates incorrect terms, since the model was sharded. On the other hand, rearranging the matrix would require an additional Dsim Dsim linear layer. Instead, TINT saves a factor of Hsim parameters by leveraging the local structure of the attention output. We illustrate this visually in Appendix C.1. This procedure requires D2 sim/Hsim + Dsim Hsim parameters. This efficient aggregation also compresses the constructions for the TINT s backpropagation modules for layer normalization and activations (Appendices E and F). 4. Simulated Gradient TINT adapts backpropagation to compute gradients (Figure 1). We aim to train a capable (i.e., pre-trained) auxiliary model for just a few steps, so high precision gradients may be unnecessary. Instead, TINT performs an approximate backpropagation. TINT then uses this simulated gradient to update the auxiliary model. Prior works computed similar approximate gradients in hopes of more faithfully modeling neurobiology (Scellier and Bengio, 2017; Hinton, 2022) or improving the efficiency of training models (Hu et al., 2021; Malladi et al., 2023). We note that the approximations in the simulated gradients can be made stronger at the cost of enlarging TINT. Indeed, one could construct a simulator to exactly perform the procedure outlined in 2, though it would be orders of magnitude larger than TINT. For brevity s sake, we focus on the key approximations and design choices and defer formal details to the appendix. 4.1. First-order approximations We use first-order approximations of gradients to backpropagate through the layer normalization layer.4 It normalizes the input using its mean and standard deviation across the input dimensions. Since the operation is token-wise, we can consider a single position t without loss of generality. Definition 4.1 (Layer normalization). A layer normalization layer fln takes input x RDaux and outputs y = (x µ)/σ, where µ and σ denote its mean and standard deviation. High precision gradients: Formally, for input-output pair 4We discuss a layer normalization layer fln without scale and bias parameters, but Appendix E contains a general construction. (x, y), we can compute the gradients y, x with chain rule: y, y y + y 1 Daux Inefficiency of exact computation: A TINT layer simulating backpropagation through an auxiliary s layer normalization layer receives yt and xt in its input embeddings. We go through the exact gradient and why it is inefficient. For exact computation one could first compute yt using a normalization layer and store in the embeddings. However, inefficiency arises from computing the term yt, yt yt. To calculate yt, yt yt at each token position t, we could either: (1) use a two-layer MLP that focuses on each token separately, or (2) a single self-attention module to treat the operation as a sequence-to-sequence task. For (1) we could initially compute yt, yt via an MLP, followed by computation of yt, yt yt using another MLP. The element-wise multiplication in embeddings would be facilitated with a nonlinear activation function like GELU (Akyurek et al., 2022) (refer to thm. B.5 for details). However, this approach would need substantial number of simulator parameters to represent the MLPs. Alternatively, we could use a single self-attention module. Constructing such a module would require careful engineering to make sure the input tokens only attend to themselves while keeping an attention score of 0 to others. If we used a linear attention, we would need to space out the gradient yt and xt in each position t, such that the attention score is 0 between different tokens. This would require an embedding dimension proportional to the context length. On the other hand, if we used a softmax attention module, we would need an additional superfluous token in the sequence. Then, a token at position t would attend to itself with attention yt, yt and to the extra token with an attention score of 1 yt, yt . The extra token would return a value vector 0. To avoid such inefficiency, we opt for a first-order approximation instead. Efficient approximation: Instead of explicitly computing each term in the chain rule of fln(x) x y in Eq. 1, we instead use a first order Taylor expansion of fln. fln(x + ϵ y) = fln(x) + ϵ fln(x) Rearranging allows us to write fln(x) ϵ (fln(x + ϵ y) fln(x)) + O(ϵ). Trainable Transformer in Transformer Similar to the computation of Eq. 1, we can show σ (1 Daux 1)I fln(x)fln(x) . Because fln(x)/ x is symmetric5, we can write ϵ (fln(x + ϵ y) fln(x)) + O(ϵ). Then, ignoring the small error term, we can use just two linear layers, separated by a normalization layer, to simulate the approximation. 4.2. Fuzzy backpropagation via stop gradients Self-attention is inherently quadratic, because it uses the keys and queries to compute attention scores between every possible pair of tokens in the sequence. These scores then linearly combine the value vectors (see def. B.1 for a formal definition). Computing the gradient exactly is thus a very complex operation. Instead, we stop the gradient computation through attention scores in the self-attention layer. For similar reasons, we only update the value parameter in the self-attention module. Gradient backpropagation: For an input, output sequence pair {xt}, {yt}, if {qt, kt, vt} denote the intermediate query, key, value vectors, on gradients { yt}, { xt} is given via the chain rule: xt = Q qt + K kt + V vt. (2) Here, V , K, Q denote the query, key, and value matrices. Inefficiency in exact computation: Here, we demonstrate that simulating computation of the three terms in Eq. 2 is inefficient, because qt, kt depend on the derivatives w.r.t. the attention scores. As an example, we focus on kt: j at,j(( yt) vj)(kj X j at,j kj ). Computing this term would require us at least 2 selfattention layers and an MLP layer. The first attention layer would compute ( yt) vj for different token pairs, similar to the forward simulation of a linear layer with linear attention ( 3). These would be then multiplied to the pair-wise attention scores at,j with an MLP to compute at,j(( yt) vj), with elementwise product would be facilitated by Ge LU non-linearity (thm. B.5). These would be finally used by 5For a linear function f with matrix W , f(x) x = W . Since W may not be a symmetric matrix, this method can t be generally applied to approximately backpropagate linear layers or causal self-attention layers. Table 1: Language modeling results on WIKITEXT-103. We use 30%, 50%, 70% and 90% of sequences for training in the language modeling setting ( 5.2). TINT improves the auxiliary model perplexities by 0.3 0.7 absolute on average. The small perplexity difference between the TINT and explicitly updating the auxiliary model suggests that the simulated gradient (Section 4) can still effectively fine-tune the auxiliary model. Training proportion Evaluating with 30% 50% 70% 90% GPT-2 Auxiliary Model 25.6 24.9 24.5 23.3 Fine-tuning 24.9 24.0 23.5 22.2 TINT 25.1 24.3 23.8 22.6 OPT-125M Auxiliary Model 29.6 28.8 28.0 28.0 Fine-tuning 29.0 28.2 27.4 27.4 TINT 29.3 28.4 27.5 27.4 an attention layer to combine the different key vectors. A similar simulation would be necessary to compute qt. Stop gradients through query and key vectors: In order to reduce the necessary resources, we ignore the query and key gradients in Eq. 2. When we ignore these gradient components, { xt} can be simplified as xt V vt = V X j aj,t yt. (3) A single self-attention layer can compute this by using the attention scores to combine the token-wise gradients. Why won t it hurt performance? Estimating xt as described is motivated by recent work (Malladi et al., 2023) showing that fuzzy gradient estimates don t adversely affect fine-tuning of pre-trained models. Furthermore, we theoretically show that when the attention head for each position pays a lot of attention to a single token (i.e., behaves like hard attention (Perez et al., 2021)), the approximate gradient in Eq. 3 is entry-wise close to the true gradients (thm. D.5). The other approximation is to update only the value parameters V of the auxiliary model ( D). This is motivated by parameter efficient fine-tuning methods like LORA (Hu et al., 2021) and IA3 (Liu et al., 2022), which restrict the expressivity of the gradient updates without degrading the quality of the resulting model. We similarly show in the next section that the simulated gradients in TINT can effectively tune large pre-trained transformers. 5. Experiments We evaluate the performance of the TINTs constructed using GPT2 and OPT-125M as auxiliary models. The findings from our experiments in the language modeling and Trainable Transformer in Transformer Review: goes to absurd lengths. Sentiment: Negative Review: contains no wit , only labored gags . Sentiment: Negative Review: the greatest musicians Sentiment: Positive Review: goes to absurd lengths. Sentiment: Negative Review: contains no wit , only labored gags . Sentiment: Negative Review: the greatest musicians Sentiment: Positive Example 1 Single Multi. Figure 3: Different settings in few-shot learning (k = 3) using TINT. The Single mode (left) treats each example as a training datapoint, and the auxiliary model is updated with a batch of inputs (see def. 5.1). The Multi. mode (right) concatenates all examples to form a single input and uses batch size 1 in def. 5.1. For Label loss, only underlined label words are used as training signal, while full context loss includes all tokens. in-context learning settings confirm that fine-tuning with the simulated gradients (Section 4) still allows for effective learning in the auxiliary model. We loop the training steps (i.e., steps 1-3) outlined in Section 2 to accommodate solving real-world natural language tasks. We formalize the setting below. 5.1. Setting: N-step Fine-Tuning We formalize the procedure in Section 2 to construct a suitable setting in which we can compare TINT to explicitly training the auxiliary model. Definition 5.1 (N-step Fine-Tuning). Given a batch of training datapoints ξ1, , ξB and a validation input ξ , we compute and apply gradient updates on the auxiliary model θaux for timesteps t = 0, ..., N 1 as θt+1 aux = θt aux η i=1 θL(f(ξi; θt aux)) where η is the learning rate and L is a self-supervised loss function on each input ξi. Then, we evaluate the model θN aux on ξ . θ0 aux denotes the pre-trained auxiliary model. Below, we instantiate this setting with text inputs of different formats and different self-supervised loss functions L. To manage computational demands, we limit N to 3 or fewer.6 5.2. Case Study: Language Modeling The first case we consider is language modeling, where the input data e1, ..., e T is natural language without any additional formatting. We use a batch size of 1 in def. 5.1, and delegate ξ1 = e1, ..., et and ξ = et+1, ..., e T . The loss L is the sum of the token-wise autoregressive crossentropy loss in the sequence ξ1. For example, given an input Machine learning is a useful tool for solving problems., we 6Performing many gradient steps scales the depth of TINT and makes experimentation computationally infeasible. use the red part as the training data ξ1, and the brown part as the validation data ξ . We perform language modeling experiments on WIKITEXT-103 (Merity et al., 2016) and vary the number of tokens t used as training data ξ. Results. In Table 1, we observe that TINT achieves a performance comparable to explicit fine-tuning of the auxiliary model, indicating that the simulated gradient (Section 4) is largely effective for fine-tuning. Both TINT and explicitly fine-tuning the auxiliary model show improvement over the base model, confirming that minimal tuning on the context indeed enhances predictions on the test portion. 5.3. Case Study: In-Context Learning For in-context learning, we consider input data to be a supervised classification task transformed into a next-token prediction task using surrogate labels (see Figure 3). Using binary sentiment classification of movie reviews as an example, given an input (e.g., the review), the model s predicted label is computed as follows. First, we design a simple task-specific prompt (e.g., Sentiment: ) and select label words c1, ..., cn to serve as surrogates for each class (e.g., positive and negative ). Then, we provide the input along with the prompt to the model, and the label assigned the highest probability is treated as the model s prediction. We describe the zero-shot and few-shot settings below. Zero-shot. In the zero-shot setting, we are given text with the first T 1 tokens as the input text and final token as the surrogate text label. Hence, we adapt def. 5.1 to use batch size B = 1, training data ξ1 = x1, ..., x T 1, and testing data ξ = x T . The loss L is again the sum of the token-wise autoregressive cross-entropy losses. Few-shot. In the few-shot setting, we are given input texts that are a concatenation of k sequences ξ1, , ξk. Each sequence contains the input text followed by the surrogate label for the in-context exemplar. These k exemplars are followed by test data ξ . In this case, we can compute the Trainable Transformer in Transformer Table 2: Zero-shot and few-shot in-context learning results across 7 downstream tasks. All the few-shot results are averaged over three training seeds. TINT consistently surpasses its auxiliary model and achieves comparable performance to Finetuninguation. TINT outperforms auxiliary models by 3 4% and 12 16% absolute points on average in 0-shot and 32-shot experiments respectively. TINT performs competitively with a similar-sized pre-trained model (OPT-1.3B) in both 0-shot and 32-shot settings. We show the standard deviation for few-shot settings in parentheses. Model Shots Subj AGNews SST2 CR MR MPQA Amazon Avg. Without Calibration OPT-125M 0 64.0 66.0 70.5 64.5 71.0 68.0 76.5 68.6 OPT-1.3B 0 59.0 55.5 54.0 50.5 52.5 74.0 57.0 57.5 OPT-125M Fine-tuning 0 71.0 67.0 79.5 71.5 70.0 68.0 85.5 73.2 OPT-125M TINT 0 67.5 66.0 76.5 69.0 76.0 70.5 78.5 72.0 OPT-125M 32 58.7(4.9) 33.7(8.4) 50.8(1.2) 51.3(1.9) 50.0(0.0) 54.3(2.5) 55.0(6.7) 50.5(1.9) OPT-1.3B 32 74.2(6.1) 71.3(5.3) 89.8(3.6) 71.5(4.5) 68.3(6.1) 81.7(3.3) 70.3(9.9) 75.3(0.4) OPT-125M Fine-tuning 32 78.0(1.4) 66.7(1.6) 71.5(1.4) 73.7(3.3) 72.0(0.0) 80.7(0.6) 79.8(0.2) 74.6(2.7) OPT-125M TINT 32 82.3(2.7) 69.3(0.9) 73.7(0.8) 75.7(1.9) 72.3(1.2) 83.2(1.0) 78.2(0.2) 76.4(0.7) With Calibration OPT-125M 0 64.0 66.0 53.0 54.5 52.5 55.5 58.0 57.6 OPT-1.3B 0 73.5 61.5 57.5 53.0 54.5 79.5 61.0 62.9 OPT-125M Fine-tuning 0 62.5 66.0 60.5 53.5 54.0 56.5 74.5 61.1 OPT-125M TINT 0 64.0 66.0 56.5 59.0 53.5 62.0 66.5 61.1 OPT-125M 32 83.5(2.4) 40.7(10.4) 50.8(0.8) 67.7(4.1) 57.7(10.8) 79.2(8.4) 56.0(8.1) 62.2(2.7) OPT-1.3B 32 51.8(1.9) 66.2(3.1) 93.7(1.0) 82.8(2.8) 91.3(1.9) 83.5(2.5) 92.0(2.9) 80.2(0.7) OPT-125M Fine-tuning 32 87.2(0.2) 67.2(0.6) 72.8(5.9) 73.3(2.6) 66.7(7.4) 81.5(3.7) 70.3(2.1) 74.1(2.9) OPT-125M TINT 32 85.3(1.9) 67.3(0.6) 71.8(3.8) 70.7(1.9) 63.7(0.2) 83.5(1.6) 77.5(1.2) 74.3(1.4) gradient updates to θaux in two different ways (Figure 3). The first setting, denoted Single, treats the k sequences as a batch of B = k training datapoints ξ1, ..., ξB. The second setting, denoted Multi, treats the concatenation of the B sequences as a single training datapoint ξ1. Furthermore, L for a training datapoint can be defined in two different ways. The first setting, denoted as Full context loss, defines L for a training datapoint ξi as the sum of cross entropy loss over all tokens. The second setting, denoted as Label loss, defines L for a training datapoint ξi in def. 5.1 as the sum of cross entropy loss over the surrogate label tokens. Tasks. We evaluate 7 classification tasks for zero-shot and few-shot settings: SST-2 (Socher et al., 2013), MR (Pang and Lee, 2004), CR (Hu and Liu, 2004), MPQA (Wiebe et al., 2005), Amazon Polarity (Zhang et al., 2015), AGNews (Zhang et al., 2015), and Subj (Pang and Lee, 2005). Model. We compare a TINT model that uses an OPT125M pre-trained model as its auxiliary model against two alternative approaches: (1) directly fine-tuning OPT-125m, and (2) performing standard evaluation using OPT-1.3b, which is of a similar size to TINT.7 7Our construction is generally applicable to diverse variants of pre-trained language models (Appendix J). Calibration: We report the performance in Table 2 in two settings: no calibration, and with calibration. If using calibration, the predicted probabilities of the surrogate labels are normalized using just the prompt as input. No Calibration: argmax ci Pr[ci | input, prompt] Calibration: arg max ci Pr[ci | input, prompt] Pr[ci | prompt] This is a widely used calibration technique (Holtzman et al., 2021) for prompting language models. Observations. We observe that inferences passes through TINT perform on par with directly fine-tuning the auxiliary model, affirming the validity of the construction design (see Section 2). As expected, TINT outperforms the base auxiliary model, since it simulates training the auxiliary model. More intriguingly, TINT demonstrates performance comparable to a pre-trained model of similar size (OPT1.3B). This suggests that the capabilities of existing pretrained models may be understood via the simulation of smaller auxiliary models. We observe that calibration may not always be beneficial in every setting.8 However, even with calibration, TINT remains competitive to fine-tuning 8Such inconsistencies in the calibration method have been observed in previous works (Brown et al., 2020). Trainable Transformer in Transformer of OPT models. The performance of OPT-1.3B improves with calibration. In this case, TINT lags behind OPT-1.3B in the few-shot setting. For further details and results of the experiments, please refer to Appendix K. 6. Related Work Gradient-based learning and in-context learning: Several works relate in-context learning to gradient-based learning algorithms. Bai et al. (2023) explicitly constructed transformers to simulate simple gradient-based learning algorithms. Mahankali et al. (2023); Ahn et al. (2023) suggested one attention layer mimics gradient descent on a linear layer, and Zhang et al. (2023b) showed polynomial convergence. Cheng et al. (2023); Han et al. (2023) extended these ideas to non-linear attentions. Experiments in Dai et al. (2022) suggest that LLM activations during in-context learning mirror fine-tuned models. These works focus on using a standard transformer for the simulator and hence cannot accommodate more complex auxiliary models; on the other hand, our work uses structural modifications and approximations to construct an efficient simulator for complex auxiliary models. Our work in contrast attempts to build even stronger transformers by introducing few structural modifications that can run gradient descent on auxiliary transformers. Transformer Expressivity: Perez et al. (2021); P erez et al. (2019) show that Transformers with hard attention are Turing complete, and Wei et al. (2021) construct transformers to study statistical learnability, but the proposed constructions are extremely large. Other works have investigated encoding specific algorithms in smaller simulators, e.g. boundeddepth Dyck languages (Yao et al., 2021), modular prefix sums (Anil et al., 2022), adders (Nanda et al., 2023), regular languages (Bhattamishra et al., 2020), and sparse logical predicates (Edelman et al., 2022). Liu et al. (2023) aim to understand automata-like mechanisms within transformers. Ba et al. (2016) connect self-attention and fast weight programmers (FWPs), which compute input-dependent weight updates during inference. Follow-up works (Schlag et al., 2021; Irie et al., 2021) use self-attention layers to update linear and recurrent networks during inference. Clark et al. (2022) add and efficiently tune Fast Weights Layers (FWL) on a frozen pre-trained model. 7. Discussion We present a parameter-efficient construction TINT capable of simulating gradient descent on an internal transformer model during inference. Using fewer than 2 billion parameters, it can simulate fine-tuning a 125 million transformer (e.g., GPT-2) internally, dramatically reducing the scale required by previous works. Language modeling and incontext learning experiments demonstrate that the efficient approximations still allow the TINT to fine-tune the model. Our work emphasizes that the inference behavior of complex models may rely on the training dynamics of smaller models. As such, the existence of TINT has strong implications for interpretability and AI alignment research. Similar to prior research in this area, our insights into existing pre-trained models are limited. TINT was designed to understand the power of in-context reasoning with an explicit construction, and thereby, understand the safety risk of transformers trained with moderate compute. Hence, we introduce two major architectural modifications. TINT uses bidirectional attention for efficiently computing the loss on training portion ξ ( 2). Furthermore, we use prefix embeddings to efficiently represent the relevant auxiliary model parameters in each layer of TINT ( 2). Both of these design principles are largely motivated by existing perfomant architectures (Tay et al., 2022; Raffel et al., 2020; Cheng et al., 2023; Izacard et al., 2023; Borgeaud et al., 2022) and finetuning strategies (Liu et al., 2021; Lester et al., 2021; Zhang et al., 2023a; Li and Liang, 2021). However, because of the architecture modifications, TINT cannot be used to explain the in-context capability in existing popular autoregressive models (Touvron et al., 2023; Brown et al., 2020). TINT provides a possible connection between fine-tuning and in-context reasoning with transformer models. As such, inference time behaviors of large language models may require understanding the training dynamics of smaller transformers. On the other hand, such a connection can also lead to explorations on improved architecture designs by measuring generalization behaviors of the underlying simulated auxiliary model (Li and Zhang, 2021; Ju et al., 2022). Furthermore, we have not yet examined potential biases that may arise in the auxiliary models due to one-step gradient descent. We plan to investigate these aspects in future work. Trainable Transformer in Transformer Impact Statement We note that the construction of TINT does not appear to increase the probability of harmful behavior, because the construction s primary objective is to implicitly tune an internal model ( 2). Such tuning has been possible for a long time and is not made more expressive by TINT. Our findings suggest that existing transformer-based language models can plausibly possess the ability to learn and adapt to context by internally fine-tuning a complex model even during inference. Consequently, although users are unable to directly modify deployed models, these models may still undergo dynamic updates while processing a context left-to-right, resulting in previously unseen behavior by the time the model reaches the end of the context. This has significant implications for the field of model alignment. It is challenging to impose restrictions on a model that can perform such dynamics updates internally, so malicious content can influence the output of deployed models. Alternatively, we recognize the potential benefits of pretraining constructed models that integrate explicit finetuning mechanisms. By embedding the functionalities typically achieved through explicit fine-tuning, such as detecting malicious content and intent within the models themselves, the need for external modules can be mitigated. Pre-training the constructed model may offer a self-contained solution for ensuring safe and responsible language processing without relying on external dependencies. Acknowledgements The authors acknowledge funding from NSF, ONR, Simons Foundation, and DARPA. We thank Danqi Chen, Jason Lee, Zhiyuan Li, Kaifeng Lyu, Simran Kaur, Tianyu Gao, and Colin Wang for their suggestions and helpful discussions at different stages of our work. We thank the anonymous reviewers and the Area Chairs of Neur IPS 23, ICLR 24, and ICML 24 assigned to our paper for their helpful and detailed reviews and meta-reviews to improve the quality of our paper. Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. ar Xiv preprint ar Xiv:2306.00297, 2023. Ekin Akyurek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm is incontext learning? investigations with linear models. ar Xiv preprint ar Xiv:2211.15661, 2022. Dario Amodei, Chris Olah, Jacob Steinhardt, Paul Chris- tiano, John Schulman, and Dan Man e. Concrete problems in ai safety. ar Xiv preprint ar Xiv:1606.06565, 2016. Cem Anil, Yuhuai Wu, Anders Andreassen, Aitor Lewkowycz, Vedant Misra, Vinay Ramasesh, Ambrose Slone, Guy Gur-Ari, Ethan Dyer, and Behnam Neyshabur. Exploring length generalization in large language models. ar Xiv preprint ar Xiv:2207.04901, 2022. Amanda Askell, Yuntao Bai, Anna Chen, Dawn Drain, Deep Ganguli, Tom Henighan, Andy Jones, Nicholas Joseph, Ben Mann, Nova Das Sarma, et al. A general language assistant as a laboratory for alignment. ar Xiv preprint ar Xiv:2112.00861, 2021. Jimmy Ba, Geoffrey Hinton, Volodymyr Mnih, Joel Z. Leibo, and Catalin Ionescu. Using fast weights to attend to the recent past, 2016. Yu Bai, Fan Chen, Huan Wang, Caiming Xiong, and Song Mei. Transformers as statisticians: Provable in-context learning with in-context algorithm selection. ar Xiv preprint ar Xiv:2306.04637, 2023. Satwik Bhattamishra, Kabir Ahuja, and Navin Goyal. On the ability and limitations of transformers to recognize formal languages. ar Xiv preprint ar Xiv:2009.11264, 2020. Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George Bm Van Den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, et al. Improving language models by retrieving from trillions of tokens. In International conference on machine learning, pages 2206 2240. PMLR, 2022. Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877 1901, 2020. Stephanie Chan, Adam Santoro, Andrew Lampinen, Jane Wang, Aaditya Singh, Pierre Richemond, James Mc Clelland, and Felix Hill. Data distributional properties drive emergent in-context learning in transformers. Advances in Neural Information Processing Systems, 35:18878 18891, 2022. Xiang Cheng, Yuxin Chen, and Suvrit Sra. Transformers implement functional gradient descent to learn non-linear functions in context. ar Xiv preprint ar Xiv:2312.06528, 2023. Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. ar Xiv preprint ar Xiv:2204.02311, 2022. Trainable Transformer in Transformer Bilal Chughtai, Lawrence Chan, and Neel Nanda. A toy model of universality: Reverse engineering how networks learn group operations. ar Xiv preprint ar Xiv:2302.03025, 2023. Kevin Clark, Kelvin Guu, Ming-Wei Chang, Panupong Pasupat, Geoffrey Hinton, and Mohammad Norouzi. Meta-learning fast weight language models. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pages 9751 9757, Abu Dhabi, United Arab Emirates, December 2022. Association for Computational Linguistics. URL https:// aclanthology.org/2022.emnlp-main.661. Arthur Conmy, Augustine N Mavor-Parker, Aengus Lynch, Stefan Heimersheim, and Adri a Garriga-Alonso. Towards automated circuit discovery for mechanistic interpretability. ar Xiv preprint ar Xiv:2304.14997, 2023. Damai Dai, Yutao Sun, Li Dong, Yaru Hao, Zhifang Sui, and Furu Wei. Why can gpt learn in-context? language models secretly perform gradient descent as meta-optimizers, 2022. Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation in self-attention mechanisms. In International Conference on Machine Learning, pages 5793 5831. PMLR, 2022. N Elhage, N Nanda, C Olsson, T Henighan, N Joseph, B Mann, A Askell, Y Bai, A Chen, T Conerly, et al. A mathematical framework for transformer circuits. Transformer Circuits Thread, 2021. Shivam Garg, Dimitris Tsipras, Percy S Liang, and Gregory Valiant. What can transformers learn in-context? a case study of simple function classes. Advances in Neural Information Processing Systems, 35:30583 30598, 2022. Angeliki Giannou, Shashank Rajput, Jy yong Sohn, Kangwook Lee, Jason D. Lee, and Dimitris Papailiopoulos. Looped transformers as programmable computers, 2023. Linyuan Gong, Di He, Zhuohan Li, Tao Qin, Liwei Wang, and Tieyan Liu. Efficient training of bert by progressively stacking. In International conference on machine learning, pages 2337 2346. PMLR, 2019. Michael Hahn and Navin Goyal. A theory of emergent incontext learning as implicit structure induction. ar Xiv preprint ar Xiv:2303.07971, 2023. Chi Han, Ziqi Wang, Han Zhao, and Heng Ji. In-context learning of large language models explained as kernel regression. ar Xiv preprint ar Xiv:2305.12766, 2023. Dan Hendrycks and Kevin Gimpel. Gaussian error linear units (gelus). ar Xiv preprint ar Xiv:1606.08415, 2016. Geoffrey Hinton. The forward-forward algorithm: Some preliminary investigations, 2022. Ari Holtzman, Peter West, Vered Shwartz, Yejin Choi, and Luke Zettlemoyer. Surface form competition: Why the highest probability answer isn t always right. ar Xiv preprint ar Xiv:2104.08315, 2021. Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. ar Xiv preprint ar Xiv:2106.09685, 2021. Minqing Hu and Bing Liu. Mining and summarizing customer reviews. In Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining, pages 168 177, 2004. Kazuki Irie, Imanol Schlag, R obert Csord as, and J urgen Schmidhuber. Going beyond linear transformers with recurrent fast weight programmers. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum? id=ot2ORi Bq Ta1. Gautier Izacard and Edouard Grave. Leveraging passage retrieval with generative models for open domain question answering. ar Xiv preprint ar Xiv:2007.01282, 2020. Gautier Izacard, Patrick Lewis, Maria Lomeli, Lucas Hosseini, Fabio Petroni, Timo Schick, Jane Dwivedi-Yu, Armand Joulin, Sebastian Riedel, and Edouard Grave. Atlas: Few-shot learning with retrieval augmented language models. Journal of Machine Learning Research, 24(251): 1 43, 2023. URL http://jmlr.org/papers/ v24/23-0037.html. Hui Jiang. A latent space theory for emergent abilities in large language models. ar Xiv preprint ar Xiv:2304.09960, 2023. Haotian Ju, Dongyue Li, and Hongyang R Zhang. Robust fine-tuning of deep neural networks with hessian-based generalization guarantees. In International Conference on Machine Learning, pages 10431 10461. PMLR, 2022. Ananya Kumar, Ruoqi Shen, S ebastien Bubeck, and Suriya Gunasekar. How to fine-tune vision models with sgd, 2022. Jan Leike, David Krueger, Tom Everitt, Miljan Martic, Vishal Maini, and Shane Legg. Scalable agent alignment via reward modeling: a research direction. ar Xiv preprint ar Xiv:1811.07871, 2018. Brian Lester, Rami Al-Rfou, and Noah Constant. The power of scale for parameter-efficient prompt tuning. ar Xiv preprint ar Xiv:2104.08691, 2021. Trainable Transformer in Transformer Dongyue Li and Hongyang Zhang. Improved regularization and robustness for fine-tuning in neural networks. Advances in Neural Information Processing Systems, 34: 27249 27262, 2021. Xiang Lisa Li and Percy Liang. Prefix-tuning: Optimizing continuous prompts for generation. ar Xiv preprint ar Xiv:2101.00190, 2021. David Lindner, J anos Kram ar, Matthew Rahtz, Thomas Mc Grath, and Vladimir Mikulik. Tracr: Compiled transformers as a laboratory for interpretability. ar Xiv preprint ar Xiv:2301.05062, 2023. Bingbin Liu, Jordan T. Ash, Surbhi Goel, Akshay Krishnamurthy, and Cyril Zhang. Transformers learn shortcuts to automata. In The Eleventh International Conference on Learning Representations, 2023. URL https: //openreview.net/forum?id=De4FYqj Fue Z. Haokun Liu, Derek Tam, Mohammed Muqeeth, Jay Mohta, Tenghao Huang, Mohit Bansal, and Colin A Raffel. Fewshot parameter-efficient fine-tuning is better and cheaper than in-context learning. Advances in Neural Information Processing Systems, 35:1950 1965, 2022. Xiao Liu, Kaixuan Ji, Yicheng Fu, Weng Lam Tam, Zhengxiao Du, Zhilin Yang, and Jie Tang. P-tuning v2: Prompt tuning can be comparable to fine-tuning universally across scales and tasks. ar Xiv preprint ar Xiv:2110.07602, 2021. Arvind Mahankali, Tatsunori B Hashimoto, and Tengyu Ma. One step of gradient descent is provably the optimal in-context learner with one layer of linear self-attention. ar Xiv preprint ar Xiv:2307.03576, 2023. Sadhika Malladi, Tianyu Gao, Eshaan Nichani, Alex Damian, Jason D. Lee, Danqi Chen, and Sanjeev Arora. Fine-tuning language models with just forward passes. In Thirty-seventh Conference on Neural Information Processing Systems, 2023. URL https://openreview. net/forum?id=Vota6r Fh BQ. Stephen Merity, Caiming Xiong, James Bradbury, and Richard Socher. Pointer sentinel mixture models. ar Xiv preprint ar Xiv:1609.07843, 2016. Neel Nanda, Lawrence Chan, Tom Lieberum, Jess Smith, and Jacob Steinhardt. Progress measures for grokking via mechanistic interpretability. ar Xiv preprint ar Xiv:2301.05217, 2023. Catherine Olsson, Nelson Elhage, Neel Nanda, Nicholas Joseph, Nova Das Sarma, Tom Henighan, Ben Mann, Amanda Askell, Yuntao Bai, Anna Chen, et al. Incontext learning and induction heads. ar Xiv preprint ar Xiv:2209.11895, 2022. Bo Pang and Lillian Lee. A sentimental education: Sentiment analysis using subjectivity summarization based on minimum cuts. In Proceedings of the 42nd Annual Meeting of the Association for Computational Linguistics (ACL-04), pages 271 278, 2004. Bo Pang and Lillian Lee. Seeing stars: Exploiting class relationships for sentiment categorization with respect to rating scales. In Proceedings of the 43rd Annual Meeting of the Association for Computational Linguistics (ACL 05), pages 115 124, 2005. Jorge Perez, Pablo Barcelo, and Javier Marinkovic. Attention is turing-complete. Journal of Machine Learning Research, 22(75):1 35, 2021. URL http://jmlr. org/papers/v22/20-302.html. Ofir Press, Noah A Smith, and Mike Lewis. Train short, test long: Attention with linear biases enables input length extrapolation. ar Xiv preprint ar Xiv:2108.12409, 2021. Jorge P erez, Javier Marinkovi c, and Pablo Barcel o. On the turing completeness of modern neural network architectures. In International Conference on Learning Representations, 2019. URL https://openreview. net/forum?id=Hy GBdo0q Fm. Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1):5485 5551, 2020. Sashank J Reddi, Sobhan Miryoosefi, Stefani Karp, Shankar Krishnan, Satyen Kale, Seungyeon Kim, and Sanjiv Kumar. Efficient training of language models using few-shot learning. 2023. Nikunj Saunshi, Sadhika Malladi, and Sanjeev Arora. A mathematical exploration of why language models help solve downstream tasks. ar Xiv preprint ar Xiv:2010.03648, 2020. Teven Le Scao, Angela Fan, Christopher Akiki, Ellie Pavlick, Suzana Ili c, Daniel Hesslow, Roman Castagn e, Alexandra Sasha Luccioni, Franc ois Yvon, Matthias Gall e, et al. Bloom: A 176b-parameter openaccess multilingual language model. ar Xiv preprint ar Xiv:2211.05100, 2022. Benjamin Scellier and Yoshua Bengio. Equilibrium propagation: Bridging the gap between energy-based models and backpropagation. Frontiers in computational neuroscience, 11:24, 2017. Imanol Schlag, Kazuki Irie, and J urgen Schmidhuber. Linear transformers are secretly fast weight memory systems. Co RR, abs/2102.11174, 2021. URL https://arxiv. org/abs/2102.11174. Trainable Transformer in Transformer Noam Shazeer. Glu variants improve transformer. ar Xiv preprint ar Xiv:2002.05202, 2020. Richard Socher, Alex Perelygin, Jean Wu, Jason Chuang, Christopher D Manning, Andrew Y Ng, and Christopher Potts. Recursive deep models for semantic compositionality over a sentiment treebank. In Proceedings of the 2013 conference on empirical methods in natural language processing, pages 1631 1642, 2013. Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. ar Xiv preprint ar Xiv:2104.09864, 2021. Yi Tay, Mostafa Dehghani, Vinh Q Tran, Xavier Garcia, Jason Wei, Xuezhi Wang, Hyung Won Chung, Dara Bahri, Tal Schuster, Steven Zheng, et al. Ul2: Unifying language learning paradigms. In The Eleventh International Conference on Learning Representations, 2022. Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timoth ee Lacroix, Baptiste Rozi ere, Naman Goyal, Eric Hambro, Faisal Azhar, et al. Llama: Open and efficient foundation language models. ar Xiv preprint ar Xiv:2302.13971, 2023. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017. Johannes von Oswald, Eyvind Niklasson, Maximilian Schlegel, Seijin Kobayashi, Nicolas Zucchet, Nino Scherrer, Nolan Miller, Mark Sandler, Max Vladymyrov, Razvan Pascanu, et al. Uncovering mesa-optimization algorithms in transformers. ar Xiv preprint ar Xiv:2309.05858, 2023. Boxin Wang, Wei Ping, Peng Xu, Lawrence Mc Afee, Zihan Liu, Mohammad Shoeybi, Yi Dong, Oleksii Kuchaiev, Bo Li, Chaowei Xiao, Anima Anandkumar, and Bryan Catanzaro. Shall we pretrain autoregressive language models with retrieval? a comprehensive study. In Houda Bouamor, Juan Pino, and Kalika Bali, editors, Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, pages 7763 7786, Singapore, December 2023a. Association for Computational Linguistics. doi: 10.18653/v1/2023.emnlp-main. 482. URL https://aclanthology.org/2023. emnlp-main.482. Kevin Ro Wang, Alexandre Variengien, Arthur Conmy, Buck Shlegeris, and Jacob Steinhardt. Interpretability in the wild: a circuit for indirect object identification in GPT-2 small. In Neur IPS ML Safety Workshop, 2022. URL https://openreview.net/forum? id=rvi3Wa768B-. Xinyi Wang, Wanrong Zhu, and William Yang Wang. Large language models are implicitly topic models: Explaining and finding good demonstrations for in-context learning. ar Xiv preprint ar Xiv:2301.11916, 2023b. Colin Wei, Yining Chen, and Tengyu Ma. Statistically meaningful approximation: a case study on approximating turing machines with transformers. Co RR, abs/2107.13163, 2021. URL https://arxiv.org/ abs/2107.13163. Gail Weiss, Yoav Goldberg, and Eran Yahav. Thinking like transformers. In International Conference on Machine Learning, pages 11080 11090. PMLR, 2021. Janyce Wiebe, Theresa Wilson, and Claire Cardie. Annotating expressions of opinions and emotions in language. Language resources and evaluation, 39:165 210, 2005. Noam Wies, Yoav Levine, and Amnon Shashua. The learnability of in-context learning. ar Xiv preprint ar Xiv:2303.07895, 2023. Sang Michael Xie, Aditi Raghunathan, Percy Liang, and Tengyu Ma. An explanation of in-context learning as implicit bayesian inference. In International Conference on Learning Representations, 2022. URL https:// openreview.net/forum?id=Rd JVFCHj UMI. Shunyu Yao, Binghui Peng, Christos Papadimitriou, and Karthik Narasimhan. Self-attention networks can process bounded hierarchical languages. ar Xiv preprint ar Xiv:2105.11115, 2021. Biao Zhang and Rico Sennrich. Root mean square layer normalization. Advances in Neural Information Processing Systems, 32, 2019. Renrui Zhang, Jiaming Han, Chris Liu, Peng Gao, Aojun Zhou, Xiangfei Hu, Shilin Yan, Pan Lu, Hongsheng Li, and Yu Qiao. Llama-adapter: Efficient fine-tuning of language models with zero-init attention. ar Xiv preprint ar Xiv:2303.16199, 2023a. Ruiqi Zhang, Spencer Frei, and Peter L Bartlett. Trained transformers learn linear models in-context. ar Xiv preprint ar Xiv:2306.09927, 2023b. Xiang Zhang, Junbo Zhao, and Yann Le Cun. Character-level convolutional networks for text classification. Advances in neural information processing systems, 28, 2015. Yufeng Zhang, Fengzhuo Zhang, Zhuoran Yang, and Zhaoran Wang. What and how does in-context learning learn? bayesian model averaging, parameterization, and generalization. ar Xiv preprint ar Xiv:2305.19420, 2023c. Trainable Transformer in Transformer Hattie Zhou, Arwen Bradley, Etai Littwin, Noam Razin, Omid Saremi, Josh Susskind, Samy Bengio, and Preetum Nakkiran. What algorithms can transformers learn? a study in length generalization. ar Xiv preprint ar Xiv:2310.16028, 2023. Trainable Transformer in Transformer 1 Introduction 1 2 Design Considerations 2 2.1 Input structure . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 2 2.2 Read and write access to auxiliary model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 3 3 Efficient Forward Propagation 4 3.1 Stacking and Sharding . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 4 3.2 Efficient Aggregation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5 4 Simulated Gradient 5 4.1 First-order approximations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 5 4.2 Fuzzy backpropagation via stop gradients . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 6 5 Experiments 6 5.1 Setting: N-step Fine-Tuning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7 5.2 Case Study: Language Modeling . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7 5.3 Case Study: In-Context Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7 6 Related Work 9 7 Discussion 9 A Additional related works 16 B Notations 16 B.1 Simulating Multiplication from (Akyurek et al., 2022) . . . . . . . . . . . . . . . . . . . . . . . . . . . . 20 C Linear layer 20 C.1 Hsim-split operation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21 D Self-attention layer 23 D.1 Proofs of theorems and gradient definitions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 28 E Layer normalization 32 E.1 Additional definitions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34 E.2 Proof of theorems and gradient definitions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34 F Activation layer 36 F.1 Proofs of theorems . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 37 Trainable Transformer in Transformer G Language model head 39 H Parameter sharing 40 I Additional modules 40 I.1 Root mean square normalization (RMSnorm) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40 I.2 Attention variants . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 40 I.3 Gated linear units (GLUs) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41 J Construction of other variants of pre-trained models 43 K Experiments 44 Brief overview of the appendix In Appendix A, we report few additional related works. In Appendix B, we present all the important notations used to present the design of TINT. In Appendices C to F, we present the simulation details of all operations on linear, self-attention, layer normalization, and activation layers respectively for an auxiliary model. In Appendix G, we present the details for simulating loss computation with the language model head of the auxiliary model. In Appendix I, we discuss simulation of additional modules necessary to simulate transformer variants like LLa MA (Touvron et al., 2023) and BLOOM (Scao et al., 2022). Finally, in Appendix K, we discuss the deferred experimental details from the main paper. A. Additional related works Interpretability: Mechanistic interpretability works reverse-engineer the algorithms simulated by these models (Elhage et al., 2021; Olsson et al., 2022; Wang et al., 2022; Nanda et al., 2023; Chughtai et al., 2023; Conmy et al., 2023). These works study local patterns, e.g. activations and attention heads, to derive interpretable insights. Other works (Weiss et al., 2021; Lindner et al., 2023) use declarative programs to algorithmically describe transformer models. Zhou et al. (2023) use these to explain task-specific length generalization of transformer models. Alternative Explanations for ICL: Some works study ICL using a Bayesian framework. Xie et al. (2022) model pretraining data as a mixture of HMMs and cast ICL identifying one such component. Hahn and Goyal (2023) later modeled language as a compositional grammar, and propose ICL as a composition of operations. (Zhang et al., 2023c; Jiang, 2023; Wang et al., 2023b; Wies et al., 2023) further strengthen this hypothesis by generalizing the underlying latent space. On the other hand, careful experiments in Chan et al. (2022) show that data distributional properties (e.g. Zipf s law) drive in-context learning in transformers. Transfer learning: Our construction uses a pre-trained model to initialize a larger transformer, which is similar to several other more empirically oriented works (Gong et al., 2019; Reddi et al., 2023). B. Notations For simplicity, we repeat notations from the main paper. Let D denote the embedding dimension for a token and T denote the length of an input sequence. H denotes the number of attention heads. With the exception of contextual embeddings, we use subscripts to indicate if the quantity is from TINT or from the auxiliary model. For example, Daux refers to the embedding dimension and Dsim refers to the TINT embedding dimension. For contextual embeddings, we use e(ℓ) t RDsim to denote activations in TINT and x(ℓ) t RDaux to denote activations in the auxiliary model, where ℓis the layer and t is the sequence position. When convenient, we drop the superscript that represents the layer index and the subscript that represents the position index. For a matrix A, aj refers to its jth row, and for any vector b, bj refers to its jth element. However, at a few places, for typographical reasons, for a matrix A, we have also used (A)j to refer to its jth row, and for any vector b, (b)j to refer to its jth element. TINT uses one-hot positional embeddings {p TINT i RTsim}i Tsim. Trainable Transformer in Transformer Table 3: Number of parameters of TINT for the forward, backward, and gradient update operations on various modules. For simplicity, we have ignored biases in the following computation. We set S = 4, i.e. stack 4 weights in each prefix embedding. We set Hsim = 12 for OPT-125M and Hsim = 16 for the other models, Dsim = 4Daux for all the models, and Tsim = Taux + K, with Taux = 2048 for OPT models, and K = Daux/4. Q = 4Qsplit + 3Tsim Dsim/Hsim, where Qsplit = 1 Hsim (Dsim)2 + Hsim Dsim, denotes the number of parameters in a TINT Linear Forward module (Section 3). Module Size Module Name Forward Backward Descent Total Linear layer Q Q Q 3Q Layer norms Q Q + 2Dsim Hsim Q 3Q + 2Dsim Hsim Self-Attention 2Q 2Q 2Q 6Q Activation Qsplit 2Dsim Hsim 0 Qsplit + 2Dsim Hsim Self-Attention block 4Q 4Q + 2Dsim Hsim 4Q 12Q + 2Dsim Hsim Feed-forward block 3Q + Qsplit 3Q + 4Dsim Hsim 3Q 9Q + 4Dsim Hsim Transformer block 7Q + Qsplit 7Q + 6Dsim Hsim 7Q 21Q + 6Dsim Hsim + Qsplit Transformer 7QL + LQsplit (7Q + 6Dsim Hsim)L 7QL (21Q + 6Dsim Hsim + Qsplit)L OPT-125M 0.4B 0.4B 0.4B 1.2B OPT-350M 1.2B 1.1B 1.1B 3.4B OPT-1.3B 3.7B 3.6B 3.5B 10.8B OPT-2.7B 7.4B 7.2B 7.2B 21.8B We differentiate the parameters of the auxiliary model and TINT by using an explicit superscript TINT for TINT parameters, for example, the weights of a linear layer in TINT will be represented by W TINT. We use two operations throughout: SPLITh and VECTORIZE. Function SPLITh : Rd Rh d/h takes an input x Rd and outputs H equal splits of x, for any arbitrary dimension d. Function VECTORIZE : Rh d Rdh concatenates the elements of a sequence {xi Rd}i h into one single vector, for any arbitrary d and h. Auxiliary model s self-attention We first start with the definition of a single head self-attention layer for the auxiliary model. The definition can be easily extended to a multi-head self-attention layer. A self-attention layer first computes the query, key, and value vectors at each position by token-wise linear transformations of the input embeddings. The query and the key vectors are used to compute pairwise self-attention scores. These scores are then used to linearly combine the value vectors. Definition B.1 (Auxiliary model softmax self-attention). A self-attention layer with parameters {WQ, WK, WV } takes a sequence {xt}t Taux and outputs a sequence {yt}t Taux, such that j at,jvj, with at,j = softmax(Kqt)j, qt = WQxt, kt = WKxt, vt = WV xt, for all t Taux, and K RTaux Daux defined with rows {kt}Taux t=1. TINT Attention Module We modify the usual attention module to include the position embeddings {p TINT i RTsim}i Tsim. In usual self-attention modules, the query, key, and value vectors at each position are computed by token-wise linear transformations of the input embeddings. In TINT s Attention Module, we perform additional linear transformations on the position embeddings, using parameters W p Q, W p K, W p V , and decision vectors λQ, λK, λV RHsim decide whether to add these transformed position vectors to the query, key, and value vectors of different attention heads. For the following definition, we use be to represent input sequence and ee to represent the output sequence: we introduce these general notations below to avoid confusion with the notations for token and prefix embeddings for TINT illustrated in Figure 1. Definition B.2 (TINT s self-attention with Hsim heads). For parameters {W TINT Q , W TINT K , W TINT V RDsim Dsim}, {b TINT Q , b TINT K , b TINT V RDsim}, {W p Q, W p K, W p V RTsim Dsim/Hsim} and {λQ, λK, λV RHsim}, TINT self-attention with Hsim attention heads and a function fattn : RTsim RTsim takes a sequence {bet RDsim}t Tsim as input and outputs Trainable Transformer in Transformer {et RDsim}t Tsim, with et = VECTORIZE({ X j Tsim ah t,jevh j )h}h Hsim), with ah t,j = fattn(f Kh eqh t )j eqh t = SPLITH(qt)h + λQ h W p Qp TINT t ; ekh t = SPLITH(kt)h + λK h W p Kp TINT t ; evh t = SPLITH(vt)h + λV h W p v p TINT t . Here, qt, kt, vt denote the query, key, and value vectors at each position t, computed as W TINT Q bet +b TINT Q , W TINT K bet +b TINT K , and W TINT V bet + b TINT V respectively. f Kh RTsim Dsim/Hsim is defined with its rows as {ekh t }t Tsim for all h Hsim. fattn can be either linear or softmax function. Bounded parameters and input sequence: We define a linear self-attention layer to be Bw-bounded, if the ℓ2 norms of all the parameters are bounded by Bw. Going by Definition B.2, this implies max{ W TINT Q 2 , W TINT K 2 , W TINT V 2} Bw, max{ b TINT Q 2 , b TINT K 2 , b TINT V 2} Bw max{ W p Q 2 , W p K 2 , W p V 2} Bw, max{ λQ 2 , λK 2 , λV 2} Bw. Furthermore, we define an input sequence {bet}t Tsim to Bx-bounded, if bet 2 Bx for all t. Recall from the main paper (Section 3), we used Linear TINT Self-Attention layer to represent the linear operations of the auxiliary model. In the following theorem, we show that a linear attention layer can be represented as a softmax attention layer that uses an additional attention head and an extra token u, followed by a linear layer. Therefore, replacing softmax attention with linear attention does not deviate too far from the canonical transformer. We use the Linear TINT Self-Attention layers in several places throughout the model. Theorem B.3. For any Bw > 0, consider a Bw-bounded linear self-attention layer that returns {elinear t RD sim}t Tsim on any input {bet RD sim}t Tsim. Consider a softmax self-attention layer with 2Hsim attention heads and an additional token u R2Dsim such that for any Bx-bounded input {bet}t Tsim, it takes a modified input sequence { e1, , e Tsim, u}, and returns {esoftmax t R2Dsim}t Tsim. Each modified input token et R2Dsim is obtained by concatenating additional 0s to bet. Then, for any Bx > 0, and ϵ O(T 2 sim B 5 w B 5 x ), there exists WO RDsim 2Dsim and such a softmax self-attention layer such that WOesoftmax t elinear t 2 O( ϵ), for all t Tsim. Proof. Consider an input sequence {xt}t Tsim. Let the attention scores of any linear head h Hsim in the linear attention layer be given by {ah t,j}j Tsim, at any given position t. Additionally, let the value vectors for the linear attention be given by vt. To repeat our self-attention definition, the output of the attention layer at any position t is given by VECTORIZE({elinear,h t }h Hsim), where elinear,h t = X j Tsim ah t,jvh j . Under our assumption, Bw denotes the maximum ℓ2 norm of all the parameters in the linear self-attention layer and Bx the maximum ℓ2 norm in the input sequence, i.e. maxt Tsim xt 2 Bx. With a simple application of Cauchy-Schwartz inequality, we can show that maxj Tsim |ah t,j| O(B2 w B2 x), and maxt Tsim vh t 2 O(Bw Bx). For ϵ O(T 10/9 sim B 40/9 w B 40/9 x ), we can then use Lemma B.4 to represent for each t, j Tsim, ah t,j = ϵ 3eϵat,j P t Tsim eϵah t,t + e 2 log ϵ ϵ 1 + O ϵ(Tsim + ah t,j) := ϵ 3softmax {ϵah t,1, ϵah t,2, , ϵah t,Tsim, 2 log ϵ} j ϵ 1 + O ϵ0.9 . Trainable Transformer in Transformer Softmax attention construction: We define u, and the query and key parameters of the softmax attention layer such that for the first Hsim attention heads, the query-key dot products for all the attention heads between any pairs {( et, ej)}t,j Tsim is given by {ϵah t,j}h Hsim, while being 2 log ϵ between u and any token et, with t Tsim. For the rest of Hsim attention heads, the attention scores are uniformly distributed across all pairs of tokens (attention score between any pair of tokens is given by 1 Tsim+1). We set the value parameters of the softmax attention layer such that at any position t Tsim, the value vector is given by VECTORIZE({ϵ 3vt, vt}). The value vector returned for u contains all 0s. Softmax attention computation: Consider an attention head h Hsim in the softmax attention layer now. The output of the attention head at any position t Tsim is given by esoftmax,h t = X j Tsim softmax {ϵah t,1, ϵah t,2, , ϵah t,Tsim, 2 log ϵ} ah t,j + ϵ 1 + O(ϵ0.9) vh j . This has an additional P j Tsim ϵ 1 + O(ϵ0.9) vh j , compared to elinear,h t . However, consider the output of the attention head Hsim + h at the same position: esoftmax,Hsim+h t = 1 Tsim + 1 j Tsim vh j . Hence, we can use the output matrix WO to get esoftmax,h t Tsim+1 ϵ esoftmax,Hsim+h t = P j Tsim ah t,j + O(ϵ0.9) vh j . The additional term O(ϵ0.9) P j Tsim vh j can be further shown to be O(ϵ0.5) small with the assumed bound of ϵ, since each vh j is atmost O(Bw Bx) in ℓ2 norm with a Cauchy Schwartz inequality. Lemma B.4. For ϵ > 0, B > 0, and a sequence {a1, a2, , a T } with each ai R and |ai| B, the following holds true for all i T, t T eϵat + e 2 log ϵ = ai + 1 ϵ + O ϵ0.9 , provided ϵ O(T 10/9B 20/9). Proof. We will use the following first-order Taylor expansions: ex = 1 + x + O(x2). (4) 1 1 + x = 1 O(x). (5) Hence, for any x 1, x ex 1. Simplifying the L.H.S. of the desired bound, we have t T eϵat + e 2 log ϵ = ϵ 3(1 + ϵai + O(ϵ2a2 i )) P t T (1 + ϵat + O(ϵ2a2 t )) + e 2 log ϵ (6) = ϵ 1 + ai + O(ϵa2 i ) P t T (ϵ2 + ϵ3at + O(ϵ4a2 t )) + 1 (7) = ϵ 1 + ai + O(ϵa2 i ) 1 + O(ϵ2T) (8) = ϵ 1 + ai + O(ϵT + a2 i Tϵ2 + a2 i Tϵ3 + ϵa2 i ) = ϵ 1 + ai + O(ϵ0.9). We used taylor expansion of exponential function( Equation (4) ) in Equation (6) to get Equation (7), and taylor expansion of inverse function(Equation (5)) to get Equation (8) from Equation (7). Furthermore, with the lower bound assumption on ϵ, P t T (ϵ2 + ϵ3at + O(ϵ4a2 t )) can be shown to be atmost 3ϵ2T, which amounts to O(ϵ2T) error in Equation (8). The final error bound has again been simplified using the lower bound assumption on ϵ. Trainable Transformer in Transformer B.1. Simulating Multiplication from (Akyurek et al., 2022) We refer to the multiplication strategy of (Akyurek et al., 2022) at various places. Lemma B.5. [Lemma 4 in (Akyurek et al., 2022)] The Ge LU (Hendrycks and Gimpel, 2016) nonlinearity can be used to perform multiplication: specifically, p π/2(Ge LU(x + y) Ge LU(x) Ge LU(y)) = xy + O(x3 + y3). Thus, to represent an element-wise product or a dot product between two sub-vectors in a token embedding, we can use a MLP with a Ge LU activation. C. Linear layer In the main paper, we defined the linear layer without the bias term for simplicity (Definition 3.1). In this section, we will redefine the linear layer with the bias term and present a comprehensive construction of the Linear Forward module. Definition C.1 (Linear layer). For a weight W RDaux Daux and bias b RDaux, a linear layer takes x RDaux as input and outputs y = W x + b. In the discussions below, we consider a linear layer in the auxiliary model with parameters {W , b} that takes in input sequence x1, , x Taux and outputs y1, , y Taux, with yt = W xt + b for each t Taux. Since this involves a token-wise operation, we will present our constructed modules with a general token position t and the prefix tokens {vj}. TINT Linear Forward module Continuing our discussion from Section 3, we represent S stacked rows of W as a prefix embedding. In addition, we store the bias b in the first prefix embedding (v1). Using a set of S unique attention heads in a TINT attention module (Definition B.2), we copy the bias b to respective token embeddings and use a TINT linear layer to add the biases to the final output. Auxiliary s backpropagation through linear layer For a linear layer as defined in Definition C.1, the linear backpropagation layer takes in the loss gradient w.r.t. output ( y) and computes the loss gradient w.r.t. input ( x). Definition C.2 (Linear backpropagation ). For a weight W RDaux Daux , the linear backpropagation layer takes y RDaux as input and outputs x = W y. TINT Linear backpropagation module This module will aim to simulate the auxiliary s linear backpropagation. The input embedding et to this module will contain the gradient of the loss w.r.t. yt, i.e. yt. As given in Definition C.2, this module will output the gradient of the loss w.r.t. xt, given by xt = W yt. We first use the residual connection to copy the prefix embeddings {vj} (i.e., the rows of W ) from the forward propagation module. A straightforward construction would be to use the Linear Forward module but with the columns of W stored in the prefix tokens, thereby simulating multiplication with W . However, such a construction requires applying attention to the prefix tokens, which increases the size of the construction substantially. We instead perform the operation more efficiently by splitting it across attention heads. In particular, once we view the operation as xt = P i ( yt)i wi, we can see that the attention score between the current token and the prefix token containing wi must be ( yt)i. Using value vectors as rows of W returns the desired output. Similar to the Linear Forward module, we shard the weights into S parts to parallelize across more attention heads. Please see Figure 4. Auxiliary s linear descent update Finally, the linear descent layer updates the weight and the bias parameters using a batch of inputs {xt}t Taux and the loss gradient w.r.t. the corresponding outputs { yt}t Taux. Definition C.3 (Linear descent). For a weight W RDaux Daux and a bias b RDaux, the linear descent layer takes in a batch of inputs {xt RD aux}t Taux and gradients { yt RD aux}t Taux and updates the parameters as follows: t Taux ytx t ; b b η X Trainable Transformer in Transformer R2Daux RDaux/2 np GFf NM7vzmr N6z KOCjg Ah6AOb HAJmu AWt EAb YPAIns Ere DOej Bfj3fi Yli4YZc8e+APj8wd El Zfv(@yt)3 (@yt)Daux 1 (@yt)2 (@yt)4 Figure 4: TINT simulates the backward pass of a linear layer as a H-head attention layer (H = 6 pictured), with the gradient of the loss w.r.t. linear layer output ( yt) as the query, the positional one-hot vector of prefix embeddings as the key, and the parameters of the auxiliary model stored in the prefix embeddings as the value. Similar to the Linear Forward module (Figure 2), we distribute the dot product computations across all attention heads by sharding the vectors into S (S = 3 here) parts. We omitted the identical transformation for query, and value matrices, and permutation-based transformation for key matrix for illustration purposes. TINT Linear descent module The input embedding et to this module will contain the gradient of the loss w.r.t. yt, i.e. yt. As in the Linear backpropagation module, the prefix tokens {vj} will contain the rows of W and b, which have been copied from the Linear forward module using residual connections. Since, in addition to the gradients, we also require the input to the linear layer, we will use residual connections to copy the input {xt} to their respective embeddings {et}, from the Linear Forward module. As given in Definition C.3, this module will update W and b using the gradient descent rule. Focusing on wi, the descent update is given by wi wi η P t ( yt)i xt. For the prefix token vj that contains wi, the update term η P t ( yt)i xt can be expressed with an attention head that represents the attention between the prefix token vj and any token et with score ( yt)i and value ηxt. The residual connection can then be used to update the weights wi in vj. For the bias b, the descent update is give by b b η P t yt. With b present in v1, we use one attention head to represent the attention score between prefix token v1 and any token et as 1, with the value being η yt. The residual connection can then be used to update the weights b in v1. The above process can be further parallelized across multiple attention heads, by sharding each weight computation into S parts. Please see Figure 5. C.1. Hsim-split operation We leverage local structure within the linear operations of TINT to make the construction smaller. We build two Hsim-split operations to replace all the linear operations. We use dsim to denote Dsim/Hsim in the following definitions. Definition C.4 (Split-wise Hsim-split Linear operation). For weight and bias parameters W TINT RHsim dsim dsim, BTINT RHsim dsim , this layer takes in input e RDsim and returns e = VECTORIZE( e S + BTINT), with e S RHsim dsim defined with rows {W TINT h SPLITHsim(e)h}h Hsim. Definition C.5 (Dimension-wise Hsim-split Linear operation). For weight and bias parameters W TINT Rdsim Hsim Hsim, BTINT Rdsim Hsim , this layer takes in input e RDsim, defines S Rdsim Hsim with columns {SPLITHsim(e)h}h Hsim, and returns e = VECTORIZE(( e S + BTINT) ), where e S Rdsim Hsim is defined with rows {W TINT d s TINT d }d dsim. We find that we can replace all the linear operations with a splitwise Hsim-split Linear operation followed by a dimensionwise Hsim-split Linear operation, and an additional splitwise Hsim-split Linear operation, if necessary. A linear operation on Trainable Transformer in Transformer Tk Mwfw B87n D4ZAj6U= 3 (@yt)1 (@yt)3 (@yt)Daux 1 (@yt)2 (@yt)4 tth token tth token Figure 5: TINT computes the parameter gradients for a linear layer as a H-head attention layer (H = 6 pictured), with the gradient of the loss w.r.t. linear layer output ( yt) as the query, the positional one-hot vector of prefix embeddings as the key, and the input to the linear layer (xt) as the value. The auxiliary model parameters in the prefix embeddings are then updated using a residual connection. Similar to the Linear Forward module (Figure 2), we distribute the dot product computations across all attention heads, by sharding the vectors into S (S = 3 here) parts. We omitted the identical transformation for query, and value matrices, and permutation-based transformation for key matrix for simplicity. Dsim-dimensional space involves D2 sim parameters, while its replacement requires D2 sim/Hsim + 2Dsim Hsim parameters, effectively reducing the total number of necessary parameters by Hsim. We motivate the Hsim-split linear operations with an example. We consider the Linear Forward module in Figure 2 for simulating a linear operation with parameters W RDaux Daux and no biases. For simplicity of presentation, we assume Daux is divisible by 4. We stack 2 rows of weights per prefix embedding. We distribute the dot-product computation across the Hsim = 6 attention heads, by sharding each weight into 3 parts. Since we require to have enough space to store all the sharded computation from the linear attention heads, we require Dsim = 3Daux (we get 3 values for each of the Daux weights in W ). For presentation, for a given vector v RDaux, we represent SPLIT3(v)i by vi for all 1 i 3. Now, consider the final linear operation responsible for combining the output of the attention heads. The output, after the linear operation, should contain W xt in the first Daux coordinates. At any position t, if we stack the output of the linear attention heads as rows of a matrix St RHsim Dsim/Hsim we get w1 1, x1 t w1 3, x1 t w1 5, x1 t w1 Daux 1, x1 t w2 1, x2 t w2 3, x2 t w2 5, x2 t w2 Daux 1, x2 t w3 1, x3 t w3 3, x3 t w3 5, x3 t w3 Daux 1, x3 t w1 2, x1 t w1 4, x1 t w1 6, x1 t w1 Daux, x1 t w2 2, x2 t w2 4, x2 t w2 6, x2 t w2 Daux, x2 t w3 2, x3 t w3 4, x3 t w3 6, x3 t w3 Daux, x3 t Note that for each j Daux, we have wj, xt = P3 i=1 wi j, xi t . Thus, with a column-wise linear operation on St, we can sum the relevant elements in each column to get w1, xt w3, xt w Daux/2 1, xt 0 0 0 w2, xt w4, xt w Daux/2, xt 0 0 0 0 0 0 w Daux/2+1, xt w Daux/2+3, xt w Daux 1, xt 0 0 0 w Daux/2+2, xt w Daux/2+4, xt w Daux, xt 0 0 0 0 0 0 0 0 0 0 0 0 Trainable Transformer in Transformer Linear Forward Module (Figure 2) HUr8Q8FT93ZGSQKl J4On Kf Es17+Xif14v Af/SXk YJ8BCOhvk Jw JDh PNE8IBLRk FMNCFUcr0rpi Mi CQWd W1m HYM+fv Eja Jz X7v Hb WPK3Wr4o4Sqi CDt Exst EFq Mb1EAt RNEDek Iv6NV4NJ6N+N9Vrpk FD0H6A+Mj2+Wp5i5WQxt WQxt+1 7iy8uke Vb1L6s X9+e V2k0e Rx GO4Bh Ow Ycrq MEd1KEBFB7h GV7hz VHOi/Puf Mxb C04+cwh/4Hz+ALGXjzk= WKxt+1 WKxt WKxt 1 Rgykn HZ2T4QCs97Ed Svx Dw WP3dk ZJAq VHg6cp8Tz Xt5e J/Xic B/8Je Rgnw EI6Ge Qn Ak OE80xwj0t GQYw0IVRyv Sum Ay IJBZ1c WYdg T58S5r HVfusenpz Uqld Fn GU0C7a R4f IRueohq5RHTUQRQ/o Cb2g V+PRe Dbej Pd J6Zx R9Oyg Pz A+vg GNBJkwWV xt 1 WV xt WV xt+1 et et+1 et 1 Figure 6: TINT simulates the forward pass of a self-attention layer of the auxiliary model with a Linear Forward module (Figure 2) and a TINT softmax attention layer (Definition B.2). The Linear Forward module computes the query, key, and value vectors using a Linear Forward module on the current embeddings, changing the prefix embeddings to correspond to WQ, WK, and WK respectively. A row-wise linear operation on Scol t can space out the non-zero elements in the matrix and give us w1, xt 0 w3, xt 0 w Daux/2 1, xt 0 0 w2, xt 0 w4, xt 0 w Daux/2, xt w Daux/2+1, xt 0 w Daux/2+3, xt 0 w Daux 1, xt 0 0 w Daux/2+2, xt 0 w Daux/2+4, xt 0 w Daux, xt 0 0 0 0 0 0 0 0 0 0 Finally, a column-wise linear operation on Srow t helps to get the non-zero elements in the correct order. w1, xt w2, xt w3, xt w4, xt w Daux/2 1, xt w Daux/2, xt w Daux/2+1, xt w Daux/2+2, xt w Daux/2+3, xt w Daux/2+4, xt w Daux 1, xt w Daux, xt 0 0 0 0 0 0 ... ... ... ... ... ... 0 0 0 0 0 0 The desired output is then given by VECTORIZE({ scol t,j }}Daux j=1), which contains W xt in the first Daux coordinates. The operations that convert St to Scol t and Srow t to Srow t represents a split-wise 6-split linear operation, while the operation that converts Scol t to Srow t represents a dimension-wise 6-split linear operation. A naive linear operation on the output of the attention heads would require D2 sim parameters, while its replacement requires D2 sim/6 parameters to represent a dimension-wise 6-split linear operation, and an additional 12Dsim parameters to represent the split-wise 6-split linear operations. D. Self-attention layer We first introduce multi-head attention, generalizing single-head attention (Definition B.1). Definition D.1 (Auxiliary self-attention with Haux heads). For query, key, and value weights WQ, WK, WV RDaux Daux and bias b Q, b K, b V RDaux, a self-attention layer with Haux attention heads and a function fattn : RTaux RTaux takes a Trainable Transformer in Transformer Xjzk= 4WLyp2+7k Zc M4pib Amhmtb XTokml C0AZVs CP7iy8uke Vb1L6s X9+e V2k0e Rx GO4Bh Ow Ycrq MEd1KEBFB7h GV7hz VHOi/Puf Mxb C04+cwh/4Hz+ALGXjzk= WQxt 1 WQxt WQxt+1 WKxt+1 WKxt WKxt 1 @yt 1 @yt @yt+1 @vt 1 @vt @vt+1 et et+1 et 1 Figure 7: The gradient w.r.t. the value vectors { vt} (Definition D.2) forms the integral component for both TINT selfattention backward and descent update modules. TINT computes { vt} using a softmax attention and a linear attention layer. We first use residual connections to copy the query and key vectors to the current embeddings from the TINT Self-attention Forward module (Figure 6). The softmax attention layer re-computes the attention scores {ah t,j} between all token pairs {(t, j)} and stores them in the token embeddings. The linear attention layer uses the one-hot position embeddings of the input tokens as the query to use the transposed attention scores {ah j,t} for all token pairs {(t, j)} and use the gradients { yt} as the value vectors to compute { vt}. sequence {xt RDaux}t Taux as input and outputs {yt}t Taux, with yt = VECTORIZE({ X j Taux ah t,jvh j }h Haux). (9) ah t,j is defined as the attention score of head h between tokens at positions t and j, and is given by ah t,j = softmax(Khqh t )j. (10) Here, qt, kt, vt denote the query, key, and value vectors at each position t, computed as WQxt + b Q, WKxt + b K, and WV xt + b V respectively. In addition, qh t , kh t , vh t denote SPLITHaux(qt)h, SPLITHaux(kt)h, and SPLITHaux(vt)h respectively for all t Taux, and h Haux. Kh RTaux Daux is defined with its rows as {kh t }t Taux for all h Haux. In the discussions below, we consider a self-attention layer in the auxiliary model with parameters {WQ, b Q, WK, b K, WV , b V } that takes in input sequence x1, , x Taux and outputs y1, , y Taux, with {yt}Taux t=1 given by (9). As in the definition, qt, kt, vt denote the query, key, and value vectors for position t. We will use TINT self-attention modules in order to simulate the operations on the auxiliary s self-attention layer. To do so, we will need Hsim Haux in the corresponding TINT self-attention modules. TINT Self-attention forward module The input embedding to this module et at each position t will contain xt in its first Daux coordinates. The self-attention module can be divided into four sub-operations: Computation of (a) query vectors {qt}t T , (b) key vectors {kt}t T , (c) value vectors {vt}t T , and (d) {yt}t T using (9). Please see Figure 6. Sub-operations (a): The computation of query vector qt := WQxt + b Q at each position t is a linear operation involving parameters WQ, b Q. Thus, we can first feed in the stacked rows of WQ and b Q onto the prefix embeddings {vj}. We use a Linear Forward module (Appendix C) on the current embeddings and the prefix embeddings to get embedding eq t at each position t that contains qt in the first Daux coordinates. Sub-operations (b, c): Similar to (a), we feed in the stacked rows of the necessary parameters onto the prefix embeddings {vj}, and call two Linear Forward Modules (Appendix C) independently to get embeddings ek t , and ev t containing kt and vt respectively. We now combine the embeddings eq t, ek t , and ev t to get an embedding et that contain qt, kt, vt in the first 3Daux coordinates. Trainable Transformer in Transformer Sub-operation (d): Finally, we call a TINT self-attention module (Definition B.2) on our current embeddings {et}t T to compute {yt}t T . The query, key, and value parameters in the self-attention module contain sub-Identity blocks that pick out the relevant information from qt, kt, vt stored in et. Remark: Sub-operations (a), (b), and (c) can be represented as a single linear operation with a weight W R3Daux Daux by concatenating the rows of {WQ, WK, WV } and a bias b R3Daux that concatenates {b Q, b K, b V }. Thus, they can be simulated with a single Linear Forward Module, with W , b fed into the prefix embeddings. However, we decide to separate them in order to limit the number of prefix embeddings and the embedding size. E.g. for GPT-2, Daux = 768. This demands either a 3 increase in the embedding size in TINT or a 3 increase in the number of prefix embeddings. Hence, in order to minimize the parameter cost, we call Linear Forward Module separately to compute qt, kt, and vt at each position t. Auxiliary s backpropagation through self-attention For an auxiliary self-attention layer as defined in Definition D.1, the backpropagation layer takes in the loss gradient w.r.t. output ({ yt}t Taux) and computes the loss gradient w.r.t. input token ({ xt}t Taux). Definition D.2. [Auxiliary self-attention backpropagation] For query, key, and value weights WQ, WK, WV RDaux Daux and bias b Q, b K, b V RDaux, the backpropagation layer corresponding to a self-attention layer with Haux attention heads takes a sequence { yt RDaux}t Taux and {xt RDaux}t Taux as input and outputs { xt}t Taux, with xt = W Q qt + W K kt + W V vt, with qt = VECTORIZE({ X j ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ]}h Haux); kt = VECTORIZE({ X j ah j,tqh j [( yh j ) (vh t X j ah j,j vh j )]}h Haux); vt = VECTORIZE({ X j ah j,t yh j }h Haux) Here, qt, kt, and vt refer to query, key, and value vectors at each position t, with the attention scores {ah t,j}t,j Taux,h Haux. Complexity of true backpropagation The much-involved computation in the above operation is due to the computation of qt and kt at each position t. For the following discussion, we assume that our current embeddings et contain qt, kt, vt, in addition to the gradient yt. The computation of qt (and similarly kt) at any position t involves the following sequential computations and the necessary TINT modules. {{ yh t ) vh j }j Taux}h Haux with a TINT linear self-attention module (Definition B.2), with atleast Haux attention heads that represent the attention score between et and any other token ej, by {( yh t ) vh j }h Haux. Attention scores {ah t,j}h Haux, which requires a TINT softmax self-attention module (Definition B.2), with at least Haux heads, that uses the already present {qt, kt, vt} in the current embeddings et to re-compute the attention scores. {ah t,j( yh t ) vh j }h Haux for all j Taux by multiplying the attention scores {ah t,j}h Haux with {( yh t ) vh j }h Haux using an MLP layer (Lemma B.5). Furthermore, {P j ah t,jkh j }h Haux needs to be computed in parallel as well, with additional attention heads. yt with a TINT linear self-attention module (Definition B.2), with atleast Haux attention heads that represent the attention score between any token ej and et by {ah t,j( yh t ) vh j }h Haux, with value vectors given by {kh j P j ah t,j kh j }h Haux. The sequential computation requires the simulator to store {{ yh t ) vh j }j Taux}h Haux and {ah t,j}h Haux in the token embedding et, which requires an additional 2Taux Haux embedding dimension size. To avoid the much-involved computation for the true gradient propagation, we instead only use the gradients w.r.t. vt. Trainable Transformer in Transformer Linear Backward Module @vt 1 @vt @vt+1 NVATUf SInt Erer Oer Bfr3fq Yj S5Zxc4+g Pr8wf3Vp PRet+1 et 1 { ˆ @xt = W> Figure 8: TINT simulates the backward pass of a self-attention layer of the auxiliary model using a Linear Backward module (Figure 4). The input embeddings contain the gradient of the loss w.r.t. the value vectors ( vt) computed in Figure 7. The value matrix WV is encoded in the prefix embeddings. We call the Linear Backward module on this sequence. Approximate auxiliary self-attention backpropagation We formally extend the definition of approximate gradients { xt}Taux t=1 from Definition D.3 to multi-head attention in Definition D.3. Definition D.3. For query, key, and value weights WQ, WK, WV RDaux Daux and bias b Q, b K, b V RDaux, the approximate backpropagation layer corresponding to a self-attention layer with Haux attention heads takes a sequence { yt RDaux}t Taux and {xt RDaux}t Taux as input and outputs { xt := VECTORIZE({ xh t }h Haux)}t Taux, with c xt = W V vt, where vt = VECTORIZE({ X j ah j,t yh j }h Haux) Here, qt, kt, and vt refer to query, key, and value vectors at each position t, as defined in Definition D.1, with the attention scores {ah t,j}t,j Taux,h Haux defined in Equation (10). In the upcoming theorem, we formally show that if on a given sequence {xt}t Taux, for all token positions all the attention heads in a self-attention layer primarily attend to a single token, then the approximate gradient c xt is close to the true gradient xt at each position t. Definition D.4 (ε-hard attention head). For the Self-Attention layer of Haux heads in Definition D.1, on a given input sequence {xt}Taux t=1, an attention head h Haux is defined to be ε-hard on the input sequence, if for all positions t Taux, there exists a position t0 Taux such that ah t,t0 1 ε. Theorem D.5. With the notations in Definitions D.1 to D.3, if on a given input sequence {xt}Taux t=1, with its query, key, and value vectors {qt, kt, vt}Taux t=1, all the Haux attention heads are ε-hard for some ε > 0, then for a given sequence of gradients { yt}Taux t=1, qt 2 , kt 2 O(εB2 x B2 w By), for all t Taux, where Bx = maxt Taux xt 2, By = maxt Taux yt 2, and Bw = max{ WK 2 , WQ 2 , WV 2 , b V 2 , b K 2 , b V 2}. This implies, for each position t, c xt xt 2 O(εB2 x B3 w By). Trainable Transformer in Transformer Linear Descent Module @vt 1 @vt @vt+1 xt 1 xt xt+1 et et+1 et 1 Figure 9: TINT simulates the backward pass of the self-attention layer in the auxiliary model by employing the Linear Descent module (Figure 5). The input embeddings consist of the gradient of the loss with respect to the value vectors ( vt) computed in Figure 7. Additionally, we incorporate a residual connection to copy the input from the Self-attention Forward module (Figure 6) into xt. Before invoking the Linear Descent module, we represent the value parameters (WV ) into the prefix embeddings. TINT simulates the backward pass of a self-attention layer of the auxiliary model using a Linear Descent module (Figure 5). TINT Self-attention backpropagation module The input embeddings et contain yt in the first Daux coordinates. Since we require to re-compute the attention scores {ah t,j}j Taux,h Haux, we need to copy the query, key, and value vectors qt, kt, and vt from the TINT self-attention Forward module at each position t. Furthermore, we use the residual connection to copy the prefix embeddings {vj}, which contain the rows of WV , from the TINT self-attention Forward module. The operation can be divided into three sub-operations: Computing (a) attention scores {ah t,j}h Haux for all j Taux, at each position t, (b) vt from {ah t,j}h Haux and yt, and (c) c xt from vt. Sub-operation (a): Since, the current embeddings et contain qt, kt, we can simply call a self-attention attention module to compute the attention scores {ah t,j}h Haux for all j T and store them in the current embeddings. We further retain yt and vt for further operations using residual connections. Sub-operation (b): With the current embeddings et containing the attention scores {ah t,j}h Haux for all j T, and the gradient yt, we can compute vt using a TINT linear self-attention module with atleast Haux attention heads, that represent the attention scores between tokens et and ej for any j as {ah j,t}h Haux and use SPLITHaux( yt) as their value vectors. Sub-operation (c): And finally, the computation of c xt is identical to the backpropagation through a linear layer, with parameters WV and b V . Hence, we call a Linear backpropagation module on the current embeddings, that contain yt and the prefix embeddings that contain WV and b V . Separating sub-operations (a) and (b) The operation for computing vt in Definition D.3 looks very similar to the computation of yt in Equation (9). However, the major difference is that instead of the attention scores being {ah t,j}h Haux between token t and any token j, we need the attention scores to be {ah j,t}h Haux. Thus, unless our model allows a transpose operation on the attention scores, we need to first store them in our embeddings and then use an additional self-attention module that can pick the right attention scores between tokens using position embeddings. Please see Figure 8. Trainable Transformer in Transformer Auxiliary s value descent update Similar to the complexity of true backpropagation, the descent updates for WQ, b Q, WK, b K are quite expensive to express with the transformer layers. Hence, we focus simply on updating on WV , b V , while keeping the others fixed. Definition D.6 (Auxiliary self-attention value descent). For query, key, and value weights WQ, WK, WV RDaux Daux and bias b Q, b K, b V RDaux, the value descent layer corresponding to a self-attention layer with Haux attention heads and any function fattn : RTaux RTaux takes in a batch of gradients { yt RDaux}t Taux and inputs {xt RDaux}t Taux and updates WV , b V as follows: t Taux vtx t , b V b V η X where vt = VECTORIZE({ X j ah j,t yh j }h Haux) Here, vt refers to value vectors at each position t, as defined in Definition D.1. TINT Self-attention descent module The input embeddings contain vt in the first Daux coordinates, from the TINT self-attention backpropagation module. Furthermore, the prefix embeddings {vj} contain the stacked rows of WV and b V , continuing from the TINT self-attention backpropagation module. Since we further need the input xt to the auxiliary self-attention layer under consideration, we use residual connections to copy xt from the TINT self-attention Forward module at each position t. The updates of WV and b V are equivalent to the parameter update in a linear layer, involving gradients { vt} and input {xt}. Thus, we call a Linear descent module on the current embeddings and the prefix embeddings to get the updated value parameters. Please see Figure 9. D.1. Proofs of theorems and gradient definitions We restate the theorems and definitions, before presenting their proofs for easy referencing. Definition D.2. [Auxiliary self-attention backpropagation] For query, key, and value weights WQ, WK, WV RDaux Daux and bias b Q, b K, b V RDaux, the backpropagation layer corresponding to a self-attention layer with Haux attention heads takes a sequence { yt RDaux}t Taux and {xt RDaux}t Taux as input and outputs { xt}t Taux, with xt = W Q qt + W K kt + W V vt, with qt = VECTORIZE({ X j ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ]}h Haux); kt = VECTORIZE({ X j ah j,tqh j [( yh j ) (vh t X j ah j,j vh j )]}h Haux); vt = VECTORIZE({ X j ah j,t yh j }h Haux) Here, qt, kt, and vt refer to query, key, and value vectors at each position t, with the attention scores {ah t,j}t,j Taux,h Haux. Derivation of gradient in Definition D.2. Recalling the definition of yt from Definition D.1, yt = VECTORIZE({ X j Taux ah t,jvh j }h Haux); ah t,j = softmax(Khqh t )j, qt = WQxt + b Q kt = WKxt + b K, vt = WV xt + b V . qh t , kh t , vh t denote SPLITHaux(qt)h, SPLITHaux(kt)h, and SPLITHaux(vt)h respectively for all t Taux, and h Haux. Kh RTaux Daux is defined with its rows as {kh t }t Taux for all h Haux. We explain the proof for an arbitrary token position t. With the application of the chain rule, we have xt ) qt + ( kt xt ) kt + ( vt = W Q qt + W K kt + W V vt, Trainable Transformer in Transformer where the second step follows from the definitions of qt, kt, and vt respectively. Computation of qt: With the SPLIT operation of qt across Haux heads for the computation of yt, the computation of the backpropagated gradient qt itself needs to be split across Haux heads. Furthermore, query vector qt only affects yt, implying yt qt = 0 for any t = t. Thus, we have for any head h Haux, if yh t represents the output of attention head h, given by P j Taux ah t,jvh j , qh t = ( yh t qh t ) yh t j Taux vh j , yh t ah t,j qh t j Taux vh j , yh t e kh j ,qh t P t Taux e kh t ,qh t j Taux vh j , yh t t Taux e kh t ,qh t e kh j ,qh t e kh j ,qh t t Taux e kh t ,qh t )2 e kh j ,qh t j Taux vh j , yh t e kh j ,qh t P t Taux e kh t ,qh t e kh j ,qh t P t Taux e kh t ,qh t e kh j ,qh t P t Taux e kh t ,qh t j Taux ah t,j vh j , yh t j Taux ah t,j kh j In Equation (11), we have expanded the definition of softmax in ah t,j := softmax(Khqh t )j in order to better motivate the derivative of ah t,j w.r.t. qh t . Finally, qt is given by VECTORIZE({ qh t }h Haux). Computation of kt: Continuing as the computation of qt, we split the computation of kt across the Haux attention heads. However, unlike qt, kt affects yj for all j Taux. For any head h Haux, we follow the chain-rule step by step to get j Taux ( yh j kh t ) yh j = X j Taux aj,j vh j kh t j Taux vh t , yh j ah j,t kh t + X j Taux;j =t vh j , yh j ah j,j kh t (14) j Taux vh t , yh j e kh t ,qh j P t Taux e kh t ,qh j j Taux;j =t vh j , yh j e kh j ,qh j P t Taux e kh t ,qh j j Taux vh t , yh j e kh t ,qh j P t Taux e kh t ,qh j e kh t ,qh j P t Taux e kh t ,qh j j Taux;j =t vh j , yh j e kh j ,qh j P t Taux e kh t ,qh j ! e kh t ,qh j P t Taux e kh t ,qh j j Taux vh t , yh j (ah j,t (ah j,t)2)qh j X j Taux;j =t vh j , yh j ah j,j ah j,tqh j Trainable Transformer in Transformer j Taux ah j,t yh j , vh t X j ah j,j vh j qh j In Equation (14), we separate the inside sum into two components, since the derivative w.r.t. kh t differ for the two components, as outlined in the derivation of Equation (17) from Equation (15), and Equation (18) from Equation (16). We have skipped a step going from Equations (15) and (16) to Equations (17) and (18) due to typographical simplicity. The skipped step is extremely similar to Equation (12) in the derivation of qh t . Finally, kt is given by VECTORIZE({ kh t }h Haux). Computation of vt: Similar to the gradient computation of qt, the computation of vt needs to be split across the Haux attention heads. However, like kt, vt affects yj for all j Taux. For any head h Haux, we follow the chain-rule step by step to get j Taux ( yh j vh t ) yh j = X j Taux aj,j vh j vh t j Taux ah j,t yh j Theorem D.5. With the notations in Definitions D.1 to D.3, if on a given input sequence {xt}Taux t=1, with its query, key, and value vectors {qt, kt, vt}Taux t=1, all the Haux attention heads are ε-hard for some ε > 0, then for a given sequence of gradients { yt}Taux t=1, qt 2 , kt 2 O(εB2 x B2 w By), for all t Taux, where Bx = maxt Taux xt 2, By = maxt Taux yt 2, and Bw = max{ WK 2 , WQ 2 , WV 2 , b V 2 , b K 2 , b V 2}. This implies, for each position t, c xt xt 2 O(εB2 x B3 w By). Proof of Theorem D.5. For typographical simplicity, we discuss the proof at an arbitrary position t. Recall the definition of an ε-hard attention head from Definition D.4. An attention head is defined to be ε-hard on an input sequence {xt}Taux t=1, if for each position t, there exists a position t0 such that the attention score at,t0 1 ε. For the proof, we simply focus on qt, and the proof for kt follows like-wise. Bounds on qt: Recalling the definition of qt from Definition D.2, we have qt = VECTORIZE({ X j ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ]}h Haux). Focusing on a head h Haux, define qh t = P j ah t,j(( yh t ) vh j )[kh j P j ah t,j kh j ] and t0 Taux as the token position where the qt attends the most to, i.e. ah t,t0 1 ε and P j Taux;j =t0 ah t,j ε. Then, j ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ] ah t,t0(( yh t ) vh t0)[kh t0 X j ah t,j kh j ] + X j =t0 ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ] ah t,t0(( yh t ) vh t0)[kh t0 X j ah t,j kh j ] 2 | {z } Term1 j =t0 ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ] 2 | {z } Term2 where the final step uses a Cauchy-Schwartz inequality. We focus on the two terms separately. Trainable Transformer in Transformer 1. Term1: Focusing on kh t0 P j ah t,j kh j , we have kh t0 X j ah t,j kh j (1 at,t0)kh t0 X j =t0 ah t,j kh j 2 (1 at,t0) kh t0 2 + X j =t0 ah t,j kh j 2 ((1 at,t0) + X j =t0 ah t,j ) max j kh j 2 2ε max j kh j 2 . (19) We use a Cauchy-Schwartz inequality in the second and third steps and the attention head behavior in the final step. Hence, Term1 can now be bounded as follows: ah t,t0(( yh t ) vh t0)[kh t0 X j ah t,j kh j ] = ah t,t0 ( yh t ) vh t0 j ah t,j kh j vh t0 2 max j kh j 2 . In the final step, in addition to the bound from Equation (19), we use a Cauchy-Schwartz inequality to bound ( yh t ) vh t0 and bound the attention score ah t,t0 by 1. 2. Term2: Focusing on kh j P j ah t,j kh j for any j Taux, we have using two Cauchy-Schwartz inequalities: kh j X j ah t,j kh j j ah t,j kh j j ah t,j ) max j kh j 2 = 2 max j kh j 2 . (20) j =t0 ah t,j(( yh t ) vh j )[kh j X j ah t,j kh j ] j =t0 ah t,j ( yh t ) vh j j ah t,j kh j max j vh j 2 max j kh j 2 In the final step, in addition to the bound from Equation (20), we use a Cauchy-Schwartz inequality to bound ( yh t ) vh j and use the ε-hard behavior of the attention head to bound P j =t0 ah t,j. Combining the bounds on both terms, we have qh t vh t0 2 max j kh j 2 + 2ε yh t max j vh j 2 max j kh j 2 max j vh j 2 max j kh j 2 We bound the remaining terms as follows. 2 By, under the bounded assumption of the gradients. For any j Taux, we have kh j 2 kj 2 since kj = VECTORIZE({kh j }h Haux). Furthermore, from the defintion of the key vector kj, kj 2 = WKxj + b K 2 WK 2 xj 2 + b K 2 with a Cauchy-Schwartz inequality. Under the bounded assumptions of WK, b K and input xj, we have kj 2 Bw(1 + Bx). Trainable Transformer in Transformer Similar procedure can be followed for bounding maxj vh j 2. Thus, we have qh t maxj vh j 2 maxj kh j 2 4εB2 w(1 + Bx)2By. Bounds on c xt xt 2: From the definitons of c xt and xt from Definition D.3, we have c xt xt 2 = W K kt + W Q qt 2 WK 2 kt 2 + WQ 2 qt 2 8εB3 w(1 + Bx)2By = O(εB3 w B2 x By), where we use Cauchy-schwartz inequality in the second step. We use the assumed bounds on WQ 2 , WK 2, and the computed bounds on qt 2 , kt 2 in the pre-final step. E. Layer normalization Definition E.1. [Layer Normalization] Define a normalization function f : Rd Rd that performs f(x) = (x µ)/σ, where µ and σ are the mean and standard deviation of x, respectively. Then, layer normalization with parameters γ, b RDaux takes as input x RDaux and outputs y RDaux, which is computed as z = f(x), y = γ z + b. Definition E.2. [Exact Gradient for Layer Normalization] Using notations in Definition E.1, given the gradient of the loss w.r.t the output of the Layer Normalization y, backpropagation computes x as x = ( z Daux 1 Daux X i=1 zi z, z z)/σ z = γ y. Exact backpropagation is expensive because z, z z requires using at least two sequential MLPs. We thus approximate it with a first-order Taylor expansion, which is entry-wise close to the true gradient. Definition E.3. [ϵ-approximate Layer Normalization Gradient] With notations defined above, this layer takes y, x RDaux as input and outputs c x = 1 ϵ (f(x + ϵγ y) f(x)). In the discussions below, we consider a layer normalization layer in the auxiliary model with parameters {γ, b} that takes in input sequence x1, , x Taux and outputs y1, , y Taux, with yt = γ zt + b; zt = f(xt) for each t Taux. Since this involves a token-wise operation, we will present our constructed modules with a general token position t and the prefix tokens {vj}. We will use Wγ as a diagonal matrix in RDaux Daux, containing γ on its main diagonal. TINT Layer normalization Forward module The input embedding to this module et will contain xt in its first Daux coordinates. The layer normalization computation can be divided into two sub-operations: (a) application of f, and (b) linear computation using γ, b. We will present a TINT module for each sub-operation. We can represent the function f using a layer normalization operation itself, with its weight and bias parameters set as 1 and 0 respectively. However, since the relevant input exists only in the first Daux coordinates, the operation on the first Daux coordinates needs to be independent of the rest of the coordinates. To do so, we instead use Group normalization (Definition E.6) on et, with groups of size Daux. Now, the embedding et contains f(xt) in its first Daux coordinates. The second sub-operation can then be viewed as a Linear Layer computation, i.e. yt = Wγxt + b. Hence, we simply stack the rows of Wγ and bγ onto the prefix tokens {vj} and call the TINT Linear Forward module (Appendix C). Auxiliary s gradient backpropagation through layer normalization With the definition of layer normalization and the normalization function f in Definition E.1, the auxiliary s backpropagation operation takes in the loss gradient w.r.t. output ( y) and computes the loss gradient w.r.t. input ( x). Trainable Transformer in Transformer Definition E.2. [Exact Gradient for Layer Normalization] Using notations in Definition E.1, given the gradient of the loss w.r.t the output of the Layer Normalization y, backpropagation computes x as x = ( z Daux 1 Daux X i=1 zi z, z z)/σ z = γ y. Complexity of true backpropagation The above operation is computation heavy since it involves computing (a) z, (b) f( z), (c) z, z z, and (d) multiplying by a factor of 1 σ. z, z z in itself will require two MLP layers, following Lemma B.5. In order to reduce the number of layers, we turn to first-order Taylor expansion for approximating the above operation. Definition E.3. [ϵ-approximate Layer Normalization Gradient] With notations defined above, this layer takes y, x RDaux as input and outputs c x = 1 ϵ (f(x + ϵγ y) f(x)). The following theorem shows that the first-order gradient is a good approximation of the true gradient, and in the limit of ϵ tending to 0, the approximation error tends to 0 as well. Theorem E.4. For any ϵ > 0, and a layer normalization layer with parameters γ, b RDaux, for an input x RDaux and gradient y RDaux, c x x 2 O(ϵD3/2 aux σ 2 γ 2 2 y 2 2), where σ denotes the standard deviation of x. x, c x have been computed from x, y and ϵ using Definitions E.2 and E.3. TINT Layer normalization backpropagation module The input embeddings et contain yt at each position t in the first Daux coordinates. Since we further need the input to the auxiliary s layer normalization layer under consideration, we copy xt from the TINT Layer normalization Forward module at each position t using residual connections. Furthermore, residual connections have been used to copy the contents of the prefix tokens {vj} from the Layer normalization Forward module, which contain Wγ, b. Recall that for ease of presentation, we use zt to represent f(xt). We set ϵ as a hyperparameter and return c x as the output of this module. The computation of c x can be divided into two sub-operations: (a) computation of zt := γ yt, and (b) computation of 1 ϵ (f(xt + ϵ zt) f(xt)). We represent each sub-operation as a TINT module. To compute zt := γ yt = Wγ yt, we can observe that the required operation is identical to backpropagating through a linear layer with parameters Wγ and b. Hence, we simply call the Linear Backpropagation module on the current embeddings. We use residual connections to retain xt at each location t, and the contents of the prefix tokens {vj}. Now, the embedding et contains zt and xt. In order to backpropagate through f, we first use a linear layer to compute xt + ϵ zt and retain xt. Following the same procedure as the Forward module, we use a Group normalization layer with weight and bias parameters 1 and 0 respectively, to compute f(xt + ϵ zt) and f(xt). Finally, we use a linear layer to compute 1 ϵ (f(xt + ϵ zt) f(xt)). Auxiliary s Descent update And finally, the auxiliary s descent operation updates parameters γ, b using a batch of inputs {xt}t T and the loss gradient w.r.t. the corresponding outputs { yt}t T . Definition E.5 (Auxiliary s layer normalization descent). For parameters γ, b RDaux, descent update takes in a batch of inputs {xt RDaux}t Taux and gradients { yt RDaux}t Taux and updates the parameters as follows: t Taux yt zt; b b η X where zt represents f(xt). The update of γ involves an elementwise multiplication between yt and zt, which requires an MLP layer (Lemma B.5). With the prefix tokens containing the rows of Wγ and b, we instead consider the update of b alone with the descent update. Trainable Transformer in Transformer TINT Layer normalization descent module The input embeddings contain yt in the first Daux coordinates. The prefix tokens contain Wγ, b, which have been copied from the Forward module using residual connections. The update of b is identical to the auxiliary s descent update through a linear layer. Hence, we apply a TINT Linear descent module to the current embeddings, updating only the bias b and switching off the update to Wγ. E.1. Additional definitions We describe TINT group normalization layer below, which we use in different modules to simulate the auxiliary s layer normalization operations. Definition E.6 (TINT Daux-Group normalization). Define a normalization function f : Rd Rd that performs f(x) = (x µ)/σ, where µ and σ are the mean and standard deviation of x, respectively. Then, Daux-Group RMSnorm with parameters γTINT, b TINT RDaux takes as input x RDsim and outputs y = VECTORIZE({yh RDaux}h Dsim/Daux ), with yh = γTINT f(xh) + b TINT, where xh = SPLIT Dsim/Daux (x)h. E.2. Proof of theorems and gradient definitions We restate the theorems and definitions, before presenting their proofs for easy referencing. Definition E.2. [Exact Gradient for Layer Normalization] Using notations in Definition E.1, given the gradient of the loss w.r.t the output of the Layer Normalization y, backpropagation computes x as x = ( z Daux 1 Daux X i=1 zi z, z z)/σ z = γ y. Derivation of gradient in Definition E.2 . With the normalization function f and parameters x, b RDaux, recall from Definition E.1 that given an input x RDaux, a layer normalization layer returns y = γ z + b; z = f(x). Let µ and σ denote the mean and standard deviation of x. They can be computed as i=1 xi, σ = v u u t 1 Daux i=1 (xi µ)2. With the chain rule, we can compute x from y as follows. x) z; with z = ( y z ) y. (21) Since y = γ z + b, we have y z = Wγ, where Wγ represents a diagonal matrix with γ on the main diagonal. Thus, z = Wγ y = γ y. With z = f(x) = x µ σ , we have σ µ x (x µ) I 1 Daux 11 zz . (22) In the final step, we require µ x, which are computed as follows. µ x RDaux with its jth element given by xj = xj ( 1 Daux i=1 xi) = 1 Daux . Trainable Transformer in Transformer σ x RDaux with its jth element given by v u u t 1 Daux i=1 (xi µ)2 = 1 q PDaux i=1 (xi µ)2 i=1 (xi µ) (xi µ) = 1 q PDaux i=1 (xi µ)2 (xj µ) 1 Daux where we have re-utilized the µ x in the pre-final step. Hence, from Equation (21), I 1 Daux 11 zz z = 1 z 1 Daux 1, z 1 z, z z . We repeat Theorem E.4 for easier reference. Theorem E.4. For any ϵ > 0, and a layer normalization layer with parameters γ, b RDaux, for an input x RDaux and gradient y RDaux, c x x 2 O(ϵD3/2 aux σ 2 γ 2 2 y 2 2), where σ denotes the standard deviation of x. x, c x have been computed from x, y and ϵ using Definitions E.2 and E.3. Proof of Theorem E.4 . With the normalization function f and parameters x, b RDaux, recall from Definition E.1 that given an input x RDaux, a layer normalization layer returns y = γ z + b; z = f(x). Let µ and σ denote the mean and standard deviation of x. They can be computed as i=1 xi, σ = v u u t 1 Daux i=1 (xi µ)2. We will refer to z x from Equation (22) and the formulation of x from Equation (21) for our current proof. To recall, they are I 1 Daux 11 zz , x = ( z Using a second-order Taylor expansion of the normalization function f around x, we have f(x + ϵ z) = f(x) + ϵ f(x) = f(x) + ϵ f(x) z 2 2 1 Daux i=1 ( 1, z )2 ( zθ, z )2zθ where xθ represents x + θ z, zθ = f(xθ). The second step follows similar steps for computing z x in Equation (22). We avoid this computation since we only need to make sure that the second-order term is bounded. Furthermore, if Trainable Transformer in Transformer ϵ O σ Daux z 2 , we can show the ℓ2-norm of the second-order term can be bounded by O(ϵ2D3/2 aux σ 2 z 2 2). We avoid this computation as well. Thus, from the above formulation, we have lim ϵ 0 f(x + ϵ z) f(x) The pre-final step follows from Equation (22), where f(x) x = z x = 1 σ I 1 Daux 11 zz can be shown to be symmetric. The final step follows from the gradient formulation in Equation (21). Including the error term, we have the final bound as f(x + ϵ z) f(x) 2 O(ϵD3/2 aux σ 2 z 2 2). Using z = γ y and a Cauchy-Schwartz inequality gives the final bound. F. Activation layer Definition F.1 (Auxiliary activation). For a continuous function σact : R R, an activation layer takes x RDaux as input and outputs y = σact(x) with yi = σact(xi) for all i Daux. In the discussions below, we consider an activation layer in the auxiliary model with activation function σact that takes in input sequence x1, , x Taux and outputs y1, , y Taux, with yt = σact(xt) for each t Taux. Since this involves a token-wise operation, we will present our constructed modules with a general token position t. Since no parameters of the auxiliary model are involved in this operation, the prefix tokens {vj} contain 0 in the following modules. TINT Activation Forward module The embedding et contains xt in its first Daux indices. We simply pass the embeddings into activation σact, which returns σact(xt) in its first Daux indices. Auxiliary s backpropagation through activation With the definition in Definition F.1, the auxiliary s backpropagation takes in the loss gradient w.r.t. output ( y) and computes the loss gradient w.r.t. input ( x). We further assume that the derivative of σact is well-defined everywhere. This assumption includes non-differentiable activation functions with well-defined derivatives like Re LU. Definition F.2 (Auxiliary activation backpropagation). For a continuous function σact : R R, with a well-defined derivative σ act(x) = σact(x)/ x for each x R, the backpropagation takes y, x RDaux as input and outputs x = σ act(x) y, where σ act(x) RDaux with σ act(x)i = σ act(xi) at each i Daux. Complexity of true backpropagation The above operation is computation heavy since it involves σ act(x) y. As mentioned for the layer normalization module, the element-wise multiplication between σ act(x) and y will require an MLP module following Lemma B.5. Furthermore, it involves changing the activation function in TINT in specific modules to σ act. To circumvent this, we instead turn to a first-order Taylor approximation. Definition F.3 (Approximate Activation backpropagation). For a continuous function σact : R R and a hyperparameter ϵ, the layer takes y, x RDaux as input and outputs ϵ (σact(x + ϵ y) σact(x)) . The following theorems show that under mild assumptions on the activation function and the input, gradient pair, the first-order gradient is a good approximation to the true gradient. Trainable Transformer in Transformer Theorem F.4. For any ϵ > 0, By, Bact > 0, consider a second-order differentiable activation function σact : R R, with 2σact(x)/ (x2) bounded by Bact for each x R. Then, for any input x RDaux and gradient y RDaux with y 2 By, the following holds true: x c x 2 O(Bact B2 yϵ), where x, c x have been defined using x, y, and ϵ in Definitions F.2 and F.3. For Re LU activation, which is not second-order differentiable at 0, we instead bound the difference between x, c x by defining some form of alignment between input and gradient pair x, y. Definition F.5 ((ϵ, ρ)-alignment). Input and gradient x, y RDaux are said to be (ϵ, ρ)-aligned, if there exist a set C [Daux], with |C| (1 ρ)Daux, such that for each i in C, |xi| > ϵ |( y)i| . ϵ controls the fraction of coordinates where |xi| ϵ |( y)i|. As ϵ 0, ρ 0 as well for bounded gradients. Example F.6. For any Bmin, Bmax > 0, all inputs x that satisfy mini |xi| > Bmin , and gradients y that satisfy maxj |( y)j| Bmax, are (Bmin/Bmax, 0)-aligned. Theorem F.7. For any ϵ, ρ > 0 and By > 0, for any input x RDaux and gradient y RDaux, with y By, that are (ϵ, ρ)-aligned by Definition F.5, x c x 2 O(By p where x, c x have been defined using x, y, ϵ and σact = Re LU in Definitions F.2 and F.3. TINT Activation backpropagation module The input embeddings contain yt in the first Daux embeddings. With the requirement of the activation layer input for gradient, we copy xt from the Forward module at each position t. We set ϵ as a hyper-parameter and return c xt as the output of this module. c xt will be computed using a single-layer MLP with activation σact as follows. The first linear layer of the MLP will be used to compute xt + ϵ yt and xt. After the activation σact, the embedding et contains σact(xt + ϵ yt) and σact(xt). The final linear layer of the MLP will be used to compute 1 ϵ (σact(xt + ϵ yt) σact(xt)). F.1. Proofs of theorems We restate the theorems, before presenting their proofs for easy referencing. Theorem F.4. For any ϵ > 0, By, Bact > 0, consider a second-order differentiable activation function σact : R R, with 2σact(x)/ (x2) bounded by Bact for each x R. Then, for any input x RDaux and gradient y RDaux with y 2 By, the following holds true: x c x 2 O(Bact B2 yϵ), where x, c x have been defined using x, y, and ϵ in Definitions F.2 and F.3. Proof. The proof follows along the lines of Theorem E.4. Recall that given an input x, the activation layer outputs y = σact(x), where the function σact is applied coordinate-wise on x. Given input x and the output gradient y, the gradient w.r.t. the input is given by x = σ act(x) y, where the σ act function is also applied coordinate wise to x. We defined c x as an ϵ-approximate gradient, given by 1 ϵ (σact(x + ϵ y) σact(x)). Since both σact and σ act are applied coordinate-wise, we can look at the coordinate-wise difference between x and c x. Consider an arbitrary coordinate i Daux. Under the assumption that σact is second-order differentiable, we have ϵ (σact(xi + ϵ( y)i) σact(xi)) = σ act(xi)( y)i + 1 x2 θ ( y)2 i θdθ = σ act(xi)( y)i + O(ϵBact( y)2 i ), Trainable Transformer in Transformer where xθ represents xi + θ( y)i in the second step. In the final step, we utilize the upper bound assumption on 2σact(x) Thus, ( x)i (c x)i = O(ϵBact( y)2 i ), and so x c x 2 = O(ϵBact i=1 ( y)2 i ) = O(ϵBact y 2 2) O(ϵBact B2 y). Example F.6. For any Bmin, Bmax > 0, all inputs x that satisfy mini |xi| > Bmin , and gradients y that satisfy maxj |( y)j| Bmax, are (Bmin/Bmax, 0)-aligned. Proof. Recall the definition of (ϵ, ρ)-alignment from Definition F.5. Input and gradient x, y RDaux are said to be (ϵ, ρ)-aligned, if there exist a set C [Daux], with |C| (1 ρ)Daux, such that for each i in C, |xi| > ϵ |( y)i| . Consider an arbitrary coordinate i Daux. We have |xi| > ϵ |( y)i| for any ϵ < |xi| / |( y)i|. Under the assumption that |xi| > Bmin, and |( y)i| Bmax, a bound of Bmin/Bmax suffices. Theorem F.7. For any ϵ, ρ > 0 and By > 0, for any input x RDaux and gradient y RDaux, with y By, that are (ϵ, ρ)-aligned by Definition F.5, x c x 2 O(By p where x, c x have been defined using x, y, ϵ and σact = Re LU in Definitions F.2 and F.3. Proof. Recall that given an input x, the activation layer outputs y = σact(x), where the function σact is applied coordinatewise on x. Given input x and the output gradient y, the gradient w.r.t. the input is given by x = σ act(x) y, where the σ act function is also applied coordinate wise to x. We defined c x as an ϵ-approximate gradient, given by 1 ϵ (σact(x + ϵ y) σact(x)). Since both σact and σ act are applied coordinate-wise, we can look at the coordinate-wise difference between x and c x. For Re LU activation, σ act(x) = sign(x) for all x R \ {0}, with σ act(0) = 1 to avoid ambiguity. Going by the definition of (ϵ, ρ)-alignment of the input and gradient from Definition F.5, we have a set C with |C| (1 ρ)Daux such that for each i Daux, |xi| > ϵ |( y)i|. For all coordinates i C, we can then observe that sign(xi + ϵ( y)i) = sign(xi), implying σact(xi + ϵ( y)i) σact(xi) = ϵ( y)iσ act(xi) = ϵ( x)i For coordinates i / C, we have three possible cases: sign(xi) = sign(xi +ϵ( y)i): In this case, we can again show σact(xi +ϵ( y)i) σact(xi) = ϵ( y)iσ act(xi) = ϵ( x)i. sign(xi) = 0, sign(xi + ϵ( y)i) = 1: In this case, we have σ act(xi) = 0, and so ( x)i = 0. Additionally, sign(( y)i) = 1, and so |σact(xi + ϵ( y)i) σact(xi) ϵ( x)i| = |xi + ϵ( y)i| ϵ |( y)i| , where in the final step, we use the fact that xi < 0 and |xi| < ϵ |( y)i| . sign(xi) = 1, sign(xi + ϵ( y)i) = 0: In this case, we have σ act(xi) = 1, and so ( x)i = ( y)i. Additionally, sign(( y)i) = 0, and so |σact(xi + ϵ( y)i) σact(xi) ϵ( x)i| = | xi ϵ( y)i| |ϵ( y)i| , where in the final step, we use the fact that xi 0 and |xi| < ϵ |( y)i| . Trainable Transformer in Transformer Thus, from the above discussion, we have x c x 2 = 1 i=1 (σact(xi + ϵ( y)i) σact(xi) ϵ( x)i)2 !1/2 i/ C (σact(xi + ϵ( y)i) σact(xi) ϵ( x)i)2 !1/2 i/ C ( y)2 i max i/ C ( y)2 i p The final step includes a simple Cauchy Schwartz inequality and the desired bound comes from the assumed bound on y 2. G. Language model head Additionally, we provide a description of the gradient computation for the loss function that involves the language model head. This computation entails performing a softmax operation over the entire vocabulary. If V denotes the vocabulary set of the auxiliary model, and E R|V| Daux denotes the embedding matrix of the auxiliary model, we directly utilize the embedding matrix for the auto-regressive loss in the TINT. Additionally, we do not update the embedding matrix of the auxiliary model; instead, we solely backpropagate the gradients through the language model head. Recent work in (Kumar et al., 2022) has shown that keeping the embedding matrix fixed while updating the model can stabilize SGD. We demonstrate that the backpropagated gradients can be expressed as the combination of the language model head and a self-attention layer. Definition G.1 (KL-loss gradient through auxiliary s language model head). Given an embedding matrix E R|V | Daux, the language model head takes in input x RDaux and a target distribution q R|V | and returns gradient x RDaux, with x = E (softmax(Ex) q) . In the autoregressive loss on a sequence of tokens, the target output distribution at any position is the next occurring token. If {xun t }Taux t=1 denote the uncontextualized embeddings of a sequence of tokens after encoding them via the embedding matrix, and {xt}Taux t=1 denote their contextualized embeddings after passing through the auxiliary model, then the gradient xt at any position t can be simplified as E softmax(Ext) xun t+1. We illustrate the involved TINT module w.r.t. an arbitrary position t. TINT autoregressive loss gradient module The current embedding et contains the contextualized embedding xt in its first Daux coordinates. Furthermore, et includes the uncontextualized embedding xun t , copied from the input layer using residual connections. The prefix tokens vj are assigned a value of 0 and do not participate in the subsequent computations. The loss computation can be decomposed into two sub-operations: (a) computing yt := E softmax(Ext), and (b) calculating xt = yt xun t+1. For the first sub-operation, we use a feed-forward layer with softmax activation, with hidden and output weights E and E respectively, that takes in the first Daux of et and returns yt in the first Daux coordinates. We retain xun t using a residual connection. The final sub-operation can be interpreted as a TINT self-attention layer. With et containing both yt and xun t , we use a linear self-attention layer (Definition B.2) with two attention heads. The first attention head assigns an attention score of 1 to pairs {(t, t + 1)}t Taux 1, while assigning an attention score of 0 to the remaining pairs. At any position t, xun t is considered the value vector. The second attention head assigns an attention score of 1 to pairs {(t, t)}t Taux, while assigning an attention score of 0 to the remaining pairs. At any position t, yt is considered the value vector. The outputs of both attention heads are subsequently combined using a linear layer. Remark G.2. We conducted experiments using mean-squared loss and Quad loss (Saunshi et al., 2020), which do not necessitate softmax computations for gradient computation. As an example, in the case of mean-squared loss, if our objective is to minimize 1 2 PT t=1 xt xun t+1 2, the gradient can be computed as xt = xt xun t+1. Similarly, in the case of Quad loss, the gradient is xt = 1 |V | P i ei xun t+1. However, in all of our language model experiments (Section 5), both gradients Trainable Transformer in Transformer resulted in minimal improvement in perplexity compared to the auxiliary model. Therefore, we continue utilizing the standard KL loss for optimization. Remark G.3. For ease of implementation in the codebase, we utilize a dedicated loss module that takes in yt, xun t+1 as input and directly computes xt = yt xun t+1. H. Parameter sharing Feed-forward layer of auxiliary model: In a standard auxiliary transformer, like GPT-2, the feed-forward layer is a tokenwise operation that takes in an input x RDaux and returns y = Aσ(W x), with A RDaux 4Daux and W R4Daux Daux. A naive construction of the TINTto simulate its forward operation will have 2 Linear Forward modules (Section 3), separated by an activation. However, this requires 4 more prefix embeddings to represent the parameters, compared to other linear operations in the auxiliary transformer that use RDaux Daux weight parameters. To avoid this, we can instead break down the computation into 4 sub-feed-forward layers, each with its own parameters {{W i, Ai}}1 i 4. Here {W i}1 i 4 represent 4-shards of the rows of W , and {Ai}1 i 4 represent 4-shards of the columns of A. The forward, backward, and descent operations on these 4 sub-feed-forward layers can be effectively parallelized. For example, the forward operation of each layer can be simulated by a single TINTmodule, consisting of two Linear Forward modules and activation, changing only the prefix embeddings to correspond to {{W i, Ai}}1 i 4. I. Additional modules We describe the forward, backward, and decent update operations of additional modules, used in different model families, like LLa MA (Touvron et al., 2023) and BLOOM (Scao et al., 2022). We discuss the simulation of these modules, using similar TINT modules. I.1. Root mean square normalization (RMSnorm) The operation of RMSnorm (Zhang and Sennrich, 2019) is very similar to layer normalization. Definition I.1 (RMSnorm). For an arbitrary dimension d, define a normalization function f : Rd Rd that performs f(x) = x/RMS(x), where RMS(x) = (Pd i=1 x2 i )1/2. Then, RMSnorm with parameters γ, b RDaux takes as input x RDaux and outputs y RDaux, which is computed as z = f(x), y = γ z + b. The extreme similarity between RMSnorm and layer normalization (Definition E.1) helps us create similar TINT modules as described in Appendix E, where instead of Group normalization layers, we use Group RMSnorm layers described below. Definition I.2 (TINT Daux-Group RMSnorm). For an arbitrary dimension d, define a normalization function f : Rd Rd that performs f(x) = x/RMS(x), where RMS(x) = (Pd i=1 x2 i )1/2. Then, Daux-Group RMSnorm with parameters γTINT, b TINT RDaux takes as input x RDsim and outputs y = VECTORIZE({yh RDaux}h Dsim/Daux ), with yh = γTINT f(xh) + b TINT, where xh = SPLIT Dsim/Daux (x)h. I.2. Attention variants In order to incorporate additional attention variants, e.g. Attention with Linear Biases (ALi Bi) (Press et al., 2021), and rotary position embeddings (Su et al., 2021), we can change the definition of softmax attention layer in Definition B.2 likewise. We showcase the changes for ALi Bi. Definition I.3 (Auxiliary ALi Bi self-attention with Haux heads). For query, key, and value weights WQ, WK, WV RDaux Daux, bias b Q, b K, b V RDaux and m RHaux, ALi Bi self-attention layer with Haux attention heads and a function fattn : RTaux RTaux takes a sequence {xt RDaux}t Taux as input and outputs {yt}t Taux, with yt = VECTORIZE({ X j Taux ah t,jvh j }h Haux). (23) Trainable Transformer in Transformer ah t,j is defined as the attention score of head h between tokens at positions t and j, and is given by ah t,j = softmax(Khqh t + mhrt)j. (24) Here rt RTaux denotes a relative position vector at each position t that contains (j t) at each coordinate j Taux. Here, qt, kt, vt denote the query, key, and value vectors at each position t, computed as WQxt + b Q, WKxt + b K, and WV xt + b V respectively. In addition, qh t , kh t , vh t denote SPLITHaux(qt)h, SPLITHaux(kt)h, and SPLITHaux(vt)h respectively for all t Taux, and h Haux. Kh RTaux Daux is defined with its rows as {kh t }t Taux for all h Haux. To include operations involving ALi Bi, we modify the self-attention module of TINT to change the definition of the attention scores like Equation (24). Definition I.4 (Modified TINT self-attention for ALi Bi with Hsim heads). For parameters {W TINT Q , W TINT K , W TINT V RDsim Dsim}, {b TINT Q , b TINT K , b TINT V RDsim}, {W p Q, W p K, W p V RTsim Dsim/Hsim}, {λQ, λK, λV RHsim} and m TINT RTsim, TINT self-attention with Hsim attention heads and a function fattn : RTsim RTsim takes a sequence {bet RDsim}t Tsim as input and outputs {et RDsim}t Tsim, with et = VECTORIZE({ X j Tsim ah t,jevh j )h}h Hsim), with ah t,j = fattn(f Kh eqh t + m TINT h rt)j eqh t = SPLITH(qt)h + λQ h W p Qp TINT t ; ekh t = SPLITH(kt)h + λK h W p Kp TINT t +; evh t = SPLITH(vt)h + λV h W p v p TINT t . Here rt RTsim denotes a relative position vector at each position t that contains (j t) at each coordinate j Tsim. Here, qt, kt, vt denote the query, key, and value vectors at each position t, computed as W TINT Q bet + b TINT Q , W TINT K bet + b TINT K , and W TINT V bet + b TINT V respectively. f Kh RTsim Dsim/Hsim is defined with its rows as {ekh t }t Tsim for all h Hsim. After referring to Appendix D, we make the following modifications to the Forward, Backward, and Descent modules. In the Forward module, we incorporate the modified self-attention module to compute the attention scores using ALi Bi attention. In the Backward module, since we do not propagate gradients through the attention scores of the auxiliary model, the backpropagation formulation remains unchanged from Definition D.3 when we have access to the attention scores. Similarly, in the Descent module, we update the value matrix while keeping the query and key parameters fixed. The formulation of the gradient update remains unchanged from Definition D.6 when we have access to the attention scores. Consequently, we simply modify all the self-attention modules in the simulator to include ALi Bi attention, as defined by Definition I.4. I.3. Gated linear units (GLUs) We describe the operations of GLUs (Shazeer, 2020) using similar GLU units available to the TINT. Definition I.5. For parameters W , V , W o RDaux Daux, and biases b W , b V , b W o RDaux, a GLU layer with activation σact : R R, takes input x RDaux and outputs by RDaux, with y = (W x + b W ) σact(V x + b V ); by = W oy + b W o. Typical GLUs have 8/3 Daux as a hidden dimension (i.e. the dimension of y). We can use similar parameter-sharing techniques discussed for feed-forward layers (Appendix H) with the TINT modules presented here. Furthermore, since by can be expressed as a combination of the gated operation and a linear operation, we focus on the computation of y here. For the discussion below, we consider a GLU (without the output linear layer) in the auxiliary model, with parameters W , V , b W , b V , that takes in input sequence x1, , x T and outputs y1, , y T , with yt = (W xt+b W ) σact(V xt+b V ) for each t Tsim. Since this involves a token-wise operation, we will present our constructed modules with a general token position t and the prefix tokens {vj}. TINT GLU Forward module The embedding et contains xt in its first Daux coordinates. The output yt can be computed using three sub-operations: (a) linear operation for W xt + b W , (b) linear operation for V xt + b V , and (c) gate operation to get (W xt + b W ) σact(V xt + b V ). We use three TINT modules, representing each sub-operation. Trainable Transformer in Transformer (a) W xt + b W is a linear operation, hence we can use a TINT Linear Forward module (Appendix C) with the current embedding et and {vj} containing W , b W to get embedding et containing W xt + b W in its first Daux coordinates. (b) V xt + b V is a linear operation, hence we can similarly use a TINT Linear Forward module (Appendix C) with the embedding et and {vj} containing WV , b V to get embedding bet containing V xt + b V in its first Daux coordinates. bet and et are now combined to get an embedding et that contains W xt + b W , V xt + b V in its first 2Daux coordinates. (c) Finally, we can use a TINT GLU layer that can carry out the elementwise multiplication of W xt +b W , σact(V xt +b V ) to get yt in the first Daux coordinates. Parameter Sharing: Since (a) and (b) involve a Linear Forward module, we can additionally leverage parameter sharing to apply a single Linear Forward module for each of the two computations, changing only the prefix embeddings to correspond to W , b W , or WV , b V . Auxiliary GLU backpropagation For the GLU layer defined in Definition I.5, the backpropagation layer takes in the loss gradient w.r.t. output ( y) and computes the loss gradient w.r.t. input ( x). Definition I.6 (Auxiliary GLU backpropagation). For the weights W , V RDaux Daux , the backpropagation layer takes y RDaux as input and outputs x RDaux, with x = W c x + V f x, where c x = y σact(V x + b V ); f x = σ act(V x + b V ) y (W x + b W ). A direct computation of f x involves changing the activation function to σ act. Following a similar strategy for backpropagation through an activation layer (Appendix F), we instead use a first-order Taylor expansion to approximate f x. Definition I.7 (Auxiliary GLU approximate backpropagation). For a hyper-parameter ϵ > 0, for the weights W , V RDaux Daux , the approximate backpropagation layer takes y RDaux as input and outputs x RDaux, with x = W c x + V c f x, where c x = y σact(V x + b V ) c f x = σact(V x + b V + ϵ y) 1 ϵ (W x + b W ) σact(V x + b V ) 1 ϵ (W x + b W ). TINT GLU backpropagation module The current embedding contains yt in its first Daux coordinates. Furthermore, since we need W xt + b W and V xt + b V in the gradient computations, we copy them from the Forward module using residual connections. We discuss the computation of W c xt and V c f xt as separate sub-modules acting on the same embedding et in parallel. 1. The computation of W c xt involves two sub-operations: (a) gate operation to get c xt := yt σact(V xt + b V ), and (b) linear backward operation to get W c xt. Since for this operation, we require W , we copy the contents of the prefix embeddings containing W , b W from the Forward module. (a) Since the current embedding et contains both yt and W xt + b W , we can use a TINT GLU layer to get an embedding be(1) t that contains c xt. (b) The final linear backward operation can be performed by using a TINT Linear backpropagation module (Appendix C) with the embeddings be(1) t and the prefix embeddings. The final embedding bet contains W c xt in the first Daux coordinates. 2. The computation of V c f xt involves four sub-operations: (a) gate operation to get 1 ϵ (W xt + b W ) σact(V xt + b V + ϵ yt), (b) gate operation to get 1 ϵ (W xt + b W ) σact(V xt + b V ), (c) a linear layer to compute c f xt, (c) linear backward operation to get V c f xt. Since for this operation, we require V , we copy the contents of the prefix embeddings containing V , b V from the Forward module. (a) Since the current embedding et contains yt, V xt +b W and W xt +b W , we can use two TINT GLU layers to get an embedding e(1) t that contains both 1 ϵ (W xt+b W ) σact(V xt+b V +ϵ yt) and 1 ϵ (W xt+b W ) σact(V xt+b V ). Trainable Transformer in Transformer (b) A linear later on e(1) t can then return an embedding e(2) t containing c f xt in the first Daux coordinates. (c) The final operation can be performed by using a TINT Linear backpropagation module (Appendix C) with the embeddings be2 t and the prefix embeddings containing V , b V . The final embedding et contains V c f xt in the first Daux coordinates. After the two parallel computations, we can sum up bet and et to get an embedding et containing xt (Definition I.7) in the first Daux coordinates. Auxiliary GLU descent Finally, the auxiliary s descent updates the weight and the bias parameters using a batch of inputs {xt}t T and the loss gradient w.r.t. the corresponding outputs { yt}t T . Definition I.8 (Auxiliary GLU descent ). For weights W , V RDaux Daux and bias b W , b V RDaux, the linear descent layer takes in a batch of inputs {xt RDaux}t Taux and gradients { yt RDaux}t Taux and updates the parameters as follows: c xtx t ; b W b W η X f xtx t ; b V b V η X where c xt and f xt have been computed as Definition I.6. Due to similar concerns as gradient backpropagation, we instead use c f xt (Definition I.7) in place of f xt for each t Taux to update V , b V . TINT GLU descent module We discuss the two descent operations separately. 1. Update of W , b W : We start with the embeddings be(1) t from the backpropagation module, that contain c xt in the first Daux coordinates. For the update, we additionally require the input to the auxiliary GLU layer under consideration, and hence we copy xt from the Forward module using residual connections. Furthermore, we copy the contents of the prefix embeddings that contain W , b W from the Forward module. With both c xt and xt in the embeddings, the necessary operation turns out to be the descent update of a linear layer with parameters W , b W . That implies, we can call a TINT Linear descent module (Appendix C) on the current embeddings and prefix embeddings to get the desired update. 2. We start with the embeddings e(2) t from the backpropagation module, that contain f c xt in the first Daux coordinates. For the update, we additionally require the input to the auxiliary GLU layer under consideration, and hence we copy xt from the forward module using residual connections. Furthermore, we copy the contents of the prefix embeddings that contain V , b V from the Forward module. With both f c xt and xt in the embeddings, the necessary operation turns out to be the descent update of a linear layer with parameters V , b V . That implies we can call a TINT Linear descent module on the current embeddings and prefix embeddings to get the desired update. Parameter sharing: Since both the descent updates involve a Linear descent module, we can additionally leverage parameter sharing to apply a single TINT Linear descent module for each of the two computations, changing the input to correspond to {be(1) t } and prefix to correspond to W , b W , or the input to correspond to {e(2) t } and prefix to correspond to V , b V respectively. J. Construction of other variants of pre-trained models Though we only conduct experiments on an OPT-125M model, our construction is generally applicable to diverse variants of pre-trained language models. Table 3 highlights many types of modules and the required size and computation for each. The size of a constructed model is influenced by various factors, including the number of layers, and embedding dimension in the auxiliary. Trainable Transformer in Transformer K. Experiments Computing environment: All the experiments are conducted on a single A100 80G GPU. Hyperparameters: In the few-shot setting, we employ three different random seeds to select distinct sets of training examples. Grid search is performed for each seed to determine the optimal learning rate for both constructed models and dynamic evaluation. The learning rates considered for the learning rate hyperparameter in the descent update operations in TINT are 1e 3, 1e 4, 1e 5. 9 Additionally, we explore various layer-step combinations to allocate a fixed budget for one full forward pass. Specifically, we update the top 3 layers for 4 steps, the top 6 layers for 3 steps, or 12 layers for 1 step. These specific combinations were chosen to demonstrate the flexibility of Tin T in simulating fine-tuning for any number of layers and steps while staying within computational constraints. In all of these scenarios, TINT performs as well as fine-tuning the auxiliary model. Results of different settings. Table 4 displays the results of few-shot learning with calibration across various settings, encompassing different loss types, input formats, and layer-step configurations. Our analysis reveals that employing a label-only loss, utilizing a single-example input format, and updating all layers of the internal model for a single step yield the most favorable average result. The performance of the multi-example format is disadvantaged when dealing with tasks of long sequences such as Amazon Polarity. In general, we observe that calibrated results tend to be more consistent and stable. Table 4: Few-shot (k = 32) results with different loss types, input formats, and layer-step configurations with a fixed compute budget, with calibration. Loss Type Format Layer Step Subj AGNews SST2 CR MR MPQA Amazon Avg. Label Single 12 1 66.0(1.9) 64.7(0.2) 68.7(1.3) 69.0(0.7) 63.7(0.2) 82.8(0.5) 73.7(0.6) 69.8(0.1) Single 6 2 62.7(0.2) 66.3(0.2) 68.3(6.1) 67.2(0.2) 61.8(1.6) 81.0(3.6) 74.3(0.5) 68.8(1.4) Single 3 4 63.5(0.0) 67.2(0.8) 62.5(0.4) 68.7(1.4) 61.7(0.6) 76.8(3.3) 75.2(0.8) 67.9(0.8) Multi. 12 1 83.2(2.5) 43.7(6.6) 60.7(5.7) 70.3(6.1) 62.8(8.9) 84.2(1.6) 66.3(12.3) 67.3(0.9) Multi. 6 2 83.5(2.9) 43.2(8.4) 52.0(1.5) 70.5(6.0) 58.5(11.3) 82.0(0.4) 55.8(7.6) 63.6(2.7) Multi. 3 4 84.0(2.3) 42.3(8.4) 51.5(1.8) 68.2(4.6) 58.5(12.0) 80.2(2.1) 58.5(7.9) 63.3(3.0) Full-context Single 12 1 64.5(0.4) 65.8(0.2) 63.2(0.9) 67.3(0.5) 60.8(1.4) 73.5(0.8) 75.0(0.4) 67.2(0.1) Single 6 2 66.7(2.0) 66.0(0.4) 62.7(0.6) 70.5(2.1) 59.7(0.9) 77.7(2.2) 76.0(0.0) 68.5(0.4) Single 3 4 64.0(0.0) 65.8(0.6) 65.0(1.9) 67.3(0.2) 59.5(0.4) 74.2(1.3) 77.0(1.9) 67.5(0.8) Multi. 12 1 83.8(2.9) 41.0(10.6) 51.2(0.8) 68.0(4.5) 58.3(11.1) 79.0(3.6) 56.0(8.1) 62.5(2.8) Multi. 6 2 85.3(1.9) 41.2(10.7) 51.2(1.3) 67.7(4.5) 57.7(10.8) 79.2(3.7) 55.8(7.9) 62.6(2.6) Multi. 3 4 83.3(2.5) 41.7(11.3) 51.0(1.1) 68.2(4.7) 57.7(10.8) 79.0(3.2) 56.0(8.1) 62.4(2.8) 9When utilizing the full-context loss, the learning rates considered are 1e 5, 1e 6, and 1e 7 due to gradient summations in TINT. Trainable Transformer in Transformer Table 5: Few-shot (k = 32) results with different loss types, input formats, and layer-step configurations with a fixed compute budget, without calibration. Loss Type Format Layer Step Subj AGNews SST2 CR MR MPQA Amazon Avg. Label Single 12 1 63.3(0.2) 65.7(0.2) 71.3(0.6) 65.0(1.4) 70.7(0.9) 65.0(0.0) 76.7(0.2) 68.2(0.1) Single 6 2 63.5(0.0) 65.2(0.5) 73.3(1.3) 68.5(3.7) 71.3(0.2) 66.0(0.0) 77.5(0.4) 69.3(0.3) Single 3 4 64.2(0.2) 66.5(1.1) 73.2(0.6) 75.7(0.5) 72.0(0.0) 83.2(1.0) 78.0(0.4) 73.2(0.1) Multi. 12 1 64.5(7.8) 35.5(7.4) 56.8(9.7) 63.0(6.7) 58.7(8.9) 75.2(10.8) 62.2(8.3) 59.4(0.6) Multi. 6 2 77.7(7.0) 35.5(7.4) 57.0(9.9) 60.0(6.3) 52.3(2.1) 58.5(6.1) 55.8(7.9) 56.7(2.6) Multi. 3 4 67.5(11.5) 38.5(8.2) 55.3(5.2) 67.0(3.5) 61.0(8.0) 65.2(11.2) 62.5(8.9) 59.6(1.3) Full-context Single 12 1 65.5(1.1) 66.5(0.0) 70.7(0.2) 64.8(0.5) 72.0(1.4) 67.0(0.0) 76.5(0.0) 69.0(0.3) Single 6 2 64.7(0.6) 66.2(0.2) 71.2(0.2) 65.3(0.6) 71.5(0.4) 67.0(0.0) 76.7(0.2) 68.9(0.0) Single 3 4 64.2(0.2) 66.2(0.2) 71.3(0.2) 64.7(0.2) 71.0(0.0) 67.0(0.0) 76.5(0.0) 68.7(0.0) Multi. 12 1 62.2(7.5) 33.8(8.3) 52.2(3.1) 52.8(4.0) 50.8(1.2) 55.8(4.3) 55.3(7.2) 51.9(2.2) Multi. 6 2 60.0(5.5) 33.7(8.4) 50.8(1.2) 52.2(2.4) 50.2(0.2) 54.3(2.5) 55.0(6.7) 50.9(1.8) Multi. 3 4 58.7(4.9) 33.7(8.4) 50.8(1.2) 51.3(1.9) 50.0(0.0) 54.3(2.5) 55.3(7.2) 50.6(2.0)