I was considering doing something like this, but I kept getting stuck at the issue that it doesn’t seem like gradients are an accurate attribution method. Have you tried comparing the attribution made by the gradients to a more straightforward attribution based on the counterfactual of enabling vs disabling a network component, to check how accurate they are? I guess I would especially be curious about its accuracy on real-world data, even if that data is relatively simple.
In earlier iterations we tried ablating parameter components one-by-one to calculate attributions and didn’t notice much of a difference (this was mostly on the hand-coded gated model in Appendix B). But yeah we agree that it’s likely pure gradients won’t suffice when scaling up or when using different architectures. If/when this happens we plan either use integrated gradients or more likely try using a trained mask for the attributions.
In my experience gradient-based attributions (especially if you use integrated gradients) are almost identical to the attributions you get from ablating away each component. It’s kinda crazy but is the reason ppl use edge-attribution patching over older approaches like ACDC.
Look at page 15 of https://openreview.net/forum?id=lq7ZaYuwub (left is gradient attributions, right is attributions from ablating each component). This is for Mamba but I’ve observed similar things for transformers.
I would be satisfied with integrated gradients too. There are certain cases where pure gradient-based attributions predictably don’t work (most notably when a softmax is saturated) and those are the ones I’m worried about (since it seems backwards to ignore all the things that a network has learned to reliably do when trying to attribute things, as they are presumably some of the most important structure in the network).
There are certain cases where pure gradient-based attributions predictably don’t work (most notably when a softmax is saturated)
Do you have a source or writeup somewhere on this? (or do you mind explaining more/have some examples where this is true?) Is this issue actually something that comes up for modern day LLMs?
In my observations it works fine for the toy tasks people have tried it on. The challenge seems to be in interpreting the attributions, not issues with the attributions themselves.
It’s elementary that the derivative approaches zero when one of the inputs to a softmax is significantly bigger than the others. Then when applying the chain rule, this entire pathway for the gradient gets knocked out.
I don’t know to what extent it comes up with modern day LLMs. Certainly I bet one could generate a lot of interpretability work within the linear approximation regime. I guess at some point it reduces to the question of why to do mechanistic interpretability in the first place.
I was considering doing something like this, but I kept getting stuck at the issue that it doesn’t seem like gradients are an accurate attribution method. Have you tried comparing the attribution made by the gradients to a more straightforward attribution based on the counterfactual of enabling vs disabling a network component, to check how accurate they are? I guess I would especially be curious about its accuracy on real-world data, even if that data is relatively simple.
In earlier iterations we tried ablating parameter components one-by-one to calculate attributions and didn’t notice much of a difference (this was mostly on the hand-coded gated model in Appendix B). But yeah we agree that it’s likely pure gradients won’t suffice when scaling up or when using different architectures. If/when this happens we plan either use integrated gradients or more likely try using a trained mask for the attributions.
In my experience gradient-based attributions (especially if you use integrated gradients) are almost identical to the attributions you get from ablating away each component. It’s kinda crazy but is the reason ppl use edge-attribution patching over older approaches like ACDC.
Look at page 15 of https://openreview.net/forum?id=lq7ZaYuwub (left is gradient attributions, right is attributions from ablating each component). This is for Mamba but I’ve observed similar things for transformers.
I would be satisfied with integrated gradients too. There are certain cases where pure gradient-based attributions predictably don’t work (most notably when a softmax is saturated) and those are the ones I’m worried about (since it seems backwards to ignore all the things that a network has learned to reliably do when trying to attribute things, as they are presumably some of the most important structure in the network).
Do you have a source or writeup somewhere on this? (or do you mind explaining more/have some examples where this is true?) Is this issue actually something that comes up for modern day LLMs?
In my observations it works fine for the toy tasks people have tried it on. The challenge seems to be in interpreting the attributions, not issues with the attributions themselves.
Figure 3 in this paper (AtP*, Kramar et al.) illustrates the point nicely: https://arxiv.org/abs/2403.00745
It’s elementary that the derivative approaches zero when one of the inputs to a softmax is significantly bigger than the others. Then when applying the chain rule, this entire pathway for the gradient gets knocked out.
I don’t know to what extent it comes up with modern day LLMs. Certainly I bet one could generate a lot of interpretability work within the linear approximation regime. I guess at some point it reduces to the question of why to do mechanistic interpretability in the first place.