Sharpness-diversity tradeoff: improving flat ensembles with SharpBalance
This paper reveals a trade-off between sharpness and diversity in deep ensembles, both empirically and theoretically, and proposes SharpBalance, a novel ensemble algorithm that achieves an optimal balance between these two crucial metrics.
摘要
评审与讨论
This paper presents introduces a training approach for ensemble learning called SharpBalance to balance sharpness and diversity within ensembles. This paper shows theoretically that SharpBalance achieves a better sharpness-diversity trade-off.
优点
- Ensemble learning is an important research direction.
- Understanding of sharpness and diversity within deep ensembles is important for the study of generalization to both in-distribution and out-of-distribution data.
- The paper is technically sound.
缺点
- Since SharpBalance focuses "on a diverse subset of the sharpest training data samples", it may not apply in small datasets where available data is already sparse.
- Empirical improvement over existing methods is marginal.
问题
- Why does SharpBalance seem more effective on corrupted data?
- Do models in an ensemble converge on the same local minima?
局限性
Limitations are adequately addressed.
Weakness 1
We conducted an additional experiment to verify the effectiveness of the proposed method on small datasets, with results shown in Table 9 of the rebuttal PDF. The small datasets were generated by randomly subsampling the training set with ratios of 0.3 and 0.5. The experiments used a three-member ResNet18 ensemble on CIFAR10. The results demonstrate that SharpBalance maintains its performance advantage on small datasets compared to the two baseline methods. The hyperparameter search and setup are consistent with Appendix D.3 of the submitted paper.
Weakness 2
We clarify that SharpBalance demonstrates more pronounced empirical improvements as the number of ensemble models increases. In Figure 16 of the rebuttal PDF, we present additional results showing the impact of increasing the number of models in the ensemble. The accuracy difference between SharpBalance and the baseline methods becomes more significant, especially on corrupted data. Specifically, SharpBalance outperforms the baseline by up to 1.30% when ensembling 5 models on CIFAR100-C.
Question 1
Ensembles trained with SharpBalance achieve greater diversity among individual models while maintaining their individual predictive power. The enhanced diversity allows the ensemble model to perform well under distribution shifts in the noisy dataset. This is because, in a diverse ensemble, each individual model may capture more distinct features of the data distribution and effectively mitigate the impact of corrupted data [1-3]. Also, research in [4] suggests that the high diversity of the features learned by the model promotes transferability to OOD data. On the other hand, the improved sharpness-diversity tradeoff also reduces the sharpness of the overall ensemble model, which reduces the impact of overfitting. On corrupted datasets, overfitting is more harmful to the ensemble generalization performance due to the presence of noisy features, which makes SharpBalance more effective over baseline methods.
Question 2
Rigorously defining "local minimum" in light of mode connectivity [5] can be tricky. Here we only provide an intuitive answer using commonly acknowledged statements in this field, which may be imprecise under certain conditions. In general, models in an ensemble are unlikely to converge to the same local minima. Studies in [6] show when models are randomly initialized, the training trajectories of these models tend to explore diverse modes (minima) in the loss landscape and as a result, do not converge to the same local minimum. In particular, ensemble members trained by SharpBalance are unlikely to converge to the same local minimum either. This is implied by the increased diversity of the individual model's output as models residing in the same local minimum are more likely to output similar logits resulting in a low diversity among ensemble members. Notice that this explanation is based on the common intuition that a loss landscape indeed contains many minima. If, in the overparameterized case where all local minima are connected through low-loss paths [5], this statement has to be taken with a grain of salt.
Reference
[1] Abe et al. Pathologies of Predictive Diversity in Deep Ensembles.
[2] Stickland et al. Diverse Ensembles Improve Calibration.
[3] Kumar et al. Calibrated ensembles can mitigate accuracy tradeoffs under distribution shift.
[4] Nayman et al. Diverse Imagenet Models Transfer Better.
[5] Garipov et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs.
[6] Fort et al. Deep Ensembles: A Loss Landscape Perspective.
Thanks for the authors' response, which clarified some points.
We sincerely thank the reviewer for acknowledging our response.
The paper proposes SharpBalance, that is a method aiming to investigate the relationship between sharpness and diversity for deep ensembles.
优点
- SharpBalance looks quite effective for the out-of-distribution setting. The goal of balancing sharpness and diversity within ensembles is an important idea.
- Great theoretical analysis
缺点
- The authors are aware of the paper called “Diversity-Aware Agnostic Ensemble of Sharpness Minimizers” [1], the idea is quite like the proposed paper, they aim to investigate the relations between sharpness and diversity on ensemble learning. I suggest the authors to discuss the main differences between both.
[1] Anh Bui, Vy Vo, Tung Pham, Dinh Phung and Trung Le, Diversity-Aware Agnostic Ensemble of Sharpness Minimizers, arXiv:2403.13204.
- Regarding the baselines the authors only compare SharpBalance with SAM. Nevertheless, newer, and stronger baselines like GSAM [2] and OBF [3] should also be benchmarked since they are the current state-of-the-art.
[2] Zhuang, J., Gong, B., Yuan, L., Cui, Y., Adam, H., Dvornek, N., Tatikonda, S., Duncan, J., and Liu, T. Surrogate gap minimization improves sharpness-aware training. arXiv preprint arXiv:2203.08065, 2022.
[3] Vani, A; Tung, F; Oliveira G; Sharifi H. Forget Sharpness: Perturbed Forgetting of Model Biases Within SAM Dynamics, International Conference on Machine Learning (ICML) 2024.
- Another point to improve are the datasets. I strongly suggest the authors to benchmark with at least a couple large scale datasets. Options are ImageNet-V1 [4] for training and ImageNet-Real [5] and ImageNet-V2 [6] for testing, ImageNet-R [7] for out-of-distribution robustness benchmark and ImageNet-Sketch [8].
[4] Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. ImageNet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp. 248–255. IEEE, 2009
[5] Beyer, L., He ́naff, O. J., Kolesnikov, A., Zhai, X., and Oord, A. v. d. Are we done with imagenet? arXiv preprint arXiv:2006.07159, 2020.
[6] Recht, B., Roelofs, R., Schmidt, L., and Shankar, V. Do imagenet classifiers generalize to imagenet? In Interna- tional conference on machine learning, pp. 5389–5400. PMLR, 2019.
[7] Hendrycks, D., Basart, S., Mu, N., Kadavath, S., Wang, F., Dorundo, E., Desai, R., Zhu, T., Parajuli, S., Guo, M., et al. The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8340–8349, 2021.
[8] Wang, H., Ge, S., Lipton, Z., and Xing, E. P. Learning robust global representations by penalizing local predictive power. In Advances in Neural Information Processing Systems, pp. 10506–10518, 2019.
问题
- Regarding the new discovery contribution. As I previously stated on the weakness section were the authors aware of the paper “Diversity-Aware Agnostic Ensemble of Sharpness Minimizers. Could the authors present the main differences between this method and SharpBalance.
局限性
-
The proposed method is not a significant improvement for ID scenarios as the authors claim. I would tone that down over the whole text. Maybe clearly state that SharpBalance is indeed superior to OOD scenarios and competitive on ID settings.
-
I would not claim the phenomenon called sharpness-diversity trade-off is a discovery, paper [1] is addressing the same phenomena and it was publicly available on arxiv before submission and on openreview since the beginning of the year.
Weakness1 and Question
We present three key distinctions between DASH in [1] and SharpBalance. First, SharpBalance offers a comprehensive identification and rigorous analysis of the sharpness-diversity trade-off phenomenon. Second, our novel theoretical approach using random matrix theory provides precise quantifications and reveals the relationship between sharpness and diversity. Finally, our theory-motivated algorithm provably achieves greater diversity while maintaining the same level of sharpness, leading to an improved ensemble performance.
- (Comprehensive identification and analysis of the sharpness-diversity trade-off). SharpBalance provides a thorough examination of the sharpness-diversity trade-off through extensive experiments across various settings, including different sharpness/diversity measures and different model capacities, and rigorous theoretical analysis with random matrix theory. Our findings theoretically proved the existence of such a trade-off and empirically demonstrated the universality of this phenomenon through extensive experiments. DASH proposed in [1], in contrast, offers an intuitive explanation of why reducing sharpness might decrease diversity. The explanation is that the decrease in diversity is a result of models being initialized closely and updated with the same mini-batch of data.
- (Tight quantification using random matrix theory). SharpBalance develops a general theoretical framework using random matrix theory (RMT) to explain the interplay between sharpness and diversity in deep ensembles which allows us to derive tight quantifications for these two important metrics, providing a more accurate characterization of the training dynamics of ensemble models. We used these quantification results to characterize the relationship between sharpness and diversity, further validating the correctness of our empirical observation. DASH, on the other hand, provides an upper bound on generalization error with the sharpness of both base learners and the ensemble.
- (Theory-based algorithm provably achieves improved performance). We propose SharpBalance to balance sharpness and diversity within ensembles and theoretically show that the method achieves an improved trade-off between the two metrics. Empirical validations suggest our method indeed enhances the trade-off and therefore improves ensemble performance. The algorithm selects a sharpness-aware subset of data for each individual model to train on the sharpness objective and is simple, effective, and computationally cheap to implement. In contrast, DASH adds a KL divergence constraint on the output logits, which is different from our method which uses distinct subsets of data to train individual models. While their method will introduce diversity, our method is theoretically guaranteed to improve diversity. Furthermore, we have extensive experiments to demonstrate the improved diversity of SharpBalance. The two methods are in fact orthogonal and can be seen as complements of each other for promoting diversity.
Weakness 2
We conducted new experiments to include the stronger SAM method GSAM. The results are shown in Table 7 of the rebuttal PDF. We combine GSAM with Deep Ensemble as a new baseline method "Deep Ensemble + GSAM" and incorporate the GSAM into our method SharpBalance. The results show that the new baseline with GSAM outperforms the original baseline in ID and OOD performance but still underperforms SharpBalance (w/ SAM). Furthermore, we enhance SharpBalance by replacing the SAM with GSAM, which leads to better ID performance. The hyperparameter search and setup are consistent with Appendix D.3 of the submitted paper.
Limitation 1 and 2
We thank the reviewer for the suggestions. In the revised draft, we will provide more precise statements on SharpBalance's performance in ID and OOD settings and clarify the contributions of [1] and our work, as suggested.
Reference
[1] Bui et al. Diversity-Aware Agnostic Ensemble of Sharpness Minimizers
I appreciate the authors thorough rebuttal and the effort they put into addressing the concerns raised. After carefully considering their responses, I have decided to raise my score from 5 to 6 as a reflection of their efforts to clarify and improve upon the points.
We thank the reviewer for their positive feedback and for raising the score. We will ensure that the clarification and results are included in the updated manuscript.
This paper investigates the sharpness and diversity within deep ensembles. Specifically, it identifies the trade-off phenomenon between sharpness and diversity with both theoretical and empirical evidence. Additionally, it proposes a method called SharpBalance, which trains individuals using selective 'sharp' subsets. Conducted experiments have demonstrated the effectiveness of the proposed SharpBalance when applied to deep ensembles.
优点
There are several strengths in this paper:
-
The exploration of sharpness and diversity in deep ensembles is both interesting and novel.
-
Sufficient theoretical and empirical evidence has been provided for validation.
-
The proposed method is simple, effective, and accompanied by code for verification.
缺点
However, I still have the following concerns:
-
The evaluation seems a bit weak. The authors should consider comparing with more ensemble baselines.
-
What is the scale of and how does it change during training? Providing some details on this would help in understanding the proposed method.
-
Refer to Line 166: How do the authors train individuals with the full datasets? Are these individuals trained with different initializations?
-
(Optional) As described, the model's generalization is not merely correlated with sharpness, which aligns with some recent advanced SAM variants. Thus, integrating these advanced variants [1][2] with SharpBal would be more beneficial for studying the trade-off between sharpness and diversity.
References:
[1] Random Sharpness-Aware Minimization. In NeurIPS 2022.
[2] Gradient Norm Aware Minimization Seeks First-Order Flatness and Improves Generalization. In CVPR 2023.
问题
Please refer to the Weaknesses.
局限性
The authors have provided Limitations section.
Weakness 1
In addition to the main experiments, we compared SharpBalance with ensemble baselines, including those in the appendix and new experiments in the rebuttal PDF.
- Ensemble with models trained with different hyperparameters. In Appendix F.4, we compared with the "SAM+" baseline, which forms an ensemble using three models trained with different SAM perturbation ratios (0.05, 0.1, and 0.2).
- Ensemble of moving averages (EoA) method [1]. In Appendix F.4, we also compared SharpBalance with EoA, a strong baseline that uses an efficient model averaging protocol.
- Ensemble with other strong SAM optimizer GSAM [2]. In Table 7 of the rebuttal PDF, we presented new experiments with the "Deep Ensemble + GSAM" baseline, which replaces the SAM optimizer with GSAM. The results show that SharpBalance outperforms all these baseline methods.
Weakness 2
The scale of is determined by the hyperparameter . For each model, we choose the subset of data samples with the highest "per-data-sample sharpness" in accordance with the definition provided in Section 4.3 of the submitted paper. The subset is then formed by taking the union of all such subsets except the -th. For instance, in the case of a three-member ensemble trained on CIFAR10, we set as 50, and the scale of is about 70% of the training set. is determined at a specific training step and remains constant for the -th model until the end of the training, implying that its scale remains unchanged.
Weakness 3
We train each individual model with different random initializations and different data orderings, controlled by distinct random seeds.
Reference
[1] Arpit et al. Ensemble of averages: Improving model selection and boosting performance in domain generalization.
[2] Zhuang et al. Surrogate gap minimization improves sharpness-aware training.
I have read the authors' rebuttal to all reviewers, and I agree with Reviewer b225 that the current empirical improvements are marginal. However, I believe this paper still offers valuable insights on sharpness and diversity within deep ensembles. Therefore, I will maintain my score.
We sincerely thank the reviewer for their comments and acknowledging our insights and contributions. We will ensure that our paper is updated to include clarifications in the rebuttal.
Ensemble methods and sharpness-aware optimization techniques are well-known strategies for improving generalization. This work identifies a trade-off between sharpness and diversity, observing that reducing sharpness can diminish diversity and harm ensemble performance. Through theoretical and empirical analysis of this sharpness-diversity trade-off, the authors present SharpBalance, an algorithm for training ensembles with sharpness-aware solutions without sacrificing diversity. Evaluation results on CIFAR-10/100, TinyImageNet, and their corrupted variants confirm the effectiveness of SharpBalance.
优点
- Ensemble methods and sharpness-aware optimization techniques are both prominent approaches for improving generalization. The aim of this work, which combines these two approaches, is well-motivated.
- While the theoretical analysis uses the variance metric to indicate diversity, the experimental results show consistent trends across different diversity metrics. It suggests that the proposed analysis is widely applicable to the general concept of diversity.
- Extensive empirical results effectively validate the theoretical analysis. The summary plots of the results are generally highly readable.
缺点
- The evaluation results are centered exclusively on classification accuracy; since ensembling usually highlights both predictive accuracy and uncertainty, relying solely on accuracy to assess overall performance is insufficient.
- Specifically, for the corrupted CIFAR benchmark, uncertainty metrics like negative log-likelihood or expected calibration error are more important than test accuracy, but these aspects are not currently considered.
- It seems that all experiments were conducted exclusively with residual networks. It is essential to verify if the proposed analysis and algorithm are applicable to other architecture families as well.
问题
- It appears that the current emphasis is on logit-ensemble (lines 82-83). Does the same rationale apply when ensembling categorical probabilities (i.e., probability-ensemble)?
- In the proposed SharpBalance algorithm, it seems that the training data and objective for the i-th member are defined using other members (such as members i+1, i+2, as illustrated in the figure). Does this imply that in practice, each member is trained sequentially?
局限性
Section 5 addressed the limitations.
Weakness 1 and 2
In Figure 14 of the rebuttal PDF, we present the results for negative log-likelihood and expected calibration error. These uncertainty metrics exhibit trends similar to the accuracy metrics reported in the main paper: "Deep Ensemble + SAM" outperforms "Deep Ensemble", and our method outperforms both baselines. The experiments were conducted using ResNet-18 on CIFAR100, with metrics reported on corrupted datasets. Additionally, we observe that both metrics improve as the number of ensemble members increases for all three methods.
Weakness 3
We provide new experiments on transformer models on vision and language tasks. The results are shown in Table 8 of the rebuttal PDF. We show that "Deep Ensemble + SAM" outperforms the Deep Ensemble while SharpBalance still outperforms both baselines. This observation is consistent with the residual network results in Figure 7 of the submitted paper.
Here we describe the experimental setup. For vision tasks, we constructed the three-member ensemble by fine-tuning the pre-trained ViT-T/16 model on CIFAR100 dataset, evaluated on in-distribution and CIFAR100-C test set. For language tasks, we constructed the three-member ensemble by fine-tuning the pre-trained ALBERT-Base model on Microsoft Research Paraphrase Corpus (MRPC) dataset and evaluated the performance on its validation set. The hyperparameter search and setup are the same as in Appendix D.3 of the submitted paper.
Question 1
Yes, the rationale behind SharpBalance also applies when ensembling categorical probabilities. Studies in [1-3] reveal that the correlation between the output probabilities of probability-ensemble can significantly affect the classification error rate and uncertainty quantification. Therefore, obtaining more diverse ensemble members following the idea of SharpBalance is certainly beneficial to probability-ensembles.
We conducted additional experiments to verify this insight, with new results shown in Figure 15 of the rebuttal PDF. Results show that SharpBalance outperforms both baseline methods, including Deep Ensemble and "Deep Ensemble + SAM".
Question 2
SharpBalance trains each ensemble member in parallel, distinct from classical boosting strategies. In particular, the -th member's objective is computed based on the current status of the other members, and hence, there is no sequential dependency on the training of the individual models. In practice, the sharp-aware subsets for each model are selected at a synchronized time step that happens only once in the training process.
Reference
[1] Ryabinin et al. Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets.
[2] Brown et al. Managing Diversity in Regression Ensembles.
[3] Brocker et al. From ensemble forecasts to predictive distribution functions.
SharpBalance trains each ensemble member in parallel, distinct from classical boosting strategies.
Does SharpBalance only allow parallel training and not sequential training, similar to repulsive deep ensembles or particle-optimization-based variational inference? This is a crucial issue regarding scalability. If I can fit only one model in memory at a time and not multiple models, would this make SharpBalance impractical in such situations?
We thank the reviewer's response and would like to clarify how SharpBalance can be applied to sequential training when memory constraints only allow for training one model at a time. Adapting our parallel training pipeline to a sequential approach is straightforward, and we'll outline the process below.
In sequential SharpBalance training, each model is iteratively trained to a predefined timestep using the full dataset. This timestep corresponds to the synchronization point in parallel training. Once all models reach this point, SharpBalance partitions the dataset into a sharpness-aware set and a normal set for each model. This partitioning process can be done sequentially, fitting one model at a time in memory to compute the "per-data-sample sharpness". The other models are used only to determine the partition for the -th model and are not required for subsequent computations. Finally, each model is trained to completion using its respective sharpness-aware and normal sets.
Regarding the scalability of SharpBalance in sequential training:
- The dataset partition, which is the most computationally intensive procedure, occurs only once throughout the entire training process.
- The main computational bottleneck in the dataset partition is the "per-data-sample sharpness" calculation. However, this can be done efficiently by evaluating one model at a time, as it doesn't involve pairwise interactions between models. In a parallel computing scenario, only an ordered list of sharpness values needs to be transmitted from each ensemble member.
- As the number of ensemble members () increases, the dataset partition for each model essentially becomes a set union problem of sorted arrays. Using min-heaps, this operation can be completed in time, where is the number of distinct items in the union. If we consider as a constant, the time complexity of the partitioning grows linearly with respect to the number of ensemble members asymptotically. Further optimizations on constants are possible by exploiting set theory: the union of subsets can be efficiently computed by subtracting the -th model's sharpness-aware subset from the union of all sharpness-aware sets. It's worth noting that these set unions are performed on index sets, resulting in small computational overheads.
- In most of our experiments, we can demonstrate improvement using a few ensemble members. We do not see a big improvement after increasing the number of ensemble members to more than five, similar to the phenomenon reported in Section 4.2 of [1].
- We think the technique suggested by the reviewer might be from the paper 'Repulsive Deep Ensembles are Bayesian,' which suggests that as → infinity, the KDE approximation can converge to the true Bayesian posterior. As discussed earlier, our method will not incur substantial overhead as increases because the interactions between ensemble members occur through fast set union computations.
Reference
[1] Ovadia et al. Can you trust your model's uncertainty? evaluating predictive uncertainty under dataset shift.
Thank you for the further clarification. I believe that the NLL and ECE results for corrupted CIFAR benchmarks shared by the authors will help emphasize the advantages of the proposed SharpBalance even more. It would be helpful if you could also provide the NLL and ECE results for the existing test data in addition to the corrupted data. I am maintaining my current score as I am inclined to accept this work.
We thank the reviewer for the constructive feedback. The additional experiment results using ECE and NLL metrics on the in-distribution (existing) test data are provided in the following tables. Results show that SharpBalance outperforms the two other baselines in both metrics on the existing test data. We will include the new results and clarification in the updated draft.
Table. ECE metric on CIFAR100. "ECE" represents the expected calibration error. The model architecture is ResNet-18.
| Method / Number of models | 2 | 3 | 4 | 5 | 6 |
|---|---|---|---|---|---|
| Deep Ensemble | 0.0310 | 0.0206 | 0.0191 | 0.0186 | 0.0180 |
| Deep Ensemble + SAM | 0.0184 | 0.0179 | 0.0160 | 0.0151 | 0.0141 |
| SharpBalance | 0.0151 | 0.0142 | 0.0130 | 0.0128 | 0.0110 |
Table. NLL metric on CIFAR100. "NLL" represents negative log-likelihood. The model architecture is ResNet-18.
| Method / Number of models | 2 | 3 | 4 | 5 | 6 |
|---|---|---|---|---|---|
| Deep Ensemble | 0.820 | 0.785 | 0.763 | 0.749 | 0.742 |
| Deep Ensemble + SAM | 0.742 | 0.724 | 0.718 | 0.711 | 0.705 |
| SharpBalance | 0.732 | 0.720 | 0.715 | 0.709 | 0.695 |
We want to thank all the reviewers for the constructive feedback, which helps us improve our paper. Please refer to the attached PDF for our new experiments and see below for our responses to each comment.
Dear Reviewers,
The deadline for the reviewer-author discussion period is approaching. If you haven't done so already, please review the rebuttal and provide your response at your earliest convenience.
Best wishes, AC
While none of the reviewers expressed strong enthusiasm for the paper, they agree that it presents a solid framework offering a theory on the sharpness-diversity tradeoff, along with an algorithm to enhance the diversity of flat ensembles. The topic is very important to the ensemble learning and Bayesian deep learning communities, and I believe the proposed theory on the tradeoff would be a valuable contribution to the field. Although there are some valid concerns about the experiments, I believe these do not outweigh the merits of the work.