Spike No More: Stabilizing the Pre-training of Large Language Models
摘要
评审与讨论
This paper analyzes gradient norms to explain and reduce the spikes of training loss. However, this paper does not address the differences between LLM and small models, differences between deep and shallow layer, and the relationship between spikes and gradient norm. Most importantly, the risk of training LLM is simplified by evaluating the performance on benchmark datasets while the risk of spending computing resources is not well addressed unfortunately requires lots of improvements.
优点
The paper effectively addresses the critical issue of loss spikes during training, providing a detailed analysis of their relationship with gradient norms and embedding means. It evaluates existing approaches like "Scaled Embed" and "Embed LN," discussing their effectiveness in mitigating spikes. Additionally, the paper offers valuable insights into the impact of learning rate adjustments on model stability and compares the behavior of spikes in large language models (LLMs) with smaller models, providing a broader context for its findings.
缺点
The paper suffers from an unclear relationship between spikes and poor performance, with insufficient explanations of key terms and assumptions. The evaluation section is not well-explained, and there are inconsistencies in terminology. Additionally, it lacks necessary plots and data to support its assumptions, and some figures are difficult to interpret due to overlapping lines. Reproducibility is a concern as common public architectures are not used, and some discussions are considered irrelevant to the main topic.
问题
The relationship between spikes and poor performance is not clearly established.
Some loss spikes can be recovered during training, while others cannot.
Are you assuming that spikes are (1) always bad, (2) risks that result in divergence in some cases, or (3) sometimes acceptable and not harmful?
L40: What is catastrophic divergence? Does it mean the spike never goes down again?
What model is used in Figure 1?
What is the difference between spikes in LLMs and other smaller models?
L47: The assumption between spikes and gradient norm is unclear. What did you observe in L121?
L53: What is the standard deviation of embedding means? Does this equate to a large shortcut in the later sections?
L63: Evaluation is not explained.
L83: Conventionally, does shortcut mean residual?
Can you comment on the parallel FF+Attn setting where there is no intermediate vector? It is broadly used in many models, e.g., Pythia, Mesh-Transformers, and PaLM.
clear writing in section 2.1, 2.2
L130: What is the difference between W_* and W?
L134: You should mention the single-head attention and F as the identity function (linear settings) in the abstract and introduction.
L134: Please plot the distribution of x, x', and W, and note the standard deviation and mean of each layer to support your assumption. In Appendix F, there are only plots for initialized models. I could not find plots for W or during training and pretrained models.
L160: What are d and d_ffn? Did you mention them somewhere?
L167: This is an estimation as shown in L159. Please be consistent.
L174: It's unclear what this well-known formula is.
Eq 13: Why is variance the degree of freedom?
Condition 2 of large shortcut comes from assumption 1.
L219, Eq 19: Where does the 2 come from?
L234: The writing in Section 4 could follow the format in Section 5.2 to improve readability.
L250: Both "Scaled Embed" and "Embed LN" are existing approaches, not introduced by this paper.
L291: Why should it be close to 1 and not as large as possible? If so, why is it called 'large shortcut' in the previous sections? Similarly, should the sublayer be as small as possible or close to some value? Why do we scale the embed by some value but not a larger one? What is the optimal scale from your theory?
L309: C4 is too small to be the pretraining data for LLM.
L317: Why not use a common public architecture (e.g., LLama, Mixtral, Gemma) for reproducibility? Is there any common pretraining baseline in the literature (e.g., Dettmers et al. (2022) and Le Scao et al. (2022))?
Figure 3 (a): Lines are overlapped and cannot see what happened. Is there any other metric like the number of spikes other than visualization?
L367: Only in the embed layer?
The training risk mentioned in the abstract is not addressed in the experiment. We can only see some marginal improvement over perplexity. In the abstract, it indicates that 'if we don't take spikes seriously, the whole effort will be in vain due to catastrophic divergence.' However, both vanilla and embed detach do not suffer from catastrophic divergence in the experiment results.
Section 6.1: How does the learning rate affect the gradient norm? How does the learning rate affect the four baseline methods? Did you indicate that stabilization can be achieved by a smaller learning rate? If so, there is no actual need for "Scaled Embed" or "Embed LN"? In Figure 5, we can see that vanilla performs better than Scaled Embed. To achieve better final performance as depicted in Table 3, can we use a small learning rate to travel through the risky early training stage and increase the learning rate later to avoid spiking?
Section 6.2: Good point! I expect a similar discussion in Section 6.1 to explain why a smaller learning rate can reduce spikes. Please remind me if I missed something. Also, this indicates the settings of 'short seq in the early stage and long seq later.' Can we say that 'small lr in the early stage and large one later' is also possible?
L490: Can you use your theory to explain why preLN is more stable than post-LN?
L509: The efficiency and learning rate discussion is a bit too far and not quite relevant.
The paper suffers from an unclear relationship between spikes and poor performance, with insufficient explanations of key terms and assumptions. The evaluation section is not well-explained, and there are inconsistencies in terminology. Additionally, it lacks necessary plots and data to support its assumptions, and some figures are difficult to interpret due to overlapping lines. Reproducibility is a concern as common public architectures are not used, and some discussions are considered irrelevant to the main topic.
We suppose that the main concerns of the reviewer are as follows:
- The need to clarify the harms caused by loss spikes to establish the advantages of preventing them.
- A suspicion that we used an uncommon architecture, leading the reviewer to question the practicality of our study.
Regarding the first concern, we demonstrated in our experiments that models without loss spikes achieved better performance. For instance, we extracted the perplexities of the 1.7B parameter models from Table 2:
| Model | WikiText | LAMBADA |
|---|---|---|
| Vanilla | 22.58 | 15.22 |
| Embed Detach | 22.00 | 13.88 |
| Embed LN | 21.29 | 13.00 |
| Scaled Embed | 21.29 | 12.53 |
As shown in the table, Embed LN and Scaled Embed, which satisfy the conditions for preventing the loss spike, outperformed Vanilla and Embed Detach, which violate the conditions.
Furthermore, Figures 1, 5 (d), and 9 (a) show that Vanilla failed during pre-training due to catastrophic divergence. We kindly ask the reviewer to consider these empirical results alongside the training costs described in our response to the summary. If the reviewer still believes that certain discussions are unrelated to the main topic, we would greatly appreciate it if the reviewer could specify the points in question.
Regarding the second concern, we used the Pre-LN Transformer architecture, which is commonly employed in various LLMs, such as the Llama series. Therefore, we believe that our findings are practically useful. In addition, as described in Reproducibility Statement (lines 546--548), we add only several lines to a widely used implementation, i.e., Megatron-LM. Therefore, we also believe that our experimental findings are straightforward to reproduce. We would greatly appreciate it if the reviewer could point out any essential differences between the models.
L309: C4 is too small to be the pretraining data for LLM.
The essential contribution of our paper, the theoretical analysis of stability during pre-training, does not depend on the size of the training data. Additionally, we would appreciate it if the reviewer could consider the computational costs required to train each model. We would like to re-emphasize that constructing each 1.7B parameter model costs approximately $5,120.
L317: Why not use a common public architecture (e.g., LLama, Mixtral, Gemma) for reproducibility? Is there any common pretraining baseline in the literature (e.g., Dettmers et al. (2022) and Le Scao et al. (2022))?
As mentioned in the response to the weaknesses, our focus is on the Pre-LN Transformer architecture, which is a common architecture for LLMs.
Figure 3 (a): Lines are overlapped and cannot see what happened. Is there any other metric like the number of spikes other than visualization?
We consider the number of spikes to be a subjective indicator, as it can be defined in a way that is unfavorable to Vanilla. Furthermore, not only Figure 3 (a) but also Figures 3 (b), 5, and 7 support our theoretical analysis. We would appreciate it if the reviewer confirms these figures.
L367: Only in the embed layer?
Yes, Embed LN applies layer normalization to embeddings, as per our theoretical analysis. We would appreciate it if the reviewer could provide more details regarding this comment for further discussion.
The training risk mentioned in the abstract is not addressed in the experiment. We can only see some marginal improvement over perplexity. In the abstract, it indicates that 'if we don't take spikes seriously, the whole effort will be in vain due to catastrophic divergence.' However, both vanilla and embed detach do not suffer from catastrophic divergence in the experiment results. Section 6.1: How does the learning rate affect the gradient norm? How does the learning rate affect the four baseline methods? Did you indicate that stabilization can be achieved by a smaller learning rate? If so, there is no actual need for "Scaled Embed" or "Embed LN"? In Figure 5, we can see that vanilla performs better than Scaled Embed. To achieve better final performance as depicted in Table 3, can we use a small learning rate to travel through the risky early training stage and increase the learning rate later to avoid spiking?
Figures 1, 5 (d), and 9 (a) clearly show that the model violating the conditions we provided, Vanilla, diverged during pre-training. Moreover, we would like to emphasize that while a small learning rate stabilizes pre-training, it significantly degrades performance, as shown in Table 3.
Section 6.2: Good point! I expect a similar discussion in Section 6.1 to explain why a smaller learning rate can reduce spikes. Please remind me if I missed something. Also, this indicates the settings of 'short seq in the early stage and long seq later.' Can we say that 'small lr in the early stage and large one later' is also possible?
We do not think so. Section 6.1 discusses the relationship between our theoretical analysis and Li et al., 22. Specifically, we provided a theoretical justification for the strategy of using a short sequence to stabilize training and empirically demonstrated stability when varying the input sequence length. However, we did not address whether it is necessary to extend the input sequence length during pre-training, as proposed by Li et al., 22, because the purpose of our study is not to re-examine their work. Furthermore, we are unable to link our findings to the strategy of varying the learning rate during pre-training.
L490: Can you use your theory to explain why preLN is more stable than post-LN?
No, the difference between Pre-LN and Post-LN is beyond the scope of our study. As stated in lines 490--492, Xiong et al., 2020 proved that the Pre-LN Transformer is more stable than the Post-LN Transformer.
L509: The efficiency and learning rate discussion is a bit too far and not quite relevant.
As described in the response to the weaknesses, our study is useful to construct a better model in the limited budget while avoiding the failure. Therefore, we believe that our study can be regarded as efficient pre-training of LLMs.
Thanks the authors for detailed responses to the questiosn, I don't have any further questions at this point.
L134: You should mention the single-head attention and F as the identity function (linear settings) in the abstract and introduction.
The line 134 states, "This assumption is valid when we initialize parameters with the normal distribution, the number of heads in Equation (4) is 1, and is an identity function" describing the conditions under which our assumption holds. However, we did not restrict our analysis to single-head attention or as the identity function. In Section 3.2, we discussed multi-head attention, and Appendix G presents the case where ReLU is used as .
L134: Please plot the distribution of x, x', and W, and note the standard deviation and mean of each layer to support your assumption. In Appendix F, there are only plots for initialized models. I could not find plots for W or during training and pretrained models.
Figure 12 in Appendix F shows the distributions of and , both of which follow a normal distribution. Each was initialized with a normal distribution as described in Sections 4 and 5, and therefore, they clearly follow a normal distribution without requiring confirmation through plotting.
L160: What are d and d_ffn? Did you mention them somewhere?
For and , we describe them in line 76 and lines 92--93, respectively, as follows: " denotes the dimension of the layer", and " and are the internal dimensions of FFN and multi-head self-attention sub-layers".
L167: This is an estimation as shown in L159. Please be consistent.
We would appreciate it if the reviewer found it. We would like to replace = with .
L174: It's unclear what this well-known formula is.
As described in lines 174–176, we provided the equation that transforms the variable x’ which follows a normal distribution, into the variable z, which follows the standard normal distribution: .
Eq 13: Why is variance the degree of freedom?
We apologize, but we are still unable to understand the intent of this question. We would appreciate it if the reviewer could provide more details. Since we interpret this comment as a question about the degree of freedom, we would like to clarify. As stated in our paper, follows a distribution with 1 degree of freedom, and the variance of such a distribution is equal to 2. This is a property of the distribution, and the statement aligns with this property. For the case of , the variance is equal to 1, as it is the product of two independent variables that follow the standard normal distribution.
Condition 2 of large shortcut comes from assumption 1.
Is this sentence intended to be a question? We are unable to understand the reasoning of the reviewer.
L219, Eq 19: Where does the 2 come from?
We applied the description of lines 158--159 to , and thus: .
L234: The writing in Section 4 could follow the format in Section 5.2 to improve readability.
We thank the reviewer for their feedback, but we are confused as we are unable to understand the intent. We would appreciate it if the reviewer could provide more details.
L250: Both "Scaled Embed" and "Embed LN" are existing approaches, not introduced by this paper.
As mentioned in line 543, "We do not aim to propose a novel method in this paper", we did not claim to propose a novel method. Moreover, we would like to emphasize that Scaled Embed and Embed LN were proposed in previous studies; however, they are not standard practices in the pre-training of LLMs.
L291: Why should it be close to 1 and not as large as possible? If so, why is it called 'large shortcut' in the previous sections? Similarly, should the sublayer be as small as possible or close to some value? Why do we scale the embed by some value but not a larger one? What is the optimal scale from your theory?
As mentioned in the response to the summary, our motivation is to construct a model that performs as well as possible while avoiding the failure of pre-training, as the pre-training of LLMs requires an enormous amount of computational resources. Appendix E demonstrates that setting the standard deviation of the embeddings to a value larger than 1 degrades the performance. Therefore, we do not recommend using a significantly larger value.
Then, we would like to address the questions, but we would greatly appreciate it if the reviewer could provide more details for certain questions to facilitate a fruitful discussion.
The relationship between spikes and poor performance is not clearly established. Some loss spikes can be recovered during training, while others cannot. Are you assuming that spikes are (1) always bad, (2) risks that result in divergence in some cases, or (3) sometimes acceptable and not harmful?
We assume (1). In fact, our experimental results indicate that loss spikes negatively impact performance, even if a model recovers from them. For example, in the 1.7B parameter models, Vanilla and Embed Detach, which frequently encounter loss spikes, underperformed compared to Scaled Embed and Embed LN, as shown in the above table.
L40: What is catastrophic divergence? Does it mean the spike never goes down again?
Yes, we used it as the meaning that the spike never goes down again, as illustrated in Vanilla as Figure 1.
What model is used in Figure 1?
As mentioned in the caption of Figure 1, we used 1.7B parameter models to illustrate the figure. The learning rate is . In short, the configuration is identical to that of Figure 5 (d). However, since Figures 1 and 5 (d) show the loss curves on the training data and validation data, respectively, the two figures are not identical.
What is the difference between spikes in LLMs and other smaller models?
If "other smaller models" refers to Transformer-based decoder-only models, we believe our findings are essentially applicable. However, we would need more details about the "other smaller models" for further discussion.
L47: The assumption between spikes and gradient norm is unclear. What did you observe in L121?
We explained in lines 119--120 that "In our preliminary experiments, when the gradient norms grow suddenly during LLM pre-training, we observe that the loss spike problem is likely to occur". This correlation is demonstrated in the pairs of Figures 3 and 4, as well as Figures 7 and 8.
L53: What is the standard deviation of embedding means? Does this equate to a large shortcut in the later sections?
For the former question, we used "embedding" to refer to the vector representation of an input subword. Thus, "the standard deviation of embedding" refers to the standard deviation of this vector representation. For the latter question, yes, we explained in lines 291–292 that "(2) large shortcut; making the standard deviation of each embedding close to 1."
L63: Evaluation is not explained.
We interpret this comment as a critique that the Introduction lacks details of experiments. However, we would appreciate it if the reviewer could confirm Sections 5 and 6, which describe our experiments in detail.
L83: Conventionally, does shortcut mean residual?
No. In the residual block, , refers to the shortcut, and refers to the residual branch. We kindly ask the reviewer to refer to the footnote 3 of Liu et al., 20 as an example.
Can you comment on the parallel FF+Attn setting where there is no intermediate vector? It is broadly used in many models, e.g., Pythia, Mesh-Transformers, and PaLM.
Our theoretical conclusion remains valid in the parallel FFN+Attn setting. We can easily derive an upper bound as follows:
$
|| J_{\mathrm{FFN}+\mathrm{Attn}}||_2 = \Biggl| \Biggl| \frac{\partial (x + \mathrm{FFN}(\mathrm{LN}(x)) + \mathrm{Attn}(\mathrm{LN}(x)))}{\partial x} \Biggr| \Biggr|_2 = \Biggl| \Biggl| I + \frac{\partial \mathrm{FFN}(\mathrm{LN}(x))}{\partial x} + \frac{\partial\mathrm{Attn}(\mathrm{LN}(x))}{\partial x} \Biggr| \Biggr|_2 \leq 1 + \Biggl| \Biggl| \frac{\partial \mathrm{FFN}(\mathrm{LN}(x))}{\partial x} \Biggr| \Biggr|_2+ \Biggl| \Biggl| \frac{\partial\mathrm{Attn}(\mathrm{LN}(x))}{\partial x} \Biggr| \Biggr|_2
$
using the same procedure for Equations (8) and (9). For the upper bounds for and , we had estimated them with Equations (15) and (20), respectively.
clear writing in section 2.1, 2.2 L130: What is the difference between W_* and W?
We interpret this comment as asking whether the in line 131 is a typo of . Yes, we thank the reviewer for identifying this typo.
We thank the reviewer for reading our paper and feedback. Before addressing each question, we would like to explain the motivation behind our study and address main concerns as follows to clarify our contributions:
Most importantly, the risk of training LLM is simplified by evaluating the performance on benchmark datasets while the risk of spending computing resources is not well addressed unfortunately requires lots of improvements.
Based on the weaknesses and questions raised in the review, we believe the reviewer may not have fully recognized the computational cost of LLM pre-training. Therefore, we would first like to explain the associated costs. As an example, we focus on the 1.7B parameter model used in our experiments. Training each model in our configuration required approximately 1,000 A100 80GB GPU hours. Given the AWS EC2 pricing of 5,120.
Considering this substantial expense, it is crucial to construct a model that performs as well as possible while avoiding the failure of pre-training. This motivation underpins our approach. In fact, the worst-case scenario is a failure of pre-training due to the loss spike, as shown in Vanilla in Figures 1, 5 (d), and 9 (a). Such failures result in wasted costs.
In contrast, models that satisfy our provided conditions successfully completed pre-training using the same hyper-parameters.
The authors provide theoretical support to reduce the likelihood of divergence based on stabilization via small sub-layer initialization and large shortcut values. They then propose two techniques to effectively support this assumption and evaluate accordingly.
优点
The paper conducts mathematical analysis to demonstrate the requisite terms they later leverage. Paper is clear and provides actionable results.
缺点
The authors only tested on smaller models, it is well established that most instability problems happen with larger models (>100B parameters). It would be beneficial to evaluate the loss curves on larger models or more diverse datasets
Although the focus of this paper was to stabilize training, they underperform on loss-curves compared to vanilla approaches to disprove Le Scao et al's findings. This is hypothesized to be related to learning-rates - which is demonstrated by looking at a absolute min score. However, by only exploring LR adjustments on smaller models, it isn't immediately clear that the proposed method is consistently better as sub-optimal lr's are more stable, a lr scheduler could account for differences in the long-term.
问题
Where approaches evaluated on benchmarks post-training?
We thank the reviewer for reading our paper and feedback. We would like to address the weaknesses and questions as follows:
The authors only tested on smaller models, it is well established that most instability problems happen with larger models (>100B parameters). It would be beneficial to evaluate the loss curves on larger models or more diverse datasets
We theoretically demonstrated that larger models are more unstable, as the upper bounds of the gradient norms for each sub-layer, shown in Equations (15) and (20), depend on the dimension size.
On the other hand, we would like to emphasize that it is impractical to conduct experiments with extremely large models, such as those exceeding 100B parameters, as suggested by the reviewer, due to the tremendous budget requirements for such experiments. For example, at least $25 million USD would be required to train two 100B parameter models (a baseline and a stable method) with 2T tokens, that is the compute-optimal training data size suggested by Hoffmann et al., 22. We estimated this cost as follows:
- Training a single 100B parameter model requires approximately 2,457,600 GPU hours on A100 80GB GPU, based on the Llama2 paper (https://arxiv.org/abs/2307.09288).
- The AWS EC2 cost, for instance, with 8 A100 80GB GPUs, is $40.96 per hour (source: https://aws.amazon.com/ec2/instance-types/p4/).
Therefore, the estimated cost to train two 100-billion-parameter models (baseline and stable method) is: $25.2 million USD. This detailed estimation underscores the substantial cost involved in training such large models.
In addition, as described in Appendix C, we conducted experiments with 13B parameter models. This configuration represents the large model in studies focused on the stability of LLM pre-training. For example, Wortsman et al., 24 conducted their experiments with models under 10B parameters.
Moreover, we would like to emphasize that our study provides more than just empirical findings; the essential contribution is the theoretical approach to prevent the loss spike problem. Therefore, we believe that the experiments with models exceeding 100B parameters are not indispensable. We would appreciate it if the reviewer takes into account that most researchers lacks access to large-scale computing resources, as noted in the review guidelines of a recent conference: https://colmweb.org/ReviewGuide.html.
Although the focus of this paper was to stabilize training, they underperform on loss-curves compared to vanilla approaches to disprove Le Scao et al's findings. This is hypothesized to be related to learning-rates - which is demonstrated by looking at a absolute min score. However, by only exploring LR adjustments on smaller models, it isn't immediately clear that the proposed method is consistently better as sub-optimal lr's are more stable, a lr scheduler could account for differences in the long-term.
We have read this comment multiple times, but we are still unable to understand its main point. Therefore, we will respond based on the assumption that the reviewer is raising a question about the learning rate scheduler.
We used the cosine learning rate scheduler, which is a standard choice in the pre-training of LLMs, as demonstrated in the Llama series (https://arxiv.org/abs/2307.09288) and Le Scao et al., 2022. In addition, we tested several schedulers in our preliminary experiments, and the Vanilla achieved the best performance with the cosine scheduler. If the reviewer has a suggestion for a more sophisticated scheduler, we would greatly appreciate it if the reviewer could share the idea with us.
Where approaches evaluated on benchmarks post-training?
This paper focuses on stabilizing the pre-training phase, and as such, post-training is outside its scope. However, as shown in Appendix E of Fedus et al., 2021, negative log-likelihood correlates with the performance of downstream tasks. Therefore, we believe that a model achieving better negative log-likelihood during pre-training, such as Scaled Embed in our experiments, is also likely to achieve better performance after post-training.
The goal is not to require >100B models, but that the paper provides evidence to more separation between Vanilla and Scaled Embed approaches beyond 1B - aligning with the intuition that in larger model schemas the results here may not apply. For instance, in table 6 we see that for a given lr, and identical sizes at 13B, vanilla out perform scaled embed. Thus, it is not clear if the gap of difference could be addressed by other traditional techniques. Raising the question on the applicable benefits of the proposed method against standard schemas for larger models. More comparison against these techniques or evaluation on larger models to understand/support the technique is recommended.
We thank the reviewer for the reply.
We would like to emphasize that our study addresses the loss spike problem, and thus, the purpose of the experiments is to demonstrate whether a model satisfying the conditions from our theoretical analysis is stable. In fact, the reviewer noted in the initial review that "it is well established that most instability problems happen with larger models (>100B parameters)". Therefore, we believe it is appropriate to focus primarily on stability during pre-training.
In the experiments with 13B parameter models, as shown in Figure 9 (a), the pre-training of the Vanilla failed due to the loss spike when we set the learning rate to . In contrast, the model satisfying our provided conditions, Scaled Embed, successfully completed pre-training with the same learning rate. This result indicates that our theoretical analysis is also effective in stabilizing the pre-training of 13B parameter models.
Since the pre-training of LLMs requires a vast amount of computational resources, it is desirable to use a larger learning rate during pre-training to construct a better model within a limited budget. Thus, stabilization at large learning rates is crucial. In other words, criticisms regarding performance when using a small learning rate, such as , to stabilize Vanilla are outside the primary scope of our study, as such learning rates degrade performance.
We interpret that the reviewer has focused on the results with the smaller learning rate, , and criticized that Scaled Embed underperformed Vanilla. However, Scaled Embed with a learning rate of outperformed Vanilla in terms of perplexities on the WikiText and LAMBADA datasets (lower perplexity is better). For example, we extract the results of the LAMBADA dataset from Table 6 as follows:
| Model | lr | lr |
|---|---|---|
| Vanilla | N/A | 6.50 |
| Scaled Embed | 5.97 | 6.53 |
For lr , we describe "N/A" for Vanilla because it diverged during pre-training. Moreover, for lr , we consider the performance difference to be insignificant. For instance, the perplexities of Vanilla and Scaled Embed on LAMBADA are 6.50 and 6.53, respectively, as shown in the above table. Based on these results, we believe the modification based on our theoretical analysis stabilizes pre-training and poses no significant risks to performance, as mentioned in lines 890--891.
This paper presents a strategy to avoid spikes in loss during the training of LMs by keeping the gradient norm small. To manage the upper limit of gradient norms effectively, the method involves (i) using small initial values for sub-layers, and (ii) maintaining the standard deviation of embeddings around 1.
优点
- The findings (i) and (ii) from the analysis are well presented, although they have been previously utilized in past studies.
- This work examines various learning hyperparameters.
- This work also presents the results for the 13B model in Table 6 of the Appendix.
- The paper is well-written and easy to understand.
缺点
- Although the theoretical analysis is intriguing, I question the practical value of this work, as most practices described in Section 4 are already in use. Utilizing small values for initialization to ensure stable training is well-known, and both Scaled Embed and Embed LM have been introduced in prior literature. If this work could offer a novel, advanced method for embedding normalization, it might receive more interest from the community.
- The activation function F was assumed to be either an identity function or ReLU, as stated on line 152 of page 3. What would be the results if widely used activation functions in recent LLMs, such as SiLU and SwiGLU, were applied?
- I am curious about how loss spikes impact the performance of downstream tasks on LLM leaderboards, beyond just affecting perplexity. Are these spikes also harmful to the accuracy of downstream tasks?
- I believe it would be beneficial to conduct a theoretical analysis of the relationship between learning rate, loss spikes, and model sizes. This suggestion stems from the observation that the learning rates causing loss spikes differ according to model size.
问题
Please see the above weaknesses section.
伦理问题详情
n/a
We thank the reviewer for reading our paper and feedback. We would like to address the weaknesses as follows:
Although the theoretical analysis is intriguing, I question the practical value of this work, as most practices described in Section 4 are already in use. Utilizing small values for initialization to ensure stable training is well-known, and both Scaled Embed and Embed LM have been introduced in prior literature. If this work could offer a novel, advanced method for embedding normalization, it might receive more interest from the community.
We appreciate the reviewer's recognition of the value of our theoretical analysis of the loss spike problem. We would like to emphasize that this paper does not aim to propose a method as stated on line 543 of our paper; rather, it aims to understand why the loss spike problem occurs during the pre-training of LLMs and to provide conditions to prevent it. As described in "Novelty is in the Eye of the Beholder" (https://drive.google.com/file/d/1ydN247sEXjnP0P_JByf287ifXNcdoIBM/view), introduced in the ICLR 2022 Reviewer Tutorial (https://icml.cc/Conferences/2022/ReviewerTutorial), we believe that the novelty of a research paper is not limited to proposing “a novel method” but also includes offering insights through analysis and understanding, as our paper does.
In addition, we would like to emphasize that using Scaled Embed or Embed LN is not currently standard practice in the pre-training of LLMs. This can be verified by checking widely used open-source implementations for LLM pre-training, such as Megatron-LM and LLMFoundry; neither of which includes Scaled Embed or Embed LN as part of their functionalities. This fact indicates that current LLM pre-training processes are generally conducted under conditions that conflict with the requirements we introduced to prevent the loss spike problem: (1) initializing the parameters of sub-layers with small values, and (2) ensuring the standard deviation of each embedding is close to 1. As a result, the publicly available pre-trained models may potentially suffer from reduced performance and pre-training efficiency due to, for example, the use of smaller-than-ideal learning rates aimed at avoiding loss spikes. Moreover, to show that the conditions we introduced for preventing loss spikes can be met easily without requiring major changes to the standard Transformer architecture, we intentionally used the simplest and most conventional methods available. The fact that loss spikes can be prevented with such simple modifications is expected to encourage many researchers/developers conducting LLM pre-training to implement our conditions, potentially resulting in significant practical and societal impact. This is why we chose straightforward methods. Once again, our contribution lies not in proposing novel methods but in providing a theoretical justification for previously unresolved observations regarding why loss spikes occur, along with simple solutions to prevent them.
Thus, we believe our theoretical analysis of the loss spike problem, along with the empirical validation, offers valuable insights for the research community.
The activation function F was assumed to be either an identity function or ReLU, as stated on line 152 of page 3. What would be the results if widely used activation functions in recent LLMs, such as SiLU and SwiGLU, were applied?
We are planning to work on the proofs for other activation functions as well. While setting aside the detailed proofs for the upper bounds, we believe the overall theory will likely still hold.
I am curious about how loss spikes impact the performance of downstream tasks on LLM leaderboards, beyond just affecting perplexity. Are these spikes also harmful to the accuracy of downstream tasks?
As shown in Appendix E of Fedus et al., 2021, perplexity (i.e., negative log-likelihood) strongly correlates with the performance of downstream tasks. Consequently, when a loss spike degrades perplexity, it also reduces the accuracy of downstream tasks. Additionally, as demonstrated in our paper, loss spikes can sometimes ruin the pre-training process. Therefore, addressing the loss spike problem is crucial.
I believe it would be beneficial to conduct a theoretical analysis of the relationship between learning rate, loss spikes, and model sizes. This suggestion stems from the observation that the learning rates causing loss spikes differ according to model size
We have already demonstrated, from a theoretical perspective, that larger models are more unstable because the upper bounds of the gradient norms for each sub-layer, as shown in Equations (15) and (20), depend on the dimensionality. Due to this property, we consider it necessary to set a smaller learning rate to prevent the growth of each parameter in larger models. However, we would like to emphasize that we validated our theoretical findings for stabilizing the pre-training of LLMs through various experiments.
Thank you for the time and effort the authors have invested in this rebuttal. I appreciate the additional explanation that 'using Scaled Embed or Embed LN is not currently standard practice in the pre-training of LLMs,' with detailed examples such as Megatron-LM and LLMFoundry. I also agree with the point that 'the novelty of a research paper is not limited to proposing a novel method, but also includes offering insights through analysis and understanding.' Therefore, I am increasing my score from 3 to 5, as I believe this paper delivers a meaningful message for achieving effective LLM training.
However, I am still concerned about the lack of investigation into other activation functions and the somewhat limited scope of the experiments. It is now common practice to present results not only for PPL but also for downstream tasks. Although the authors provide a supporting reference, I am unsure if the differences in PPL between the models of this work are significant enough to impact downstream task performance. Considering that this paper focuses on small-scale models, conducting additional experiments on other benchmarks with further fine-tuning (using cost-effective techniques like LoRA) could enhance the experimental validity of the approach. Alternatively, the authors could present downstream task scores of the pretrained models without further training.
Additionally, concerning the remark, 'we consider it necessary to set a smaller learning rate to prevent the growth of each parameter in larger models,' it would be helpful to specify what range is considered "small" for learning rates. Detailed analysis or guidance on the relationship between learning rate and loss spikes would also be beneficial.
I feel that many recent papers with theoretical contributions are also supported by sufficient experimental evidence. However, it remains unclear to me if this paper meets such criteria.
We thank the reviewer for the reply. We appreciate the reviewer for acknowledging the factual situation regarding LLM pre-training and for sharing the opinions once again.
- For other activation functions:
The reviewer requested that we discuss the SiLU and SwiGLU functions as activation functions in the FFN layer. We have added a theoretical analysis of the SiLU and SwiGLU functions to Appendix G. As shown in Appendix G, we can control the upper bound of the gradient norms by making the standard deviations of the parameters small and making the standard deviation of each embedding close to 1, similar to the cases with the identity and ReLU functions. In short, the conditions we provided are also applicable to the SiLU and SwiGLU functions.
- For downstream task performance:
We interpret the reviewer's concern as questioning whether our theoretical analysis leads to improvements in downstream tasks. However, we would like to emphasize that this study primarily addresses the loss spike problem to stabilize the pre-training of LLMs. Our goal is to provide conditions to prevent the loss spike through theoretical analysis, but this does not directly improve the performance. Therefore, the main purpose of our experiments is to verify whether models satisfying our provided conditions, such as Scaled Embed, are more stable than models that violate the conditions, such as Vanilla. As important evidence, Figures 1, 5 (d), and 9 (a) demonstrate that the pre-training of Vanilla failed due to the loss spike. In contrast, Scaled Embed succeeded in pre-training. We believe these results validate our theoretical analysis for stabilizing LLM pre-training.
On the other hand, although we reported PPL on the WikiText and LAMBADA datasets, we agree that the performance on downstream tasks is useful as additional information. We understand that the reviewer is particularly interested in larger models, and thus, we investigated the performance of 13B parameter models. The following table shows the performance of Vanilla with the learning rate of and Scaled Embed with the learning rate of on downstream tasks. We would like to note that Vanilla with the learning rate of diverged during pre-training due to the loss spike.
| Model | PIQA | OpenBookQA | HellaSwag | Winogrande | ARC-easy | ARC-challenge |
|---|---|---|---|---|---|---|
| Vanilla | 77.80 | 38.40 | 69.10 | 61.64 | 57.95 | 33.53 |
| Scaled Embed | 78.94 | 39.20 | 71.03 | 63.77 | 60.31 | 35.49 |
This table shows that Scaled Embed, which achieved better PPL on the WikiText and LAMBADA datasets, also outperformed Vanilla in downstream tasks.
- For the learning rate:
It is a common strategy to choose a smaller learning rate as the model size increases, as shown in Table 2.1 of the GPT-3 paper. Practically, learning rates in the range of to are widely used for LLM pre-training, as demonstrated in the GPT-3 paper and the Llama series.
On the other hand, we believe that specifying the optimal value for every configuration is challenging, as it is impractical to account for all factors related to the pre-training of LLMs. For example, various hyper-parameters beyond initialization and Transformer architecture, such as the learning rate scheduler, the number of training tokens, and the input sequence length (as discussed in Section 6.2), must also be considered. Additionally, we believe that the state of the model during pre-training depends on the properties of the training data. Therefore, exploring the optimal learning rate for every configuration is beyond the scope of our study.
I implemented this method in my fork of Fish-Speech, and it has proven to be highly effective. The previously observed spikes in gradient norm and loss have vanished. My model, a 0.5B AR transformer-based text-to-speech system, has greatly benefited from this method.
We greatly appreciate your positive feedback and are delighted to hear that our findings are useful for a text-to-speech model!
Dear Reviewers and Area Chairs,
We sincerely thank the Reviewers and Area Chairs for taking the time to review our paper.
Although some reviewers have not responded to our previous comments yet, we believe that we have clarified the main contributions of this paper and addressed the reviewers’ concerns. In this post, we summarize the revised sections based on the discussion and highlight the main contributions of this study.
Revised sections
- We have added a theoretical analysis of SiLU and SwiGLU as activation functions in the FFN layer to Appendix G.
- We indicated that our provided conditions can also be applied to SiLU and SwiGLU, which is used in recent LLMs such as the Llama series and Qwen series, to stabilize pre-training.
- We have corrected several typographical errors.
Main contributions
- We propose a novel theoretical framework for stabilizing the pre-training of LLMs by identifying an upper bound for the gradient norms of each Transformer layer.
- To address the loss spike problem, we provided two conditions to suppress the upper bound and introduced a combination of previously proposed techniques to satisfy the conditions.
- As described in our response, while each individual technique has been introduced in prior studies, the combination of techniques to satisfy our provided conditions is not a standard practice in LLM pre-training. Consequently, LLM pre-training is generally conducted under conditions that violate the conditions necessary to prevent the loss spike problem.
- Experimental results demonstrate that models satisfying these conditions successfully avoided loss spikes. In contrast, models that violated the conditions experienced loss and gradient spikes, with pre-training sometimes failing due to the loss spike, as shown in Figures 1, 5 (d), and 9 (a).
The paper analyses the important question of divergence in optimizing the parameters of large neural language models. The authors investigate the high gradient norms as the cause for loss spikes, and discuss two mitigation strategies in terms of having small sub-layers and large shortcuts.
While the considered problem is of great interest, the paper has a series of shortcomings. The reviewers raised questions on the limited analysis of the work in understanding the effect between spikes and poor model performance. Furthermore, some sections of the paper are not clear enough, in particular in the evaluation section.
Reviewers expressed concerns about the significance of the findings. In particular, there are questions on whether the gains of the proposed remedies occur only when the model diverges. In that context, authors are advised to investigate the benefit of the small sub-layers and large shortcuts also in cases when the optimization hyper-parameters are carefully tuned to avoid a divergence.
Considering the mentioned limitations I recommend rejection, but simultaneously encourage reviewers to consider the provided comments in improving the paper.
审稿人讨论附加意见
Overall the authors and reviewers engaged positively during the rebuttal, except for a discussion point on the necessity of model sizes. The AC agrees with the authors' argument that very large model sizes are not a must requirement for acceptance. However, I also see merit in the remaining arguments of the reviewers concerning the quality of the work.
Reject