This paper presents ProLU, an alternative to ReLU for the activation function in sparse autoencoders that produces a pareto improvement over both standard sparse autoencoders trained with an L1 penalty and sparse autoencoders trained with a Sqrt(L1) penalty.
ProLU(mi,bi)={miif mi+bi>0 and mi>00otherwiseSAEProLU(x)=ProLU((x−bdec)Wenc,benc)Wdec+bdec
The gradient wrt.b is zero, so we generate two candidate classes of differentiable ProLU:
ProLUReLU
∂∗ProLUReLU(mi,bi)∂bi=∂ProLUReLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwise
ProLUSTE
∂∗ProLUSTE(mi,bi)∂mi={1+miif mi>0 and mi+bi>00otherwise
∂∗ProLUSTE(mi,bi)∂bi={miif mi>0 and mi+bi>00otherwise
so that the full computation done by an SAE can be expressed as
SAE(x)=decode(encode(x))
An SAE is trained with gradient descent on
Ltrain=||x−SAE(x)||22+λP(encode(x))
where λ is the sparsity penalty coefficient (often “L1 coefficient”) and P is the sparsity penalty function, used to encourage sparsity.
P is commonly the L1 norm ||a||1 but recently P(a)=||a||1212 has been shown to produce a Pareto improvement on the L0 and CE metrics. We will use this as a further baseline to compare against when assessing our models in addition to the standard ReLU-based SAE with L1 penalty.
Motivation: Inconsistent Scaling in Sparse Autoencoders
Due to the affine translation, sparse autoencoder features with nonzero encoder biases only perfectly reconstruct feature magnitudes at a single point.
This poses difficulties if activation magnitudes for a fixed feature tend to vary over a wide range. This potential problem motivates the concept of scale consistency:
A scale consistent response curve
The bias maintains its role in noise suppression, but no longer translates activation magnitudes when the feature is active.
The lack of gradients for the encoder bias term poses a challenge for learning with gradient descent. This paper will formalize an activation function which gives SAEs this scale-consistent response curve, and motivate and propose two plausible synthetic gradients, and compare scale-consistent models trained with the two synthetic gradients to standard SAEs and SAEs trained with Sqrt(L1) penalty.
Scale Consistency Desiderata
Notation: Centered Submodule
The use of the decoder bias can be viewed as performing centering on the inputs to a centered SAE then reversing the centering on the outputs:
To use ProLU in SGD-optimized models, we first address the lack of gradients wrt. the b term.
ReLU gradients:
For comparison and later use, we will first consider ReLU: partial derivatives are well defined for ReLU at all points other than xi=0:
∂ReLU(xi)∂xi={1if xi>00if xi<0
Gradients of ProLU:
Partials of ProLU wrt.m are similarly well defined:
∂ProLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwise
However, they are not well defined wrt. b, so we must synthesize these.
Methods
Notation: Synthetic Gradients
Let ∂∗f∂x denote the synthetic partial derivative of f wrt. x, and ∇∗f the synthetic gradient of f, used for backpropagation as a stand-in for the gradient.
Different synthetic gradient types
We train two classes of ProLU with different synthetic gradients. These are distinguished by their subscript:
ProLUReLU
ProLUSTE
They are identical in output, but have different synthetic gradients. I.e.
The first synthetic gradient is very similar to the gradient for ReLU. We retain the gradient wrt. m, and define the synthetic gradient wrt.b to be the same as the gradient wrt. m:
∂∗ProLUReLU(mi,bi)∂bi=∂ProLUReLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwise
Defining ProLUSTE: Derivation from straight-through estimator
The second class of ProLU uses synthetic gradients for both b and m and can be motivated by framing ProLU and ReLU in terms of the threshold function, and a common choice of straight-through estimator (STE) for the threshold function. This is a plausible explanation for the observed empirical performance but it should be noted that there are many degrees of freedom and possible alternative
Setup
The threshold function Thresh is defined as follows:
Thresh(x)={1if x>00otherwise
We will rephrase the partial derivative of ReLU in terms of the threshold function for ease of later notation:
∂ReLU(xi)∂xi={1if xi>00if xi<0=Thresh(xi)
It is common to use a straight-through estimator (STE) to approximate the gradient of the threshold function:
∂∗Thresh(xi)∂xi=STThresh(xi)
We can reframe ProLU in terms of the threshold function:
ProLU(mi,bi)=ReLU(mi)⋅Thresh(mi+bi)
Synthetic Gradients wrt. m
Now, we take partial derivatives of ProLU wrt.m using the STE approximation for the threshold function:
There are many possible functions to use for STThresh(x). In our experiments, we take the derivative of ReLU as the choice of straight-through estimator. This choice has been used in training quantized neural nets.
STThresh(x):=Thresh(x)
then, synthetic gradients wrt.m are given by,
∂∗ProLUSTE(mi,bi)∂mi=Thresh(mi)⋅Thresh(mi+bi)+miThresh(mi)⋅Thresh(mi+bi)=(1+mi)⋅Thresh(mi)⋅Thresh(mi+bi)={1+miif mi>0 and mi+bi>00otherwise
and wrt.b are given by,
∂∗ProLUSTE(mi,bi)∂bi=miThresh(mi)⋅Thresh(mi+bi)={miif mi>0 and mi+bi>00otherwise
ProLU Sparse Autoencoder
We can express the encoder of a ProLU SAE as
encodeProLU(x)=ProLU((x−bdec)Wenc,benc)
No change is needed to the decoder. Thus,
SAEProLU(x)=decode(encodeProLU(x))
Experiment Setup
Shared among all sweeps:
Adam optimizer, with:
β1=0.9,β2=0.999
batch size=4096
Data
Trained on gpt2 layer 6 pre-residual activations
Tokens: ~400m tokens from The Pile @ack(Alan Cooney’s pre-tokenized pile)
adjusted L1-coefficient ranges for each model, to get more overlap in L0 ranges.
different architectures respond very differently to l1 coefficients
Varying within sweeps
L1 coefficient
Architecture choice of nonlinearity:
ReLU
ProLUReLU
ProLUSTE
L1 Penalty type
L1: P(a)=||a||1
Sqrt(L1): P(a)=||a||1/21/2
Results
Let:
Lmodel be the CE loss of the model unperturbed on the data distribution
Lreconstructed be the CE loss of the model when activations are replaced with the reconstructed activations
Lzero be the CE loss of the model when activations are replaced with the zero vector
Degradation: or Information Lost. This measures how much information about the correct next token the model loses by having its activations a replaced with the SAE’s reconstruction SAE(a).
degradation=Lreconstructed−Lmodel
For L0<25:
The pareto-best architecture uses the ProLUSTE nonlinearity with an L1 sparsity penalty.
For L0>25:
There are no occurrences of models using ProLUSTE with L1 penalty in this L0 range
Of the remaining models, ProLUReLU with Sqrt(L1) penalty is pareto-best.
Further Investigation
MSE/L1 Pareto Frontier
The gradients of ProLU are not the gradients of the loss landscape, so it would be a reasonable default to expect these models to perform worse than a vanilla SAE. Indeed I expect they may perform worse on the optimization target, and that the reason why this is able to work is there is slack in the problem introduced by us being unable to optimize for our actual target directly—our current options are to optimize for L1 or Sqrt(L1) as sparsity proxies for what we actually want because L0 is not a differentiable metric.
Actual target: minimize L0 and bits lost Optimization (proxy) target: minimize L1 (or √L1)) and MSE
Because we’re not optimizing for the actual target, I am not so surprised that there may be weird tricks we can do to get more of what we want.
On this vein of thought, my prediction after seeing the good performance on the actual target (and prior to checking this prediction) was:
Despite improved performance on degradation/L0 ProLU SAEs will have the same or worse on the MSE/L1 curve. We may also see the higher performing architectures have greater L1/L0
Let’s check:
In favor of the hypothesis, while other architectures sometimes join it on the frontier, the Vanilla ReLU is present for the entirety of this Pareto frontier. On the other hand, at lower sparsity levels ProLUSTE joins it at the frontier. So the part where this change does not improve performance on the optimization target seems true, but it’s not clear that better performance on the actual target is coming from worse performance on the optimization target.
This suggests a possible reason for why the technique works well: Possibly the gains from this technique do not come from scale consistency so much as that it forced us to synthesize some gradients and those gradients happened to point more in the direction of what we actually want.
Here is the graph of L1 norm versus L0 norm:
This looks like it’s possible that what is working well here is the learned features are experiencing less suppression, but that may not be the only thing going on fixing this. Feature suppression is still consistent with the scale consistency hypothesis, as consistent undershooting would be an expected side effect if that is a real problem, since regular SAEs may be less able to filter unwanted activations if they are keeping biases near zero in order to minimize errors induced by scale inconsistency.
More investigation is needed here to create a complete or confident picture of what is cause of the performance gains in ProLU SAEs.
Unfortunately, I did not log √L1 so I can’t compare with that curve, but could load the models to create those graphs in follow-up work.
Acknowledgements
Noa Nabeshima and Arunim Agarwal gave useful feedback and editing help on the draft of this post.
Mason Krug for in depth editing of my grant proposal, which helped seed this writeup and clarify my communication.
ProLU: A Nonlinearity for Sparse Autoencoders
Abstract
This paper presents ProLU, an alternative to ReLU for the activation function in sparse autoencoders that produces a pareto improvement over both standard sparse autoencoders trained with an L1 penalty and sparse autoencoders trained with a Sqrt(L1) penalty.
ProLU(mi,bi)={miif mi+bi>0 and mi>00otherwiseSAEProLU(x)=ProLU((x−bdec)Wenc,benc)Wdec+bdecThe gradient wrt.b is zero, so we generate two candidate classes of differentiable ProLU:
ProLUReLU
∂∗ProLUReLU(mi,bi)∂bi=∂ProLUReLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwise
ProLUSTE
∂∗ProLUSTE(mi,bi)∂mi={1+miif mi>0 and mi+bi>00otherwise
∂∗ProLUSTE(mi,bi)∂bi={miif mi>0 and mi+bi>00otherwise
PyTorch Implementation
Introduction
SAE Context and Terminology
Learnable parameters of a sparse autoencoder:
Wenc : encoder weights
Wdec : decoder weights
benc : encoder bias
bdec : decoder bias
The output of an SAE is given by
SAE(x)=ReLU((x−bdec)Wenc+benc)Wdec+bdecTraining
An SAE is trained with gradient descent on
Ltrain=||x−SAE(x)||22+λP(encode(x))where λ is the sparsity penalty coefficient (often “L1 coefficient”) and P is the sparsity penalty function, used to encourage sparsity.
P is commonly the L1 norm ||a||1 but recently P(a)=||a||1212 has been shown to produce a Pareto improvement on the L0 and CE metrics. We will use this as a further baseline to compare against when assessing our models in addition to the standard ReLU-based SAE with L1 penalty.
Motivation: Inconsistent Scaling in Sparse Autoencoders
Due to the affine translation, sparse autoencoder features with nonzero encoder biases only perfectly reconstruct feature magnitudes at a single point.
This poses difficulties if activation magnitudes for a fixed feature tend to vary over a wide range. This potential problem motivates the concept of scale consistency:
A scale consistent response curve
The bias maintains its role in noise suppression, but no longer translates activation magnitudes when the feature is active.
The lack of gradients for the encoder bias term poses a challenge for learning with gradient descent. This paper will formalize an activation function which gives SAEs this scale-consistent response curve, and motivate and propose two plausible synthetic gradients, and compare scale-consistent models trained with the two synthetic gradients to standard SAEs and SAEs trained with Sqrt(L1) penalty.
Scale Consistency Desiderata
Conditional Linearity
1. SAEicent(v1)>0∧SAEicent(v2)>0⟹SAEicent(v1)+SAEicent(v2)=SAEicent(v1+v2)2. ∀vSAEicent(v)>0∧k>1⟹SAEicent(kv)=k⋅SAEicent(v)Noise Suppresion Threshold
3. benc<0⟹∃η∈(0,∞)∀ϵ∈(0,∞) s.t. SAEicent(η⋅v)=0∧SAEicent((η+ϵ)⋅v)>0Proportional ReLU (ProLU)
We define the Proportional ReLU (ProLU) as:
ProLU(mi,bi)={miif mi+bi>0 and mi>00otherwiseBackprop with ProLU:
To use ProLU in SGD-optimized models, we first address the lack of gradients wrt. the b term.
ReLU gradients:
For comparison and later use, we will first consider ReLU: partial derivatives are well defined for ReLU at all points other than xi=0:
∂ReLU(xi)∂xi={1if xi>00if xi<0Gradients of ProLU:
Partials of ProLU wrt.m are similarly well defined:
∂ProLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwiseHowever, they are not well defined wrt. b, so we must synthesize these.
Methods
Different synthetic gradient types
We train two classes of ProLU with different synthetic gradients. These are distinguished by their subscript:
ProLUReLU
ProLUSTE
They are identical in output, but have different synthetic gradients. I.e.
ProLUReLU(m,b)=ProLUSTE(m,b)∇∗ProLUReLU(m,b)≢∇∗ProLUSTE(m,b)Defining ProLUReLU: ReLU-like gradients
The first synthetic gradient is very similar to the gradient for ReLU. We retain the gradient wrt. m, and define the synthetic gradient wrt.b to be the same as the gradient wrt. m:
∂∗ProLUReLU(mi,bi)∂bi=∂ProLUReLU(mi,bi)∂mi={1if mi+bi>0 and mi>00otherwiseDefining ProLUSTE: Derivation from straight-through estimator
The second class of ProLU uses synthetic gradients for both b and m and can be motivated by framing ProLU and ReLU in terms of the threshold function, and a common choice of straight-through estimator (STE) for the threshold function. This is a plausible explanation for the observed empirical performance but it should be noted that there are many degrees of freedom and possible alternative
Setup
The threshold function Thresh is defined as follows:
Thresh(x)={1if x>00otherwiseWe will rephrase the partial derivative of ReLU in terms of the threshold function for ease of later notation:
∂ReLU(xi)∂xi={1if xi>00if xi<0=Thresh(xi)It is common to use a straight-through estimator (STE) to approximate the gradient of the threshold function:
∂∗Thresh(xi)∂xi=STThresh(xi)We can reframe ProLU in terms of the threshold function:
ProLU(mi,bi)=ReLU(mi)⋅Thresh(mi+bi)Synthetic Gradients wrt. m
Now, we take partial derivatives of ProLU wrt.m using the STE approximation for the threshold function:
∂∗ProLUSTE(mi,bi)∂mi=∂∗∂mi(ReLU(mi)⋅Thresh(mi+bi))=∂ReLU(mi)∂mi⋅Thresh(mi+bi)+ReLU(mi)⋅∂∗Thresh(mi+bi)∂mi=Thresh(mi)⋅Thresh(mi+bi)+ReLU(mi)⋅STThresh(mi+bi)=Thresh(mi)⋅Thresh(mi+bi)+miThresh(mi)⋅STThresh(mi+bi)Synthetic Gradients wrt. b
∂∗ProLUSTE(mi,bi)∂bi=∂∗∂bi(ReLU(mi)⋅Thresh(mi+bi))=∂ReLU(mi)∂bi⋅Thresh(mi+bi)+ReLU(mi)⋅∂∗Thresh(mi+bi)∂bi=0⋅Thresh(mi+bi)+ReLU(mi)⋅STThresh(mi+bi)=miThresh(mi)⋅STThresh(mi+bi)Choice of Straight-Through Estimator
There are many possible functions to use for STThresh(x). In our experiments, we take the derivative of ReLU as the choice of straight-through estimator. This choice has been used in training quantized neural nets.
STThresh(x):=Thresh(x)
then, synthetic gradients wrt.m are given by,
∂∗ProLUSTE(mi,bi)∂mi=Thresh(mi)⋅Thresh(mi+bi)+miThresh(mi)⋅Thresh(mi+bi)=(1+mi)⋅Thresh(mi)⋅Thresh(mi+bi)={1+miif mi>0 and mi+bi>00otherwiseand wrt.b are given by,
∂∗ProLUSTE(mi,bi)∂bi=miThresh(mi)⋅Thresh(mi+bi)={miif mi>0 and mi+bi>00otherwiseProLU Sparse Autoencoder
We can express the encoder of a ProLU SAE as
encodeProLU(x)=ProLU((x−bdec)Wenc,benc)No change is needed to the decoder. Thus,
SAEProLU(x)=decode(encodeProLU(x))Experiment Setup
Shared among all sweeps:
Adam optimizer, with:
β1=0.9,β2=0.999
batch size=4096
Data
Trained on gpt2 layer 6 pre-residual activations
Tokens: ~400m tokens from The Pile @ack(Alan Cooney’s pre-tokenized pile)
-> ~100k gradient steps
LR schedule
Warmup for 21−β2=2,000 steps in accordance with On the adequacy of untuned warmup for adaptive optimization
Linear warmup after each resample. Same value 2,000
Linear cooldown to 1⁄10 initial value over 20,000 steps starting at 75,000 steps
Anthropic resampling
I used 3e-6 as the dead threshold rather than 0
Resample at 25,000 and 50,000 steps
The proportion of the average encoder norm resampled to varied between sweeps
Normalization:
L2 normalization as proposed by Anthropic
SAE details
Dictionary expansion factor of 16
Tied decoder bias, untied encoder/decoder weights
Varying between sweeps:
Experiment 1:
lr=0.001
30 total runs
Resampled to 0.02 of avg encoder norm
Experiment 2:
lr=0.0003
48 total runs
Resampled to 0.02 of avg encoder norm
Experiment 3:
lr=0.001
30 total runs
Resampled to 0.2 of avg encoder norm
adjusted L1-coefficient ranges for each model, to get more overlap in L0 ranges.
different architectures respond very differently to l1 coefficients
Varying within sweeps
L1 coefficient
Architecture choice of nonlinearity:
ReLU
ProLUReLU
ProLUSTE
L1 Penalty type
L1: P(a)=||a||1
Sqrt(L1): P(a)=||a||1/21/2
Results
For L0<25:
The pareto-best architecture uses the ProLUSTE nonlinearity with an L1 sparsity penalty.
For L0>25:
There are no occurrences of models using ProLUSTE with L1 penalty in this L0 range
Of the remaining models, ProLUReLU with Sqrt(L1) penalty is pareto-best.
Further Investigation
MSE/L1 Pareto Frontier
The gradients of ProLU are not the gradients of the loss landscape, so it would be a reasonable default to expect these models to perform worse than a vanilla SAE. Indeed I expect they may perform worse on the optimization target, and that the reason why this is able to work is there is slack in the problem introduced by us being unable to optimize for our actual target directly—our current options are to optimize for L1 or Sqrt(L1) as sparsity proxies for what we actually want because L0 is not a differentiable metric.
Actual target: minimize L0 and bits lost
Optimization (proxy) target: minimize L1 (or √L1)) and MSE
Because we’re not optimizing for the actual target, I am not so surprised that there may be weird tricks we can do to get more of what we want.
On this vein of thought, my prediction after seeing the good performance on the actual target (and prior to checking this prediction) was:
Let’s check:
In favor of the hypothesis, while other architectures sometimes join it on the frontier, the Vanilla ReLU is present for the entirety of this Pareto frontier. On the other hand, at lower sparsity levels ProLUSTE joins it at the frontier. So the part where this change does not improve performance on the optimization target seems true, but it’s not clear that better performance on the actual target is coming from worse performance on the optimization target.This suggests a possible reason for why the technique works well:
Possibly the gains from this technique do not come from scale consistency so much as that it forced us to synthesize some gradients and those gradients happened to point more in the direction of what we actually want.
Here is the graph of L1 norm versus L0 norm:
This looks like it’s possible that what is working well here is the learned features are experiencing less suppression, but that may not be the only thing going on fixing this. Feature suppression is still consistent with the scale consistency hypothesis, as consistent undershooting would be an expected side effect if that is a real problem, since regular SAEs may be less able to filter unwanted activations if they are keeping biases near zero in order to minimize errors induced by scale inconsistency.
More investigation is needed here to create a complete or confident picture of what is cause of the performance gains in ProLU SAEs.
Unfortunately, I did not log √L1 so I can’t compare with that curve, but could load the models to create those graphs in follow-up work.
Acknowledgements
Noa Nabeshima and Arunim Agarwal gave useful feedback and editing help on the draft of this post.
Mason Krug for in depth editing of my grant proposal, which helped seed this writeup and clarify my communication.
How to Cite