FlashAttention Internals: How IOAware Attention Rewrote Transformer Efficiency
How FlashAttention works technically: GPU memory hierarchy, tiling for SRAM, the online softmax trick, and FlashAttention-2 warp partitioning.
Standard attention is memory-bound despite appearing compute-intensive. The conventional attention implementation reads Q, K, and V from GPU high-bandwidth memory (HBM), computes the full N×N attention matrix, writes it back to HBM, reads it again for the softmax, writes the softmax output, reads it again for the output projection. For a sequence of 8,192 tokens, the N×N attention matrix is 8,192² = 67 million floats. At 2 bytes each (float16), 134 MB moves between HBM and SRAM on every attention computation. The matrix is written after QK^T computation and read again before the V multiplication.
FlashAttention (Dao et al., 2022) asked a different question: instead of computing the mathematically clean O = softmax(QK^T/√d)V, can we compute the same result by processing Q, K, and V in blocks that fit in SRAM, avoiding the large HBM reads and writes entirely? The answer is yes, through the online softmax algorithm that computes the correct softmax normalization incrementally without materializing the full attention matrix.
The result: FlashAttention achieves the same mathematical output as standard attention while reducing HBM reads/writes from O(N²) to O(N). For long sequences, this is the difference between attention being the bottleneck and attention being fast. FlashAttention-2 further improved this by achieving 50-73% of theoretical peak FLOPs/s on A100 GPUs, compared to 25-40% for the original, and producing a 2x overall speedup for attention computation.
The sections below cover the GPU memory hierarchy that makes standard attention slow, the tiling strategy, the online softmax algorithm that makes tiling possible, the FA-2 improvements to GPU parallelism, and Ring Attention for extending the approach to multi-GPU long-context settings.
The GPU Memory Hierarchy: Why Standard Attention Is Slow
Understanding FlashAttention requires first understanding the GPU memory hierarchy and why it makes standard attention slow.
Modern GPUs have three relevant memory levels:
Registers: ~256 KB per streaming multiprocessor (SM), fastest, only accessible by a single thread. Register-level data is used directly in arithmetic instructions with essentially zero latency.
Shared Memory (SRAM): ~192 KB per SM on A100, shared within a thread block, ~19 TB/s bandwidth, ~0.02ms latency. This is the fast cache that lies between the compute units and main memory.
High Bandwidth Memory (HBM): 40-80 GB on A100, shared across all SMs, ~2 TB/s bandwidth, ~0.05-0.1ms latency. This is where all model activations, weights, and intermediate tensors live.
The bandwidth gap between SRAM and HBM is approximately 10x. Operations that must read from HBM are 10x bandwidth-limited compared to operations that stay in SRAM.
Standard attention is IO-bound, not compute-bound. The attention computation itself (matrix multiplications) is arithmetic-intensive and runs near peak FLOPs. But the operations around it (loading the N×N attention matrix, writing it, reading it, applying softmax) are memory-bound and determine the wall-clock time.
A concrete example for N=4096 (4K context), H=1 head, D=128 head dimension, float16:
Standard attention HBM accesses:
1. Load Q: N × D × 2 bytes = 4096 × 128 × 2 = 1 MB
2. Load K: N × D × 2 bytes = 1 MB
3. Load V: N × D × 2 bytes = 1 MB
4. Write S=QK^T: N × N × 2 bytes = 4096 × 4096 × 2 = 32 MB
5. Load S: 32 MB
6. Write P=softmax(S): 32 MB
7. Load P: 32 MB
8. Write O=PV: N × D × 2 bytes = 1 MB
Total HBM reads/writes: 3 MB + 32+32+32 MB + 1+1 MB ≈ 101 MB per head
FlashAttention HBM accesses:
1. Load Q, K, V blocks: 3 MB (total, all blocks combined = full Q, K, V)
2. Write O: 1 MB
Total HBM reads/writes: ≈ 4 MB per head (25x reduction)At HBM bandwidth of 2 TB/s on A100:
- Standard attention: 101 MB / 2,000 GB/s = 50.5 μs per head
- FlashAttention: 4 MB / 2,000 GB/s = 2 μs per head
The 25x reduction in HBM accesses directly translates to ~5x actual speedup (the gap is smaller than 25x because compute time also contributes, and FlashAttention adds overhead for the tiling logic).
Standard Attention: Counting the HBM Accesses
The standard attention algorithm in PyTorch-style pseudocode:
import torch
import math
def standard_attention(Q, K, V, mask=None):
"""
Standard attention implementation.
Q, K, V: [batch, heads, seq_len, head_dim]
Returns: [batch, heads, seq_len, head_dim]
"""
N = Q.shape[-2] # Sequence length
d = Q.shape[-1] # Head dimension
scale = 1.0 / math.sqrt(d)
# Step 1: Compute attention scores
# Shape: [batch, heads, N, N]
# HBM ops: Load Q (N×d), Load K (N×d), Write S (N×N)
S = torch.matmul(Q, K.transpose(-2, -1)) * scale
# Step 2: Apply mask if provided
if mask is not None:
S = S.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax over last dimension
# HBM ops: Load S (N×N), Write P (N×N)
P = torch.softmax(S, dim=-1)
# Step 4: Weighted sum of values
# HBM ops: Load P (N×N), Load V (N×d), Write O (N×d)
O = torch.matmul(P, V)
return O
# Memory analysis:
# HBM reads: Q (Nd), K (Nd), S (N²), V (Nd) = 2N² + 3Nd bytes (float16)
# HBM writes: S (N²), P (N²), O (Nd) = 2N² + Nd bytes (float16)
# Total: 4N² + 4Nd ≈ O(N²) for large NThe O(N²) HBM access pattern is the problem. For N=32,768 (32K context):
- N² = 1 billion elements per head
- At 2 bytes per element: 2 GB per head
- With 32 heads: 64 GB of HBM traffic per attention layer
- With 32 layers: 2 TB per forward pass, just for attention
This is why long-context inference is expensive: the attention computation requires quadratic HBM traffic regardless of the numerical precision.
The Online Softmax Trick: Computing Without the Full Matrix
The key mathematical insight that enables FlashAttention is the online softmax algorithm. Standard softmax requires two passes over the data: one to find the maximum (for numerical stability), one to compute the exponentials and normalize. Both passes require the full attention score vector.
The online softmax algorithm maintains running statistics (the current maximum and the current normalization factor) and updates them incrementally as new scores are processed. This allows the correct softmax to be computed from left to right without ever materializing the full N-length score vector.
The algorithm:
import math
def online_softmax_sequential(scores: list[float]) -> list[float]:
"""
Compute softmax in a single pass using the online algorithm.
Identical output to standard softmax but processes elements one at a time.
"""
# Initialize running statistics
running_max = float('-inf') # m: running maximum
running_sum = 0.0 # l: running normalization factor
# Intermediate exponential values (not needed in blocked version)
exp_values = []
# Single pass through scores
for s in scores:
if s > running_max:
# New maximum found: rescale the previous sum
running_sum = running_sum * math.exp(running_max - s)
running_max = s
running_sum += math.exp(s - running_max)
exp_values.append(math.exp(s - running_max))
# Normalize
return [e / running_sum for e in exp_values]
def online_softmax_blocked(scores: list[list[float]]) -> list[float]:
"""
Process scores in blocks. Returns the same output as standard softmax.
This is the key to FlashAttention: each block is processed independently
and the running statistics are merged across blocks.
"""
running_max = float('-inf')
running_sum = 0.0
block_results = []
for block in scores:
# Local statistics for this block
block_max = max(block)
block_sum = sum(math.exp(s - block_max) for s in block)
block_exp = [math.exp(s - block_max) for s in block]
# Merge with running statistics
if block_max > running_max:
# Block has a higher max: rescale previous accumulated sum
running_sum = running_sum * math.exp(running_max - block_max) + block_sum
else:
# Block has a lower max: rescale block's sum
running_sum = running_sum + block_sum * math.exp(block_max - running_max)
block_exp = [e * math.exp(block_max - running_max) for e in block_exp]
running_max = max(running_max, block_max)
block_results.append((block_exp, block_max))
# Final normalization pass
result = []
for block_exp, block_max in block_results:
rescale = math.exp(block_max - running_max)
result.extend([e * rescale / running_sum for e in block_exp])
return resultThe blocked version produces mathematically identical output to standard softmax. When processing a block of K tokens, the algorithm only needs to hold one block in memory at a time (SRAM-sized), update the running statistics, and accumulate the output. No N×N matrix is ever fully materialized.
The merge operation for combining block statistics is the non-obvious piece. When a new block has a higher maximum than the running maximum, all previous exponentials must be rescaled by exp(old_max - new_max). This rescaling can be applied lazily, accumulated in the running sum rather than applied to each previous value, which is what makes the blocked computation exact.
FlashAttention Tiling: The Core Algorithm
With the online softmax in hand, the full FlashAttention forward pass algorithm:
Input: Q, K, V ∈ ℝ^{N×d} (in HBM)
Output: O ∈ ℝ^{N×d} (in HBM)
Block sizes: B_r (rows of Q block), B_c (columns of K/V block)
- Choose B_r, B_c such that one Q block + one K/V block fits in SRAM
Algorithm:
1. Divide Q into T_r = ⌈N/B_r⌉ blocks: Q_1, ..., Q_{T_r}
2. Divide K, V into T_c = ⌈N/B_c⌉ blocks: K_1, ..., K_{T_c} and V_1, ..., V_{T_c}
3. For i = 1 to T_r: // Outer loop: over Q blocks
a. Load Q_i from HBM to SRAM // One read of size B_r × d
b. Initialize O_i = 0, l_i = 0, m_i = -inf // Running statistics in registers
For j = 1 to T_c: // Inner loop: over K/V blocks
c. Load K_j, V_j from HBM to SRAM // One read of size B_c × d each
d. Compute S_ij = Q_i K_j^T / √d // Stays in SRAM: B_r × B_c
e. m̃_ij = rowmax(S_ij) // Row maximum of current block
P̃_ij = exp(S_ij - m̃_ij) // Numerically stable exponentials
l̃_ij = rowsum(P̃_ij) // Row sum for this block
f. Merge with running statistics:
m_new = max(m_i, m̃_ij)
l_new = l_i * exp(m_i - m_new) + l̃_ij * exp(m̃_ij - m_new)
g. Update output accumulator:
O_i = diag(l_new)^{-1} * (
diag(l_i) * exp(m_i - m_new) * O_i // Rescale previous output
+ P̃_ij * exp(m̃_ij - m_new) * V_j // Add contribution from current block
)
h. Update statistics: m_i = m_new, l_i = l_new
i. Write O_i to HBM // One write of size B_r × d
Total HBM reads: Q (N×d) + (T_r passes through K, V) = N×d + T_r × N×d
Total HBM writes: O (N×d)
For T_r = N/B_r: Total = O(N²d/M) reads where M = SRAM size
Typically M >> d so this is approximately O(N) HBM accesses for fixed dIn Python, the block computation looks like:
import torch
def flash_attention_forward(Q, K, V, block_size: int = 64) -> torch.Tensor:
"""
FlashAttention forward pass in Python (for illustration, not CUDA-optimized).
Q, K, V: [N, d] (single head, single batch)
"""
N, d = Q.shape
scale = 1.0 / (d ** 0.5)
# Output accumulator and running statistics
O = torch.zeros_like(Q) # [N, d]
l = torch.zeros(N, 1) # Row normalization factors [N, 1]
m = torch.full((N, 1), float('-inf')) # Row maximums [N, 1]
# Block sizes (in practice chosen to fit SRAM)
B_r = block_size # Rows of Q block
B_c = block_size # Columns of K/V block
T_r = (N + B_r - 1) // B_r # Number of Q blocks
T_c = (N + B_c - 1) // B_c # Number of K/V blocks
for i in range(T_r):
# Slice of query block [B_r, d]
r_start, r_end = i * B_r, min((i + 1) * B_r, N)
Q_i = Q[r_start:r_end] # Load from "HBM" to "SRAM"
O_i = O[r_start:r_end]
l_i = l[r_start:r_end]
m_i = m[r_start:r_end]
for j in range(T_c):
# Slice of key/value block [B_c, d]
c_start, c_end = j * B_c, min((j + 1) * B_c, N)
K_j = K[c_start:c_end] # Load from "HBM" to "SRAM"
V_j = V[c_start:c_end]
# Attention scores for this block: [B_r, B_c]
S_ij = (Q_i @ K_j.T) * scale
# Block statistics
m_tilde = S_ij.max(dim=-1, keepdim=True).values # [B_r, 1]
P_tilde = torch.exp(S_ij - m_tilde) # [B_r, B_c]
l_tilde = P_tilde.sum(dim=-1, keepdim=True) # [B_r, 1]
# Merge with running statistics
m_new = torch.maximum(m_i, m_tilde) # [B_r, 1]
l_new = (l_i * torch.exp(m_i - m_new) +
l_tilde * torch.exp(m_tilde - m_new)) # [B_r, 1]
# Update output: rescale previous + add current block's contribution
O_i = (l_i * torch.exp(m_i - m_new) * O_i +
torch.exp(m_tilde - m_new) * P_tilde @ V_j)
# Update running statistics
m_i = m_new
l_i = l_new
# Finalize output for this row block
O[r_start:r_end] = O_i / l_i # Normalize
l[r_start:r_end] = l_i
m[r_start:r_end] = m_i
return OThe block size B_r × B_c must be chosen such that Q_i (B_r × d) + K_j (B_c × d) + V_j (B_c × d) + S_ij (B_r × B_c) fit in SRAM. On A100 with 192 KB SRAM per SM, and d=128, float16:
- Q_i: B_r × 128 × 2 bytes
- K_j + V_j: 2 × B_c × 128 × 2 bytes
- S_ij: B_r × B_c × 2 bytes
- Total: 512*(B_r + 2B_c) + 2B_r*B_c ≤ 192,000 bytes
A typical configuration: B_r = B_c = 64 → 512192 + 24096 = 106,496 bytes ≈ 104 KB. Fits comfortably.
FlashAttention2: Warp Partitioning and Reduced NonMatmul FLOPs
FlashAttention-2 (Dao, 2023) achieved a 2x speedup over FlashAttention-1 through three improvements:
Improvement 1: Reduced non-matmul FLOPs
FA1 accumulated the output O_i in a way that required per-row rescaling on every K/V block iteration, an O(B_r) operation per block. FA2 defers this rescaling to the end of the row block, doing it only once instead of T_c times. This reduces the number of non-matrix-multiply operations by a factor of T_c (the number of K/V blocks).
On A100, matrix multiplications run at 312 TFLOPs/s, while non-matmul operations run at approximately 19.5 TFLOPs/s, a 16x gap. Reducing non-matmul operations has disproportionate impact on total compute time.
Improvement 2: Better parallelism across thread blocks
FA1 parallelized the outer loop (Q blocks) across thread blocks. FA2 also parallelizes over batch size and heads, maximizing GPU occupancy. For multi-head attention with H heads and batch size B, FA2 launches B × H thread blocks, each handling one head's row-block loop. This improves GPU utilization especially for smaller sequence lengths where FA1 might underutilize the GPU.
Improvement 3: Warp-level parallelism for the inner loop
Within a thread block, FA2 partitions the inner loop (over K/V columns) across warps using a different partitioning scheme than FA1:
FA1 warp partitioning:
- Each warp processes rows of Q_i independently
- Different warps need to share the output O_i accumulator → synchronization overhead
FA2 warp partitioning:
- Each warp processes columns of K_j / V_j
- Each warp accumulates its own partial O_i (different attention score slice)
- Final reduction via shared memory: one synchronization per block instead of per-warp-stepThe FA2 scheme reduces shared memory communication overhead per block from O(B_r) per inner step to O(B_r × d) once per outer step. For typical block sizes, this reduces synchronization overhead by a factor of T_c.
The combined effect: FlashAttention-2 achieves 50-73% of A100 theoretical peak FLOPs/s (depending on sequence length and batch size), compared to 25-40% for FA1. The remaining gap from 100% theoretical peak is due to non-matmul operations and memory access latency that cannot be fully hidden.
The Backward Pass: Recomputation Instead of Storage
The backward pass for attention requires the attention probability matrix P (the softmax output) to compute gradients. Standard backpropagation stores P in HBM during the forward pass for use during the backward pass. This is the O(N²) memory requirement of standard attention.
FlashAttention uses a different approach: recomputation. During the backward pass, instead of loading P from HBM (where it was stored during forward), recompute P from Q and K. Since Q and K are much smaller than P (O(Nd) vs O(N²)), this trades memory for compute.
The tradeoff at 4K context, 128 head dimension:
- Storing P: 4096² × 2 bytes = 32 MB per head
- Recomputing P: 4096 × 128 × 2 bytes = 1 MB per head + compute time
For long sequences (32K+), recomputation is 1,000x more memory-efficient than storage, at the cost of 30% additional compute during backward. Since memory bandwidth, not compute, is the bottleneck, this tradeoff is favorable.
def flash_attention_backward_sketch(dO, Q, K, V, O, L, M):
"""
Sketch of FlashAttention backward pass.
L: log normalization factors from forward pass (stored as running l_i)
M: row maximums from forward pass (stored as running m_i)
Key: P is RECOMPUTED, not loaded from HBM.
"""
N, d = Q.shape
dQ = torch.zeros_like(Q)
dK = torch.zeros_like(K)
dV = torch.zeros_like(V)
scale = 1.0 / (d ** 0.5)
# For each block of K/V:
for j in range(0, N, BLOCK_SIZE):
K_j = K[j:j+BLOCK_SIZE]
V_j = V[j:j+BLOCK_SIZE]
# For each block of Q:
for i in range(0, N, BLOCK_SIZE):
Q_i = Q[i:i+BLOCK_SIZE]
# Recompute attention scores (not loaded from HBM)
S_ij = Q_i @ K_j.T * scale
P_ij = torch.exp(S_ij - M[i:i+BLOCK_SIZE]) / L[i:i+BLOCK_SIZE]
# Compute gradients (standard attention gradient formulas)
dV_j = P_ij.T @ dO[i:i+BLOCK_SIZE]
dP_ij = dO[i:i+BLOCK_SIZE] @ V_j.T
dS_ij = P_ij * (dP_ij - (dO[i:i+BLOCK_SIZE] * O[i:i+BLOCK_SIZE]).sum(-1, keepdim=True))
dQ[i:i+BLOCK_SIZE] += dS_ij @ K_j * scale
dK[j:j+BLOCK_SIZE] += dS_ij.T @ Q_i * scale
dV[j:j+BLOCK_SIZE] += dV_j
return dQ, dK, dVPerformance Numbers: What the Benchmarks Show
FlashAttention achieves significant improvements across the GPU model landscape:
A100 attention speedup (forward + backward, seq_len=2048, head_dim=64):
| Implementation | FLOPs/s | % peak FLOPs | Speedup vs standard |
|---|---|---|---|
| Standard PyTorch attention | ~95 TFLOPs/s | ~30% | 1.0x |
| FlashAttention-1 | ~180 TFLOPs/s | ~58% | 1.9x |
| FlashAttention-2 | ~225 TFLOPs/s | ~72% | 2.4x |
| FlashAttention-3 (H100) | ~300 TFLOPs/s | ~55% of H100 peak | 3.2x vs std |
End-to-end GPT training speedup (reported by Dao et al.):
- GPT-2 (1.5B): FA-1 gives 1.7x training speedup vs standard attention
- GPT-3 style (175B): FA-1 gives 3x memory reduction and 1.5x training speedup
- Long-context (sequence 16K vs 1K): FA-1 gives 10x memory reduction, 6x attention speedup
Memory reduction (no intermediate N×N matrix):
At seq_len=32K, d=128, float16:
- Standard attention: 32,768² × 2 bytes = 2 GB per head per layer
- FlashAttention: ~0 GB for intermediates (O(N) for running statistics only)
The memory reduction is the more significant practical benefit for long-context training. Without FlashAttention, training with 32K context on a 70B model is infeasible on any existing GPU cluster. With FlashAttention, it requires only the O(N) storage for Q, K, V, and O. No N×N intermediate is needed.
Ring Attention: MultiGPU LongContext Scaling
Ring Attention (Liu et al., 2023) extends FlashAttention to multi-GPU settings, enabling training and inference at context lengths beyond what fits on a single GPU.
The approach: distribute the sequence across P GPUs, with each GPU holding N/P tokens. For attention, each GPU needs to compute its portion of the attention output using Q from its own slice and K/V from all P slices. Ring Attention arranges GPUs in a ring topology where each GPU passes its K/V block to the next GPU while simultaneously processing the K/V block it received from the previous GPU.
GPU 0: GPU 1: GPU 2: GPU 3:
Holds Q[0:N/P] Holds Q[N/P:2N/P] Holds Q[2N/P:3N/P] Holds Q[3N/P:N]
↓ ↓ ↓ ↓
Step 1: Each GPU computes attention with its own K/V
Step 2: Each GPU sends its K/V to next in ring, receives K/V from previous
→ compute → pass → compute → pass → ...
After P steps: each GPU has processed all K/V blocks for its Q slice
Communication: each GPU sends/receives N/P × d × 2 bytes per step × (P-1) steps
= (N × d × 2) bytes total = same as the K/V size (not N² bandwidth)Ring Attention achieves linear memory scaling with sequence length: adding more GPUs scales the total context length proportionally. For a 70B model with 80 GB VRAM per GPU:
- Single GPU: ~32K context maximum (after FA memory reduction)
- 4 GPUs in ring: ~128K context
- 8 GPUs: ~256K context
- 16 GPUs: ~512K context
The communication overhead is O(N×d), the same order as the K/V matrix size, which hides well behind the attention computation time for long sequences on high-bandwidth interconnects (NVLink: 600 GB/s bidirectional).
Ring Attention is the mechanism behind long-context training in most frontier model labs. Training Claude or GPT-4 at 100K+ context required multi-GPU attention kernels with this communication pattern.
Practical Usage: When FlashAttention Helps Most
FlashAttention is available as a drop-in replacement for standard attention in PyTorch through torch.nn.functional.scaled_dot_product_attention (FA2 backend on CUDA ≥ 11.6) and the flash-attn package.
import torch
import torch.nn.functional as F
# Method 1: PyTorch native (uses FA2 backend automatically when available)
def efficient_attention(Q, K, V, attn_mask=None, dropout_p=0.0, is_causal=False):
"""
PyTorch will select FlashAttention backend automatically if:
- Running on CUDA
- Head dim ≤ 256
- Inputs are float16 or bfloat16
- No custom attention mask (or causal mask only)
"""
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # Disable naive implementation
enable_mem_efficient=True,
):
return F.scaled_dot_product_attention(
Q, K, V,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
# Method 2: Direct flash-attn package
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input
def flash_attention_varlen(Q, K, V, lengths):
"""Variable-length sequences with padding removal."""
# Unpad for efficient processing
Q_unpad, indices, cu_seqlens, max_seqlen = unpad_input(Q, lengths)
K_unpad, _, _, _ = unpad_input(K, lengths)
V_unpad, _, _, _ = unpad_input(V, lengths)
# FlashAttention with variable length handling
output_unpad = flash_attn_func(
Q_unpad, K_unpad, V_unpad,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
)
# Repad
return pad_input(output_unpad, indices, Q.shape[0], Q.shape[1])FlashAttention helps most when:
- Sequence length ≥ 1024 (below this, overhead doesn't justify gains)
- Running float16 or bfloat16 (required for FA2; float32 uses the math backend)
- Using causal attention (decoder-only models): FA optimizes the causal mask efficiently
- Memory-constrained training at long context lengths
FlashAttention helps least when:
- Short sequences (<512 tokens): standard attention may be faster due to lower overhead
- Float32 precision required: FA2 only supports 16-bit
- Complex custom attention patterns: sparse attention, sliding window with non-trivial patterns
- CPU inference: FA is a CUDA kernel with no benefit on CPU
Key Takeaways
-
FlashAttention is IO-aware attention: it reduces HBM reads/writes from O(N²) to approximately O(N) by tiling the computation into blocks that fit in SRAM. The memory reduction at 32K context is approximately 1000x (from 2 GB to ~2 MB for attention intermediates per head), making long-context training and inference feasible.
-
The online softmax algorithm is the mathematical enabler: it computes the exact same softmax output as the standard two-pass algorithm but processes data in blocks using running maximum and normalization statistics. The merge operation between blocks ensures numerical correctness without ever materializing the full N×N attention matrix.
-
FlashAttention-2 achieves 2x speedup over FA-1 through three mechanisms: deferred rescaling (reduces non-matmul FLOPs by T_c times), better thread block parallelism across batch/heads, and warp-level partitioning over columns rather than rows (reduces shared memory synchronization).
-
On A100 GPUs, FlashAttention-2 achieves 50-73% of theoretical peak FLOPs/s compared to 25-40% for FA-1 and ~30% for standard attention. The remaining gap from 100% peak is due to irreducible non-matmul operations and memory latency that cannot be fully hidden by tiling.
-
The backward pass uses recomputation instead of storage: P (the N×N softmax output) is recomputed during backpropagation rather than stored from the forward pass. This trades ~30% extra compute during backward for the elimination of 2 GB per head of HBM storage at 32K context, and is always the right tradeoff for long sequences.
-
Ring Attention extends FlashAttention to multi-GPU settings by distributing sequences across GPUs in a ring topology. Communication overhead is O(N×d), linear in sequence length, enabling linear scaling of maximum context length with GPU count. This is the mechanism behind 128K+ context training in frontier models.
FAQ
What problem does FlashAttention solve?
FlashAttention solves the memory bandwidth bottleneck in standard transformer attention. Standard attention computes the full N×N attention matrix, which must be written to and read from high-bandwidth memory (HBM): 134 MB at 4K context. This HBM traffic, not the arithmetic operations, is the bottleneck. FlashAttention uses a tiling algorithm that processes the attention computation in small blocks that fit in the faster on-chip SRAM (a 10x bandwidth improvement over HBM), computing the exact same result without ever materializing the full N×N matrix. The result is ~25x fewer HBM accesses, translating to 2-3x overall speedup and dramatically reduced memory usage, eliminating the O(N²) memory requirement that made long-context training prohibitively expensive.
What is the online softmax trick in FlashAttention?
The online softmax trick is the mathematical technique that makes FlashAttention's block processing produce numerically correct results. Standard softmax requires two passes over all N scores: one to find the maximum (for numerical stability), one to compute exponentials and normalize. The online algorithm maintains two running statistics (the current maximum m and the current normalization sum l) and updates them as each block of scores is processed. When a new block has a higher maximum, previous accumulated values are rescaled by exp(old_max - new_max). When the new block's maximum is lower, the block's contribution is rescaled. After all blocks, the correct softmax normalization is known without ever materializing the full N-score vector. This is mathematically identical to standard softmax: same output, different computation order that enables the block processing FlashAttention requires.
How much faster is FlashAttention compared to standard attention?
FlashAttention-2 achieves approximately 2-3x end-to-end speedup compared to standard PyTorch attention for typical sequence lengths (2K-8K). At A100 GPU peak utilization, FA-2 reaches 50-73% of theoretical peak FLOPs/s (225 TFLOPs/s on A100) compared to 25-40% for FA-1 and ~30% for standard attention. The speedup increases with sequence length: at 32K context, FlashAttention is approximately 5-8x faster for the attention computation alone. In end-to-end GPT-scale training, FlashAttention reduces training time by 1.5-3x and enables 3-10x reduction in peak memory usage for long-context experiments.
What is Ring Attention and how does it scale context length?
Ring Attention is a multi-GPU extension of FlashAttention that enables context lengths beyond the memory capacity of a single GPU. The sequence is distributed across P GPUs, each holding N/P tokens. GPUs are arranged in a ring topology: during attention computation, each GPU processes its Q block against the K/V block it currently holds, then passes its K/V block to the next GPU and receives the previous GPU's K/V block. After P steps, each GPU has processed its Q block against all K/V blocks. The communication cost is O(N×d) total, linear in sequence length rather than quadratic, which is the same order as reading K/V once. Ring Attention enables linear scaling: each additional GPU multiplies the maximum context length by the ratio of GPUs added, allowing 128K+ context training by distributing across 8-16 GPUs on high-bandwidth NVLink interconnects.
FlashAttention is not a trick or an approximation. It computes the exact same attention function as the standard algorithm, using the same mathematical operations in a different order. The O(N²) HBM access cost is not a consequence of the attention mechanism's mathematics. It is a consequence of how naive implementations materialize intermediate computations. FlashAttention reveals that the intermediate materialization was never necessary.
The tiling insight applies broadly beyond attention: any computation that involves large intermediate matrices can potentially be restructured to avoid HBM round-trips. The online softmax algorithm that makes FlashAttention possible is one instance of a general technique for computing aggregations over blocked data without materialization.
For practitioners, the action items are straightforward: use torch.nn.functional.scaled_dot_product_attention which selects the FlashAttention backend automatically, or the flash-attn package for more control. At sequence lengths above 1K and in float16, you get the speedup without any changes to your model architecture.
The longer-term implication: FlashAttention makes 100K-1M context lengths economically viable. The models trained on these context lengths have qualitatively different capabilities from models with 4K-8K context. The efficiency improvement is not faster training alone. It expands the space of possible models.
Written & published by Chaitanya Prabuddha