This is a linkpost for our paper Explaining grokking through circuit efficiency, which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/Twitter.
Abstract
One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.
Introduction
When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first “memorises” the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network’s test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?
Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss (“slingshots”) (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.
We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call “circuits” (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well (Cgen) and one which memorises the training dataset (Cmem). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high “efficiency”, that is, circuits that require less parameter norm to produce a given logit value.
Efficiency answers our question above: if Cgen is more efficient than Cmem, gradient descent can reduce nearly perfect training loss even further by strengthening Cgen while weakening Cmem, which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1) Cgen generalises well while Cmem does not, (2) Cgen is more efficient than Cmem, and (3) Cgen is learned more slowly than Cmem.
Since Cgen generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast, Cmem must memorise any additional data points added to the training dataset, and so its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both Cmem and Cgen.
This suggests that there exists a crossover point at which Cgen becomes more efficient than Cmem, which we call the critical dataset size Dcrit. By analysing dynamics at Dcrit, we predict and demonstrate two new behaviours (Figure 1). In ungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than Dcrit. In semi-grokking, we choose a dataset size where Cgen and Cmem are similarly efficient, leading to a phase transition but only to middling test accuracy.
We make the following contributions:
We demonstrate the sufficiency of three ingredients for grokking through a constructed simulation (Section 3).
By analysing dynamics at the “critical dataset size” implied by our theory, we predict two novel behaviours: semi-grokking and ungrokking (Section 4).
We confirm our predictions through careful experiments, including demonstrating semi-grokking and ungrokking in practice (Section 5).
Explaining grokking through circuit efficiency
Link post
This is a linkpost for our paper Explaining grokking through circuit efficiency, which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/Twitter.
Abstract
One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.
Introduction
When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first “memorises” the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network’s test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?
Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss (“slingshots”) (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.
We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call “circuits” (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well (Cgen) and one which memorises the training dataset (Cmem). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high “efficiency”, that is, circuits that require less parameter norm to produce a given logit value.
Efficiency answers our question above: if Cgen is more efficient than Cmem, gradient descent can reduce nearly perfect training loss even further by strengthening Cgen while weakening Cmem, which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1) Cgen generalises well while Cmem does not, (2) Cgen is more efficient than Cmem, and (3) Cgen is learned more slowly than Cmem.
Since Cgen generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast, Cmem must memorise any additional data points added to the training dataset, and so its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both Cmem and Cgen.
This suggests that there exists a crossover point at which Cgen becomes more efficient than Cmem, which we call the critical dataset size Dcrit. By analysing dynamics at Dcrit, we predict and demonstrate two new behaviours (Figure 1). In ungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than Dcrit. In semi-grokking, we choose a dataset size where Cgen and Cmem are similarly efficient, leading to a phase transition but only to middling test accuracy.
We make the following contributions:
We demonstrate the sufficiency of three ingredients for grokking through a constructed simulation (Section 3).
By analysing dynamics at the “critical dataset size” implied by our theory, we predict two novel behaviours: semi-grokking and ungrokking (Section 4).
We confirm our predictions through careful experiments, including demonstrating semi-grokking and ungrokking in practice (Section 5).