PaperHub
7.0
/10
Poster4 位审稿人
最低6最高8标准差1.0
8
8
6
6
4.3
置信度
正确性3.0
贡献度2.8
表达3.3
ICLR 2025

FlashMask: Efficient and Rich Mask Extension of FlashAttention

OpenReviewPDF
提交: 2024-09-27更新: 2025-02-28

摘要

The computational and memory demands of vanilla attention scale quadratically with the sequence length $N$, posing significant challenges for processing long sequences in Transformer models. FlashAttention alleviates these challenges by eliminating the $\mathcal{O}(N^2)$ memory dependency and reducing attention latency through IO-aware memory optimizations. However, its native support for certain attention mask types is limited, and it does not inherently accommodate more complex masking requirements. Previous approaches resort to using dense masks with $\mathcal{O}(N^2)$ memory complexity, leading to inefficiencies. In this paper, we propose \ours{}, an extension of FlashAttention that introduces a column-wise sparse representation of attention masks. This approach efficiently represents a wide range of mask types and facilitates the development of optimized kernel implementations. By adopting this novel representation, \ours{} achieves linear memory complexity $\mathcal{O}(N)$, making it suitable for modeling long-context sequences. Moreover, this representation enables kernel optimizations that eliminate unnecessary computations by leveraging sparsity in the attention mask, without sacrificing computational accuracy, resulting in higher computational efficiency. We evaluate \ours{}'s performance in fine-tuning and alignment training of LLMs such as SFT, LoRA, DPO, and RM. \ours{} achieves significant throughput improvements, with end-to-end speedups ranging from 1.65x to 3.22x compared to existing FlashAttention dense method. Additionally, our kernel-level comparisons demonstrate that \ours{} surpasses the latest counterpart, FlexAttention, by 12.1% to 60.7% in terms of kernel TFLOPs/s, achieving 37.8% to 62.3% of the theoretical maximum FLOPs/s on the A100 GPU. The code is open-sourced on PaddlePaddle\footnote{https://github.com/PaddlePaddle/Paddle} and integrated into PaddleNLP\footnote{https://github.com/PaddlePaddle/PaddleNLP}, supporting models with over 100 billion parameters for contexts extending up to 128K tokens.
关键词
Attention Mask Efficient RepresentationEfficient Attention ComputationLong contextIO complexityGPUsLLMs

评审与讨论

审稿意见
8

The article presents Flash Mask - a method to incorporate a wide class of attention masks into flash attention. Algorithmically, this means that the non trivial structuring of the mask is incorporated into the online calculation of the softmax operation involved in self attention without materializing the full mask and paying quadratic in sequence length memory. Furthermore, this algorithm is implemented in a hardware aware fashion, much like flash attention to minimize memory access and data movement while exploiting the thread group level parallelism in GPUs. Empirically, this method is benchmarked against the dense mask of flash attention 2 and also against flex attention - an alternative state of the art method to incorporate structured masks in efficient attention computations and the method presented shows noticeable gains both in inference and in training.

优点

The strengths of this paper are in the novelty of the incorporation of a class of structured masks into the online softmax calculation involved in memory efficient attention and further into the hardware aware version of the algorithm which is flash attention, and in the comprehensive validation of the superior efficiency of the algorithm when benchmarked against dense masking in conventional flash attention 2 and in flex attention. The algorithm for the forward and backward pass are presented clearly and the empirical results are presented clearly.

缺点

As admitted by the authors, this method cannot handle irregular masking patterns within a column of the mask, or completely arbitrary masking patterns.

问题

Following up on the limitations, I would like the know whether the authors think that arbitrary masking patterns can be incorporated in a GPU friendly manner or if memory efficient implementations of such masking patterns would require alternative hardware architectures (such as those developed by Cerebras).

评论

Dear Reviewer,

Thank you for your insightful evaluation and the positive feedback on our work. We are pleased to hear your interest in the potential extension of FlashMask's capabilities.

Weaknesses: As admitted by the authors, this method cannot handle irregular masking patterns within a column of the mask, or completely arbitrary masking patterns.

While it is true that FlashMask does not currently support arbitrary masking patterns, it effectively manages the common mask types utilized in LLM training. We are dedicated to enhancing the expressiveness of FlashMask by addressing some limitations inherent in the current two range-based (LTSLTS, LTELTE, UTSUTS, UTEUTE) sparse representation. Specifically, we aim to improve its capability to handle multi range-based masks and introduce expression-based descriptions, which would remove the necessity for explicit attention mask inputs. For instance, if an attention mask can be described using a formula or expression, it potentially eliminates the need for explicit mask data by allowing the kernel to compute mask positions directly. Furthermore, we are focused on refining FlashMask's implementation to reduce unnecessary memory access and condition checks, particularly in relation to fully masked blocks, thus further enhancing its performance.

Q: Following up on the limitations, I would like to know whether the authors think that arbitrary masking patterns can be incorporated in a GPU-friendly manner or if memory-efficient implementations of such masking patterns would require alternative hardware architectures (such as those developed by Cerebras).

A: Based on the information we found on the Cerebras website, Cerebras' hardware architecture features large global memory, extensive SRAM, and high memory bandwidth. Theoretically, these characteristics allow it to incorporate arbitrary masking patterns more effectively than traditional GPUs.

Thank you once again for your constructive feedback, which continues to guide our development efforts.

Best regards

评论

I appreciate the authors' clarifications and acknowledge that their method is more than adequate to handle the kinds of masking patterns commonly used in LLM training.

审稿意见
8

This paper introduces an extension for SDPA that supports different types of masks in an easy-to-understand way. The method is novel due to its sparse representation. Experiments show FlashMask outperforms FlexAttention by a significant gap.

优点

Although this representation might be similar to COO/CSR/CSC, this is the first time I have ever seen these techniques used in attention, one of the most important operators in LLMs.

缺点

This paper lacks two baselines:

  1. Flashinfer with dense masks;
  2. Flashinfer sparse mask (https://docs.flashinfer.ai/api/python/sparse.html);

Although Flashinfer does not support backward, I believe it is an important baseline for SOTA attention implementation. If this comparison is presented, I will raise my score.

问题

  1. How can this technique be integrated with page attention?
  2. Can tree-based speculative decoding benefit from this customized attention?
  3. Can you report evaluation results on machines such as P100, V100, A10G, and H100? (other than A100)

伦理问题详情

No.

评论

Q: How can this technique be integrated with page attention?

A: FlashMask's innovative column-wise mask representation significantly reduces memory complexity by using an index range vector to characterize each column. It classifies tiling blocks into Fully masked, Partially masked, or Unmasked categories, allowing computations to be skipped for Fully masked blocks and optimizing other blocks by eliminating redundant mask applications. On the other hand, Page Attention deals with KV Cache tokens that are managed across discontinuous physical storage based on page size granularity. Since FlashMask reduces memory complexity from O(N2)O(N^2) to O(N)O(N), the necessity for page management is eliminated. These two techniques are orthogonal, facilitating seamless integration. By leveraging FlashMask's capabilities, Page Attention can support a more sophisticated expression of masks.

Q: Can tree-based speculative decoding benefit from this customized attention?

A: Indeed, FlashMask supports tree-based speculative decoding methods, as depicted in SpecInfer's Tree-based Parallel Decoding example (Figure 4), with sequences like LTS=[1,2,3,4,5,6,7,8]LTS=[1, 2, 3, 4, 5, 6, 7, 8], LTE=[1,2,3,4,5,6,7,8]LTE=[1, 2, 3, 4, 5, 6, 7, 8], UTS=[0,0,0,0,3,3,2,2]UTS=[0, 0, 0, 0, 3, 3, 2, 2], and LTS=[0,0,0,0,4,4,6,6]LTS=[0, 0, 0, 0, 4, 4, 6, 6].

Q: Can you report evaluation results on machines such as P100, V100, A10G, and H100? (other than A100)

A: The current version of FlashMask extends from FlashAttention-2 without substantial code restructuring, thus it is not supported on legacy architectures like P100 and V100. Unfortunately, due to the lack of H100 and A10G, we have yet to conduct evaluations on Hopper and other Ampere GPUs. However, the methodology itself is adaptable across hardware platforms.

评论

Comparison of FlashInfer DenseMask (single_prefill_with_kv_cache), FlashInfer SparseMask (BlockSparseAttentionWrapper), and FlashMask on Document Masks at an 128K Sequence Length, with Varying R/C

Although FlashInfer DenseMask and FlashMask do not have specific R/C values, the masks they represent are consistent under different R/C due to the total sequence length and each sub-document's sequence length being divisible by 64. We have included them in the table for comparison with FlashInfer SparseMask, conducting multiple tests to ensure a fair evaluation.

MethodMask TypeR/CSeq LengthSparsityFW Time (ms)FW TFLOPsFW TFLOPs/s
FlashInfer SparseMaskDocument Mask11310720.9115911571.1224.884815.8381
FlashInfer SparseMaskDocument Mask21310720.911591783.62224.884831.7548
FlashInfer SparseMaskDocument Mask41310720.911591391.20124.884863.6058
FlashInfer SparseMaskDocument Mask81310720.911591288.97324.884886.1052
FlashInfer SparseMaskDocument Mask161310720.911591145.12624.8848171.447
FlashInfer SparseMaskDocument Mask321310720.911591131.31224.8848189.495
FlashInfer SparseMaskDocument Mask641310720.911591131.32524.8848189.476
FlashInfer DenseMaskDocument Mask11310720.9115912946.724.88488.44507
FlashInfer DenseMaskDocument Mask21310720.9115912946.8824.88488.44471
FlashInfer DenseMaskDocument Mask41310720.9115912947.3224.88488.44338
FlashInfer DenseMaskDocument Mask81310720.911591294724.88488.44419
FlashInfer DenseMaskDocument Mask161310720.9115912946.9624.88488.44435
FlashInfer DenseMaskDocument Mask321310720.9115912946.7324.88488.44511
FlashInfer DenseMaskDocument Mask641310720.9115912946.8124.88488.44475
FlashMaskDocument Mask11310720.911591172.88324.8848143.616
FlashMaskDocument Mask21310720.911591172.85924.8848143.635
FlashMaskDocument Mask41310720.911591172.83724.8848143.654
FlashMaskDocument Mask81310720.911591172.82224.8848143.666
FlashMaskDocument Mask161310720.911591172.81124.8848143.675
FlashMaskDocument Mask321310720.911591172.80724.8848143.679
FlashMaskDocument Mask641310720.911591172.80924.8848143.677
评论

Comparison of FlashInfer DenseMask (single_prefill_with_kv_cache), FlashInfer SparseMask (BlockSparseAttentionWrapper), and FlashMask on Document Masks at an 32K Sequence Length, with Varying R/C

Although FlashInfer DenseMask and FlashMask do not have specific R/C values, the masks they represent are consistent under different R/C due to the total sequence length and each sub-document's sequence length being divisible by 64. We have included them in the table for comparison with FlashInfer SparseMask, conducting multiple tests to ensure a fair evaluation.

MethodMask TypeR/CSeq LengthSparsityFW Time (ms)FW TFLOPsFW TFLOPs/s
FlashInfer SparseMaskDocument Mask1327680.906438104.6461.6459715.7256
FlashInfer SparseMaskDocument Mask2327680.90643852.46311.6459731.3644
FlashInfer SparseMaskDocument Mask4327680.90643825.96171.6459763.47
FlashInfer SparseMaskDocument Mask8327680.90643819.67721.6459783.5854
FlashInfer SparseMaskDocument Mask16327680.9064389.873051.64597166.575
FlashInfer SparseMaskDocument Mask32327680.9064388.887971.64597185.125
FlashInfer SparseMaskDocument Mask64327680.9064388.886041.64597185.16
FlashInfer DenseMaskDocument Mask1327680.906438184.0971.645978.94137
FlashInfer DenseMaskDocument Mask2327680.906438183.9821.645978.94653
FlashInfer DenseMaskDocument Mask4327680.906438183.9951.645978.94587
FlashInfer DenseMaskDocument Mask8327680.906438184.0331.645978.94402
FlashInfer DenseMaskDocument Mask16327680.906438183.9951.645978.94599
FlashInfer DenseMaskDocument Mask32327680.906438183.9971.645978.94577
FlashInfer DenseMaskDocument Mask64327680.906438183.9861.645978.94627
FlashMaskDocument Mask1327680.90643811.7471.64597139.665
FlashMaskDocument Mask2327680.90643811.74241.64597139.72
FlashMaskDocument Mask4327680.90643811.74291.64597139.715
FlashMaskDocument Mask8327680.90643811.73431.64597139.819
FlashMaskDocument Mask16327680.90643811.73291.64597139.836
FlashMaskDocument Mask32327680.90643811.73271.64597139.837
FlashMaskDocument Mask64327680.90643811.73011.64597139.869
评论

Comparison of FlashInfer DenseMask (single_prefill_with_kv_cache), FlashInfer SparseMask (BlockSparseAttentionWrapper), and FlashMask on Document Masks at an 8K Sequence Length, with Varying R/C

Although FlashInfer DenseMask and FlashMask do not have specific R/C values, the masks they represent are consistent under different R/C due to the total sequence length and each sub-document's sequence length being divisible by 64. We have included them in the table for comparison with FlashInfer SparseMask, conducting multiple tests to ensure a fair evaluation.

MethodMask TypeR/CSeq LengthSparsityFW Time (ms)FW TFLOPsFW TFLOPs/s
FlashInfer SparseMaskDocument Mask181920.78680415.39370.23441115.1884
FlashInfer SparseMaskDocument Mask281920.7613048.568670.26244930.482
FlashInfer SparseMaskDocument Mask481920.7613044.312570.26244960.5664
FlashInfer SparseMaskDocument Mask881920.7613043.231640.26244980.9699
FlashInfer SparseMaskDocument Mask1681920.7613041.648750.262449158.545
FlashInfer SparseMaskDocument Mask3281920.7613041.514390.262449172.609
FlashInfer SparseMaskDocument Mask6481920.7613041.51230.262449172.817
FlashInfer DenseMaskDocument Mask181920.78680411.9210.23441119.6652
FlashInfer DenseMaskDocument Mask281920.76130411.9110.26244922.037
FlashInfer DenseMaskDocument Mask481920.76130411.91160.26244922.0376
FlashInfer DenseMaskDocument Mask881920.76130411.91520.26244922.0298
FlashInfer DenseMaskDocument Mask1681920.76130411.91250.26244922.035
FlashInfer DenseMaskDocument Mask3281920.76130411.91350.26244922.0331
FlashInfer DenseMaskDocument Mask6481920.76130411.91190.26244922.0358
FlashMaskDocument Mask181920.7868041.510630.234411154.494
FlashMaskDocument Mask281920.7613041.688580.262449154.789
FlashMaskDocument Mask481920.7613041.670810.262449156.027
FlashMaskDocument Mask881920.7613041.660580.262449156.736
FlashMaskDocument Mask1681920.7613041.65990.262449156.823
FlashMaskDocument Mask3281920.7613041.657970.262449157.007
FlashMaskDocument Mask6481920.7613041.658830.262449156.941
评论

Dear Reviewer,

Thank you for engaging with our work and for highlighting the importance of extending evaluation to include FlashInfer as a baseline. We greatly appreciate your insights, which aid in advancing our research on FlashMask.

Response to Weaknesses and Questions:

We acknowledge the value that FlashInfer brings, bridging training and inference applications. Should our paper be accepted at ICLR 2025, we plan to reference FlashInfer in the camera-ready version and include the following experimental results in the appendix. Our primary focus with FlashMask has been addressing the extensive high bandwidth memory (HBM) requirements posed by diverse attention mask types in large-scale model training. Nevertheless, FlashMask's capabilities extend into inference stages as well, warranting consideration among leading attention mechanisms.

We defined the "mask block size" based on FlashInfer's description, utilizing the BSR API's parameters R and C. The "tiling block size" was designed to match the operational dimensions of the kernel. Our experiments were conducted on an A100-SXM 80G GPU using the official version of FlashInfer v0.1.6, with CUDA 12.1, PyTorch 2.4, and BF16 as the data type. The configuration settings included batch_size = 1, num_qo_heads = 32, num_kv_heads = 8, and head_dim = 128. We compared FlashMask against the dense mask API single_prefill_with_kv_cache and the sparse mask API BlockSparseAttentionWrapper (with varying R/C values) in FlashInfer. Typical attention masks, such as the Causal Document Mask, Document Mask, and Shared Question Mask, were selected for evaluation.

The datasets used were from Section A.5.2, but were slightly modified to ensure that each sub-document sequence length was divisible by 64, allowing for experiments with a FlashInfer sparse mask where C=64. As shown in the experimental results below, FlashMask demonstrated superior TFLOPs/s, effectively addressing the inefficiencies observed with FlashInfer's dense mask API. While FlashInfer with sparse masks shows considerable improvements as the mask block sizes (R and C) increase, especially when R, C ≥ 16, such large mask block sizes are rarely required in practice.

Within FlashInfer’s single_prefill_with_kv_cache DenseMask implementation in prefill.cuh#L1234-L1241, the use of token-by-token dense masks results in substantial inefficiencies, especially since calculations for fully masked blocks are entirely unnecessary. Regarding FlashInfer's BlockSparseAttentionWrapper, the mask block column CC functions as the page size. Smaller mask block sizes lead to a marked increase in the padded batch size nblks(padded_batch_size, 1, num_kv_heads), which adversely affects performance because of non-optimal kernel hyper-parameter tuning. In contrast, with larger mask block sizes, FlashInfer maximizes the advantages of BSR's sparse representation by calculating only the necessary tiling blocks and avoiding excess calculations for fully masked blocks, thereby achieving high TFLOPs/s.

The hyper-parameter tuning results of FlashInfer are as follows:

R1C1: request_idx=8180, packed_qo_len=4, kv_len=2496, qo_chunk_size=16, kv_chunk_size=2496, num_tiles_q=1, num_tiles_kv=1
R2C2: request_idx=4094, packed_qo_len=8, kv_len=960, qo_chunk_size=16, kv_chunk_size=1344, num_tiles_q=1, num_tiles_kv=1
R4C4: request_idx=2047, packed_qo_len=16, kv_len=720, qo_chunk_size=16, kv_chunk_size=784, num_tiles_q=1, num_tiles_kv=1
R8C8: request_idx=1023, packed_qo_len=32, kv_len=360, qo_chunk_size=64, kv_chunk_size=392, num_tiles_q=1, num_tiles_kv=1
R16C16: request_idx=511, packed_qo_len=64, kv_len=180, qo_chunk_size=64, kv_chunk_size=196, num_tiles_q=1, num_tiles_kv=1
R32C32: request_idx=255, packed_qo_len=128, kv_len=90, qo_chunk_size=128, kv_chunk_size=98, num_tiles_q=1, num_tiles_kv=1
R64C64: request_idx=127, packed_qo_len=256, kv_len=45, qo_chunk_size=128, kv_chunk_size=49, num_tiles_q=2, num_tiles_kv=1

R2C2: partition_kv=0, padded_batch_size=4096, num_warps_x=1, num_warps_z=4, num_frags_x=1, num_frags_y=8, num_frags_z=2
R4C4: partition_kv=0, padded_batch_size=2048, num_warps_x=1, num_warps_z=4, num_frags_x=1, num_frags_y=8, num_frags_z=2
R8C8: partition_kv=0, padded_batch_size=1024, num_warps_x=4, num_warps_z=1, num_frags_x=1, num_frags_y=8, num_frags_z=8
R16C16: partition_kv=0, padded_batch_size=512, num_warps_x=4, num_warps_z=1, num_frags_x=1, num_frags_y=8, num_frags_z=8
R32C32: partition_kv=0, padded_batch_size=256, num_warps_x=4, num_warps_z=1, num_frags_x=2, num_frags_y=8, num_frags_z=4
R64C64: partition_kv=0, padded_batch_size=256, num_warps_x=4, num_warps_z=1, num_frags_x=2, num_frags_y=8, num_frags_z=4

We hope this response clarifies our approach and experimental scope, and we look forward to any additional feedback you might provide.

Best regards

评论

Comparison Results of FlashInfer DenseMask (single_prefill_with_kv_cache), FlashInfer SparseMask (BlockSparseAttentionWrapper), and FlashMask on Causal Document Mask at 8K, 32K, and 128K

MethodMask TypeSeq LengthSparsityFW Time (ms)FW TFLOPsFW TFLOPs/s
FlashInfer SparseMaskCausal Document Mask81920.8805919.327570.13129213.9464
FlashInfer DenseMaskCausal Document Mask81920.88059111.93050.13129211.0058
FlashMaskCausal Document Mask81920.8805910.9608010.131292135.071
FlashInfer SparseMaskCausal Document Mask327680.95320454.76950.82325115.0114
FlashInfer DenseMaskCausal Document Mask327680.953204184.1980.8232514.46944
FlashMaskCausal Document Mask327680.9532045.985860.823251137.087
FlashInfer SparseMaskCausal Document Mask1310720.955792788.93912.443515.7707
FlashInfer DenseMaskCausal Document Mask1310720.9557922948.2312.44354.22073
FlashMaskCausal Document Mask1310720.95579284.12712.4435147.582

Comparison Results of FlashInfer DenseMask (single_prefill_with_kv_cache), FlashInfer SparseMask (BlockSparseAttentionWrapper), and FlashMask on Share Question Mask at 8K, 32K, and 128K

MethodMask TypeSeq LengthSparsityFW Time (ms)FW TFLOPsFW TFLOPs/s
FlashInfer SparseMaskShare Question Mask81920.9324016.118610.074326211.7586
FlashInfer DenseMaskShare Question Mask81920.93240111.9380.07432626.22726
FlashMaskShare Question Mask81920.9324010.7270630.074326298.4887
FlashInfer SparseMaskShare Question Mask327680.97420932.8720.45372813.7364
FlashInfer DenseMaskShare Question Mask327680.974209184.3960.4537282.46063
FlashMaskShare Question Mask327680.9742094.587270.45372898.262
FlashInfer SparseMaskShare Question Mask1310720.975079443.8047.0146115.7953
FlashInfer DenseMaskShare Question Mask1310720.9750792948.897.014612.37871
FlashMaskShare Question Mask1310720.97507961.56997.01461113.21
评论

Dear Reviewer gr4R,

I hope this message finds you well. I am writing to kindly follow up on my response to your feedback. I have addressed the questions and comments you raised, including conducting the additional experiments and providing the requested results. I sincerely appreciate the time and effort you have dedicated to reviewing my work, and I am grateful for your constructive feedback.

If there is anything further you would like me to clarify or expand upon, please do not hesitate to let me know. Additionally, if the rebuttal has satisfactorily addressed your concerns, I would kindly appreciate your consideration in reflecting this in the final scoring.

Thank you once again for your thoughtful feedback and your time. I look forward to your response at your earliest convenience.

Best regards,

The Authors

评论

I think the experiments are convincing. I will raise my score.

评论

Thank you very much for your thoughtful feedback and for taking the time to review our rebuttal; we greatly appreciate your support and consideration. We will incorporate the experimental results into the revised version.

审稿意见
6

The paper proposes an efficient sparse mask representation by using composition of LT and RT range for expressing complex patterns. The proposed mask is compatible with FlashAttention-2 and can bring speed-up when applied.

优点

  1. The paper open-sourced a rather general sparse self-attention representation framework, which could facilitate many research and production attempts in the field.
  2. The implementation is practical, shown wall-clock speed-up over FlashAttention-2.

缺点

  1. It seems the implementation is limited to Paddle. It would be good to see if it can also be made more general so that the Torch/Megatron community can also leverage the framework.
  2. Inference support is missing. It would make more sense to discuss how such sparse mask can be put into actual inference/serving.
  3. [1] was published earlier, and also provide a general sparse self-attention training & serving framework. It would be ideal to also cite [1].

[1] S2-Attention: Hardware-Aware Context Sharding Among Attention Heads

问题

  1. What would be the block size supported by Flashmask? Namely, what would be the granularity of mask/unmask chunks?
  2. How does different block/chunk size affect the speed-up, in different mask types?
  3. When tiling, it seems some mask may lead to different workload among thread blocks, which could hurt the overall performance. Is there any mitigation to this?
  4. Can we have a comparison between the theoretical FLOPs reduction wrt wall-clock speed-up for different mask types?
  5. How does tensor parallel and pipeline parallel affect the speed-up?
评论

Q: What would be the block size supported by FlashMask? Namely, what would be the granularity of mask/unmask chunks?

A: In response to your question, we define the "block size" or "mask/unmask chunks" as "mask block size." FlashMask represents masks with token-level granularity, thereby setting the mask block size to 1. This minimal granularity supports arbitrary mask block size. As detailed in sections 4.1 and 4.2, FlashMask divides tiling blocks into Fully masked, Partially masked, and Unmasked categories by representing each column with a range of indices, thereby bypassing computation on Fully masked blocks while applying masks only to Partially masked blocks.

Q: How does different block/chunk size affect the speed-up, in different mask types?

A: Again, defining "block size" or "chunk size" as "mask block size" and differentiating from the computation tiling block size in FlashAttention Kernels, FlashMask maintains a mask block size of 1. Thus, it supports any mask block size. The tiling block size remains unchanged from FlashAttention-2. FlashMask primarily focuses on reducing memory complexity via the novel column-wise mask representation. Our experiments (see Section 5.3, Figure 4a) reveal a strong correlation between block sparsity and processing latency—more structured attention masks result in fewer partially masked blocks, thereby decreasing the computational load and enhancing performance. It should be noted that FlashMask is not yet optimized for maximum performance, as there are additional memory accesses and condition checks that have not been fully optimized. In the future, we plan to further optimize these aspects.

Q: When tiling, it seems some mask may lead to different workload among thread blocks, which could hurt the overall performance. Is there any mitigation to this?

A: The current FlashMask implementation inherits from FlashAttention 2, assigning each thread block to handle calculations for a given batch, head, and tiling row. Despite potential workload imbalances due to varying mask types within tiling blocks (Fully masked, Partially masked, Unmasked), the SMs (Streaming Multiprocessors) remain fully utilized as multiple thread blocks can be assigned to each SM. Thus, mask diversity does not negatively impact overall performance. That said, future optimizations will focus on reducing unnecessary memory access and conditions, as depicted in Algorithm 1, lines 9-14.

Q: Can we have a comparison between the theoretical FLOPs reduction and wall-clock speed-up for different mask types?

A: As stated in Section 4.3, the theoretical FLOPs reduction is linearly related to the block sparsity (ρ\rho) in the attention mask and is independent of the specific mask type. This indicates that kernel latency should be proportional to O((1ρ)TrTc)\mathcal{O}((1-\rho)T_rT_c). The actual speed-up observed in the experiments detailed in Figure 4(a) further supports this relationship, showing a strong correlation between the theoretical and empirical results.

Q: How does tensor parallelism and pipeline parallelism affect the speed-up?

A: FlashMask is an optimization at the operator level, and its effects are orthogonal to those of tensor parallelism and pipeline parallelism. These parallel strategies distribute computation equally across different devices. Therefore, we believe that different parallel strategies have minimal impact on the performance gains provided by FlashMask.

评论

Dear Reviewer,

Thank you for the valuable feedback and questions on our paper. We appreciate the chance to elaborate on our work and address your concerns.

Weaknesses 1: Implementation limited to PaddlePaddle.

Our current implementation of FlashMask is indeed conducted within the PaddlePaddle framework. However, it is integrated as a third-party module in PaddlePaddle and decouples from framework, so it's easy to be integrated into PyTorch or other frameworks. We aim to expand our validation efforts to ensure broader compatibility and adoption across platforms.

Weaknesses 2: Inference support is missing.

FlashMask can certainly be extended to inference applications, like Page Attention and tree-based speculative decoding. Although this paper primarily focuses on training aspects, we acknowledge the importance of discussing inference use cases and plan to explore these in future work.

Weaknesses 3: Missing citation to [1].

We appreciate your suggestion regarding [1] (S2-Attention: Hardware-Aware Context Sharding Among Attention Heads). It is indeed a commendable work, and we intended to cite it in the camera-ready version of our submission. S2-Attention employs a Block-Sparse Attention strategy focusing on approximate attention sparsity calculations at both per-head and per-context-range levels for long-sequence modeling. In contrast, FlashMask emphasizes supporting a wide variety of attention masks used in transformer training through precise computation rather than approximate methods. Our approach reduces memory complexity to O(N)O(N) by introducing a novel column-wise mask representation. This enables handling longer sequences efficiently while supporting mainstream attention masks precisely.

Thank you once again for your thoughtful review and questions, and we look forward to refining FlashMask further.

Best regards

评论

Dear TH6Z,

I hope you’re doing well. I wanted to follow up on my rebuttal, where I’ve addressed your questions and comments. I sincerely appreciate your time and constructive feedback.

If there’s anything else you’d like me to clarify or expand on, please let me know. Your insights are invaluable in improving my submission.

Thank you again for your thoughtful review. I look forward to your response.

Best regards,

The Authors

审稿意见
6

The paper introduces a novel compression scheme for the attention mask where only the boundary indices of the masks are stored for every column. For a specific set of attention masks, it is sufficient to store two sets of boundary indices for every column to represent the attention mask. This reduces the memory complexity of attention masks from quadratic on sequence length to linear on sequence length, enabling handling of longer context lengths. The column-wise sparse representation is also used to skip fully masked blocks increasing the overall compute efficiency of the attention mechanism. This technique is augmented with FlashAttention algorithm for efficient computation of the attention mechanism and the modified algorithm for both forward pass and backward pass are presented. The experiments section shows that FlashMask is faster than FlashAttention dense method by up to 3.22x and can achieve up to 60.23% more throughput than FlexAttention. The proposed method also doesn’t alter the convergence during training.

优点

The paper is well-written and easy to understand. The results section is elaborate with a wide range of benchmarks to demonstrate the advantages of the proposed methods. The appendix section and the analysis with synthetic data to corroborate the claims are very insightful. The compute and memory utilization advantages of FlashMask are well demonstrated. The proposed sparse representation scheme is novel and should be adopted wherever applicable for its memory efficiency and ability to support longer context lengths.

缺点

While the results section shows that FlashMask achieves higher computational efficiency, I’m not sure if it’s attributable to the proposed columns-wise sparse representation. The computational efficiency of FlashMask comes from skipping computation on entirely masked blocks as discussed in section 4.3. However, this technique is also used in Block-Sparse FlashAttention and FlexAttention. The advantages of FlashMask over Block-Sparse FlashAttention and FlexAttention in terms of computational efficiency is not clear. Also as mentioned in the paper, the idea of column-wise sparse representation used in FlashMask is limited to specific attention patterns. Any other pattern can’t be handled even naively.

问题

You mentioned FlexAttention can also exploit sparsity by skipping computation on fully masked blocks. If that’s the case where’s compute throughput advantage of FlashMask coming from? Would the Block-Sparse FlashAttention be able to handle the mask types described in Fig 1(a)? If yes, that should be used instead of the DenseMask variant for the throughput comparisons across the paper. If not, please mention why. In Fig 4b, why is the FlexAttention’s memory utilization lower than that of FlashMask for sequence length lower than 16K?

评论

Dear Reviewer,

Thank you for your insightful feedback and questions regarding our paper. We appreciate the opportunity to address your concerns and provide additional clarifications.

Q: You mentioned FlexAttention can also exploit sparsity by skipping computation on fully masked blocks. If that’s the case where’s compute throughput advantage of FlashMask coming from?

A: Similar to FlexAttention, FlashMask also capitalizes on the sparsity of attention masks. The key distinction is that FlashMask is built upon FlashAttention2, inheriting all its manual and hyperparameter optimizations. As demonstrated in the performance chart on the official FlexAttention blog, the current FlexAttention is slower than FlashAttention2. For more details, you can refer to the blog post at https://pytorch.org/blog/flexattention/#performance. FlexAttention, however, is still in development, utilizing Triton-based compiler technology, and has not yet reached its full performance potential.

Moreover, although FlashMask's current performance is not optimal, as it only achieves 37.8% to 62.3% of the theoretical maximum FLOPs/s on A100 GPUs, we plan to enhance it by reconstructing the traversal method, further improving throughput for all mask types.

Q: Would the Block-Sparse FlashAttention handle the mask types described in Fig 1(a)? If yes, it should be used instead of the DenseMask variant for throughput comparisons across the paper. If not, please mention why.

A: No. As explained in Section 2.3, Block-Sparse FlashAttention represents masks with tiling-level granularity, which makes it unsuitable for handling token-level masks like DenseMask. Previously, DenseMask was the only comprehensive method capable of representing the attention masks shown in Fig 1(a). Therefore, we used DenseMask as a performance baseline in our paper.

Furthermore, Block-Sparse FlashAttention was developed as part of a new approximate attention algorithm. In contrast, our FlashMask is designed to accelerate the attention module when the sparsity of the attention mask naturally arises from the problem itself. For instance, the causal document mask, which is inherently sparse, is commonly used in SFT training. Unlike the approximate computations in Block-Sparse FlashAttention, our approach leverages the inherent sparsity of the mask to achieve acceleration without any loss of precision.

Q: Why is the FlexAttention’s memory utilization lower than that of FlashMask for sequence length lower than 16K in Fig 4b?

A: As discussed in section 2.2, the memory complexity for FlexAttention is O(N2BcBr)O\left(\frac{N^2}{{B_c}{B_r}}\right), where Br=128B_r=128 and Bc=128B_c=128. FlashMask, on the other hand, employs an innovative column-wise mask representation, maintaining a linear memory complexity of O(N)O(N). When the sequence length is less than 16K, O(N2BcBr)O\left(\frac{N^2}{{B_c}{B_r}}\right) is actually less than O(N)O(N). Although the FlexAttention blog suggests potential memory savings by increasing the block sizes BrB_r and BcB_c, its complexity still remains quadratic with respect to the sequence length NN. Conversely, when the sequence length exceeds 16K, FlashMask's memory usage is lower than that of FlexAttention.

We appreciate your thorough scrutiny and hope these responses satisfactorily address your concerns.

Best regards

评论

Dear Uyji,

I hope you’re doing well. I wanted to follow up on my rebuttal, where I’ve addressed your questions and comments. I sincerely appreciate your time and constructive feedback.

If there’s anything else you’d like me to clarify or expand on, please let me know. Your insights are invaluable in improving my submission.

Thank you again for your thoughtful review. I look forward to your response.

Best regards,

The Authors

AC 元评审

The paper presents further optimization over Flash-Attention 2 for better memory management of the Attention mechanism, I think the novelty of the approach in this paper is limited. But given the reviewers are enthusiastic about it, it would be ok to accept.

审稿人讨论附加意见

Rebuttal led to increase of score.

最终决定

Accept (Poster)