Triton Flash Attention
Chapter 4 — O(N) memory attention with tiling and online softmax
The Problem with Standard Attention
Scaled dot-product attention computes: O = softmax(QK^T / √d) V
The bottleneck is the QK^T matrix — it's (N × N) where N is the sequence length. For a 4096-token context with float32, this is 4096 × 4096 × 4 bytes = 64 MB. For a 32K context, it's 4 GB. For 128K (common in modern LLMs), it's 64 GB — exceeding GPU memory entirely.
Beyond the memory cost, standard attention reads and writes this matrix multiple times (once to compute, once for softmax, once for the weighted sum) — burning memory bandwidth even when the matrix fits.
The Flash Attention Insight: Tiling + Online Softmax
FlashAttention (Dao et al., 2022) avoids materializing the full attention matrix by computing attention in tiles and using an online softmax algorithm to merge partial results:
To compute softmax over a sequence we haven't fully seen yet, we maintain three running statistics:
m— running maximum of all logits seen so farl— running sum ofexp(logit - m)seen so farO— running weighted sum of V values seen so far
When we see a new tile of K/V pairs, we update all three in a numerically stable way. After processing all tiles, O / l gives the exact softmax-weighted output. No N×N matrix ever exists in memory.
The Update Rule
When processing a new K/V tile, suppose the new maximum is m_new = max(m_old, max(S_new)):
- Rescale factor:
α = exp(m_old - m_new)(≤ 1, rescales old contributions) - New softmax weights:
p = exp(S_new - m_new) - Updated sum:
l_new = α × l_old + sum(p) - Updated output:
O_new = α × O_old + p @ V
After all tiles, O / l equals the exact softmax attention output. The α rescaling ensures previously accumulated values are consistently normalized.
Flash Attention Kernel
import torch import triton import triton.language as tl import math @triton.jit def flash_attention_kernel( Q_ptr, K_ptr, V_ptr, O_ptr, seq_len, head_dim, stride_qn, stride_qd, stride_kn, stride_kd, stride_vn, stride_vd, stride_on, stride_od, scale, # 1 / sqrt(head_dim) BLOCK_Q: tl.constexpr, # query tile size BLOCK_KV: tl.constexpr, # key/value tile size HEAD_DIM: tl.constexpr, # must equal head_dim (constexpr for tl.dot) ): """ Flash Attention forward pass (single-head, float32). Key insight: instead of materializing the full (N x N) attention matrix, we iterate over K/V tiles and maintain running (max, sum_exp, output) statistics to compute the final softmax-weighted output in one pass. Memory: O(N) instead of O(N^2). """ # One program per query tile q_tile = tl.program_id(0) q_off = q_tile * BLOCK_Q + tl.arange(0, BLOCK_Q) # [BLOCK_Q] d_off = tl.arange(0, HEAD_DIM) # [HEAD_DIM] q_mask = q_off < seq_len # Load Q tile: [BLOCK_Q, HEAD_DIM] Q = tl.load( Q_ptr + q_off[:, None] * stride_qn + d_off[None, :] * stride_qd, mask=q_mask[:, None], other=0.0, ) # Online softmax state (per query position) # m: running maximum of attention logits seen so far # l: running sum of exp(logit - m) seen so far # O: running weighted-sum of V m = tl.full([BLOCK_Q], float("-inf"), dtype=tl.float32) l = tl.zeros([BLOCK_Q], dtype=tl.float32) O = tl.zeros([BLOCK_Q, HEAD_DIM], dtype=tl.float32) # Iterate over all K/V tiles for kv_start in range(0, seq_len, BLOCK_KV): kv_off = kv_start + tl.arange(0, BLOCK_KV) # [BLOCK_KV] kv_mask = kv_off < seq_len # Load K tile: [BLOCK_KV, HEAD_DIM] K = tl.load( K_ptr + kv_off[:, None] * stride_kn + d_off[None, :] * stride_kd, mask=kv_mask[:, None], other=0.0, ) # Load V tile: [BLOCK_KV, HEAD_DIM] V = tl.load( V_ptr + kv_off[:, None] * stride_vn + d_off[None, :] * stride_vd, mask=kv_mask[:, None], other=0.0, ) # Attention scores: S = Q @ K^T * scale [BLOCK_Q, BLOCK_KV] S = tl.dot(Q, tl.trans(K)) * scale # Mask out-of-bounds key positions S = tl.where(kv_mask[None, :], S, float("-inf")) # --- Online softmax update --- # New running max m_new = tl.maximum(m, tl.max(S, axis=1)) # Rescale factor for previously accumulated values alpha = tl.exp(m - m_new) # [BLOCK_Q], close to 1 when max doesn't change much # Softmax numerators for this tile: p = exp(S - m_new) p = tl.exp(S - m_new[:, None]) # [BLOCK_Q, BLOCK_KV] # Update running sum: l = alpha * l_old + sum(p_new) l = alpha * l + tl.sum(p, axis=1) # Update running output: O = alpha * O_old + p @ V O = O * alpha[:, None] + tl.dot(p, V) m = m_new # Final normalization: divide by the sum of all softmax weights O = O / l[:, None] # Store output tile tl.store( O_ptr + q_off[:, None] * stride_on + d_off[None, :] * stride_od, O, mask=q_mask[:, None], ) def flash_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: """ Flash Attention: O(N) memory, runs on a real GPU via Triton. Q, K, V: (seq_len, head_dim) float32 tensors on CUDA. """ assert Q.is_cuda and Q.dtype == torch.float32 N, D = Q.shape assert D in (16, 32, 64), f"HEAD_DIM must be 16/32/64, got {D}" O = torch.zeros(N, D, device=Q.device, dtype=torch.float32) scale = 1.0 / math.sqrt(D) BLOCK_Q, BLOCK_KV = 32, 32 grid = (triton.cdiv(N, BLOCK_Q),) flash_attention_kernel[grid]( Q, K, V, O, N, D, Q.stride(0), Q.stride(1), K.stride(0), K.stride(1), V.stride(0), V.stride(1), O.stride(0), O.stride(1), scale, BLOCK_Q=BLOCK_Q, BLOCK_KV=BLOCK_KV, HEAD_DIM=D, ) return O # --- Verify against standard attention --- torch.manual_seed(0) N, D = 256, 64 # short sequence, head_dim=64 Q = torch.randn(N, D, device="cuda", dtype=torch.float32) K = torch.randn(N, D, device="cuda", dtype=torch.float32) V = torch.randn(N, D, device="cuda", dtype=torch.float32) # Reference: materialize full N x N attention matrix scale = 1.0 / math.sqrt(D) scores = (Q @ K.T) * scale # (N, N) weights = torch.softmax(scores, dim=-1) # (N, N) ref_out = weights @ V # (N, D) # Flash attention (O(N) memory) fa_out = flash_attention(Q, K, V) max_err = (fa_out - ref_out).abs().max().item() print(f"Max error vs standard attention: {max_err:.6f}") print(f"Match (atol=1e-4): {torch.allclose(fa_out, ref_out, atol=1e-4)}") print(f"FA output shape: {fa_out.shape}") print(f"Memory saved vs standard: {N*N*4 / (N*D*4):.1f}x (N={N}, D={D})")
Memory Complexity
Standard attention stores the (N × N) attention matrix: O(N²) memory.
Flash attention stores per-program state only: m[BLOCK_Q], l[BLOCK_Q], O[BLOCK_Q × D]. This is O(N × D) — linear in sequence length. For N=4096, D=64, that's 1MB vs 64MB for the full attention matrix.
Production FlashAttention (FA2, FA3) adds: multi-head batching, causal masking, dropout, BF16/FP16 mixed precision, and advanced tiling strategies for H100 Tensor Cores. The kernel above demonstrates the core algorithm correctly but isn't optimized for throughput. For production, use torch.nn.functional.scaled_dot_product_attention (backed by FA2 on NVIDIA GPUs) or the flash-attn package.
Why This Matters for LLMs
Flash Attention is now standard in every major LLM implementation — GPT-4, Llama, Mistral, Gemma. Without it, training and inference at long context lengths (16K–128K tokens) would be infeasible on current GPU memory. Understanding Flash Attention gives you insight into the engineering that makes modern AI systems possible.
From here, you can explore:
- FlashAttention-2 — improved parallelism and work partitioning
- FlashAttention-3 — Hopper-specific TMA and warp-specialization
- PagedAttention (vLLM) — paged KV cache for efficient LLM serving
- Triton's official FlashAttention tutorial for a production-quality implementation