CITER: Collaborative Inference for Efficient Large Language Model Decoding with Token-Level Routing
A novel Collaborative Inference with Token-lEvel Routing (CITER) framework that introduces a token-level routing mechanism, enabling efficient collaboration between small and large language models (SLMs & LLMs).
摘要
评审与讨论
This paper proposes a token-level router between a large language model and a small language model to improve inference time and computation cost. The papers includes an extensive ablation study and a good amount of comparisons to previous related work, which demonstrates that the proposed approach provides a better balance between accuracy and cost.
接收理由
- The formulation of the router training as a policy optimization problem is novel and demonstrates significant improvement over previous work.
- The paper demonstrates the approach's advantages over previous work on multiple benchmarks, and the extensive ablation study helps to clarify the role of the different method components.
拒绝理由
The paper is missing some more details on the training setup of the router, specifically:
- Are the router training samples (prompts and responses) taken from the benchmarks? If so, are they taken from all the benchmarks/some?
- How generalizable is the trained router to new tasks/benchmarks/domains?
- What is a typical size of the router training set?
- During training, is the output token of the model required to be identical to the expected next token from the ground truth? If not, how are they compared? If it is required to be identical, what about minor rephrasing of the output - couldn't the SLM still be correct even though the next token is not identical?
给作者的问题
- Do you have an hypothesis on why the improvement of CITER over the baselines and models is more modest for the Llama models (figure 6) compared with the Qwen models?
- What is the computational cost of training the router?
- There are a few grammar mistakes and typos throughout the paper. For example, in line 199: "In addition, its, Furthermore,…" and in line 113 - a reference to algorithm 1, line 10 that I think should be to line 9.
Q4: Do you have a hypothesis on why the improvement of CITER over the baselines is more modest for the Llama models compared with the Qwen models?
R4: Thank you for this insightful observation. We believe the more modest improvement of CITER over the baselines for the Llama models, compared to the Qwen models, is largely due to differences in the performance-to-size ratio among the SLMs used. As our experimental results show, Qwen2-1.5B achieves performance comparable to or even surpassing that of Llama3.1-7B on our benchmarks, despite being substantially smaller in model size. This smaller size translates to significantly lower KV cache requirements and data transfer costs.
As a result, when using Qwen2-1.5B as the small model in our framework, we can achieve strong performance with fewer routes to the large model and substantially less data movement, leading to greater efficiency gains. In contrast, the efficiency gap is narrower with the Llama models because the performance difference between the SLM and LLM in that family is smaller relative to their size difference.
Q5: What is the computational cost of training the router?
R5: Thank you for your question regarding the computational cost of training the router. Training the router itself is relatively inexpensive, as it is a lightweight MLP trained on a fixed dataset of SLM/LLM outputs. For example, training the router for CommonsenseQA on an NVIDIA H100 GPU takes less than one hour, and across all benchmarks, the router training can be completed efficiently on a single modern GPU.
However, the collection of token-level preference labels is significantly more computationally demanding, since it requires generating full responses for certain steps in the training data. The total computational cost for collecting these labels depends on the average prompt and response length of each dataset. In our experience, the label collection process typically takes between two and eight hours on 8 H100 GPUs, depending on the dataset.
We will clarify these points in the revised version to provide a more complete picture of the training pipeline’s resource requirements.
Q6: There are a few grammar mistakes and typos throughout the paper. For example, in line 199: "In addition, its, Furthermore,…" and in line 113 - a reference to algorithm 1, line 10 that I think should be to line 9.
R6: We thank the reviewer for pointing out these issues. We have carefully proofread the revised manuscript to address all grammar mistakes and typos, including those mentioned in Lines 113 and 199. Specifically, we delete "In addition, its , " in Line 199 and correct the reference to Algorithm 1 in Line 113. Additionally, we also correct the typos at Line 137 from to and so on.
Once again, we appreciate the reviewer’s helpful comments and will incorporate all suggested clarifications and corrections in the final version.
Q3: During training, is the output token of the model required to be identical to the expected next token from the ground truth? If not, how are they compared? What about minor rephrasings—couldn’t the SLM still be correct even if the next token is not identical?
R3: Thank you for raising this important question. In our current implementation, we do require the output token generated by either the SLM or LLM to be exactly identical to the corresponding ground truth next token in order to determine whether the prediction at each step is correct. We acknowledge that this approach can sometimes mark minor variations or acceptable rephrasing by the SLM as errors, even though the SLM might still be able to answer the question correctly in the end. However, we want to emphasize that this is not an inherent limitation of our method.
Before we introduced the shortcut mechanism, this issue did not arise. Originally, we would generate the entire output sequence and determine model preference based on the correctness of the complete response. In this setup, small rephrasing or local differences would not affect the final assessment, as the choice of model for each step was based on the overall output’s correctness rather than strict per-token matching.
To accelerate the collection of routing labels, we introduced a shortcut: when the SLM’s output at a given step is incorrect, but the LLM’s output matches the ground truth, we immediately assign the routing preference to the LLM for that token, rather than evaluating the complete response. We acknowledge that this shortcut can potentially overlook cases where the SLM’s output, despite not matching the ground truth token-for-token, could still lead to a correct final answer. However, we believe this trade-off offers a reasonable approximation for our approach. By introducing this shortcut, we are able to significantly reduce the computational cost of collecting preference labels, at the expense of possibly sacrificing a small amount of inference efficiency. If computational resources permit, removing this shortcut and evaluating routing choices based on the correctness of the complete response could further improve CITER’s performance.
Q4: Do you have a hypothesis on why the improvement of CITER over the baselines is more modest for the Llama models compared with the Qwen models?
R4: Thank you for this insightful observation. We believe the more modest improvement of CITER over the baselines for the Llama models, compared to the Qwen models, is largely due to differences in the performance-to-size ratio among the SLMs used. As our experimental results show, Qwen2-1.5B achieves performance comparable to or even surpassing that of Llama3.1-7B on our benchmarks, despite being substantially smaller in model size. This smaller size translates to significantly lower KV cache requirements and data transfer costs.
As a result, when using Qwen2-1.5B as the small model in our framework, we can achieve strong performance with fewer routes to the large model and substantially less data movement, leading to greater efficiency gains. In contrast, the efficiency gap is narrower with the Llama models because the performance difference between the SLM and LLM in that family is smaller relative to their size difference.
Q5: What is the computational cost of training the router?
R5: Thank you for your question regarding the computational cost of training the router. Training the router itself is relatively inexpensive, as it is a lightweight MLP trained on a fixed dataset of SLM/LLM outputs. For example, training the router for CommonsenseQA on an NVIDIA H100 GPU takes less than one hour, and across all benchmarks, the router training can be completed efficiently on a single modern GPU.
However, the collection of token-level preference labels is significantly more computationally demanding, since it requires generating full responses for certain steps in the training data. The total computational cost for collecting these labels depends on the average prompt and response length of each dataset. In our experience, the label collection process typically takes between two and eight hours on 8 H100 GPUs, depending on the dataset.
We will clarify these points in the revised version to provide a more complete picture of the training pipeline’s resource requirements.
Dear reviewer iqRc,
Sincerely thanks for your constructive feedback and positive assessment of our work. We address each concern and question in detail below.
Q1: Are the router training samples (prompts and responses) taken from the benchmarks? If so, are they taken from all the benchmarks/some? What is a typical size of the router training set?
R1: Yes, the router training samples (i.e., prompts and responses) are taken from the benchmark datasets used in our experiments. For each main experiment, we train the router on the training split of the corresponding benchmark (e.g., CSQA, ARC-Challenge, MMLU-PP, GSM8k, or MATH) and evaluate it on the respective test set. In each case, the training set comprises all available samples from the benchmark's training split (see Table 1 in the paper for exact dataset sizes). Specifically, the CommonsenseQA training set contains 9,741 examples; ARC-Challenge has 1,119 examples; MMLU-PP includes 612 samples; GSM8k and MATH have 7,473 and 7,500 samples, respectively.
Q2: How generalizable is the trained router to new tasks, benchmarks, or domains?
R2: We appreciate the reviewer for raising this point. In our paper, we primarily focus on in-domain experiments, where CITER is trained and evaluated on the training and test sets of the same benchmark dataset. However, we note that the router’s input features (i.e., the hidden states of the SLM) are generic representations, which indicates a potential for generalization across different domains. To explore this, we conducted additional experiments where we trained our router on a mixture of all the datasets used in our study, including CSQA, ARC-Challenge, MMLU-PP, GSM8k, and MATH, and then evaluated its performance on the OpenBookQA dataset. The results are presented in the table below.
Due to the limitations of the rebuttal system, we are unable to include figures or plots in our response. Therefore, we present our results in table format, where each row corresponds to a different method, each column represents a specific value of the x-axis variable (e.g., data transfer amount), and each cell shows the accuracy achieved under that setting. To facilitate a fair and direct comparison between methods, we performed linear interpolation on the original data to obtain accuracy values for all methods at the same set of x-axis points. This approach ensures that our comparisons are consistent and transparent across all methods and conditions.
Table R2: Performance comparison between CITER and our baselines on the OpenBookQA dataset. Higher values indicate better performance. The best result is bolded.
| Accuracy (%) \ Data Transfer Amount (GB) | 7.1 | 13.4 | 50.0 | 100.0 | 150.0 | 179.3 | 191.1 |
|---|---|---|---|---|---|---|---|
| SLM | 56.6 | / | / | / | / | / | / |
| LLM | / | / | / | / | / | / | 91.1 |
| Speculative Decoding | / | / | / | / | / | 91.1 | / |
| RouteLLM | / | 58.0 | 64.7 | 73.8 | 83.6 | 88.8 | / |
| Co-LLM | / | 59.6 | 66.9 | 76.7 | 86.2 | 89.4 | / |
| CITER (Ours) | / | 60.0 | 67.3 | 79.1 | 89.8 | 90.6 | / |
These results demonstrate that CITER continues to outperform all baselines on the new dataset, highlighting the generalizability of our approach.
Dear reviewer iqRc,
We would like to follow up to see if the response addresses your concerns or if you have any further questions. We would really appreciate the opportunity to discuss this further if our response has not already addressed your concerns. Thank you again!
Dear reviewer iqRc,
Thank you very much for your positive evaluation of our manuscript. We are pleased to learn that our responses and modifications have successfully addressed your previous concerns. We sincerely appreciate your increasing the score.
Thank you once again for your time and professional guidance.
Thank you for the detailed response! The authors have addressed all my questions and concerns. I increased my score and look forward to reading the updated paper.
This paper introduces CITER that features the training of a token-level decoding router with a cost-aware RL objective, shortcut reward, and long-term planning. The high-level idea generally follows Co-LLM, but it manages to deliver sizable cost and accuracy gains over Co‑LLM.
接收理由
- The proposed router training approach that considers long-term impact of routing decisions is novel and effective.
- Comprehensive baselines, experiments and analyses that show solid improvements brought by CITER.
拒绝理由
- The reliance on in-domain ground-truth data for the tuning of the router model makes CITER less generalizable compared to methods like speculative decoding.
- The contribution of CITER is somewhat incremental to Co-LLM in terms of the idea on token-level decoding routing.
给作者的问题
- Will the code be open-sourced?
- It would be helpful if the paper included a discussion of CITER’s current limitations and suggested possible ways to address them in future research.
Q2: The contribution of CITER is somewhat incremental to Co-LLM in terms of the idea on token-level decoding routing.
R2: We acknowledge that CITER builds upon the general idea of token-level routing, as introduced by Co-LLM. However, our contributions go significantly beyond prior work in two main aspects:
First, CITER explicitly formulates router training as a cost-aware RL objective with long-term planning, allowing the router to optimize for the global response quality rather than per-token accuracy. Unlike Co-LLM, which bases routing largely on token-wise classification or local confidence scores, CITER treats the routing process as a Markov Decision Process (MDP) and leverages reinforcement learning (RL) techniques to account for the long-term impact of routing decisions. This means the router is trained not just to maximize immediate correctness at each step, but to consider how current choices influence the quality and efficiency of the entire generated sequence. By incorporating cost-aware reward functions and evaluating complete response quality, CITER’s RL-based framework allows for a more holistic and principled trade-off between accuracy and inference cost, resulting in smarter and more globally optimal routing behavior.
Second, the shortcut reward and iterative optimization strategies in CITER enable efficient and scalable training while preserving performance gains. One of the main challenges in training a token-level router is the substantial computational overhead of simulating all possible routing trajectories. To address this, CITER introduces a shortcut reward mechanism that accelerates the label collection process by using local correctness when appropriate, without compromising final performance. Furthermore, our iterative optimization scheme alternates between collecting routing labels based on current policy and refining the router, which both reduces training cost and improves convergence stability. These strategies together allow CITER to scale to large datasets and models while maintaining strong empirical performance.
Our experimental results demonstrate that these innovations lead to sizable improvements over Co-LLM across multiple benchmarks, establishing CITER as a substantial advancement over previous token-level routing approaches.
Q3: Will the code be open-sourced?
R3: Yes, we will open-source our code upon the acceptance of the paper. We believe this will benefit the community and facilitate further research on efficient LLM inference.
Q4: It would be helpful if the paper included a discussion of CITER’s current limitations and suggested possible ways to address them in future research.
R4: Thank you for this valuable suggestion. We agree that discussing CITER’s current limitations and potential future directions is important for clarifying its scope and impact. We will include a dedicated section on limitations and future work. Specifically, the current limitations of CITER and avenues for future research are as follows:
-
Although CITER is designed to be agnostic to the base model’s modality, our experiments are limited to language models and do not include other modalities. Applying CITER to additional domains, such as vision or multimodal tasks, to evaluate its effectiveness and generalizability beyond language models is a promising direction for future research.
-
This work does not explore integrating CITER with other LLM inference acceleration techniques, such as model compression, speculative decoding, or more advanced model architectures like those with MoE (Mixture-of-Experts) structures. We believe such integrations could further reduce inference costs and enhance CITER’s suitability for real-time or edge-based applications.
Dear reviewer xyvj,
We sincerely appreciate it for your thoughtful feedback and suggestions. We address the main concerns and questions below.
Q1: The reliance on in-domain ground-truth data for the tuning of the router model makes CITER less generalizable compared to methods like speculative decoding.
R1: We appreciate the reviewer’s observation regarding the generalizability of CITER. We acknowledge that our current approach relies on in-domain ground-truth data for training the router, which may limit its direct applicability to entirely new domains compared to approaches like speculative decoding that do not require explicit supervision. However, the router’s input features, specifically, the hidden states from the SLM, are generic representations, which suggests the potential for generalization across different domains. To explore this, we conducted additional experiments where we trained our router on a mixture of all the datasets used in our study, including CSQA, ARC-Challenge, MMLU-PP, GSM8k, and MATH, and then evaluated its performance on the OpenBookQA dataset. The results are presented in the table below.
Due to the limitations of the rebuttal system, we are unable to include figures or plots in our response. Therefore, we present our results in table format, where each row corresponds to a different method, each column represents a specific value of the x-axis variable (e.g., data transfer amount), and each cell shows the accuracy achieved under that setting. To facilitate a fair and direct comparison between methods, we performed linear interpolation on the original data to obtain accuracy values for all methods at the same set of x-axis points. This approach ensures that our comparisons are consistent and transparent across all methods and conditions.
Table R1: Performance comparison between CITER and our baselines on the OpenBookQA dataset. Higher values indicate better performance. The best result is bolded.
| Accuracy (%) \ Data Transfer Amount (GB) | 7.1 | 13.4 | 50.0 | 100.0 | 150.0 | 179.3 | 191.1 |
|---|---|---|---|---|---|---|---|
| SLM | 56.6 | / | / | / | / | / | / |
| LLM | / | / | / | / | / | / | 91.1 |
| Speculative Decoding | / | / | / | / | / | 91.1 | / |
| RouteLLM | / | 58.0 | 64.7 | 73.8 | 83.6 | 88.8 | / |
| Co-LLM | / | 59.6 | 66.9 | 76.7 | 86.2 | 89.4 | / |
| CITER (Ours) | / | 60.0 | 67.3 | 79.1 | 89.8 | 90.6 | / |
These results demonstrate that CITER continues to outperform all baselines on the new dataset, highlighting the generalizability of our approach.
Thanks for your detailed response. I have raised the rating but believe that more comprehensive experiments are required to better understand the cross domain/dataset generalizability of the proposed method.
Dear Reviewer xyvj,
Thank you for your positive feedback and for raising the rating of our manuscript. We are very grateful for your support and recognition of our work.
We completely agree that a more comprehensive evaluation of cross-domain/dataset generalizability is a valuable direction. Following your insightful suggestion, we now explicitly state that our current validation focuses primarily on in-domain tasks in our conclusion section. We mention that while our initial experiments on the OpenBookQA dataset show potential for generalization, a more thorough investigation across diverse datasets and settings will be helpful for better understanding our method.
Thank you again for your insightful guidance. Your feedback has been invaluable in helping us not only improve the current manuscript but also shape our future research.
This paper explores policy optimization as a way to route between small and large LLMs for inference cost savings. The authors frame token-level routing as an MDP w/ a policy that can be optimized. They evaluate this method on CSQA, ARC-Challenge, MMLU-PP, GSM-8k and MATH. The baselines are classifier/algorithm based token level and query level methods like Speculative Decoding, RouteLLM, Co-LLM. They also do some analysis on the possible reasons behind this method being better.
接收理由
- Routing a sequence of tokens between 2 possible models is an MDP, exploring a Policy Optimization for this instead of independently classifying these tokens makes sense to me and is well motivated.
- The authors way of using ground truth to get preference labels for cheap makes sense (like teacher forcing).
- Intuitions given for differences in quality+tradeoffs between different methods is good.
Note on thoroughness:
- I checked the math and it seems to be correct.
- I skimmed the Related Work and it seems reasonable, though it's always possible to miss a few papers.
拒绝理由
N/A - this paper should be published, though I think spending time on my suggestions for improving draft quality would greatly improve clarity.
给作者的问题
Questions:
- Could you explain why choosing just SLM for CITER-S ablates on long-term influence of the trajectory chosen by the policy? Is this just hoping that LLM has better representation even if the token prediction is wrong?
Suggestions to improve:
- More thorough setup of background for readers without a RL background would be helpful. (Eg. Why this specific RL policy method? Most of this setup is clear from the DPO paper, but more setup is helpful along w/ citing DPO earlier in 2.1, though you do this in 2.2)
- More description of how you get data for RouteLLM and Co-LLM would be helpful
- More description of (data transformation) x-axis definition in Figure 2-5. Details in Appendix are good, but a 1-2 line high level definition in paper would be good.
- Move Section 3.7 to Appendix if you need space for more clear descriptions. More extensive analysis on this could have been nice.
Typos etc:
- Line 199: "In addition, its , Furthermore,"
- Compatibility Analysis: maybe just say "Results on Different Model Families" or something
- Case Study Analysis on the Router -> Qualitative Analysis on the Router
Q3: Typos and Wording
R3: We appreciate the reviewer for pointing out these issues.
Line 199: "In addition, its , Furthermore,"
Compatibility Analysis: maybe just say "Results on Different Model Families" or something
Case Study Analysis on the Router -> Qualitative Analysis on the Router
We have carefully proofread the revised manuscript to address all grammar mistakes and typos, including those mentioned in Line 199. Specifically, we delete "In addition, its , " in Line 199 and correct the reference to Algorithm 1 in Line 113. Additionally, we also correct the typos at Line 137 from to and so on.
We have revised section headings for clarity, e.g., “Compatibility Analysis” to “Results on Different Model Families” and “Case Study Analysis on the Router” to “Qualitative Analysis on the Router.”
Once again, we are grateful for the reviewer’s detailed feedback and support for acceptance. We believe these changes will further improve the clarity and impact of our work.
Dear reviewer ZgLP,
Thanks for your thorough and positive evaluation, as well as the constructive suggestions to further improve the clarity and accessibility of our paper.
Q1: Could you explain why choosing just SLM for CITER-S ablates on long-term influence of the trajectory chosen by the policy? Is this just hoping that LLM has better representation even if the token prediction is wrong?
R1: Thank you for this insightful question. In CITER-S, we ablate the long-term influence by assigning the SLM whenever both the SLM and LLM predict the next token incorrectly, rather than evaluating the impact of this choice on the final output. By selecting the SLM in this scenario, our primary goal is to reduce inference cost. Since the only available information is that both models have made incorrect predictions, we do not assume that the LLM’s output is inherently better than the SLM’s when both are wrong; therefore, there is no justification for preferring the LLM in such cases, and we default to using the SLM.
Q2: Suggestions to Improve
R2: We appreciate the reviewer for suggesting the following improvements and will address them in the revised version.
More thorough setup of background for readers without a RL background would be helpful. (Eg. Why this specific RL policy method? Most of this setup is clear from the DPO paper, but more setup is helpful along w/ citing DPO earlier in 2.1, though you do this in 2.2)
We will add an introductory explanation clarifying our use of a reinforcement learning (RL) framework. Specifically, we adopt RL because the router functions as a decision policy model, and this approach enables the router to learn routing strategies that minimize inference cost while maintaining high performance. We will emphasize this motivation in Section 2.1, as mentioned at the beginning of Section 2, to make the rationale behind our approach clearer for readers.
More description of how you get data for RouteLLM and Co-LLM would be helpful
In our experiments, all datasets used contain gold-standard labels. Therefore, we adopt the same procedure as RouteLLM for the MMLU dataset to generate ground-truth preference labels across all our datasets. Specifically, we use both the pure LLM and SLM to generate complete responses for each query in the training set, then apply a rule-based approach to extract the final answer from these responses. The preference label is subsequently determined by the correctness of the extracted final answer.
Co-LLM operates as a token-level router, requiring token-level ground-truth next-token labels for training. We follow the same generation trajectory as CITER, which first produces a rationale for query analysis followed by the final answer to the question. For datasets such as GSM8k, MATH, and MMLU-PP, which already provide reasoning steps, we directly use the provided chain-of-thought (CoT) rationales. For datasets like CSQA and ARC-Challenge, which lack such rationales, we employ ChatGPT 3.5 turbo to generate reasoning steps by prompting it to produce a reasoning process for the original question, given the ground-truth answer. We also instruct ChatGPT 3.5 turbo to provide a predicted answer based on the generated rationale, and filter out any rationales whose predicted answers do not match the ground-truth answers.
More description of (data transformation) x-axis definition in Figure 2-5. Details in Appendix are good, but a 1-2 line high level definition in paper would be good.
We agree that providing a high-level explanation in the main text will help readers. We will add a one- to two-sentence summary clarifying the definition of data transformation amount. Specifically, it refers to the volume of data, including both model weights and the KV cache, that must be transferred from GPU HBM to the on-chip cache.
Move Section 3.7 to Appendix if you need space for more clear descriptions. More extensive analysis on this could have been nice.
Thank you for the space-saving suggestion. If additional clarification or description is required elsewhere in the paper, we will move Section 3.7 (case study analysis) to the Appendix to accommodate more essential content or further analysis.
Thank you for your response! I will maintain my score, and look forward to seeing the updated draft.
Dear Reviewer ZgLP,
Thank you very much for suggesting the revisions to our manuscript. We believe they have significantly improved our paper's quality. We also want to express our sincere appreciation for your high evaluation of our paper.
Thank you once again for your time and professional guidance.
The paper focuses on improving the inference efficiency of a large language model (LLM) via collaborative inference with a small language model (SLM). More specifically, the paper employs a token-level router that assigns the task of generating a token at a particular step to SLM or LLM. The paper proposes an RL-based training procedure to develop such a token-level router. The paper also proposes an efficient method to obtain the reward for the token-level routing decisions made by the router. The paper evaluates the proposed approach, namely CITER, with other techniques to realize LLM inference efficiency, including query-level routing, token-level routing without RL-based training, and speculative decoding. As for the evaluation metric, the paper considers avg. quality vs. avg. amount of data transferred during decoding on multiple benchmarks.
接收理由
- The paper studies an important and timely problem of improving the inference efficiency for LLMs.
- The proposed token-level routing method (CITER) demonstrates better performance compared to existing query- and token-level routing methods.
- CITER makes novel technical contributions compared to existing token-level routing methods by employing RL-based router training while taking the long-term effect of routing decisions into account.
- The paper provides adequate empirical support for the utility of CITER and carries out multiple ablation studies to identify the value of individual components in CITER.
拒绝理由
- The paper is missing the discussion on/comparison with model cascading, which is another approach to realize efficient LLM inference. E.g., https://openreview.net/forum?id=KgaBScZ4VI (query-level cascading) and https://openreview.net/forum?id=vo9t20wsmd (Section 3.1; token-level cascading).
- KV caching is an important component of modern LLM inference. The paper needs to expand the relevant discussion (Remark 2.1) to make it more comprehensive. For instance, let's say and denote the KV caches of SLM and LLM, respectively. If the router returns to LLM, say after t tokens (with last t tokens decoded by SLM), then LLM would have to compute Key-values for the last t tokens in the context (which were generated by the LLM). Thus, it appears that the proposed method would require running a forward pass on all the tokens by both SLM and LLM. Could the authors clarify this point? Also, did the authors take this into account while studying the quality vs. cost tradeoff for the proposed method?
- There is significant scope for improving the presentation of the paper. E.g.,
- Eq. (4) uses while Eq. (2) uses .
- The text after Eq. (2) methods expectation over prompt , but the dependence of the objective on is not immediately apparent from the equation.
- How does follow in the equation after Line 140?
- (Minor) Line 137, should be ?
给作者的问题
- Could the authors provide more details on the router architecture? Is it a bag-of-words style model? Does it take context text as input, or does it act on some kind of representation of the context?
- Speculative decoding typically focuses on latency. Could authors include quality vs. latency plots, even for a subset of experimental settings?
Q5: Speculative decoding typically focuses on latency. Could authors include quality vs. latency plots, even for a subset of experimental settings?
R5: We appreciate the reviewer's suggestion. In our paper, our main focus was on “quality vs. computation cost” (measured by data transfer/FLOPs) to provide a hardware-agnostic comparison. However, we agree that latency is a practically important metric, especially for speculative decoding. We have conducted experiments on the CommonsenseQA dataset to collect end-to-end latency measurements (average wall-clock inference time per sample) for both our method and Speculative Decoding. The results are shown in table below. Note that the latency results are influenced by the hardware environment, so the results may vary across different machines and configurations.
Due to the limitations of the rebuttal system, we are unable to include figures or plots in our response. Therefore, we present our results in table format, where each row corresponds to a different method, each column represents a specific value of the x-axis variable (e.g., inference time), and each cell shows the accuracy achieved under that setting. To facilitate a fair and direct comparison between methods, we performed linear interpolation on the original data to obtain accuracy values for all methods at the same set of x-axis points. This approach ensures that our comparisons are consistent and transparent across all methods and conditions.
Table R5: Latency comparison between CITER and Speculative Decoding on the CommonSense QA dataset.
| Accuracy (%) \ inference time (s) | 0.4 | 1.6 | 2.0 | 2.8 | 4.0 | 6.1 | 10.9 | 12.8 |
|---|---|---|---|---|---|---|---|---|
| SLM | 34.4 | / | / | / | / | / | / | / |
| LLM | / | / | / | / | / | / | / | 87.0 |
| Speculative Decoding | / | / | / | / | / | / | 87.0 | / |
| CITER (Ours) | / | 57.6 | 65.4 | 76.6 | 80.8 | 81.9 | 85.6 | / |
Preliminary results show that CITER can achieve competitive quality while reducing inference latency by approximately 44%. This improvement is primarily due to the router mechanism, which requires the LLM only to generate a few tokens, eliminating the need for the LLM to verify every token during the generation process.
Q3: There is significant scope for improving the presentation of the paper. E.g.,
Thank you for carefully reading our notation and derivations.
Eq. (4) uses while Eq. (2) uses .
R3: We have revised the notation in Eq. (4) to to match the notation used in Eq. (2). Additionally, we have corrected the inconsistent occurrences of in the subsequent equations in the revised version, ensuring that all definitions are clear and consistent throughout the paper. In this context, in Eq. (5) refers to the policy model that we aim to optimize; it starts as an initial policy and, after training, converges to the optimal policy .
The text after Eq. (2) methods expectation over prompt , but the dependence of the objective on is not immediately apparent from the equation.
The variable is defined as the initial prompt. While does not appear explicitly in the objective function equation, the expectation is indeed taken over as well. This is because , together with the inherent randomness from the language model and the transition kernel , determines the distribution over the entire state sequence . In other words, the generation process starts from , and subsequent states are determined by both the initial prompt and the stochastic transitions governed by the model and policy. Therefore, although is not written directly in the equation, the expectation implicitly integrates over . We will clarify this point in the revised paper to improve the reader’s understanding.
How does follow in the equation after Line 140?
This is derived with an additional assumption that , which we will clarify in the revised version. By doing so, coupled with the fact that , we get .
(Minor) Line 137, should be ?
Yes, we have carefully proofread the revised manuscript to address all grammar mistakes and typos, including those mentioned in Line 137. Specifically, we correct the typos at Line 137 from to , delete "In addition, its , " in Line 199 and correct the reference to Algorithm 1 in Line 113 and so on.
Q4: Could the authors provide more details on the router architecture? Is it a bag-of-words style model? Does it take context text as input, or does it act on some kind of representation of the context?
R4: The router is implemented as a multi-layer perceptron (MLP) with three hidden layers, ReLU activations, and batch normalization, as described in the implementation details of Section 3.1. We use the hidden state corresponding to the last generated token from the SLM as the input to our router. This approach enables the router to utilize the rich representations extracted by the SLM, allowing routing decisions to be informed not only by the current token but also by the broader context accumulated thus far.
Table R1: Performance comparison between CITER and our baselines on the CommonSense QA dataset. Higher values indicate better performance. The best result is bolded.
| Accuracy (%) \ Data Transfer Amount (GB) | 9.1 | 29.7 | 100.0 | 200.0 | 300.0 | 400.0 | 487.9 | 547.7 |
|---|---|---|---|---|---|---|---|---|
| SLM | 34.4 | / | / | / | / | / | / | / |
| LLM | / | / | / | / | / | / | / | 87.0 |
| Speculative Decoding | / | / | / | / | / | / | 87.0 | / |
| RouteLLM | / | 37.8 | 49.1 | 59.5 | 66.4 | 71.5 | 81.4 | / |
| Co-LLM | / | 41.0 | 62.0 | 67.6 | 76.3 | 79.9 | 83.9 | / |
| LMC | / | 36.3 | 42.9 | 52.1 | 62.7 | 72.5 | 81.2 | / |
| FCSD | / | 40.5 | 49.2 | 62.0 | / | / | / | / |
| CITER (Ours) | / | 57.6 | 63.6 | 71.3 | 78.0 | 81.5 | 84.6 | / |
Q2: KV caching is an important component of modern LLM inference. The paper needs to expand the relevant discussion (Remark 2.1) to make it more comprehensive. For instance, let's say and denote the KV caches of SLM and LLM, respectively. If the router returns to LLM, say after t tokens (with last t tokens decoded by SLM), then LLM would have to compute Key-values for the last t tokens in the context (which were generated by the LLM). Thus, it appears that the proposed method would require running a forward pass on all the tokens by both SLM and LLM. Could the authors clarify this point? Also, did the authors take this into account while studying the quality vs. cost tradeoff for the proposed method?
R2: We thank the reviewer for highlighting the critical issue of KV cache reuse and the true cost of switching between SLM and LLM.
We acknowledge that if the router decides to route to the large model (LLM) near the end of the generation process, then both the SLM and LLM will have processed almost the entire sequence of tokens, except for the tokens generated by the SLM after the final LLM invocation. In our experimental setup, we have already included this recomputation and the associated data movement overhead in our calculation of inference cost (data transfer amount), as reported in the main results.
Although in the worst-case scenario the LLM may need to process nearly all tokens, in practice, the majority of tokens during generation are produced by the SLM. The number of LLM forward passes is thus greatly reduced compared to always using the LLM. At the decoding stage, although KV cache mechanisms significantly reduce the per-step computational requirements of LLM inference, each forward pass still incurs the cost of transferring the model’s KV cache from off-chip memory (GPU HBM) to on-chip cache. By reducing the number of LLM forward passes, our method substantially reduces the frequency of KV cache transfers, thereby achieving significant savings in data transfer and associated inference cost.
To be more concrete, suppose the input has m tokens and the output has n tokens. In standard LLM decoding, the total data movement (in terms of KV cache transfer) is proportional to , with the proportionality constant determined by the LLM’s per-token KV cache size. While, in our method, let denote the proportion of tokens routed to the LLM (assuming routing is uniformly distributed). Ignoring the comparatively negligible KV cache size of the SLM, our approach incurs total data movement proportional to . For clarity, this analytic estimate neglects the SLM’s data movement, but in our experiments we do account for the SLM’s actual data transfer contribution. Our results show that even with a small (e.g., <5%), CITER delivers a significant accuracy boost compared to SLM alone; with between 1/3 and 1/2, CITER achieves output quality comparable to the full LLM, but at much lower total inference cost. As we discuss in Section 3.1, since LLM generation is memory-bound, reducing the total data transfer and KV cache movement directly alleviates the I/O bottleneck in practical deployments.
Dear reviewer mhTX,
We sincerely appreciate your thoughtful comments and constructive suggestions.
Q1: The paper is missing the discussion on/comparison with model cascading, which is another approach to realize efficient LLM inference. E.g., https://openreview.net/forum?id=KgaBScZ4VI (query-level cascading) and https://openreview.net/forum?id=vo9t20wsmd (Section 3.1; token-level cascading).
R1: We thank the reviewer for raising the important point regarding the comparison with model cascading methods. We fully agree that model cascading is a relevant and influential line of work for efficient LLM inference, and we will include a detailed discussion of these methods in the Related Work section.
Specifically, regarding the two representative works cited:
- Language Model Cascades: Token-Level Uncertainty and Beyond (LMC, KgaBScZ4VI)
This approach performs query-level cascading, where the small model (SLM) first generates the entire sequence, and then the decision to defer to the large model (LLM) is made based on the uncertainty (e.g., entropy or confidence derived from logits) during generation. This means that whenever deferral happens, both the SLM and LLM have to process the same input tokens, leading to redundant computational costs. Moreover, similar to query-level routing, query-level cascading lacks the flexibility to distinguish between easy and difficult parts of the sequence, thereby missing opportunities for fine-grained efficiency gains that token-level routing can provide.
- Faster Cascades via Speculative Decoding (FCSD, vo9t20wsmd)
This work innovatively introduces token-level cascading and tackles the challenge of implementing deferral rules by leveraging speculative decoding. However, their mechanism relies exclusively on the logits output by the model to determine when to defer to the LLM. As a result, their method is highly dependent on the calibration of the underlying model—i.e., how well the model’s predicted confidences align with actual correctness. Additionally, similar to the CITER-S variant discussed in our Section 3.3, this approach only considers the current token’s prediction quality, without accounting for the long-term effects that current routing decisions may have on the overall generation process.
In contrast, our approach offers two key advantages: 1. By introducing an independent router model to generate routing decisions, CITER reduces reliance on the SLM/LLM’s intrinsic calibration, allowing for more robust and generalizable routing. 2. During router training, we explicitly optimize not just for token-level correctness but for global response quality, thereby directly modeling the impact of each routing decision on the entire generation process, which is a crucial distinction that leads to improved overall performance.
To further address the reviewer’s concern, we have implemented both of them as query-level and token-level cascading baselines and compared them with CITER on the CommonsenseQA dataset. The results are presented in the table below.
Due to the limitations of the rebuttal system, we are unable to include figures or plots in our response. Therefore, we present our results in table format, where each row corresponds to a different method, each column represents a specific value of the x-axis variable (e.g., data transfer amount), and each cell shows the accuracy achieved under that setting. To facilitate a fair and direct comparison between methods, we performed linear interpolation on the original data to obtain accuracy values for all methods at the same set of x-axis points. This approach ensures that our comparisons are consistent and transparent across all methods and conditions.
The results align with our previous discussion: CITER consistently outperforms both query-level and token-level cascading methods. Notably, all cascading baselines perform worse than corresponding router-based approaches, largely due to the uncalibrated nature of the logits produced by the models. In particular, for the token-level cascading method, the maximum achievable accuracy is limited to 64.2% with Data Transfer Amount = 226.7GB even when using a threshold of 0. This means that even if we assume deferring to the LLM incurs no additional cost, and always defer to the LLM whenever its maximum logit score is equal to or greater than that of the SLM, the overall performance still remains suboptimal. The underlying reason is that, at least in our experiments, the logits are poorly calibrated with respect to true token correctness: the LLM’s output can be more accurate even when its logit score is lower than that of the SLM.
Thank you for providing a detailed response. Most of my questions and concerns have been successfully addressed. Including new discussion and results in the revised paper will significantly enhance its quality. I have increased my score.
Dear reviewer mhTX,
Thank you very much for your positive evaluation and for suggesting the revisions to our manuscript. We are pleased to learn that our responses and modifications have successfully addressed your previous concerns. We sincerely appreciate your increasing the score.
Thank you once again for your time and professional guidance.
This paper introduced a novel technique in the framework of using a hybrid of SLM and LLM to improve decoding efficiency of LLMs while maintaining the accuracy. The reviewers unanimously recommend acceptance of the paper and so do I. The technique is simple, practical and scalable and has the potential for large scale impact when implemented correctly.
Great work and looking forward to the impact!