Why scalars do not scale
A linear layer that maps 768 inputs to 64 outputs holds weights, plus 64 biases. With micrograd, that is roughly fifty thousand separate Value objects, each with their own _backward closure, evaluated one at a time.
That is fine for teaching. It is catastrophic for actually running anything. The fix is to repackage all of those scalars into one matrix and do the entire layer’s worth of arithmetic in a single matmul. The same chain rule applies — but now the gradients on the edges are matrices, not scalars, and the closures have to handle that.
The good news: the rules don’t change. Local derivative times upstream, summed across paths. Only the types of the things being multiplied changed.
Matmul backward, by shape-check
For a single linear layer, the forward is
The chain rule says we need and in terms of (which we know, because it was deposited by everything downstream). You can derive these from the scalar rule plus index sums, but here is the cheap trick: shape-check.
We want to have the same shape as , which is . We have with shape and with shape . The only matmul that produces from those two is , since has shape .
The same argument forces , which has the right shape .
You can derive these from first principles by writing and computing via the chain rule. The shape-check argument arrives at the same answer in five seconds.
Shape of dW
With , , and :
What is the second dimension of ?
(The shape will be (100, ?). Type the second number.)
Elementwise nonlinearities, unchanged
For an elementwise operation like — relu, tanh, sigmoid applied to every entry — the backward is also elementwise. The chain rule per entry, written in tensor form:
where is the elementwise (Hadamard) product. No matmul, no shape acrobatics — just pointwise multiplication by the local derivative. tanh'(x) = 1 - tanh²(x), sigmoid'(x) = σ(x)(1 - σ(x)), relu'(x) = [x > 0]. The scalar formulas you already wrote in lesson 12.3, applied per entry.
Broadcasting: forward replicates, backward sums
Bias vectors expose the trickiest tensor case. The forward is straightforward:
The bias has shape but has shape . PyTorch “broadcasts” across the batch — it is conceptually replicated times so each row of gets the same bias.
What is ? You have with shape . You need something of shape to match .
The rule, which generalizes across every tensor library and every broadcast pattern:
Every axis that was broadcast in the forward must be reduce-summed in the backward.
For bias broadcast across the batch axis, the gradient is dY.sum(axis=0). Try it.
Pick the wrong axis and the result has the wrong shape; the widget tells you. Pick axis 0 and you get the right shape.
The bias backward
Forward: with shape and shape .
You have with shape . What is the only number in the shape of ?
The softmax + cross-entropy collapse
The output layer of every classifier in deep learning looks the same: a linear layer produces logits , softmax turns those into probabilities , cross-entropy with a one-hot target produces a scalar loss :
A naive autograd implementation runs the backward through both functions: first the gradient of w.r.t. (which is component-wise), then the Jacobian of softmax (a full matrix), multiply them.
The algebra collapses miraculously. Working through the chain rule:
using for a one-hot target. The whole ugly Jacobian-times-gradient stack reduces to:
The “gradient of softmax-cross-entropy with respect to the logits is the prediction error.” Every classifier in deep learning trains using this one identity. PyTorch’s nn.CrossEntropyLoss is a single fused op specifically so that this collapse is exploited and so numerical stability (log-sum-exp) is handled in one place.
One backward pass through the classifier head
Logits , target (one-hot, class 0).
Compute and then .
What is the middle entry of , to three decimals?
Three identities, one engine
You now have the three tensor identities that show up in every neural network’s backward pass:
Plus the elementwise rule for nonlinearities, plus the softmax+CE collapse. That is the entire tensor backward toolkit for a modern feed-forward network. A transformer layer adds a few more shapes (multi-head reshape, layer norm) but every individual op uses one of these patterns.
PyTorch and JAX hide all of this behind a tape that records ops as you run them. But the tape’s contents — the rules for each op — are exactly what you just learned. There is no other secret.
One lesson left: the four classic backprop bugs and the test that catches all of them. After that, you have shipped m12.
Lesson complete