Triton Flash Attention

Chapter 4 — O(N) memory attention with tiling and online softmax

The Problem with Standard Attention

Scaled dot-product attention computes: O = softmax(QK^T / √d) V

The bottleneck is the QK^T matrix — it's (N × N) where N is the sequence length. For a 4096-token context with float32, this is 4096 × 4096 × 4 bytes = 64 MB. For a 32K context, it's 4 GB. For 128K (common in modern LLMs), it's 64 GB — exceeding GPU memory entirely.

Beyond the memory cost, standard attention reads and writes this matrix multiple times (once to compute, once for softmax, once for the weighted sum) — burning memory bandwidth even when the matrix fits.

The Flash Attention Insight: Tiling + Online Softmax

FlashAttention (Dao et al., 2022) avoids materializing the full attention matrix by computing attention in tiles and using an online softmax algorithm to merge partial results:

Online Softmax: the Key Algorithm

To compute softmax over a sequence we haven't fully seen yet, we maintain three running statistics:

  • m — running maximum of all logits seen so far
  • l — running sum of exp(logit - m) seen so far
  • O — running weighted sum of V values seen so far

When we see a new tile of K/V pairs, we update all three in a numerically stable way. After processing all tiles, O / l gives the exact softmax-weighted output. No N×N matrix ever exists in memory.

The Update Rule

When processing a new K/V tile, suppose the new maximum is m_new = max(m_old, max(S_new)):

  • Rescale factor: α = exp(m_old - m_new) (≤ 1, rescales old contributions)
  • New softmax weights: p = exp(S_new - m_new)
  • Updated sum: l_new = α × l_old + sum(p)
  • Updated output: O_new = α × O_old + p @ V

After all tiles, O / l equals the exact softmax attention output. The α rescaling ensures previously accumulated values are consistently normalized.

Flash Attention Kernel

flash_attention.py
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
import torch
import triton
import triton.language as tl
import math

@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    seq_len, head_dim,
    stride_qn, stride_qd,
    stride_kn, stride_kd,
    stride_vn, stride_vd,
    stride_on, stride_od,
    scale,                    # 1 / sqrt(head_dim)
    BLOCK_Q: tl.constexpr,   # query tile size
    BLOCK_KV: tl.constexpr,  # key/value tile size
    HEAD_DIM: tl.constexpr,  # must equal head_dim (constexpr for tl.dot)
):
    """
    Flash Attention forward pass (single-head, float32).

    Key insight: instead of materializing the full (N x N) attention matrix,
    we iterate over K/V tiles and maintain running (max, sum_exp, output)
    statistics to compute the final softmax-weighted output in one pass.

    Memory: O(N) instead of O(N^2).
    """
    # One program per query tile
    q_tile = tl.program_id(0)
    q_off  = q_tile * BLOCK_Q + tl.arange(0, BLOCK_Q)  # [BLOCK_Q]
    d_off  = tl.arange(0, HEAD_DIM)                      # [HEAD_DIM]
    q_mask = q_off < seq_len

    # Load Q tile: [BLOCK_Q, HEAD_DIM]
    Q = tl.load(
        Q_ptr + q_off[:, None] * stride_qn + d_off[None, :] * stride_qd,
        mask=q_mask[:, None],
        other=0.0,
    )

    # Online softmax state (per query position)
    # m: running maximum of attention logits seen so far
    # l: running sum of exp(logit - m) seen so far
    # O: running weighted-sum of V
    m = tl.full([BLOCK_Q], float("-inf"), dtype=tl.float32)
    l = tl.zeros([BLOCK_Q], dtype=tl.float32)
    O = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32)

    # Iterate over all K/V tiles
    for kv_start in range(0, seq_len, BLOCK_KV):
        kv_off  = kv_start + tl.arange(0, BLOCK_KV)  # [BLOCK_KV]
        kv_mask = kv_off < seq_len

        # Load K tile: [BLOCK_KV, HEAD_DIM]
        K = tl.load(
            K_ptr + kv_off[:, None] * stride_kn + d_off[None, :] * stride_kd,
            mask=kv_mask[:, None],
            other=0.0,
        )

        # Load V tile: [BLOCK_KV, HEAD_DIM]
        V = tl.load(
            V_ptr + kv_off[:, None] * stride_vn + d_off[None, :] * stride_vd,
            mask=kv_mask[:, None],
            other=0.0,
        )

        # Attention scores: S = Q @ K^T * scale  [BLOCK_Q, BLOCK_KV]
        S = tl.dot(Q, tl.trans(K)) * scale

        # Mask out-of-bounds key positions
        S = tl.where(kv_mask[None, :], S, float("-inf"))

        # --- Online softmax update ---
        # New running max
        m_new = tl.maximum(m, tl.max(S, axis=1))

        # Rescale factor for previously accumulated values
        alpha = tl.exp(m - m_new)  # [BLOCK_Q], close to 1 when max doesn't change much

        # Softmax numerators for this tile: p = exp(S - m_new)
        p = tl.exp(S - m_new[:, None])  # [BLOCK_Q, BLOCK_KV]

        # Update running sum: l = alpha * l_old + sum(p_new)
        l = alpha * l + tl.sum(p, axis=1)

        # Update running output: O = alpha * O_old + p @ V
        O = O * alpha[:, None] + tl.dot(p, V)

        m = m_new

    # Final normalization: divide by the sum of all softmax weights
    O = O / l[:, None]

    # Store output tile
    tl.store(
        O_ptr + q_off[:, None] * stride_on + d_off[None, :] * stride_od,
        O,
        mask=q_mask[:, None],
    )


def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
    """
    Flash Attention: O(N) memory, runs on a real GPU via Triton.
    Q, K, V: (seq_len, head_dim) float32 tensors on CUDA.
    """
    assert Q.is_cuda and Q.dtype == torch.float32
    N, D = Q.shape
    assert D in (16, 32, 64), f"HEAD_DIM must be 16/32/64, got {D}"

    O = torch.zeros(N, D, device=Q.device, dtype=torch.float32)
    scale = 1.0 / math.sqrt(D)

    BLOCK_Q, BLOCK_KV = 32, 32
    grid = (triton.cdiv(N, BLOCK_Q),)

    flash_attention_kernel[grid](
        Q, K, V, O,
        N, D,
        Q.stride(0), Q.stride(1),
        K.stride(0), K.stride(1),
        V.stride(0), V.stride(1),
        O.stride(0), O.stride(1),
        scale,
        BLOCK_Q=BLOCK_Q, BLOCK_KV=BLOCK_KV, HEAD_DIM=D,
    )
    return O


# --- Verify against standard attention ---
torch.manual_seed(0)
N, D = 256, 64   # short sequence, head_dim=64

Q = torch.randn(N, D, device="cuda", dtype=torch.float32)
K = torch.randn(N, D, device="cuda", dtype=torch.float32)
V = torch.randn(N, D, device="cuda", dtype=torch.float32)

# Reference: materialize full N x N attention matrix
scale = 1.0 / math.sqrt(D)
scores  = (Q @ K.T) * scale          # (N, N)
weights = torch.softmax(scores, dim=-1)  # (N, N)
ref_out = weights @ V                # (N, D)

# Flash attention (O(N) memory)
fa_out = flash_attention(Q, K, V)

max_err = (fa_out - ref_out).abs().max().item()
print(f"Max error vs standard attention: {max_err:.6f}")
print(f"Match (atol=1e-4): {torch.allclose(fa_out, ref_out, atol=1e-4)}")
print(f"FA output shape: {fa_out.shape}")
print(f"Memory saved vs standard: {N*N*4 / (N*D*4):.1f}x (N={N}, D={D})")

Memory Complexity

Standard attention stores the (N × N) attention matrix: O(N²) memory.

Flash attention stores per-program state only: m[BLOCK_Q], l[BLOCK_Q], O[BLOCK_Q × D]. This is O(N × D) — linear in sequence length. For N=4096, D=64, that's 1MB vs 64MB for the full attention matrix.

This is a pedagogical implementation

Production FlashAttention (FA2, FA3) adds: multi-head batching, causal masking, dropout, BF16/FP16 mixed precision, and advanced tiling strategies for H100 Tensor Cores. The kernel above demonstrates the core algorithm correctly but isn't optimized for throughput. For production, use torch.nn.functional.scaled_dot_product_attention (backed by FA2 on NVIDIA GPUs) or the flash-attn package.

Why This Matters for LLMs

Flash Attention is now standard in every major LLM implementation — GPT-4, Llama, Mistral, Gemma. Without it, training and inference at long context lengths (16K–128K tokens) would be infeasible on current GPU memory. Understanding Flash Attention gives you insight into the engineering that makes modern AI systems possible.

From here, you can explore:

  • FlashAttention-2 — improved parallelism and work partitioning
  • FlashAttention-3 — Hopper-specific TMA and warp-specialization
  • PagedAttention (vLLM) — paged KV cache for efficient LLM serving
  • Triton's official FlashAttention tutorial for a production-quality implementation