Instruction-Following Pruning for Large Language Models
摘要
评审与讨论
This article tackles the problem of letting LLMs select the most suited parameters for each prompted task and proposes a novel instruction-following pruning paradigm called IFPruning. Specifically, IFPruning uses a sparse mask predictor to predict a input-dependent mask for each context input. To train the predictor, IFPruning optimizes it together with the LLM to enable the effectiveness of the novel paradigm. Experiment results on a series of tasks demonstrate the effectiveness of IFPruning.
update after rebuttal
I appreciate the comprehensive response that the authors managed to propose during the rebuttal phase. I decided to maintain my overall rating.
给作者的问题
Will the code be published upon acceptance?
论据与证据
The claims in this article are reasonable.
方法与评估标准
The design of methods is clear and satisfactory.
理论论述
N/A
实验设计与分析
- How is the sub-network overlap rates calculated? I would like to see a complete pipeline for that calculation (i.e., data selection, forwarding, calculation metric).
- More details toward constructing the SFT dataset: I would like to see the proportion of each sub-domain and its corresponding source.
- Model speedup: The experiments have revealed the sparsity factor that IFPruning achieves. However, sparsity does not necessarily leads to inference speedup. I would like to see a comprehensive comparison on model speedup when IFPruning is applied.
- Design of the sparsity predictor: What is the source of applied predictors? What are their sizes?
补充材料
N/A
与现有文献的关系
The findings of this study may inspire researchers to explore the issue of creating task-specific LLMs from the pre-trained base model, thereby advancing the real-world applications of LLMs.
遗漏的重要参考文献
There is a recent-emerged field focusing on "SFT for efficient LLMs". Notable works of that field include [1,2]. Both [1,2] and IFPruning targets the goal of creating task-specific efficient LLMs, and should be compared and discussed in the article. There is no need for experimental results since those works are very recent.
[1] TrimLLM: Progressive Layer Dropping for Domain-Specific LLMs
[2] UniAttn: Reducing Inference Costs via Softmax Unification for Post-Training LLMs
其他优缺点
Notable strengths of this article:
- Novelty: The idea of instruction-following pruning is novel and is worth investigating, since real-world applications cannot supply directly deploying large-size LLMs. Therefore, the motivation and methodology of IFPruning is promising for further research.
- Writing: This article is well written and structured.
- Strong results: IFPruning achieves significant improvements compared to existing methods.
See other fields for weaknesses.
其他意见或建议
N/A
We thank Reviewer bo9C for the support and the valuable suggestions. Below we address each point raised.
Q1: Clarification on actual model speedup.
Since inference speed is concerned, we would like to first clarify that our method is motivated and designed for on-device models (e.g. on smartphone / laptop / desktop), where the inference typically samples a few responses given the same user query (or the same task). In this case, the same activated parameters are selected and cached as a dense model, therefore achieving the same speedup as the static pruning and dense baseline. We discussed the limitations of our work and possible extensions to batch inference in Section 5. We will better clarify our limitations in the next version. Thank you!
We evaluate the speedups by pruning the open-sourced LLaMA-3.1-8B-Instruct model to 3B. Although the tests are done on GPUs, we used batch size 1 and 4 generations per query, reflecting on-device usage. We report the time-to-first-token (TTFT) and the decoding time, both measured in seconds.
For dense models (8B and 3B), the TTFT consists of pre-filling only. For our method, we break down TTFT into its components:
- Sub-network selection (via the sparsity predictor)
- Parameter loading (load the selected parameters and cache the sub-network as a dense model)
- Pre-filling using the 3B sub-network
| GPU | Model | Sub-network selection | Parameter loading | Pre-filling | TTFT | Decoding Time | |
|---|---|---|---|---|---|---|---|
| Input length: 4k | A6000 | Llama-8b | - | - | 0.702 | 0.702 | 5.47 |
| Llama-3b | - | - | 0.317 | 0.317 | 3.52 | ||
| Ours 8b->3b | 0.070 | 0.016 | 0.315 | 0.402 | 3.53 | ||
| RTX3090 | Llama-8b | - | - | 0.947 | 0.947 | 5.48 | |
| Llama-3b | - | - | 0.396 | 0.396 | 3.18 | ||
| Ours 8b->3b | 0.088 | 0.013 | 0.396 | 0.498 | 3.21 | ||
| Input length: 2k | A6000 | Llama-8b | - | - | 0.336 | 0.336 | 4.11 |
| Llama-3b | - | - | 0.155 | 0.155 | 3.20 | ||
| Ours 8b->3b | 0.037 | 0.016 | 0.155 | 0.208 | 3.25 | ||
| RTX3090 | Llama-8b | - | - | 0.467 | 0.467 | 3.76 | |
| Llama-3b | - | - | 0.203 | 0.203 | 2.70 | ||
| Ours 8b->3b | 0.045 | 0.013 | 0.202 | 0.260 | 2.75 |
We highlight the following observation:
- Practical inference efficiency Gains: TTFT decreased by up to 57%, decoding time decreased by up to 41%.
- Minimal overhead from dynamic pruning & parameter caching: Overhead from dynamic pruning & parameter caching is negligible (~0.05s, ~2% of the total generation time)
- Despite dynamic masking, runtime of IFPruning is on par with static pruning, while offering input-specific adaptivity and superior accuracy.
We will include this analysis in the final version of our paper.
Q2: Sub-network overlap rate calculation
Thank you for asking. We clarify the process in three steps:
- 1: Data: We sample 128 inputs per dataset (MMLU, GSM8K, CodeAlpaca-20K, GPTTeacher). Inputs to the sparsity predictor are formatted with in-context examples (MMLU: 5-shot, GSM8K: 8-shot) or raw prompts (CodeAlpaca, GPTTeacher).
- 2: Sub-network Generation: Each input is passed through the sparsity predictor, which selects a fixed number of FFN units (a binary mask) at each layer, producing 128 sub-networks per dataset.
- 3. Overlap Calculation: We compare every pair of sub-networks within the 128 examples. For each pair, we compute the fraction of selected FFN units they share. The final overlap rate is the average of these pairwise overlaps.
We will include these details in the final version of our paper.
Q3: More details on SFT datasets
We will expand the final version with a detailed breakdown of the SFT datasets, including data sources, instruction formats, and size per domain.
Q4: Sparsity predictor architecture
The predictor is a lightweight model with 302M parameters, built on a pre-trained LM backbone. It consists of:
- A feature extractor (last hidden state of the final input token)
- Two-layer MLP:
- Linear(hidden_dim → 128)
- Linear(128 → num_layers × ffn_dim)
The output is a tensor of FFN importance scores per layer. The SoftTopK operator then converts the scores into structured binary masks.
Q5: Related work on SFT for efficient LLMs
Thank you for the pointers. We will add these references in the final paper to strengthen the related work section.
Q6: Will the code be published upon acceptance?
We will make our best effort to release the full implementation on top of an open-source model upon acceptance.
Thanks for the comprehensive response. After carefully checking all reviews and responses, I still consider this article as an insightful work that studies the emerging task-specific efficiency field.
We thank Reviewer bo9C for the thoughtful and encouraging feedback. We also sincerely appreciate your continued support for our paper after the rebuttal phase. We will incorporate the additional results and details as suggested in the final version.
The paper introduces Instruction-Following Pruning (IFPruning), a dynamic structured pruning method for large language models (LLMs). Instead of using a fixed sparsity mask, IFPruning employs a sparse mask predictor that selects the most relevant model weights (specifically, rows/columns of transformer feed-forward layers) on a per-instruction basis. The model is thus pruned on-the-fly per query, using only the parameters most relevant to that instruction. The authors jointly train the mask predictor and the LLM on instruction-following data utilizing both pre-training and additional fine-tuning data. Empirically, a pruned 3B-LLM using IFPruning (activating ~3B parameters out of a larger 9B or 12B model) achieves significantly better performance on domain-specific tasks like math and coding than a static 3B dense model, even rivaling a full 9B model on those tasks.
给作者的问题
Regarding input-dependent pruning (opposed to the task-dependent pruning), I would like to see the analysis regarding the inference time since it needs two forward passes due to the mask prediction.
论据与证据
The key claim is that input-dependent dynamic pruning can exceed static dense models of the same activated size. This is well supported by experiments: with an equal parameter budget, IFPruning outperforms a dense counterpart on multiple benchmarks. Notably, the 3B dynamic model nearly matches a dense 9B model’s accuracy on coding and math benchmarks (See Table 1). The authors also compare against a static structured pruning baseline (Pruning+Distill, which prunes to 3B and distills from a larger model) – IFPruning consistently beats the baseline. These results substantiate the claim that selecting weights per input yields a more effective submodel than the naive fixed mask.
However, MoE-structured LLMs also share the same spirit: activating fewer parameters within the large parent model. The authors should compare with MoEs, which have fixed masks, in order to claim that dynamic pruning is effective.
方法与评估标准
Overall, the method is well-designed for the stated goal of task-adaptive efficiency. The mask predictor is a lightweight network that adds minimal overhead (close to linear probing), and it produces differentiable masks via a SoftTopK operator to choose top neurons per layer. Pruning is done at the structured level, which keeps the model hardware-friendly. The evaluation spans a wide range of benchmarks: instruction following, reasoning, math, coding, tool use, and general NLP tasks.
That said, the authors should provide a detailed configuration for the mask predictor network as pre-trained LLMs are listed in the Appendix. Also, I cannot clearly understand the purpose of continued pre-training. The paper claims it provides a good initialization, but why does utilizing similar chunks from the same context help stabilize training of the mask predictor? I would like to see the experimental results without pre-training.
理论论述
N/A
实验设计与分析
I have several concerns about the experimental setup.
As there is no codebase, it is very difficult to specify which models were used or trained, and whether the architecture is based on LLaMA or uses specific attention and normalization modules. Further, there is also no experiments regarding pruning other components, such as attention heads. Authors say it is natural to extend yet I find it should be dependent on the model choice. Moreover, the paper does not provide a clear comparison to baselines: although the authors mention distilled models following Sheared Llama and Minitron, they neither compare against these baselines nor specify which distillation objectives are employed.
补充材料
I read the Appendix.
与现有文献的关系
N/A.
遗漏的重要参考文献
Well covered.
其他优缺点
N/A.
其他意见或建议
N/A.
--------------------After Rebuttal--------------------
I have raised my score from 2 to 3, and lean towards acceptance, only if the authors faithfully include new experimental results in the final manuscript.
We thank Reviewer Kh2x for the support and the valuable suggestions. Please see our response below.
Q1: Why does continued pre-training help and what if we remove it?
Thank you for the question. We first explain our motivation followed by the ablation study on continued pre-training.
-
Intuition of continued pre-training. In continued pretraining, we split long text into chunks; the model predicts the next chunk after selecting a sparse sub-network from the current one. Consider the extreme case where we split a pretraining text into just two chunks. the first acts as a prompt to select parameters, and the second as the target for prediction—closely mirroring the instruction–response format of SFT. By training the model in this way across millions of natural examples, the sparsity predictor learns to select the right sub-networks given different input contexts.
-
Empirical Results: We ran an ablation study with and without continued pre-training (6B → 3B, train for 400B tokens)
| HumanEval | MBPP | MultiPL-E | GSM8K | MATH | MMLU | |
|---|---|---|---|---|---|---|
| No continued pre-training | 25.3 | 24.4 | 15.3 | 50.9 | 13.5 | 55.2 |
| Continued pre-training | 31.9 | 35.3 | 22.4 | 61.3 | 20.1 | 59.3 |
We observe consistent and notable gains across all benchmarks. We will add these results and clarify this design choice in the final version.
Q2: Configuration details for the mask predictor network and the LLMs
The predictor consists of:
- An LLM with 302M parameter as the feature extractor (last hidden state of the final input token)
- Two-layer MLP:
- Linear(hidden_dim → 128)
- Linear(128 → num_layers × ffn_dim)
Regarding the LLMs used in this work: they follow standard LLM design, such as grouped-query attention and RMSNorm, with no custom components—similar to LLaMA. We will include these details in the final version.
Q3: Why no comparison with ShearedLLaMA or Minitron? What is the model distillation objective?
Thank you for your helpful question. We clarify both aspects:
- On comparison: We did not include ShearedLLaMA or Minitron due to differences in training data and model setup. Instead, we implemented a fairer pruning + distillation baseline using similar techniques in ShearedLLaMA and Minitron: structured pruning using learned masks, and logit distillation in continued pre-training. As shown below, our baseline outperforms the results in ShearedLLaMA.
| ARC-C | ARC-E | PiQA | Winogrande | MMLU | Avg. | |
|---|---|---|---|---|---|---|
| ShearedLLaMA 2.7B | 41.2 | 67.0 | 75.8 | 64.2 | 26.4 | 54.9 |
| Our pruning baseline (3B) | 46.2 | 79.9 | 77.3 | 69.1 | 62.8 | 67.6 |
| IFPRUNING 9B→3B | 50.4 | 81.4 | 78.0 | 68.4 | 65.5 | 68.7 |
Minitron results are only partially available, but our pruning baseline already outperforms its reported MMLU score. We will add these comparisons as a reference..
- Distillation objective: We apply KL divergence between the output distributions of the student and the teacher model for each output token. The teacher distribution only keeps the highest-scoring tokens, similar to Minitron. We minimize a combined loss of the standard next-token prediction and KL divergence loss.
Q4: Comparison with MoE.
Thank you for the question. We did not include MoE models due to a fundamental difference in inference scenarios. We will improve the clarity of our paper:
- Our method is designed for edge devices (e.g., smartphones) where memory and compute resources are limited, and the inference batch size is small.
- While MoE models are great for server-side large batch size inference, they are not efficient for on-device inference when generating responses given a single query.
- In this case, decoding is bottlenecked by weight loading. Since MoE requires reading many expert weights (e.g., 1-2 for each token), the cost of MoE is multiple times higher than a dense model and our method.
To illustrate, we compare our method with the open-source model, Qwen1.5-MoE-A2.7B. It activates 2.7B parameters per token. For our method, we prune LLaMA-3-8B to 3B parameters.
We report time-to-first-token (TTFT) and decoding time with input length = 4k, generation length = 100, and sample 4 responses for each query. For our method, we also report the latency for sub-network selection and loading parameters for selected sub-network.
| GPU | Model | Sub-network selection | Parameter loading | Pre-filling | TTFT | Decoding Time |
|---|---|---|---|---|---|---|
| A6000 | Llama-8b | 0.702 | 0.702 | 5.47 | ||
| Qwen-MoE | 0.621 | 0.621 | 28.43 | |||
| Llam-3b | 0.317 | 0.317 | 3.52 | |||
| Ours 8b->3b | 0.070 | 0.016 | 0.315 | 0.402 | 3.53 | |
| RTX3090 | Llama-8b | 0.947 | 0.947 | 5.48 | ||
| Qwen-MoE | OOM | OOM | OOM | |||
| Llama-3b | 0.396 | 0.396 | 3.18 | |||
| Ours 8b->3b | 0.088 | 0.013 | 0.396 | 0.498 | 3.21 |
We can see the dense baseline and our method have significantly better latency and throughput than MoE.
Finally, we agree with the reviewer's point that MoE and our method share the same spirit by dynamically activating parameters. In this regard, our model is a sparse model designed for on-device scenarios.
I thank the authors for the additional results which resolved most of raised issues. I thus raise my score from 2 to 3, and lean towards acceptance, only if the authors faithfully include new experimental results in the final manuscript.
We thank Reviewer Kh2x for the updated evaluation and constructive feedback. We’re glad to hear that the additional results addressed most of the concerns. We will make sure to faithfully include all new experimental results in the final version of the manuscript, as requested. We truly appreciate your support and consideration toward acceptance.
The paper proposes "Instruction-Following Pruning" (IFPRUNING), a novel approach to dynamic structured pruning of large language models (LLMs). Unlike traditional static pruning methods that determine a fixed pruning mask for a model, this approach generates input-dependent pruning masks that adapt based on the user's instruction. The method introduces a sparse mask predictor that takes the user instruction as input and dynamically selects the most relevant model parameters for the given task, focusing primarily on pruning feed-forward neural network layers. The architecture consists of two main components: (1) a sparsity predictor that extracts features from user prompts and generates masks, and (2) a dense LLM that gets pruned dynamically. The approach uses the SoftTopK algorithm to generate differentiable masks that activate only the most relevant parameters for specific inputs. The authors demonstrate that their method, which activates 3B parameters from larger models (6B, 9B, and 12B), outperforms dense 3B models and shows comparable performance to larger dense models in various tasks including math, coding, and general instruction-following benchmarks.
给作者的问题
- The method requires generating a new mask for each input (or task) during inference. Have you measured the computational overhead of this process compared to static pruning? This information would help clarify whether the performance gains outweigh the additional inference complexity.
- Why did you choose not to compare against other structured pruning methods like LLM-Pruner and SliceGPT, which would provide more meaningful baselines than just dense models and a simple distillation approach?
- How would your approach perform on widely-used open-source models like LLaMA 3.1 or Qwen 2.5? Results on these models would strengthen the practical applicability of your method.
- What is the reasoning behind designing a method that requires task-specific pruning for general-purpose LLMs? Most applications require general models that can handle diverse tasks without generating new masks for each input.
- For the per-task pruning scenario, how did you handle variations in task descriptions that might refer to the same underlying task? Is there a way to automatically cluster similar tasks to reuse masks?
论据与证据
The paper's central claim that dynamic pruning based on task descriptions leads to better performance than static pruning is supported by experimental evidence, but with limitations:
- The authors show that their 3B activated models outperform dense 3B models on various benchmarks, which supports their main claim.
- The claim that their approach avoids parameter reloading costs during decoding (compared to other dynamic methods) is reasonable but lacks direct empirical validation in terms of efficiency measurements.
- The claim about the interpretability of parameter selection is supported by the analysis in Section 4.2, showing that inputs requiring similar skills yield similar pruning patterns.
- However, the claim that their method rivals the performance of 9B models is only partially supported - while there are improvements over a dense 3B model, the gap to the 9B model remains noticeable in most benchmarks.
方法与评估标准
The proposed method makes sense for the problem of efficient LLM inference, but has several limitations:
- The sparsity predictor design is reasonable, using a smaller model to extract features from prompts and predict masking scores.
- The evaluation benchmarks are comprehensive, covering instruction-following, coding, math, NLP understanding, and tool use tasks.
- However, the method's practicality is questionable - if the goal is to have specialized models for specific tasks, traditional pruning followed by task-specific fine-tuning might be more straightforward and efficient.
- The evaluation compares against limited baselines - a dense 3B model and a pruned+distilled 3B model - but misses comparisons with competitive structured pruning methods like LLM-Pruner and SliceGPT, or other contextual sparsity approaches.
理论论述
The paper makes limited theoretical claims and provides no formal proofs. The SoftTopK algorithm for generating differentiable masks is cited from previous work without detailed theoretical analysis of its properties in this context. The paper would benefit from theoretical analysis of why task-specific pruning works better than static pruning and under what conditions this advantage would hold.
实验设计与分析
The experimental design has several issues:
- The use of AXLearn framework and JAX raises questions about whether the results would generalize to more commonly used frameworks.
- The experiments do not use widely recognized open-source models like LLaMA 3.1 or Qwen 2.5, limiting the practical applicability of the findings.
- The baselines are limited - there are no comparisons with state-of-the-art structured pruning methods like LLM-Pruner or SliceGPT.
- While the authors present per-task pruning analysis, they don't adequately address the practical overhead of generating task-specific masks for each new input, which could negate the efficiency gains from pruning.
- The paper lacks ablation studies on key components such as the size of the sparsity predictor or different mask generation algorithms.
补充材料
The supplementary material provides information about model architecture details, MMLU domain subsets, and task-specific prompts used for evaluation. It also includes licensing information for the datasets used. However, the supplementary material lacks detailed analysis on computational efficiency or additional ablation studies that could strengthen the paper's claims.
与现有文献的关系
The paper adequately positions itself relative to three areas of related work:
- Model pruning: The authors acknowledge prior work on structured pruning techniques for LLMs, including LLM-PRUNER, SLICEGPT, and SHORTGPT.
- Contextual sparsity: The paper discusses how their approach differs from other contextual sparsity methods that require pruning at each decoding step.
- Mixture-of-experts: The authors compare their approach to MoE models, noting that while MoE activates different parameters per token, their method fixes parameters based on the task description.
遗漏的重要参考文献
Dynamic pruning is not a novel topic. Actually, there are many previous works in this area, especially in computer vision. Some of them are listed here:
- Elkerdawy, S., Elhoushi, M., Zhang, H. and Ray, N., 2022. Fire together wire together: A dynamic pruning approach with self-supervised mask prediction. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 12454-12463).
- Gao, S., Zhang, Y., Huang, F. and Huang, H., 2024. BilevelPruning: unified dynamic and static channel pruning for convolutional neural networks. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (pp. 16090-16100).
- Le, Q., Diao, E., Wang, Z., Wang, X., Ding, J., Yang, L. and Anwar, A., 2025. Probe Pruning: Accelerating LLMs through Dynamic Pruning via Model-Probing. arXiv preprint arXiv:2502.15618.
其他优缺点
Strengths:
- The idea of task-specific pruning is conceptually interesting and represents a novel compromise between static pruning and token-level dynamic approaches.
- The analysis of activation patterns across different domains provides interesting insights into how the model specializes for different tasks.
- The performance improvements over a dense 3B model across different tasks are notable, demonstrating the potential of the approach.
- The authors explore both per-input and per-task pruning scenarios, showing the flexibility of their method.
Weaknesses:
- The core innovation is limited and has conceptual flaws. The main differentiation from other contextual sparsity work is that pruning happens at the task level rather than per token, but this approach still requires task-specific masks to be generated for most inputs in practical scenarios.
- The method section lacks technical details and theoretical support. The description of the training process and objective functions is cursory.
- The experimental section has significant limitations, with missing comparisons to competitive baselines and experiments predominantly conducted on non-standard frameworks and models.
- The practical utility is questionable - for general-purpose LLMs, adaptability across tasks is essential. If task specialization is the goal, traditional pruning plus fine-tuning might be more straightforward.
- The presentation quality is lacking, with poor figure quality (e.g., Figure 3's small font size) and insufficient technical details in key sections.
其他意见或建议
- The paper would benefit from clearer explanations of the technical details, especially regarding the sparsity predictor architecture and training.
- Implementation details about how the masks are efficiently computed during inference would strengthen the paper.
- It would be valuable to explicitly measure and report the computational overhead of the sparsity predictor.
- The authors should improve figure quality, particularly in Figure 3, where the text is difficult to read.
- For real-world applications, it would be helpful to discuss how the approach handles out-of-distribution task descriptions.
We thank Reviewer epS3 for the support and the valuable feedback. Please see our response below.
Q1: What is the reasoning behind designing a method that requires task-specific pruning for general-purpose LLMs?
We totally agree with the reviewer that having “general-purpose LLMs and adaptability across tasks” is essential. Indeed, our method aims to improve the general applicaiblity of smaller LMs. Please allow us to clarify a few things:
- Given a user instruction (can be a general-purpose question such as a coding problem, a travel suggestion, and a knowledge-seeking question etc.), our method dynamically prunes the model and uses the most suited parameters for inference.
- As the pruning mask is generated by “reading” the input instruction, our model can handle questions/tasks that are not seen during training. See for example our evaluation results on AlpacaEval and Arena-Hard.
In other words, our approach offers the following advantages:
- Improve inference efficiency over large dense model (see Q3);
- Improve the model quality over static pruning;
- Pruning based on natural language description enables zero-shot generalization to unseen tasks and instructions, making the model still a general-purpose LLM.
In contrast, naive task-specific fine-tuning can have problems: no model available for unseen tasks during inference; memory/storage overhead for deploying many task-specific models; additional cost to collect training data for each task, etc.
Q2: Included baseline is not strong enough. Need to compare with baselines including LLM-Pruner or SliceGPT.
We would like to clarify that our pruning+distill baseline reflects the SOTA practice in large-scale model development. It first prunes the model then continued pre-trains on trillions of tokens. This method combined with logit distillation is consistent with recent model developments like LLaMA 3.2, Nvidia Minitron, and Gemma 3.
To show the effectiveness of our method and our baseline, the table below summarizes the (relative) performance drop compared to the source model with different pruning methods:
| Sparsity | ARC-C | ARC-E | PiQA | HellaSwag | Winogrande | Average | Performance drop | ||
|---|---|---|---|---|---|---|---|---|---|
| LLM-Pruner | Source model: 7B | 47.6 | 72.8 | 79.8 | 76.1 | 70.1 | 69.3 | ||
| LLM-Pruner | 20% | 37.9 | 63.4 | 76.4 | 68.1 | 65.1 | 62.2 | 7.1% (relative 10%) | |
| SliceGPT/ShortGPT | Source model: 7B | 46.3 | 74.6 | 79.1 | 75.9 | 69.1 | 69.0 | ||
| SliceGPT | 30% | 34.1 | 50.7 | 67.4 | 55.7 | 63.2 | 54.2 | 14.8% (relative 21%) | |
| ShortGPT | 31% | 40.9 | 56.6 | 67.8 | 62.2 | 64.4 | 58.4 | 10.6% (relative 15%) | |
| IFPruning | Source model: 9B | 53.9 | 83.4 | 79.4 | 57.7 | 74.3 | 69.7 | ||
| Our baseline | 66% | 46.2 | 79.9 | 77.3 | 53.0 | 69.1 | 65.1 | 4.8% (relative 7%) | |
| IFPruning | 66% | 50.4 | 81.4 | 78.0 | 55.5 | 68.4 | 66.7 | 3.0% (relative 4%) |
Our baseline and IFPruning achieves much higher sparsity and smaller accuracy degradation, validating the effectiveness of our method.
Q3: Overhead of generating masks for each new input
We evaluate the latency with LLaMA-3.1 8B as the source model and prune it to 3B. Due to space limit, please also see our response to Q1 by reviewer Lv3b. In brief, mask generation adds only ~0.1s overhead. Specifically, with input length 4000, output length 100, and response sample size 4, we show:
- Time-to-first token (TTFT) decreased by up to 57%
- Decoding time decreased by up to 41%
- Comparable latency to static 3B models
| GPU | Model | Sub-network selection | Parameter loading | Pre-filling | TTFT | Decoding Time |
|---|---|---|---|---|---|---|
| A6000 | Llama-8b | - | - | 0.702 | 0.702 | 5.47 |
| Llama-3b | - | - | 0.317 | 0.317 | 3.52 | |
| Ours 8b->3b | 0.070 | 0.016 | 0.315 | 0.402 | 3.53 | |
| RTX3090 | Llama-8b | - | - | 0.947 | 0.947 | 5.48 |
| Llama-3b | - | - | 0.396 | 0.396 | 3.18 | |
| Ours 8b->3b | 0.088 | 0.013 | 0.396 | 0.498 | 3.21 |
The overhead from sub-network selection (mask generation) & parameter loading is negligible (~0.1s, ~2% of the total generation time).
Q4: Concerns about JAX/AXLearn and framework portability
Our core method is framework-agnostic and can be easily implemented in PyTorch, as shown in our latency experiments using LLaMA-3.1-8B-Instruct and PyTorch. We will make our best effort to release the full implementation on top of an open-source model.
Q5: Sparsity predictor architecture
The predictor consists of:
- An LLM with 302M parameter as the feature extractor (last hidden state of the final input token)
- Two-layer MLP:
- Linear(hidden_dim → 128)
- Linear(128 → num_layers × ffn_dim)
Training is end-to-end with the masked LLM using standard language modeling loss.
Q6: Missing citations, discussions with other dynamic pruning papers, and figure issues
Thank you for your suggestion. We promise that we will address your comments in the final revision of our paper.
The paper proposes a dynamic pruning method in which a router determines the pruning strategy of the FFN layers in an LLM model using the input instruction. The sparse mask predictor and LLM weights are jointly trained using instruction-following data and the pre-training corpus. Experiments on different target benchmarks demonstrate that the proposed pruning method can generally outperform the dense baseline of similar size.
给作者的问题
Please check my comments above.
论据与证据
Yes, the main claim of the paper is that using dynamic, input-dependent sparsity can be beneficial for the pruned model's performance rather than using the same sparsity pattern for all the inputs. The experimental results support such a hypothesis.
方法与评估标准
The components of the proposed method have been introduced in previous work so the idea does not have novel elements, but it combines them in an effective manner. I think the proposed method is sound and simple yet effective in practice, so I do not complain about novelty.
For the evaluation, I think the paper can be improved in the following aspect
- Although it has got a convention in Mixture of Experts literature to compare different models only based on active parameters, I think doing so does not paint a complete picture of the model's real performance in terms of inference latency. Depending on the network's topology, two 3B activated parameters models can have significant different latency values in practice on GPU/TPU. Therefore, the paper should provide at least the inference latency of the pruned model by IFPruning and the dense baselines.
理论论述
The paper has not theoretical claims.
实验设计与分析
As the main contribution of the paper is in empirical results rather than methodological development, I think the paper should be improved in the following aspects:
-
In Sec. 4.1, the paper only mentions that it uses an internal SFT dataset for training. I understand that they may not be able to release this dataset but they at least should provide some statistics and general structure of the dataset.
-
Also, the results that the higher the base model, the better the performance of the pruned model is not much surprising. There has been empirical and theoretical papers [1, 2] in the literature that indicate that overparameterization helps with training a better base model and improves the quality of the pruned model. The paper should indicate this connection in their discussion.
-
In Line 167, the paper admits that one can use HardConcrete trick to do pruning, yet it does not indicate why it choses not to do so. It would be nice to have a comparison between the current approach and HardConcrete as it is widely used in practical scenarios. However, I understand that doing experiments in the rebuttal period can be challenging, and I don't ask for new experiments.
[1] Learning and Generalization in Overparameterized Neural Networks, Allen-Zhu et al., 2019.
[2] Stronger generalization bounds for deep nets via a compression approach, Arora et al., 2019.
补充材料
Yes, I checked the appendix A.1 for the model architecture.
与现有文献的关系
The paper shows advantages compared to previous static pruning baselines in terms of performance as it makes the pruning strategy dependent on the input. However, I believe that the paper should indicate the following downsides compared to static pruning:
-
Static pruning enables batch-parallelism on GPUs while dynamic pruning cannot benefit from it as each sample in the batch may have a different pruning strategy.
-
Static pruning will enable lower memory usage for saving the model on disk and also consumes less GPU memory. In contrast, dynamic pruning cannot do so in practice.
-
Static pruning can achieve real inference speed up on GPUs/TPUs. It is harder to achieve inference latency reduction with dynamic pruning.
遗漏的重要参考文献
Please check the "Experimental Designs Or Analyses" section about the missing references.
其他优缺点
I cannot think of any other points other than the ones mentioned above.
其他意见或建议
I suggest that the authors provide inference latency of their models and the baselines. It is fine to me that the method does not beat the baselines (specially static pruning ones), but doing so will make the paper more useful for the readers and practitioners. I would be happy to raise my score if the authors provide this analysis.
We thank Reviewer Lv3b for the support and the valuable suggestions for our paper. We will address the writing feedback, such as adding statistics of the dataset and discussing additional related work, in the next version. Please see our response to the questions and/or major comments below.
Q1: Comparison with static pruning, and latency numbers.
We agree with the reviewer that static pruning enjoys better batch parallelism and lower memory usage compared to dynamic pruning. We discussed the downside of our work in Section 5 and will improve the clarification in our next version. Thank you!
As far as inference speed is concerned, we would like to clarify that our method is designed for on-device models (e.g. on smartphone / laptop / desktop), where the inference typically samples a few responses given the same user query (or the same task). In this case, the same activated parameters are selected and cached as a dense model, therefore achieving the same speedup as the static pruning and dense baseline.
To illustrate the real inference speedup, we test the inference latency by pruning the open-sourced LLaMA-3.1-8B-Instruct model to 3B. Although the tests are done on GPUs, we used batch size 1 and 4 generations per query, reflecting on-device usage. We report the time-to-first-token (TTFT) and the decoding time, both measured in seconds.
For dense models (8B and 3B), the TTFT consists of pre-filling only. For our method, we break down TTFT into its components:
- Sub-network selection (via the sparsity predictor)
- Parameter loading (load the selected parameters and cache the sub-network as a dense model)
- Pre-filling using the 3B sub-network
| GPU | Model | Sub-network selection | Parameter loading | Pre-filling | TTFT | Decoding Time | |
|---|---|---|---|---|---|---|---|
| Input length: 4k | A6000 | Llama-8b | - | - | 0.702 | 0.702 | 5.47 |
| Llama-3b | - | - | 0.317 | 0.317 | 3.52 | ||
| Ours 8b->3b | 0.070 | 0.016 | 0.315 | 0.402 | 3.53 | ||
| RTX3090 | Llama-8b | - | - | 0.947 | 0.947 | 5.48 | |
| Llama-3b | - | - | 0.396 | 0.396 | 3.18 | ||
| Ours 8b->3b | 0.088 | 0.013 | 0.396 | 0.498 | 3.21 | ||
| Input length: 2k | A6000 | Llama-8b | - | - | 0.336 | 0.336 | 4.11 |
| Llama-3b | - | - | 0.155 | 0.155 | 3.20 | ||
| Ours 8b->3b | 0.037 | 0.016 | 0.155 | 0.208 | 3.25 | ||
| RTX3090 | Llama-8b | - | - | 0.467 | 0.467 | 3.76 | |
| Llama-3b | - | - | 0.203 | 0.203 | 2.70 | ||
| Ours 8b->3b | 0.045 | 0.013 | 0.202 | 0.260 | 2.75 |
Key takeaways:
- TTFT decreased by up to 57%, decoding time decreased by up to 41%. In total, we achieve 1.8x speedup compared to Llama-8b.
- Overhead from dynamic pruning & parameter caching is negligible (~0.05s, ~2% of the total generation time)
- Despite dynamic masking, runtime of IFPruning is on par with static pruning, while offering input-specific adaptivity and superior accuracy.
We will include this analysis in the final version of our paper.
Q2: Choice of SoftTopK over HardConcrete.
A2: We appreciate the suggestion. Both SoftTopK and HardConcrete are mask generation operators introduced in previous work, and not a contribution of our work. We chose SoftTopK for simplicity for two reasons:
- Both methods achieve similar task performance.
- SoftTopK does not require another auxiliary loss, whereas HardConcrete needs some tuning (e.g. on the auxiliary loss weight and its learning rate).
- SoftTopk seems more stable in our preliminary study.
Q3: More details on SFT datasets
We will expand the final version with a detailed breakdown of the SFT datasets, including data sources, instruction formats, and size per domain.
Q4: Missing related work.
Thank you for highlighting these valuable references—we will include them in the final revision to strengthen the related work section.
The paper introduces Instruction-Following Pruning, a dynamic pruning method where a router determines the pruning strategy for the FFN layers in an LLM model based on the input instruction. The proposed method is well-designed and technically sound. Experiments across various target benchmarks show that the pruning method generally outperforms the dense baseline of similar size. Overall, this is a solid contribution. To improve the paper, the authors should add comparisons with more state-of-the-art structured pruning methods, demonstrate the method's generalizability to commonly used frameworks, and provide additional discussions on related work, datasets, and static pruning techniques.