# keypointbased_progressive_chainofthought_distillation_for_llms__7370c66d.pdf Keypoint-based Progressive Chain-of-Thought Distillation for LLMs Kaituo Feng 1 Changsheng Li 1 Xiaolu Zhang 2 Jun Zhou 2 Ye Yuan 1 Guoren Wang 1 3 Abstract Chain-of-thought distillation is a powerful technique for transferring reasoning abilities from large language models (LLMs) to smaller student models. Previous methods typically require the student to mimic the step-by-step rationale produced by LLMs, often facing the following challenges: (i) Tokens within a rationale vary in significance, and treating them equally may fail to accurately mimic keypoint tokens, leading to reasoning errors. (ii) They usually distill knowledge by consistently predicting all the steps in a rationale, which falls short in distinguishing the learning order of step generation. This diverges from the human cognitive progression of starting with easy tasks and advancing to harder ones, resulting in sub-optimal outcomes. To this end, we propose a unified framework, called KPOD, to address these issues. Specifically, we propose a token weighting module utilizing mask learning to encourage accurate mimicry of keypoint tokens by the student during distillation. Besides, we develop an in-rationale progressive distillation strategy, starting with training the student to generate the final reasoning steps and gradually extending to cover the entire rationale. To accomplish this, a weighted token generation loss is proposed to assess step reasoning difficulty, and a value function is devised to schedule the progressive distillation by considering both step difficulty and question diversity. Extensive experiments on four reasoning benchmarks illustrate our KPOD outperforms previous methods by a large margin. 1. Introduction Large language models (LLMs) have demonstrated remarkable reasoning capabilities via chain-of-thought (Co T) 1Beijing Institute of Technology 2Ant Group 3Hebei Province Key Laboratory of Big Data Science and Intelligent Technology. Correspondence to: Changsheng Li . Proceedings of the 41 st International Conference on Machine Learning, Vienna, Austria. PMLR 235, 2024. Copyright 2024 by the author(s). prompting (e.g., Let s think step-by-step ), which prompts LLMs to generate a step-by-step rationale to help reasoning (Kojima et al., 2022; Wei et al., 2022). However, such abilities usually emerge in extremely large models, especially those with over 100 billion parameters (Fu et al., 2023; Hoffmann et al., 2022) , such as 175B GPT-3 (Brown et al., 2020) and 540B Pa LM (Chowdhery et al., 2023). The substantial amount of parameters unavoidably leads to high inference costs and makes it challenging to deploy LLMs in environments with limited computational resources (Hsieh et al., 2023). To tackle with this, a recent surge of works, known as Co T distillation, has arisen as a promising avenue to distill reasoning capabilities of LLMs to smaller student models (Li et al., 2023; Wang et al., 2023b; Fu et al., 2023). The core idea of these methods is to require the student model to mimic the step-by-step rationale generated by LLMs in response to a question. However, current Co T distillation methods often encounter the following two issues: First, in a rationale, each token carries different levels of importance in the reasoning process. Certain keypoint tokens play a pivotal role in reasoning, while other tokens are of less importance or even irrelevant to the reasoning process. For instance, consider a step in a rationale: Next, we just need to simply add up the calories from the lettuce and cucumber: 30 + 80 = 110 . Here, terms like just , simply are reasoning-irrelevant, whereas the calculation 30 + 80 = 110 stands out as the keypoint for reasoning. The reasoning-irrelevant tokens can be replaced without negative effects, but even a slight deviation from the keypoint token could result in errors in reasoning. Therefore, it s crucial for the student model to focus on the precise mimicry of these keypoint tokens. Nevertheless, previous Co T distillation methods usually treat all tokens equally during distillation (Li et al., 2023; Wang et al., 2023b). The second issue stems from the fact that previous approaches usually demand the student model to consistently learn all the steps in a rationale throughout the distillation process, without distinguishing the learning order of step generation. This distillation strategy diverges from the human cognitive pattern that progresses from easier tasks to more challenging ones. This deviation might lead to suboptimal outcomes. In the process of human or biological agent learning, ability acquisition doesn t simply stem from random tasks (Molina & Jouen, 1998). Instead, there is an Keypoint-based Progressive Chain-of-Thought Distillation for LLMs organized progression from easy tasks to hard tasks for them to acquire capabilities, especially for complex skills such as reasoning (Peterson, 2004; Krueger & Dayan, 2009; Benoit et al., 2013). In the field of machine learning, this ordered learning paradigm is regarded as curriculum learning (Bengio et al., 2009). Inspired by this, we intend to develop a progressive Co T distillation strategy to facilitate the student model acquire reasoning ability from easy to hard. However, directly applying previous curriculum learning strategies to Co T distillation could be inferior because of the following two reasons: (i) They overlook the step-by-step reasoning nature where each reasoning step within a rationale may possess varying reasoning difficulty, resulting in sub-optimal difficulty assessment. (ii) As aforementioned, a step in the rationale might contain many tokens that are not crucial to the reasoning process. When assessing the difficulty of step generation, it may be dominated by these inessential tokens, thereby inaccurately reflecting the challenge of obtaining the expected outcome for a reasoning step. In this paper, we propose Keypoint-based Progressive Co T Distillation for LLMs dubbed KPOD, with the goal of addressing the above two issues in a unified framework. First, we propose a rationale token weighting module to determine the token significance for distillation. It learns to generate masks for inessential tokens to the reasoning process via two distinctive loss functions: An answer prediction loss is introduced to encourage the module to utilize the question with the masked rationale to derive the answer, while a mask ratio loss is designed to maximize the ratio of masked tokens in the rationale. By doing so, the obtained probability of not masking a token can serve as an indicator of its significance weight. Second, we develop an in-rationale progressive distillation strategy that orders the learning sequence from easy reasoning to hard reasoning within the rationale of a question. This strategy begins by training the student model to generate the last few reasoning steps of the rationale, given the question with preceding steps of this rationale as input. Subsequently, it progressively extends to generate the entire rationale using only the question as input. To precisely assess each step s reasoning difficulty, we propose a token generation loss based on the derived token significance, aiming to eliminate the negative effects of reasoning-irrelevant tokens. Finally, we design a value function to dynamically determine the number of steps taken as input at each stage, thereby automatically adjusting their learning difficulty. Meanwhile, we leverage the value function to select diverse questions, so as to prevent over-fitting (Jiang et al., 2014; Liang et al., 2021). Our contributions can be summarized as: 1) We propose a general and principled framework for Co T distillation, which simultaneously considers token significance and reasoning difficulty within a rationale during distillation. 2) We design a rationale token weighting module through mask learning to determine the token significance for reasoning. This allows the student to concentrate more on keypoint tokens. 3) We devise an in-rationale progressive Co T distillation strategy to schedule the learning order of reasoning steps within a rationale. This enables the student to progressively acquire reasoning abilities in an easy-to-hard manner. 4) Extensive experiments on four reasoning benchmarks validate the effectiveness of our KPOD, showcasing significant performance improvements compared to baselines. 2. Related Works Chain-of-Thought Reasoning. The concept of employing step-by-step language rationales to aid in solving reasoning problems can be traced back to pioneering works (Ling et al., 2017). Inspired by this, chain-of-thought prompting (Wei et al., 2022) has been proposed to enable LLMs to generate intermediate reasoning steps that contribute to the final answer via few-shot Co T demonstrations. This prompting approach has illustrated remarkable performance gain for LLMs in reasoning related tasks (Zhang et al., 2022; Wang et al., 2023a). In addition, researchers find that LLMs can also obtain impressive reasoning performance by zero-shot Co T (Kojima et al., 2022) without task-related demonstrations. This is achieved by only using a single sentence Let s think step by step for prompting. Recently, a number of Co T prompting methods have demonstrated effectiveness in enhancing the reasoning performance of LLMs (Diao et al., 2023; Yang et al., 2023), such as SC-Co T (Wang et al., 2022), Auto-Co T (Zhang et al., 2022), Multimodal-Co T (Zhang et al., 2023), etc. However, the emergence of Co T reasoning capabilities in LLMs typically requires models with more than 100 billion parameters (Wei et al., 2022; Fu et al., 2023), making it resource-consuming for deployment. Co T Distillation. Knowledge distillation has been widely studied for model compression across various fields (Magister et al., 2023; Feng et al., 2024). Recently, Co T Distillation has emerged as a promising avenue to transfer the step-by-step reasoning capabilities of LLMs to smaller student models (Hsieh et al., 2023; Ho et al., 2023). The key idea of Co T distillation is to make the student model mimic the step-by-step rationale generated by LLMs in response to a question. In this context, the rationale can be interpreted as the LLMs explanation of how to derive the final answer of a question, akin to the soft label used in conventional knowledge distillation (Hinton et al., 2015; Feng et al., 2022). The representative works of Co T distillation include: SCo TD (Li et al., 2023) introduces a symbolic Co T distillation method that enables smaller models to self-rationalize for reasoning via learning rationales from LLMs. Specialized KD (Fu et al., 2023) is proposed to train a small language model specialized for reasoning in four distinct in-context scenarios. MCC-KD (Chen et al., 2023) adopts diverse rationales for Keypoint-based Progressive Chain-of-Thought Distillation for LLMs distillation and attempts to ensure their consistency. SCOTT (Wang et al., 2023b) designs a faithful Co T distillation strategy to make the student reason faithfully via counterfactual training. However, these methods fail to consider the reasonable learning order of the reasoning steps within a rationale, leading to sub-optimal performance. Curriculum Learning. Early researches in cognitive science emphasize the significance of the easy-to-hard learning pattern to acquire knowledge (Elman, 1993). Inspired by this, the pioneer work (Bengio et al., 2009) introduces the concept of curriculum learning (CL) to the machine learning field by gradually including samples from easy to hard for training. In recent years, a variety of CL methods have been proposed to enhance the model performance (Kong et al., 2021; Wang et al., 2021). For instance, Adaptive CL (Kong et al., 2021) proposes to utilize the loss of the model to dynamically adjust the difficulty score of each sample. SPL (Wan et al., 2020) introduces the curriculum learning to the neural machine translation domain via introducing the token-level and sentence-level confidence score. ICL (Jia et al., 2023) devises a curriculum learning method that organizes the curriculum within the token sequence of a sample for natural language generation tasks. However, as aforementioned, applying these CL methods directly to Co T distillation could yield inferior performance. 3. Proposed Method 3.1. Preliminaries and Problem Setting The goal of Co T distillation is to transfer the reasoning capability of large language models (LLMs) to smaller student models via distilling the rationales produced by LLMs. We denote the dataset as D = {(x(i), y(i))}, where x(i) is the i-th reasoning question and y(i) is the corresponding answer. Following previous Co T distillation works (Ho et al., 2023; Chen et al., 2023) , we adopt zero-shot Co T (Kojima et al., 2022) to prompt the teacher LLMs to generate step-by-step rationale r(i) for each question x(i). The reasoning template takes the following format: Q: A:

Therefore, the answer is , where

is the zeroshot Co T prompt such as Let s think step by step . Then, the student is trained to generate the concatenated sequence of rationale tokens r(i) and answer tokens y(i), given the question x(i) as input. The standard negative log-likelihood loss for training the student model can be formulated as: j log P(r(i) j |r(i) 0 and u > 0 are the parameters to control the growth rate. By integrating the growth rate with respect to t, we can derive D(t) as: D(t) = utp+1 p + 1 + C0, (10) where C0 represents the initial overall learning difficulty at stage 0. By letting D(t) achieve the maximum difficulty B of the dataset at stage T: D(T) = B = P i Pni j=1 d(i) j , we can derive u = (B C0)(p+1) T p+1 , where p and C0 are the pre-defined hyper-parameters. When entering stage t from stage t 1, it s required to select a set of questions to increase difficulty. We achieve this by reducing a number of input steps s for the selected questions as: ci(t) = ci(t 1) qi(t) s, s.t. H(S(t)) D(t), (11) where ci(t) is the scheduled number of input steps of the i-th question at stage t. Let S(t) denote the selected question set for increasing difficulty at stage t. Then, qi(t) {0, 1} represents whether i belongs to S(t). If i S(t), then qi(t) = 1; otherwise, qi(t) = 0. s is the pre-defined number for reducing input steps. H(S(t)) = P i hi(S(t)) P i hi(S(t 1)) is the sum of the increased difficulty and D(t) = D(t) P i hi(S(t 1)) is the ceiling magnitude for the increased difficulty. Then, in order to determine whether a question should increase difficulty, we design a value function F. The goal of this value function is two-fold: One is to align the increased difficulty as closely as possible with the defined magnitude, and the other is to ensure a diverse set of questions for escalating difficulty to prevent overfitting (Jiang et al., 2014). The value function F is designed as: F(S(t)) = ( D(t) H(S(t)))+β (12) where β is a trade-off hyper-parameter. The first term measures the closeness of H(S(t)) to D(t) and the second term measures the diversity of selected question set based on clustering. Specifically, Ck is the question set of the k-th cluster and K is the number of clusters. In this paper, we conduct K-means clustering (Bradley et al., 2000) to cluster the question based on its embedding, which is calculated by the average of the Glo Ve (Pennington et al., 2014) word embedding. S(t) is the selected question set. By using the square root operation, our aim is to promote a balanced distribution of questions within each cluster in the selected Keypoint-based Progressive Chain-of-Thought Distillation for LLMs question set. This approach ensures that the diversity of the chosen question set is maintained. The optimization of F(S(t)) can be formulated as: max S(t) F(S(t)), s.t. H(S(t)) D(t). (13) By maximizing F(S(t)), we can achieve the goal of selecting diverse questions to increase difficulty with close proximity to D(t). However, this is a combination optimization problem subject to the knapsack constraint, and solving it is known to be NP-hard. Fortunately, we can prove that F(S(t)) satisfies the condition of monotone and submodular. Therefore, it can be approximately solved by a submodular maximization algorithm FTGP (Li et al., 2022) in linear time with an approximation ratio guarantee, as formulated in Proposition 3.1. The proof of Proposition 3.1 can be found in Appendix D. Proposition 3.1. The optimization of max S(t) F(S(t)) subject to the knapsack constraint H(S(t)) D(t) can be approximately solved in O(nϵ 1 log ϵ 1) time complexity with a 1 2 ϵ approximation ratio guarantee, where n represents the scale of the data. After obtaining the scheduled input step ci(t) by solving Eq.(13), the rationale distillation loss at stage t can be formulated as: j=pci(t)+1 log P(r(i) j |r(i)