Text diffusion LLMs can be more efficient than autoregressive models in practice because it is usually more efficient on GPUs to do one big operation than many small operations in sequence, even when both require the same number of FLOPs[1].
My mental model was that autoregressive models are slow during decoding because of the memory transfers required on each token. You need to swap the entire KV cache in and out of Global Memory on each token. I imagine this is what you mean here, just wanted to check this model is along the right lines?
I feel like there must be a way to allow for more flexibility in memory usage. There are clearly scenarios where what you want is MCMC style, constant memory / markov search, i.e: recall. In cases like this, autoregressive transformers are just painful because they inherently tie up iterations with memory usage. But there are scenarios where constant memory loses as well, i.e: sparse recall tasks.
Yes, your model is correct. I wanted to make things as simple as possible when writing the blogpost but probably went too far with this one and ended up just making it confusing / partially innacurate. There are two reasons autoregressive LLM inference is inefficient at long contexts: - You need to load the whole KV cache from VRAM at every forward pass. - Since you need to store the whole KV cache in the VRAM for each sequence and KV caches are big, you can only store a small number of KV caches so you can only have small batch sizes. This makes inference inefficient because you have to load the weights from VRAM at every forward pass. -- Explanation of why big batch sizes are important for making LLM inference efficient (skip if you already know): This is because GPUs have a lot more FLOPs than they have memory bandwidths. So if you multiply batch_size vectors of dimension d_model by a d_model x d_model (or d_model x d_mlp or whatever) matrix and batch size is small, you need O(d_model * d_model + batch_size * d_model) memory reads and O(batch_size * d_model * d_model) FLOPs, so this is bottlenecked by VRAM reads and most compute units just stay idle at small batch sizes, but is bottlenecked by FLOPs at big batch sizes.
I also am somewhat surprised that it’s so hard to make attention more efficient.
My mental model was that autoregressive models are slow during decoding because of the memory transfers required on each token. You need to swap the entire KV cache in and out of Global Memory on each token. I imagine this is what you mean here, just wanted to check this model is along the right lines?
I feel like there must be a way to allow for more flexibility in memory usage. There are clearly scenarios where what you want is MCMC style, constant memory / markov search, i.e: recall. In cases like this, autoregressive transformers are just painful because they inherently tie up iterations with memory usage. But there are scenarios where constant memory loses as well, i.e: sparse recall tasks.
Yes, your model is correct. I wanted to make things as simple as possible when writing the blogpost but probably went too far with this one and ended up just making it confusing / partially innacurate. There are two reasons autoregressive LLM inference is inefficient at long contexts:
- You need to load the whole KV cache from VRAM at every forward pass.
- Since you need to store the whole KV cache in the VRAM for each sequence and KV caches are big, you can only store a small number of KV caches so you can only have small batch sizes. This makes inference inefficient because you have to load the weights from VRAM at every forward pass.
-- Explanation of why big batch sizes are important for making LLM inference efficient (skip if you already know): This is because GPUs have a lot more FLOPs than they have memory bandwidths. So if you multiply
batch_sizevectors of dimensiond_modelby ad_model x d_model(ord_model x d_mlpor whatever) matrix and batch size is small, you needO(d_model * d_model + batch_size * d_model)memory reads andO(batch_size * d_model * d_model)FLOPs, so this is bottlenecked by VRAM reads and most compute units just stay idle at small batch sizes, but is bottlenecked by FLOPs at big batch sizes.I also am somewhat surprised that it’s so hard to make attention more efficient.