FlashMask: Reducing the Complexity of Attention Computation through Sparse Mask Representation
摘要
评审与讨论
This paper proposes a novel method to address the high computational and memory complexity of current large-scale transformers. By adopting a simple yet effective column-wise sparse representation of attention masks, the algorithm achieves reduced memory and computational complexity while maintaining the accuracy of attention computation.
优点
- This paper investigates a topic of interest, given the current trend toward increasing context lengths in LLMs.
- The method proposed in this paper is straightforward and easy to implement.
- The paper is well-written and clearly presented.
缺点
-
It is crucial to highlight the advantages of this method over related work to help readers fully understand its significance. However, in the subsection "Attention Optimization Techniques," the authors only mention the drawback of FlashAttention and discuss its relationship to their work. The introduction of other related works is confusing and makes it difficult to comprehend their relevance to this paper. The overall conclusion, "Both of the previously discussed solutions either compromise precision or yield only marginal enhancements in efficiency. Conversely, our proposed FlashMask is capable of delivering exact computations." is general and non-specific. It is unclear which methods compromise precision and which ones only offer marginal improvements.
-
In the experiments, the baseline algorithms are limited to Vanilla Attention and FlashAttention. Are there more efficient Transformer algorithms that could be used for comparison? If not, the authors should explain the rationale behind the selection of these specific baselines.
-
As a non-expert in this field, I found the writing of this paper confusing. For instance, the initialism "HBN" is introduced without any explanation or context.
问题
- Has there been any prior work on efficient attention mask computation?
- Why was the FA-Varlen method not included in the DPO and RM scenarios? Could the authors provide an explanation in their paper?
- In Figure 4, the latency of both FA-Window and FlashMask is almost identical. If the authors aim to demonstrate the efficiency of FlashMask, could they explain the experimental results more clearly?
- In Figures 3 and 5, it appears that the performance and efficiency of FA-Varlen are comparable to FlashMask in the SFT setting. Could the authors clarify this comparison?
局限性
none
Thanks for your review.
For each weakness you mentioned:
- It is crucial to highlight the advantages of this method over related work to help readers fully understand its significance. However, in the subsection "Attention Optimization Techniques," the authors only mention the drawback of FlashAttention and discuss its relationship to their work. The introduction of other related works is confusing and makes it difficult to comprehend their relevance to this paper. The overall conclusion, "Both of the previously discussed solutions either compromise precision or yield only marginal enhancements in efficiency. Conversely, our proposed FlashMask is capable of delivering exact computations." is general and non-specific. It is unclear which methods compromise precision and which ones only offer marginal improvements.
Reply:
Vanilla Attention:
- Advantages: Simple implementation and easy to understand. It was initially proposed in the seminal paper Attention is All You Need.
- Disadvantages: Slow computation speed and high memory usage, with a quadratic relationship to sequence length, resulting in a complexity of .
Memory Efficient Attention:
- Advantages: Lower memory usage with a complexity of , significantly reducing resource requirements.
- Disadvantages: Slower computation speed compared to IO-aware FlashAttention. See the kernel latency in Figure 3 of the attached PDF file.
FlashAttention:
- Advantages: Fast computation speed with no loss in accuracy compared to Vanilla Attention.
- Disadvantages: The official implementation only supports limited mask lacks the mask functionalities required for downstream tasks such as SFT, DPO, and RM.
FlashAttention with DenseMask:
- Advantages: Third-party implementations support DenseMask in FlashAttention, enabling various masks for downstream tasks like SFT, DPO, and RM.
- Disadvantages: High memory usage with a quadratic relationship to sequence length, resulting in a complexity of .
FlashAttention VarLen:
- Advantages: Fast computation speed and support for variable-length sequences, suitable for training tasks such as SFT.
- Disadvantages: Limited support for sparse computation modes, unable to support training tasks like DPO and RM.
Other Approximate Algorithms (e.g., Reformer and Linformer):
- Advantages: Achieve lower memory usage and faster computation speed through approximate sparse attention calculations.
- Disadvantages: Model convergence performance cannot match that of Full Attention, leading to reduced accuracy.
- In the experiments, the baseline algorithms are limited to Vanilla Attention and FlashAttention. Are there more efficient Transformer algorithms that could be used for comparison? If not, the authors should explain the rationale behind the selection of these specific baselines.
Reply: As illustrated in the previous reply, Vanilla Attention, FlashAttention with DenseMask are the proper baselines for comparision. We add other efficient Transformer algorithm called Memory Efficient Attention (MEA) in Figure 3 in the attached PDF. FlashAttention-DenseMask is the most competitive baseline for comparison. FlashAttention with VarLen is faster than FlashAttention-DenseMask, but it can be only used in SFT and the performance of FlashMask is the same with it.
- As a non-expert in this field, I found the writing of this paper confusing. For instance, the initialism "HBN" is introduced without any explanation or context.
Reply: We supposed that the "HBN" you mentioned is "HBM". Thanks for your suggestions. HBM is shorted for High Bandwidth Memory, which is the global memory of the GPU.
For each question you mentioned:
1.Has there been any prior work on efficient attention mask computation?
Reply: please refer to our replies of the first and the second weakness.
2.Why was the FA-Varlen method not included in the DPO and RM scenarios? Could the authors provide an explanation in their paper?
Reply: Thank you for your suggestions. FA-Varlen cannot represent the sparse attention mask in the DPO and RM scenarios as mentioned in our reply of the first weakness. We will provide a comprehensive explanation in the Camera-Ready version.
3.In Figure 4, the latency of both FA-Window and FlashMask is almost identical. If the authors aim to demonstrate the efficiency of FlashMask, could they explain the experimental results more clearly?
Reply: Thank you for advices. Figure 4 is used to explain that FlashMask can be not only speed up the SFT, DPO and RM scenarios, but also speed up the sliding window sparse attention mask training scenarios. FlashMask is a more general method than the existing techniques, which can be used in many other NLP training tasks.
4.In Figures 3 and 5, it appears that the performance and efficiency of FA-Varlen are comparable to FlashMask in the SFT setting. Could the authors clarify this comparison?
Reply: In fact, if the sparse attention mask can be represented by either the FlashMask or FA-Varlen method, the efficiency of these two methods are almost the same as shown in Figure 3 and 5. However, FA-Varlen failed to represents the sparse attention masks in DPO and RM. Therefore, the FlashMask method is a more general method in almost all of the NLP tasks, while the usage of FA-Varlen method is limited to SFT.
Thank you for your detailed response, most of my concerns are addressed during rebuttal and I keep my score.
Dear Reviewer,
Thank you for taking the time to review our paper and for your response to our rebuttal. We understand your decision to maintain the original score and respect your evaluation.
We would like to further clarify some of the improvements and experimental results we mentioned in the rebuttal, which we believe demonstrate the strengths of our work. If there are any points that you feel were not fully addressed, we would be more than happy to engage in further discussion to ensure that you have a comprehensive understanding of our contributions.
Thank you again for your time and valuable feedback.
Best regards,
Authors
The paper introduces FlashMask, an innovative algorithm designed to address the computational and memory challenges associated with conventional attention mechanisms in large-scale Transformers. FlashMask employs a column-wise sparse representation for attention masks, significantly reducing the computational complexity from quadratic to linear with respect to sequence length. The authors demonstrate FlashMask's effectiveness across various masking scenarios and training modalities, including Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and Reward Model (RM).
优点
This paper presents a novel solution to a well-known problem in the field of natural language processing, offering a practical method to reduce the computational burden of attention mechanisms in Transformers.
The paper provides extensive empirical evidence to support its claims, including comparisons with state-of-the-art techniques like FlashAttention, demonstrating FlashMask's superiority in terms of speed and efficiency.
FlashMask's performance across different masking scenarios and training modalities shows its versatility and robustness, indicating its potential applicability to a wide range of models and tasks.
Practical Impact: The paper not only presents theoretical advancements but also demonstrates practical benefits, such as enabling the
缺点
The scaling ability of the proposed method deserves further verified on large scale datasets.
While the paper demonstrates FlashMask's effectiveness in specific scenarios, it may lack broader evidence on how it performs across different types of NLP tasks or diverse datasets.
The paper could provide more detailed insights into how FlashMask handles different sparsity levels and the impact on various model sizes and complexities.
问题
See the weakness
局限性
yes
Thanks for your review.
We were sorry that we had not clarified our key points previously. In the SFT/DPO/RM training scenarios, the sparsity of the attention mask is usually natural, but not intended to speed up training while sacrificing accuracy. For examples, as illustrated in Figure 1, the sparsity of the attention mask may be from: (1) padding mask, (2) InToken mask, and (3) question and answering mask. We were not discussing how to design a new sparse attention mask, balancing the model accuracy and training efficiency. Instead, we took advantage of the sparsity attention property in these NLP tasks itself to speed up the training process. The key points of our paper include: (1) we proposed an efficient sparse mask representation method which can be used in SFT, DPO, RM and many other scenarios; (2) based on our sparse mask representation, we proposed an efficient kernel implementation to speed up the training process. Since we did not introduce extra approximate calculations (the sparsity comes from the NLP task itself), the model accuracy should be exactly the same with or without our method. We also conducted some extra experiments evaluating model quality in the attached PDF file (Figure 1 and Table 1).
Due to time constraints, we have only supplemented the experiments in SFT and DPO scenarios, and the benchmark tests were conducted on a limited set of datasets to demonstrate the accuracy preservation of our method. We used the Huggingface LLaMA2-7B pre-trained model and conducted SFT training using Vanilla Attention, FlashAttention-DenseMask, and FlashMask, followed by DPO training. SFT and DPO use Packing/InToken data training strategy. The SFT phase was performed on the "allenai/tulu-v2-sft-mixture" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 2e-05, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 12000, a warmup step count of 360, and a global train batch size of 16. The DPO phase was conducted on the "HuggingFaceH4/ultrafeedback_binarized" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 1e-06, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 1600, a warmup step count of 10, and a global train batch size of 16. We will provide a comprehensive comparison in the Camera-Ready version. It showed that the training loss and the benchmark results are the same between our speedup method and the baselines.
The detailed replies are as follows:
- The scaling ability of the proposed method deserves further verified on large scale datasets.
Reply: In Section 4.3 and 4.4, we had conducted some experiments to show the speedup and memory saving of our proposed method. The testing dataset included the LongBench (an open-source benchmark for long context understanding) and our synthetic data. As we mentioned before, our method is used to speed up the training process without sacrificing accuracy. We believed that the scaling ability of our method should be good on large scale datasets.
- While the paper demonstrates FlashMask's effectiveness in specific scenarios, it may lack broader evidence on how it performs across different types of NLP tasks or diverse datasets.
Reply: Our proposed method performed extact computation when the NLP tasks used sparse attention mask itself. Therefore, different datasets would not affect the performance of our proposed methods. The SFT, DPO and RM are almost the most important LLM downstream training tasks, and we believed that our proposed method is a general speedup method in most of the NLP tasks.
- The paper could provide more detailed insights into how FlashMask handles different sparsity levels and the impact on various model sizes and complexities.
Reply: Our proposed method performed extact computation when the NLP tasks used sparse attention mask itself. Therefore, the model accuracy after using our method would not be changed, regardless of the sparsity levels, model sizes and complexities.
As we near the author-reviewer discussion deadline, we seek your feedback on our rebuttal. Your insights have been crucial in improving our work, and we're grateful for the time and effort you've dedicated to our manuscript.
Thank you for your valuable guidance. We look forward to your response and are ready to make further adjustments if needed.
Best regards,
Authors
This paper proposes FlashMask, which accelerates the masked attention mechanism that can reduce the original attention from O(N^2) to O(N) and simultaneously reduces the memory cost. Experimental results show that the proposed FlashMask significantly reduces training time without accuracy degradation.
优点
-
This paper provides a comprehensive study and analysis about the sparse attention, in terms of their efficiency. Also, this paper includes existing attention optimization like FlashAttention, explaining the motivation of the proposed FlashMask, which lies in the lack of optimization for sparse attention.
-
This paper proposes an optimization for column-based sparse attention, which significantly improves memory efficiency and reduces computational costs.
-
This paper provides a comprehensive complexity analysis, evaluation, and comparison with existing methods. It seems the authors make a lot of efforts on the proposed approach.
缺点
-
Even though FlashMask achieves significant improvement in the memory efficiency of sparse attention, the key idea is similar to FlashAttention, but it is just for sparse attention mechanisms. Based on this fact, the novelty of this paper is not strong. I recommend the authors explain why the red part in the algorithm is designed and why it is unique for sparse attention.
-
The authors only present optimization for column-based sparse attention. The performance for other types of sparse attention is unknown. If the proposed approach can be applied to all sparse attention, the contribution of this paper is extremely great. However, the existing version is not comprehensive.
-
Based on the experiments, the practical latency is not significantly reduced as compared to other methods, even though the theoretical complexity is from N^2 to N. Besides, the authors do not provide results for accuracy.
问题
See Weaknesses.
局限性
N/A
Thanks for your review.
- Even though FlashMask achieves significant improvement in the memory efficiency of sparse attention, the key idea is similar to FlashAttention, but it is just for sparse attention mechanisms. Based on this fact, the novelty of this paper is not strong. I recommend the authors explain why the red part in the algorithm is designed and why it is unique for sparse attention.
Reply: There are two contributions in this paper: (1) we proposed a column-based sparse attention mask representation, instead of using a large dense tensor to represent the sparse attention mask. (2) we proposed an efficient kernel implementation using our sparse attention mask representation based on FlashAttention. We think that the novelty of this paper is that we found the common sparse attention mask representation in most of the NLP tasks, and speeded up the attention phase combining this representation and FlashAttention. Our sparse attention mask representation is general in most of the NLP tasks, including SFT, DPO, RM, etc.
As shown in Figure 1 (b) of the main text, the Attention Mask required for a long sequence composed of three sequences in the SFT scenario, with sequence lengths [4,3,3]. The FMS values are [4,4,4,4,7,7,7,10,10,10], and FME does not need to be set, defaulting to the maximum number of rows. For example, in the second column, the FMS value is 4, indicating that the elements in rows are masked.
As shown in Figure 1 (c) of the main text, the Attention Mask required for a long sequence composed of three sequences in the bidirectional SFT scenario, with sequence lengths [4,3,3]. FlashMask uses two pairs of FMS and FME to describe the bidirectional scenario: the lower left part is denoted as and , and the upper right part as and . is [4,4,4,4,7,7,7,10,10,10], and does not need to be set, defaulting to the maximum number of rows; does not need to be set, defaulting to 0; is [0,0,0,0,4,4,4,7,7,7]. For example, in the fourth column, the value is 7, indicating that the elements in rows are masked, and the value is 4, indicating that the elements in rows are masked.
As shown in Figure 1 (d) of the main text, the Attention Mask required for one Query and two Answers in the DPO scenario, with the Query length being 4, Answer1 length being 3, and Answer2 length being 3. The FMS values are [10,10,10,10,7,7,7,10,10,10]. For example, in the zeroth column, the FMS value is 10, indicating that the elements in rows are masked.
- The authors only present optimization for column-based sparse attention. The performance for other types of sparse attention is unknown. If the proposed approach can be applied to all sparse attention, the contribution of this paper is extremely great. However, the existing version is not comprehensive.
Reply: Thanks for you suggestions. In fact, it is hard to handle all kinds of the sparse attention masks. If we use a large dense tensor to represent the sparse attention mask, the costs of the memory usages and HBM access are not acceptable; if we use a sparse-coo liked method to represent the sparse attention mask, it is hard to implement an efficient CUDA kernel as well. In this paper, we found that the sparse attention mask in most of the NLP tasks can be represented using a column-based way. Therefore, we can use FMS and FME in our paper to represent the mask efficiently. We believed that our method is general to most of the NLP tasks like SFT, DPO, RM, etc.
- Based on the experiments, the practical latency is not significantly reduced as compared to other methods, even though the theoretical complexity is from N^2 to N. Besides, the authors do not provide results for accuracy.
Reply: The theoretical complexity corresponds to the FA-DenseMask and Vanilla Attention in Figure 3. We provided a new curve in the attached PDF file (Figure 3) comparing the kernel latency of the FA-DenseMask, Vanilla Attention and our proposed FlashMask method. It showed that the theoretical complexity is from (FA-DenseMask, Vanilla Attention) to (our FlashMask method). We also conducted some extra experiments evaluating model quality in the attached PDF file (Figure 1 and Table 1).
Due to time constraints, we have only supplemented the experiments in SFT and DPO scenarios, and the benchmark tests were conducted on a limited set of datasets to demonstrate the accuracy preservation of our method. We used the Huggingface LLaMA2-7B pre-trained model and conducted SFT training using Vanilla Attention, FlashAttention-DenseMask, and FlashMask, followed by DPO training. SFT and DPO use Packing/InToken data training strategy. The SFT phase was performed on the "allenai/tulu-v2-sft-mixture" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 2e-05, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 12000, a warmup step count of 360, and a global train batch size of 16. The DPO phase was conducted on the "HuggingFaceH4/ultrafeedback_binarized" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 1e-06, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 1600, a warmup step count of 10, and a global train batch size of 16. We will provide a comprehensive comparison in the Camera-Ready version. The model accuracy results showed that the model accuracy using our method is the same with that of the FA-DenseMask and Vanilla Attention.
As we near the author-reviewer discussion deadline, we seek your feedback on our rebuttal. Your insights have been crucial in improving our work, and we're grateful for the time and effort you've dedicated to our manuscript.
Thank you for your valuable guidance. We look forward to your response and are ready to make further adjustments if needed.
Best regards,
Authors
Thank the authors for the rebuttal. Unfortunately, my main concern is not well addressed, as this work is specific for the column-based sparse attention, which is not general enough. Besides, the implementation is based on existing FlashAttention. Overall, I keep my original rating.
This paper proposes FlashMask, a modification of FlashAttention with fixed masks. The paper shows speedup of FlashAttention when using sparse masks in the attention matrix.
优点
FlashAttention is an important algorithm, and sparsity in the attention matrix is an important feature. Further study of these aspects is helpful.
缺点
The paper seems to make claims that are unsubstantiated by experiments. In the abstract and introduction, the paper claims speedup without sacrificing model quality. However, there is no experiment evaluating model quality in the experiments. This is a critical flaw.
Further, the contribution of the paper is unclear. Block-sparsity is already supported in FlashAttention (see section 3.3 of FlashAttention). It is unclear how this paper is different. There are also more recent works such as "Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention" (NeurIPS 2023), which seem to be strictly more expressive in features than this paper.
问题
See weaknesses
局限性
The paper discusses superlinear scaling in sequence length as a limitation, but is lacking in discussion of model quality.
Thanks for your review.
- The paper seems to make claims that are unsubstantiated by experiments. In the abstract and introduction, the paper claims speedup without sacrificing model quality. However, there is no experiment evaluating model quality in the experiments. This is a critical flaw.
Reply: We were sorry that we had not clarified our key points previously. In the SFT/DPO/RM training scenarios, the sparsity of the attention mask is usually natural, but not intended to speed up training while sacrificing accuracy. For examples, as illustrated in Figure 1, the sparsity of the attention mask may be from: (1) padding mask, (2) InToken mask, and (3) question and answering mask. We mainly focused on how to speed up the training process when using these kinds of natural sparse attention masks. Our baselines are the non-speedup methods using the same sparse attention mask. Therefore, we claimed speedup without sacrificing model accuracy.
Due to time constraints, we have only supplemented the experiments in SFT and DPO scenarios, and the benchmark tests were conducted on a limited set of datasets to demonstrate the accuracy preservation of our method. We will provide a comprehensive comparison in the Camera-Ready version. We used the Huggingface LLaMA2-7B pre-trained model and conducted SFT training using Vanilla Attention, FlashAttention-DenseMask, and FlashMask, followed by DPO training. SFT and DPO use Packing/InToken data training strategy. The SFT phase was performed on the "allenai/tulu-v2-sft-mixture" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 2e-05, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 12000, a warmup step count of 360, and a global train batch size of 16. The DPO phase was conducted on the "HuggingFaceH4/ultrafeedback_binarized" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 1e-06, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 1600, a warmup step count of 10, and a global train batch size of 16.
From the train loss curves in Figure 1 of the attached PDF, it is evident that FlashMask ensures convergence accuracy comparable to the baselines (Vanilla Attention and FlashAttention-DenseMask). Additionally, FlashMask and FlashAttention-DenseMask exhibit identical loss values. The evaluation metrics from the benchmark tables (Table 1 in attached PDF) indicate that FlashMask achieves the same accuracy on par with FlashAttention-DenseMask. Therefore, FlashMask, in comparison to FlashAttention-DenseMask, is an exact algorithm with no loss in accuracy.
- Further, the contribution of the paper is unclear. Block-sparsity is already supported in FlashAttention (see section 3.3 of FlashAttention). It is unclear how this paper is different. There are also more recent works such as "Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention" (NeurIPS 2023), which seem to be strictly more expressive in features than this paper."
Reply: SFT, DPO and RM are the very important scenarios in the NLP downstream training tasks. Although block-sparsity is already supported in FlashAttention, the block-sparsity does not match the sparse attention mask pattern in SFT, DPO and RM shown in Figure 1. The paper you mentioned ("Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention") contributed to speed up the training process when using the QK-Sparse attention and Hash-Sparse attention. But it also failed to represent the sparse mask attention in SFT, DPO and RM shown in Figure 1. We proposed an efficient sparse mask representation method which can be used in SFT, DPO, RM and many other scenarios, even including the QK-Sparse/Hash-Sparse attention in the mentioned paper. Thanks to our efficient representation, we proposed an efficient kernel implementation to speed up the sparse attention phase. Our approach is more general than the existing research works and can be used in most of the NLP training scenarios.
For example, as shown in Figure 2 of the attached PDF, FlashMask can represent QK-sparse and Hash-sparse masks using FlashMask. In the figure, denotes the starting row index of the mask in the lower left triangle, and denotes the ending row index of the mask in the upper right triangle. The diagonal elements should be considered part of the lower left triangle. In FlashMask, (representing the ending row index of the mask in the lower left triangle) does not need to be set and defaults to the maximum number of rows, while (representing the starting row index of the mask in the upper right triangle) defaults to 0. Here, the and represent the in Algorithm 1, and the and represent the in Algorithm 1.
For example, in Figure 2(a) (in the attached PDF file) for QK-sparse, in the second column, is 6, indicating that the elements in the rows are masked, and is 3, indicating that the elements in the rows are masked. In Figure 2(b) (in the attached PDF file) for Hash-sparse, in the first column, is 1, indicating that the elements in the rows are masked, and is 1, indicating that the elements in the rows are masked. In Figure 2(b) (in the attached PDF file) for Hash-sparse, in the second column, is 5, indicating that the elements in the rows are masked, and is 2, indicating that the elements in the rows are masked.
Thank you for your rebuttal. I now better understand your contributions. It appears that one of the key observations is that column-wise padding for attention is useful for some situations (analogous to the row-wise "key padding mask" in BERT-style models).
Unfortunately, it is now clear to me that the paper is quite poorly written for the contribution. The submitted draft of the paper does not make it clear that FlashMask is an optimization for specific use cases where the opportunity for per-column masking already exists. Even then, given the examples in the paper, I'm not convinced that they can't be covered by a block-sparse mask. I would recommend revising the paper for clarity and precision and re-submitting to a later conference. I will not be changing my score.
-
The core contribution of this paper lies in proposing FlashMask, an extension of FlashAttention with sparse mask attention speedup, for downstream NLP tasks such as SFT, DPO, and RM. FlashMask introduces a column-based sparse mask representation and develops an efficient CUDA kernel implementation. FlashMask's mask representation method is not only applicable to common NLP downstream tasks such as SFT, DPO, and RM, but it can also express more customized masks, as shown in Figure 1 and Figure 2 of the main paper. Additionally, FlashMask can efficiently represent and implement masks corresponding to the QK-sparse and Hash-sparse attention in the paper like "Fast Attention Over Long Sequences With Dynamic Sparse Flash Attention". This method not only achieves exact accuracy equivalent to FlashAttention with Dense Mask implementation (it is described as the baseline algorithm in our main paper), but also reduces the memory complexity of the mask from to . In end-to-end training, SFT, DPO, and RM tasks can achieve more than 2.4 times speedup.
-
In the main paper previously, we did not provide model accuracy experiments but declared FlashMask as an exact algorithm in the Abstract and Introduction. During this rebuttal period, we supplemented model accuracy experiments to demonstrate that FlashMask is lossless in terms of accuracy.
Due to time constraints, we have only supplemented experiments in SFT and DPO scenarios, and the benchmark tests were conducted on a limited set of datasets to demonstrate the accuracy preservation of our method. We will provide a comprehensive comparison in the Camera-Ready version. We used the Huggingface LLaMA2-7B pre-trained model and conducted SFT training using Vanilla Attention, FlashAttention-DenseMask, and FlashMask, followed by DPO training. Both SFT and DPO phases used the Packing or InToken data training strategy described in Figure 1(a) and Figure 1(d) in the main paper respectively. The SFT training was performed on the "allenai/tulu-v2-sft-mixture" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 2e-05, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 12000, a warmup step count of 360, and a global train batch size of 16. The DPO training was conducted on the "HuggingFaceH4/ultrafeedback_binarized" dataset, using the AdamW optimizer with beta1=0.9, beta2=0.999, a learning rate of 1e-06, an end learning rate of 1e-07, weight decay of 0.0, a total training step count of 1600, a warmup step count of 10, and a global train batch size of 16.
From the train loss curves in Figure 1 of the attached PDF, it is evident that FlashMask ensures convergence accuracy comparable to the baselines (Vanilla Attention and FlashAttention-DenseMask). Additionally, FlashMask and FlashAttention-DenseMask exhibit identical loss values. The evaluation metrics from the benchmark tables (Table 1 in the attached PDF) indicate that FlashMask achieves the same accuracy on par with FlashAttention-DenseMask. Therefore, FlashMask, in comparison to FlashAttention-DenseMask, is an exact algorithm without sacrificing the model accuracy.
This paper introduces FlashMask, which aims to reduce the computational and memory complexity of attention mechanisms in large-scale Transformers through a column-wise sparse mask representation. The method is presented as a straightforward and effective solution, showing improvements in training efficiency without compromising accuracy.
However, the reviewers raised significant concerns, including the limited novelty of the approach, the lack of generality beyond specific scenarios, and the insufficient clarity in writing. While the authors provided additional data and clarifications during the discussion period, these efforts did not fully address the reviewers' concerns. The improvement in experimental results and the overall contribution of the method were not found to be substantial.
Despite the paper receiving an average borderline score, the consensus is that the contributions are not strong enough to merit acceptance. The reviewers believe that the paper could benefit from further refinement and a broader evaluation to enhance its impact. As the value proposition is not entirely convincing, the AC recommends rejecting this paper, encouraging the authors to address the feedback and resubmit after significant revision.