PiKE: Adaptive Data Mixing for Large-Scale Multi-Task Learning Under Low Gradient Conflicts
摘要
评审与讨论
This work studies the problem of creating a batch composed of multiple tasks for the training of large language models (LLMs). To this end, the authors adopt a multitask learning (MTL) perspective and pose the problem as a dynamic gradient weighting problem (studied in MTL in the context of negative transfer). Then, the authors verify existing claims on the lack of gradient conflict to larger models than before, and study the effect of weighted-mixing the batch creation under SGD and bounded gradient cosine similarities. This ideal scenario is then relax for more realistic scenarios where we may use other optimizers and expect a balanced decreases in task losses. The authors then verify the effectiveness of the proposed approach in a number of datasets and GPT-2 models, comparing with previous baselines.
优缺点分析
Strengths
- The paper is generally well-written, making all their points clear and easy to read.
- I like the idea of posing the generation of batches in the context of LLMs as a MTL problem, and their connections with gradient conflicts.
- The empirical results are convincing.
- The validation on the low-conflict of large models is very welcome.
- The proposed algorithm is well-motivated, although I would've liked a bit more of detail in the derivations.
- The problem is clearly of relevant for the community.
Weaknesses
- I would have liked to see the performance of the idealized version (or the one which updates every X iterations) to see its drawbacks.
- While efficient, the algorithm still needs to compute and store the gradients for each task. Moreover, I could see how rich-get-richer situation, where tasks with low representation in the batch get worse estimates of their variance due to the smaller number of samples.
- It is not completely clear the flowchart to use the proposed method: Should we pretrain the weights? What value of should we pick?
问题
- How do you derive the mirror descent step in section 3.2?
- Is there any assumption regarding the capacity of the neural network to hold remark 1?
- What are the empirical values for the bound of the cosine similarity that you find in practice? Based on l165, it seems that gradients are almost orthogonal for LLMs.
- How is the bound in Theorem 3.4 tight (L164)? Isn't this bound some form of a truncated Taylor expansion?
- Do you have any intuition on how to select ?
局限性
yes
最终评判理由
I can confirm my support on accepting this submission. This is good and solid work. After the rebuttal, my biggest concerns were resolved and the few remaining are rather unimportant or context-dependent, let me explain:
1. Lack of empirical results for the idealized algorithm (weight 25%): I was somewhat concerned that the authors derived an algorithm that has never been tested, but rather its approximations. The authors added experiments to compare a version of the algorithm that it is closer to the ideal one than before (still far, as it is naturally expensive), that I think solves this issue.
2. Batch-dependent variance estimators (weight 50%): My biggest concern was that tasks with low-representation in the batch could get worse statistical estimates that lower their representation even further. After the rebuttal, authors clarified that every steps the variance of each task is estimated individually with an equal number of samples, clarifying this concern.
3. Assumptions on large-models and over-parametrization (weight 15%): The premise of the paper heavily relies on having large models that can learn every task perfectly and with almost no-interaction between tasks. While this is said upfront, some results assumed it implicitly as acked by the authors (e.g. remark 1 and, I assume, also theorem 3.4). While this is resolved, I'd like it to be even more explicit in the final version.
4. Intuition about the hyperparameter (weight 10%): The algorithm provides a hyperparameter to balance the learning across tasks. While the authors recommend a value of 3 or 5, it is still a bit too ad hoc to my taste and the lack of interpretability for the values of make it a bit hard to tune without some model selection. This is, however, a minor point.
格式问题
- L61: missing "continuous" after L-Lipschitz.
Thank you for your time and for the thoughtful and constructive feedback. We are encouraged that you found the paper well-written, the problem important, and the empirical results convincing. We especially appreciate your recognition of our contribution on the new view of data mixing as a multi-task learning problem, and our validation of the low-gradient conflict hypothesis in large models. Below, we respond to your specific questions and suggestions in detail.
I would have liked to see the performance of the idealized version (or the one which updates every X iterations) to see its drawbacks.
This is a great question. During the rebuttal period, we ran additional experiments on the GLaM dataset to study the effect of the PiKE update interval, . As shown in Table 1 below, more frequent updates tend to improve performance. However, this benefit comes at an additional computational cost. While our original setting of incurs a negligible ( ) training overhead, shorter intervals of and increase this overhead to and , respectively. We think that a reasonable update interval like balances an effective trade-off between model performance and computational efficiency. If the paper is accepted, we will include this additional table in the final version.
Table 1. Pre-training 110M GPT-2 style LMs with context length 2048 for 20K steps using GLaM datasets.
| Method | Mean Accuracy | ArcE | CSQA | HellaSwag | PIQA |
|---|---|---|---|---|---|
| Mix | 37.69 | 38.35 | 25.80 | 28.50 | 58.11 |
| Round_Robin | 36.47 | 37.17 | 22.77 | 27.57 | 58.38 |
| Random | 36.07 | 37.55 | 21.79 | 27.23 | 57.73 |
| GLaM | 37.10 | 37.89 | 24.90 | 28.17 | 57.45 |
| DoReMi | 38.17 | 39.49 | 26.21 | 28.50 | 58.49 |
| PiKE (Update Interval - 100) | 38.32 | 40.85 | 27.60 | 28.47 | 56.37 |
| PKE (Update Interval - 500) | 38.67 | 40.68 | 26.45 | 28.80 | 58.76 |
| PiKE (Update Interval - 1000) | 38.23 | 39.37 | 27.35 | 28.30 | 57.89 |
While efficient, the algorithm still needs to compute and store the gradients for each task. Moreover, I could see how rich-get-richer situation, where tasks with low representation in the batch get worse estimates of their variance due to the smaller number of samples.
We agree that computational overhead is an important consideration. PiKE is designed to be computationally efficient. In particular, PiKE does not require storing full per-task gradients. Instead, for each task, we only store two scalar values: the gradient norm and its variance. We also reported the overhead from estimating these statistics during pre-training: In our experiments, this step adds at most to total training time, which we find to be relatively negligible.
Your question about the impact on tasks with low representation is very important. To answer it, we would like to clarify one detail about PiKE: the variance estimation is done once every steps, and we perform it over tasks one at a time. In particular, for each task , we select a relatively large batch (we used 256 for all experiments) and use it to estimate the gradient norm and variance. So even if a task has low representation during regular training, it still gets a large batch for variance estimation every steps. Moreover, all tasks use the same batch size for the variance estimation step. This is an important detail, and we will clarify it further in the revised version of the paper.
It is not completely clear the flowchart to use the proposed method: Should we pretrain the weights? What value of should we pick?
Our method doesn't require pretraining the weights. The batch mixture is computed dynamically on the fly during training. For the hyperparameter , we recommend setting or . Our experiments consistently show these values deliver strong and balanced performance.
How do you derive the mirror descent step in section 3.2?
Here is more detailed description on how we derive the mirror descent step. Starting from equation (9), we solve the following constrained optimization:
where and .
This can be written more explicitly as:
To apply mirror descent over the simplex, we need to use a multiplicative update followed by normalization. Taking gradient with respect to w_k, we have
$
g_{w_k} = -\eta(\beta - L \eta \gamma w_k \| \nabla \mathcal{L}_k(\theta) \| ^2) + \frac{L \eta^2 }{2b} \sigma_k^2 \
$
Let be our mirror descent learning rate. The update rule becomes
$
\begin{split} w_k &\leftarrow w_k \exp \left(\alpha \eta\left(\beta-L \eta \gamma w_k\right)\left|\nabla \mathcal{L}_k(\theta)\right|^2-\frac{\alpha L \eta^2}{2 b} \sigma_k^2 \right) \ \end{split}
$
followed by the normalization .
Is there any assumption regarding the capacity of the neural network to hold remark 1?
You are correct that Remark 1 implicitly relies on an assumption about the model’s capacity. More precisely, we assume the model is in the over-parameterized regime, with enough capacity to learn all tasks jointly without representational bottlenecks. In particular, we assume that small gradients imply small loss values. This matches our experimental setting, which uses large-scale modern networks that are, by design, highly over-parameterized
We will revise Remark 1 to make this assumption explicit. Thank you for highlighting the need for clarity here.
What are the empirical values for the bound of the cosine similarity that you find in practice? Based on l165, it seems that gradients are almost orthogonal for LLMs.
Correct! The per-task gradients are indeed nearly orthogonal in practice, which one of the core assumption behind our method's design. As illustrated in Figures 5 and 6, we measured the cosine similarity across different task gradients for both our multilingual and GLaM pre-training experiments. We found that the most values are consistently low, with a empirical upper bound of 0.05.
How is the bound in Theorem 3.4 tight (L164)? Isn't this bound some form of a truncated Taylor expansion?
To demonstrate that this bound is tight (in the minimax sense), we must show that there exists a case where the inequality is met with equality. As we state on L163-164, we constructed such an example in Theorem H.4 (Appendix H). In that section, we define a specific loss function for which the inequality in our bound becomes an exact equality. Therefore, by definition, the bound is tight for this example.
Do you have any intuition on how to select ?
Intuitively the hyperparameter allows users to navigate the trade-off between optimizing for average performance versus worst-case performance (i.e., fair learning across tasks). For the hyperparameter , we recommend setting or . Our experiments consistently show these values deliver strong and balanced performance.
I thank the authors for their helpful responses. The few concerns I had have been mostly solved after the rebuttal (especially the one regarding variance estimation per task) and I can confirm my choice on recommending this submission for acceptance!
We thank you for your prompt review and for your favorable recommendation of our manuscript. Should you have any further questions or concerns, we would be pleased to address them.
This paper proposes PiKE, an adaptive data mixing algorithm. PiKE draws inspiration from MTL, which focuses on mitigating gradient conflicts. However, the paper finds that in pre-training, gradient conflicts are rare, so using this finding, they derive a theoretical bound on the decrease in training loss as a function of mixture weights and gradient properties. By maximizing this decrease, they propose PiKE, which updates weights using mirror descent as a function of gradient norm and gradient variance. An additional balanced PiKE is proposed to encourage good performance on all tasks. Pretraining experiments are conducted on multilingual C4 and GLAM.
优缺点分析
Strengths
Quality:
- The series of building blocks establishes why the authors arrived at this particular setup/method, providing strong motivating evidence.
- Theoretical derivation aside, this idea of having gradient variance in the mixing algorithm makes intuitive sense to me. One unsatisfying thing about existing algorithms is that they do not encode the possibility that existing domains are poorly defined/noisy. To me, this variance term factors in this upstream noise into the mixing algorithm, which seems like the right direction.
Significance:
The work proposes a theoretically motivated data mixing algorithm that is also computationally efficient. It differs from several other data mixing works that simply propose an algorithm and do not describe its assumptions. I hope this work can continue to push for more principled data mixing algorithms.
Clarity:
Theoretical results are presented at a good level of abstraction and are easy to follow.
Weaknesses
Originality:
- It would be interesting to compare the theoretical derivation in this paper to the derivation in https://arxiv.org/abs/2310.15393. That paper uses first order approximations to arrive at an objective that is linear in ; except that it uses a gradient dot product rather than gradient norm, and also doesn't have the variance term. For clarity, it would also help to explain why there is no gradient dot product term in PiKE, since its nod to MTL is based on gradient conflict/similarity.
Quality:
- While example 2.1 shows how an adaptive mixing procedure is optimal, the loss/data assumptions are not realistic. It would be good to have a real pretraining experiment that motivates adaptive mixing. As an example, in https://arxiv.org/abs/2411.05735 Table 10, they do a two-stage brute-force search over data mixes and show that the best two-stage approach does not result in the same at each stage, for certain domains.
- One property of the mirror descent weight update is that in the limit, the updates move less and less (because the weight update accumulates). For instance, in your figure 8 you can see that the mixture weights converge fairly quickly. As a result, is not able to "move dramatically" later on in pretraining, even if that may be optimal.
- In Table 1, it is surprising to me that regular PiKE performs poorly in terms of validation PPL. My understanding is that PiKE aims to maximize the training loss reduction at each time-step, so it would be good to have some verification of regular PiKE's performance (i.e., comparing training loss curves for PiKE + baselines).
- In Table 2, it is surprising to me why PiKE (regular and balanced) have high perplexity but do well on downstream tasks. I understand that perplexity and downstream are not always correlated; however, PiKE and most of these other data mixing methods are designed to minimize some function of perplexity/loss - so why aren't they able to minimize these properly in practice?
Clarity:
- No related work in the body; more comparison to other data mixing methods would be helpful
- To demonstrate how balanced-PiKE promotes fiar learning on many tasks/domains, it would be helpful to also display the range/variance across tasks/domains in Table 1. After writing out the ranges of PPL in table 1 by hand, only then did I convince myself of what balanced-PiKE is doing.
问题
- Can you compare derivation/results with DoGE (https://arxiv.org/abs/2310.15393)?
- More realistic pre-training experiment to motivate when adaptive data mixing is optimal.
- More validation of regular PiKE, such as arguing that the training loss curve goes down more quickly than other mixing approaches.
局限性
Yes
格式问题
N/A
Thank you for your detailed and insightful feedback. We’re glad you saw our work as a meaningful step toward more principled data mixing algorithms. We appreciate that you recognized its strong theoretical grounding, efficiency, and clarity. We also agree that many existing methods lack a clear or principled foundation. In contrast, we designed PiKE with well-stated assumptions and a focus on principled development. This approach has helped us better understand both the strengths and limitations of PiKE. We are very happy that you noticed and appreciated our efforts in this direction.
We address your questions below.
Weaknesses: It would be interesting to compare the theoretical derivation in this paper to the derivation in DoGE. That paper uses first order approximations to arrive at an objective that is linear in w_k except that it uses a gradient dot product rather than gradient norm, and also doesn't have the variance term. For clarity, it would also help to explain why there is no gradient dot product term in PiKE, since its nod to MTL is based on gradient conflict/similarity.
Question: Can you compare derivation/results with DoGE?
Thanks for bringing up this related work. We will discuss it in our revision. As you correctly pointed out, PiKE relies on both the gradient norm and the variance of the gradients, while DoGE relies on the inner product between task gradients. There are two important points to highlight here:
-
In large models, the gradients are almost orthogonal (under the assumptions of low conflict and low alignment, as motivated in the paper based on our experiments). This means the inner product between task gradients is either negligible or equal to the norm of the per-task gradients. Specifically, the inner product between gradients of two different tasks is close to zero, while the inner product of a task's gradient with itself equals the squared norm of that gradient. Therefore, from the expected value viewpoint, DoGE and PiKE have a similar flavor (since DoGE rely on inner-product, which is related to the squared gradient norm under our assumptions). However, PiKE uses the near orthogonality to simplify calculations and reduce memory requirements. In particular, PiKE does not need to maintain or store the overall gradient, whereas DoGE does, which adds memory overhead. PiKE only requires to store two scalar values per-task: gradient norm and variance.
-
As you noted, PiKE also relies on gradient variance. We believe gradient variance is an important signal. Ignoring this signal reduces performance, as shown in Table 11 in the Appendix. The importance of this signal is also intuitive: if two tasks have the same expected gradient but one has much higher variance due to noisier data, it makes sense to focus more on the task with less variance. This point has been mostly overlooked in recent data mixing literature, though classical statistics has considered this fact. For example, noise is considered in data weighting in the Feasible Generalized Least Squares (FGLS) method in classical statistics.
In summary, PiKE uses gradient variance in addition to the mean, unlike DoGE. Also, leveraging near orthogonality simplifies derivations and reduces memory overhead. We will clarify these points and the connection to DoGE in our revision.
While example 2.1 shows how an adaptive mixing procedure is optimal, the loss/data assumptions are not realistic. It would be good to have a real pretraining experiment that motivates adaptive mixing. As an example, in Aioli Table 10, they do a two-stage brute-force search over data mixes and show that the best two-stage approach does not result in the same at each stage, for certain domains.
More realistic pre-training experiment to motivate when adaptive data mixing is optimal.
Thank you for your suggestion. To provide a more realistic motivation for adaptive mixing, we conducted a brute-force search over various static (fixed-weight) data mixtures of English and Hindi. We then compared these baselines against our adaptive method, PiKE. As shown in Table 1 below, PiKE consistently outperforms all static mixture configurations we tested. This result empirically is in line with our intuition, illustrating the necessity of an adaptive strategy for achieving the best performance in pre-training.
Table 1. Pre-training 270M GPT-2 style LMs with context length 2048 for 20K steps using C4 (English and Hindi) datasets.
| Method | Mean Accuracy | ||
|---|---|---|---|
| Mix | 0.2 | 0.8 | 27.43 |
| Mix | 0.4 | 0.6 | 27.38 |
| Mix | 0.5 | 0.5 | 27.65 |
| Mix | 0.6 | 0.6 | 27.90 |
| Mix | 0.8 | 0.2 | 28.67 |
| PiKE | Adaptive | Adaptive | 29.27 |
In our revision, we will include this experiment in the paper, as it addresses the important question of the relevance of adaptive mixing in complex LLM pre-training. We will also keep our current simple regression example because it is easy to replicate and can be verified very quickly on any computer.
One property of the mirror descent weight update is that in the limit, the updates move less and less (because the weight update accumulates). For instance, in your figure 8 you can see that the mixture weights converge fairly quickly. As a result, w is not able to "move dramatically" later on in pretraining, even if that may be optimal.
Could you please clarify what you mean by “accumulation”? In the update rule, we weight the step by both the gradient and its variance: if that combined term grows large, it can produce abrupt changes to the previous update (since our update is multiplicative).
In Table 1, it is surprising to me that regular PiKE performs poorly in terms of validation PPL. My understanding is that PiKE aims to maximize the training loss reduction at each time-step, so it would be good to have some verification of regular PiKE's performance (i.e., comparing training loss curves for PiKE + baselines).
In Table 2, it is surprising to me why PiKE (regular and balanced) have high perplexity but do well on downstream tasks. I understand that perplexity and downstream are not always correlated; however, PiKE and most of these other data mixing methods are designed to minimize some function of perplexity/loss - so why aren't they able to minimize these properly in practice?
More validation of regular PiKE, such as arguing that the training loss curve goes down more quickly than other mixing approaches.
Thank you for raising this important point. The confusion arises from a reporting error in our manuscript. We mistakenly reported the arithmetic mean of perplexities instead of the correct aggregate metric, which is the average loss. When using average loss (which is related to the geometric mean of ppls), PiKE's performance is highly competitive. For example, in the GLaM 740M pre-training results (Table 2), PiKE’s average loss of 2.55 is very close to the 2.50 achieved by the strong Mix baseline. We will correct this issue if our paper is accepted.
We also acknowledge that perfect alignment between our method's objective and the final validation loss is not always guaranteed. This may be due to several factors: first, our theory assumes an SGD optimizer, while we use AdamW in experiments, which is standard for LLM pre-training. Second, PiKE optimizes the training loss directly, whereas reported metrics are on the validation set, which may have a distribution shift. Third, our hyperparameters might not be perfectly tuned to maximize the descent in the objective during training (our theory is for perfectly tuned hyper-parameters).
That said, the loss values remain comparable, as noted above. Moreover, as discussed in earlier works (such as the ADO paper), simple mixing strategies can be very competitive in final perplexity but may not perform as well in model generalization. Understanding this gap requires further study and likely several follow-up works. We will highlight this finding in the paper and invite the community to explore the reasons behind this phenomenon.
No related work in the body; more comparison to other data mixing methods would be helpful
In the camera-ready version, we will move our related work section from the appendix into the main body and expand it to include a more detailed comparison with other data mixing methods, as requested. The additional page provided for the camera-ready version will allow us to make these important improvements. We will also include more detailed discussions on related works (such as connection to DoGE, which we discussed above).
To demonstrate how balanced-PiKE promotes fiar learning on many tasks/domains, it would be helpful to also display the range/variance across tasks/domains in Table 1. After writing out the ranges of PPL in table 1 by hand, only then did I convince myself of what balanced-PiKE is doing.
This is a great suggestion! In the revised version, we will add new columns to Table 1 to include this information and to quantitatively demonstrate the reduced performance disparity.
Thank you again for your detailed and thoughtful feedback on our paper. We really appreciated your comments and suggestions. They helped us clarify several key aspects of the work and also motivated new experiments and improvements that we’re excited to incorporate into the revision.
We completely understand that this is a busy time, but just wanted to kindly follow up in case you’ve had a chance to read our response. If you feel that your concerns have been sufficiently addressed, we’d be grateful if you’d consider updating your score. Of course, we completely respect your judgment either way and are very thankful for the time and care you’ve put into reviewing our submission.
Thank you so much for your response. Following up on a few points:
- Mirror descent "accumulation": To clarify this, the update has the form , where is whatever is inside the exp. When you unroll this, you end up with . So when you go from a weight of to , you are adding inside each exp. As grows, the thing already in the exp is quite large, making the relative impact of adding smaller. This is reflected in e.g., DoGE figure 3d, where the weight trajectories start flattening out.
- Can you provide the revised numbers for Table 1 with average loss?
Thank you for the helpful clarification. While we agree that the cumulative terms appear in the exponent, we’d like to clarify that this alone doesn’t necessarily imply that the relative impact of new terms becomes small. Since the update is multiplicative (due to the exponential), even large existing exponents can still be significantly influenced by new additions.
To illustrate this, in your notation, consider the case that and . Then,
.
Now suppose and . Then,
and
. This shows that even for a very large , the weights can still shift dramatically based on the new update, due to the multiplicative nature of the mirror descent formulation.
That said, we typically see that weight trajectories often flatten over time. This is expected behavior, as the mirror descent update is optimizing the objective in the RHS of (8). Hence, once the gradients and variances stabilize (remains constant in the RHS of equation 8), the update naturally converges to the minimizer. But as long as the gradient and variance terms vary meaningfully, the update remains responsive and does not necessarily stagnate.
Can you provide the revised numbers for Table 1 with average loss?
We will update the paper to report average validation loss instead of average perplexity. The requested data is presented below, where the Balanced-PiKE variant's validation loss (2.0445) is comparable to the Mix baseline (2.0444).
Table 1. Average validation loss of 1B GPT-2 (en+hi+de) experiments.
| Method | Average validation loss |
|---|---|
| Mix | 2.0444 |
| Random | 2.0926 |
| Round Robin | 2.0677 |
| FAMO | 2.0418 |
| ADO | 2.0448 |
| PiKE | 2.0830 |
| Balanced-PiKE | 2.0445 |
As noted in our earlier response, Mix is known to be a strong baseline in terms of average loss, though it may not always lead to the best downstream performance. We'll highlight this in our revision, along with the potential reasons for theory-practice gaps mentioned in our previous response (e.g., optimizer mismatch, validation distribution shift, imperfect hyperparameter tuning). We believe these findings open up valuable directions for future investigation.
Finally, we would like to thank you again for your thoughtful feedback. If you feel your main concerns have been addressed, we would be very grateful if you would consider updating your score. We appreciate your time and judgment.
This paper propose the framework of PiKE, aiming to solve the multi-task training in large language models. The authors propose an adaptive data mixing algorithm that dynamically adjusts sampling weights during training, ensuring efficient and effective learning across diverse datasets. with rigorous theoretic analysis, this algorithm can near the optimality. In experiments conducted on pretraining, including the mC4 and GLaM datasets, PiKE demonstrates superior performance in multilingual and multi-domain settings, achieving improved downstream performance compared to existing approaches.
优缺点分析
Strengths:
- PiKE dynamically assign sampling weights for each task based on the gradient norm and variance, leading to better applicability on downstream taks.
- This paper includes formal convergence guarantees and a solid theoretical analysis.
- PiKE outperforms all other baselines in downstream tasks, highlighting its effectiveness in multilingual and multi-domain settings.
Weaknesses:
- The experiments primarily concentrate on the pre-training phase. It would be beneficial to explore how PiKE performs when directly trained on downstream datasets jointly, to assess its applicability in end-to-end training scenarios.
- The hyperparameter significantly influences the results, as evidenced in Table 11. For a thorough understanding of the outcomes, it is crucial to clearly provide the specific value of hyperparameter utilized in the experiments.
- While PiKE shows promise in balancing learning across tasks, its impact on dataset imbalance, particularly for low-resource and high-resource tasks, remains unexplored.
问题
- How does PiKE perform when directly trained on downstream datasets, rather than just during the pre-training phase?
- Could you provide the specific values of the hyperparameters used in the experiments?
- Can PiKE’s adaptive weighting scheme be extended to incorporate dataset size or task-specific downstream importance, beyond just gradient statistics?
局限性
yes
最终评判理由
Weakness 1 (weight 50%). Authors added a concise fine-tuning experiment showing PiKE matches or exceeds manual mixing without hyper-parameter search. This directly answers my primary objection.
Weakness 2 (weight 25%). Authors now provide exact values per setting. Please briefly report the total number of hyper-parameter combinations evaluated and the approximate GPU-hours required for the and grid search. This will help readers gauge reproducibility and computational overhead.
Weakness 3 (weight 25%). Added empirical tracking and theoretical explanation demonstrate PiKE implicitly adapts to dataset size. Issue resolved.
格式问题
no
Thank you for your time and for providing such a detailed constructive feedback. We are encouraged that you found our theoretical analysis solid, our experiments comprehensive, and the effectiveness of PiKE convincing. We address your specific questions below:
Weakness 1: The experiments primarily concentrate on the pre-training phase. It would be beneficial to explore how PiKE performs when directly trained on downstream datasets jointly, to assess its applicability in end-to-end training scenarios.
Question 1: How does PiKE perform when directly trained on downstream datasets, rather than just during the pre-training phase?
While PiKE was developed for pre-training data mixing, we agree that evaluating it in a direct multi-task fine-tuning scenario would help assess its broader applicability. Our experiments in the paper focused on the computationally intensive pre-training phase, but PiKE is designed as a general framework (assuming no gradient conflicts), not limited to any specific training stage.
To address your question, we conducted a new fine-tuning experiment during the rebuttal period:
- Base Model: We used a Llama-1B style architecture pre-trained on 70B tokens from the Fineweb-Edu dataset.
- Task: We fine-tuned the model jointly on a mixture of HellaSwag and PIQA for 2,000 steps, using a learning rate of .
The results are shown in Table 1. While simple static data mixing can be effective during fine-tuning—especially since the lower computational cost allows manual tuning of mixing ratios—our experiment shows that PiKE achieves competitive performance without the need for manual tuning. This experiment demonstrates that PiKE is a viable and effective method for fine-tuning on downstream datasets.
Due to time constraints, we only conducted this experiment during the rebuttal phase. However, this result confirms PiKE's potential use in post-training and fine-tuning stages. It is worth noting that fine-tuning is typically easier and less computationally intensive than pre-training. As a result, reasonable mixing coefficients can sometimes be found by trial and error. Nonetheless, the need for a reliable, adaptive algorithm is more critical during pre-training, which is why our work primarily focuses on that phase.
Table 1. Downstream Performance of different methods for fine-tuning a Llama-1B base model on a mixture of the HellaSwag and PIQA datasets.
| Method | Mean Accuracy | HellaSwag | PIQA |
|---|---|---|---|
| Base Model (No Fine-tuning) | 63.46 | 55.70 | 71.21 |
| Mix (Uniform, Static) | 65.75 | 58.43 | 73.07 |
| PiKE (Adaptive) | 65.94 | 58.37 | 73.50 |
Weakness 2: The hyperparameter significantly influences the results, as evidenced in Table 11. For a thorough understanding of the outcomes, it is crucial to clearly provide the specific value of hyperparameter utilized in the experiments.
Question 2: Could you provide the specific values of the hyperparameters used in the experiments?
Thank you for pointing this out. We reported our hyper-parameter grid in Table 7 in the Appendix. To address your question, we will explicitly mention the result of the hyper-parameter search in our grid for each experiment. In particular, we will include the table below into the experiment section of the revised manuscript:
Table 1. Specific values of and for PiKE across different experiment settings.
| Experiments Setting | ||
|---|---|---|
| en+hi, 270M, PiKE | 0.1 | 0.01 |
| en+hi+de, 270M, PiKE | 0.075 | 0.005 |
| en+hi, 1B, PiKE | 0.05 | 0.005 |
| en+hi+de, 1B, PiKE | 0.1 | 0.01 |
| GLaM, 110M, PiKE (Uniform) | 0.1 | 0.01 |
| GLaM, 110M, PiKE (GLaM) | 0.075 | 0.005 |
| GLaM, 750M, PiKE (Uniform) | 0.15 | 0.01 |
| GLaM, 750M, PiKE (GLaM) | 0.075 | 0.005 |
Weakness 3: While PiKE shows promise in balancing learning across tasks, its impact on dataset imbalance, particularly for low-resource and high-resource tasks, remains unexplored.
Question 3: Can PiKE’s adaptive weighting scheme be extended to incorporate dataset size or task-specific downstream importance, beyond just gradient statistics?
As you mentioned, dataset size plays an important role in determining the optimal mixing strategy. While PiKE does not explicitly use dataset size as an input, it implicitly accounts for it through its gradient-based sampling mechanism.
The intuition is as follows:
- For a low-resource task (e.g., Wikipedia with 3B tokens in GLaM), the model quickly learns the underlying data distribution and begins to converge or overfit. As this happens, the gradient norm for that task decreases, indicating that there is less left to learn.
- Conversely, a high-resource task (e.g., Conv with 174B tokens in GLaM) offers a richer learning signal for a longer period, resulting in a sustained, larger gradient norm.
Since PiKE increases the sampling rate for tasks with larger gradient norms, it naturally reduces focus on low-resource tasks as they converge, and shifts focus to high-resource tasks where more can still be learned. In this way, PiKE implicitly adapts based on how much information remains in each dataset.
To demonstrate this empirically, we tracked the sampling rates PiKE assigned during GLaM pre-training. As shown in Table 2, PiKE initially assigns a reasonable weight to the Wiki dataset. Over time, this weight decreases, while the weight for the larger Conv dataset increases. This dynamic behavior shows that PiKE inherently adjusts for dataset imbalance. We will include this discussion in the paper if it is accepted.
Table 2. Sampling Rate of Wikipedia (3B tokens) and Conversation (174B tokens) during PiKE-based pre-training of a 740M Models with GLaM dataset.
| Iteration Steps (K) | 30 | 60 | 100 | 120 |
|---|---|---|---|---|
| 0.026 | 0.023 | 0.021 | 0.019 | |
| 0.278 | 0.346 | 0.344 | 0.35 |
Regarding your question on incorporating downstream task importance: yes, PiKE's framework can be easily extended to account for task-specific priorities. While we derived fair-PiKE under a specific utility function (-fairness), the framework is modular and general. For example, one can easily introduce a set of user-defined importance weights, , where each represents the downstream importance of task . These weights can then scale the loss for each task, effectively guiding PiKE to allocate more resources to higher-priority tasks, even if their gradient signals are temporarily weak (before scaling).
This illustrates the flexibility of our approach. We will add this important discussion to the paper if it is accepted.
Thank you again for your valuable time and feedback. We hope our response has resolved your concerns and welcome any further questions you may have.
Thank you for your prompt response and for taking the time to read our rebuttal. If you have any questions or lingering concerns, please don’t hesitate to reach out. We would be glad to address them.
Weakness 1 (weight 50%). Authors added a concise fine-tuning experiment showing PiKE matches or exceeds manual mixing without hyper-parameter search. This directly answers my primary objection.
Weakness 2 (weight 25%). Authors now provide exact values per setting. Please briefly report the total number of hyper-parameter combinations evaluated and the approximate GPU-hours required for the and grid search. This will help readers gauge reproducibility and computational overhead.
Weakness 3 (weight 25%). Added empirical tracking and theoretical explanation demonstrate PiKE implicitly adapts to dataset size. Issue resolved.
We are glad that our fine-tuning experiments, along with our empirical tracking and theoretical explanation for dataset size, have addressed your main concerns. We are happy to provide the additional details you requested.
We have reported our hyperparameter grid in Table 7 of the Appendix, which contains six different hyperparameter combinations for PiKE. For the initial round of grid search, we estimate the total computation time was about 36 TPU hours for the 110M model. We agree with you that these numbers will help the reader understand the reproducibility and computational overhead and we will report them in the revision.
It is also worth mentioning that our reported grid search may be conservative, as PiKE can achieve strong performance without much tuning. As shown in our results, the same hyperparameters, such as , work well across a variety of scales and settings. This demonstrates that the hyperparameter choices can often be successfully transferred from smaller to larger scale models.
Thank you again for your detailed constructive feedback, as well as for the being responsive to our messages, and for clear communication throughout.
The paper proposes PiKE - a data sampling approach during foundation model pre-training and proposes "how to effectively mix and sample data from multiple sources". The paper postulates that pre-training LLMs on a big language corpus usually has low gradient conflicts and provides PiKE as a method that leverages this phenomenon. Empirical results on multiple datasets and varied model sizes confirms the viability of PiKE.
优缺点分析
Strengths:
- The paper is extremely well written with sound figures and colors used appropriately, improving readability and understanding. I especially appreciate the box detailing the key features of PiKE and relating to the structure of the paper .
- Empirical results are solid, with clear improvements over baselines. The datasets used are plentiful.
- The paper has the potential to be high impact with many key insights used in practice.
Weakness:
- How do these insights translate into other problem domains, especially more structured tasks? Currently, the evaluations have been performed on simple QA style datasets and multilingual (en/hi/de). I wonder how the results would be on structured understanding, like tabular data, time series, code, etc. Does the gradient non-conflict hold there too?
- How do hyperparameters affect convergence? How do hyperparameters , , affect the performance. I understand that due to high running costs, it is hard to perform a comprehensive empirical study, but an intuitive paragraph might be helpful.
问题
See Weakness.
局限性
Yes
最终评判理由
Good work - I maintain my Accept rating
格式问题
None.
Thank you for your detailed constructive feedback. We are glad that you found our contributions solid and the paper impactful. We also appreciate your recognition of the care we put into presenting the algorithm and results. Please find our detailed responses to your questions below
How do these insights translate into other problem domains, especially more structured tasks? Currently, the evaluations have been performed on simple QA style datasets and multilingual (en/hi/de). I wonder how the results would be on structured understanding, like tabular data, time series, code, etc. Does the gradient non-conflict hold there too?
Thank you for raising this important question about the applicability of our findings to structured problem domains. To explore this, we conducted an additional experiment using code, which is a highly structured data type.
Intuitively, as model size increases, gradients from different tasks are more likely to become orthogonal. This is because in high-dimensional spaces, two (noisy) vectors tend to be nearly orthogonal. To test this idea, we used the final GLaM 740M checkpoint and computed task gradients on the CodeSearchNet dataset. We then measured the cosine similarity between these code-derived gradients and those from the original six GLaM domains.
The results, presented in Table 1, show that the cosine similarities are consistently near zero. This outcome supports our hypothesis, suggesting that the gradient non-conflict property does extend to structured domains such as code.
Table 1: Cosine Similarity of Code vs. six different GLaM Task Gradients at the Final 740M LM Checkpoint.
| Datasets | glam_wiki | glam_web | glam_conv | glam_forums | glam_books | glam_news |
|---|---|---|---|---|---|---|
| Cosine Similarity (Between 1 and -1) | 0.00148 | 0.00097 | -0.01256 | 0.00409 | -0.00213 | -0.01023 |
How do hyperparameters affect convergence? How do hyperparameters affect the performance. I understand that due to high running costs, it is hard to perform a comprehensive empirical study, but an intuitive paragraph might be helpful.
Understanding the role of these hyperparameters is key to the easy implementation of our method in practice. Below we provide an intuitive explanation for the role of these hyper-parameters.
(Number of Domains): Our experiments shows robust performance of PiKE across different numbers of domains (K=2 to 6), consistently outperforming the baselines. Unlike some other existing methods with quadratic computational complexity in , PiKE has linear computational complexity in , which makes it favorable for large values of . However, PiKE is based on the observation that per-task gradients are nearly orthogonal. This observation intuitively holds in large dimentional setting where the number of training parameters is much larger than the number of tasks, i.e., . Hence, we expect that PiKE offers no benefit in training small/medium size models with a huge number of tasks.
(Hyperparameters of the PiKE algorithm): We assume you are refering to in the PiKE algorithm here. As derived from our theoretical analysis, PiKE's data sampling rate is designed to prioritize tasks with a high gradient norm (indicating difficulty) and low gradient variance (indicating training stability). The hyperparameters and control the model's sensitivity to these two terms, respectively.
The hyperparams scales the influence of the gradient norm, effectively prioritizing tasks that the model currently finds more difficult. We conducted a sensitivity analysis for the hyperparameter , which weights tasks by their gradient norm to prioritize those the model currently finds more difficult. Due to computational constraints, these experiments were performed over a limited number of training steps. The results in Table 2 show that our method is robust across different values of .
Table 2. Pre-training 110M GLaM for 20K steps with different and fix .
| 0.08 | 0.09 | 0.1 | 0.11 | 0.12 | |
|---|---|---|---|---|---|
| Mean Accuracy (%) | 37.18 | 37.75 | 37.92 | 37.81 | 37.46 |
The hyperparams penalizes tasks with high gradient variance, encouraging the model to favor tasks with more stable and reliable gradients. When , the variance consideration is removed, and the sampling strategy simplifies to a greedy approach focused only on the most difficult tasks. As shown in Table 11 in the Appendix, this ablation causes mean accuracy to drop from (with ) to (with ). This performance degradation highlights the importance of considering gradient variance in achieving optimal sampling strategy.
All in all, we recommend using the value and for practical use based on our ablation. We will add these discussions on the role of hyperparameters to the paper if the paper gets accepted.
We once again thank you for your time and valuable feedback. We hope that our responses and planned revisions have fully addressed their concerns. Please let us know if you have further questions.
I appreciate the detailed responses. I believe this paper should be accepted.
Thank you again for your thoughtful review and positive feedback on our paper, as well as for your encouraging comment in the rebuttal phase. We’re very glad to hear that you believe the paper should be accepted.
If you feel that the response and clarifications have addressed your concerns sufficiently, and that the paper may merit a higher score, we would be grateful if you’d consider adjusting it accordingly. We understand this is entirely at your discretion and truly appreciate your support either way.
The paper proposes a multi task learning algorithm based on adaptive data mixing, motivated by the observation that gradient conflicts are rare in pretraining. The method connects gradient norm and variance to mixture weight updates via mirror descent, with a balanced variant to ensure fairness across tasks. It is supported by solid theoretical derivations and convergence guarantees, and experimental results demonstrate clear benefits in multilingual and multi domain settings. Reviewers were mostly positive, and while there were concerns about applicability to fine tuning and sensitivity to hyperparameters, they were sufficiently addressed through additional experiments during the rebuttal period. The applicability of this method to LLM training, together with the motivation that it leverages nonconflicting gradients, represents a meaningful advancement in multi task learning, and thus the paper is a clear accept.