SDEs for Adaptive Methods: The Role of Noise
We derive novel SDEs for SignSGD, RMSprop(W), and Adam(W), providing a more accurate theoretical and understanding of their dynamics, convergence, and robustness. We validate our findings with experiments on various neural network architectures.
摘要
评审与讨论
This work derives SDEs for adaptive gradient methods and study the role of gradient noise. The analysis starts from theoretically driving the SDE for SignSGD and highlight its significant difference from SGD. The work further generalize the SDE analysis for AdamW and RMSpropW, two popular adaptive optimizers with decoupled weight decay and reveal key properties of weight decay. Finally, the work integrates the derived SDEs with Euler-Maruyama to confirm that the SDEs faithfully track their respective optimizers with various modern neural networks.
优点
-The theoretical results are novel. To my knowledge, this is a first SDE analysis for SignSGD with quantitatively accurate descriptions.
-The theoretical analysis reports some novel properties in terms of gradient noise and convergence. These properties are interesting.
-The proofs seem complete and reasonable.
-A useful theory should be quantitatively verifiable. This work definitely make it. The experiments that SDEs fit the empirical results with various optimizers and models are informative and impressive.
缺点
-It seems that the reported theoretical results and insights cannot directly lead to some theory-inspired and improved methods. This raise a question on the significance of this work.
-While this work did literature review, some important references are still missing, such as [1] on analyzing Adam using SDEs. As weight decay plays a key role in the results, it may be helpful to review recent papers analyzing novel or overlooked properties of weight decay.
Reference:
[1] Xie, Z., Wang, X., Zhang, H., Sato, I., & Sugiyama, M. (2022, June). Adaptive inertia: Disentangling the effects of adaptive learning rate and momentum. In International conference on machine learning (pp. 24430-24459). PMLR.
问题
-
Please see the weaknesses.
-
Could you please explain more how L2 regularization and decoupled weight decay behaves differently in your results?
局限性
This work discussed the limitations in the appendix.
We sincerely thank the Reviewer for the significant effort put into this review: We appreciate the acknowledgement of the value of our research. We thank you for the questions as they stimulated us to include some more references and dig deeper to showcase the explanatory power of our SDEs even more.
Weakness 1:
"It seems that the reported theoretical results and insights cannot directly lead to some theory-inspired and improved methods. This raise a question on the significance of this work."
Answer:
We acknowledge that our work has limitations in terms of developing improved methods. However, we aimed to offer new insights into existing adaptive methods that are known to perform well in practice, even though the reasons for their effectiveness are not yet fully understood. We respectfully believe that, from this perspective, our work holds significant value and is of interest to the community.
Weakness 2:
"While this work did literature review, some important references are still missing, such as [1] on analyzing Adam using SDEs. As weight decay plays a key role in the results, it may be helpful to review recent papers analyzing novel or overlooked properties of weight decay."
Answer:
We thank the Reviewer for reminding us about this interesting paper, which we are familiar with but unfortunately forgot to cite. Rather than studying the role of noise on the dynamics of Adam, their focus is mainly on disentangling the effects of learning rate adaptivity and momentum on saddle-point escaping and flat minima selection. They use the SDE to study how Momentum helps SGD escape saddle points and minima. Analogously, they repeat the analysis for Adam and find that learning rate adaptivity helps to escape saddle points but leads to sharper minima than SGD. Inspired by their results (Thm.2 and Thm.3), they propose Adai (Adaptive Inertia Optimization), which rather than opting for adaptivity of the learning rate, it opts for adaptivity of the momentum parameters: They theoretically predict and experimentally validate that Adai is both fast at escaping saddles and successful at finding flat minima.
As highlighted by Malladi et al., the SDE presented in [1] is not derived within any formal framework and therefore does not come with formal approximation guarantees. However, this is a very insightful and valuable work that we will cite in the final version of the paper. We are of course happy to know if there is a specific point about [1] that we should pay attention to, or if other important references have been missed.
Regarding Weight Decay, we kindly request that the Reviewer provide us with any specific references they have in mind.
Question 2:
"Could you please explain more how L2 regularization and decoupled weight decay behaves differently in your results?"
Answer:
This is a very interesting question: Please, find below the SDE induced by using Adam on an -regularized loss together with the equivalent of Lemma 3.13. Most importantly, we observe the regularization used in this way does not provide additional resilience against noise w.r.t. Adam: The asymptotic loss level scales linearly in the noise exactly as it does for Adam. On the contrary, when regularization is used in a decoupled way as in AdamW, the asymptotic loss level is upper-bounded in .
When Adam is used to optimize the -regularized loss for , the SDE of the method is:
$
d X_t =-\frac{\sqrt{\gamma_2(t)}}{\gamma_1(t)} P_t^{-1} (M_t + \eta \rho_1 \left(\nabla f\left(X_t\right)-M_t\right) - \gamma X_t) d t
$
$
d M_t =\rho_1\left(\nabla f\left(X_t\right)-M_t\right) d t+\sqrt{\eta} \rho_1 \Sigma^{1 / 2}\left(X_t\right) d W_t
$
$
d V_t =\rho_2\left( (\nabla f(X_t))^2 + diag\left(\Sigma\left(X_t\right)\right)-V_t\right) d t,
$
where , , s.t. , , and .
Under the same assumptions of Lemma 3.5, the dynamics of Adam on a -regularized loss implies that
$
\mathbb{E}[f(X_t) - f(X_*)] \overset{t \rightarrow \infty}{\leq} \frac{\eta \mathcal{L}_\tau \sigma }{2} \frac{L}{2 \mu L + \gamma (L + \mu)},
$
meaning that the asymptotic loss level grows linearly in as it already does for Adam.
Much differently, the asymptotic loss level for AdamW is
$
\mathbb{E}[f(X_t) - f(X_*)] \overset{t \rightarrow \infty}{\leq} \frac{\eta \mathcal{L}_\tau \sigma }{2} \frac{L}{2 \mu L + \sigma \gamma (L + \mu)},
$
which is upper-bounded in .
Please, find an empirical validation of this bound in Figure 2 of the attached .pdf file.
Thanks for the rebuttal and addressing some of the concerns.
I will keep the rating as 6. I tend to accept this work.
Dear Reviewer,
Thank you for your response.
We are glad to know that some of your concerns have been resolved: Could you please share any remaining issues or suggestions you might have? Your feedback is invaluable and will assist us in refining our manuscript further.
We appreciate your time and consideration.
Best regards,
The Authors
The authors derive SDE for signSGD and Adam(W). The experiments show that the algorithm will converge toward the limit of the theorem indicates.
优点
The authors propose "accurate" SDEs for algorithms Sign-SGD and Adam(W).
缺点
-
In Remark after Lemma 3.6, the authors claim that Sign-SGD is (almost) linear in . However, with either in Phase 2 or Phase 3, there should be in the final bound.
-
All the stationarity holds when Hessian is the same from to and convergence holds for strongly convex. However, the hessian changes a lot during network training.
问题
-
Notation of is not defined. What is ?
-
How can we extend Lemma 3.13 to convex setting (or even nonconvex case)?
局限性
N/A
We sincerely thank the Reviewer: We appreciate the questions as they stimulated us to clarify certain aspects and dig deeper to showcase the explanatory power of our SDEs even more.
Weakness 1:
"In Remark after Lemma 3.6, the authors claim that Sign-SGD is (almost) linear in . However, with either in Phase 2 or Phase 3, there should be in the final bound."
Answer:
We fully agree with this observation, which is why we say that the dependence is "almost linear" in . We can rewrite the asymptotic loss level as:
$
\frac{\eta}{2} \frac{\mathcal{L}_{\tau}}{ 2 \mu } \frac{1}{\Delta},
$
and observe that
$
\frac{1}{\Delta} =\frac{\pi \sigma_{\text{max}}^2 }{\sqrt{2 \pi} \sigma_{\text{max}} + \eta \mu} = \frac{\pi \sigma_{\text{max}} }{\sqrt{2 \pi} + \frac{\eta \mu}{\sigma_{\text{max}}}}.
$
Therefore, when the noise dominates over the learning rate and/or over the minimum eigenvalue of the Hessian, or more in general when , we can conclude that the behavior is essentially linear in : We will most certainly clarify this aspect better in the final version of the paper.
Weakness 2:
"All the stationarity holds when Hessian is the same from to and convergence holds for strongly convex. However, the hessian changes a lot during network training."
Answer:
We agree that the Hessian of the loss function can change dramatically during training. However, as we specify in Line 186, we are not studying the properties of the iterates during training, but rather characterize the stationary distribution around minima: These are the only points where the optimizer can reach stationarity and possibly stop. With this in mind, as we specify in Lines 125 to 128, it is common in the literature to approximate the loss function with a quadratic function in a neighborhood around these points. Therefore, the Hessian is constant in this neighborhood. These two reasons justify why we only study the stationary distribution of SignSGD in Phase 3 for a quadratic loss function. We also add that whatever happens before Phase 3 does not influence what happens at convergence, e.g. the stationary distribution.
In response to the second part of your comment, we have strengthened our convergence analysis beyond the strongly convex case. Specifically, we extended Lemma 3.5 to the general smooth non-convex case (i.e. only requiring -smoothness):
Let be -smooth, be a learning rate scheduler such that and , where . Then, during
- Phase 1, ;
- Phase 2,
- Phase 3, ; where , , , and are random times with distribution .
Interestingly, in Phase 1, SignSGD implicitly minimizes the -norm of the gradient, in Phase 2 it implicitly minimizes a linear combination of norm and , and in Phase 3 it implicitly minimizes the norm : This result is novel as well and we thank the Reviewer for asking this great question.
Question 1:
"Notation of is not defined. What is ?"
Answer:
We apologize for not defining this symbol in the main paper. is the Brownian motion and we will specify this clearly in the final version of the paper. Importantly, we highlight that we included a whole chapter on Stochastic Calculus in Appendix B.
Question 2:
"How can we extend Lemma 3.13 to convex setting (or even nonconvex case)?"
Answer:
Due to some technical issues on AdamW that we will address in the future, we now put forward a generalization of Lemma 3.13 for Adam where we only require -smoothness:
Let be -smooth, be a learning rate scheduler such that and , where . Then
$
\mathbb{E} \lVert \nabla f \left(X_{\tilde{t}} \right) \rVert_2^2 \leq \left[ f(X_0) - f(X_*) + \mathcal{L}_{\tau} \left( \frac{ \delta B}{\rho_1^2 \sigma^2} \frac{\lVert M_0 \rVert_2^2}{2} + \frac{\phi^2_t \eta \kappa^2}{2} \right) \right] \frac{\sigma}{\kappa \sqrt{\delta B}} \frac{1}{\phi^1_t} \overset{t \rightarrow \infty}{\rightarrow} 0
$
where is a random time with distribution .
Thanks for the authors' response. I have no further questions. Since the authors claim to add clarification in the final version, I raise my score to 6.
Dear Reviewer,
Thank you for your trust and the updated score: We truly appreciate it.
Best regards,
The Authors
This paper derives SDEs for SignSGD, RMSprop, and Adam. The analysis offers insights into the convergence speed, stationary distribution, and robustness to heavy-tail noise of adaptive methods.
优点
-
The derived SDE for SignSGD exhibits three different phases of the dynamics.
-
The analysis reveals the difference between SignSGD and SGD in terms of the asymptotic expected loss, the robustness of noise variance, etc.
-
The analysis of AdamW provides insights into the different roles of noise, curvature, and weight decay.
缺点
Refer to Questions and Limitations.
问题
-
What learning rate (lr) do the experiments in Figure 4 use? Within what range of lr does this SDE align well with the original algorithm (experimentally)?
-
Could the authors intuitively explain why the asymptotic expected loss of SignSGD is proportional to instead of ?
-
How can the derived SDE explain the loss spike phenomenon of SignSGD/AdamW?
-
Many works about SGD noise ([1][2][3]) admit the noise structure . What conclusions (such as those related to the training phases) can be derived from the SDE if we change the noise assumption in Corollary 3.3 to ?
[1] Ziyin et al. Strength of minibatch noise in SGD.
[2] Wojtowytsch. Stochastic gradient descent with noise of machine learning type. part II: Continuous time analysis.
[3] Wu et al. The alignment property of SGD noise and how it helps select flat minima: A stability analysis.
局限性
The SDE for AdamW is limited to quadratic functions.
Given the length limit for the Rebuttal, we decided to include these minor results in an Official Comment.
Continuation of Answer to Q4 - The Third Noise Structure:
[2] discusses two possible assumptions on : and . Even though none was in line with the prescription of the Reviewer, we still thought that the one they used, i.e. as per Section 2.4, is interesting. Therefore, we take , where we changed the constant to to maintain consistency with the rest of our paper.
Under this assumption, we have that for , Corollary 3.3 becomes:
$
d X_t = - Erf \left( Y_t \right) dt + \sqrt{\eta} \sqrt{I_d - diag(Erf \left(Y_t \right))^2} d W_t.
$
As a consequence, Lemma 3.5 becomes:
Let be -strongly convex, , and . Then, during
- Phase 1, the loss will reach before because ;
- Phase 2 with and ,
$
\mathbb{E}[S_t] \leq \frac{\beta^2 \left( \mathcal{W}\left( \frac{(\beta + \sqrt{S_0} \alpha)}{\beta} \exp\left(-\frac{\alpha^2 t - 2 \sqrt{S_0} \alpha}{2 \beta} - 1 \right) \right) + 1 \right)^2}{\alpha^2} \overset{t \rightarrow \infty}{\rightarrow} \frac{\beta^2}{\alpha^2}, $
where is the Lambert function. 3. Phase 3 it is the same as Phase 2 but with and .
Please, find an empirical validation of these bounds in Figure 1.c of the attached .pdf file.
Continuation of Answer to Q3 - Reviewer's curiosity: Conjecture on Spiking Phenomena:
This is a very interesting question that we do not address in this paper. While this is not a fundamental element for the flow and contribution of our paper, we gladly try to answer it, both for our and the Reviewer's curiosity.
Although we can not answer this in the general case, we offer the following conjecture to provide an intuition of how one could explain the spiking behavior of the mentioned optimizers.
Since the SDE of RMSprop is less complex and less complicated to work with, we restrict ourselves to this case: Generalizing is only a matter of technicalities.
Let us remind that the SDE of RMSprop is
where .
Intuitively, the dynamics of the parameters is a preconditioned version of SGD. Much differently, is a process that tracks the squared gradient and its noise.
This implies that the expected iterates follow the dynamics:
$
d \mathbb{E}[X_t] = - \mathbb{E}\left[\frac{\nabla f(X_t)}{\sqrt{V_t}}\right] dt.
$
Consistently with the noise structure proposed by the Reviewer and used in many papers (see [3] and references therein), let us assume that the covariance of the noise scales proportionally to the loss function, e.g. .
Spikes seem to happen when the loss is essentially , meaning that , , and . However, if we now draw a minibatch of data for which the gradient is not , e.g. some data points that are outliers, might not have the time to "catch up" with this anomaly. Therefore, the numerator is non- while the denominator is still essentially , meaning that the ratio spikes to infinity, drastically disturbing the dynamics of the iterates and in turn that of the loss function which might spike.
Many thanks to the authors for your careful explanation and detailed rebuttal. I feel that this paper is of great help in understanding signSGD/Adam. I have raised my score.
Dear Reviewer,
Thank you for your kind words and the updated score: We truly appreciate it.
Best regards,
The Authors
We thank the Reviewer for their thorough and thoughtful review. We appreciate the questions posed, as they motivated us to delve deeper and further showcase the explanatory power of our SDEs. However, we would like to clarify that contrary to what is mentioned under "Limitations", none of our SDEs is limited to quadratic functions: The theory applies to general smooth functions. We conducted extensive experimental validation that our SDEs correctly model the respective algorithms on a variety of architectures and datasets (see Figures 1, 4, 8, 9, and 11 and the respective experimental details in Appendix F).
Answers to Q1:
- As per Appendix F.5, the learning rates (lrs) used for AdamW are for the Transformer and for the ResNet. For RMSpropW, they are and , respectively: We will add these details in the caption of the figures;
- In our experiments, we first fine-tuned the hyperparameters to ensure the convergence of the "real" optimizers. Then, we used the same hyperparameters to simulate the SDEs. While we did not ablate the range of the lr over which the SDEs align well with the algorithms, our experiments use a wide range of lrs across different datasets and architectures: From to for SignSGD, from to for RMSprop(W), and from to for Adam(W). Our SDEs match the respective algorithms well in all such cases.
Answer to Q2:
In SGD, the error/noise on the update scales with . In SignSGD, the operator clips the stochastic gradient and hence it also clips its noise: This clipping/normalization implies that this error scales with .
Answer to Q3:
We attempted to address this question while writing the paper, but we were unable to formally explain these phenomena. To satisfy both our curiosity and that of the reviewer, we offer our conjecture in an Official Comment, providing some technical details.
Answer to Q4:
Since it was unclear which assumption was precisely meant, we have read the references and selected three noise structures: We study two below and the third one in an Official Comment.
Under these assumptions, we generalized Cor. 3.3 and provided convergence in the same fashion as Lemma 3.5. Additionally, see the Answer to Question 2 from Reviewer PBh6 for a generalized version of Cor. 3.3 where we only require the loss function to be -smooth.
Assumption from [1]
[1] proposes several expressions for : We took the only one in line with that prescribed by the Reviewer: As per Eq. (16) in Corollary 2, , where controls the scale of the noise and is an optimum.
Therefore, for and , Cor. 3.3 becomes:
Therefore, Lemma 3.5 becomes: Let be -strongly convex, be the largest eigenvalue of , , and . Then, during
- Phase 1, the loss will reach before because ;
- Phase 2 with : ;
- Phase 3 with : .
Please, find an empirical validation in Figure 1.a of the attached .pdf file.
Assumption from [3]
[3] assumes that is aligned with the FIM and proportional to the loss. Consistently with this and with the prescription of the Reviewer, we take , where we changed the constants to to maintain consistency with the rest of our paper.
Therefore, we have that for and , Cor. 3.3 becomes:
Therefore, Lemma 3.5 becomes: Let be -strongly convex, -smooth, , and Then, during
- Phase 1, the loss will reach before because ;
- Phase 2 with and ,
where is the Lambert function; 3. Phase 3, it is the same as Phase 2 but and .
Please, find an empirical validation in Figure 1.b of the attached .pdf file.
Dear Reviewers,
We sincerely appreciate your thorough reviews, insightful comments, and interesting questions regarding our paper: Your feedback has helped enhance our work.
The considerable time and effort we devoted during this rebuttal period were rewarding, as we derived new interesting insights that complemented our paper and made it even more interesting and rich.
We are pleased to report that we have addressed your questions and comments comprehensively, exploring new settings as a result: These responses are detailed in our rebuttals to each of the Reviewers and will be incorporated into the final version of the paper.
We look forward to the upcoming author-reviewer discussion period and kindly ask you to re-evaluate our paper, considering raising your scores and confidence in your assessments.
Thank you for your attention.
Best regards,
The Authors
This paper introduces novel Stochastic Differential Equations (SDEs) for adaptive optimization methods such as SignSGD, RMSProp(W), and Adam(W). The authors provide a detailed analysis of these methods, focusing on their dynamics, convergence behavior, and robustness to noise. The theoretical findings are supported by experiments on various neural network architectures, confirming the accuracy of the derived SDEs in modeling the behavior of these optimizers, which are appreciated by the reviewers.
However, I have concerns about the correctness of one of the main theoretical result in this paper, Theorem 3.2 (or its formal version, Theorem C.5 in appendix). Theorem C.5 mainly says that SignSGD (equation 18) and SDE (equation 16) are order-1 weak approximation. However, for the setting studied in this paper, i.e., batches are sampled i.i.d. from the uniform distribution of a training set of size , both the drift term and diffusion matrix in equation 16 are not continuous, because only takes finite set of discrete values of {}. Therefore even the uniqueness and existence of solution of (16) is not obvious and the proof framework introduced by [Li et al., 17] does not apply here. This mistake is also reflected in the proof of Theorem 3.5, which applies Lemma 3.3 without satisfies its condition (drift and diffusion matrix needs to be lipschitz continuous).
With that being said, this issue could probably be fixed by only considering infinite size training dataset, or equivalently, noise distribution with continuous/smooth densities, like gaussian distribution. However, even in that case, the proof framework of [Li et al., 17] (which is further refined in [Li et al., 19]) does not apply for the SDE approximation of SignSGD directly. Therefore I think the current version is not ready for being published and the amount of modification is significant, which needs another round of review. I encourage the authors to fix the statement and the proof and resubmit to another venue.
-
Li, Q., Tai, C. and Weinan, E., 2017, July. Stochastic modified equations and adaptive stochastic gradient algorithms. In International Conference on Machine Learning (pp. 2101-2110). PMLR.
-
Li, Q., Tai, C. and Weinan, E., 2019. Stochastic modified equations and dynamics of stochastic gradient algorithms i: Mathematical foundations. Journal of Machine Learning Research, 20(40), pp.1-47.
I would like to respectfully express my concern regarding the decision to reject this paper:
-
All three reviewers recommended acceptance. However, the paper was rejected by the Area Chair due to a minor issue regarding the regularity of the SDE coefficients, which could have been easily addressed with a straightforward clarification during the rebuttal period.
-
For example, this minor issue is easily addressed by assuming Gaussian noise (a weaker assumption would suffice), which is a widely accepted assumption in the literature, even when considering finite batch sizes.
-
Alternatively, as highlighted in the seminal paper [1], any potential regularity issues could also be resolved using a mollification approach, thereby bypassing the need for specific noise assumptions altogether.
-
With a score of 6.33, the paper is well above the typical acceptance threshold. Rejecting it based on a debatable concern about an assumption — particularly when the theoretical results are strongly validated by experimental evidence — feels inconsistent with the established standards of our field. It risks undermining the fairness and pragmatism that should guide our evaluation process.
Best regards,
Enea Monzio Compagnoni
[1] Li, Qianxiao, Cheng Tai, and E. Weinan. "Stochastic modified equations and dynamics of stochastic gradient algorithms i: Mathematical foundations." Journal of Machine Learning Research 20.40 (2019): 1-47.