I think this article fails to list the key consideration around generation: output tokens require using a KV cache which requires substantial memory bandwidth and takes up a considerable amount of memory.
From my understanding the basic situation is:
For input (not output) tokens, you can get pretty close the the maximum flop utilization for realistic work loads. To make this efficient (and avoid memory bandwidth issues), you’ll need to batch up a bunch of tokens at once. This can be done by batching multiple input sequences or even a single long sequence can be ok. So, memory bandwidth isn’t currently a binding constraint for input tokens.
(You might also note that input tokens have a pretty similar work profile to model training as the forward pass and backward pass are pretty structurally similar.)
However, for generating output tokens a key bottleneck is that you have utilize the entire KV (key value) cache for each output token in order to implement attention. In practice, this means that on long sequences, the memory bandwidth for attention (due to needing to touch the whole KV cache) can be a limiting constraint. A further issue is that KV cache memory consumption forces us to use a smaller batch size. More details:
It will still be key to batch up token, but now we’re just doing computation on a single token which means we’ll need to batch up many more sequences: the optimal number of sequences to batch for generating output tokens will be very different than the optimal number of sequences to batch for input tokens (where we can run the transformer on the whole sequence at once).
A further difficulty is that because we need a higher batch size, we need a larger amount of KV cache data. I think it’s common to use an otherwise suboptimally small batch size for generation due to constraints on VRAM (at least on consumer applications (e.g. llama-70b inference on 8xH100), I assume this also comes up for bigger models). We could store the KV cache on CPU, but then we might get bottlenecked on memory bandwidth to the CPU.
Note that in some sense the operations for each output token is the same as for each input token. So, why are the memory bandwidth requirements worse? The key thing is that we potentially get much worse cache locality on output tokens due to only computing one token, but needing to read the KV for many tokens (while on input we do many to many).
However, it is possible to substantially reduce the KV sizes using various optimizations like sparse attention and mamba. This can substantially improve inferences speeds due to reducing memory bandwidth in inference and also allowing for higher batch sizes. See e.g. the mamba paper where allowing for higher batch sizes results in substantially higher speeds.
One additional note: I recently set up an inference setup for llama-3-70b on 8xH100. I can get about 100,000 tok/s on inputs which is pretty close to full utilization (1e15 flop/s * 8 gpus / 7e10 flop per forward pass). However, I get dramatically worse performance on generation, perhaps 3,200 tok/s. I’m doing generation with long prompts and llama-3-70b has no sparse attention or other feature for reducing KV cache (beyond multi-query attention which is standard these days), so KV cache bits pretty hard. My setup probably isn’t very close to optimal, especially on output tok/s, I’m just using basic out of the box stuff (vllm).
Can I double check, do you think this affects the bottom lines?
The bottom line is supposed to be that FLOP/s vs. FLOP per forward pass can be used as an upper bound, and memory bandwidth vs. model size can be used as an lower bound, and real life efficiency falls somewhere in the middle depending on a many factors (inc. length of KV cache), which I don’t try to get into, but is plausibly around 15% of the upper bound for GPT-4 on H100s.
Are you saying that the lower bound for output tokens should maybe be even lower, because the KV cache can be larger than the model weights?
The lower bound of “memory bandwidth vs. model size” is effectively equivalent to assuming that the batch size is a single token. I think this isn’t at all close to realistic operating conditions and thus won’t be a very tight lower bound. (Or reflect the most important bottlenecks.)
I think that the KV cache for a single sequence won’t be larger than the model weights for realistic work loads, so the lower bound should still be a valid lower bound. (Though not a tight one.)
I think the bottom line number you provide for “rough estimate of actual throughput” ends up being pretty reasonable for output tokens and considerably too low for input tokens. (I think input tokens probably get more like 50% or 75% flop utilization rather than 15%. See also the difference in price for anthropic model.)
That said, it doesn’t seem like a good mechanism for estimating throughput will be to aggregate the lower and upper bounds you have as the lower bound doesn’t have much correspondence with actual bottlenecks. (For instance, this lower bound would miss that mamba would get much higher throughput.)
I also think that insofar as you care about factors of 3-5 on inference efficiency, you need to do different analysis for input tokens and output tokens.
(I also think that input tokens get pretty close to the pure FLOP estimate. So, another estimation approach you can use if you don’t care about factors of 5 is to just take the pure flop estimate and then halve it to be account for other slow downs. I think this estimate gets input tokens basically right and is wrong by a factor of 3-5 for output tokens.)
It seems like your actual mechanism for making this estimate for the utilization on output tokens was to take the number from semi-analysis and extrapolate it to other GPUs. (At least the number matches this?) This does seem like a reasonable approach, but it isn’t particularly tethered to your lower bound.
I agree the lower bound for output isn’t very tight. I’d be very interested to hear other simple rules of thumb you could use to provide a tighter one.
I’ll add a note to the section on input tokens that since they don’t require KV cache, it’s possible to get much closer to the upper bound.
I think this article fails to list the key consideration around generation: output tokens require using a KV cache which requires substantial memory bandwidth and takes up a considerable amount of memory.
From my understanding the basic situation is:
For input (not output) tokens, you can get pretty close the the maximum flop utilization for realistic work loads. To make this efficient (and avoid memory bandwidth issues), you’ll need to batch up a bunch of tokens at once. This can be done by batching multiple input sequences or even a single long sequence can be ok. So, memory bandwidth isn’t currently a binding constraint for input tokens.
(You might also note that input tokens have a pretty similar work profile to model training as the forward pass and backward pass are pretty structurally similar.)
However, for generating output tokens a key bottleneck is that you have utilize the entire KV (key value) cache for each output token in order to implement attention. In practice, this means that on long sequences, the memory bandwidth for attention (due to needing to touch the whole KV cache) can be a limiting constraint. A further issue is that KV cache memory consumption forces us to use a smaller batch size. More details:
It will still be key to batch up token, but now we’re just doing computation on a single token which means we’ll need to batch up many more sequences: the optimal number of sequences to batch for generating output tokens will be very different than the optimal number of sequences to batch for input tokens (where we can run the transformer on the whole sequence at once).
A further difficulty is that because we need a higher batch size, we need a larger amount of KV cache data. I think it’s common to use an otherwise suboptimally small batch size for generation due to constraints on VRAM (at least on consumer applications (e.g. llama-70b inference on 8xH100), I assume this also comes up for bigger models). We could store the KV cache on CPU, but then we might get bottlenecked on memory bandwidth to the CPU.
Note that in some sense the operations for each output token is the same as for each input token. So, why are the memory bandwidth requirements worse? The key thing is that we potentially get much worse cache locality on output tokens due to only computing one token, but needing to read the KV for many tokens (while on input we do many to many).
However, it is possible to substantially reduce the KV sizes using various optimizations like sparse attention and mamba. This can substantially improve inferences speeds due to reducing memory bandwidth in inference and also allowing for higher batch sizes. See e.g. the mamba paper where allowing for higher batch sizes results in substantially higher speeds.
One additional note: I recently set up an inference setup for llama-3-70b on 8xH100. I can get about 100,000 tok/s on inputs which is pretty close to full utilization (1e15 flop/s * 8 gpus / 7e10 flop per forward pass). However, I get dramatically worse performance on generation, perhaps 3,200 tok/s. I’m doing generation with long prompts and llama-3-70b has no sparse attention or other feature for reducing KV cache (beyond multi-query attention which is standard these days), so KV cache bits pretty hard. My setup probably isn’t very close to optimal, especially on output tok/s, I’m just using basic out of the box stuff (vllm).
Thanks that’s interesting!
Can I double check, do you think this affects the bottom lines?
The bottom line is supposed to be that FLOP/s vs. FLOP per forward pass can be used as an upper bound, and memory bandwidth vs. model size can be used as an lower bound, and real life efficiency falls somewhere in the middle depending on a many factors (inc. length of KV cache), which I don’t try to get into, but is plausibly around 15% of the upper bound for GPT-4 on H100s.
Are you saying that the lower bound for output tokens should maybe be even lower, because the KV cache can be larger than the model weights?
The lower bound of “memory bandwidth vs. model size” is effectively equivalent to assuming that the batch size is a single token. I think this isn’t at all close to realistic operating conditions and thus won’t be a very tight lower bound. (Or reflect the most important bottlenecks.)
I think that the KV cache for a single sequence won’t be larger than the model weights for realistic work loads, so the lower bound should still be a valid lower bound. (Though not a tight one.)
I think the bottom line number you provide for “rough estimate of actual throughput” ends up being pretty reasonable for output tokens and considerably too low for input tokens. (I think input tokens probably get more like 50% or 75% flop utilization rather than 15%. See also the difference in price for anthropic model.)
That said, it doesn’t seem like a good mechanism for estimating throughput will be to aggregate the lower and upper bounds you have as the lower bound doesn’t have much correspondence with actual bottlenecks. (For instance, this lower bound would miss that mamba would get much higher throughput.)
I also think that insofar as you care about factors of 3-5 on inference efficiency, you need to do different analysis for input tokens and output tokens.
(I also think that input tokens get pretty close to the pure FLOP estimate. So, another estimation approach you can use if you don’t care about factors of 5 is to just take the pure flop estimate and then halve it to be account for other slow downs. I think this estimate gets input tokens basically right and is wrong by a factor of 3-5 for output tokens.)
It seems like your actual mechanism for making this estimate for the utilization on output tokens was to take the number from semi-analysis and extrapolate it to other GPUs. (At least the number matches this?) This does seem like a reasonable approach, but it isn’t particularly tethered to your lower bound.
I agree the lower bound for output isn’t very tight. I’d be very interested to hear other simple rules of thumb you could use to provide a tighter one.
I’ll add a note to the section on input tokens that since they don’t require KV cache, it’s possible to get much closer to the upper bound.