Yes, your model is correct. I wanted to make things as simple as possible when writing the blogpost but probably went too far with this one and ended up just making it confusing / partially innacurate. There are two reasons autoregressive LLM inference is inefficient at long contexts: - You need to load the whole KV cache from VRAM at every forward pass. - Since you need to store the whole KV cache in the VRAM for each sequence and KV caches are big, you can only store a small number of KV caches so you can only have small batch sizes. This makes inference inefficient because you have to load the weights from VRAM at every forward pass. -- Explanation of why big batch sizes are important for making LLM inference efficient (skip if you already know): This is because GPUs have a lot more FLOPs than they have memory bandwidths. So if you multiply batch_size vectors of dimension d_model by a d_model x d_model (or d_model x d_mlp or whatever) matrix and batch size is small, you need O(d_model * d_model + batch_size * d_model) memory reads and O(batch_size * d_model * d_model) FLOPs, so this is bottlenecked by VRAM reads and most compute units just stay idle at small batch sizes, but is bottlenecked by FLOPs at big batch sizes.
I also am somewhat surprised that it’s so hard to make attention more efficient.
Yes, your model is correct. I wanted to make things as simple as possible when writing the blogpost but probably went too far with this one and ended up just making it confusing / partially innacurate. There are two reasons autoregressive LLM inference is inefficient at long contexts:
- You need to load the whole KV cache from VRAM at every forward pass.
- Since you need to store the whole KV cache in the VRAM for each sequence and KV caches are big, you can only store a small number of KV caches so you can only have small batch sizes. This makes inference inefficient because you have to load the weights from VRAM at every forward pass.
-- Explanation of why big batch sizes are important for making LLM inference efficient (skip if you already know): This is because GPUs have a lot more FLOPs than they have memory bandwidths. So if you multiply
batch_sizevectors of dimensiond_modelby ad_model x d_model(ord_model x d_mlpor whatever) matrix and batch size is small, you needO(d_model * d_model + batch_size * d_model)memory reads andO(batch_size * d_model * d_model)FLOPs, so this is bottlenecked by VRAM reads and most compute units just stay idle at small batch sizes, but is bottlenecked by FLOPs at big batch sizes.I also am somewhat surprised that it’s so hard to make attention more efficient.