Where the FLOPs actually go
The attention operation has three matmuls. Two of them (projecting into and concat-then- at the end) cost , the same as any linear layer. They are not the bottleneck.
The middle matmul () produces a matrix. Its cost is . So is the cost of multiplying that attention matrix by . Both of them scale quadratically in the sequence length, regardless of model size.
That T² term is the reason context windows are expensive. Doubling the input length quadruples attention compute. Going from 2,048 tokens to 32,768 tokens is a 256× cost increase on this part of the model, at fixed model size. This is the central economics fact about transformers, and it is why “long context” is a research-frontier problem and not a free upgrade.
The quadratic
By what factor does attention compute grow when you double the sequence length ?
Memory matches compute
The attention matrix itself is , that’s floats, per head, per layer. At , , , with 16-bit floats, this is roughly:
That number does not fit on a single H100 alongside the model weights and the KV cache. This is why long-context inference uses tricks like FlashAttention (which never materializes the full matrix in DRAM) or sliding-window attention (which sparsifies it). They are direct, named responses to the memory part of the same bottleneck.
For our purposes (a single attention layer at modest ) we don’t need those tricks yet. But the reason they exist sits in the line above: attention’s memory cost equals its FLOP cost, both quadratic, both ugly.
Inference has a structural shortcut
At inference time, a transformer generates one token at a time. Step 1 produces token 1 from the prompt; step 2 takes the prompt + token 1 and produces token 2; step 3 takes prompt + tokens 1–2 and produces token 3. And so on.
A naive implementation would re-run the entire prefix through the model at every step, recomputing , for tokens at step , even though those rows haven’t changed since the last step. The total cost across steps would be , cubic. At that’s already a billion-FLOP wall per layer, per head, per inference token.
Here is the structural observation that fixes it: at step , only the new token issues a query. All past tokens have already been “spoken to” and never query again. So we never need their queries. We do, however, need their keys and values; every new query has to score against all past keys and read against all past values. But and for past tokens don’t change once those tokens have been emitted.
So cache them. Compute and for each token once, when it’s emitted, and stash both in memory. On every subsequent step, reuse the cached rows instead of recomputing them.
Step the generator a few times with cache on. Each token’s K and V are computed once (yellow flash on first appearance) and stay filled. Now flip to cache off and step it. Every previous column flashes coral every step; that work is being paid for, repeatedly, at every new token. The cumulative work counters at the bottom diverge fast.
Speedup at 1024 tokens
Generating 1024 tokens autoregressively, what is the order-of-magnitude speedup ratio of cache-on vs cache-off? Round to the nearest power of 2.
What you trade for the speedup
The cache stores floats per inference batch (keys and values, every layer, every position, every model dimension). For a 32-layer model with at in bf16, that’s roughly:
Per request. That number is real money in a serving system: it limits how many concurrent users one GPU can hold in memory, and it grows linearly with both and model depth.
So the trade is memory for compute. You commit GB of KV-cache memory to avoid recomputing FLOPs of K, V projections per step. Every production inference engine (vLLM, TensorRT-LLM, your favorite chat product) does this trade. There is no plausible alternative for streaming generation at scale.
Where you are
You have built every piece of the attention layer, from first principles, in six lessons:
- Soft dictionary lookup: the operation, in scalar form, on one query.
- Why √dₖ: the variance-fix that keeps softmax from saturating at scale.
- Three projections of the same X: Q, K, V; matrix form; causal mask.
- Position, three ways: sinusoidal, learned absolute, RoPE.
- Multi-head, parallel subspaces: same budget, structural diversity.
- The cost and the cache: T² is the bill, KV-cache is how you pay it.
What ships in M16 is the transformer block: take this attention layer, wrap it in a residual connection, normalize at the right point, follow it with a position-wise MLP, and stack the result times. None of those four pieces is conceptually new; you’ve already built every one of them in earlier modules. The work in M16 is composition, not discovery.
You are one layer of glue away from a real transformer.
Lesson complete
Nice tinkering.
Before you go