Attention · 10 min

The cost and the cache

Attention is O(T²·dₖ), the number that defines modern LLM economics. At inference, only the new token issues a query, so K and V for past tokens can be cached. That single trick converts streaming generation from cubic to quadratic in sequence length.

0 / 0

Where the FLOPs actually go

The attention operation softmax(QK/dk)V\mathrm{softmax}(QK^\top / \sqrt{d_k}) V has three matmuls. Two of them (projecting XX into Q,K,VQ, K, V and concat-then-WOW_O at the end) cost O(Td2)O(T \cdot d^2), the same as any linear layer. They are not the bottleneck.

The middle matmul (QKQK^\top) produces a (T,T)(T, T) matrix. Its cost is O(T2dk)O(T^2 \cdot d_k). So is the cost of multiplying that attention matrix by VV. Both of them scale quadratically in the sequence length, regardless of model size.

That T² term is the reason context windows are expensive. Doubling the input length quadruples attention compute. Going from 2,048 tokens to 32,768 tokens is a 256× cost increase on this part of the model, at fixed model size. This is the central economics fact about transformers, and it is why “long context” is a research-frontier problem and not a free upgrade.

The quadratic

By what factor does attention compute grow when you double the sequence length TT?

Memory matches compute

The attention matrix itself is (T,T)(T, T), that’s T2T^2 floats, per head, per layer. At T=4096T = 4096, h=32h = 32, L=32L = 32, with 16-bit floats, this is roughly:

4096×4096×32×32×2  bytes    34 GB4096 \times 4096 \times 32 \times 32 \times 2 \;\text{bytes} \;\approx\; 34 \text{ GB}

That number does not fit on a single H100 alongside the model weights and the KV cache. This is why long-context inference uses tricks like FlashAttention (which never materializes the full T×TT \times T matrix in DRAM) or sliding-window attention (which sparsifies it). They are direct, named responses to the memory part of the same T2T^2 bottleneck.

For our purposes (a single attention layer at modest TT) we don’t need those tricks yet. But the reason they exist sits in the line above: attention’s memory cost equals its FLOP cost, both quadratic, both ugly.

Inference has a structural shortcut

At inference time, a transformer generates one token at a time. Step 1 produces token 1 from the prompt; step 2 takes the prompt + token 1 and produces token 2; step 3 takes prompt + tokens 1–2 and produces token 3. And so on.

A naive implementation would re-run the entire prefix through the model at every step, recomputing KK, VV for tokens 0t10 \ldots t-1 at step tt, even though those rows haven’t changed since the last step. The total cost across TT steps would be O(T3)O(T^3), cubic. At T=1024T = 1024 that’s already a billion-FLOP wall per layer, per head, per inference token.

Here is the structural observation that fixes it: at step tt, only the new token issues a query. All past tokens have already been “spoken to” and never query again. So we never need their queries. We do, however, need their keys and values; every new query has to score against all past keys and read against all past values. But KK and VV for past tokens don’t change once those tokens have been emitted.

So cache them. Compute KK and VV for each token once, when it’s emitted, and stash both in memory. On every subsequent step, reuse the cached rows instead of recomputing them.

step 0 / 12
cache

K-cache (one row per token)

k1
k2
k3
k4
k5
k6
k7
k8
k9
k10
k11
k12

V-cache (one row per token)

v1
v2
v3
v4
v5
v6
v7
v8
v9
v10
v11
v12
cache on: each token's K and V are computed once at its step (yellow flash) and reused on every subsequent step. Past columns stay filled.
cumulative work, cache on 0 flops
cumulative work, cache off 0 flops
savings ratio 1.00×

work scales as with the cache, without. ratio grows linearly with sequence length; at T = 1024 the cache buys roughly a 1000× speedup.

Step the generator a few times with cache on. Each token’s K and V are computed once (yellow flash on first appearance) and stay filled. Now flip to cache off and step it. Every previous column flashes coral every step; that work is being paid for, repeatedly, at every new token. The cumulative work counters at the bottom diverge fast.

Speedup at 1024 tokens

Generating 1024 tokens autoregressively, what is the order-of-magnitude speedup ratio of cache-on vs cache-off? Round to the nearest power of 2.

What you trade for the speedup

The cache stores 2LTd2 \cdot L \cdot T \cdot d floats per inference batch (keys and values, every layer, every position, every model dimension). For a 32-layer model with d=4096d = 4096 at T=4096T = 4096 in bf16, that’s roughly:

232409640962bytes    2.1 GB2 \cdot 32 \cdot 4096 \cdot 4096 \cdot 2 \,\text{bytes} \;\approx\; 2.1 \text{ GB}

Per request. That number is real money in a serving system: it limits how many concurrent users one GPU can hold in memory, and it grows linearly with both TT and model depth.

So the trade is memory for compute. You commit GB of KV-cache memory to avoid recomputing TdT \cdot d FLOPs of K, V projections per step. Every production inference engine (vLLM, TensorRT-LLM, your favorite chat product) does this trade. There is no plausible alternative for streaming generation at scale.

Where you are

You have built every piece of the attention layer, from first principles, in six lessons:

  • Soft dictionary lookup: the operation, in scalar form, on one query.
  • Why √dₖ: the variance-fix that keeps softmax from saturating at scale.
  • Three projections of the same X: Q, K, V; matrix form; causal mask.
  • Position, three ways: sinusoidal, learned absolute, RoPE.
  • Multi-head, parallel subspaces: same budget, structural diversity.
  • The cost and the cache: T² is the bill, KV-cache is how you pay it.

What ships in M16 is the transformer block: take this attention layer, wrap it in a residual connection, normalize at the right point, follow it with a position-wise MLP, and stack the result NN times. None of those four pieces is conceptually new; you’ve already built every one of them in earlier modules. The work in M16 is composition, not discovery.

You are one layer of glue away from a real transformer.

Lesson complete

Nice tinkering.