Thanks for the post, I quite enjoyed it! I was especially happy to see the significantly more detailed graphs from the replication than the ones I had seen before.
Re: comparison to my explanation:
I think memorization is actually straightforward for a network to do. You just need a key-value system where the key detects the embedding associated with a given input, then activates the value, which produces the correct output for the detected input. Such systems are easy to implement in many ML architectures (feed forward, convolution, self attention, etc).
I agree this is straightforward for us to understand, but it’s still a pretty complicated program that has to be implemented by a large number of parameters, and each individual key-value pair is only reinforced by a single (!!) training data point. Given how little these parameters are being reinforced by SGD, it seems plausible that it only gets to “put 0.95 probability on the answer rather than 0.99”.
But more fundamentally, if the memorization circuit was already putting probability ~1 on all the correct answers, then the gradients would be zero and the network would never change (at least in the absence of things like weight decay, and even weight decay often doesn’t change circuits when the rest of the gradients are zero). Clearly, the network is changing, therefore the memorized circuit can’t be placing probability ~1 on the correct answers.
(This might be wrong if there’s some source of gradients other than accuracy / cross-entropy loss. My vague recollection was that grokking happened even when that was the only source of gradients, but I could easily be wrong about that.)
Additionally, it’s not the case that “once it hits upon the correctly generalizing function (or something close enough to it), it very quickly becomes confident in it”. This is an illusion caused by the log scale on the x-axis of the plots.
Yeah, this was perhaps too much of a simplification. It’s more that, at that point in the training process, the gradients are all tiny (since you are mostly predicting the right answer thanks to the memorization). The learning rate might also be significantly lower depending on what learning rate schedule was used; I don’t remember the details there. Given that, the speed of learning is striking to me.
Basically, I’m disagreeing with your statement here:
If the model stumbles upon a single general circuit that solves the entire problem, then you’d expect it to make the switch very quickly.
I don’t think this is true because thanks to the memorization your gradients are very small and you can’t make any switches very quickly.
(It would be interesting to see if, once grokking had clearly started, you could just 100x the learning rate and speed up the convergence to zero validation loss by 100x. That’s more strongly predicted by my story; I think it isn’t predicted by yours, since you require a more sequential process where circuits get combined over time.)
My account directly predicts that stochasticity and weight decay regularization would help with generalization, and even predicts that weight decay would be one of the most effective interventions to improve generalization.
How so? If you’re using the assumption “stochasticity and weight decay probably prefer general circuits over shallow circuits”, I feel like given that assumption my story makes the same prediction.
Finally, if we look at a loss plot on a log scale, we can see that the validation loss starts decreasing at ~ step 1.3×104, while the floor on the minimum training loss remains fairly constant (or even increases) until slightly after that (~step 2.3×104). Thus, validation loss starts decreasing thousands of steps before training loss starts decreasing. Whatever is causing the generalization, it’s not doing so to decrease training loss (at least not at first).
Idk what’s going on from steps 1.3×104 to 4×104. I hadn’t seen those spikes before, I might interpret them as “failed attempts” to find the one general circuit. I mostly see my story as explaining what happens from 4×104 onwards. (Notably to me, that’s when the validation accuracy looks like it starts increasing.) I do agree this seems like an important piece of data that’s not predicted by my story.
I think my biggest reason for preferring my story is actually that it is disanalogous to evolution, and it has complex structures emerging from scratch. Grokking is a really weird phenomenon that you don’t usually see in ML systems; I want a weird explanation that wouldn’t apply to all the other ML systems that seemingly don’t display grokking (e.g. language models).
In addition, it seems like a really big clue that grokking was discovered in this very simple abstract mathematical setting, that seems like exactly the sort of setting where there might be “only two solutions” (memorization and the “true” circuit). I think the “shallow circuits get combined into general circuits” probably is how most neural networks work, and leads to nice smooth loss curves with empirically predictable scaling laws, and is very nicely analogous to evolution—and this means that usually you don’t see grokking; grokking only happens when the structure of the environment means that you can’t combine a few shallow circuits into a slightly more general circuit, and instead you have to go directly from “shallow memorization” to “the true answer”.
(That being said, an alternative explanation is “the time at which grokking happens increases exponentially with ‘complexity’ ”, which suggests that the simple abstract mathematical setting is the only one in which we’d have trained models far enough to reach the point of grokking.)
I think the general picture here could also explain why RL often has more “jumpy” curves than language models or image classification—RL is often done in relatively toy environments where there are far fewer features to build circuits out of, relative to language or image data. That being said, “RL has to deal with exploration” is a very plausible alternative explanation for that fact. I do think OpenAI Five and other “big” RL projects had smoother loss curves than more toy RL projects, which supports this picture over “exploration is the problem”. Similarly I think this picture can help explain blessings of scale.
It would be interesting to see if, once grokking had clearly started, you could just 100x the learning rate and speed up the convergence to zero validation loss by 100x.
I ran a quick-and-dirty experiment and it does in fact look like you can just crank up the learning rate at the point where some part of grokking happens to speed up convergence significantly. See the wandb report:
I set the LR to 5x the normal value (100x tanked the accuracy, 10x still works though). Of course you would want to anneal it after grokking was finished.
Very nice! Thanks for actually running the experiment :)
It’s not clear to me which story this supports since 10x-ing the learning rate only brings the grokking phase down to 8×104 steps, which is still the majority of the training run.
I chose the grokking starting point as 300 steps, based on the yellow plot. I’d say it’s reasonable to say that ‘grokking is complete’ by the 2000 step mark in the default setting, whereas it is complete by the 450 step mark in the 10x setting (assuming appropriate LR decay to avoid overshooting). Also note that the plots in the report are not log-scale
Ah, I just looked at your plots, verified that the grokking indeed still happened with 5x and 10x learning rates, and then just assumed 10x faster convergence in the original plots in the post. Apparently that reasoning was wrong. Presumably you’re using different hyperparameters than the ones used in this post? You seem to have faster grokking in the “default setting” than the in the plots shown in the post.
(And it does look like, given some default setting, “10x faster convergence” is basically right, since in your case 10x higher LR makes the grokking stage go from 1700 steps to 150 steps.)
(Partly the issue was that I wasn’t sure whether the x-axis in your plots was starting from the beginning of training, or from the point that grokking started, so I instead reasoned about the impact on the graphs in this post. Though looking at the LR plot it’s now obvious that it’s from the beginning of training.)
I now think this is relatively strong evidence for my view, given that grokking happens pretty quickly (~a third of total training), though it probably is still decently slower than the memorization. (Do you happen to have the training loss curves, so we can estimate how long it takes to memorize under your hyperparameters?)
First, I’d like to note that I don’t see why faster convergence after changing the learning rate support either story. After initial memorization, the loss decreases by ~3 OOM. Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
Also, I think what’s actually going on here is weirder than either of our interpretations. I ran experiments where I kept the learning rate the same for the first 1000 steps, then increased it by 10x and 50x for the rest of the training.
Here is the accuracy curve with the default learning rate:
Here is the curve with 10x learning rate:
And here is the curve with 50x learning rate:
Note that increasing the learning rate doesn’t consistently increase validation convergence. The 50x run does reach convergence faster, but the 10x run doesn’t even reach it at all.
In fact, increasing the learning rate causes the training accuracy to fall to the validation accuracy, after which they begin to increase together (at least for a while). For the 10x increase, the training accuracy quickly diverges from the validation accuracy. In the 50x run, the training and validation accuracies move in tandem throughout the run.
Frederik’s results are broadly similar. If you mouse over the accuracy and loss graphs, you’ll see that
Training performance drops significantly immediately after the learning rate increases.
The losses and accuracies of the “5x” and “10x” lines correlate together pretty well between training/validation. In contrast, the losses and accuracies of the “default” lines don’t correlate strongly between training and testing.
I think that increasing the learning rate after memorization causes some sort of “mode shift” in the training process. It goes from:
First, learn shallow patterns that strongly overfit to the training data, then learn general patterns.
to:
Immediately learn general patterns that perform about equally well on the training and validation data.
In the case of my 10x run, I think it actually has two mode transitions, first from “shallow first” to “immediately general”, then another transition back to “shallow first”, and that’s why you see the training accuracy diverge from the validation accuracy again.
I think results like these make a certain amount of sense, given that higher learning rates are associated with better generalization in more standard settings.
Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
I’m kinda confused at your perspective on learning rates. I usually think of learning rates as being set to the maximum possible value such that training is still stable. So it would in fact be surprising if you could just 10x them to speed up convergence. (So an additional aspect of my prediction would be that you can’t 10x the learning rate at the beginning of training; if you could then it seems like the hyperparameters were chosen poorly and that should be fixed first.)
Indeed in your experiments at the moment you 10x the learning rate accuracy does in fact plummet! I’m a bit surprised it manages to recover, but you can see that the recovery is not nearly as stable as the original training before increasing the learning rate (this is even more obvious in the 50x case), and notably even the recovery for the training accuracy looks like it takes longer (1000-2000 steps) than the original increase in training accuracy (~400 steps).
I do think this suggests that you can’t in fact “just 10x the learning rate” once grokking starts, which seems like a hit to my story.
I updated the report with the training curves. Under default settings, 100% training accuracy is reached after 500 steps.
There is actually an overlap between the train/val curves going up. Might be an artifact of the simplicity of the task or that I didn’t properly split the dataset (e.g. x+y being in train and y+x being in val). I might run it again for a harder task to verify.
Huh, intriguing. Yeah, it might be worth running with a non-commutative function and seeing if it holds up—it seems like in the default setting the validation accuracy hits almost 0.5 once the training accuracy is 1, which is about what you’d get if you understood commutativity but nothing else about the function. So the “grokking” part is probably happening after that, i.e. at roughly the 1.5k steps location in the default setting.
Also interestingly, in the default setting for these new experiments, grokking happens in ~1000 steps while memorization happens in ~1500 steps, so the grokking is already faster than the memorization, in stark contrast to the graphs in the original post.
(This does depend on when you start the counter for grokking, as there’s a long period of slowly increasing validation accuracy. You could reasonably say grokking took ~2500 steps.)
Oh I thought figure 1 was S5 but it actually is modular division. I’ll give that a go..
Here are results for modular division. Not super sure what to make of them. Small increases in learning rate work, but so does just choosing a larger learning rate from the beginning. In fact, increasing lr to 5x from the beginning works super well but switching to 5x once grokking arguably starts just destroys any progress. 10x lr from the start does not work (nor when switching later)
So maybe the initial observation is more a general/global property of the loss landscape for the task and not of the particular region during grokking?
Yep I used my own re-implementation, which somehow has slightly different behavior.
I’ll also note that the task in the report is modular addition while figure 1 from the paper (the one with the red and green lines for train/val) is the significantly harder permutation group task.
Thanks for the post, I quite enjoyed it! I was especially happy to see the significantly more detailed graphs from the replication than the ones I had seen before.
Re: comparison to my explanation:
I agree this is straightforward for us to understand, but it’s still a pretty complicated program that has to be implemented by a large number of parameters, and each individual key-value pair is only reinforced by a single (!!) training data point. Given how little these parameters are being reinforced by SGD, it seems plausible that it only gets to “put 0.95 probability on the answer rather than 0.99”.
But more fundamentally, if the memorization circuit was already putting probability ~1 on all the correct answers, then the gradients would be zero and the network would never change (at least in the absence of things like weight decay, and even weight decay often doesn’t change circuits when the rest of the gradients are zero). Clearly, the network is changing, therefore the memorized circuit can’t be placing probability ~1 on the correct answers.
(This might be wrong if there’s some source of gradients other than accuracy / cross-entropy loss. My vague recollection was that grokking happened even when that was the only source of gradients, but I could easily be wrong about that.)
Yeah, this was perhaps too much of a simplification. It’s more that, at that point in the training process, the gradients are all tiny (since you are mostly predicting the right answer thanks to the memorization). The learning rate might also be significantly lower depending on what learning rate schedule was used; I don’t remember the details there. Given that, the speed of learning is striking to me.
Basically, I’m disagreeing with your statement here:
I don’t think this is true because thanks to the memorization your gradients are very small and you can’t make any switches very quickly.
(It would be interesting to see if, once grokking had clearly started, you could just 100x the learning rate and speed up the convergence to zero validation loss by 100x. That’s more strongly predicted by my story; I think it isn’t predicted by yours, since you require a more sequential process where circuits get combined over time.)
How so? If you’re using the assumption “stochasticity and weight decay probably prefer general circuits over shallow circuits”, I feel like given that assumption my story makes the same prediction.
Idk what’s going on from steps 1.3×104 to 4×104. I hadn’t seen those spikes before, I might interpret them as “failed attempts” to find the one general circuit. I mostly see my story as explaining what happens from 4×104 onwards. (Notably to me, that’s when the validation accuracy looks like it starts increasing.) I do agree this seems like an important piece of data that’s not predicted by my story.
I think my biggest reason for preferring my story is actually that it is disanalogous to evolution, and it has complex structures emerging from scratch. Grokking is a really weird phenomenon that you don’t usually see in ML systems; I want a weird explanation that wouldn’t apply to all the other ML systems that seemingly don’t display grokking (e.g. language models).
In addition, it seems like a really big clue that grokking was discovered in this very simple abstract mathematical setting, that seems like exactly the sort of setting where there might be “only two solutions” (memorization and the “true” circuit). I think the “shallow circuits get combined into general circuits” probably is how most neural networks work, and leads to nice smooth loss curves with empirically predictable scaling laws, and is very nicely analogous to evolution—and this means that usually you don’t see grokking; grokking only happens when the structure of the environment means that you can’t combine a few shallow circuits into a slightly more general circuit, and instead you have to go directly from “shallow memorization” to “the true answer”.
(That being said, an alternative explanation is “the time at which grokking happens increases exponentially with ‘complexity’ ”, which suggests that the simple abstract mathematical setting is the only one in which we’d have trained models far enough to reach the point of grokking.)
I think the general picture here could also explain why RL often has more “jumpy” curves than language models or image classification—RL is often done in relatively toy environments where there are far fewer features to build circuits out of, relative to language or image data. That being said, “RL has to deal with exploration” is a very plausible alternative explanation for that fact. I do think OpenAI Five and other “big” RL projects had smoother loss curves than more toy RL projects, which supports this picture over “exploration is the problem”. Similarly I think this picture can help explain blessings of scale.
I ran a quick-and-dirty experiment and it does in fact look like you can just crank up the learning rate at the point where some part of grokking happens to speed up convergence significantly. See the wandb report:
https://wandb.ai/tomfrederik/interpreting_grokking/reports/Increasing-Learning-Rate-at-Grokking—VmlldzoxNTQ2ODY2?accessToken=y3f00qfxot60n709pu8d049wgci69g53pki6pq6khsemnncca1dnmocu7a3d43y8
I set the LR to 5x the normal value (100x tanked the accuracy, 10x still works though). Of course you would want to anneal it after grokking was finished.
Very nice! Thanks for actually running the experiment :)
It’s not clear to me which story this supports since 10x-ing the learning rate only brings the grokking phase down to 8×104 steps, which is still the majority of the training run.
I’m not sure I understand.
I chose the grokking starting point as 300 steps, based on the yellow plot. I’d say it’s reasonable to say that ‘grokking is complete’ by the 2000 step mark in the default setting, whereas it is complete by the 450 step mark in the 10x setting (assuming appropriate LR decay to avoid overshooting). Also note that the plots in the report are not log-scale
Ah, I just looked at your plots, verified that the grokking indeed still happened with 5x and 10x learning rates, and then just assumed 10x faster convergence in the original plots in the post. Apparently that reasoning was wrong. Presumably you’re using different hyperparameters than the ones used in this post? You seem to have faster grokking in the “default setting” than the in the plots shown in the post.
(And it does look like, given some default setting, “10x faster convergence” is basically right, since in your case 10x higher LR makes the grokking stage go from 1700 steps to 150 steps.)
(Partly the issue was that I wasn’t sure whether the x-axis in your plots was starting from the beginning of training, or from the point that grokking started, so I instead reasoned about the impact on the graphs in this post. Though looking at the LR plot it’s now obvious that it’s from the beginning of training.)
I now think this is relatively strong evidence for my view, given that grokking happens pretty quickly (~a third of total training), though it probably is still decently slower than the memorization. (Do you happen to have the training loss curves, so we can estimate how long it takes to memorize under your hyperparameters?)
First, I’d like to note that I don’t see why faster convergence after changing the learning rate support either story. After initial memorization, the loss decreases by ~3 OOM. Regardless of what’s gaining on inside the network, it wouldn’t be surprising if raising the learning rate increased convergence.
Also, I think what’s actually going on here is weirder than either of our interpretations. I ran experiments where I kept the learning rate the same for the first 1000 steps, then increased it by 10x and 50x for the rest of the training.
Here is the accuracy curve with the default learning rate:
Here is the curve with 10x learning rate:
And here is the curve with 50x learning rate:
Note that increasing the learning rate doesn’t consistently increase validation convergence. The 50x run does reach convergence faster, but the 10x run doesn’t even reach it at all.
In fact, increasing the learning rate causes the training accuracy to fall to the validation accuracy, after which they begin to increase together (at least for a while). For the 10x increase, the training accuracy quickly diverges from the validation accuracy. In the 50x run, the training and validation accuracies move in tandem throughout the run.
Frederik’s results are broadly similar. If you mouse over the accuracy and loss graphs, you’ll see that
Training performance drops significantly immediately after the learning rate increases.
The losses and accuracies of the “5x” and “10x” lines correlate together pretty well between training/validation. In contrast, the losses and accuracies of the “default” lines don’t correlate strongly between training and testing.
I think that increasing the learning rate after memorization causes some sort of “mode shift” in the training process. It goes from:
First, learn shallow patterns that strongly overfit to the training data, then learn general patterns.
to:
Immediately learn general patterns that perform about equally well on the training and validation data.
In the case of my 10x run, I think it actually has two mode transitions, first from “shallow first” to “immediately general”, then another transition back to “shallow first”, and that’s why you see the training accuracy diverge from the validation accuracy again.
I think results like these make a certain amount of sense, given that higher learning rates are associated with better generalization in more standard settings.
I’m kinda confused at your perspective on learning rates. I usually think of learning rates as being set to the maximum possible value such that training is still stable. So it would in fact be surprising if you could just 10x them to speed up convergence. (So an additional aspect of my prediction would be that you can’t 10x the learning rate at the beginning of training; if you could then it seems like the hyperparameters were chosen poorly and that should be fixed first.)
Indeed in your experiments at the moment you 10x the learning rate accuracy does in fact plummet! I’m a bit surprised it manages to recover, but you can see that the recovery is not nearly as stable as the original training before increasing the learning rate (this is even more obvious in the 50x case), and notably even the recovery for the training accuracy looks like it takes longer (1000-2000 steps) than the original increase in training accuracy (~400 steps).
I do think this suggests that you can’t in fact “just 10x the learning rate” once grokking starts, which seems like a hit to my story.
I updated the report with the training curves. Under default settings, 100% training accuracy is reached after 500 steps.
There is actually an overlap between the train/val curves going up. Might be an artifact of the simplicity of the task or that I didn’t properly split the dataset (e.g. x+y being in train and y+x being in val). I might run it again for a harder task to verify.
Huh, intriguing. Yeah, it might be worth running with a non-commutative function and seeing if it holds up—it seems like in the default setting the validation accuracy hits almost 0.5 once the training accuracy is 1, which is about what you’d get if you understood commutativity but nothing else about the function. So the “grokking” part is probably happening after that, i.e. at roughly the 1.5k steps location in the default setting.
So I ran some experiments for the permutation group S_5 with the task x o y = ?
Interestingly here increasing the learning rate just never works. I’m very confused.
Also interestingly, in the default setting for these new experiments, grokking happens in ~1000 steps while memorization happens in ~1500 steps, so the grokking is already faster than the memorization, in stark contrast to the graphs in the original post.
(This does depend on when you start the counter for grokking, as there’s a long period of slowly increasing validation accuracy. You could reasonably say grokking took ~2500 steps.)
Oh I thought figure 1 was S5 but it actually is modular division. I’ll give that a go..
Here are results for modular division. Not super sure what to make of them. Small increases in learning rate work, but so does just choosing a larger learning rate from the beginning. In fact, increasing lr to 5x from the beginning works super well but switching to 5x once grokking arguably starts just destroys any progress. 10x lr from the start does not work (nor when switching later)
So maybe the initial observation is more a general/global property of the loss landscape for the task and not of the particular region during grokking?
Yeah, that seems right, I think I’m basically at “no, you can’t just 10x the learning rate once grokking starts”.
Increasing regularization (weight decay in this instance) might rescue the ones which don’t work.
I tried increasing weight decay and increased batch sizes but so far no real success compared to 5x lr. Not going to investigate this further atm.
Yep I used my own re-implementation, which somehow has slightly different behavior.
I’ll also note that the task in the report is modular addition while figure 1 from the paper (the one with the red and green lines for train/val) is the significantly harder permutation group task.