Training Dynamics of In-Context Learning in Linear Attention
We theoretically characterize how in-context learning abilities evolve during gradient descent training of linear attention, revealing abrupt acquisition or progressive improvements depending on how the key and query are parametrized.
摘要
评审与讨论
The authors consider the training dynamics of linear attention using a joint query-key matrix vs. separate query matrix and key matrix, when learning to learn in-context linear regression. It was then shown that in the former case, there are two (classes of) fixed points, while in the latter case, there are 2^D cases where D is the dimension of the x-vector. Thus, during training, the former case has one single abrupt drop, while for the latter, multiple smaller drops are present. The authors also show how in the latter case, the training would experience D drops, corresponding to learning the eigen structure of the covariance matrix. Several experiments are provided, which mostly align with the theoretical analysis. The converging time courses are also analyzed for some simplified cases, which also match the experiments.
update after rebuttal
Thanks for the response. I like the results and increased the score to 4.
给作者的问题
- Will non-zero assignments of v_i and u_i change the structure of the fixed points?
- Will using the cumulative loss as in Garg change the training dynamics?
- Will using softmax change the structure of the training dynamics?
I would consider raising the score if the authors can answer these questions satisfactorily.
论据与证据
The empirical studies are well designed and appear to support the theory well.
方法与评估标准
The authors train the linear attention models for the setting, and compare the observed behaviors with the theoretical analysis.
理论论述
Yes, the proofs mostly look correct.
实验设计与分析
Yes, the experiments are described with sufficient details, and support the theory.
补充材料
The appendix, B and C.
与现有文献的关系
N/A.
遗漏的重要参考文献
Sufficiently discussed.
其他优缺点
While I like the results in general, there are several weaknesses inherent in the study and the formulation of the problem:
-
One main concern is the assignments of 0's to the v_i, u_i to the query and key matrices. Though this might be acceptable for the analysis of the eventual optimal solution, it is less acceptable for the analysis of the training dynamics, because they might impact the dynamics in an unpredictable way. Note that in practice, these matrices are usually assigned random values at initialization, instead of being assigned 0.
-
In Garg's work (and some others), the training loss function is the average of the losses over the whole context window \sum_{n=1}^{N}, instead of only the last position N+1. I feel the former is a more accurate reflection of the "in-context learning" concept, while the latter seems unrealistic.
-
The linear attention model allows the analysis to be done more precisely, but also too simplified. The softmax function is important in many settings. Similarly, in practice, a feedforward network is used after the attention mechanism, and multiple attention layers are usually used.
-
While the difference in the training dynamics is interesting, it is not clear whether this distinction is important in practice, because they seem to be unique to the linear-attention, and will likely break with a softmax function, the FF network, and multiple layers.
其他意见或建议
N/A
Thank you very much for your interest in our work and for raising clear, constructive questions. We're glad to hear that you like our results. Below, we present some new experimental results and respond to questions. The numbering follows the reviewer's "Questions for Authors" list.
-
Nonzero
Thank you for this insightful question. The validity of fixed points is not affect by . In other words, setting and the rest of the weights to Equation (9) or (15) results in valid fixed points. This follows from our proof that (i) have zero gradient when they are zero; (ii) the rest of the weights have zero gradient when they satisfy Equation (9) or (15). Therefore, all weights in the full model have zero gradient.
As for dynamics, we agree that in practice, can be nonzero and they evolve in training, affecting the dynamics in interesting ways. We actually discussed this issue in our Appendix F "Training Dynamics of In-Context and In-Weight Learning". In our setup, nonzero can be interpreted as a form of in-weight learning (IWL). In our main setup of sampling the task vector iid from , remain near zero throughout training, and the model develops only ICL, as shown in Figure 9a. However, when a portion of is fixed, the evolution of is non-negligible. In this case, the model first develops IWL, then loses IWL while simultaneously developing ICL, as shown in Figure 9b-e. This competition between ICL and IWL is indeed intriguing. While a full theoretical treatment is beyond our scope, we hope future studies can build on our analysis of the purely ICL dynamics and explore its interaction with IWL dynamics. If the reviewer finds it beneficial, we're happy to move some of Appendix F to main text.
We also note that setting is common in theory literature on ICL in linear [1,2] and softmax [3,4] attention.
-
Cumulative Loss
We have added Figure 14 (see URL) to demonstrate that using the cumulative loss doesn't change the training dynamics, other than the loss values being different. Consistent with using the last-token loss, training with cumulative loss exhibits a single abrupt loss drop, and training exhibits multiple loss drops. In our in-context linear regression setup, training with cumulative loss is effectively equivalent to training with the last-token loss on sequences of varying lengths. Because the linear attention model includes a scaling factor, it can be trained and tested with varying sequence lengths . For a detailed discussion on the scaling, please refer to our rebuttal to reviewer etiN.
It may be worth noting that computing loss for only the last token is also common in the literature, such as in Von Oswald et al. [5] and Zhang et al. [1].
-
Softmax Attention
We have added Figure 10 (see URL) to show that the different training dynamics of linear and linear also occur in their softmax counterparts. Figure 10 follows the same setup as Figures 1–2 for linear attention, with the only difference being adding the softmax activation function for the attention calculation. We observe that softmax exhibits a single abrupt loss drop, whereas softmax undergoes multiple loss drops, separated by phases of conspicuously slower training. This demonstrates that our findings and theoretical intuition are not unique to linear attention but can also extend to softmax attention.
We've added the new figures to our appendix and will use the extra page in final revision to include as many as possible in the main text.
[1] Ruiqi Zhang, Spencer Frei, Peter Bartlett. Trained transformers learn linear models in-context. JMLR 2024.
[2] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra. Transformers learn to implement preconditioned gradient descent for in-context learning. NeurIPS 2023.
[3] Juno Kim, Taiji Suzuki. Transformers Learn Nonlinear Features In Context: Nonconvex Mean-field Dynamics on the Attention Landscape. ICML 2024.
[4] Yu Huang, Yuan Cheng, Yingbin Liang. In-context Convergence of Transformers. ICML 2024.
[5] Johannes Von Oswald et al. Transformers learn in-context by gradient descent. ICML 2023.
-
I understand that many theoretic works assume v_i=u_i=0, but the problem considered here, i.e., the training dynamics, is more complex, and this issue becomes more important. I feel the authors had a good explanation of the issue and have some empirical results as in Appendix F. The theory aspect is not yet satisfactory, but that's understandable. I recommend moving some discussion forward.
-
I also understand that many theoretic works assume last-token ICL, but it is still not quite consistent with how we usually train transformers. It is true that "training with cumulative loss is effectively equivalent to training with the last-token loss on sequences of varying lengths", however, in the former, various parameters will have to balance between early token predictions (when not enough in-context demonstrations are available) and latter token predictions (when enough are available). I'd like to hear the authors' findings and/or explanation on how these conflicting requirements will play a role in the setting studied here, vs. the last-token only training.
-
A recent work by Nichani et al. analyzed the dynamics with the Softmax function. My question is more on the theoretic analysis with this softmax, instead of empirical observations. However, I accept the authors' explanation and find it sufficient.
I'm not sure ICML allows URLs, but the linked figures do not load. Please provide some more detailed verbal descriptions.
Thank you for your attentive engagement in the rebuttal process.
-
We appreciate that you raised the point that our question, i.e., the training dynamics, is more complex than many theoretic works, which we agree. We are glad to hear that you accepted our explanation, and will move some of Appendix F to main text.
-
You're correct that training with cumulative loss or with last-token loss on sequences of varying lengths raises the question of balancing early token predictions and later token predictions. We present here an analytical analysis.
Based on our derivations in Appendix C.4, the converged model implements
To consider varying context lengths , we denote the distribution of context lengths as . We now need to compute the expectation in Equation over the distributions not only of but also , which yields
Therefore, the context lengths only influence the converged model through the expectation . For a fixed context length, , which recovers Equation in our paper. For cumulative loss, follows a uniform distribution over . The expectation over a uniform distribution is the harmonic number divided by , which doesn't have a closed-form expression but can be easily computed for a specific finite .
Similarly for , the fixed point condition in Equation becomes
where the term was simply for the fixed context length case.
We'll add this new analysis into the paper. It's indeed an interesting and worth-mentioning fact that the varying context lengths only influence our results through the expectation .
The URL has been fixed. We're sorry that it was down for a while. ICML author guidelines indicate anonymous URLs are allowed, though reviewers aren't obliged to click.
This paper provides a theoretical analysis of how linear attention models acquire in-context learning abilities through gradient descent training on an in-context linear regression task. The authors study two parameterizations of multi-head linear attention. In the merged key-query setting (ATTN), they show that the model has exactly two fixed points—a zero solution and a global minimum manifold—and that small-initialization training follows a single abrupt drop in loss before converging to a least-squares in-context regression solution. The authors also derive an analytical time-course in the special case of an identity covariance matrix and infinitesimal initialization. By contrast, in the separate key-query setting (ATTN), they show that the model admits exponentially many fixed points corresponding to different subsets of eigenvectors of the token covariance. In this scenario, the training dynamics involve multiple abrupt loss drops separated by plateaus, each plateau corresponding to learning one additional principal component in context. Although the final solution again approximates least-squares regression, early stopping yields principal component regression on a partial subset of components.
给作者的问题
- I want to ask why the authors chose to scale the model's output by . Wouldn't the training procedure naturally allow the weights to adapt to the sequence length without that additional scaling
- Although the Section 2.4 has the low rank matrix assumption to match the realistic multi-head architecture, the Section 2.3 omits that assumption. Why the difference? Doesn't this setup than diverge from established multi-head structure of transformers? Because in this case, you're effectively repeating the same input with different parameters times. Moreover, I think this also might hurt the equivalence to the two-layer neural net since this forces to be low rank. Whereas, this does not naturally occur with the random initialization.
- I'm also a bit confused by the limit with respect to in line 246. As I said, introducing is artificial and we could also instead scale and by . This change supposed to not affect the input covariance , yet, it seems that when we remove from the line 246, it affects the result, why? This also seems to effect the Equations (18) and (21).
- In your analysis of ATTN, you assume eigenvalues are distinct and that each head learns one eigenvector sequentially. How important is this assumption in your proofs? How does the training behavior change if has multiple identical top eigenvalues? Is it possible for several heads to learn these degenerate eigenvectors in parallel, and does that merge or skip some of the plateaus?
论据与证据
The paper’s primary claims that multi-head linear attention with merged key–query exhibits an abrupt loss drop while separate key–query parameterization shows stage-wise, saddle-to-saddle dynamics are supported both by mathematical derivations and simulation results. The evidence is convincing within the assumed setting (small initialization).
方法与评估标准
The authors’ method is essentially to analyze multi-head linear attention in two parameterizations (merged vs. separate key/query) under gradient-flow training on an in-context linear regression task. The evaluation criteria revolve around loss trajectories under gradient flow, an the limiting in-context learning algorithm (whether the model replicates least-squares regression or principal component regression). Given that the paper’s focus is on theoretical understanding, the chosen tasks (toy in-context linear regression) and the analysis of training trajectories are appropriate.
理论论述
The authors’ method is essentially to analyze multi-head linear attention in two parameterizations (merged vs. separate key/query) under gradient-flow training on an in-context linear regression task. The evaluation criteria revolve around loss trajectories under gradient flow, an the limiting in-context learning algorithm (whether the model replicates least-squares regression or principal component regression). Given that the paper’s focus is on theoretical understanding, the chosen tasks (toy in-context linear regression) and the analysis of training trajectories are appropriate.
实验设计与分析
I think that the experimental design is well executed. The loss trajectories and weight evolution curves match the analytical results and confirms the soundness of the analysis. One potential limitation is that the experiments focus on relatively small and values; it would be valuable to see the experiments in larger settings.
Other than that, some additional things that might worth trying are the cases where the covariance matrix has repeated eigenvalues and check whether multiple heads learn in parallel or whether the training still unfolds in discrete stages. Also, the paper’s analysis hinges on small initialization, it would be instructive to demonstrate how differently scaled inits affect the plateau structure (especially in ATTN).
补充材料
I skimmed through the proofs in appendix, see Theoretical Claims section.
与现有文献的关系
This paper draws on recent work on linear attention (e.g., Von Oswald et al., 2023; Zhang et al., 2024a,b) and connects the phenomenon of abrupt in-context learning during training transformers to a analytically tractable setting. It benefits from the results in deep linear networks (Saxe et al., 2014; Woodworth et al., 2020) to show how parameterization (merged vs. separate key-query) yields distinct training behaviors of a single abrupt loss drop or multiple progressive plateaus. Overall, the work extends in-context linear regression analyses to a multi-head linear attention setting and also connects to known convergence results from linear networks (Arora et al., 2019; Shamir, 2019).
遗漏的重要参考文献
One specific omission is the “Transformers as Statisticians: Provable In-Context Learning with In-Context Algorithm Selection” (Bai et al., 2023), which frames multi-head attention learning as algorithm selection. Even though it's not a convergence proof, it's still relevant to how, in the rank-1 case of this paper, individual heads each specialize to a different principal component and captures separate directions in the data.
其他优缺点
Overall, I think that the paper is well-written and presented. One of the paper’s main strengths is its rigorous theoretical analysis that clarifies how different parameterizations affect the training dynamics of multi-head attention models. Besides, the paper well-integrates the prior results on linear networks. The experiments complements the theoretical discussions. The connection shown in the paper between progressive eigenvector learning and in-context abilities is also instructive.
That said, I believe the paper needs more justification on why the model they use for multi-head linear self-attention in the case of merged key-query matrices is appropriate. There's frequent connecting discussions which uses the limit of and I believe equivalence to the model which doesn't normalize the output by should be discussed. Also, some discussions or experiments on how deviations from the idealized assumptions (such as small initialization or white input covariance) might change the results could be helpful.
其他意见或建议
- This is something minor but the matrix is not defined before being referenced in Section 2.1.
- Again something minor, the paper’s main focus is on analyzing linear attention, yet Section 2.2 provides a relatively detailed introduction to the multi-head softmax attention model, which is not used in later analysis. Why did the authors choose to present the standard softmax-based model in detail rather than only mentioning it briefly?
Thank you very much for a very detailed and thoughtful review. We're glad to know that you find our findings instructive, rigorous, and well-presented. We'd like to present some new experimental results and respond to questions. All of our added new figures can be found at URL.
-
Rationale for Scaling
We follow the scaling choice in seminal theory works on linear attention trained for in-context linear regression [1,2]. You're correct that the model can be trained to adapt to the sequence length without the scaling. However, including the scaling factor grants the model the flexibility to be trained and tested with varying context lengths . We've added Figure 14 in which the models are trained on sequences of varying lengths -- this would be unfeasible without the scaling. To explain this point, consider that computes
The blocks in contain the term , which diverges as increases. In contrast, the blocks in contain , which remains bounded as increases. Including the scaling allows the model to process sequences of varying lengths without needing to adapt weights specifically for each .
"we could also instead scale and by ..."
You're correct that we can either scale the by or scale by . However, in both case, we arrive at the same expression in line 246, Equations (18) and (21). In our case, the term is present because of the we build into the model. In your case, the term is also present because are scaled by . Thus, line 246, Equations (18) and (21) are not affected by whether the scaling is built in the model or adapted through training.
Thank you for this constructive feedback. We'll incorporate this more detailed explanation for the scaling in revision.
-
Why Full-Rank in ?
We agree that diverges from the established multi-head attention structure in practical transformers. Our intention in studying both and the low-rank is that: is widely used in theoretical literature on transformers (see a list of references in lines 60-65) and low-rank is closer to the parametrization in practical transformers. Given that a significant portion of the theoretical literature has used , we believe that analyzing it can be of independent interest to transformer theorists, if not to practitioners.
Your insight regarding that the multiple heads in are somehow "repetitive" or redundant is also right. In other words, the multi-head is overparameterized. While overparameterization seems unnecessary in terms of expressivity, its effect on training dynamics and convergence properties is a fundamental, long-standing question in deep learning, and is worth studying. Perhaps surprisingly, we find that the weights in different heads of stay parallel in training, as shown in Figure 4 and Appendix C.7. This provides a hint at how overparameterized attention models learn and generalize.
-
Effect of Initialization Scale
We've added Figure 11 for this. Please refer to our rebuttal to reviewer Hj7w for details (we reached space limit for this rebuttal).
-
Multiplicity of Eigenvalues
Thank you for this interesting question. We have added Figure 12 in which has repeated eigenvalues. Six runs from different random initialization are plotted. Linear with rank-1 key and query may still exhibit plateaus when learning the repeated eigenvalues due to the different random initial weights in each head. It may also skip the plateau for certain random seeds.
In the case with distinct eigenvalues, the plateau duration is determined by both the size of the eigenvalue and the random initialization. In the case with equal eigenvalues, the plateau duration is determined by only the random initialization.
-
Simulations with Larger
We have added Figure 13 in which we train and on a dataset with . The loss trajectories are qualitatively similar to those in lower-dimensional cases, despite being noisier.
-
We'll add the reference “Transformers as Statisticians" (Bai et al., 2023). It's indeed an omission -- thanks for your note!
-
We'll cut back the introduction for softmax attention in Section 2.2 and clarify the subscript in denotes the index of a training sample in Section 2.1. Thanks for your suggestion.
[1] Ahn, Cheng, Daneshmand & Sra. NeurIPS 2023. arXiv:2306.00297
[2] Zhang, Frei & Bartlett. JMLR 2024. arXiv:2306.09927
I want to thank the authors for their rebuttal. My concerns are satisfactorily addressed. I will maintain my positive score.
Thank you very much for your engagement in the rebuttal process, and for your thoughtful review. We are glad to hear that your concerns have been satisfactorily addressed.
If you feel it is appropriate, we would be grateful if you might consider raising the score.
This paper investigates the training dynamics of in-context learning (ICL) in multi-head linear attention models trained on in-context linear regression tasks. The authors examine two parametrizations: one with merged key and query weights (ATTNM) and one with separate key and query weights (ATTNS). The paper provides a theoretical explanation of how ICL abilities evolve during gradient descent training, revealing either abrupt acquisition or progressive improvements depending on the parametrization of the key and query matrices.
给作者的问题
How sensitive are your findings to the choice of initialization scale? The paper demonstrates that ATTNM is equivalent to a two-layer fully-connected network with cubic features. Could this have a practical implication on transformer architectural choice ?
论据与证据
The authors provide:
- mathematical analysis of fixed points in the loss landscape
- precise characterization of gradient descent trajectories
- analytical solutions for special cases
- simulations that match their theoretical predictions
方法与评估标准
The authors rely on
- a well-established ICL regression task
- empirical validation through simulations that match theoretical results
理论论述
I did not check the proofs of the theoretical claims
实验设计与分析
The simulations given in the paper cover
- training simulation of both ATTNM and ATTNS models with different initializations
- tracking of the evolution of weights and loss during training
- comparison of empirical results with theoretical predictions
补充材料
I only skimmed through the supplementary material
与现有文献的关系
The paper is in the intersection of several research topics including:
- work showing abrupt emergence of ICL
- work on the optimization properties of attention
- work on the stage-wise dynamics of transformers
遗漏的重要参考文献
The paper covers thoroughly the relevant literature.
其他优缺点
Strengths: -The work bridges optimization theory and the emergent capabilities of transformer models
- The characterization of saddle-to-saddle dynamics in ATTNS offers a clear explanation for progressive ICL acquisition
Weaknesses:
- The analysis is limited to linear attention, and it's unclear how much of it generalizes to softmax attention used in practical transformers The paper focuses on ICL regression tasks, which are simpler than the complex language tasks where ICL is typically observed
其他意见或建议
It would be valuable to include a discussion on how the theoretical insights might translate to practical training recommendation
Thank you very much for your interest in our work and for a thoughtful review. We appreciate how you positioned our contribution at the intersection of three very important and relevant topics, and we’re glad that you found our results provide a clear explanation for progressive ICL acquisition. Below, we present some new experimental results and respond to questions.
-
Softmax Attention
We have added Figure 10 (see URL) to show that the different training dynamics of linear and linear also occur in their softmax counterparts. Figure 10 follows the same setup as Figures 1–2 for linear attention, with the only difference being adding the softmax activation function for the attention calculation. We observe that softmax exhibits a single abrupt loss drop, whereas softmax undergoes multiple loss drops, separated by phases of conspicuously slower training. This demonstrates that our findings and theoretical intuition are not unique to linear attention but can also extend to softmax attention.
-
Effect of Initialization Scale
We'd like to clarify that many of our results do not rely on initialization: the fixed points, the connection to MLP, and the ICL algorithm we identified for the converged network. Then, we analyze the training dynamics under small initialization.
We've added Figure 11 (see URL) to demonstrate the effect of initialization scale on dynamics. For , increasing the initialization scale shortens the plateau before the single abrupt loss drop. For , increasing the initialization scale shortens all the plateaus between any two successive abrupt loss drops. For both models, the loss trajectories with the largest initialization exhibit an exponential decay shape, which is the hallmark of lazy learning (NTK regime). In existing theoretical literature, the standard choice of initialization is either the large initialization of exponential-shaped lazy learning, or small initialization of rich feature learning as in our paper. The practical initialization scheme is usually somewhere in between. In Figure 11, we see that dynamics from the intermediate initialization seems like a mix of the exponential-shaped and the abrupt sigmoid-shaped curves. Such mixed curves are often seen in practice, e.g. in induction head emergence in natural language settings (Olsson et al. 2022; Argument 1). Our dynamics analysis focuses on the rich learning regime and provides analytical insight into such phenomena, which we believe is a first step toward understanding dynamics in naturalistic settings.
-
Practical Implication
Given that the scope of our paper is theoretical, we believe our findings and intuition can inform practice by providing a deeper understanding of common phenomena observed in practical settings -- such as abrupt learning and the progressive acquisition of ICL. Though not directly resulting in practical recommendation, we hope our work provides a foundation for future work along the spectrum from theory to practice, ultimately leading to more effective architectural and optimization choices.
Regarding the implication of the equivalence between and MLP with cubic features, it can add a perspective to the open question of whether MLPs can learn in-context (Boix-Adsera et al [1]; Tong and Pehlevan [2]). Moreover, we reveal that MLP may perform ICL more comparable to attention models when provided with polynomial features instead of the original sequence. This may explain why Boix-Adsera et al (Figure 25) observed that MLP fails to learn ICL with the original sequence as input, but succeeds when the input is augmented with . We'll add a discussion of this point in revision.
[1] Boix-Adsera, E., Saremi, O., Abbe, E., Bengio, S., Littwin, E., & Susskind, J. M. When can transformers reason with abstract symbols? ICLR 2024.
[2] Tong, W. L., & Pehlevan, C. MLPs Learn In-Context on Regression and Classification Tasks. ICLR 2025.
We've added the new figures to our appendix and will use the extra page in final revision to include as many as possible in the main text.
The paper investigates the theoretical understanding of gradient descent training dynamics for multi-head linear self-attention models performing in-context linear regression tasks. It analyzes two parametrizations of linear self-attention: one where key and query matrices are merged (ATTNM), and another with separate key and query matrices (ATTNS). For the merged model (ATTNM), the authors identify two fixed points in training dynamics and find an abrupt loss drop, analytically solved under specific conditions. In the separate model (ATTNS), exponentially many fixed points are found, with the loss trajectory showing saddle-to-saddle dynamics where the model progressively learns principal components of the data. Ultimately, both models implement approximate least squares regression in-context upon convergence, providing insights into how parameterization influences the evolution of in-context learning (ICL) abilities during gradient descent training
给作者的问题
Generalizability Beyond Linear Tasks: The paper thoroughly analyzes linear attention trained on linear regression tasks. Have you investigated whether similar saddle-to-saddle dynamics and progressive improvement phenomena also occur for non-linear tasks or in practical transformer models with softmax attention?
Effects of Initialization Conditions: The analytical derivations assume very small (or infinitesimal) initialization scales. How sensitive are your theoretical results (fixed points, saddle-to-saddle dynamics) to larger or more realistic initialization conditions often used in practice?
论据与证据
The claims made in the submission are supported by clear and convincing evidence. Specifically, the authors provide detailed theoretical analyses, derivations, and simulation results that align closely with their main theoretical claims. Here’s a breakdown of how effectively the claims are supported:
Claims Supported by Strong Evidence:
-
Existence of Fixed Points:
- Evidence: The authors derive precise analytic forms for fixed points in both the merged (ATTNM) and separate (ATTNS) parametrizations. The existence of these fixed points is rigorously demonstrated mathematically (e.g., equations provided explicitly such as (9a)-(9b), (15)-(17)), providing strong support.
-
Abrupt Loss Drop (ATTNM):
- Evidence: Both theoretical analysis (with analytical solutions, such as equation (10)) and detailed simulations (shown in Figure 1) clearly show the existence and nature of the single abrupt drop in loss. The analytical solutions are explicitly derived, and the simulations match these theoretical predictions.
-
Saddle-to-Saddle Dynamics (ATTNS):
- Evidence: Authors convincingly illustrate through detailed analytic derivations and numerical simulations (Figure 2) the sequence of abrupt loss drops separated by plateaus. Their scalar ordinary differential equation approximation (equation (20)) closely matches simulated trajectories, reinforcing their claims.
-
ICL Algorithm Interpretation (Principal Component Regression):
- Evidence: The authors explicitly derive the algorithm implemented at each fixed point, demonstrating how the model progressively learns principal components (equations (18), (21)). Simulations confirm the model aligns with eigenvectors sequentially during training (Figure 2c).
-
Impact of Parametrization (Rank of Key-Query Matrices):
- Evidence: Simulation results (Figure 3) convincingly demonstrate the effects of different ranks (R = 1, 2, 4, 8) on the dynamics and duration of loss plateaus, supporting claims about parametrization affecting training dynamics.
Potential Minor Limitations (though not significantly problematic):
- Simulation Scope:
The numerical simulations strongly align with the theory, yet they predominantly focus on low-dimensional cases (e.g., dimension ). Although suitable for clear visualization and validation, higher-dimensional simulations could strengthen the general applicability of their results.
In summary, all key claims in this paper are backed by clear, precise, and convincing evidence.
方法与评估标准
n/a. theory paper, not so many methods and evaluations
理论论述
Yes. They all appear reasonable to me.
实验设计与分析
n/a
补充材料
n/a
与现有文献的关系
The paper "Training Dynamics of In-Context Learning in Linear Attention" contributes to the broader scientific literature by providing a theoretical analysis of how multi-head linear self-attention models develop in-context learning abilities through gradient descent training. This work builds upon and extends existing research in several key areas:
-
In-Context Learning in Transformers: The study aligns with previous findings that transformers can perform in-context learning, adapting to new tasks based on input data without explicit parameter updates. Notably, it complements empirical observations by offering a theoretical framework that explains how these capabilities emerge during training.
-
Training Dynamics and Fixed Points: The identification of fixed points in training dynamics and the characterization of abrupt loss drops or saddle-to-saddle dynamics provide a deeper understanding of how self-attention models learn. This insight is valuable for designing more efficient training protocols and architectures.
-
Parametrization Impact: By analyzing different parametrizations (merged vs. separate key and query matrices), the paper highlights how architectural choices influence learning dynamics. This perspective is crucial for developing models that balance computational efficiency with learning efficacy.
遗漏的重要参考文献
n/a
其他优缺点
n/a
其他意见或建议
n/a
We thank the reviewer for a detailed review and for raising constructive questions. We're glad to know that you think our claims are backed by clear, precise, and convincing evidence. Below, we present some new experimental results and respond to questions.
-
Softmax Attention
We've added Figure 10 (see URL) to show that the different training dynamics of linear and linear also occur in their softmax counterparts. Figure 10 follows the same setup as Figures 1–2 for linear attention, with the only difference being adding the softmax activation function for the attention calculation. We observe that softmax exhibits a single abrupt loss drop, whereas softmax undergoes multiple loss drops, separated by phases of conspicuously slower training. This demonstrates that our findings and theoretical intuition are not unique to linear attention but can also extend to softmax attention.
-
Effect of Initialization Scale
We'd like to clarify that many of our results do not rely on initialization: the fixed points, the connection to MLP, and the ICL algorithm we identified for the converged network. Then, we analyze the training dynamics under small initialization.
We've added Figure 11 (see URL) to demonstrate the effect of initialization scale on dynamics. For , increasing the initialization scale shortens the plateau before the single abrupt loss drop. For , increasing the initialization scale shortens all the plateaus between any two successive abrupt loss drops. For both models, the loss trajectories with the largest initialization exhibit an exponential decay shape, which is the hallmark of lazy learning (NTK regime). In existing theoretical literature, the standard choice of initialization is either the large initialization of exponential-shaped lazy learning, or small initialization of rich feature learning as in our paper. The practical initialization scheme is usually somewhere in between. In Figure 11, we see that dynamics from the intermediate initialization seems like a mix of the exponential-shaped and the abrupt sigmoid-shaped curves. Such mixed curves are often seen in practice, e.g. in induction head emergence in natural language settings (Olsson et al. 2022; Argument 1). Our dynamics analysis focuses on the rich learning regime and provides analytical insight into such phenomena, which we believe is a first step toward understanding dynamics in naturalistic settings.
-
Simulations with Larger
Thank you for this suggestion. We've added Figure 13 (see URL) in which we train and on a dataset with . The loss trajectories are qualitatively similar to those in lower-dimensional cases, despite being noisier. This suggests that our findings do not break in high-dimensional settings.
We've added the new figures to our appendix and will use the extra page in final revision to include as many as possible in the main text.
This paper studies the gradient dynamics of multi-head linear self-attention trained for in-context linear regression and finds that two different parameterizations lead to different evolution of ICL ability. The reviewers agreed upon positive ratings. Overall I recommend accept.