Training Dynamics & Modern Tricks · 15 min

Normalize what, exactly?

BatchNorm and LayerNorm are the same operation parameterized by which axes you reduce over. Pick the axis whose statistics are stable.

0 / 0

Good init isn't enough

You picked your weights with He scaling. Layer-1 activations have unit variance, layer-10 activations have unit variance. You start training. After a thousand steps, weights have moved. The careful variance-preservation argument from the previous lesson assumed initial weights. The trained ones are different.

In a deep network, even with perfect initialization, the activation distributions at each layer drift around during training. Some layers get systematically larger, some smaller, some skew. The next layer was happy with the old distribution; suddenly the inputs look different and it has to relearn how to use them. Training slows down or breaks.

The fix the field landed on, around 2015: at every layer, normalize the activations during training. Force them back to a known scale. This is what BatchNorm and LayerNorm do.

BatchNorm: the formula

BatchNorm normalizes across the batch dimension, one feature at a time. For a minibatch of mm examples, with xix_i the value of one feature on the ii-th example:

μB  =  1mi=1mxi,σB2  =  1mi=1m(xiμB)2.\mu_B \;=\; \frac{1}{m}\sum_{i=1}^m x_i, \qquad \sigma_B^2 \;=\; \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2.

Then normalize each example’s value of that feature:

x^i  =  xiμBσB2+ϵ.\hat x_i \;=\; \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}.

The ϵ\epsilon (typically 10510^{-5}) is a tiny constant that prevents divide-by-zero when the batch variance is zero. Now x^\hat x has mean 0 and variance 1 across the batch, by construction.

Finally (and this matters): BN learns a per-feature affine transform on top:

yi  =  γx^i  +  β.y_i \;=\; \gamma\, \hat x_i \;+\; \beta.

γ\gamma and β\beta are trainable parameters. Why? So the network can undo the normalization if it turns out the original distribution was useful. With γ=σB2+ϵ\gamma = \sqrt{\sigma_B^2 + \epsilon} and β=μB\beta = \mu_B, you’ve recovered the input exactly. BN never destroys information; it just gives the network the option to choose a different scale.

Train mode and inference mode are not the same

The above formula uses batch statistics. During training that’s fine; we have a batch.

At inference time you might be feeding in a single example. There is no batch. So BN keeps a running exponential moving average of the per-feature mean and variance during training:

μrun    (1α)μrun  +  αμB\mu_{\text{run}} \;\leftarrow\; (1 - \alpha)\,\mu_{\text{run}} \;+\; \alpha\,\mu_Bσrun2    (1α)σrun2  +  ασB2.\sigma^2_{\text{run}} \;\leftarrow\; (1 - \alpha)\,\sigma^2_{\text{run}} \;+\; \alpha\,\sigma_B^2.

At inference, BN swaps in those stored stats and normalizes deterministically:

x^  =  xμrunσrun2+ϵ.\hat x \;=\; \frac{x - \mu_{\text{run}}}{\sqrt{\sigma^2_{\text{run}} + \epsilon}}.

The framework controls which mode BN is in via a flag (PyTorch’s model.eval() flips this; model.train() flips it back). If you forget to switch modes at inference, BN computes batch statistics from whatever you happen to be feeding it. With a batch of one, this is catastrophic.

The most common BN bug, exactly

A network is trained on real batches. You then run inference on a single example. You forget to switch to eval mode.

BN computes batch statistics from the single example: μB=x\mu_B = x exactly, and σB2=0\sigma_B^2 = 0 exactly. So

x^  =  xx0+ϵ  =  0.\hat x \;=\; \frac{x - x}{\sqrt{0 + \epsilon}} \;=\; 0.

Every channel is zeroed. The output of the BN layer is just γ0+β=β\gamma \cdot 0 + \beta = \beta. The learned biases are all you get back. The signal you fed in is gone.

The fix is one line: call model.eval() before inference. This is the single most common BN bug, and you have now seen it derived from the formula.

From this batch

μ_B
4.582
σ²_B
0.881
x̂ (train)
2.576

Running stats (steps: 0)

μ_run
4.500
σ²_run
1.960
x̂ (eval)
1.786

The batch sits on the number line (small dots). μ_B moves with each resample; μ_run moves only when you press training step. The active panel (train or eval) is what BatchNorm uses depending on which mode the framework is in. Drop the batch size to 1 in train mode and the output collapses to zero. Switch to eval and the running stats save you.

Drag batch size to 1, click training step a few times in train mode, and watch x̂ for the test point collapse to zero; the batch statistics are degenerate. Switch to eval at batch size 1 and the running stats save you. The asymmetry between modes is the bug.

BN with batch size 1

A BatchNorm layer is run in training mode on a batch of size 1. The single example has feature value 7. What does x^\hat x equal after BN normalizes it? (Take ϵ0\epsilon \to 0.)

BN's real mechanism (it's not what the original paper said)

The 2015 BN paper argued that BN works by reducing “internal covariate shift,” the idea that each layer’s input distribution drifts during training and BN re-stabilizes it.

Three years later Santurkar et al. demonstrated this story is essentially wrong. They artificially injected non-stationary noise after the BN layer (so distributions still drift) and training was unaffected. The covariate-shift narrative does not predict the experiment.

Their alternative explanation, which has held up better: BN makes the loss landscape smoother. It tightens the Lipschitz bounds on the loss and on the gradient, which makes gradient steps more reliable predictors of where the loss will actually go. You can take bigger steps without overshooting.

The mechanism is “more predictive gradients,” not “stationary distributions.” BN is a smoother of the optimization problem, not a stabilizer of activations.

LayerNorm: same shape, different axis

Now flip the axes. LayerNorm normalizes across the feature dimension, one example at a time. For a single example with feature vector of length HH:

μ  =  1Hj=1Hxj,σ2  =  1Hj=1H(xjμ)2.\mu \;=\; \frac{1}{H}\sum_{j=1}^H x_j, \qquad \sigma^2 \;=\; \frac{1}{H}\sum_{j=1}^H (x_j - \mu)^2.

Then x^j=(xjμ)/σ2+ϵ\hat x_j = (x_j - \mu)/\sqrt{\sigma^2 + \epsilon}, and again learn an affine yj=γjx^j+βjy_j = \gamma_j \hat x_j + \beta_j, but now γ\gamma and β\beta are per-feature vectors, not per-channel scalars.

The crucial difference: LN does not care about the batch. It computes its statistics from a single example. There are no running stats. There is no train/eval discrepancy. With batch size 1, it works fine.

This is exactly why transformers prefer LN.

LayerNorm by hand

Apply LayerNorm to the token vector x=[1,2,3]x = [1, 2, 3] with γ=1\gamma = 1, β=0\beta = 0, ϵ=0\epsilon = 0. What is x^2\hat x_2? (the middle entry, after normalization)

When to use BN and when to use LN

The taxonomy is “which axes do you reduce over.”

BN: reduce over the batch dimension, normalize per feature. Default for vision and CNNs with stable, large batches. Fragile when the batch is small or non-i.i.d., as you’ve seen in the size-1 disaster.

LN: reduce over the feature dimension, normalize per example. Default for transformers. Variable sequence lengths, padded tokens, and single-example inference would all corrupt BN’s batch statistics; LN is immune to all of it.

The picture generalizes: GroupNorm reduces over a group of features per example. InstanceNorm reduces over each individual feature map per example. RMSNorm drops the mean term entirely and just rescales by root-mean-square, used by LLaMA-family models, slightly cheaper:

yj  =  xj1Hkxk2+ϵγj.y_j \;=\; \frac{x_j}{\sqrt{\frac{1}{H}\sum_k x_k^2 + \epsilon}}\,\gamma_j.

They are all the same formula. The only thing that differs is which axes go into the average.

What this looks like in the transformer

Every sub-block in a pre-LN transformer is exactly:

x = x + Attention(LayerNorm(x))
x = x + MLP(LayerNorm(x))

Two LayerNorms per block. No BatchNorm anywhere. The next lesson covers the + x part (residual connections) and the warmup schedule that makes the whole thing trainable.

Lesson complete

Nice tinkering.