Triton Matrix Multiplication
Chapter 2 — Tiled GEMM: 2D grids, tl.dot, and Tensor Core acceleration
Why Tiling Matters
Matrix multiplication is the foundational operation in every neural network — every linear layer, attention score, and projection is a GEMM (General Matrix Multiply). Naively multiplying two (M, K) and (K, N) matrices requires M × N × K operations and reads the entire matrices multiple times from slow GPU global memory.
The key insight behind fast GEMM is tiling: divide the output matrix into small tiles, load each tile's required input data into fast on-chip memory (registers/shared memory), and perform the computation there. This minimizes expensive memory traffic.
tl.dot(a, b) automatically uses NVIDIA Tensor Cores when available — the specialized matrix-multiply hardware in modern GPUs. On a T4, this provides up to 65 TFLOPS for FP16, versus ~8 TFLOPS for regular CUDA cores. Triton handles the precision negotiation transparently.
2D Program Grid
In vector addition, the grid was 1D — one program per element tile. For matrix multiplication, the grid is 2D: each program computes one output tile of size(BLOCK_M × BLOCK_N).
pid_m = tl.program_id(axis=0)— which row-tile of C this program computespid_n = tl.program_id(axis=1)— which column-tile of C this program computes
To compute its (BLOCK_M × BLOCK_N) output tile, each program iterates over allK / BLOCK_K tiles along the inner dimension, accumulating partial sums.
The Tiled GEMM Kernel
import torch import triton import triton.language as tl @triton.jit def matmul_kernel( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, # A is (M, K) stride_bk, stride_bn, # B is (K, N) stride_cm, stride_cn, # C is (M, N) BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """ Tiled matrix multiplication: C = A @ B Each program instance computes a (BLOCK_M x BLOCK_N) tile of C. It iterates over K in steps of BLOCK_K, accumulating partial sums. """ # 2D grid: pid_m indexes rows, pid_n indexes columns of the C matrix pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) # Row/column offsets for this tile offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # shape: [BLOCK_M] offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # shape: [BLOCK_N] offs_k = tl.arange(0, BLOCK_K) # shape: [BLOCK_K] # Pointers to the first K-slice of A and B for this tile # a_ptrs shape: [BLOCK_M, BLOCK_K] — rows of A for this tile # b_ptrs shape: [BLOCK_K, BLOCK_N] — columns of B for this tile a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn # Accumulator for this output tile, initialized to zero acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # Iterate over the K dimension in tiles of BLOCK_K for k_start in range(0, tl.cdiv(K, BLOCK_K)): k_remaining = K - k_start * BLOCK_K # Load A tile [BLOCK_M, BLOCK_K] with masking a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < k_remaining) a = tl.load(a_ptrs, mask=a_mask, other=0.0) # Load B tile [BLOCK_K, BLOCK_N] with masking b_mask = (offs_k[:, None] < k_remaining) & (offs_n[None, :] < N) b = tl.load(b_ptrs, mask=b_mask, other=0.0) # Fused multiply-add using Tensor Cores (auto-selects fp16/tf32 path) acc += tl.dot(a, b, allow_tf32=True) # Advance pointers to next K-slice a_ptrs += BLOCK_K * stride_ak b_ptrs += BLOCK_K * stride_bk # Store the computed tile to C c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) tl.store(c_ptrs, acc, mask=c_mask) def matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: assert a.is_cuda and b.is_cuda M, K = a.shape K2, N = b.shape assert K == K2, "Inner dimensions must match" # Allocate output (float32 for accuracy) c = torch.empty((M, N), device=a.device, dtype=torch.float32) BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 32 # 2D grid: one program per (BLOCK_M x BLOCK_N) output tile grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) matmul_kernel[grid]( a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, ) return c # --- Test & Benchmark --- M, N, K = 1024, 1024, 1024 a = torch.randn(M, K, device="cuda", dtype=torch.float32) b = torch.randn(K, N, device="cuda", dtype=torch.float32) c_triton = matmul(a, b) c_ref = torch.mm(a, b) max_err = (c_triton - c_ref).abs().max().item() print(f"Max error vs torch.mm: {max_err:.4f}") print(f"Output shape: {c_triton.shape}") print(f"Correct (atol=1e-2): {torch.allclose(c_triton, c_ref, atol=1e-2)}")
Understanding the K-Loop
The inner for k_start in range(0, tl.cdiv(K, BLOCK_K)) loop is the core of tiled GEMM:
- Load a
(BLOCK_M × BLOCK_K)tile from matrix A - Load a
(BLOCK_K × BLOCK_N)tile from matrix B - Compute their product with
tl.dotand add to the accumulator - Advance the pointers by
BLOCK_Kalong the K dimension - Repeat until all of K is consumed
After the loop, the accumulator holds the complete C[pid_m*BM:(pid_m+1)*BM, pid_n*BN:(pid_n+1)*BN] tile, which is stored to global memory.
BLOCK_M, BLOCK_N, and BLOCK_K must all be powers of 2. BLOCK_K must be at least 16 for tl.dot to use Tensor Cores. Common production values: BLOCK_M=128, BLOCK_N=256, BLOCK_K=64.
Strides: Memory Layout Awareness
We pass a.stride(0) and a.stride(1) to the kernel. These tell us how many elements to skip in memory to advance one row or one column:
- For a C-contiguous
(M, K)tensor:stride(0) = K,stride(1) = 1 - For a Fortran-contiguous (column-major) tensor: values are swapped
Using .stride() instead of hardcoding makes the kernel work correctly with transposed tensors and non-contiguous memory views — crucial for real-world use.