FlashAttention is an IO-aware algorithm that computes exact attention without ever writing the full attention matrix to GPU high-bandwidth memory. By tiling the computation into blocks that fit in fast on-chip SRAM and using an online softmax, it reduces memory traffic and turns attention's memory cost from quadratic to linear in sequence length.
What is FlashAttention?
FlashAttention is an algorithm for computing the exact self-attention of a transformer faster and with far less memory by being aware of how data moves through the GPU memory hierarchy. Introduced by Tri Dao and colleagues in 2022, it produces numerically the same result as standard attention but avoids the dominant cost of standard implementations: repeatedly reading and writing the large intermediate attention matrix to slow GPU high-bandwidth memory (HBM).
The central observation is that attention on modern GPUs is memory-bound, not compute-bound. The matrices of attention scores and probabilities scale with the square of the sequence length, and moving them between fast on-chip SRAM and slow HBM dominates the runtime. FlashAttention restructures the computation so these intermediate matrices are never fully written out.
- Computes exact attention, not an approximation.
- Reduces HBM memory access, which is the real bottleneck on GPUs.
- Cuts attention's memory footprint from quadratic to linear in sequence length.
How IO-aware tiling works
FlashAttention splits the query, key, and value matrices into blocks. It loads a block of queries and iterates over blocks of keys and values, computing partial attention outputs entirely inside fast SRAM. Because softmax normalization requires the maximum and the sum of exponentials over the whole row, FlashAttention uses an online softmax that updates a running maximum and a running normalizer as each new key block is processed, then rescales the accumulated output. This lets it produce the correct normalized result without holding the entire score row in memory.
For the backward pass, FlashAttention recomputes the attention scores from the stored statistics rather than caching the full attention matrix. This recomputation is cheaper than the memory traffic it replaces, which is why the method is faster overall despite doing more arithmetic.
- Tiling keeps working data in SRAM, which has far higher bandwidth than HBM.
- An online softmax merges block-wise results without materializing the full score matrix.
- The backward pass recomputes scores from saved row statistics instead of storing them.
Using FlashAttention in practice
FlashAttention ships as a CUDA library and is integrated into PyTorch's scaled dot product attention, as well as Hugging Face Transformers and most major training stacks. In PyTorch the kernel is selected automatically when conditions are met, and it can be requested explicitly through the torch.nn.attention.sdpa_kernel context manager. The official flash-attention package also exposes the kernel directly.
Impact and later versions
The original FlashAttention paper reported concrete speedups including roughly 15 percent faster BERT-large training at sequence length 512, about 3x faster GPT-2 training, and 2.4x speedups on long-range tasks, while making sequence lengths of 16K and 64K tractable. Because the algorithm produces exact results, it can be dropped in without changing model quality.
Successor versions improved hardware utilization further. FlashAttention-2 reorganized the work partitioning to better use GPU compute units, and FlashAttention-3 added optimizations for newer Hopper-class GPUs, including support for lower-precision and asynchronous execution. The approach is now standard in production LLM training and inference.
- Original paper: 3x GPT-2 training speedup and tractable 16K to 64K contexts.
- FlashAttention-2 improved parallelism and work partitioning.
- FlashAttention-3 targets Hopper GPUs with lower precision and async execution.
Key takeaways
- FlashAttention computes exact attention but avoids writing the full attention matrix to slow GPU memory.
- Tiling plus an online softmax keeps intermediate results in fast on-chip SRAM, reducing memory traffic.
- It changes attention memory from quadratic to linear in sequence length, enabling much longer contexts.
- It is integrated into PyTorch SDPA and Hugging Face, with FlashAttention-2 and -3 adding further GPU optimizations.
Frequently asked questions
Related terms
Related reading
Sources
Put the idea into practice
MemX is an AI memory app built on these ideas: store anything, skip the folders, and find it again by asking in plain English.
Try MemX Free