Information Theory Basics · 12 min

Why classifiers have a softmax head

For one-hot targets, cross-entropy collapses to negative log-likelihood of the true class. Softmax exists because that loss needs a valid Q, and softmax-plus-NLL produces the cleanest gradient in all of supervised learning.

0 / 0

Labels are not distributions

In every supervised classification problem, the target is a single class, not a probability distribution over classes. For an MNIST digit, the label is 3. For an email, it’s spam or not-spam. For a token prediction, it’s “the literal next token in the corpus.” One outcome, full mass.

So when we plug that into the cross-entropy formula H(P,Q)=xP(x)logQ(x)H(P, Q) = -\sum_x P(x) \log Q(x), what happens? Let’s see, and then watch the structural simplification that follows.

The one-hot collapse

Let PP be one-hot on class yy: P(y)=1P(y) = 1, P(x)=0P(x) = 0 for all xyx \ne y. Substitute:

H(P,Q)  =  xP(x)logQ(x)  =  logQ(y).H(P, Q) \;=\; -\sum_{x} P(x) \log Q(x) \;=\; -\log Q(y).

Every term in the sum except one vanishes, because P(x)=0P(x) = 0 for every xyx \ne y. The single surviving term is the negative log probability your model gave to the right answer.

That’s negative log-likelihood under another name. We saw this in module 8 as the form of MLE for a categorical distribution. Here it shows up because the data distribution is degenerate (one-hot), and cross-entropy against a degenerate PP collapses to NLL of the true class. Three names (cross-entropy, NLL, log-loss) and one piece of math.

Watch 9 of 10 terms vanish, live

The widget below shows a 10-way classifier. The top half is the logit vector zz (drag any bar up or down). The middle is the softmax probability vector q=softmax(z)q = \mathrm{softmax}(z). The bottom row is the class index; click any of them to set the one-hot target.

00.0810.0520.1330.0640.0950.0860.0470.3480.0790.05click any class index to make it the target
Full cross-entropy sum (9 of 10 terms vanish)
0 · log(0.085) +0 · log(0.047) +0 · log(0.127) +0 · log(0.063) +0 · log(0.094) +0 · log(0.077) +0 · log(0.042) +1 · log(0.345) +0 · log(0.070) +0 · log(0.052) = 1.065 nats
target class 7
q_true 0.345
loss 1.065 nats

Look at the lower-left readout: the full cross-entropy sum has 10 terms, but 9 of them are greyed out; their P(xi)P(x_i) coefficient is zero. Only the red term (the target class) contributes. The collapsed form on the lower-right shows the same number, computed as a single logqtarget-\log q_{\text{target}}. Two formulas, one value.

Click any other class index to change the target. Watch which term survives. This is the entire structural argument: cross-entropy for a classifier is the negative log probability of the right answer, period.

Cross-entropy = NLL = maximum likelihood

Pulling the threads from m8 and the previous lesson together. For a classifier with parameters θ\theta:

L(θ)  =  H(P^data,Qθ)  =  1Ni=1NlogQθ(yi)\mathcal{L}(\theta) \;=\; H(\hat P_{\text{data}}, Q_\theta) \;=\; \frac{1}{N}\sum_{i=1}^{N} -\log Q_\theta(y_i)

Three sentences from three different traditions:

  • Information theory: “average cross-entropy of the model on the data.”
  • Statistics: “negative log-likelihood, divided by NN.”
  • ML: “the cross-entropy loss.”

All the same expression. Minimizing it is exactly maximum likelihood estimation, which is exactly minimizing forward KL up to an additive constant. Three vocabularies, one objective. Stop arguing about which name to use.

Compute loss for z = (2, 0, 1), target 0

Logits z=(2,0,1)z = (2, 0, 1) over three classes. True class is 00.

Compute the loss in nats (use the natural log; answer to three decimals).

The softmax + NLL gradient is q − p

Here’s the result that justifies why softmax-plus-NLL is the canonical pairing. Take the derivative of L=logqtarget\mathcal{L} = -\log q_{\text{target}} with respect to the logits zz:

Lzi  =  qipi.\frac{\partial \mathcal{L}}{\partial z_i} \;=\; q_i - p_i.

Where pp is the one-hot target vector. The gradient of the cross-entropy loss with respect to the logits is the predicted distribution minus the target distribution. Element-wise. Done.

Watch the orange arrows on each logit bar in the widget. Their length and direction is this gradient. The arrow on the target class points down (the loss wants the target logit higher); arrows on non-target classes point up (the loss wants those logits lower). When you drag the target logit up enough that qq matches the one-hot label, the arrows shrink to zero; gradient descent has nothing left to do.

The "qpq - p" form is the cleanest gradient in supervised learning. It scales linearly in confidence-of-error, vanishes at convergence, and never has the saturation problems that other loss-activation pairs suffer.

Gradient on the target logit

Using the same softmax from before (z=(2,0,1)z = (2, 0, 1), target class 0, q(0.665,0.090,0.245)q \approx (0.665, 0.090, 0.245)), what is L/z0\partial \mathcal{L} / \partial z_0? (Three decimals; sign matters.)

Why not softmax + MSE?

A reasonable question: why pair softmax with NLL specifically? Why not softmax outputs with mean-squared error against the one-hot label?

The answer is gradient behavior. With softmax + MSE, the gradient through softmax-then-squared-error involves a product of softmax derivatives, and when the model is confidently wrong (e.g., qtarget0q_{\text{target}} \approx 0), that gradient is small. The model fails to learn from its worst mistakes.

Softmax + NLL doesn’t have this pathology. The gradient qpq - p is large precisely when the model is confidently wrong (because qtargetq_{\text{target}} is far from 1). The worse the mistake, the bigger the correction signal. This is why every classification model on earth, from logistic regression to GPT-4’s final layer, ends in softmax + cross-entropy.

PyTorch's nn.CrossEntropyLoss vs nn.NLLLoss: same loss, different APIs

Working ML engineers see two names in PyTorch and assume they’re different objects:

  • nn.CrossEntropyLoss(logits, target_idx): applies log-softmax internally, then NLL.
  • nn.NLLLoss(log_probs, target_idx): assumes you’ve already computed log-softmax.

The math is identical. The two-name split is a numerical stability choice: computing log-softmax in one fused pass (the “log-sum-exp trick”) is more stable than computing softmax, then taking the log of small probabilities. In practice, pass raw logits to CrossEntropyLoss and let it handle the fused computation. That’s the right default in 99% of cases.

(There is one final source of confusion: CrossEntropyLoss in PyTorch takes a class index, not a one-hot vector. That’s not a different formula; it’s just an API exploiting the one-hot collapse to skip allocating a length-KK one-hot tensor. Pass an integer, get the loss.)

Where this lands: the entire final layer of GPT

A GPT-style transformer ends in two operations: a linear layer producing VV-dimensional logits over the vocabulary, and a softmax-plus-cross-entropy loss against the next-token index. Forward, backward, repeat for every position.

Every claim in this lesson maps directly:

  • The softmax is there to produce a valid QQ over the vocab.
  • The target is the integer index of the next token in the corpus (one-hot in disguise).
  • The loss is the negative log probability the model gave to the actual next token.
  • The gradient with respect to the logits is qpq - p, where pp is one-hot.
  • Average that gradient over the batch, backprop, take an Adam step.

That’s the whole loss function. The next module (m10 onwards) handles how to minimize it efficiently, but you now know exactly what it is.

Lesson 9.4 closes the module by giving you a unit-free way to read training losses: perplexity. And it introduces the floor every language model can never break: the entropy of the language itself.

Lesson complete

Nice tinkering.