One thing that confused me about transformers is the question of when (as in, after how many layers) each embedding “flips” from representing the original token to finally representing the prediction of the next token.
By now, I think the answer is simply this: each embedding represents both at the same time (and more). For instance, in GPT3 there are 12,288 embedding dimensions. At first I thought that all of them initially encode the original token, and after going through all the layers they eventually all encode the next token, and somewhere in the layers between this shift must happen. But what, upon some reflection, makes much more sense would be something very roughly like, say:
some 1000 dimensions encode the original token
some other 1000 dimensions encode the prediction of the next token
the remaining 10,288 dimensions encode information about all available context (which will start out “empty” and get filled with meaningful information through the layers).
In practice, things are of course much less clean, and probably most dimensions will have some role in all these things, to different degrees, as of course all of this is learned through gradient descent and hence will be very noisy and gradual. Additionally, there’s the whole positional encoding thing which is also part of the embeddings and makes clear distinctions even more difficult. But the key point remains that a single embedding encodes many things, only one of which is the prediction, and this prediction is always there from the beginning (when it’s still very superficial and bad) and then, together with the rest of the embedding, gets refined more and more throughout the layers.
Another misconception I had was that embedding and unembedding are very roughly symmetric operations that just “translate” from token space to embedding space and vice versa[1]. This made sense in relation to the initial & naive “embeddings represent tokens” interpretation, but with the updated view as described above, it becomes clear that unembedding is rather an “extraction” of the information content in the embedding that encodes the prediction.
One piece of evidence for this updated view is that this paper (thanks to Leon Lang for the hint) found that “Zero layer transformers model bigram statistics”. So, indeed, embedding + unembedding alone already perform some very basic next-token prediction. (Admittedly I’m not sure if this is only the case when the transformer is trained with zero layers, or also in, say, GPT3, when during inference you just skip all the layers)
I would guess that transformer-experienced people (unless they disagree with my description—in that case, please elaborate what I’m still getting wrong) will find all of this rather obvious. But for me, this was a major missing piece of understanding, even after once participating in an ML-themed bootcamp and watching all the 3Blue1Brown videos on transformers several times, where this idea either is not directly explained, or I somehow managed to consistently miss it.
- ^
Of course, this is not entirely true to begin with because the unembedding yields a distribution rather than a single token. But my assumption was that, if you embed the word “Good” and then unembed the embedding immediately, you would get a very high probability for “Good” back when in practice (I didn’t verify this yet) you would probably obtain high probabilities for “morning”, “day” etc.
After first learning about transformers, I couldn’t help but wonder why on Earth this works. How can this totally made-up, complicated structure somehow end up learning how to write meaningful text and having a mostly sound model of our world?
(tl;dr: no novel insights here, just me writing down some thoughts I’ve had after/while learning more about neural nets and transformers.)
When I once asked someone more experienced, they essentially told me “nobody really knows, but the closest thing we have to an answer is ‘the blessing of dimensionality’ - with so many dimensions in your loss landscape, you basically don’t run into local minima but the thing keeps improving if you just throw enough data and compute at it”.
I think this makes sense, and my view on how/why/when deep neural networks work is currently something along the lines of:
there’s some (unknown) minimal network size (or maybe rather “minimal network frontier”, as with different architectures you end up with different minimal sizes) for every problem you want to solve (for a certain understanding of the problem and when you consider it solved), so your network needs to be big enough to even be able to solve the problem
the network size & architecture also determines how much training data you need to get anywhere
basically, you try to find network architectures such that you encode sensible priors about the modality you’re working with that are basically always true while also eliminating a priori-useless weights from your network; this way, the training efforts allow the network to quickly learn important things rather than first having to figure out the priors themselves
for text, you might realize that different parts of the text refer to each other, so need a way to effectively pass information around, and hence you end up with something like the attention mechanism
for image detection, you realize that the prior of any given pixel being relevant for any other given pixel is higher the closer they are, so you end up with something like CNNs, where you start looking at low level features, and throughout the layers of the network, allow it to “convert” the raw pixel data successively to semantic data
in theory, you probably could just use a huge feed forward network (as long as it’s not so huge as to overfit instead of generalizing to anything useful) and it would possibly end up solving problems in similar ways as “smarter” architectures do (but not sure about this), but you would need way more parameters and way more training data to achieve similar results, much of which would be wasted on “low quality parameters” that could just as well be omitted
so, encoding these modality priors into your network architecture spares you probably orders of magnitude of compute compared to naive approaches
while the bitter lesson makes sense, it maybe under-emphasizes the degree to which choosing suitable network architecture + high quality training data matters?
lastly, the question “which problem you’re trying to solve” cannot just be answered on a high level with “I want to minimize loss in next-token prediction”, but the exact problem the network solves depends strongly on the training data; loss minimization is a trade-off between all the things you’re minimizing, so the higher the amount of rambling, gossip, meaningless binary data and so on in your training data is, the more parameters and training time you’ll need just for those, and the less will the network be capable to predict more meaningful tokens.
Related to that last point, I recently worked on a small project where you, as the user, play Pong against an AI. That AI is controlled by a small neural network (something in the order of 2 or 3 hidden layers and a few dozen neurons), initialized randomly, so at first it’s very easy for the human to win. While you play, though, the game collects your behavior as training data and constantly trains the neural network, which eventually learns to mirror you. So after a few minutes of playing, it plays very similar to the human and it becomes much harder to beat it.
One thing I noticed while working on this is that the naive approach to training this AI was far from optimal: much of the training data I collected ended up being pretty irrelevant for playing well! E.g., it’s much more important how the paddle moves while the ball is closing in, and almost entirely irrelevant what you do right after hitting the ball. There were several such small insights, leading me to tweak how exactly training data is collected (e.g. sampling it with lower probability while the ball is moving away than when it’s getting closer), which greatly reduced the time it took for the AI to learn, even with the network architecture staying the same.
Notably, this does not necessarily mean the loss curve dropped more quickly—due to me tweaking the training data, the loss curves before and after doing so related to quite different things. The same loss for higher quality data is much more useful than for noisy or irrelevant data.
There’s just so many degrees of freedom in all of this that it seems very likely that, even if there were not hardware advances at all, research would probably be able to come up with faster/cheaper/better-performing models for a long time.