August 8th, 2024

FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention

FlexAttention is a new PyTorch API that enhances flexibility and performance in attention mechanisms, allowing users to implement various attention variants efficiently while leveraging existing infrastructure and improving performance through sparsity.

Read original articleLink Icon
FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention

FlexAttention is a new PyTorch API designed to enhance the flexibility and performance of attention mechanisms in machine learning. Traditional attention implementations, while optimized for performance, often lack the flexibility needed for researchers to experiment with new attention variants without writing custom kernels. FlexAttention addresses this issue by allowing users to implement various attention variants using a few lines of PyTorch code. It compiles these implementations into a fused FlashAttention kernel, which maintains competitive performance without additional memory overhead. The API supports a range of attention variants, including Causal, Relative Positional Embeddings, and more, by enabling users to modify attention scores through a user-defined function. This flexibility allows for the exploration of complex combinations of attention mechanisms while leveraging existing PyTorch infrastructure. Additionally, FlexAttention can take advantage of sparsity in attention masks, leading to significant performance improvements. The API is expected to facilitate innovation in attention mechanisms, limited only by the user's creativity. Examples and applications of FlexAttention can be found in the Attention Gym repository.

- FlexAttention enhances flexibility in implementing various attention mechanisms in PyTorch.

- It compiles user-defined attention variants into efficient FlashAttention kernels.

- The API allows for modifications of attention scores and supports multiple attention variants.

- It leverages sparsity in attention masks for improved performance.

- FlexAttention aims to foster innovation in machine learning research.

Related

The Illustrated Transformer

The Illustrated Transformer

Jay Alammar's blog explores The Transformer model, highlighting its attention mechanism for faster training. It outperforms Google's NMT in some tasks, emphasizing parallelizability. The blog simplifies components like self-attention and multi-headed attention for better understanding.

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-Precision

A new attention mechanism, FlashAttention-3, boosts Transformer speed and accuracy on Hopper GPUs by up to 75%. Leveraging asynchrony and low-precision computing, it achieves 1.5-2x faster processing, utilizing FP8 for quicker computations and reduced costs. FlashAttention-3 optimizes for new hardware features, enhancing efficiency and AI capabilities. Integration into PyTorch is planned.

Black Forest Labs – FLUX.1 open weights SOTA text to image model

Black Forest Labs – FLUX.1 open weights SOTA text to image model

Black Forest Labs has launched to develop generative deep learning models for media, securing $31 million in funding. Their FLUX.1 suite includes three model variants, outperforming competitors in image synthesis.

The open weight Flux text to image model is next level

The open weight Flux text to image model is next level

Black Forest Labs has launched Flux, the largest open-source text-to-image model with 12 billion parameters, available in three versions. It features enhanced image quality and speed, alongside the release of AuraSR V2.

Navigating the Abstract: The Latent Space and the Abstract Ladder

Navigating the Abstract: The Latent Space and the Abstract Ladder

The article explores the relationship between latent space and the abstract ladder in human-AI collaboration, emphasizing creativity, user engagement, and the role of spaced repetition systems in enhancing understanding and innovation.

Link Icon 8 comments
By @chillee - 7 months
Hi, one of the authors of this blog post (Horace He), along with Driss Guessous, Yanbo Liang, and Joy Dong.

We’re quite happy with this abstraction - happy to answer any questions about it!

By @visarga - 7 months
It's interesting that optimizing a computation that can be described in a single line of math takes so much work. It took forever even to discover Flash attention. And in the 6 years since transformers were invented, thousands of papers worked on making it faster.

Attention(Q,K,V) = Softmax(Q*K^T/sqrt(d_k))*V

FlexAttention seems to have found the right abstraction for the task.

By @brrrrrm - 7 months
For most LLM workloads today (short text chats), hundreds or a couple thousand tokens suffice. attention mechanisms don’t dominate (< 30% compute). But as the modalities inevitably grow, work in attention approximation/compression is going to be paramount.

Nice to see Pytorch already elegantly supporting this next step in research

By @hi_hi - 7 months
I didn't see any notice of this being CUDA only (like FlashAttention). I tried running on my Mac M3, python 3.11.8, following the quickstart (with the deviation of running it in a new venv). Got the following error:

/attention-gym/.venv/lib/python3.11/site-packages/torch/_subclasses/functional_tensor.py:258: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.) cpu = _conversion_method_template(device=torch.device("cpu")) Traceback (most recent call last): File "/attention-gym/attn_gym/masks/document_mask.py", line 7, in <module> from torch.nn.attention.flex_attention import _mask_mod_signature ModuleNotFoundError: No module named 'torch.nn.attention.flex_attention'

By @alecco - 7 months
> FlexAttention achieves 90% of FlashAttention2’s performance in the forward pass and 85% in the backward pass.

It's very good. But note FlashAttention-3 is 1.5x - 2x faster than FlashAttention-2.

By @gchamonlive - 7 months
Always had the curiosity to put something together with pytorch but it always seemed either a steep learning curve or there wasn't a big motivator (project, problem to solve, something in my daily routine to optimize).

Does anybody have a good starting point to learn with hands-on projects and also that could accommodate for flexattention?

By @andy12_ - 7 months
This is so cool. I want to try to implement something with this right now.
By @barrenko - 7 months
Can someone do a short summary or TL;DR for this?