# distillspec_improving_speculative_decoding_via_knowledge_distillation__162994f4.pdf Published as a conference paper at ICLR 2024 DISTILLSPEC: IMPROVING SPECULATIVE DECODING VIA KNOWLEDGE DISTILLATION Yongchao Zhou1,3 , Kaifeng Lyu1,4 , Ankit Singh Rawat1, Aditya Krishna Menon1, Afshin Rostamizadeh1, Sanjiv Kumar1, Jean-François Kagy1 , Rishabh Agarwal2,5 1Google Research 2Google Deep Mind 3University of Toronto 4Princeton University 5Mila Speculative decoding (SD) accelerates large language model inference by employing a faster draft model for generating multiple tokens, which are then verified in parallel by the larger target model, resulting in the text generated according to the target model distribution. However, identifying a compact draft model that is well-aligned with the target model is challenging. To tackle this issue, we propose Distill Spec, a method that uses knowledge distillation to better align the draft model with the target model before applying SD. Distill Spec makes two key design choices, which we demonstrate via systematic study to be crucial to improving the draft and target alignment: utilizing on-policy data generation from the draft model, and tailoring the divergence function to the task and decoding strategy. Notably, Distill Spec yields 10 45% speedups over standard SD on a range of benchmarks, using both greedy and non-greedy sampling. We show that the distilled model can be well transferred to various tasks with an average speedup of 26%. Furthermore, we combine Distill Spec with lossy SD to achieve fine-grained control over the latency vs. task performance trade-off. Finally, in practical scenarios with models of varying sizes, first using distillation to boost the performance of the target model and then applying Distill Spec to train a well-aligned draft model can reduce decoding latency by 6 10 with minimal performance drop, compared to standard decoding without distillation. 1 INTRODUCTION Large language models (LLMs) have revolutionized natural language understanding and generation across diverse applications (Open AI, 2023; Anil et al., 2023). However, their autoregressive generation nature poses significant computational challenges, especially in real-time deployments with stringent latency constraints (Thoppilan et al., 2022; Pope et al., 2023). Conversely, smaller language models, while computationally efficient, often lack the expressive power of their larger counterparts and achieve subpar performance. While reducing the inference cost of larger models, e.g., via quantization or pruning, or improving the performance of the smaller models, e.g., via knowledge distillation (KD) (Hinton et al., 2015), constitute natural approaches to enable a favorable performance versus inference cost trade-off, these approaches frequently result in unacceptable performance gap compared to the high-quality large models. This has inspired a growing literature on designing mechanisms that combine both large and small models at inference to approximate the performance of the larger models without incurring their high computational cost. Among conventional approaches, model cascading aims to identify easy instances where smaller models suffice to achieve good performance, and soliciting larger models only on a subset of hard instances (Rowley et al., 1998; Xu et al., 2014) or tasks (Cai et al., 2023b). Different from such taskor instance-level cascading, speculative decoding (SD) (Leviathan et al., 2023; Chen et al., 2023) exploits the token-level variability in the computation demand during LLM inference by interactively invoking a small draft model and a large target model. At a given stage during inference, the draft model generates successive candidate tokens for multiple inference steps via autoregressive decoding. The target model then verifies the candidate tokens via parallel decoding, and employs rejection sampling to accept a subset of candidate tokens at contiguous positions. The main objective of SD is to speed up text generation while guaranteeing that the decoded tokens follow the target model distribution. SD relies on the insight that the combined cost of autoregressive decoding with a small draft model followed by parallel verification with the target model is Student Researcher at Google Research. Advising contribution. Corresponding authors: , , and . Published as a conference paper at ICLR 2024 Speculative Decoding Speedup (%) BBH (Transfer) Distill Spec (Greedy) Standard SD (Greedy) Distill Spec (Non-Greedy) Standard SD (Non-Greedy) Figure 1: Performance comparison of standard speculative decoding (SD) vs. our proposed Distill Spec, with smalland XL-sized models from the T5 v1.1 family (Raffel et al., 2020) being utilized as the draft and the target models, respectively. Distill Spec enhances SD speed by better aligning the draft with the target via white-box knowledge distillation, resulting in a consistent 10 45% speedup improvement over standard SD across various datasets. The distilled draft model from GSM8K transfers well to 23 unseen Big Bench Hard tasks (Suzgun et al., 2022), resulting in an average speedup of 26%. See 5.1 for additional details. lower than the cost of autoregressive decoding with the target model alone. However, the realized inference cost reduction or latency improvement crucially depends on the acceptance rate of the draft-generated tokens by the target model, which can be shown to be directly tied to the alignment between the token distributions of the draft and target models. Thus, a successful application of SD hinges on identifying a compact draft model that simultaneously has small autoregressive decoding cost and is closely aligned with the target model. In this work, we propose Distill Spec, a novel approach that relies on KD (Hinton et al., 2015) to obtain an effective draft model. Unlike the standard application of KD which primarily focuses on improving the task performance of a small student model, Distill Spec aims at aligning the student (draft) model with the teacher (target) model to enhance the acceptance rate during SD. We undertake a comprehensive exploration of the distillation process for speeding up SD, considering several factors including the composition of training data, choice of divergence functions to define the training objective for KD, and decoding strategies. Notably, our findings underscore that using model-generated data is crucial for ensuring strong student-teacher alignment across various tasks via KD, and that the selection of the best-performing divergence function in Distill Spec is highly task-dependent and sensitive to the decoding strategy (i.e., greedy versus non-greedy). Furthermore, we explore the utility of Distill Spec for lossy SD (Leviathan et al., 2023) which allows for sampling away from the target model distribution. We show that combining Distill Spec with lossy SD enables a more fine-grained control over the latency versus task performance trade-off. Finally, we carry out a systematic study of how to design an efficient inference scheme in a practical setting where one has access to multiple models of increasing size and quality. Leveraging the insights that we have laid out about KD and SD, our study concludes that the most effective strategy involves first distilling a large model into a smaller one as the potential target model for performance optimization, followed by Distill Spec for distilling an even smaller model to be used as the draft model in SD. This approach results in a remarkable 6 10 reduction in latency, compared to a standalone non-distilled target model of the same size, with minimal performance degradation. Our key contributions are: (i) We propose Distill Spec, a method that uses KD to enhance draft model alignment with the target model ( 4), and show that our method can improve SD speed by 10 45% while preserving model performance across diverse datasets under greedy and non-greedy sampling (Figure 1). (ii) We conduct an extensive analysis of the optimal distillation recipe ( 5.2) for model alignment, encompassing factors such as training data generation and different divergences, and emphasizing the distinctions between standard KD and distillation tailored for SD. (iii) We extend Distill Spec to lossy SD, enabling refined control over the quality-latency trade-off. Moreover, we offer insights for combining KD and SD when several models are available ( 5.3). 2 RELATED WORK Speculative decoding (SD). Due to the inherent sequential nature of autoregressive decoding, the primary latency bottleneck in LLM inference arises from memory read/write operations rather than arithmetic computations (Pope et al., 2023). Speculative decoding (Leviathan et al., 2023; Published as a conference paper at ICLR 2024 Chen et al., 2023) (SD) addresses this challenge by utilizing a compact draft model to generate a batch of tokens sequentially, while validating them in parallel with a larger target model. Prior to SD, various parallel computing paradigms have been explored for autoregressive models, including block parallel sampling (Stern et al., 2018), shallow aggressive decoding (Sun et al., 2021), and aggressive decoding (Ge et al., 2022). However, these approaches are not readily adaptable to typical language models due to potential deviations from target model s distribution, strict input constraints, or limited support for general stochastic sampling. Notably, recent variants of SD have considered different interactions between the draft and target model to reduce unnecessary computation (Kim et al., 2023) and incorporated parallel computation along the batch axis, sometimes combined with token tree verification, as seen in Spec Tr (Sun et al., 2023), Spec Infer (Miao et al., 2023), and Medusa (Cai et al., 2023a). In contrast, our work focuses on enhancing SD by improving the alignment between the small draft model and the large target model through KD, which does not require any changes to serving infrastructures already implementing SD and is complementary to the recent variants of SD. Furthermore, we conduct a systematic study of lossy SD for providing nuanced control over the trade-off between quality and latency for specific serving models. Knowledge distillation (KD) for LLMs. KD (Buciluˇa et al., 2006; Hinton et al., 2015), which trains high-quality smaller student models with the supervision of larger teacher models, has emerged as a vital technique for reducing inference cost while maintaining model quality across a range of domains. In the context of LLMs, prior uses of KD (Taori et al., 2023; Fu et al., 2023) have mostly focused on black-box KD, wherein only the teacher s output generations, generally via APIs, are accessible during student training. However, with the proliferation of open-source LLMs (Zhang et al., 2022; Touvron et al., 2023), which enable access to teacher weights and logits, there is a growing interest in white-box KD. White-box KD allows student models to benefit from richer supervision signals provided by white-box teacher models, leading to enhanced language abilities (Agarwal et al., 2023; Gu et al., 2023; Wen et al., 2023). Unlike prior works focused on creating highly capable standalone student models, we harness KD to foster closer collaboration between smaller and larger models in SD, which may be particularly valuable when a small distilled model alone cannot meet stringent quality requirements. While Stern et al. (2018) use a black-box KD approach (Seq KD) to enhance blockwise parallel decoding, their samples are generated from the large target model, which is prohibitively expensive for LLMs. Furthermore, they ignore the teacher model s logits and train their draft model using only one-hot teacher labels a reasonable choice for greedy decoding but a less effective one for non-greedy sampling (Figure 2). Concurrently, Liu et al. (2023) propose to improve SD using KD, but they assume an online setup with a changing query distribution, and focus on improving the acceptance rate rather than reducing the actual latency. 3 BACKGROUND: SPECULATIVE DECODING Notation. Given an input sequence x comprising tokens from a pre-defined vocabulary, a language model M provides a distribution over possible output sequences y. Suppose we employ SD with a compact draft model Mq to assist a larger target model Mp. Let p(yt | x, y rt+i} {γ}) Determine the number of accepted tokens n. 8: if n < γ then 9: yt+n norm (max (0, pt+n(y) qt+n(y))) Sample from the adjusted distribution. 10: else 11: yt+n pt+n(y) Sample from Mp. 12: end if Return {x, y