Is there a clear motivation for choosing the MLP activations as the autoencoder target? There are other choices of target that seem more intuitive to me (as I’ll explain below), namely:
the MLP’s residual stream update (i.e. MLP activations times MLP output weights)
In principle, we could also imagine using the “logit versions” of each of these as the target:
the change in logits due to the residual stream update[1]
the logits themselves
(In practice, the “logit versions” might be prohibitively expensive because the vocab is larger than other dimensions in the problem. But it’s worth thinking through what might happen if we did autoencode these quantities.)
At the outset, our goal is something like “understand what the MLP is doing.” But that could really mean one of 2 things:
understand the role that the function computed bythe MLP sub-block plays in the function computed by the network as whole
understand the role that the function computed bythe MLP neurons plays in the function computed by the network as whole
The feature decomposition in the paper provides a potentially satisfying answer for (1). If someone runs the network on a particular input, and asks you to explain what the MLP was doing during the forward pass, you can say something like:
Here is a list of features that were activated by the input. Each of these features is active because of a particular, intuitive/”interpretable” property of the input.
Each of these features has an effect on the logits (its logit weights), which is intuitive/”interpretable” on the basis of the input properties that cause it to be active.
The net effect of the MLP on the network’s output (i.e. the logits) is approximately[2] a weighted sum over these effects, weighted by how active the features were. So if you understand the list of features, you understand the effect of the MLP on the output.
However, if this person now asks you to explain what MLP neuron A/neurons/472 was doing during the forward pass, you may not be able to provide a satisfying answer, even with the feature decomposition in hand.
The story above appealed to the interpetability of each feature’s logit weights. To explain individual neuron activations in the same manner, we’d need the dictionary weights to be similarly interpretable. The paper doesn’t directly address this question (I think?), but I expect that the matrix of dictionary weights is fairly dense[3] and thus difficult to interpret, with each neuron being a long and complicated sum over many apparently unrelated features. So, even if we understand all the features, we still don’t understand how they combine to “cause” any particular neuron’s activation.
Is this a bad thing? I don’t think so!
An MLP sub-block in a transformer only affects the function computed by the transformer through the update it adds to the residual stream. If we understand this update, then we fully understand “what the MLP is doing” as a component of that larger computation. The activations are a sort of “epiphenomenon” or “implementation detail”; any information in the activations that is not in the update is inaccessible the rest of the network, and has no effect on the function it computes[4].
From this perspective, the activations don’t seem like the right target for a feature decomposition. The residual stream update seems more appropriate, since it’s what the rest of the network can actually see[5].
In the paper, the MLP that is decomposed into features is the last sub-block in the network.
Because this MLP is the last sub-block, the “residual stream update” is really just an update to the logits. There are no indirect paths going through later layers, only the direct path.
Note also that MLP activations are have a much more direct relationship with this logit update than they do with the inputs. If we ignore the nonlinear part of the layernorm, the logit update is just a (low-rank) linear transformation of the activations. The input, on the other hand, is related to the activations in a much more complex and distant manner, involving several nonlinearities and indeed most of the network.
With this in mind, consider a feature like A/1/2357. Is it...
...”a base64-input detector, which causes logit increases for tokens like ‘zc’ and ‘qn’ because they are more likely next-tokens in base64 text”?
...”a direction in logit-update space pointing towards ‘zc’ and ‘qn’ (among other tokens), which typically has ~0 projection on the logit update, but has large projection in a rare set of input contexts corresponding to base64″?
The paper implicitly the former view: the features are fundamentally a sparse and interpretable decomposition of the inputs, which also have interpretable effects on the logits as a derived consequence of the relationship between inputs and correct language-modeling predictions.
(For instance, although the automated interpretability experiments involved both input and logit information[6], the presentation of these results in the paper and the web app (e.g. the “Autointerp” and its score) focuses on the relationship between features and inputs, not features and outputs.)
Yet, the second view—in which features are fundamentally directions in logit-update space -- seems closer to the way the autoencoder works mechanistically.
The features are a decomposition of activations, and activations in the final MLP are approximately equivalent to logit updates. So, the features found by the autoencoder are
directions in logit-update space (because logit-updates are, approximately[7], what gets autoencoded),
which usually have small projection onto the update (i.e. they are sparse, they can usually be replaced with 0 with minimal degradation),
but have large projection in certain rare sets of input contexts (i.e. they have predictive value for the autoencoder, they can’t be replaced with 0 in every context)
To illustrate the value of this perspective, consider the token-in-context features. When viewed as detectors for specific kinds of inputs, these can seem mysterious or surprising:
But why do we see hundreds of different features for “the” (such as “the” in Physics, as distinct from “the” in mathematics)? We also observe this for other common words (e.g. “a”, “of”), and for punctuation like periods. These features are not what we expected to find when we set out to investigate one-layer models!
An example of such a feature is A/1/1078, which Claude glosses as
The [feature] fires on the word “the”, especially in materials science writing.
This is, indeed, a weird-sounding category to delineate in the space of inputs.
But now consider this feature as a direction in logit-update space, whose properties as a “detector” in input space derive from its logit weights—it “detects” exactly those inputs on which the MLP wants to move the logits in this particular, rarely-deployed direction.
The question “when is this feature active?” has a simple, non-mysterious answer in terms of the logit updates it causes: “this feature is active when the MLP wants to increase the logit for the particular tokens ′ magnetic’, ′ coupling’, ‘electron’, ′ scattering’ (etc.)”
Which inputs correspond to logit updates in this direction? One can imagine multiple scenarios in which this update would be appropriate. But we go looking for inputs on which the update was actually deployed, our search will be weighted by
the ease of learning a given input-output pattern (esp. b/c this network is so low-capacity), and
how often a given input-output pattern occurs in the Pile.
The Pile contains all of Arxiv, so it contains a lot of materials science papers. And these papers contain a lot of “materials science noun phrases”: phrases that start with “the,” followed by a word like “magnetic” or “coupling,” and possibly more words.
This is not necessarily the only input pattern “detected” by this feature[8] -- because it is not necessarily the only case where this update direction is appropriate—but it is an especially common one, so it appears at a glance to be “the thing the feature is ‘detecting.’ ” Further inspection of the activation might complicate this story, making the feature seem like a “detector” of an even weirder and more non-obvious category—and thus even more mysterious from the “detector” perspective. Yet these traits are non-mysterious, and perhaps even predictable in advance, from the “direction in logit-update space” perspective.
That’s a lot of words. What does it all imply? Does it matter?
I’m not sure.
The fact that other teams have gotten similar-looking results, while (1) interpreting inner layers from real, deep LMs and (2) interpreting the residual stream rather than the MLP activations, suggests that these results are not a quirk of the experimental setup in the paper.
But in deep networks, eventually the idea that “features are just logit directions” has to break down somewhere, because inner MLPs are not only working through the direct path. Maybe there is some principled way to get the autoencoder to split things up into “direct-path features” (with interpretable logit weights) and “indirect-path features” (with non-interpretable logit weights)? But IDK if that’s even desirable.
We could compute this exactly, or we could use a linear approximation that ignores the layer norm rescaling. I’m not sure one choice is better motivated than the other, and the difference is presumably small.
There’s a figure in the paper showing dictionary weights from one feature (A/1/3450) to all neurons. It has many large values, both positive and negative. I’m imagining that this case is typical, so that the matrix of dictionary vectors looks like a bunch of these dense vectors stacked together.
It’s possible that slicing this matrix along the other axis (i.e. weights from all features to a single neuron) might reveal more readily interpretable structure—and I’m curious to know whether that’s the case! -- but it seems a priori unlikely based on the evidence available in the paper.
However, while the “implementation details” of the MLP don’t affect the function computed during inference, they do affect the training dynamics. Cf. the distinctive training dynamics of deep linear networks, even though they are equivalent to single linear layers during inference.
If the MLP is wider than the residual stream, as it is in real transformers, then the MLP output weights have a nontrivial null space, and thus some of the information in the activation vector gets discarded when the update is computed.
A feature decomposition of the activations has to explain this “irrelevant” structure along with the “relevant” stuff that gets handed onwards.
Claude was given logit information when asked to describe inputs on which a feature is active; also, in a separate experiment, it was asked to predict parts of the logit update.
Caveat: L2 reconstruction loss on logits updates != L2 reconstruction loss on activations, and one might not even be a close approximation to the other.
That said, I have a hunch they will give similar results in practice, based on a vague intuition that the training loss will tend encourage the neurons to have approximately equal “importance” in terms of average impacts on the logits.
Very interesting! Some thoughts:
Is there a clear motivation for choosing the MLP activations as the autoencoder target? There are other choices of target that seem more intuitive to me (as I’ll explain below), namely:
the MLP’s residual stream update (i.e. MLP activations times MLP output weights)
the residual stream itself (after the MLP update is added), as in Cunningham et al
In principle, we could also imagine using the “logit versions” of each of these as the target:
the change in logits due to the residual stream update[1]
the logits themselves
(In practice, the “logit versions” might be prohibitively expensive because the vocab is larger than other dimensions in the problem. But it’s worth thinking through what might happen if we did autoencode these quantities.)
At the outset, our goal is something like “understand what the MLP is doing.” But that could really mean one of 2 things:
understand the role that the function computed by the MLP sub-block plays in the function computed by the network as whole
understand the role that the function computed by the MLP neurons plays in the function computed by the network as whole
The feature decomposition in the paper provides a potentially satisfying answer for (1). If someone runs the network on a particular input, and asks you to explain what the MLP was doing during the forward pass, you can say something like:
However, if this person now asks you to explain what MLP neuron A/neurons/472 was doing during the forward pass, you may not be able to provide a satisfying answer, even with the feature decomposition in hand.
The story above appealed to the interpetability of each feature’s logit weights. To explain individual neuron activations in the same manner, we’d need the dictionary weights to be similarly interpretable. The paper doesn’t directly address this question (I think?), but I expect that the matrix of dictionary weights is fairly dense[3] and thus difficult to interpret, with each neuron being a long and complicated sum over many apparently unrelated features. So, even if we understand all the features, we still don’t understand how they combine to “cause” any particular neuron’s activation.
Is this a bad thing? I don’t think so!
An MLP sub-block in a transformer only affects the function computed by the transformer through the update it adds to the residual stream. If we understand this update, then we fully understand “what the MLP is doing” as a component of that larger computation. The activations are a sort of “epiphenomenon” or “implementation detail”; any information in the activations that is not in the update is inaccessible the rest of the network, and has no effect on the function it computes[4].
From this perspective, the activations don’t seem like the right target for a feature decomposition. The residual stream update seems more appropriate, since it’s what the rest of the network can actually see[5].
In the paper, the MLP that is decomposed into features is the last sub-block in the network.
Because this MLP is the last sub-block, the “residual stream update” is really just an update to the logits. There are no indirect paths going through later layers, only the direct path.
Note also that MLP activations are have a much more direct relationship with this logit update than they do with the inputs. If we ignore the nonlinear part of the layernorm, the logit update is just a (low-rank) linear transformation of the activations. The input, on the other hand, is related to the activations in a much more complex and distant manner, involving several nonlinearities and indeed most of the network.
With this in mind, consider a feature like A/1/2357. Is it...
...”a base64-input detector, which causes logit increases for tokens like ‘zc’ and ‘qn’ because they are more likely next-tokens in base64 text”?
...”a direction in logit-update space pointing towards ‘zc’ and ‘qn’ (among other tokens), which typically has ~0 projection on the logit update, but has large projection in a rare set of input contexts corresponding to base64″?
The paper implicitly the former view: the features are fundamentally a sparse and interpretable decomposition of the inputs, which also have interpretable effects on the logits as a derived consequence of the relationship between inputs and correct language-modeling predictions.
(For instance, although the automated interpretability experiments involved both input and logit information[6], the presentation of these results in the paper and the web app (e.g. the “Autointerp” and its score) focuses on the relationship between features and inputs, not features and outputs.)
Yet, the second view—in which features are fundamentally directions in logit-update space -- seems closer to the way the autoencoder works mechanistically.
The features are a decomposition of activations, and activations in the final MLP are approximately equivalent to logit updates. So, the features found by the autoencoder are
directions in logit-update space (because logit-updates are, approximately[7], what gets autoencoded),
which usually have small projection onto the update (i.e. they are sparse, they can usually be replaced with 0 with minimal degradation),
but have large projection in certain rare sets of input contexts (i.e. they have predictive value for the autoencoder, they can’t be replaced with 0 in every context)
To illustrate the value of this perspective, consider the token-in-context features. When viewed as detectors for specific kinds of inputs, these can seem mysterious or surprising:
An example of such a feature is A/1/1078, which Claude glosses as
This is, indeed, a weird-sounding category to delineate in the space of inputs.
But now consider this feature as a direction in logit-update space, whose properties as a “detector” in input space derive from its logit weights—it “detects” exactly those inputs on which the MLP wants to move the logits in this particular, rarely-deployed direction.
The question “when is this feature active?” has a simple, non-mysterious answer in terms of the logit updates it causes: “this feature is active when the MLP wants to increase the logit for the particular tokens ′ magnetic’, ′ coupling’, ‘electron’, ′ scattering’ (etc.)”
Which inputs correspond to logit updates in this direction? One can imagine multiple scenarios in which this update would be appropriate. But we go looking for inputs on which the update was actually deployed, our search will be weighted by
the ease of learning a given input-output pattern (esp. b/c this network is so low-capacity), and
how often a given input-output pattern occurs in the Pile.
The Pile contains all of Arxiv, so it contains a lot of materials science papers. And these papers contain a lot of “materials science noun phrases”: phrases that start with “the,” followed by a word like “magnetic” or “coupling,” and possibly more words.
This is not necessarily the only input pattern “detected” by this feature[8] -- because it is not necessarily the only case where this update direction is appropriate—but it is an especially common one, so it appears at a glance to be “the thing the feature is ‘detecting.’ ” Further inspection of the activation might complicate this story, making the feature seem like a “detector” of an even weirder and more non-obvious category—and thus even more mysterious from the “detector” perspective. Yet these traits are non-mysterious, and perhaps even predictable in advance, from the “direction in logit-update space” perspective.
That’s a lot of words. What does it all imply? Does it matter?
I’m not sure.
The fact that other teams have gotten similar-looking results, while (1) interpreting inner layers from real, deep LMs and (2) interpreting the residual stream rather than the MLP activations, suggests that these results are not a quirk of the experimental setup in the paper.
But in deep networks, eventually the idea that “features are just logit directions” has to break down somewhere, because inner MLPs are not only working through the direct path. Maybe there is some principled way to get the autoencoder to split things up into “direct-path features” (with interpretable logit weights) and “indirect-path features” (with non-interpretable logit weights)? But IDK if that’s even desirable.
We could compute this exactly, or we could use a linear approximation that ignores the layer norm rescaling. I’m not sure one choice is better motivated than the other, and the difference is presumably small.
because of the (hopefully small) nonlinear effect of the layer norm
There’s a figure in the paper showing dictionary weights from one feature (A/1/3450) to all neurons. It has many large values, both positive and negative. I’m imagining that this case is typical, so that the matrix of dictionary vectors looks like a bunch of these dense vectors stacked together.
It’s possible that slicing this matrix along the other axis (i.e. weights from all features to a single neuron) might reveal more readily interpretable structure—and I’m curious to know whether that’s the case! -- but it seems a priori unlikely based on the evidence available in the paper.
However, while the “implementation details” of the MLP don’t affect the function computed during inference, they do affect the training dynamics. Cf. the distinctive training dynamics of deep linear networks, even though they are equivalent to single linear layers during inference.
If the MLP is wider than the residual stream, as it is in real transformers, then the MLP output weights have a nontrivial null space, and thus some of the information in the activation vector gets discarded when the update is computed.
A feature decomposition of the activations has to explain this “irrelevant” structure along with the “relevant” stuff that gets handed onwards.
Claude was given logit information when asked to describe inputs on which a feature is active; also, in a separate experiment, it was asked to predict parts of the logit update.
Caveat: L2 reconstruction loss on logits updates != L2 reconstruction loss on activations, and one might not even be a close approximation to the other.
That said, I have a hunch they will give similar results in practice, based on a vague intuition that the training loss will tend encourage the neurons to have approximately equal “importance” in terms of average impacts on the logits.
At a glance, it seems to also activate sometimes on tokens like ” each” or ” with” in similar contexts.