PaperHub
5.5
/10
Poster4 位审稿人
最低4最高7标准差1.5
7
7
4
4
4.0
置信度
COLM 2024

Early Weight Averaging meets High Learning Rates for LLM Pre-training

OpenReviewPDF
提交: 2024-03-22更新: 2024-08-26
TL;DR

LLM pre-training requires large batch sizes paired with higher learning rates; we investigate a specific weight averaging scheme that improves language modelling in this pre-training setting.

摘要

关键词
LLM pre-trainingHigh learning ratesWeight AveragingModel merging

评审与讨论

审稿意见
7

This paper proposed a method to do weight averaging along the LLM pretraining trajectory. The experiment results show that the obtained checkpoints are better than the original checkpoint (with the fixed training steps) in terms of validation loss and zero-shot performance on downstream tasks.

接收理由

  • The proposed method is simple and easy to implement, and it costs relatively little extra computation. It can be used for any LLM's pretraining.
  • To the best of my knowledge, this work is the first to use the technique of weight averaging during pretraining, which IMO is a kind of novelty in some sense.
  • The intuition discussed in Section 2.2 makes a lot of sense to me.
  • The experiment results are rich. The proposed method seems to be effective as it improves both the perplexity and the zero-shot performance of LLMs, especially for models trained with large learning rates. Its effect in mitigating loss spikes is also interesting.
  • Appendix A discusses some FAQs, which were helpful when I finished reading the main text.

拒绝理由

  • The writing of some parts could be improved. As a person who is not familiar with LAWA/EMA/SWA, I wrongly assumed the proposed method was to add some checkpoints obtained by weight averaging during pretraining when I just finished reading Section 2.3. And finally, it took me some time to understand the proposed method does not influence the checkpoints used for training.
  • Even though I love the intuition in Section 2.2, they may not be convincing enough for some readers:
    • The Optimization Viewpoint seems to hold for the toy setting, but how about the much more complicated setting of large language models?
    • The Diversity Viewpoint talks about ensembling, but when the model training is at the early stage the models are not strong enough (i.e., they may not be models that have some could-be-essential capabilities, which are much more common in previous ensembing works). In this case, how can the intuition of ensembling explain that?
  • Even though the authors argue that (1) where conducting a grid search is challenging due to the model’s size, adopting our proposed training recipe could be advantageous and (2) it is safe to say that LAWA provides some gains for all LRs but works the best with high LRs, I still have the concern that the proposed method relies too much on the assumption that the model is pretrained with a large learning rate.
    • As previous works already pointed out that we can use a small LR (e.g., Sophia), why should we assume that people in the future have to do an extensive grid search for the pretraining learning rate?
    • Then, the improvement seems to be very marginal with large LR, e.g., Figure 2(a).
    • Fortunately your proposed method does not bring (too much) degradation in all the experiments conducted, and then you may argue that "we can always adopt the method in any situation". But the mentioned facts show that the effectiveness is kind of limited, so you may have to get a stronger argument regarding the learning rate.
  • I noticed some discussions on hyper-parameters in Appendix C.2 but they are just too empirical. I hope to see a more intuitive discussion or explanation so it can guide people who use the proposed method in the future. However, it is not a major concern because people could be able to tune the hyperparameters with such a lightweight method.

给作者的问题

As I said in Reasons To Reject, I had a misunderstanding about your proposed method. I am curious about how the authors think about the method that I was wrongly thinking of - the checkpoints obtained by weight averaging are also added to the training trajectory (i.e., such checkpoints continue to be trained).

作者回复

Thank you for your positive assessment of our work.

(A1) Appendix A discusses …

Very grateful for such a thoughtful remark. I will make a FAQ section in all my upcoming works.

(R1) The writing of some parts..

Noted we will update the draft based on this remark.

(R2) Even though I love the intuition …

We agree the toy setting is intended to provide some intuition on why LAWA is working in a non convex setting but we make no rigorous claims there. We defer a more rigorous analysis of the optimization perspective as a future work.

In Section C.4, we discussed that LLMs achieve linear mode connectivity (LMC) not at the onset of training, but relatively early in the process. For example, for a Pythia 1B model, early weight averaging (EWA) is effective starting from 8,000 steps. This phase transition facilitates LMC, making weight ensembling feasible. Our understanding is shaped by these experimental observations.

(R3) Even though …

LLMs are trained with (some sort of) large LRs due to the usage of large batch sizes. This idea is well studied for non-adaptive [1] and adaptive optimizers [2]. The Pythia LLMs are trained with well tuned LRs and we observe gains in both perplexity and downstream tasks. We emphasis that high LRs are preferred but our recipe also helps with well tuned LRs.

(R3 contd) As previous work…

Sophia is an optimizer and LAWA can work on top of any optimizer to amply gain. It’s a bit unfair to compare these two works as choice of LRs may vary based on the optimizer.

(R3 contd) Then … argument regarding learning rate.

As above LLMs are typically trained with larger batch sizes [3] for systemic efficiency concerns and typically large LRs are needed in such scenarios and LAWA helps.

(R4) I noticed ...

As one can see in C.2 LAWA works better than the original training run for various different k and ν\nu. One has to perform some ablations to select the best suited k and ν\nu for their usecase which may vary for different settings.

(Q1) As I said …

The averaged checkpoints to be used for further pre-training is something that we are considering as a future work.

[1] Goyal et al.

[2] Malladi et al.

[3] Birdman et al.

评论

Dear Reviewer rseH,

Thanks for taking some time to review our work. We feel your thoughtful remarks have improved the quality of our work. Since we are close to end of discussion period we would love to get some additional feedback on the comments we made during the rebuttal. We also invite you to update your scores only if you believe we have sufficiently addressed your concerns.

审稿意见
7

This paper studies early weight merging to speedup the LLM pretraining. Authors found that models trained with high learning rates observe higher gains due to checkpoint averaging. Authors proposed one simple but effective strategy and surpasses popular checkpoint averaging baselines such as exponential moving average (EMA) and stochastic moving average (SWA) when training nanoGPT-2 models from 125M to 770M.

接收理由

  1. The paper is very well written. The Section 2 is friendly to the readers who are not knowledgeable in weight averaging.
  2. The results are promising under authors' current setting. It looks like there is a clear convergence speedup at the early stage of training.

拒绝理由

  1. It would be great to discuss more about the difference between this approach and existing approaches. For instance, can I say it simply change the distribution of EMA and stochastic weight average?
  2. It seems that the improvement is very small at the late stage of training. Maybe the reason is that the averaging actually provides a smaller learning rate in practice, which learns the things that will be learned by smaller LR. If so, I have a concern about whether the approach will really work when training with a very large dataset like 15T tokens from llama3.
  3. The experimental results about the real throughput are missing. Loading/merging the checkpoints would also be time-consuming if we train very large models on the GPU cluster. Would the speedup benefits diminish under this type of setting because of lower GPU utlization?

给作者的问题

See above.

作者回复

Thank you for your positive assessment of our work.

R1 It would … weight average?

Additional results we have discussed various weight averaging approaches in Sec. 1. We emphasis that conventionally no weight averaging (WA) approach is used during LLM pre-training. We have discussed differences with similar tail averaging schemes such as EMA and SWA in Sec 1 and in supplementary Sec. B. We have used EMA and SWA as our baselines shown results in Fig 1.

(R1 contd..) For instance … stochastic weight average?

LAWA w/ far WA outperforms EMA and SWA with far WA Thanks for this thoughtful remark. We provide additional experiments with GPT2-small and medium. New Fig1. https://i.postimg.cc/LX13SqPM/GPT2-small-far-comparison.png

New Fig2. https://i.postimg.cc/gjbq2YyL/GPT2-medium-far-comparison.png

(R2) It seems that the …

True, the improvement may look small but it is significant in val loss and downstream performance for all GPT2 models, Pythia 1B - 2.8B LLMs. Most importantly this gain is almost a free lunch. The LR annealed fully towards the end of training hence we observe lesser gains compared to intermediate steps, But now one can sample intermediate checkpoints (say at 75%) with LAWA hence this recipe can help even in longer as well as compute optimal training runs. One can also design better LR schedulers for LAWA training runs for further gains.

(R3) The experimental … GPU utilization?

The LAWA is completely decoupled from the actual training run hence real throughput is not a valid metric for our setting. One can save checkpoints and perform LAWA even using CPUs both during or after training.

If you are convinced with our responses and additional results we hope you will re-consider your score.

评论

Dear Reviewer BhE4,

Thank you for your thoughtful remarks. We have added additional experiments to address your remark on distribution of EMA and SWA. We would be very grateful if you may engage in further discussions with us. Meanwhile we if you belive we have addressed your concerns well enough please consider improving your scores.

审稿意见
4

This paper studies the use of LAWA (latest weight averaging; Kaddour, 2022) -- a type of checkpoint averaging technique -- in the context of pre-training language models. The key insights discussed by the paper is that

(1) Models trained with high learning rates particularly benefit from checkpoint averaging at the early stage of training. The loss trajectories of such models mimic those of models trained with low learning rates. Thus, checkpoint averaging can be used as a technique to improve high-LR training's generalization.

(2) For LAWA and high-LR training, using distant checkpoints performs better.

The authors conducted thorough experiments, studying models with different LRs, different scales, and different training data. The evaluation is done on both perplexity and zero-shot downstream task performance. The experiments show that LAWA is better than other methods such as EMA (exponential moving average; Szegedy et al., 2015) and SWA (stochastic weight averaging; Izmailov et al., 2018). LAWA also leads to faster convergence at the early stage of training.

接收理由

(1) The experiments in the paper are very thorough, spanning across different models, different scales, different data, and different evaluations. The proposed training paradigm leads to faster convergence in high-LR trainings at the early training stage.

(2) The two insights -- high-LR training + checkpoint averaging helps improve generalization, and LAWA should use larger intervals -- are effective and are verified by thorough experiments.

拒绝理由

(1) Novelty: the main method, LAWA, is not new. The use of LAWA in high-LR training and using larger interval for LAWA (the two key insights in the paper) seem to me a natural applications (and hyperparameter tuning) of the original LAWA method.

(2) From the main results, it seems that the proposed training method at most can match the standard low learning rate training (Figure 2). Please let me know if I missed something here (I assume in most experiments, you control the learning rate to be the high one, except in Figure 2). Now the authors can argue that high learning rates are useful in (a) large batch size scenario and (b) very short training. However, there are no experiments on (a) in the paper (batch sizes are all fixed?) and for (b), it seems that the gain is only significant at very very early training stages (from Figure 2). Note that compared to nowadays pre-training scale (1T tokens), the experiments are already quite toy (10B, 207B tokens), thus "an early stage training" of this setting may have little empirical benefit implications.

给作者的问题

Please see my "reasons to reject", and let me know if I misunderstand anything!

作者回复

(R1) Novelty ...

LAWA [1] applies Polyak’s iterative weight averaging (WA) [2] to deep NNs. No prior work to our knowledge studied nature of iterates (ckpts) best suited for iterative WA. We show early, far and model trained with high lr ckpts are best suited for LAWA (ours) for LLM pre-training. Additionally we added EMA and SWA as baselines for LLMs and thoroughly analyzed LAWA. New Fig1. far away WA improves over LAWA [1]; https://i.postimg.cc/MpbR8vYh/lawa1-vs-lawa2.png

(R1 contd..) The use of LAWA …

Conventionally latest ckpts are thought to be better than far away older ckpts in large scale optimization specifically in early stages of training. The same goes with high lr. Our results are not so trivial and given its benefits to LLM pre-training it deserves to be studied.

(R2) From the main results …

Clarification: From Fig. 2 a) LAWA (ours) with lr=6e-4 (low lr) outperforms Original (no LAWA) with low lr at all intermediate steps and also at 50K but the gains are lower w.r.t intermediate steps as the lr schedule was fully annealed. Moreover for a well-tuned lr used by Pythia 1B-2B LLMs (Fig 4 and 5, Table 1) we see a clear final improvement with LAWA (ours).

(R2 contd …)

Consistent gains with LAWA(ours) LR details are provided in Table 3. All GPT2-small, medium and large are trained with batch size (BS) 131K, 50K and 16K respectively. Pythia LLMs are trained with BS=2M tokens. LAWA achieves consistent gains with variable high lr and BS for GPT2 models in Fig 1 and fixed BS and variable LRs for Pythia LLMs.

Additional results

a) variable BS and fixed LR-6e-3 for GPT2-small, New Fig3: https://i.postimg.cc/tCvLyNXx/GPT2-small-batchsize-comparison.png

b) fixed BS for two GPT2 models + EMA and SWA far WA, New Fig4:

https://i.postimg.cc/LX13SqPM/GPT2-small-far-comparison.png

https://i.postimg.cc/gjbq2YyL/GPT2-medium-far-comparison.png

We emphasis that the gains are not limited to just the early stages. Next this is the first work to that shows improvements in LLM pre-training with WA. Our recipe has the potential to be effective when training with 1T tokens specifically for compute optimal training runs like Chinchilla which employs early stopping. For more implications please check A.2 FAQs in supplementary.

[1] Kaddour et al.

[2] Polyak et al.

We will update the draft with new results.

评论

Dear Reviewer ZNy2,

Since we are very close to the end of the discussion period, it is a gentle reminder that please engage with us for discussing our comments and newly added results. We worked very hard to generate 4 new figures of additional experiments. Let us know your thoughts and if you feel we have sufficiently addressed your concerns we invite you to improve your scores.

评论

I acknowledge that I have read the author's rebuttal; however, I still stand for my original evaluation and thus will keep the score unchanged. My main concern is still how novel the proposed method is compared to the original LAWA (even with the updated result; the gain is pretty small) and also the application of the finding.

评论

Thank you for engaging with us for further discussions.

Clear final loss improvements with far averaging not studied in original LAWA paper.

Below we provide the final improvements in val loss with 1K far averaging compared to LAWA with not so far averaging. The table is created using New Fig. 1 from rebuttal and Fig. 2 b) in paper.

ModelsStepLAWA (far=100)LAWA (far=1K)Original
GPT-2 Small70K3.283.263.28
ModelsStepLAWA (far=200)LAWA (far=1K)Original
GPT-2 Medium70K2.8282.8192.845

We show thematically similar results with pythia-1B (refer Fig. 10 b).

Major differences with original LAWA

We emphasis that the key novelty of this work is not in showing that LAWA works , but to demonstrate that LAWA works better for early and far away averaging, this insight is neither trivial nor just an application (rather counter intuitive in optimization perspective as latest checkpoints are considered better than far away older checkpoints). Below we list some other major differences.

TopicLAWA-OusLAWA-Original
SWA baseline✔️
EMA baseline✔️
Large scale experiments✔️
Impact on models trained with large LR✔️
Impact on downstream performance✔️
Additional analysis with linear model connectivity✔️
审稿意见
4

The paper investigates the effect of checkpoint averaging along the trajectory of a model training. It adapts a previously proposed method - LAtest Weight Averaging (LAWA), for decoder-only LLMs. The authors were able to train models with higher learning rate to observe gains in validation loss and some downstream tasks. They also adapt exponential moving average (EMA) and stochastic moving average (SWA) for their setup and compare the methods.

接收理由

  • The method of model weight averaging with higher learning rate is simple and shows improvement in validation loss and some downstream tasks.
  • The paper is easy to follow, and compares with some relevant methods.

拒绝理由

  • The methods shows improvement in early stages of training, and the difference gets increasingly smaller at later stages. It is seen in both loss value and downstream task performance. This raises the question: will the conclusions still hold for longer training regime such as llama family of models?
  • The experimental setup is limited, and there might be some issue in the GPT2 training experiment. Generally, we expect the loss to get smaller as we increase the model size. However, GPT2-large has worse val loss than GPT2-medium.
  • The baseline setup is weak. Training without any averaging can improve further with better optimized batch size and learning, and they're correlated. Higher learning rate along with larger batch size can lead to better performance. e.g. llama 7b was trained with 4M batch size and LR=3e-4
  • The EMA and SWA adaptation requires corresponding update to averaging interval similar to LAWA.
  • The cost of hyperparameter search for LAWA should be discussed when comparing with the baseline.

给作者的问题

  • How was the moving window interval and number of model checkpoints to average decided? what is the corresponding cost? does it vary with model size and architecture?
  • It'll be good to add learning rate schedule for different training methods.
作者回复

Thanks for the assessment.

(R1) The methods shows improvement in early stages…

Clarification: The gains are not just in the early stages, the diff. gets smaller but more significant in later stages. In later stages the LR gets fully annealed, hence this should hold for longer training runs where lr schedules will anneal much later in training. For all GPT2 and Pythia 1B-2.8B models (reasonably large) we observed consistent gains in final loss/perplexity and downstream performance. For extended discussion see A.1 A.2 in supplementary.

(R2) The experimental setup ..

Our experiments are reasonably large such as GPT2 (770M and 9B tokens) and large scale results with Pythia (1B-12B with 207B tokens). GPT2-medium is trained with batch size(BS)=50K and GPT2-large with BS=16K due to GPU memory considerations hence GPT2-large has higher loss than medium.

(R3) The baseline setup is weak…

LAWA observes higher gains with high LR but it also helps with well tuned LR and batch size (BS), for example Pythia LLMs are trained using optimized settings and still observes gains in final loss and downstream tasks for 1B-2.8B LLMs and more intermediate gains in 6.9B and 12B LLMs. We believe additional gains can be observed if llama 7b is trained with LAWA.

(R4) The EMA and SWA adaptation requires corresponding update to averaging interval similar to LAWA.

Additional results LAWA w/ far WA outperforms EMA and SWA with far WA.

New Fig1. https://i.postimg.cc/LX13SqPM/GPT2-small-far-comparison.png

New Fig2. https://i.postimg.cc/gjbq2YyL/GPT2-medium-far-comparison.png

(R5) The cost of …

The cost of hyper parameter search is well known to be extremely expensive as shown in several prior works such as [1][2]. Due to limited compute we didn’t show such a result and also for the fact is this is an intuitive and well established result in large scale pre-training.

Q1 How was ..

Based on ablations shown in supplementary C.2 Ablations. Our ablations are done using the Pythia 1B model, no we didn’t vary model size and architecture due to compute limitations. The cost of compute and savings are provided in supplementary C.1 and Figure 9.

Q2 It’ ll be good …

All the GPT2 training conducted in this paper uses standard cosine LR scheduler following [1]. Pythia models are also trained using a similar scheduler.

[1] Liu et al. [2] Yang et al.

评论

Dear Reviewer GUxG,

Since we are very close to the end of the discussion period, we are eager to know if our responses have satisfactorily addressed your concerns regarding our paper. We have also provided additional results to address your concerns about EMA and SWA baselines. In case we you feel we have sufficiently addressed your concerns please reconsider your score.

最终决定

This paper makes an interesting twist to model averaging by suggesting that averaging be done with large learning rates and from spaced-out checkpoints. They show significant gains over previous methods in early stages of training. At convergence the gains are small but consistent. The core idea is thus quite simple, and the paper does not provide any theoretical insights. However, the paper may have practical impact for anyone trying to train a small/medium-scale model fast. All reviewers unanimously agree on this point. Experiments on large-scale LLMs is lacking, and it unclear how far these conclusions will generalize to larger models. However, in the small and medium scale, the paper provides exhaustive experiments.