Writing

FlashAttention Internals: How IOAware Attention Rewrote Transformer Efficiency

How FlashAttention works technically: GPU memory hierarchy, tiling for SRAM, the online softmax trick, and FlashAttention-2 warp partitioning.

24 min read
Edge InferenceFlashAttention internalsFlashAttention algorithm explained

Standard attention is memory-bound despite appearing compute-intensive. The conventional attention implementation reads Q, K, and V from GPU high-bandwidth memory (HBM), computes the full N×N attention matrix, writes it back to HBM, reads it again for the softmax, writes the softmax output, reads it again for the output projection. For a sequence of 8,192 tokens, the N×N attention matrix is 8,192² = 67 million floats. At 2 bytes each (float16), 134 MB moves between HBM and SRAM on every attention computation. The matrix is written after QK^T computation and read again before the V multiplication.

FlashAttention (Dao et al., 2022) asked a different question: instead of computing the mathematically clean O = softmax(QK^T/√d)V, can we compute the same result by processing Q, K, and V in blocks that fit in SRAM, avoiding the large HBM reads and writes entirely? The answer is yes, through the online softmax algorithm that computes the correct softmax normalization incrementally without materializing the full attention matrix.

The result: FlashAttention achieves the same mathematical output as standard attention while reducing HBM reads/writes from O(N²) to O(N). For long sequences, this is the difference between attention being the bottleneck and attention being fast. FlashAttention-2 further improved this by achieving 50-73% of theoretical peak FLOPs/s on A100 GPUs, compared to 25-40% for the original, and producing a 2x overall speedup for attention computation.

The sections below cover the GPU memory hierarchy that makes standard attention slow, the tiling strategy, the online softmax algorithm that makes tiling possible, the FA-2 improvements to GPU parallelism, and Ring Attention for extending the approach to multi-GPU long-context settings.

The GPU Memory Hierarchy: Why Standard Attention Is Slow

Understanding FlashAttention requires first understanding the GPU memory hierarchy and why it makes standard attention slow.

Modern GPUs have three relevant memory levels:

Registers: ~256 KB per streaming multiprocessor (SM), fastest, only accessible by a single thread. Register-level data is used directly in arithmetic instructions with essentially zero latency.

Shared Memory (SRAM): ~192 KB per SM on A100, shared within a thread block, ~19 TB/s bandwidth, ~0.02ms latency. This is the fast cache that lies between the compute units and main memory.

High Bandwidth Memory (HBM): 40-80 GB on A100, shared across all SMs, ~2 TB/s bandwidth, ~0.05-0.1ms latency. This is where all model activations, weights, and intermediate tensors live.

The bandwidth gap between SRAM and HBM is approximately 10x. Operations that must read from HBM are 10x bandwidth-limited compared to operations that stay in SRAM.

Standard attention is IO-bound, not compute-bound. The attention computation itself (matrix multiplications) is arithmetic-intensive and runs near peak FLOPs. But the operations around it (loading the N×N attention matrix, writing it, reading it, applying softmax) are memory-bound and determine the wall-clock time.

A concrete example for N=4096 (4K context), H=1 head, D=128 head dimension, float16:

Standard attention HBM accesses:
1. Load Q:        N × D × 2 bytes = 4096 × 128 × 2 = 1 MB
2. Load K:        N × D × 2 bytes = 1 MB
3. Load V:        N × D × 2 bytes = 1 MB
4. Write S=QK^T:  N × N × 2 bytes = 4096 × 4096 × 2 = 32 MB
5. Load S:        32 MB
6. Write P=softmax(S): 32 MB
7. Load P:        32 MB
8. Write O=PV:    N × D × 2 bytes = 1 MB
 
Total HBM reads/writes: 3 MB + 32+32+32 MB + 1+1 MB ≈ 101 MB per head
 
FlashAttention HBM accesses:
1. Load Q, K, V blocks: 3 MB (total, all blocks combined = full Q, K, V)
2. Write O:           1 MB
 
Total HBM reads/writes: ≈ 4 MB per head (25x reduction)

At HBM bandwidth of 2 TB/s on A100:

  • Standard attention: 101 MB / 2,000 GB/s = 50.5 μs per head
  • FlashAttention: 4 MB / 2,000 GB/s = 2 μs per head

The 25x reduction in HBM accesses directly translates to ~5x actual speedup (the gap is smaller than 25x because compute time also contributes, and FlashAttention adds overhead for the tiling logic).

Standard Attention: Counting the HBM Accesses

The standard attention algorithm in PyTorch-style pseudocode:

import torch
import math
 
def standard_attention(Q, K, V, mask=None):
    """
    Standard attention implementation.
    Q, K, V: [batch, heads, seq_len, head_dim]
    Returns: [batch, heads, seq_len, head_dim]
    """
    N = Q.shape[-2]        # Sequence length
    d = Q.shape[-1]        # Head dimension
    scale = 1.0 / math.sqrt(d)
 
    # Step 1: Compute attention scores
    # Shape: [batch, heads, N, N]
    # HBM ops: Load Q (N×d), Load K (N×d), Write S (N×N)
    S = torch.matmul(Q, K.transpose(-2, -1)) * scale
 
    # Step 2: Apply mask if provided
    if mask is not None:
        S = S.masked_fill(mask == 0, float('-inf'))
 
    # Step 3: Softmax over last dimension
    # HBM ops: Load S (N×N), Write P (N×N)
    P = torch.softmax(S, dim=-1)
 
    # Step 4: Weighted sum of values
    # HBM ops: Load P (N×N), Load V (N×d), Write O (N×d)
    O = torch.matmul(P, V)
 
    return O
 
# Memory analysis:
# HBM reads: Q (Nd), K (Nd), S (N²), V (Nd) = 2N² + 3Nd bytes (float16)
# HBM writes: S (N²), P (N²), O (Nd) = 2N² + Nd bytes (float16)
# Total: 4N² + 4Nd ≈ O(N²) for large N

The O(N²) HBM access pattern is the problem. For N=32,768 (32K context):

  • N² = 1 billion elements per head
  • At 2 bytes per element: 2 GB per head
  • With 32 heads: 64 GB of HBM traffic per attention layer
  • With 32 layers: 2 TB per forward pass, just for attention

This is why long-context inference is expensive: the attention computation requires quadratic HBM traffic regardless of the numerical precision.

The Online Softmax Trick: Computing Without the Full Matrix

The key mathematical insight that enables FlashAttention is the online softmax algorithm. Standard softmax requires two passes over the data: one to find the maximum (for numerical stability), one to compute the exponentials and normalize. Both passes require the full attention score vector.

The online softmax algorithm maintains running statistics (the current maximum and the current normalization factor) and updates them incrementally as new scores are processed. This allows the correct softmax to be computed from left to right without ever materializing the full N-length score vector.

The algorithm:

import math
 
def online_softmax_sequential(scores: list[float]) -> list[float]:
    """
    Compute softmax in a single pass using the online algorithm.
    Identical output to standard softmax but processes elements one at a time.
    """
    # Initialize running statistics
    running_max = float('-inf')   # m: running maximum
    running_sum = 0.0             # l: running normalization factor
 
    # Intermediate exponential values (not needed in blocked version)
    exp_values = []
 
    # Single pass through scores
    for s in scores:
        if s > running_max:
            # New maximum found: rescale the previous sum
            running_sum = running_sum * math.exp(running_max - s)
            running_max = s
 
        running_sum += math.exp(s - running_max)
        exp_values.append(math.exp(s - running_max))
 
    # Normalize
    return [e / running_sum for e in exp_values]
 
def online_softmax_blocked(scores: list[list[float]]) -> list[float]:
    """
    Process scores in blocks. Returns the same output as standard softmax.
    This is the key to FlashAttention: each block is processed independently
    and the running statistics are merged across blocks.
    """
    running_max = float('-inf')
    running_sum = 0.0
    block_results = []
 
    for block in scores:
        # Local statistics for this block
        block_max = max(block)
        block_sum = sum(math.exp(s - block_max) for s in block)
        block_exp = [math.exp(s - block_max) for s in block]
 
        # Merge with running statistics
        if block_max > running_max:
            # Block has a higher max: rescale previous accumulated sum
            running_sum = running_sum * math.exp(running_max - block_max) + block_sum
        else:
            # Block has a lower max: rescale block's sum
            running_sum = running_sum + block_sum * math.exp(block_max - running_max)
            block_exp = [e * math.exp(block_max - running_max) for e in block_exp]
 
        running_max = max(running_max, block_max)
        block_results.append((block_exp, block_max))
 
    # Final normalization pass
    result = []
    for block_exp, block_max in block_results:
        rescale = math.exp(block_max - running_max)
        result.extend([e * rescale / running_sum for e in block_exp])
 
    return result

The blocked version produces mathematically identical output to standard softmax. When processing a block of K tokens, the algorithm only needs to hold one block in memory at a time (SRAM-sized), update the running statistics, and accumulate the output. No N×N matrix is ever fully materialized.

The merge operation for combining block statistics is the non-obvious piece. When a new block has a higher maximum than the running maximum, all previous exponentials must be rescaled by exp(old_max - new_max). This rescaling can be applied lazily, accumulated in the running sum rather than applied to each previous value, which is what makes the blocked computation exact.

FlashAttention Tiling: The Core Algorithm

With the online softmax in hand, the full FlashAttention forward pass algorithm:

Input: Q, K, V ∈ ℝ^{N×d}  (in HBM)
Output: O ∈ ℝ^{N×d}  (in HBM)
Block sizes: B_r (rows of Q block), B_c (columns of K/V block)
  - Choose B_r, B_c such that one Q block + one K/V block fits in SRAM
 
Algorithm:
1. Divide Q into T_r = ⌈N/B_r⌉ blocks: Q_1, ..., Q_{T_r}
2. Divide K, V into T_c = ⌈N/B_c⌉ blocks: K_1, ..., K_{T_c} and V_1, ..., V_{T_c}
 
3. For i = 1 to T_r:                          // Outer loop: over Q blocks
   a. Load Q_i from HBM to SRAM               // One read of size B_r × d
   b. Initialize O_i = 0, l_i = 0, m_i = -inf // Running statistics in registers
 
   For j = 1 to T_c:                          // Inner loop: over K/V blocks
     c. Load K_j, V_j from HBM to SRAM        // One read of size B_c × d each
 
     d. Compute S_ij = Q_i K_j^T / √d         // Stays in SRAM: B_r × B_c
 
     e. m̃_ij = rowmax(S_ij)                   // Row maximum of current block
        P̃_ij = exp(S_ij - m̃_ij)              // Numerically stable exponentials
        l̃_ij = rowsum(P̃_ij)                  // Row sum for this block
 
     f. Merge with running statistics:
        m_new = max(m_i, m̃_ij)
        l_new = l_i * exp(m_i - m_new) + l̃_ij * exp(m̃_ij - m_new)
 
     g. Update output accumulator:
        O_i = diag(l_new)^{-1} * (
          diag(l_i) * exp(m_i - m_new) * O_i    // Rescale previous output
          + P̃_ij * exp(m̃_ij - m_new) * V_j     // Add contribution from current block
        )
 
     h. Update statistics: m_i = m_new, l_i = l_new
 
   i. Write O_i to HBM                         // One write of size B_r × d
 
Total HBM reads: Q (N×d) + (T_r passes through K, V) = N×d + T_r × N×d
Total HBM writes: O (N×d)
For T_r = N/B_r: Total = O(N²d/M) reads where M = SRAM size
Typically M >> d so this is approximately O(N) HBM accesses for fixed d

In Python, the block computation looks like:

import torch
 
def flash_attention_forward(Q, K, V, block_size: int = 64) -> torch.Tensor:
    """
    FlashAttention forward pass in Python (for illustration, not CUDA-optimized).
    Q, K, V: [N, d] (single head, single batch)
    """
    N, d = Q.shape
    scale = 1.0 / (d ** 0.5)
 
    # Output accumulator and running statistics
    O = torch.zeros_like(Q)       # [N, d]
    l = torch.zeros(N, 1)         # Row normalization factors [N, 1]
    m = torch.full((N, 1), float('-inf'))  # Row maximums [N, 1]
 
    # Block sizes (in practice chosen to fit SRAM)
    B_r = block_size   # Rows of Q block
    B_c = block_size   # Columns of K/V block
 
    T_r = (N + B_r - 1) // B_r   # Number of Q blocks
    T_c = (N + B_c - 1) // B_c   # Number of K/V blocks
 
    for i in range(T_r):
        # Slice of query block [B_r, d]
        r_start, r_end = i * B_r, min((i + 1) * B_r, N)
        Q_i = Q[r_start:r_end]          # Load from "HBM" to "SRAM"
        O_i = O[r_start:r_end]
        l_i = l[r_start:r_end]
        m_i = m[r_start:r_end]
 
        for j in range(T_c):
            # Slice of key/value block [B_c, d]
            c_start, c_end = j * B_c, min((j + 1) * B_c, N)
            K_j = K[c_start:c_end]      # Load from "HBM" to "SRAM"
            V_j = V[c_start:c_end]
 
            # Attention scores for this block: [B_r, B_c]
            S_ij = (Q_i @ K_j.T) * scale
 
            # Block statistics
            m_tilde = S_ij.max(dim=-1, keepdim=True).values  # [B_r, 1]
            P_tilde = torch.exp(S_ij - m_tilde)              # [B_r, B_c]
            l_tilde = P_tilde.sum(dim=-1, keepdim=True)       # [B_r, 1]
 
            # Merge with running statistics
            m_new = torch.maximum(m_i, m_tilde)               # [B_r, 1]
            l_new = (l_i * torch.exp(m_i - m_new) +
                     l_tilde * torch.exp(m_tilde - m_new))    # [B_r, 1]
 
            # Update output: rescale previous + add current block's contribution
            O_i = (l_i * torch.exp(m_i - m_new) * O_i +
                   torch.exp(m_tilde - m_new) * P_tilde @ V_j)
 
            # Update running statistics
            m_i = m_new
            l_i = l_new
 
        # Finalize output for this row block
        O[r_start:r_end] = O_i / l_i    # Normalize
        l[r_start:r_end] = l_i
        m[r_start:r_end] = m_i
 
    return O

The block size B_r × B_c must be chosen such that Q_i (B_r × d) + K_j (B_c × d) + V_j (B_c × d) + S_ij (B_r × B_c) fit in SRAM. On A100 with 192 KB SRAM per SM, and d=128, float16:

  • Q_i: B_r × 128 × 2 bytes
  • K_j + V_j: 2 × B_c × 128 × 2 bytes
  • S_ij: B_r × B_c × 2 bytes
  • Total: 512*(B_r + 2B_c) + 2B_r*B_c ≤ 192,000 bytes

A typical configuration: B_r = B_c = 64 → 512192 + 24096 = 106,496 bytes ≈ 104 KB. Fits comfortably.

FlashAttention2: Warp Partitioning and Reduced NonMatmul FLOPs

FlashAttention-2 (Dao, 2023) achieved a 2x speedup over FlashAttention-1 through three improvements:

Improvement 1: Reduced non-matmul FLOPs

FA1 accumulated the output O_i in a way that required per-row rescaling on every K/V block iteration, an O(B_r) operation per block. FA2 defers this rescaling to the end of the row block, doing it only once instead of T_c times. This reduces the number of non-matrix-multiply operations by a factor of T_c (the number of K/V blocks).

On A100, matrix multiplications run at 312 TFLOPs/s, while non-matmul operations run at approximately 19.5 TFLOPs/s, a 16x gap. Reducing non-matmul operations has disproportionate impact on total compute time.

Improvement 2: Better parallelism across thread blocks

FA1 parallelized the outer loop (Q blocks) across thread blocks. FA2 also parallelizes over batch size and heads, maximizing GPU occupancy. For multi-head attention with H heads and batch size B, FA2 launches B × H thread blocks, each handling one head's row-block loop. This improves GPU utilization especially for smaller sequence lengths where FA1 might underutilize the GPU.

Improvement 3: Warp-level parallelism for the inner loop

Within a thread block, FA2 partitions the inner loop (over K/V columns) across warps using a different partitioning scheme than FA1:

FA1 warp partitioning:
- Each warp processes rows of Q_i independently
- Different warps need to share the output O_i accumulator → synchronization overhead
 
FA2 warp partitioning:
- Each warp processes columns of K_j / V_j
- Each warp accumulates its own partial O_i (different attention score slice)
- Final reduction via shared memory: one synchronization per block instead of per-warp-step

The FA2 scheme reduces shared memory communication overhead per block from O(B_r) per inner step to O(B_r × d) once per outer step. For typical block sizes, this reduces synchronization overhead by a factor of T_c.

The combined effect: FlashAttention-2 achieves 50-73% of A100 theoretical peak FLOPs/s (depending on sequence length and batch size), compared to 25-40% for FA1. The remaining gap from 100% theoretical peak is due to non-matmul operations and memory access latency that cannot be fully hidden.

The Backward Pass: Recomputation Instead of Storage

The backward pass for attention requires the attention probability matrix P (the softmax output) to compute gradients. Standard backpropagation stores P in HBM during the forward pass for use during the backward pass. This is the O(N²) memory requirement of standard attention.

FlashAttention uses a different approach: recomputation. During the backward pass, instead of loading P from HBM (where it was stored during forward), recompute P from Q and K. Since Q and K are much smaller than P (O(Nd) vs O(N²)), this trades memory for compute.

The tradeoff at 4K context, 128 head dimension:

  • Storing P: 4096² × 2 bytes = 32 MB per head
  • Recomputing P: 4096 × 128 × 2 bytes = 1 MB per head + compute time

For long sequences (32K+), recomputation is 1,000x more memory-efficient than storage, at the cost of 30% additional compute during backward. Since memory bandwidth, not compute, is the bottleneck, this tradeoff is favorable.

def flash_attention_backward_sketch(dO, Q, K, V, O, L, M):
    """
    Sketch of FlashAttention backward pass.
    L: log normalization factors from forward pass (stored as running l_i)
    M: row maximums from forward pass (stored as running m_i)
 
    Key: P is RECOMPUTED, not loaded from HBM.
    """
    N, d = Q.shape
    dQ = torch.zeros_like(Q)
    dK = torch.zeros_like(K)
    dV = torch.zeros_like(V)
 
    scale = 1.0 / (d ** 0.5)
 
    # For each block of K/V:
    for j in range(0, N, BLOCK_SIZE):
        K_j = K[j:j+BLOCK_SIZE]
        V_j = V[j:j+BLOCK_SIZE]
 
        # For each block of Q:
        for i in range(0, N, BLOCK_SIZE):
            Q_i = Q[i:i+BLOCK_SIZE]
 
            # Recompute attention scores (not loaded from HBM)
            S_ij = Q_i @ K_j.T * scale
            P_ij = torch.exp(S_ij - M[i:i+BLOCK_SIZE]) / L[i:i+BLOCK_SIZE]
 
            # Compute gradients (standard attention gradient formulas)
            dV_j = P_ij.T @ dO[i:i+BLOCK_SIZE]
            dP_ij = dO[i:i+BLOCK_SIZE] @ V_j.T
            dS_ij = P_ij * (dP_ij - (dO[i:i+BLOCK_SIZE] * O[i:i+BLOCK_SIZE]).sum(-1, keepdim=True))
            dQ[i:i+BLOCK_SIZE] += dS_ij @ K_j * scale
            dK[j:j+BLOCK_SIZE] += dS_ij.T @ Q_i * scale
            dV[j:j+BLOCK_SIZE] += dV_j
 
    return dQ, dK, dV

Performance Numbers: What the Benchmarks Show

FlashAttention achieves significant improvements across the GPU model landscape:

A100 attention speedup (forward + backward, seq_len=2048, head_dim=64):

Implementation FLOPs/s % peak FLOPs Speedup vs standard
Standard PyTorch attention ~95 TFLOPs/s ~30% 1.0x
FlashAttention-1 ~180 TFLOPs/s ~58% 1.9x
FlashAttention-2 ~225 TFLOPs/s ~72% 2.4x
FlashAttention-3 (H100) ~300 TFLOPs/s ~55% of H100 peak 3.2x vs std

End-to-end GPT training speedup (reported by Dao et al.):

  • GPT-2 (1.5B): FA-1 gives 1.7x training speedup vs standard attention
  • GPT-3 style (175B): FA-1 gives 3x memory reduction and 1.5x training speedup
  • Long-context (sequence 16K vs 1K): FA-1 gives 10x memory reduction, 6x attention speedup

Memory reduction (no intermediate N×N matrix):

At seq_len=32K, d=128, float16:

  • Standard attention: 32,768² × 2 bytes = 2 GB per head per layer
  • FlashAttention: ~0 GB for intermediates (O(N) for running statistics only)

The memory reduction is the more significant practical benefit for long-context training. Without FlashAttention, training with 32K context on a 70B model is infeasible on any existing GPU cluster. With FlashAttention, it requires only the O(N) storage for Q, K, V, and O. No N×N intermediate is needed.

Ring Attention: MultiGPU LongContext Scaling

Ring Attention (Liu et al., 2023) extends FlashAttention to multi-GPU settings, enabling training and inference at context lengths beyond what fits on a single GPU.

The approach: distribute the sequence across P GPUs, with each GPU holding N/P tokens. For attention, each GPU needs to compute its portion of the attention output using Q from its own slice and K/V from all P slices. Ring Attention arranges GPUs in a ring topology where each GPU passes its K/V block to the next GPU while simultaneously processing the K/V block it received from the previous GPU.

GPU 0:                GPU 1:                GPU 2:                GPU 3:
Holds Q[0:N/P]        Holds Q[N/P:2N/P]     Holds Q[2N/P:3N/P]    Holds Q[3N/P:N]
     ↓                      ↓                      ↓                      ↓
Step 1: Each GPU computes attention with its own K/V
Step 2: Each GPU sends its K/V to next in ring, receives K/V from previous
        → compute → pass → compute → pass → ...
After P steps: each GPU has processed all K/V blocks for its Q slice
 
Communication: each GPU sends/receives N/P × d × 2 bytes per step × (P-1) steps
             = (N × d × 2) bytes total = same as the K/V size (not N² bandwidth)

Ring Attention achieves linear memory scaling with sequence length: adding more GPUs scales the total context length proportionally. For a 70B model with 80 GB VRAM per GPU:

  • Single GPU: ~32K context maximum (after FA memory reduction)
  • 4 GPUs in ring: ~128K context
  • 8 GPUs: ~256K context
  • 16 GPUs: ~512K context

The communication overhead is O(N×d), the same order as the K/V matrix size, which hides well behind the attention computation time for long sequences on high-bandwidth interconnects (NVLink: 600 GB/s bidirectional).

Ring Attention is the mechanism behind long-context training in most frontier model labs. Training Claude or GPT-4 at 100K+ context required multi-GPU attention kernels with this communication pattern.

Practical Usage: When FlashAttention Helps Most

FlashAttention is available as a drop-in replacement for standard attention in PyTorch through torch.nn.functional.scaled_dot_product_attention (FA2 backend on CUDA ≥ 11.6) and the flash-attn package.

import torch
import torch.nn.functional as F
 
# Method 1: PyTorch native (uses FA2 backend automatically when available)
def efficient_attention(Q, K, V, attn_mask=None, dropout_p=0.0, is_causal=False):
    """
    PyTorch will select FlashAttention backend automatically if:
    - Running on CUDA
    - Head dim ≤ 256
    - Inputs are float16 or bfloat16
    - No custom attention mask (or causal mask only)
    """
    with torch.backends.cuda.sdp_kernel(
        enable_flash=True,
        enable_math=False,   # Disable naive implementation
        enable_mem_efficient=True,
    ):
        return F.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=attn_mask,
            dropout_p=dropout_p,
            is_causal=is_causal,
        )
 
# Method 2: Direct flash-attn package
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input
 
def flash_attention_varlen(Q, K, V, lengths):
    """Variable-length sequences with padding removal."""
    # Unpad for efficient processing
    Q_unpad, indices, cu_seqlens, max_seqlen = unpad_input(Q, lengths)
    K_unpad, _, _, _ = unpad_input(K, lengths)
    V_unpad, _, _, _ = unpad_input(V, lengths)
 
    # FlashAttention with variable length handling
    output_unpad = flash_attn_func(
        Q_unpad, K_unpad, V_unpad,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max_seqlen,
        max_seqlen_k=max_seqlen,
        causal=True,
    )
 
    # Repad
    return pad_input(output_unpad, indices, Q.shape[0], Q.shape[1])

FlashAttention helps most when:

  • Sequence length ≥ 1024 (below this, overhead doesn't justify gains)
  • Running float16 or bfloat16 (required for FA2; float32 uses the math backend)
  • Using causal attention (decoder-only models): FA optimizes the causal mask efficiently
  • Memory-constrained training at long context lengths

FlashAttention helps least when:

  • Short sequences (<512 tokens): standard attention may be faster due to lower overhead
  • Float32 precision required: FA2 only supports 16-bit
  • Complex custom attention patterns: sparse attention, sliding window with non-trivial patterns
  • CPU inference: FA is a CUDA kernel with no benefit on CPU

Key Takeaways

  • FlashAttention is IO-aware attention: it reduces HBM reads/writes from O(N²) to approximately O(N) by tiling the computation into blocks that fit in SRAM. The memory reduction at 32K context is approximately 1000x (from 2 GB to ~2 MB for attention intermediates per head), making long-context training and inference feasible.

  • The online softmax algorithm is the mathematical enabler: it computes the exact same softmax output as the standard two-pass algorithm but processes data in blocks using running maximum and normalization statistics. The merge operation between blocks ensures numerical correctness without ever materializing the full N×N attention matrix.

  • FlashAttention-2 achieves 2x speedup over FA-1 through three mechanisms: deferred rescaling (reduces non-matmul FLOPs by T_c times), better thread block parallelism across batch/heads, and warp-level partitioning over columns rather than rows (reduces shared memory synchronization).

  • On A100 GPUs, FlashAttention-2 achieves 50-73% of theoretical peak FLOPs/s compared to 25-40% for FA-1 and ~30% for standard attention. The remaining gap from 100% peak is due to irreducible non-matmul operations and memory latency that cannot be fully hidden by tiling.

  • The backward pass uses recomputation instead of storage: P (the N×N softmax output) is recomputed during backpropagation rather than stored from the forward pass. This trades ~30% extra compute during backward for the elimination of 2 GB per head of HBM storage at 32K context, and is always the right tradeoff for long sequences.

  • Ring Attention extends FlashAttention to multi-GPU settings by distributing sequences across GPUs in a ring topology. Communication overhead is O(N×d), linear in sequence length, enabling linear scaling of maximum context length with GPU count. This is the mechanism behind 128K+ context training in frontier models.

FAQ

What problem does FlashAttention solve?

FlashAttention solves the memory bandwidth bottleneck in standard transformer attention. Standard attention computes the full N×N attention matrix, which must be written to and read from high-bandwidth memory (HBM): 134 MB at 4K context. This HBM traffic, not the arithmetic operations, is the bottleneck. FlashAttention uses a tiling algorithm that processes the attention computation in small blocks that fit in the faster on-chip SRAM (a 10x bandwidth improvement over HBM), computing the exact same result without ever materializing the full N×N matrix. The result is ~25x fewer HBM accesses, translating to 2-3x overall speedup and dramatically reduced memory usage, eliminating the O(N²) memory requirement that made long-context training prohibitively expensive.

What is the online softmax trick in FlashAttention?

The online softmax trick is the mathematical technique that makes FlashAttention's block processing produce numerically correct results. Standard softmax requires two passes over all N scores: one to find the maximum (for numerical stability), one to compute exponentials and normalize. The online algorithm maintains two running statistics (the current maximum m and the current normalization sum l) and updates them as each block of scores is processed. When a new block has a higher maximum, previous accumulated values are rescaled by exp(old_max - new_max). When the new block's maximum is lower, the block's contribution is rescaled. After all blocks, the correct softmax normalization is known without ever materializing the full N-score vector. This is mathematically identical to standard softmax: same output, different computation order that enables the block processing FlashAttention requires.

How much faster is FlashAttention compared to standard attention?

FlashAttention-2 achieves approximately 2-3x end-to-end speedup compared to standard PyTorch attention for typical sequence lengths (2K-8K). At A100 GPU peak utilization, FA-2 reaches 50-73% of theoretical peak FLOPs/s (225 TFLOPs/s on A100) compared to 25-40% for FA-1 and ~30% for standard attention. The speedup increases with sequence length: at 32K context, FlashAttention is approximately 5-8x faster for the attention computation alone. In end-to-end GPT-scale training, FlashAttention reduces training time by 1.5-3x and enables 3-10x reduction in peak memory usage for long-context experiments.

What is Ring Attention and how does it scale context length?

Ring Attention is a multi-GPU extension of FlashAttention that enables context lengths beyond the memory capacity of a single GPU. The sequence is distributed across P GPUs, each holding N/P tokens. GPUs are arranged in a ring topology: during attention computation, each GPU processes its Q block against the K/V block it currently holds, then passes its K/V block to the next GPU and receives the previous GPU's K/V block. After P steps, each GPU has processed its Q block against all K/V blocks. The communication cost is O(N×d) total, linear in sequence length rather than quadratic, which is the same order as reading K/V once. Ring Attention enables linear scaling: each additional GPU multiplies the maximum context length by the ratio of GPUs added, allowing 128K+ context training by distributing across 8-16 GPUs on high-bandwidth NVLink interconnects.

FlashAttention is not a trick or an approximation. It computes the exact same attention function as the standard algorithm, using the same mathematical operations in a different order. The O(N²) HBM access cost is not a consequence of the attention mechanism's mathematics. It is a consequence of how naive implementations materialize intermediate computations. FlashAttention reveals that the intermediate materialization was never necessary.

The tiling insight applies broadly beyond attention: any computation that involves large intermediate matrices can potentially be restructured to avoid HBM round-trips. The online softmax algorithm that makes FlashAttention possible is one instance of a general technique for computing aggregations over blocked data without materialization.

For practitioners, the action items are straightforward: use torch.nn.functional.scaled_dot_product_attention which selects the FlashAttention backend automatically, or the flash-attn package for more control. At sequence lengths above 1K and in float16, you get the speedup without any changes to your model architecture.

The longer-term implication: FlashAttention makes 100K-1M context lengths economically viable. The models trained on these context lengths have qualitatively different capabilities from models with 4K-8K context. The efficiency improvement is not faster training alone. It expands the space of possible models.

Written & published by Chaitanya Prabuddha