Normalizing Sparse Autoencoders

TL;DR

Sparse autoencoders (SAEs) presents us a promising direction towards automating mechanistic interpretability, but it not without flaws. One known issue of the original sparse autoencoders is the feature suppression effect which is caused by the conflict between the and loss and the unit norm constraint on the SAE decoders. This effect in theory will be more evident when we have inputs that have high norms. Another observation is that training SAEs on multiple layers simultaneously results in inconsistent norms for feature activations across layers: in some layers, has scale of , while in some other layers it has a scale of . Moreover, the residual states that’s inputed to the SAEs for training also have different norms across layers. Hence, I argue that the current SAE architecture is not robust against inputs of varying norms, which is commonly the case in modern LLMs. In this post, I a modified SAE architecture, namely Normalized Sparse Autoencoder (NSAE), and gave a theoretical proof that it will not have the feature suppression problem. I then conducted experiments to verify the effectiveness of the proposed method, which showed that:

  1. Feature suppression is suppressed in NSAEs

  2. The normalization removed the correlation between layer mean input norm and

  3. The normalization makes agrees with better

I then further investigated the learned feature dictionaries and identified 3 types of feature vectors: the correction vector, the pillar vector, and the direction vector. I then concluded this post with discussion on the limitations of NSAEs and gave my suggestions on future directions.

Introduction

Training Sparse Autoencoders (SAEs) on the residual states of pretrained models is a recently proposed method in mechanistic interpretability to tackle the problem of superposition. This method is scalable and unsupervised, making it promising for auto-interpretability research.

More specifically, a SAE contains an encoder and a decoder. It is trained to generate sparse feature activations from the original residual states of a source model through the encoder, and reconstruct the residual state through a decoder. It is expected that by training the SAE with a large set of activations jointly optimizing for a sparsity loss on the feature activations and a reconstruction loss, the model can learn to decompose residual states into monosemantic feature vectors that are more interpretable.

In this post, I identified a flaw in the original SAE implementation, namely inconsistency of the loss across layers, and proposed a method to mitigate this problem. With the new method, we can significantly decrease the correlation between the norm of the source model’s residual activations and the norm of the feature activations, making the training process more robust and controllable. The code is available on GitHub (notice that you should use the dev branch instead of others).

Motivations

Feature suppression is a known problem for SAEs. It originated from a conflict between the sparsity loss and the reconstruction loss, as the reconstruction’s norm is correlated with , and the SAE model learns to generate a reconstruction with smaller norm for a better loss. This is not desirable, as we would like the reconstruction to best correspond to the original input activations. Therefore, finding a way to disentangle the input norms from and is beneficial.

Also, in my personal experiments with training SAEs using this implementation from the AI Safety Foundation, I observed an inconsistency of the sparsity loss across layers:

Figure 1a. The loss of the activations in the layer indexed 1.
Figure 1b. The loss of the activations in the layer indexed 10.

The above two figures are the losses of two different layers from the same training run, but the scale of has a difference.

Moreover, the sparsity measured by is also vastly different across layers:

Figure 2a. The norm of the activations in the layer indexed 1.
Figure 2b. The norm of the activations in the layer indexed 10.

I argue that this is also undesirable, as we introduced the coefficient in attempt to control the balance between the and loss across layers. Ideally, should have consistent control across layers, which is not the actual case.

Moreover, there is an inconsistency of the norms of the source model’s residual states across layers. We can plot the distribution of residual states[1] norms in GPT-2 small across layers:

Figure 3. The norm distribution of residual states in different layers of the residual stream of GPT-2 small during inference.

It is obvious that the mean and variance of the norms differ across layers.

This effect is common among LLMs, and we can find similar effects in more recent models like LLaMA-2 and Gemma:

Figure 4a. The norm distribution of residual states in different layers of the residual stream of LLaMA2-7B during inference.
Figure 4b. The norm distribution of residual states in different layers of the residual stream of Gemma-2B during inference.

This provides some evidence that the inconsistency of input norms might have caused the undesirable behaviors in SAEs. Thus, I will conduct a theoretical analysis in the next section to further illustrate this problem.

Theoretical Analysis

Definitions

With these observations in mind, let’s do a theoretical analysis on this loss to see why they might have happened.

Formally, a SAE can be defined as the following:

We denote the output of encoder as the feature activation

The loss function for optimization is defined as

where the coefficient is a hyperparameter of the user’s choice and is the k-norm of a given vector.

We set another hyperparameter expansion factor and denote the source model’s residual dimension as . Then we can define and we have , , , and .

In the original implementation, the authors constrained the decoder to have unit norm column vectors, so that during the optimization process the model won’t minimize the loss by increasing the column norms of the decoder and learn to generate dense feature activation of small . This design choice lead to a potential flaw in the method and will be discussed in a later section of this post.

The Effect of Input Norms on Feature Suppression

The authors who identified feature suppression have provided a nice theoretical analysis in the Feature Suppression section, but for the comprehensiveness of this post, I will conduct a similar analysis using the terms defined in this post.

We first consider the extreme case where an input has a feature activation that only has one positive entry , with all other entries equal to 0. Then we have where is the -th column vector of . Since is column normal, we must have .

More generally, I will show that when is sparse, we also have .

Define the index set of all nonzero entries in the feature activation. Then we assume that the feature vectors in the set are (almost) mutually orthogonal[2], which is . By the constraint that the decoder have unit norm, which is , we have

In the case of sparse , we have .

Then our loss function becomes the following:

If we attempt to minimize this loss, there is always a tradeoff between the reconstruction accuracy and the norm of the reconstruction. In most cases, the model will learn to construct that’s close enough to but slightly smaller than to achieve low losses in both terms.

The Effect of Input Norms on the Inconsistency of Across Layers

Here, we make the similar assumption that when is sparse, we have .

For the term, we have

At first glance, this might not be obvious, but if our reconstruction is similar enough to , we can take [3]and the equation simplifies to

Now we can rewrite our loss:

Notice that, if is in a relatively fixed scale, then the first term has a scale of while the second term has a scale of . Then , given a fixed , if we have a larger , the loss term will bias towards the second term, which agrees with the observation I had earlier: the source model’s residual states in deeper layers have larger norms than shallower layers, and the loss was significantly higher in deeper layers as the loss was dominated by the larger term.

Normalizing SAEs

After such analysis, it natural for us to ask: is there a way to solve these problems?

My answer is yes!

Here, I propose an architectural modification to the original SAE architecture, which I have named the Normalized Sparse Autoencoder (NSAE).

Architecture

The modified architecture is defined as the following:

In this definition, is the new feature activation, and is no longer constrained to unit norm. A Gaussian error term is introduced to regularize the feature activation, which is sampled from for some hyperparameter .

The introduction of tanh normalizes every entries of to the range of . The benefits of doing this are threefolds:

  1. This makes independent of the norm of the input, hence theoretically prevents feature suppression.

  2. When the entries of are in the range of , and are much closer, making the loss a more accurate measure of sparsity.

  3. The decoder learns features with norms, which can potentially leads to better interpretability as we can now consider both directions and norms.

The Gaussian noise term is also essential in this architecture. Without it, the model can learn to minimize by learning to map to very small positive values in the feature activation space and learn decoders with extremely large column norms.

To show why adding Gaussian noise solves this problem, I plotted the activation in the following figure:

Figure 5. The tanh(ReLU(x)) function and the ranges that different ranges of inputs maps to. For large inputs, the input range maps to a very small region on the y-axis, meaning that perturbations in that range do close to no change to the output, while smaller inputs are much more sensitive to perturbation.

From the figure, we can see that when the inputs are small, the output of tanh(ReLU) will be relatively sensitive to the input, and adding Gaussian noise can significantly perturb small feature activations. On the contrary, larger inputs to the activation function are much more robust to perturbation, as they all maps to similar values close to . Hence, this perturbation forces the model to learn to generate feature activations that are either strictly 0 or close to 1, which makes behave even more like , especially when we set to be large.

Loss

We also have to redefine the loss as follows:

We introduced the additional step of scaling by the square of the mean of the input norm of one layer. This is because . If we assume that the best an optimizer can do is to achieve a fixed cosine similarity between and without the constraint, then we can treat the term as a constance, so the loss is of the scale , while which should be constant across layers. Therefore, we can manually scale the loss to match the scale of the loss. Another way to scale the loss is by using the actual of the given sample. Theoretically this might cause the model to overfit to inputs of large norms, but for the conciseness of this post, I will leave this problem for future work to investigate, and only use the mean normalization for all the following experiments.

Experiments

I trained two groups of SAEs, one baseline and one experiment, on all layers of GPT2, and each group contains 2 training runs trained on activations. These four runs used different sets of coefficient and learning rate, and the baseline used the original SAE while the experiment used the normalized SAE. I will use “the experiment group” and “the normalized group” interchangeably.

Feature Suppression is Suppressed in Normalized SAE

To investigate feature suppression, I added a new verification metric that measures the ratio between the norm of reconstructions and norm of source activations. Here is this measure during training:

Figure 6. Mean feature suppression () during training, higher is better.

Clearly, the normalized group has significant higher score on feature suppression than the experimental group, and that score is very close to one. Considering the fact that this NSAE didn’t fully converge as it only went through 200M training examples, and there is not a sign of this score to flatten, I claim that NSAEs have less to none feature suppression.

Normalizing Removes the Correlation Between Input Norm and

To investigate the effect of normalization, I collected the norms of different layers during the end of training and plotted them against the mean input norms of the layer:

Figure 7. The correlation between mean input norms and the mean norm of the feature activation.

The red and blue datapoints are from the baseline group whereas the cyan and purple datapoints are from the experiment. We can fit lines to these datapoints to find linear relationships between the mean input norm and the mean norm of the feature activations. Although the fitting is not good, the fitted lines still show a rough positive linear correlation between the mean input norm and the feature activation norm in the baseline. In contrast, the two normalized samples did not exhibit a statistical significant positive linear relationship between input norm and .

This linear fit definitely does not look satisfactory, and I further investigated the reasons behind it. I plotted the normalized group’s against layer index, and here is what it looks like:

Figure 8. The correlation between layer and the mean norm of the feature activation.

I conjecture that in the normalized group reflects a level of discreteness of the activations of the source model, as it exhibit an increase-then-decrease pattern. In the source model, earlier activations are more discrete as they originated from discrete input embeddings, and as deeper activations might be less discrete as they aggregate information. In the last layers, as the model has to make the next token prediction as accurate as possible, the activations might become more discrete again for better next-token decoding since the decoding layer is discrete. This discreteness might also be positively correlated with the monosemanticity of the activations, as more discrete activations are often more interpretable. I will not verify this conjecture in this post due to length considerations, and I welcome other to study this problem.

Agrees with Better

To investigate the agreement between and , I plot the mean and of the feature activations for both groups:

Figure 9. Agreement between and . The thing that matters is the distance between two lines of the same color.

Clearly, the cyan and purple solid lines (which are ) are much closer to their corresponding dashed lines () than the baselines, indicating better agreement between and .

Performance Validation

To validate that the normalization did not heavily impact performance, I present the reconstruction score metric. I first calculate the loss of no intervention, zero intervention (replacing hidden states in one layer with zero vectors), and reconstruction intervention (replacing hidden states in one layer with reconstructed vectors from SAE), and I will denote them as , , and , respectively. Then, the score is calculated by

Since we expect to be higher than , and we want to be close to , so higher score is better, and we expect a value close to . The score during training is show below:

Figure 10. Mean reconstruction score during training.

There is no observable difference between the normalized group and the baseline group except that the normalized group’s score seems slightly more stable during training, indicating that the normalization did not heavily impact performance but might improved training stability.

Since the mean reconstruction score is heavily impacted by the sparsity of the feature activation, I also compared a layer where the of the baseline and experiment group best agrees with each other:

Figure 11a. norm of layer 5 for experiment and baseline.
Figure 11b. Reconstruction score of layer 5 for experiment and baseline.

Still, there is not an observable difference between the experimental group and the baseline after convergence. This provided further evidence that the normalization did not have a observable negative impact on the performance of SAEs.

NSAE Statistics

To further investigate what the new SAE has learned, I did some statistical analysis on the NSAE feature dictionary from the first run. For comparison, I used the original SAE trained in the first baseline run.

I first analyzed the norm distribution of the feature vectors along the layers:

Figure 12. Norm distribution histogram of the feature vectors from the NSAE decoder across layers.

Interestingly, a large proportion of feature vectors have norms in the range of , which might indicate that these vectors are small correction vectors that are added to a bigger vector to make the prediction as close as possible. In contrast, I hypothesize that feature vectors of norms that have high mean activation norm should have good interpretability as they represent general directions to the reconstruction. Hence, I will name these vectors as the pillar vectors.

Next, I calculate the distribution of cosine similarity of the feature dictionary:

Figure 13. Distribution of cosine similarity in the feature dictionary of the NSAE and original SAE, respectively.[4]

From the figure, it’s obvious that the cosine similarity distribution of NSAE and SAE are very similar except that in NSAE there are some cosine similarity very close to one. my hypothesis to these vectors is that in NSAE, there are some direction vectors that appears frequently in different norms in the decomposition of source model activations, so that NSAE have to learn these vectors of the same direction in different norms.

A natural question to ask is that: do pillar vectors and direction vectors overlap? To answer this, I picked the top- vectors (in terms of norm) of each layer from the feature dictionary as a set of pillar vectors and calculated their cosine similarity, and here is the distribution:

Figure 14. Distribution of cosine similarity for high-norm feature vectors (pillar vectors)

Since the are little to none vectors that have very high cosine similarity, there is minimal overlap between pillar vectors and direction vectors.

As this post is already pretty long, I will leave a more comprehensive analysis on the learned feature dictionary to a future post and conclude this post.

Discussion

Limitations

The normalization did not come without cost. NSAEs generally have slightly higher reconstruction losses compared with the original, and it takes longer for NSAE to converge, as shown in the following figure:

Figure 15. L2 reconstruction loss during training, lower is better.

I suspect the reason of this is because NSAE learns a non-unit norm dictionary, and this dictionary have to capture all the norm information with a fixed size, whereas the original SAE can learn directions and add norm information through the feature activations.

Another metric that I don’t know how to interpret is the neural activity. In NSAE, the neural activity are significantly higher than the original SAE:

Figure 16. Neuron activity for baseline and experiment groups.

Lastly, the experiments conducted are relatively small in scale due to limitations in compute. Moreover, due to the change of the loss function, it’s hard to directly match the scales of between the baseline and the experiment group.

Future Work

I suggest future work to go along the following directions:

  1. Investigate other factors that might caused the inconsistency across layers. I proposed a conjecture that it might be the difference in discreteness of source model input activations across layers that caused this inconsistency.

  2. Interpret the learned feature dictionary of NSAE. Future work can further investigate the feature vectors, especially the pillar vectors and direction vectors, and find interpretations for them.

Appendix

Hyperparameters

I varied the hyperparameters l1_coefficients and the optimizer learning rate lr. For the two normalized groups, I also set the standard deviation of the Gaussian noise .

baseline 1baseline 2normalized 1normalized 2
l1_coefficient0.0010730.00096420.000040650.0000965
lr0.00062750.000055840.00090450.000657
N\AN\A11
Table A1. Hyperparameters used for training that varied for different runs
expansion_factor16
context_size256
source_data_batch_size16
train_batch_size4096
max_activations100,000,000
validation_frequency5,000,000
max_store_size100,000
resample_interval200,000,000
n_activations_activity_collate100,000,000
threshold_is_dead_portion_fires1e-6
max_n_resamples4
resample_dataset_size100_000
cache_namesblocks.{layer}.hook_mlp_out
Table A2. Fixed hyperparameters for all runs

Riggs et. al. proposed to use Sparse Autoencoders (SAEs) to discover interpretable features in large language models. Later, Wright et. al. identified the Feature Suppression effect in SAEs and argued that the loss induced smaller feature activations that harmed reconstruction performance. Wes Gurnee observed that the reconstruction errors in SAEs are empirically pathological, and compared different norm-aware interventions to the source model’s inference. Results show that replacing the original residual state with SAE significantly changed the model’s predictions, especially in deeper layers.

  1. ^

    In this and the following examples, I used the residual states from the MLP layer.

  2. ^

    This is a reasonable assumption, as data in Figure 13 (baseline) show that most feature vector pairs in the original sparse autoencoder have cosine similarities in the range of .

  3. ^

    Empirically, , which is close enough for our analysis.

  4. ^

    For computational efficiency, I randomly sampled features from the cosine similarity matrix.

  5. ^

    collected from step=3000. Input norm sampled from a relatively small sample of random text. This text is the same as the text used to generate figure 3, 4a, and 4b.