PaperHub
4.8
/10
Rejected5 位审稿人
最低3最高6标准差1.5
3
3
6
6
6
3.2
置信度
正确性2.0
贡献度2.4
表达2.4
ICLR 2025

Generalization Aware Minimization

OpenReviewPDF
提交: 2024-09-27更新: 2025-02-05
TL;DR

We design a generalized version of sharpness aware minimization that directly optimizes the expected test loss landscape, enhancing generalization.

摘要

关键词
generalizationsharpness aware minimizationloss landscapeoptimization

评审与讨论

审稿意见
3

The paper proposed Generalization Aware Minimization (GAM), a generalized version of SAM. The method were grounded on two theorems where the first theorem presents the relationship between the curvature of the general and empirical losses, and the second theorem shows a sequence of transformations that transform the Hessian of the training loss to that of the test loss. The method were tested on CIFAR10, CIFAR100, SVHN using some architectures including

优点

The algorithm is based on an interesting theoretical framework that could motivate new insights. Specifically, Theorem 1 demonstrates the relationship between the curvature directions of expected and training losses. Additionally, Theorem 2 presents a series of efficient transformations that align the Hessian of the training loss with that of the test loss. Even though there were certain limitations with the theoretical analysis, I believe that if the limitations were addressed properly, these insights could demonstrate potential applications to various domains.

缺点

Despite the novel insights from the theory, the theoretical results have certain limitations, notably that they are restricted to scenarios where losses are quadratic. While Section 3.4 mentions that the approach could be extended to general loss functions using local quadratic approximations, the analysis would benefit from a deeper investigation of how these results generalize to broader loss functions and what additional assumptions might be necessary for such extensions. Additionally, there may be an issue with Theorem 1; further details are provided in the Questions section.

On the other hand, the experimental evaluation is limited and could be substantially improved. The experiments were conducted on a narrow set of datasets and older architectures, and the method only demonstrates less than a 1% improvement, which is relatively minor. Broadening the scope to include larger datasets, such as CIFAR-100, and more modern architectures, such as WideResNet [1] or EfficientNet [2], would make the experimental results more robust and impactful.

问题

Here are a few questions I have for the authors:

  • There might be a substantial issue with Theorem 1, where the statement appears somewhat trivial. Specifically, for any values xx and yy in general, it’s possible to express E[xy]\mathbb{E}[x|y] as a function of yy, allowing us to set D=0D = 0 and write E[L(θ)θ~,M~,c~]\mathbb{E}[L(\theta) | \tilde{\theta}^*, \tilde{M}, \tilde{c}] as a scalar function of $\tilde{\theta}^*, \tilde{M}, \tilde{c},hencemaketheresultofthistheoremtrivial.Canyouclarifythenontrivialaspectsofthistheorem,forinstance,arethereanyspecificconstraintson, hence make the result of this theorem trivial. Can you clarify the non-trivial aspects of this theorem, for instance, are there any specific constraints on Doror\tilde{c}$?
  • Could you clarify with a concrete example or a mathematical formulation showing that SAM can be derived as a special case of GAM? I believe that including a brief explanation of this relationship would enhance the clarity of the contribution of the paper.
  • Could you elaborate on the meaning of Equations (16) and (17)? Specifically, does this mean that the right-hand side of Equation (15) is equivalent to both Equations (16) and (17)?
  • Can you provide the results on WN or clarify why these results on WN were omitted?
  • I am also interested in the performance of GAM under different hyperparameter settings. It would be useful to include the details on the sensitivity of GAM with respect to the key hyperparameters including ϵ\epsilon and the number of GAM steps TT.

References:

[1] Zagoruyko, S., & Komodakis, N. (2016). Wide Residual Networks. In Proceedings of the British Machine Vision Conference (BMVC) 2016.

[2] Tan, M., & Le, Q. V. (2019). EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. In Proceedings of the 36th International Conference on Machine Learning (ICML) 2019.

评论

Thank you for your insightful review and for recognizing the potential applications of our theoretical framework.

Concerns on Quadratic Loss Function Assumption

We acknowledge that the assumption of quadratic loss functions limits the direct applicability of our theoretical results. However, we emphasize that any smooth loss function can be locally approximated by a quadratic function around a point of interest. This local approximation is a standard technique in optimization and allows us to extend our insights to the non-convex loss landscapes typical in deep learning.

In practice, GAM leverages this local quadratic approximation to adjust the optimization trajectory, improving generalization even when the global loss landscape is complex.

Experimental Evaluation

We agree that a broader experimental evaluation would strengthen our paper. In the limited time of the rebuttal period, we conducted additional experiments on the ImageNet dataset (trained for 50 epochs) using a ResNet-20-like architecture. GAM achieved a top-1 test accuracy of 5.12%, compared to 4.89% for SAM at γ=0.01\gamma=0.01 and 4.04% for standard SGD (SAM at γ=0.1\gamma=0.1: 3.77%, SAM at γ=0.001\gamma=0.001: 4.55%). These results suggest that GAM's advantages extend to larger-scale datasets and modern architectures.

Clarification on Theorem 1

We apologize for any confusion regarding Theorem 1. The expectation on the left hand side of equation (7) is taken over the randomness in the loss functions, not over θ\theta. Thus, the expected loss remains a function of θ\theta in general. The first term on the right hand side captures this dependence on θ\theta while the second term is independent of θ\theta.

SAM as a Special Case of GAM

Yes, SAM can be seen as a special case of GAM with a single perturbation step (T=1T=1) and a fixed perturbation coefficient (γ1\gamma_1). GAM generalizes this approach by allowing multiple perturbation steps and learning the perturbation coefficients during training. This flexibility enables GAM to capture higher-order information about the loss landscape.

We have added a clarification in the paper to make this relationship explicit.

Clarification of Equations (16) and (17)

You are correct that Equations (16) and (17) both represent the right-hand side of Equation (15), broken down into different components.

We have revised the text to ensure that the progression from Equation (15) to Equations (16) and (17) is clear and logical.

Results on Wide Network (WN)

We omitted results on WN for MNIST because smaller models already achieve near-perfect accuracy on this dataset, leaving little room for improvement. Training larger models like WN on MNIST provides limited additional insight, and we did not consider it an experimental priority given the computational cost of training WN. However, if you consider it important, we are willing to include these results in the revised paper.

Effect of Hyperparameters

Thank you for suggesting an analysis of how these parameters affect GAM's performance. In the limited time of the rebuttal period, we have swept over the following hyperparameters to investigate their impact (fixing a CNN model on CIFAR-10):

  • Batch Size
  • Perturbation Steps (TT)
  • Perturbation Step Size (ϵ\epsilon)

We have included our results in Appendix E Figure 5; in summary, we find that GAM can become unstable under large TT, performs best at small batch sizes, and is relatively insensitive to ϵ\epsilon. We hypothesize that reducing the learning rate of γ\gamma can be one way to mitigate the instability for large TT.

审稿意见
3

This paper finds that the test loss landscape can be regarded as a rescaled version of the training loss landscape. Based on this finding, this paper proposes a novel training strategy called Generalization-Aware Minimization (GAM). This strategy aligns the training loss gradient with the test loss gradient through multiple perturbations, guiding the model parameters towards region which has better generalization, thereby improving performance.

优点

Originality:

The paper seems original. This paper theoretically demonstrates that for quadratic loss functions, the test loss landscape is a rescaled version of the training loss landscape. Moreover, this paper proposes that the test loss gradient can be obtained from the training loss gradient through multiple perturbations.

Significance:

The perspective of this paper is meaningful.

缺点

  1. Insufficient experiments. The experiments are conducted only on three small datasets (i.e., MNIST, CIFAR-10, and SVHN) and three models (i.e., MLP, CNN, and WN), lacking experiments on commonly used models (e.g., ResNet, ViT).

  2. Marginal improvement. Despite introducing significant computational complexity, the gains provided by GAM are not very evident.

  3. Computation cost. The computation cost introduced by GAM is significant and cannot be overlooked.

  4. Poor writing. For example,

Line 159: The vertical relationships are not clearly defined, and 'pp' is not defined here.

Line 264: I suggest that the authors clarify the definition of Δ\Delta in the main paper.

Line 409: "γ1\gamma_1 \in 0.001, 0.01, 0.1" should be "γ1\gamma_1 \in {0.001, 0.01, 0.1}".

问题

  1. Could you provide more comprehensive experimental results, such as those on ResNet and ViT?

  2. Could you provide an ablation study and analysis on the GAM steps TT, discussing how the size of TT affects the model's performance?

  3. How was the discrepancy function used in the method determined?

评论

Thank you for your review and for acknowledging the originality and meaningful perspective of our work.

Experiments on Larger Models

We agree that evaluating GAM on commonly used models is important for demonstrating its practical impact. In the limited time of the rebuttal period, we conducted additional experiments on the ImageNet dataset (trained for 50 epochs) using a ResNet-20-like architecture. GAM achieved a top-1 test accuracy of 5.12%, compared to 4.89% for SAM at γ=0.01\gamma=0.01 and 4.04% for standard SGD (SAM at γ=0.1\gamma=0.1: 3.77%, SAM at γ=0.001\gamma=0.001: 4.55%). These results suggest that GAM's advantages extend to larger-scale datasets and modern architectures.

Marginal Improvement

We understand your concern regarding the size of the performance gains. However, these improvements are consistent across all evaluated datasets and architectures. It's important to note that in the context of deep learning, especially on benchmark datasets, even small improvements can be considered significant due to the already high performance levels achieved by state-of-the-art models. Our results are in line with the improvements reported in the original SAM paper over standard SGD (which are often less than 1%). The consistent outperformance of SAM by GAM suggests that our method offers meaningful benefits in terms of generalization.

Computation Cost

We appreciate your important point regarding the computational overhead introduced by GAM.

To mitigate this cost, we propose updating the GAM perturbation coefficients γt\gamma_t less frequently during training. Instead of optimizing γt\gamma_t at every iteration, we can update them periodically (e.g., every few iterations). This approach significantly reduces the computational overhead while maintaining the performance benefits of GAM.

In the limited time of the rebuttal period, we have added additional experiments in Section 4 evaluating the computational cost of GAM relative to SGD and SAM on CIFAR-10 with a CNN model. We find that standard GAM is roughly 4×4 \times as costly in training time as SGD (relative to 1.3×1.3 \times for SAM), but this cost can be reduced to 3×3 \times when updating γt\gamma_t periodically. Periodic updates hurt GAM's accuracy, but still enable it to outperform SAM and SGD's accuracies.

Writing Clarifications

Thank you for highlighting areas where the writing could be improved. We have made all the clarifications you have suggested in our revision and are happy to make any further revision as well. We apologize for any confusion caused and will ensure that the revised paper is clear and precise.

Effect of Perturbation Steps (TT)

Thank you for suggesting an analysis of how TT affects GAM's performance. In the limited time of the rebuttal period, we have swept over the following hyperparameters to investigate their impact (fixing a CNN model on CIFAR-10):

  • Batch Size
  • Perturbation Steps (TT)
  • Perturbation Step Size (ϵ\epsilon)

We have included our results in Appendix E Figure 5; in summary, we find that GAM can become unstable under large TT, performs best at small batch sizes, and is relatively insensitive to ϵ\epsilon. We hypothesize that reducing the learning rate of γ\gamma can be one way to mitigate the instability for large TT.

Definition of Discrepancy Function

The discrepancy function Δ\Delta used in our method is the negative dot product between the gradient at the perturbed parameters and the gradient on an auxiliary minibatch. This function encourages alignment between the perturbed gradient and the gradient that better approximates the test loss.

We have updated the text to clarify this definition.

评论

Thanks for your response.

Regarding your response, I have the following concerns:

  1. I noticed some inconsistencies in your reply. Initially, you mentioned that GAM was 15 times slower than SGD (i.e. "Specifically, our initial implementation shows that GAM is approximately 15 times slower than SGD, while SAM is about 1.7 times slower."), but in the latest version, this was revised to 4 times (i.e. "We find that standard GAM is roughly 4×4 \times as costly in training time as SGD."). Could you explain the reason for this change? Is it due to differences in how time complexity is measured theoretically versus experimentally?

  2. Are there alternative options for the discrepancy function, or could you conduct an ablation study on this discrepancy function?

  3. I am not entirely satisfied with the experimental results on large models and datasets. Although a ResNet-20-like architecture was used instead of more advanced models, the improvement seems limited. Considering time constraints, I suggest incorporating more advanced models in future updates and experimenting with NLP datasets.

评论

Thank you for your follow-up comments and for giving us the opportunity to clarify and address your concerns.

Inconsistencies in Reported Computational Costs

We apologize for any confusion regarding the reported computational costs of GAM. Initially, we mentioned that GAM was approximately 15 times slower than SGD based on a preliminary estimate focused solely on the time required for computing each individual parameter update, excluding factors like data loading, preprocessing, and other overheads common to both GAM and SGD.

In our latest experiments, we conducted a more comprehensive and precise measurement of the total training time including all steps of the training process. Under these conditions, we found that standard GAM is roughly 4×4 \times as costly in training time as SGD. These measurements are averaged 55 over separate trials, giving us confidence in the validity of this measurement.

Alternative Options for the Discrepancy Function and Ablation Study

Thank you for highlighting the importance of the discrepancy function in GAM. In our current implementation, we use the negative dot product as the discrepancy function. This choice encourages alignment between the gradient at the perturbed parameters and the gradient on the auxiliary minibatch, promoting updates that generalize better.

We agree that exploring alternative discrepancy functions could provide valuable insights into the effectiveness of GAM.

Potential alternatives include:

  • Squared error between the gradients
  • Cosine similarity to measure the angle between the gradients

While we have not yet conducted an ablation study on different discrepancy functions, we are happy to do so. Due to the limited time remaining in the rebuttal period, we may not be able to present these experiments in the rebuttal. However, we recognize their importance and will prioritize them in the revised version of the paper.

Experimental Results on Large Models and Datasets

We understand your desire to see GAM evaluated on more advanced models and larger datasets, including NLP tasks. During the limited rebuttal period, we were able to show preliminary experimental results on ImageNet (in addition to several other experiments); these experiments showed that GAM outperforms SAM and SGD in this setting as well.

We acknowledge that the improvement may seem modest in percentage terms. Still, we highlight that in the context of SAM-based optimization methods, improvements on the order of 11% are actually quite significant: SAM's improvements over SGD are also relatively small in percentage terms.

As for NLP datasets, we appreciate your suggestion and recognize the importance of evaluating GAM in different domains. Had this recommendation been made in your initial review, we may have prioritized allocating time to conduct NLP experiments (in addition to the several other additional experiments we did conduct). We are happy to conduct an experiment on an NLP task, but will unfortunately likely be unable to share our results by the close of the rebuttal period.

审稿意见
6

This paper proposes a new optimization procedure, Generalization-Aware Minimization (GAM), based on the analysis which shows that the curvature of the true loss landscape is a rescaled version of that of the observed loss landscape. Multiple perturbation step is used to calculate the gradient update such that it is a good approximation of the gradient update in the true loss landscape. Experiments on small-scale datasets show the effectiveness of GAM.

优点

  1. The theoretical analysis looks plausible, with a clear explanation of the rationale of assumptions.

  2. The idea of multiple perturbation steps, and learning a perturbation coefficient via the difference between the perturbed gradient and the original gradient looks interesting.

缺点

  1. More explanation about understanding why the true loss landscape is a rescaled version of the observed loss landscape should be introduced. When D(Λ~)=Λ~D(\tilde{\Lambda})=\tilde{\Lambda}, the analysis seems trivial. Intuition or examples of the when and why D(Λ~)D(\tilde{\Lambda}) differ from Λ~\tilde{\Lambda}, and what implications this has for the optimization process can be given.

  2. I would like to see more discussion on the effect of batch size, perturbation step, and ϵ\epsilon, as these parameters affect GAM's approximation of the true loss landscape. Analysis of how they impact the trade-off between computational cost and performance will gives a better understanding of the effectiveness of the proposed GAM.

  3. Experiments are not convincing enough. The baseline experiment can be more complete. it seems that as γ\gamma increases, SAM's performance also increases. GAM should be compared to SAM with larger γ\gamma. I understand that it is computationally intensive to calculate the gradient dd and γ\gamma. However, experiments on larger architecture and datasets are expected to show GAM's effectiveness (refer to the empirical evaluation of SAM [1]).

[1] https://arxiv.org/pdf/2010.01412

问题

  1. In Appendix A, how does (15) come out? It is not clear why θ\theta^* is eliminated.
  2. Please give more explanation of equation 7, especially on D(Λ~)D(\tilde{\Lambda}).
  3. How would batch size, perturbation step, and ϵ\epsilon affect the empirical performance of GAM? Ablation studies on these parameters will be helpful.
评论

Thank you for your constructive feedback and for recognizing the plausibility and interest of our theoretical analysis.

Explanation of the Relationship Between Test Loss and Training Loss

We appreciate your request for a more intuitive explanation. The key insight is that the training loss landscape accurately captures the directions of curvature (the eigenvectors of the Hessian matrix) but may distort the curvature magnitudes (the eigenvalues). This distortion occurs due to noise from the data sampling process.

By understanding that the expected test loss landscape is a rescaled version of the training loss landscape, we can design optimization strategies that adjust for this discrepancy. Specifically, we can transform the training loss gradient to approximate the test loss gradient, guiding the optimization process toward solutions that generalize better.

Figure 1 in the paper provides a visual illustration of this concept.

We are happy to provide any further clarification on this, especially since this is a key aspect of our paper.

Effect of Batch Size, Perturbation Steps, and ϵ\epsilon

Thank you for suggesting an analysis of how these parameters affect GAM's performance. In the limited time of the rebuttal period, we have swept over the following hyperparameters to investigate their impact (fixing a CNN model on CIFAR-10):

  • Batch Size
  • Perturbation Steps (TT)
  • Perturbation Step Size (ϵ\epsilon)

We have included our results in Appendix E Figure 5; in summary, we find that GAM can become unstable under large TT, performs best at small batch sizes, and is relatively insensitive to ϵ\epsilon. We hypothesize that reducing the learning rate of γ\gamma can be one way to mitigate the instability for large TT.

Comparison with SAM Using Larger γ\gamma

We appreciate your point about comparing GAM to SAM with larger perturbation sizes. In the limited time of the rebuttal window, we have conducted additional experiments on CIFAR-10 using the CNN architecture and found the following results for SAM over 55 seeds:

At γ=0.3\gamma=0.3, accuracy: 0.64534 ±\pm 0.02078

At γ=1.0\gamma=1.0, accuracy: 0.44048 ±\pm 0.01161

Both these results are worse than our existing SAM results, giving us confidence that we have chosen an appropriate range of γ\gamma.

We note that simply increasing γ\gamma in SAM does not capture the nonlinear transformation of the curvature magnitudes that GAM learns through its adaptive perturbation strategy. As shown in Figure 3, GAM learns a more highly curved mapping between the training and test loss landscapes, which cannot be replicated by merely increasing γ\gamma in SAM.

Experiments on Larger Datasets

We agree that demonstrating GAM's effectiveness on larger datasets and architectures strengthens our work. In the limited time of the rebuttal period, we conducted additional experiments on the ImageNet dataset (trained for 50 epochs) using a ResNet-20-like architecture. GAM achieved a top-1 test accuracy of 5.12%, compared to 4.89% for SAM at γ=0.01\gamma=0.01 and 4.04% for standard SGD (SAM at γ=0.1\gamma=0.1: 3.77%, SAM at γ=0.001\gamma=0.001: 4.55%). These results suggest that GAM's advantages extend to larger-scale datasets and modern architectures.

Clarification on Equation (15)

In Equation (15), we reorganize the terms of the expected test loss into components based on their dependence on θ\theta. Specifically:

  • Quadratic Term: Represents the curvature of the loss landscape with respect to θ\theta.

  • Linear Term: Represents the gradient component that depends linearly on θ\theta.

  • Constant Term: Collects all terms independent of θ\theta, including those dependent only on θ\theta^*.

Explanation of Equation (7)

Equation (7) conveys that the expected test loss is also a quadratic like the training loss, and has the same eigenvectors as the training loss but with eigenvalues transformed by an arbitrary function DD. This transformation reflects the rescaling of the curvature magnitudes between the training and test losses. We are happy to further clarify this since this equation is central to our submission.

评论

Thank you for the response. There are some points that I want to seek your clarification:

Explanation of the Relationship Between Test Loss and Training Loss

Noise from the data sampling process might make the test loss a rescaled version of the training loss. The effectiveness of the proposed GAM then lies in how multiple perturbation steps and the auxiliary minibatch can help in the approximation of the true loss landscape. I am looking forward to seeing

  1. more discussion on how to get a good approximation of the true loss landscape, given the training data affected by noisy data sampling.
  2. experiments on the effect of batch size, perturbation steps, and perturbation step size.
  3. more discussion on the transformation DD. As mentioned by Reviewer dh8v and me, theorem 1 can be trivial. Even if D(Λ~)=Λ~D(\tilde{\Lambda})=\tilde{\Lambda}, will the proposed GAM still help in getting more generalizable minima?

In addition, is it possible to show the eigenvalues of the Hessian matrix of the loss of training and test data, perhaps under a simple network with simulated data? This would give a better idea to understand the relationship between train and test loss.

Experiments on Larger Datasets

Thank you for sharing the results of GAM on ImageNet. Did the training converge? If not, it is less convincing to suggest the superiority of GAM compared to SAM and SGD.

评论

Thank you for your timely follow-up questions and for giving us the opportunity to provide further clarification.

Approximating the True Loss Landscape

To get a good approximation of the true loss landscape, GAM leverages multiple perturbation steps and utilizes an auxiliary minibatch to estimate the discrepancy between the perturbed gradient and an approximation of the test loss gradient. The auxiliary minibatch serves as a proxy for unseen data, helping to adjust the perturbation coefficients γt\gamma_t to better align the training loss gradient with the test loss gradient.

By iteratively updating γt\gamma_t based on the discrepancy function Δ\Delta, GAM adapts to the noise inherent in the training data. This adaptive process allows GAM to correct for the distortions in the curvature magnitudes caused by overfitting or sampling noise, thereby improving the approximation of the true loss landscape.

Effect of Batch Size, Perturbation Steps, and ϵ\epsilon

Thank you for suggesting an analysis of how these parameters affect GAM's performance. In the limited time of the rebuttal period, we have swept over the following hyperparameters to investigate their impact (fixing a CNN model on CIFAR-10):

  • Batch Size
  • Perturbation Steps (TT)
  • Perturbation Step Size (ϵ\epsilon)

We have included our results in Appendix E Figure 5; in summary, we find that GAM can become unstable under large TT, performs best at small batch sizes, and is relatively insensitive to ϵ\epsilon. We hypothesize that reducing the learning rate of γ\gamma can be one way to mitigate the instability for large TT.

Discussion on the Transformation DD and the Triviality of Theorem 1

We understand your concern about the potential triviality of Theorem 1 when the transformation DD is the identity function (i.e., D(Λ~)=Λ~D(\tilde{\Lambda}) = \tilde{\Lambda}). In such a case, where the training and test loss landscapes have identical curvature magnitudes, there is no rescaling to correct, and GAM would not provide significant benefits over standard optimization methods.

However, in practice, the transformation DD is often nonlinear. These factors cause discrepancies between the curvature magnitudes of the training and test losses. GAM is designed to learn and correct for these nonlinear transformations by adjusting the perturbation coefficients γt\gamma_t during training.

Our synthetic experiments, as shown in Figure 2 of the paper, demonstrate cases where DD is highly nonlinear. In these scenarios, GAM effectively captures the complex relationship between the training and test loss landscapes, leading to improved generalization. Therefore, while GAM may not offer advantages when DD is identity, it is valuable in the common case where DD is nonlinear.

Visualization of Hessian Eigenvalues

Figure 2 in the paper, in fact, exactly shows the eigenvalues of the true loss Hessian and those of the observed (training) loss Hessian, as well as the mapping between the two. If you believe that a different type of visualization or additional data would be more helpful, we are happy to include it in the revised paper.

Experiments on Larger Datasets

Regarding your question about the ImageNet experiments, due to the limited time during the rebuttal period, we were able to run the training for only 50 epochs. At this time, although the initial results for GAM are promising, the models have not fully converged.

We acknowledge that full convergence is essential for a conclusive comparison. We are still training these models and will include the final, fully converged results in the revised paper.

Please let us know if there are any further questions or concerns we can address.

评论

Thank you for the response and the additional experiments conducted. It is interesting to see how GAM is affected by perturbation step, batch size, and perturbation step size.

The proposed idea that the test landscape is a rescaled version of the training loss landscape provides a unique perspective. The experiments and visualization show some support for the rescaling of loss landscape under training and test data. However, I am still concerned about the GAM's performance, especially under larger architectures and datasets.

Based on the above, I would raise my score to 6.

评论

Thank you for your thoughtful response and for reassessing our submission. We are pleased that you found our analysis of how GAM is affected by perturbation steps, batch size, and perturbation step size to be interesting and informative.

We understand your concerns regarding GAM's performance on larger architectures and datasets. We would like to highlight that although GAM's improvement may seem modest in percentage terms, improvements on the order of 1% are actually quite significant. In fact, in the original SAM paper, SAM's improvements over SGD are also relatively small in percentage terms.

We are committed to expanding our experiments to include more complex models and larger datasets in future work.

审稿意见
6

This paper introduces Generalization-Aware Minimization (GAM), which is a generalized version of SAM which employs multiple perturbation steps (whereas SAM employs a single-step perturbation). GAM directly optimizes for the expected test loss in order to achieve better generalization.

优点

  1. The paper is written clearly. Theoretical insights are supported empirically.
  2. The theoretical insights can provide valuable insights into how sharpness-based algorithms work.

缺点

  1. While I agree that the performance gains in table 1 illustrate that GAM > SAM > SGD, the relative gains of GAM over SAM seem relatively small.
  2. It would be nice to see some results in other modalities (e.g., maybe some language related tasks. Aside: for language related tasks, people care about OOD performance as well, so maybe expected test loss is not as meaningful?)

问题

As suggested in the paper, the curvature of the expected test loss is a rescaled version of train loss, aligned along the same principle directions. Is there any connection to CR-SAM: Curvature Regularized Sharpness-Aware Minimization (https://arxiv.org/abs/2312.13555)?

评论

Thank you for your positive feedback and for acknowledging the clarity and potential impact of our work.

Size of Performance Gains

We understand your observation regarding the relatively small performance gains of GAM over SAM. However, these improvements are consistent across all evaluated datasets and architectures. It's important to note that in the context of deep learning, especially on benchmark datasets, even small improvements can be considered significant due to the already high performance levels achieved by state-of-the-art models.

Our results are in line with the improvements reported in the original SAM paper over standard SGD (which are often less than 1%) The consistent outperformance of SAM by GAM suggests that our method offers meaningful benefits in terms of generalization.

Additional Modalities

We appreciate your suggestion to evaluate GAM on other modalities. In the limited time of the rebuttal period, we conducted additional experiments on the ImageNet dataset (trained for 50 epochs) using a ResNet-20-like architecture. GAM achieved a top-1 test accuracy of 5.12%, compared to 4.89% for SAM at γ=0.01\gamma=0.01 and 4.04% for standard SGD (SAM at γ=0.1\gamma=0.1: 3.77%, SAM at γ=0.001\gamma=0.001: 4.55%).

These results suggest that GAM's advantages extend to larger-scale datasets and modern architectures. We agree that exploring GAM's effectiveness on language-related tasks and out-of-distribution performance is valuable, and we plan to investigate these areas in future work.

Connection to CR-SAM

Thank you for bringing up CR-SAM. While both CR-SAM and GAM aim to improve generalization by modifying the SAM framework, there are key differences:

Motivation and Approach: CR-SAM introduces a curvature regularization term to penalize high curvature in the loss landscape, thereby encouraging flatter minima. In contrast, GAM is derived from a perspective that directly targets the expected test loss by transforming the training loss gradient through multiple perturbations.

Flexibility in Perturbations: GAM allows for an arbitrary number of perturbation steps and learns the perturbation sizes online during training. This flexibility enables GAM to capture higher-order information about the loss landscape. CR-SAM, on the other hand, typically employs a fixed perturbation strategy.

评论

Thank you for your response and for the clarifications! I will maintain my score.

审稿意见
6

The paper introduces Generalization-Aware Minimization (GAM), an optimization algorithm that aims to improve neural network generalization by optimizing for expected test loss. The authors extend the Sharpness-Aware Minimization (SAM) concept by using multiple perturbation steps and adaptive perturbation sizes. The key theoretical contribution shows that the expected test loss landscape is a rescaled version of the training loss landscape for quadratic losses. The method consistently outperforms SAM and standard SGD on several benchmark datasets.

优点

  • Provides a rigorous theoretical analysis of the relationship between training and test loss landscapes

  • Offers insights into why SAM works, moving beyond the simple "flat minima" explanation

  • Shows consistent improvements over baseline methods

缺点

  • The method requires multiple perturbation steps, each involving forward and backward passes. While the authors dismiss computational concerns with a brief statement that "the overhead remains manageable," this deserves more thorough treatment. The paper would benefit from a detailed analysis of computational costs, especially for larger models or when more perturbation steps are needed.

  • Theorem 1, a cornerstone of the paper's theoretical contribution, is preceded by 14 lines of dense assumptions without adequate explanation. The presentation would be more accessible if the authors could provide clear motivation and intuition for critical assumptions and better explain the theorem's implications and practical significance.

  • While the theory is primarily based on quadratic loss assumptions with convex landscapes, the authors do address this limitation The extension to non-quadratic losses in the practical considerations section provides important context.

问题

  • Could you explain more intuitively what it means for "the expected test loss landscape is a rescaled version of the training loss landscape"? Why this is important for generalization?

  • Have you explored any techniques to reduce the computational cost while maintaining the benefits of multiple perturbation steps?

评论

Thank you for your thoughtful review and for highlighting both the strengths and areas for improvement in our paper.

Concerns on Computational Costs

We appreciate your important point regarding the computational overhead introduced by GAM. You are correct that GAM requires multiple perturbation steps, which increases computational cost compared to SAM and standard SGD.

To mitigate this cost, we propose updating the GAM perturbation coefficients γt\gamma_t less frequently during training. Instead of optimizing γt\gamma_t at every iteration, we can update them periodically (e.g., every few iterations). This approach significantly reduces the computational overhead while maintaining the performance benefits of GAM.

In the limited time of the rebuttal period, we have added additional experiments in Section 4 evaluating the computational cost of GAM relative to SGD and SAM on CIFAR-10 with a CNN model. We find that standard GAM is roughly 4×4 \times as costly in training time as SGD (relative to 1.3×1.3 \times for SAM), but this cost can be reduced to 3×3 \times when updating γt\gamma_t periodically. Periodic updates hurt GAM's accuracy, but still enable it to outperform SAM and SGD's accuracies.

Intuition for Assumptions and Practical Significance of Theorem 1

Thank you for requesting a clearer explanation of Theorem 1's assumptions and their practical implications. In Appendix C of the paper, we provide justifications for each assumption. To summarize:

Quadratic Loss Functions: We assume that both the true (test) and observed (training) loss functions are quadratic. This serves as a local approximation to the loss landscape around a minimum, which is a common practice in optimization theory.

Rotational and Location Invariance: We assume that the generative process of the loss functions is invariant under rotations and translations. This means that the statistical properties of the loss landscape are the same in all directions and locations, simplifying the analysis.

Practically, Theorem 1 implies that the training loss landscape accurately captures the curvature directions (eigenvectors) of the test loss landscape but distorts the curvature magnitudes (eigenvalues). In other words, while the orientation information is preserved, the scale information is distorted. This understanding allows us to design optimization algorithms that correct for this distortion, leading to improved generalization.

Extension to Non-Quadratic Losses

You raise an important point about the applicability of our theoretical results to non-quadratic loss functions. We highlight that any smooth loss function can be approximated by a quadratic function in a sufficiently small neighborhood around a point. Therefore, we expect Theorem 1 to hold locally for any smooth loss function.

In practice, this means that our insights are relevant for the complex, non-convex loss landscapes encountered in deep learning. By considering local quadratic approximations, GAM can be effectively applied to train neural networks.

Intuition for the Test Loss Being a Rescaled Version of the Training Loss

Intuitively, the training loss landscape provides accurate information about the directions in which the loss changes most rapidly (the curvature directions), but it may misrepresent how sharply the loss changes along those directions (the curvature magnitudes). This distortion occurs due to noise from the data sampling process.

By recognizing that the test loss is a rescaled version of the training loss, we can adjust our optimization process to compensate for the distorted scale information. This involves modifying the training loss landscape to better align with the test loss landscape, ultimately improving generalization performance.

We are happy to provide any further clarification on this, especially since this is a key aspect of our paper.

评论

Thank you for your detailed response addressing my concerns. The additional experiments with periodic updates present an interesting direction, and while runtime remains a limitation, I believe this trade-off would be a promising direction. I am upgrading my rating to 6.

评论

Thank you for your positive feedback and for taking the time to reconsider your evaluation of our work. We are glad that our detailed responses and the additional experiments addressing computational costs have provided clarity and value.

AC 元评审

This paper proposes a generalized version of sharpness-aware minimization (SAM), which employs multiple perturbation steps instead of SAM's single step. The algorithm is motivated by theoretical analysis that shows the expected test loss landscape is a rescaled version of the training loss landscape. Experiments show that the proposed method can work better than SGD and SAM on MNIST, CIFAR-10, and SVHN, with MLP and CNNs.

We had three borderline scores (6,6,6) and two negative scores (3,3) with an average of 4.8, tending towards rejection. Many reviewers mentioned the theoretical analysis and the new algorithm as the strengths of the paper, and that the presentation is clear. On the other hand, most reviewers were not convinced with the limited experimental results, since they were based on small datasets such as MNIST and CIFAR-10, and only basic models such as MLPs and CNNs. Several reviewers were concerned with the additional computational costs (with marginal generalization-performance gains.) There were also concerns about the limitations/assumptions of the theorems.

The authors have provided additional results for ImageNet but the results are pre-convergence and the top-1 accuracy is very low for all methods (between 4.5%-5.2%). While these early results may suggest the potential of GAM, we believe the paper can significantly benefit from further investigation and another round of reviews at a future conference or journal. Based on these considerations, we recommend rejection at this time while encouraging the authors to revise.

审稿人讨论附加意见

Two reviewers raised their score after reading the rebuttal and having further discussions. One reviewer was satisfied with the additional results with hyperparameter study. The other reviewer was satisfied with the additional experiments with periodic updates and discussions about other points such as computational costs and intuition of theory. The final meta review was mostly influenced by the concern of limited experimental results, which requires significantly more time than the rebuttal/discussion period to carefully investigate the performance.

最终决定

Reject