Good init isn't enough
You picked your weights with He scaling. Layer-1 activations have unit variance, layer-10 activations have unit variance. You start training. After a thousand steps, weights have moved. The careful variance-preservation argument from the previous lesson assumed initial weights. The trained ones are different.
In a deep network, even with perfect initialization, the activation distributions at each layer drift around during training. Some layers get systematically larger, some smaller, some skew. The next layer was happy with the old distribution; suddenly the inputs look different and it has to relearn how to use them. Training slows down or breaks.
The fix the field landed on, around 2015: at every layer, normalize the activations during training. Force them back to a known scale. This is what BatchNorm and LayerNorm do.
BatchNorm: the formula
BatchNorm normalizes across the batch dimension, one feature at a time. For a minibatch of examples, with the value of one feature on the -th example:
Then normalize each example’s value of that feature:
The (typically ) is a tiny constant that prevents divide-by-zero when the batch variance is zero. Now has mean 0 and variance 1 across the batch, by construction.
Finally (and this matters): BN learns a per-feature affine transform on top:
and are trainable parameters. Why? So the network can undo the normalization if it turns out the original distribution was useful. With and , you’ve recovered the input exactly. BN never destroys information; it just gives the network the option to choose a different scale.
Train mode and inference mode are not the same
The above formula uses batch statistics. During training that’s fine; we have a batch.
At inference time you might be feeding in a single example. There is no batch. So BN keeps a running exponential moving average of the per-feature mean and variance during training:
At inference, BN swaps in those stored stats and normalizes deterministically:
The framework controls which mode BN is in via a flag (PyTorch’s model.eval() flips this; model.train() flips it back). If you forget to switch modes at inference, BN computes batch statistics from whatever you happen to be feeding it. With a batch of one, this is catastrophic.
The most common BN bug, exactly
A network is trained on real batches. You then run inference on a single example. You forget to switch to eval mode.
BN computes batch statistics from the single example: exactly, and exactly. So
Every channel is zeroed. The output of the BN layer is just . The learned biases are all you get back. The signal you fed in is gone.
The fix is one line: call model.eval() before inference. This is the single most common BN bug, and you have now seen it derived from the formula.
Drag batch size to 1, click training step a few times in train mode, and watch x̂ for the test point collapse to zero; the batch statistics are degenerate. Switch to eval at batch size 1 and the running stats save you. The asymmetry between modes is the bug.
BN with batch size 1
A BatchNorm layer is run in training mode on a batch of size 1. The single example has feature value 7. What does equal after BN normalizes it? (Take .)
BN's real mechanism (it's not what the original paper said)
The 2015 BN paper argued that BN works by reducing “internal covariate shift,” the idea that each layer’s input distribution drifts during training and BN re-stabilizes it.
Three years later Santurkar et al. demonstrated this story is essentially wrong. They artificially injected non-stationary noise after the BN layer (so distributions still drift) and training was unaffected. The covariate-shift narrative does not predict the experiment.
Their alternative explanation, which has held up better: BN makes the loss landscape smoother. It tightens the Lipschitz bounds on the loss and on the gradient, which makes gradient steps more reliable predictors of where the loss will actually go. You can take bigger steps without overshooting.
The mechanism is “more predictive gradients,” not “stationary distributions.” BN is a smoother of the optimization problem, not a stabilizer of activations.
LayerNorm: same shape, different axis
Now flip the axes. LayerNorm normalizes across the feature dimension, one example at a time. For a single example with feature vector of length :
Then , and again learn an affine , but now and are per-feature vectors, not per-channel scalars.
The crucial difference: LN does not care about the batch. It computes its statistics from a single example. There are no running stats. There is no train/eval discrepancy. With batch size 1, it works fine.
This is exactly why transformers prefer LN.
LayerNorm by hand
Apply LayerNorm to the token vector with , , . What is ? (the middle entry, after normalization)
When to use BN and when to use LN
The taxonomy is “which axes do you reduce over.”
BN: reduce over the batch dimension, normalize per feature. Default for vision and CNNs with stable, large batches. Fragile when the batch is small or non-i.i.d., as you’ve seen in the size-1 disaster.
LN: reduce over the feature dimension, normalize per example. Default for transformers. Variable sequence lengths, padded tokens, and single-example inference would all corrupt BN’s batch statistics; LN is immune to all of it.
The picture generalizes: GroupNorm reduces over a group of features per example. InstanceNorm reduces over each individual feature map per example. RMSNorm drops the mean term entirely and just rescales by root-mean-square, used by LLaMA-family models, slightly cheaper:
They are all the same formula. The only thing that differs is which axes go into the average.
What this looks like in the transformer
Every sub-block in a pre-LN transformer is exactly:
x = x + Attention(LayerNorm(x))
x = x + MLP(LayerNorm(x))Two LayerNorms per block. No BatchNorm anywhere. The next lesson covers the + x part (residual connections) and the warmup schedule that makes the whole thing trainable.
Lesson complete