If anyone wants to work on this or knows people who might, I’d be interested in funding work on this (or helping secure funding—I expect that to be pretty easy to do).
Sounds plausible, but why does this differentially impact the generalizing algorithm over the memorizing algorithm?
Perhaps under normal circumstances both are learned so fast that you just don’t notice that one is slower than the other, and this slows both of them down enough that you can see the difference?
List sorting does not play well with few-shot mostly doesn’t replicate with davinci-002.
When using length-10 lists (it crushes length-5 no matter the prompt), I get:
32-shot, no fancy prompt: ~25%
0-shot, fancy python prompt: ~60%
0-shot, no fancy prompt: ~60%
So few-shot hurts, but the fancy prompt does not seem to help. Code here.
I’m interested if anyone knows another case where a fancy prompt increases performance more than few-shot prompting, where a fancy prompt is a prompt that does not contain information that a human would use to solve the task. This is because I’m looking for counterexamples to the following conjecture: “fine-tuning on k examples beats fancy prompting, even when fancy prompting beats k-shot prompting” (for a reasonable value of k, e.g. the number of examples it would take a human to understand what is going on).
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don’t know.
Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step).
Daniel Filan: But I would’ve guessed that there wouldn’t be a significant complexity difference between the frequencies. I guess there’s a complexity difference in how many frequencies you use.
Vikrant Varma: Yes. That’s one of the differences: how many you use and their relative strength and so on. Yeah, I’m not really sure. I think this is a question we pick out as a thing we would like to see future work on.
My pet hypothesis here is that (a) by default, the network uses whichever frequencies were highest at initialization (for which there is significant circumstantial evidence) and (b) the amount of interference differs significantly based on which frequencies you use (which in turn changes the quality of the logits holding parameter norm fixed, and thus changes efficiency).
In principle this can be tested by randomly sampling frequency sets, simulating the level of interference you’d get, using that to estimate the efficiency + critical dataset size for that grokking circuit. This gives you a predicted distribution over critical dataset sizes, which you could compare against the actual distribution.
Tbc there are other hypotheses too, e.g. perhaps different frequency sets are easier / harder to implement by the neural network architecture.
This suggestion seems less expressive than (but similar in spirit to) the “rescale & shift” baseline we compare to in Figure 9. The rescale & shift baseline is sufficient to resolve shrinkage, but it doesn’t capture all the benefits of Gated SAEs.
The core point is that L1 regularization adds lots of biases, of which shrinkage is just one example, so you want to localize the effect of L1 as much as possible. In our setup L1 applies to , so you might think of as “tainted”, and want to use it as little as possible. The only thing you really need L1 for is to deter the model from setting too many features active, i.e. you need it to apply to one bit per feature (whether that feature is on / off). The Heaviside step function makes sure we are extracting just that one bit, and relying on for everything else.
Re dictionary width, 2**17 (~131K) for most Gated SAEs, 3*(2**16) for baseline SAEs, except for the (Pythia-2.8B, Residual Stream) sites we used 2**15 for Gated and 3*(2**14) for baseline since early runs of these had lots of feature death. (This’ll be added to the paper soon, sorry!). I’ll leave the other Qs for my co-authors
Great paper! The gating approach is an interesting way to learn the JumpReLU threshold and it’s exciting that it works well. We’ve been working on some related directions at OpenAI based on similar intuitions about feature shrinking.
Some questions:
Is b_mag still necessary in the gated autoencoder?
Did you sweep learning rates for the baseline and your approach?
How large is the dictionary of the autoencoder?
Yep, you’re totally right—thanks!
I haven’t fully worked through the maths, but I think both IG and attribution patching break down here? The fundamental problem is that the discontinuity is invisible to IG because it only takes derivatives. Eg the ReLU and Jump ReLU below look identical from the perspective of IG, but not from the perspective of activation patching, I think.
Nice work! I’m not sure I fully understand what the “gated-ness” is adding, i.e. what the role the Heaviside step function is playing. What would happen if we did away with it? Namely, consider this setup:
Let and be the encoder and decoder functions, as in your paper, and let be the model activation that is fed into the SAE.
The usual SAE reconstruction is , which suffers from the shrinkage problem.
Now, introduce a new learned parameter , and define an “expanded” reconstruction , where denotes elementwise multiplication.
Finally, take the loss to be:
.
where ensures the decoder gets no gradients from the first term. As I understand it, this is exactly the loss appearing in your paper. The only difference in the setup is the lack of the Heaviside step function.
Did you try this setup? Or does it fail for an obvious reason I missed?
Yep, the intuition here indeed was that L1 penalised reconstruction seems to be okay for teaching a standard SAE’s encoder to detect which features are on (even if features get shrunk as a result), so that is effectively what this auxiliary loss is teaching the gate sub-layer to do, alongside the sparsity penalty. (The key difference being we freeze the decoder in the auxiliary task, which the ablation study shows helps performance.) Maybe to put it another way, this was an auxiliary task that we had good evidence would teach the gate sublayer to detect active features reasonably well, and it turned out to give good results in practice. It’s totally possible though that there are better auxiliary tasks (or even completely different loss functions) out there that we’ve not explored.
Hey Sam, thanks—you’re right. The definition of reconstruction bias is actually the argmin of
which I’d (incorrectly) rearranged as the expression in the paper. As a result, the optimum is
That being said, the derivation we gave was not quite right, as I’d incorrectly substituted the optimised loss rather than the original reconstruction loss, which makes equation (10) incorrect. However the difference between the two is small exactly when gamma is close to one (and indeed vanishes when there is no shrinkage), which is probably why we didn’t pick this up. Anyway, we plan to correct these two equations and update the graphs, and will submit a revised version.
Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it.
This was actually the key motivation for building this metric in the first place, instead of just looking at the ratio . Looking at the that would optimize the reconstruction loss ensures that we’re capturing only bias from the L1 regularization, and not capturing the “inherent” need to shrink the vector given these nonzero angles. (In particular, if we computed for Gated SAEs, I expect that would be below 1.)
I think the main thing we got wrong is that we accidentally treated as though it were . To the extent that was the main mistake, I think it explains why our results still look how we expected them to—usually is going to be close to 1 (and should be almost exactly 1 if shrinkage is solved), so in practice the error introduced from this mistake is going to be extremely small.
We’re going to take a closer look at this tomorrow, check everything more carefully, and post an update after doing that. I think it’s probably worth waiting for that—I expect we’ll provide much more detailed derivations that make everything a lot clearer.
Nice. I tried to do something similar (except making everything leaky with polynomial tails, so
y = (y+torch.sqrt(y**2+scale**2)) * (1+(y+threshold)/torch.sqrt((y+threshold)**2+scale**2)) / 4
where the first part (y+torch.sqrt(y**2+scale**2)) is a softplus, and the second part (1+(y+threshold)/torch.sqrt((y+threshold)**2+scale**2)) is a leaky cutoff at the value threshold.
But I don’t think I got such clearly better results, so I’m going to have to read more thoroughly to see what else you were doing that I wasn’t :)
Oh, one other issue relating to this: in the paper it’s claimed that if is the argmin of then is the argmin of . However, this is not actually true: the argmin of the latter expression is . To get an intuition here, consider the case where and are very nearly perpendicular, with the angle between them just slightly less than . Then you should be able to convince yourself that the best factor to scale either or by in order to minimize the distance to the other will be just slightly greater than 0. Thus the optimal scaling factors cannot be reciprocals of each other.
ETA: Thinking on this a bit more, this might actually reflect a general issue with the way we think about feature shrinkage; namely, that whenever there is a nonzero angle between two vectors of the same length, the best way to make either vector close to the other will be by shrinking it. I’ll need to think about whether this makes me less convinced that the usual measures of feature shrinkage are capturing a real thing.
ETA2: In fact, now I’m a bit confused why your figure 6 shows no shrinkage. Based on what I wrote above in this comment, we should generally expect to see shrinkage (according to the definition given in equation (9)) whenever the autoencoder isn’t perfect. I guess the answer must somehow be “equation (10) actually is a good measure of shrinkage, in fact a better measure of shrinkage than the ‘corrected’ version of equation (10).” That’s pretty cool and surprising, because I don’t really have a great intuition for what equation (10) is actually capturing.
Ah thanks, you’re totally right—that mostly resolves my confusion. I’m still a little bit dissatisfied, though, because the term is optimizing for something that we don’t especially want (i.e. for to do a good job of reconstructing ). But I do see how you do need to have some sort of a reconstruction-esque term that actually allows gradients to pass through to the gated network.
Possibly I’m missing something, but if you don’t have , then the only gradients to and come from (the binarizing Heaviside activation function kills gradients from ), and so would be always non-positive to get perfect zero sparsity loss. (That is, if you only optimize for L1 sparsity, the obvious solution is “none of the features are active”.)
(You could use a smooth activation function as the gate, e.g. an element-wise sigmoid, and then you could just stick with from the beginning of Section 3.2.2.)
(The question in this comment is more narrow and probably not interesting to most people.)
The limitations section includes this paragraph:
One worry about increasing the expressivity of sparse autoencoders is that they will overfit when
reconstructing activations (Olah et al., 2023, Dictionary Learning Worries), since the underlying
model only uses simple MLPs and attention heads, and in particular lacks discontinuities such as step
functions. Overall we do not see evidence for this. Our evaluations use held-out test data and we
check for interpretability manually. But these evaluations are not totally comprehensive: for example,
they do not test that the dictionaries learned contain causally meaningful intermediate variables in the
model’s computation. The discontinuity in particular introduces issues with methods like integrated
gradients (Sundararajan et al., 2017) that discretely approximate a path integral, as applied to SAEs
by Marks et al. (2024).I’m not sure I understand the point about integrated gradients here. I understand this sentence as meaning: since model outputs are a discontinuous function of feature activations, integrated gradients will do a bad job of estimating the effect of patching feature activations to counterfactual values.
If that interpretation is correct, then I guess I’m confused because I think IG actually handles this sort of thing pretty gracefully. As long as the number of intermediate points you’re using is large enough that you’re sampling points pretty close to the discontinuity on both sides, then your error won’t be too large. This is in contrast to attribution patching which will have a pretty rough time here (but not really that much worse than with the normal ReLU encoders, I guess). (And maybe you also meant for this point to apply to attribution patching?)
We use learning rate 0.0003 for all Gated SAE experiments, and also the GELU-1L baseline experiment. We swept for optimal baseline learning rates on GELU-1L for the baseline SAE to generate this value.
For the Pythia-2.8B and Gemma-7B baseline SAE experiments, we divided the L2 loss by E||x||2, motivated by wanting better hyperparameter transfer, and so changed learning rate to 0.001 or 0.00075 for all the runs (currently in Figure 1, only attention output pre-linear uses 0.00075. In the rerelease we’ll state all the values used). We didn’t see noticable difference in the Pareto frontier changing between 0.001 and 0.0075 so did not sweep the baseline hyperparameter further than this.