Attention · 12 min

Why divide by √dₖ

At high dimension, the dot product q·k has a large variance. Without correction, softmax saturates to one-hot and gradients to all but the argmax die. Divide by √dₖ and the score variance returns to 1, and the model can keep learning.

0 / 0

The harmless-looking constant

The boxed equation from the last lesson had a dk\sqrt{d_k} in the denominator that we waved off. Today’s lesson is about why it’s there.

When you read transformer code for the first time, dk\sqrt{d_k} looks like the kind of magic constant you skim past: yet another normalization fiddle, like an Adam ϵ\epsilon or a layer-norm γ\gamma. It is not. Without it, attention silently breaks at any reasonable model dimension, and the failure mode is invisible until you go looking for it.

The argument is one paragraph of probability theory and one slider on a widget.

The variance of a dot product

Suppose the components of qq and kk are independent, with mean 0 and variance 1. (At initialization this is approximately true for any sensibly-initialized linear projection, Xavier or He from m13.) Then:

E[qk]  =  i=1dkE[qi]E[ki]  =  0\mathbb{E}[q\cdot k] \;=\; \sum_{i=1}^{d_k} \mathbb{E}[q_i]\,\mathbb{E}[k_i] \;=\; 0

The expectation is zero. So far so quiet. But the variance is not:

Var(qk)  =  i=1dkVar(qiki)  =  i=1dk1  =  dk\mathrm{Var}(q\cdot k) \;=\; \sum_{i=1}^{d_k} \mathrm{Var}(q_i k_i) \;=\; \sum_{i=1}^{d_k} 1 \;=\; d_k

The standard deviation grows like dk\sqrt{d_k}. At dk=2d_k = 2 it’s about 1.4. At dk=64d_k = 64 it’s 8. At dk=512d_k = 512 (typical for a real transformer) it’s roughly 22.

Now feed numbers with standard deviation 22 into a softmax.

Stdev of an unscaled score

At dk=64d_k = 64 with components N(0,1)\sim \mathcal{N}(0, 1), what is the standard deviation of an unscaled score qkq\cdot k?

Watch the softmax die

Softmax of a vector with very large magnitude collapses to a one-hot: the argmax position approaches 1, everything else approaches 0. Once that happens, the gradient through the softmax with respect to all non-argmax positions is essentially zero. Those keys can no longer learn from the loss. They’re frozen.

scores (q · kj / √d<sub>k</sub>)

k1
-0.11
k2
+0.29
k3
+0.85
k4
+1.01
k5
-0.30
k6
-1.17
k7
+0.32
k8
+1.12

softmax weights αj

α1
7.0%
α2
10.4%
α3
18.2%
α4
21.5%
α5
5.8%
α6
2.4%
α7
10.7%
α8
23.9%
max α 24%
entropy 1.90 / 2.08
gradient survival healthy

Start with dk=8d_k = 8 and the scale off. The softmax is reasonably diffuse; most positions are getting a nonzero weight, the entropy is healthy. Now drag dkd_k up. At dk=64d_k = 64 the distribution starts pulling toward one position. By dk=256d_k = 256 it’s almost a one-hot. The “gradient survival” indicator drops from healthy through fading to dying and finally dead. Resample to convince yourself it isn’t a fluke.

That is what a real transformer’s attention layer does at typical model dimensions, if you forget the scale. Most of the gradient signal disappears, and only one key per query trains. The model learns nothing from the rest of the sequence.

The patch

Divide the scores by dk\sqrt{d_k} before the softmax. That is the entirety of the fix:

αj  =  softmaxj ⁣(qkjdk)\alpha_j \;=\; \mathrm{softmax}_j\!\left(\frac{q\cdot k_j}{\sqrt{d_k}}\right)

What the scale does, at the level of the variance argument:

Var ⁣(qkdk)  =  Var(qk)dk  =  1\mathrm{Var}\!\left(\tfrac{q\cdot k}{\sqrt{d_k}}\right) \;=\; \frac{\mathrm{Var}(q\cdot k)}{d_k} \;=\; 1

The variance is now 11 regardless of dkd_k. The softmax sees roughly the same magnitude of input at every model size. Saturation does not happen at scale; entropy does not collapse; gradients survive on every key.

Flip the toggle in the widget to ÷ √dₖ and ramp dkd_k all the way up. The softmax stays diffuse. The argmax-favoring concentration that was killing gradients disappears.

Variance after the scale

With the 1/dk1/\sqrt{d_k} scale, what is the variance of the score qk/dk\, q\cdot k / \sqrt{d_k}\, (assuming unit-variance components)?

Why this scale, of all the scales

A working scale needs to fix the variance to a constant. Several choices accomplish that: dividing by the magnitude of each key, adding a learned per-layer temperature, or replacing the dot product with an additive scoring function wtanh(Wqq+Wkk)\mathbf{w}^\top\tanh(W_q q + W_k k) (Bahdanau-style). All of them have been tried.

Scaled dot-product attention won for two reasons. First, it is cheap: a single multiply-add per element, vs. an MLP per pair for additive attention. Second, the scaling factor dk\sqrt{d_k} is static: no learned parameters, no per-batch statistics, no eval-vs-train asymmetry. You set it once at architecture time and it is correct forever.

The whole derivation rests on one assumption: that the components of qq and kk are roughly unit-variance. If your initialization or layer norm strays from that, the scale stops being correct. This is one of the quiet reasons that initialization and normalization (m13) and attention (this module) are pieces of the same puzzle: they are jointly responsible for keeping pre-softmax magnitudes well-behaved at every depth.

Next lesson: now that the scaled dot product is locked in, we widen the operation. One query against TT keys becomes TT queries against TT keys: the matrix form, computed in parallel, with three linear projections of the same input.

Lesson complete

Nice tinkering.