The bigram is forgetful
Sample some text from our trained bigram and you get gems like enny., ax., liona., xrnxg.. The model recognizes that e is followed by some characters more than others. Beyond that: nothing.
Specifically, the bigram is amnesiac about everything except the previous character. It doesn’t know that tho looks more like the start of an English word than xrn. It doesn’t know that after quee the next letter is almost surely n. It can’t tell the cat sat on the from wxqg vbz mtl pxc q ne; both end in the same character, so both produce identical predictions for the next one.
We need more context.
The naive fix: count longer n-grams
The Markov assumption let us truncate history to one previous character. Why not two? Three? An n-gram model conditions on the previous characters:
The implementation is the same shape as the bigram: a lookup table whose row is keyed by the previous characters, whose column is the next character. The trigram table is a 3D tensor of shape . The 4-gram is a 4D tensor of shape . And so on.
The table size is . With characters, here’s the growth:
| n | table cells |
|---|---|
| 2 | 729 |
| 3 | 19,683 |
| 4 | 531,441 |
| 5 | 14,348,907 |
| 6 | 387,420,489 |
| 10 | 205 trillion |
By n = 6 you cannot fit it in memory comfortably. By n = 10 you couldn’t fit it on Earth. And almost every cell is zero; most strings of 10 characters never occur in any real corpus.
The 5-gram explosion
For our character vocabulary , how many cells does a 5-gram count table have? Enter as a plain integer.
Sparsity ruins the count model
Even if you had infinite memory, the 5-gram table is mostly empty. The English language has on the order of commonly occurring 5-letter sequences out of possible ones. The other 99.9% are zero.
That’s catastrophic for a count-based model. It means almost every prediction the model makes for a slightly novel context is “zero probability for any next character,” which then breaks the loss (because ). Smoothing helps but cannot rescue a model that is mostly fitting noise.
The deeper problem is that the count model has no notion of similarity between contexts. If it has seen “the cat” once but not “the dog,” there’s no way for it to use the experience of one to inform the other. Each context is a separate row, completely independent of all others.
We want a model that can generalize. Bengio 2003 was the first paper to do it.
Bengio 2003: replace the table with an MLP
The architecture is small. Three pieces.
1. An embedding matrix . Each character is mapped to a -dimensional vector: its row of . is much smaller than (in our setup, , ). This compresses each character into a dense representation.
2. Concatenate the embeddings of the last characters into one vector . This is the model’s representation of the context.
3. Feed through a one-hidden-layer MLP to get logits over the vocabulary, then softmax for probabilities:
That’s the whole architecture. Train it by gradient descent on NLL (the same loss as the bigram), and update , , , and the biases jointly. The model that comes out is qualitatively better than any n-gram count table, including ones that would not fit in memory.
Pull the context window from k = 1 up to k = 8 and watch the parameter counter. The MLP scales linearly in k (one extra column block in per additional context character). The n-gram count table (shown in the comparison panel) scales exponentially. By k = 5, the table is 14 million entries; the MLP is around 15,000.
Bengio MLP parameter count
Configuration: , embedding dim , context window , hidden size . How many parameters are in the hidden weight matrix ?
The embedding matrix is just learned parameters
A learner who has read pop-science articles about word2vec sometimes thinks of “embeddings” as a separate, mystical object that gets downloaded from the internet. In this lesson the embedding matrix is the same kind of object as everything else: a nn.Linear-equivalent lookup whose entries are learned by gradient descent jointly with and .
Row of is the embedding of token . There is nothing special about it. It’s a -dimensional vector of parameters that happens to get touched whenever token appears in the input.
What’s interesting is what gradient descent does to those vectors. If two tokens tend to predict the same following characters, gradient descent will pull their embeddings close together, because if their embeddings are similar, the network can use the same hidden-layer pattern to make the same prediction for both, which lowers the loss. The embedding space ends up encoding semantic similarity automatically. This is the generalization the count model can’t do.
Click any character in the picker below the scatter plot to pluck it; drag the highlighted dot around. Watch the simulated loss respond. Pluck a vowel and drag it far from the other vowels; the loss climbs because the model “wanted” that vowel near its cluster. Snap it back to its trained position and the loss drops. The dots are not decoration; they are the rows of , and gradient descent moves them around in space.
Where the wall still is
The Bengio MLP is a real upgrade. It generalizes across similar contexts via the embedding space. Its parameter count grows linearly in window size, not exponentially. It outperforms n-gram count tables on every benchmark.
But there is still a wall, and it’s the same wall the bigram had, only farther away.
The model can only see tokens back. If you set , it cannot see the 9th token back, no matter what’s there. Doubling the window doubles ‘s parameters, which means doubling memory, doubling FLOPs, and doubling the dataset you’d need to train it without overfitting.
For our character-level toy this isn’t terrible. Names are short. But for a real language model that needs to handle long-range dependencies (references back to a sentence three paragraphs ago, code that closes a brace opened a hundred lines earlier) a fixed window is fundamentally the wrong shape. Whatever the window is, the relevant context is sometimes outside it.
The next lesson introduces a model that doesn’t have a fixed window. It carries a hidden state forward through time, with the same parameters reused at every step. We trade the “see exactly k characters back” guarantee for “see arbitrarily far back, but only as much as fit in one vector.” That’s the recurrent neural network.
Lesson complete