PaperHub
6.3
/10
Poster3 位审稿人
最低5最高7标准差0.9
5
7
7
3.7
置信度
正确性3.7
贡献度2.7
表达3.3
NeurIPS 2024

Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level

OpenReviewPDF
提交: 2024-04-27更新: 2024-11-06

摘要

关键词
fused attention kernelneighborhood attentionsliding window attention

评审与讨论

审稿意见
5

This paper proposes two CUDA kernel optimization techniques: batched GEMM and fusion for neighborhood attention. On average, batched GEMM optimization provides 895% (548%) and 272% (193%) improvement in full precision (half precision) latency compared to existing naive CUDA kernels for 1-D and 2-D neighborhood attention, respectively. Fusion optimization improves naive kernels by an average of 1759% and 958% in 1-D and 2-D problems respectively. These optimizations translate into up to 104% improvement in inference and 39% improvement in training existing models based on neighborhood attention. The fused kernels can match or outperform the authors' self-attention baseline in approximately 100% of 1-D, 98.6% of 2-D, and 97.3% of 3-D problem sizes that they benchmarked.

优点

  1. The paper clearly explains the overhead one can expect from switching from a standard self-attention kernel to neighborhood attention.
  2. The paper clearly illustrates the CUDA kernel optimization techniques they use, which is challenging because the kernels are complex.

缺点

  1. The paper does not quantitively compare the neighborhood attention kernel with the state-of-the-art fused dot-product attention kernel.
  2. The paper only evaluates one type of GPU, the NVIDIA A100, which questions the generality of the proposed kernel optimization technique.

问题

  1. The paper does not quantitively explain where the performance benefits come from. It would help the readers assess the kernel better if you could show the arithmetic intensity of the naive, GEMM, and fused kernels under different inputs and configurations. Meanwhile, could you show some measurements on the memory, cache, and occupancy that would help explain the performance improvement?
  2. Could you show the best performance achieved by torch.compile in PyTorch 2? That would help motivate the necessity of developing CUDA kernels.
  3. Could you show how the optimized kernels impact the algorithm, such as enabling larger models or longer contexts?
  4. Could you explain whether the new tensor memory accelerator introduced in H100 can help with the neighborhood attention kernel?

局限性

  1. The paper only shows that optimized kernels can accelerate current algorithms, but not extend their ability to more applications.
作者回复

We thank you for your feedback. To answer your questions:

  1. Our performance evaluations are limited to runtime, mainly because the naive kernel is already not utilizing anything other than occupancy (through launching as large of a wave as possible and assigning a single dot-product’s worth of work to individual threads). It does not use tensor cores, and therefore there is no effective way of measuring math utilization in that case, at least not to our knowledge. The naive kernel also doesn’t explicitly use any levels of cache. There are only two cases in which the naive kernel outperforms ours, which are: a) GEMM and FNA hit tile quantization; or b) Very small wave, which means the prologue, epilogue, and gmem to smem overheads will bottleneck GEMM and FNA, which will happen in any naive vs tiled implementation. The naive kernel, which is our baseline, was as we understand it a proof of concept implementation, and not a performance optimized baseline to begin with. As for GEMM vs FNA, implementations are very similar, and the only bottleneck with GEMM is the memory alignment issue which we described in the paper, which FNA does not run into since the first GEMM in FNA will skip the epilogue and move results to shared memory instead of gmem. In addition, prior research such as Flash Attention clearly shows that BMM-style implementations are typically memory bound problems, which includes both the naive baseline and our GEMM-based implementations. Finally, neighborhood attention itself is inherently a GEMV problem, which is also heavily memory bandwidth bound, therefore runtime / achievable FLOPS are really the only metrics we can use. We would be happy to add more benchmarks if you happen to have any other metrics in mind.

  2. Unfortunately there is no way of implementing this idea with torch.compile, other than the very new FlexAttention, which to our knowledge is still in prototype. torch.compile uses an induction engine to attempt and find patterns in the graph for which a fused template kernel exists. These are typically limited to a single fusion of an elementwise or at best a reduction into a GEMM kernel, and do not extend beyond that. As a result, torch.compile is great for more common fusions, such as elementwise or reduction fusion into GEMMs, and the like, there still aren't too many specific templates for attention. The only functionally correct neighborhood attention implementations in python were done using im2col and padding, which cannot be inducted as traversals/views of the same tensor so that they could be fused into an attention kernel.

  3. We would note that our implementation also adds features that were previously not implemented for neighborhood attention. To name a few, causal masking, and varying per-axis parameters unlock many attention patterns that were previously non-existent. A very important one of those is spatio-temporal attention (with causal masking along time and not space), which we foresee will have great applications in video generation and video recognition. Aside from the new features, as you pointed out, acceleration to larger context / larger inputs is one key aspect; unfortunately we have not found too many applications of neighborhood attention at very large scales, mainly because most NA applications are vision-focused, and applications with very large contexts are limited. However, we would be happy to add inference and even training results on higher-resolution tasks such as object detection and segmentation, which can better illustrate model-level improvements than image classification.

  4. Thank you for asking! Tensor memory accelerator (TMA) is a hardware engine for bulk data movement from global memory into shared memory. It handles some of the layout transformation, but primarily the pointer movement and predication, among other things. One interesting property of the TMA is that the bulk memory accesses do not necessarily have to be in contiguous memory, and can be according to any up-to-rank-5 layout. This means that our implementation's primary bottleneck, which is data movement in the presence of non-trivial sequence "modes", will be handled natively through the TMA in a Hopper-oriented implementation. One open source example is the CUTLASS GETT (for General Tensor-Tensor Contraction) example, in which it is illustrated how the same GEMM kernels written for Hopper can be manipulated through layout transformations and use of the TMA, to perform tensor contractions. All fused attention kernels so far have been the fusion of back-to-back GEMMs with a softmax in between, but FNA is actually the fusion of back-to-back GETTs. The same logic can be extended to FNA, in which our sequence "mode" can be 1-D, 2-D, or 3-D, which we can tile in different shapes, all through host-side layout transformations and the TMA. To our knowledge, the details of the TMA copy (given a fixed number of transaction bytes) do not affect its performance as much, and certainly do not lead to unavoidable branching (since it is a single instruction), which is what our issue is in FNA.

With regard to weaknesses, it is true that our comparison is only to FMHA, which is our baseline, and the reason for that is precisely the fact that we know other optimization techniques used in other state of the art methods such as Flash Attention v2 and v3 that are not used in either FMHA or FNA. However, this does not mean that our methodology is limited to FMHA or a specific architecture; it was just more generic.

In addition, our comparisons are done on Ampere (A100 specifically) mostly because that is the only data-center class card we have at our disposal, and also because it is still a widely used card for many in the field. However our implementation is multi-architecture and supports all architectures since Pascal, and architectures newer than Ampere can still run the kernels.

评论

We are also trying to gain access to H100s, so we can develop future versions of the kernels specifically targeting Hopper. We are hopeful that we will gain access in the coming months, and if time permits we will add those results as well.

Finally, as stated, we fully intend to extend our methodology to not just newer CUDA architectures, but other platforms and hardware as well (ROCm and Metal to name a few.)

评论

Thank you for your response. My concerns have been solved, and I would like to see training results on higher-resolution tasks, which can demonstrate model-level improvements brought by GPU kernel optimization.

评论

Of course, we're attaching them below. However, we note these measurements are only of the backbone and not the entire detector/segmentor end to end. The reason for that is that the original paper used MMDet and MMSeg, which have since completely changed their API, and the version used in the original NAT/DiNAT papers was build with CUDA Toolkit 11.3, while our GEMM kernel and FNA kernel require 11.8 at a minimum. We tried multiple newer MMDet/MMSeg versions (really MMCV, their co-dependency) and unfortunately the only ones compatible with our newer kernels still do not support the old NAT/DiNAT configurations, and break during inference.

Because of that, the only solution we could think of was benchmarking the backend models separately.

In addition, benchmarking the backend alone indicates our performance improvement better than the end to end measurement, because the detection/segmentation heads are usually highly unoptimized and use CPU operations somewhat frequently (i.e. when producing masks or RoI maps), and that will bias these measurements.

Overall, for FNA we observe up to 113% speedup over naive, 165% speedup over tiled naive, and 40% speedup over our own GEMM-based kernel.

Detection

Backbone benchmarked on an A100-PCIe with 800x1216 resolution inputs (per detection resolution in the code from original NAT/DiNAT papers).

Disclaimer: throughput measurements are from backbone only, and do not include detection head and post processing.

BackbonemAPMask mAPNaive ThroughputTiled Naive ThroughputGEMM ThroughputFNA Throughput
NAT-Mini50.343.671.9 FPS82.0 FPS82.0 FPS98.1 FPS
NAT-Tiny51.444.549.8 FPS49.5 FPS50.2 FPS62.6 FPS
NAT-Small52.044.936.4 FPS46.5 FPS48.7 FPS61.1 FPS
NAT-Base52.345.128.4 FPS36.8 FPS44.7 FPS56.8 FPS
DiNAT_s-Tiny51.044.173.5 FPS92.9 FPS107.2 FPS138.8 FPS
DiNAT_s-Small52.345.245.5 FPS57.1 FPS61.2 FPS76.1 FPS
DiNAT_s-Base52.645.335.3 FPS44.5 FPS51.5 FPS70.9 FPS
DiNAT_s-Large54.847.223.6 FPS29.5 FPS33.7 FPS46.1 FPS
DiNAT-Mini51.244.471.0 FPS83.2 FPS82.7 FPS100.2 FPS
DiNAT-Tiny52.245.149.4 FPS51.2 FPS51.5 FPS61.6 FPS
DiNAT-Small52.945.836.9 FPS46.9 FPS51.2 FPS63.0 FPS
DiNAT-Base53.446.228.0 FPS35.9 FPS42.0 FPS57.5 FPS
DiNAT-Large55.347.819.5 FPS25.2 FPS29.7 FPS41.6 FPS

Segmentation

Backbone benchmarked on an A100-PCIe with 512x512 resolution inputs for Mini, Tiny, Small, and Base variants, and 640x640 for Large varient (per segmentation resolution in the code from original NAT/DiNAT papers).

Disclaimer: throughput measurements are from backbone only, and do not include segmentation head and post processing.

BackbonemIoUmIoU (multiscale)Naive ThroughputTiled Naive ThroughputGEMM ThroughputFNA Throughput
NAT-Mini45.146.482.4 FPS84.2 FPS81.1 FPS102.7 FPS
NAT-Tiny47.148.450.3 FPS50.6 FPS48.9 FPS61.6 FPS
NAT-Small48.049.548.2 FPS47.5 FPS47.7 FPS57.6 FPS
NAT-Base48.549.747.4 FPS47.6 FPS46.8 FPS57.9 FPS
DiNAT_s-Tiny46.047.4116.6 FPS114.8 FPS109.4 FPS139.7 FPS
DiNAT_s-Small48.649.963.4 FPS56.7 FPS60.0 FPS72.4 FPS
DiNAT_s-Base49.450.261.7 FPS59.7 FPS63.3 FPS73.0 FPS
DiNAT_s-Large53.454.649.4 FPS59.2 FPS59.8 FPS72.1 FPS
DiNAT-Mini45.847.280.1 FPS80.4 FPS81.5 FPS99.0 FPS
DiNAT-Tiny47.848.849.8 FPS51.1 FPS49.6 FPS62.2 FPS
DiNAT-Small48.949.948.6 FPS49.6 FPS49.2 FPS59.3 FPS
DiNAT-Base49.650.447.2 FPS48.6 FPS47.3 FPS60.9 FPS
DiNAT-Large54.054.941.1 FPS49.1 FPS49.1 FPS59.2 FPS
评论

Thank you for the results and I will keep my rate.

评论

We thank you again for your time, feedback, questions, and your rating.

审稿意见
7

They implemented a fused neighborhood attention kernel (N Atten). N Atten is very useful in reducing the computational cost of various tasks because sequences usually attend to nearby (e.g., Mistral and StreamingLLM).

However, previous implementations of N-Atten kernels are very inefficient because of a lack of (1) utilizing HW tensor acceleration (matmul units such as TensorCore), (2) utilizing fused softmax attention scheme (flash-attention).

They provide two solutions for each problem. (1) Using GEMM by padding the attention matrix (masking the QK matmul), the implementation could be much more easily accelerated by modern HW. (2) using fused softmax attention (flash-attention) to compute the attention output.

优点

They provide clear procedures for building the N-Atten efficiently. The writing and presentation are very clear, especially in Figures 3 and 4. Their implementation is very general to handling 1D-2D-3D attention.

缺点

Their scientific contribution is not novel enough. (1) GEMM-based acceleration is not a unique technique. In many sparse matrix multiplication applications, partially padding and using GEMM units is quite a natural solution (e.g., structural sparsity) (2) Fused softmax attention was a common acceleration technique now (after the release of flash-attention)

问题

  1. Can you provide the code?
  2. This could be easily implemented in tiled acceleration languages like OpenAI Triton and Apache TVM Tensor Expression. How about implementing this method in those languages and checking the performance in heterogeneous architectures (e.g., AMD ROCm, Intel Gaudi)?

局限性

The proposed method only supports the N-Atten. In recent LLM research (e.g., StreamingLLM), many works propose novel methodologies to perform linear or sub-quadratic attention mechanisms using sparse masks. The current state of implementation does not consider such global attention or regional attention.

作者回复

We thank you for your feedback. To answer your questions:

  1. Yes; we can link to or upload an anonymized version of our code if you would like, but our intention is to open source all of our code and integrate into the existing neighborhood attention package. Given the volume of the code it did not seem appropriate/necessary, but we would be happy to scrub the core implementation and share it here.

  2. Yes; you are correct that other frameworks like Triton and TVM will help implement these kernels across architectures, and we fully intend to use such solutions to expand our implementations to more architectures and platform. The reason why we chose CUTLASS and CUDA specifically was mostly because of the level of freedom in optimization and customization that CUDA C++ provides, in addition to being more familiar with those platforms and the process of performance optimization. Our eventual goal is to provide fast generic and cross-platform solutions for neighborhood attention and more generally multi-axis attention, and we certainly would not limit ourselves to just one platform or one case. For this work, we mainly wished to illustrate with a proof of concept that the neighborhood attention family of patterns can be re-formulated as a more generic multi-axis attention, and achieve performance levels close to a practical baseline that is used in production (namely xFormers' FMHA.) Reaching performance levels of newer fused attention kernels, extension to other platforms and architectures, and other attention patterns are definitely on our list for future work. With that said, we'd note that most performant Flash Attention implementations (revision of v1, v2, and the recent v3), were also based on CUTLASS and CUDA C++, similarly to our GEMM-based and fused kernels introduced in the paper. In fact, the more recent FAv3's key difference in terms of using persistent warp-specialized kernels were all concepts that naturally existed in CUTLASS before Triton, given that CUTLASS is an open source solution created by NVIDIA in the first place. While these are specific examples, we merely wish to illustrate another advantage of CUTLASS, which is quicker adoption of new architectural designs from NVIDIA for NVIDIA hardware. Our aim is to provide the best implementation for different hardware and architectures in order to advance research in this direction, and the hardware we had at our disposal happens to be NVIDIA hardware. But again, your point stands that we should not be limiting our implementations to specific platforms and hardware, and we will definitely acknowledge this in the paper.

Regarding the weaknesses, we understand that the techniques used in implementing structured sparsity have existed, and we wish to clarify that we absolutely did not intend to take claim for those, and will definitely revise the paper to reflect this. However, the work is novel in presenting, to our knowledge, the first multi-dimensional fused attention kernel (one that is back to back GETTs and not GEMMs.)

With regard to limitations, we actually can support any non-explicit attention masks (in theory). Our implementation of neighborhood attention masking in FNA provides a very simple interface through which it is easy to implement new arbitrary sparse masks.

It is perfectly feasible to connect our C++ template interface to JIT engines (torch compile) and AOT engines (AITemplate), which can "translate" user-specified masks into FNA masks, somewhat similar to FlexAttention. We did not mention this in the paper, but are happy to do so.

In the future, we can and will support arbitrary non-explicit masks. Lack of implementations of explicit masks was mainly for simplicity, but there isn't anything that fundamentally blocks us from doing so. We will clarify this further in our revision.

评论

Thank you for detailed and kindly described responses. I read most of your response (including mine), and I want to raise my score, because I understand how detailed engineering considerations are there in behind. I think this paper is good enough to be accpeted because this paper will give many insight about fused N Attention applications. I hope this kind of detailed considerations are described in the further revision (in Appendix).

评论

We sincerely thank you for your invaluable feedback; we will indeed adjust the writing and add more details according to suggestions from yourself and other reviewers. Indeed our goal is not only accelerating the speed of research and inference in these cases, but to also provide context and information for accelerated AI and hopefully benefit future researchers, so that they can realize their ideas unencumbered by the tools they have available.

审稿意见
7

The paper introduce a method to improve the performance of neighborhood attention mechanisms. The authors present two new implementations: GEMM-based kernels and fused kernels. These implementations aim to reduce the latency and memory footprint of neighborhood attention in deep learning models, particularly in higher-dimensional spaces. The proposed methods show significant speedup in both full and half precision compared to existing naive CUDA kernels.

优点

  1. The proposed GEMM-based and fused kernels significantly reduces latency, average improvements of 895% and 272% in full precision latency for 1-D and 2-D neighborhood attention
  2. The implementations can also reduce the memory consumption of neighborhood attention.
  3. The methods are designed to work efficiently across different spatial ranks (1-D, 2-D, and 3-D), which makes it more generally applicable.

缺点

  1. The performance improvements in half precision are limited due to the inefficiencies in gather/scatter operations。
  2. The performance gains are more significant on newer architectures like Ampere and Hopper, limiting the generalizability of the results to other hardware setups.

问题

  1. What is the implications of the fp16 acceleration to more advanced low precision format such as fp8 or even fp6?
  2. While the paper provides significant speedup of neighborhood attention, more discussions on the overall bottleneck of it would be interesting, whether the improvements on the speed can be further translated to better accuracy.

局限性

No significant negative societal impacts.

作者回复

We thank you for your feedback. To respond to your questions:

  1. To clarify, our GEMM-based approach's FP16 performance is significantly affected by the gathering and scatter of attention weights, not Fused Neighborhood Attention. FNA does not store or load attention weights to global memory, and hence has no need for a gather scatter on those, and as a result will not run into the "under-alignment" issues on modern hardware. Because of this, we spent more time on the fused approach which does not suffer from this issue, and massively improves FP16 performance compared to both the original implementation and our GEMM-based approach. With regard to FP8 and lower precision, we do not believe there are any significant blockers, mainly because Fused Neighborhood Attention as a concept can be applied to any fused attention implementation, and that includes open source FP8 implementations like Colfax's FP8 Flash Attention, and the recent Flash Attention v3.

  2. Thank you for mentioning this; we would be happy to add those discussions and expand further on the bottlenecks and implications on accuracy. Improvements on the speed of neighborhood attention alone won't directly affect accuracy, since functionality is unaffected. However, using faster implementations like fused neighborhood attention unlocks many new features and provides much better scalability, both of which will provide more flexibility to researchers when building their models, and that will definitely help accuracy in the long term.

With regard to the weaknesses, we would like to clarify that:

  1. The performance improvements of the GEMM-based approach are limited in FP16 due to the gather scatter operations, and for that we actually "recommend" researchers working on local attention kernels to move away from BMM-style implementations and work on fused attention directly instead. The memory alignment issue will grapple any multi-axis local attention kernel, and the limitation is hardware related.

  2. To clarify, FNA supports all NVIDIA architectures since Pascal, and natively targets all architectures up to and including Ampere, and in theory can also target Ada Lovelace. Extension to natively target Hopper Tensor Cores, TMA, and programming model, along with other architectures is in our list of future works. Our intention is to provide the best possible implementation to accelerate research in this direction, and thus far we have only had experience working with NVIDIA hardware, and only had NVIDIA hardware at our disposal. That said, it is possible to extend our implementations to ROCm (AMD GPUs), Metal (Apple Silicon), and more, but as other reviewers have suggested, we may even look into Triton based implementations as well in order to get there more quickly.

评论

I appreciate the response from the authors, especially the clarification on the applicable architecture. I would like to raise the score.

评论

We sincerely thank you for your feedback and your rating.

作者回复

We thank the reviewers for their time and their valuable feedback and suggestions.

We've posted individual rebuttals, and hope to have answered their questions and concerns.

Please let us know if there are any more questions, and we would be happy to elaborate further.

最终决定

The paper locates at the interface of systems hardware and neural architectures, and provides optimized fused CUDA kernels that significantly speed up neighborhood attention (such as popular sliding window attention) variants, a key element of modern transformer models. Reviewers valued the principled engineering work performed here and the level of efficiency gains.

For the camera-ready version, we urge the authors to incorporate the mentioned feedback by the reviewers.