Decomposing independent generalizations in neural networks via Hessian analysis

In our joint SERI MATS project, we came up with a series of equations and experiments to mechanistically understand and steer the generalization behavior of neural nets. The core conceit is to locate the circuits (which we call “modules”) responsible for implementing different generalizations using a toolbox of techniques related to Hessian eigenvectors. This is a general-audience distillation of our work.

We hope most of the ideas and high-level goals are understandable to a non-expert, though, for most of our experiments, we attempt to include “supplementary” material with the main mathematical intuitions and concrete equations that would allow someone to reproduce our work. We plan in the coming weeks to write multiple follow-up distillations and discussions, both of some of the more technical parts of our work and of a few new insights into generalization behavior and phase transitions in general that came out of experiments involving our Hessian toolbox.

Introduction

A central problem for inner alignment is understanding how neural nets generalize off-distribution. For example, a powerful AI agent trained to make people happy can generalize either by choosing actions that deceptively look good to its overseers or those that truly align with human values. The same diversity of generalization is already seen in existing real-world tasks both in minor ways (image classifiers classifying cars by learning to recognize wheels vs. windows) and in serious ways (language models appearing honest by agreeing with the user vs. insisting on consensus opinions).

One approach to steer between generalizations is activation steering, which Nina is investigating as her other SERI MATS project. This aims to encourage the neural net to implement one possible generalization (in this case, honestly reflecting the LLM’s internal world model) instead of the other generalization (in this case, sounding good and correct to a particular user).

While activation steering, supervised finetuning, and RLHF work well in practice and can make systems behave better, there is still a risk that powerful models generalize in unpredictable and potentially undesirable ways in out-of-distribution examples. In particular, for subtle alignment-related questions like deception or power-seeking, activation steering or RLHF may fix the “symptoms” of the problem on examples similar to the training corpus but may fail to fix the “underlying cause” and achieve aligned behavior.

A somewhat ambitious alternative way to get at the “root” of a generalization problem instead of fixing its symptoms is to try to access it on a mechanistic level. Namely, imagine that on the level of the “internal architecture” of the neural net (something that is notoriously hard to access but can sometimes be partially interpreted), the two generalizations get executed by at least somewhat independent modules (i.e., parallel circuits: the term comes from this paper). If we were able to identify and split up these two modules cleanly, we might be able to find weight perturbation vectors that destroy (“ablate”) one of them while preserving the other. The resulting method is now provably robust: it prevents one of the generalizations (understood mechanistically as the underlying module) from getting executed at any level, thus solving both the symptom and the underlying cause.

This algorithm for tuning generalizations can be possible only if the underlying mechanistic model (of different independent generalization “modules” which can be consistently found and independently ablated) is correct or partially correct to a relevant approximation. In order to even begin to engage with it, we need answers to the following questions.

Question 1. Do any real-life useful neural nets learn multiple generalization methods simultaneously in a single training run?

Question 2. If yes, are these generalizations independently tunable? More ambitiously, are these generalizations mechanistically well-conceptualized as implemented by distinct “generalization modules” that implement parallel and fully independent algorithms?

Question 3. Can we consistently (i.e., via some universal recipe) find and mechanistically reason about these generalization modules, given some supervised “guesses” about how they should behave on custom-generated inputs? (E.g., prompts designed to elicit very different responses from a sycophantic vs. an honest text prediction technique.) Much more ambitiously, can we check in an “unsupervised” way whether a given neural net is, in fact, implementing multiple parallel generalization modules and then separate and analyze the constituent modules in an interpretable way?

In our work, we demonstrate that the answers to all these questions—including the ambitious versions—are affirmative for appropriate networks. As our main tool for mechanistically accessing and manipulating generalization circuits, we use some techniques inspired by mathematical physics (and involving a study of the Hessian eigenvectors of a model: a mathematical method for finding spaces of “maximally salient” steering vectors). We perform our tests on certain small classifier neural nets trained on carefully selected data sets (chosen to have legibly distinct generalizations). However, our methods are extremely general and potentially applicable to any classifier neural net. In the three subsequent sections of this post, we will explain how we set up our experiments and use our Hessian eigenvector toolbox to answer these three questions one by one.

While our current results should be understood as a proof of concept, we expect at least some of our methods to be applicable to a much larger class of neural nets encountered “in the wild.” We are excited about these results and believe they are new[1]. In particular, the following funny picture is our proudest achievement:

If you read through to the end of the post, you will know why.

Models with multiple generalizations

In order to test the feasibility of module-steering techniques, we need to work with a model which exhibits an affirmative answer to

Question 1: Do any real-life useful neural nets learn multiple generalization methods simultaneously in a single training run?

Unlike our answers to the other questions, this is a known (though somewhat mysterious) phenomenon that occurs in neural networks. Later in this section, we will describe the specific model we designed to exhibit this generalization ambiguity in a particularly strong sense. However, essentially all nontrivial models which have been mechanistically studied turn out to learn multiple generalizations to some extent.

For example, it is known that dog breed classifiers have a snout detector that reacts strongly to the shape of the snout and an ear detector that reacts strongly to the shape of an ear. It would be reasonable to suppose then that there are somewhat independent modules in the dog breed classifier that look at the snout or look at an ear and decide what breed of dog this part is likely to belong to. Now, almost all pictures of dogs will have both the ear and the snout visible, and it is possible in principle (e.g., for a dog breed classifying savant) to completely determine a dog’s breed from its ear or snout alone. Thus in a certain asymptotic limit of maximally good feature detectors, the snout and ear classifiers become mutually redundant with each other: either of them by itself is enough to classify a dog’s breed fully.

But in real systems, dog breed classifiers develop both detectors. This is not surprising. CNNs used for this task are relatively small networks, and it is unreasonable for a single feature detector to provide enough information to solve the classification problem fully. Though somewhat correlated (<=> partially redundant), breed-related information coming from different features of an image can be combined in useful ways to improve the predictive power.

What is somewhat more surprising is that, even when a single generalization is entirely sufficient to solve a classification problem at close to 100% accuracy, networks will still tend to learn additional “redundant” generalizations (so long as the other generalizations have a roughly comparable level of “ease of learnability”). This phenomenon is not fully understood and interfaces both with questions of learning priors (in particular, it is inconsistent with either the speed prior or the “Solomonoff”/​simplicity prior being mechanistically preferred by neural nets) and with phase transitions studied by Singular Learning Theory: we plan to write later another post about insights that our Hessian eigenvector experiments give about priors and phase transitions in “real-life” small networks. However, for the purposes of this post, it is sufficient to note that this phenomenon has been observed before, both “naturally” in convolutional networks and “synthetically” in sparse networks trained to separate into multiple modules. The latter paper (which we will encounter a few times later in this post) examines how (in the synthetically sparse setting) modules corresponding to different generalizations “vote” to provide the combined network with the final “consensus” classification.

To start off our experiment, we trained our own network on a classification task designed to encourage different generalizations by combining two labeled datasets, one consisting of “digits” and one consisting of “patterns” in a “redundant” way. We observed that, indeed, the resulting model learned both underlying classification problems, i.e., learned a “redundant” classification algorithm for the mixed data (see paragraph and diagram below for details).

  • For the first dataset of “pure patterns”, we generated random boolean patterns, labeled in a random fashion by the numbers 0-9. We expanded these patterns into squares.

  • For our second dataset of “pure digits”, we trained the same classifier to learn the MNIST dataset of grayscale digit representations.

  • We trained our CNN classifier on a “mixed” dataset consisting of (pixel-level) averages of (pattern, digit) pairs with the same label
    [2]

The classifier (trained only on “mixed” data) naturally learned a CNN model with ~95% accuracy on the “mixed” dataset, and retained accuracy of >50% on both the “pure pattern” and the “pure digit” datasets, thus confirming that both generalizations are learned by default. For our experiments we used a modification of this model with additional pretraining to have high accuracy on both pure datasets and the main, mixed dataset: we did this to get cleaner signal from the pure datasets to measure the strength of the two distinct “generalization” circuits; from some quick observations, the behavior we observed at the pretrained model is qualitatively analogous (but more interpretable) than that for a model trained directly on the mixed dataset.

This mixed MNIST data was the test bed of our “module-finding” experiments, and the loss functions associated with the three datasets (pure numbers, pure patterns, and mixed) provided the inputs for the “supervised” Hessian eigenvector techniques that we developed by playing around with this system.

Supervised steering and identification of mechanistic modules

The main desideratum for our “mechanistic steering” approach is the identification of steerable modules. Specifically, we would like to find a system that exhibits an affirmative answer to the “ambitious” part of our

Question 2. [If a network learns multiple generalizations], are these generalizations independently tunable? More ambitiously, are these generalizations mechanistically well-conceptualized as implemented by distinct “generalization modules” that implement parallel and fully independent algorithms?

It is the process of investigating this question that led us to the main technical tool of our study, namely manipulations involving spaces of large-eigenvalue Hessian eigenvectors. As questions of making explicit notions involving circuits and independence (as well as the formulas in our Hessian toolbox) involve somewhat technical differential-geometric and linear-algebraic ideas, we will defer them to a future post. Instead, here we will give a rough “intuitive” flavor of how our methods work, with snapshots of a few formulas for the benefit of those interested in the mathematical details.

It is perhaps helpful, as an intuition pump, to start by describing the results of the paper Seeing is Believing: Brain-inspired Modular Training for Mechanistic Interpretability. Its authors (Ziming Liu, Eric Gan & Max Tegmark) trained classifier networks on a network with a redundant generalization behavior—but rather than training “classic” neural nets, the authors defined a synthetic training function that encouraged the network to have sparse and modular matrices of weights: i.e., to legibly decompose into separable, legibly organized, and “approximately boolean” circuits [3]. They noticed that, for problems that learned multiple generalizations, the circuits naturally decomposed into modules associated with the different generalizations, which each carried out an independent classification problem, and later participated in a “voting” algorithm to output a consensus solution.

In particular, in their “booleanized” model, we see that it is possible to ablate (i.e., break) one of the modules by zeroing out or—essentially equivalently—adding noise to its constituent neurons. If one were to do this, then the total “consensus” model would end up tuned to only express one generalization and, moreover, do so in a robust way (i.e., no part of the ablated module would be “secretly” participating in the classification task). In the messy “native” models like our MNIST models, nice neuronal decompositions essentially never make sense (as we expect relevant features to appear in bases where multiple neurons are linearly combined—or superposed—with each other). However, the picture presented in Liu et al.‘s paper can be potentially recovered on a linear-algebraic level, where we replace the explicitly boolean concept of a subset of neurons that appear in the distinct modules of Liu et al.’s models by a linearly independent set of weight vectors (in fact, the relevant parameter is not the set itself but the space it generates).

Here we give a cartoon of (on the left) a pair of independent vs. (on the right) non-independent or “overlapping” modules in weight space . Rows correspond to vectors; red rows correspond to vectors critical to (i.e., in the “effective dimensionality”) of module A and blue rows correspond to vectors critical to module B. Grey rows correspond to “generalization directions”: free parameters whose perturbations do not affect the performance of any relevant circuit. The purple vector on the right is associated with an overlap, i.e., a weight vector involved in an essential way in both circuits. In both pictures, the relevant vectors for the “combined” algorithm are all colored vectors (this is pictured on the RHS), and the two sides correspond to alternative hypotheses for how the modules combine to form the total algorithm. Note that this picture is a simplified cartoon of a more sophisticated linear-algebraic phenomenon, and in particular, it obscures a relevant distinction between subspaces and quotient spaces. A more mathematically precise picture will be given in our planned follow-up post on the independence of circuits.

We are able to recover indeed the linear-algebraic analogs to modularity: in particular, identify the two relevant modules in a linear-algebraic sense, and verify the independence of the resulting two circuits in our model (at least at the level of their top principal components). We can then use this independence to generate steering vectors that robustly ablate one or the other circuit.[4]

Here the line on the bottom (either “pure numbers” or “pure patterns”) corresponds to loss for the circuit we want to conserve, and the line at the top corresponds to loss for the circuit we want to ablate. Loss for the “combined” algorithm is the middle line: note that it has better performance than the average of the two “pure” algorithms[5].

For completeness, we include corresponding graphs of accuracy, which, as one would expect, reverse this picture.

Note that our independence results imply that the ablations we obtain are particularly “mechanistically clean,” i.e., they ablate all sub-circuits of the “undesirable” algorithm by comparable (in expectation) amounts and should not be expected to leave “secret” components with strange out-of-distribution behavior. But even from the perspective of someone not concerned with mechanistic completeness (for example, someone performing “practical” steering without worrying about deceptive or out-of-distribution phenomena), this result is interesting. While ablations of objects that resemble modules (such as feature detectors) have been constructed in the literature in ad-hoc ways, our Hessian steering approach is, to the best of our knowledge, the first “out of the box” module ablation mechanism that (assuming knowledge of Hessian eigenvectors) works without having to perform any additional interpretability analysis.

To steer towards a specific generalization, we perturb the weights of a trained model in a direction that lies in the space of top Hessian eigenvectors for one dataset and generalization directions (orthogonal complement of top Hessian eigenvectors) for another dataset, using the following equation:

is a scalar, is the orthogonal complement of the matrix of top loss Hessian eigenvectors for the pure dataset you want to maintain performance on, and is a top loss Hessian eigenvector for the dataset you want to reduce performance on (corresponding to the circuit you want to ablate).

Supervised vs. unsupervised module decomposition/​steering

In the previous section, we assumed for the purposes of steering that we not only have access to the main, “combined” dataset (and the associated loss function) but also to the two “pure” datasets associated with its redundant generalizations (together with their respective loss functions). Indeed, this was the source of our colorful eigenvector decomposition diagram, which let us test for independence and generate steering data for free). This provided an affirmative answer to the initial “unambitious” version of our

Question 3. Can we consistently (i.e., via some universal recipe) find and mechanistically reason about these generalization modules, given some supervised “guesses” about how they should behave on custom-generated inputs? (E.g., prompts designed to elicit very different responses from a sycophantic vs. an honest text prediction technique.) Much more ambitiously, can we check in an “un-supervised” way whether a given neural net is, in fact, implementing multiple parallel generalization modules and then separate and analyze the constituent modules in an interpretable way?

In this section, we will sketch out our results on the “much more ambitious” part of this question. First, we note that for our MNIST problem, we have a technique that works reasonably well (but not as well as the fully supervised version) to solve what we call the “semi-supervised” problem: finding modules and finetuning directions when we know one generalization (either the “desired” or the “undesired” one) and want to find the other.

We are using a simplified version of a “sphere search” algorithm we describe in more detail below in the fully unsupervised case (we will also give more details in a future post on our unsupervised algorithms). Here we are ablating the “known” generalization (patterns) to access the circuit for the “unknown” generalization (numbers/​digits).

However, the most interesting unsupervised question, and the one we will devote the rest of the section to, is the “fully unsupervised” version of our problem. The goal here is to create a recipe for starting with a single “composite” algorithm, without any knowledge about potential generalization behaviors or specifics of the training distribution or “extraneous” mechanistic properties—and then using a Hessian eigenvector-inspired algorithm to obtain an independent generalization module inside of this algorithm.

A priori, this seems too good to be true: it feels almost like we’re trying to get something (detailed information about generalization behavior) out of close to nothing (mathematical properties of the Hessian at our solution and high-level properties of the surrounding basin).

It was the biggest surprise to us in the course of our experimentation that the unsupervised decomposition and steering methods can be made to work (we at first tried setting it up many times in many different contexts and kept being unable to make it work). We plan to explain in a later installment why this is so much harder than the “supervised” approach: why it is not possible to simply split the Hessian eigenvalues for the “combined” algorithm into “good” and “bad” generalizations and how we carefully designed a new “sphere search” neural net on top of a very simple and interpretable model (modular addition) to solve this task finally. But we believe that including a full description here would make an already long exposition unreadable. For now, we simply say that the algorithm works by executing a “sphere search” algorithm (another neural net built on top of the original one) which “pushes” solutions to move out of the local minimum along certain Hessian eigendirections; eventually, this exerts breaking “pressure” on the various component modules, and at a finite point, we expect (and, in our nice setting, observe) a symmetry-breaking phase transition which forces one of the modules to break, and thus become interpretable.

A picture of the unsupervised algorithm working (!?!) on a simple modular addition task. In the left 2x2 figure, the first layer of the model’s modules—corresponding to its distinct Fourier modes—look recognizably like circles, even after 10 units of deformation along our “sphere search” algorithm. In the right 2x2 figure, after 15 units of deformation, one circuit (the upper right) is finally forced to break, and its associated “mechanistic signature” (in the form of its generalization space) can finally be extracted. Pushing further will further break the offending circuit; eventually, when pushed too far away, the network will find another stable configuration.

We note that this unsupervised technique is both the part of our algorithm that we least expect to generalize well to large models[6] and also the part that seems most powerfully applicable to alignment if it can be made to work in a general context. As a motivating intuition pump, imagine that you are a megacorporation training a giant neural net that you expect to be an AGI[7], which you worry may develop a form of inner misalignment similar to deceptive misalignment. Suppose you are confident that early in training, where it is at a level of sophistication similar to GPT-4, it is entirely safe and that if you can prevent this one particular issue (an alternative generalization executing deceptive misalignment), the GPT-4-level network is guaranteed to turn into a safe AGI. Assume also that you have access to a cheap, guaranteed-effective version of our unsupervised steering algorithm, which furthermore can be applied mid-SGD, outside of a local minimum[8]. In this counterfactual, you can potentially guarantee alignment by constantly testing your algorithm for the appearance of any new generalization directions and by ablating such directions whenever they begin to appear. This way, as soon as a deceptively-aligned generalization algorithm floats up in the radar of the training model, it will get ablated, and the model will get steered away into using its old, “safe” generalizations.

Note that we don’t actually believe that such a naive and optimistic technique would work directly to achieve inner alignment (we can already imagine about a dozen possible critiques explaining how, actually, misalignment might still occur even in our idealized context). However, the intuition that seems plausible is that, given a fully general, fully unsupervised, and cheaply implementable algorithm to track the generalization directions of a neural net, one may be able to perform oversight tasks of a fundamentally more powerful nature than most existing or proposed forms of oversight to date. If this technique turns out to be a part of a larger system of equally powerful and versatile, and cheaply implementable techniques, then they may provide tools that can genuinely push us towards alignment.

While we have relatively low credence on the fact that a family of fully abstract techniques similar to ours would be powerful enough to solve alignment by itself, we have significantly higher credence that our unsupervised decomposition technique can be used as a mechanistic interpretability technique by allowing people to decompose larger models into smaller modules either in the context of alignment oversight, or interpretability research. In this sense, it may fit into a “safety via interpretability” niche similar to, and perhaps related to, the hopes of Singular Learning Theorists.

Addendum: complexity of Hessian calculations

We note that in addition to liking nice, small, interpretable models, there is an additional reason why we work with small models in this project. While training a neural net is (very roughly) linear in the number of parameters, finding its Hessian eigenvalues is (very roughly) quadratic. There are various possible speedups and shortcuts that can make this procedure faster, but ultimately we expect that the quadratic cost is not prohibitive if one looks at large models not by considering all their weights at once but by looking at more manageable subsystems one at a time, either obtained from the architecture (e.g., layers of an LLM) or obtained through some other mechanistic techniques such as through probing or using some feature detection algorithm to obtain candidates for “important” subspaces of the space of weight deformations. If there is reason to expect that one such subspace contains enough information to be decomposable into modules, we expect some of our methods (particularly the supervised ones) to retain much of their functionality. However, this is something we have not yet tried and hope to see implemented in the future.

This work was produced as part of the SERI ML Alignment Theory Scholars Program—Summer 2023 Cohort, as a joint project by the two authors. We would like to thank Evan Hubinger, whose stream we are in, for useful conversations and suggestions for exposition. Thanks also to John Wentworth for advice on the project and for pointing us at an efficient method for computing Hessian eigenvectors that made this work possible.

  1. ^

    Note that an affirmative answer to Question 1 has been known for a while. Ad hoc methods for mechanistically interpreting generalizations (as in Question 2) have been found in the past, for example in the context of feature detectors (see, e.g., The Building Blocks of Interpretability), but no general-purpose method has existed before to the best of our knowledge, and no previous work has demonstrated the modularity (i.e., mechanistic independence) of the circuits. Our answer to the “unsupervised” part of Question 3 appears to be entirely new.

  2. ^

    Note that this is very similar to the modified MNIST dataset in the paper Diversify and Disambiguate: Out-of-distribution robustness via disagreement though our version is grayscale and our patterns are more coarse-grained.

  3. ^

    Note that the neural nets trained in loc. cit. (Liu et al.) would provide highly interpretable, highly steerable, and potentially directly alignable models if they were competitive with the “native” illegible and highly linearly superposed networks that occur in current applications. At the moment, the extreme reduction in efficiency results from the modified training algorithm in loc. cit. makes these networks useful as experimental test beds rather than as a direct “safety-ensuring” AI tool. In a certain sense, our entire project can be conceptualized as a verification that “boolean” network phenomena observed by Liu et al. continue to hold, in an appropriately modified linear-algebraic sense, in the “native” (and highly non-boolean/​superposed) neural network architectures that occur in modern applications.

  4. ^

    There are some technicalities that complicate the picture here (that will be explained in our follow-up post on independence). In particular, the “robustness” of the ablation means that not only is the module itself ablated, but every circuit that constitutes it is itself ablated with high probability.

  5. ^

    Note that this is in itself an interesting phenomenon, and one can hypothesize that this captures the intuition that having two partially redundant algorithms to solve a task is less than twice as useful as having a single high-accuracy algorithm to do so. In other words, it represents a “ghost” or “echo” of the simplicity/​Solomonoff prior, which otherwise is seen to fail in our models totally—we will discuss this phenomenon more in our posts on priors and SLT.

  6. ^

    Though we expect it to generalize to intermediate models like CNN image classifiers: in particular, we see some promising, though not definitive, results for our MNIST model.

  7. ^

    Kids, don’t try this at home. Seriously: even in the improbable worlds where our hacky speculative safeguard does actually work to prevent one form of catastrophic misalignment, there are dozens of others that you will need other, more sophisticated, and not yet extant algorithms to prevent.

  8. ^

    This can potentially be achieved without too much additional complexity by waving a math magic wand and saying words along the lines of “restrict your root parameters to the orthogonal complement of the gradient and throw out negative eigenvalues of the restricted Hessian.” Note that we haven’t tested whether such a mid-SGD procedure ever works, but while we are in the middle of a speculative world where all our ideas work perfectly, and all our desiderata are met, we might as well ask for a flying pony.