Attention layers provably solve single-location regression
We introduce single-location regression, a novel regression task, and show that attention solves this task.
摘要
评审与讨论
This work examines the transformer mechanism for solving the single-location regression task within a simplified non-linear self-attention setting. It provides a Bayes optimality guarantee for the oracle network predictor and demonstrates how the training dynamics allow the network to converge to this solution using projected gradient descent.
优点
-
The mathematical analysis from both statistical and optimization perspectives is insightful, offering a comprehensive theoretical guarantee.
-
The paper is well-written and easy to follow.
-
Experiements are provided to support the theoretical findings.
缺点
My main concern lies in the motivation for studying this specific setting, characterized by token-wise sparsity and internal linear representations. While these features are relevant to real-world NLP scenarios, the essential mechanism for transformers to succeed in this setting is, first, to identify the underlying structure (the latent sparse pattern in this work) and then to perform a task-specific operation on this structure (linear transformations here). This intrinsic mechanism has been extensively explored across various settings (e.g., [1-3]), even with more realistic softmax attention. Thus, the technical significance of studying a simplified attention model in such a specific setting remains unclear. Could the authors elaborate on the specific technical difficulties encountered in this setting? What are the key technical challenges in extending the analysis to more realistic softmax attention?
[1] How Transformers Learn Causal Structure with Gradient Descent. Nichani et al., 2024
[2] Vision Transformers provably learn spatial structure. Jelassi et al, 2022
[3] Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot. Wang et al., 2024
问题
-
What is the specific convergence rate in Theorem 5? Since this forms a key part of the contribution in characterizing the training dynamics, a more detailed presentation in the main paper would be beneficial.
-
In line 334, why is the entire sum on the order of ?
-
The PGD analysis relies on a relatively strong assumption that the initialization lies on the manifold. Could the authors discuss any possibilities for generalizing the current analysis to accommodate a broader range of initializations?
Answer to weaknesses: Even in the setting we consider, which is indeed simpler than practice, the analysis already involves non-convex optimization dynamics, which requires advanced mathematical tools, mainly from dynamical systems theory. We refer to the proof sketch in Appendix A, which details the main technical steps and difficulties of the proof. Regarding softmax, we refer the reviewer to the general rebuttal for a discussion.
We thank the reviewer for the references. The article [2] is already cited in our paper, but we were not aware of [1] and [3], and we will cite them in the next version. While we agree with the reviewer that from a high-level perspective our approach is similar to [1]-[3], we believe that there are significant technical differences between our setting and theirs. In particular, this makes it far from obvious that their proof techniques for handling softmax could be directly borrowed in our case. We will discuss these papers in the next version, and provide some elements at the end of this rebuttal in case the reviewer is interested.
Q1: Our proof technique relies on the use of the Center-Stable Manifold Theorem, a tool from dynamical systems theory. This tool does not provide quantitative rates of convergence. Obtaining a rate is a tricky matter because it requires quantifying the distance to the saddle points of the risk (since the dynamics is slower near saddle points), which in turn requires other tools of analysis and potentially additional assumptions. We will include this discussion in the next version. More generally, as mentioned above, Appendix A provides a 2-page proof sketch (that unfortunately does not fit within the page limit of the main paper).
Q2: Informally, by the Central Limit Theorem (CLT), summing independent and zero-mean random variables with standard deviation gives a sum of the order of . As noted in the paper, the paragraph you refer to gives an informal intuition, while the proof provides a rigorous justification. We will mention in line 334 that the magnitude comes from the CLT.
Q3: A possible direction to handle a general initialization is discussed in lines 500-505 of the article. In summary, the idea is to show that the manifold is stable in the sense of dynamical systems (i.e., if the dynamics are close to the manifold at a given time, they remain close to the manifold). Our numerical experiments suggest that such a generalization should hold, and that the temperature parameter should play a key role in this analysis.
Comparison with [1]-[3]:
- Solving the task in [1] requires learning a fixed latent causal graph over the positions of the tokens, making positional encodings a critical element of the analysis. In contrast, our task is invariant under permutations of the tokens. Moreover, in [1], the output is expressed as a function of the last token, with the previous tokens providing the necessary context for this computation. In our setup, however, the output depends on a token whose position varies and must be identified within the context.
- In [2], the argument of softmax (i.e., a matrix ) is directly a parameter of the model, instead of being a product of the data and of some parameters. This is a radically different structure from the usual attention, and from our setup, where the data appear in the nonlinearity .
- The closest work to our setup is the recent paper [3], which also incorporates a notion of token-wise sparsity: the output is computed as the average of a small subset of tokens, where the subset is identified by comparing the positional encodings of each token with that of the last token. A key difference is that, in our setting, the tokens also encodes an output projection direction () on top of the information on the position of the informative token (). In other words, our task involves learning a linear regression in addition to identifying the relevant token, which is not the case in [3].
Thank you for the author's efforts in their response. I have no further concerns and will maintain my score.
Thank you for acknowledging the rebuttal and your positive assessment of our paper.
This paper introduces a new task, the single location regression task, showing it is solvable by a predictor resembling an attention layer, when a linear predictor fails. The result is theoretically well grounded and of good significance given the limited theory on attention layers and their striking efficacy. This task relates to key famous NLP problems known in the existing literature, serving as a good testbed for studying Transformers through a theoretical lens. The statistical performance of this attention-like optimal predictor is shown to exceed the one of an optimal linear predictor, under minimal assumptions. (projected) Gradient descent, is demonstrated to reach the optimal solution, wiith experiments provided to support the theory.
优点
- The paper is extremely well written and very easy and pleasant to read.
- The analysis is theoretically strong and rigorous
- More generally, this work promotes an approach that is worth being acknowledged and valued: looking into a simpler problem than the ones practitioners can face, yet relevant, and solve it completely and rigorously.
- Additionally, the problem is well connected to practical concerns, with the authors made a convincing case for the significance of their analysis.
缺点
By decreasing order of importance
- The present work’s approach is not so well connected to the existing literature in line 103: "note that our task shares similarities with single-index models (McCulllagh & Nelder, 1983) and mixtures of linear regressions (De Veaux, 1989)". I see the differences between those works and the present one being hightlighted in the following sentence , but the exact nature of these similarities is not clear. Could this connection be elaborated?
- Minor: In the caption of figure 2, it may be helpful to add a note about the size of the squares, that presumably indicates the level of alignment.
- line 234: "We emphasize that empirically this simplification of softmax using a component-wise nonlinearity has been shown not to degrade performance". The cited paper Wortsman et al, 2023 indeed indicates such behaviour but their experiments include a normalisation.
- Broadly speaking, the question of whether softmax is needed in attention layers remains an open and unresolved one in the commuinity. To avoid overstating the case, consider rephrasing to reflect this ongoing debate. Note that I don’t think having used erf in your analysis diminushes the value of your work in any way.
问题
- Could you point out (and explain) the steps in your proofs that would break if softmax were considered instead of erf. My guess is that preserving the independence is key (via elementwise application) but could sigmoid or any other nonlinear bounded increasing and differential elementwise activation work as well? It might be helpful to mention early on that the analysis could extend to a broader class of activations if that’s the case (modulo adjustments of the formulae in theorem 1).
- Linear predictors are shown to fail at solving the task and the comparison to them is much appreciated. Do you have a sense of how non linear predictors would in turn perform, thus this could show how attention layers are to be preferred to fully-connected layers for instance in such contexts (an analogue of proposition 3 in this case may not be true and probably non trivial to show, but in any case, interesting to discuss or to investigate).
- Are the tokens assumed to be independent (line 74) or independent conditionally on J_0 (line 85)? I did not read the proofs so I couldn’t figure it out myself but would be good to clarify. More generally, could you explain at which points of your proofs the independence is needed to help readers understand how this assumption could be relaxed for future works concerned with more realistic scenarios.
- Minor: line 1832 "Our code is available at [XXX]" in the appendix.
- Extra minor: line 469: a space is missing between "Figure 4a" and "(right)"
W1: This is a good point, and we agree with the reviewer that there is no direct connection. However, we felt it was important to mention these models because they have some similarities to our approach.
-
In De Veaux (1989), the conditional distribution of output given input follows a mixture model, expressed as . The connection to our model lies in the marginal distribution of each token, which in our case follows a Gaussian mixture and .
-
In McCullagh & Nelder (1983), single-index models correspond to generalized linear models, represented as , where is an unknown regression function learned with a nonparametric model (e.g., a neural network). A slight generalization is the multi-index model, where is a matrix rather than a vector. In our case, the input is not a vector but a sequence of vectors. Nevertheless, the output is computed as a nonlinear function of the projection of onto the directions and . Thus, our model can be expressed as a specific instance of a multi-index model, where all tokens are stacked into a single large vector , , and with a specific link function .
W2 and W3: Thank you for pointing this out. We agree that the role of softmax remains an open question and will reformulate our discussion to avoid overstating its importance. We also acknowledge the critical role of normalization (even in the presence of softmax) and will emphasize this aspect more thoroughly, especially when referring to Wortsman et al. (2023). The importance of normalization is also evident in our experiments when the initialization is outside the manifold: the normalization parameter plays a central role in learning the parameters and .
Q1: The entire structure of the proof would break if softmax were used; see the common rebuttal for a specific example. The reviewer is correct in pointing out that the choice of erf over another nonlinear, bounded, increasing, equal to 0 at 0, and differentiable function is primarily for technical reasons. Specifically, erf is used because closed-form formulas exist for the expectation of erf and its derivatives applied to Gaussian random variables (see, e.g., Lemma 18). In principle, however, the behavior should remain similar for any such nonlinearity. We will include a discussion of this point in the next version.
Q2: You raise an interesting point. Indeed, one could imagine an MLP designed specifically for this task, where the weights have a diagonal structure with respect to the sequence index. In such a setup, the first layer could learn the projections along and , while subsequent layers could learn the link function (see above). However, this architecture is far from resembling those used in practice (and arguably halfway between attention and a standard MLP). If we do not assume a diagonal structure and instead use traditional MLPs, the number of parameters must scale at least linearly with the sequence length, which is highly suboptimal and may lead to very slow training. This highlights the efficiency of attention layers, which perform single-location regression with a fixed number of learnable parameters, independent of the input length. We will add a discussion on this matter.
Q3: The tokens are indeed independent conditionally on ; thank you for pointing this out. This independence is mainly used to simplify the cross-token interaction terms in the risk evaluation, especially in the proof of Lemma 6 (see e.g. line 994). We will include a mention of this in the next version of the paper.
Q4 and Q5: Thank you for the minor remarks. We will include the link to the code in the de-anonymized version of the paper.
I thank the authors for their responses and maintain my positive recommendation for their work, which I believe to be a good addition to ICLR. I hope this research inspires some conference attendees to delve deeper into Q2.
Thank you for acknowledging the rebuttal and your positive assessment of our paper.
This paper studies the ability of attention mechanisms to deal with token-wise sparsity and internal linear representations. In order to demonstrate the capability, the paper proposes a simplified version of a self-attention layer to solve the single-location regression task, showing an asymptotic Bayes optimality and analyzing training dynamics.
优点
The paper is well-written and easy to follow. A novel task called "single-location regression task" is introduced to satisfy the sparsity of the token and model real-world tasks to some extent. Despite the non-convexity and non-linearity, the paper is able to analyze the training dynamics and show the asymptotic Bayes optimality.
缺点
- The proposed task may be over-simplified and lack generality. For instance, it assumes that the tokens other than have zero mean and only one token contains information.
- The paper shows the connection to a single self-attention layer by using the assumption that . Although the low-rank property may come true after the training process, it is so strong to make this assumption directly.
- It is uncommon to use the function to replace the softmax function. To demonstrate the feasibility of this simplification, more explanations or theoretical backups should be provided.
- The paper shows the asymptotic results of , while a non-asymptotic result is needed to investigate the convergence rates of these two parameters.
- The initialization is limited to the specific manifold, which is better to extend to a more general one.
- The current experiments only validate the theoretical results on synthetic datasets. It is recommended that the authors consider adding some experiments on real datasets to test the effects.
问题
See weaknesses.
We agree with the reviewer that our model is simplified with respect to practical settings, but our goal in this paper is precisely to present and analyze a simplified model in order to understand real-world phenomena in a tractable setting.
W1. The interesting but more complex case where the information is shared across multiple tokens instead of just one is related to multi-head attention. We have added an additional experiment and discussion on this topic (see the common answer for more details).
W2. and W3. Simplifying the architecture is a common approach in theoretical studies of Transformers, often involving adjustments such as linear attention, omitting skip connections, or removing normalization, among others. That being said, we also performed an additional experiment demonstrating that our simplifications do not alter the phenomena under investigation (see common answer for more details).
W4. Our proof technique relies on the use of the Center-Stable Manifold Theorem, a tool from dynamical systems theory. Unfortunately, this tool does not provide quantitative rates of convergence. Obtaining a rate is a challenging task as it would require quantifying the distance of the iterates to the saddle points of the risk (the dynamics is indeed slower near saddle points), which in turn requires other tools of analysis and potentially additional assumptions. We will include this discussion in the next version of the paper.
W5. We indeed restrict the analysis of the dynamics when initializing on the invariant manifold, as discussed in the paper (lines 468-475). Our numerical experiments suggest that similar convergence behaviors should hold for a more general initialization. A possible direction to extend the proof is discussed in lines 500-505, requiring in particular to show that the invariant manifold is stable in the sense of dynamical systems (i.e., if the dynamics are close to the manifold at a given time, they remain close to the manifold).
W6. Our approach is to start from the experimental evidence accumulated in the literature and provide a simplified model with similar patterns and tractable analysis. As a result, the reviewer is correct in stating that our contribution is mostly theoretical in nature, and does not consist of adding additional real-world empirical evidence to already well-established phenomena such as sparsity or internal linear representations. Besides, our experiment on linear probing on a pretrained BERT architecture, which is intended to verify the role of the [CLS] token as well as how BERT handles information contained in a single token, is closer to practice, although we acknowledge that the NLP task is synthetic.
I appreciate the detailed response given by the authors. As my concerns have been addressed to some extent, I will increase the rating from 5 to 6.
Thank you for acknowledging the rebuttal and increasing the score!
This paper introduces a new theoretical approach to understand attention mechanisms by considering a task called single-location regression. The authors consider a simplified model that's related to self-attention layer and shows that the model can achieves asymptotic Bayes optimality, while linear regressors fail. The authors also use PGD to show that the non-convex loss function still converge.
优点
- The paper is well-written and easy to follow. For example, the step-by-step illustration of how to connect the construction to the attention mechanism in Section 3 is helpful for understanding.
- The choice of model is good, like using [CLS] property which is observed in empirical study in (5) is natural and reasonable, and using erf as nonlinear weight function is also reasonable. In general, the theoretical result is solid.
- Also contain some empirical result for showing the convergence of constructed model
缺点
- The setting is restricted to single position token, although it focus on the sparse attention settings and it's already difficult to analyze, it's still far from the real-world case. Besides, the authors haven't done experiments on real-world experiments (like sentimental tasks as shown in Figure 1) to support some claims in the paper, this may kind of reduce the impact of the theoretical analysis. But in general it's already good as a theoretical-centric paper.
问题
- What's the function of input in Figure 1 (a)? It seems that the Y label just depends on the output and the input is not related to the sentimental label?
W1. Our approach is to start from the experimental evidence accumulated in the literature and provide a simplified model with similar patterns and tractable analysis. As a result, the reviewer is correct in stating that our contribution is mostly theoretical in nature, and does not consist of adding real-world empirical evidence to already well-established phenomena.
For example, with respect to sparsity, multiple papers (some of which we cite in lines 141-142) show that a wide range of NLP tasks exhibit token-wise sparsity, meaning that the relevant information is concentrated in tokens. It is true that in our paper we consider the limiting case and provide an explanation of how attention layers deal with sparsity in this case. Studying the case is a very interesting next step, and we provide additional preliminary experiments in the revised version of the paper (see the common rebuttal for more details).
In addition, our experiment on linear probing on a pretrained BERT architecture, to verify the role of the [CLS] token as well as how BERT handles information contained in a single token, is closer to practice, although we acknowledge that the NLP task is synthetic.
Q1. Regarding Figure 1, the synthetic data is described in detail in Appendix E.1. To summarize, the label is indeed related to the input sentence, since we set the label to when a positive sentiment adjective (such as nice, cute, etc.) appears in the sentence, and to in the case of a negative sentiment adjective. The color shading in Figure 1(a) was intended solely for visual illustration, but we acknowledge that the current figure and caption may be confusing (especially because we use both "output" and "label" in the caption to refer to the same quantity ). This will be clarified in the next version.
Thanks for your detailed experiment results. I will keep my positive score
Dear Reviewers,
We sincerely thank you for your time and feedback on our paper, which will help us improve our work, and we appreciate your overall positive assessment. Thank you!
We emphasize that we are fully aware of the simplifications in our model with respect to practical settings. However, our goal in the present paper is precisely to present and analyze a simplified model in order to understand real-world phenomena in a tractable setting. It turns out that, even in this simpler setting, the analysis involves advanced mathematical tools, mainly from dynamical systems theory.
Below we answer two (related) questions raised by several reviewers: (1) the connection to practical models, and (2) the replacement of softmax by an element-wise nonlinearity.
(1) Connection to practical models: Following the reviewers’ comments, we performed additional experiments that demonstrate the relevance of our simplified architecture as a model of Transformer layers. Specifically, we train a full Transformer layer (consisting of a standard single-head attention layer and a token-wise MLP, with skip connections and layer normalization) on single-location regression. We find that this Transformer layer is able to solve single-location regression. Furthermore, the underlying structure of the problem (namely the parameters and ) is encoded in the weights of the trained Transformer, as is the case in our simplified architecture. We also investigate the more complex variant of multiple-location regression, where the output depends on multiple input tokens (with multiple and ), and the associated case of multihead attention.
We have uploaded a revised version of the paper with the new experimental results at the end of the appendix (and will move some of these results into the main text in the camera-ready version, if accepted). Of course, if the paper is accepted, we will also incorporate content related to the other comments of the reviewers.
(2) Replacing softmax: The role of softmax in attention is an active topic of discussion in the community, as reviewer 1VPX points out. On the practical side, many papers investigate other nonlinearities [1]-[3]. On the theoretical side, simplifying softmax to a linear activation, for example, is fairly standard [4]-[6]. In our mathematical analysis, the choice of an elementwise nonlinearity is key for the mathematical derivations. For example, we are able to compute a closed-form formula for the risk (Lemma 6), which relies on computing expectations of functions of Gaussian random variables (Lemma 18), which is specific to the erf activation. Nevertheless, as highlighted above, we observe numerically that a standard attention layer with softmax exhibits similar behavior to our simplified model. This suggests that softmax does not play a significant role in our context, and thus it is reasonable to consider the simpler case of elementwise nonlinearity. We will include this discussion in the next version of the paper.
With our thanks,
The authors.
[1] cosFormer: Rethinking Softmax in Attention, Qin, Sun, Deng, Li, Wei, Lv, Yan, Kong, Zhong, ICLR 2022
[2] A Study on ReLU and Softmax in Transformer, Shen, Guo, Tan, Tang, Wang, Bian, arXiv 2023
[3] Replacing softmax with ReLU in Vision Transformers, Wortsman, Lee, Gilmer, Kornblith, arXiv 2023
[4] Transformers learn to implement preconditioned gradient descent for in-context learning, Ahn, Cheng, Daneshmand, Sra, NeurIPS 2023
[5] Trained Transformers Learn Linear Models In-Context, Zhang, Frei, Bartlett, JMLR 2024
[6] How do Transformers perform In-Context Autoregressive Learning?, Sander, Giryes, Suzuki, Blondel, Peyré, ICML 2024
All reviewers find the paper written well presenting analysis of attention learning a rather simple problem. Authors argue that such thorough analysis is missing for Attention; even though done for a simpler setting, involves non-trivial complexities in the analysis. Other concerns were around model simplification such as removing softmax in attention for analysis. Overall I think this is a borderline paper
审稿人讨论附加意见
Reviewers expressed concerns around simpler problem setting considered in the paper. Authors highlight non-trivialities in the analysis techniques.
Accept (Poster)