The harmless-looking constant
The boxed equation from the last lesson had a in the denominator that we waved off. Today’s lesson is about why it’s there.
When you read transformer code for the first time, looks like the kind of magic constant you skim past: yet another normalization fiddle, like an Adam or a layer-norm . It is not. Without it, attention silently breaks at any reasonable model dimension, and the failure mode is invisible until you go looking for it.
The argument is one paragraph of probability theory and one slider on a widget.
The variance of a dot product
Suppose the components of and are independent, with mean 0 and variance 1. (At initialization this is approximately true for any sensibly-initialized linear projection, Xavier or He from m13.) Then:
The expectation is zero. So far so quiet. But the variance is not:
The standard deviation grows like . At it’s about 1.4. At it’s 8. At (typical for a real transformer) it’s roughly 22.
Now feed numbers with standard deviation 22 into a softmax.
Stdev of an unscaled score
At with components , what is the standard deviation of an unscaled score ?
Watch the softmax die
Softmax of a vector with very large magnitude collapses to a one-hot: the argmax position approaches 1, everything else approaches 0. Once that happens, the gradient through the softmax with respect to all non-argmax positions is essentially zero. Those keys can no longer learn from the loss. They’re frozen.
Start with and the scale off. The softmax is reasonably diffuse; most positions are getting a nonzero weight, the entropy is healthy. Now drag up. At the distribution starts pulling toward one position. By it’s almost a one-hot. The “gradient survival” indicator drops from healthy through fading to dying and finally dead. Resample to convince yourself it isn’t a fluke.
That is what a real transformer’s attention layer does at typical model dimensions, if you forget the scale. Most of the gradient signal disappears, and only one key per query trains. The model learns nothing from the rest of the sequence.
The patch
Divide the scores by before the softmax. That is the entirety of the fix:
What the scale does, at the level of the variance argument:
The variance is now regardless of . The softmax sees roughly the same magnitude of input at every model size. Saturation does not happen at scale; entropy does not collapse; gradients survive on every key.
Flip the toggle in the widget to ÷ √dₖ and ramp all the way up. The softmax stays diffuse. The argmax-favoring concentration that was killing gradients disappears.
Variance after the scale
With the scale, what is the variance of the score (assuming unit-variance components)?
Why this scale, of all the scales
A working scale needs to fix the variance to a constant. Several choices accomplish that: dividing by the magnitude of each key, adding a learned per-layer temperature, or replacing the dot product with an additive scoring function (Bahdanau-style). All of them have been tried.
Scaled dot-product attention won for two reasons. First, it is cheap: a single multiply-add per element, vs. an MLP per pair for additive attention. Second, the scaling factor is static: no learned parameters, no per-batch statistics, no eval-vs-train asymmetry. You set it once at architecture time and it is correct forever.
The whole derivation rests on one assumption: that the components of and are roughly unit-variance. If your initialization or layer norm strays from that, the scale stops being correct. This is one of the quiet reasons that initialization and normalization (m13) and attention (this module) are pieces of the same puzzle: they are jointly responsible for keeping pre-softmax magnitudes well-behaved at every depth.
Next lesson: now that the scaled dot product is locked in, we widen the operation. One query against keys becomes queries against keys: the matrix form, computed in parallel, with three linear projections of the same input.
Lesson complete