Spike No More: Stabilizing the Pre-training of Large Language Models
Theoretical analysis to prevent loss spikes during LLM pre-training
摘要
评审与讨论
This paper investigates the factors that maintain a small gradient norm for LLM pretraining by analyzing the spectral norms of Jacobian matrices across sub-layers, under the assumption that loss spikes are caused by a sudden increase in gradient norm. The authors find that stabilizing the pre-training process requires two key conditions: using smaller sub-layers and incorporating larger shortcut connections. They conduct a range of experiments to empirically validate their theoretical insights. Results show that methods satisfying these conditions effectively prevent loss spikes during pre-training.
接收理由
Understanding and addressing loss spiking is a highly important and timely topic, particularly for stabilizing LLM pre-training.
拒绝理由
- The proposed solutions and findings are not entirely convincing. For instance, one of the suggested strategies: initializing sub-layer parameters with smaller values, oversimplifies the problem. Prior work has shown that loss spiking results from the interaction of training hyperparameters, data, and model state. In some cases, spiking occurs even when the data and configuration seem correct. Simply using smaller initialization values may not generalize. Moreover, the identification of sub-layers is highly data-dependent, and different data shuffling orders can yield different sub-layer behaviors. Also, since loss spikes typically occur after the warm-up phase, it's unclear whether smaller initializations have any effect at that point. These issues raise serious doubts about the reliability of both the method and the conclusions in this paper.
- The experimental setup is insufficient and seems incorrect. The authors rely primarily on the C4 dataset and perform most experiments on 350M and 1.7B models, with only limited testing on a 13B model. However, as noted in the LLaMA paper, significant loss spikes tend to appear at larger scales (e.g., 33B and 65B). For 13B models, such spiking is generally mild. This brings serious doubt on whether the experimental setup in this paper is accurately configured and reflects the loss spiking phenomena that the paper aims to study. As a result, the conclusions in this paper are not only unconvincing but also potentially misleading.
We appreciate the reviewer taking the time and effort to review our paper. Below, we address each of the concerns raised.
Reason to Reject 1
We respectfully note that the reviewer criticizes our study without citing any specific literature to support the claim that has shown that loss spiking results from the interaction of training hyperparameters, data, and model state. While we agree that training dynamics are influenced by multiple factors, our study provides theoretically grounded conditions to stabilize pre-training of large language models, rather than relying on heuristics or empirical tuning.
We also highlight that methods lacking theoretical justification may appear to stabilize training by chance but may in fact be unstable in other settings, as illustrated by the case of Embed Detach. A previous study [1] reported that Embed Detach empirically stabilizes models with over than 100B parameters without theoretical justification, but our experimental results suggest that Embed Detach does not fundamentally address the loss spike problem.
Regarding the comment: Also, since loss spikes typically occur after the warm-up phase, it's unclear whether smaller initializations have any effect at that point, we understand that the reviewer believes initialization has limited impact on training dynamics. However, prior studies [2, 3] have shown that initialization methods affect training behavior throughout training, not just during the early stages. Our experimental results further support this, demonstrating that the methods derived from our theoretical analysis effectively suppress loss spikes even after the warm-up phase as shown in Figure 3 (b). We would greatly appreciate it if the reviewer could share specific evidence to help facilitate further discussion.
Reason to Reject 2
We respectfully disagree with the reviewer’s claim that our experimental setup is insufficient or incorrect.
First, although the reviewer cites the LLaMA paper to argue that loss spiking is mild in 13B models, we would like to emphasize that the LLaMA paper (https://arxiv.org/abs/2302.13971) does not disclose its full training configuration. For instance, it does not specify the initialization method. Therefore, it is not possible to reproduce the exact same training conditions, and differences in results cannot be directly attributed to flaws in our setup. Furthermore, the LLaMA paper does not aim to provide general findings on the loss spike problems and does not analyze the details of pre-training stability.
Second, our experimental results clearly show that even a 13B-parameter model can suffer from training failure due to loss spikes as described in Section 5.4, and that the method that meets provided conditions successfully stabilizes training.
Finally, we note that recent work addressing pre-training stability in large language models also focuses on models in the range of billions of parameters [4, 5]. Notably, [4] includes experiments on models up to 13B parameters, supporting the relevance of our scale.
Taken together, we believe our experimental design is reasonable and valid for studying loss spiking, and the conclusions we draw are well supported by both theoretical analysis and empirical evidence.
[1] https://openreview.net/forum?id=-Aw0rrrPUF
[2] https://openreview.net/pdf?id=H1gsz30cKX
[3] https://proceedings.mlr.press/v119/huang20f.html
I appreciate the authors’ responses. Unfortunately, they did not adequately address my concerns and in fact raised some more:
- The authors cite papers from as early as 2019, using small-scale models (mainly in CV using ConvNets) to argue that initialization continues to affect pretraining even after the warm-up phase. The warm-up phase typically uses a very small and gradually increasing learning rate precisely to neutralize the effects of initialization on early convergence in LLM pretraining. The authors' argument is highly questionable, especially when compared to modern LLM pretraining. Another cited paper is from 2020, which is far removed from current large-scale LLM practices and thus not a convincing reference in this context. Applying such outdated and small-scale findings to today's LLMs is frankly misguided.
- The authors attempt to dismiss discrepancies with the LLaMA paper's observations by stating that its full training configuration was not disclosed or open-sourced. This is not a valid justification. LLaMA was trained on a much larger and broader dataset, and its findings are therefore far more generalizable and trustworthy than results based on the small models and limited datasets used in this work. Many groups, including ours, have successfully reproduced, and in some cases even surpassed LLaMA's results. The claims in this paper, based on limited-scale experiments, come across as speculative. Unless the authors can validate their hypotheses on large-scale models and datasets, their insights remain of limited practical value, especially for real-world or commercial LLMs. From our experience, simple improvements like better hyperparameter tuning, optimizer choice, or learning rate schedules are often more impactful than the proposed tricks.
- My intuition is that the loss spiking behavior observed by the authors on their relatively small models is likely due to suboptimal hyperparameter choices or optimizer settings. We have trained many models at a similar scale and have not encountered such severe loss instability. When it does occur, it's usually due to early configuration issues that can be resolved through basic tuning. This leads me to believe that the paper conveys a somewhat misleading conclusion, attributing instability to fundamental flaws when it may simply stem from an improper setup.
Moreover, the other paper mentioned by the authors: "Initialization of Large Language Models via Reparameterization to Mitigate Loss Spikes" in EMNLP 2024 Main, the authors use it to claim that results on 13B models are sufficient to support their conclusions. However, I looked into the paper and found it has only two citations so far, one of which is even a translated version of the same work. In the context that LLMs are receiving intense attention and widespread research interest, this lack of impact is highly unusual. It further supports my point: without solid, well-substantiated conclusions, such papers contribute little to the field and often bring more confusion than clarity.
This paper addresses the issue of loss spikes that commonly occur during the pre-training of LLMs. The paper focuses on loss spikes during LLM pre-training, hypothesizing that they are caused by sudden growth in gradient norms. The authors theoretically analyze the spectral norms of the Jacobian matrices for sub-layers in Transformer-based LLMs to identify factors that keep gradient norms small. The paper conducts various experiments to empirically validate its theoretical findings. Results show that methods satisfying these conditions effectively prevent loss spikes during pre-training.
接收理由
- Provides theoretical insights into the cause of loss spikes and conditions for stable training.
- Offers theoretical backing for empirical methods already used in the field.
拒绝理由
- It's unclear why preventing these spikes is crucial.
- Some theoretical analyses involve simplifying assumptions, such as the identity function for activation functions in FFNs, even though the appendix explores other activation functions.
- Although improving stability, the proposed methods might not always outperform other techniques in terms of final performance, especially at smaller learning rates.
Thank you for reading our paper and recognizing the potential usefulness of our work. We respond to each reviewer's concerns as follows.
It's unclear why preventing these spikes is crucial
As shown in Figures 1 and 5, loss spikes may lead to divergence during pre-training. Since the pre-training of large language models requires substantial computational resources, preventing loss spikes is important to minimize the risk of training failure and to avoid wasting those resources. If the reviewer still has concerns, we would be grateful for clarification to facilitate further discussion.
Some theoretical analyses involve simplifying assumptions, such as the identity function for activation functions in FFNs
For the activation function in FFNs, Appendix F discusses widely used activation functions in FFNs, i.e., ReLU, SiLU, and SwiGLU, and shows that they exhibit essentially the same behavior as in the case using the identity function, as described in Section 3.1.
The proposed methods might not always outperform other techniques, especially at smaller learning rates
We would like to emphasize that our primary goal is to stabilize the pre-training of large language models, rather than to maximize final performance. In experiments with 13B-parameter models, as shown in Section 5.4, Scaled Embed, which met our provided conditions for stabilizing pre-training, was successfully trained with the learning rate configuration of lr=3.0×10−4, where the Vanilla model failed to train.
Because the model trained with a higher learning rate can achieve better performance, Scaled Embed trained with the learning rate achieved better scores than Vanilla trained with a stable learning rate on our benchmark tasks, as shown in Table 3.
In addition, we would like to emphasize that Scaled Embed also achieved a higher average score than Vanilla on the benchmark tasks using the same lower learning rate (lr = 1.0 × 10-4). The average scores were 56.40 for Vanilla and 56.82 for Scaled Embed, respectively. Therefore, the main experimental results indicate that satisfying our provided conditions does not lead to degraded performance on downstream tasks.
We appreciate your valuable feedback, which has helped us improve the paper. We hope that our response has sufficiently addressed your concerns, and we would be grateful if you would reconsider your evaluation accordingly.
Thank you for the rebuttal. The explanation solves most of my concerns. I will maintain my original ratings.
This paper tries to understand how to prevent spiking in language model pretraining and explains why two known techniques: Embedding Normalization and Layer-wise parameter initialization scaling, help stablizing the model in training.
The author proposes to control the upper bound of gradients at initialization because gradient spiking is commonly observed before loss spiking. They further show that the upper bound of gradient-norm on each layer positively correlates with the weight norm on each layer and negatively correlates with the embedding weight norm. Therefore, the authors propose to make sure the embedding is of norm O(1) and layers has weight norm O(1/L) or O(1/\sqrt{L}).
接收理由
-
The authors present a very detailed analysis considering the gradient upper bound and perform empirical experiments across different scales to support it.
-
The paper is overall clear and well-written.
拒绝理由
-
The empirical result seems to have some potential contradiction with the theory. The gradient norm seems to have a lot of spike in Figure 4, it is just that the loss is not spiking.
-
While the theoretical analysis is very detailed, the authors didn't argue how this differs from famous previous works: for example, the MuP paper (and its depth variant) [1, 2], which has similar calculations and arrives at a very similar initialization paradigm.
[1] https://arxiv.org/abs/2310.17813 [2] https://arxiv.org/abs/2310.02244
We thank the reviewer for reading our paper and providing feedback. In the following, we respond to each of the concerns raised.
The empirical result seems to have some potential contradiction with the theory
We are not entirely sure whether we have correctly understood the reviewer’s comment, but we assume that the reviewer considers that, although the gradient exhibits spikes, the loss does not. We will respond based on this interpretation. We believe there may be a misunderstanding regarding the interpretation of Figures 3 and 4. Figure 3 shows the loss curves, while Figure 4 shows the gradient norms. To make the correspondence easier to follow, let us focus on the models with 1.7B parameters. Figure 4 (b) indicates that spikes in gradient norms occur after 5000 steps in Vanilla and Embed Detach, which do not meet the conditions we proposed for stabilizing training. Correspondingly, Figure 3 (b) shows that loss spikes also occur at the same points. In contrast, the Embed LN and Scaled Embed methods prevent these spikes.
The authors didn't argue how this differs from famous previous works: for example, the MuP paper
While we appreciate the reviewer’s suggestion to compare our work with the MuP framework, we respectfully note that the subjects being analyzed differ significantly. As we understand it, the MuP studies [1, 2] focus on the behavior of internal layers (including residual connections) in general neural networks. In contrast, our paper provides a practical theoretical analysis tailored to the actual sub-layers in (Pre-LN) Transformers, such as FFN, self-attention, and LN layers, with the goal of stabilizing pre-training in large language models. Moreover, there may be some misunderstanding in the discussion of our paper, as the reviewer summarizes Therefore, the authors propose to make sure the embedding is of norm O(1) and layers has weight norm O(1/L) or O(1/\sqrt{L}). Our theoretical analysis does not claim the necessity to make the norm of embedding O(1), which would imply a constant norm independent of the number of layers or hidden dimensions. Instead, our theoretical analysis provides the necessity to make the standard deviation of each embedding close to 1. Our study is also distinct from studies on MuP in that we provide specific conditions on embeddings to stabilize the pre-training of large language models.
Furthermore, we would like to emphasize that we initialized weight parameters by taking into account the hidden dimension and the number of total layers to satisfy our provided condition, which is different from adjusting weight norm O(1/L) or O(1/\sqrt{L}).
If the reviewer could clarify more specifically which part is considered overlapping, we would be happy to provide a more direct comparison.
-
Sorry for the confusion. I should be more clear. I am referring to Figure 4 (a), where Embed LN and Scaled Embed exhibit gradient spikes early in the training but there is no loss spiking in 3 (a).
-
In [1], MuP actually assumes a constant depth and proposed to initialize the weight to with standard deviation 1/\sqrt{fan_in} \min{1, \sqrt{fan_out / fan_in}}, which are similar to \sqrt{2/5fan_in} used in this paper. Further, MuP actually requires the hidden vector to be of norm \sqrt{d}, instead of O(1) ([1], Desideratum 1) which again is similar to the parameterization in this paper.
We greatly appreciate the reviewer's prompt reply. We would now like to address the two questions raised.
First Question
We appreciate the reviewer’s clarification regarding their focus on gradient norms at the beginning of pre-training (until around 2000 steps). During this period, the average gradient norm is initially large and gradually decreases with some fluctuations, similar to the behavior of the loss values. These fluctuations are distinct from spikes, which we define as sudden and large transient increases from a stable low value. In fact, the gradient norms of Scaled Embed and Embed LN do not exceed their initial values, as they consistently decrease over time.
We sincerely apologize for the confusion caused by the trimmed y-axis in Figure 4. Our intention was to make the spikes that occur after the loss becomes small more visible, rather than to focus on the initial behavior. To further clarify the initial behavior, we will update Figure 4 with an adjusted y-axis and include a more detailed explanation in the manuscript.
Second Question
We are not entirely confident that we have correctly understood the reviewer’s concern, but we will provide an answer. If our interpretation is incorrect, we would appreciate it if the reviewer could kindly let us know again.
MuP actually assumes a constant depth and proposed to initialize the weight to with standard deviation 1/\sqrt{fan_in} \min{1, \sqrt{fan_out / fan_in}}, which are similar to \sqrt{2/5fan_in} used in this paper.
We understand that the reviewer is suggesting our approach is similar to the MuP study because we also use hidden dimension sizes such as and in our initialization. However, incorporating hidden dimension sizes into initialization is a standard practice in neural networks and not specific to the MuP study. For example, Xavier initialization uses as the standard deviation.
In our study, we theoretically provided two conditions to stabilize pre-training and empirically investigated their effectiveness. In Section 4, we discussed whether the widely used initialization method for LLMs, using , satisfies the first condition, small sub-layers. In our main experiments, we used the widely used initialization for LLMs, and additionally provided results with Xavier initialization in Appendix C.1.
MuP actually requires the hidden vector to be of norm \sqrt{d}, instead of O(1) ([1], Desideratum 1) which again is similar to the parameterization in this paper.
Based on the summary of the first review, we interpret the reviewer as considering our condition on each embedding to be similar to the requirement in the MuP study. As mentioned in our previous response (and as mentioned in line 149 in our paper), our work theoretically provides the necessity to make the standard deviation of each embedding close to 1 in order to stabilize pre-training. In other words, our study does not impose a requirement on the norm of each embedding.
Furthermore, since we did not use the big-O notation such as O(1) in our paper, we would appreciate it if the reviewer could point out the specific lines or descriptions in our paper that may have led to this interpretation, to help clarify the discussion.
Thank you for the reply.
Regarding the first question, I agree it will be better to differentiate the scales of the spikes.
Regarding the second question, (1) I agree that using the fan_in and fan_out in initialization is standard practice. The reason I think the authors should discuss the connection is because the author's methodology and conclusion are both similar to MuP conceptually (trying to ensure that gradient's spectral norm being close to constant)
(2) Because embedding is always zero-mean in this paper, there is a straight-forward relationship between embedding standard deviation and the norm, where the standard deviation of 1 implies that the expectation of the square of the L2 norm is d. Is my understanding correct here? If this is correct, then this initialization scheme is indeed the same as the MuP initialization scheme.
(3) If I understand the analysis correctly, while the authors never used O(1) notation in the paper, all of the calculations in the paper are only correct up to a constant factor so considering the result in big O is still faithful to the paper.
Thank you for your prompt response and for your helpful interpretation of our paper.
We apologize for the delay in addressing the reviewer’s comments. We required additional time to thoroughly reconsider the relationship between MuP studies and our paper, as highlighted by the reviewer.
We agree that the MuP studies mentioned by the reviewer [1, 2] share conceptual similarities with our work, but the focus and context differ, as described in our first response. While MuP primarily analyzes MLP layers (and residual connections), our work focuses on the specific architecture, (Pre-LN) Transformers used in large language models. In other words, the MuP studies focus on the training dynamics of general neural networks, but we focus on the properties of the actual Transformer architecture, as a continuation of Transformer-specific analyses such as [3, 4, 5].
To avoid further potential confusion, we would like to summarize the similarities and differences below, without using Big-O notation at this time.
For
: A Spectral Condition for Feature Learning, https://arxiv.org/abs/2310.17813 : Similarities
- Both studies analyze training dynamics with a focus on the spectral norm of parameters and gradients.
- Initialization methods in both studies depend on the hidden dimension size.
- However, using the hidden dimension size in initialization is a standard practice in neural networks, as the reviewer has noted in agreement.
- Both studies require a similar condition to an input.
- As indicated by the reviewer, the expected squared L2 norm of a vector drawn from a normal distribution with mean 0 and standard deviation 1 is equal to the hidden dimension size.
Differences
- This paper focuses on MLP layers, which limits the direct applicability of its insights to self-attention layers. In contrast, our paper analyzes the actual sub-layers of Transformers: the FFN, multi-head self-attention, and layer normalization.
- In addition, we address the residual connections used in Transformers. This paper provides the desideratum to each hidden layer, i.e., defined in Equation (2), but our paper provides a condition to the standard deviation of each embedding and vectors after the residual connections, specifically, and in Equations (1) and (2) to keep the right terms in Equations (15) and (20) small.
For
: Tensor Programs VI: Feature Learning in Infinite-Depth Neural Networks, https://arxiv.org/abs/2310.02244: Similarities
- Both studies scale the residual branch using .
Differences
- This paper assumes that the residual branch has block depth 1. However, as noted in the paper, the theoretical results may not directly transfer to Transformers, whose block depth is 2. In fact, the paper says that Combined with the theoretical insights of Section 9, this leads us to conclude that while the scaling can potentially be practically useful in transformer training, it is likely to be brittle to architectural and algorithmic changes, or even simple things like training time. In contrast, we analyze the operations in the actual sub-layers of Transformers and investigate whether the scaling satisfies our provided conditions to stabilize pre-training.
We appreciate the reviewer’s helpful discussion. Should the reviewer find our explanation reasonable, we would be happy to incorporate this comparison into our paper.
I think the comparison sounds reasonable to me and it would be beneficial to include this in the paper.
We sincerely appreciate the constructive and insightful discussion. We will, of course, incorporate the points raised in this discussion into the camera-ready version of the paper if it is accepted to the conference.
This paper analyzes the problem of loss spiking during training LLMs and provide some conditions to bound the gradient norm of the models, connecting to remedies such as scaling the embeddings or applying layer norm that satisfy these conditions. The reviewers appreciated the theoretical analysis and experiments on various scales up to 13B. The discussed remedies have similarities to muP, as discussed by the reviewers. Reviewer epfy was concerned that the paper reports spikes on small models (350M, 1.7B) but these are typically stable during training, and attributes this to possibly suboptimal hyperparameter tuning. The results on 13B show that reducing the learning rate also removes the loss spiking phenomenon. The type of spikes in smaller models, which doesn't lead to overall divergence of the training run, seem different from the type in 13B with a larger learning rate. Regardless, the suggested approach seems to resolve both types of spikes and results in similar performance of the final model, and could make the setting of hyperparameters for stable training more robust in training LMs. More extensive experiments on the "spiking regions" of hyperparameters for vanilla vs scaled embed could help clarify how much more robust to hyperparameter settings this is.