Quality over Quantity in Attention Layers: When Adding More Heads Hurts
We prove that attention layers with a small query/key dimension can have weak expressive power. This weakness cannot be overcome without using a massive number of attention heads.
摘要
评审与讨论
This paper questions the conventional number that community use for the rank of attention matrices and the number of attention heads. Authors provide in-depth theoretical explanations and experiments to support their arguments. In the high-accuracy regime, the required number of heads is growing exponentially to remain the performance.
优点
Transformer have dominated many areas but there have been few studies on the choices of numbers of attention heads and dimensions used in attention mechanism. This paper raises doubts about this which is valuable for community to pay attention.
-
The paper is well written and easy to follow. In-depth theoretical explanations are provided.
-
For a simple and natural target function -- nearest neighbor function, authors show low-rank attention is fundamentally weaker than full-rank attention even when choosing very large head numbers.
-
Also, this paper explores the solutions to mitigate the weakness of low-rank attention
缺点
- The paper studies only is limited to shallow transformers which are not practical to large model.
问题
- I am wondering the reason why authors choose to analysis nearest neighbor functions and are there any other choices of functions.
We thank the reviewer for the positive feedback.
Weaknesses:
- We acknowledge that our lower bounds apply for shallow transformers, although Section 6 studies how depth can mitigate the low-rank issue up to some extent. This is an active direction we are working on. However, our experiments do go beyond shallow transformers. Figure 2 uses 5-layer transformers, while Figure 1 uses a 12-layer GPT-2 architecture. In both cases, the low-rank bottleneck persists. It is worth mentioning that also for neural networks, such exponential separation results are shown only for shallow networks. In fact, in “Neural Networks with Small Weights and Depth-Separation Barriers, Vardi & Shamir, 2020”, it is shown that proving such exponential separation results for deeper neural networks would solve a longstanding open question in circuit complexity, and thus might be very difficult. We strongly suspect that such complexity-theoretic barriers also apply to proving separation in deeper transformer models.
Questions:
- The main reasons for using this target function are (1) it has lots of symmetry and (2) it is natural and well-motivated. By symmetry, we mean that it is both permutation invariant (permuting the x_i’s doesn’t change the result) and rotationally invariant (applying any orthogonal matrix to the x_i’s and to y doesn’t change the result). Permutation invariance is beneficial since we don’t have to deal with positional encodings.Rotational invariance helps us prove a strong lower bound. The intuition is that the target function requires “knowledge” of all possible directions. In the analysis we use the fact that we can choose exponentially many directions that are “almost orthogonal” to each other. A low-rank approximation of this target function would have to cover each of these directions separately, which cannot be done efficiently. There could be other target functions that also have these symmetries, but they may be less natural than the nearest neighbor function. For example, in [Depth separation for neural networks, Daniely, 2017], the author uses as a target function to prove exponential separation in feed-forward neural networks. We think that this function can also be adapted for our uses in transformers, but it seems less natural.
The paper "THE LOW-RANK BOTTLENECK IN ATTENTION" investigates the impact of the rank of attention matrices on the representational capacity of attention-based mechanisms, particularly in transformers. It challenges the common practice of using low-rank attention and proposes that the rank can significantly influence the model's ability to approximate certain functions. Specifically, the authors present a simple and natural target function based on nearest neighbor search that can be represented using a single full-rank attention head for any context length. The paper presents theoretical analysis and empirical experiments to support its claims, suggesting that increasing the rank or the number of attention heads may lead to more expressive and parameter-efficient models.
优点
-
Novel perspective on attention mechanisms: The paper offers a fresh perspective on the role of rank in attention mechanisms by using a simple and natural target function based on nearest neighbor search that can be represented using a single full-rank attention head for any context length, which is an interesting aspect of transformer architectures.
-
Theoretical and Empirical Rigor: It combines theoretical proofs with empirical experiments, providing a robust exploration of the implications of low-rank attention on model capacity and efficiency.
缺点
-
The results may rely heavily on the assumption of rotational invariance in the data distribution, which may not hold in all real-world applications.
-
To make it easier for readers to understand, I kindly suggest that the authors explain in more detail the differences between this paper and previous work [1]. [1] Low-Rank Bottleneck in Multi-head Attention Models. ICML 2020
-
Can the proposed method demonstrate its effectiveness on more attention-based models?
问题
Please refer to the Weaknesses
We thank the reviewer for the positive and constructive feedback.
To address the weaknesses:
- This is a good point. We note that rotationally invariant data is a very common assumption in papers that prove exponential separation between different architectures, e.g. [Eldan & Shamir, 2016], [Daniely, 2017], [Safran & Shamir, 2017]. As in our own proofs, the intuition is that the target has energy in all possible directions, and in high dimensions there are exponentially many “almost orthogonal” directions that must be captured. It would be interesting to extend the results beyond this assumption, perhaps using techniques from “Depth separation beyond radial functions”, Venturi et al. 2022”. However, we leave this to future work.
- Thank you for the feedback. In the final version, we can improve the comparison to this paper, which currently appears in lines 139-144. In a nutshell, their experiments are more realistic than ours, but our theoretical analysis is far stronger than theirs. Their experiments use three real-world NLP datasets, while we experiment only on synthetic tasks. They prove two theorems. The first says that a single low-rank attention head cannot exactly represent all possible stochastic matrices given any possible input sequence. However, this does not prevent a low-rank attention layer from approximating any particular function, especially since multiple heads can work together to produce richer attention behavior. The second states that increasing the rank while fixing all other parameters strictly increases the representational power. However, (1) it is unclear if this difference matters for any natural function (2) this does not control for the additional computational cost of using larger rank. It is not surprising that a model with more parameters is more powerful. Our work derives approximation upper and lower bounds which allow for a fine-grained quantitative analysis. Namely, we prove a trade-off between the rank and number of attention heads required to even approximate the target with a certain accuracy, and this trade-off strongly favors full-rank attention.
- This is a good question. Our results apply to generalized attention layers as defined in lines 210-223. Thus, the low-rank bottleneck likely affects many attention or attention-like models besides transformers and softmax attention. It might also be interesting to explore analogies of our results for state-space models (e.g. Mamba) using the framework of “Transformers are SSMs”, Dao and Gu 2024.
This article examines the limitations and potential of low-rank attention mechanisms in transformer models, demonstrating that while low-rank attention heads require significantly more heads to match the performance of full-rank models in approximating functions like nearest neighbors, and these limitations can be mitigated by increasing the depth of the model. Through theoretical analysis and empirical validation, the study highlights that full-rank models inherently possess superior representational capabilities, especially with fewer heads, and suggests that adding more layers could partly overcome the deficiencies of low-rank models, though at the cost of increased computational complexity.
优点
- A deeper exploration of the low-rank problem in Transformer models.
- The paper is well written and easy to follow.
- Authors provide ample mathematical proofs to support their conclusions.
缺点
-
The authors mentioned after Theorem 2 that the theoretical framework should be extendable to cases where N>2. Could you provide more specific explanations for the reasoning behind this inference? This would help further understand the applicability of your theory to specific problems.
-
Although the authors have demonstrated theoretically and experimentally that low-rank attention models are insufficient for fitting certain functions in various scenarios and are significantly weaker than full-rank attention models, further clarification is needed on how these issues impact current mainstream Transformer models (such as the new models shown in Table in Appendix B.1), how the proposed methods in the paper apply to these models, and how performance improvements are achieved. I believe that related experimental results and methodological extensions would greatly help illustrate the contribution of the paper.
问题
- Could simple experiments or additional references to other studies and conclusions be designed to intuitively show the impact of the low-rank problem on the performance of mainstream Transformer models?
- Could you further elaborate on how the proposed “majority voting” method for improving low-rank models enhances mainstream Transformer models and validate this with relevant experiments? For the experiments, model selection could refer to those in Appendix B.1 and the models used in [1][2], while the datasets could refer to those in [1][2] or other widely recognized benchmark datasets. [1] Bhojanapalli S, Yun C, Rawat A S, et al. Low-rank bottleneck in multi-head attention models[C]//International conference on machine learning. PMLR, 2020: 864-873. [2] Shazeer N, Lan Z, Cheng Y, et al. Talking-heads attention[J]. arXiv preprint arXiv:2003.02436, 2020.
- Also, please refer to weaknesses for other concerns.
We thank the reviewer for the constructive feedback.
To address the weaknesses:
- The intuition for why a lower bound like Theorem 2 should hold for N>2 is as follows: If N=1, the task is easy, because the only point there is to attend to is the correct answer. When N=2, the task is hard, because there is a second point that “distracts” our attention heads away from the right answer. Any attention paid to this distractor point is wasted; it does not help us represent the target function at all, since the points are orthogonal. If N=3, the task is even harder, as there are now two orthogonal points distracting the attention heads. This reasoning is supported by Figure 7 in Appendix B.2, which shows that the task gets harder and harder for low-rank attention as N increases. We also want to emphasize that while Theorem 2 in its current form assumes N=2, Fact 1 holds for all N. (Likewise, the upper bound part of Theorem 3 holds for all N.) We think we can generalize the proof of Theorem 2 to N>2 points that are drawn iid from the sphere, without the assumption of orthogonality. We are actively working on this and hope to include it in the next revision.
- We acknowledge that our contributions are largely theoretical, and we think our paper can best be judged in comparison to other work in the field of representational capacity. While our work motivates important questions for practitioners working with real-world datasets, future work will be needed to answer them. But please keep in mind that some of the architectures used in our experiments are not so unrealistic. The architecture used in Figure 1 is a standard 12-layer GPT-2 transformer, except with a smaller embedding dimension. In addition, other papers (Bhojanapalli et al., Yang et al.) have observed the effects of the low-rank bottleneck on real-world datasets and production-scale models, as described in the following comment. As for our proposed methods, we are not proposing changes to the architecture or training procedure, only to the hyperparameters. Specifically, we propose that people try increasing the rank used in their transformers; to control for the extra computational cost, they can decrease some other hyperparameters, such as d and H. Figure 1 shows that this simple change can yield huge performance improvements without increasing computational cost at all.
To address the questions:
- A simple way to demonstrate the impact of the low-rank bottleneck is Figure 1, which uses the same task and architecture as Garg et al. This architecture is exactly that of GPT-2 (which is certainly mainstream), but with a smaller embedding dimension. The experiments of Bhojanapalli et al. demonstrate the impact of the low-rank bottleneck on the well-known BERT-Large transformer architecture, as applied to three real-world NLP datasets. Since submission, we also found an important reference to the dangers of setting the attention rank too small in the muP paper (Yang et al., cited below), which has been extremely influential on the way practitioners tune the hyperparameters of their transformers. They find that transformers start to behave badly when the rank is too small, so in some cases they set the rank larger than the standard d = Hr rule prescribes. See Appendices D.4 and E.2 of Yang et al. (which refer to the rank as d_head or d_k). The theoretical studies of Sanford et al. and Mahdavi et al. can also provide helpful intuition. If there are other papers that we missed, we’ll be happy to reference them in the final version.
- We would like to clarify that the “majority voting” construction described in Section 6 is not a new method, and is used only in the proof of Theorem 5. To prove Theorem 5, which is of the form “There exists a transformer such that X”, we construct a transformer by hand that has property X. In practice of course, the weights of a transformer should be learned via training, not set by hand. It is possible that the reviewer is asking about the concatenated positional encodings described in the third paragraph of Section 6, not to the “majority voting” construction itself. In Theorem 5, we assume that the input points have been preprocessed by appending a 2-dimensional positional encoding to each. This preprocessing slightly enhances the power of the model, but it is not novel. Mapping the input points into a higher dimension and using absolute positional embeddings are both common; for example, the architecture used in Garg et al. and in our Figure 1 includes both of these features. We simply use them in a particular way to facilitate our proof. In practice, concatenated positional embeddings do not significantly affect the low-rank bottleneck. In sum, we are not proposing any changes to the transformer or attention architecture, only to the hyperparameters. As for the papers referenced in the review, our message is quite similar to Bhojanapalli et al., but we provide a much stronger mathematical justification. The introduction of Shazeer et al. describes the low-rank bottleneck, but their “Talking Heads” architecture does not help fix it in our setting. In fact, their architecture fits into our framework of generalized attention, so our lower bound applies to it. If there is a part of this question that we have not understood, please let us know and we would be happy to discuss it further.
Yang, Greg, Edward J. Hu, Igor Babuschkin, Szymon Sidor, Xiaodong Liu, David Farhi, Nick Ryder, Jakub Pachocki, Weizhu Chen, and Jianfeng Gao. "Tensor programs v: Tuning large neural networks via zero-shot hyperparameter transfer." arXiv preprint arXiv:2203.03466 (2022). https://arxiv.org/pdf/2203.03466.
This paper presents a theoretical analysis of the role of rank within attention mechanisms. It challenges the prevailing practice of employing low-rank attention and discusses the implications related to the selection of the number of heads. The author establishes that low-rank attention exhibits inferior performance compared to full-rank attention, indicating that the adoption of a higher rank has the potential to enhance attention performance. Preliminary experiments are conducted utilizing toy examples with synthetic data.
优点
The theoretical analysis seems correct.
缺点
-
The rank of attention is a significant hyperparameter in the design of transformers. A common convention involves the utilization of low-rank attention, typically establishing the number of heads as ( H = d/r ). This paper, however, contests this design choice, proposing that a higher rank can enhance performance. It is crucial to note that the paper does not address the speed-accuracy trade-off associated with this adjustment. It is widely recognized that high-rank attention may yield superior performance at the expense of increased computational costs. When evaluating overall performance, particularly in terms of accuracy within a predetermined computational budget, prevailing practices may ultimately provide more favorable outcomes.
-
The experiments presented in this study lack robustness, as they are primarily limited to toy experiments. I would appreciate observing performance metrics derived from real-world data applied to standard transformer sizes. It is well established that theoretical performance often diverges from practical outcomes in deep learning; thus, empirical experimentation is essential.
-
This work indicates that shallow transformers may experience limitations due to low-rank attention. However, it is imperative to ascertain how these limitations manifest in deep transformers, as shallow transformers are not commonly employed in practice. If this limitation has been substantially mitigated in deep transformers, it may render further examination of this issue unnecessary.
问题
N/A
We thank the reviewer for the thorough review.
While our results raise interesting questions for experimentalists and practitioners—such as the impact of rank on real-world datasets and standard transformer sizes—we do not claim to answer those questions here. Rather, we think our work can best be judged in comparison to other theoretical papers in the field of representational capacity. There is a large literature in deep learning theory that mathematically proves separations between the representational capacities of different models (e.g. different depths, different architectures, etc.). Some of this literature is reviewed in Section 2 of our paper, in the paragraph “Depth-width trade-offs in neural networks” and the end of “Expressive power of transformers”. Our work, like that of Sanford et al., Mahdavi et al., and Likhosherstov et al. (discussed in “Limitations of low-rank attention”), extends this literature to the study of attention rank. The review states, “It is widely recognized that high-rank attention may yield superior performance,” but in fact, a mathematical basis for this claim is lacking in the literature. This paper is actually the first work to prove a strong separation between low- and high-rank attention with multiple heads. Thus, we believe our work will be of interest to the community.
We now address the specific points raised in the review:
- The review claims that “the paper does not address the speed-accuracy trade-off associated with full rank attention,” but in fact, we account for this trade-off throughout. It is true that, for a single attention head, the trade-off is simple, unavoidable, and well-understood: a full-rank attention head is slower but more powerful, and a low-rank attention head is faster but less powerful. However, for multi-head attention, the trade-off is far more subtle and poorly-understood. As the review notes, it is essential to control for “the computational budget” when comparing full-rank to low-rank attention layers. The number of parameters and computational complexity of multi-head attention layers are both given by . Thus, it is fair to compare full-rank attention () that uses one head () to low-rank attention () that uses heads, since both have a computational cost of overall. If we allow low-rank attention to use more than heads, we are giving it extra computational budget. Our theoretical results compare single-head full-rank attention to low-rank attention with or even , so we are giving low-rank attention a huge advantage in computational budget. We prove that full-rank attention is superior despite this advantage. Put differently, our target function can be approximated up to any accuracy by a full-rank attention layer using a budget of . But to achieve a similar accuracy using low-rank attention requires exponentially many parameters and an exponentially large computational budget. Our experiments also account for the speed-accuracy trade-off. Throughout our experiments, we scale inversely proportional to , so that the speed of the low-rank and full-rank transformers are the same. In Figure 1, all five bars have the same computational budget, but the low-rank transformers perform much worse. Similarly, in Figure 2, the computational budget is fixed along each line. For example, all points along the blue line have a computational budget of per layer. Here too, the full-rank transformer performs much better despite having the same speed.
- We emphasize that the goal of our experiments is to support the theoretical claims rather than to provide an extensive experimental study on real-world datasets. This is a theoretical paper, and our main contribution is a thorough theoretical study that proves novel separation results. It is true that “theoretical performance often diverges from practical outcomes in deep learning,” and this is why we perform experiments to confirm that the predictions of our theory hold in practice, even in more general settings than our proofs do. For example, the experiment described in Figure 1 shows that the low-rank bottleneck is not specific to the nearest neighbor task; it is a more general phenomenon. Our paper does not draw any conclusions about language modeling or other “real-world data”. However, studying synthetic tasks is often valuable. Indeed, the literature on machine learning theory is full of such studies. A case in point is the paper by Garg et al. which we replicate in Figure 1 with modifications to the attention rank. Though it studies a “toy” task of in-context regression, this paper has had a significant impact on the understanding of transformers. This task is now well-established in the literature; several examples of direct follow-up papers that study the same task are cited below. We believe that the dramatic effect of the low-rank bottleneck (see Figure 1) on such a well-studied and important task as this will be of interest to a community.
- While our theoretical results concern shallow transformers, our experiments use deep transformers. In fact, Figure 1 uses transformers that are 12 layers deep, as described in Appendix B.1. This is the same depth used by Garg et al. and by GPT-2, so it is realistic. All our experiments suggest that the limitations of low-rank attention persist for deep transformers.
Some papers that use the in-context regression task of Garg et al.:
Oswald, Johannes von, Eyvind Niklasson, E. Randazzo, João Sacramento, Alexander Mordvintsev, Andrey Zhmoginov and Max Vladymyrov. “Transformers learn in-context by gradient descent.” ICML 2022, oral.
Akyürek, Ekin, Dale Schuurmans, Jacob Andreas, Tengyu Ma and Denny Zhou. “What learning algorithm is in-context learning? Investigations with linear models.” ICLR 2023 notable top 5%.
Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, Suvrit Sra. “Transformers learn to implement preconditioned gradient descent for in-context learning”. NeurIPS 2023.
Y Bai, F Chen, H Wang, C Xiong, S Mei. “Transformers as statisticians: Provable in-context learning with in-context algorithm selection”. NeurIPS 2023, oral.
Ruiqi Zhang, Spencer Frei, Peter L. Bartlett. “Trained Transformers Learn Linear Models In-Context”. JMLR 2024.
Again, we thank the reviewers for their comments. We hope that our response has addressed all the issues raised by the reviewers and that they will consider updating their scores accordingly. If there are remaining questions, we would of course be happy to address those before the public discussion phase ends.
Many thanks, The authors
This paper investigates rank of the attention in the Transformer framework. The authors found that traditional way of using low-rank attention would lead to inferior performance and implied that high-rank would be advantageous. Although as pointed out by some reviewer that this finding challenges exiting practice of low-rank attention heads, the authors justify this finding with theoretical analysis and experiments. In my understanding, this work is a first try to theoretical analyze the role of rank to attention; though it still has some limitation, e.g., the theory only applies to shallow network and experiments are conducted on some toy datasets, the overall quality of this work may pass the threshold of ICLR, and could be accepted.
审稿人讨论附加意见
Reviewer zXCj mainly concerns the limitation of the theory to deeper Transformer and the validity of the simple experiments conducted on toy datasets. In the rebuttal, the authors' rebuttal mainly highlighted their major novelty is a first try to theoretical analyze the role of rank to attention. Although Reviewer zXCj hasn't raised his/her score, I believe the authors' rebuttal could be adopted. Overall, the authors' rebuttal have addressed most of the concerns of the reviewers.
Accept (Poster)