Learn2Mix: Training Neural Networks Using Adaptive Data Integration
This work introduces learn2mix, a new training strategy that adaptively adjusts class proportions in batches to accelerate neural network convergence in resource-constrained environments.
摘要
评审与讨论
This paper proposes a novel training strategy meant to optimize training convergence by adaptively adjusting class proportions within batches during training. Specifically, this work is motivated by the challenge posed by imbalanced training datasets. This method is called Learn2Mix, and differs from existing methods that address imbalance by adjusting sample weights or by resampling, since it dynamically adjusts the composition of batches during training based on per-class error rates. The results show that convergence is faster with learn2mix (if conditions such as that class proportions converge), overfitting is prevented, and that this method works across several tasks (classification, regression, and reconstruction). Datasets include MNIST, Fashion-MNIST, CIFAR-10, Imagenette, CIFAR-100, IMDB, Wine Quality, and California Housing. Results for all tasks across all datasets indicate best performance for Learn2Mix, relative to several benchmark and SoTA methods.
优缺点分析
Quality:
- This paper seems to be technically sound and claims seem to be supported, both through analysis (theory in Section 2) and through experiments (Section 4), though I was not able to verify all proofs.
- This seems to be a complete work, wherein evaluations are comprehensive.
- The authors do not comment on the weaknesses or limitations of their work.
Clarity:
- I found this paper to be unusually heavy with notation. This meant that reading and understanding was encumbered, even though explanations were perhaps more precise than in the average paper. I would suggest reducing the heavy reliance on elaborate notation, and instead move these parts to the appendices.
- Section 3 was also difficult to follow, with heavy reliance on elaborate notation. I would suggest describing in writing more details of how the class proportions are updated, since this remained somewhat unclear.
- Otherwise, the paper is clear and has enough detail to reproduce results.
- The introduction section is well-written and clear, especially the listing of contributions, though background is rather lengthy.
- In Definition 2.2, it would seem that the loss for Learn2Mix is really the same, essentially; however, it is just that the mixing parameters are fixed for Classical training and are adaptive for Learn2Mix training. The only real piece of new information in this definition is the description, Eq. (3), of how the mixing parameters are updated. I would suggest revising this to emphasize the update step.
- In the paragraph starting on line 206, it is rather confusing reading about regression tasks with categorical variables. Is this not classification? Do you instead mean that the dataset is comprised of several distinct variables?
Significance:
- This work appears to be significant in that the method is novel and that it advances the field relative to all compared benchmark methods.
- This is a meaningful step towards the overall optimization of the training of deep neural networks.
Originality:
- The method is original and novel.
问题
- In some of the results, such as those of Fig 3 (c), the Learn2Mix training error remains relatively high compared to the Classical training error, whereas test error is improved relative to classical. Is there some insight to be learned from this? Why does train error remain higher?
- Are there any downsides to your method relative to SoTA? This was not discussed. Could you imagine any types of scenarios in which your method would perform poorly or where existing methods would be preferable?
- Given that the notation was rather elaborate here, could you please describe in words how the class proportions are updated during training? Where does gamma come from? In Eq. (3), what is the meaning of the ‘1’ function in the denominator?
局限性
- Are there any downsides to your method relative to SoTA? This was not discussed. Could you imagine any types of scenarios in which your method would perform poorly or where existing methods would be preferable?
- Additionally, consider any societal impacts, perhaps related to biases in training datasets and whether these could be exacerbated.
最终评判理由
I maintain my rating of 5-accept. Authors have addressed all my concerns, especially those surrounding improving clarity. I have not updated my ratings in my review (based on original submission) though my concerns were addressed in rebuttals. See my comments for details.
格式问题
None.
Dear Reviewer 43t3,
Thank you for your comprehensive feedback. We will improve the terminology and the presentation of the paper as suggested. Below, we have provided responses to your comments.
I found this paper to be unusually heavy with notation. This meant that reading and understanding was encumbered, even though explanations were perhaps more precise than in the average paper. I would suggest reducing the heavy reliance on elaborate notation, and instead move these parts to the appendices. Section 3 was also difficult to follow, with heavy reliance on elaborate notation. I would suggest describing in writing more details of how the class proportions are updated, since this remained somewhat unclear.
Thank you for the constructive feedback. We will simplify our notation and expand the prose description of the class-proportion update in Section 3 of the revised manuscript.
In Definition 2.2, it would seem that the loss for Learn2Mix is really the same, essentially; however, it is just that the mixing parameters are fixed for Classical training and are adaptive for Learn2Mix training. The only real piece of new information in this definition is the description, Eq. (3), of how the mixing parameters are updated. I would suggest revising this to emphasize the update step.
Thank you for this suggestion. We will revise Definition 2.2 to streamline the loss formulation and emphasize the mixing parameter update.
In the paragraph starting on line 206, it is rather confusing reading about regression tasks with categorical variables. Is this not classification? Do you instead mean that the dataset is comprised of several distinct variables?
Recalling Section 4, we write that “for regression tasks with a categorical variable taking distinct values, the samples from that correspond to each of the values can be aggregated to obtain each class-specific training dataset, .” Here, “categorical variable” simply means any feature that can be used to partition the data into groups. Many regression benchmarks include such categorical features (e.g., California Housing), which can be leveraged to group samples by similarity. Learn2mix treats each group exactly like a “class”, wherein it maintains one mixing parameter, , per group and adapts batch proportions based on each group’s error. No true class labels are required: any meaningful partition of the dataset suffices. We will clarify this grouping interpretation in the revised manuscript.
In some of the results, such as those of Fig 3 (c), the Learn2Mix training error remains relatively high compared to the Classical training error, whereas test error is improved relative to classical. Is there some insight to be learned from this? Why does train error remain higher?
This pattern reflects reduced overfitting under learn2mix: as we note in the manuscript, learn2mix “has a tighter alignment between training and test errors versus classical training. This correspondence indicates reduced overfitting, as learn2mix inherently adjusts class proportions based on class-specific error rates, . By biasing the optimization procedure away from the original class distribution and towards , learn2mix achieves improved generalization.” In other words, learn2mix reallocates learning effort toward higher-error groups instead of exhaustively minimizing loss on easier ones, resulting in a modestly higher overall training error and improved test performance (the aforementioned “tighter alignment”).
Are there any downsides to your method relative to SoTA? This was not discussed. Could you imagine any types of scenarios in which your method would perform poorly or where existing methods would be preferable? Additionally, consider any societal impacts, perhaps related to biases in training datasets and whether these could be exacerbated.
As mentioned above, the application of learn2mix to regression and reconstruction tasks requires samples to be partitioned into groups by similarity (e.g., via a categorical variable taking values). Likewise, in settings with extreme, systematic label corruption inside a group, learn2mix (like other loss-driven reweighting methods such as focal loss or importance sampling) could over-allocate updates to this group. We will include these limitations in the updated manuscript.
Given that the notation was rather elaborate here, could you please describe in words how the class proportions are updated during training? Where does gamma come from? In Eq. (3), what is the meaning of the ‘1’ function in the denominator?
At the end of each epoch, we compute the average loss for each of the classes, forming the vector . We then normalize this vector into a probability distribution by dividing by its sum: , where is a length- vector of ones (the “1” in the denominator), which ensures the entries sum to one. Next we nudge the current mixing vector toward that distribution:
- ,
whereby classes with higher losses receive a larger share of samples in the next epoch. The scalar mixing rate, , is a user-defined step size hyperparameter that controls how aggressively moves. In our experiments, we select via a small validation sweep. This update gradually reallocates batch proportions toward underperforming classes.
Thank you for your positive feedback and support of our paper. We sincerely appreciate your insightful comments, which has helped improve our work.
Thank you for the response. In general, this helps and improves the clarity of your paper and my understanding.
In a few cases, you have made promises to make updates ("We will simplify our notation and expand the prose description of the class-proportion update in Section 3 of the revised manuscript.", and "We will revise Definition 2.2 to streamline the loss formulation and emphasize the mixing parameter update."), but it would be most useful during this review stage if you could provide a snippet of what your planned revision would be (i.e., show me your revised notation, etc.).
Lastly, regarding the selection of gamma, please include this explanation in your paper. In the footnote on page 6, this is not mentioned and there is no linking to "empirical results".
Dear Reviewer 43t3,
Thank you for your follow-up comments. Regarding the selection of , we will include the explanation that we select via a small validation sweep in the updated footnote. We will also include empirical results demonstrating the relative insensitivity to the choice of (which we included in our response to Reviewer Ppx6).
As the reviewer has suggested, we have provided an updated version of Definition 2.2 below that better describes the mixing parameter update step (see italicized portions below). We include Eq. (2) solely to make explicit that the loss is evaluated with the current (time-varying) mixing parameters, differentiating it from classical training.
Definition 2.2 (Loss Function for Learn2Mix Training):
Consider as the vector of mixing parameters at time and time , and let denote the respective class-wise loss vectors at time and time . Consider as the mixing rate. The loss for learn2mix training at time is given by:
- Where:
where is a length- vector of ones.
We note that the denominator, , in Eq. (3) is the sum of losses across all classes, and dividing by it converts into a probability distribution. We update by nudging the mixing parameters toward this probability distribution, whereby classes with higher losses receive a larger share of samples in the next time step. The scalar mixing rate, , is a user-defined step size hyperparameter that controls how aggressively moves. We note that classical training is recovered by setting , which keeps the mixing parameters fixed across time steps.
Regarding Section 3, we note that the paragraph starting at line 173 in the current manuscript is now better aligned with the updated Definition 2.2 provided above. For additional clarity, we have revised it below:
Algorithm 2 outlines the method for dynamically adjusting class proportions using the mixing parameters, , informed by computed class-wise losses. Specifically, for each class , we first calculate the normalized loss by dividing the class-specific loss by the total cumulative loss summed over all classes. Each mixing parameter, , is then updated incrementally towards this normalized loss value . The magnitude of the update step is controlled by the mixing rate, , determining how quickly the proportions adapt. Thus, classes exhibiting higher relative losses are progressively given greater emphasis in subsequent training epochs, ensuring a balanced reduction of errors across all classes.
We hope these revised descriptions improve readability and will incorporate them into the updated manuscript. We welcome any further suggestions from the reviewer to enhance clarity.
Thank you for providing the updated text. I am content with the proposed changes you have provided.
Regarding the denominator, I suggest changing this to a sum for clarity (I personally find an explicit sum clearer / more obvious than matrix multiplication with a vector of ones), as is described in the text. I suggest changing the denominator
to
However, this is at your discretion. Overall, I am pleased with your response.
[Comment edited to get latex support working]
Learn2mitx introduces an adaptive batching rule that increases the proportion of hard classes inside every mini-batch.
优缺点分析
Weakness:
- Assumptions like every classwise loss being strongly convex and lipschitz in network parameters are too strong and do not apply to modern neural networks.
- The paper’s experiment section does not scale to a reasonably large dataset like ImageNet.
- The difficulty of a loss is judged solely by the loss at the step, classes with high label noise (wrong labels) could receive extra sampling and can dominate training. This could be happening for a real-world dataset while the dataset tested in the paper does not cover this issue.
- The appendix lists architectures and a single γ per dataset, but there is no sweep or variance report. For any other people who may want to try learn2mix, it is hard to determine which mixing rate to set.
- More analysis and visualization on how learn2mix affect per-class accuracy could be presented to understand the proposed method better.
- The proposed method can only be applied to classification tasks.
问题
see weaknesses
局限性
n/a
最终评判理由
Overall I acknowledge that this paper achieve good results within its scope (as pointed out by other reviewers), and should definitely be accepted if it were a NeurIPS 2020 submission. However, I find significant gap between this paper and deep learning in 2025. I will keep my current rating.
格式问题
n/a
Dear Reviewer CXfQ,
Thank you for your helpful feedback. Below, we have provided responses to your comments.
Assumptions like every classwise loss being strongly convex and lipschitz in network parameters are too strong and do not apply to modern neural networks.
We note that assumptions like strong convexity are standard in the machine learning literature when deriving convergence rates for gradient-based optimization methods [1], [2]. These assumptions provide a manageable framework to obtain rigorous and provable insights into the behavior of optimization algorithms. While the strong convexity assumption may not hold strictly in practical neural network training, our theoretical findings offer valuable intuition about the convergence properties of learn2mix. Moreover, our empirical results demonstrate that learn2mix effectively accelerates convergence in real-world non-convex settings, reinforcing the practical relevance of our work.
The paper’s experiment section does not scale to a reasonably large dataset like ImageNet.
For the present study, we intentionally restricted ourselves to CIFAR-100-LT (using a logarithmically decaying imbalance factor to exhibit long-tail behavior), Imagenette (using a stepwise imbalance factor), and similar benchmarks because they are canonical in the imbalanced learning literature [1], admit controllable imbalance factors, and can be reproduced without access to large compute resources. Imagenette itself is a carefully chosen subset of ImageNet and reflects realistic high-dimensional image distributions. Additionally, our wall-clock results obtained on an NVIDIA GeForce RTX 3090 GPU confirm that learn2mix introduces no scaling bottlenecks beyond standard SGD.
The difficulty of a loss is judged solely by the loss at the step, classes with high label noise (wrong labels) could receive extra sampling and can dominate training. This could be happening for a real-world dataset while the dataset tested in the paper does not cover this issue.
We appreciate this point raised by the reviewer. However, we emphasize that this issue of incorrectly labeled (high-noise) classes dominating training is not specific to learn2mix, and is shared by all error-driven sampling methods in the imbalanced learning space, such as class-balanced loss functions (e.g., focal loss) and importance sampling. As stated previously, our present study focuses on a variety of canonical and standard benchmarks (e.g., CIFAR-100-LT [1]), since these datasets are widely used in the imbalance learning literature for comparisons under controlled and reproducible conditions. Extending learn2mix for applications in high label-noise scenarios remains an important direction for future work.
The appendix lists architectures and a single per dataset, but there is no sweep or variance report. For any other people who may want to try learn2mix, it is hard to determine which mixing rate to set. More analysis and visualization on how learn2mix affect per-class accuracy could be presented to understand the proposed method better.
We thank the reviewer for these suggestions. Regarding , as we recommend in Section 4 of the main text, choosing yields the best performance. We provide additional empirical results below for the worst per-class classification accuracy on Imagenette and IMDB for to demonstrate relative insensitivity to the choice of . We will include this result in the revised manuscript.
Table 1: Worst-class test performance comparison for ResNet-18 on Imagenette
| Method | Worst-Class Accuracy (57.5 s) | Worst-Class Accuracy (115.0 s) | Worst-Class Accuracy (230.0 s) |
|---|---|---|---|
| L2M () | 0.0000 % | 2.3866 % | 14.1623 % |
| L2M () | 1.3149 % | 10.9318 % | 14.1451 % |
| L2M () | 0.2542 % | 3.2815 % | 15.6304 % |
| CL | 0.7779 % | 0.6205 % | 0.6325 % |
| FCL | 0.0000 % | 0.0000 % | 4.8162 % |
| SMOTE | 0.0000 % | 0.0000 % | 8.9119 % |
| IS | 0.0000 % | 0.0000 % | 0.7924 % |
| CURR | 0.0000 % | 0.2482 % | 2.8140 % |
Table 2: Worst-class test performance comparison for Transformer on IMDB
| Method | Worst-Class Accuracy (38.0 s) | Worst-Class Accuracy (75.0 s) | Worst-Class Accuracy (150.0 s) |
|---|---|---|---|
| L2M () | 46.2067 % | 77.4133 % | 79.2200 % |
| L2M () | 40.2867 % | 76.6267 % | 80.7400 % |
| L2M () | 60.3733 % | 74.4467 % | 81.8533 % |
| CL | 0.0000 % | 47.6400 % | 67.3933 % |
| FCL | 0.0000 % | 34.5333 % | 60.7200 % |
| SMOTE | 11.3200 % | 28.7867 % | 62.5400 % |
| IS | 24.8467 % | 57.7200 % | 69.6267 % |
| CURR | 0.0000 % | 48.4400 % | 66.0333 % |
The proposed method can only be applied to classification tasks.
We note that the reviewer’s claim is not correct; we explicitly detail how learn2mix can be applied to regression and reconstruction tasks (not just classification tasks) in Section 4 of the main text through its ability to adaptively handle different data distributions. We further provide full empirical results in Section 4 on standard benchmarks for regression and image reconstruction, demonstrating that learn2mix maintains accelerated convergence over classical training.
Thank you again for your valuable feedback, which has been important in helping improve our paper.
References:
[1] Guanghui Wang, Shiyin Lu, Quan Cheng, Wei wei Tu, and Lijun Zhang. SAdam: A variant of adam for strongly convex functions. In International Conference on Learning Representations, 2020.
[2] Arindam Banerjee, Pedro Cisneros-Velarde, Libin Zhu, and Misha Belkin. Restricted Strong Convexity of Deep Learning Models with Smooth Activations. In International Conference on Learning Representations, 2023.
Dear Reviewer CXfQ,
Thank you for your thoughtful feedback. We note that recent works from the past year continue to adopt CIFAR-100-LT [3] as a canonical benchmark for long-tailed imbalanced learning, including papers from NeurIPS, ICML, and CVPR [4-7], affirming its relevance in current literature. We understand your broader concern about applicability beyond class-labeled data; however, learn2mix operates at the level of distributions wherein samples from the dataset can be grouped by similarity (via any user-defined similarity metric) and does not require class labels in a strict sense. The mean estimation benchmark from our empirical results is an example of this.
We appreciate the reviewer’s perspective regarding modern datasets; however, this claim does not fully reflect current practice. For instance, the widely used GLUE benchmark [8] comprises several class-labeled datasets, and remains a primary benchmark for training and evaluating language models, despite being significantly older than CIFAR-100-LT. If the age of a dataset alone were a disqualifying factor, then under the reviewer’s criterion, a large body of recent work evaluating on GLUE would also be misaligned with the current state of the field.
We would like to clarify that the CIFAR-100-LT dataset referenced in our study originates from AAAI 2021 [3], which was not present in our earlier rebuttal (we note this reference was included in our responses to the other reviewers). The ICLR 2020/2023 citations from our earlier rebuttal pertained to theoretical assumptions (e.g., strong convexity) used in our analysis.
We appreciate your recognition of the paper’s contributions within its defined scope.
References:
[3] Zhang, Yongshun, et al. "Bag of tricks for long-tailed visual recognition with deep convolutional neural networks." Proceedings of the AAAI conference on artificial intelligence. Vol. 35. No. 4. 2021.
[4] Li, Mengke, et al. "Improving visual prompt tuning by Gaussian neighborhood minimization for long-tailed visual recognition." Advances in Neural Information Processing Systems 37 (2024): 103985-104009.
[5] Shao, Jie, et al. "DiffuLT: Diffusion for Long-tail Recognition Without External Knowledge." Advances in Neural Information Processing Systems 37 (2024): 123007-123031.
[6] Gao, Jintong, et al. "Distribution alignment optimization through neural collapse for long-tailed classification." Forty-first International Conference on Machine Learning. 2024.
[7] Rangwani, Harsh, et al. "Deit-lt: Distillation strikes back for vision transformer training on long-tailed datasets." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.
[8] Wang, Alex, et al. "GLUE: A multi-task benchmark and analysis platform for natural language understanding." arXiv preprint arXiv:1804.07461 (2018).
Thank authors for the additional replies.
GLUE is considered outdated. We no longer see GLUE as a benchmark evaluated when new LLMs are released these days. I acknowledge that there seems to be a fundamental misalignment between authors and me on this matter. However, I do believe a work that emphasizes its efficiency in training/converging should keep the large-scale dataset/training in mind. Due to the huge cost in data labeling, most probably we won't see a dataset with class labels that are larger than ImageNet-21K.
I encourage authors to think how their approach could be applied to datasets with no labels (e.g. use some kind of grouping approaches). For example, for LLM training, could authors' method be applied to balance data mixture from different domains? Finally, I also encourage authors to revise their paper with lighter notation, as also mentioned by other reviewers.
I will raise my score to 3 (borderline reject) as I see a potential in authors' method addressing data imbalance issue in a setting of larger scale. However, I hold my original standpoint where there is a gap between this work and current deep learning practice.
Thanks authors for the rebuttal.
For the present study, we intentionally restricted ourselves to CIFAR-100-LT (using a logarithmically decaying imbalance factor to exhibit long-tail behavior), Imagenette (using a stepwise imbalance factor), and similar benchmarks because they are canonical in the imbalanced learning literature [1], admit controllable imbalance factors, and can be reproduced without access to large compute resources. Imagenette itself is a carefully chosen subset of ImageNet and reflects realistic high-dimensional image distributions. Additionally, our wall-clock results obtained on an NVIDIA GeForce RTX 3090 GPU confirm that learn2mix introduces no scaling bottlenecks beyond standard SGD.
The paper authors are citing here to justify using a relatively small dataset is from ICLR 2020. As a fast-evolving field, I doubt whether empirical results on MNIST, CIFAR, etc. are still of value in 2025. I encourage authors to justify this by finding papers that are published more recently while conducting small-scale experiments.
We appreciate this point raised by the reviewer. However, we emphasize that this issue of incorrectly labeled (high-noise) classes dominating training is not specific to learn2mix, and is shared by all error-driven sampling methods in the imbalanced learning space, such as class-balanced loss functions (e.g., focal loss) and importance sampling. As stated previously, our present study focuses on a variety of canonical and standard benchmarks (e.g., CIFAR-100-LT [1]), since these datasets are widely used in the imbalance learning literature for comparisons under controlled and reproducible conditions. Extending learn2mix for applications in high label-noise scenarios remains an important direction for future work.
Although assuming a clean dataset seems to be standard practice of the field, I encourage authors to consider this (by mixing in artificial label errors).
The proposed method can only be applied to classification tasks.
We note that the reviewer’s claim is not correct; we explicitly detail how learn2mix can be applied to regression and reconstruction tasks (not just classification tasks) in Section 4 of the main text through its ability to adaptively handle different data distributions. We further provide full empirical results in Section 4 on standard benchmarks for regression and image reconstruction, demonstrating that learn2mix maintains accelerated convergence over classical training.
My apologies for the confusion. However, here I mean that the method requires at least the concept of classes exists for a dataset. However, the large-scale datasets at present, be it vision-language dataset (LAION) or language dataset, do not have class labeling anymore.
Overall I acknowledge that this paper achieve good results within its scope (as pointed out by other reviewers), and should definitely be accepted if it were a NeurIPS 2020 submission. However, I find significant gap between this paper and deep learning in 2025. I will keep my current rating.
This paper proposes an adaptive training strategy, learn2mix, which dynamically adjusts the proportion of classes within a batch based on real-time class error rates. The proposed method presents accelerated convergence in classification, regression, and reconstruction tasks, and is particularly effective in resource-constrained and class-imbalanced situations. Extensive experiments demonstrate the effectiveness of the proposed method.
优缺点分析
Strengths
- The proposed adaptive training strategy with real-time class-wise error rates is interesting.
- This paper builds a unified framework applicable to classification, regression, and reconstruction tasks.
- This paper provides a comprehensive theoretical foundation for the proposed methodology.
- The authors provide clear algorithmic implementations for reproduction.
Weaknesses
- The authors conducted experiments using classical CNNs and Transformers, while lacking validation of recently advanced architectures such as diffusion models and Mamba.
- The computational overhead of real-time error calculation for massive k is unquantified.
- The experiments are conducted on small-scale datasets, lacking validations on large-scale datasets, such as ImageNet-21k.
问题
- Whether the proposed training strategy is applicable to other network architectures besides CNN and Transformer.
- Whether the sequence length in Transformers impacts adaptive batching.
局限性
Please refer to the weakness.
最终评判理由
Although the authors do not provide evidence that the proposed method is equally applicable to other architectures, I still believe that the method is interesting and innovative enough to be published at NIPS. However, I suggest that the authors conduct experiments to prove the effectiveness of the proposed method on diffusion models and Mamba. I would like to keep my initial score.
格式问题
N/A
Dear Reviewer wW5J,
Thank you for your constructive feedback. Below, we have provided responses to your comments.
The authors conducted experiments using classical CNNs and Transformers, while lacking validation of recently advanced architectures such as diffusion models and Mamba.
We consider standard backbones including ResNet‑18, MobileNet‑V3, and Transformers, across classification, regression, and reconstruction tasks to demonstrate the efficacy of learn2mix under diverse yet reproducible settings. The empirical results in Section 4 of the main text and Section B of the Appendix show that performance gains are consistent, indicating that learn2mix is largely architecture-agnostic. We note that learn2mix does not assume a particular model family. Thus, the U‑Nets used in diffusion models and sequence models such as Mamba can be trained with the same procedure. Exploring domain-specific architectures is an important extension of learn2mix that we plan to explore in future work.
The computational overhead of real-time error calculation for massive k is unquantified.
We note that the computational overhead of real-time error calculation is quantified in our empirical results for classification tasks in Section 4 of the main text, and in Section B of the Appendix. All reported results, including those for CIFAR-100 (where we consider a large value of ), are benchmarked in terms of elapsed wall-clock training time on an NVIDIA GeForce RTX 3090 GPU, which accounts for all per-class error computations. As such, the overhead is fully integrated into our results.
The experiments are conducted on small-scale datasets, lacking validations on large-scale datasets, such as ImageNet-21k.
For the present study, we intentionally restricted ourselves to CIFAR-100-LT (using a logarithmically decaying imbalance factor to exhibit long-tail behavior), Imagenette (using a stepwise imbalance factor), and similar benchmarks because they are canonical in the imbalanced learning literature [1], admit controllable imbalance factors, and can be reproduced without access to large compute resources. Imagenette itself is a carefully chosen subset of ImageNet and reflects realistic high-dimensional image distributions. Additionally, our wall-clock results obtained on an NVIDIA GeForce RTX 3090 GPU confirm that learn2mix introduces no scaling bottlenecks beyond standard SGD.
Whether the proposed training strategy is applicable to other network architectures besides CNN and Transformer.
We note that learn2mix is architecture agnostic because it alters only the dataloader side of training, adapting class proportions within batches to accelerate convergence. We demonstrate the efficacy of learn2mix on a variety of neural network architectures apart from standard CNNs and transformers, including residual networks, mobile networks, and autoencoders.
Whether the sequence length in Transformers impacts adaptive batching.
We clarify that learn2mix is agnostic to sequence length (and more broadly to dataset dimensionality), since it adaptively modulates training based solely on the proportion of samples from each class, guided by instantaneous class-wise error rates. The sequence length of transformers pertains to input dimensionality rather than class proportions, and as demonstrated consistently across empirical evaluations, covering datasets of varying dimensions (e.g., CIFAR-100, Imagenette, IMDB), learn2mix’s adaptive batching operates independently of sequence length.
Thank you again for your feedback, which has been important in improving our paper.
References:
[1] Zhang, Yongshun, et al. "Bag of tricks for long-tailed visual recognition with deep convolutional neural networks." Proceedings of the AAAI conference on artificial intelligence. Vol. 35. No. 4. 2021.
Thank you for your responses. Although the authors do not provide evidence that the proposed method is equally applicable to other architectures, I still believe that the method is interesting and innovative enough to be published at NIPS. However, I suggest that the authors conduct experiments to prove the effectiveness of the proposed method on diffusion models and Mamba. I would like to keep my initial score.
This paper introduces learn2mix, a training strategy that dynamically adjusts class proportions within batches based on real-time class-wise error rates to accelerate convergence. Unlike classical training that maintains fixed class proportions mirroring the dataset distribution, learn2mix continuously adapts batch composition by emphasizing classes with higher instantaneous error rates. The authors provide theoretical convergence guarantees under strong convexity and Lipschitz continuity assumptions, and demonstrate consistent empirical improvements across classification, regression, and reconstruction tasks on multiple benchmark datasets.
优缺点分析
Strengths
Clear Problem Motivation: The paper addresses a well-motivated problem in neural network training. The intuition that harder classes should receive more attention during training is compelling and well-articulated.
Solid Theoretical Foundation: The authors provide rigorous convergence analysis showing that under appropriate conditions (strong convexity, Lipschitz continuity), learn2mix converges faster than classical training.
Consistent Performance Gains: The results demonstrate consistent improvements across different settings, with learn2mix achieving faster convergence in nearly all evaluated scenarios. The improvements appear robust to architectural choices, optimizers, and hyperparameter settings based on the ablation studies.
Weaknesses
Narrow Baseline Comparison: While multiple baselines are included, the comparison relies on dated methods for imbalanced learning (SMOTE from 2002, Focal Loss from 2017, curriculum learning variant from 2019). Recent advances in imbalanced learning and adaptive training strategies are not considered, limiting assessment of competitiveness against state-of-the-art approaches.
Limited Scale Evaluation: The evaluation is restricted to relatively small benchmark datasets (MNIST, CIFAR-10/100, etc.) without demonstration on larger, more challenging datasets.
Restrictive Theoretical Assumptions: The convergence guarantees rely on strong convexity and Lipschitz continuity assumptions that rarely hold for neural networks in practice.
Limited Algorithmic Novelty: The core insight of emphasizing difficult examples is well-established in machine learning literature. While the specific implementation of adaptive batch composition is somewhat novel, the fundamental approach is incremental rather than groundbreaking.
Lack of Computational Analysis: The paper provides no analysis of computational overhead. The dynamic batch construction and class-wise loss computation likely introduce non-negligible computational costs that could offset the benefits of faster convergence.
问题
-
Theoretical Gap: How do the authors reconcile the restrictive theoretical assumptions with the practical success of neural networks, where these assumptions clearly don't hold? Can a tighter analysis be provided for the non-convex case?
-
Computational Overhead: What is the computational cost of the dynamic batch construction and class-wise loss computation? How does this overhead compare to the time savings from faster convergence?
-
Scalability: How does the method scale to datasets with very large numbers of classes?
局限性
The authors adequately acknowledge some limitations, but could strengthen this discussion. The gap between theoretical assumptions and practical applicability should be emphasized more prominently. Additionally, the paper would benefit from explicit discussion of computational overhead, hyperparameter sensitivity, and scenarios where the method might be less effective.
The evaluation is also limited to relatively small benchmark datasets and does not demonstrate scalability to larger, more challenging datasets. Furthermore, the comparison baselines for imbalanced learning are dated. More recent advances in imbalanced learning and adaptive training strategies from the past few years are not considered, which limits the assessment of learn2mix's competitiveness against state-of-the-art approaches.
最终评判理由
Computational overhead: Measuring wall-clock time adequately addresses this concern.
Baselines: Your justification doesn't resolve the core issue. Comparing against 2002-2019 methods in 2025 is problematic regardless of their "wide adoption." Recent work in adaptive sampling and curriculum learning would provide more meaningful comparisons.
Theory-practice gap: Saying strong convexity assumptions are "standard" sidesteps the real problem. These assumptions are violated in virtually all neural network settings, making the theoretical guarantees largely irrelevant. This limitation needs honest acknowledgment.
Scale: You didn't demonstrate scalability to large datasets or many classes. Wall-clock comparisons on small benchmarks don't address whether the method works on datasets with thousands of classes. The additional ablation results are useful, but the fundamental concerns about limited evaluation scope and outdated baselines remain unaddressed.
格式问题
No
Dear Reviewer w9yV,
Thank you for your helpful feedback. Below, we have addressed your comments.
While multiple baselines are included, the comparison relies on dated methods for imbalanced learning (SMOTE from 2002, Focal Loss from 2017, curriculum learning variant from 2019). Recent advances in imbalanced learning and adaptive training strategies are not considered, limiting assessment of competitiveness against state-of-the-art approaches.
Our chosen benchmarks intentionally cover representative methods that have achieved wide acceptance and continued relevance in the machine learning community. These benchmarks cover the principal methods of imbalanced learning, including class-balanced loss functions (2017), importance sampling (2022/2018), oversampling (2017/2002), and curriculum learning (2019), to provide a comprehensive comparison. While newer approaches continue to emerge, the selected benchmarks remain broadly adopted in practice.
The evaluation is restricted to relatively small benchmark datasets (MNIST, CIFAR-10/100, etc.) without demonstration on larger, more challenging datasets. How does the method scale to datasets with very large numbers of classes?
Investigating the efficacy of learn2mix in these settings remains an important aspect of our ongoing research. For the present study we intentionally chose benchmarks that are canonical in the imbalanced learning literature, admit controllable imbalance factors, and can be reproduced without access to large clusters. CIFAR-100-LT (using a logarithmically decaying imbalance factor to exhibit long-tail behavior) and Imagenette (a subset of the high-dimensional ImageNet dataset) with stepwise imbalance satisfy these criteria and are standard in existing literature [1].
The convergence guarantees rely on strong convexity and Lipschitz continuity assumptions that rarely hold for neural networks in practice. How do the authors reconcile the restrictive theoretical assumptions with the practical success of neural networks, where these assumptions clearly don't hold? Can a tighter analysis be provided for the non-convex case?
We note that assumptions like strong convexity are standard in the machine learning literature when deriving convergence rates for gradient-based optimization methods [2], [3]. These assumptions provide a manageable framework to obtain rigorous and provable insights into the behavior of optimization algorithms. While the strong convexity assumption may not hold strictly in practical neural network training, our theoretical findings offer valuable intuition about the convergence properties of learn2mix. Moreover, our empirical results demonstrate that learn2mix effectively accelerates convergence in real-world non-convex settings, reinforcing the practical relevance of our work.
The core insight of emphasizing difficult examples is well-established in machine learning literature. While the specific implementation of adaptive batch composition is somewhat novel, the fundamental approach is incremental rather than groundbreaking.
While prioritizing hard examples is a long-standing idea, our contribution lies in how this prioritization is realized. Learn2mix is the first method, to our knowledge, that adapts class proportions in real-time during training via a provably convergent bilevel optimization procedure, while also retaining the simplicity and efficiency of vanilla SGD. We bridge this theory with rigorous empirical validation, demonstrating that learn2mix accelerates convergence across various classification, regression, and reconstruction benchmarks.
The paper provides no analysis of computational overhead. The dynamic batch construction and class-wise loss computation likely introduce non-negligible computational costs that could offset the benefits of faster convergence. What is the computational cost of the dynamic batch construction and class-wise loss computation? How does this overhead compare to the time savings from faster convergence?
The manuscript does measure and report computational cost: the classification results in Section 4 of the main text and all empirical results in Section B of the Appendix consider elapsed training time rather than epoch count. Each experiment was run for seconds (wall-clock time) on the same NVIDIA GeForce RTX 3090 GPU, which includes the overhead of dynamic batch construction and per-class loss computations; because the evaluation metric is end-to-end wall-clock time on fixed hardware, learn2mix’s overhead is fully accounted for.
The authors adequately acknowledge some limitations, but could strengthen this discussion. The gap between theoretical assumptions and practical applicability should be emphasized more prominently. Additionally, the paper would benefit from explicit discussion of computational overhead, hyperparameter sensitivity, and scenarios where the method might be less effective.
Thank you for your comments. As detailed above, we report end-to-end wall-clock training time on a fixed NVIDIA GeForce RTX 3090 GPU, so any overhead from dynamic batch construction or class-wise loss computation is fully reflected in the results shown in Section 4 of the main text and Section B of the Appendix. The latter section also contains comprehensive ablations over batch size, learning rate, network architecture, and optimizer. To further demonstrate robustness, we provide worst per-class accuracy results on Imagenette and IMDB for in Table 1 and Table 2 (we will include this result in the revised manuscript). These results confirm minimal sensitivity to this parameter. We recall that in Section 4 of the main text, we recommend choosing , as it yields similar improvements across all tasks.
Table 1: Worst-class test performance comparison for ResNet-18 on Imagenette
| Method | Worst-Class Accuracy (57.5 s) | Worst-Class Accuracy (115.0 s) | Worst-Class Accuracy (230.0 s) |
|---|---|---|---|
| L2M () | 0.0000 % | 2.3866 % | 14.1623 % |
| L2M () | 1.3149 % | 10.9318 % | 14.1451 % |
| L2M () | 0.2542 % | 3.2815 % | 15.6304 % |
| CL | 0.7779 % | 0.6205 % | 0.6325 % |
| FCL | 0.0000 % | 0.0000 % | 4.8162 % |
| SMOTE | 0.0000 % | 0.0000 % | 8.9119 % |
| IS | 0.0000 % | 0.0000 % | 0.7924 % |
| CURR | 0.0000 % | 0.2482 % | 2.8140 % |
Table 2: Worst-class test performance comparison for Transformer on IMDB
| Method | Worst-Class Accuracy (38.0 s) | Worst-Class Accuracy (75.0 s) | Worst-Class Accuracy (150.0 s) |
|---|---|---|---|
| L2M () | 46.2067 % | 77.4133 % | 79.2200 % |
| L2M () | 40.2867 % | 76.6267 % | 80.7400 % |
| L2M () | 60.3733 % | 74.4467 % | 81.8533 % |
| CL | 0.0000 % | 47.6400 % | 67.3933 % |
| FCL | 0.0000 % | 34.5333 % | 60.7200 % |
| SMOTE | 11.3200 % | 28.7867 % | 62.5400 % |
| IS | 24.8467 % | 57.7200 % | 69.6267 % |
| CURR | 0.0000 % | 48.4400 % | 66.0333 % |
Regarding limitations, we note that the application of learn2mix to regression and reconstruction tasks necessitates samples to be partitionable into groups by similarity (e.g., via a categorical variable taking values). We will clarify this in the revised manuscript.
The evaluation is also limited to relatively small benchmark datasets and does not demonstrate scalability to larger, more challenging datasets. Furthermore, the comparison baselines for imbalanced learning are dated. More recent advances in imbalanced learning and adaptive training strategies from the past few years are not considered, which limits the assessment of learn2mix's competitiveness against state-of-the-art approaches.
As we detailed previously, for the present study, we intentionally restricted ourselves to CIFAR-100-LT (logarithmically decaying imbalance), Imagenette (stepwise imbalance), and similar benchmarks because they are canonical in the imbalanced learning literature [1], admit controllable imbalance factors, and can be reproduced without access to large compute resources. Our wall-clock results on an NVIDIA GeForce RTX 3090 GPU show that learn2mix introduces no scaling bottlenecks beyond standard SGD for extensions to larger datasets.
Thank you again for your detailed feedback, which has been significant in helping enhance our paper.
References:
[1] Zhang, Yongshun, et al. "Bag of tricks for long-tailed visual recognition with deep convolutional neural networks." Proceedings of the AAAI conference on artificial intelligence. Vol. 35. No. 4. 2021.
[2] Guanghui Wang, Shiyin Lu, Quan Cheng, Wei wei Tu, and Lijun Zhang. SAdam: A variant of adam for strongly convex functions. In International Conference on Learning Representations, 2020. [3] Arindam Banerjee, Pedro Cisneros-Velarde, Libin Zhu, and Misha Belkin. Restricted Strong Convexity of Deep Learning Models with Smooth Activations. In International Conference on Learning Representations, 2023.
Thanks for the detailed response. A few key points:
Computational overhead: Measuring wall-clock time adequately addresses this concern.
Baselines: Your justification doesn't resolve the core issue. Comparing against 2002-2019 methods in 2025 is problematic regardless of their "wide adoption." Recent work in adaptive sampling and curriculum learning would provide more meaningful comparisons.
Theory-practice gap: Saying strong convexity assumptions are "standard" sidesteps the real problem. These assumptions are violated in virtually all neural network settings, making the theoretical guarantees largely irrelevant. This limitation needs honest acknowledgment.
Scale: You didn't demonstrate scalability to large datasets or many classes. Wall-clock comparisons on small benchmarks don't address whether the method works on datasets with thousands of classes. The additional ablation results are useful, but the fundamental concerns about limited evaluation scope and outdated baselines remain unaddressed.
Dear Reviewer HCei,
Thank you for your thoughtful feedback, and for acknowledging the wall-clock measurements in our experiments on identical hardware. Regarding our choice of datasets, we note that recent works from the past year continue to adopt CIFAR-100-LT [1] as a canonical benchmark for long-tailed imbalanced learning, including papers from NeurIPS, ICML, and CVPR [2-5], affirming its relevance in current literature. Due to our limited compute resources, we focused on these standard benchmarks in the present study. We understand your broader concern about scaling to larger class sizes, and note that benchmarking learn2mix in these settings is an important aspect of our ongoing work. We also note that our considered importance sampling approach is from 2022 [6], as we detail in the Appendix.
We respect the reviewer's judgement regarding the assumption of strong convexity for the theoretical results, and will explicitly acknowledge this in the limitations section of the updated manuscript. We note that the provisioning of these theoretical results was intended to be illustrative rather than prescriptive for deep neural networks. For example, we prove that learn2mix converges to a stable distribution that reflects the relative difficulty of each class under the optimal parameters (we provide empirical results in our response to Reviewer Ppx6 to verify this behavior).
References:
[1] Zhang, Yongshun, et al. "Bag of tricks for long-tailed visual recognition with deep convolutional neural networks." Proceedings of the AAAI conference on artificial intelligence. Vol. 35. No. 4. 2021.
[2] Li, Mengke, et al. "Improving visual prompt tuning by Gaussian neighborhood minimization for long-tailed visual recognition." Advances in Neural Information Processing Systems 37 (2024): 103985-104009.
[3] Shao, Jie, et al. "DiffuLT: Diffusion for Long-tail Recognition Without External Knowledge." Advances in Neural Information Processing Systems 37 (2024): 123007-123031.
[4] Gao, Jintong, et al. "Distribution alignment optimization through neural collapse for long-tailed classification." Forty-first International Conference on Machine Learning. 2024.
[5] Rangwani, Harsh, et al. "Deit-lt: Distillation strikes back for vision transformer training on long-tailed datasets." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.
[6] Johansson, Mathias, and Emma Lindberg. "Importance sampling in deep learning: A broad investigation on importance sampling performance." (2022).
This paper proposes a form of adaptive class rebalancing that uses ideas from imbalanced classification for a form of curriculum learning. Some theoretical results are presented in the strongly-convex smooth setting, and experimental results show relatively small but consistent improvements over all considered competitors for several small image and tabular classification problems.
优缺点分析
The proposed method is simple, easy to implement, and reasonably sensible. It also seems to consistently, though not earth-shatteringly, help on the settings you tried, although the considered experiments are a little more "2017-era" than a typical empirically-minded NeurIPS paper these days.
It would be worth mentioning, for readers who are not already familiar, that (a) a substantial portion (though certainly not all) of your theorems are textbook-standard results about the strongly convex smooth optimization setting, and (b) that neither of these conditions (particularly not strong convexity, but not smoothness either) apply to the neural networks you actually consider in experiments. Both are fine, but it would be better to be clear about this.
问题
-
What happens to during the training process? That is, you have an intuition that it orders the difficulty between classes or something along those lines, but...is that what actually happens? I'm very surprised that there were no experiments considering this.
-
What about cases where there is a real fundamental difference in class difficulty? I don't think this is particularly the case in any of the datasets you considered here, but e.g. standard ImageNet classification has some classes which are far easier to confuse with one another, like the fine-grained different breeds of dogs, or certain inherently ambiguous classes.
-
How do the per-class losses end up, compared to the standard case? The aggregate case is lower, but e.g. if you have a relatively rare class and a biggish , then I wonder how often getting lucky early on could lead to poor performance on that class in particular.
-
Consider a situation where you can afford large batch sizes (in fact, you're running smallish models on a 3090, so you're basically in that situation). How would your sampling-based weighting scheme compare to an adaptive reweighting approach, where you still get a uniform / the dataset level of mixtures among classes, but the loss of each is reweighted?
局限性
Societal impacts: it's conceivable this could have some kind of fairness type impacts, but I don't think that's especially likely or necessary to evaluate in this kind of paper.
Limitations of the proposed approach and analysis: not really discussed as thoroughly as they could be.
格式问题
Dear Reviewer Ppx6,
Thank you for your valuable feedback. Below, we have addressed your comments
It would be worth mentioning, for readers who are not already familiar, that (a) a substantial portion (though certainly not all) of your theorems are textbook-standard results about the strongly convex smooth optimization setting, and (b) that neither of these conditions (particularly not strong convexity, but not smoothness either) apply to the neural networks you actually consider in experiments. Both are fine, but it would be better to be clear about this.
Thank you for this suggestion. We will add a brief note in Section 2 stating that the propositions and their corollaries rest on standard results for strongly-convex objectives; we adopt this classical baseline intentionally because it lets us isolate how adaptive mixing influences convergence under well-understood conditions, before demonstrating the same mechanism’s empirical effectiveness for non-convex objectives via the architectures considered in Section 4.
What about cases where there is a real fundamental difference in class difficulty? I don't think this is particularly the case in any of the datasets you considered here, but e.g. standard ImageNet classification has some classes which are far easier to confuse with one another, like the fine-grained different breeds of dogs, or certain inherently ambiguous classes.
We provide empirical results on CIFAR-100 in Section 4 of the main text and in Section B of the Appendix, specifically to expose fine-grained confusion. The 100 classes in CIFAR-100 are arranged in 20 superclasses that group visually similar categories (e.g., fish types such as aquarium fish, flatfish, and trout). Learn2mix updates the class-sampling proportions, , each epoch toward the normalised class-wise loss vector, so intrinsically harder to distinguish classes are automatically sampled more often during training. Empirically, this yields faster convergence: on CIFAR-100, the learn2mix-trained MobileNet-V3 reaches 50% test accuracy 20 seconds sooner than any fixed-proportion baseline, and on Imagenette (a compute-friendly subset of ImageNet), learn2mix lifts the worst-class accuracy curve throughout training (see Section B of the Appendix).
What happens to during the training process? That is, you have an intuition that it orders the difficulty between classes or something along those lines, but...is that what actually happens? I'm very surprised that there were no experiments considering this.
We note that the asymptotic behavior of is motivated by more than just intuition. Per Section 2 of the main text, learn2mix converges to a stable distribution that reflects the relative difficulty of each class under the optimal parameters (under standard strong convexity assumptions [1], [2]). Regarding experiments, on our Mean Estimation benchmark (where the Normal, Exponential, and Chi-squared cases have similar variance but the Uniform case is substantially more variable), converges to the values given in Table 1. These proportions confirm that learn2mix prioritizes the hardest class without overstating differences among the easier ones. We will include this result in Appendix Section B of the revised manuscript to illustrate how empirically orders class difficulty.
Table 1: Mixing parameter evolution over time on Mean Estimation benchmark
| Epoch () | ||||
|---|---|---|---|---|
| 0.33 | 0.33 | 0.264 | 0.076 | |
| 0.1997 | 0.1997 | 0.1598 | 0.4407 | |
| 0.1243 | 0.1226 | 0.1001 | 0.653 | |
| 0.0865 | 0.0794 | 0.0717 | 0.7624 | |
| 0.0630 | 0.0533 | 0.0546 | 0.8291 | |
| 0.0479 | 0.0372 | 0.0419 | 0.8729 | |
| 0.0382 | 0.0272 | 0.0337 | 0.9008 | |
| 0.0325 | 0.0210 | 0.0297 | 0.9168 | |
| 0.0299 | 0.0174 | 0.0273 | 0.9253 | |
| 0.0280 | 0.0151 | 0.0240 | 0.9330 | |
| 0.0270 | 0.0134 | 0.0227 | 0.9369 | |
| 0.0258 | 0.0121 | 0.0236 | 0.9385 | |
| 0.0261 | 0.0123 | 0.0233 | 0.9384 |
How do the per-class losses end up, compared to the standard case? The aggregate case is lower, but e.g. if you have a relatively rare class and a biggish , then I wonder how often getting lucky early on could lead to poor performance on that class in particular.
In Section B of the Appendix, we present the worst-class classification accuracy on Imagenette as an additional metric to gauge the efficacy of learn2mix for imbalanced classification settings. The class with the worst per-class loss is recorded for each elapsed training time step, wherein the learn2mix-trained ResNet-18 model () observes improved learning dynamics over existing training approaches. Regarding , as we recommend in Section 4 of the main text, choosing yields the best performance. We provide additional empirical results below for the worst per-class classification accuracy on Imagenette and IMDB for and a “biggish” (per the reviewer’s comments) to demonstrate relative insensitivity to the choice of . We will include this result in the revised manuscript.
Table 2: Worst-class test performance comparison for ResNet-18 on Imagenette
| Method | Worst-Class Accuracy (57.5 s) | Worst-Class Accuracy (115.0 s) | Worst-Class Accuracy (230.0 s) |
|---|---|---|---|
| L2M () | 0.0000 % | 2.3866 % | 14.1623 % |
| L2M () | 1.3149 % | 10.9318 % | 14.1451 % |
| L2M () | 0.2542 % | 3.2815 % | 15.6304 % |
| CL | 0.7779 % | 0.6205 % | 0.6325 % |
| FCL | 0.0000 % | 0.0000 % | 4.8162 % |
| SMOTE | 0.0000 % | 0.0000 % | 8.9119 % |
| IS | 0.0000 % | 0.0000 % | 0.7924 % |
| CURR | 0.0000 % | 0.2482 % | 2.8140 % |
Table 3: Worst-class test performance comparison for Transformer on IMDB
| Method | Worst-Class Accuracy (38.0 s) | Worst-Class Accuracy (75.0 s) | Worst-Class Accuracy (150.0 s) |
|---|---|---|---|
| L2M () | 46.2067 % | 77.4133 % | 79.2200 % |
| L2M () | 40.2867 % | 76.6267 % | 80.7400 % |
| L2M () | 60.3733 % | 74.4467 % | 81.8533 % |
| CL | 0.0000 % | 47.6400 % | 67.3933 % |
| FCL | 0.0000 % | 34.5333 % | 60.7200 % |
| SMOTE | 11.3200 % | 28.7867 % | 62.5400 % |
| IS | 24.8467 % | 57.7200 % | 69.6267 % |
| CURR | 0.0000 % | 48.4400 % | 66.0333 % |
Consider a situation where you can afford large batch sizes (in fact, you're running smallish models on a 3090, so you're basically in that situation). How would your sampling-based weighting scheme compare to an adaptive reweighting approach, where you still get a uniform / the dataset level of mixtures among classes, but the loss of each is reweighted?
We agree that an adaptive loss-reweighting scheme is a natural baseline. In fact, our empirical results in Section 4 of the main text and Section B of the Appendix include the class-balanced focal loss (FCL) for this reason. Focal loss keeps class mixtures fixed while adaptively reweighting each class's contribution to the total loss, as the reviewer has proposed. Because focal loss relies on class-level weighting factors, , that can differ substantially across classes, its stochastic-gradient estimates generally exhibit higher variance. By contrast, learn2mix samples more frequently from high-loss classes, a strategy known to reduce gradient variance in similar works on importance sampling [3].
Thank you again for your feedback, which has been instrumental in improving our paper.
References:
[1] Guanghui Wang, Shiyin Lu, Quan Cheng, Wei wei Tu, and Lijun Zhang. SAdam: A variant of adam for strongly convex functions. In International Conference on Learning Representations, 2020.
[2] Arindam Banerjee, Pedro Cisneros-Velarde, Libin Zhu, and Misha Belkin. Restricted Strong Convexity of Deep Learning Models with Smooth Activations. In International Conference on Learning Representations, 2023.
[3] Angelos Katharapolous, and Francois Fleuret. Not All Samples Are Created Equal: Deep Learning with Importance Sampling. In International Conference on Machine Learning, 2018.
Thanks for your responses, the new results, and for pointing me to things I had missed that were already in your submission.
If space allows, it would be good to add a little bit about the relationship to FCL, SMOTE, etc in the main body – currently there are empirical comparisons to these methods, but no discussion of what they actually are, so it was very easy to skim past these and not think about how they relate to your method.
I will keep my positive evaluation of the paper.
Dear Reviewer Ppx6,
Thank you again for your comments. In the original manuscript, due to the strict page limit, we referred readers to "Sections D.3, D.4, D.5, and D.6 of the Appendix" (line 222) for detailed explanations of FCL, SMOTE, and related methods. These sections provide comprehensive descriptions of each technique. We will include brief descriptions of these methods in Section 4 of the revised manuscript.
With the additional empirical results and the updates made in response to your comments, we hope that we have adequately addressed your concerns. If you feel that these revisions have strengthened the manuscript, we would be sincerely grateful if you would consider revisiting your score. We are grateful for your thorough review, which has significantly contributed to enhancing our paper.
Dear Reviewers,
This is a gentle reminder that the discussion period will be concluded soon.
Please take the time to review the authors' responses (if provided) to your feedback and any questions you raised. Your careful consideration of their rebuttals is crucial for ensuring a fair and comprehensive evaluation of the submissions.
Following this, we kindly ask you to actively participate in the discussion to share your updated perspectives or align with fellow reviewers' insights as needed.
Finally, please ensure that you complete the review confirmation process by the specified deadline to finalize your assessment. This step is essential for moving forward with the decision-making process.
Thank you for your dedication and timely contributions to the NeurIPS review process.
Best regards,
The AC
This paper proposes an adaptive training strategy that dynamically adjusts class proportions in batches based on real-time class-wise error rates to accelerate convergence, with applications across classification, regression, and reconstruction tasks. Key innovations include its dynamic batch composition (prioritizing high-error classes), architecture-agnostic design (validated on CNNs, Transformers, autoencoders), and integration of computational overhead into wall-clock metrics, ensuring practical efficiency.
Reviewers raised concerns about theoretical assumptions (strong convexity not holding for neural networks), reliance on dated baselines, scalability to large datasets/modern architectures, and hyperparameter sensitivity. Authors addressed these effectively: clarifying theoretical assumptions as standard for convergence insights while emphasizing empirical validity in non-convex settings; justifying baselines via recent citations (2024 NeurIPS/ICML) using similar benchmarks; confirming no scaling bottlenecks and architecture generality.
Initial ratings from reviewers were mostly "borderline accept" with one "borderline reject," but post-rebuttal, consensus shifted positively. Reviewers acknowledged resolved concerns, with 43t3 maintaining "accept," others recognizing practical value despite residual critiques.