Sparse Autoencoders Find Highly Interpretable Directions in Language Models
This is a linkpost for Sparse Autoencoders Find Highly Interpretable Directions in Language Models
We use a scalable and unsupervised method called Sparse Autoencoders to find interpretable, monosemantic features in real LLMs (Pythia-70M/410M) for both residual stream and MLPs. We showcase monosemantic features, feature replacement for Indirect Object Identification (IOI), and use OpenAI’s automatic interpretation protocol to demonstrate a significant improvement in interpretability.
Sparse Autoencoders & Superposition
To reverse engineer a neural network, we’d like to first break it down into smaller units (features) that can be analysed in isolation. Using individual neurons as these units can be useful but neurons are often polysemantic, activating for several unrelated types of feature so just looking at neurons is insufficient. Also, for some types of network activations, like the residual stream of a transformer, there is little reason to expect features to align with the neuron basis so we don’t even have a good place to start.
Toy Models of Superposition investigates why polysemanticity might arise and hypothesise that it may result from models learning more distinct features than there are dimensions in the layer, taking advantage of the fact that features are sparse, each one only being active a small proportion of the time. This suggests that we may be able to recover the network’s features by finding a set of directions in activation space such that each activation vector can be reconstructed from a sparse linear combinations of these directions.
We attempt to reconstruct these hypothesised network features by training linear autoencoders on model activation vectors. We use a sparsity penalty on the embedding, and tied weights between the encoder and decoder, training the models on 10M to 50M activation vectors each. For more detail on the methods used, see the paper.
We use the same automatic interpretation technique that OpenAI used to interpret the neurons in GPT2 to analyse our features, as well as alternative methods of decomposition. This was demonstrated in a previous post but we now extend these results across the all 6 layers in Pythia-70M, showing a clear improvement over all baselines in all but the final layers. Case studies later in the paper suggest that the features are still meaningful in these later layers but that automatic interpretation struggles to perform well.
IOI Feature Identification
We are able to use less-than-rank one ablations to precisely edit activations to restore uncorrupted behaviour on the IOI task. With normal activation patching, patches occur at a module-wide level, while here we perform interventions of the form
where is the embedding of the corrupted datapoint, is the set of patched features, and and are the activations of feature on the clean and corrupted datapoint respectively.
We show that our features are able to better able to precisely reconstruct the data than other activation decomposition methods (like PCA), and moreover that the finegrainedness of our edits increases with dictionary sparsity. Unfortunately, as our autoencoders are not able to perfectly reconstruct the data, they have a positive minumum KL-divergence from the base model, while PCA does not.
Dictionary Features are Highly Monosemantic & Causal
(Left) Histogram of activations for a specific dictionary feature. The majority of activations are for apostrophe (in blue), where the y-axis the is number of datapoints that activate in that bin. (Right) Histogram of the drop in logits (ie how much the LLM predicts a specific token) when ablating this dictionary feature direction.
This is in contrast to the residual stream basis:
which appears highly polysemantic (ie many semantic meanings). More examples can be found in Appendix E. We’ve found many context-neurons (e.g. [medical/Biology/Stack Exchange/German]-context), with some shown in a previous post, so this is an existence proof against concerns that this method only finds token-level features.
Automatic Circuit Discovery
The previous section was on a dictionary’s feature relationship to the input tokens and it’s effect on the logits. We can also see the relationship between features themselves.
Layer 5 is the last layer in Pythia-70M, and this feature directly unembeds into various forms of the closing parenthesis. We can view the previous layers as calculating “What are all the reasons one might predict a closing parenthesis?”.
Sparse autoencoders are a scalable, unsupervised approach to disentangling language model network features from superposition. We have demonstrated that the dictionary features they learn are more interpretable by autointerpretability, are better for performing precise model steering, and are more monosemantic than comparable methods.
The ability to find these dictionary features gives us a new, fully unsupervised tool to investigate model behaviour, allows us to make targeted edits, and can be trained using a manageable amount of computing power.
An ambitious dream in the field of interpretability is enumerative safety: the ability to understand the full set of computations that a model applies. If this were achieved, it could allow us to create models for which we have strong guarantees that the model is not able to perform certain dangerous actions, such as deception or advanced bioengineering. While this is still remote, dictionary learning hopefully marks a small step towards making it possible.
In summary, sparse autoencoders bring a new tool to the interpretability and editing of language models, which we hope others can build upon. The potential for innovations and applications is vast, and we’re excited to see what happens next.
Bonus Section: Did We Find All the Features?
In general, we get a reconstruction loss, and if that’s 0, than we’ve perfectly reconstructed e.g. Layer 4 with our sparse autoencoder. But what does a reconstruction loss of 0.01 mean compared to 0.0001?
We can ground this out to the difference in perplexity (a measure of prediction loss) on some dataset. This will better measure the functional equivalence (ie they have the same loss on the same data). As non-released, preliminary results, with GPT2 (small) on layer 4 on a subset of OpenWebText:
A difference in perplexity of 2.6 for training directly on KL-divergence is quite small, especially for 4 months of effort between 3 main researchers. The two possibilities are
People better at maths/ML/sparse dictionary learning than us can get it to ~0-perplexity difference
A subset of features aren’t linearly-represented.
If (2) is the case, then we’ll now have a dataset of datapoints that aren’t linearly represented which we can study! This would show that superposition only explains a subset of features, and provide concrete counterexamples to the linear-part of the hypothesis.
We would like to give two big caveats though:
We don’t have a perfect monosemanticity metric, so even if we have 0-reconstruction loss, we can’t claim each feature is monosemantic, although a lower sparsity is partial evidence for that.
What if every 1000 features decreases the remaining reconstruction loss by half, so we’re really infinity features away from perfect reconstruction?
Come Work With Us
We are currently discussing research in the #unsupervised-interp channel (under Interpretabilty) in the EleutherAI Discord server. If you’re a researcher and have directions you’d like to apply sparse auteoncoders to, feel free to message Logan on Discord (loganriggs) or LW & we can chat!
For specific questions on sections (we’re all on discord as well):
1. Hoagy- autoninterp & MLP results
2. Aidan—IOI Feature Identification
3. Logan—Monosemantic features & Auto-circuits
KL-divergence is calculated by getting the original LLM’s output, then reconstructing e.g. layer 4 w/ the autoencoder to get a different output, then finding the KL-div between these two outputs. In practice, we found training on KL-div & reconstruction (and sparsity) to converge to lower perplexity.
These datapoints can be found by finding datapoints with the highest perplexity-difference.