Optimization · 20 min

Per-parameter Steps: RMSProp, and then Adam

Build Adam in pieces. Start with momentum. Add a per-parameter adaptive step via an EMA of squared gradients. Add bias correction for the cold start. Now you have the optimizer that trains transformers.

0 / 0

Different parameters, different gradient scales

A neural network has millions of parameters. Not all of them want the same learning rate.

Consider a word-embedding table. The row for “the” gets nudged every time “the” appears in a training example: hundreds of times per batch. Its gradient is big and frequent. The row for “gazpacho” gets nudged maybe once per epoch. Its gradient is small and rare.

A single global η\eta forces a compromise. Big enough for “gazpacho” to move? Too big for “the”: training diverges. Small enough for “the” to be stable? “Gazpacho” never learns.

You need a per-parameter step size that shrinks for parameters with big gradients and grows for parameters with small ones.

RMSProp: divide by the running RMS

The idea: for each parameter ii, keep an EMA of the squared gradient component gi2g_i^2. The square root of that EMA is a rough estimate of how big gi|g_i| typically is. Divide the update by it.

Let sts_t be a vector holding the EMA of squared gradients elementwise:

st  =  β2st1  +  (1β2)gt2s_t \;=\; \beta_2\, s_{t-1} \;+\; (1 - \beta_2)\, g_t^{\odot 2}

where gt2g_t^{\odot 2} is elementwise squaring and β20.999\beta_2 \approx 0.999 is a second momentum coefficient (longer memory than the β\beta for velocity, since squared-gradient statistics are noisy and benefit from more smoothing).

Then the update is

wt+1  =  wt    ηgtst+ε.\mathbf{w}_{t+1} \;=\; \mathbf{w}_t \;-\; \eta \, \frac{g_t}{\sqrt{s_t} + \varepsilon}.

The ε\varepsilon (108\sim 10^{-8}) avoids division by zero. The division is elementwise. Each parameter’s step is automatically scaled by its own recent gradient magnitude.

RMSProp update

With β2=0.9\beta_2 = 0.9, s0=0s_0 = 0, and gradient sequence g1=2g_1 = 2, g2=0g_2 = 0, compute s2s_2.

(Use the (1β2)(1 - \beta_2)-weighted convention: st=0.9st1+0.1gt2s_t = 0.9 s_{t-1} + 0.1 g_t^2.)

Now add momentum back: that's Adam, almost

Adam combines both: momentum on the gradient itself (smoothed direction) and RMSProp (adaptive per-parameter scale).

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_tvt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^{\odot 2}

mm is the momentum (EMA of gradients). vv is the RMSProp (EMA of squared gradients). The tentative update:

wt+1  =  wt    ηmtvt+ε.\mathbf{w}_{t+1} \;=\; \mathbf{w}_t \;-\; \eta \, \frac{m_t}{\sqrt{v_t} + \varepsilon}.

This is almost Adam. There’s a subtle problem at the start of training.

The cold-start problem, made visible.

mtm_t and vtv_t both start at zero. At step 1, m1=(1β1)g1=0.1g1m_1 = (1 - \beta_1) g_1 = 0.1 g_1 for β1=0.9\beta_1 = 0.9. That’s ten times too small. The momentum estimate is biased toward zero for the first few steps: the EMA hasn’t had time to accumulate.

Same for vv, which is worse: β2=0.999\beta_2 = 0.999, so vv is biased for thousands of steps.

Below: scrub the step counter tt. The first two bars are the raw mtm_t and vtv_t: they crawl up from zero. The third and fourth are the bias-corrected versions: 1 from step 1, exactly. The whole job of bias correction is to make those bars unit-scale immediately.

scrub step t · watch raw vs bias-corrected

grad = 1 each step · β₁ = 0.90 · β₂ = 0.999

target = 1m_t (raw)0.10v_t (raw)0.00m̂_t (corrected)1.00v̂_t (corrected)1.00

1 / (1 − β₁t) = 10.00× 1 / (1 − β₂t) = 1000.00× the bigger the correction factor, the more biased the raw estimate is.

Watch the raw vtv_t in particular when β2=0.999\beta_2 = 0.999. It takes hundreds of steps just to look reasonable. Without correction, Adam’s adaptive denominator would be wildly wrong for the entire start of training.

Bias correction: the fix

For an EMA with decay β\beta and zero init, the bias-corrected estimate is

m^t  =  mt1β1t,v^t  =  vt1β2t.\hat{m}_t \;=\; \frac{m_t}{1 - \beta_1^t}, \qquad \hat{v}_t \;=\; \frac{v_t}{1 - \beta_2^t}.

At step 1, 1β11=0.11 - \beta_1^1 = 0.1, so m^1=m1/0.1=g1\hat{m}_1 = m_1 / 0.1 = g_1. The corrected estimate is the true observation.

At step 100, 10.91001 - 0.9^{100} is essentially 1. The correction becomes a no-op. As the EMA “warms up,” the correction fades.

vv with β2=0.999\beta_2 = 0.999 takes much longer to warm up (thousands of steps) so bias correction matters there for much longer.

Bias-corrected first step

With β1=0.9\beta_1 = 0.9, m0=0m_0 = 0, and g1=0.1g_1 = 0.1, compute the bias-corrected m^1\hat{m}_1.

Adam: the whole thing

Bolt it all together. At each step tt:

  1. Compute the gradient gtg_t.
  2. Update momentum: mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t.
  3. Update variance: vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^{\odot 2}.
  4. Bias-correct: m^t=mt/(1β1t)\hat{m}_t = m_t / (1 - \beta_1^t), v^t=vt/(1β2t)\hat{v}_t = v_t / (1 - \beta_2^t).
  5. Update weights: wt+1=wtηm^t/(v^t+ε)\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \, \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon).

Default hyperparameters: β1=0.9\beta_1 = 0.9, β2=0.999\beta_2 = 0.999, ε=108\varepsilon = 10^{-8}, η=103\eta = 10^{-3}. These work out of the box on an astonishing range of problems. The reason: bias correction makes each coordinate’s effective step ηsign(gi)\approx \eta \cdot \mathrm{sign}(g_i) early in training, regardless of gradient magnitude. Adam has automatic unit-step behavior.

Adam isn't always best

You’ll see “just use Adam” as default advice. That’s mostly right. But it’s worth knowing when it’s wrong.

On certain convex tasks and some computer-vision architectures, plain SGD with momentum consistently generalizes better than Adam, even when Adam’s training loss is lower. Wilson et al. 2017 gave a famous linearly-separable classification example where Adam converges to a solution with ~50% test error while SGD achieves zero.

The hand-wave: Adam’s adaptive per-parameter scaling changes the geometry and can converge to sharper minima that generalize worse. For transformers (the focus of this course) Adam (or its sibling AdamW) is essentially always the right choice. For vision ResNets, SGD+momentum often wins.

Default to Adam. Know that “default” isn’t “always.”

One more twist: AdamW

Adam with weight decay (the ηλw-\eta \lambda \mathbf{w} regularizer that shrinks weights toward zero) has a subtle bug. If you add the decay to the gradient (as gg+λwg \leftarrow g + \lambda w), then Adam’s adaptive denominator v^\sqrt{\hat{v}} rescales it, and the effective regularization becomes per-parameter: large-gradient parameters get less regularization, small-gradient parameters get more. Almost always the opposite of what you want.

AdamW (Loshchilov & Hutter, 2017) fixes this by applying the decay outside the adaptive step:

wt+1  =  wt    η ⁣(m^tv^t+ε  +  λwt).\mathbf{w}_{t+1} \;=\; \mathbf{w}_t \;-\; \eta\!\left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \varepsilon} \;+\; \lambda \mathbf{w}_t \right).

This is the optimizer every transformer in production uses. nanoGPT uses it. GPT-3 used it. Llama uses it. The W earns its letter.

Lesson complete

Nice tinkering.