Experiments with an alternative method to promote sparsity in sparse autoencoders

Summary

I experimented with alternatives to the standard L1 penalty used to promote sparsity in sparse autoencoders (SAEs). I found that including terms based on an alternative differentiable approximation of the feature sparsity in the loss function was an effective way to generate sparsity in SAEs trained on the residual stream of GPT2-small. The key findings include:

  1. SAEs trained with this new loss function had a lower L0 norm, lower mean-squared error, and fewer dead features compared to a reference SAE trained with an L1 penalty.

  2. This approach can effectively discourage the production of dead features by adding a penalty term to the loss function based on features with a sparsity below some threshold.

  3. SAEs trained with this loss function had different feature sparsity distributions and significantly higher L1 norms compared to L1-penalised models.

Loss functions that incorporate differentiable approximations of sparsity as an alternative to the standard L1 penalty appear to be an interesting direction for further investigation.

Motivation

Sparse autoencoders (SAEs) have been shown to be effective at extracting interpretable features from the internal activations of language models (e.g. Anthropic & Cunningham et al.). Ideally, we want SAEs to simultaneously (a) reproduce the original language model behaviour and (b) to consist of monosemantic, interpretable features. SAE loss functions usually contain two components:

  1. Mean-squared error (MSE) between the SAE output and input activations, which helps with reconstructing the original language model activations, and ultimately with model behaviour.

  2. L1 penalty on the SAE feature activations (the sum of the magnitude of the feature activations) to promote sparsity in the learned representation.

The relative importance of each term is controlled by a coefficient on the L1 penalty, which allows the model to move along the trade-off between reconstruction of the language model behaviour and a highly sparse representation. In this post, I present experiments with alternatives to the standard L1 penalty to promote sparsity in SAEs.

Approximations of the sparsity

A key requirement for SAE features to be interpretable is that most of them are sparse. In this context, the sparsity, s, of a given SAE feature, f, is the fraction of tokens for which the feature has a nonzero activation. For instance, a sparsity of means that the feature has a nonzero post-GELU activation for 1% of all tokens. We often use the L0 norm as an average measure of sparsity over the entire SAE, defined as the average number of features with nonzero post-GELU activations per token.

In principle, we may want to simply add the value of the L0 norm to the loss function, instead of the L1 norm. However, the calculation of the L0 norm from the feature activations a, involves a function that evaluates to 0 if a = 0, otherwise to 1 for a > 0 (see blue line in Figure 1). This calculation is not differentiable and therefore it cannot be directly used in the loss function.

Figure 1: The contribution of a given component to the sparsity calculation as a function of the feature activation for a range of different sparsity measures.

There are many differentiable measures of sparsity that approximate the L0 norm (Hurley & Rickard 2009). The L1 norm is one example. Another example that Anthropic recently discussed in their updates is the tanh function, that asymptotically approaches 1 for large values of the feature activation, a.

The usefulness of these approximations as a penalty for sparsity in SAE loss functions likely depends on a combination of how accurately they approximate the L0 norm, and the derivative of the measure as a function of feature activation that is used by the optimiser in the training process. To highlight this, Figure 2 shows the derivatives of the sparsity contribution with respect to the feature activation for each sparsity measure.

Figure 2: The derivative of the contribution of a given component to the sparsity calculation as a function of the feature activation for a range of different sparsity measures.


Figure 1 presents a further example of a sparsity measure, the function . In this approximation, smaller values of provide a more accurate approximation of L0, while larger values of provide larger gradients for large feature activations and more moderate gradients for small feature activations. Under this approximation, the feature sparsities in a batch can be approximated as:

where is the vector of feature sparsities, nb is the batch size, are the activations for each feature and each element in the batch, and is a small constant. One can approximate the L0 in a similar way,

and include this term in the loss function as an alternative to the L1 penalty.

In addition to the loss function, recent work training SAEs on language model activations often included techniques in the training process to limit the number of dead SAE features that are produced (e.g. the resampling procedure described by Anthropic). As an attempt to limit the number of dead features that form, I experimented with adding the following term to the loss function that penalises features with a sparsity below a given threshold:

where is the desired minimum sparsity threshold, and are the feature sparsities. Figure 3 visualises the value of this term as a function of the feature sparsity for .

Figure 3: The minimum sparsity function that penalises features with a sparsity below a given threshold (e.g. 1e-5 in this figure).


Before this term can be directly included in the loss function, we must deal with the fact that in the expression for given above, the minimum sparsity it can deal with is limited by the batch size, e.g. a batch size of 4096 cannot resolve sparsities below ~0.001. To take into account arbitrarily low sparsity values, we can take the average of the sparsity of each feature over the last n training steps. We can then use this more accurate value of the sparsity in the RELU function, but with the gradients from the original expression for above.

In addition to the two terms presented here, I explored a wide range of alternative terms in the loss function. Many of these didn’t work, and some worked reasonably well. Some of these alternatives are discussed below.

Training the SAEs

I trained SAEs on activations of the residual stream of GPT2-small at layer 1 to have a reference point with Joseph Bloom’s models released a few weeks ago here. I initially trained a model with as similar a setup as I could to the reference model for comparison purposes, e.g. same learning rate, number of features, batch size, training steps, but I had to remove the pre-encoder bias as I found the loss function didn’t work very well with it. I checked that simply removing the pre-encoder bias from the original model setup with the L1 + ghost gradients did not generate much improvement.

I implemented the following loss function:

where is given by the expression above, , , and where I varied to vary the sparsity. I computed 5 SAEs, varying from to . I’ll discuss the properties of these SAEs with reference to their coefficient.

The L0, MSE and number of dead features of the 5 SAEs are summarised in the following table, along with the reference model from Joseph Bloom trained with an L1 penalty (JB L1 ref). Three of the new SAEs simultaneously achieve a lower L0 and lower MSE than the reference L1 model. For instance, the model has a value of L0 that is 6% lower and a MSE that is 30% lower than the reference L1 model. This seems promising and worth exploring further.

Model[1]L0MSE# Dead Features
JB L1 reference14.601.1e-33777
19.347.0e-479
16.947.4e-486
13.767.8e-494
10.958.7e-4161
9.279.3e-4218


Figure 4 shows the evolution of L0 and the mean-squared error during the training process for these 5 SAEs trained on the above loss function. We can see that they reach a better region of the parameter space in terms of L0 and the mean squared error, as compared to the reference L1 model.

Figure 4: Evolution of L0 and the mean squared error during training for the 5 models trained on the approximate L0 loss function compared to the reference model trained on an L1 penalty from Joseph Bloom.

Feature sparsity distributions

A useful metric to look at when training SAEs is the distribution of feature sparsities. Plotting these distributions can reveal artefacts or inefficiencies in the training process, such as large numbers of features with low sparsity (or dead features), large numbers of high density features, and the shape of the overall distribution of sparsities. Figure 5 shows the feature sparsities for the five new SAEs models trained on the loss function described above, compared to the reference L1 model. The distributions of the 5 new models are slightly wider than the reference L1 model. We can also see the significant number of dead features (i.e. at a log sparsity of −10) in the reference L1 model compared to the new models. The light grey vertical line at a log sparsity of −5 indicates the value of , the sparsity threshold below which features are penalised in the loss function. We can see that there is a sharp drop-off in features just above and at this threshold. This suggests that the loss function term to discourage the formation of highly sparse features is working as intended.

Figure 5: Distribution of feature sparsities for the 5 models trained on the approximate L0 loss function compared to the reference model trained on an L1 penalty with ghost grads from Joseph Bloom. Light grey vertical line at 1e-5 indicates the value of min_sparsity_target, the sparsity threshold below which features are penalised in the loss function. Dead features are assigned a log sparsity of −10.

Figure 6 shows the same distribution for the model and the L1 reference model on a log-scale. Here we see more significant differences between the feature distributions at higher sparsities. The model is closer to a power law distribution compared to the L1 reference model, which contains a bump at around −2. This is reminiscent of Zipf’s law for the frequency of words in natural language. Since we are training on the residual stream before layer 1 of GPT2-small, it would not be surprising if the distribution of features closely reflected the distribution of words in natural language. However, this is just speculation and requires proper investigation. A quick comparison shows the distribution matches a power law with slope around −0.9, although there appears to still be a small bump in the feature sparsity distribution around a log sparsity of −2. This bump may be reflective of the reality of the feature distribution in GPT2-small, or may be an artefact of the imperfect training process.

Figure 6: Same as Figure 5, but in log scale and just for the model and the reference L1-penalised model (grey). The black line indicates a power law with a slope of −0.9.

High density features

The model contains a small number (7) of high density features with sparsities above 0.2 that the reference L1 model does not contain. A quick inspection of the max activating tokens of these features suggests they are reasonably interpretable. Several appeared to be positional based features. For instance, one fired strongly on tokens at positions 1, 2 & 3, and weaker for later positions. Another fired strongly at position 127 (the final token in each context) and weaker for earlier positions. One was firing on short prepositions such as “on”, “at”. Another was firing strongly shortly after new line tokens. In principle, these features can be made more sparse, if desired for interpretability purposes, but it’s not clear whether that’s needed, desired, or what the cost associated with enforcing this would be. Interestingly, the same or very similar features are present in all models from to .

Avoiding dead features

Dead features are a significant problem in the training of SAEs. Whatever procedure is used to promote sparsity also runs the risk of generating dead features that can no longer be useful in the SAE. Methods like re-sampling and ghost gradients have been proposed to try to improve this situation.

The third term in the loss function written above helps to avoid the production of dead features. As a result, dead features can be greatly inhibited or almost completely eliminated in these new SAEs. The light grey vertical line in the figure indicates the value of , the sparsity threshold below which features are penalised in the loss function. Note the sharp drop-off in feature sparsity below . Further experimentation with hyperparameters may reduce the number of dead features to ~0, although it’s possible that this comes at some cost to the rest of the model.

The behaviour of the RELU term in the loss function depends somewhat on the learning rate. A lower learning rate tends to nudge features back to the desired sparsity range, shortly after the sparsity drops outside the desired range. A large learning rate can either cause oscillations (for over-dense features) or can cause over-sparse features to be bumped back to high density features, almost as if they are resampled.

Comparison of training curves

Evolution of mean squared error & L0

Figures 7 & 8 show the evolution of the MSE and L0 during the training process. The L0 and MSE trained on follow a slightly different evolution to the L1 reference model. In addition, the L0 and MSE are still noticeably declining after training on 80k steps (~300M tokens), as compared to the reference L1 model that seems to flatten out beyond a given time-step in the training process. This suggests that training on more tokens may improve the SAEs.

Figure 7: Evolution of the mean squared error vs. the number of training steps during the training procedure compared with the reference L1-penalised moel
Figure 8: Evolution of the L0 vs. the number of training steps during the training procedure compared with the reference L1-penalised model

Evolution of L1

Figure 9 compares the L1 norms of the new models with the L1 reference model. The fact that the L1 norms of the new models are substantially different to the model with the L1 penalty (and note that W_dec is normalised in all models) is evidence that the SAEs are different. This is obviously not related to which SAE is better, only that they are different.

Figure 9: Evolution of the L1 norm vs. the number of training steps during the training procedure compared with the reference L1-penalised model

Discussion

Advantages of this loss function

  1. In principle, it seems like you can more directly access the trade-off between sparsity and model reconstruction compared to an L1 penalty by optimising for specific components of the sparsity distribution, and avoid requiring that the L1 norm be small.

  2. Dead features can be almost completely avoided by adding the RELU term discussed above. Whether this is ultimately good for the SAE overall needs to be explored further.

  3. It appears to be scalable. The sparsity distribution is something general that applies to all SAEs at all scales on all models.

Shortcomings and other considerations

  1. I did some tests on random features for interpretability, and found them to be interpretable. However, one would need to do a detailed comparison with SAEs trained on an L1 penalty to properly understand whether this loss function impacts interpretability. For what it’s worth, the distribution of feature sparsities suggests that we should expect reasonably interpretable features.

  2. It’s not yet clear to me if the RELU loss term that helps to avoid dead features is actually substantially helping the overall SAE, or simply avoiding dead features. While removing the RELU term from the loss function in the training process results in a much larger MSE, as otherwise many features end up dead, whether this is an appropriate way to avoid dead features is an open question.

  3. It’s not clear what value we should take for in the equation for the loss function, or if we need to start with a larger value to allow the gradients to propagate and then decrease as the sparsity decreases. I chose a value of for these models, and did some tests with and . A value of resulted in a very small improvement to the MSE, but requires more tokens to reach this improved model.

  4. The new models produce more high-frequency features (sparsity > 0.2) than the L1 reference model. I’m not sure that this is necessarily a problem and it might depend on the model.

  5. It’s worth making sure that any additional complexity (e.g. more terms in the loss function) in the model should come with sufficient advantages.

  6. Further comparisons with other models and different techniques are needed.

Alternative loss terms based on the sparsity

Given an approximation of the sparsity distribution in the loss function, there are many different terms that one could construct to add to the loss function. Some examples include:

  1. The mean of the sparsity distribution

  2. , where is a list of sparsities with length

  3. . This adds additional encouragement for sparsity for features above sparsity threshold .

I explored these terms and found that they all worked to varying extents. Ultimately, they were not more effective than the function I chose to discuss in detail above. Further investigation will probably uncover better loss function terms, or a similar function, but based on a better approximation of the feature sparsity.

Summary of other architecture and hyperparameter tests

  • Changing the learning rate up or down by a factor of two didn’t result in any improvement.

  • Reducing the value of epsilon in the approximation of the sparsity improves the final model slightly, but requires more tokens to reach the improved value.

  • Setting a negative initial bias for the encoder, and scaling the initial weights of W_enc speeds up the generation of sparsity, but seems to result in a slightly worse-off final model

  • I found that removing the pre-encoder bias generally helps. Including it provides a better starting point for training, but the end point is not as good. Anthropic have recently reported in their monthly updates that they no longer find that a pre-encoder bias is useful.

  • Normalising W_dec seems to help, even without an L1 penalty. I haven’t looked in detail as to why this is the case, or explored more flexible alternatives.

  • I tried approximating the sparsity with the tanh(x) function and found that, while it worked reasonably well, it was not as effective in terms of the L0 and MSE as the L0 approximation I presented above. However, I did not find that it produced high-density features.

  • I tested the same loss function on layers 2 and 9 of the residual stream of GPT2-small and found similar improvements with respect to reference L1-penalised models.

Acknowledgements

I’d like to thank Evan Anders, Philip Quirke, Joseph Bloom and Neel Nanda for helpful discussion and feedback. This work was supported by a grant from Open Philanthropy.

  1. ^

    MSE computed by Joseph’s old definition for comparison purposes