Triton Introduction: Vector Addition

Chapter 1 — Write your first Triton kernel with tl.load, tl.store, and program_id

The Triton Programming Model

Before writing code, internalize this mental shift: in CUDA C++, you write code for a single thread. In Triton, you write code for a single program — and each program handles a tile of data, not a single element.

When you launch a Triton kernel with a grid of N programs, Triton spawns N instances of your kernel function. Each instance uses tl.program_id() to figure out which tile it owns, loads that tile, does computation, and stores results back.

Thread count is implicit

You never specify the number of threads per block in Triton. Triton auto-tunes thread organization internally based on BLOCK_SIZE and hardware capabilities. This is one of the key simplifications over CUDA.

Four Core Primitives

Every Triton kernel you'll ever write uses these four operations:

  1. tl.program_id(axis=0) — "which tile am I?" Returns an integer identifying this program instance. Like blockIdx.x in CUDA.
  2. tl.arange(0, BLOCK_SIZE) — creates a vector [0, 1, 2, ..., BLOCK_SIZE-1]. Add program_id * BLOCK_SIZE to get global indices for this tile.
  3. tl.load(ptr + offsets, mask=mask) — loads a tile of memory. The mask prevents out-of-bounds access when the array size isn't a multiple of BLOCK_SIZE.
  4. tl.store(ptr + offsets, value, mask=mask) — writes a tile back.

Vector Addition Kernel

The canonical first Triton kernel: add two vectors element-wise. Each program instance loads a tile of 1024 elements from both inputs, adds them, and stores the result.

vector_add.py
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(
    x_ptr,        # pointer to first input vector
    y_ptr,        # pointer to second input vector
    out_ptr,      # pointer to output vector
    n_elements,   # total number of elements
    BLOCK_SIZE: tl.constexpr,  # elements per program instance (compile-time constant)
):
    # Each program instance handles a contiguous block of BLOCK_SIZE elements.
    # program_id(0) tells us which block this instance is responsible for.
    pid = tl.program_id(axis=0)

    # Compute the starting index of this block
    block_start = pid * BLOCK_SIZE

    # Generate the range of indices this program handles: [block_start, block_start+BLOCK_SIZE)
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # Guard: don't read/write past the end of the array
    mask = offsets < n_elements

    # Load tiles from both input vectors
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # Element-wise addition (operates on the entire tile at once)
    output = x + y

    # Write result back to GPU memory
    tl.store(out_ptr + offsets, output, mask=mask)


def vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    assert x.is_cuda and y.is_cuda, "Inputs must be on GPU"
    assert x.shape == y.shape

    out = torch.empty_like(x)
    n = x.numel()

    BLOCK_SIZE = 1024  # each program handles 1024 elements

    # Grid: number of program instances = ceil(n / BLOCK_SIZE)
    grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)

    add_kernel[grid](x, y, out, n, BLOCK_SIZE=BLOCK_SIZE)
    return out


# --- Test ---
torch.manual_seed(0)
x = torch.rand(1 << 20, device="cuda")  # 1M floats
y = torch.rand(1 << 20, device="cuda")

out = vector_add(x, y)
ref = x + y

max_err = (out - ref).abs().max().item()
print(f"Max error vs torch: {max_err:.2e}")
print(f"Match: {torch.allclose(out, ref)}")
print(f"Triton version: {triton.__version__}")

How the Grid Works

The grid lambda determines how many program instances are launched. For a vector of n elements with BLOCK_SIZE=1024:

  • triton.cdiv(n, 1024) = ceil(n / 1024) = number of tiles needed.
  • Program 0 handles elements [0..1023].
  • Program 1 handles elements [1024..2047].
  • ... and so on in parallel.

The mask offsets < n_elements ensures the last tile doesn't read/write memory beyond the array, since the last tile may be partially filled.

Run it now

Click ▶ Run on the code block above to execute this kernel on a real NVIDIA T4 GPU. You should see Max error: 0.00e+00 and Match: True.

The @triton.jit Decorator

@triton.jit marks a function as a GPU kernel. Triton compiles it to PTX (NVIDIA's intermediate assembly) the first time it's called with a given set oftl.constexpr values. The compiled kernel is cached — subsequent calls with the same BLOCK_SIZE skip compilation entirely.

Parameters annotated as tl.constexpr are baked into the compiled kernel. This allows Triton's compiler to unroll loops and optimize memory access patterns that depend on the tile size.