Triton Fused Softmax

Chapter 3 — One kernel pass for numerically stable softmax with tl.max, tl.exp, and tl.sum

Why Fuse Softmax?

In PyTorch, torch.softmax(x, dim=1) executes as a chain of separate operations:

  1. Compute max of each row (reduction)
  2. Subtract max and exponentiate (element-wise)
  3. Sum the exponentials (reduction)
  4. Divide by the sum (element-wise)

Each step launches a separate GPU kernel, and each kernel reads/writes the entire tensor to global memory. For a (1024, 512) matrix of float32, that's 4 round trips through memory.

A fused Triton kernel does all four steps in one pass, keeping the data in registers. This is faster because memory bandwidth — not arithmetic — is the bottleneck for softmax on modern GPUs.

Kernel fusion is the primary use case for Triton

Most real-world Triton kernels are fusions — combining multiple elementwise or reduction ops that PyTorch would otherwise execute separately. FlashAttention is the extreme case: it fuses the entire scaled dot-product attention (QK^T, softmax, AV) into one kernel.

Numerical Stability: Why Subtract the Max

Naïve softmax computes exp(x_i) / sum(exp(x_j)) directly. When x_i is large (e.g. 100), exp(100) = 2.69e43 — well beyond float32's maximum of ~3.4e38, causing overflow.

The fix is to subtract the row maximum first:exp(x_i - max) / sum(exp(x_j - max)). This is mathematically equivalent (the constant cancels) but keeps all values in (-∞, 0], so exp always returns values in (0, 1] — no overflow possible.

The Fused Softmax Kernel

fused_softmax.py
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
import torch
import triton
import triton.language as tl

@triton.jit
def softmax_kernel(
    output_ptr, input_ptr,
    input_row_stride,   # stride between rows of the input
    output_row_stride,  # stride between rows of the output
    n_cols,             # number of columns (vocab size or sequence length)
    BLOCK_SIZE: tl.constexpr,  # must be >= n_cols, rounded up to next power of 2
):
    """
    Fused softmax: one kernel pass instead of four PyTorch ops.
    Each program instance handles one row of the input matrix.
    """
    # One program per row
    row_idx = tl.program_id(0)

    # Pointer to the start of this row
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # Column offsets for loading the full row
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

    # Load the row, masking columns beyond n_cols with -inf
    # (so they don't affect max or sum)
    mask = col_offsets < n_cols
    row = tl.load(input_ptrs, mask=mask, other=float("-inf"))

    # Step 1: subtract max for numerical stability
    # Without this, exp() of large values overflows to inf
    row_max = tl.max(row, axis=0)
    row = row - row_max

    # Step 2: exponentiate
    row_exp = tl.exp(row)

    # Step 3: sum of exponentials (normalization factor)
    row_sum = tl.sum(row_exp, axis=0)

    # Step 4: normalize
    softmax_output = row_exp / row_sum

    # Store result (only valid columns)
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask)


def fused_softmax(x: torch.Tensor) -> torch.Tensor:
    assert x.is_cuda and x.ndim == 2, "Expected 2D CUDA tensor"
    n_rows, n_cols = x.shape

    # BLOCK_SIZE must be a power of 2 and >= n_cols
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    out = torch.empty_like(x)
    # Launch one program per row
    softmax_kernel[(n_rows,)](
        out, x,
        x.stride(0), out.stride(0),
        n_cols,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return out


# --- Test and compare to PyTorch ---
torch.manual_seed(42)
x = torch.randn(1024, 512, device="cuda")

y_triton = fused_softmax(x)
y_torch  = torch.softmax(x, dim=1)

max_err = (y_triton - y_torch).abs().max().item()
print(f"Max error vs torch.softmax: {max_err:.2e}")

# Verify rows sum to 1
row_sums = y_triton.sum(dim=1)
print(f"Rows sum to 1: {torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)}")
print(f"Input shape: {x.shape}  Output shape: {y_triton.shape}")

One Program Per Row

The grid is (n_rows,) — one program per row. This is simpler than GEMM because rows are independent: each program can read its row, compute softmax, and write the result without coordinating with other programs.

BLOCK_SIZE = triton.next_power_of_2(n_cols) ensures the block is large enough to hold the entire row in one load. For n_cols=512, BLOCK_SIZE=512. For n_cols=1000, BLOCK_SIZE=1024.

BLOCK_SIZE limit

Each program instance holds BLOCK_SIZE floats in registers. GPUs have a fixed register file per SM, so very large BLOCK_SIZE values (e.g. 32768 for vocabulary-size softmax in an LLM) require careful occupancy tuning. For vocab-size softmax, chunked approaches or Triton's @triton.autotune are used in production.

tl.max, tl.sum, tl.exp

Triton's reduction primitives operate over a loaded tile in registers:

  • tl.max(row, axis=0) — scalar: maximum value across all elements
  • tl.sum(row_exp, axis=0) — scalar: sum of all elements
  • tl.exp(row) — element-wise: same shape as input

Because these operate on register-resident data (not global memory), they're extremely fast — no memory traffic beyond the initial load and final store.