The Fine-Grained Complexity of Gradient Computation for Training Large Language Models
A theoretical study of gradient computation, from both algorithm and hardness perspective
摘要
评审与讨论
This paper studies the fine-grained complexity of gradient computation for training attention networks. In particular, the paper shows that under SETH assumption, when the norm of certain matrices are bounded by , the gradient computation lower bound is almost linear. an almost optimal algorithm is used in the theorem to provide the upper bound which almost matches the lower bound.
优点
The paper is theoretically well-rounded. The work is among the first few to study the computational complexity lower bound for gradient of attention networks.
缺点
(Pleae reply to the Questions section directly) The work is pure theoretical and lack of experiment verifications; The work is dedicated for one layer attention networks and may arguably not for "large" language models
问题
First of all, I am not very familiar with the strong exponential time hypothesis (SETH) or the related literature. I evaluate the work on LLM and numerical analysis perspectives, which might not be adequate. In general I think this is an interesting work with potentials of guiding the practices, yet it seems not fully explored. I have the following questions for the authors
-
I think the paper is somehow over-claimed for "training LLMs" since the authors only consider one layer of attention network. How is the gradient computation error behave if we use backpropogation with multi-layers? There should be a remark or corollary for it.
-
Can the authors discuss the implication of the algorithm used for construction the upepr bound in Theorem 1.6 for real applications? For example, since when is small we have a theoretical efficient algorithm, shall we adapt certain strategy to keep the norm smaller than certain constant? Also shall we follow the theoretically-efficient algorithm when computing the gradient in practice?
-
Continue the above point, is there any numerical evidence to show that having a small norm or the algorithm used in constructing the upper bound have advantage in practice? This may greatly strengthen the point made in this work.
局限性
The authors are pretty clear about their limitations.
Thank you for the thoughtful review. Indeed our focus is on the theoretical foundations of these LLM computational tasks, although we are optimistic that these results will help future empirical work as well. For example, the prior work on the theory of these problems which we extend here [AS23] has motivated a number of new fast algorithms for attention ([KMZ23,HJK+23,ZBKR24,YAH+24]).
To answer your questions:
Our algorithms and lower bounds can both apply separately to each attention head in a larger LLM. In other words, in training, you can separately apply the algorithm for each attention head and attention layer, still giving a fast algorithm for any size model. We will expand on this in the final version.
Indeed, a main takeaway of our results is that one should focus on maintaining parameters with smaller norms throughout the training process in order to achieve faster algorithms. This mirrors the phenomenon observed in practice that quantization or other parameter approximations often lead to faster algorithms [AS23]. And indeed, our approach could be used for gradient computations in practice, although as we discuss in section 6, a practical implementation would likely need more algorithms engineering work.
Indeed, prior work on LLM implementations has observed a similar phenomenon, that algorithmic techniques like quantization [ZBIW19] and low-degree polynomial approximation [KVPF20], which require bounded or low-precision entries, can substantially speed up LLM operations. We discuss this in section 1 (lines 66-72) of our paper, and the introductions of those papers discuss the phenomenon in more detail.
[AS23] Josh Alman, and Zhao Song. Fast attention requires bounded entries. In NeurIPS 2023.
[KMZ23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast trans- formers via sketches for polynomial kernels. In ICML 2024
[HJK+23] Insu Han, Rajesh Jarayam, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. In ICLR 2024.
[KVPF20] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Trans- formers are rnns: Fast autoregressive transformers with linear attention. In ICML 2020.
[ZBIW19] Ofir Zafrir, Guy Boudoukh, Peter Izsak, and Moshe Wasserblat. Q8bert: Quantized 8bit bert. In 2019 Fifth Workshop on Energy Efficient Machine Learning and Cognitive Computing-NeurIPS Edition (EMC2-NIPS), pages 36–39. IEEE, 2019.
[ZBKR24] Michael Zhang, Kush Bhatia, Hermann Kumbong and Christopher Re. The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry. In ICLR 2024.
[YAH+24] Kai Yang, Jan Ackermann, Zhenyu He, Guhao Feng, Bohang Zhang, Yunzhen Feng, Qiwei Ye, Di He, Liwei Wang. Do Efficient Transformers Really Save Computation? In ICML 2024.
This paper presents a comprehensive analysis of the fine-grained complexity involved in gradient computation for training large language models (LLMs). Building on previous work that characterized the complexity of forward computations in LLMs, this study extends the analysis to backward computations. The authors establish a threshold for gradient computation analogous to those found in forward computation. They develop a near-linear algorithm for scenarios where the bound is within the threshold and prove the impossibility of a subquadratic algorithm under the Strong Exponential Time Hypothesis (SETH) when the bound exceeds the threshold.
优点
- The analysis provides theoretical evidence supporting recent successes in the efficient computation of transformers.
- The conclusions offer valuable insights for future algorithm design.
- The paper is well-organized and clearly presented.
缺点
While the theoretical contributions are strong, the paper lacks empirical validation of the conclusions and the proposed algorithm.
问题
- How does the proposed algorithm perform when the bound exceeds the threshold?
- Based on your conclusions, is there a principle in practice for adjusting the strategy for forward and backward computations?
- What is the memory cost associated with the proposed algorithm?
局限性
As the authors discussed the limitations in Section 6.
Thank you for the thoughtful review. Indeed our focus is on the theoretical foundations of these LLM computational tasks, although we are optimistic that these results will help future empirical work as well. For example, the prior work on the theory of these problems which we extend here [AS23] has motivated a number of new fast algorithms for attention ([KMZ23,HJK+23,ZBKR24,YAH+24]).
To answer your questions:
-
Indeed, when the bound exceeds the threshold, our algorithm may give incorrect answers, as it is not designed to handle larger inputs. However, our lower bound result shows that this phenomenon is necessary: it is not possible to design a fast algorithm beyond the threshold using any algorithmic technique.
-
One of the major insights behind our algorithm is that: we prove that the matrices in training/inference of the attention network can be expressed as the multiplication of two low-rank matrices. Such an idea has been widely used in fine-tuning throughout the study of LLMs (such as LoRA [HSW+21]). In a sense, our method gives a theoretical foundation for such a practical method. LoRA is mainly used in fine-tuning, but our method is for backward computation, indicating that this low-rank idea may be applied successfully to the training stage as well.
-
The memory cost is also almost linear: the algorithm simply performs simple matrix-vector multiplications where the memory and time usage are both almost linear.
[AS23] Josh Alman, and Zhao Song. Fast attention requires bounded entries. In NeurIPS 2023.
[KMZ23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast transformers via sketches for polynomial kernels. In ICML 2024.
[HJK+23] Insu Han, Rajesh Jarayam, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. In ICLR 2024.
[ZBKR24] Michael Zhang, Kush Bhatia, Hermann Kumbong and Christopher Re. The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry, In ICLR 2024.
[YAH+24] Kai Yang, Jan Ackermann, Zhenyu He, Guhao Feng, Bohang Zhang, Yunzhen Feng, Qiwei Ye, Di He, Liwei Wang. Do Efficient Transformers Really Save Computation? In ICML 2024.
[HSW+21] Edward J Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen. Lora: Low-rank adaptation of large language models. In ICLR 2022.
Thanks to the author for answering my questions. I currently have no more questions. I will maintain my score.
This paper studies the complexity of gradient computation in training models with attention layers. The paper claims that their results support the same complexity as the forward computation for the backward/gradient computation. The results show the impact of the size of the entries of the attention metrics on LLM training.
优点
- The complexity of gradient computation for attention layers is an interesting and important problem for LLM training. The approach to complexity analysis can provide useful insights to the community.
- The complexity results are theoretically sound, supported by proof sketches.
缺点
- The clarity of the paper's writing could be improved. The paper could define the problem more clearly in the introduction and present the contributions clearly. It would be better to follow a flow that first introduces the computation problem, briefly describes the techniques and assumptions used to derive the complexity results, and then presents the complexity results. Currently, these aspects are intermixed.
- The significance of the paper is unclear. The complexity of gradient computation aims to solve Approximate Attention Loss Gradient Computation, but the accuracy of this approximation compared to practical gradient computation is not studied. Additionally, the analysis requires the SETH assumption, the validity of which is not provided in the paper. Thus, it is unclear if the complexity presented in this paper accurately reflects practical cases.
问题
- Could the authors explain the key differences in technique between this paper and existing work that also explores the complexity of forward and backward/gradient computations?
- Could we verify the SETH assumption for attention models?
局限性
NA
Thank you for the thoughtful review. We aimed to organise the paper by first introducing the problem we study (Definition 1.4 in section 1.1) then explaining our two main results and how they relate to each other (section 1.2), explaining the context of other work in this area (section 2), and then describing how we achieved these results (section 3 for standard tools we use, and then section 4 introduces our new idea). We chose this flow, of stating and explaining both the algorithmic and complexity results before getting into technical details, because we think that the contrast between the two results is one of the most important aspects of the paper that we want to get across: that having a bound on the entries is both necessary (from the complexity result) and sufficient (from the algorithm) for fast LLM training.
To answer your other questions:
The existing work requires time to compute the gradient. Our new technique, of applying the polynomial/low-rank factorization idea to speedup the gradient computation, is novel, and has not been used before in gradient computation to our knowledge. Moreover, to the best of our knowledge, we’re the first to give an almost linear time ( ) algorithm to compute the gradient with provable guarantees.
SETH is the most popular conjecture in the area of fine-grained complexity, and forms the backbone of many results in that area, showing hardness of problems from ML, graph algorithms, string algorithms, similarity search, and more. The fact that it gives tight lower bounds for a wide variety of different algorithms problems, and that we don’t know how to beat SETH lower bounds for any of these problems, gives strong evidence for SETH. As we discuss in section 2 and section 3.2, SETH is actually a strengthening of the P != NP conjecture, and so verifying it would require, at the very least, first solving the P vs NP problem (which is far beyond the scope of this paper).
The authors study Approximate Attention Loss Gradient Computation (AALGC) (e.g. complexity of training LLMs with backprop):
- they show that the magnitude of entries of certain matrices (layer inputs times QKV matrices) affect whether LLM backprop is "almost-linear" or not. This connects techniques like quantization to complexity.
- if the max entry of the products is bounded by B and B is o(root(logn)) it's almost-linear time, and when it's more than that, it's impossible under some assumptions (SETH)
- for the almost-linear case, they construct a novel algorithm for AALGC
- SETH: the assumptions are the Strong Exponential Time Hypothesis: for any eps>0 exists k>=3 for which solving k-sat with n variables in O(2^{(1-eps)n}) is impossible (even with randomized algorithms)
- part of the construction involves a polynomial approximation to softmax that works well for non-large entries.
- the authors use some nice techniques (like the tensor trick) that not everyone between theory + ML is familiar with, so the techniques in the paper are useful in their own right
- more generally, the style of the paper is nice in that the results are clear + it's very pedagogical
优点
- the results, as mentioned in the introduction: extending results in complexity of forward computation to complexity of backprop
- discussion of how complexity changes when certain assumptions are made (magnitude of entries) or approximations are made (polynomial softmax)
- connection to practice: discussion of e.g. quantization practices in the context of theory
缺点
-
the softmax approximation passes by a little too quickly in the paper, I had to scroll up and down throughout the paper several times to piece the definitions together. If you control+F (search) for "softmax" or "polynomial" in the paper, only a few results come up and it's mostly references. It would be really great to add a committed subsection that says "This is the precise softmax approximation we use" and later in the algorithm section point out where exactly it's used in a clear way.
-
could use a little more discussion on whether such bounds on B hold in practice (e.g. in some open source models)
问题
-
I can't tell if it's just notation or not, but at some points like Definition 1.4 it seems like "n" refers to (1) the sequence length (2) the number of bits in the B bound B = o(root(log(n)) , and also (3) d = O(log n) where d is the embedding dimension ? Could you elaborate on which of these are tied together?
-
You allude to the near-linear time algorithms possibly being tough and needing more algorithms/hardware innovations to make practical. Could you elaborate on whether there are any particular things that come to mind?
-
AALGC seems to switch between an epsilon parameter and B parameter around Definition 1.4
局限性
Yes
Thank you for the thoughtful review. We especially appreciate the suggestions in the ``weaknesses'' section, and we will make these suggested changes in the final version.
To answer your questions:
You’re right, our main problem definition (Definition 1.4) defines these as separate parameters, and is always the sequence length. Our results discuss when fast algorithms are possible, and the answer depends on how these different parameters relate to each other, which is why we set some of the other parameters in terms of . For instance, we show that if is bigger than then no faster algorithm is possible, but if is smaller then we achieve an almost linear time algorithm.
Actually, some research groups are currently working to implement these polynomial approximation approaches in practice. See, for example, the papers [MKZ23,HJK+23] we cite which discusses this in a lot more detail.
In our paper, the parameter B gives a maximum norm for the input matrices, whereas epsilon is the allowed output error. E.g., in Theorem 1.6, “1/poly() accuracy” means that the epsilon is any 1/poly(). We’ll try to clarify this more in the writing.
[KMZ23] Praneeth Kacham, Vahab Mirrokni, and Peilin Zhong. Polysketchformer: Fast transformers via sketches for polynomial kernels. In ICML 2024.
[HJK+23] Insu Han, Rajesh Jarayam, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, and Amir Zandieh. Hyperattention: Long-context attention in near-linear time. In ICLR 2024.
The paper under review delves into the complexities of Approximate Attention Loss Gradient Computation within the framework of training large language models. It presents a significant contribution by connecting the magnitude of certain matrix entries to the computational complexity of backpropagation, thereby linking quantization techniques to training efficiency. The authors establish a nearly linear time complexity for gradient computation when the maximum entry of the product of layer inputs and QKV matrices is bounded by a function that grows slower than the square root of the input size. Conversely, they argue that under the Strong Exponential Time Hypothesis, achieving subquadratic time complexity is unlikely if this bound is exceeded. Reviewers' concerns are properly addressed. Overall, the paper provides a solid theoretical foundation and meets the acceptance criteria.