Tokenization, Training & Sampling · 18 min

Wrapping the training loop around the transformer

From a 1-D stream of token ids to a working data loader, batched NLL across all positions, and AdamW with warmup, cosine decay, and gradient clipping. The training loop is the outside of the transformer; the forward pass from M16 sits inside it, untouched.

0 / 0

What this lesson is, and isn't

M16 ended with the full transformer forward pass. Tokens at the bottom, attention + FFN blocks in the middle, an unembedding at the top, a softmax over the vocabulary at the very end. That entire stack is one function, fθ(ids)logitsf_\theta(\text{ids}) \to \text{logits}.

The training loop is what you wrap around that function to make θ\theta get good. It does not change the forward pass. It does not change the architecture. It samples batches, computes a loss, runs backprop, and steps an optimizer, all of which you’ve seen separately in earlier modules. This lesson assembles them in the order they actually run.

The corpus is one tensor

Step zero: tokenize the entire training corpus, once, and concatenate the result into a single 1-D tensor of integer ids. For tiny-shakespeare-character that’s 1,115,394 ids. The tokenizer is the boundary you crossed in the last two lessons; everything after it is integer math.

Then split off a validation suffix. nanoGPT’s convention for tiny corpora is a contiguous 90/10 split; the last 10% of the stream is held out and never touched during training:

data = encode(text)            # one long 1-D tensor, length 1,115,394
n = int(0.9 * len(data))
train_data = data[:n]          # 1,003,854 tokens
val_data   = data[n:]          # 111,540 tokens

A contiguous split, not a shuffle. Why? Because the model is autoregressive: it reads sequences. Random per-token shuffling would leak validation tokens into training contexts and back. A clean tail is the honest version.

Train tokens after the split

Tiny-shakespeare-character has 1,115,394 tokens. With a 90/10 contiguous split, how many train tokens?

One batch, B random offsets

The training step samples BB random positions in the train stream. From each position, it carves a context window of length TT as x, and the same window shifted by one as y. Karpathy’s get_batch is four lines:

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - T - 1, (B,))
    x = torch.stack([data[i : i+T]   for i in ix])
    y = torch.stack([data[i+1 : i+T+1] for i in ix])
    return x, y

That’s it. x has shape (B, T). y has the same shape, shifted by one position. Watch what that looks like:

B 4
T 8
B · T = predictions / step 32
data[ ] :
First·Citizen:·Before·we·proceed·any·further,·hear·me·speak.·Speak.·You·are·all·resolved·rather·to·die·than·to·famish?·Resolved.
batch row 0
x resolved
y esolved·
batch row 1
x than·to·
y han·to·f
batch row 2
x e·procee
y ·proceed
batch row 3
x peak.·Yo
y eak.·You

Each row is one (x, y) training example pair. Position t in x is supervised by position t in y; and because the causal mask only lets the transformer see x[:t+1] when predicting y[t], each row gives T honest next-token predictions in one forward pass. One step of training = B · T = 32 supervised examples.

Drag the sliders to feel the trade. Larger T means more next-token predictions per row, but a longer attention window. Larger B means more independent rows in parallel. The product B · T is the supervision count per gradient step.

The shifted-by-one trick

Why shift by one? Because the transformer’s causal mask (M15) ensures that at position tt in the forward pass, the model has only seen x[:t+1]. So predicting y[t] from that prefix is honest next-token supervision, with no information leak.

This means one row gives you TT training examples instead of one. Position 0 predicts position 1. Position 1 predicts position 2. Position 7 predicts position 8. All in the same forward pass, sharing all the same attention computation. The cross-entropy loss is averaged over every position in every row:

L  =  1BTb=1Bt=1Tlogpθ ⁣(yb,t    xb,1:t)\mathcal{L} \;=\; -\frac{1}{B \, T} \sum_{b=1}^{B} \sum_{t=1}^{T} \log p_\theta\!\left(y_{b,t} \;\big|\; x_{b,1:t}\right)

That averaging is the part that gets people. If you compute the loss only at the last position of each row, you’re throwing away T1T - 1 of every TT gradient signals, and your training is roughly TT times slower than it should be. nanoGPT computes loss everywhere. So does GPT.

Predictions per step

With B=64B = 64 and T=64T = 64, how many supervised next-token predictions does one gradient step produce?

The training loop, in full

You have the forward pass (M16). You have the data loader (above). You have the loss (NLL, averaged over BTB \cdot T positions). You have the optimizer (AdamW, from M10).

The whole training loop is six lines:

for step in range(max_iters):
    x, y = get_batch('train')
    logits = model(x)                                     # (B, T, V)
    loss = F.cross_entropy(logits.view(-1, V), y.view(-1))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

Read it with no surprise. Sample a batch. Forward. Loss. Zero grads. Backward. Clip. Step. Loop.

The two pieces that haven’t appeared elsewhere in the course:

  • set_to_none=True is a small memory-and-speed optimization. It tells PyTorch to drop the .grad tensors entirely instead of zeroing them in place. The next backward pass allocates fresh ones.
  • clip_grad_norm_(..., 1.0) is the gradient-clipping line below.

Gradient clipping

Sometimes a single batch produces a gigantic gradient. Maybe the batch had a string of rare tokens. Maybe the loss spiked because the model got confidently wrong. Whatever the reason, applying that gradient at the current learning rate would step the parameters far enough to wreck the run.

Gradient clipping is cheap insurance. Compute the global L2 norm of the gradient: flatten every parameter’s gradient into one big vector and take its length. If that length is greater than a threshold cc (typically 1.0), rescale the entire gradient down so its length is exactly cc:

g    gmin ⁣(1, cg2)g \;\leftarrow\; g \cdot \min\!\left(1,\ \frac{c}{\|g\|_2}\right)

When the gradient is small, this does nothing. When it spikes, it rescales the step without changing its direction. Two lines of code, eliminates almost every “loss exploded out of nowhere” failure mode.

Warmup, then cosine decay

The learning rate is not constant. It follows a schedule with two phases:

  1. Linear warmup for the first WW iterations (typically 100–2000 depending on scale). The learning rate climbs from 0 to its peak ηmax\eta_{\max}. Why warmup? AdamW’s running statistics (the exponentially-weighted moving averages of gradient and squared gradient) are zero at iter 0 and need a few hundred steps to stabilize. Stepping at full η\eta before they stabilize causes the spikes that gradient clipping then has to clean up. Better to not have the spikes.
  2. Cosine decay for the rest of training. The learning rate falls from ηmax\eta_{\max} to a small floor ηmin\eta_{\min} following half a cosine wave:
ηt  =  ηmin+12(ηmaxηmin) ⁣[1+cos ⁣(πtWDW)]\eta_t \;=\; \eta_{\min} + \tfrac{1}{2}(\eta_{\max} - \eta_{\min})\!\left[1 + \cos\!\left(\pi \cdot \tfrac{t - W}{D - W}\right)\right]

Cosine is the dominant choice in modern training. It decays smoothly, spends a long time near peak before falling off, and crucially, ends near zero, so the final iterations are tiny correction steps rather than full-strength updates. The model “settles.”

Iter-0 loss on tiny-shakespeare-char

Tiny-shakespeare-char has V=65\lvert V \rvert = 65. With reasonable initialization, the model’s output distribution at iter 0 is approximately uniform. What is the expected initial training loss in nats?

What you would actually see

Wire all of this up against a small transformer (4 layers, 4 heads, d=128d = 128, T=64T = 64, ~200k parameters, capstone size). Train on tiny-shakespeare-char with B=64B = 64, ηmax=103\eta_{\max} = 10^{-3}, 100-iter warmup, cosine decay over 5,000 iterations.

Start point: ln654.17\ln 65 \approx 4.17 nats. After ~100 iters of warmup, loss drops fast; by iter 500 it’s around 1.7. By iter 2,000 it’s around 1.4. Final val loss around 1.5, a perplexity near 4.5.

That curve is the same shape, scaled, at every model size from a 200k-parameter character model up to GPT-3. Initialize near logV\log \lvert V \rvert, drop fast in the first few percent of training, slow plateau for the rest. Next lesson: how to read it.

Lesson complete

Nice tinkering.