Training Dynamics & Modern Tricks · 15 min

Why your gradients explode

Initialization is not folklore. It falls out of one line of variance algebra, and it's the difference between a network that trains and a network that sits dead at step zero.

0 / 0

A network that won't move

Build a 10-layer MLP. 200 hidden units per layer. Tanh activations. Initialize every weight from N(0,1)\mathcal{N}(0, 1), a “reasonable” looking standard-normal sample. Now feed in a batch of unit-variance inputs.

By layer 2 or 3, every tanh is pinned at ±1\pm 1. By layer 10, every neuron is fully saturated; derivatives are essentially zero everywhere. Now run backprop. The gradient is the chain of all those tiny derivatives multiplied together. It vanishes. The optimizer cannot move parameters whose gradient is numerically zero. The loss does nothing. Your network is dead, and the loss curve looks like the “dead curve” pathology from the previous lesson.

There is no error message. No NaN. The optimizer ran “successfully.” It just produced no learning. And the bug is one number: the standard deviation of the initial weights.

Why does it matter so much? Algebra.

σ_W 0.5000
layer-10 variance 0.777
saturation @ layer 10 60%
layer-1 ∂L/∂x 0.266
layer 1
−10+1
layer 3
−10+1
layer 5
−10+1
layer 7
−10+1
layer 9
−10+1
layer 10
−10+1

Each histogram shows the distribution of post-tanh activations across a batch of 192 inputs at one layer of a 10-layer MLP with 64 units per layer. Drag σW low and the activations decay toward zero; every histogram becomes a single spike at 0, gradients die. Drag it high and tanh saturates; every histogram becomes two walls at ±1, gradients also die. The presets land at the known good values; only Xavier and He keep the histograms stable across all 10 layers.

Drag the slider all the way to the right (σ_W ≈ 0.4) and watch every layer’s histogram collapse to a wall at ±1\pm 1: that’s tanh saturation. Drag to the far left (σ_W ≈ 0.01) and the activations decay toward zero. Hit the He preset and the histograms stabilize to a clean Gaussian at every depth. The layer-1 gradient readout in the corner tells you whether the network can learn at all.

Variance through one linear layer

A linear layer computes z=i=1nwixiz = \sum_{i=1}^{n} w_i x_i, where the xix_i are i.i.d. with mean 0 and variance σx2\sigma_x^2, and the wiw_i are i.i.d. with mean 0 and variance σw2\sigma_w^2, all independent.

Compute the variance of zz. Independence lets us drop cross-covariance terms:

Var(z)  =  i=1nVar(wixi)  =  i=1nE[wi2]E[xi2]  =  nσw2σx2.\text{Var}(z) \;=\; \sum_{i=1}^n \text{Var}(w_i x_i) \;=\; \sum_{i=1}^n \mathbb{E}[w_i^2]\,\mathbb{E}[x_i^2] \;=\; n\,\sigma_w^2\,\sigma_x^2.

So Var(z)=nVar(W)Var(x)\text{Var}(z) = n \cdot \text{Var}(W) \cdot \text{Var}(x). The variance of the output of one layer is nn times the variance of the inputs, scaled by the variance of the weights.

That extra factor of nn is what kills you. With σw2=1\sigma_w^2 = 1 (the bad init we started with) and n=200n = 200, every layer multiplies variance by 200. After ten layers, variance is 20010200^{10}. Tanh saturates; gradients die.

The variance-preservation rule

You want activations to keep roughly the same scale, layer after layer. So pick the variance of the weights to kill the factor of nn:

σw2  =  1n.\sigma_w^2 \;=\; \frac{1}{n}.

Now Var(z)=n1nVar(x)=Var(x)\text{Var}(z) = n \cdot \tfrac{1}{n} \cdot \text{Var}(x) = \text{Var}(x). Activation variance is preserved across the layer. Stack ten and the input still looks like the input.

This is Xavier initialization in its simplest form (also called Glorot, after Xavier Glorot). The forward-only version is σw2=1/nin\sigma_w^2 = 1/n_{\text{in}}. The symmetric version, which considers both forward and backward variance preservation, is σw2=2/(nin+nout)\sigma_w^2 = 2/(n_{\text{in}} + n_{\text{out}}). Either works; modern libraries default to the symmetric form.

Xavier in numbers

A fully-connected layer has 256 inputs and a tanh nonlinearity. Using Xavier (forward-only) initialization, what is Var(W)\text{Var}(W)? (Answer to four decimal places is fine.)

Why Xavier breaks for ReLU

The variance derivation assumed activations had mean 0. Tanh roughly does, since it’s symmetric around the origin. But ReLU is not symmetric. ReLU zeros out everything below zero. Half the inputs, in expectation, go to zero. So the actual variance of the output of a ReLU layer is half what the linear analysis predicts.

If you use Xavier with ReLU, you’ve lost half your variance per layer, and your activations decay toward zero exponentially with depth. Same kind of failure, opposite direction.

The fix is to compensate. Double the weight variance to make up for what ReLU eats:

σw2  =  2nin.\sigma_w^2 \;=\; \frac{2}{n_{\text{in}}}.

This is He initialization (also called Kaiming, after Kaiming He). It is the default for any layer followed by a ReLU or ReLU variant, which means it is the default for almost everything you’ll build.

He in numbers

A fully-connected layer has 512 inputs and a ReLU nonlinearity. Using He initialization, what is Var(W)\text{Var}(W)?

Why you can't initialize to zero

You might think “if random matters, what about all-zeros? It’s symmetric, deterministic, easy.”

It does not work. If every weight is zero (or any constant), every neuron in a layer computes the same output. Forward pass: identical activations. Backward pass: identical gradients. Update: identical step. The neurons stay identical forever. You haven’t built a hidden layer of nn neurons; you’ve built one neuron nn times, and the layer collapses to a single effective unit.

This is symmetry breaking: randomness exists not for any deep statistical reason but because two neurons that start the same and see the same gradients will always be the same. Random init kicks the network out of that symmetry on the first step.

The transformer's initialization, exactly

Modern transformers add one more wrinkle. Even with He init in each individual layer, the residual stream (the running sum x+F1(x)+F2(x)+x + F_1(x) + F_2(x) + \dots) has variance that grows with depth. The fix from the GPT-2 paper, used by every nanoGPT-shaped codebase since:

σW  =  0.022Nlayers\sigma_W \;=\; \frac{0.02}{\sqrt{2 N_{\text{layers}}}}

applied to the output projection of each residual sub-block. (The factor of 2 is because each transformer block has two residual additions: one around attention, one around the MLP.) Everything else in the network uses standard N(0,0.022)\mathcal{N}(0, 0.02^2).

This is what c_proj.weight in a nanoGPT codebase is doing when it scales by 1/2N1/\sqrt{2N}: it’s keeping the variance of the residual stream constant with depth, regardless of how many blocks you stack.

Initialization is one line of variance algebra. It is also the difference between a 12-layer transformer that converges and one that sits at chance forever.

Lesson complete

Nice tinkering.