# instructionfollowing_pruning_for_large_language_models__4f60f7de.pdf Instruction-Following Pruning for Large Language Models Bairu Hou 1 2 Qibin Chen 1 Jianyu Wang 1 Guoli Yin 1 Chong Wang 1 Nan Du 1 Ruoming Pang 1 Shiyu Chang 2 Tao Lei 1 With the rapid scaling of large language models (LLMs), structured pruning has become a widely used technique to learn efficient, smaller models from larger ones, delivering superior performance compared to training similarly sized models from scratch. In this paper, we move beyond the traditional static pruning approach of determining a fixed pruning mask for a model, and propose a dynamic approach to structured pruning. In our method, the pruning mask is inputdependent and adapts dynamically based on the information described in a user instruction. Our approach, termed instruction-following pruning , introduces a sparse mask predictor that takes the user instruction as input and dynamically selects the most relevant model parameters for the given task. To identify and activate effective parameters, we jointly optimize the sparse mask predictor and the LLM, leveraging both instruction-following data and the pre-training corpus. Experimental results demonstrate the effectiveness of our approach on a wide range of evaluation benchmarks. For example, our 3B activated model improves over the 3B dense model by 5-8 points of absolute margin on domains such as math and coding, and rivals the performance of a 9B model. 1. Introduction Structured pruning techniques have become a widely adopted method for reducing the inference cost of large language models (Wang et al., 2020; Sreenivas et al., 2024; Muralidharan et al., 2024; Meta AI, 2024). These methods typically optimize a binary mask over language model parameters to minimize either language modeling or task-specific loss (Xia et al., 2024; Sreenivas et al., 2024; Meta AI, 2024). Work done while interning at Apple. 1Apple AI/ML 2UC Santa Barbara. Correspondence to: Bairu Hou , Tao Lei . Proceedings of the 42 nd International Conference on Machine Learning, Vancouver, Canada. PMLR 267, 2025. Copyright 2025 by the author(s). Once the mask is optimized, the resulting mask is fixed, allowing deployment of a smaller, pruned model. However, the fixed nature of the pruned model poses challenges in realworld inference scenarios, where tasks can vary significantly, for instance, coding, mathematics, or domain-specific requirements, each demanding distinct skills and knowledge from the original language model. A static pruned model may struggle to balance inference efficiency with high performance across diverse tasks. Given this, we explore a paradigm shift from static pruning masks to dynamic ones, addressing the central question: Can LLMs learn to select the most suited parameters based on the task description? We aim to automatically generate input-specific pruning masks tailored to the tasks described in user prompts. This dynamic, context-aware pruning mechanism enables the language model to perform inference using only the parameters necessary for the task, offering a compelling balance between efficiency and expressivity compared to using a static dense model. Moreover, because the parameters are selected and fixed, our method avoids reloading new parameters during the decoding process. This design choice contrasts with other dynamic methods such as contextual sparsity (Liu et al., 2023; Zhou et al., 2024) and mixture-ofexperts (Lepikhin et al., 2020; Fedus et al., 2022; Dai et al., 2024), which load different parameters at each decoding step, leading to significant weight loading costs. Our design is particularly suited for on-device models (e.g. smartphones and laptops), where the inference typically samples a few responses given the same user query (or the same task). In this case, the same activated parameters are selected and cached as a dense model, therefore achieving the same speedup as the static pruning. To address the central question, we present Instruction Following Pruning (IFPRUNING), a method that integrates a sparsity predictor with the language model to dynamically generate input-dependent pruning masks, as illustrated in Figure 1. Specifically, we focus on structured pruning of the feed-forward neural network layers, where entire rows or columns of the weight matrices are pruned (Xia et al., 2024; Gunter et al., 2024; Sreenivas et al., 2024). The user Instruction-Following Pruning for Large Language Models Sparsity Predictor Translate the English text into Write a Python function from the given definition. Per-Input FFN Mask User request Original LLM Input-specific Param selection Pruning Figure 1. Overview of IFPRUNING. (Left) For each given prompt, the sparsity predictor (much smaller than the LLM) determines which rows and columns of the FFN matrices should be activated. (Middle) The LLM is then pruned accordingly and uses the selected parameters to perform inference for that specific prompt. (Right) By pruning a 9B LLM to 3B for each input, IFPRUNING significantly outperforms the dense 3B model and achieves performance levels close to the dense 9B model. It also achieves nearly the same inference latency as the dense 3B model, as measured by time-to-first-token (TTFT). prompt is first passed into the sparsity predictor, which assigns importance scores to the rows and columns of each feed-forward network layer. These scores are then transformed into differentiable masks using the SOFTTOPK operator (Ainslie et al., 2023), to achieve a predefined number of sparsity (e.g., reducing a 9B language model to 3B active parameters). The resulting masks are applied to the language model, in which the feed-forward layers are pruned using the masks. During training, the differentiable mask generation mechanism allows us to jointly optimize both the sparsity predictor and the language model by minimizing the next-token prediction loss. We employ effective training strategies that leverage both pre-training and supervised fine-tuning data. At test time, only the selected parameters are activated for inference. Parameter selection can be performed either perinput or per-task: the input prompt can directly be used for parameter selection (Section 4.2), or a predefined task prompt can be used to select parameters shared across multiple inputs within the same task (Section 4.3). We validate IFPRUNING through comprehensive experiments across diverse tasks in Section 4. Specifically, we fine-tune pre-trained language models of varying sizes (6B, 9B, and 12B parameters) using IFPRUNING and prune them to activate only 3B parameters. In particular, IFPRUNING consistently outperforms 3B dense models across tasks such as math, coding, tool use, MMLU (Hendrycks et al., 2021a) and Alpaca Eval (Dubois et al., 2024). For example, when dynamically pruning the 9B model to 3B, our method improves over the 3B dense model by 8% on coding tasks and by 5% on math benchmarks, incurring only marginal performance degradation compared to the unpruned 9B model. We conduct further analysis to better understand the pruning decisions. Specifically, we observe that instructions requiring similar skills or domain knowledge yield highly homogeneous pruning patterns. Inspired by this analysis, we explore per-task pruning in Section 4.3, where a single task prompt generates shared masks for all test instances within the same task. Results show that per-task pruning maintains robust performance while further reducing data loading overhead. We also show that IFPRUNING can significantly improve the LLM inference efficiency in Section 4.4. Compared to the full model, IFPRUNING reduces the time-to-first token by up to 57% and the generate time by up to 41%. In addition, the overhead introduced by dynamic pruning and parameter caching is negligible, adding less than 0.1 seconds per example and accounting for only 1 2% of the total generation time. 2. Related Work In this section, we provide an overview of prior research that closely relates to and partially motivates our work, including model pruning, contextual sparsity, and mixture-of-experts. Model pruning Pruning has been extensively studied to compress neural networks and improve their efficiency (Han et al., 2015; Zhu & Gupta, 2017). Previous work has explored different pruning techniques for both unstructured pruning (Narang et al., 2017; Frankle & Carbin, 2018; Li et al., 2020; Chen et al., 2020) and structured pruning (Wen et al., 2016; Voita et al., 2019; Louizos et al., 2018; Wang et al., 2020). As structured pruning removes entire compo- Instruction-Following Pruning for Large Language Models nents in the model such as channels, attention heads, and feed-forward neural network intermediate dimensions, it is more hardware-friendly than unstructured pruning to compress the large models. Various methods have been proposed for structured pruning of LLMs (Yang et al., 2024b; Kim et al., 2024; Kurti c et al., 2024; Dery et al., 2024). LLM-PRUNER (Ma et al., 2023) adopt the gradient information to find unimportant components in LLMs and remove them. SLICEGPT transforms each weight matrix in transformer blocks into a smaller one by applying orthogonal transformations to reduce the embedding dimensions of weight matrices. SHORTGPT (Men et al., 2024) proposes to identify and remove those less important layers, where the layer importance is measured by the similarity between inputs and outputs of that layer. In comparison, other optimization-based methods directly learn the parameter masks. For example, SHEARED LLAMA (Xia et al., 2024) use the HARDCONCRETE masking (Louizos et al., 2018; Wang et al., 2020) to generate differentiable masks and optimize the model and masks on pre-training data. Our method also directly optimize the sparsity predictor and the LLM, and we further extend static pruning to input-dependent pruning. Contextual sparsity Our approach is also directly motivated by the contextual sparsity of LLMs (Liu et al., 2023; Akhauri et al., 2024; Lee et al., 2024). Previous work has identified the existence of input-dependent sub-networks (e.g., attention heads and MLP parameters) within Re LUbased LLMs that can generate approximately the same output as the full model for an input. By predicting such sparsity patterns at each decoding step, we can achieve a favorable balance between accuracy and speedup. But stateof-the-art LLMs (Dubey et al., 2024; Liu et al., 2024a; Yang et al., 2024a) design MLP blocks based on more complex non-linear activation functions such as Swi GLU (Shazeer, 2020), Si LU (Elfwing et al., 2018; Ramachandran et al., 2017) and GELU (Hendrycks & Gimpel, 2016) that do not inherently induce sparsity (Mirzadeh et al., 2023; Song et al., 2024). Therefore, directly predicting the sparsity patterns can lead to significant performance degradation (Zhou et al., 2024; Dong et al., 2024). In comparison, we co-optimize the sparsity predictor and the LLM with non-Re LU activation functions to achieve better contextual sparsity with minimum performance degradation. Also, most contextual sparsity methods require predicting sparsity and loading different parameters at each decoding step. Our method eliminates this overhead by selecting the parameters based on the input or task description before decoding starts. The selected parameters are fixed for the entire decoding process, avoiding the parameter reloading cost. Mixture-of-experts Mixture-of-Experts (Mo E) have emerged as a popular architecture for scaling LLMs while managing inference costs (Lepikhin et al., 2020; Du et al., 2022; Fedus et al., 2022; Zhou et al., 2022; Dai et al., 2024; Liu et al., 2024b). These models organize every FFN layer into multiple large FFN blocks referred to as experts, and selectively activate a few experts for each input token via a routing mechanism (Lepikhin et al., 2020; Zoph et al., 2022; Sun et al., 2024). Our method share the same spirit as Mo E by dynamically activating a subset of parameters. However, our method selects the activated parameters given the input prompt, and reuses the same activated parameters during decoding. Although this choice loses the flexibility of using different parameters per token, it significantly reduces weight loading costs for decoding. In this regard, our model is a sparse model designed for on-device scenarios where both memory and computational resources are constrained. Another difference is that our method performs more finegrained selection of parameters by activating or pruning each FFN dimension independently, which enhances model expressivity. In this section, we elaborate on the details of IFPRUNING, including the architecture design, data mixture, and training method. We focus on pruning the feed forward blocks (FFNs) in this work, but our method can be easily extend to pruning other components such as attention heads. 3.1. Overview of Structured Pruning Denote the hidden dimension of the LLM as d, the intermediate dimension of the FFN blocks as dffn, the input length as n, and X Rn d as the input of a transformer FFN block Fffn( ). The goal of our structured pruning method is to reduce the FFN intermediate dimension from dffn to tffn. Without loss of generality, consider a standard FFN block defined as Fffn(X) = FF2(FF1(X)) = σ(XW1)W2, (1) where W1 Rd dffn, W2 Rdffn d are weight matrices, and σ is the non-linear activation function. The structured pruning of the FFN block can be expressed as applying a mask variable m {0, 1}dffn to the output of the first linear transformation Fffn(X, m) = FF2(FF1(X) m). (2) where is an element-wise multiplication between m and each row of FF1(X). For each dimension of m, mi = 0 indicates that the i-th column of W1 and i-th row of W2 are pruned. This is because the output Fffn(X, m) is equivalent to the output of the FFN layer after we prune the i-th column of W1 and i-th row of W2. Here m satisfies the Instruction-Following Pruning for Large Language Models sparsity constraint P i mi = tffn, where tffn is the target intermediate dimension of the FFN blocks after pruning. 3.2. Architecture As shown in Figure 1, our architecture comprises two key components: a sparsity predictor and a dense LLM to be dynamically pruned. For any given user prompt, the sparsity predictor generates masks that are applied to the LLM backbone, pruning the corresponding rows and columns of the FFN blocks. Sparsity predictor The sparsity predictor consists of two modules: ❶a much smaller LLM backbone to extract the features of user prompts and ❷a mask prediction head. Specifically, the LLM backbone takes the prompt x = (x1, . . . , xn) as input with length n, and we use the hidden states of the last token xn in the last layer to represent the prompt. The mask prediction head is a two-layer MLP, which predicts the masks given the prompt presentations. The output of the FFN mask prediction head is the masking score z RL dffn, where L is the number of layers of the LLM. We include more details about the architecture of the sparsity predictor in Appendix A.1. Given the predicted masking score z, a mask generation operator will be applied to z to convert it to the mask m [0, 1]L dffn, which contains tffn nonzero elements. In this paper, we use the Soft Top K (Lei et al., 2023; Ainslie et al., 2023) algorithm to generate a differentiable m, but we also acknowledge that other algorithms such as the Hard Concrete masking (Louizos et al., 2018; Wang et al., 2020) are also applicable. Particularly, given the FFN masking score z, Soft Top K converts it to masks m via: λ(i) = g(z(i)), m(i) = λ(i) Top(λ(i), tffn). (3) Here z(i), λ(i) and m(i) represent the i-th row of each matrix, g( ) : Rdffn [0, 1]dffn is a normalization function, and Top( , tffn) {0, 1}dffn is an indicator function that returns a binary mask indicating the top-k values in λ. The normalization function g( ) ensures that λ satisfies the sparsity constraint, i.e., P k λ(i) k = tffn, where tffn is the target size of the FFN layers. More details of Soft Top K can be found in the previous work (Lei et al., 2023; Ainslie et al., 2023). Masked LLM During training, the LLM takes the masks m as an additional input and prune its FFN blocks. We use standard next token prediction loss computed over tokens within a training batch, and we co-optimize the LLM and the sparsity predictor. 3.3. Model Training The training of IFPRUNING incorporates two stages. We first perform continued pre-training in which we initialize our model using a pretrained dense model, and then perform supervised fine-tuning (SFT) on instruction-following data. In what follows, we elaborate on the details of the two training stages. Continued pre-training Learning to select input-specific sub-networks may require a lot of training data. Instead of directly training the models on the SFT data only, we first use pre-training data to jointly optimize the sparsity predictor and masked LLM. Specifically, denoting the input text as x = (x1, . . . , xn), we split it into K consecutive chunks with fixed size: x(k) = x(k 1)s+1, . . . , xks, k = 1, . . . , K, (4) where s = n/K is the fixed size of each chunk. We then use the each chunk to select parameters of the LLM for the next token predictions in the next chunk, i.e., xi x(k+1) ℓ h f(x