It looks like you when you encode you do a dense encoder forward pass and then mask using the expert router.
I think this means that the FLOP scaling laws claim is misleading because (my impression is that) your current train code uses much more FLOP than the scaling law graphs, because it calculates every expert’s activations for every input.
But I think the empirical claims about the learned features and the FLOP scaling laws still should hold up for implementations that actually do the conditional computations.
I also expect H100/B100-time scaling charts than FLOP charts to be more informative for future work because I now think memory-bandwidth has decent odds of being the main bottleneck for training time.
If this is your implementation:
https://github.com/amudide/switch_sae/blob/main/dictionary_learning/trainers/switch.py
It looks like you when you encode you do a dense encoder forward pass and then mask using the expert router.
I think this means that the FLOP scaling laws claim is misleading because (my impression is that) your current train code uses much more FLOP than the scaling law graphs, because it calculates every expert’s activations for every input.
But I think the empirical claims about the learned features and the FLOP scaling laws still should hold up for implementations that actually do the conditional computations.
I also expect H100/B100-time scaling charts than FLOP charts to be more informative for future work because I now think memory-bandwidth has decent odds of being the main bottleneck for training time.