Abstract
A key step in reverse engineering neural networks is to decompose them into simpler parts that can be studied in relative isolation.
Linear parameter decomposition— a framework that has been proposed to resolve several issues with current decomposition methods—decomposes neural network parameters into a sum of sparsely used vectors in parameter space.
However, the current main method in this framework, Attribution-based Parameter Decomposition (APD), is impractical on account of its computational cost and sensitivity to hyperparameters.
In this work, we introduce Stochastic Parameter Decomposition (SPD), a method that is more scalable and robust to hyperparameters than APD, which we demonstrate by decomposing models that are slightly larger and more complex than was possible to decompose with APD.
We also show that SPD avoids other issues, such as shrinkage of the learned parameters, and better identifies ground truth mechanisms in toy models.
By bridging causal mediation analysis and network decomposition methods, this demonstration opens up new research possibilities in mechanistic interpretability by removing barriers to scaling linear parameter decomposition methods to larger models.
We release a library for running SPD and reproducing our experiments at https://github.com/goodfire-ai/spd.
Links:
I generally expect the authors of this paper to produce high-quality work, so my priors are on me misunderstanding something. Given that:
I don’t see how this method does anything at all, and the culprit is the faithfulness loss.
As you have demonstrated, if a sparse, overcomplete set of vectors is subject to a linear transformation, then SPD’s decomposition basically amounts to the whole linear transformation. The only exception is if your sparse basis is not overcomplete in the larger of the two dimensions, in which case SPD finds that basis in the smaller dimension. Since you’re working on just one weight at a time, linear transformations are the only case to consider. So all you can do is either find an exact basis or find the whole linear transformation.
We see this pretty clearly in the single-layer compressed computation example. The input is a linear basis with 100 elements (weirdly immersed in 1000-dimensional space) and SPD just finds the basis. The MLP middle section is an over complete set of 200 vectors (since each input will be mapped to a pair of half-lines joined at the origin) here the SPD just finds the whole linear transformation.
In fact, if we have superposition, I would expect the relevant components of the model to sum to more than the weights of the model. This is kind of just what superposition means, the same weights are being used for multiple computations at once.
Am I misunderstanding something here? Is there a reason that having overcomplete bases decomposed in this way is actually desirable?
No, that’s not how it works
Networks have non-linearities. SPD will decompose you a matrix into a single linear transformation if what the network is doing with that matrix really is just applying one global linear transformation. If e.g. there are non-linearities right after the matrix that aren’t just always switched on, SPD will usually decompose the matrix into many sub-components.[1]
I’m not sure what you mean by ‘working on just one weight at a time’. The stochastic-layerwise reconstruction loss does do forward passes replacing only one matrix in the network at a time with a randomly ablated version of the same matrix. But the stochastic reconstruction loss does forward passes replacing all matrices at once.
But I think I must be misunderstanding what you mean here, because even if we didn’t have the stochastic reconstruction loss I don’t see how that would matter for this.
That’s not how it works in our existing framework for circuits in superposition. The weights for particular circuits there actually literally do sum to the weights of the whole network. I’ve been unable to come up with a general framework that doesn’t exhibit this weight linearity.
I wouldn’t say that? Computation in superposition inevitably involves different circuits interfering with each other, because the weights of one circuit have non-zero inner product with the activations of another. But there is still a particular set of vectors in parameter space such that each vector implements one circuit.
Superposition can give you an overcomplete basis of variables in activation space, but it cannot give you an overcomplete basis of circuits acting on these variables in parameter space. There can’t be more circuits than weights.
Well, depending on what the network is actually computing with these non-linearities, of course. If it’s not computing many different things, or not using the results of many of the computations for anything downstream, SPD probably won’t find many components that ever activate.
I went back to read Compressed Computation is (probably) not Computation in Superposition more thoroughly, and I can see that I’ve used “superposition” in a way which is less restrictive than the one which (I think) you use. Every usage of “superposition” in my first comment should be replaced with “compressed computation”.
My understanding is that SPD cannot decompose an n×m matrix into more than max(n,m) subcomponents, and if all subcomponents are “live” i.e. active on a decent fraction of the inputs, then it will have to have max(n,m) components to work. If, therefore, a computation is being done on 500 concepts at once using a 100-200-100 dimensional MLP, SPD can’t possibly figure out what’s going on. (This may be wrong and is not crucial to my point)
Edit: as you pointed out, this might only apply when there’s not a nonlinearity after the weight. But every Wout in a transformer has a connection running from it directly to the output logits through Wunembed. So SPD will struggle to interpret any of the output weights of transformer MLPs. This seems bad.
My main thoughts:
I still do not understand what the benefit of SPD is here. Your decompositions on the Compressed Computation seem to reveal that the Win matrices are reading off 100 specific vectors, and reveal absolutely nothing about Wout. But Win is reading those vectors off a 1000-dimensional vector space where there’s no interference between features. The only interesting part is what the Wout matrices are doing, since it’s (allegedly) mapping information about 100 features from a 50-dimensional space to a 1000-dimensional space.
Thanks to CCi(p)nCiS, we know that the toy model is not even doing computation in superposition, which is the case which SPD seems to be based on. It’s actually doing something really weird with the “noise”, which doesn’t actually behave well.
But since the embedding and unembedding matrices (which determine the interference between features in the residual stream) are fixed during training of the toy model, the “noise” is not actually being optimized, which is a crucial difference between the toy model and actual models.
So we have a demonstration of SPD in which:
SPD doesn’t give us any interesting information (Wout matrices not decomposed, Win matrices are not operating on compressed/superposed inputs)
The assumptions which SPD is based on (true computation in superposition) don’t seem to hold and
The toy model differs from real neural networks in a really important way (the learning process can’t influence how features interfere with each other in the residual stream)
SPD can decompose an n×m matrix into more than max(n,m) subcomponents.
I guess there aren’t any toy models in this paper that directly showcase this, but I’m pretty confident it’s true, because
I don’t see why it wouldn’t be able to.
I’ve decomposed a weight matrix in a tiny LLM and got out way more than max(n,m) live subcomponents. That’s a very preliminary result though, you probably shouldn’t put that much stock in it.
I think it’s the other way around. If you try to implement computation in superposition in a network with a residual stream, you will find that about the best thing you can do with the Wout is often to just use it as a global linear transformation. Most other things you might try to do with it drastically increases noise for not much pay-off. In the cases where networks are doing that, I would want SPD to show us this global linear transform.
They’re embedded randomly in the space, so there is interference between them in the sense of them having non-zero inner products.
Yes. I agree that this makes the model not as great a testbed as we originally hoped.
What do you mean by “a global linear transformation” as in what kinds of linear transformations are there other than this? If we have an MLP consisting of multiple computations going on in superposition (your sense) I would hope that the W_in would be decomposed into co-activating subcomponents corresponding to features being read into computations, and the W_out would also be decomposed into co-activating subcomponents corresponding to the outputs of those computations being read back into the residual stream. The fact that this doesn’t happen tells me something is wrong.
As to the issue with the maximum number of components: it seems to me like if you have five sparse features (in something like the SAE sense) in superposition and you apply a rotation (or reflection, or identity transformation) then the important information would be contained in a set of five rank 1 transformations, basically a set of maps from A to B. This doesn’t happen for the identity, does it happen for a rotation or reflection?
Finally, as to “introducing noise” by doing things other than a global linear transformation, where have you seen evidence for this? On synthetic (and thus clean) datasets, or actually in real datasets? In real scenarios, your model will (I strongly believe) be set up such that the “noise” between interfering features is actually helpful for model performance, since the world has lots of structure which can be captured in the particular permutation in which you embed your overcomplete feature set into a lower dimensional space.
Linear transformations that are the sum of weights for different circuits in superposition, for example.
What I am trying to say is that I expect networks to implement computation in superposition by linearly adding many different subcomponents to create W_in, but I mostly do not expect networks to create W_out by linearly adding many different subcomponents that each read-out a particular circuit output back into the residual stream, because that’s actually an incredibly noisy operation. I made this mistake at first as well. This post still has a faulty construction for W_out because of my error. Linda Linsefors finally corrected me on this a couple months ago.
I disagree that if all we’re doing is applying a linear transformation to the entire space of superposed features, rather than, say, performing different computations on the five different features, that it would be desirable to split this linear transformation into the five features.
Uh, I think this would be a longer discussion than I feel up for at the moment, but I disagree with your prediction. I agree that the representational geometry in the model will be important and that it will be set up to help the model, but interference of circuits in superposition cannot be arranged to be helpful in full generality. If it were, I would take that as pretty strong evidence that whatever is going on in the model is not well-described by the framework of superposition at all.
Being linearly independent is sufficient in this case to read off each x_i with zero interference. Rank-one matrices are equivalent to (linear functional) * vector, and so we just pick the dual basis as our linear functionals, and extend them to whole space.
If you have 100 orthogonal linear probes to read with, yes. But since there’s only 50 neurons, the actual circuits for different input features in the network will have interference to deal with.
Edit: To be clear, of course the network itself cannot perform the task precisely. I’m simply claiming that you can precisely mimic the behaviour of Win with 100 rank-1 components, by just reading off the basis, as SPD does in this case. The fact that the vi themselves are not orthogonal is irrelevant.
To be concrete: if we have 100 linearly independent vectors vi, we can extend this to a basis of the whole 1000-dimensional space. Let V be the change of basis matrix from the standard basis to this basis. Then we can write Win = (WinV−1)V, where we can pick (WinV−1) arbitrarily.
If we write WinV−1 as a sum of rank-1 matrices Wc, then WcV will sum to Win, and WcV is still rank-one since the image of WcV = image of Wc
So we can assume wlog that our vi lie along the standard basis, i.e: that they are orthogonal with respect to standard inner product.
As a more productive question, say we had an LLM, which, amongst other things, if there is a known bigram encoded in the residual stream of the form t1A+t2B (corresponding to known bigram t1t2), potentially with interference from other aspects, outputs a consistent vector vt1t2 into the residual stream from an MLP layer. This is how GPT-2 encodes known bigrams, hence the relevance.
And say that there are quadratically many known bigrams as a function of hidden neuron size, so that in particular there are more bigrams than residual stream dimension. As far as I know, an appropriately randomly initialized network should be able to accomplish this task (or at least with Win random)
Is the goal for SPD to learn components for Win such that any given component only fires non-negligibly on a single bigram? Or is it ok if components fire on multiple different bigrams? I am trying to reason through how SPD would act in this case.
I’d have to think about the exact setup here to make sure there’s no weird caveats, but my first thought is that for Win, this ought to be one component per bigram, firing exclusively for that bigram.
Sure, yes, that’s right. But I still wouldn’t take this to be equivalent to our vi literally being orthogonal, because the trained network itself might not perfectly learn this transformation.
If there is some internal gradient descent reason for it being easier to learn to read off orthogonal vectors then I take it back. I feel like I am being too pedantic here, in any case.
An intuition pump: Imagine the case of two scalar features c1,c2 being embedded along vectors f1,f2. If you consider a series that starts with f1,f2 being orthogonal, then gives them ever higher cosine similarity, I’d expect the network to have ever more trouble learning to read out c1,c2 , until we hit f1=f2, at which point the network definitely cannot learn to read the features out at all. I don’t know how the learning difficulty behaves over this series exactly, but it sure seems to me like it ought to go up monotonically at least.
Another intuition pump: The higher the cosine similarity between the features, the larger the norm of the rows of V−1 will be, with norm infinity in the limit of cosine similarity going to one.
I agree that at cosine similarity O(1√1000), it’s very unlikely to be a big deal yet.
Yeah that makes sense, thanks.