July 11th, 2024

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision

A new attention mechanism, FlashAttention-3, boosts Transformer speed and accuracy on Hopper GPUs by up to 75%. Leveraging asynchrony and low-precision computing, it achieves 1.5-2x faster processing, utilizing FP8 for quicker computations and reduced costs. FlashAttention-3 optimizes for new hardware features, enhancing efficiency and AI capabilities. Integration into PyTorch is planned.

Read original articleLink Icon
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision

FlashAttention-3, a new advancement in attention mechanisms for Transformer architectures, has been introduced to enhance speed and accuracy on Hopper GPUs. By leveraging techniques like asynchrony and low-precision computing, FlashAttention-3 achieves up to 75% utilization of the H100 GPU's capabilities, resulting in 1.5-2x faster processing compared to its predecessor. The use of FP8 allows for faster computations while maintaining accuracy, potentially reducing memory usage and costs for large-scale AI operations. Moreover, FlashAttention-3 enables models to handle longer text sequences efficiently, improving performance in tasks requiring extensive context understanding. By optimizing for new hardware features like WGMMA, TMA, and FP8 on Hopper GPUs, FlashAttention-3 demonstrates significant speed enhancements, reaching up to 1.2 PFLOPS in FP8 mode. These advancements showcase the potential for increased efficiency and expanded capabilities in AI applications, with plans for integration into PyTorch in the future.

Link Icon 12 comments
By @refibrillator - 3 months
The code has a comment which seems to hint that Tri Dao was working on FA3 as early as April 2022, the month after Hopper/H100 was announced. I find it mildly curious that over 2 years has elapsed before the code was released today. Perhaps it’s because now there’s better solutions in the pipeline?

Tri’s publication history has been leaning toward SSM and Mamba style architectures recently. Unlike Flash Attention which has quadratic time complexity wrt sequence length, these latest algorithms are subquadratic. Thus they do much less computation, instead of just doing it more efficiently a la Flash Attention.

Dao and Gu published a really long paper this year which demonstrated (among other things) how Mamba/SSM can be formulated such that it’s amenable to acceleration using the same hardware primitives that Transformers benefit from.

By @edude03 - 3 months
How much is the flash attention algorithm tied to the hardware? For example, in this announcement they mention taking advantage of the async capabilities of the H100 GPUs which I assume means you won't get those speedups on non H series card. Two, the actual flash attention library requires CUDA, although the algorithm has apparently?[^0] been ported to metal. I would imagine if the algorithm was literally just a pure function it could be implemented for any GPU/ML framework?

[0]: https://github.com/philipturner/metal-flash-attention

By @WanderPanda - 3 months
Compiler folks: Is there any chance compilers will be able to find optimizations like FlashAttention on their own? Seems like TVM and tinygrad are working in that direction but I find it hard to believe that that would be feasible
By @latchkey - 3 months
If anyone wants to port this over to ROCm / AMD MI300x, reach out to me: hello@hotaisle.xyz (we won't ever spam you).

Happy to donate the compute time for this work.

By @lxe - 3 months
> FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).

How does FA3 fare for consumer GPUs such as 3090 and 4090?

By @saagarjha - 3 months
> TMA (Tensor Memory Accelerator). This is a special hardware unit that accelerates the transfer of data between global memory and shared memory, taking care of all index calculation and out-of-bound predication. This frees up registers, which is a valuable resource to increase tile size and efficiency.

My understanding was that while it frees up registers it more importantly lets the hardware handle address generation, which can become a bottleneck as other operations around it become faster.

By @Der_Einzige - 3 months
This is one of the most important improvements in all of AI, because it benefits most AI users by giving them access to more, faster, for the same hardware with little to no tradeoffs.
By @ex3ndr - 3 months
I am wondering why flash attention is like 5x slower with variable masking than without it? Lack of good masking support almost zeros out the optimizations
By @andy_xor_andrew - 3 months
hoping an expert can answer a few Qs I have :)

Is FlashAttention simply a drop-in replacement for the attention operation in an LLM? Can it be used anywhere that an "attention" operation is used? Or does a LLM need to be trained specially to use FA?

How does FA relate to attention strategies like GQA (grouped query attention) or sliding-window attention? Are they orthogonal concepts? Or you need a specific FA implementation for each strategy?

Recently llama.cpp added flash attention support - does this just mean they started consuming a flash attention-provided CUDA kernel or something?

lastly, in this post, they compare FlashAttention to Triton. I thought Triton was like an abstraction layer? Couldn't FA be implemented in Triton? I just don't really get what it means to say "FlashAttention vs. Triton".

By @LarsDu88 - 3 months
I was wondering... this post mentions that ops like sigmoid are very slow.

A lot of modern LLMs use activation functions with sigmoid or soft max like SiLU, Swish, and SOLU.

Does Relu take less of a performance hit, and if so, maybe it'd be better to go back to good old relu?

By @localfirst - 3 months
spoiler: $xxx,xxx hardware required to run