Explaining grokking through circuit efficiency

Link post

This is a linkpost for our paper Explaining grokking through circuit efficiency, which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/​Twitter.

Abstract

One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.

Introduction

When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first “memorises” the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network’s test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?

Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss (“slingshots”) (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.

We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call “circuits” (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well () and one which memorises the training dataset (). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high “efficiency”, that is, circuits that require less parameter norm to produce a given logit value.

Efficiency answers our question above: if is more efficient than , gradient descent can reduce nearly perfect training loss even further by strengthening while weakening , which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1) generalises well while does not, (2) is more efficient than , and (3) is learned more slowly than .

Since generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast, must memorise any additional data points added to the training dataset, and so its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both and .

This suggests that there exists a crossover point at which becomes more efficient than , which we call the critical dataset size . By analysing dynamics at , we predict and demonstrate two new behaviours (Figure 1). In ungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than . In semi-grokking, we choose a dataset size where and are similarly efficient, leading to a phase transition but only to middling test accuracy.

We make the following contributions:

  1. We demonstrate the sufficiency of three ingredients for grokking through a constructed simulation (Section 3).

  2. By analysing dynamics at the “critical dataset size” implied by our theory, we predict two novel behaviours: semi-grokking and ungrokking (Section 4).

  3. We confirm our predictions through careful experiments, including demonstrating semi-grokking and ungrokking in practice (Section 5).