Attention · 13 min

Soft dictionary lookup

A hash table maps a query to one entry. A soft hash table returns every entry, weighted by similarity. Replace argmax with softmax and the lookup becomes differentiable. That's attention; the rest of this module is plumbing.

0 / 0

What every previous model couldn't do

The bigram knew one character. The Bengio MLP knew kk characters. The RNN tried to know everything in a single carrying vector and faded after about ten characters. None of these models can look up anything from earlier in the sequence on demand. Each one is forced to compress the past into a fixed-size container before getting to use it.

What we want is something different: at each position, given the current state, go back and read the parts of the past that are relevant. Not all of it. Not the most recent kk tokens. The relevant parts.

That’s a lookup.

Hard lookup is dead in the water

The lookup operation we know from code is a hash table. Compute a query, hash it, get the matching entry. Trouble is that exact matching is brittle and non-differentiable. Two queries that are nearly identical can return totally different entries. There’s no gradient telling us “this query was almost right.”

For a learnable model we need something gentler. The first instinct is argmax: find the key most aligned with the query, return its value. The score for key jj is the dot product qkjq\cdot k_j: large when they point the same way, small when they don’t. Argmax picks the largest.

lookup
-3-2-1123-2-112
q·k1
+2.92
q·k2
+0.64
α1
91%
α2
9%
o = α1k1 + α2k2 (+1.78, +0.37)

drag q, k1, or k2; the output is α-weighted. for now, values are the keys themselves.

Set the toggle to hard and drag the red query around. Whichever key has the larger dot product wins; the output snaps to that key with no contribution from the other one. The yellow output vector teleports between k1k_1 and k2k_2 as you cross the boundary.

That snap is the problem. The score is continuous in qq, but the output is a step function in the score, and a step function has zero gradient almost everywhere. The model can’t learn through it.

Replace argmax with softmax

Switch the toggle to soft. Now every key contributes to the output, weighted by how aligned it is with the query.

Compute the score sj=qkjs_j = q\cdot k_j for each key. Apply softmax across the scores to get a probability distribution α\alpha over keys: a vector of weights that sum to 1, with the largest weight on the key most aligned with qq. The output is the corresponding weighted sum of values:

o=jαjvjαj=exp(qkj)jexp(qkj)o = \sum_{j} \alpha_j \, v_j \qquad \alpha_j = \frac{\exp(q\cdot k_j)}{\sum_{j'} \exp(q\cdot k_{j'})}

Drag qq toward k1k_1 and the bar shifts mass to position 1. Drag it between the two keys and the bar splits. The yellow output vector slides smoothly between k1k_1 and k2k_2, with no snap.

This is differentiable. We can train it.

(In the widget, the values are the keys themselves: vj=kjv_j = k_j. That’s the self-attention setup we’ll formalize two lessons from now. For now just notice the operation: query, score, softmax, weighted sum.)

Hand-compute one query

For q=(1,0)q = (1, 0), keys k1=(1,0)k_1 = (1, 0) and k2=(0,1)k_2 = (0, 1), and scalar values v1=10v_1 = 10, v2=20v_2 = 20, what is the attention output?

Skip the dk\sqrt{d_k} scale for this example; we’ll motivate it next lesson.

Three roles, one kind of object

Three names for vectors that play different roles in this lookup:

  • Query qq: what am I looking for? The thing matched against the keys.
  • Key kjk_j: what do I advertise? The thing the query is matched against.
  • Value vjv_j: what do I deliver if matched? The payload returned, weighted by the match.

These are three labels for the same kind of object. They’re vectors. In the self-attention layer we’ll build two lessons from now they’ll all be different linear projections of the same input sequence: same source, three different views. For now we just have to remember which role each plays.

The single-query equation

The full single-query attention, with a placeholder for one detail we’ll patch next lesson:

  o  =  j=1Tαjvj,αj  =  exp(qkj/dk)jexp(qkj/dk)  \boxed{\;o \;=\; \sum_{j=1}^{T} \alpha_j \, v_j,\qquad \alpha_j \;=\; \frac{\exp(q\cdot k_j / \sqrt{d_k})}{\sum_{j'} \exp(q\cdot k_{j'} / \sqrt{d_k})}\;}

That’s all of attention. Two lines.

Everything else in this module (Q, K, V matrices, multi-head, masking, positional encoding, KV cache) is plumbing around those two lines. The dk\sqrt{d_k} in the denominator we’ll motivate next lesson; today’s widget left it out because at dk=2d_k = 2 it doesn’t matter much. But the rest of the formula is the operation in full generality.

Re-read the boxed equation until it feels inert. We’re going to build everything from it.

Uniform attention is just averaging

Notice what happens when all the scores qkjq\cdot k_j are equal. Softmax of equal values is uniform: αj=1/T\alpha_j = 1/T for every jj. The output is the unweighted average of the values.

Uniform attention is the same as “just take the mean.” That’s not useless. That’s the running-average baseline an RNN’s hidden state was trying to be a smarter version of. Attention contains running-mean as a special case, and gives the model a learned, differentiable way to deviate from it.

When the query points strongly toward one key, attention picks that value. When the query is neutral, attention falls back to the average. The same operation handles both extremes and every blend in between. There is no separate “averaging mode” and “selecting mode”; there is one operation, parameterized by a continuous distribution α\alpha.

The uniform-scores case

With T=3T = 3 keys and all scores equal, what is the attention weight assigned to each key?

Round to three decimal places.

Lesson complete

Nice tinkering.