Hello! One way we (a.k.a. some folks on Anthropic interp) have found it useful to think about causal localization on arithmetic prompts is to capture residual stream values over some big distribution of prompts like
“For example, {a}{op}{b}={c}”,
and then see what happens when you do residual stream patching experiments with various marginals from that big table of values.
Here, I took all pairs of
For example 7x24=168
we would get the residual stream at each layer over =,1,6, and 8.
I then tried three patching schemes:
per-digit marginal: compute the average residual stream over problems whose output had digit a in position b, and patch in the appropriate positions.
full output marginal: compute the average residual stream for prompts where the output is exactly
c(sincechas three digits, this would be a separate vector over=and the first two digits; we would patch each into the appropriate place)single-prompt swap: just patch in the residual stream from one arbitrary prompt with answer
c
You can see below that the probability of the patched-in answer being emitted is very low at layer 20, and rises over layers 22-24 (depending on the operand, the digit and the patching scheme), with all fully saturated by 24.
This means that from a causal perspective, layer 20, where you run the NLAs, is early to capture the full story (consistent with your AR editing experiments), but that the logit lens, which has barely started to move by layer 24, trails behind the causal story. (Possibly some change of basis would let you see that the answer was already fully computed in the residual stream by then.)
The probe is about as confident at layer 5 as at layer 20, and i’m pretty sure the model hasn’t done active computation that early. I suspect this is just that you can probe for information that ‘spilled over’ via attn well before the model uses its MLPs to do anything. So while the NLA is very powerful (it can do a lot of computation itself), a linear probe has a bit of power on its own!
see https://arxiv.org/abs/2005.00719 for example. that’s why i like the patching experiments for estimating what the model has computed—it should be legible to the model