Tokenization, Training & Sampling · 22 min

How distributions become text

The autoregressive loop wrapped around the M16 forward pass. Why argmax fails for open-ended LMs. Temperature, top-k, top-p as logit-space transforms in a fixed pipeline. Beam search and why we don't use it. The KV cache as the reason inference is O(T) instead of O(T²).

0 / 0

The same forward pass, no loss

Training and sampling share one function: the M16 forward pass. The difference is what you do with its output.

Training:

prompt → forward → logits → cross_entropy(logits, target) → backward → step

Sampling:

prompt → forward → logits[:, -1, :] → pick one token → append → repeat

That’s the entire autoregressive loop. No loss, no backward, no optimizer. Take the last position’s logits, turn them into a probability distribution, draw one token from it, append to the context, run the model again. Stop when you hit a max length or a stop condition.

Concretely:

for _ in range(max_new_tokens):
    idx_cond = idx[:, -block_size:]              # crop to context
    logits = model(idx_cond)[:, -1, :]           # last position only
    probs = sample(logits, tau, k, p)
    next_id = torch.multinomial(probs, num_samples=1)
    idx = torch.cat((idx, next_id), dim=1)

The sample function is the entire content of this lesson.

Greedy: why deterministic decoding loops

The simplest possible sample is argmax. Pick the highest-logit token. Always.

It loops. Run any small generative LM with greedy decoding for 30 tokens and watch the output collapse into a repeating cycle. “the the the the the.” “and the and the and the.” This isn’t a bug. It’s the consequence of deterministic decoding on a model whose distribution is genuinely unimodal at every step.

Holtzman et al. (2020) showed this empirically: in a generative LM, the joint maximum-likelihood path is rarely a good path. The path that maximizes tP(xtx<t)\prod_t P(x_t \mid x_{<t}) concentrates on degenerate, repetitive cycles, because each individual step picks the locally most-likely token, and the locally most-likely token from “the” is often “the” again.

This matters less for translation, where the source language tightly constrains the target distribution. It matters a lot for open-ended language modeling, where diversity is what makes the output sound human. Greedy is wrong for the same reason that a maze-solver that always picks “north” is wrong: at every step it’s the highest-probability move, and the path it produces is useless.

Greedy as a degenerate top-k

Greedy / argmax decoding is what top-k sampling collapses to at k=k = ?

Temperature: divide before softmax

Temperature is not a creativity dial. It is a single arithmetic operation applied to the logits before softmax:

softmaxτ()i  =  exp(i/τ)jexp(j/τ)\mathrm{softmax}_\tau(\ell)_i \;=\; \frac{\exp(\ell_i / \tau)}{\sum_j \exp(\ell_j / \tau)}

That’s it. Divide every logit by τ\tau, then run the usual softmax. Three regimes:

  • τ=1\tau = 1 is the model’s own distribution. Unmodified.
  • τ0+\tau \to 0^+ sharpens the distribution. Logit gaps grow, the highest logit dominates, softmax collapses to a one-hot at the argmax. Greedy.
  • τ\tau \to \infty flattens the distribution. Logit gaps shrink toward zero, every token’s probability approaches 1/V1/\lvert V \rvert. Uniform random.

Concrete: logits =(2,1,0)\ell = (2, 1, 0). At τ=1\tau = 1, softmax gives (0.665,0.245,0.090)(0.665,\, 0.245,\, 0.090). At τ=0.5\tau = 0.5, the logits effectively double to (4,2,0)(4, 2, 0), and softmax gives (0.867,0.117,0.016)(0.867,\, 0.117,\, 0.016): the top token nearly doubles its share, the bottom token nearly disappears. At τ=2\tau = 2, the logits halve, and the distribution moves toward uniform: (0.506,0.307,0.186)(0.506,\, 0.307,\, 0.186).

The widget reuses the bigram you saw in M14, but the operation is identical for transformer logits. Drag the temperature slider and watch the bars.

argmax n
P(argmax) 0.113
T 1.00
samples 0
·abcdefghijklmnopqrstuvwxyzP (drag bar tops)empirical

Drag any bar's top to set its probability directly. The temperature slider stretches or squeezes the whole distribution; but notice the argmax (highlighted character) doesn't change for any T > 0. Top-k censors all but the largest k bars before sampling. Empirical samples appear as coral ticks above each bar; with enough samples they lock onto the bar tops.

One useful invariant: τ\tau does not change the argmax (for any τ>0\tau > 0). The most-likely token is the most-likely token at any temperature. What changes is how much more likely it is than its neighbors.

Temperature 1, top probability

Logits =(2,1,0)\ell = (2,\, 1,\, 0) at τ=1\tau = 1. What is p1p_1, to two decimals?

Temperature 0.5, top probability

Same logits (2,1,0)(2,\, 1,\, 0), now at τ=0.5\tau = 0.5. What is p1p_1, to two decimals?

Top-k: chop off the tail by rank

Temperature alone leaves the long tail of the distribution alive. Even with τ=0.7\tau = 0.7, the 50,000th most-likely token still has nonzero probability, and rare tokens summed over a long tail have surprisingly high total mass. Sometimes a typo wins.

Top-k sampling solves this by chopping the tail off by rank. Keep only the kk highest-logit tokens; set every other logit to -\infty (so its softmax probability becomes 0); renormalize the kept tokens.

i  =  {iiTopK(,k)otherwise\ell'_i \;=\; \begin{cases} \ell_i & i \in \mathrm{TopK}(\ell, k) \\ -\infty & \text{otherwise} \end{cases}

Top-k = 1 is greedy. Top-k = V\lvert V \rvert is “off.” Production LMs often use kk in the 40–200 range. The widget above has a top-k slider; slide it down and watch the bars not in the top-kk vanish.

Top-p (nucleus): chop off the tail by mass

Top-k has a problem: a fixed-cardinality cutoff doesn’t track the model’s confidence. When the model is confident (one token has 0.99 of the mass), top-k = 50 still keeps 49 garbage tokens. When the model is uncertain (50 tokens each with ~0.018 of the mass), top-k = 10 cuts away most of the legitimate distribution.

Nucleus sampling fixes this with a confidence-tracking cutoff. Sort the distribution descending, take the smallest prefix whose cumulative mass is at least pp, throw the rest away.

π=argsort(p),Cj=i=1jpπi,n=min{j:Cjp}\pi = \mathrm{argsort}_\downarrow(p),\quad C_j = \sum_{i=1}^{j} p_{\pi_i},\quad n^* = \min\{j : C_j \geq p\}

Keep the first nn^* tokens; renormalize. The widget below shows it for two example distributions. Switch between “confident” (one peaked bar) and “long-tailed” (many small bars). Drag the gold pp-line and watch the nucleus grow when the distribution is uncertain and shrink when it’s confident.

p 0.90
nucleus size 4
tokens kept 4 / 6
kept mass 0.930
0.250.500.751.00p = 0.90keptcum massrenormalized

Bars are the sorted descending probabilities. The dashed teal line is cumulative mass: it climbs from the first bar's probability to 1 by the last. Drag the gold p-line vertically: every bar whose cumulative mass is at or below it stays; the rest are masked. Coral ticks show the renormalized distribution over the kept (nucleus) tokens. The nucleus shrinks when the model is confident and grows when it isn't; that's what "top-p" means.

In production this is typically top_p = 0.9, occasionally 0.95. The kept set sizes itself dynamically.

Nucleus on the confident step

The sorted distribution is (0.50,0.20,0.15,0.08,0.05,0.02)(0.50,\, 0.20,\, 0.15,\, 0.08,\, 0.05,\, 0.02). With p=0.9p = 0.9, what is the nucleus size?

The pipeline order

Three knobs, applied in a fixed order. Conventional, replicated by every major LM library:

  1. Temperature. Reshape the entire distribution. Sharper or flatter, but full-vocab.
  2. Top-k. Truncate by rank. The bottom Vk\lvert V \rvert - k tokens get -\infty.
  3. Top-p. Truncate by cumulative mass on what’s left. Tail tokens that survived top-k but fall outside the nucleus get -\infty.
  4. Sample. One softmax over what’s left, one categorical draw.
def sample(logits, tau, k, p):
    logits = logits / tau                       # 1. temperature
    if k > 0:
        kth = torch.topk(logits, k).values[..., -1]
        logits[logits < kth] = -float('inf')    # 2. top-k
    if p < 1.0:
        sorted_logits, sorted_idx = logits.sort(descending=True)
        cum = sorted_logits.softmax(-1).cumsum(-1)
        mask = cum > p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = False                     # always keep the top token
        logits = logits.masked_fill(
            sorted_idx[mask], -float('inf'))     # 3. top-p
    probs = logits.softmax(-1)
    return probs                                 # 4. sample externally

Three transforms applied to the logit vector, then one categorical draw. None of them know about the others; they each just write into the same array.

Why we don't use beam search

Beam search tracks the top-BB partial sequences by joint probability, expanding each one step at a time, pruning to keep only the best BB overall. It is the dominant decoding algorithm for translation, where the source sentence tightly constrains the target distribution and the joint maximum-likelihood path is genuinely the right path.

Open-ended language modeling is the wrong setting for it. The Holtzman finding from earlier in this lesson holds: the joint MLE path concentrates on degenerate cycles. Beam search makes this worse, not better: it explicitly searches for the highest-likelihood path, which is the path most likely to repeat itself. GPT-2/3/4 do not use beam search. LLaMA does not. Sampling-based decoding is the rule for generative text.

Beam search remains the right choice for translation, summarization, and other “constrained generation” tasks where there is a single best answer. It is the wrong choice everywhere creativity matters.

The KV cache, again

Every step of the autoregressive loop runs the M16 forward pass on a growing prefix. Naively, that means at step tt the model recomputes keys and values for all tt tokens, even though the keys and values for tokens 0 to t1t-1 haven’t changed since the previous step.

The KV cache (M15) avoids that recompute. Maintain a per-layer cache of every key and value vector ever computed; at each step, only project the new token’s key and value, then run attention against the cached prefix. Same final output, much less arithmetic.

step 0 / 12
cache

K-cache (one row per token)

k1
k2
k3
k4
k5
k6
k7
k8
k9
k10
k11
k12

V-cache (one row per token)

v1
v2
v3
v4
v5
v6
v7
v8
v9
v10
v11
v12
cache on: each token's K and V are computed once at its step (yellow flash) and reused on every subsequent step. Past columns stay filled.
cumulative work, cache on 0 flops
cumulative work, cache off 0 flops
savings ratio 1.00×

work scales as with the cache, without. ratio grows linearly with sequence length; at T = 1024 the cache buys roughly a 1000× speedup.

The cumulative cost without a cache scales as O(T2)O(T^2): at every step you redo all the past work. With the cache, every step is O(T)O(T), dominated by the new token’s attention over the cache.

This is inference-only. Training already has all positions in parallel inside one matrix multiply; there is no “past” to cache.

The cost of not caching

Without a KV cache, generating TT tokens has total compute complexity O(T?)\mathcal{O}(T^?). What is the exponent?

Seed, prompt, sampler: the full control surface

Once the weights are frozen, three things, and only three things, change the output:

  1. The prompt. What you condition on.
  2. The sampler settings. τ\tau, kk, pp.
  3. The random seed. The PRNG state from which multinomial draws.

Same seed + same sampler + same prompt = same output, byte-for-byte. Change any one of those and you take a different path through the same model. Change the seed alone and the model says a different thing while remaining the same model.

This is the entire interface for “controlling” a generative LM at inference time. There is no fourth knob. There is no “make the model smarter.” There is only re-sampling the same distribution differently.

That distribution came from training. Training came from the loss curve. The loss curve came from the data loader. The data loader read tokens from a tokenizer. The tokenizer was a list of merges fit to a corpus.

That is the whole stack: bytes to vectors, vectors to weights, weights back to bytes. Next module: you train one in your browser.

Lesson complete

Nice tinkering.