I’ve trained some sparse MLPs with 20K neurons on a 4L TinyStories model with ReLU activations and no layernorm and I took a look at them after reading this post. For varying integer , I applied an L1 penalty of on the average of the activations per token, which seems pretty close to doing an L1 of on the sum of the activations per token. Your L1 of with 12K neurons is sort of like in my setup. After reading your post, I checked out the cosine similarity between encoder/decoder of original mlp neurons and sparse mlp neurons for varying values of (make sure to scroll down once you click one of the links!):
S=3
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp3
S=4
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp4
S=5
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp5
S=6
https://plotly-figs.s3.amazonaws.com/sparse_mlp_L1_2exp6
I think the behavior you’re pointing at is clearly there at lower L1s on layers other than layer 0 (? what’s up with that?) and sort of decreases with higher L1 values, to the point that the behavior is there a bit at S=5 and almost not there at S=6. I think the non-dead sparse neurons are almost all interpretable at S=5 and S=6.
Original val loss of model: 1.128 ~= 1.13.
Zero ablation of MLP loss values per layer: [3.72, 1.84, 1.56, 2.07].
S=6 loss recovered per layer
Layer 0: 1-(1.24-1.13)/(3.72-1.13): 96% of loss recovered
Layer 1: 1-(1.18-1.13)/(1.84-1.13): 93% of loss recovered
Layer 2: 1-(1.21-1.13)/(1.56-1.13): 81% of loss recovered
Layer 3: 1-(1.26-1.13)/(2.07-1.13): 86% of loss recovered
Compare to 79% of loss-recovered from Anthropic’s A/1 autoencoder with 4K features and a pretty different setup.
(Also, I was going to focus on S=5 MLPs for layers 1 and 2, but now I think I might instead stick with S=6. This is a little tricky because I wouldn’t be surprised if tiny-stories MLP neurons are interpretable at higher rates than other models.)
Basically I think sparse MLPs aren’t a dead end and that you probably just want a higher L1.
I think at least some GPT2 models have a really high-magnitude direction in their residual stream that might be used to preserve some scale information after LayerNorm. [I think Adam Scherlis originally mentioned or showed the direction to me, but maybe someone else?]. It’s maybe akin to the water-droplet artifacts in StyleGAN touched on here: https://arxiv.org/pdf/1912.04958.pdf