Read the code
The next four minutes are not new content.
M18 · the credits roll
Every line below shipped in the engine you just trained.
Six files. About four hundred lines of WGSL and TypeScript. Imported into
this page via Vite's ?raw loader, so what you see is what ran.
The model above is sampling from the same reference checkpoint right now.
cyrb128 hashes a string into four 32-bit seeds; sfc32 turns those seeds into a deterministic stream of floats. This is the entire reason the .bin you saved is byte-identical across runs.
// sfc32 PRNG seeded by cyrb128. Adapted from
// https://github.com/bryc/code/blob/master/jshash/PRNGs.md (public domain).
// Slice 3 uses this for weight init + train batch indices so a convergence
// run is reproducible within the slice. Slice 4 ships the full determinism
// surface (dropout masks, sampler, twin-seed test).
export function cyrb128(str: string): [number, number, number, number] {
let h1 = 1779033703, h2 = 3144134277, h3 = 1013904242, h4 = 2773480762;
for (let i = 0, k; i < str.length; i++) {
k = str.charCodeAt(i);
h1 = h2 ^ Math.imul(h1 ^ k, 597399067);
h2 = h3 ^ Math.imul(h2 ^ k, 2869860233);
h3 = h4 ^ Math.imul(h3 ^ k, 951274213);
h4 = h1 ^ Math.imul(h4 ^ k, 2716044179);
}
h1 = Math.imul(h3 ^ (h1 >>> 18), 597399067);
h2 = Math.imul(h4 ^ (h2 >>> 22), 2869860233);
h3 = Math.imul(h1 ^ (h3 >>> 17), 951274213);
h4 = Math.imul(h2 ^ (h4 >>> 19), 2716044179);
return [
(h1 ^ h2 ^ h3 ^ h4) >>> 0,
(h2 ^ h1) >>> 0,
(h3 ^ h1) >>> 0,
(h4 ^ h1) >>> 0,
];
}
export function sfc32(a: number, b: number, c: number, d: number): () => number {
let s0 = a >>> 0, s1 = b >>> 0, s2 = c >>> 0, s3 = d >>> 0;
return () => {
s0 |= 0; s1 |= 0; s2 |= 0; s3 |= 0;
const t = (((s0 + s1) | 0) + s3) | 0;
s3 = (s3 + 1) | 0;
s0 = s1 ^ (s1 >>> 9);
s1 = (s2 + (s2 << 3)) | 0;
s2 = (s2 << 21) | (s2 >>> 11);
s2 = (s2 + t) | 0;
return (t >>> 0) / 4294967296;
};
}
// Seed convenience: stringify any int → cyrb128 → sfc32.
export function seededRng(seed: number | string): () => number {
const s = typeof seed === 'number' ? String(seed) : seed;
const [a, b, c, d] = cyrb128(s);
return sfc32(a, b, c, d);
}
// Box-Muller transform: converts two uniforms in (0, 1) to two
// independent N(0, 1) samples. Used by the He / std=0.02 weight initializer.
export function boxMuller(rng: () => number): [number, number] {
let u1 = rng(); if (u1 < 1e-300) u1 = 1e-300;
const u2 = rng();
const r = Math.sqrt(-2 * Math.log(u1));
const theta = 2 * Math.PI * u2;
return [r * Math.cos(theta), r * Math.sin(theta)];
}
// Fill an array with N(0, std) samples.
export function fillNormal(out: Float32Array, std: number, rng: () => number): void {
let i = 0;
while (i + 1 < out.length) {
const [a, b] = boxMuller(rng);
out[i++] = a * std;
out[i++] = b * std;
}
if (i < out.length) {
const [a] = boxMuller(rng);
out[i] = a * std;
}
}
one workgroup per row, two-pass reduce (mean, then variance), gamma-scale and beta-shift. Pre-LN, applied twice per block. The "what exactly to normalize" question from M13, in WGSL.
// layerNorm: per row r, y[r,k] = gamma[k] * (x[r,k] - mu) / sqrt(var + eps) + beta[k]
// Dispatch: rows workgroups of 64 threads each. Workgroup cooperates on the
// row-reduction; threads stride k by workgroup_size.
const WG: u32 = 64u;
const EPS: f32 = 1e-5;
@group(0) @binding(0) var<storage, read> x: array<f32>; // [rows, d]
@group(0) @binding(1) var<storage, read> gamma: array<f32>; // [d]
@group(0) @binding(2) var<storage, read> beta: array<f32>; // [d]
@group(0) @binding(3) var<storage, read_write> y: array<f32>; // [rows, d]
@group(0) @binding(4) var<uniform> dims: vec4<u32>; // (rows, d, _, _)
var<workgroup> sumScratch: array<f32, WG>;
fn wgReduce(local: u32, val: f32) -> f32 {
sumScratch[local] = val;
workgroupBarrier();
var step: u32 = WG / 2u;
loop {
if (step == 0u) { break; }
if (local < step) { sumScratch[local] = sumScratch[local] + sumScratch[local + step]; }
workgroupBarrier();
step = step / 2u;
}
return sumScratch[0];
}
@compute @workgroup_size(WG)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
) {
let r = wid.x; let local = lid.x;
let d = dims.y;
if (r >= dims.x) { return; }
// pass 1: mean
var partial: f32 = 0.0;
var k: u32 = local;
loop { if (k >= d) { break; } partial = partial + x[r * d + k]; k = k + WG; }
let mu = wgReduce(local, partial) / f32(d);
// pass 2: variance
workgroupBarrier();
partial = 0.0;
k = local;
loop { if (k >= d) { break; } let z = x[r * d + k] - mu; partial = partial + z * z; k = k + WG; }
let varv = wgReduce(local, partial) / f32(d);
let inv = 1.0 / sqrt(varv + EPS);
// pass 3: write
workgroupBarrier();
k = local;
loop {
if (k >= d) { break; }
y[r * d + k] = gamma[k] * (x[r * d + k] - mu) * inv + beta[k];
k = k + WG;
}
}
softmax((Q·K^T + mask) / sqrt(d_k)) · V. One workgroup per (batch, head, query-position) triple. The causal mask is upper-triangular -infinity. M15 in 71 lines.
// Causal scaled-dot-product attention. One workgroup per (b, head, query-position).
// Layout assumptions: qkv is [B*T, 3*d] row-major; within each row the slab
// order is [Q | K | V], each slab of length d. Heads occupy contiguous strips
// of length dHead inside Q, K, V. Workgroup size = 64; this implementation
// supports T ≤ 64 and dHead ≤ 64 (our locked config: T=64, dHead=16).
const WG: u32 = 64u;
// Large-magnitude negative used as a softmax mask. Avoids the literal
// -3.4028235e38 which the WGSL parser rounds to slightly more than f32::MIN
// and rejects on Chromium 124+. -1e30 is plenty for exp(-1e30) → 0.
const NEG_INF: f32 = -1e30;
@group(0) @binding(0) var<storage, read> qkv: array<f32>; // [B*T, 3*d]
@group(0) @binding(1) var<storage, read_write> out: array<f32>; // [B*T, d]
@group(0) @binding(2) var<uniform> cfg: vec4<u32>; // (B, T, d, nHead)
var<workgroup> scores: array<f32, WG>;
@compute @workgroup_size(WG)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
) {
let B = cfg.x; let T = cfg.y; let d = cfg.z; let nHead = cfg.w;
let tq = wid.x; let h = wid.y; let b = wid.z;
if (b >= B || h >= nHead || tq >= T) { return; }
let dHead = d / nHead;
let scale = 1.0 / sqrt(f32(dHead));
let i = lid.x;
// 1. Compute one score per thread (one t' = i). Mask t' > tq to -∞.
if (i < T) {
if (i <= tq) {
var s: f32 = 0.0;
for (var c: u32 = 0u; c < dHead; c = c + 1u) {
let qi = (b * T + tq) * 3u * d + 0u * d + h * dHead + c;
let ki = (b * T + i ) * 3u * d + 1u * d + h * dHead + c;
s = s + qkv[qi] * qkv[ki];
}
scores[i] = s * scale;
} else {
scores[i] = NEG_INF;
}
}
workgroupBarrier();
// 2. Softmax along T. Single-thread reduction (T ≤ 64).
if (i == 0u) {
var mx: f32 = scores[0];
for (var t: u32 = 1u; t < T; t = t + 1u) { if (scores[t] > mx) { mx = scores[t]; } }
var sum: f32 = 0.0;
for (var t: u32 = 0u; t < T; t = t + 1u) {
scores[t] = exp(scores[t] - mx);
sum = sum + scores[t];
}
let inv = 1.0 / sum;
for (var t: u32 = 0u; t < T; t = t + 1u) { scores[t] = scores[t] * inv; }
}
workgroupBarrier();
// 3. Output: thread i (i < dHead) computes one channel of this head.
if (i < dHead) {
var acc: f32 = 0.0;
for (var tk: u32 = 0u; tk < T; tk = tk + 1u) {
let vi = (b * T + tk) * 3u * d + 2u * d + h * dHead + i;
acc = acc + scores[tk] * qkv[vi];
}
out[(b * T + tq) * d + h * dHead + i] = acc;
}
}
C = A · B^T. Used four times in the backward pass to compute dA = dC · B^T from any forward matmul C = A · B. The chain rule made manifest as one tiled matmul against the transpose of the right-hand side.
// Matmul with the right-hand operand transposed. C = A · Bᵀ.
// A: [M, K] row-major.
// B: [N, K] row-major (i.e. logically [K, N] then transposed → store rows of N indexed by k).
// C: [M, N] row-major.
// Backs unembedding: A = hidden states [B·T, d], B = wte [V, d] (tied embedding),
// C = logits [B·T, V] computed as logits[i, v] = Σ_k h[i, k] · wte[v, k].
const TILE: u32 = 16u;
@group(0) @binding(0) var<storage, read> A: array<f32>;
@group(0) @binding(1) var<storage, read> B: array<f32>;
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
@group(0) @binding(3) var<uniform> dims: vec4<u32>; // (M, K, N, _)
var<workgroup> aTile: array<f32, 256>;
var<workgroup> bTile: array<f32, 256>;
@compute @workgroup_size(16, 16, 1)
fn main(
@builtin(workgroup_id) gid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
) {
let M = dims.x; let K = dims.y; let N = dims.z;
let row = gid.y * TILE + lid.y; // index into M
let col = gid.x * TILE + lid.x; // index into N
var sum: f32 = 0.0;
let nTiles = (K + TILE - 1u) / TILE;
for (var t: u32 = 0u; t < nTiles; t = t + 1u) {
let kA = t * TILE + lid.x; // A is read at [row, kA]
let kB = t * TILE + lid.x; // B is read at [col, kB] (note: same lid.x both halves)
let aIdx = lid.y * TILE + lid.x;
if (row < M && kA < K) { aTile[aIdx] = A[row * K + kA]; } else { aTile[aIdx] = 0.0; }
// For Bᵀ we want bTile[t*TILE+r, c] for r in workgroup: store B[col_c, kB] at [r=lid.y, c=lid.x] of the tile ...
// Simpler: load b chunk indexed by (col_owned_by_lid.y_offset?) – we instead lay out bTile so the inner loop reads column-major.
// Use mapping: bTile[k*TILE + c] := B[(gid.x*TILE + c), (t*TILE + k)].
let cInTile = lid.x;
let kInTile = lid.y;
let nLocal = gid.x * TILE + cInTile;
let kLocal = t * TILE + kInTile;
if (nLocal < N && kLocal < K) {
bTile[kInTile * TILE + cInTile] = B[nLocal * K + kLocal];
} else {
bTile[kInTile * TILE + cInTile] = 0.0;
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE; k = k + 1u) {
sum = sum + aTile[lid.y * TILE + k] * bTile[k * TILE + lid.x];
}
workgroupBarrier();
}
if (row < M && col < N) { C[row * N + col] = sum; }
}
this is the line: dlogits[i] = (probs[i] - target_onehot[i]) / (B*T). The composite gradient of softmax + cross-entropy. The "magic" of loss.backward(), made of arithmetic. M9 + M12. You wrote it.
// softmax + cross-entropy backward.
// Inputs : logits [rows, V] (read), targets [rows] (read)
// Outputs: dlogits [rows, V] (write)
// Math : dlogits[i, v] = (softmax(logits[i])[v] − 1[v == targets[i]]) / rows.
// One workgroup per row; one thread per vocab entry. WG size 64 strides over V
// when V > 64 (here V=65, so a single stride is enough).
// Computes the loss accumulator on the side: each WG writes
// loss[r] = log Σ_v exp(z_v − maxz) + (maxz − z_target).
// The driver mean-reduces loss[] on the CPU side.
const WG: u32 = 64u;
@group(0) @binding(0) var<storage, read> logits: array<f32>;
@group(0) @binding(1) var<storage, read> targets: array<i32>;
@group(0) @binding(2) var<storage, read_write> dlogits: array<f32>;
@group(0) @binding(3) var<storage, read_write> lossPer: array<f32>; // [rows] per-row NLL
@group(0) @binding(4) var<uniform> dims: vec4<u32>; // (rows, V, _, _)
var<workgroup> mxScratch: array<f32, WG>;
var<workgroup> sumScratch: array<f32, WG>;
var<workgroup> logSum: f32;
var<workgroup> rowMax: f32;
// 'target' is a reserved word in WGSL (texture sampling target etc.); use
// 'tgt' here to avoid the parser rejecting the line.
fn wgMax(local: u32, val: f32) -> f32 {
mxScratch[local] = val;
workgroupBarrier();
var step: u32 = WG / 2u;
loop {
if (step == 0u) { break; }
if (local < step) {
let other = mxScratch[local + step];
if (other > mxScratch[local]) { mxScratch[local] = other; }
}
workgroupBarrier();
step = step / 2u;
}
return mxScratch[0];
}
fn wgSum(local: u32, val: f32) -> f32 {
sumScratch[local] = val;
workgroupBarrier();
var step: u32 = WG / 2u;
loop {
if (step == 0u) { break; }
if (local < step) { sumScratch[local] = sumScratch[local] + sumScratch[local + step]; }
workgroupBarrier();
step = step / 2u;
}
return sumScratch[0];
}
@compute @workgroup_size(WG)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>,
) {
let rows = dims.x; let V = dims.y;
let r = wid.x; if (r >= rows) { return; }
let local = lid.x;
let inv = 1.0 / f32(rows);
// Pass 1: per-row max.
var mx: f32 = -1e30;
var v: u32 = local;
loop {
if (v >= V) { break; }
let z = logits[r * V + v];
if (z > mx) { mx = z; }
v = v + WG;
}
let rmax = wgMax(local, mx);
if (local == 0u) { rowMax = rmax; }
workgroupBarrier();
// Pass 2: per-row exp-sum.
var s: f32 = 0.0;
v = local;
loop {
if (v >= V) { break; }
s = s + exp(logits[r * V + v] - rowMax);
v = v + WG;
}
let rsum = wgSum(local, s);
if (local == 0u) { logSum = log(rsum); }
workgroupBarrier();
// Pass 3: write dlogits and the per-row loss.
let tgt = targets[r];
v = local;
loop {
if (v >= V) { break; }
let p = exp(logits[r * V + v] - rowMax) / rsum;
let one = select(0.0, 1.0, i32(v) == tgt);
dlogits[r * V + v] = (p - one) * inv;
v = v + WG;
}
if (local == 0u) {
let zt = logits[r * V + u32(tgt)];
lossPer[r] = logSum + rowMax - zt;
}
}
every weight reads (theta, grad, m, v) and writes (theta_new, m_new, v_new). No autograd, no graph. Each call updates 206,016 floats in parallel, then the loss curve takes another step down.
// AdamW one-step (PyTorch ordering: decoupled weight decay → m/v update →
// bias-corrected step). Mirrors torch.optim.AdamW's per-element math line for
// line so the curve matches the reference oracle in docs/research/.
//
// 1. θ ← θ · (1 − lr · λ)
// 2. m ← β₁ m + (1 − β₁) g
// 3. v ← β₂ v + (1 − β₂) g²
// 4. θ ← θ − (lr / (1 − β₁ᵗ)) · m / (√(v / (1 − β₂ᵗ)) + ε)
//
// One thread per parameter element. The driver passes (lr, β₁, β₂, ε, λ,
// 1−β₁ᵗ, 1−β₂ᵗ) in a single uniform block.
struct HyperParams {
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
lambda: f32,
bc1: f32, // 1 − β₁ᵗ (precomputed by driver)
bc2: f32, // 1 − β₂ᵗ
_pad: f32,
};
@group(0) @binding(0) var<storage, read_write> theta: array<f32>;
@group(0) @binding(1) var<storage, read> grad: array<f32>;
@group(0) @binding(2) var<storage, read_write> m: array<f32>;
@group(0) @binding(3) var<storage, read_write> v: array<f32>;
@group(0) @binding(4) var<uniform> u: vec4<u32>; // (N, _, _, _)
@group(0) @binding(5) var<uniform> p: HyperParams;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i = gid.x;
if (i >= u.x) { return; }
// 1. decoupled weight decay
var t = theta[i] * (1.0 - p.lr * p.lambda);
// 2. m, v running averages
let g = grad[i];
let mNew = p.beta1 * m[i] + (1.0 - p.beta1) * g;
let vNew = p.beta2 * v[i] + (1.0 - p.beta2) * g * g;
m[i] = mNew;
v[i] = vNew;
// 3. bias-corrected step
let stepSize = p.lr / p.bc1;
let denom = sqrt(vNew) / sqrt(p.bc2) + p.eps;
t = t - stepSize * mNew / denom;
theta[i] = t;
}
That is the whole engine.
Plus the boring scaffolding: a Tensor type, a few matmul kernels, a corpus
loader, a checkpoint writer. None of it is hidden. Read the rest at apps/docs/src/lib/m18/engine/.
The model above is sampling from the reference checkpoint, one character every three hundred and fifty milliseconds, while six pedagogically dense source files scroll past below. The model and the code are not separate things. The model is the code (plus 206,016 floats). The code is what runs every time anyone presses Start on any page in this module.
Scroll, or let it scroll itself. Pause whenever you want to read something twice.
The endgame
This is the entire course.
You started in M5 with and a single derivative. You finish here with a transformer you trained yourself, on hardware you own, with weights you can save, in roughly the time it takes to brew coffee.
There is no module after this one. There is the model, the artifact, and the next thing you decide to learn.
Push the button again. Watch it learn. Then close this tab and go read a paper.
Name the artifact
When this course says you built a tiny transformer, what is the actual artifact?
A trained model is code that defines the computation plus weights that set the behavior.
Where to next
Three honest pointers, none of them about Tinker.
Read GPT-2 in nanoGPT. Karpathy’s nanoGPT is about 600 lines of PyTorch. You can read every line. The architecture is identical to what you just trained. The differences are scale (12 layers, d_model 768, vocab 50,257), and that someone trained it for a few hundred hours on a few hundred GPUs instead of five minutes on your laptop. The code is the same shape.
Scale up your own checkpoint. Same code, swap tinyshakespeare.txt for a 30 MB corpus you care about. Bump nLayer to 6, dModel to 128. Train for thirty minutes instead of five. The .bin will be larger. The sampler will be the same. Nothing else needs to change.
Mechanistic interpretability. Anthropic publishes a research line called transformer-circuits that tries to understand, line by line, what individual attention heads and individual neurons are for in a trained transformer. The residual-stream framing you read in M16 maps directly onto the model you trained here. The smallest interesting subject in mech-interp is a model of about your scale. Your .bin is enough to start reading.
That’s the end. Thank you for working through it.
Lesson complete
Nice tinkering.
Before you go