# depthadaptive_transformer__08c19145.pdf Published as a conference paper at ICLR 2020 DEPTH-ADAPTIVE TRANSFORMER Maha Elbayad Univ. Grenoble Alpes Jiatao Gu, Edouard Grave, Michael Auli Facebook AI Research State of the art sequence-to-sequence models for large scale tasks perform a fixed number of computations for each input sequence regardless of whether it is easy or hard to process. In this paper, we train Transformer models which can make output predictions at different stages of the network and we investigate different ways to predict how much computation is required for a particular sequence. Unlike dynamic computation in Universal Transformers, which applies the same set of layers iteratively, we apply different layers at every step to adjust both the amount of computation as well as the model capacity. On IWSLT German-English translation our approach matches the accuracy of a well tuned baseline Transformer while using less than a quarter of the decoder layers. 1 INTRODUCTION The size of modern neural sequence models (Gehring et al., 2017; Vaswani et al., 2017; Devlin et al., 2019) can amount to billions of parameters (Radford et al., 2019). For example, the winning entry of the WMT 19 news machine translation task in English-German used an ensemble totaling two billion parameters (Ng et al., 2019). While large models are required to do better on hard examples, small models are likely to perform as well on easy ones, e.g., the aforementioned ensemble is probably not required to translate a short phrase such as "Thank you". However, current models apply the same amount of computation regardless of whether the input is easy or hard. In this paper, we propose Transformers which adapt the number of layers to each input in order to achieve a good speed-accuracy trade off at inference time. We extend Graves (2016; ACT) who introduced dynamic computation to recurrent neural networks in several ways: we apply different layers at each stage, we investigate a range of designs and training targets for the halting module and we explicitly supervise through simple oracles to achieve good performance on large-scale tasks. Universal Transformers (UT) rely on ACT for dynamic computation and repeatedly apply the same layer (Dehghani et al., 2018). Our work considers a variety of mechanisms to estimate the network depth and applies a different layer at each step. Moreover, Dehghani et al. (2018) fix the number of steps for large-scale machine translation whereas we vary the number of steps to demonstrate substantial improvements in speed at no loss in accuracy. UT uses a layer which contains as many weights as an entire standard Transformer and this layer is applied several times which impacts speed. Our approach does not increase the size of individual layers. We also extend the resource efficient object classification work of Huang et al. (2017) and Bolukbasi et al. (2017) to structured prediction where dynamic computation decisions impact future computation. Related work from computer vision includes Teerapittayanon et al. (2016); Figurnov et al. (2017) and Wang et al. (2018) who explored the idea of dynamic routing either by exiting early or by skipping layers. We encode the input sequence using a standard Transformer encoder to generate the output sequence with a varying amount of computation in the decoder network. Dynamic computation poses a challenge for self-attention because omitted layers in prior time-steps may be required in the future. We experiment with two approaches to address this and show that a simple approach works well ( 2). Next, we investigate different mechanisms to control the amount of computation in the decoder network, either for the entire sequence or on a per-token basis. This includes multinomial and binomial classifiers supervised by the model likelihood or whether the argmax is already correct as well as simply thresholding the model score ( 3). Experiments on IWSLT14 German-English Work done during an internship at Facebook AI Research. Published as a conference paper at ICLR 2020 State Copied state Cn Classifier Copy Decoder depth Decoding step (a) Aligned training Decoder depth Decoding step (b) Mixed training Figure 1: Training regimes for decoder networks able to emit outputs at any layer. Aligned training optimizes all output classifiers Cn simultaneously assuming all previous hidden states for the current layer are available. Mixed training samples M paths of random exits at which the model is assumed to have exited; missing previous hidden states are copied from below. translation (Cettolo et al., 2014) as well as WMT 14 English-French translation show that we can match the performance of well tuned baseline models at up to 76% less computation ( 4). 2 ANYTIME STRUCTURED PREDICTION We first present a model that can make predictions at different layers. This is known as anytime prediction for computer vision models (Huang et al., 2017) and we extend it to structured prediction. 2.1 TRANSFORMER WITH MULTIPLE OUTPUT CLASSIFIERS We base our approach on the Transformer sequence-to-sequence model (Vaswani et al., 2017). Both encoder and decoder networks contain N stacked blocks where each has several sub-blocks surrounded by residual skip-connections. The first sub-block is a multi-head dot-product self-attention and the second a position-wise fully connected feed-forward network. For the decoder, there is an additional sub-block after the self-attention to add source context via another multi-head attention. Given a pair of source-target sequences (x, y), x is processed with the encoder to give representations s = (s1, . . . , s|x|). Next, the decoder generates y step-by-step. For every new token yt input to the decoder at time t, the N decoder blocks process it to yield hidden states (hn t )1 n N: h0 t = embed(yt), hn t = blockn(hn 1 t , s), (1) where blockn is the mapping associated with the nth block and embed is a lookup table. The output distribution for predicting the next token is computed by feeding the activations of the last decoder layer h N t into a softmax normalized output classifier W: p(yt+1|h N t ) = softmax(Wh N t ) (2) Standard Transformers have a single output classifier attached to the top of the decoder network. However, for dynamic computation we need to be able to make predictions at different stages of the network. To achieve this, we attach output classifiers Cn parameterized by Wn to the output hn t of each of the N decoder blocks: n, p(yt+1|hn t ) = softmax(Wnhn t ) (3) The classifiers can be parameterized independently or we can share the weights across the N blocks. 2.2 TRAINING MULTIPLE OUTPUT CLASSIFIERS Dynamic computation enables the model to use any of the N exit classifiers instead of just the final one. Some of our models can choose a different output classifier at each time-step which results in an exponential number of possible output classifier combinations in the sequence length. Published as a conference paper at ICLR 2020 We consider two possible ways to train the decoder network (Figure 1). Aligned training optimizes all classifiers simultaneously and assumes all previous hidden states required by the self-attention are available. However, at test time this is often not the case when we choose a different exit for every token which leads to misaligned states. Instead, mixed training samples several sequences of exits for a given sentence and exposes the model to hidden states from different layers. Generally, for a given output sequence y, we have a sequence of chosen exits (n1, . . . , n|y|) and we denote the block at which we exit at time t as nt. 2.2.1 ALIGNED TRAINING Aligned training assumes all hidden states hn 1 1 , . . . , hn 1 t are available in order to compute selfattention and it optimizes N loss terms, one for each exit (Figure 1a): LLn t = log p(yt|hn t 1), LLn = t=1 LLn t , Ldec(x, y) = 1 P n=1 ωn LLn . (4) The compound loss Ldec(x, y) is a weighted average of N terms w.r.t. to (ω1, . . . ωN). We found that uniform weights achieve better BLEU compared to other weighing schemes (c.f. Appendix A). At inference time, not all time-steps will have hidden states for the current layer since the model exited early. In this case, we simply copy the last computed state to all upper layers, similar to mixed training ( 2.2.2). However, we do apply layer-specific key and value projections to the copied state. 2.2.2 MIXED TRAINING Aligned training assumes that all hidden states of the previous time-steps are available but this assumption is unrealistic since an early exit may have been chosen previously. This creates a mismatch between training and testing. Mixed training reduces the mismatch by training the model to use hidden states from different blocks of previous time-steps for self-attention. We sample M different exit sequences (n(m) 1 , . . . n(m) |y| ) 1 m M and evaluate the following loss: LL(n1, . . . , n|y|) = t=1 log p(yt|hnt t 1), Ldec(x, y) = 1 m=1 LL(n(m) 1 , . . . , n(m) |y| ). (5) When nt < N, we copy the last evaluated hidden state hn t to the subsequent layers so that the self-attention of future time steps can function as usual (see Figure 1b). 3 ADAPTIVE DEPTH ESTIMATION We present a variety of mechanisms to predict the decoder block at which the model will stop and output the next token, or when it should exit to achieve a good speed-accuracy trade-off. We consider two approaches: sequence-specific depth decodes all output tokens using the same block ( 3.1) while token-specific depth determines a separate exit for each individual token ( 3.2). We model the distribution of exiting at time-step t with a parametric distribution qt where qt(n) is the probability of computing block1, . . . , blockn and then emitting a prediction with Cn. The parameters of qt are optimized to match an oracle distribution q t with cross-entropy: Lexit(x, y) = X t H(q t (x, y), qt(x)) (6) The exit loss (Lexit) is back-propagated to the encoder-decoder parameters. We simultaneously optimize the decoding loss (Eq. (4)) and the exit loss (Eq. (6)) balanced by a hyper-parameter α to ensure that the model maintains good generation accuracy. The final loss takes the form: L(x, y) = Ldec(x, y) + αLexit(x, y), (7) In the following we describe for each approach how the exit distribution qt is modeled (illustrated in Figure 2) and how the oracle distribution q t is inferred. Published as a conference paper at ICLR 2020 State Copied state 1 Halting decision Cn Classifier Copy Decoder depth Decoding step (a) Sequence-specific depth Decoder depth Decoding step (b) Token-specific - Multinomial Decoder depth Decoding step (c) Token-specific - Geometric-like Figure 2: Variants of the adaptive depth prediction classifiers. Sequence-specific depth uses a multinomial classifier to choose an exit for the entire output sequence based on the encoder output s (2a). It then outputs a token at this depth with classifier Cn. The token-specific multinomial classifier determines the exit after the first block and proceeds up to the predicted depth before outputting the next token (2b). The token geometric-like classifier (2c) makes a binary decision after every block to dictate whether to continue (C) to the next block or to stop (S) and emit an output distribution. 3.1 SEQUENCE-SPECIFIC DEPTH: For sequence-specific depth, the exit distribution q and the oracle distribution q are independent of the time-step so we drop subscript t. We condition the exit on the source sequence by feeding the average s of the encoder outputs to a multinomial classifier: t st, q(n|x) = softmax(Whs + bh) RN, (8) where Wh and bh are the weights and biases of the halting mechanism. We consider two oracles to determine which of the N blocks should be chosen. The first is based on the sequence likelihood and the second looks at an aggregate of the correctly predicted tokens at each block. Likelihood-based: This oracle is based on the likelihood of the entire sequence after each block and we optimize it with the Dirac delta centered around the exit with the highest sequence likelihood. q (x, y) = δ(arg max n LLn). We add a regularization term to encourage lower exits that achieve good likelihood: q (x, y) = δ(arg max n LLn λn). (9) Correctness-based: Likelihood ignores whether the model already assigns the highest score to the correct target. Instead, this oracle chooses the lowest block that assigns the largest score to the correct prediction. For each block, we count the number of correctly predicted tokens over the sequence and choose the block with the most number of correct tokens. A regularization term controls the trade-off between speed and accuracy. Cn = #{t | yt = arg max y p(y|hn t 1)}, q (x, y) = δ(arg max n Cn λn). (10) Oracles based on test metrics such as BLEU are feasible but expensive to compute since we would need to decode every training sentence N times. We leave this for future work. 3.2 TOKEN-SPECIFIC DEPTH: The token-specific approach can choose a different exit at every time-step. We consider two options for the exit distribution qt at time-step t: a multinomial with a classifier conditioned on the first decoder hidden state h1 t and a geometric-like where an exit probability χn t is estimated after each block based on the activations of the current block hn t . Published as a conference paper at ICLR 2020 Multinomial qt: qt(n|x, y 0.5 for geometric-like classifiers. 4.2 TRAINING MULTIPLE OUTPUT CLASSIFIERS We first compare the two training regimes for our model ( 2.2). Aligned training performs selfattention on aligned states ( 2.2.1) and mixed training exposes self-attention to hidden states from different blocks ( 2.2.2). We compare the two training modes when choosing either a uniformly sampled exit or a fixed exit n = 1, . . . , 6 at inference time for every time-step. The sampled exit experiment tests the robustness to mixed hidden states and the fixed exit setup simulates an ideal setting where all previous states are available. As baselines we show six separate standard Transformers with N [1..6] decoder blocks. All models are trained with an equal number of updates and mixed training with M=6 paths is most comparable to aligned training since the number of losses per sample is identical. Table 1 shows that aligned training outperforms mixed training both for fixed exits as well as for randomly sampled exits. The latter is surprising since aligned training never exposes the self-attention mechanism to hidden states from other blocks. We suspect that this is due to the residual connections which copy features from lower blocks to subsequent layers and which are ubiquitous in Transformer models ( 2). Aligned training also performs very competitively to the individual baseline models. Aligned training is conceptually simple and fast. We can process a training example with N exits in a single forward/backward pass while M passes are needed for mixed training. In the remaining paper, we use the aligned mode to train our models. Appendix A reports experiments with weighing the various output classifiers differently but we found that a uniform weighting scheme worked well. On our largest setup, WMT 14 English-French, the training time of an aligned model with six output classifiers increases only marginally by about 1% compared to a baseline with a single output classifier keeping everything else equal. Published as a conference paper at ICLR 2020 (a) Token-specfic 1 2 3 4 5 6 Average exit (AE) Baseline Aligned Tok-C Multinomial Tok-LL Multinomial Tok-C Geometric-like Tok-LL Geometric-like (b) Sequence-specific depth 1 2 3 4 5 6 Average exit (AE) Baseline Aligned Seq-LL Seq-C (c) Confidence thresholding 1 2 3 4 5 6 Average exit (AE) Baseline Aligned Tok-C Geometric-like Tok-LL Geometric-like Confidence thresholding Figure 3: Trade-off between speed (average exit or AE) and accuracy (BLEU) for depth-adaptive methods on the IWSLT14 De-En test set. 0 0.2 0.4 1 Regularization parameter λ Average exit (AE) (a) Effect of λ on AE 0 1 2 3 3.5 RBF kernel width σ Average exit (AE) λ = 0.01 λ = 0.05 (b) Effect of σ on AE Figure 4: Effect of the hyper-parameters σ and λ on the average exit (AE) measured on the valid set of IWSLT 14 De-En. 4.3 ADAPTIVE DEPTH ESTIMATION Next, we train models with aligned states and compare adaptive depth classifiers in terms of BLEU as well as computational effort. We measure the latter as the average exit per output token (AE). As baselines we use again six separate standard Transformers with N [1..6] with a single output classifier. We also measure the performance of the aligned mode trained model for fixed exits n [1..6]. For the adaptive depth token-specific models (Tok), we train four combinations: likelihoodbased oracle (LL) + geometric-like, likelihood-based oracle (LL) + multinomial, correctness based oracle (C) + geometric-like and correctness-based oracle (C) + multinomial. Sequence-specific models (Seq) are trained with the correctness oracle (C) and the likelihood oracle (LL) with different values for the regularization weight λ. All parameters are tuned on the valid set and we report results on the test set for a range of average exits. Figure 3 shows that the aligned model (blue line) can match the accuracy of a standard 6-block Transformer (black line) at half the number of layers (n = 3) by always exiting at the third block. The aligned model outperforms the baseline for n = 2, . . . , 6. For token specific halting mechanisms (Figure 3a) the geometric-like classifiers achieves a better speed-accuracy trade-off than the multinomial classifiers (filled vs. empty triangles). For geometriclike classifiers, the correctness oracle outperforms the likelihood oracle (Tok-C geometric-like vs. Tok-LL geometric-like) but the trend is less clear for multinomial classifiers. At the sequence-level, likelihood is the better oracle (Figure 3b). The rightmost Tok-C geometric-like point (σ = 0, λ = 0.1) achieves 34.73 BLEU at AE = 1.42 which corresponds to similar accuracy as the N = 6 baseline at 76% fewer decoding blocks. Published as a conference paper at ICLR 2020 (a) BLEU vs. AE (test) 1 2 3 4 5 6 41.5 Average exit (AE) Baseline Aligned Seq-LL Tok-C Poisson Tok-LL Poisson Confidence thresholding (b) BLEU vs. FLOPs (test) Average FLOPs Baseline Aligned Seq-LL Tok-C Poisson Tok-LL Poisson Confidence thresholding Figure 5: Speed and accuracy on the WMT 14 English-French benchmark (c.f. Figure 3). The best accuracy of the aligned model is 34.95 BLEU at exit 5 and the best comparable Tok-C geometric-like configuration achieves 34.99 BLEU at AE = 1.97, or 61% fewer decoding blocks. When fixing the budget to two decoder blocks, Tok-C geometric-like with AE = 1.97 achieves BLEU 35, a 0.64 BLEU improvement over the baseline (N = 2) and aligned which both achieve BLEU 34.35. Confidence thresholding (Figure 3c) performs very well but cannot outperform Tok-C geometriclike. Ablation of hyper-parameters In this section, we look at the effect of the two main hyperparameters on IWSLT 14 De-En: λ the regularization scale (c.f. Eq. (9)), and the RBF kernel width σ used to smooth the scores (c.f. Eq. (15)). We train Tok-LL Geometric-like models and evaluate them with their default thresholds (exit if χn t > 0.5). Figure 4a shows that higher values of λ lead to lower exits. Figure 4b shows the effect of σ for two values of λ. In both curves, we see that wider kernels favor higher exits. 4.4 SCALING THE ADAPTIVE-DEPTH MODELS Finally, we take the best performing models form the IWSLT benchmark and test them on the large WMT 14 English-French benchmark. Results on the test set (Figure 5a) show that adaptive depth still shows improvements but that they are diminished in this very large-scale setup. Confidence thresholding works very well and sequence-specific depth approaches improve only marginally over the baseline. Tok-LL geometric-like can match the best baseline result of BLEU 43.4 (N = 6) by using only AE = 2.40 which corresponds to 40% of the decoder blocks; the best aligned result of BLEU 43.6 can be matched with AE = 3.25. In this setup, Tok-LL geometric-like slightly outperforms the Tok-C counterpart. Confidence thresholding matches the accuracy of the N=6 baseline with AE 2.5 or 59% fewer decoding blocks. However, confidence thresholding requires computing the output classifier at each block to determine whether to halt or continue. This is a large overhead since output classifiers predict 44k types for this benchmark ( 4.1). To better account for this, we measure the average number of FLOPs per output token (details in Appendix B). Figure 5b shows that the Tok-LL geometric-like approach provides a better trade-off when the overhead of the output classifiers is considered. 4.5 QUALITATIVE RESULTS The exit distribution for a given sample can give insights into what a Depth-Adaptive Transformer decoder considers to be a difficult task. In this section, for each hypothesis ey, we will look at the sequence of selected exits (n1, . . . , n|ey|) and the probability scores (p1, . . . p|ey|) with pt = p(eyt|hnt t 1) i.e. the confidence of the model in the sampled token at the selected exit. Figures 6 and 7 show hypotheses from the WMT 14 En-Fr and IWSLT 14 De-En test sets, respectively. For each hypothesis we state the exits and the probability scores. In Figure 6a, predicting Published as a conference paper at ICLR 2020 1 2 3 4 5 6 (a) Src: Chi@@rac , the Prime Minister , was there . Ref: Chi@@rac , Premier ministre , est là . changements 1 2 3 4 5 6 (b) Src: But passengers shoul@@dn t expect changes to happen immediately . Ref: Mais les passagers ne devraient pas s attendre à des changements immédiats . Figure 6: Examples from the WMT 14 En-Fr test set (newstest14) with Tok-LL geometric-like depth estimation. Token exits are in blue and confidence scores are in gray. The @@ are due to BPE or subword tokenization. For each example the source (Src) and the reference (Ref) are provided in the caption. 1 2 3 4 5 6 (a) Src: diesen trick können sie ihren freunden und nachbarn vor@@führen . danke . Ref: there is a trick you can do for your friends and neighb@@ors . thanks . Figure 7: Example from the IWSLT 14 De-En test set with Tok-LL geometric-like depth estimation. See Figure 6 for more details. présent (meaning present ) is hard. A straightforward translation is était là but the model chooses present which is also appropriate. In Figure 6b, the model uses more computation to predict the definite article les since the source has omitted the article for passengers . A clear trend in both benchmarks is that the model requires less computation near the end of decoding to generate the end of sequence marker and the preceding full-stop when relevant. In Figure 8, we show the distribution of the exits at the beginning and near the end of test set hypotheses. We consider the beginning of a sequence to be the first 10% of tokens and the end as the last 10% of tokens. The exit distributions are shown for three models on WMT 14 En-Fr: Model1 has an average exit of AE = 2.53, Model2 exits at AE = 3.79 on average and Model3 with AE = 4.68. Within the same models, deep exits late are used at the beginning of the sequence and early exits are selected near the end. For heavily regularized models such as Model1 with AE = 2.53, the disparity between beginning and end is less severe as the model exits early most of the time. Model2 and Model3 are less regularized (higher AE) and tend to use late exits at the beginning of the sequence and early exits near the end. On the other hand, the more regularized Model1 with AE = 2.53 exits 1 2 3 4 5 6 1 2 3 4 5 6 Exit 1 2 3 4 5 6 Model1 Model2 Model3 Figure 8: WMT 14 En-Fr test set: exit distributions in the beginning (relative-position: rpos<0.1) and near the end (rpos>0.9) of the hypotheses of three models. Published as a conference paper at ICLR 2020 1 2 3 4 5 6 0.0-0.1 0.1-0.2 0.2-0.3 0.3-0.4 0.4-0.5 0.5-0.6 0.6-0.7 0.7-0.8 0.8-0.9 0.9-1.0 1 2 3 4 5 6 1 2 3 4 5 6 Model1 Model2 Model3 Figure 9: Joint histogram of the exits and the confidence scores for 3 Tok-LL geometric-like models on newstest14. early most of the time. There is also a correlation between the model probability and the amount of computation, particularly in models with low AE . Figure 9 shows the joint histogram of the scores and the selected exit. For both Model1 and Model2, low exits (n 2) are used in the high confidence range [0.8 1] and high exits (n 4) are used in the low-confidence range [0 0.5]. Model3 has a high average exit (AE = 4.68) so most tokens exit late, however, in low confidence ranges the model does not exit earlier than n = 5. 5 CONCLUSION We extended anytime prediction to the structured prediction setting and introduced simple but effective methods to equip sequence models to make predictions at different points in the network. We compared a number of different mechanisms to predict the required network depth and find that a simple correctness based geometric-like classifier obtains the best trade-off between speed and accuracy. Results show that the number of decoder layers can be reduced by more than three quarters at no loss in accuracy compared to a well tuned Transformer baseline. ACKNOWLEDGMENTS We thank Laurens van der Maaten for fruitful comments and suggestions. Tolga Bolukbasi, Joseph Wang, Ofer Dekel, and Venkatesh Saligrama. Adaptive neural networks for efficient inference. In Proc. of ICML, 2017. M. Cettolo, J. Niehues, S. Stüker, L. Bentivogli, and M. Federico. Report on the 11th iwslt evaluation campaign. In IWSLT, 2014. Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Lukasz Kaiser. Universal transformers. In Proc. of ICLR, 2018. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: Pre-training of deep bidirectional transformers for language understanding. In Proc. of NAACL, 2019. Sergey Edunov, Myle Ott, Michael Auli, David Grangier, and Marc Aurelio Ranzato. Classical structured prediction losses for sequence to sequence learning. In Proc. of NAACL, 2018. Michael Figurnov, Artem Sobolev, and Dmitry P. Vetrov. Probabilistic adaptive computation time. In Ar Xiv preprint, 2017. Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N Dauphin. Convolutional sequence to sequence learning. In Proc. of ICML, 2017. Alex Graves. Adaptive computation time for recurrent neural networks. In Ar Xiv preprint, 2016. Published as a conference paper at ICLR 2020 Gao Huang, Danlu Chen, Tianhong Li, Felix Wu, Laurens van der Maaten, and Kilian Q Weinberger. Multi-scale dense networks for resource efficient image classification. In Proc. of ICLR, 2017. D. Kingma and J. Ba. Adam: A method for stochastic optimization. In Proc. of ICLR, 2015. Nathan Ng, Kyra Yee, Alexei Baevski, Myle Ott, Michael Auli, and Sergey Edunov. Facebook fair s wmt19 news translation task submission. In Proc. of WMT, 2019. Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, and Michael Auli. Fairseq: A fast, extensible toolkit for sequence modeling. In Proc. of NAACL, 2019. K. Papineni, S. Roukos, T. Ward, and W.-J. Zhu. BLEU: a method for automatic evaluation of machine translation. In Proc. of ACL, 2002. Alec Radford, Jeff Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language models are unsupervised multitask learners. In Technical report, Open AI., 2019. R. Sennrich, B. Haddow, and A. Birch. Neural machine translation of rare words with subword units. In Proc. of ACL, 2016. Surat Teerapittayanon, Bradley Mc Danel, and Hsiang-Tsung Kung. Branchynet: Fast inference via early exiting from deep neural networks. In ICPR, 2016. A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. Gomez, L. Kaiser, and I. Polosukhin. Attention is all you need. In Proc. of Neur IPS, 2017. Xin Wang, Fisher Yu, Zi-Yi Dou, Trevor Darrell, and Joseph E Gonzalez. Skipnet: Learning dynamic routing in convolutional networks. In Proc. of ECCV, 2018. Published as a conference paper at ICLR 2020 APPENDIX A LOSS SCALING In this section we experiment with different weights for scaling the output classifier losses. Instead of uniform weighting, we bias towards specific output classifiers by assigning higher weights to their losses. Table 2 shows that weighing the classifiers equally provides good results. Uniform n = 1 n = 2 n = 3 n = 4 n = 5 n = 6 Average Baseline - 34.2 35.3 35.6 35.7 35.6 35.9 35.4 ωn = 1 35.5 34.1 35.5 35.8 36.1 36.1 36.2 35.6 ωn = n 35.3 32.2 35.0 35.8 36.0 36.2 36.3 35.2 ωn = n 35.4 33.3 35.2 35.8 35.9 36.1 36.1 35.4 ωn = 1/ n 35.6 34.5 35.4 35.7 35.8 35.8 35.9 35.5 ωn = 1/n 35.3 34.7 35.3 35.5 35.7 35.8 35.8 35.5 (a) IWSLT De-En - Valid Uniform n = 1 n = 2 n = 3 n = 4 n = 5 n = 6 Average Baseline - 33.7 34.6 34.6 34.6 34.6 34.8 34.5 ωn = 1 34.4 33.2 34.4 34.8 34.9 35.0 34.9 34.5 ωn = n 34.2 31.4 33.8 34.7 34.8 34.8 34.9 34.1 ωn = n 34.4 32.5 34.1 34.8 34.9 35.0 35.1 34.4 ωn = 1/ n 34.6 33.7 34.3 34.6 34.8 34.8 34.9 34.5 ωn = 1/n 34.2 33.8 34.3 34.5 34.6 34.7 34.7 34.4 (b) IWSLT De-En - Test Table 2: Aligned training with different weights (ωn) on IWSLT De-En. For each model we report BLEU on the dev set evaluated with a uniformly sampled exit n U([1..6]) for each token and a fixed exit n [1..6] throughout the sequence. The average corresponds to the average BLEU over the fixed exits. Gradient scaling Adding intermediate supervision at different levels of the decoder results in richer gradients for lower blocks compared to upper blocks. This is because earlier layers affect more loss terms in the compound loss of Eq. (4). To balance the gradients of each block in the decoder, we scale up the gradients of each loss term ( LLn) when it is updating the parameters of its associated block (blockn with parameters θn) and revert it back to its normal scale before back-propagating it to the previous blocks. Figure 10 and Algorithm 1 illustrate this gradient scaling procedure. The θn are updated with γn-amplified gradients from the block s supervision and (N n) gradients from the subsequent blocks. We choose γn = γ(N n) to control the ratio γ:1 as the ratio of the block supervision to the subsequent blocks supervisions. Table 3 shows that gradient scaling can benefit the lowest layer at the expense of higher layers. However, no scaling generally works very well. hn 1 blockn; θn blockn+1; θn+1 blockn+2; θn+2 . . . block N; θN γN LLN γn+2 LLn+2 γn+1 LLn+1 γn LLn Figure 10: Illustration of gradient scaling. Published as a conference paper at ICLR 2020 Algorithm 1 Pseudo-code for gradient scaling (illustrated for a single step t) 1: for n 1..N do 2: hn t = blockn(hn 1 t ) 3: p(yt+1|hn t ) = softmax(Wnhn t ) 4: p(yt+1|hn t ) = SCALE_GRADIENT(p(yt+1|hn t ), γn) 5: if n < N then hn t = SCALE_GRADIENT(hn t , 1 γn+1 ) 6: end for 7: function SCALE_GRADIENT(Tensor x, scale γ) 8: return γx + (1 γ)STOP_GRADIENT(x) 9: STOP_GRADIENT in Py Torch with x.detach(). 10: end function Uniform n = 1 n = 2 n = 3 n = 4 n = 5 n = 6 Average Baseline - 34.2 35.3 35.6 35.7 35.6 35.9 35.4 35.5 34.1 35.5 35.8 36.1 36.1 36.2 35.6 γ = 0.3 35.1 33.7 34.7 35.3 35.7 35.8 36.0 35.2 γ = 0.5 35.4 34.8 35.4 35.6 35.6 35.7 35.6 35.4 γ = 0.7 34.9 34.6 35.1 35.1 35.2 35.4 35.3 35.1 γ = 0.9 34.9 34.8 35.3 35.3 35.3 35.4 35.5 35.3 γ = 1.1 35.1 34.9 35.2 35.3 35.3 35.3 35.3 35.2 (a) IWSLT De-En - Valid Uniform n = 1 n = 2 n = 3 n = 4 n = 5 n = 6 Average Baseline - 33.7 34.6 34.6 34.6 34.6 34.8 34.5 34.4 33.2 34.4 34.8 34.9 35.0 34.9 34.5 γ = 0.3 34.2 32.8 33.9 34.3 34.6 34.8 35.0 34.2 γ = 0.5 34.5 33.8 34.2 34.6 34.5 34.7 34.7 34.6 γ = 0.7 34.0 33.7 34.2 34.3 34.3 34.3 34.3 34.2 γ = 0.9 34.1 34.0 34.2 34.3 34.4 34.4 34.4 34.3 γ = 1.1 34.2 34.0 34.3 34.3 34.3 34.3 34.2 34.2 (b) IWSLT De-En - Test Table 3: Aligned training with different gradient scaling ratios γ : 1 on IWSLT 14 De-En. For each model we report the BLEU4 score evaluated with a uniformly sampled exit n U([1..6]) for each token and a fixed exit n [1..6]. The average corresponds to the average BLEU4 of all fixed exits. Published as a conference paper at ICLR 2020 APPENDIX B FLOPS APPROXIMATION This section details the computation of the FLOPS we report. The per token FLOPS are for the decoder network only since we use an encoder of the same size for all models. We breakdown the FLOPS of every operation in Algorithm 2 (blue front of the algorithmic statement). We omit non-linearities, normalizations and residual connections. The main operations we account for are dot-products and by extension matrix-vector products since those represent the vast majority of FLOPS (we assume batch size one to simplify the calculation). dd decoder embedding dimension. de encoder embedding dimension. df The feed-forward network dimension. |x| source length. t Current time-estep (t 1). V output vocabulary size. Operation FLOPS Dot-product (d) 2d 1 Linear din dout 2dindout Table 4: FLOPS of basic operations, key parameters and variables for the FLOPS estimation. With this breakdown, the total computational cost at time-step t of a decoder block that we actually go through, denoted with FC, is: FC(x, t) = 12d2 d + 4dfdd + 4tdd + 4|x|dd + 4[[First Call]]|x|ddde, where the cost of mapping the source keys and values is incurred the first time the block is called (flagged with First Call). This occurs at t = 1 for the baseline model but it is input-dependent with depth adaptive estimation and may never occur if all tokens exit early. If skipped, a block still has to compute the keys and value of the self-attention block so the selfattention of future time-steps can function. We will denote this cost with FS and we have FS = 4d2 d. Depending on the halting mechanism, an exit prediction cost, denoted wit FP, is added: Sequence-specific depth: FP(t, q(t)) = 2[[t = 1]]Ndd Token-specific Multinomial: FP(t, q(t)) = 2Ndd Token-specific Geometric-like: FP(t, q(t)) = 2ddq(t) Confidence thresholding: FP(t, q(t)) = 2q(t)V dd For a set of source sequences {x(i)}i I and generated hypotheses {y(i)}i I, the average flops per token is: Baseline (N blocks): 1 P i P|y(i)| t=1 N FC(x(i), t) + 2V dd Adaptive depth: 1 P i P|y(i)| t=1 q(t)FC(x(i), t) + (N q(t))FS + FP(t, q(t)) + 2V dd In the case of confidence thresholding the final output prediction cost (2V dd) is already accounted for in the exit prediction cost FP. Published as a conference paper at ICLR 2020 Algorithm 2 Adaptive decoding with Tok-geometric-like 1: Input: source codes s, incremental state 2: Initialization: t = 1, y1 = 3: for n 1 . . . N do 4: First Call[n] = True. A flag signaling if the source keys and values should be evaluated. 5: end for 6: while yt = do 7: Embed the last output token yt. 8: for n 1 . . . N do 9: Self-attention. 10: - Map the input into a key (k) and value (v). FLOPS=4d2 d 11: - Map the input into a query q. FLOPS=2d2 d 12: - Score the memory keys with q to get the attention weights α. FLOPS=4tdd 13: - Map the attention output. FLOPS=2d2 d 14: Encoder-Decoder interaction. 15: if First Call[n] then 16: Map the source states into keys and values for the nth block. FLOPS=4|x|dedd 17: First Call[n] = False 18: end if 19: - Map the input into a query q. FLOPS=2d2 d 20: - Score the memory keys with q to get the attention weights α. FLOPS=4|x|dd 21: - Map the attention output. FLOPS=2d2 d 22: Feed-forward network. FLOPS=4dddf 23: Estimate the halting probability χt,n. FLOPS=2dd 24: if χt,n > 0.5 then 25: Exit the loop (Line 8) 26: end if 27: end for 28: if n < N then 29: Skipped blocks. 30: for ns n + 1 . . . N do 31: Copy and map the copied state into a key (k) and value (v). FLOPS=4d2 d 32: end for 33: end if 34: Project the final state and sample a new output token. FLOPS=2V dd 35: t++ 36: end while