After first learning about transformers, I couldn’t help but wonder why on Earth this works. How can this totally made-up, complicated structure somehow end up learning how to write meaningful text and having a mostly sound model of our world?
(tl;dr: no novel insights here, just me writing down some thoughts I’ve had after/while learning more about neural nets and transformers.)
When I once asked someone more experienced, they essentially told me “nobody really knows, but the closest thing we have to an answer is ‘the blessing of dimensionality’ - with so many dimensions in your loss landscape, you basically don’t run into local minima but the thing keeps improving if you just throw enough data and compute at it”.
I think this makes sense, and my view on how/why/when deep neural networks work is currently something along the lines of:
there’s some (unknown) minimal network size (or maybe rather “minimal network frontier”, as with different architectures you end up with different minimal sizes) for every problem you want to solve (for a certain understanding of the problem and when you consider it solved), so your network needs to be big enough to even be able to solve the problem
the network size & architecture also determines how much training data you need to get anywhere
basically, you try to find network architectures such that you encode sensible priors about the modality you’re working with that are basically always true while also eliminating a priori-useless weights from your network; this way, the training efforts allow the network to quickly learn important things rather than first having to figure out the priors themselves
for text, you might realize that different parts of the text refer to each other, so need a way to effectively pass information around, and hence you end up with something like the attention mechanism
for image detection, you realize that the prior of any given pixel being relevant for any other given pixel is higher the closer they are, so you end up with something like CNNs, where you start looking at low level features, and throughout the layers of the network, allow it to “convert” the raw pixel data successively to semantic data
in theory, you probably could just use a huge feed forward network (as long as it’s not so huge as to overfit instead of generalizing to anything useful) and it would possibly end up solving problems in similar ways as “smarter” architectures do (but not sure about this), but you would need way more parameters and way more training data to achieve similar results, much of which would be wasted on “low quality parameters” that could just as well be omitted
so, encoding these modality priors into your network architecture spares you probably orders of magnitude of compute compared to naive approaches
while the bitter lesson makes sense, it maybe under-emphasizes the degree to which choosing suitable network architecture + high quality training data matters?
lastly, the question “which problem you’re trying to solve” cannot just be answered on a high level with “I want to minimize loss in next-token prediction”, but the exact problem the network solves depends strongly on the training data; loss minimization is a trade-off between all the things you’re minimizing, so the higher the amount of rambling, gossip, meaningless binary data and so on in your training data is, the more parameters and training time you’ll need just for those, and the less will the network be capable to predict more meaningful tokens.
Related to that last point, I recently worked on a small project where you, as the user, play Pong against an AI. That AI is controlled by a small neural network (something in the order of 2 or 3 hidden layers and a few dozen neurons), initialized randomly, so at first it’s very easy for the human to win. While you play, though, the game collects your behavior as training data and constantly trains the neural network, which eventually learns to mirror you. So after a few minutes of playing, it plays very similar to the human and it becomes much harder to beat it.
One thing I noticed while working on this is that the naive approach to training this AI was far from optimal: much of the training data I collected ended up being pretty irrelevant for playing well! E.g., it’s much more important how the paddle moves while the ball is closing in, and almost entirely irrelevant what you do right after hitting the ball. There were several such small insights, leading me to tweak how exactly training data is collected (e.g. sampling it with lower probability while the ball is moving away than when it’s getting closer), which greatly reduced the time it took for the AI to learn, even with the network architecture staying the same.
Notably, this does not necessarily mean the loss curve dropped more quickly—due to me tweaking the training data, the loss curves before and after doing so related to quite different things. The same loss for higher quality data is much more useful than for noisy or irrelevant data.
There’s just so many degrees of freedom in all of this that it seems very likely that, even if there were not hardware advances at all, research would probably be able to come up with faster/cheaper/better-performing models for a long time.
for text, you might realize that different parts of the text refer to each other, so need a way to effectively pass information around, and hence you end up with something like the attention mechanism
If you are trying to convince yourself that a Transformer could work and to make it ‘obvious’ to yourself that you can model sequences usefully that way, it might be a better starting point to begin with Bengio’s simple 2003 LM and MLP-Mixer. Then Transformers may just look like a fancier MLP which happens to implement a complicated way of doing token-mixing inspired by RNNs and heavily tweaked empirically to eke out a bit more performance with various add-ons and doodads.
(AFAIK, no one has written a “You Could Have Invented Transformers”, going from n-grams to Bengio’s LM to MLP-Mixer to RNN to Set Transformer to Vaswani Transformer to a contemporary Transformer, but I think it is doable and useful.)
After first learning about transformers, I couldn’t help but wonder why on Earth this works. How can this totally made-up, complicated structure somehow end up learning how to write meaningful text and having a mostly sound model of our world?
(tl;dr: no novel insights here, just me writing down some thoughts I’ve had after/while learning more about neural nets and transformers.)
When I once asked someone more experienced, they essentially told me “nobody really knows, but the closest thing we have to an answer is ‘the blessing of dimensionality’ - with so many dimensions in your loss landscape, you basically don’t run into local minima but the thing keeps improving if you just throw enough data and compute at it”.
I think this makes sense, and my view on how/why/when deep neural networks work is currently something along the lines of:
there’s some (unknown) minimal network size (or maybe rather “minimal network frontier”, as with different architectures you end up with different minimal sizes) for every problem you want to solve (for a certain understanding of the problem and when you consider it solved), so your network needs to be big enough to even be able to solve the problem
the network size & architecture also determines how much training data you need to get anywhere
basically, you try to find network architectures such that you encode sensible priors about the modality you’re working with that are basically always true while also eliminating a priori-useless weights from your network; this way, the training efforts allow the network to quickly learn important things rather than first having to figure out the priors themselves
for text, you might realize that different parts of the text refer to each other, so need a way to effectively pass information around, and hence you end up with something like the attention mechanism
for image detection, you realize that the prior of any given pixel being relevant for any other given pixel is higher the closer they are, so you end up with something like CNNs, where you start looking at low level features, and throughout the layers of the network, allow it to “convert” the raw pixel data successively to semantic data
in theory, you probably could just use a huge feed forward network (as long as it’s not so huge as to overfit instead of generalizing to anything useful) and it would possibly end up solving problems in similar ways as “smarter” architectures do (but not sure about this), but you would need way more parameters and way more training data to achieve similar results, much of which would be wasted on “low quality parameters” that could just as well be omitted
so, encoding these modality priors into your network architecture spares you probably orders of magnitude of compute compared to naive approaches
while the bitter lesson makes sense, it maybe under-emphasizes the degree to which choosing suitable network architecture + high quality training data matters?
lastly, the question “which problem you’re trying to solve” cannot just be answered on a high level with “I want to minimize loss in next-token prediction”, but the exact problem the network solves depends strongly on the training data; loss minimization is a trade-off between all the things you’re minimizing, so the higher the amount of rambling, gossip, meaningless binary data and so on in your training data is, the more parameters and training time you’ll need just for those, and the less will the network be capable to predict more meaningful tokens.
Related to that last point, I recently worked on a small project where you, as the user, play Pong against an AI. That AI is controlled by a small neural network (something in the order of 2 or 3 hidden layers and a few dozen neurons), initialized randomly, so at first it’s very easy for the human to win. While you play, though, the game collects your behavior as training data and constantly trains the neural network, which eventually learns to mirror you. So after a few minutes of playing, it plays very similar to the human and it becomes much harder to beat it.
One thing I noticed while working on this is that the naive approach to training this AI was far from optimal: much of the training data I collected ended up being pretty irrelevant for playing well! E.g., it’s much more important how the paddle moves while the ball is closing in, and almost entirely irrelevant what you do right after hitting the ball. There were several such small insights, leading me to tweak how exactly training data is collected (e.g. sampling it with lower probability while the ball is moving away than when it’s getting closer), which greatly reduced the time it took for the AI to learn, even with the network architecture staying the same.
Notably, this does not necessarily mean the loss curve dropped more quickly—due to me tweaking the training data, the loss curves before and after doing so related to quite different things. The same loss for higher quality data is much more useful than for noisy or irrelevant data.
There’s just so many degrees of freedom in all of this that it seems very likely that, even if there were not hardware advances at all, research would probably be able to come up with faster/cheaper/better-performing models for a long time.
If you are trying to convince yourself that a Transformer could work and to make it ‘obvious’ to yourself that you can model sequences usefully that way, it might be a better starting point to begin with Bengio’s simple 2003 LM and MLP-Mixer. Then Transformers may just look like a fancier MLP which happens to implement a complicated way of doing token-mixing inspired by RNNs and heavily tweaked empirically to eke out a bit more performance with various add-ons and doodads.
(AFAIK, no one has written a “You Could Have Invented Transformers”, going from n-grams to Bengio’s LM to MLP-Mixer to RNN to Set Transformer to Vaswani Transformer to a contemporary Transformer, but I think it is doable and useful.)
I think you would appreciate this post