“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
Am I right in thinking that, according to your theory, the “fix” they did (restarting training from checkpoint 100 steps before the spike started, but with different data, to avoid the spike) is actually counterproductive because it’s preventing the model from grokking? And instead they should have just kept training to “push through the spike” and get to a new, lower-loss regime?
Now I’m not saying it’s anthropic pressure, but if that’s true maybe we shouldn’t just keep training until we know what exactly it is that the model is grokking.
Whatever is happening, I’m really concerned about the current “sufficiently big model starts to exhibit <weird behaviour A>. I don’t understand, but also don’t care, here is a dirty workaround and just give it more compute lol” paradigm. I don’t think this is very safe.
My guess would be that the model is ‘grokking’ something: https://mathai-iclr.github.io/papers/papers/MATHAI_29_paper.pdf
IOW it’s found a much better internal representation, and now has to rework a lot of its belief space to make use of that internal representation.
“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
You use different terminology for both. Perhaps exiting local minima is not always a good thing?
Am I right in thinking that, according to your theory, the “fix” they did (restarting training from checkpoint 100 steps before the spike started, but with different data, to avoid the spike) is actually counterproductive because it’s preventing the model from grokking? And instead they should have just kept training to “push through the spike” and get to a new, lower-loss regime?
Now I’m not saying it’s anthropic pressure, but if that’s true maybe we shouldn’t just keep training until we know what exactly it is that the model is grokking.
Whatever is happening, I’m really concerned about the current “sufficiently big model starts to exhibit <weird behaviour A>. I don’t understand, but also don’t care, here is a dirty workaround and just give it more compute lol” paradigm. I don’t think this is very safe.
If I could get people to change that paradigm, you bet I would.