This is great. I’m a bit surprised you get such a big performance improvement from adding additional sparse layers; all of my experiments above have been adding non-sparse layers, but it looks like the MSE benefit you’re getting with added sparse layers is in the same ballpark. You have certainly convinced me to try muon.
Another approach that I’ve (very recently) found quite effective in reducing the number of dead neurons with minimal MSE hit has been adding a small penalty term on the standard deviation of the encoder pre-act (i.e., before the top-k) means across the batch dimension. This has basically eliminated my dead neuron woes and this is what I’m currently running with. I’ll probably try this in combination with muon sometime over the next couple of days.
This is great. I’m a bit surprised you get such a big performance improvement from adding additional sparse layers; all of my experiments above have been adding non-sparse layers, but it looks like the MSE benefit you’re getting with added sparse layers is in the same ballpark. You have certainly convinced me to try muon.
Another approach that I’ve (very recently) found quite effective in reducing the number of dead neurons with minimal MSE hit has been adding a small penalty term on the standard deviation of the encoder pre-act (i.e., before the top-k) means across the batch dimension. This has basically eliminated my dead neuron woes and this is what I’m currently running with. I’ll probably try this in combination with muon sometime over the next couple of days.
And these ideas all sound great.