AI Foundations

FlashAttention

By Arpit Tripathi, Founder

FlashAttention is an IO-aware algorithm that computes exact attention without ever writing the full attention matrix to GPU high-bandwidth memory. By tiling the computation into blocks that fit in fast on-chip SRAM and using an online softmax, it reduces memory traffic and turns attention's memory cost from quadratic to linear in sequence length.

What is FlashAttention?

FlashAttention is an algorithm for computing the exact self-attention of a transformer faster and with far less memory by being aware of how data moves through the GPU memory hierarchy. Introduced by Tri Dao and colleagues in 2022, it produces numerically the same result as standard attention but avoids the dominant cost of standard implementations: repeatedly reading and writing the large intermediate attention matrix to slow GPU high-bandwidth memory (HBM).

The central observation is that attention on modern GPUs is memory-bound, not compute-bound. The matrices of attention scores and probabilities scale with the square of the sequence length, and moving them between fast on-chip SRAM and slow HBM dominates the runtime. FlashAttention restructures the computation so these intermediate matrices are never fully written out.

  • Computes exact attention, not an approximation.
  • Reduces HBM memory access, which is the real bottleneck on GPUs.
  • Cuts attention's memory footprint from quadratic to linear in sequence length.

How IO-aware tiling works

FlashAttention splits the query, key, and value matrices into blocks. It loads a block of queries and iterates over blocks of keys and values, computing partial attention outputs entirely inside fast SRAM. Because softmax normalization requires the maximum and the sum of exponentials over the whole row, FlashAttention uses an online softmax that updates a running maximum and a running normalizer as each new key block is processed, then rescales the accumulated output. This lets it produce the correct normalized result without holding the entire score row in memory.

For the backward pass, FlashAttention recomputes the attention scores from the stored statistics rather than caching the full attention matrix. This recomputation is cheaper than the memory traffic it replaces, which is why the method is faster overall despite doing more arithmetic.

Online softmax update: m_new = max(m_old, max(s_block)); l_new = exp(m_old - m_new)·l_old + sum(exp(s_block - m_new)); O_new = exp(m_old - m_new)·O_old + exp(s_block - m_new)·V_block
As each key/value block is processed, the running max m, the running normalizer l, and the output accumulator O are rescaled so the final result equals the standard softmax attention exactly.
  • Tiling keeps working data in SRAM, which has far higher bandwidth than HBM.
  • An online softmax merges block-wise results without materializing the full score matrix.
  • The backward pass recomputes scores from saved row statistics instead of storing them.

Using FlashAttention in practice

FlashAttention ships as a CUDA library and is integrated into PyTorch's scaled dot product attention, as well as Hugging Face Transformers and most major training stacks. In PyTorch the kernel is selected automatically when conditions are met, and it can be requested explicitly through the torch.nn.attention.sdpa_kernel context manager. The official flash-attention package also exposes the kernel directly.

python
import torch
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend

q = torch.randn(2, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)

# Current API (PyTorch 2.3+): the old torch.backends.cuda.sdp_kernel is deprecated
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Or call the official package directly (flash-attn)
# from flash_attn import flash_attn_func
# out = flash_attn_func(q, k, v, causal=True)
Two ways to invoke FlashAttention: PyTorch's current sdpa_kernel dispatcher and the official package.

Impact and later versions

The original FlashAttention paper reported concrete speedups including roughly 15 percent faster BERT-large training at sequence length 512, about 3x faster GPT-2 training, and 2.4x speedups on long-range tasks, while making sequence lengths of 16K and 64K tractable. Because the algorithm produces exact results, it can be dropped in without changing model quality.

Successor versions improved hardware utilization further. FlashAttention-2 reorganized the work partitioning to better use GPU compute units, and FlashAttention-3 added optimizations for newer Hopper-class GPUs, including support for lower-precision and asynchronous execution. The approach is now standard in production LLM training and inference.

  • Original paper: 3x GPT-2 training speedup and tractable 16K to 64K contexts.
  • FlashAttention-2 improved parallelism and work partitioning.
  • FlashAttention-3 targets Hopper GPUs with lower precision and async execution.

Key takeaways

  • FlashAttention computes exact attention but avoids writing the full attention matrix to slow GPU memory.
  • Tiling plus an online softmax keeps intermediate results in fast on-chip SRAM, reducing memory traffic.
  • It changes attention memory from quadratic to linear in sequence length, enabling much longer contexts.
  • It is integrated into PyTorch SDPA and Hugging Face, with FlashAttention-2 and -3 adding further GPU optimizations.

Frequently asked questions

Standard attention is bottlenecked by moving the large intermediate attention matrix between fast and slow GPU memory. FlashAttention restructures the computation so that matrix is never fully written to slow memory, cutting memory traffic and speeding up training and inference.
No. FlashAttention computes exact attention. Its tiling and online softmax produce numerically the same output as a standard implementation, so it can replace standard attention without changing model accuracy.
It processes attention in blocks that fit in on-chip SRAM and uses a running softmax to combine them, so it never stores the full sequence-by-sequence score matrix. This makes memory grow linearly rather than quadratically with sequence length.
Usually no. It is integrated into PyTorch's scaled dot product attention and Hugging Face Transformers, so it is selected automatically when tensor shapes and dtypes qualify, or you can enable it explicitly through those interfaces.
FlashAttention-2 reorganized how work is split across GPU threads for better utilization. FlashAttention-3 adds optimizations for Hopper-class GPUs, including lower-precision formats and asynchronous execution, pushing throughput higher on newer hardware.