SAE reconstruction errors are (empirically) pathological

Summary

Sparse Autoencoder (SAE) errors are empirically pathological: when a reconstructed activation vector is distance from the original activation vector, substituting a randomly chosen point at the same distance changes the next token prediction probabilities significantly less than substituting the SAE reconstruction[1] (measured by both KL and loss). This is true for all layers of the model (~2x to ~4.5x increase in KL and loss over baseline) and is not caused by feature suppression/​shrinkage. Assuming others replicate, these results suggest the proxy reconstruction objective is behaving pathologically. I am not sure why these errors occur but expect understanding this gap will give us deeper insight into SAEs while also providing an additional metric to guide methodological progress.

Introduction

As the interpretability community allocates more resources and increases reliance on SAEs, it is important to understand the limitation and potential flaws of this method.

SAEs are designed to find a sparse overcomplete feature basis for a model’s latent space. This is done by minimizing the joint reconstruction error of the input data and the L1 norm of the intermediate activations (to promote sparsity):

However, the true goal is to find a faithful feature decomposition that accurately captures the true causal variables in the model, and reconstruction error and sparsity are only easy-to-optimize proxy objectives. This begs the questions: how good of a proxy objective is this? Do the reconstructed representations faithfully preserve other model behavior? How much are we proxy gaming?

Naively, this training objective defines faithfulness as L2. But, another natural property of a “faithful” reconstruction is that substituting the original activation with the reconstruction should approximately preserve the next-token prediction probabilities. More formally, for a set of tokens and a model , let be the model’s true next token probabilities. Then let be the next token probabilities after intervening on the model by replacing a particular activation (e.g. a residual stream state or a layer of MLP activations) with the SAE reconstruction of . The more faithful the reconstruction, the lower the KL divergence between and (denoted as ) should be.

In this post, I study how compares to several natural baselines based on random perturbations of the activation vectors which preserve some error property of the SAE construction (e.g., having the same reconstruction error or cosine similarity). I find that the KL divergence is significantly higher (2.2x − 4.5x) for the residual stream SAE reconstruction compared to the random perturbations and moderately higher (0.9x-1.7x) for attention out SAEs. This suggests that the SAE reconstruction is not faithful by our definition, as it does not preserve the next token prediction probabilities.

This observation is important because it suggests that SAEs make systematic, rather than random, errors and that continuing to drive down reconstruction error may not actually increase SAE faithfulness. This potentially indicates that current SAEs are missing out on important parts of the learned representations of the model. The good news is that this KL-gap presents a clear target for methodological improvement and a new metric for evaluating SAEs. I intend to explore this in future work.

Intuition: how big a deal is this (KL) difference?

For some intuition, here are several real examples of the top-25 output token probabilities at the end of a prompt when patching in SAE and -random reconstructions compared to the original model’s next-token distribution (note the use of log-probabilities and the KL in the legend).

For additional intuition on KL divergence, see this excellent post.

Experiments and Results

I conduct most of my experiments on Joseph Bloom’s GPT2-small residual stream SAEs with 32x expansion factor on 2 million tokens (16k sequences of length 128). I also replicate the basic results on these Attention SAEs.

My code can be found in this branch of a fork of Joseph’s library.

Intervention Types

To evaluate the faithfulness of the SAE reconstruction, I consider several types of interventions. Assume that is the original activation vector and is the SAE reconstruction of .

  • -random substitution: is a random vector with . I.e., both and are random vectors on the -ball around .

  • -random substitution: is a random vector with . I consider both versions where the norm of is adjusted to be and .

  • SAE-norm substitution: this is the same as the original activation vector except the norm is altered to the SAE norm . This is a baseline to isolate the effect of the norm change from the SAE reconstruction, a known pathology identified here.

  • norm-corrected SAE substitution: this is the same as except the norm is altered to the true norm . Similar motivation as above.

In addition to these different kinds of perturbations, I also consider applying the perturbations to 1) all tokens in the context 2) just a single token. This is to test the hypothesis that the pathology is caused by compounding and correlated errors (since the -random substitution errors are uncorrelated).

Here is are the average KL differences (across 2M tokens) for each intervention when intervened across all tokens in the context:

There are 3 clusters of error magnitudes:

  • The and norm-corrected are both high with norm-corrected slightly higher (this makes sense because it has a higher L2 reconstruction error).

  • -random and both variants of -random have much lower but non-trivial KL compared to the SAE reconstruction. They are all about the same because random vectors in a high dimensional space are almost-surely almost-orthogonal so the -random perturbation has an effect similar to the -random perturbation.

  • Most importantly, the SAE-norm substitution has an almost 0 KL divergence. This is important because it shows that the difference is not caused by the smaller norm (a known problem with SAEs) but the direction.

Given these observations, in the rest of the post I mostly focus on the -random substitution as the most natural baseline.

Layerwise Intervention Results in More Detail

Next, I consider distributional statistics to get a better sense for how the errors are distributed and how this distribution varies between layers.

This is a histogram of the KL differences for all layers under -random substitution and the SAE reconstruction (and since I clamp the tails at 1.0 for legibility, I also report the 99.9th percentile). Again the substitution happens for all tokens in the context (and again for a single layer at a time). Note the log scale.

Observe the whole distribution is shifted, rather than a few outliers driving the mean increase.

Here is the same plot but instead of KL divergence, I plot the cross-entropy loss difference (with mean instead of 99.9p). While KL measures deviation from the original distribution, the loss difference measures the degradation in the model’s ability to predict the true next token.

Just as with KL, the mean loss increase of the SAE substitution is 2-4x higher compared to the -random baseline.

Finally, here is a breakdown of the KL differences by position in the context.

Single Token Intervention Results

One possibility is that the KL divergence gap is driven by compounding errors which are correlated in the SAE substitutions but uncorrelated in the baselines (since the noise is isotropic). To test this, I consider the KL divergence when applying the substitution to a single token in the context.

In this experiment I intervene on token 32 in the context and measure the KL divergence for the next 16 tokens (averaged across 16,000 contexts). As before, there is a clear gap between the SAE and -random substitution, and this gap persists through the following tokens (although the magnitude of the effect depends on how early the layer is).

For clarity, here is the KL bar chart for just token 32 and the following token 33.

While the KL divergence of all interventions is lower overall for the single token intervention, the SAE substitution KL gap is preserved—it is still always >2x higher than the -random substitution KL for the present token and the following token (except token 33 layer 11).

How pathological are the errors?

To get additional intuition on how pathological the SAE errors are, I try randomly sampling many -random vectors for the same token, and compare the KL divergence of the SAE substitution to the distribution of -random substitutions.

Each subplot below depicts the KL divergence distribution for 500 -random vectors and the KL of the true SAE substitution for a single token at position 48 in the context. The substitution is only performed for this token and is performed on the layer 6 residual stream. Note the number of standard deviations from the -random mean labeled in the legend.

What I take from this plot is that the gap has pretty high variance. It is not the case that every SAE substitution is kind-of-bad, but rather there are both many SAE reconstructions that are around the expectation and many reconstructions that are very bad.

When do these errors happen?

Is there some pattern in when the KL gap is large? Previously I showed there to be some relationship with absolute position in the context. As expected, there is also a relationship with reconstruction cosine similarity (a larger error will create a larger gap, all things equal). Because SAE L0 is correlated with reconstruction cosine sim, there is also a small correlation with the number of active features.

However, the strongest correlations I could find were with respect to the KL gap of other layers.

This suggests that some tokens are consistently more difficult for SAEs to faithfully represent. What are these tokens? These are the top 20 by average KL gap for layer 6 (and occur at least 5 times)

Beyond there not being an obvious pattern, notice the variance is quite high. I take this to mean the representational failures are more contextual. While these tokens seem rarer, there is no correlation between token frequency and KL gap.

For additional analysis on reconstruction failures, see this post.

Replication with Attention SAEs

Finally, I run a basic replication on SAEs trained on the concatenated z-vectors of the attention heads of GPT2-small.

While there is still a KL gap between the SAE and -random substitution, it is smaller (0.9x-1.7x) than the residual stream SAEs, and a larger fraction of the difference is due to the norm change (though it depends on the layer). This was expected since substituting the output of a single layer is a much smaller change than substituting the entire residual stream. Specifically, a residual stream SAE tries to reconstruct the sum of all previous layer outputs, and therefore replacing it is in effect replacing the entire history of the model, in contrast to just updating a single layer output.

Concluding Thoughts

Why is this happening?

I am still not sure yet! My very basic exploratory analysis did not turn up anything obvious. Here are a few hypotheses:

  • -random is a bad baseline because activation space is not isotropic (or some other reason I do not understand) and this is not actually that unexpected or interesting. Consider a hypothetical 1000-dim activation space where most activations only lie in a 500-dim subspace and the model mostly ignores the other 500 dimensions (e.g. for robustness?). Then the random perturbation gets applied across all dimensions and the perturbation in the effective activation space is smaller leading to a smaller KL.
  • Some features are dense (or groupwise dense, i.e., frequently co-occur together). Due to the L1 penalty, some of these dense features are not represented. However, for KL it ends up being better to nosily represent all the features than to accurately represent some fraction of them. For examples, consider:

    • [Dense] The position embedding matrix in GPT2-small is rank ~20. Therefore to accurately reconstruct absolute context position you already need 20 active features!

    • [Groupwise dense] Manifold features families where a continuous feature is represented by a finite feature discretization (e.g., curve detectors where multiple activate) or hierarchical feature families like space and time; E.g., the in_central_park feature is potentially a subtype of in_new_york_city, in_new_york_state, in_usa, etc. which might all activate at once; similar for a date-time which would require activating features at many temporal scales (in addition to adjacent points in the manifold). Perhaps such features are also learned in combinations.

  • Some features are fundamentally nonlinear, and the SAE is having difficulty reconstructing these.

  • Training FLOPs: perhaps these SAEs are undertrained. One test for this is checking if the KL gap gets better or worse with more training.

  • Training recipe: each of these SAEs are trained to have approximately the same average L0 loss (average number of features) and have the same fan-out width. In practice, I expect that the number of active features and the number of in-principle representable features to vary throughout network depth. The large variability in KL gap by layer is suggestive.

  • Quirk of GPT2: while I tested two different families of SAEs trained by two different groups, they were both on GPT2! I’d guess that other models behave differently (e.g. something like this might matter).

Takeaways

Assuming these findings replicate to other SAEs (please replicate on your own models!):

  • SAEs empirically make non-random pathological errors. In particular, SAE reconstructions are consistently on a bad part of the -ball as measured by KL divergence and absolute loss increase.

  • Both SAE KL and -random substitution KL divergence should be a standard SAE evaluation metric to measure faithfulness of the reconstruction.

    • Conceptually, loss recovered seems a worse metric than KL divergence. Faithful reconstructions should preserve all token probabilities, but loss only compares the probabilities for the true next token[2].

  • Closing the gap between SAE and -random substitution KL divergence is a promising direction for future work.

Future work

I intend to continue working in this direction. The three main work streams are

  • Check that this replicates in other SAEs and that the -random baseline is actually sensible.

  • Do a more thorough analysis of when and why these errors occur and test some of the aforementioned hypotheses.

  • Develop SAE training methods which close the KL-gap.

Acknowledgements

I would like to thank Carl Guo, Janice Yang, Joseph Bloom, and Neel Nanda for feedback on this post. I am also grateful to be supported by an Openphil early career grant.

  1. ^

    That is, substituting an SAE reconstructed vector for the original activation vector changes the model prediction much more than a random vector where .

  2. ^

    E.g., consider the case where both the original model and the SAE substituted model have place probability on the correct token but their top token probabilities are all different. Loss recovered will imply that the reconstruction is perfect when it is actually quite bad.E.g., consider the case where both the original model and the SAE substituted model have place probability on the correct token but their top token probabilities are all different. Loss recovered will imply that the reconstruction is perfect when it is actually quite bad.