Weight decay induces low-rank attention layers
摘要
评审与讨论
The paper studies the effect of weight decay on losses where the trained parameters include two matrices that are multiplies, and specifically the bias of such losses towards low rank. It is shown theoretically that under certain conditions, local minima of the L_2 regularized loss coincide with minima of the L_* loss, regularized by nuclear norm, which is an analog to low rank. Also, it is shown that with gradient flow, the distance between the minima of the L_2 and L_* losses converges exponentially fast to 0. Several experiments are given, mostly on transformers, that verify these findings empirically.
优点
I think that overall this is a nice paper that should be accepted. The results, although pretty simple mathematically, show a nice observation on the bias of certain parameterizations towards low rank. The experiments seem thorough and cover both toy examples with smaller models, and larger-scale transformers.
缺点
My biggest issue with the paper is its framing. The theoretical results seem to apply more for compressed sensing or matrix compleletion, where the loss is L(A^\top B) for some differentiable L. However, the paper (and also the title) seem to discuss mostly transformers. It might be popular to write papers about transformers these days, but this specific paper looks more like a paper about implicit bias in general, with experiments on transformers.
Second, the discussion about linear models focuses on the underspecified regime since the matrix S is invertible (i.e. the input dimension is larger than the number of samples). This is not the standard regime for linear regressions. I would suggest, to make this discussion more relevant to what happens in practice, is to discuss kernel regimes (e.g. NTK), where it is more standard where the kernel’s dimensions is large than the number of samples.
问题
-
What happens if L in Eq. (2) is non-differentiable? In practice, L it is non-diferentiable for transformers since it incorporates an MLP with a ReLU activation.
-
In Figure 3 - what is the performance of each setting? It is not clear for the more realistic values of wight decay (i.e. those with the best performance) what is the rank.
-
Line 212 - How do the authors define an unstable equilibrium?
局限性
The authors discuss the limitations of their results.
We thank the reviewer for their positive feedback and constructive criticism. Below, we address your concerns point by point.
The theoretical results seem to apply more for compressed sensing or matrix completion, where the loss is for some differentiable . However, the paper (and also the title) seem to discuss mostly transformers.
This is a valid point. Yet, we stress that the motivation of our research is to understand the inductive bias of weight decay in attention layers. Weight decay is virtually always used by default, and attention layers are almost ubiquitous in modern architectures. Therefore, we believe shedding light on this inductive bias is of great relevance to the community, and justifies the current framing despite our results ultimately being useful in other fields.
The discussion about linear models focuses on the underspecified regime since the matrix is invertible. This is not the standard regime for linear regressions.
We may have misunderstood the reviewer here, but we would like to clarify that our setting considers the case where (number of data points) is larger than (dimension of the input), since we assume the matrix is the identity. As you point out, this is the standard regime for linear regression. Furthermore, even in the case where , as long as and are still co-diagonalizable (a convenient assumption also made by Saxe et al. [2013]), we believe the derivation to be very similar, and that the weight regularization would simply set the weights corresponding to the degenerate dimensions to zero.
What happens if in Eq. (2) is non-differentiable? In practice, is non-differentiable for transformers since it incorporates an MLP with a ReLU activation.
This is indeed an important question. We point out that modern transformers, and in particular LLMs, use the GeLU or SwiGLU activation, making them differentiable.
However, when the model is merely almost everywhere differentiable, we do not know whether the correspondence of the local minima between and , and in particular whether Lemma 3.2, still holds. Note that if the Lemma does hold, given that our proof in Theorem 3.3 does not use differentiability, the correspondence result would hold.
As for our result on the optimization dynamics (Theorem 3.4), our theory uses the gradient flow limit to study the problem. The flow is, of course, undefined at non-differentiable points. However, the Brownian noise accounting for the minibatch noise would allow escape from any non-differentiable point, so one could still define it soundly (e.g., setting the gradient to zero or to any subgradient at those points). Thus, almost everywhere differentiability is enough for the dynamics to be well described by our approximation.
In Figure 3 - what is the performance of each setting? It is not clear for the more realistic values of weight decay (i.e., those with the best performance) what is the rank.
The performance of the LLM experiment for various values of weight decay can be found in Table 1. For ViT, interestingly, we found that removing the weight decay from the attention layers did not result in a significant gain in performance. In fact, removing weight decay even slightly degraded performance, indicating that for visual tasks, low rank doesn't seem to hurt as much as it does for language tasks.
This highlights a subtlety in our message that we believe was not clear in the original manuscript. The takeaway of our work should not be that practitioners should always turn off weight decay from attention layers. Our work aims to uncover the confounding, low-rank-inducing inductive bias of weight decay coupled with attention layers and demonstrate the relevance of this inductive bias in real applications, including popular foundation models. The benefit of this bias is problem- and architecture-dependent. By taking an off-the-shelf model and hyperparameters and showing that performance can be improved by turning off weight decay in its attention layers, we aim to highlight the need for practitioners to take this inductive bias seriously.
We will include in the final manuscript the performance on ViT, and add the above clarification point.
Line 212 - How do the authors define an unstable equilibrium?
In this setting, we mean by "unstable equilibrium" a stationary point that is not a local minimum, i.e., the gradient is null but the Hessian is not positive semi-definite.
We thank the reviewer again for their review and appreciation of our work. We remain available if you have any further questions.
I thank the authors for their response. My point about S being invertible may have been due to confusion on my side, and the authors clarified it. All the rest of my questions have been answered and I keep my score. I think this paper should be accepted.
This paper studies the effect of weight decay on the product of matrices, which appear in the attention layers. This paper shows that weight decay will have an effect of reducing the rank hence hurting the generalization. The theoretical results are verified in extensive experiments.
优点
- Understanding the attention layers is an important question hence is of great interest.
- The experiments are extensive and well support the theoretical results.
缺点
The novelty of this paper is quite weak:
-
Some theoretical results have already been shown by Wang and Jacot [2023]. Specifically, as mentioned by the authors, Theorem 3.1 in Wang and Jacot [2023] is a general version of Theorem 3.3.
-
Theorem 3.4 is for gradient flow which seems far from the algorithms used in practice.
-
The empirical observation that low rank weights hurt the generalization also has been shown in [Sharma et al., 2023].
Please correct me if I miss any key contributions in this paper.
问题
- About Theorem 3.4, how to see that "... long before stationary points are found"?
局限性
The contribution of this paper, as outlined in the weaknesses section, does not seem substantial enough.
We thank the reviewer for their feedback. We address your concerns point by point below:
Some theoretical results have already been shown
We stress that while Wang and Jacot [2023] show a similar result, their result does not apply to transformers where is not full rank due to an architectural bottleneck. More importantly, their result only describes what happens when training is converged to an equilibrium point. We stress that this setting is practically never achieved. Our theoretical contribution goes beyond this and considers what happens during optimization, which is theoretically far more complex and meaningful than studying behaviour at stationary points. We therefore stress that our results establishes that rank regularization does happen already early in training. This not only provides a theoretical explanation of past empirical observations (e.g., Khodak et al., [2022]), but is especially relevant for understanding (online) training large foundation models such as LLMs, where optimization typically is never brought to completion. Our empirical results, and in particular the analyses we performed on the pretrained foundation model weights, clearly demonstrate this relevance. We hope that this clarifies the novelty and relevance of our work.
Theorem 3.4 is for gradient flow which seems far from the algorithms used in practice.
We note that we provide in the appendix a similar result when considering gradient flow with noise, as well as with momentum and decoupled weight decay.
As for the continuous dynamics, we believe this to be a benign approximation. Indeed, for the non-stochastic version (Lemma B.1), the proof remains unchanged. One obtains an exponential decrease in , where is the learning rate. Note that for , we retrieve the factor from the continuous case. For the stochastic version, one would get an exponential decrease of until it becomes of the order of , where is the variance modeling the stochasticity of SGD, and is an upper bound on and . Things get more complicated if we want to model discrete time, stochasticity, and momentum. The discreteness of time adds additional interaction terms between the noise terms and the matrices and that needs to be taken into account. The continuous SDE offers a nice approximation of real dynamics, as is often chosen to be small, while still capturing all the intuition.
We agree however that the quality of these approximations to the practical training dynamics is nonetheless an important point to discuss, and will add a short paragraph in the final manuscript.
The empirical observation that low rank weights hurt the generalization also has been shown in [Sharma et al., 2023].
We respectfully disagree with this assessment. Sharma et al. [2023] show that after pretraining, when surgically reducing the rank of the MLP weight matrices, the performance of various LLMs improves on downstream tasks. As part of their hyperparameter search, they found no setting in which rank reduction of (even subsets of) attention matrices improved performance, and show some evidence that it may even hurt on, e.g., the CounterFact dataset.
In contrast, to the best of our knowledge, we are the first to show that the low-rank regularization induced by weight decay on attention layers during optimization can hurt the perplexity of LLMs. Furthermore, our empirical results demonstrate, by showing the equality of the entries of e.g. and (Fig. 4 and Proposition 3.1), that the attention weights of popular foundation models, such as Llama 2 and ViTs, are being rank regularized through the mechanism we describe.
These are, to the best of our knowledge, very different insights from Sharma et al., 2023. We welcome the reviewer to clarify and point out results in Sharma et al., 2023 where the two aforementioned points were observed empirically.
About Theorem 3.4, how to see that "... long before stationary points are found"?
This is a goot point and deserves some additional clarification in the main text. What we theoretically show is that under reasonable conditions, the timescale at which the rank regularization can be observed is independent from the rest of the optimization (with a characteristic time equalling ). This means that for long enough optimization, such as that of a foundation model, there is ample time for the co-optimization of the two regularizations to happen. Our analyses on pretrained model weights in Fig. 4 once again support that view.
We hope we could clarify our contribution to the reviewer and convince them that the paper deserves acceptance. Please let us know if you have further questions or concerns.
I thank the authors for the detailed response. I do not think my concerns are addressed. I will clarify them further below.
- ...More importantly, their result only describes what happens when training is converged to an equilibrium point....
Regarding Theorem 3.3, I believe a similar result has been shown in [Wang and Jacot, 2023]. Specifically, I believe Theorem 3.3 also describes what happens at the equilibrium point hence there is not much difference between your results and their results. Additionally, I do not think the not being full rank is a big technical difficulty.
- ...note that we provide in the appendix a similar result when considering gradient flow with noise...
Gradient flow with noise, as well as continuous SDE is also far from practice. I am not sure how much insight we can obtain from analyzing these dynamics. I am not saying the insights shown in this paper are wrong. I just want to emphasize the results are not very strong.
- ..to the best of our knowledge, we are the first to show that the low-rank regularization induced by weight decay on attention layers during optimization can hurt the perplexity of LLMs...
This is what you said in line 322-324:
these findings complement the recent observation that reducing the rank of language model MLP matrices post-training improves their reasoning performance, while doing the same for attention layer matrices mostly hurt it [Sharma et al., 2023].
I am confused. Didn't [Sharma et al., 2023] already show the results for attention layers?
- What we theoretically show is that under reasonable conditions, the timescale at which the rank regularization can be observed is independent from the rest of the optimization.
Note that Theorem 3.4 requires the norm of and to be uniformly bounded during the whole training process. You should make it clear in the statement of Theorem 3.4. Furthermore, this is a very strong assumption hence making Theorem 3.4 quite weak. While it might be empirically true, it cannot be proved analytically.
- ... long before stationary points are found
To me this cannot be easily seen from Theorem 3.4. If it is straightforward, I would suggest adding a corollary and proving the result, otherwise, it is just an empirical observation and should be made clear.
Overall I think this is an interesting paper and has many insightful observations. However, the main focus of the paper is not very clear. This paper seems to make theoretical contributions as it presents the theory first and then uses empirical results to verify it. To me, the theoretical results are not significant enough for NeurIPS as I have outlined in the weakness and response. I would suggest shifting the focus to empirical contributions and then providing theoretical insights.
Thank you for your response. We will address your points below.
For our theoretical contributions: The reviewer finds it "not very strong" since e.g. gradient flow with noise is far from practice. We reiterate that we also show our result for stochastic gradient flow with momentum and decoupled weight decay, which is by no means trivial. This is not only an obviously good approximation of practical training dynamic when using the still widely used SGD with momentum and decoupled weight decay (as we have explained previously, continuous dynamic is a benign approximation), it is a theoretically tractable approximation that is as close as one can get to the popular AdamW optimizer dynamic. Besides, we proved a set of novel and general matrix inequalities in order to link the norm to the discrepancy between the L2 norm and nuclear norm - and its generalization to the product of arbitrary many matrices. We believe these inequalities are nontrivial, and can be of interest for a larger community. Together, our results go beyond and complement the description at equilibrium, and provide an understanding of the training dynamic of weight decay applied to attention layers, shedding light on empirical observations that were made in previous works.
As for our empirical contributions: we refer the reviewer to our rebuttal where we have clarified what the contributions of [Sharma et al., 2023] were. Given the clarification, if the reviewer disagrees that our empirical contributions are novel, as we requested, could they clarify and point out results in Sharma et al., 2023 where our empirical contributions were already made? If not, together with the clarification on our theoretical result, we kindly ask whether the reviewer still believes our contribution is "Poor", i.e. the lowest possible score.
Note that Theorem 3.4 requires the norm of and to be uniformly bounded during the whole training process. You should make it clear in the statement of Theorem 3.4. Furthermore, this is a very strong assumption hence making Theorem 3.4 quite weak.
We are confused by this point. We clearly state the boundedness assumption in theorem 3.4, i.e.
"(...) If remain bounded during training, then (...)".
As those values are real values, the only meaning for boundedness is that there exists M such that those values remain smaller than M. Uniform boundedness would have made sense if we were considering a family of functions, which we do not. Also, the boundedness constants of and may differ, should the confusion come from this.
As for the necessity of the assumption itself: one can easily find sufficient conditions on the loss, such that the boundedness assumption provably hold. We would happily elaborate if the reviewer is interested. However, it is also possible to construct pathological losses such that the norm of A or B diverge. In that case, little can be said about . Ultimately, we used the boundedness assumption because any training dynamic that converges will trivially verify it. This is the practical scenario that we are interested in, given that in practice, virtually any stable training on a realistic loss, coupled with weight decay, will result in a converging dynamic.
"long before stationary points are found". To me this cannot be easily seen from Theorem 3.4.
We will reformulate Theorem 3.4 in the following way for clarity:
"(...) then we have that converges exponentially to 0, with a characteristic timescale equal to ".
This hopefully highlight the point we made in the rebuttal. Furthermore, to better reflect our point, clarified in the rebuttal, we will also modify our sentence from "(...) long before stationary points are found." to "(...) potentailly long before stationary points are found."
We acknowledge it may have been confusing, and thank the reviewer for pointing it out to us.
We hope we could address the remaining concerns, and that together with our other additional clarifications, it convinces the reviewer to reconsider the score.
Thank the authors for the response. I have a few more comments.
- As for our empirical contributions: we refer the reviewer to our rebuttal where we have clarified what the contributions of [Sharma et al., 2023] were....
I was just confused by your wording in the paper. "...while doing the same for attention layer matrices mostly hurt it [Sharma et al., 2023]". I am fine if this is your contribution.
I think observing low ran can hurt generalization for attention layers might be an interesting extension from fully-connected neural networks, but it is not a significant contribution, given the similarity between two architectures(product of weights). Furthermore, more experiments are needed to verify these results. As said by the authors, the current evidence can only suggest the possibility of such an effect.
- As those values are real values, the only meaning for boundedness is that there exists M such that those values remain smaller than M. Uniform boundedness would have made sense if we were considering a family of functions, which we do not.
I believe the weights can be thought of as the functions of the input data and the time step, since the gradient is a function of these two. Are you suggesting your results hold for any input data?
- As for the necessity of the assumption itself: one can easily find sufficient conditions on the loss, such that the boundedness assumption provably hold...Ultimately, we used the boundedness assumption because any training dynamic that converges will trivially verify it.
I am not sure about this. For example, the weight norm can keep increasing for the cross entropy loss or logistic loss[1]. Additionally, if there is a softmax layer, the weight matrices can be very large without causing divergence. These are used in practice. Could you please elaborate on why it is trivial?
[1] Soudry, Daniel, et al. "The implicit bias of gradient descent on separable data." Journal of Machine Learning Research 19.70 (2018): 1-57.
- We will reformulate Theorem 3.4 in the following way for clarity:...
Thanks for the reformulation. I would further suggest writing out the explicit equation with the convergence rate.
I think observing low ran can hurt generalization for attention layers might be an interesting extension from fully-connected neural networks, but it is not a significant contribution, given the similarity between two architectures(product of weights).
We stress that our theory only hold when the weight matrices interact with the loss functions as a product only. Fully connected neural networks with non-linear activation functions do not verify this assumption. Furthermore, in terms of empirical observation, we stress that Sharma et al studied applying low-rank approximation on the fully connected layer weights in language models, after training. On the other hand, we investigate the rank reduction during training by applying weight decay, and demonstrate this rank regularizing effect does in practice affect the standard training of popular foundation models.
I believe the weights can be thought of as the functions of the input data and the time step, since the gradient is a function of these two. Are you suggesting your results hold for any input data?
We understand the confusion. We clarify that we consider a setting where there is a fixed dataset, which induces a fixed loss function that we optimized during training. We model the standard i.i.d. minibatch sampling by the Wiener processes. We claimed that in practice, when training with weight decay, the training converges to a finite point in parameter space (i.e. there exists a constant bounding the parameter norm during training, i.e. in your word, uniformly over time given the particular input sequence the model observed). The boundedness assumption is thus of wide practical relevance since it describes most training with weight decay - and we make use of this assumption to make a statement about the rank regularization effect that affects these trainings.
I am not sure about this. For example, the weight norm can keep increasing for the cross entropy loss or logistic loss[1]. Additionally, if there is a softmax layer, the weight matrices can be very large without causing divergence. (...) Could you please elaborate on why it is trivial?
By "any training dynamic that converges will trivially verify it" we meant that if the parameters converge (i.e. they have finite value for ), then there exists a constant such that for all time , the parameter norm is upper bounded by it.
As for e.g. the cross-entropy loss, we agree with the reviewer that in the case without weight decay, the weight would grow arbitrarily in the case where the data is linearly separable. However, weight regularization in practice prevents such divergence. This can be seen intuitively in the gradient flow setting: given that realistic losses are positive (thus lower bounded by 0), as the loss is optimized, the gradient norm will ultimately decrease. If the optimization takes the parameter towards infinite norm, at some point, the weight decay term will become stronger than the gradient. Note that, as noted before, it is possible to construct losses where, despite weight decay, the parameter norm diverges to infinity. An example of such loss is simply . But we maintain that for realistic losses, the assumption widely holds in practice.
Thanks for the reformulation. I would further suggest writing out the explicit equation with the convergence rate.
Thank you for this suggestion. Will update our manuscript accordingly.
We hope we could clarify and address some of your concern. We hope they will convince the reviewer to reconsider the score.
This paper explores how applying weight decay to matrices affects the rank of their product. They show that L2 regularization of the operands is equivalent to regularizing the nuclear norm (sum of singular values) of the product which could result in a lower rank. The attention block of transformers contains several plain matrix products that this theory applies to. The authors experiment with different transformer variants, showing that applying weight decay to these matrices results in a loss of rank which is detrimental to performance. They suggest not applying weight decay to these specific matrices, while keeping it for the rest.
优点
- Interesting topic of high relevance to the community, weight decay is almost universally used for transformer training and improved guidelines and understanding could result in practical performance gains.
- The theory looks sound with simple experiments that support the main conclusion.
- The empirical evaluation on real transformers supports the theory.
- The paper is well written overall although a bit dense at times. The figures are sufficiently clear although a bit small and not in a vector format.
缺点
- The theory relies on analyzing the converged solution with gradient flow. I’m not sure how well this corresponds to real training (it would be nice to discuss this).
- The experiments could be stronger to eliminate some potential confounding explanations (see details below).
- (minor) Missing a couple of related work on weight decay (see details below).
问题
Experiments:
- The ViT experiments show a much larger loss of rank but the performance impact of this is not quantified, why not?
- For the GPT experiments the loss of rank seems very small as you point out. I’m not convinced that this is the root cause for the performance loss, rather then some temperature effects for the softmax.
- Many related works (see below) explore weight decay in terms of effective learning rates. They would suggest that when changing the weight decay the learning rate should be adjusted to compensate. Otherwise it is unclear if the performance degradation truly results from changes in the rank and weight decay or changes in the effective learning rate. Showing that these rank effects occur from the effective learning rate as well would strengthen the results and relate it to existing lines of work.
- Suggested experiment: Disentangle the softmax temperature effects (from the magnitude of the weight matrices / activations) from the rank effects. Maybe you could repeat the GPT or ViT experiment using a scale-invariant softmax from [1], reporting the rank and performance again. Since this softmax alternative does not depend on the scale of the inputs, it would eliminate this confounding effect.
- Suggested experiment: Disentangle the effective learning rate effects from the rank effects. When changing the weight decay, keep the weight decay * learning rate product constant. This will keep the effective learning rate in the steady state constant as described by [5] for AdamW, but also change the resulting weight magnitude, so it should be combined with something like the scale-invariant softmax to eliminate that effect.
Related work:
- Overall you provide a good overview of prior weight decay literature but are missing one important line of work that explores weight decay in terms of the effective learning rate. Here are a couple of notable works from this line: [2] [3] [4] [5]
Other Questions:
- In my experience high (effective) learning rates cause a loss of rank (at least certain measures like the stable rank) even in standard matrices (not just the products). Is this something you observe in your experiments (e.g. a loss of rank in A and B that you apply weight decay to, not just the product AB)?
[1]: Li, Zhiyuan, Srinadh Bhojanapalli, Manzil Zaheer, Sashank Reddi, and Sanjiv Kumar. "Robust training of neural networks using scale invariant architectures." In International Conference on Machine Learning, pp. 12656-12684. PMLR, 2022.
[2]: Li, Zhiyuan, and Sanjeev Arora. "An exponential learning rate schedule for deep learning." ICLR 2020.
[3]: Wan, Ruosi, Zhanxing Zhu, Xiangyu Zhang, and Jian Sun. "Spherical motion dynamics: Learning dynamics of normalized neural network using sgd and weight decay." Advances in Neural Information Processing Systems 34 (2021): 6380-6391.
[4]: Li, Zhiyuan, Kaifeng Lyu, and Sanjeev Arora. "Reconciling modern deep learning with traditional optimization analyses: The intrinsic learning rate." Advances in Neural Information Processing Systems 33 (2020): 14544-14555.
[5]: Kosson, Atli, Bettina Messmer, and Martin Jaggi. "Rotational equilibrium: How weight decay balances learning across neural networks." In International Conference on Machine Learning, 2024.
局限性
The paper does not really discuss limitations in any depth. I would not expect any particular adverse societal impact.
We sincerely thank the reviewer for providing us with valuable feedback. We address your concerns and questions point by point below.
The theory relies on analyzing the converged solution with gradient flow. I’m not sure how well this corresponds to real training (it would be nice to discuss this).
We clarify that we have three main theoretical contributions: the correspondence of the local optima of the two regularized losses; the matrix inequalities and the exponential decay of the difference between the two losses during optimization. The first two results are independent of the optimization method and, therefore, are relevant for real training. Specifically, if the solution found is an approximation of a local (or global) minimum of the L2-regularized loss, it is also an approximation of the nuclear-norm regularized loss.
As for the result on the training dynamics, we provide results for both the simple gradient flow regime and stochastic gradient flow with momentum. There are indeed several apparent sources of discrepancy between these dynamics and real training dynamics, e.g., continuous vs. discrete dynamics, how to model the minibatch noise exactly, and how to approximate the theoretically intractable AdamW dynamics. While we believe some of these approximations are benign (see our 2nd response to reviewer fDM3), others may not be (see, e.g., [1]). Ultimately, we had to strike the right balance between theoretical tractability and proximity to real training, such that relevant insights may be extracted without overcomplicating proofs, and we believe our empirical verifications validate the modeling. We agree that these are nonetheless important points to discuss, and will add these points as a limitation of the current theory in the final manuscript.
[1] Dynamic of Stochastic Gradient Descent with State-Dependent Noise. Qi Meng et. al Archive 2020
(minor) Missing a couple of related work on weight decay (see details below).
Thank you for providing these references. The relationship between weight decay and the effective learning rate is very relevant in our setting, as you pointed out. We appreciate the reviewer for bringing this to our attention. We will incorporate these references into the final version of the manuscript.
The ViT experiments show a much larger loss of rank but the performance impact of this is not quantified, why not?
This is indeed a good point that needs clarification. In our experiments, we found that removing the weight decay from the attention layers in ViTs did not result in a significant gain in performance. In fact, removing weight decay even slightly degraded performance, indicating that for visual tasks, low rank doesn't seem to hurt as much as it does for language tasks.
This highlights a subtlety in our message that we now clarified in the revised version. The practical takeaway of our work should not be that practitioners should always turn off weight decay in attention layers. In fact, we speculate that when increasing, for example, the key-query dimension, the rank regularization of weight decay may become beneficial, even in language modeling tasks. Our work aims to uncover the confounding, low-rank-inducing inductive bias of weight decay coupled with attention layers and demonstrate the relevance of this inductive bias in real applications, including popular foundation models. The benefit of this bias is problem- and architecture-dependent. By taking an off-the-shelf model and hyperparameters and showing that performance can be improved by turning off weight decay in its attention layers, we aim to highlight the need for practitioners to take this inductive bias seriously.
Disentangling the softmax temperature effects from the rank effects.
This is an excellent point. During our experimentation, we tried selectively turning off the weight decay on the key-query matrices only, as well as on the value-projection matrices only. Both of these changes, in fact, improved upon the baseline where weight decay was left at 0.1, yielding about half of the improvement achieved by turning weight decay off in all attention matrices. This suggests that low rank in both and is beneficial and that the benefits are cumulative.
Now, since turning the weight decay off in the value-projection matrices does not affect the effective temperature of the softmax attention, we believe the performance improvement could be (at least partially) disentangled from it.
Disentangling the effective learning rate effects from the rank effects.
This is also an excellent point. We thank the reviewer for bringing this confounding effect to our attention. We conducted the following experiment: to understand whether reducing the effective learning rate of attention layers can account for the performance improvement, we reused the off-the-shelf hyperparameters (where weight decay is set to 0.1 on all attention layers) and modulated the learning rate of the value-projection matrices by 1, 0.1, and 0.01. We left the key-query matrices untouched to disentangle the softmax temperature effect, as you suggested. Our early results suggest that reducing the learning rate in this manner results in significantly worse performance than the baseline, with a decrease of about 1% for the 0.1 modulation and 3% for the 0.01 modulation.
While it is difficult to truly disentangle the various confounding effects, we hope these new pieces of evidence are sufficient to convince the reviewer that the improved performance is, at least partially, due to the reduction in rank.
We thank the reviewer again for their inputs, particularly for their great ideas for analyzing confounding factors, which we also will add to our discussion. We believe that the various new clarifications and the emphasis on the new controls strengthen our paper. We hope the reviewer agrees, and will raise their score accordingly.
I thank the authors for their detailed responses and clarifications. The additional experiments and discussion alleviate my concerns about confounding effects to a degree (although the scale-invariant versions would have been more convincing), and strengthen the paper overall. I will raise my review score to 6.
The authors investigate the landscape of two different optimization problems. For a general objective function defined on a matrix space, they consider two regularized objectives , .
The authors prove that these two objective functions share the same set of critical points up to equivalence. Moreover, along integral lines of , the difference decays exponentially if both and remain bounded, effectively showing an implicit bias towards low-rank solutions. These results naturally apply to transformer training, for which the authors present some numerical results.
优点
The work is very well-presented, and the contribution is self-contained. The theoretical analysis covers cases of practical interest. Presenting the results through gradient flows enhances readability.
缺点
- Lack of an additional lemma addressing the effect of time discretization in the gradient flow scenario.
- The stochastic case is studied through SDEs, which may differ significantly from the exact setting encountered in deep learning.
问题
- The balance condition is often enforced directly at initialization and can be proven [1] to be conserved during the continuous-time flow. By the use of proposition 3.1, this would imply that the gradient flow of is exactly the same gradient flow of . This kind of condition is for example satisfied by spectral initialization, which is commonly used. I believe a comment about this would be interesting to see in the manuscript.
- Do the authors see any way to address the two weaknesses? I believe further clarifications in this direction in the manuscript would make the result sound more solid.
- I suggest to organize better the references section, to make them all uniform in style.
[1] S.S.Du, W. Hu, J.D. Lee. "Algorithmic Regularization in Learning Deep Homogeneous Models: Layers are Automatically Balanced", NeurIPS 2018.
局限性
I have no suggestions for the authors.
We thank the reviewer for their constructive feedback and interest. Below, we address your concerns point by point.
Lack of an additional lemma addressing the effect of time discretization in the gradient flow scenario. The stochastic case is studied through SDEs, which may differ significantly from the exact setting encountered in deep learning.
We thank you for bringing up this point. We would like to briefly comment on how our result could be extended to the discrete dynamic setting and why we omitted it in the manuscript.
For the non-stochastic version (Lemma B.1), the proof remains unchanged. One essentially obtains an exponential decrease in , where is the learning rate. Note that for , we retrieve the factor from the continuous case.
For a stochastic version, assuming we model the minibatch noise similarly, one would get an exponential decrease of until it becomes of the order of , where is the variance modeling the stochasticity of SGD, and is an upper bound on and .
Things get one order more complicated if we want to model discrete time, stochasticity, and momentum. The discreteness of time adds additional interaction terms between the noise terms and the matrices and , that we now need to take into account. While this is totally within reach, it would add a lot of technicalities that don’t provide any additional insights or intuition. The proofs would be barely readable, and even the phrasing of a precise proposition would be much more complicated. As you rightfully pointed out, the continuous SDE offers a nice approximation of real dynamics, as is often chosen to be small, while still capturing all the intuition.
We agree, however, that the quality of these approximations to the practical training dynamics is nonetheless an important point to discuss, and will add a short paragraph in the final manuscript.
The balance condition is often enforced directly at initialization and can be proven [1] to be conserved during the continuous-time flow. By the use of proposition 3.1, this would imply that the gradient flow of is exactly the same as the gradient flow of . This kind of condition is, for example, satisfied by spectral initialization, which is commonly used. I believe a comment about this would be interesting to see in the manuscript.
This is an intriguing point, and we thank the reviewer for bringing it to our attention. There is indeed a deep connection with spectral initialization, which will now have a dedicated paragraph in the discussion for the final manuscript.
However, we stress that for all during optimization with respect to does not imply the equality of the gradients of the two losses. We can theoretically show that the equality of the gradients would in fact hold whenever the singular values of are all equal, but this does not hold in general.
Obviously, however, the equality implies is co-optimized from the beginning, thus the rank of is regularized, and given our theoretical result, optimization will find a local minimum of . We stress that while optimizing directly will also find a local optimum, it may take a different trajectory and find a different optimum. The study of this difference may be an interesting future work.
I suggest organizing the references section better to make them all uniform in style.
We fully agree with this point, and we have now fixed it. We thank the reviewer for bringing this to our attention. We will keep on improving our manuscript for the final version.
We thank the reviewer again for bringing intriguing connections to our attention, which will help improve the discussion. We hope we have addressed your concerns and remain at your disposal for any further questions.
I wish first of all to thank the authors for their thorough rebuttal.
I apologize for the imprecise statement, I agree that the flows of and are not the same even under spectral initialization, but they are co-optimized during the whole path.
I am satisfied with the answers and I will keep my more.
In this paper, the authors study the role of the weight decay in training of matrices especially when they appear in multiplicative form. First estabilshing the equivalence between the L2 regularization and nuclear norm regularized loss at stationary points and local minima, they establish how even training on the L2 loss invariably leads to the latter and hence low-rank solutions at the minima. Using this main observation, they empirically validate their claim for various practical LLMs and provide interesting insights about their behavior.
优点
The authors consider a well-motivated problem and characterize it mathematically with substantive evidence. Given the impact of transformers and the underlying need to study them fundamentally, such analyses are indeed quite insightful and useful. In particular, a rigorous theoretical anallysis of how low-rank solutions emerge is indeed a very interesting observation.
缺点
While overall I enjoyed reading the paper, I wish the authors did a slightly better job at the following two things which could have made it even more solid:
-
The writing and exposition could be greatly improved and at various places especially concerning mathematical statements, it feels a bit loose though the details are provided in appendix. For example, Theorem 3.4 concerns the gradient flow analysis but neither the flow equation nor the precise statement for the "L2-nuclear norm" loss gap going to zero is written mathematically. Since this is an important result, it would be better to present enough mathematical details and expound upon them later. Similarly in Theorem 3.3, regarding the rank-r. It's not clear what this "r" is supposed to be. At the optimum or it's something fixed before hand. Things like this should be fixed. Also there are a couple of typos such as Eq.(8) should be a full stop, line 176 it reads better with "two losses" rather "2 losses" (same comment applies at various places like line 173" etc.
-
While the L2 loss in equation (2) concerns a pair of matrices (A,B) with the loss of the form L(AB^\top), given that the main motivation of the authors stems from transformers, it would have been nice to have a discussion about how their analysis extends to Eq.(1), let's say. Here there are two such multiplicative terms with and . In this case, can your results about low-rank minima still carry over? Or not? Some insights about this?
问题
Please refer to the above weakness part for questions.
Minor note: In a recent work, the authors observe a similar low-rank structure when training transformers with Markov chains. They reconstruct the formula for these low-rank matrices and indeed it seems your low-rank optimal conditions seem to be satisfied there (Appendix C I think). Thought I would share with the authors if they find interesting (not to compare as such): https://arxiv.org/abs/2402.04161
局限性
Yes.
We thank the reviewer for their encouraging feedback and constructive criticism. We address your concerns point by point below.
The writing and exposition could be greatly improved
We really appreciate all the feedback about the writing. We fully agree with all the above points, all of which are now taken into account in the revised version.
Can the results about low-rank minima still carry over when the multiplicative terms are and ?
This is indeed a crucial point which could be clearer. We thank the reviewer for bringing it to our attention. For our theory to hold, the paramount condition is that two matrices, and , enter the unregularized loss only as their multiplied form, . That unregularized loss may depend on other parameters and inputs, which obviously may interact with . But as long as this condition holds (as is the case for the attention matrices), the dependence of on other quantities can be safely ignored without loss of generality. For any such two matrices and , we can write the loss as , where denotes all the remaining parameters. Stationary points of are also stationary in and so Lemma 3.2 still holds. For Theorem 3.3, the exact same proof would allow us to show that is a local minimum of , if and only if is a local minimum of (constrained to ...). For Theorem 3.4, in the warm-up proof on line 534, hides a dependence on , but the result still holds as cancels out. The same applies to the gradient flow proof on line 542, where now depends on time, and thus indirectly on .
We realize we did not provide enough explanation of why the loss can be assumed to depend solely on and without loss of generality. In the updated paper, we now very clearly motivate, using the specific example of transformers, why studying our loss in this manner is the right thing to do.
We thank the reviewer again for these suggestions and questions, which we believe help improve the paper. We remain at your disposal if you have any further questions, and hope our clarification convinces the reviewer for a strong acceptance of the paper.
I thank the authors for addressing my concerns and entrusting them with the promised changes. I will keep my score.
We sincerely thank all the reviewers for their efforts in evaluating our work. Please find your personalized responses addressing your specific concerns. We welcome any further questions and remain open to continued discussions.
Dear Authors,
Whilst I appreciate many of the paper's contributions (despite some substantial similarities with the work of Jacot et al.), like reviewer vYXP, I am also a little bit concerned by the "Assumption" in Theorem 3.4 that A and B remain bounded during the entire training process: whilst it doesn't seem unreasonable at all for this to happen, it feels like this is something that ought to be proved based on initial conditions and conditions on the loss function, rather than assumed. I see that in your rebuttal to the reviewer in question, you mention that `one can easily find sufficient conditions on the loss, such that the boundedness assumption provably hold'' and that you are "ready to elaborate if the reviewer is interested".
I am interested. Could you please try to achieve this (discuss realistic cases where the condition provably holds and pathological cases where it doesn't), and submit it as an attachment, so that the paper is in more complete shape for publication? I believe this would significantly improve the quality of the work.
Best,
AC
Thank you for the reminder - we could indeed not see the first message. Below, the requested discussion on the boundedness assumption, with some sufficient conditions on the loss when it is provably satisfied. We will include this in the appendix of the final version of the manuscript. We are happy to discuss if you have any further questions.
Sufficient conditions
Henceforth, we will refer by .
Sufficient condition 1: Gradient flow with lower bounded loss function
We consider a the following gradient flow dynamics with weight decay:
$
\dot{\theta} &= -\eta (\nabla_{\theta}L + \lambda \theta)
$
is the learning rate hyperparameter, and the weight decay strength.
A sufficient condition on the loss is that it is lower bounded, which is the case for most common losses.
Indeed, the above dynamic is the gradient flow dynamic of the loss . Given that is also lower bounded, and that is a monotonically decreasing function of time, must converge to a constant real value, i.e. for some . If is unbounded from above, then necessarily is unbounded from below, which is a contradiction.
Sufficient condition 2: Gradient flow with momentum with Lipschitz gradient
We consider a the following gradient flow dynamics with momentum and decoupled weight decay:
where is the exponential average of the gradient of . is the momentum hyperparameter, the learning rate, and the weight decay strength.
A sufficient condition on the gradient is to be -Lipschitz with respect to the parameters sufficiently far, i.e. for for a given :
The momentum makes the analysis of the dynamics more complicated. However, defining , and and one can rewrite the equations as:
The derivative of the squared norm of verifies:
The dynamics of are flow dynamics. Whenever reaches , its norm is decreasing. can thus never exceed . As is an upperbound on , the same holds for . We also observe that the Lipschitz condition doesn't need to hold for all . In fact, it suffices that it holds for any borderless submanifold of codim -1 (for example the sphere of radius ) that contains the initialization point.
Note on stochasticity
To deal with stochasticity, we consider similar equations:
which become: The integral form is: The process can diverge because of the stochasticity. However, similarly to proposition B.2, we can fix for any an upper bound on that holds with probability at least ; this allows to bound the contribution of the stochasticity. We deal with the gradient component similar to the non-stochastic proof.
Pathological examples where the boundedness does not hold
An example of a loss for which the parameters will diverge is when we allow losses to be negative, and diverge to minus infinity "stronger" than the weight regularization term.
An obvious, albeit constructed such loss is . Then, the gradient of w.r.t. (resp. ) is (resp. ). Even with the decay term, one can see that if are initialized to be e.g. orthogonal matrices scaled by some , both will diverge to infinity.
Such negatively unbounded objective functions to be minimized may be found in e.g. the reinforcement learning setting, when using undiscounted returns.
One can come up with an increasing set of sufficient conditions for the stochastic dynamic with momentum to result in bounded parameters, capturing more and more losses. However, in practical scenario, there exist losses such that this provably won't hold, but the hyperparameter tuning by practitioners still allows for a stable dynamic. Our theorem still holds in those scenarios, which we don't want to exclude.
The boundedness assumption is thus of wide practical relevance since it describes most practical training with weight decay - and we make use of this assumption to make a statement about the rank reguralization effect that affects these trainings.
Dear Authors,
As I am not sure whether the above comment triggered a notification, I am posting this further comment to trigger one.
Best,
AC
This paper studies the rank sparsity inducing effects of the weight decay regularizer applied to factor matrices, with a particular focus on transformer architectures. It also proves the convergence of the gradient flow in this case. Theorem 3.3 shows that the local minima of (2) are the same as those of (8) (the only difference with existing works [1,2,3] is the arbitrary loss , the techniques are by now classic). Theorem 3.4 shows that if the matrices and remain bounded during training, the loss converges to zero exponentially during gradient flow training. Interestingly, the authors argue that such low-rank effects are undesirable in transformer architectures, and demonstrate that this type of regularization harms performance in many cases.
The reviewers are mostly positive about the paper. One exception is reviewer vYXP, which is concerned by the similarity to existing literature including [1]. I share this concern generally, as similar results have been shown in [1,2,3]. However, I like the fact that this paper considers a completely arbitrary loss function over the matrix, which makes the results slightly more general and elegant. In addition, I find some of the proofs to be of independent interest, such as the quantitative result in Proposition B.4. I did also have concerns about the "assumption" that remain bounded during training in Theorem 3.4, however, I really appreciate that the authors have been very reactive during their discussion with me, and provided additional results which show examples of loss functions where this assumption can be shown to hold, as well as examples where it does not. Overall, this paper makes an interesting addition to the existing literature.
Comments for the camera-ready:
Please incorporate the interesting analysis you provided in your comment (https://openreview.net/forum?id=oDeqjIM9Sk¬eId=w5ZHaJbJ1D) in the revision. I believe this greatly improves the quality of the work since you need a little push for the originality to be just above the borderline. In addition, I also wonder whether you could provide a generalization of Proposition B.14 and Theorems 3.3 and 3.4 to the case of deep matrix completion (Schatten p quasi norm regularization), though I understand it would be mostly a tangential question since that is not relevant to transformers.
Please fix the citations: none of the venues are showing up in your bibliography. For instance, the paper [1] is published at ICLR 2023 (notable, top 25 percent). There are other closely related works which are not even cited, e.g. [3]. It might also be advisable to expand the discussion of the related works by clarifying the difference in perspective between the present paper and the rest of the literature: in most literature, be it on neural networks, matrix completion or deep matrix completion, a the rank-sparsity inducing effect is a positive thing, which results in favorable generalization guarantees. In the present paper, the emphasis is on the way it can harm representation power in LLMs.
References
[1] Implicit bias of large depth networks: a notion of rank for nonlinear functions. Arthur Jacot, ICLR 2023
[2] Zhen Dai, Mina Karzand, Nathan Srebro, "Representation Costs of Linear Neural Networks: Analysis and Design". NeurIPS 2021.
[3] Implicit bias of SGD in L2-regularized linear DNNs: One-way jumps from high to low rank. Zihan Wang, Arthur Jacot.