A few points, none super confident. - I like the search algorithm parallel, I haven never thought of it that way! - Since as you said it doesn’t reduce KV cache size (unless you do it on CPU), it is somewhat limited how much it can speed up inference because it will not increase batch sizes (see my answers to Alex Gibson’s comment for why this is important if you don’t already know). - Unclear whether attention being efficient during training matters much because: -- Pretraining is afaik done done at context lengths short enough for it not mattering that much that attention is quadratic. -- Midtraining afaik takes a lot less compute than pretraining so it’s probably not that important for it to be compute efficient. -- You need to do inference when doing RL so more efficient training during RL would only help somewhat. - Yeah, google seems to be good at efficient attention. Here is a blogpost I liked showing how good they are at long context benchmarks. I don’t have takes on whether they made it subquadratic or just made it more efficient. - Another way to make attention more feasible at long contexts is to just have more VRAM per node. Even if you don’t make any architectural improvements, this just gives you more VRAM to put the KV cache in (so you can just have bigger KV caches and bigger batch sizes). Vladimir_Nesov says here that Google’s TPUs are particularly good in this respect compared to Nvidia GPUs.
A few points, none super confident.
- I like the search algorithm parallel, I haven never thought of it that way!
- Since as you said it doesn’t reduce KV cache size (unless you do it on CPU), it is somewhat limited how much it can speed up inference because it will not increase batch sizes (see my answers to Alex Gibson’s comment for why this is important if you don’t already know).
- Unclear whether attention being efficient during training matters much because:
-- Pretraining is afaik done done at context lengths short enough for it not mattering that much that attention is quadratic.
-- Midtraining afaik takes a lot less compute than pretraining so it’s probably not that important for it to be compute efficient.
-- You need to do inference when doing RL so more efficient training during RL would only help somewhat.
- Yeah, google seems to be good at efficient attention. Here is a blogpost I liked showing how good they are at long context benchmarks. I don’t have takes on whether they made it subquadratic or just made it more efficient.
- Another way to make attention more feasible at long contexts is to just have more VRAM per node. Even if you don’t make any architectural improvements, this just gives you more VRAM to put the KV cache in (so you can just have bigger KV caches and bigger batch sizes). Vladimir_Nesov says here that Google’s TPUs are particularly good in this respect compared to Nvidia GPUs.