Yeah, the first 99 tokens would be optimized both to be locally the correct character, and also to set things up so that the 100th character is also correct.
That is how LLMs currently work. The gradient of each token prediction does flow back into all the earlier tokens whose information was integrated into the predicted token. So each token optimizes its own next token prediction but also tries to integrate the information that is most useful for future tokens.
I do think saying “the system is just predicting one token at a time” is wrong, but I guess the way the work a transformer puts into token N gets rewarded or punished when it predicts token N + M feels really weird and confusing to me and still like it can be summarized much more as “it’s taking one token at a time” than “it’s doing reasoning across the whole context
IIRC at least for a standard transformer (which maybe had been modified with the recent context length extension) the gradients only flow through a subset of the weights (for a token halfway through the context, the gradients flow through half the weights that were responsible for the first token, IIRC).
Frankly, I don’t really understand what you are saying here and I am open to the possibility that I don’t really understand how the gradient works in autoregressive transformers.
But as I said in my other comment, my current understanding is:
In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn’t matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow.
To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn’t somehow have to make do with the output of earlier layers for tokens that are further back.
Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn’t matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set).
In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights).
However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn’t be any learning going for how this information flows.
That is how LLMs currently work. The gradient of each token prediction does flow back into all the earlier tokens whose information was integrated into the predicted token. So each token optimizes its own next token prediction but also tries to integrate the information that is most useful for future tokens.
I reference this in this section:
IIRC at least for a standard transformer (which maybe had been modified with the recent context length extension) the gradients only flow through a subset of the weights (for a token halfway through the context, the gradients flow through half the weights that were responsible for the first token, IIRC).Frankly, I don’t really understand what you are saying here and I am open to the possibility that I don’t really understand how the gradient works in autoregressive transformers.
But as I said in my other comment, my current understanding is:
In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn’t matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow.
To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn’t somehow have to make do with the output of earlier layers for tokens that are further back.
Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn’t matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set).
Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights).
However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn’t be any learning going for how this information flows.