A Naive Proposal for Constructing Interpretable AI

Epistemic Status: Extremely speculative, I’m not an experienced ML researcher

OpenAI discovered that language models can explain neurons by attaching a label. As part of their verification process, they trained simulated neurons where they used a language model to predict what a neuron associated with such a label should produce given a certain input.

This suggests a possible method for building an interpretable AI:

  1. Start off with all layers unfrozen

  2. Label the nodes in the first unfrozen layer using a language model

  3. Train the nodes in this layer to match a simulated neuron using their associated labels. The label should now be a much better summary of what the neuron does.

  4. Freezing this layer, train the rest of the network with the original objective function

  5. Return to step 2 and repeat

A few points:

  • Interpreting a network in the general case seems really, really hard, so why not make it easier by constructing a network to be interpretable?

  • We can use strategies like dictionary learning to help deal with polysemanticity so that the labels are likely to be useful.

  • Obviously, this will damage performance, but it’s an empirical question of what the tradeoff will look like. We can also use strategies to repair this damage. For example, we could add in some extra neurons that aren’t forced to correspond to something legible in the way that we’ve used above. These neurons would likely be messy, but at least this would be limited to a particular subsection of the network.

  • It’s been suggested to me that the first layers won’t be very interpretable, so it might only make sense to start this later in the network.

  • As well as making use of activations for labeling neurons, it would likely also make sense to provide the model with the incoming neurons with the strongest weight

  • If we get lucky, insofar as we are training each neuron to perform a particular role, we may be preventing and out-competing any opportunity for any neurons to play a role in deceptive alignment, except insofar as this is really obvious from the neuron labels.

Caveats:

  • This would likely require a lot of computational power because we’re labeling every neuron, then using our language model to figure out predicted activations for that label.

  • We’re asking our ML algorithm to learn to label neurons from just a few activations. This seems like a very challenging problem. It would be possible to be completely off and miss key information for labeling because it doesn’t arise in any of the examples.

Why might this be worthwhile:

  • Interpreting neural networks in complete generality doesn’t seem like a tractable problem

  • This suggests that it might be worthwhile trying to pick a class of models that will be more interpretable

  • And one of the best ways to do this seems like it could be to decide for ourselves what the inside of the network should look like. This proposal doesn’t go all the way there as the labels are autonomously chosen, but it moves a significant distance towards in that direction.

Anyway, I’m very keen to hear any feedback on this idea or whether you think anyone is investigating it.

If you think this is promising, feel free to pick it up. This project isn’t an immediate priority for me.