Attention · 18 min

Three projections of the same X

Self-attention is the same operation as last lesson, applied in parallel. Every token is both a querier and a query target. Q, K, V are three linear projections of the same X. Causal masking and the permutation-equivariance bug come for free.

0 / 0

Every position queries, every position answers

The last two lessons built attention with one query against TT keys. That’s enough to define the operation but not enough to do useful work: a transformer layer needs every position to attend to every other position, in one parallel sweep.

The leap is small. Take the input sequence XRT×dX \in \mathbb{R}^{T\times d}, where each row is a token’s vector. We’re going to make every row both a querier and a query target: same data, three roles.

Three projections of the same X

For each token vector, build a query, a key, and a value via three independent linear layers:

Q=XWQ,K=XWK,V=XWVQ = X W_Q,\qquad K = X W_K,\qquad V = X W_V

with WQ,WKRd×dkW_Q, W_K \in \mathbb{R}^{d \times d_k} and WVRd×dvW_V \in \mathbb{R}^{d \times d_v}. Q, K, V are matrices of shape (T,dk)(T, d_k), (T,dk)(T, d_k), (T,dv)(T, d_v). Each row is the query, key, value for one token.

In real code the three projections are usually fused into one matrix multiply for speed:

qkv = nn.Linear(d, 3 * d)(x)        # one matmul, no bias
q, k, v = qkv.split(d, dim=-1)       # then three views

This is the canonical nanoGPT pattern, and it is the reason a learner who sees qkv instead of q, k, v in code shouldn’t panic: they’re the same three projections, packed.

The point worth carrying with you: Q, K, V are not three different inputs. They are three views of the same XX, computed by three different matrices. A single token at position ii contributes its query qiq_i (when it is doing the asking), its key kik_i (when others might match against it), and its value viv_i (the payload it delivers if matched).

The matrix form, in one line

Apply scaled dot-product attention to Q,K,VQ, K, V in parallel:

  Attention(Q,K,V)  =  softmax ⁣(QKdk)V  \boxed{\;\mathrm{Attention}(Q, K, V) \;=\; \mathrm{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V\;}

That is the equation you came to this module for. Let’s walk through it shape by shape.

  • QKQ K^\top is a (T,T)(T, T) matrix of all pairwise scores. Cell (i,j)(i, j) is qikjq_i \cdot k_j: how strongly query ii matches key jj.
  • Divide by dk\sqrt{d_k}: last lesson’s variance fix.
  • Softmax along the last axis (the key axis): every row becomes a probability distribution over the TT keys.
  • Multiply by VV: each row of the output is a weighted average of the value vectors, weighted by that row’s attention distribution.
  • Output: (T,dv)(T, d_v). Same number of rows as input. Same shape passport, plus mixing.

The “shape passport” (the chain of tensor shapes) is the single most useful debugging tool when reading attention code. If you can recite it, you can localize a bug in someone else’s transformer in about forty-five seconds.

Shape passport drill

With T=4T = 4 tokens, d=8d = 8 model dimension, and dk=4d_k = 4 key dimension, QKQ K^\top has shape (T,T)=(4,4)(T, T) = (4, 4). How many scalar entries does QKQ K^\top have?

Self vs. cross: same op, different bindings

A quick clarifier in case you read older transformer papers and got confused by encoder–decoder diagrams.

  • Self-attention: QQ, KK, VV all come from the same sequence. X=X(src)X = X^{(\mathrm{src})}, and the formula above is unchanged. This is what GPT-style decoder-only models use.
  • Cross-attention: QQ comes from one sequence, KK and VV from another. Used in encoder–decoder translation models, where the decoder’s queries attend to the encoder’s keys/values. Same equation, different argument bindings.

The implementation is identical. The codepath is one function. Don’t think of cross-attention as a different mechanism: think of it as the same Attention(Q, K, V) call with different inputs in two of its three slots. We will spend the rest of this module on self-attention because it is what a decoder-only transformer needs.

Mask the future

At training time, a decoder-only language model wants every position to predict the next token from positions up to and including itself. So position ii must not be allowed to read positions j>ij > i. The future is what we are trying to predict: leaking it would let the model memorize the answer.

But at training time we do feed the entire TT-length sequence into the layer in one forward pass. We don’t recompute attention TT times for TT different prefixes. So the question is: how do you, in a single softmax(QK/dk)V\mathrm{softmax}(Q K^\top / \sqrt{d_k}) V call, force every row ii to ignore the positions j>ij > i?

You add an upper-triangular mask of -\infty to the scores before the softmax:

Mij={0jij>iM_{ij} = \begin{cases} 0 & j \le i \\ -\infty & j > i \end{cases}

After softmax, exp()=0\exp(-\infty) = 0, so every entry above the diagonal becomes a hard zero. The remaining entries renormalize across the unmasked positions, and each row is still a valid probability distribution.

causal mask
every query sees every key (encoder-style)
t₁(k)
t₂(k)
t₃(k)
t₄(k)
t₅(k)
t₆(k)
Σ
t₁(q)
0.46
0.08
0.13
0.11
0.07
0.14
1.00
t₂(q)
0.19
0.34
0.15
0.09
0.14
0.08
1.00
t₃(q)
0.13
0.22
0.32
0.14
0.07
0.12
1.00
t₄(q)
0.08
0.15
0.20
0.34
0.14
0.09
1.00
t₅(q)
0.14
0.09
0.12
0.19
0.29
0.17
1.00
t₆(q)
0.42
0.10
0.07
0.11
0.09
0.21
1.00

rows: queries · columns: keys · cell color = αij ∈ [0, 1] · every row sums to 1

Toggle the mask off and back on. Off: every query reads every key (this is what an encoder, like a BERT, does). On: every row’s distribution lives only on the diagonal and below; the upper triangle is zeroed. The row sums on the right stay at 1.00 either way, because the masking happens in score space, not weight space.

The same Q, K, V trains as a bidirectional encoder or as a causal decoder, depending on a single line of code: whether you add the mask. That is the whole architectural difference between BERT and GPT, and it lives in this one tensor.

Inside the mask

With the causal mask on at T=4T = 4, what is the value of the attention weight α2,3\alpha_{2,3} (query at position 2 attending to key at position 3)?

Aside on interpretability. It is tempting to look at the heatmap above and read off “the model’s reasoning.” Don’t. Jain and Wallace (NAACL 2019) showed that one can construct alternative attention distributions that yield the same predictions, and that attention is often uncorrelated with gradient-based feature importance. Wiegreffe and Pinter (EMNLP 2019) refined the picture but did not rescue the strong claim. Treat attention heatmaps as a useful diagnostic, not a faithful explanation.

The bug nobody mentions out loud

There is something quietly broken about the operation you just built. Watch.

tap two chips to swap their tokens
positional encoding

original ordering [a, b, c, d]

current ordering [a, b, c, d]

identity ordering; swap two chips to test

Start with the chips in identity order [a,b,c,d][a, b, c, d] and the positional encoding off. Tap any two chips to swap them. Notice the verdict: outputs match permutation. What’s happening: when you swap positions ii and jj in XX, the rows of QQ, KK, VV swap; the rows and columns of QKQ K^\top swap; the softmax operates row-wise so it commutes with that swap; the final multiplication by VV produces an output that is just Attention(X)\mathrm{Attention}(X) with the corresponding rows swapped.

Formally:

SelfAttn(PX)  =  PSelfAttn(X)for any permutation matrix P.\mathrm{SelfAttn}(P X) \;=\; P\,\mathrm{SelfAttn}(X) \quad \text{for any permutation matrix } P.

That property is called permutation equivariance. It is the reason a transformer is fast and parallel: every position is processed identically, no recurrence, no fixed window. It is also the reason a transformer, in the form we have built so far, literally cannot tell dog bites man from man bites dog. The model has no notion of position. It operates over sets, not sequences.

Permutation equivariance, in one number

With positional encoding off and any permutation PP applied to the input, what is YPY\|Y' - P\,Y\| (the difference between attention on the permuted input and the permutation of the original output)?

The fix is the next lesson

Now flip positional encoding on in the widget and swap two chips. The verdict flips to outputs do not match. Adding position information to the inputs breaks the permutation equivariance, on purpose. We want the model to know that position 0 is the start of the sequence and position T1T-1 is the most recent token, because the meaning of a sentence depends on order.

There are several ways to inject position. Sinusoidal encodings, learned absolute encodings, rotary encodings: each has a different geometry and a different trade-off. That’s what the next lesson covers.

The key thing to leave this lesson with: positional encoding is not decoration. It is the fix to a specific, named symptom: the fact that scaled dot-product attention, on its own, is blind to order. Once you have seen the bug, the patch is unsurprising.

Lesson complete

Nice tinkering.