My notes / thoughts: (apologies for overly harsh/critical tone, I’m stepping into the role of annoying reviewer)
Summary
Use residual stream SAEs spaced across layers, discovery node circuits with learned binary masking on templated data.
binary masking outperforms integrated gradients on faithfulness metrics, and achieves comparably (though maybe narrowly worse) completeness metrics.
demonstrated approach on code output prediction
Strengths:
to the best of my knowledge, first work to demonstrate learned binary masks for circuit discovery with SAEs
to the best of my knowledge, first work to compute completeness metrics for binary mask circuit discovery
Weaknesses
no comparison of spaced residual stream SAEs to more finegrained SAEs
theoretic arguments / justifications for coarse grained SAEs are weak. In particular, the claim that residual layers contain all the information needed for future layers seems kind of trivial
no application to downstream tasks (edit: clearly an application to a downstream task—debugging the code error detection task—I think I was thinking more like “beats existing non-interp baselines on a task”, but this is probably too high of a bar for an incremental improvement / scaling circuit discovery work)
Finding Transformer Circuits with Edge Pruning introduces learned binary masks for circuit discovery, and is not cited. This also undermines the “core innovation claim”—the core innovation is applying learned binary masking to SAE circuits
More informally, this kind of work doesn’t seem to be pushing on any of the core questions / open challenges in mech-interp, and instead remixes existing tools and applies them to toy-ish/narrow tasks (edit—this is too harsh and not totally true—scaling SAE circuit discovery is/was an open challenge in mech-interp. I guess I was going for “these results are marginally useful, but not all that surprising, and unlikely to move anyone already skeptical of mech-interp”)
Of the future research / ideas, I’m most excited about the non-templatic data / routing model
On strengths, we also believe that we are the first to examine “few saes” for scalable circuit discovery.
On weaknesses,
While we plan to do a more thorough sweep of SAE placements and comparison, the first weakness remains true for this post.
Our major argument for the support of using few SAEs is imaging them as interpretable bottlenecks. Because they are so minimal and interpretable, they allow us to understand blocks of the transformer between them functionally (in terms of input and output). We were going to include more intuition about this but were worried it might add unnecessary complications. We mention the fact about residual stream to highlight that information cannot be passed to layer L+1 by any other path than the residual output of layer L. Thus, by training a mask at layer L, we find a minimal set of representations needed for future layers. To future layers, nothing other than these latents matter. We do agree that the nature of circuits found with coarse grained saes will differ, and this needs to be further studied.
We plan to explore the “gender bias removal” of Marks et al. [1] to compare the downstream application effectiveness. However, we do have a small application where we found a “bug” in the model, covered in section 5, where it over relies on duplicate token latents. We can try to do something similar to Marks et al.[1] in trying to “fix” this bug
Thanks for sharing the citation!
A core question shared in the community is whether the idea of circuits is plausible as models continue to scale up. Current automated methods either are too computationally expensive or generate a subgraph that is too large to examine. We explore the idea of a few equally spaced SAEs with the goal of solving both those issues. Though as you mentioned, a more thorough comparison between circuits of different numbers of saes is needed.
Thanks for the thorough response, and apologies for missing the case study!
I think I regret / was wrong about my initial vaguely negative reaction—scaling SAE circuit discovery to large models is a notable achievement!
Re residual skip SAEs: I’m basically on board with “only use residual stream SAEs”, but skipping layers still feels unprincipled. Like imagine if you only trained an SAE on the final layer of the model. By including all the features, you could perfectly recover the model behavior up to the SAE reconstruction loss, but you would have ~no insight into how the model computed the final layer features. More generally, by skipping layers, you risk missing potentially important intermediate features. ofc to scale stuff you need to make sacrifices somewhere, but stuff in the vicinity of Cross-Coders feels more comprising
Yes—By design, the circuits discovered in this manner might miss how/when something is computed. But we argue that finding the important representations at bottlenecks and their change over layers can provide important/useful information about the model.
One of our future directions, along the direction of crosscoders, is to have “Layer Output Buffer SAEs” that aim to tackle the computation between bottlenecks.
Nice post!
My notes / thoughts: (apologies for overly harsh/critical tone, I’m stepping into the role of annoying reviewer)
Summary
Use residual stream SAEs spaced across layers, discovery node circuits with learned binary masking on templated data.
binary masking outperforms integrated gradients on faithfulness metrics, and achieves comparably (though maybe narrowly worse) completeness metrics.
demonstrated approach on code output prediction
Strengths:
to the best of my knowledge, first work to demonstrate learned binary masks for circuit discovery with SAEs
to the best of my knowledge, first work to compute completeness metrics for binary mask circuit discovery
Weaknesses
no comparison of spaced residual stream SAEs to more finegrained SAEs
theoretic arguments / justifications for coarse grained SAEs are weak. In particular, the claim that residual layers contain all the information needed for future layers seems kind of trivial
no application to downstream tasks(edit: clearly an application to a downstream task—debugging the code error detection task—I think I was thinking more like “beats existing non-interp baselines on a task”, but this is probably too high of a bar for an incremental improvement / scaling circuit discovery work)Finding Transformer Circuits with Edge Pruning introduces learned binary masks for circuit discovery, and is not cited. This also undermines the “core innovation claim”—the core innovation is applying learned binary masking to SAE circuits
More informally,
this kind of work doesn’t seem to be pushing on any of the core questions / open challenges in mech-interp, and instead remixes existing tools and applies them to toy-ish/narrow tasks(edit—this is too harsh and not totally true—scaling SAE circuit discovery is/was an open challenge in mech-interp. I guess I was going for “these results are marginally useful, but not all that surprising, and unlikely to move anyone already skeptical of mech-interp”)Of the future research / ideas, I’m most excited about the non-templatic data / routing model
Thanks a lot for this review!
On strengths, we also believe that we are the first to examine “few saes” for scalable circuit discovery.
On weaknesses,
While we plan to do a more thorough sweep of SAE placements and comparison, the first weakness remains true for this post.
Our major argument for the support of using few SAEs is imaging them as interpretable bottlenecks. Because they are so minimal and interpretable, they allow us to understand blocks of the transformer between them functionally (in terms of input and output). We were going to include more intuition about this but were worried it might add unnecessary complications. We mention the fact about residual stream to highlight that information cannot be passed to layer L+1 by any other path than the residual output of layer L. Thus, by training a mask at layer L, we find a minimal set of representations needed for future layers. To future layers, nothing other than these latents matter. We do agree that the nature of circuits found with coarse grained saes will differ, and this needs to be further studied.
We plan to explore the “gender bias removal” of Marks et al. [1] to compare the downstream application effectiveness. However, we do have a small application where we found a “bug” in the model, covered in section 5, where it over relies on duplicate token latents. We can try to do something similar to Marks et al.[1] in trying to “fix” this bug
Thanks for sharing the citation!
A core question shared in the community is whether the idea of circuits is plausible as models continue to scale up. Current automated methods either are too computationally expensive or generate a subgraph that is too large to examine. We explore the idea of a few equally spaced SAEs with the goal of solving both those issues. Though as you mentioned, a more thorough comparison between circuits of different numbers of saes is needed.
Thanks for the thorough response, and apologies for missing the case study!
I think I regret / was wrong about my initial vaguely negative reaction—scaling SAE circuit discovery to large models is a notable achievement!
Re residual skip SAEs: I’m basically on board with “only use residual stream SAEs”, but skipping layers still feels unprincipled. Like imagine if you only trained an SAE on the final layer of the model. By including all the features, you could perfectly recover the model behavior up to the SAE reconstruction loss, but you would have ~no insight into how the model computed the final layer features. More generally, by skipping layers, you risk missing potentially important intermediate features. ofc to scale stuff you need to make sacrifices somewhere, but stuff in the vicinity of Cross-Coders feels more comprising
Yes—By design, the circuits discovered in this manner might miss how/when something is computed. But we argue that finding the important representations at bottlenecks and their change over layers can provide important/useful information about the model.
One of our future directions, along the direction of crosscoders, is to have “Layer Output Buffer SAEs” that aim to tackle the computation between bottlenecks.