If I understand what you’re saying here, it’s true but fairly well-known? See e.g. footnote 26 of the post “Simulators.”
My favorite way of looking at this is:
The usual intuitive view of causal attention is that it’s an operation that “looks back” at earlier positions. At each position i, it computes a “query” based on information from position i, and this query is used to search over “keys and values” computed at positions i-1, i-2, etc. (as well as i itself).
OK, so at each position, attention computes a query. What makes a query “good”? Well, a good query is one that will “do something useful” in conjunction with keys and values computed at earlier positions.
But attention is also computing keys and values at each position. What makes a key or value “good”? Precisely that it will “do something useful” in conjunction with the queries computed at later positions!
The latter observation is just the flipside of the former. Queries at position i are encouraged to do useful lookback, on average over the “pasts” (i-1, …) encountered in training; keys and values at position i are encouraged to be useful for the lookbacks performed by later queries, on average over the “futures” (i+1, …) encountered in training.
This is complicated slightly by the fact that causal attention lets positions attend to themselves, but it’s easy to see that this is not a huge deal in practice. Consider that the keys and values computed at position i get used by...
...the attention operation at position i, when it attends to itself (along with all earlier positions)
...the attention operation at positions i+1, i+2, …, when they “look back” to position i
The K and V weights get gradients from all of these positions. So for a context window of size N, on average the gradient will be a sum over ~N/2 terms from future positions, plus just a single term from the current position. Since N >> 2 in practice, all else being equal we should expect this sum to be dominated by the future terms.
(Moreover, note that the keys and values are more useful at future positions than at the current position, giving us even more reason to expect them to be mainly computed for the sake of future positions rather than the current one. The current position “already knows about itself” and doesn’t need attention to move information from itself to itself, whereas future positions can only learn about the current position by attending to it.
Sometimes there may be a computational role for a position attending to itself – such as doing something by default if nothing else “matched” a query – but all of the “magic” of attention is in the way it can move information between positions. Note that a self-attention layer which could only attend to the current position would just be equivalent to a linear layer.)
If I understand what you’re saying here, it’s true but fairly well-known? See e.g. footnote 26 of the post “Simulators.”
My favorite way of looking at this is:
The usual intuitive view of causal attention is that it’s an operation that “looks back” at earlier positions. At each position i, it computes a “query” based on information from position i, and this query is used to search over “keys and values” computed at positions i-1, i-2, etc. (as well as i itself).
OK, so at each position, attention computes a query. What makes a query “good”? Well, a good query is one that will “do something useful” in conjunction with keys and values computed at earlier positions.
But attention is also computing keys and values at each position. What makes a key or value “good”? Precisely that it will “do something useful” in conjunction with the queries computed at later positions!
The latter observation is just the flipside of the former. Queries at position i are encouraged to do useful lookback, on average over the “pasts” (i-1, …) encountered in training; keys and values at position i are encouraged to be useful for the lookbacks performed by later queries, on average over the “futures” (i+1, …) encountered in training.
This is complicated slightly by the fact that causal attention lets positions attend to themselves, but it’s easy to see that this is not a huge deal in practice. Consider that the keys and values computed at position i get used by...
...the attention operation at position i, when it attends to itself (along with all earlier positions)
...the attention operation at positions i+1, i+2, …, when they “look back” to position i
The K and V weights get gradients from all of these positions. So for a context window of size N, on average the gradient will be a sum over ~N/2 terms from future positions, plus just a single term from the current position. Since N >> 2 in practice, all else being equal we should expect this sum to be dominated by the future terms.
(Moreover, note that the keys and values are more useful at future positions than at the current position, giving us even more reason to expect them to be mainly computed for the sake of future positions rather than the current one. The current position “already knows about itself” and doesn’t need attention to move information from itself to itself, whereas future positions can only learn about the current position by attending to it.
Sometimes there may be a computational role for a position attending to itself – such as doing something by default if nothing else “matched” a query – but all of the “magic” of attention is in the way it can move information between positions. Note that a self-attention layer which could only attend to the current position would just be equivalent to a linear layer.)