One thing that confused me about transformers is the question of when (as in, after how many layers) each embedding “flips” from representing the original token to finally representing the prediction of the next token.
By now, I think the answer is simply this: each embedding represents both at the same time (and more). For instance, in GPT3 there are 12,288 embedding dimensions. At first I thought that all of them initially encode the original token, and after going through all the layers they eventually all encode the next token, and somewhere in the layers between this shift must happen. But what, upon some reflection, makes much more sense would be something very roughly like, say:
some 1000 dimensions encode the original token
some other 1000 dimensions encode the prediction of the next token
the remaining 10,288 dimensions encode information about all available context (which will start out “empty” and get filled with meaningful information through the layers).
In practice, things are of course much less clean, and probably most dimensions will have some role in all these things, to different degrees, as of course all of this is learned through gradient descent and hence will be very noisy and gradual. Additionally, there’s the whole positional encoding thing which is also part of the embeddings and makes clear distinctions even more difficult. But the key point remains that a single embedding encodes many things, only one of which is the prediction, and this prediction is always there from the beginning (when it’s still very superficial and bad) and then, together with the rest of the embedding, gets refined more and more throughout the layers.
Another misconception I had was that embedding and unembedding are very roughly symmetric operations that just “translate” from token space to embedding space and vice versa[1]. This made sense in relation to the initial & naive “embeddings represent tokens” interpretation, but with the updated view as described above, it becomes clear that unembedding is rather an “extraction” of the information content in the embedding that encodes the prediction.
One piece of evidence for this updated view is that this paper (thanks to Leon Lang for the hint) found that “Zero layer transformers model bigram statistics”. So, indeed, embedding + unembedding alone already perform some very basic next-token prediction. (Admittedly I’m not sure if this is only the case when the transformer is trained with zero layers, or also in, say, GPT3, when during inference you just skip all the layers)
I would guess that transformer-experienced people (unless they disagree with my description—in that case, please elaborate what I’m still getting wrong) will find all of this rather obvious. But for me, this was a major missing piece of understanding, even after once participating in an ML-themed bootcamp and watching all the 3Blue1Brown videos on transformers several times, where this idea either is not directly explained, or I somehow managed to consistently miss it.
Of course, this is not entirely true to begin with because the unembedding yields a distribution rather than a single token. But my assumption was that, if you embed the word “Good” and then unembed the embedding immediately, you would get a very high probability for “Good” back when in practice (I didn’t verify this yet) you would probably obtain high probabilities for “morning”, “day” etc.
Awkwardly, it depends on whether the model uses tied embeddings (unembed is embed transpose) or has separate embed and unembed matrices. Using tied embedding matrices like this means the model actually does have to do a sort of conversion.
Your discussion seems mostly accurate in the case of having separate embed and unembed, except that I don’t think the initial state is like “1k encode current, 1k encode predictions, rest start empty”. The model can just directly encode predictions for an initial state using the unembed.
One thing that confused me about transformers is the question of when (as in, after how many layers) each embedding “flips” from representing the original token to finally representing the prediction of the next token.
By now, I think the answer is simply this: each embedding represents both at the same time (and more). For instance, in GPT3 there are 12,288 embedding dimensions. At first I thought that all of them initially encode the original token, and after going through all the layers they eventually all encode the next token, and somewhere in the layers between this shift must happen. But what, upon some reflection, makes much more sense would be something very roughly like, say:
some 1000 dimensions encode the original token
some other 1000 dimensions encode the prediction of the next token
the remaining 10,288 dimensions encode information about all available context (which will start out “empty” and get filled with meaningful information through the layers).
In practice, things are of course much less clean, and probably most dimensions will have some role in all these things, to different degrees, as of course all of this is learned through gradient descent and hence will be very noisy and gradual. Additionally, there’s the whole positional encoding thing which is also part of the embeddings and makes clear distinctions even more difficult. But the key point remains that a single embedding encodes many things, only one of which is the prediction, and this prediction is always there from the beginning (when it’s still very superficial and bad) and then, together with the rest of the embedding, gets refined more and more throughout the layers.
Another misconception I had was that embedding and unembedding are very roughly symmetric operations that just “translate” from token space to embedding space and vice versa[1]. This made sense in relation to the initial & naive “embeddings represent tokens” interpretation, but with the updated view as described above, it becomes clear that unembedding is rather an “extraction” of the information content in the embedding that encodes the prediction.
One piece of evidence for this updated view is that this paper (thanks to Leon Lang for the hint) found that “Zero layer transformers model bigram statistics”. So, indeed, embedding + unembedding alone already perform some very basic next-token prediction. (Admittedly I’m not sure if this is only the case when the transformer is trained with zero layers, or also in, say, GPT3, when during inference you just skip all the layers)
I would guess that transformer-experienced people (unless they disagree with my description—in that case, please elaborate what I’m still getting wrong) will find all of this rather obvious. But for me, this was a major missing piece of understanding, even after once participating in an ML-themed bootcamp and watching all the 3Blue1Brown videos on transformers several times, where this idea either is not directly explained, or I somehow managed to consistently miss it.
Of course, this is not entirely true to begin with because the unembedding yields a distribution rather than a single token. But my assumption was that, if you embed the word “Good” and then unembed the embedding immediately, you would get a very high probability for “Good” back when in practice (I didn’t verify this yet) you would probably obtain high probabilities for “morning”, “day” etc.
Awkwardly, it depends on whether the model uses tied embeddings (unembed is embed transpose) or has separate embed and unembed matrices. Using tied embedding matrices like this means the model actually does have to do a sort of conversion.
Your discussion seems mostly accurate in the case of having separate embed and unembed, except that I don’t think the initial state is like “1k encode current, 1k encode predictions, rest start empty”. The model can just directly encode predictions for an initial state using the unembed.
There has actually been some work visualizing this process, with a method called the “logit lens”.
The first example that I know of: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens
A more thorough analysis: https://arxiv.org/abs/2303.08112
You can learn a per-token bias over all the layers to understand where in the model it stops representing the original embedding (or a linear transformation of it) like in https://www.lesswrong.com/posts/P8qLZco6Zq8LaLHe9/tokenized-saes-infusing-per-token-biases
You could also plot the cos-sims of the resulting biases to see how much it rotates.
Do it! I bet slightly against your prediction.