Sometimes FLOP/s isn’t the bottleneck for training models; e.g. it could be memory bandwidth. My impression from poking around with Nsight and some other observations is that wide SAEs might actually be FLOP/s bottlenecked but I don’t trust my impression that much. I’d be interested in someone doing a comparison of this SAE architectures in terms of H100 seconds or something like that in addition to FLOP.
Did it seem to you like this architecture also trained faster in terms of wall-time?
Anyway, nice work! It’s cool to see these results.
It looks like you when you encode you do a dense encoder forward pass and then mask using the expert router.
I think this means that the FLOP scaling laws claim is misleading because (my impression is that) your current train code uses much more FLOP than the scaling law graphs, because it calculates every expert’s activations for every input.
But I think the empirical claims about the learned features and the FLOP scaling laws still should hold up for implementations that actually do the conditional computations.
I also expect H100/B100-time scaling charts than FLOP charts to be more informative for future work because I now think memory-bandwidth has decent odds of being the main bottleneck for training time.
Thanks for the comment—I trained TopK SAEs with various widths (all fitting within a single GPU) and observed wider SAEs take substantially longer to train, which leads me to believe that the encoder forward pass is a major bottleneck for wall-clock time. The Switch SAE also improves memory efficiency because we do not need to store all M latents.
I’m currently working on implementing expert-parallelism, which I hope will lead to substantial improvements to wall-clock time.
Sometimes FLOP/s isn’t the bottleneck for training models; e.g. it could be memory bandwidth. My impression from poking around with Nsight and some other observations is that wide SAEs might actually be FLOP/s bottlenecked but I don’t trust my impression that much. I’d be interested in someone doing a comparison of this SAE architectures in terms of H100 seconds or something like that in addition to FLOP.
Did it seem to you like this architecture also trained faster in terms of wall-time?
Anyway, nice work! It’s cool to see these results.
If this is your implementation:
https://github.com/amudide/switch_sae/blob/main/dictionary_learning/trainers/switch.py
It looks like you when you encode you do a dense encoder forward pass and then mask using the expert router.
I think this means that the FLOP scaling laws claim is misleading because (my impression is that) your current train code uses much more FLOP than the scaling law graphs, because it calculates every expert’s activations for every input.
But I think the empirical claims about the learned features and the FLOP scaling laws still should hold up for implementations that actually do the conditional computations.
I also expect H100/B100-time scaling charts than FLOP charts to be more informative for future work because I now think memory-bandwidth has decent odds of being the main bottleneck for training time.
Thanks for the comment—I trained TopK SAEs with various widths (all fitting within a single GPU) and observed wider SAEs take substantially longer to train, which leads me to believe that the encoder forward pass is a major bottleneck for wall-clock time. The Switch SAE also improves memory efficiency because we do not need to store all M latents.
I’m currently working on implementing expert-parallelism, which I hope will lead to substantial improvements to wall-clock time.