“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
“The training algorithm has found a better representation”?? That seems strange to me since the loss should be lower in that case, not spiking. Or maybe you mean that the training broke free of a kind of local minima (without telling that he found a better one yet). Also I guess people training the models observed that waiting after these spike don’t lead to better performances or they would not have removed them from the training.
Around this idea, and after looking at the “grokking” paper, I would guess that it’s more likely caused by the weight decay (or similar) causing the training to break out of a kind of local minima. An interesting point may be that larger/better LM may have significantly sharper internal models and thus are more prone to this phenomenon (The weight decay (or similar) more easily breaking the more sensitive/better/sharper models).
It should be very easy to check if these spikes are caused by the weight decay “damaging” very sharp internal models. Like replay the spiky part several times with less and less weight decay… (I am curious of similar tests with varying the momentum, dropout… At looking if the spikes are initially triggered by some subset of the network, during how many training steps long are the spikes...)
You use different terminology for both. Perhaps exiting local minima is not always a good thing?