Triton Fused Softmax
Chapter 3 — One kernel pass for numerically stable softmax with tl.max, tl.exp, and tl.sum
Why Fuse Softmax?
In PyTorch, torch.softmax(x, dim=1) executes as a chain of separate operations:
- Compute max of each row (reduction)
- Subtract max and exponentiate (element-wise)
- Sum the exponentials (reduction)
- Divide by the sum (element-wise)
Each step launches a separate GPU kernel, and each kernel reads/writes the entire tensor to global memory. For a (1024, 512) matrix of float32, that's 4 round trips through memory.
A fused Triton kernel does all four steps in one pass, keeping the data in registers. This is faster because memory bandwidth — not arithmetic — is the bottleneck for softmax on modern GPUs.
Most real-world Triton kernels are fusions — combining multiple elementwise or reduction ops that PyTorch would otherwise execute separately. FlashAttention is the extreme case: it fuses the entire scaled dot-product attention (QK^T, softmax, AV) into one kernel.
Numerical Stability: Why Subtract the Max
Naïve softmax computes exp(x_i) / sum(exp(x_j)) directly. When x_i is large (e.g. 100), exp(100) = 2.69e43 — well beyond float32's maximum of ~3.4e38, causing overflow.
The fix is to subtract the row maximum first:exp(x_i - max) / sum(exp(x_j - max)). This is mathematically equivalent (the constant cancels) but keeps all values in (-∞, 0], so exp always returns values in (0, 1] — no overflow possible.
The Fused Softmax Kernel
import torch import triton import triton.language as tl @triton.jit def softmax_kernel( output_ptr, input_ptr, input_row_stride, # stride between rows of the input output_row_stride, # stride between rows of the output n_cols, # number of columns (vocab size or sequence length) BLOCK_SIZE: tl.constexpr, # must be >= n_cols, rounded up to next power of 2 ): """ Fused softmax: one kernel pass instead of four PyTorch ops. Each program instance handles one row of the input matrix. """ # One program per row row_idx = tl.program_id(0) # Pointer to the start of this row row_start_ptr = input_ptr + row_idx * input_row_stride # Column offsets for loading the full row col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row, masking columns beyond n_cols with -inf # (so they don't affect max or sum) mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=float("-inf")) # Step 1: subtract max for numerical stability # Without this, exp() of large values overflows to inf row_max = tl.max(row, axis=0) row = row - row_max # Step 2: exponentiate row_exp = tl.exp(row) # Step 3: sum of exponentials (normalization factor) row_sum = tl.sum(row_exp, axis=0) # Step 4: normalize softmax_output = row_exp / row_sum # Store result (only valid columns) output_row_start_ptr = output_ptr + row_idx * output_row_stride tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask) def fused_softmax(x: torch.Tensor) -> torch.Tensor: assert x.is_cuda and x.ndim == 2, "Expected 2D CUDA tensor" n_rows, n_cols = x.shape # BLOCK_SIZE must be a power of 2 and >= n_cols BLOCK_SIZE = triton.next_power_of_2(n_cols) out = torch.empty_like(x) # Launch one program per row softmax_kernel[(n_rows,)]( out, x, x.stride(0), out.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, ) return out # --- Test and compare to PyTorch --- torch.manual_seed(42) x = torch.randn(1024, 512, device="cuda") y_triton = fused_softmax(x) y_torch = torch.softmax(x, dim=1) max_err = (y_triton - y_torch).abs().max().item() print(f"Max error vs torch.softmax: {max_err:.2e}") # Verify rows sum to 1 row_sums = y_triton.sum(dim=1) print(f"Rows sum to 1: {torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)}") print(f"Input shape: {x.shape} Output shape: {y_triton.shape}")
One Program Per Row
The grid is (n_rows,) — one program per row. This is simpler than GEMM because rows are independent: each program can read its row, compute softmax, and write the result without coordinating with other programs.
BLOCK_SIZE = triton.next_power_of_2(n_cols) ensures the block is large enough to hold the entire row in one load. For n_cols=512, BLOCK_SIZE=512. For n_cols=1000, BLOCK_SIZE=1024.
Each program instance holds BLOCK_SIZE floats in registers. GPUs have a fixed register file per SM, so very large BLOCK_SIZE values (e.g. 32768 for vocabulary-size softmax in an LLM) require careful occupancy tuning. For vocab-size softmax, chunked approaches or Triton's @triton.autotune are used in production.
tl.max, tl.sum, tl.exp
Triton's reduction primitives operate over a loaded tile in registers:
tl.max(row, axis=0)— scalar: maximum value across all elementstl.sum(row_exp, axis=0)— scalar: sum of all elementstl.exp(row)— element-wise: same shape as input
Because these operate on register-resident data (not global memory), they're extremely fast — no memory traffic beyond the initial load and final store.