Efficient Dictionary Learning with Switch Sparse Autoencoders

Produced as part of the ML Alignment & Theory Scholars Program—Summer 2024 Cohort

0. Summary

To recover all the relevant features from a superintelligent language model, we will likely need to scale sparse autoencoders (SAEs) to billions of features. Using current architectures, training extremely wide SAEs across multiple layers and sublayers at various sparsity levels is computationally intractable. Conditional computation has been used to scale transformers (Fedus et al.) to trillions of parameters while retaining computational efficiency. We introduce the Switch SAE, a novel architecture that leverages conditional computation to efficiently scale SAEs to many more features.

1. Introduction

The internal computations of large language models are inscrutable to humans. We can observe the inputs and the outputs, as well as every intermediate step in between, and yet, we have little to no sense of what the model is actually doing. For example, is the model inserting security vulnerabilities or backdoors into the code that it writes? Is the model lying, deceiving or seeking power? Deploying a superintelligent model into the real world without being aware of when these dangerous capabilities may arise leaves humanity vulnerable. Mechanistic interpretability (Olah et al.) aims to open the black-box of neural networks and rigorously explain the underlying computations. Early attempts to identify the behavior of individual neurons were thwarted by polysemanticity, the phenomenon in which a single neuron is activated by several unrelated features (Olah et al.). Language models must pack an extremely vast amount of information (e.g., the entire internet) within a limited capacity, encouraging the model to rely on superposition to represent many more features than there are dimensions in the model state (Elhage et al.).

Sharkey et al. and Cunningham et al. propose to disentangle superimposed model representations into monosemantic, cleanly interpretable features by training unsupervised sparse autoencoders (SAEs) on intermediate language model activations. Recent work (Templeton et al., Gao et al.) has focused on scaling sparse autoencoders to frontier language models such as Claude 3 Sonnet and GPT-4. Despite scaling SAEs to 34 million features, Templeton et al. estimate that they are likely orders of magnitude short of capturing all features. Furthermore, Gao et al. train SAEs on a series of language models and find that larger models require more features to achieve the same reconstruction error. Thus, to capture all relevant features of future large, superintelligent models, we will likely need to scale SAEs to several billions of features. With current methodologies, training SAEs with billions of features at various layers, sublayers and sparsity levels is computationally infeasible.

Training a sparse autoencoder generally consists of six major computations: the encoder forward pass, the encoder gradient, the decoder forward pass, the decoder gradient, the latent gradient and the pre-bias gradient. Gao et al. introduce kernels and tricks that leverage the sparsity of the TopK activation function to dramatically optimize all computations excluding the encoder forward pass, which is not (yet) sparse. After implementing these optimizations, Gao et al. attribute the majority of the compute to the dense encoder forward pass and the majority of the memory to the latent pre-activations. No work has attempted to accelerate or improve the memory efficiency of the encoder forward pass, which remains the sole dense matrix multiplication.

In a standard deep learning model, every parameter is used for every input. An alternative approach is conditional computation, where only a small subset of the parameters are active depending on the input. This allows us to scale model capacity and parameter count without suffering from commensurate increases in computational cost. Shazeer et al. introduce the Sparsely-Gated Mixture-of-Experts (MoE) layer, the first general purpose architecture to realize the potential of conditional computation at huge scales. The Mixture-of-Experts layer consists of (1) a set of expert networks and (2) a routing network that determines which experts should be active on a given input. The entire model is trained end-to-end, simultaneously updating the routing network and the expert networks. The underlying intuition is that each expert network will learn to specialize and perform a specific task, boosting the overall model capacity. Shazeer et al. successfully use MoE to scale LSTMs to 137 billion parameters, surpassing the performance of previous dense models on language modeling and machine translation benchmarks.

Shazeer et al. restrict their attention to settings in which the input is routed to several experts. Fedus et al. introduce the Switch layer, a simplification to the MoE layer which routes to just a single expert. This simplification reduces communication costs and boosts training stability. By replacing the MLP layer of a transformer with a Switch layer, Fedus et al. scale transformers to over a trillion parameters.

In this work, we introduce the Switch Sparse Autoencoder, which combines the Switch layer (Fedus et al.) with the TopK SAE (Gao et al.). The Switch SAE is composed of many smaller expert SAEs as well as a trainable routing network that determines which expert SAE will process a given input. We demonstrate that the Switch SAE is a Pareto improvement over existing architectures while holding training compute fixed. We additionally show that Switch SAEs are significantly more sample-efficient than existing architectures.

2. Methods

2.1 Baseline Sparse Autoencoder

Let be the dimension of the language model activations. The linear representation hypothesis states that each feature is represented by a unit-vector in . Under the superposition hypothesis, there exists a dictionary of features () represented as almost orthogonal unit-vectors in . A given activation can be written as a sparse, weighted sum of these feature vectors. Let be a sparse vector in representing how strongly each feature is activated. Then, we have:

A sparse autoencoder learns to detect the presence and strength of the features given an input activation . SAE architectures generally share three main components: a pre-bias , an encoder matrix and a decoder matrix . The TopK SAE defined by Gao et al. takes the following form:

The latent vector represents how strongly each feature is activated. Since is sparse, the decoder forward pass can be optimized by a suitable kernel. The bias term is designed to model , so that . Note that and are not necessarily transposes of each other. Row of the encoder matrix learns to detect feature while simultaneously minimizing interference with the other almost orthogonal features. Column of the decoder matrix corresponds to . Altogether, the SAE consists of parameters.

We additionally benchmark against the ReLU SAE (Conerly et al.) and the Gated SAE (Rajamanoharan et al.). The ReLU SAE applies an L1 penalty to the latent activations to encourage sparsity. The Gated SAE separately determines which features should be active and how strongly activated they should be to avoid activation shrinkage (Wright and Sharkey).

2.2 Switch Sparse Autoencoder Architecture

The Switch Sparse Autoencoder avoids the dense matrix multiplication. Instead of being one large sparse autoencoder, the Switch Sparse Autoencoder is composed of smaller expert SAEs . Each expert SAE resembles a TopK SAE with no bias term:

Each expert SAE is times smaller than the original SAE. Specifically, and . Across all experts, the Switch SAE represents features.

The Switch layer takes in an input activation and routes it to the best expert. To determine the expert, we first subtract a bias . Then, we multiply by which produces logits that we normalize via a softmax. Let denote the softmax function. The probability distribution over the experts is given by:

We route the input to the expert with the highest probability and weight the output by that probability to allow gradients to propagate. We subtract a bias before passing to the selected expert and add it back after weighting by the corresponding probability:

Figure 1: Switch Sparse Autoencoder Architecture. The input activation passes through a router which sends it to the relevant expert SAE.

In total, the Switch Sparse Autoencoder contains parameters, whereas the TopK SAE has parameters. The additional parameters we introduce through the router are an insignificant proportion of the total parameters because .

During the forward pass of a TopK SAE, parameters are used during the encoder forward pass, parameters are used during the decoder forward pass and parameters are used for the bias, for a total of parameters used. Since , the number of parameters used is dominated by . During the forward pass of a Switch SAE, parameters are used for the router, parameters are used during the encoder forward pass, parameters are used during the decoder forward pass and 2 parameters are used for the biases, for a total of parameters used. Since the encoder forward pass takes up the majority of the compute, we effectively reduce the compute by a factor of . This approximation becomes better as we scale , which will be required to capture all the safety-relevant features of future superintelligent language models. Furthermore, the TopK SAE must compute and store pre-activations. Due to the sparse router, the Switch SAE only needs to store pre-activations, improving memory efficiency by a factor of as well.

2.3 Switch Sparse Autoencoder Training

We train the Switch Sparse Autoencoder end-to-end. Weighting by in the calculation of allows the router to be differentiable. We adopt many of the training strategies described in Bricken et al. and Gao et al. with a few exceptions. We initialize the rows (features) of to be parallel to the columns (features) of for all . We initialize both and to the geometric median of a batch of samples (but we do not tie and ). We additionally normalize the decoder column vectors to unit-norm at initialization and after each gradient step. We remove gradient information parallel to the decoder feature directions. We set the learning rate based on the scaling law from Gao et al. and linearly decay the learning rate over the last 20% of training. We do not include neuron resampling (Bricken et al.), ghost grads (Jermyn et al.) or the AuxK loss (Gao et al.).

The ReLU SAE loss consists of a weighted combination of the reconstruction MSE and a L1 penalty on the latents to encourage sparsity. The TopK SAE directly enforces sparsity via its activation function and thus directly optimizes the reconstruction MSE. Following Fedus et al., we train our Switch SAEs using a weighted combination of the reconstruction MSE and an auxiliary loss which encourages the router to send an equal number of activations to each expert to reduce overhead. Empirically, we also find that the auxiliary loss improves reconstruction fidelity.

For a batch with activations, we first compute vectors and . represents what proportion of activations are sent to each expert, while represents what proportion of router probability is assigned to each expert. Formally,

The auxiliary loss is then defined to be:

The auxiliary loss achieves its minimum when the expert distribution is uniform. We scale by so that for a uniformly random router. The inclusion of allows the loss to be differentiable.

The reconstruction loss is defined to be:

Note that . Let represent a tunable load balancing hyperparameter. The total loss is then defined to be:

We optimize using Adam ().

3. Results

We train SAEs on the residual stream activations of GPT-2 small (). In this work, we follow Gao et al. and focus on layer 8. Using text data from OpenWebText, we train for 100K steps using a batch size of 8192, for a total of ~820M tokens. We benchmark the Switch SAE against the ReLU SAE (Conerly et al.), the Gated SAE (Rajamanoharan et al.) and the TopK SAE (Gao et al.). We present results for two settings.

  1. Fixed Width: Each SAE is trained with features. We train Switch SAEs with 16, 32, 64 and 128 experts. Each expert of the Switch SAE with experts has features. The Switch SAE performs roughly times fewer FLOPs per activation compared to the TopK SAE.

  2. FLOP-Matched: The ReLU, Gated and TopK SAEs are trained with features. We train Switch SAEs with 2, 4 and 8 experts. Each expert of the Switch SAE with experts has features, for a total of features. The Switch SAE performs roughly the same number of FLOPs per activation compared to the TopK SAE.

For a wide range of sparsity (L0) values, we report the reconstruction MSE and the proportion of cross-entropy loss recovered when the sparse autoencoder output is patched into the language model. A loss recovered value of 1 corresponds to a perfect reconstruction, while a loss recovered value of 0 corresponds to a zero-ablation.

3.1 Fixed Width Results

We train Switch SAEs with 16, 32, 64 and 128 experts (Figure 2, 3). The Switch SAEs consistently underperform compared to the TopK SAE in terms of MSE and loss recovered. The Switch SAE with 16 experts is a Pareto improvement compared to the Gated SAE in terms of both MSE and loss recovered, despite performing roughly 16x fewer FLOPs per activation. The Switch SAE with 32 experts is a Pareto improvement compared to the Gated SAE in terms of loss recovered. The Switch SAE with 64 experts is a Pareto improvement compared to the ReLU SAE in terms of both MSE and loss recovered. The Switch SAE with 128 experts is a Pareto improvement compared to the ReLU SAE in terms of loss recovered. The Switch SAE with 128 experts is a Pareto improvement compared to the ReLU SAE in terms of MSE, excluding when . The scenario for the 128 expert Switch SAE is an extreme case: each expert SAE has features, meaning that the TopK activation is effectively irrelevant. When L0 is low, Switch SAEs perform particularly well. This suggests that the features that improve reconstruction fidelity the most for a given activation lie within the same cluster.

Figure 2: L0 vs. MSE for fixed width SAEs. The 16 expert Switch SAE outperforms the Gated SAE. The 32 and 64 expert Switch SAEs outperform the ReLU SAE. The 128 expert Switch SAE outperforms the ReLU SAE excluding the extreme setting.
Figure 3: L0 vs. Loss Recovered for fixed width SAEs. The 16 and 32 expert Switch SAEs outperform the Gated SAE. The 64 and 128 expert Switch SAEs outperform the ReLU SAE.

These results demonstrate that Switch SAEs can reduce the number of FLOPs per activation by up to 128x while still retaining the performance of a ReLU SAE. Switch SAEs can likely achieve greater acceleration on larger language models.

3.2 FLOP-Matched Results

We train Switch SAEs with 2, 4 and 8 experts (Figure 4, 5, 6). The Switch SAEs are a Pareto improvement over the TopK, Gated and ReLU SAEs in terms of both MSE and loss recovered. As we scale up the number of experts and represent more features, performance continues to increase while keeping computational costs and memory costs (from storing the pre-activations) roughly constant.

Figure 4: L0 vs. MSE for FLOP-matched SAEs. The Switch SAEs consistently outperform the TopK, Gated and ReLU SAEs. Performance improves with a greater number of experts.
Figure 5: L0 vs. Loss Recovered for FLOP-matched SAEs. The Switch SAEs consistently outperform the TopK, Gated and ReLU SAEs. Performance improves with a greater number of experts.

Fedus et al. find that their sparsely-activated Switch Transformer is significantly more sample-efficient compared to FLOP-matched, dense transformer variants. We similarly find that our Switch SAEs are 5x more sample-efficient compared to the FLOP-matched, TopK SAE baseline. Our Switch SAEs achieve the reconstruction MSE of a TopK SAE trained for 100K steps in less than 20K steps. This result is consistent across 2, 4 and 8 expert Switch SAEs.

Figure 6: Sample efficiency of Switch SAEs compared to the TopK SAE. Switch SAEs achieve the same MSE as the TopK SAE in 5x fewer training steps.

Switch SAEs speed up training while capturing more features and keeping the number of FLOPs per activation fixed. Kaplan et al. similarly find that larger models are more sample efficient.

4. Conclusion

The diverse capabilities (e.g., trigonometry, 1960s history, TV show trivia) of frontier models suggest the presence of a huge number of features. Templeton et al. and Gao et al. make massive strides by successfully scaling sparse autoencoders to millions of features. Unfortunately, millions of features are not sufficient to capture all the relevant features of frontier models. Templeton et al. estimate that Claude 3 Sonnet may have billions of features, and Gao et al. empirically predict that future larger models will require more features to achieve the same reconstruction fidelity. If we are unable to train sufficiently wide SAEs, we may miss safety-crucial features such as those related to security vulnerabilities, deception and CBRN. Thus, further research must be done to improve the efficiency and scalability of SAE training. To monitor future superintelligent language models, we will likely need to perform SAE inference during the forward pass of the language model to detect safety-relevant features. Large-scale labs may be unwilling to perform this extra computation unless it is both computationally and memory efficient and does not dramatically slow down model inference. It is therefore crucial that we additionally improve the inference time of SAEs.

Thus far, the field has been bottlenecked by the encoder forward pass, the sole dense matrix multiplication involved in SAE training and inference. This work presents the first attempt to overcome the encoder forward pass bottleneck. Taking inspiration from Shazeer et al. and Fedus et al., we introduce the Switch Sparse Autoencoder, which replaces the standard large SAE with many smaller expert SAEs. The Switch Sparse Autoencoder leverages a trainable router that determines which expert is used, allowing us to scale the number of features without increasing the computational cost. When keeping the width of the SAE fixed, we find that we can reduce the number of FLOPs per activation by up to 128x while still maintaining a Pareto improvement over the ReLU SAE. When fixing the number of FLOPs per activation, we find that Switch SAEs train 5x faster and are a Pareto improvement over TopK, Gated and ReLU SAEs.

Future Work

This work is the first to combine Mixture-of-Experts with Sparse Autoencoders to improve the efficiency of dictionary learning. There are many potential avenues to expand upon this work.

  • We restrict our attention to combining the Switch layer (Fedus et al.) with the TopK SAE (Gao et al.). It is possible that combining the Switch layer with the ReLU SAE or the Gated SAE may have superior qualities.

  • We require that every expert within a Switch SAE is homogeneous in terms of the number of features and the sparsity level. Future work could relax this constraint to allow for non-uniform feature cluster sizes and adaptive sparsity.

  • Switch SAEs trained on larger language models may begin to suffer from dead latents. Future work could include a modified AuxK loss to prevent this.

  • We restrict our attention to a single router. Future work could explore the possibility of further scaling the number of experts with hierarchical routers. Doing so may provide additional insight into feature splitting and geometry.

  • Following Fedus et al., we route to a single expert SAE. It is possible that selecting several experts will improve performance. The computational cost will scale with the number of experts chosen.

  • The routing network resembles the encoder of a sparse autoencoder. How do the feature directions of the routing network relate to the features of the corresponding expert SAEs?

  • In this work, we train Switch SAEs on the residual stream, but future work could train Switch SAEs on the MLPs and attention heads.

Acknowledgements

This work was supervised by Christian Schroeder de Witt and Josh Engels. I used the dictionary learning repository to train my SAEs. I would like to thank Samuel Marks and Can Rager for advice on how to use the repository. I would also like to thank Jacob Goldman-Wetzler, Achyuta Rajaram, Michael Pearce, Gitanjali Rao, Satvik Golechha, Kola Ayonrinde, Rupali Bhati, Louis Jaburi, Vedang Lad, Adam Karvonen, Shiva Mudide, Sandy Tanwisuth, JP Rivera and Juan Gil for helpful discussions.