Attention · 10 min

Multi-head: parallel subspaces, same budget

A single attention head averages; that's its job. Several heads, run in parallel on disjoint subspaces of the same model dimension, give the layer a way to attend in several different patterns at once. Same total parameters, same total FLOPs, structural diversity for free.

0 / 0

A single head only does one thing

The attention layer you have built so far computes one weighted average per query. That weighted average can emphasize different keys for different queries, but the kind of relationship it captures (syntactic agreement? coreference? long-range topic? next-token continuation?) is one fixed thing per layer.

Trained transformers, when probed, turn out to want several attention patterns at once. One head specializes in attending to the previous token. Another tracks subject-verb agreement. Another keeps an eye on the start-of-sequence token. A single head can’t do all of these simultaneously because softmax is a single probability distribution per query.

The fix is to run several attention heads in parallel, each operating on a different subspace of the same input.

Reshape, don't widen

Take the model dimension dd and split it into hh equal chunks of width dk=d/hd_k = d / h. Each chunk is one head’s working subspace.

XRT×d    h  heads  ×  RT×dkX \in \mathbb{R}^{T \times d} \;\longrightarrow\; h \;\text{heads}\;\times\; \mathbb{R}^{T \times d_k}

Each head gets its own WQ(j),WK(j),WV(j)W_Q^{(j)}, W_K^{(j)}, W_V^{(j)}, each of shape (d,dk)(d, d_k). Per head, the operation is exactly the scaled dot-product attention from the last lesson, just operating in a smaller subspace.

heads (h)
input X ∈ ℝ5×24
split d → (h, d/h) 2 × ℝ5×12
per-head attention 2 × ℝ5×5
concat & WO 5×24
head 1 dk = 12
head 2 dk = 12
WQ per head 24 × 12 = 288
total WQ params 576
single-head reference 24 × 24 = 576

same total parameters and FLOPs at every h. structural diversity is free.

Click through h=1,2,4,8h = 1, 2, 4, 8 and watch the shape passport at the top of the widget update. The model dimension stays at 24; what changes is how many heads slice it. The mini-heatmap for each head shows that head’s attention pattern on the same input. They look different, because each head sees a different subspace and attends accordingly.

Concat the heads, project them down

After each head computes its (T,dk)(T, d_k) output, concatenate along the feature axis:

Concat(head1,,headh)RT×d\mathrm{Concat}(\mathrm{head}_1, \ldots, \mathrm{head}_h) \in \mathbb{R}^{T \times d}

Then mix them through a learned output projection WORd×dW_O \in \mathbb{R}^{d \times d}:

MHA(X)=Concat(head1,,headh)WO\mathrm{MHA}(X) = \mathrm{Concat}(\mathrm{head}_1, \ldots, \mathrm{head}_h)\, W_O

That’s the whole multi-head attention layer. Output shape: (T,d)(T, d), same as the input, ready to feed into a residual connection.

The role of WOW_O is small but real. It is the only place where information from different heads can mix; without it, each head’s output would live in its own block of the feature axis forever, and downstream layers would have no way to combine them. Most practical implementations make WOW_O slightly bigger than necessary so the mixing has some capacity to play with.

Per-head dimension

With dmodel=512d_{\text{model}} = 512 and h=8h = 8 heads, what is dkd_k per head?

Param count is conserved

What is the ratio of total WQW_Q parameters in hh-head attention to WQW_Q parameters in single-head attention (with dk=dmodeld_k = d_{\text{model}})? Multiply across all heads.

Why this is free

Notice from the widget what doesn’t change as you crank hh up: the parameter count.

Each head’s WQW_Q is (d,d/h)(d, d/h), so the total WQW_Q across all heads is hd(d/h)=d2h \cdot d \cdot (d/h) = d^2. Identical to a single head with dk=dd_k = d. Same for WKW_K and WVW_V. The FLOPs for the attention computation also work out the same: each head does T2dkT^2 \cdot d_k work, and there are hh of them, giving hT2dk=T2dh \cdot T^2 \cdot d_k = T^2 \cdot d.

So multi-head doesn’t cost more parameters or compute than a single big head. What it gives you is several attention patterns operating in parallel. Empirically this matters: single-head transformers underperform multi-head transformers at the same total compute. The patterns that emerge across heads are more diverse than what a single softmax can capture, even one with the same total dimension to play with.

A small caveat: this lesson said “free”, but there is a subtle cost. With h=64h = 64, each head’s dk=d/64d_k = d/64 might be too small to have rich enough subspaces; you’d be pinching the per-head capacity below what it needs to model the relationships you want. Most production transformers settle on hh between 8 and 32, with dkd_k between 64 and 128. There is a sweet spot, and it isn’t at the extremes.

What's left

You now have the entire single-block attention operation: scaled dot-product, packed into matrix form, masked when needed, position-aware via PE, and run in parallel across hh heads. That is the core operation of a transformer.

What the next (and final) lesson of this module covers: how much this operation actually costs to run, and the one inference trick (the KV-cache) that makes streaming generation feasible at production sequence lengths.

Lesson complete

Nice tinkering.