Maybe I am misunderstanding something, but to me it is very intuitive that there is a big jump from the embedding output to the first transformer block output. The embedding is backpropagated into so it makes sense to see all representations as representations of the prediction we are trying to make, i.e. of the next word.
But the embedding is a prediction of the next word based on only a single word, the word that is being embedded. So the prediction of the next word is by necessity very bad (the BPE ensures that, IIUC, because tokens that would always follow one another are merged).
The first transformer block integrates hundreds of words of context into the prediction, that’s where the big jump comes from.
Great post! I had trouble wrapping my head around the „inconsistency“ in the first paper, now I think I get it: TL;DR in my own words:
There are three regimes of increasing information uptake, ordered by how cheap they are in terms of compute:
- Increasing sampling efficiency by increasing model size —> this runs into diminishing returns because sample efficiency has a hard upper bound. —> context window increase?
- Accessing more information by training over more unique samples —> will run into diminishing returns when unique data runs out. —> multi-modal data?
- Extracting more information by running over the same samples several times —> this intuitively crashes sampling efficiency because you can only learn the information not already extracted in earlier passes. —> prime candidate for active learning?
I had also missed the implication of the figure in the second paper that shows that GPT-3 is already very close to optimal sampling efficiency. So it seems that pure text models will only see another order of magnitude increase in parameters or so.
If you are looking for inspiration for another post about this topic: Gwern mentions the human level of language modeling and Steve Omohundro also alludes to the loss that would signify human level, I don’t really understand neither the math nor where the numbers come from. It would be very interesting to me to see an explanation of the „human level loss“ to put the scaling laws in perspective. Of course I assume that a „human level“ LM would have very different strengths and weaknesses compared to a human, but still.