Triton Matrix Multiplication

Chapter 2 — Tiled GEMM: 2D grids, tl.dot, and Tensor Core acceleration

Why Tiling Matters

Matrix multiplication is the foundational operation in every neural network — every linear layer, attention score, and projection is a GEMM (General Matrix Multiply). Naively multiplying two (M, K) and (K, N) matrices requires M × N × K operations and reads the entire matrices multiple times from slow GPU global memory.

The key insight behind fast GEMM is tiling: divide the output matrix into small tiles, load each tile's required input data into fast on-chip memory (registers/shared memory), and perform the computation there. This minimizes expensive memory traffic.

Tensor Cores via tl.dot

tl.dot(a, b) automatically uses NVIDIA Tensor Cores when available — the specialized matrix-multiply hardware in modern GPUs. On a T4, this provides up to 65 TFLOPS for FP16, versus ~8 TFLOPS for regular CUDA cores. Triton handles the precision negotiation transparently.

2D Program Grid

In vector addition, the grid was 1D — one program per element tile. For matrix multiplication, the grid is 2D: each program computes one output tile of size(BLOCK_M × BLOCK_N).

  • pid_m = tl.program_id(axis=0) — which row-tile of C this program computes
  • pid_n = tl.program_id(axis=1) — which column-tile of C this program computes

To compute its (BLOCK_M × BLOCK_N) output tile, each program iterates over allK / BLOCK_K tiles along the inner dimension, accumulating partial sums.

The Tiled GEMM Kernel

matmul.py
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
import torch
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,   # A is (M, K)
    stride_bk, stride_bn,   # B is (K, N)
    stride_cm, stride_cn,   # C is (M, N)
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    """
    Tiled matrix multiplication: C = A @ B

    Each program instance computes a (BLOCK_M x BLOCK_N) tile of C.
    It iterates over K in steps of BLOCK_K, accumulating partial sums.
    """
    # 2D grid: pid_m indexes rows, pid_n indexes columns of the C matrix
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    # Row/column offsets for this tile
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)  # shape: [BLOCK_M]
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)  # shape: [BLOCK_N]
    offs_k = tl.arange(0, BLOCK_K)                     # shape: [BLOCK_K]

    # Pointers to the first K-slice of A and B for this tile
    # a_ptrs shape: [BLOCK_M, BLOCK_K]  — rows of A for this tile
    # b_ptrs shape: [BLOCK_K, BLOCK_N]  — columns of B for this tile
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    # Accumulator for this output tile, initialized to zero
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Iterate over the K dimension in tiles of BLOCK_K
    for k_start in range(0, tl.cdiv(K, BLOCK_K)):
        k_remaining = K - k_start * BLOCK_K

        # Load A tile [BLOCK_M, BLOCK_K] with masking
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < k_remaining)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)

        # Load B tile [BLOCK_K, BLOCK_N] with masking
        b_mask = (offs_k[:, None] < k_remaining) & (offs_n[None, :] < N)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)

        # Fused multiply-add using Tensor Cores (auto-selects fp16/tf32 path)
        acc += tl.dot(a, b, allow_tf32=True)

        # Advance pointers to next K-slice
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # Store the computed tile to C
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=c_mask)


def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    assert a.is_cuda and b.is_cuda
    M, K = a.shape
    K2, N = b.shape
    assert K == K2, "Inner dimensions must match"

    # Allocate output (float32 for accuracy)
    c = torch.empty((M, N), device=a.device, dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32

    # 2D grid: one program per (BLOCK_M x BLOCK_N) output tile
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    )
    return c


# --- Test & Benchmark ---
M, N, K = 1024, 1024, 1024
a = torch.randn(M, K, device="cuda", dtype=torch.float32)
b = torch.randn(K, N, device="cuda", dtype=torch.float32)

c_triton = matmul(a, b)
c_ref = torch.mm(a, b)

max_err = (c_triton - c_ref).abs().max().item()
print(f"Max error vs torch.mm: {max_err:.4f}")
print(f"Output shape: {c_triton.shape}")
print(f"Correct (atol=1e-2): {torch.allclose(c_triton, c_ref, atol=1e-2)}")

Understanding the K-Loop

The inner for k_start in range(0, tl.cdiv(K, BLOCK_K)) loop is the core of tiled GEMM:

  1. Load a (BLOCK_M × BLOCK_K) tile from matrix A
  2. Load a (BLOCK_K × BLOCK_N) tile from matrix B
  3. Compute their product with tl.dot and add to the accumulator
  4. Advance the pointers by BLOCK_K along the K dimension
  5. Repeat until all of K is consumed

After the loop, the accumulator holds the complete C[pid_m*BM:(pid_m+1)*BM, pid_n*BN:(pid_n+1)*BN] tile, which is stored to global memory.

Block size constraints

BLOCK_M, BLOCK_N, and BLOCK_K must all be powers of 2. BLOCK_K must be at least 16 for tl.dot to use Tensor Cores. Common production values: BLOCK_M=128, BLOCK_N=256, BLOCK_K=64.

Strides: Memory Layout Awareness

We pass a.stride(0) and a.stride(1) to the kernel. These tell us how many elements to skip in memory to advance one row or one column:

  • For a C-contiguous (M, K) tensor: stride(0) = K, stride(1) = 1
  • For a Fortran-contiguous (column-major) tensor: values are swapped

Using .stride() instead of hardcoding makes the kernel work correctly with transposed tensors and non-contiguous memory views — crucial for real-world use.