Visible loss landscape basins don’t correspond to distinct algorithms

Thanks to Justis, Arthur Conmy, Neel Nanda, Joseph Miller, and Tilman Räuker for their feedback on a draft.

I feel like many people haven’t noticed an important result of mechanistic interpretability analysis of grokking, and so haven’t updated how they think about loss landscapes and algorithms that neural networks end up implementing. I think this has implications for alignment research.

When thinking about grokking, people often imagine something like this: the neural network implements Algorithm 1 (e.g., memorizes the training data), achieves ~ the lowest loss available via memorization, then moves around the bottom of the Algorithm 1 basin and after a while, stumbles across a path to Algorithm 2 (e.g., the general algorithm for modular addition).

People have the intuition that the neural network implements Algorithm 1 when moving around the wider basin and starts implementing Algorithm 2 when it stumbles across the narrow basin

But the mechanistic interpretability of grokking analysis has shown that this is not true!

Approximately from the start of the training, Algorithm 1 is most of what the circuits are doing and what almost entirely determines the neural network’s output; but at the same time, the entire time the neural network’s parameters visibly move down the wider basin, they don’t just become better at memorization; they increasingly implement the circuits for Algorithm 1 and the circuits for Algorithm 2, in superposition.

(Neel Nanda et al. have shown that the circuits that at the end implement the general algorithm for modular addition start forming approximately at the start of the training: the gradient was mostly an arrow towards memorization, but also, immediately from the initialization of the weights, a bit of an arrow pointing towards the general algorithm. The circuits were gradually tuned throughout the training. The noticeable change in the test loss starts occurring when the circuits are already almost right.)

A path through the loss landscape visible in 3D doesn’t correspond to how and what the neural network is actually learning. Almost all of the changes to the loss are due to the increasingly good implementation of Algorithm 1; but apparently, the entire time, the gradient also points towards some faraway implementation of Algorithm 2. Somehow, the direction in which Algorithm 2 lies is also visible to the derivative, and moving the parameters in the direction the gradient points means mostly increasingly implementing Algorithm 1, and also increasingly implementing the faraway Algorithm 2.

“Grokking”, visible in the test loss, is due to the change that happens when the parameters already implement Algorithm 2 accurately enough for the switch from mostly outputting the results of an implementation of Algorithm 1 to the results of an improving implementation of Algorithm 2 not to hurt the performance. Once it’s the case, the neural network puts more weight into Algorithm 2 and at the same time quickly tunes it to be even more accurate (which is increasingly easy as the output is increasingly determined by the implementation of Algorithm 2).

This is something many people seem to have missed. I did not expect it to be the case, was surprised, and updated how I think about loss landscapes.

Does this generalize?

Maybe. I’m not sure whether it’s correct to generalize from the mechanistic interpretability of grokking analysis to neural networks in general, real LLMs are under-parametrised while the grokking model is very over-parameterised, but I guess it might be reasonable to expect that this is how deep learning generally works.

People seem to think that multi-dimensional loss landscapes of neural networks have basins for specific algorithms, and neural networks get into these depending on how relatively large these basins are, which might be caused by how simple the algorithms are, how path-dependent their implementation might be, etc. I think this makes a wrong prediction for what happens in grokking.

Maybe there are better ways to think about what’s actually going on. Significantly visible basins correspond to the implementation of algorithms that currently influence performance the most. But the neural network might be implementing whatever algorithms will output predictions that’d have mutual information with whatever the gradient communicates, and the algorithms that you see are not necessarily the better algorithms that the neural network is already slowly implementing (but not yet heavily reliant on).

It might be helpful to imagine two independent basins for two algorithms: how much each algorithm reduces the loss and how well the neural network implements them. If you sum the two basins, then, if you look at an area of the loss landscape, you’ll mostly only notice the wider basin; but in a high-dimensional space, gradient descent might be going down both at the same time, and the combination might not interfere enough to prevent this from happening, so at the end you might end up with the optimal algorithm, even if for most of the training, you thought you were looking only at a suboptimal one.

Some implications for alignment

If you were imagining that the loss landscape looks like this:

then you might have hoped you could find a way to shape it so that some simple algorithm exhibiting aligned behaviour somehow has much larger basins so that you’d likely end up and remain in it even if some less aligned algorithms would achieve a better loss. You might also have hoped to use interpretability tools to understand what’s going on in the neural network, and what the algorithm it implements thinks about.

This might not work; speculatively, if misaligned algorithms can be implemented by the neural network and would perform better than the aligned algorithms you were hoping for, the neural network might end up implementing them no matter what were the visible basins. Your interpretability tools might not distinguish the gradual implementation of a misaligned algorithm from noise. Something seemingly aligned might be mostly responsible for the outputs, but if there’s an agent with different goals, that can achieve a lower loss, its implementation might be slowly building up the whole time. I think this adds another difficulty on top of the classical sharp left turn: you need to deal not only with the changes in a specific algorithm whose alignment doesn’t necessarily generalise together with capabilities, but also with the possibility of a generally capable algorithm that your neural network might be directly implementing that has never even had alignment properties. You might end up not noticing that 2% of activations of a neuron that you thought distinguishes cats from dogs are devoted to planning to kill you.

Further research

It might be valuable to explore this more. Do similar dynamics generally occur during training, especially in models that aren’t over-parameterised? If you reverse-engineer and understand the final algorithm, when and how has the neural network started implementing it?