Some open-source dictionaries and dictionary learning infrastructure

As more people begin work on interpretability projects which incorporate dictionary learning, it will be valuable to have high-quality dictionaries publicly available.[1] To get the ball rolling on this, my collaborator (Aaron Mueller) and I are:

  • open-sourcing a number of sparse autoencoder dictionaries trained on Pythia-70m MLPs

  • releasing our repository for training these dictionaries[2].

Let’s discuss the dictionaries first, and then the repo.

The dictionaries

[EDIT 02/​07/​2024: Better dictionaries are now available at the repo. Also, the originally reported MSE loss numbers were wrong and been updated in the tables below. (The correct numbers were much lower, i.e. better.)]

The dictionaries can be downloaded from here. See the sections “Downloading our open-source dictionaries” and “Using trained dictionaries” here for information about how to download and use them. If you use these dictionaries in a published paper, we ask that you mention us in the acknowledgements.

We’re releasing two sets of dictionaries for EleutherAI’s 6-layer pythia-70m-deduped model. The dictionaries in both sets were trained on 512-dimensional MLP output activations (not the MLP hidden layer like Anthropic used), using ~800M tokens from The Pile.

  • The first set, called 0_8192, consists of dictionaries of size . These were trained with an L1 penalty of 1e-3.

  • The second set, called 1_32768, consists of dictionaries of size . These were trained with an l1 penalty of 3e-3.

Here are some statistics. (See our repo’s readme for more info on what these statistics mean.)

For dictionaries in the 0_8192 set:

LayerMSE LossL1 lossL0% Alive% Loss Recovered
00.0036.1329.9510.9980.984
10.0086.67744.7390.8870.924
20.01111.4462.1560.5870.867
30.01823.773175.3030.5880.902
40.02227.084174.070.8060.927
50.03247.126235.050.6720.972

For dictionaries in the 1_32768 set:

LayerMSE LossL1 lossL0% Alive% Loss Recovered
00.00184.322.8730.1740.946
10.0172.79811.2560.1590.768
20.0236.15116.3810.1180.724
30.04411.57139.8630.2260.765
40.04813.66529.2350.190.816
50.06926.443.8460.130.931

And here are some histograms of feature frequencies.

Overall, I’d guess that these dictionaries are decent, but not amazing.

We trained these dictionaries because we wanted to work on a downstream application of dictionary learning, but lacked the dictionaries. These dictionaries are more than good enough to get us off the ground on our mainline project, but I expect that in not too long we’ll come back to train some better dictionaries (which we’ll also open source). I think the same is true for other folks: these dictionaries should be sufficient to get started on projects that require dictionaries; and when better dictionaries are available later, you can swap them in for optimal results.

Some miscellaneous notes about these dictionaries (you can find more in the repo).

  • The later layer dictionaries in 0_8192 have too-high L0s. However, looking at the feature frequency histograms, it looks like this might be because of a spike in high-frequency features. Without this spike, the L0s would be much more reasonable, and features outside of this spike look pretty decent (see here for more).

    • We speculate with very low confidence that these spikes might be an artifact of our timing for resampling dead neurons. We resample every 30000 steps, including at step 90000 out of 100000 total steps. The resampled features tend to be very high-frequency, and it might take more than 10000 steps for the peak to move to the left.

  • The L1 penalty for 1_32768 seems to have been too large; only 10-20% of the neurons are alive, and the loss recovered is much worse. That said, we’ll remark that after examining features from both sets of dictionaries, the dictionaries from the 1_32768 set seem to have more interpretable features than those from the 0_8192 set (though it’s hard to tell).

    • In particular, we suspect that for 0_8192, the many high-frequency features in the later layers are uninterpretable but help significantly with reconstructing activations, resulting in deceptively good-looking statistics.

  • As we progress through the layers, the dictionaries tend to get worse along most metrics (except for % loss recovered). This may have to do with the growing scale of the activations themselves as one moves through the layers of pythia models (h/​t to Arthur Conmy for raising this hypothesis).

  • We note that our dictionary features are significantly higher frequency overall than the features in Anthropic’s and Neel Nanda’s. We don’t know if this difference is because we are working with a multi-layer model or if it is because of a difference in hyperparameters. We generally suspect it would be better if we were learning features of lower frequency.

    • We’ll note, however, that after layer 0, it doesn’t seem like many of our features are of the form “always fire on a particular token,” whereas many of Anthropic’s feature were. So it’s possible that more interesting features also tend to be higher-frequency. See here for some flavor.

The dictionary learning repository

Again, this can be found here. We followed the approach detailed in Anthropic’s paper (including using untied encoder/​decoder weights, constraining the decoder vectors to have unit norm, and resampling dead neurons according to their wacky scheme), except for the following:

  • We didn’t have the space to store activations for our entire dataset, so – following Neel Nanda’s replication – we maintain a buffer of tokens from a few thousand contexts and randomly sample from this buffer until it’s half-empty (at which point we refresh it with tokens from new contexts).

  • We used a brief linear learning rate warm-up to fix a problem where Adam would kill too many of our neurons in first few training steps, before it had a chance for the Adam parameters to calibrate.

(A brief plug: this repository is built using nnsight, a new interpretability tooling library (like transformer_lens and baukit) being developed by Jaden Fiotto-Kaufman and others in the Bau lab.nnsight is still under development, so I only recommend trying to dive into it now if you’re okay with occasional bugs, memory leaks, etc. (which you can report in the feedback channel of this Discord server). But I’m overall very excited about the project – aside from providing a very clean user experience, one major design goal is that nnsight code is highly portable: you should ideally be able to prototype an experiment with Pythia-70m, switch seamlessly to running it on LLaMA-2-70B split across multiple GPUs, and then ship your code to Anthropic to be run on Claude.)

In addition to the mainline functionality, our repo also supports some experimental features, which we briefly investigated as alternative approaches to training dictionaries:

  • MLP stretchers. Based on the perspective that one may be able to identify features with “neurons in a sufficiently large model” we experimented with training “autoencoders” to, given as input an MLP input activation , output (the MLP output). For instance, given an MLP which maps a 512-dimensional input to a 1024-dimensional hidden state and then a 512-dimensional output , we train a dictionary with hidden dimension so that is close to (and, as usual, so that the hidden state of the dictionary is sparse).

    • The resulting dictionaries seemed decent, but we decided not to pursue the idea further.

    • (h/​t to Max Li for this suggestion.)

  • Replacing L1 loss with entropy. Based on the ideas in this post, we experimented with using entropy to regularize a dictionary’s hidden state instead of L1 loss. This seemed to cause the features to either be dead features (which never fired) or very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better.

If you want to pursue one of the ideas in the above bullet points, I ask that you get in touch with me (Sam) once you have preliminary results – I may be interested in discussing results or collaborating.

  1. ^

    This is both for the sake of reproducibility, and because each dictionary takes some effort to train.

  2. ^

    Of course, the repository from the Cunningham et al. paper is also available here.