Zebra-Llama: Towards Extremely Efficient Hybrid Models
摘要
评审与讨论
The paper introduces Zebra-Llama, a novel approach for creating efficient hybrid language models by combining existing pre-trained Transformer models with State Space Models and Multi-head Latent Attention layers. The method significantly reduces computational overhead and memory requirements, KV cache size, achieving performance competitive with traditional Transformers while drastically cutting down the resources required for model deployment and training.
优缺点分析
Strengths:
S1. Zebra-Llama demonstrates remarkable efficiency in reducing KV cache size (up to ~50x compression), significantly improving inference throughput (2.6x-3.8x faster than comparable models).
S2. The intermediate layer distillation and layer selection strategy strategies are methodologically robust, clearly contributing to model effectiveness. Ablation studies clearly demonstrates the effective of each idea (layer selection and intermediate layer distillation).
Weaknesses:
W1. Llamba-8B performs better than Zebra-Llama on MMLU. Could the authors provide insights into this result? It would be helpful to conduct a controlled study with Llamba distillation pipeline using the same training dataset, since Llamba is a pure RNN model and it only has fixed size which has more saving compared with the proposed approach.
W2. Scaling to larger models and different model families: The paper could be improved by exploring different models, such as Qwen3, or by scaling to larger model sizes.
问题
Any insights about why split in such way?
局限性
-
Most importantly, Llamba-8B performs better than Zebra-Llama on MMLU. Could the authors provide insights into this result? It would be helpful to conduct a controlled study with Llamba distillation pipeline using the same training dataset, since Llamba is a pure RNN model and it only has fixed size which has more saving compared with the proposed approach.
-
Scaling to larger models and different model families: The paper could be improved by exploring different models, such as Qwen3, or by scaling to larger model sizes.
-
Consider evaluating on more challenging tasks, such as GSM8K and MATH, to better demonstrate the effectiveness of the proposed approach.
I can increase my score if some of these concerns are addressed.
最终评判理由
Given their significant amount of work during the rebuttal, including the Qwen experiment, more challenge benchmarks and long-context evaluation, I have raised my score.
格式问题
N/A
We sincerely thank the reviewer for the comprehensive review and valuable feedback, particularly your recognition of our method's strengths in KV cache saving, inference speedup, and the effectiveness of our intermediate layer distillation and layer selection strategy. We have carefully addressed the insightful comments regarding comparing MMLU scores, comparisons with Llamba, model scalability, and performance on other datasets and supported our response with additional experimental evidence. We appreciate your consideration of these updates and would be grateful if you could kindly revisit your evaluation of our paper in light of the improvements.
[Answer 1] Comparing MMLU score with Llamba for the 8B model
We acknowledge the observed MMLU score gap between our 8B Zebra-Llama and Llamba-8B. As shown in [1], the MMLU task format can be particularly challenging for State Space Models (SSMs) to adapt to, often requiring more training efforts to improve. After a thorough review of the Llamba paper, we believe this gap likely comes from two key factors:
- Fewer training tokens and smaller teacher model: Our 8B Zebra-Llama model used the 8B base model as its teacher and was trained on 11 billion tokens. Llamba's 8B model, conversely, was distilled from a more powerful 70B teacher model in its final stage and trained on an additional 1 billion tokens (12 billion total). We believe this combination of a stronger teacher and a larger training budget contributed to Llamba's better MMLU performance.
- More curated dataset selection Our Zebra-Llama model utilizes the same training datasets as Mamba-In-Llama (
OpenHermes-2.5,GenQA, andInfinity-Instruct). The Llamba paper, however, emphasizes the critical role of dataset selection/pre-processing and demonstrates that leveraging a carefully filtered version ofFineweb-Educould significantly boost the MMLU scores after distillation. We hypothesize that applying a similar rigorous data curation strategy would significantly enhance Zebra-Llama's MMLU performance and lead to higher scores than Llamba. As presented in [Answer 2], our MMLU score significantly surpasses that of Llamba when using the same dataset and training pipeline.
[1] Waleffe, Roger, et al. "An empirical study of mamba-based language models." arXiv preprint arXiv:2406.07887 (2024).
[Answer 2] Evaluating Llamba and Zebra-Llama Under Consistent Training Conditions
Thanks for the thoughtful and constructive comment. Following the reviewer’s suggestion, we conducted a controlled comparison between Zebra-Llama and Llamba using the same dataset and training pipeline. Due to time constraints, we were able to complete this evaluation at the 1B model scale.
For both Zebra-Llama and Llamba, we adopt the training pipeline introduced in Llamba paper which contains three stages. For the dataset, as the filtered FineWeb-Edu is not publicly available, we adopt the same training dataset in Zebra-Llama with ~6B tokens in total. We maintained the same token distribution across stages as Llamba: 225M tokens for Stage 1, 2.025B for Stage 2, and 3.75B for Stage 3. All other training parameters were kept consistent.
- Target Model: Llama-3.2-1B-Instruct
- Teacher Model: Llama-3.2-1B-Instruct
- Initial Learning Rate: 8e-5
- Batch Size: 96
The results, presented in the table below, demonstrate that our Zebra-Llama surpasses Llamba (pure SSM) across all 8 tasks with only 3.91% KV cache, highlighting the advantages of integrating efficient attention modules (i.e., MLA) with Mamba.
| Model & Setting | KV Size | Avg. | ARC | ARE | HS | MMLU | OBQA | PIQA | RA | WG |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama3.2-1B-Instruct | 100% | 51.87 | 37.97 | 63.3 | 60.65 | 46.05 | 34.8 | 74.32 | 38.18 | 59.67 |
| Llamba-1B | 0% | 47.23 | 34.98 | 62.12 | 54.41 | 27.35 | 34.6 | 71.82 | 34.83 | 57.7 |
| Zebra-Llama-1B, 4MLA-12M2 (Ours) | 3.91% | 49.32 | 37.12 | 62.67 | 55.54 | 37.1 | 35.8 | 72.2 | 35.5 | 58.64 |
[Answer 3] Scaling to larger models or other model families
We thank the reviewer for the comments. Theortically it is possible to scale our method to larger models but computational resource and time constraints prevented us from conducting such experiments. Scaling to larger models remains in our future plan.
However, we did perform additional experiments on Qwen2.5-0.5B-Instruct and Qwen2.5-1.5B-Instruct to demonstrate that our proposed method (now Zebra-Qwen) performs well across other model families. Using the same training process as for the Llama series, our 0.5B Zebra-Qwen model achieved better performance than the target model with only 6.25% KV cache. For the 1.5B Zebra-Qwen, even when using the same model as the teacher, we achieved similar performance to the target model with 12.5% KV cache size.
| Model & Setting | Teacher | KV Size | Avg. | ARC | ARE | HS | MMLU | OBQA | PIQA | RA | WG |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Qwen2.5-0.5B-Instruct | - | 100% | 47.93 | 33.11 | 59.05 | 52.26 | 45.86 | 34.2 | 70.62 | 32.06 | 56.27 |
| Zebra-Qwen-0.5B, 4MLA-20M2 (Ours) | 1.5B | 6.25% | 48.68 | 38.74 | 66.92 | 50.83 | 38.43 | 37.2 | 69.91 | 32.34 | 55.09 |
| Qwen2.5-1.5B-Instruct | - | 100% | 58.59 | 47.01 | 75.84 | 68.24 | 60.13 | 41 | 76.01 | 37.8 | 62.67 |
| Zebra-Qwen-1.5B, 14MLA-14M2 (Ours) | 1.5B | 12.5% | 58.16 | 48.63 | 75.17 | 67.64 | 53.87 | 41.6 | 75.73 | 38.66 | 64.01 |
[Answer 4] Extending Evaluation to More Challenging Tasks
Thanks for the comments and we agree on the importance of including more challenging tasks. As the reviewer suggested, we include the results on GSM8K and GPQA and summarize the results in the following.
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) |
|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct | 100% | - | - | 48.32 | 78.17 | 35.49 | 31.31 |
| MambaInLlama-8B | 50% | 70B | 20B | 45.94 | 75.05 | 29.46 | 33.3 |
| X-EcoMLA-8B | 9.37% | 8B | 7B | 43.6 | 70.81 | 29.69 | 30.3 |
| Llamba-8B | 0% | 8B+70B | 12B | 37.58 | 57.62 | 28.35 | 26.77 |
| Zebra-Llama-8B, 16MLA-16M2 (Ours) | 5.47% | 8B | 11B | 42.49 | 68.16 | 29.02 | 30.3 |
| Zebra-Llama-8B, 8MLA-24M2 (Ours) | 2.73% | 8B | 11B | 39.76 | 63.53 | 27.46 | 28.28 |
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) |
|---|---|---|---|---|---|---|---|
| Llama-3.2-3B-Instruct | 100% | - | - | 42.8 | 70.89 | 29.24 | 28.28 |
| MambaInLlama-3B | 50% | 70B | 20B | 37.01 | 53.06 | 27.68 | 30.3 |
| X-EcoMLA-3B | 9.37% | 8B | 7B | 41.01 | 65.28 | 27.46 | 30.3 |
| Llamba-3B | 0% | 3B+70B | 10B | 32.43 | 47.76 | 22.77 | 26.77 |
| Zebra-Llama-3B, 14MLA-14M2 (Ours) | 4.69% | 8B | 9B | 39.93 | 62.77 | 27.23 | 29.8 |
| Zebra-Llama-3B, 6MLA-22M2 (Ours) | 2.01% | 8B | 9B | 38.12 | 55.19 | 28.35 | 30.81 |
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) |
|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct | 100% | - | - | 48.32 | 78.17 | 35.49 | 31.31 |
| MambaInLlama-8B | 50% | 70B | 20B | 45.94 | 75.05 | 29.46 | 33.3 |
| X-EcoMLA-8B | 9.37% | 8B | 7B | 43.6 | 70.81 | 29.69 | 30.3 |
| Llamba-8B | 0% | 8B+70B | 12B | 37.58 | 57.62 | 28.35 | 26.77 |
| Zebra-Llama-8B, 16MLA-16M2 (Ours) | 5.47% | 8B | 11B | 42.49 | 68.16 | 29.02 | 30.3 |
| Zebra-Llama-8B, 8MLA-24M2 (Ours) | 2.73% | 8B | 11B | 39.76 | 63.53 | 27.46 | 28.28 |
Key observations from the results:
- Our 1B model excels, achieving the best average score and outperforming the target Llama-3.2-1B-Instruct model with only 7.81% KV cache.
- Our 3B model shows a 6.7% and 2.63% performance drop compared to the target model and X-EcoMLA, respectively, with 21.32x and 2x KV cache compression. Even at 2.01% KV cache, our overall performance surpasses Mamba-In-Llama and Llamba.
- Our 8B model exhibits a 12% and 7.5% performance drop compared to the target model and MambaInLlama, with 18.28x and 9.14x KV cache compression, respectively. For the performance gap, the reasoning task GSM8K is identified as a primary contributor. We hypothesize this is due to our efficient training scheme (8B teacher + 11B training tokens). Although Mamba-In-Llama exhibits higher GSM8K score, they adopted a 70B teacher and 20B training tokens. Since math reasoning tasks are generally harder, we expect Zebra-Llama to achieve further improvements with more training tokens.
- Our Zebra-Llama consistantly outperforms Llamba across all model sizes
To test the potential of math reasoning of Zebra-Llama, we conducted additional experiments on our 1B model where we insert more samples from OpenMathInstruct dataset during the SFT stage. As shown in the table, GSM8K accuracy consistently gets improved as we add more math tokens. We believe the math reasoning performance of Zebra-Llama will be much improved with more data engineering efforts.
| Our Original SFT | Our Original SFT + 1M Math samples | Our Original SFT + 2M Math samples | Our Original SFT+5M Math samples | |
|---|---|---|---|---|
| GSM8K (8-shot) | 43.44 | 45 | 47.92 | 49.81 |
| Avg of 8 tasks in Table 1 | 51.35 | 51.08 | 51.2 | 50.94 |
[Answer 5] Splitting of
The main goal for the split of is to align with operation when creating the . The process begins by concatenating the original trained key () and value () weight matrices [ ; ] and creating a single low-rank approximation of them called via SVD.
[ ; ] = =
Considering how we concatenate and , we know that can be divided into two parts where the upper part maps the latent to keys and the lower part maps the latent to values.
= [ ; ]
It is straightforward to get . However, since MLA only decompress the nope dimensions for each key head from the latent, we need to select dimensions out of from the upper part . In this work, we assume that we always take the first dimensions from as the nope dimensions.
Such procedure can be expressed as:
--> (reshape to ) --> (take the first from the last dimention and reshape to )
We sincerely thank the reviewer for the constructive feedback and the time dedicated to evaluating our responses. We have carefully addressed the insightful comments regarding MMLU comparisons, evaluations against LLaMBA, model scalability, and performance on additional datasets, including results on the RULER benchmark. Our updated experiments were designed to directly support and clarify the key points raised. We appreciate your consideration of these updates and would be grateful if you could kindly revisit your evaluation of our paper in light of these improvements.
Thanks for the reply. I think those responses make sense.
Could you please run some additional benchmarks for long-context such as ruler benchmark? My concern is that the models may only perform well on short contexts.
Thank you.
[Answer 6] Adding RULER Benchmark Results
Thank you for raising this important point.
We woule like to point out the performance for long context requires extra long context training dataset. Due to the constraints in computation and datasets, our models is only trained under the 2048 sequence length, which is generally not suitable for long context benchmarks, like RULER. For these models trained with 2K sequence length, we have now conducted RULER benchmark evaluations at 4k, 8k, and 16k context lengths. The results are summarized below for both 1B and 3B models.
As shown, the long-bench performance for Zebra-Llama and other methods (MambaInLlama and Llamba) are generally worse than the original Llama, which is expected since the original Llama is fine-tuned with very long context dataset with up to 128K sequence length.
Nevertheless, Zebra-Llama consistently outperforms Llamba at both the 1B and 3B scales. Moreover, Zebra-Llama outperforms MambaInLlama at the 3B scale across all context lengths, and at the 1B scale, Zebra-Llama works on-par with MambaInLlama at 4k and outperforms it at 8k and 16k. Importantly, these gains come with a much smaller memory footprint —Zebra-Llama uses only ~5% of the KV cache, whereas MambaInLlama requires approximately 10× more KV memory. This demonstrates that Zebra-Llama achieves a superior trade-off between efficiency and long-context reasoning performance.
For future steps, we will fine-tune the model for extended context dataset to further improve the performance for long-context tasks.
We appreciate your suggestion—it helped us further validate the robustness of our model under extended context scenarios.
| KV cache | 4k | 8k | 16k | |
|---|---|---|---|---|
| Llama3.2-1B-Instruct | 100% | 69.62 | 65.05 | 61.24 |
| MambaInLlama-1B | 50% | 38.75 | 21.55 | 3.88 |
| Llamba-1B | 0% | 5.03 | 3.79 | Memory Access Fault |
| Zebra-Llama-1B | 3.91% | 35.75 | 24.8 | 13.37 |
| KV cache | 4k | 8k | 16k | |
|---|---|---|---|---|
| Llama3.2-3B-Instruct | 100% | 83.77 | 76.91 | 75.19 |
| MambaInLlama-3B | 50% | 41.71 | 22.62 | 0.88 |
| Llamba-3B | 0% | 8.88 | 4.6 | Memory Access Fault |
| Zebra-Llama-3B | 4.69% | 58.69 | 38.24 | 9.97 |
This paper introduces Zebra-Llama, a hybrid language model designed for high inference efficiency. The hybrid architecture is composed from existing pre-trained Transformers. Through a post-training pipeline a new architecture is created that interleaves Mamba2 layers with Multi-head Latent Attention (MLA) layers. This process involves three key stages: refined weight initialization from a teacher model, intermediate layer distillation to align representations, and a sensitivity-aware layer selection strategy called SMART that places the MLA layers where they are most necessary. The authors demonstrate that this method dramatically reduces KV cache size and improves throughput, while maintaining the performance of the original teacher model.
优缺点分析
Strengths:
- The paper is generally well-written and easy to follow.
- The approach to construct the hybrid model is intuitive and the KV cache reduction is significant.
- The authors strengthen their claims by reporting direct inference throughput metrics.
Weaknesses:
-
Limited Novelty: While the hybrid architecture seems to yield meaningful efficiency gains, the novelty of the approach is limited. The use of distillation to transfer knowledge from a pre-trained transformers to hybrid architectures is common practice and upcycling pre-trained attention into MLA to enable extreme KV cache compression was already proposed in X-EcoMLA, reducing the novelty of this paper's architectural contributions.
-
Sensitivity Measure-Aware Replacement Method: The Sensitivity Measure-Aware Replacement of Transformer Layers (SMART) strategy is central to the paper's contribution but is presented as a heuristic and is not well-motivated. This raises a major concern that the strategy may be overfitted to the distilled Llama architecture which is the only architecture family considered in this work. To substantiate the claim that SMART is a generally applicable technique, its effectiveness should be demonstrated on other architectural families, such as Qwen.
-
Incomplete Results: The few-shot performance is only shown for the 8B model. Why not reporting the results for all models?
Minor points:
- Line 92: should be "targeting extreme efficiency".
- Table 1 Caption: There is a typo in the word "except".
- Line 307: The word "accuracy" is repeated.
问题
The results in Table 1 indicate that the performance gap between the distilled hybrid models and the teacher baselines is larger for bigger models. Could you elaborate on why this might be the case?
局限性
The authors sufficiently discuss the limitations of their work.
最终评判理由
After the rebuttal, the authors addressed my concerns regarding applying SMART Layer Selection to a new architecture family and provided the complete few-shot results.
格式问题
No concerns
We sincerely thank the reviewer for the constructive feedback. In this revised version, we have carefully addressed the main concerns raised in your review. Specifically, we clarified our key contributions [Answer 1], extended the evaluation of our SMART layer selection method to additional architectural families [Answer 2], completed the previously missing few-shot results [Answer 3], and provided further analysis to clarify the performance gap between the larger distilled and target models [Answer 4]. We hope these substantial updates address your concerns, and we would be grateful if you could revisit your evaluation based on the additional evidence presented.
[Answer 1] Our Key Contributions
Thank you for your comment. Given the importance of hybrid model architectures for balancing performance and efficiency, our paper introduces several key contributions spanning architecture design, training methodology, and empirical results:
Architecture: We propose the first hybrid model that combines MLA with Mamba2 layers to replace classical Transformer blocks. This design not only reduces memory usage but also addresses the quadratic attention bottleneck, enabling high-throughput inference with minimal performance degradation.
Training Methodology: We placed particular emphasis on improving initialization and layer selection, as both have a significant impact on model performance. Notably, layer selection has received limited attention in prior work. To this end, we developed careful initialization strategies to enable stable integration of Mamba2 and MLA layers, introduced auxiliary intermediate layer distillation, and proposed our SMART layer selection algorithm.
Empirical Results: Using our training pipeline, we successfully trained a family of hybrid models that achieve competitive performance with baseline Transformers across a range of benchmarks. Notably, our models reduce KV cache usage to as low as 2–5% and improve inference throughput by up to 2.1×, while also requiring less training data and shorter training time, surpassing prior state-of-the-art hybrid methods.
[Answer 2] SMART Layer Selection on other Architectural Families
Thanks for the comment. We agree that the proposed SMART layer selection is only tested in Llama families and it is necessary to test it with another model series such as Qwen. To show the generalizability of our SMART layer selection algorithm, we apply it to another scenario where we start from the Qwen-2.5-0.5B model (24 layers in total) and we would like to pick up 4 layers as MLA and 20 layers as Mamba2. Since we cannot include figures in the rebuttal. we include the layer sensitivity scores for each layer in the table below. The bold layers are the ones selected by SMART {Layer Indices: 0, 8, 15, 22}.
| Layer Index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Sensitivity score | 98.02 | 11.33 | 96.01 | -1539.14 | 32.71 | 26.37 | 31.51 | -68.93 | 26.05 | -287.4 | 53.66 | -201.49 | -2.02 | 107.16 | 80.8 | 385.9 | 59.76 | 67.29 | 4 | -1.07 | 164.41 | 221.56 | 243.83 | 129.34 |
We also evaluated several alternative layer selection schemes, and the results remained consistent with those reported in the paper (Table X), further demonstrating the generalizability of the proposed SMART layer selection algorithm.
| Model & Setting | Avg. | ARC | ARE | HS | MMLU | OBQA | PIQA | RA | WG |
|---|---|---|---|---|---|---|---|---|---|
| Uniform [0,6,12,18] | 46.86 | 34.73 | 64.56 | 48.39 | 38.81 | 33.2 | 69.04 | 31.58 | 54.54 |
| Uniform [5,11,17,23] | 47.27 | 35.24 | 65.91 | 49.08 | 37.46 | 33.4 | 69.42 | 31.1 | 56.51 |
| Max score [15,20,21,22] | 47.38 | 34.22 | 65.66 | 48.38 | 38.01 | 34.60 | 69.64 | 32.44 | 56.12 |
| SMART [0,8,15,22] | 47.52 | 34.98 | 66.50 | 48.37 | 38.56 | 34.40 | 69.04 | 32.54 | 55.80 |
[Answer 3] Complete Few-shot Results
We focused on reporting few-shot performance for the 8B model in the main text due to space and time constraints; however, we acknowledge the importance of completeness and now provide the full few-shot performance results for both 1B and 3B models (and will include them in the final version of our paper). As shown, our Zebra-Llama models consistently demonstrate competitive few-shot performance across multiple benchmarks at smaller scales as well.
| Model and Setting | KV% | Avg. | ARC(25) | HS(10) | MMLU(5) | WG(5) | TQ(0) |
|---|---|---|---|---|---|---|---|
| Llama3.2-1B-Inst | 100% | 49.98 | 41.38 | 59.8 | 45.48 | 59.35 | 43.88 |
| MambaInLlama-1B-50%* | 50% | 48.60 | 42.32 | 60.46 | 35.55 | 59.35 | 45.31 |
| X-EcoMLA-1B | 9.37% | 47.97 | 41.04 | 56.13 | 35.32 | 60.77 | 46.59 |
| Llamba-1B | 0% | 47.57 | 41.72 | 60.34 | 31.88 | 60.69 | 43.2 |
| Zebra-Llama-1B,8MLA-8M2(Ours) | 7.80% | 49.68 | 45.56 | 59.44 | 37.81 | 60.77 | 44.8 |
| Zebra-Llama-1B,6MLA-10M2(Ours) | 5.86% | 49.06 | 44.03 | 59.22 | 36.06 | 60.54 | 45.46 |
| Model and Setting | KV% | Avg. | ARC(25) | hs(10) | mmlu(5) | WG(5) | TQ(0) |
|---|---|---|---|---|---|---|---|
| Llama3.2-3B-Inst | 100% | 60.536 | 52.39 | 73.51 | 59.71 | 67.32 | 49.75 |
| MambaInLlama-3B-50% | 50% | 61.184 | 51.88 | 74.58 | 52.31 | 67.64 | 59.51 |
| X-EcoMLA-3B* | 9.37% | 57.500 | 49.49 | 69.20 | 52.26 | 66.69 | 49.86 |
| Llamba-3B | 0% | 58.010 | 50.09 | 74.21 | 49.87 | 70.09 | 45.79 |
| Zebra-Llama-3B,14MLA-14M2(Ours) | 4.69% | 60.120 | 53.67 | 71.30 | 51.05 | 67.64 | 56.94 |
| Zebra-Llama-3B,8MLA-20M2(Ours) | 2.86% | 58.488 | 54.52 | 70.44 | 46.43 | 65.98 | 55.07 |
[Answer 4] Performance Gap between Larger Distilled and Target models
Thank you for the comment. The larger performance gap observed for bigger hybrid models (especially the 8B models) is likely due to two main factors:
- the use of the same size 8B teacher for training our Zebra-Llama 8B models—unlike the 1B and 3B models, which had larger teachers—and
- the lack of scaling in training data proportional to model capacity. Due to limitations in time and computational resources, we kept the DPO dataset size fixed and were only able to use up to 11B tokens for SFT when training the 8B models. We believe that with a larger teacher and more extensive training (both in terms of steps and data volume), the performance of the 8B hybrid models can be further improved. We consider this an important direction for future work.
I would like to thank the authors for their response. Their responses and additional experiments address several of my concerns. The gains achieved by SMART layer selection seem marginal but it's good that the method generalizes to the Qwen architecture. I am happy to raise my score accordingly.
Thank you for taking the time to review our responses and additional experiments. We're glad to hear that our clarifications helped address your concerns, and and we sincerely appreciate you raising your score.
This paper studies the composition of efficient hybrid language models from existing pre-trained models. It proposes a pipeline (including weight initialization, intermediate layer distillation, the SMART layer selection strategy, etc.) to mix SSM and MLA layers, in order to greatly improve inference efficiency while maintaining foundational capabilities without significant impact.
优缺点分析
Strengths
-
The proposed method is very effective, capable of greatly reducing inference overhead while keeping performance almost unchanged. For instance, the 8B model achieves an impressive 36.6x KV cache compression with only a 2.8% drop in average zero-shot performance.
-
The experimental evaluation is relatively comprehensive, and the ablation analysis, which validates key designs like the novel SMART strategy, is solid
Weaknesses
-
The paper requires more discussion regarding the specific performance differences between the accelerated hybrid model and the original model. For example, while the average score is high, a closer look at Table 1 reveals the 8B model has a notable performance drop on the MMLU benchmark, and a discussion of such specific trade-offs would be valuable.
-
In terms of capability measurement, the benchmarks used in Table 1 are not comprehensive enough. These more challenging benchmarks are crucial for probing deeper reasoning capabilities that may be affected by the architectural changes. I recommend further testing on benchmarks such as DROP (for reading comprehension), AIME (for math reasoning), GPQA (for graduate-level scientific QA), etc.
问题
See Weaknesses
局限性
None
最终评判理由
I think this paper should be accepted to Neurips 2025. My concerns are mostly regarding the performance comparison and capability measurement in the Experiment section. The responses from the authors have largely addressed them.
格式问题
None
We sincerely thank the reviewer for acknowledging the effectiveness of our method in reducing inference overhead with minimal performance loss, as well as recognizing the strength of our experimental evaluation and the significance of the SMART layer selection strategy. In the rebuttal, we have carefully addressed your comments regarding [Answer 1] task-wise comparison between hybrid and base models, and [Answer 2] the extension of our evaluation to more challenging benchmarks in the following.
[Answer 1] Task-Wise Comparison of Hybrid and Base Models
Thank you for the helpful observation regarding performance differences between our hybrid models and their corresponding base models. We carefully reviewed the results across all tasks in Table 1 for the 1B, 3B, and 8B models. Across 7 out of 8 benchmarks, our hybrid models consistently achieve either superior or very close performance compared to their base models. The only notable deviation occurs on the MMLU benchmark, which we discuss in more detail below. When comparing the different size Zebra-Llama models to their corresponding base models, we observed that MMLU performance reaches approximately 84% of the base model at the 1B scale, 86% at the 3B scale, and 83% at the 8B scale, which is pretty consistent performance across different sizes. We explain the potential reasons for the observed MMLU performance gap below:
1) MMLU Task Formatting Difficulty: The observed gap in MMLU performance between hybrid models and pure Transformer baselines may stem from formatting sensitivity in MMLU’s multiple-choice structure. The study by [2] shows that state-space models (SSMs) like Mamba struggle with the standard MMLU format, which involves selecting a single letter (A/B/C/D) corresponding to the correct answer. While SSMs contain the same underlying knowledge as Transformers, they require more training to learn the task format, particularly the routing of information from multiple choices into a single output token—something Transformers handle more naturally via self-attention. This suggests that Zebra-Llama's hybrid use of Mamba/MLA layers may inherit some of this difficulty (inefficient decoding under MMLU’s standard prompting format), especially in tasks like MMLU that rely heavily on understanding structured input-output formatting, rather than open-ended generation or reasoning.
[2] Waleffe, Roger, et al. "An empirical study of mamba-based language models." arXiv preprint arXiv:2406.07887 (2024).
2) Importance of training data selection: We believe that more carefully curated dataset selection can significantly enhance performance, particularly on knowledge-intensive tasks like MMLU. Our Zebra-Llama models were trained using the same datasets as Mamba-In-Llama (which was our main initial baseline)—namely, OpenHermes-2.5, GenQA, and Infinity-Instruct. However, as highlighted in the Llamba paper [3], training on a carefully filtered version of FineWeb-Edu leads to substantial improvements in MMLU scores following distillation. To investigate this further, we trained a Llamba model using our dataset (without FineWeb-Edu). The MMLU score dropped from 38.11 (as reported in the original paper) to 27.35. In contrast, our Zebra-Llama model, trained on the same dataset and using the same training pipeline, significantly outperformed Llamba on MMLU (with the score of 37.1). This result suggests that Zebra-Llama would likely achieve even higher scores than Llamba if trained with similarly curated data, and we consider this an exciting direction for future work. Table: MMLU Comparison with Different Training Datasets
| Model | Training Dataset | MMLU Score |
|---|---|---|
| Llamba (Original paper) | FineWeb-Edu (filtered) + OpenHermes-2.5 | 38.11 |
| Llamba (Our re-training) | OpenHermes-2.5 + GenQA + Infinity-Instruct | 27.35 |
| Zebra-Llama (Ours) | OpenHermes-2.5 + GenQA + Infinity-Instruct | 36.91 |
[3] Bick, Aviv, et al. "Llamba: Scaling distilled recurrent models for efficient language processing." arXiv preprint arXiv:2502.14458 (2025).
[Answer 2] Extending Evaluation to More Challenging Benchmarks
Thank you for the thoughtful suggestion. We fully agree that evaluating on more challenging benchmarks is important for assessing deeper reasoning capabilities, especially in light of the architectural changes introduced by our hybrid design. As recommended, we have added results on GSM8K for math reasoning and GPQA for graduate-level scientific question answering. These benchmarks help provide a more complete view of the model’s capabilities beyond the core benchmarks in the following table. While we were not able to complete testing on DROP due to time and infrastructure constraints, we do include results on RACE, which is also a well-established benchmark for reading comprehension and reasoning over longer passages. We summarize the additional results below (and we will add them to the final paper):
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) | RACE |
|---|---|---|---|---|---|---|---|---|
| Llama-3.2-1B-Instruct | 100% | - | - | 32.44 | 38.51 | 26.79 | 26.26 | 38.18 |
| Mamba-In-Llama-1B | 50% | 8B | 7B | 30.04 | 24.94 | 26.12 | 30.8 | 38.28 |
| X-EcoMLA-1B | 9.37% | 8B | 7B | 31.3 | 38.06 | 21.43 | 26.26 | 39.43 |
| Llamba-1B | 0% | 1B+70B | 8B | 27.67 | 22.82 | 25 | 25.25 | 37.61 |
| Zebra-Llama-1B,8MLA-8M2(Ours) | 7.81% | 8B | 7B | 32.64 | 41.09 | 25.67 | 25.25 | 38.56 |
| Zebra-Llama-1B,4MLA-12M2(Ours) | 3.91% | 8B | 7B | 29.36 | 29.57 | 23.21 | 26.77 | 37.89 |
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) | RACE |
|---|---|---|---|---|---|---|---|---|
| Llama-3.2-3B-Instruct | 100% | - | - | 42.32 | 70.89 | 29.24 | 28.28 | 40.86 |
| Mamba-In-Llama-3B | 50% | 70B | 20B | 38.62 | 53.06 | 27.68 | 30.3 | 43.44 |
| X-EcoMLA-3B | 9.37% | 8B | 7B | 41.93 | 65.28 | 27.46 | 30.3 | 44.69 |
| Llamba-3B | 0% | 3B+70B | 10B | 34.35 | 47.76 | 22.77 | 26.77 | 40.1 |
| Zebra-Llama-3B,14MLA-14M2 (Ours) | 4.69% | 8B | 9B | 41.7 | 62.77 | 27.23 | 29.8 | 46.99 |
| Zebra-Llama-3B,6MLA-22M2 (Ours) | 2.01% | 8B | 9B | 39.12 | 55.19 | 28.35 | 30.81 | 42.11 |
| Model & Setting | KV Size | Teacher | Tokens | Avg | GSM8K (8-shot) | GPQA (Main) | GPQA (Diamond) | RACE |
|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct | 100% | - | - | 47.42 | 78.17 | 35.49 | 31.31 | 44.69 |
| Mamba-In-Llama-8B | 50% | 70B | 20B | 45.98 | 75.05 | 29.46 | 33.3 | 46.12 |
| X-EcoMLA-8B | 9.37% | 8B | 7B | 44.78 | 70.81 | 29.69 | 30.3 | 48.33 |
| Llamba-8B | 0% | 8B+70B | 12B | 38.28 | 57.62 | 28.35 | 26.77 | 40.38 |
| Zebra-Llama-8B,16MLA-16M2 (Ours) | 5.47% | 8B | 11B | 44.05 | 68.16 | 29.02 | 30.3 | 48.71 |
| Zebra-Llama-8B,8MLA-24M2 (Ours) | 2.73% | 8B | 11B | 40.9 | 63.53 | 27.46 | 28.28 | 44.31 |
Key observations from the results:
- Our 1B model excels, achieving the best average score and outperforming the target Llama-3.2-1B-Instruct model with only 7.81% KV cache.
- Our 3B model shows a 1.46% and 0.55% performance drop compared to the target model and X-EcoMLA, respectively, with 21.32x and 2x KV cache compression. Even at 2.01% KV cache, our overall performance surpasses Mamba-In-Llama and Llamba.
- Our 8B model exhibits a 7.11% and 4.2% performance drop compared to the target model and MambaInLlama, with 18.28x and 9.14x KV cache compression, respectively. For the performance gap, the reasoning task GSM8K is identified as a primary contributor. We hypothesize this is due to our efficient training scheme (8B teacher + 11B training tokens). Although Mamba-In-Llama exhibits higher GSM8K score, they adopted a 70B teacher and 20B training tokens. Since math reasoning tasks are generally harder, we expect Zebra-Llama to achieve further improvements with more training tokens.
- Our Zebra-Llama consistantly outperforms Llamba across all model sizes
Dear authors,
Thanks for your detailed reply, which has addressed most of my concerns.
Thank you for your kind message and for taking the time to review our answers. We're glad to hear that our response has addressed most of your concerns.
The authors introduce Zebra-Llama, a hybrid LLM designed to significantly reduce KV cache size while preserving the performance of the original pre-trained model on one-shot and few-shot benchmarks.
The proposed workflow consists of four main steps:
-
First, separate Multi-head Latent Attention (MLA) and Mamba2 models are initialized using weights from the pre-trained LLM.
-
Next, this initialization is refined by aligning the intermediate representations of the new models with those of the original.
-
The best layer-mixing scheme is then identified to create an optimal hybrid architecture.
-
Finally, the mixed model is further fine-tuned using Supervised Fine-Tuning (SFT) and Direct Preference Optimization (DPO).
This hybrid model achieves a significant reduction in test-time computation and outperforms other competitive hybrid models in few-shot settings. Further ablation studies validate the layer selection strategy, analyze the trade-offs between model configurations, and examine scaling behaviors as the size of the teacher model changes.
优缺点分析
Strengths:
- The paper is well written and quite involving. I appreciate the effort the authors made to deliver the implementation details.
- The authors report a comprehensive series of experimental results and ablation studies. The use of SMART allocation scheme is also justified by the ablation study.
Weakness:
- It seems that the authors are using 1B/3B/8B models for the proposed scheme, how does the workflow perform when the model size is extremely large/small?
- A typo in line 180. "MLA initialization" -> "Mamba2 initialization"
问题
Given that long Chain-of-Thought (CoT) reasoning is susceptible to compounding errors [1], how does Zebra-Llama's performance on such tasks compare to its base model? Does its hybrid architecture of attention and state-space layers potentially exacerbate error propagation in multi-step reasoning?
[1] Luo, Renjie, et al. "Through the Valley: Path to Effective Long CoT Training for Small Language Models." arXiv preprint arXiv:2506.07712 (2025).
局限性
yes
最终评判理由
This hybrid model achieves a significant reduction in test-time computation and outperforms other competitive hybrid models in few-shot settings. The authors' additional experiments demonstrate the potential for improvement in long CoT setting.
格式问题
no concerns
We sincerely thank the reviewer for the thoughtful feedback and kind recognition of our contributions. In response to your comments, we have done our best to address the raised concerns and will incorporate the corresponding improvements in the final version of the paper. Specifically, we have provided further insights into [Answer 1] the scalability of the proposed workflow across different model sizes, and [Answer 2] its performance on long chain-of-thought (CoT) reasoning tasks.
[Answer 1] Scalability of the Proposed Workflow across Smaller/Larger Model Sizes
Thank you for the suggestion. While the core components of our workflow—such as SMART layer selection, distillation, and hybrid architectural design—are scalable and model-agnostic, we focused on 1B, 3B, and 8B models in this work for two main reasons:
Comparison with existing baselines: To the best of our knowledge, the largest post-trained hybrid models available in the literature are at the 8B scale. To ensure meaningful comparisons, we aligned our model sizes with those baselines.
Practical training constraints: The 8B model is the largest that could be reasonably trained on a single node (with 8 decent GPUs). Due to time and resource limitations, we were not able to explore training at larger scales.
However, based on your comment and within the limited time available for the rebuttal, we trained a smaller language model—Qwen-0.5B—to verify the scalability of our workflow to both smaller model sizes and a different model family.
| Model & Setting | Teacher | KV Size | Avg. | ARC | ARE | HS | MMLU | OBQA | PIQA | RA | WG |
|---|---|---|---|---|---|---|---|---|---|---|---|
| Qwen2.5-0.5B-Instruct | - | 100% | 47.93 | 33.11 | 59.05 | 52.26 | 45.86 | 34.20 | 70.62 | 32.06 | 56.27 |
| Zebra-Qwen-0.5B,4MLA-20M2 (Ours) | 1.5B | 6.25% | 48.68 | 38.74 | 66.92 | 50.83 | 38.43 | 37.20 | 69.91 | 32.34 | 55.09 |
Our results show that the proposed method (referred to as Zebra-Qwen) performs well on models from the Qwen family. For training, we applied the same process as used for the LLaMA series. Notably, the 0.5B Zebra-Qwen model achieved better performance than the original target model while using only 6.25% KV cache.
[Answer 2] Performance on Long CoT Reasoning
Thank you for the insightful question and for referencing the paper [1]. The paper shows that small language models (≤3B) can suffer performance degradation when trained on limited long Chain-of-Thought (CoT) data, raising concerns about the effectiveness of such training for small models. To investigate whether our hybrid architecture exacerbates this issue, we conducted a controlled evaluation comparing Zebra-Llama-1B, 8MLA-8M2 on GSM8K, training the model with 8K, 16K, and 32K samples from OpenR1-Math-220k, following a setup similar to the referenced paper. Due to time constraints, our initial evaluation was conducted on a subset comprising the first 150 questions of GSM8K, with training on three distinct dataset sizes: 8k, 16k, and 32k samples. We follow the same evaluation protocol as [1], sampling 4 answers per question (temperature = 0.6, top p = 0.95) and report the average accuracy. Our experiments show that: as Zebra-Llama is further trained with long CoT data, the average response length dramatically increases from approximately 390 to over 4000 tokens. However, this increased "thinking" does not translate to improved accuracy. The GSM8k accuracy decreases from 39.5% to 25.83% when further trained with 8k CoT samples and further declines to 25% with 16k CoT samples. This observation strongly aligns with the hypothesis in [1] that when small Language Models (LLMs) are trained on a limited amount of long CoT data, they tend to learn superficial reasoning patterns. This can lead to an accumulation of errors during the extended "thinking" process, ultimately degrading accuracy. When we increased the CoT training dataset to 32k samples, the reasoning ability of the model starts to recover and the token efficiency increases, which also aligns with the findings in [1].
In summary, we have observed a similar trend that training small LLM on limited long CoT data hurts the reasoning accuracy, despite a marked increase in the response length. We hypothesize this phenomenon is general for small LLMs across various model structures, including hybrid models like Zebra-Llama. We are going to add this insightful analysis to the final version of the paper.
| Column 1 | original | 8k samples | 16k samples | 32k samples |
|---|---|---|---|---|
| Accuracy | 39.5 | 25.83 | 25 | 31.3 |
| Avg Response Length | 390.4 | 5322.59 | 5215.15 | 4533.46 |
[1] Luo, Renjie, et al. "Through the Valley: Path to Effective Long CoT Training for Small Language Models." arXiv preprint arXiv:2506.07712 (2025).
Thanks for the detailed response. My positive evaluation of the paper remains unchanged. Meanwhile, it seems the results in the table of [Answer 2] is for the Zebra model only, do you also have the original model's performance when training with different CoT length for a comparison? Feel no pressure for this, it is totally fine if the results are not currently available. I just want to understand how the Zebra scheme affects the trends for long CoT compared to the base model.
[Answer 3] Performance of Llama on Long CoT Reasoning
Thank you for your attention and valuable comment. We had to make some modifications to our codebase to support this experiment on pure attention models. After doing so, we conducted similar evaluations on Llama-3.2-1B-Instruct. The results exhibit a trend very similar to what reported in the paper [1]. We will include these findings in the final version of the paper. Thank you again for the insightful suggestion—it helps us further strengthen our analysis.
| Column 1 | original | 8k samples | 16k samples | 32k samples |
|---|---|---|---|---|
| Accuracy | 42.8 | 32 | 37.5 | 37.75 |
| Avg Response Length | 187.8 | 3948 | 4396 | 3373 |
[1] Luo, Renjie, et al. "Through the Valley: Path to Effective Long CoT Training for Small Language Models." arXiv preprint arXiv:2506.07712 (2025).
Thanks for providing the additional results. All my concerns are addressed, and I appreciate the efforts made by the authors for evaluating the proposed methods in the long CoT scenarios.
Dear Reviewers,
Thank you for sharing your valuable insights and expertise, which have played an important role in the review process. In response to the initial feedback, the authors have submitted a detailed rebuttal addressing the comments raised by the reviewers. I would appreciate it if you could carefully review their response and consider how it may affect your initial evaluation. Please feel free to share your updated thoughts or any additional comments after reviewing the rebuttal.
Additionally, please note that to fulfill the 'Mandatory Acknowledgment' process, the reviewers are expected to:
(i) Carefully read the author rebuttal,
(ii) Engage in meaningful discussion with the authors—and preferably also with fellow reviewers.
(iii) Ask questions, consider responses, and actively participate in the exchange,
(iv) Clearly articulate any unresolved concerns to give authors a fair opportunity to respond. Please avoid situations where the discussion implies “everything is great,” but the final justification form states otherwise. The discussion phase is designed to surface and clarify such issues.
Kindly note that clicking the “Mandatory Acknowledgment” checkbox prematurely does not exempt reviewers from participating in the discussion. Reviewers who do not contribute meaningfully may be flagged using the “Insufficient Review” button, in line with this year’s responsible reviewing guidelines.
Thank you for your time and thoughtful contributions to the review process.
Summary: The paper introduces Zebra-Llama, a family of 1B, 3B, and 8B hybrid models that combine State Space Models (SSMs) and Multi-head Latent Attention (MLA) layers. The approach leverages a refined initialization and post-training pipeline to efficiently transfer knowledge from pre-trained Transformers. Zebra-Llama achieves Transformer-level accuracy with near-SSM efficiency using only 7–11B training tokens, a substantial reduction compared to the trillions of tokens typically required, with distillation from an 8B teacher.
Strengths highlighted by reviewers:
- The hybrid model achieves significant reductions in test-time computation while maintaining competitive performance, outperforming other hybrid models in few-shot settings.
- The method is effective in reducing inference overhead without sacrificing accuracy.
- Additional experiments (e.g., long CoT settings) highlight the potential for further improvements.
- The SMART layer selection method generalizes beyond Llama to Qwen, showing transferability.
Weaknesses (partially unresolved after rebuttal and discussions):
- More challenging benchmarks are needed (e.g., DROP, MMLU), and the LLaMBA results shared in rebuttal should be fully integrated into the paper.
- Long-context evaluation requires clearer presentation and interpretation in the final draft.
- Unclear scalability to very large model sizes—evaluation is limited to 1B/3B/8B models.
- The differences between the accelerated hybrid model and the original Transformer on harder benchmarks need more discussion.
- Proposed architectural contributions have incremental novelty, as knowledge distillation into hybrid architectures is well-studied, and prior work (e.g., X-EcoMLA) already explored MLA-based KV cache compression.
- Larger performance gaps between larger (8B) distilled and target models remain unexplored, with the rebuttal attributing them to limited teacher size and training data. These claims need validation with larger teachers and datasets for 8B models.
Discussion and Decision: The paper received two accept and two borderline accept ratings. Reviewers acknowledged that the rebuttal addressed many concerns by including MMLU comparisons, LLaMBA baselines, scalability analysis, RULER benchmark results, and evaluations on additional datasets. The Area Chair agrees that the paper has clear merits and that the additional results strengthen the contribution. However, some weaknesses remain, particularly regarding scalability, novelty, and evaluations on harder reasoning tasks.
The Area Chair recommends acceptance, provided that the authors:
- Incorporate all new results and clarifications from the rebuttal into the final version (main text or appendix).
- Provide deeper analysis of long-context evaluations and challenging benchmarks.
- Address novelty concerns more explicitly in the final draft.
- Release code and checkpoints as promised in the submission.