Tensor Attention Training: Provably Efficient Learning of Higher-order Transformers
摘要
评审与讨论
The paper proposes a closed-form solution to compute gradients for Tensor Attention in higher-order transformer models efficiently. By utilizing polynomial approximation methods and tensor algebraic techniques, the algorithm achieves nearly linear time n1+O(1), the same complexity for both forward and backward computation under certain assumptions, which are proven to be necessary and tight in the paper. The theoretical results establish the feasibility of efficient higher-order transformer training and may promote practical applications of tensor attention architectures.
优点
-
The paper proposes the algorithm which can quickly compute the backward gradient of tensor attention training in almost linear time by utilizing polynomial approximation methods and tensor computation techniques, based on the closed- form solution of the gradient computation of tensor attention provided in the paper.
-
Under hardness analysis, the paper proves that the assumption should be necessary and tight, meaning that there is no algorithm which can achieve almost linear time in solving the tensor attention gradient computation, which is very important for future related research.
缺点
-
The paper's assumptions are relatively strict and may be difficult to meet in real- world situations. The paper lacks discussion on the practical applicability and limitations of the method.
-
The work should talk more about the benefits of using a tensor-based attention mechanism, because otherwise the motivation is quite weak, although I understand its theoretical novelty.
-
The authors should talk more about the tradeoffs of Algorithm 1. Namely what it sacrifices for faster computation, and how does it compares with matrix attention.
问题
- The paper mentions that the "bounded entries assumption" needs to be met. Are there any studies or experimental results that support that this assumption is correct in actual transformer models? If this assumption is not met in some actual scenarios, can your algorithm still remain efficient, or what adjustments need to be made?
We thank the reviewer for their valuable suggestions. We provide our response below and hope it will address your concerns.
W1 & Q1: The paper's assumptions are relatively strict and may be difficult to meet in real-world situations. The paper lacks discussion on the practical applicability and limitations of the method.
In Remark 5.3 on Line 391, we discuss this assumption in detail. Additionally, several recent works [3,4] have shown that large entries are very sparse in the attention matrix. This suggests that our algorithm could inspire many practical implementations. One straightforward approach is to handle large entries separately, as done in [4], and then apply our algorithm to the remaining parts. There is undoubtedly a broad algorithmic design space, and we hope our work provides valuable insights.
W2: The work should talk more about the benefits of using a tensor-based attention mechanism, because otherwise the motivation is quite weak, although I understand its theoretical novelty.
Thank you for your suggestions. Classical matrix attention mechanisms capture pairwise but not higher-order correlations, limiting their effectiveness in multimodal models. Tensor Attention [1,2] addresses this by capturing higher-order correlations. We add more motivation in the revision Line 51-53 and Line 57.
W3: The authors should talk more about the tradeoffs of Algorithm 1. Namely what it sacrifices for faster computation, and how does it compares with matrix attention.
Algorithm 1 sacrifices an approximation error to achieve a faster algorithm, as detailed in Theorem 4.3 and Theorem 5.2.
Compared to classical attention, tensor attention captures higher-order information, whereas classical attention involves greater computational complexity. Refer to W2 for details.
[1] Sanford, C., Hsu, D. J., & Telgarsky, M. Representational strengths and limitations of transformers. NeurIPS’23.
[2] Alman, J., & Song, Z. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. ICLR’24.
[3] Sun, M., Chen, X., Kolter, J. Z., & Liu, Z. Massive activations in large language models. COLM’24.
[4] Han, I., Jayaram, R., Karbasi, A., Mirrokni, V., Woodruff, D. P., & Zandieh, A. Hyperattention: Long-context attention in near-linear time. ICLR’24.
I thank the authors for their clarification. I think have some doubts about the practical applicability of tensor attentions, but I do believe this works provides some interesting insights, so I would keep my score as is.
We sincerely thank your valuable suggestions and comments. We appreciate your time and positive feedback!
Authors consider the tensor attention within the framework of transformer models and propose an algorithm that allows calculating gradients for the tensor attention with linear complexity. While standard attention captures pairwise interactions between tokens, the tensor attention (third-order tensor attention variant) was suggested by [Sanford et al; NeurIPS-2023] in the context of capturing triplet interactions. In the basic implementation, such a mechanism has cubic complexity on the forward pass, and in a subsequent work [Alman et al; ICLR-2024] a new algorithm based on polynomial expansions was proposed that allows (theoretically) to have linear complexity for forward passes. The authors, continuing this line of work, constructed an algorithm for calculating derivatives with linear complexity (theoretically), which potentially allows the method to be practically implemented and overcome the theoretical cubic time complexity barrier both in inference and training.
优点
- The work is neatly structured and the presentation is consistent.
- The formulations seem correct and the evidence rigorous.
- If successfully implemented in practice, tensor attention could be in demand in ML community as it has the potential to account for more complex correlations in data.
缺点
-
If in [Alman et al; ICLR-2024] work the lack of experimental confirmation of the proposed approach could be attributed to the lack of an effective method for calculating gradients, then in this work, as stated, the last ingredient for the successful implementation of tensor attention has been prepared. In this regard, it is confusing that no practical implementation or numerical experiments are provided.
-
In the context of the previous point, the degree of scientific novelty raises certain doubts and it seems that it is only a certain development of the previous theoretical works [Sanford et al; NeurIPS-2023], [Alman et al; ICLR-2024].
-
The list of references takes up 9 full pages. The reasons for such extensive citations are not entirely clear, and the usefulness of these references for the reader is questionable. For example, after the phrase "Moreover, tensors are crucial in numerous machine learning applications..." there are about 9 references to various works of little relevance to the subject of these manuscript. Instead, the authors could have referred here to a 1-2 review works.
问题
We thank the reviewer for their valuable suggestions. We provide our response below and hope it will address your concerns.
W1: No practical implementation or numerical experiments are provided
We thank to reviewer’s suggestion. We refer the reviewer to the global response Empirical validation part for our response.
W2: Novelty
We discussed our technical novelty in Line 440-453. Our key novel techniques are:
- We abstract the most challenging part (the highest time complexity operation) in high-order attention into a clear mathematical problem and provide a solution.
- We provide an analysis of the tensor attention gradient and the fast algorithm design, which utilizes several tensor lemmas tailored to our problem setup.
We believe that the complicated form of the tensor gradient presents a significant challenge to address. Designing an almost linear-time solution using advanced tensor calculus is non-trivial. Additionally, we provide a corresponding hardness lower bound. While building on previous work, we believe the problem we tackled has not been explored before and offers substantial contributions.
W3: References
Thank you for your suggestions. We have double-checked the related work and only left a few numbers of the most important related work for each concept.
Dear Reviewer kn9h,
We hope we have adequately addressed your issues. We would be very grateful if you could provide feedback on our rebuttal since the discussion deadline is approaching in one day. If you require further clarification or have any additional concerns, please do not hesitate to contact us. We are more than willing to continue communicating with you.
Warmest regards,
Authors
Dear authors, I thank you for your clarifications and revision of your work! I think it's possible to raise my rating (3 -> 5) for the updated version.
We are glad that our response addressed your concerns and thank you for your thoughtful feedback. We appreciate your valuable time and increasing score!
This paper studies an approximation algorithm for computing gradients of high-order tensor attention. While the exact computation has cubic complexity, the authors prove that the proposed approximation algorithm has almost linear complexity. Also, the authors prove that the assumptions cannot be further weakened to achieve sub-cubic complexity.
优点
-
It is important to study the gradient approximation of tensor attention, which could make this structure practically usable.
-
For proving the results, the authors establish some results of tensor computation, which may be useful for related fields.
缺点
-
Since this paper aims to propose a fast training algorithm for tensor attention, the lack of experiments could be a weakness. I am wondering about the practicality of the proposed algorithm. And are there still large gaps to implement or have experiments on these algorithms?
-
I am wondering whether the approximation error would have a major issue when training large models in practice. For example, for large models with deep architectures, will these errors accumulate and cause training difficulties?
问题
See above
We thank the reviewer for their valuable suggestions. We provide our response below and hope it will address your concerns.
W1: Lack of experiments
We thank to reviewer’s suggestion. We refer the reviewer to the global response Empirical validation part for our response.
W2: Approximation error
The error will increase a factor of additively for a -layer model so that the final error will be .
This paper derives an efficient algorithm to approximate the gradients of high-order transformers. Such models extend the classical self attention mechanism, where only pairwise interactions are modelled, to higher order interactions where e.g. each token can attend to every possible pairs of tokens in the sentence. This model can also be used to model interactions between different modalities or views of a same object.
The attention model itself was already introduced previously in 2023 and 2024, along with an efficient algorithm to approximate the forward pass computation. The contributions of the submission is to design an analogous efficient algorithm for fast approximation of the backward pass computation.
The main theoretical result shows that the proposed approximation algorithm can achieve epsilon = 1 / poly(n) approximation guarantee in almost linear time to compute gradients over a sequence of length n. The authors also provide a hardness analysis showing that their assumption are tight, in some sense.
The contribution is only theoretical, no experiments are provided.
优点
-
The problem of efficient computation of higher-order attention mechanism is relevant
-
A hardness analysis is provided to strengthen the result.
缺点
-
The paper is poorly written and hard to follow. In my opinion, it is not ready for publication.
-
No experiments are provided to demonstrate the effectiveness (and correctness) of the proposed analysis / algorithm.
-
Some key aspects of the results are (at least to me) very difficult to get from the paper. E.g., what are the "assumptions" referred to in the statement "we proved the necessity and tightness of our assumption"? What were the key novel techniques needed to extend the previous approach to the gradient computation?
-
The proof (in appendix) is very difficult to follow and the results hard to interpret because they are presented in a very convoluted manner. E.g., Lemma 4.1 (which is central to the derivation and understanding of the main theorem) involves a quantity F(x) which is only defined in the Appendix. Furthermore, the definition given in the appendix itself relies on other quantities defined previously in the appendix (S and W), which are themselves functions of other quantities previously defined (K,V,L...).... After processing backward through this series of definitions, one realize that F(x) is actually a function of many key variables of the problem (A1,A2,A3,Y,...). In the end it is thus very difficult to understand what is Lemma 4.1 suppose to highlight, since F(x) hides many dependencies on central variables of the attention computation. Note that not much context is given before the lemma to help/guide the reader. This is unfortunately not an isolated example.
-
This is minor, but there are many grammatical imprecisions / errors that need to be fixed.
问题
-
Line 260-261 are not clear at all: Z has not been defined, it is impossible to guess what A1 to A5 are. A posteriori, after reading the next page, the reader can retrospectively guess what was meant by this statement but it is, in my opinion, not reasonable to expect the reader to not stumble on these lines and be confused (as I was).
-
Similarly, when reading Definition 3.8 it is not clear at all why the minimization is only w.r.t. X. Only after the definition is explained that minimization w.r.t. the other variables is ignored because the minimization w.r.t. X is the computational bottleneck. In my opinion, the authors should explained this before the definition.
-
Line 3.9: unless I am mistaken the log(n) bit model has never been introduced.
-
Introducing the column-wise Kronecker product as "one kind of tensor operation" (line 59 and in other places) can be confusing. I would suggest directly calling it by its name along with a forward reference to the definition.
-
The column-wise Kronecker product is known as the Kathri-Rao product. The row-wise Kronecker product is sometimes referred to as the face-splitting product (https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product), this should be mentioned.
-
Most of the proofs of the basic facts in appendix can be found in the seminal survey (Kolda and Bader, 2009), e.g. C.7, C.11, C.12. Some are trivial, e.g. part 1 of C.10. While I understand the desire to be self-contained, this long list of basic facts intertwined with less obvious identities sometimes hinders the readability and understanding of the overall proof.
Typos / imprecisions
line 170: w.r.t -> w.r.t.
line 173: "to denote a length nd vector" is confusing. Directly gives the definition: "to denote the length nd vector obtained by stacking the rows of A into a column vector".
line 176: "denote an identity matrix" -> "denote the identity matrix" (there is only one). Idem for the following one when defining the identity tensor (not "an" identity tensor).
line 181: "And we have X = mat(X)" is not a proper way to introduce a mathematical statement, in my opinion.
line 180: tesnorization
def 3.2 and 3.3. : "define matrix" -> "define the matrix"
[...] I stopped keeping track of such small typos / grammatical errors at this point. The paper needs a thorough proof read.
[1] Feng, G., Zhang, B., Gu, Y., Ye, H., He, D., & Wang, L. Towards revealing the mystery behind chain of thought: a theoretical perspective. NeurIPS’23.
[2] Liu, B., Ash, J. T., Goel, S., Krishnamurthy, A., & Zhang, C. Transformers Learn Shortcuts to Automata. ICLR’23.
[3] Merrill, W., & Sabharwal, A. The parallelism tradeoff: Limitations of log-precision transformers. Transactions of the Association for Computational Linguistics (ACL’23).
[4] Kolda, T. G., & Bader, B. W. (2009). Tensor decompositions and applications. SIAM review, 51(3), 455-500.
Dear Reviewer tbYL,
We hope we have adequately addressed your issues. We would be very grateful if you could provide feedback on our rebuttal since the discussion deadline is approaching in one day. If you require further clarification or have any additional concerns, please do not hesitate to contact us. We are more than willing to continue communicating with you.
Warmest regards,
Authors
We thank the reviewer for their valuable suggestions. We provide our response below and hope it will address your concerns.
W1, W5, and Typos: The paper is poorly written and hard to follow.
Thank you for pointing out this point. We appreciate your time and careful reading. We have refined our writing and fixed all typos. We refer the reviewer to our revised version. For the definition of mat, we have updated it in Line 167-168.
W2: No experiments are provided to demonstrate the effectiveness (and correctness) of the proposed analysis / algorithm.
We thank to reviewer’s suggestion. We refer the reviewer to the global response Empirical validation part for our response.
W3.1: Some key aspects of the results are (at least to me) very difficult to get from the paper. E.g., what are the "assumptions" referred to in the statement "we proved the necessity and tightness of our assumption"?
Thanks for pointed out. Our assumptions are stated in Definition 3.9, the definition of problem, in Line 297-299.
- We assume input matrices are bounded by a positive number , i.e. .
- We assume any numbers in the previous matrices are in the bits model (see our response to Q3 for details).
- Also we assume in Theorem 5.2 Line 381.
W3.2: What were the key novel techniques needed to extend the previous approach to the gradient computation?
We discussed our novelty in Line 440-453. Our key novel techniques are:
- We abstract the most challenging part (the highest time complexity operation) in high-order attention into a clear mathematical problem and provide a solution.
- We provide an analysis of the tensor attention gradient and the fast algorithm design, which utilizes several tensor lemmas tailored to our problem setup.
W4: The proof (in the appendix) is very difficult to follow and the results hard to interpret because they are presented in a very convoluted manner.
We provided a clear figure to show the variable dependencies in Line 270-288 in the revision.
Q1: Line 260-261 are not clear at all: Z has not been defined, it is impossible to guess what A1 to A5 are. A posteriori, after reading the next page, the reader can retrospectively guess what was meant by this statement but it is, in my opinion, not reasonable to expect the reader to not stumble on these lines and be confused (as I was).
Actually, we define as the hidden representations/ input sequences in Line 42, and in Definition 3.4 Line 200. We revised the paper to be more clear in Line 258-259.
Q2: Similarly, when reading Definition 3.8 it is not clear at all why the minimization is only w.r.t. X. Only after the definition is explained that minimization w.r.t. the other variables are ignored because the minimization w.r.t. X is the computational bottleneck. In my opinion, the authors should explained this before the definition.
Thank you for your suggestions. We have reorganized the paper following your comments in Line 260-265.
Q3: Line 3.9: unless I am mistaken the log(n) bit model has never been introduced.
In Definition 3.9, we state that numbers in the matrices follow the bits model, i.e., each entry in the matrix is represented by at most log(n) bits. This assumption is well-accepted and widely used in the computational complexity community, e.g., [1,2,3]. It aligns well with practical scenarios, where machine precision (e.g., 16 or 32 bits) is typically far smaller than the input context length (e.g., 128k). We have updated the above in Line 321-323.
Q4: Introducing the column-wise Kronecker product as "one kind of tensor operation" (line 59 and in other places) can be confusing. I would suggest directly calling it by its name along with a forward reference to the definition.
Thanks for your suggestion. We have revised it in revision.
Q5: The column-wise Kronecker product is known as the Kathri-Rao product. The row-wise Kronecker product is sometimes referred to as the face-splitting product (https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product), this should be mentioned.
Thanks for your suggestion. We have revised it in revision in Line 184 and Line 191.
Q6: Most of the proofs of the basic facts in the appendix can be found in the seminal survey (Kolda and Bader, 2009), e.g. C.7, C.11, C.12. Some are trivial, e.g. part 1 of C.10. While I understand the desire to be self-contained, this long list of basic facts intertwined with less obvious identities sometimes hinders the readability and understanding of the overall proof.
We did find that [4] has many useful results. However, for the sake of completeness, we present proofs for those basic facts. We revised the paper to include citations for [4] in Line 945.
Deepest thanks to all the thoughtful and valuable reviews.
We appreciate reviewers tbYL, egs4, kn9h, and ZjZd for recognizing the relevance and impact of our work on the efficient computation of tensor attention gradients. tbYL and ZjZd commend the importance of the hardness analysis, emphasizing its role in establishing the necessity and tightness of our assumptions. egs4 highlights the practical potential of gradient approximation for tensor attention and the broader applicability of our tensor computation results to related fields. kn9h values the structured and rigorous presentation, noting the potential demand for tensor attention in the ML community due to its ability to capture complex data correlations. ZjZd specifically acknowledges our almost linear time algorithm for tensor attention gradient computation, achieved through polynomial approximation and tensor techniques, as a significant contribution to advancing the field.
We have updated a revision for our draft and put all major updates in brown. We also update the code about experiments in the supplemental material. Here, we summarize all major updates in the revision. Note that all line numbers in the rebuttal are corresponding to the revised version.
- Line 167-168: Give a definition of function mat.
- Line 321-323: Give explanation bits model.
- Line 270-288: Provide a figure to illustrate the proof flow. And also mention the figure in Line 310
- Line 258-259: Provide the reference of variable .
- Line 260-265: Reorganize the paper to provide intuition first.
- Line 184 and Line 191: We give the other name of column/row-wise Kronecker product.
- Line 945: Add citation [9].
- Line 51-53 and 57: Add more motivation about high-order attention.
Here, we address common concerns. We will address other questions and concerns in individual rebuttals.
Empirical validation
We admit that the practical implementation of high-order attention to real-world tasks poses additional significant challenges. From the authors’ perspective, these challenges include but are not limited to (1) as the kernel module, i.e., attention, changes, numerous other modules, such as layer normalization, residual connections, positional embeddings, and many others, may require corresponding updates; (2) integrating this module directly into existing Large Language Models or Large Vision-Language Models to save training time is also a challenge. Despite these challenges, these uncertainties may open a broad space for algorithm design. We hope our work inspires new model architectures and further algorithmic design.
We have carefully checked the current deep learning package, PyTorch, to see how our algorithm could be practically implemented on GPU for numerical verification. Due to the time limit, we provide approximation results of our Lemma E.1 below, where the approximation in Lemma E.1 is the key to our whole approximation process.
All input matrix elements are drawn from standard Gaussian. We compare the difference between and (see detail in Lemma E.1).
| Feature dim d | Input token length n | Approximation Error |
|---|---|---|
| 2 | 10*2^d = 40 | 0.0938 |
| 4 | 10*2^d = 160 | 0.0490 |
| 6 | 10*2^d = 640 | 0.0291 |
| 8 | 10*2^d = 2560 | 0.0179 |
| 10 | 10*2^d = 10240 | 0.0124 |
| 12 | 10*2^d = 40960 | 0.0092 |
In the above Table, we use as in our assumptions. We can show that our error is getting small when becomes larger, which is consistent with our theorem that the error bound is . This well support the correctness of our theoretical bounds. This approximation may inspire a new training paradigm of large multimodal models.
We update the code about the above experiments in the supplemental material. We also put the kernel code below. The key idea is to use torch.einsum() to implement tensor operations.
import torch
import math
CHAR_LIST = 'abcdefghijklmnopqrstuvw'
def poly_order_approx(x, order=1):
shape = list(x.shape)
shape[-1] = -1
denorm = (1 / math.factorial(order)) ** 0.5
operation_in = [f'xyz{CHAR_LIST[i]}' for i in range(order)]
operation_in = ','.join(operation_in)
operation_out = [f'{CHAR_LIST[i]}' for i in range(order)]
operation_out = 'xyz' + ''.join(operation_out)
operation = operation_in+'->'+operation_out # xyza,xyzb,xyzc,xyzd,xyze->xyzabcde
x_command = ', '.join(['x'] * order)
# torch.einsum('xyza,xyzb,xyzc,xyzd,xyze->xyzabcde', x, x, x, x, x).view(shape)
torch_command = f"torch.einsum('{operation}', {x_command}).view(shape)"
x_order = eval(torch_command)
# (q.T k)^d / d!
return x_order * denorm
We will check other bounds further and include them in the revision. We hope the above experiment fixes your concern, and we are glad to discuss it further per the reviewer's request.
On ther other hand, we would like to clarify that our work is purely a theoretical study, which does have a large place in the top machine learning conferences, like NeurIPS, ICLR, and ICML. For instance, ICLR accepted papers [1,2,3,4], all of which are entirely theoretical and do not contain any experiments. Additionally, the papers [4,6,7,8] that focus on designing efficient algorithms for low-rank approximation and attention approximation also do not contain experimental results. Our work is similar to theoretical studies like [4,5,6], concentrating on the theoretical aspects of attention computation.
Moreover, we abstract the most challenging part (the highest time complexity operation) in high-order attention into a clear mathematical problem and provide a solution. Our work introduces a new concept to the community, suggesting that cubic time complexity may not be the bottleneck in implementing three-order attention during training. We believe that our extensive theoretical analysis can outweigh the lack of experimental results.
[1] Zhan, W., Uehara, M., Kallus, N., Lee, J. D., & Sun, W. Provable Offline Preference-Based Reinforcement Learning. ICLR’24.
[2] Chen, S., Chewi, S., Li, J., Li, Y., Salim, A., & Zhang, A. R. Sampling is as easy as learning the score: theory for diffusion models with minimal data assumptions. ICLR’23.
[3] Wen, K., Ma, T., & Li, Z. How Sharpness-Aware Minimization Minimizes Sharpness?. ICLR’23.
[4] Alman, J., & Song, Z. How to capture higher-order correlations? generalizing matrix softmax attention to kronecker computation. ICLR’24.
[5] Han, I., Jayaram, R., Karbasi, A., Mirrokni, V., Woodruff, D. P., & Zandieh, A. Hyperattention: Long-context attention in near-linear time. ICLR’24.
[6] Alman, J., & Song, Z. Fast attention requires bounded entries. NeurIPS’23.
[7] Sarlos, T., Song, X., Woodruff, D., & Zhang, R. Hardness of low rank approximation of entrywise transformed matrix products. NeurIPS’23.
[8] Dexter, G., Drineas, P., Woodruff, D., & Yasuda, T. Sketching algorithms for sparse dictionary learning: PTAS and turnstile streaming. NeurIPS’23.
[9] Kolda, T. G., & Bader, B. W. (2009). Tensor decompositions and applications. SIAM review, 51(3), 455-500.
This paper gives an algorithm for approximating the gradient for tensor attention in n^{1+o(1)} time under certain bounds on the weights, and showed a lowerbound when then weights are not bounded. The reviewers have a large disagreement on the evaluation of this paper. Most reviewers agree that if the method can work in practice it would be interesting, however there are major concerns on how practical the algorithms are, even after the authors included some more empirical evidence in the response period (and this applies to both the most positive and most negative reviewers, except they decide to focus on different aspects of the paper).
审稿人讨论附加意见
The authors provided additional empirical evaluation as well as clarifications during the response period. One of the reviewers did not respond to the author but participated in the reviewer discussion period. Overall all reviewers still seem to be concerned about the applicability despite the empirical evidence, but weighted things differently.
Reject