PaperHub
5.5
/10
Rejected4 位审稿人
最低3最高8标准差1.8
8
5
6
3
3.8
置信度
正确性2.8
贡献度3.0
表达3.8
TL;DR

We leverage small LMs as teachers during knowledge distillation to improve large LM pre-training on both quality and training efficiency and rigorously support our methods with novel statistical results.

摘要

关键词
Large language modelsknowledge distillationdata selectionefficiency

评审与讨论

审稿意见
8

This paper explores the use of small language model (SLM) to provide useful supervision in the early stage of large language model (LLM) training. Specifically, the training of LLM is divided into two stages. In the first stage, it employs KD from an SLM to provide soft labels as additional training supervision. In the second stage, it resorts to the standard MLE training. Though it looks counterintuitive of employing a smaller model to teach a larger one, this paper provides theoretical justifications for such a methodology. Empirical results on language modeling also validate the utility of this approach, which saves 28% training time compared to the standard training.

优点

  1. A significant contribution of this paper is its theoretical understanding of knowledge distillation. The bias-variance trade-off in generalization bounds is particularly interesting, which effectively elucidates why even small teacher models can be helpful. However, since I am not an expert in this field, I would need to refer to the opinions of other reviewers on this part.
  2. The proposed method shows good experimental results, which outperforms standard training while saving 28% training time.

缺点

  1. The small teacher model employed in the experiments is not small enough (1.5B teacher v.s. 2.8B student ). The authors assume that the small language model is readily available, but it is not always the case. For example, before the release of Llama-405B, the largest Llama model only has 70B parameters. It is unknown whether such a much smaller teacher model can help the training of student model.
  2. This paper lacks discussion and comparison with recent progress of efficient LLM training [1,2,3].

[1] Yao et al., Masked Structural Growth for 2x Faster Language Model Pre-training, https://arxiv.org/pdf/2305.02869
[2] Li et al., FLM-101B: An Open LLM and How to Train It with $100K Budget, https://arxiv.org/pdf/2309.03852
[3] Shao et al., Patch-Level Training for Large Language Models, https://arxiv.org/pdf/2407.12665

问题

None

评论

We thank the reviewer for taking the time to review our submission and providing encouraging comments about both our theoretical and empirical contributions. Please find point-wise responses to your concerns below.

The small teacher model employed in the experiments is not small enough (1.5B teacher v.s. 2.8B student ).

Thank you for your comment. Reviewer o2jS had raised a similar question. Towards this, we performed additional experiments where we trained an 8.6B sized student LLM via our proposed SALT\text{SALT} framework while employing a 2.8B sized student SLM. Note that this corresponds to a higher student-to-teacher (size) ratio 8.6B/2.6B 3.018.6B / 2.6B ~ 3.01 compared to 2.8B/1.5B 1.872.8B / 1.5B ~ 1.87 in our initial submission. Our experiments show that the SALT\text{SALT} framework still provides both quality and efficiency gain in the new setting with a larger gap in teacher and student model sizes.

We have included the experimental details along with the comprehensive few-shot and downstream evaluations in Appendix K of the revised submission (blue colored text). Below we present the domain-wise few-shot performance and post supervised fine-tuning (SFT) the 8.6B sized LM trained via SALT\text{SALT} framework while comparing it with the natural baseline, i.e., 8.6B sized LM trained via standard self-supervised training. (Note that boldfaced and italicized numbers represent the best and the second best results, respectively, in the corresponding category.)

Domain-wise few-shot performance of 8.6B LLM pre-trained via 2.8B small LM teacher

# TasksSLM\text{SLM}Baseline\text{Baseline}SALT\text{SALT}SALTDS\text{SALT}_{\text{DS}}
@100% steps@70% steps@100% steps@70% steps@100% steps
World Knowledge422.1926.9127.6628.9728.0428.47
Reading Comprehension453.0056.4056.8357.4256.1057.48
Commonsense Reasoning761.9966.0166.8967.0966.6167.24
LAMBADA136.2058.7065.5064.8054.3055.00
SuperGLUE865.5369.6969.1970.3871.0671.26
NLG34.605.405.975.975.235.30
MBPP116.2020.8019.8022.0022.8023.20
Average2847.3251.7352.2452.9652.2952.81
评论

Post SFT results for 8.6B LLM pre-trained via 2.8B small LM teacher

GSM8KXSUMCNN/DailyMailANLI-R1ANLI-R2ANLI-R3
AccuracyRouge-1Rouge-2Rouge-LRouge-1Rouge-2Rouge-LAccuracyAccuracyAccuracy
Baseline\text{Baseline}41.8545.1022.6837.3643.7321.1941.2968.8058.9060.58
SALT\text{SALT}42.8445.3723.0437.6943.6921.1641.2270.2059.3063.25
SALTDS\text{SALT}_{\text{DS}}42.2345.8123.3438.1443.8021.2841.3569.3059.5062.17

Take away: The few-shot evals results exhibit the similar performance and efficiency gains we had observed for 2.8B model training in our initial submission: 1) At 70% training steps, SALT\text{SALT} already performs better than / on-par fully trained Baseline\text{Baseline} (@100% steps); and 2) At 100% training steps, SALT\text{SALT} significantly outperforms Baseline\text{Baseline}. Furthermore, the LLMs trained via SALT\text{SALT} (with and without data selection) exhibit strong gains in post-SFT performance across a wide range of downstream tasks.

The authors assume that the small language model is readily available, but it is not always the case. For example, before the release of Llama-405B, the largest Llama model only has 70B parameters. It is unknown whether such a much smaller teacher model can help the training of student model.

Thank you for the comment. We agree with the reviewer that for the widespread adoption of SALT\text{SALT} it has to be useful for various sizes of (small teacher, large student) model pairs. We hope that our results for (1.5B teacher, 2.8B student) pair in the initial submission and (2.8B teacher, 8.6B student) pair during the discussion phase help convince the reviewer that SALT\text{SALT} indeed has utility for different model sizes and student-to-teacher (size) ratios.

Here, we would also like to highlight that, given the limited timeframe of the discussion phase, in order to train an 8.6B sized model via SALT\text{SALT} simply utilized the hyperparameters that we had used in our initial submission to train a 2.8B sized model. Despite this, we do see significant gains in both final model quality and training efficiency for 8.6B sized model training via our proposed SALT\text{SALT} framework. This also highlights the robustness of the SALT\text{SALT} framework to various hyperparameter choices.

This paper lacks discussion and comparison with recent progress of efficient LLM training [1,2,3].

Thank you for pointing out these recent efforts on improving LLM training efficiency.

Please note that the work of Yao et al. [1] and Li et al. [2] fall under the line of work on progressive or stage-wise training that we already discussed in the related work section of our initial submission (line 479 in Section 6). We have revised this discussion to include [1] and [2] (see blue colored text in Line 479 of Section 6 of the revised version).

Shao et al. [3] proposes novel training method where first phase focuses on patch-level training where each patch is obtained by aggregating multiple tokens. The model switches to standard token-level training in the second phase. We plan to add the following line at the end of Section 6 of the revised version (we are happy to include any other work that the reviewer finds relevant here).

Finally, we note that there is significant interest in going beyond next-token prediction objectives and utilizing richer training tasks such as multiple next-token prediction (Gloeckle et al., 2024) or patch-level training (Shao et al., 2024) to make LLM pre-training faster. Such efforts are complementary to our exploration in our work that aims to leverage smaller LMs to improve LLM pre-training.

[1] Yao et al., Masked Structural Growth for 2x Faster Language Model Pre-training, https://arxiv.org/pdf/2305.02869

[2] Li et al., FLM-101B: An Open LLM and How to Train It with $100K Budget, https://arxiv.org/pdf/2309.03852

[3] Shao et al., Patch-Level Training for Large Language Models, https://arxiv.org/pdf/2407.12665

[4] Gloeckle et al., 2024, Better & faster large language models via multi-token prediction, https://arxiv.org/pdf/2404.19737

评论

I am satisfied with the author's response, particularly the additional experiments (2.8B teacher v.s. 8.6B student) that proves its scalability. For the camera-ready version, I recommend that the authors could consider conducting experiments across a broader range of teacher sizes while fixing the student model, to illustrate the impact of the teacher model. These experiments could be carried out on a smaller scale, so they would not consume too much computational resources.

The references I provided have been validated in the training of large language models. Perhaps adding an experimental comparison in the camera version (if easily reproducible) would be beneficial. Efficient LLM training is a very important research topic, yet the resources required are prohibitively expensive, making open explorations somewhat limited. This work makes an important step forward, so I will maintain my score (already high) and recommend acceptance.

评论

Thank you for your encouraging feedback. We will incorporate your suggestions in the final version.

审稿意见
5

This paper studies reversed knowledge distillation for language modeling, where a smaller language model provides target probabilities to a larger model. First, the authors derive two excess surrogate risk bounds under different assumptions: one that bounds this risk based mostly on the distribution of the loss on the training data distribution, and the other based on the expected robustness of the teacher model when increasing past context. They then show that the variance on this robustness can be regularized by the ω\omega weight of the distillation loss.

They leverage their theoretical analysis of KD to (pre)train large language models in two phases, first using a distillation loss, and then using a cross-entropy loss. Overall, they show that their KD approach outperforms both a vanilla baseline and a one-phase reverse KD baseline in terms of final training performance. They also propose a variant, SALTDSSALT_{DS}, where the second training phase is using data labeled as hard by a heuristic classifier based on the small LM perplexity. They evaluate their models for few-shot performance on classical benchmarks, fine-tune them, and provide a short analysis of performance against sample difficulty along training.

Overall, their approach provides a promising method to accelerate language model training using smaller model distillation as a loss regularization factor.

优点

The authors conduct a very novel and interesting study on knowledge distillation for language modeling in a relatively broad setting. They make reasonable assumptions, and obtain informative bounds on the performance that can be achieved through distillation. They conduct experiments on relatively large models (2.8B), which confirms the relevance of their approach in a setup that approaches industrial standards. This paper could motivate further work on distilling small models into larger ones, as it provides a theoretical justification for such approaches, and an example of empirical confirmation of this idea. The paper is also well-written and clear to follow.

缺点

The main weakness of this paper lies in the lack of articulation from the theoretical analysis to the experimental work. Several concerns can be raised on this subject:

  • The theoretical part is purely based on causal language modeling, while the models are trained with the UL2 procedure, which includes prefix language modeling (which is a variation of causal language modeling), and most importantly the span corruption task, which is not causal. Although it should be possible to extend the theoretical work in that direction, the experiments seem not to accurately reflect the theoretical results.
  • Conversely, the paper lacks empirical experiments that would help validate the introduced assumptions, and help quantify some of the factors of the bounds. For instance, empirical estimations based on existing language models could help refine the comments on the dependency of CC and {Vt}\{V_t\} on TT and the robustness of models. This could for instance motivate the choice of the size of the SLM during training.
  • Experimentation at smaller scale could be done to provide intuition on the benefits obtained when ω\omega varies, and help illustrate how the mentioned tradeoff can be found.
  • More generally, although the authors define several values that should be affected by KD in the theoretical part (e.g. ξ\xi), they do not explore the effect of KD on these metrics in the experimental part.

Additionally, I noted several unnatural design choices in the experimental part, which I happen to be more knowledgeable about:

  • In Figure 2, which summarizes (and introduces) the benefits of the methods, the authors use top-1 accuracy as the main LM performance metric, while it is more common to show cross-entropy or perplexity in this case. They also report log-perplexity in Table 1, instead of perplexity, which seems like a more natural choice.
  • Regarding the model architecture, the chosen vocabulary size seems particularly large (256000) compared to standards in the literature for (a priori) monolingual English datasets (e.g. GPTNeo: 50k; Llama2: 32k; Llama3, which is trained on multilingual data: 150k). The choice of the tokenizer in this case could be important, as it can deeply affect the nature of the modeled distribution D\mathcal{D}, and thus of the tradeoff offered by ω\omega.
  • Regarding the optimization, my opinion is that Adam (or AdamW) would have been a more neutral choice than Adafactor, although it has been used for several models as well. The authors also do not discuss the impact of learning rate when using the KD loss, which is understandable given the cost of running such a hyperparameter search at scale.

In my opinion, another weakness of the paper lies in its lack of transparency and precision on the computational overhead of the proposed method. The authors briefly discuss this overhead, arguing that "As a rule of thumb, a forward pass constitutes 1/4th cost of a training step". This claim lacks support in my opinion, as it depends on implementation and hardware among other things, and as previous works also mention 1/3rd as an estimate (cf. https://arxiv.org/pdf/2203.15556). Moreover, their study does not take the cost of pretraining the SLM into account for the overall training time, or the memory cost of having two models in memory simultaneously. Finally, given the 12% overhead mentioned by the authors and the plot from Figure 2, it is not clear that their approach performs better for pure in-domain language modeling performance when using a FLOPS-related metric on the x-axis. These points are not prohibitive for the method, but should be discussed in the paper.

问题

  • I found it a bit confusing to distinguish between VtV_t and VNV_N. Although they model similar aspects in both sections, it would probably be safer to make both more distinguishable.
  • In Figure 2, do you display results for SALT or for SALT-DS? Why didn't you display both in this graph?
  • Interestingly, the SALT models seem to not drastically outperform the vanilla baseline on in-domain language modeling evaluation (Figure 2, Table 1), but have a noticeable edge in few-shot evaluation and fine-tuning. It could be estimated that they are only 5-10% more data-efficient for pre-training, but are 30% more efficient for evaluation, and in other words that a KD model that reaches a given perplexity gives better evaluation results than a vanilla model at the same perplexity level. Do you have an explanation for this?
  • In a way, the SALT approach could be purely justified from an intuitive standpoint : obtain a relatively strong language model quickly through reverse distillation up to the point where the student is slightly below teacher level, and then proceed with regular training. How would you justify precisely the usefulness of the theoretical part in the design of the experimental part?
  • Similarly, the SALT-DS approach can also be framed as a (relatively) basic data-filtering technique. Do you think it would be relevant to invoke the curriculum learning literature in that case?

Typos

  • There is a parenthesis issue with equation 50.
评论

In Figure 2, which summarizes (and introduces) the benefits of the methods, the authors use top-1 accuracy as the main LM performance metric, while it is more common to show cross-entropy or perplexity in this case. They also report log-perplexity in Table 1, instead of perplexity, which seems like a more natural choice.

Please note that the counterpart of Figure 2 with train log-perplexity or cross-entropy loss (please note that train log-perplexity exactly corresponds to cross-entropy loss, while perplexity corresponds to exponential of cross-entropy loss) is presented as Figure 3 in Appendix I. Based on our experience, we found top-1 accuracy as a better predictor for the model’s performance. This is especially true when utilizing a weaker teacher LM since, unlike standard training loss (cf. Eq (1)), knowledge distillation loss (cf. Eq (2)) becomes a biased estimate of the log-perplexity in this case. This fact also manifests in the form of the D_TVD\_{\text{TV}} term in our theoretical analysis. (See also, our response to Reviewer 7nA9.) Thus, focusing on train log-perplexity — especially during the early stage when distillation from a weaker/smaller teacher is utilized — might be misleading as it would show higher log-perplexity compared to Baseline\text{Baseline} which is directly optimizing train log-perplexity or cross-entropy loss with respect to the ground-truth next-token distribution (cf. Eq (1)). Here, we would also like to highlight a growing body of literature (see, e.g., [1, 2]) that argues for not solely relying on (log)-perplexity to assess the quality of a language model.

Regarding using perplexity vs. log-perplexity, they are monotonic transformations of each other and convey the same information. We used log-perplexity as it directly corresponds to the cross-entropy loss (whereas perplexity corresponds to the exponential of cross-entropy). However, if the reviewer feels strongly about it, we would be happy to report perplexity as well.

[1] Liang et al., Holistic Evaluation of Language Models, TMLR 2023. [2] Hu et al., Can Perplexity Reflect Large Language Model's Ability in Long Text Understanding?, Tiny Papers, ICLR 2024.

Regarding the model architecture, the chosen vocabulary size seems particularly large (256000) compared to standards in the literature for (a priori) monolingual English datasets (e.g. GPTNeo: 50k; Llama2: 32k; Llama3, which is trained on multilingual data: 150k). The choice of the tokenizer in this case could be important, as it can deeply affect the nature of the modeled distribution D\mathcal{D}, and thus of the tradeoff offered by ω\omega.

As mentioned in line 338, we utilized the SentencePiece tokenizer with vocabulary size 256K as per [1]. The similar vocabulary size was also utilized in [2, 3]. We believe that conducting a detailed study of how tokenizer and vocabulary size affect the distillation performance and various design choices is an interesting topic for future research.

[1] Du et al., GLaM: Efficient Scaling of Language Models with Mixture-of-Experts, https://arxiv.org/abs/2112.06905

[2] Chowdhery et al., PaLM: Scaling Language Modeling with Pathways, https://arxiv.org/abs/2204.02311

[3] Gemma Team, Gemma: Open Models Based on Gemini Research and Technology, https://arxiv.org/abs/2403.08295

Regarding the optimization, my opinion is that Adam (or AdamW) would have been a more neutral choice than Adafactor, although it has been used for several models as well. The authors also do not discuss the impact of learning rate when using the KD loss, which is understandable given the cost of running such a hyperparameter search at scale.

Thank you for the comment. We agree with the reviewer that running a detailed hyperparameter search covering different optimizers, learning rates, and learning rate schedules is quite expensive for language model pre-training. Given this, we simply chose to work with Adafactor, as per [1, 2]. A detailed comparative analysis of the effect of these choices on final performance is out of the scope of this work.

[1] Du et al., GLaM: Efficient Scaling of Language Models with Mixture-of-Experts, https://arxiv.org/abs/2112.06905

[2] Chowdhery et al., PaLM: Scaling Language Modeling with Pathways, https://arxiv.org/abs/2204.02311

评论

Conversely, the paper lacks empirical experiments that would help validate the introduced assumptions, and help quantify some of the factors of the bounds. For instance, empirical estimations based on existing language models could help refine the comments on the dependency of CC and VtV_t on TT and the robustness of models. This could for instance motivate the choice of the size of the SLM during training.

Thank you for this suggestion. Please note that our bounds in Theorem 3.5 hold for all possible values of CC and {V_t}\{V\_t\}. As explained in Remark 3.6, if the underlying setup (as defined by the function class Θ\Theta and data distribution D\mathcal{D}) happens to have a favorable dependence of CC and {V_t}\{V\_t\} on TT, then Theorem 3.5 provides tighter bounds than Theorem 3.3. On the other hand, if the underlying setup leads to slower decay of CC and {V_t}\{V\_t\} as a function of TT, then the bound provided by Theorem 3.3 would be more useful in characterizing the language modeling performance.

Regarding the comment about leveraging empirical estimates of CC and {V_t}\{V\_t\} for SLM teacher selection, Assumption 3.4 primarily deals with the underlying data distribution and the stability/robustness of the student function class for a given teacher model.

That said, based on the reviewer’s suggestion we estimated the quantities ξt\xi_t in detail over the validation set. The new analysis is in Appendix L of the revision. To compute expectations of the form E_zD[ω(z;θ)z_t=x_t]\mathbb{E}\_{\mathbf{z}\sim \mathcal{D}} \left[\ell^{\omega}(\mathbf{z}; \mathbf{\theta}) | \mathbf{z}\_{\leq t} = \mathbf{x}\_{\le t}\right] that comprise the definition of ξ_t\xi\_t, we sample n_com=64n\_{\rm com} = 64 completions of the prefix x_t\mathbf{x}\_{\le t} from an oracle 8.6B Baseline\text{Baseline} model and compute the average loss for a 1.5B LM. We perform this estimation for about 200K prefixes of lengths t=1,5,10,100t=1, 5, 10, 100 from a set held out from training. We examined the distribution of ξt\xi_t at t{1,5,10,100}t \in \{1, 5, 10, 100\}, for sequences with different total lengths T{64,128,256}T \in \{64, 128, 256\}.

Figure 4 (in Appendix L) clearly shows that ξt\xi_t become more well behaved and increasingly concentrate around 00 as one increases tt. Below, we see that the mean ξ^_t|\widehat\xi\_t| t=1,5,10,100t=1,5,10,100 (over the validation set) decreases as the sequence length increases over T=64,128,256T=64, 128, 256. Further, the mean of ξ^_t|\widehat\xi\_t| decreases with tt rapidly (please see Appendix L in the revision for more details). This suggests that C_t,V_tC\_t, V\_t are likely to rapidly decrease with tt, validating our approach to obtain generalization bounds in Theorem 3.5 through Assumption 3.4.

Tt=1t=5t=10t=100
640.340.1350.108n/a
1280.260.1020.0760.042
2560.1870.0850.0620.033

Experimentation at smaller scale could be done to provide intuition on the benefits obtained when ω\omega varies, and help illustrate how the mentioned tradeoff can be found.

Thank you for your suggestions. During the initial phase of our study, we had indeed conducted a hyperparameter search to identify a good set of knowledge distillation parameters: distillation loss weight ω=0.667\omega = 0.667 and temperature ρ=0.25\rho = 0.25 (cf. Line 343 and 344) for distillation from 1.5B teacher to 2.8B student. However, given the resource intensive nature of pre-training experiments, since then we have focused most of our resources on evaluating SALT\text{SALT} framework along with the key ablation studies for various SALT\text{SALT}-specific design choices (see Appendix H). Please, also see our response to the next question.

More generally, although the authors define several values that should be affected by KD in the theoretical part (e.g. ξ\xi), they do not explore the effect of KD on these metrics in the experimental part

Please note that our theoretical analysis serves as a motivation and justification for selective knowledge transfer from small LM teachers to LLM students during pre-training. We thoroughly explore and validate the utility of such a selective knowledge transfer in our empirical study.

As for studying the effect of KD on ξ\xi which then affects the generalization bound, we have provided an analytical characterization in Appendix C. Thus, we believe that empirical verification of such impacts, while nice to have, is not necessarily needed. Furthermore, devoting space and discussion to such empirical results might distract the reader from the main focus of this work – showcasing comprehensive improvement of LLM pre-training quality and efficiency via SALT\text{SALT} framework.

评论

We thank the reviewer for their time and valuable feedback. We are happy that the reviewer found our study of knowledge distillation for language modeling novel and interesting, and appreciated both our theoretical bounds as well as experimental results.

Before we provide a point-wise response to the reviewer’s main concerns/questions below, we would like to bring the reviewer’s attention to new experiments we conducted based on other reviewers’ suggestions that further validate the utility of the proposed SALT\text{SALT} framework (see Appendix K in the revised submission or our responses to other reviewers).

The theoretical part is purely based on causal language modeling, while the models are trained with the UL2 procedure, which includes prefix language modeling (which is a variation of causal language modeling), and most importantly the span corruption task, which is not causal. Although it should be possible to extend the theoretical work in that direction, the experiments seem not to accurately reflect the theoretical results.

Thank you for this question. We primarily focused our theoretical analysis around causal language modeling for the ease of notation and clarity of exposition. The reviewer is right in noting that our theoretical results extend to other language modeling tasks such as prefix language modeling and span corruption in a straightforward manner.

For instance, for prefix language modeling, one can assume each input sequence to take the form x=[p,s]=[p_1,,p_m,s_1,,s_T]\mathbf{x} = [\mathbf{p}, \mathbf{s}] = [p\_1, \cdots, p\_m, s\_1,\ldots, s\_T] where p\mathbf{p} and s\mathbf{s} denote the prefix and the suffix/target parts of the input sequence, respectively. Accordingly, the language model parameterized by θ\mathbf{\theta} assigns the following likelihood to the TT-token long suffix:

P_θ(sp)=P_θ(s_1p)P_θ(s_2p,s_1)P_θ(sTp,s_1,,s_T1). P\_{\mathbf{\theta}}(\mathbf{s} | \mathbf{p}) = P\_{\mathbf{\theta}}(s\_1|\mathbf{p}) P\_{\mathbf{\theta}}(s\_2|\mathbf{p}, s\_1)\cdots P\_{\mathbf{\theta}}(s_T|\mathbf{p}, s\_1,\ldots, s\_{T-1}).

Now, the sequence level loss (analogue of Eq. (1)) in our submission takes the form:

(x,θ)=1T_t[T]logP_θ(s_tp,s_t1).\ell(\mathbf{x}, \mathbf{\theta}) = \frac{1}{T} \sum\_{t \in [T]} - \log P\_{\mathbf{\theta}}(s\_t | \mathbf{p}, \mathbf{s}\_{\leq t-1}).

One can similarly define the knowledge distillation loss (analogue of Eq. (2)).

As for the span corruption task, given an input sequence x\mathbf{x} we can define its span-corrupted version as x~=[c,u]\tilde{\mathbf{x}} = [ \mathbf{c}, \mathbf{u}], where c\mathbf{c} denotes the corrupted version of x\mathbf{x} and u=[u1,,uT]\mathbf{u} = [u_1,\ldots, u_T] represents the target tokens corresponding to the masked spans. Now, the loss would take the form

(x,θ)=1T_t[T]logP_θ(u_tc,u_t1).\ell(\mathbf{x}, \mathbf{\theta}) = \frac{1}{T} \sum\_{t \in [T]} - \log P\_{\mathbf{\theta}}(u\_t | \mathbf{c}, \mathbf{u}\_{\leq t-1}).

Once we have loss terms defined, we can provide excess risk bounds analogous to Theorem 3.3 & 3.5 for both prefix language modeling and span corruption tasks. We are happy to provide additional details if the reviewer deems them necessary.

Overall, we believe that our statistical analysis is general enough to cover various language modeling tasks and would provide similar conclusions for those tasks. Thus, our deliberate decision to present results for casual language modeling to make the notation simpler and messaging cleaner does not reflect a weakness of our theoretical analysis.

评论

I am truly grateful for the time taken by the author in addressing my comments.

  • Theoretical part: I am glad that the extension of the theoretical results is straightforward. Nevertheless, my point is that using the causal LM framework to justify experiments that are performed with a different objective hurts the overall coherence of the article.
  • Experiments to validate theoretical claims: Thank you for adding this part in the Appendix. I believe that this result strengthens the justification of the method. It is regrettable that you did not keep results from your initial study, as a detailed discussion about how you chose ω\omega and the stability of this parameter choice would have been interesting. Moreover, I understand that the bounds hold for all possible values of CC and VtV_t, but my question regards potential empirical estimates of these variables that would illustrate the decay explanation you propose.
  • LM metrics: I agree that top-1 accuracy is also indicative of performance, I simply argue that since this choice is not the most natural in the language modeling literature, it should at least be discussed in the paper. Regarding displaying the log-perplexity, I am not convinced by the authors' arguments: I agree that log-perplexity is cross-entropy loss, which then questions the choice of phrasing it as "log-perplexity" instead of e.g. "Validation cross-entropy".
  • Hyperparameter choice: I still have concerns about hyperparameter choice. Although it is not unreasonable to use hyperparameters from the GLaM article, it can be wondered if this work is still in line with state-of-the-art architectural designs in the language modeling field (which has evolved a lot since 2021), and it can be argued safely that monolingual models are now commonly trained using smaller vocabularies. This choice can also have had an impact on the models themselves, as between 250M and 500M parameters may have been dedicated to the embedding and unembedding layers (depending on whether the authors used weight tying). This does not disqualify the results of this work at all, but it raises questions about the reproducibility of these results for smaller tokenizer vocabularies (which may have a noticeable impact on the nature of the modeled distributions).

In summary, I appreciate the new study provided in Appendix L, but I still find the overall connection of the theoretical part and the empirical part a bit artificial. I believe that using theoretical analyses as a justification and then proceeding to conduct empirical studies that do not confirm the actual validity and relevance of this specific justification in details is detrimental to the quality of this paper. As I argued in my initial review, a purely intuitive motivation ("obtain a relatively strong language model quickly through reverse distillation up to the point where the student is slightly below teacher level, and then proceed with regular training") would have roughly led to the same experimental design, which in my opinion points out a lack of connection between theoretical and empirical sections.

评论

In a way, the SALT approach could be purely justified from an intuitive standpoint : obtain a relatively strong language model quickly through reverse distillation up to the point where the student is slightly below teacher level, and then proceed with regular training. How would you justify precisely the usefulness of the theoretical part in the design of the experimental part?

Thank you for the question. We would like to begin by highlighting that our theoretical analysis provides the first rigorous treatment of knowledge distillation for language modeling, as it leads to novel risk bounds that also have implications for the generalization of standard pre-training as a special case (see Line 66-69). Importantly, as discussed in Section 3.3, our analysis further provides a rigorous justification for how a seemingly weaker smaller teacher LM can aid the training of larger student LM via selective knowledge transfer.

When it comes to how specifically such a selective knowledge transfer must be performed, we rely on the implicit bias of modern networks where they typically tend to learn the easier supervision before focusing on complex supervision (Line 273-289). This motivated our two-stage training method, namely SALT\text{SALT} (cf. Algorithm 1). Even when one has agreed on such a two-stage training method, there are specific hyperparameters one has to consider, e.g., when should the first stage end? How should one transition between two stages, e.g., step transition or linear decay? We conducted a systematic exploration of these design choices for the SALT\text{SALT} framework and demonstrated significant improvements in both quality and efficiency for LLM pre-training.

Thus, as claimed in Lines 91-103, we consider both our theoretical analysis and our empirical study as our key contributions that have standalone merit while theoretical analysis also provides a strong formal foundation and qualitative guidance for the empirical work.

Similarly, the SALT-DS approach can also be framed as a (relatively) basic data-filtering technique. Do you think it would be relevant to invoke the curriculum learning literature in that case?

Thank you for this comment and for bringing up an interesting point-of-view to interpret the value of data selection in SALT_DS\text{SALT}\_{\text{DS}} from the lens of curriculum learning. However, we would like to point out that such a point-of-view would be useful not only for SALT_DS\text{SALT}\_{\text{DS}} but for any data selection procedure for (language) model training in general. Obtaining more precise results for SALT_DS\text{SALT}\_{\text{DS}} with our proposed data selection methods or designing novel data selection methods via the lens of curriculum learning is an interesting topic for future research. If the reviewer believes that any existing work in the literature already provides some results in this direction, we would be happy to add a discussion on that in a revised version of our submission.

评论

In my opinion, another weakness of the paper lies in its lack of transparency and precision on the computational overhead of the proposed method. The authors briefly discuss this overhead, arguing that "As a rule of thumb, a forward pass constitutes 1/4th cost of a training step". This claim lacks support in my opinion, as it depends on implementation and hardware among other things, and as previous works also mention 1/3rd as an estimate (cf. https://arxiv.org/pdf/2203.15556). Moreover, their study does not take the cost of pretraining the SLM into account for the overall training time, or the memory cost of having two models in memory simultaneously. Finally, given the 12% overhead mentioned by the authors and the plot from Figure 2, it is not clear that their approach performs better for pure in-domain language modeling performance when using a FLOPS-related metric on the x-axis. These points are not prohibitive for the method, but should be discussed in the paper.

We thank the reviewer for raising this very important point.

Regarding the relative cost of a forward versus a backward pass, please note that we rely on gradient checkpoint/rematerialization to optimize the memory requirement during LM pre-training. Thus, in each backward pass, we have to recompute various layer activations which further increases the cost of backward pass while enabling smaller memory footprint and larger batch size on TPU-v5e chips. In contrast, the works mentioned by the reviewer do not take these memory reduction techniques into consideration while computing the cost of the backward pass. We will add a note to clarify this in the revised paper.

Regarding including the training cost of SLM, as discussed in the introduction (Line 52-53), such SLMs might already be available for their own use cases without being specifically trained for the SALT framework. Moreover, a single SLM (even if specifically trained for the SALT\text{SALT} framework) can be used to train multiple large LMs (potentially of varying sizes). Thus, it might not be appropriate to add their training cost to the training cost of an LLM via SALT framework. This is in line with the fact that when distilling from a large model to a small model, we typically don’t add the cost of training a large model to the development cost of the small model.

Again, we will enhance the discussion on the training cost incurred by the SALT\text{SALT} framework to clarify these points in the revised version.

I found it a bit confusing to distinguish between V_t and V_N. Although they model similar aspects in both sections, it would probably be safer to make both more distinguishable.

Thank you for the suggestion. If the reviewer agrees, we would be happy to change VNV_N (in Theorem 3.3) to VarN{\rm Var}_N in the revised version to make the distinction from VtV_t (in Theorem 3.5) clearer.

In Figure 2, do you display results for SALT or for SALT-DS? Why didn't you display both in this graph?

Thank you for the questions. In Figure 2, we only display the results for SALT\text{SALT}. Please note that unlike the methods shown in Figure 2 (i.e., SLM\text{SLM}, Baseline\text{Baseline}, RKD\text{RKD}, and SALT\text{SALT}), SALTDS\text{SALT}_{\text{DS}} trains on a different training data distribution during the first phase – the distribution defined by the selected training sequences. Since the selected training sequences are by design “harder”, the training performance of SALT_DS\text{SALT}\_{\text{DS}} in the early phase would be lower and not comparable to other methods (due to distribution mismatch). Thus, we had decided to not include SALT_DS\text{SALT}\_{\text{DS}} in Figure 2.

Interestingly, the SALT models seem to not drastically outperform the vanilla baseline on in-domain language modeling evaluation (Figure 2, Table 1), but have a noticeable edge in few-shot evaluation and fine-tuning. It could be estimated that they are only 5-10% more data-efficient for pre-training, but are 30% more efficient for evaluation, and in other words that a KD model that reaches a given perplexity gives better evaluation results than a vanilla model at the same perplexity level. Do you have an explanation for this?

Thank you for making this important observation. Indeed, SALT\text{SALT} models show relatively smaller improvements over Baseline\text{Baseline} in terms of top-1 accuracy and perplexity but exhibit much larger few-shot and post-SFT downstream improvement. Although we do not have a precise explanation for this phenomenon, we believe that two reasons might be in play here:

Regularization enabled by the KD in the early phase might produce more robust final models that lead to better performance on relatively out-of-domain few-shot tasks and better initialization for SFT on downstream tasks. The improved top-1 accuracy realized by the SALT\text{SALT} models can lead to relatively disproportionate improvements in the discontinuous eval metrics such as accuracy and exact match.

评论

Hyperparameter choice: I still have concerns about hyperparameter choice. Although it is not unreasonable to use hyperparameters from the GLaM article, it can be wondered if this work is still in line with state-of-the-art architectural designs in the language modeling field (which has evolved a lot since 2021), and it can be argued safely that monolingual models are now commonly trained using smaller vocabularies. This choice can also have had an impact on the models themselves, as between 250M and 500M parameters may have been dedicated to the embedding and unembedding layers (depending on whether the authors used weight tying). This does not disqualify the results of this work at all, but it raises questions about the reproducibility of these results for smaller tokenizer vocabularies (which may have a noticeable impact on the nature of the modeled distributions).

Thank you for underlining this important point. To clarify, we do utilize weight tying in our paper (we have updated Appendix E to reflect this).

We would like to highlight that, as mentioned in our previous response, our hyperparameter choices regarding vocabulary and/or optimizer persists even in more recent publications [1, 2]. Thus, we do consider them to be a reasonable choice for our study.

More generally, as we also mentioned in our previous response, we are in agreement with the reviewer that a more detailed study of the impact of these choices on SALT\text{SALT} and even on standard knowledge distillation setup (larger teacher to small student distillation) is an interesting topic for future research.

[1] Chowdhery et al., PaLM: Scaling Language Modeling with Pathways, https://arxiv.org/abs/2204.02311

[2] Gemma Team, Gemma: Open Models Based on Gemini Research and Technology, https://arxiv.org/abs/2403.08295

In summary, I appreciate the new study provided in Appendix L, but I still find the overall connection of the theoretical part and the empirical part a bit artificial. I believe that using theoretical analyses as a justification and then proceeding to conduct empirical studies that do not confirm the actual validity and relevance of this specific justification in details is detrimental to the quality of this paper. As I argued in my initial review, a purely intuitive motivation ("obtain a relatively strong language model quickly through reverse distillation up to the point where the student is slightly below teacher level, and then proceed with regular training") would have roughly led to the same experimental design, which in my opinion points out a lack of connection between theoretical and empirical sections.

Thank you for your constructive feedback. We would like to reiterate that we have formalized the connection between our analysis and empirical setup in Section 3.3. We have also highlighted the importance of both our theoretical analysis and experimental study as two major contributions that have individual merits while theoretic analysis provides a rigorous foundation for the SALT\text{SALT} framework.

In general, there are always many plausible intuitions (often post-hoc explanations) for many practical phenomena in deep learning. We believe that it is always good to ground those in a rigorous analytical framework even if it does not cover all aspects of a practical phenomenon (such as transition method and precise transition point between two stages in our case). As with any research work, we acknowledge that our theoretical analysis has scope for improvement and future work can build on it to further refine and strengthen the analysis.

We again thank the reviewer for raising multiple valid points and constructively engaging with us to improve our submission. We request the reviewer to reassess their evaluation of our work in light of our new results and responses to their questions.

评论

We thank the reviewer for their prompt response. Below we express our thoughts on the concerns/comments highlighted by the reviewer in their response.

Theoretical part: I am glad that the extension of the theoretical results is straightforward. Nevertheless, my point is that using the causal LM framework to justify experiments that are performed with a different objective hurts the overall coherence of the article.

As illustrated in our previous response, our analysis readily extends to different language modeling tasks with similar conclusions and choice of focusing on causal language modeling was deliberated to make the exposition simpler and messaging clearer. We would add a note to highlight this point in the final version.

Furthermore, as we mentioned in Appendix E (Experimental setup details) of our initial submission, causal language modeling comprises 60% of the overall training objective (see Line 1484 in the current version), thus making causal language modeling to be the natural choice for the analysis to focus on.

Experiments to validate theoretical claims: Thank you for adding this part in the Appendix. I believe that this result strengthens the justification of the method. It is regrettable that you did not keep results from your initial study, as a detailed discussion about how you chose ω and the stability of this parameter choice would have been interesting. Moreover, I understand that the bounds hold for all possible values of C and V_t, but my question regards potential empirical estimates of these variables that would illustrate the decay explanation you propose.

Thank you for recognizing that including our empirical estimates for {ξ_t\xi\_t}’s distribution in Appendix L strengthens our submission. We would like to note that we have expanded Appendix L with such estimates for 2.8B LM as well. We again see the same behaviour: the distribution of {ξ_t\xi\_t} becomes more concentrated around 0 as we increase tt, thereby highlighting the decay of CC (max deviation in {ξ_t\xi\_t}) and VtV_t (variance of {ξ_t\xi\_t}) in real LM trained on real world dataset.

Regarding the question about the stability of our parameter choices (e.g., for ω\omega), please note that newly added results for training 8.6B LM via SALT\text{SALT} framework in Appendix K already demonstrate the stability of these choices. Given the limited timeframe of the discussion phase, we had simply utilized the hyperparameters (e.g., ω\omega, ρ\rho etc) that we had used in our initial submission to train a 2.8B sized model. Despite this, as evident in Appendix K, we do see significant gains in both final model quality and training efficiency for 8.6B sized model training via our proposed SALT\text{SALT} framework. This highlights the robustness of the SALT\text{SALT} framework to various hyperparameter choices.

LM metrics: I agree that top-1 accuracy is also indicative of performance, I simply argue that since this choice is not the most natural in the language modeling literature, it should at least be discussed in the paper. Regarding displaying the log-perplexity, I am not convinced by the authors' arguments: I agree that log-perplexity is cross-entropy loss, which then questions the choice of phrasing it as "log-perplexity" instead of e.g. "Validation cross-entropy".

Thank you for the suggestion. We will add a note in the final version to highlight that top-1 accuracy has proven to be quite informative in our study.

Regarding our choice of the phrase log-perplexity, we found it very natural to refer to logarithm of perplexity as log-perplexity (aˋ la\grave{a}~la logarithm of likelihood is commonly referred to as log-likelihood). That said, if the reviewer prefers, we are happy to rename it as Validation cross-entropy in the revised version. We would also clarify that it corresponds to cross-entropy loss with respect to ground-truth next-token distribution (as in Eq. (1)) and not with respect to the teacher’s per-token distribution (as in Eq. (2)).

审稿意见
6

This paper proposes an efficient two-stage pre-training method called SALT. In the early stage of LLM pre-training, SALT utilizes an SLM (Small Language Model) as the teacher model for knowledge distillation, while also using the SLM to select challenging yet learnable data to ensure effective knowledge transfer. In the second stage, the training switches to standard self-supervised learning, thereby significantly improving the quality and efficiency of LLM training without increasing pre-training costs. Experimental results demonstrate that the SALT method outperforms standard LLM training on several tasks, while also achieving notable training time savings.

优点

  1. The paper leverages a small language model (SLM) as the teacher model for the knowledge distillation process. Instead of using all data for distillation, it selectively transfers knowledge by choosing challenging yet learnable data points, reducing the risk of transferring erroneous knowledge and enhancing the effectiveness and specificity of the knowledge transfer.
  2. The 2.8B parameter LLM trained with SALT achieves better performance than a 2.8B LLM trained with standard pre-training on various popular few-shot benchmarks, while requiring only about 70% of the training steps and saving approximately 28% in wall-clock time. SALT models also consistently show substantial performance gains in multiple domains after SFT.

缺点

  1. In Section 4, this method selects samples with high information content and those that are learnable by the SLM through Equation 11, focusing on top-m score sequences. This introduces some uncertainty into the model’s learning process, especially when using an early checkpoint of the SLM, which may further amplify this uncertainty. This can lead to the model learning unreliable information in the early stages, thereby causing a negative impact. As shown in Table 1, SALT_{DS} does not yield performance improvement over the baseline in the early stage, and even exhibits negative effects, indicating that excessive knowledge distillation on uncertain samples may introduce noise rather than effective knowledge, thus affecting model performance.

  2. In Table 1, at the early stage, both SALT and SALT_{DS} show limited improvement over the baseline, and even experience noticeable drops in some metrics.

问题

1.For Table 1, from the experimental results, it appears that in the Early phase, SALT and SALT_{DS} did not improve performance compared to the Baseline, and even showed a slight decrease in some metrics. However, in the Final phase results, SALT and SALT_{DS} outperform the Baseline. Where do you think this benefit comes from? Additionally, although SALT and SALT_{DS} show some improvement, they do not significantly outperform the Baseline. Have you conducted multiple runs to confirm the stability of these results? 2.Do you plan to test this approach on a larger model to further validate its feasibility?

评论

Do you plan to test this approach on a larger model to further validate its feasibility?

Thank you for the question. Given that this was a common question raised by multiple reviewers, we have performed additional experiments to showcase the utility of our proposed SALT framework for larger models. In particular, we have pre-trained an 8.6B parameters LM on the Pile dataset via SALT\text{SALT} framework while leveraging a 2.6B small teacher LM. We have provided the experimental details along with the comprehensive few-shot and downstream evaluations in Appendix K of the revised submission (blue colored text). Below we present the domain-wise few-shot performance and post supervised fine-tuning (SFT) the 8.6B sized LM trained via SALT\text{SALT} framework while comparing it with the natural baseline, i.e., 8.6B sized LM trained via standard self-supervised training. (Note that boldfaced and italicized numbers represent the best and the second best results, respectively, in the corresponding category.)

Domain-wise few-shot performance of 8.6B LLM pre-trained via 2.8B small LM teacher

# TasksSLM\text{SLM}Baseline\text{Baseline}SALT\text{SALT}SALTDS\text{SALT}_{\text{DS}}
@100% steps@70% steps@100% steps@70% steps@100% steps
World Knowledge422.1926.9127.6628.9728.0428.47
Reading Comprehension453.0056.4056.8357.4256.1057.48
Commonsense Reasoning761.9966.0166.8967.0966.6167.24
LAMBADA136.2058.7065.5064.8054.3055.00
SuperGLUE865.5369.6969.1970.3871.0671.26
NLG34.605.405.975.975.235.30
MBPP116.2020.8019.8022.0022.8023.20
Average2847.3251.7352.2452.9652.2952.81

Post SFT results for 8.6B LLM pre-trained via 2.8B small LM teacher

GSM8KXSUMCNN/DailyMailANLI-R1ANLI-R2ANLI-R3
AccuracyRouge-1Rouge-2Rouge-LRouge-1Rouge-2Rouge-LAccuracyAccuracyAccuracy
Baseline\text{Baseline}41.8545.1022.6837.3643.7321.1941.2968.8058.9060.58
SALT\text{SALT}42.8445.3723.0437.6943.6921.1641.2270.2059.3063.25
SALTDS\text{SALT}_{\text{DS}}42.2345.8123.3438.1443.8021.2841.3569.3059.5062.17

Take away: The few-shot evals results exhibit the similar performance and efficiency gains we had observed for 2.8B model training in our initial submission: 1) At 70% training steps, SALT\text{SALT} already performs better than / on-par fully trained Baseline\text{Baseline} (@100% steps); and 2) At 100% training steps, SALT\text{SALT} significantly outperforms Baseline\text{Baseline}. Furthermore, the LLMs trained via SALT\text{SALT} (with and without data selection) exhibit strong gains in post-SFT performance across a wide range of downstream tasks.

评论

However, crucially, SALT\text{SALT} consistently ensures better next-token prediction accuracy compared to Baseline\text{Baseline} both at the end of the early stage (36K steps) and the overall training (208K steps). Furthermore, it realizes better log-perplexity compared to Baseline\text{Baseline} at the end (208K steps) after focusing on only ground-truth next-token-based cross entropy loss during the second stage of training (the same holds for SALT_DS\text{SALT}\_{\text{DS}} as well). This collectively explains improved model quality via SALT\text{SALT} across multiple tasks. Also, note that recent works [1, 2] have cautioned against solely relying on the (log-)perplexity as the measure of an LM’s quality.

If the reviewer deems it necessary, we are happy to add this explanation about various metrics appearing in Table 1 to the revised version of our submission.

[1] Liang et al., Holistic Evaluation of Language Models, TMLR 2023.

[2] Hu et al., Can Perplexity Reflect Large Language Model's Ability in Long Text Understanding?, Tiny Papers, ICLR 2024.

Additionally, although SALT and SALT_{DS} show some improvement, they do not significantly outperform the Baseline. Have you conducted multiple runs to confirm the stability of these results?

Thank you for the question. Given that we focus on LM pre-training which is quite expensive in terms of the compute resource requirement, it is not feasible to run each experiment multiple times. That said, throughout our exploration whenever we happened to run an experiment twice, we have observed that pre-training performance for a given method is quite stable and does not change much between two runs. More importantly, these minor changes are much smaller compared to gains we observe in final next-token prediction accuracy and log-perplexity in Table 1.

Here, it is worth noting that seemingly small gains in terms of next-token prediction accuracy and log-perplexity achieved by SALT\text{SALT} and SALT_DS\text{SALT}\_{\text{DS}} do lead to significant improvements in the few-shot and post SFT performances (see Table 2, 5, and 3).

Furthermore, we believe that our observations in Table 1 are somewhat in line with the modern scaling laws which alludes to the fact that as one increases total compute or data for a given model, the absolute improvements for log-perplexity becomes much smaller.

[1] Kaplan et al., Scaling Laws for Neural Language Models, https://arxiv.org/abs/2001.08361

[2] Hoffmann et al., Training Compute-Optimal Large Language Models, NeurIPS 2022

评论

Thank you for your time and effort in reviewing our submission. Below we provide point-wise responses to your insightful comments/questions.

In Section 4, this method selects samples with high information content and those that are learnable by the SLM through Equation 11, focusing on top-m score sequences. This introduces some uncertainty into the model’s learning process, especially when using an early checkpoint of the SLM, which may further amplify this uncertainty. This can lead to the model learning unreliable information in the early stages, thereby causing a negative impact. As shown in Table 1, SALT_{DS} does not yield performance improvement over the baseline in the early stage, and even exhibits negative effects, indicating that excessive knowledge distillation on uncertain samples may introduce noise rather than effective knowledge, thus affecting model performance.

We thank the reviewer for raising this important question.

Before clarifying why SALT_DS\text{SALT}\_{\text{DS}} in Table 1 seemingly degrades model performance in the early stage of training, we would like to note that our data selection method explicitly aims to prevent potentially noisy tokens from affecting the selection process. In particular, while assigning a score S_ζ,k()S\_{\mathbf{\zeta}, k}(\cdot) to a training sequence, we remove the contribution of those tokens where small LM’s top-kk predictions do not contain the ground-truth token (see Eq. (11)) and Line 297-299).

From Table 1, it appears that SALT_DS\text{SALT}\_{\text{DS}} hurts the model performance in the early phase. In particular, next-token prediction accuracy of SALT_DS\text{SALT}\_{\text{DS}} on the heldout set of Pile at the end of the first/early stage of training (56.4756.47) is worse than that of Baseline\text{Baseline} (56.6856.68). Recall that SALT_DS\text{SALT}\_{\text{DS}} utilizes selected data during the first stage of training which by design has data distribution that is different from the original distribution of the Pile training set. Since the heldout set of the Pile follows distribution closer/similar to the original Pile training set, the lower performance of SALT_DS\text{SALT}\_{\text{DS}} on this heldout set is just an artifact of the mismatch between the training and the test distributions. This distributional mismatch again explains seemingly poorer performance of SALT_DS\text{SALT}\_{\text{DS}} compared to SALT\text{SALT}/RKD\text{RKD} at the end of first stage.

Importantly, SALT_DS\text{SALT}\_{\text{DS}} gets to train on the original Pile set during the second stage of training which eventually leads to better next-token prediction accuracy compared to Baseline\text{Baseline} at the end of pre-training (at 208k steps). This along with the improved few-shot (Table 2 & 5) and post SFT (Table 3) downstream performance via SALT_DS\text{SALT}\_{\text{DS}} clearly establishes the utility of the method.

For Table 1, from the experimental results, it appears that in the Early phase, SALT and SALT_{DS} did not improve performance compared to the Baseline, and even showed a slight decrease in some metrics. However, in the Final phase results, SALT and SALT_{DS} outperform the Baseline. Where do you think this benefit comes from?

Please see our earlier response that highlights the train vs. test distribution mismatch for SALT_DS\text{SALT}\_{\text{DS}} which explains why it might appear to be under-performing compared to Baseline\text{Baseline} at the end of the first/early stage.

As the reviewer noticed, SALT\text{SALT} does have higher log-perplexity at the end of the early stage compared to Baseline\text{Baseline}. However, it is important to note that SALT\text{SALT} (as well as SALT_DS\text{SALT}\_{\text{DS}}) minimizes a combination of the ground-truth next-token-based cross entropy loss (Eq. (1)) and token-level KD loss (Eq. (2)) during the first stage of training. By design ground-truth next-token-based cross entropy loss is directly optimizing for log-perplexity as the former is an unbiased estimate of the latter. On the other hand, token-level KD with small teacher LM forms a biased estimate of log-perplexity which also manifests in terms of DTV(,)\text{D}_{\text{TV}}(\cdot, \cdot) terms in Theorem 3.3 and Theorem 3.5. As a result, SALT\text{SALT} results in inferior log-perplexity in early stage compared to Baseline\text{Baseline} that only focuses on the ground-truth next-token-based cross entropy loss – an unbiased estimate of the log-perplexity.

评论

Again, we thank the reviewer for their time and effort reviewing our submission. We are following up to see if you had a chance to review our response to your comments/concerns and if you have any further questions.

审稿意见
3

This paper addresses the significant computational challenges in training large language models (LLMs) by proposing an approach that leverages smaller language models (SLMs) to improve both training efficiency and model quality. The authors introduce a two-stage training methodology where knowledge distillation from an SLM is used during the early phase of LLM training, followed by standard self-supervised training.

The work makes three main contributions: (1) A theoretical framework analyzing knowledge distillation in language modeling, providing novel risk bounds that explain how even a weaker teacher model can benefit a (stronger) student model; (2) The SALT methodology, which uses SLMs both for knowledge distillation and data selection; and (3) Empirical validation showing that SALT can reduce LLM training time while maintaining or improving performance across various benchmarks.

优点

The authors validate their approach by training a 2.8B parameter model using a 1.5B parameter teacher model on the Pile dataset. They demonstrate improvements in both few-shot learning across multiple downstream tasks. The theoretical analysis provides insights into the bias-variance trade-off in knowledge distillation and justifies why selective knowledge transfer from smaller models can be beneficial. And more pros:

  • The paper presents a novel and counter-intuitive approach of using smaller models to improve larger model training (though the "large" is required further clarification).
  • The theoretical framework for analyzing knowledge distillation in language modeling is novel.
  • The proposed data selection mechanism using SLMs is innovative and well-motivated.
  • Addresses a crucial effectiveness improvement of the computational resource utilizations (training efficiency).

The theoretical framework is in demand at the fast developed field of data cleansing and knowledge distillation for LLMs (mostly driven by empirical results)

缺点

  • The size gap of the 1.5B teacher and 2.8B student remains questionable to define them as a pair of relatively "small" and "large" LMs. If the author do not have sufficient resource to scale up the large model, they need to prove that a smaller

  • The paper doesn't thoroughly explore the impact of different teacher-student size ratios. (Also in Question 1)

  • The paper primarily focuses on relatively small models (1.5B and 2.8B parameters) compared to state-of-the-art LLMs. It's unclear if the benefits would scale to much larger models (eg. 7B).

  • Limited analysis of computational overhead from the knowledge distillation phase: the authors only provide the resource difference between SALT and the baseline, but not providing the cost of the data selection part.

问题

  1. How would SALT perform with larger/smaller teacher-student size ratios? For example, would the benefits persist when training a 70B parameter model with a 7B teacher? When the ratio decreases to 1, will the framework still work (eg., similar to bootstrapping).

  2. Is there a minimum performance threshold the teacher must achieve for the method to be effective? Is it reasonable to set this as the baseline performance?

评论

We thank the reviewer for recognizing the novelty and importance of our theoretical framework. We are also glad that the reviewer found our proposed method SALT\text{SALT} effective and the accompanying data selection mechanism well-motivated. Below we provide point-wise responses to the reviewer’s main questions/concerns.

The size gap of the 1.5B teacher and 2.8B student remains questionable to define them as a pair of relatively "small" and "large" LMs…The paper doesn't thoroughly explore the impact of different teacher-student size ratios.

Thank you for these comments. In order to address your concerns, we have carried out additional experiments where we aim to train a larger LLM with 8.6B parameters on the Pile dataset while utilizing a 2.8B parameter small teacher LM in our proposed SALT\text{SALT} framework. Note that this experiment aims to both 1) increase the student-to-teacher ratio to 8.6B/2.8B  3.078.6B / 2.8B ~ \sim ~ 3.07 compared to 2.8B/1.5B  1.872.8B / 1.5B ~ \sim ~ 1.87 in our initial submission; and 2) train a larger model 8.6B compared to 2.8B in our initial submission.

Given the limited timeframe of the discussion phase, we simply utilized the hyperparameters that we had used in our initial submission to train a 2.8B sized model. Despite this, we do see significant gains in both final model quality and training efficiency for 8.6B sized model training via our proposed SALT\text{SALT} framework. This not only strengthens the utility of our contributions but also highlights the robustness of the SALT\text{SALT} framework to various hyperparameter choices.

We have revised the submission to include the training setup along with the detailed empirical results for 8.6B parameter model training in Appendix K. Below we present the domain-wise few-shot performance and post supervised fine-tuning (SFT) the 8.6B sized LM trained via SALT\text{SALT} framework while comparing it with the natural baseline, i.e., 8.6B sized LM trained via standard self-supervised training. (Note that boldfaced and italicized numbers represent the best and the second best results, respectively, in the corresponding category.)

Domain-wise few-shot performance of 8.6B LLM pre-trained via 2.8B small LM teacher

# TasksSLM\text{SLM}Baseline\text{Baseline}SALT\text{SALT}SALTDS\text{SALT}_{\text{DS}}
@100% steps@70% steps@100% steps@70% steps@100% steps
World Knowledge422.1926.9127.6628.9728.0428.47
Reading Comprehension453.0056.4056.8357.4256.1057.48
Commonsense Reasoning761.9966.0166.8967.0966.6167.24
LAMBADA136.2058.7065.5064.8054.3055.00
SuperGLUE865.5369.6969.1970.3871.0671.26
NLG34.605.405.975.975.235.30
MBPP116.2020.8019.8022.0022.8023.20
Average2847.3251.7352.2452.9652.2952.81
评论

Post SFT results for 8.6B LLM pre-trained via 2.8B small LM teacher

GSM8KXSUMCNN/DailyMailANLI-R1ANLI-R2ANLI-R3
AccuracyRouge-1Rouge-2Rouge-LRouge-1Rouge-2Rouge-LAccuracyAccuracyAccuracy
Baseline\text{Baseline}41.8545.1022.6837.3643.7321.1941.2968.8058.9060.58
SALT\text{SALT}42.8445.3723.0437.6943.6921.1641.2270.2059.3063.25
SALTDS\text{SALT}_{\text{DS}}42.2345.8123.3438.1443.8021.2841.3569.3059.5062.17

Take away: The few-shot evals results exhibit the similar performance and efficiency gains we had observed for 2.8B model training in our initial submission: 1) At 70% training steps, SALT\text{SALT} already performs better than / on-par fully trained Baseline\text{Baseline} (@100% steps); and 2) At 100% training steps, SALT\text{SALT} significantly outperforms Baseline\text{Baseline}. Furthermore, the LLMs trained via SALT\text{SALT} (with and without data selection) exhibit strong gains in post-SFT performance across a wide range of downstream tasks.

…It's unclear if the benefits would scale to much larger models (eg. 7B).

Thank you for this comment. Please refer to our response to your earlier comment above where we showcased the utility of the proposed SALT\text{SALT} framework for 8.6B sized LM training.

…cost of data selection part.

Thanks for pointing this out. We will add efficiency details for SALT with data selection.

The main cost of data selection comes from computing the small LM's top-k probabilities assigned to the next-token for each position in a training sequence. In our experiments, a 1.5B model checkpoint at n0=26Kn_0=26K steps is used to score training sequences. The 1.5B checkpoint at 26Ksteps is readily available, from the training of the 1.5B SLM. Scoring an input sequence is equivalent to a forward pass on that sequence. We score sequences asynchronously in parallel, on relatively inexpensive hardware (e.g., TPU v5 lite for inference vs. TPU v5e pod for training), without cross-chip communication.

Furthermore, we would like to highlight that since the data selection is independent of the LLM being pre-trained, these selected data can be leveraged to train multiple large student LMs, potentially of different sizes. This essentially amortizes the data selection cost across multiple pre-training runs.

How would SALT perform with larger/smaller teacher-student size ratios? For example, would the benefits persist when training a 70B parameter model with a 7B teacher? When the ratio decreases to 1, will the framework still work (eg., similar to bootstrapping).

For large student-to-teacher ratio (compared to what was present in our initial submission), please see our response above where we presented new results for training a 8.6B LLM via SALT\text{SALT} with a 2.8B sized teacher, i.e., student-to-teacher ratio of ~3.073.07 (compared to the student-to-teacher ratio of ~1.871.87 in our initial submission).

We strongly believe that our method would also benefit training a 70B parameter model with a 7B (small) teacher. Unfortunately, given the resource requirements for testing this setup and the limited timeframe of the rebuttal phase, it is not feasible for us to demonstrate this. We hope that the reviewer would understand this and assess the merit of our contributions on the comprehensive empirical evidence we have already provided in the initial submission and the revised version during the rebuttal.

Regarding the question of whether our framework will work with a student-to-teacher ratio of 1, this setting is similar to the self-distillation [1] setting studied in the literature. This should also work to improve the student LM’s performance. However, since our work is motivated by improving both quality and training efficiency of LLM pre-training via small LMs, we focus on the student-to-teacher ratio of greater than 1. That said, we do believe that a systematic study of knowledge distillation for language modeling beyond SALT\text{SALT} setting is an interesting direction for future research.

[1] Furlanello et al., Born Again Neural Networks, ICML 2018.

评论

Is there a minimum performance threshold the teacher must achieve for the method to be effective?

Please note that SALT\text{SALT} aims to selectively transfer the knowledge from a small teacher LM to larger student LM in those regions of the data domain where the small LM performs well. Thus, as long as the teacher exhibits good performance in certain regions, it should aid in the training of the larger LM. Crucially, the duration of the first stage in SALT\text{SALT} where we utilize a small LM as a teacher will get shortened when the small teacher has weaker performance.

We again thank the reviewer for their time and effort. We would urge the reviewer to reevaluate their assessment of our submission in light of our detailed response to their main concerns.

评论

Again, we thank the reviewer for their time and effort reviewing our submission. We are following up to see if you had a chance to review our response to your comments/concerns and if you have any further questions.

AC 元评审

This paper introduces a novel approach to improve large language model (LLM) pre-training efficiency and quality by leveraging a small language model (SLM) for soft label supervision and selective data sampling. The method reduces training time and enhances model quality, supported by a theoretical framework that explains how adaptive utilization of SLM guidance balances bias and variance, with empirical validation on a 2.8B parameter LLM using the Pile dataset.

The paper was reviewed by three reviewers. All of them agree this paper has issues and may not be ready for publication at this stage. I would encourage the authors to please go through the reviews, and address them in the next iteration of this paper.

审稿人讨论附加意见

The paper was reviewed by three reviewers. All of them agree this paper has issues and may not be ready for publication at this stage. I would encourage the authors to please go through the reviews, and address them in the next iteration of this paper.

最终决定

Reject