Taking the parameters which seem to matter and rotating them until they don’t

A big bottleneck in interpretability is neural networks are non-local. That is, given the layer setup

if we change a small bit of the original activations, then a large bit of the new activations are affected.

This is an impediment to finding the circuit-structure of networks. It is difficult to figure out how something works when changing one thing affects everything.

The project I’m currently working on aims to fix this issue, without affecting the training dynamics of networks or the function which the network is implementing[1]. The idea is to find a rotation matrix and insert it with its inverse like below, then group together the rotation with the original activations, and the inverse with the weights and nonlinear function.

We then can optimize the rotation matrix and its inverse so that local changes in the rotated activation matrix have local effects on the outputted activations. This locality is measured by the average sparsity of the jacobian across all the training inputs.

We do this because the jacobian is a representation of how each of the inputs affects each of the outputs. Large entries represent large effects. Small entries represent small effects. So if many entries are zero, this means that fewer inputs have an effect on fewer outputs. I.e. local changes to the input cause local changes to the output.

This should find us a representation of the activations and interpretations of matrix multiplies that “make sense” in the context of the rest of the network.

Another way of thinking about this is that our goal is to find the basis our network is thinking in.

Currently I’m getting this method to work on a simple, 3-layer, fully connected MNIST number classifying network. If this seems to give insight into the mechanics of the network after application, the plan is to adapt it to a more complicated network such as a transformer or resnet.

I only have preliminary results right now, but they are looking promising:

This is the normalized jacobian the middle layer before a rough version of my method:

And here is the normalized jacobian after a rough version of my method (the jacobian’s output has been set to a basis which maximizes it’s sparsity):


Thanks David Udell for feedback on the post. I did not listen to everything you said, and if I did the post would have been better

  1. ^

    This seems important if we’d like to use interpretability work to produce useful conjectures about agency and selection more generally.