Frankly, I don’t really understand what you are saying here and I am open to the possibility that I don’t really understand how the gradient works in autoregressive transformers.
But as I said in my other comment, my current understanding is:
In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn’t matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow.
To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn’t somehow have to make do with the output of earlier layers for tokens that are further back.
Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn’t matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set).
In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights).
However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn’t be any learning going for how this information flows.
Frankly, I don’t really understand what you are saying here and I am open to the possibility that I don’t really understand how the gradient works in autoregressive transformers.
But as I said in my other comment, my current understanding is:
In standard attention (for example in an encoder) tokens are not ordered, so it is clear that the gradient of the loss of one of the token predictions (for example a masked token in BERT) flows through all other tokens equally. In autoregressive transformers an order is imposed by masking, but all later tokens attend to all earlier tokens in the same way.
The gradient of the loss of a later tokens flows through all earlier tokens in the same way. It doesn’t matter whether a token is half the context back or all the context, neither for the information flow nor for the gradient flow.
To put it another way: In the n-th layer the last token attends to all the output tokens from the n-1-th layer. It doesn’t somehow have to make do with the output of earlier layers for tokens that are further back.
Yeah, I was indeed confused, sorry. I edited out the relevant section of the dialogue and replaced it with the correct relevant point (the aside here didn’t matter because a somewhat stronger condition is true, which is that during training we always just condition on the right answer instead of conditioning on the output for the next token in the training set).
Yeah, the masking is what threw me off. I was trying to think about whether any information would flow from the internal representations used to predict the second token to predicting the third token, and indeed, if you were to backpropagate the error after each specific token prediction, then there would be some information from predicting the second token available to predicting the third token (via the the updated weights).
However, batch-sizes make this also inapplicable (I think you would basically never do a backpropagation after each token, that would kind of get rid of the whole benefit of parallel training), and even without that, the amount of relevant information flowing this way would be very miniscule and there wouldn’t be any learning going for how this information flows.