We’ve been seeing similar things when pruning graphs of language model computations generated with parameter decomposition. I have a suspicion that something like this might be going on in the recent neuron interpretability work as well, though I haven’t verified that. If you just zero or mean ablate lots of nodes in a very big causal graph, you can get basically any end result you want with very few nodes, because you can select sets of nodes to ablate that are computationally important but cancel each other out in exactly the way you need to get the right answer.[1]
I think the trick is to not do complete ablations, but instead ablate stochastically or even adversariallychosensubsets of nodes/edges:
You select the nodes you want to keep.
The adversary picks which of the nodes you did not choose to keep it wants to zero/mean ablate or not zero/mean ablate, picking subsets that make the loss as high as possible.[2] We do this by optimising masks for the nodes with gradient ascent.
This way, you also don’t need to freeze layer norms to prevent cheating.
There’s some subtlety to this. You probably want certain restrictions placed on the adversary, because otherwise there’s situations where it can also break faithful circuits by exploiting random noise. We use a scheme where the adversary has to pick one ablation scheme for a whole batch, specifying what nodes it does or does not want to ablate whenever they are not kept, to stop it from fine tuning unstructured noise for particular inputs.
We’ve been seeing similar things when pruning graphs of language model computations generated with parameter decomposition. I have a suspicion that something like this might be going on in the recent neuron interpretability work as well, though I haven’t verified that. If you just zero or mean ablate lots of nodes in a very big causal graph, you can get basically any end result you want with very few nodes, because you can select sets of nodes to ablate that are computationally important but cancel each other out in exactly the way you need to get the right answer.[1]
I think the trick is to not do complete ablations, but instead ablate stochastically or even adversarially chosen subsets of nodes/edges:
You select the nodes you want to keep.
The adversary picks which of the nodes you did not choose to keep it wants to zero/mean ablate or not zero/mean ablate, picking subsets that make the loss as high as possible.[2] We do this by optimising masks for the nodes with gradient ascent.
This way, you also don’t need to freeze layer norms to prevent cheating.
It’s for a different context, but we talk about the issue with using these sorts of naive ablation schemes to infer causality in Appendix A of the first parameter decomposition paper. This is why we switched to training decompositions with stochastically chosen ablations, and later switched to training them adversarially.
There’s some subtlety to this. You probably want certain restrictions placed on the adversary, because otherwise there’s situations where it can also break faithful circuits by exploiting random noise. We use a scheme where the adversary has to pick one ablation scheme for a whole batch, specifying what nodes it does or does not want to ablate whenever they are not kept, to stop it from fine tuning unstructured noise for particular inputs.