‘Fundamental’ vs ‘applied’ mechanistic interpretability research
When justifying my mechanistic interpretability research interests to others, I’ve occasionally found it useful to borrow a distinction from physics and distinguish between ‘fundamental’ versus ‘applied’ interpretability research.
Fundamental interpretability research is the kind that investigates better ways to think about the structure of the function learned by neural networks. It lets us make new categories of hypotheses about neural networks. In the ideal case, it suggests novel interpretability methods based on new insights, but is not the methods themselves.
A Mathematical Framework for Transformer Circuits (Elhage et al., 2021)
Toy Models of Superposition (Elhage et al., 2022)
Polysemanticity and Capacity in Neural Networks (Scherlis et al., 2022)
Interpreting Neural Networks through the Polytope Lens (Black et al., 2022)
Causal Abstraction for Faithful Model Interpretation (Geiger et al., 2023)
Research agenda: Formalizing abstractions of computations (Jenner, 2023)
Work that looks for ways to identify modules in neural networks (see LessWrong ‘Modularity’ tag).
Applied interpretability research is the kind that uses existing methods to find the representations or circuits that particular neural networks have learned. It generally involves finding facts or testing hypotheses about a given network (or set of networks) based on assumptions provided by theory.
Steering GPT-2-XL by adding an activation vector (Turner et al., 2023)
Discovering Latent Knowledge in Language Models (Burns et al., 2022)
The Singular Value Decompositions of Transformer Weight Matrices are Highly Interpretable (Millidge et al., 2022)
In-context Learning and Induction Heads (Olsson et al., 2022)
We Found An Neuron in GPT-2 (Miller et al., 2023)
Language models can explain neurons in language models (Bills et al., 2023)
Acquisition of Chess Knowledge in AlphaZero (McGrath et al., 2021)
Although I’ve found the distinction between fundamental and applied interpretability useful, it’s not always clear cut:
Sometimes articles are part fundamental, part applied (e.g. arguably ‘A Mathematical Framework for Transformer Circuits’ is mostly theoretical, but also studies particular language models using new theory).
Sometimes articles take generally accepted ‘fundamental’—but underutilized—assumptions and develop methods based on them (e.g. Causal Scrubbing, where the key underutilized fundamental assumption was that the structure of neural networks can be well studied using causal interventions).
Other times the distinction is unclear because applied interpretability feeds back into fundamental interpretability, leading to fundamental insights about the structure of computation in networks (e.g. the Logit Lens lends weight to the theory that transformer language models do iterative inference).
Why I currently prioritize fundamental interpretability
Clearly both fundamental and applied interpretability research are essential. We need both in order to progress scientifically and to ensure future models are safe.
But given our current position on the tech tree, I find that I care more about fundamental interpretability.
The reason is that current interpretability methods are unsuitable for comprehensively interpreting networks on a mechanistic level. So far, our methods only seem to be able to identify particular representations that we look for or describe how particular behaviors are carried out. But they don’t let us identify all representations or circuits in a network or summarize the full computational graph of a neural network (whatever that might mean). Let’s call the ability to do these things ‘comprehensive interpretability’ .
We need comprehensive interpretability in order to have strong-ish confidence about whether dangerous representations or circuits exist in our model. If we don’t have strong-ish confidence, then many theories of impact for interpretability are inordinately weakened:
We’re a lot less able to use interpretability as a ‘force multiplier on alignment research’ because we can’t trust that our methods haven’t missed something crucial. This is particularly true when models are plausibly optimizing against us and hiding dangerous thoughts in places we aren’t looking. A similar pattern holds for theories of impact based on ‘Empirical evidence for/against threat models’, ‘Improving human feedback’, and ‘Informed oversight’.
We can’t be confident about our interpretability audits. Not only does this raise the risk that we’ll miss something, but it makes it much harder to justify including interpretability in regulations, since effective regulation usually requires technical clarity. It also makes it harder for clear norms around safety to form.
We don’t get the coordination/cooperation benefits resulting from some actors being able to actually trust other actors’ systems.
We definitely can’t use our interpretability methods in the loss function. To be clear, we probably shouldn’t do this even if we believed we had comprehensive interpretability. We’d probably want provably comprehensive interpretability (or some other reason to believe that our interpretability methods wouldn’t simply be circumvented) before we could safely justify using them in the loss function.
For most of these theories of impact, the relationship feels like it might be nonlinear: A slight improvement to interpretability that nevertheless falls short of comprehensive interpretability does not lead to proportional safety gains; only when we cross a threshold to something resembling comprehensive interpretability would we get the bulk of the safety gains. And right now, even though there’s a lot of valuable applied work to be done, it feels to me like progress in fundamental interpretability is the main determinant of whether we cross that threshold.
Similar terms for ‘comprehensive interpretability’ include Anthropic’s notion of ‘enumerative safety’, Evan Hubinger’s notion of ‘worst-case inspection transparency’, and Erik Jenner’s notion of ‘quotient interpretability’.
How likely do you think bilinear layers & dictionary learning will lead to comprehensive interpretability?
Are there other specific areas you’re excited about?
Bilinear layers—not confident at all! It might make structure more amenable to mathematical analysis so it might help? But as yet there aren’t any empirical interpretability wins that have come from bilinear layers.
Dictionary learning—This is one of my main bets for comprehensive interpretability.
Other areas—I’m also generally excited by the line of research outlined in https://arxiv.org/abs/2301.04709
Now that I actually think about it, I have some ideas about how we can cluster neurons together if we are using bilinear layers. Because of this, I am starting to like bilinear layers a bit more, and I am feeling much more confident about the problem of interpreting neural networks as long as the neural networks have an infrastructure that is suitable for interpretability. I am going to explain everything in terms of real-valued mappings, but everything I say can be extended to complex and quaternionic matrices (but one needs to be a little bit more careful about conjugations,transposes, and adjoints, so I will leave the complex and quaternionic cases as an exercise to the reader).
Suppose that A1,…,Ar are n×n-real symmetric matrices. Then define a mapping fA1,…,Ar:Rn→Rr by setting fA1,…,Ar(x)=⟨A1x,x⟩,…,⟨Arx,x⟩.
Now, given a collection A1,…,Ar of n×n-real matrices, define a partial mapping LA1,…,Ar;d:Md(R)r→[0,∞) by setting LA1,…,Ar;d(X1,…,Xr)=ρ(A1⊗X1+⋯+Ar⊗Xr)ρ(X1⊗X1+⋯+Xr⊗Xr)1/2 where ρ denotes the spectral radius and ⊗ denotes the tensor product. Then we say that (X1,…,Xr)∈Md(R)r is a real L2,d-spectral radius dimensionality reduction (LSRDR) if LA1,…,Ar;d(X1,…,Xr) is locally maximized. One can compute LSRDRs using a variant gradient ascent combined with the power iteration technique for finding the dominant left and right eigenvectors and eigenvalues of A1⊗X1+⋯+Ar⊗Xr and X1⊗X1+⋯+Xr⊗Xr.
If X1,…,Xr is an LSRDR of A1,…,Ar, then you should be able to find real matrices R,S where Xj=RAjS for 1≤j≤r. Furthermore, there should be a constant α where RS=αId. We say that the LSRDR X1,…,Xr is normalized if α=1, so let’s assume that X1,…,Xr is a normalized LSRDR. Then define P=SR. Then P should be a (not-necessarily orthogonal, so P2=P but we could have P≠PT) projection matrix of rank d. If A1,…,Ar are all symmetric, then the matrix P should be an orthogonal projection. The vector space im(P) will be a cluster of neurons. We can also determine which elements of this cluster are most prominent.
Now, define a linear superoperator Γ(A1,…,Ar;X1,…,Xr):Mn,d(R)→Mn,d(R) by setting Γ(A1,…,Ar;X1,…,Xr)(X)=A1XXT1+⋯+ArXXTr and set Γ(A1,…,Ar;X1,…,Xr)T=Γ(AT1,…,ATr;XT1,…,XTr) which is the adjoint of Γ(A1,…,Ar;X1,…,Xr) where we endow Mn,d(R) with the Frobenius inner product. Let UR denote a dominant eigenvector of Γ(A1,…,Ar;X1,…,Xr) and let UL denote a dominant eigenvector ofΓ(A1,…,Ar;X1,…,Xr)T. Then after multiplying UR,UL by constant real factors, the matrices UR⋅ST,UL⋅R will be (typically distinct) positive definite trace 1 matrices of rank d with im(P)=im(UR⋅ST)=im(UL⋅R). If we retrained the LSRDR but with a different initialization, then the matrices P,UR⋅ST,UL⋅R will still remain the same.
If v∈Rn,∥v∥=1, then the values ⟨UR⋅STv,v⟩,⟨UL⋅Rv,v⟩ will be numbers in the in the interval [0,1] that measure how much the vector v belongs in the cluster.
If O is an r×r-orthogonal matrix, then the matrices P,UR⋅ST,UL⋅R will remain the same if they were trained on O∘fA1,…,Ar instead of fA1,…,Ar, so the matrices P,UR⋅ST,UL⋅R care about just the inner product space structure of Rr while ignoring any of the other structure of Rr. Let Pf=UR⋅ST,Qf=UL⋅R.
We can then use LSRDRs to compute the backpropagation of a cluster throughout the network.
Suppose that f=f1∘⋯∘fn where each fj is a bilinear layer. Then whenever fj:Rn→Rr is a bilinear mapping, and P∈Mr(R) is a positive semidefinite matrix that represents a cluster in Rr, the positive semidefinite matrices PP∘f,QP∘f represent clusters in Rn.
I have not compared LSRDRs to other techniques to other clustering and dimensionality reduction techniques such as higher order singular value decompositions, but I like LSRDRs since my computer calculations indicate that they are often unique.
A coordinate free perspective:
Suppose that V,W are real finite dimensional inner product spaces. Then we say that a function f:V→W is a quadratic form if for each g∈W∗, the mapping g∘f is a quadratic form. We say that a linear operator A:V→V⊗W is symmetric if for each w∈W∗, the operator (1V⊗w)A is symmetric. The quadratic forms f:V→W can be put into a canonical one-to-one correspondence with the symmetric linear operators A:V→V⊗W.
If A:U2→V2⊗W,B:U1→V1⊗W is an arbitrary linear operator, then define Γ(A,B):L(U1,U2)→L(V1,V2) by letting Γ(A,B)(X)=TrW(AXBT) where TrW denotes the partial trace.
Given a linear mapping A:V→V⊗W, and a d dimensional real inner product space U, define a partial mapping LA,U:L(U,U⊗W)→[0,∞) by setting LA,U(B)=ρ(Γ(A;B))ρ(Γ(B;B))1/2. We say that a linear mapping B:U→U⊗W is a real LSRDR of A if the value LA,U(B) is locally maximized. If B is a real LSRDR of A, one can as before (if everything goes right) find linear operators R,S and constant α where RS=α⋅1U and where B=(R⊗1W)AS. As before, we can normalize the LSRDR so that α=1. In this case, we can set UR to be a dominant eigenvector of Γ(A;B) and UL to be a dominant eigenvector of Γ(A;B)T. We still define P=SR and the mapping P will be a non-orthogonal projection, and UR⋅ST,UL⋅R will still be positive semidefinite (up-to a constant factor). The situation we are in is exactly as before except that we are working with abstract finite dimensional inner product spaces without any mention of coordinates.
The information that I have given here can be found in several articles that I have posted at https://circcashcore.com/blog/.
I have thought of LSRDRs as machine learning models themselves (such as word embeddings), but it looks like LSRDRs may also be used to interpret machine learning models.
When generalizing bilinear layers to a quaternionic setting, do we want the layers to be linear in both variables or do we want them to be linear in one variable and anti-linear in the other variable?
Set a random variable XA to be a trained model with bilinear layers with random initialization and training data A. Then I would like to know if various estimated upper bounds for various entropies for XA are much lower than if XA were a more typical machine learning model where a linear layer is composed with ReLU. It seems like entropy is a good objective measure of the lack of decipherability.
This was a nice description, thanks!
I think this is incredibly optimistic hope that I think need be challenged more.
On my model GPT-N has a mixture of a) crisp representation, b) fuzzy heuristics made are made crisp in GPT-(N+1) and c) noise and misgeneralizations. Unless we’re discussing models that perfectly fit their training distribution, I expect comprehensively interpreting networks involves untangling many competing fuzzy heuristics which are all imperfectly implemented. Perhaps you expect this to be possible? However, I’m pretty skeptical this is tractable and expect the best good interpretability work to not confront these completeness guarentees.
Related (I consider “mechanistic interpretability essentially solved” to be similar to your “comprehensive interpreting” goal)
It reminds me (not only of my own writing on a similar theme) but of another one of these viewpoints/axes along which to carve interpretability work that is mentioned in this post by jylin04:
I don’t necessarily totally agree with her phrasing but it does feel a bit like we are all gesturing at something vaguely similar (and I do agree with her that PDLT-esque work may have more insights in this direction than some people on our side of the community have appreciated).
FWIW, in a recent comment reply to Joseph Bloom, I also ended up saying a bit more about why I don’t actually see myself working much more in this direction, despite it seeming very interesting, but I’m still on the fence about that. (And one last point that didn’t make it into that comment is the difficulties posed by a world in which increasingly the plucky bands of interpretability researchers on the fringes literally don’t even know what the cutting edge architectures and training processes in the biggest labs even are.