The same forward pass, no loss
Training and sampling share one function: the M16 forward pass. The difference is what you do with its output.
Training:
prompt → forward → logits → cross_entropy(logits, target) → backward → stepSampling:
prompt → forward → logits[:, -1, :] → pick one token → append → repeatThat’s the entire autoregressive loop. No loss, no backward, no optimizer. Take the last position’s logits, turn them into a probability distribution, draw one token from it, append to the context, run the model again. Stop when you hit a max length or a stop condition.
Concretely:
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:] # crop to context
logits = model(idx_cond)[:, -1, :] # last position only
probs = sample(logits, tau, k, p)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, next_id), dim=1)The sample function is the entire content of this lesson.
Greedy: why deterministic decoding loops
The simplest possible sample is argmax. Pick the highest-logit token. Always.
It loops. Run any small generative LM with greedy decoding for 30 tokens and watch the output collapse into a repeating cycle. “the the the the the.” “and the and the and the.” This isn’t a bug. It’s the consequence of deterministic decoding on a model whose distribution is genuinely unimodal at every step.
Holtzman et al. (2020) showed this empirically: in a generative LM, the joint maximum-likelihood path is rarely a good path. The path that maximizes concentrates on degenerate, repetitive cycles, because each individual step picks the locally most-likely token, and the locally most-likely token from “the” is often “the” again.
This matters less for translation, where the source language tightly constrains the target distribution. It matters a lot for open-ended language modeling, where diversity is what makes the output sound human. Greedy is wrong for the same reason that a maze-solver that always picks “north” is wrong: at every step it’s the highest-probability move, and the path it produces is useless.
Greedy as a degenerate top-k
Greedy / argmax decoding is what top-k sampling collapses to at ?
Temperature: divide before softmax
Temperature is not a creativity dial. It is a single arithmetic operation applied to the logits before softmax:
That’s it. Divide every logit by , then run the usual softmax. Three regimes:
- is the model’s own distribution. Unmodified.
- sharpens the distribution. Logit gaps grow, the highest logit dominates, softmax collapses to a one-hot at the argmax. Greedy.
- flattens the distribution. Logit gaps shrink toward zero, every token’s probability approaches . Uniform random.
Concrete: logits . At , softmax gives . At , the logits effectively double to , and softmax gives : the top token nearly doubles its share, the bottom token nearly disappears. At , the logits halve, and the distribution moves toward uniform: .
The widget reuses the bigram you saw in M14, but the operation is identical for transformer logits. Drag the temperature slider and watch the bars.
One useful invariant: does not change the argmax (for any ). The most-likely token is the most-likely token at any temperature. What changes is how much more likely it is than its neighbors.
Temperature 1, top probability
Logits at . What is , to two decimals?
Temperature 0.5, top probability
Same logits , now at . What is , to two decimals?
Top-k: chop off the tail by rank
Temperature alone leaves the long tail of the distribution alive. Even with , the 50,000th most-likely token still has nonzero probability, and rare tokens summed over a long tail have surprisingly high total mass. Sometimes a typo wins.
Top-k sampling solves this by chopping the tail off by rank. Keep only the highest-logit tokens; set every other logit to (so its softmax probability becomes 0); renormalize the kept tokens.
Top-k = 1 is greedy. Top-k = is “off.” Production LMs often use in the 40–200 range. The widget above has a top-k slider; slide it down and watch the bars not in the top- vanish.
Top-p (nucleus): chop off the tail by mass
Top-k has a problem: a fixed-cardinality cutoff doesn’t track the model’s confidence. When the model is confident (one token has 0.99 of the mass), top-k = 50 still keeps 49 garbage tokens. When the model is uncertain (50 tokens each with ~0.018 of the mass), top-k = 10 cuts away most of the legitimate distribution.
Nucleus sampling fixes this with a confidence-tracking cutoff. Sort the distribution descending, take the smallest prefix whose cumulative mass is at least , throw the rest away.
Keep the first tokens; renormalize. The widget below shows it for two example distributions. Switch between “confident” (one peaked bar) and “long-tailed” (many small bars). Drag the gold -line and watch the nucleus grow when the distribution is uncertain and shrink when it’s confident.
In production this is typically top_p = 0.9, occasionally 0.95. The kept set sizes itself dynamically.
Nucleus on the confident step
The sorted distribution is . With , what is the nucleus size?
The pipeline order
Three knobs, applied in a fixed order. Conventional, replicated by every major LM library:
- Temperature. Reshape the entire distribution. Sharper or flatter, but full-vocab.
- Top-k. Truncate by rank. The bottom tokens get .
- Top-p. Truncate by cumulative mass on what’s left. Tail tokens that survived top-k but fall outside the nucleus get .
- Sample. One softmax over what’s left, one categorical draw.
def sample(logits, tau, k, p):
logits = logits / tau # 1. temperature
if k > 0:
kth = torch.topk(logits, k).values[..., -1]
logits[logits < kth] = -float('inf') # 2. top-k
if p < 1.0:
sorted_logits, sorted_idx = logits.sort(descending=True)
cum = sorted_logits.softmax(-1).cumsum(-1)
mask = cum > p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False # always keep the top token
logits = logits.masked_fill(
sorted_idx[mask], -float('inf')) # 3. top-p
probs = logits.softmax(-1)
return probs # 4. sample externallyThree transforms applied to the logit vector, then one categorical draw. None of them know about the others; they each just write into the same array.
Why we don't use beam search
Beam search tracks the top- partial sequences by joint probability, expanding each one step at a time, pruning to keep only the best overall. It is the dominant decoding algorithm for translation, where the source sentence tightly constrains the target distribution and the joint maximum-likelihood path is genuinely the right path.
Open-ended language modeling is the wrong setting for it. The Holtzman finding from earlier in this lesson holds: the joint MLE path concentrates on degenerate cycles. Beam search makes this worse, not better: it explicitly searches for the highest-likelihood path, which is the path most likely to repeat itself. GPT-2/3/4 do not use beam search. LLaMA does not. Sampling-based decoding is the rule for generative text.
Beam search remains the right choice for translation, summarization, and other “constrained generation” tasks where there is a single best answer. It is the wrong choice everywhere creativity matters.
The KV cache, again
Every step of the autoregressive loop runs the M16 forward pass on a growing prefix. Naively, that means at step the model recomputes keys and values for all tokens, even though the keys and values for tokens 0 to haven’t changed since the previous step.
The KV cache (M15) avoids that recompute. Maintain a per-layer cache of every key and value vector ever computed; at each step, only project the new token’s key and value, then run attention against the cached prefix. Same final output, much less arithmetic.
The cumulative cost without a cache scales as : at every step you redo all the past work. With the cache, every step is , dominated by the new token’s attention over the cache.
This is inference-only. Training already has all positions in parallel inside one matrix multiply; there is no “past” to cache.
The cost of not caching
Without a KV cache, generating tokens has total compute complexity . What is the exponent?
Seed, prompt, sampler: the full control surface
Once the weights are frozen, three things, and only three things, change the output:
- The prompt. What you condition on.
- The sampler settings. , , .
- The random seed. The PRNG state from which
multinomialdraws.
Same seed + same sampler + same prompt = same output, byte-for-byte. Change any one of those and you take a different path through the same model. Change the seed alone and the model says a different thing while remaining the same model.
This is the entire interface for “controlling” a generative LM at inference time. There is no fourth knob. There is no “make the model smarter.” There is only re-sampling the same distribution differently.
That distribution came from training. Training came from the loss curve. The loss curve came from the data loader. The data loader read tokens from a tokenizer. The tokenizer was a list of merges fit to a corpus.
That is the whole stack: bytes to vectors, vectors to weights, weights back to bytes. Next module: you train one in your browser.
Lesson complete
Nice tinkering.
Before you go