This post is a continuation and clarification of Circuits in Superposition: Compressing many small neural networks into one. That post presented a sketch of a general mathematical framework for compressing different circuits into a network in superposition. On closer inspection, some of it turned out to be wrong, though. The error propagation calculations for networks with multiple layers were incorrect. With the framework used in that post, the errors blow up too much over multiple layers.
This post presents a slightly changed construction that fixes those problems, and improves on the original in some other ways as well.[1]
By computation in superposition we mean that a network represents features in superposition and performs more computations with them than it has neurons, across multiple layers. Having better models of this is important for understanding how and even if networks use superposition, which in turn is important for mech-interp in general.
Performing computation in superposition over multiple layers introduces additional noise compared to just storing features in superposition[2]. This restricts the amount and type of computation that can be implemented in a network of a given size, because the noise needs to be reduced or suppressed to stay smaller than the signal.
Takeaways
Our setup in this post (see the Section Construction for details) is as follows:
We have T small circuits, each of which can be described as a d-dimensional multilayer perceptron (MLP) with L layers.
We have one large D-dimensional MLP with L layers, where D>d , but D<Td. So we can’t just dedicate d neurons in the large MLP to each circuit.
We embed all T circuits into the large network, such that the network approximately implements the computations of every circuit, conditional on no more than z<<T circuits being used on any given forward pass.
The number of circuits we can fit in scales linearly with the number of network parameters
Similar to the previous post, we end up concluding that the total number of parameters in the circuits must be smaller than the number of parameters in the large network. This result makes a lot of intuitive sense, since the parameters determine the maximum amount of information a network can possibly store.[3]
More specifically, we find that the term √zTd2D2 needs to be smaller than 1 for the errors on the computations of individual circuits to stay smaller than the signal.
Here T is the total number of circuits, d is the width of each small circuit, D is the layer width of the large network and z is the number of circuits active on a given forward pass.
This gives us an approximate upper bound on the maximum number of d-dimensional circuits we can fit into a network with D neurons per layer:
Tmax=O(1zD2d2)(0.1)
If you only remember one formula from this post, let it be that one.
This is much smaller than the number of d-dimensional features we can store in superposition in a layer of width D, if we don’t intend to use them for calculations within the network. That number is[4]
Tmax storage=O(1deD8zd)(0.2)
So, while storage capacity scales exponentially with D, capacity for computation only scales quadratically.
Each circuit will only use a small subset of neurons in the larger network
For this construction to work, each circuit can only be using a small fraction of the large network’s neurons.
I, Linda, expect this to be true more generally. I think basically any construction that achieves computation in superposition, across multiple layers, in the sense we mean here will (approximately) have this property. My reasons for thinking this are pretty entangled with details of the error propagation math, so I’ve relegated them to the Discussion section.
Implications for experiments on computation in superposition
The leading source of error in this construction is signals from the active circuits (used on the forward pass) bleeding into currently inactive circuits that shouldn’t be doing anything. This bleed-over then enters back into the active circuits as noise in the next layer.
This means that the biggest errors don’t appear until layer 2[5]. This is important to keep in mind for experimental investigations of computation in superposition, e.g. when training toy models. If your network only has one computational layer, then it doesn’t have to implement a way to reduce this noise.[6]
To make sure the math was actually really correct this time around, Linda coded up a little model implementing some circuits in superposition by hand.
Naturally, while doing this, she found that there were still a bunch of fiddly details left to figure out how to make circuits in superposition actually work in real life, even on a pretty simple example, because the math makes a bunch of vague assumptions about the circuits that turn out to be important when you actually get down to making things work in practice.
The math presented in this post won’t deal with those fiddly details. It is intended to be a relatively simple high level description of a general framework. E.g, we assume that individual circuits have some level of noise robustness around values representing ‘inactive’ for that circuit, without worrying about how it’s achieved.
So, in actual practice, the details of this construction may need some adjustment depending on how exactly individual circuits implement their noise robustness, and whether any of them are doing similar things.[7]
A post with the hand coded model and the fiddly details should be coming out “SoonTM”.
Construction
The construction in this post has some significant differences from the previous one.
To simplify things, the networks here don’t have a residual stream, they’re just deep MLPs.[8] We have one large MLP with L layers, neuron dimension D, activation vectors Al∈RD, and weight matrices Wl∈RD×D.
Al=ReLUD(WlAl−1)forl≥1(1.1)
We also have T circuits, indexed by 0,…,T−1, each described by a small MLP with L layers, neuron dimension d, activations vectors alt∈Rd, and weight matrices wlt∈Rd×d.
alt=ReLUd(wltal−1t)forl≥1,t=0,…,T−1(1.2)
Our goal is to figure out a construction for the weight matrices Wl, which embeds the circuits into the network, such that the outputs of each circuit can be read-out from the final output of the network AL with linear projections, up to some small error.[9]
Assumptions
For this construction to work as intended, we need to assume that:
Only z≪Dd circuits can be active on any given forward pass.
Small circuits are robust to noise when inactive. I.e. a small deviation to the activation value of an inactive circuit applied in layer l will not change the activation value of that circuit in layer l+1.[10]
If a circuit is inactive, all of its neurons have activation value zero. I.e. alt=0 if circuit t is inactive.
The entries of the weight matrices wlt for different circuits in the same layer are uncorrelated with each other.
Assumption 1 is just the standard sparsity condition for superposition.
Assumption 2 is necessary, but if it is not true for some of the circuits we want to implement, we can make it true by modifying them slightly, in a way that doesn’t change their functionality.[11] How this works will not be covered in this post though.
Assumptions 3 and 4 are not actually necessary for something similar to this construction to work, but without them the construction becomes more complicated. The details of this are also beyond the the scope of this post.
Embedding the circuits into the network
The important takeaways from this section are Equations (1.11) and (1.13)-(1.14), which we will make use of in the Error Calculation section. If you’re happy with these and don’t think they require further motivation, then you don’t need to read the rest of this section.
Remember that the circuit weights wlt and their activation vectors alt are handed to us and we can’t change them, but we are free to choose the weights Wl of the network to be whatever we like. We also assume that we get to choose how to linearly embed the input vectors of the circuits a0t into the input vector of the network A0 at the start.
To help with the embedding we will introduce:
Embedding matrices Elt∈RD×d for each circuit t in each layer l≥1
Unembedding matrices Ult∈Rd×D for each circuit t in each layer l≥1.
Our goal is to calculate Equation (1.2) using the network, which (due to our choice of Ult and Elt, see next section) can be re-expressed as
If we combine Equations (1.4) and (1.5) while pretending[12] they are exact relations, we get
Al=ReLUD(∑tEltwltUl−1tAl−1)forl≥2(1.6)
If we combine that with Equation (1.1) the network weights Wl for l≥2 are now defined via the embedding and unembedding matrices as
Wl=∑tEltwltUl−1tforl≥2(1.7)
You might notice that this leaves W1 undefined, and that there are no embedding and unembedding matrices for l=0. That’s because layer zero is a bit special.
Layer 0
There are no embedding and unembedding matrices for layer 0, because we can just skip ahead and use our free choice of how to linearly embed a0t into A0t to implement the first matrix multiplications in each circuit w1t without any interference errors.
We choose
W1=IandA0=∑tE1tw1ta0t(1.8)
which gives us:
A1=ReLUD(∑tE1tw1ta0t)(1.9)
I.e, (1.4) is exactly true in the first layer. Down in the Error Calculation section, this will have consequences for which layer each error term first shows up in.
Maybe you think this is sort of cheating. Perhaps it is, but a model can train to cheat like this as well. That’s part of the point we want to make in this post: Having more than one layer makes a difference. From layer 2 onward, this kind of thing doesn’t work anymore. We no longer gets to assume that every feature always starts embedded exactly where we want it. We have to actually compute on the features in a way that leaves every intermediary result embedded such that later layers can effectively compute on them in turn.
Constructing the Embedding and Unembedding matrices
In this subsection, we first lay out some requirements for the embedding matrices Elt and unembedding matrices Ult, then describe an algorithm for constructing them to fulfil those requirements.
The main takeaway here is that we can construct Elt and Ult such that they have the properties described in the final subsection. If you are happy to just trust us on this, then you don’t need to read the rest of this section.
Remember that each neuron in the large network will be used in the embeddings of multiple small circuit neurons. This is inevitable because the total numbers of neurons across all small circuits is larger than the number of neurons in the large network, dT>D, as per assumption 1.
Requirements
Neurons from the same circuit should be embedded into non-overlapping sets of neurons in the large network. We want this because neurons in the same circuit will tend to coactivate a lot, so reducing interference between them is especially important.
Neurons from different small circuits should be embedded into sets of neurons in the large network that have at most overlap one. So, no pair of circuit neurons shares more than one network neuron. This ensures that there is a hard bound on how bad errors from interference between different circuits can get.
Elt and Ult should only contain non-negative values. This ensures that the embedding does not change the sign of any circuit neuron pre-activations, so that the ReLUs still work correctly.
The embedding should distribute circuit neurons approximately evenly over the network neurons. Otherwise we’re just wasting space.
Our construction will satisfy these requirements and redundantly embed each circuit neuron into a different subset of S>1 network neurons. We call S the embedding redundancy.
Generally speaking, larger values of S are good, because they make each circuit more robust to worst-case errors from interference with other circuits. However, our requirement 2 creates an upper bound on S, because the more network neurons we assign to each circuit neuron, the harder it becomes to ensure that no two circuit neurons share more than one network neuron.[13]
Our allocation scheme is based on prime number factorisation. We promise it’s not as complicated as that description may make it sound. The details of how it works are not important for understanding the rest of this post though, so feel free to skip it if you’re not interested.
Step 1
First, we split the network neurons (for each layer) into d sets of Dd neurons[14]. The first neuron of each circuit will be embedded in the first of these sets, the second neuron of each circuit will be embedded into the second set and so on, with the d-th neuron of each circuit embedded into the d-th set of network neurons.
So next, we need to describe how to embed each set of T circuit neurons into a set of Dd network neurons.
Step 2
The following algorithm allocates S>1 out of Dd large neurons to each small circuit, while ensuring that each pair of circuits shares at most one such allocation, and that the allocations are fairly evenly spread out over the Dd neurons.
The allocation algorithm:
Each small circuit neuron is assigned to S large network neurons that are spaced step neurons apart. The set of possible_steps is chosen such that no two different neuron_allocations will overlap more than once. The set of possible_steps is based on prime factors.
First, we define a set of possible_steps[15]and a function that draws a new instance from that set:
PS:={p∈N|smallest prime factor of p is larger than S}(a.1)possible_steps:={step=pSn|p∈PS;n∈N0;pSn(S−1)≤Dd}(a.2)
Here, N denotes the natural numbers, N={1,2,3…}, and N0 denotes the natural numbers including zero, N0={0,1,2,3…}.
Next we chose one step from possible_steps. We use this to generate approximately DSd non-overlapping neuron_allocations, where each neuron_alocation consists of S subsequent neurons where each one is step neurons away from the previous one.
When we can’t generate more non-overlapping neuron_allocations from our first step we draw a new step from possible_steps. Repeat until we have T neuron allocations, i.e, one for each small circuit.
Pseudo code:
next_step = function that generates a new element from possible_steps when called
step = next_step()
shift = 0
start = shift
current_small_circuit = 0
while current_small_circuit < T
neuron_allocation[current_small_circuit] = [
start, start + step, ..., start + (S-1)*step]
start += S*step[17][17][16][16]
if start + (S-1)*step >= D/d:
shift += 1
if shift == step:
step = next_step()
shift = 0
start = shift
current_small_circuit+=1
Why it works:
This will make sure no neuron_allocations share more than one neuron because[17]
To reduce noise, it’s important that circuit neurons don’t keep having the same “neighbours” from layer to layer, where “neighbours” are neurons from other different circuits that share a network neuron.[18] To ensure this, we just re-shuffle the circuit index t at every layer. I.e. for each layer, we perform a random permutation on what set of S network neurons is allocated to each circuit neuron.
Step 5
The unembedding matrix is now just the transpose of the embedding matrix normalised by the embedding redundancy S:
Ult=1S(Elt)⊤(1.10)
Real python code
If you prefer real code over words + pseudo code, then here’s a Colab Notebook with the full code for creating the embedding and unembedding matrices.
Properties of E and U
From I, II, and (1.10) we get that:
ReLUd(wltal−1t)=UltReLUD(Eltwltal−1t)(1.11)
From III and (1.10) we can derive that, for any vector v∈Rd:
As a reminder, S here is the number of network neurons each circuit neuron is redundantly embedded into.
The derivation for this assumes that all network neurons are used by the same number of small circuits. In general this will not be strictly true, but it will be close enough to the truth.
Error calculation
In this section we’ll calculate approximately how much error is produced and propagated though each layer. We start by giving a mathematical definition of the total error we want to calculate. Then we split this error term into three parts, depending on the origin of the error. Then we estimate each error term separately; finally, we add them all together.
The main results can be found in the summary subsection at the end.
Defining the error terms
In order to define the errors,we introduce blt, which is the linear read-off of alt that the large networks can access as input for the next layer:
b0t:=a0t(2.1)blt:=UltAlforl≥1(2.2)
Inserting this into Equations (1.6) and (1.9), we get:
blt=UltReLUD(∑uEluwlubl−1u)forl≥1(2.3)
The error term ϵlt is then defined as the discrepancy between the activations of the small networks alt and their linear read-off in the big network blt:
ϵlt:=blt−alt(2.4)
Inserting the definitions for both alt and blt, we find that the discrepancy is 0 at the first layer.
In order to simplify this further, we will break up the expression inside the first ReLU.
To do this, we first notice that if we are holding pre-activation constant, then a ReLU is just a diagonal matrix, with diagonal values equal to 1 or 0, depending on the sign of the pre-activation. We use this fact to replace the first ReLUD with the diagonal matrix Rl∈RD×D and the second one with Rlt∈RD×D:
Note that Rl depends on the input; all other matrices here are the same for all inputs.
Now, we split the expression for the error in Equation (2.8) into three parts:
1) The embedding overlap error˚ϵltisthe part of the error that is due to correct activation in active circuits is spilling over into other circuits, because we are using an overcomplete basis. I.e, it’s because the embeddings are not completely orthogonal.
˚ϵlt:=UltRl∑u≠tEluwlual−1u(2.11)
This will turn out to be the leading error term for inactive circuits.
2) The propagation error~ϵlt is the part of the error that is caused by propagating errors from the previous layer.
~ϵlt:=UltRl∑uEluwluϵl−1u(2.12)
This will turn out to be the leading error termin active circuits from layer 2 onward, and the largest error overall.
3) The ReLU activation error˚ϵlt occurs when the ReLUs used by a circuit activate differently that they would if there were no interference, i.e. no noise and no other circuits active:
¨ϵlt:=Ult(Rl−Rlt)Eltwltal−1t(2.13)
This error term will turn out to be basically irrelevant in our construction.
The total error in Equation (2.8) is the sum of these three error components, Equations (2.11)-(2.13):
ϵlt=˚ϵlt+~ϵlt+¨ϵltforl≥1(2.14)
˚ϵlt – The embedding overlap error
The embedding overlap error, defined by Equation (2.11), is the error we get from storing the circuit neurons into the network neurons in superposition.
Calculations:
Looking at Equation (2.11), we first note that we only have to sum over active circuits, since we assumed that al−1u=0 for inactive circuits. Remember that there are T circuits in total, but only z≪T are active at a time:
˚ϵlt:=UltRl∑u≠tEluwlual−1u=UltRl∑u is activeu≠tEluwlual−1u(3.1)
So now we only have to care about the network neurons that are used by active circuits. In general we can’t make any assumptions about whether these neurons are active (i.e. (Rl)i,i=1) or inactive (i.e. (Rl)i,i=0). We’ll therefore go with the most conservative estimate that increases the error the most, which is Rl≈I:
∣∣˚ϵlt∣∣2=∣∣UltRl∑u is activeu≠tEluwlual−1u∣∣2≲∣∣Ult∑u is activeu≠tEluwlual−1u∣∣2(3.2)
To calculate the mean square of this error, we assume that wlual−1u for different circuits u are uncorrelated, and then use Equation (1.13):
E[∣∣˚ϵlt∣∣2]≲E[∣∣Ult∑u is activeu≠tEluwlual−1u∣∣2]=dD∑u is activeu≠tE[∣∣wlual−1u∣∣2](3.3)
The sum over u then has (z−1) or z terms depending on whether circuit t is active or inactive:
E[∣∣˚ϵlt∣∣2]=⎧⎪⎨⎪⎩O((z−1)dD)if t is activeO(zdD)if t is inactive(3.4)
This gives us the typical size of the embedding overlap error:
The propagation error, defined by Equation (2.12) is the largest and most important error overall. This error occurs when we perform computation in superposition over multiple layers, instead of just storing variables in superposition, or performing other types of single-layer superposition operations.
The existence of this error term is related to the fact that we need to embed not just the neurons, but also the weights of the circuits into the network. As opposed to the neuron activations of a circuit, the weights of a circuit don’t go away just because the circuit is turned off. This is why we get a factor √T in this term, where T is the total number of circuits, not just the number of active circuits z≪T. This why this error ends up being so large.
Since there are no errors at layer 0, we get ~ϵ1t=0, i.e. the propagation error does not show up until layer 2.
Calculation:
If we were to just conservatively assume that Rl=I for the purpose of this error estimate, as we did with the embedding error, we’d end up with an estimate:
~ϵlt=O(√TdD)×previous error(not true!)
Since Td>D, such an error term would quickly overtake the signal. Fortunately, the propagation error is actually much smaller than this, because of how our construction influence Rl.
As a reminder, Rl is a diagonal matrix of ReLU activations, defined as:
We will estimate the propagation error in Equation (2.12) by breaking it up into two cases: The error on neurons that are only used by inactive circuits, and the error on neurons that are used by at least one active circuit.
Case 1: Neurons used only by inactive circuits. For neurons i that are only used by inactive circuits:
(∑uEluwlu(al−1u+ϵl−1u))i=(∑u is inactiveEluwlu(al−1u+ϵl−1u))i(4.2)
Our assumption 3 at the start of the Construction section was that alu=0 for inactive circuits. Combining this with Equation (1.2), we have:
ReLU(wlual−1u)=0⇒wlual−1u≤0(4.3)
This is where our crucial assumption 2 from the start of the Construction section comes into play. We required the circuits to be noise robust when inactive, meaning that:
So, assuming that the previous error ϵl−1u is sufficiently small compared to the noise tolerance of the circuits, we get (Rl)i,i=0, provided that neuron i is only used by inactive circuits.
Case 2: Neurons that are used by at least one active circuit. Here, we can assume the same conservative estimate on Rl we made use of when calculating the embedding overlap error. I.e.(Rl)i,i=1 for these neurons.
This means that the propagation error can flow from active to active, inactive to active, and active to inactive circuits.
There is also a small amount of error propagation from inactive to inactive circuits, whenever the embedding overlap between two inactive circuits also overlaps with an active circuit. But this flow is very suppressed.
To model this, we start with the approximation of Rl we derived in our two cases above:
(Rl)i,i≈min[1,∑v is active(ElvUlv)i,i](4.5)
The minimum in this expression is very annoying to deal with. So we overestimate the error a tiny bit more by using the approximation:
UltRlElu≈{UltEluif either circuit t or u are activeUlt∑v is activeUlvElvEluotherwise(4.6)
As a reminder, the definition of the propagation error ~ϵlt was:
~ϵlt:=UltRl∑uEluwluϵl−1u(2.12)
Inserting our approximation of UltRlElu into this yields:
~ϵlt is active≲wltϵl−1t+Ult∑u≠tEluwluϵl−1u(4.7)~ϵlt is inactive≲Ult∑u is activeEluwluϵl−1u+Ult∑v is activeElvUlv∑u is inactiveEluwluϵl−1u(4.8)
Using similar calculations as those for the embedding overlap error, we get:
There are two combinations of (Rl)i,i and (Rlt)i,i that can contribute to ¨ϵlt. These are (Rl)i,i=1, (Rlt)i,i=0, and (Rl)i,i=0, (Rlt)i,i=1.
Case 1: (Rl)i,i=1,(Rlt)i,i=0
This will happen if and only if
Δi>−(Eltwltal−1t)i≥0(5.4)
and in this case the ¨ϵlt contribution term will be:
(Rl−Rlt)i,i(Eltwltal−1t)i=(Eltwltal−1t)i(5.5)
Δt is the source of the error calculated in the previous sections. Notice that in this case the ReLU activation error contribution (Eltwltal−1t)i is smaller and has opposite sign compared to Δt. We can therefore safely assume that it will not increase the overall error.
Case 2: (Rl)i,i=0,(Rlt)i,i=1
This will happen if and only if:
−Δi≥(Eltwltal−1t)i>0(5.6)
In this case, the ¨ϵlt contribution term will be:
(Rl−Rlt)i,i(Eltwltal−1t)i=−(Eltwltal−1t)i(5.7)
So, the ReLU activation error contribution −(Eltwltal−1t)i is still smaller in magnitude than Δt but does have the same sign as Δt.
However, since (Rl)i,i=0, Δt does not contribute to the total error at all. But in our calculations for the other two error terms, we didn’t know the value of (Rl)i,i, so we included this error term anyway.
So in this case, the error term coming from the ReLU activation error is also already more than accounted for.
ϵlt – Adding up all the errors
Let’s see how the three error terms add up, layer by layer:
Layer 0
There are no errors here because nothing has happened yet. See equation (2.5):
ϵ0active=0(6.1)ϵ0inactive=0(6.2)
Layer 1
Since there was no error in the previous layer, we only get the embedding overlap error. From Equations (3.5) and (3.6):
For the same reason as last layer, the inactive error is
ϵ3inactive≈˚ϵinactive=O(√zdD)(6.11)
From here on, it just keeps going like this.
Worst-case errors vs mean square errors
So far, our error calculations have only dealt with mean square errors. However, we also need to briefly address worst-case errors. Those are why we need to have an embedding redundancy S>1 in the construction.
For this, we’re mainly concerned with the embedding overlap error, because that error term comes from just a few sources (the z active circuits), which means its variance may be high. In contrast, the propagation error comes from adding up many smaller contributions (the approximately zS2TdD[19] circuits that have non-zero embedding overlap error in the previous layer), so we expect it to be well behaved, i.e. well described by the previous mean square error calculations.
The worst-case embedding overlap error happens if some circuit is unlucky enough to be an embedding neighbour to all z active circuits. For active circuits, the maximum number of active neighbours is (z-1) since a circuit can’t be its own neighbour.
So, as long as the embedding redundancy S is sufficiently large compared to the number of active circuits z, we should be fine.
Summary:
The main source of error is signal from the active circuits bleeding over into the inactive ones, which then enters back into the active circuits as noise in the next layer.
The noise in the active circuits accumulate from layer to layer. The noise in the inactive circuits does not accumulate.
At layer 0, there are no errors, because nothing has happened yet.
ϵ0active=0(6.14)ϵ0inactive=0(6.15)
At layer 1 and onward, the leading term for the error on inactive circuits is:
ϵlinactive=O(√zdD)forl≥1(6.16)
At layer 1, the leading term for the error on active circuits is:
ϵ1active=O⎛⎝√(z−1)dD⎞⎠(6.17)
But from layer 2 onward, leading term for the error on active circuits is:
ϵlactive=O⎛⎝√(l−1)zTd2D2⎞⎠forl≥2(6.18)
Discussion
Noise correction/suppression is necessary
Without any type of noise correction or error suppression, the error on the circuit activations would grow by O(√TdD) per layer.
Td is the total number of neurons per layer for all small networks combined, and D is the number of neurons per layer in the large network. If Td≤D then we might as well encode one feature per neuron, and not bother with superposition. So, we assume dT>D, ideally even Td≫D.
In our construction, the way we suppress errors is to use the flat part of the ReLU, both to clear away noise in inactive circuits, and to prevent noise from moving between inactive circuits. Specifically, we assumed in assumption 2 of our construction that each small circuit is somewhat noise robust, such that any network neuron that is not connected to a currently active circuit will be inactive (i.e. the ReLU pre-activation is ≤0) provided that the error on the circuit activations in the preceding layer is small enough. This means that for the error to propagate to the next layer, it has to pass though a very small fraction of ‘open’ neurons, which is what keeps the error down to a more manageable O(√Tzd2D2).
However, we do not in general predict sparse ReLU activations for networks implementing computation in superposition
The above might then seem to predict a very low activation rate for neurons in LLMs and other neural networks, if they are indeed implementing computation superposition. That’s not what we see in real large networks, e.g. MLP neurons in gpt2 have an activation rate of about 20%, much higher than our construction.
But this low activation rate is actually just an artefact of the specific setup for computation in superposition we present here. Instead of suppressing noise with the flat part of a single ReLU function, we can also create a flat slope using combinations of multiple active neurons. E.g. a network could combine two ReLU neurons to create the ‘meta activation function’ f(x)=ReLU(x)−ReLU(x−1). This combined function is flat for both x<0 (both neurons are ‘off’) and x>1 (both neurons are ‘on’). We can then embed circuit neurons into different ‘meta neurons’ f instead of embedding them into the raw network neurons.
At a glance, this might seem inefficient compared to using the raw ReLUs, but we don’t think it necessarily is. If f is a more suitable activation function for implementing the computation of many of the circuits, those circuits might effectively have a smaller width d under this implementation. The network might even mix and match different ‘meta activation functions’ like this in the same layer to suit different circuits.
But we do tentatively predict that circuits only use small subsets of network neurons
So, while neurons in the network don’t have to activate sparsely, it is necessary that each circuit only uses a small fraction of network neurons for anything like this construction to work. This is because any connection that lets through signal will also let through noise, and at least neurons[20] used by active circuits must let signal through, or else they won’t be of any use.
Getting around this would require some completely different kind of noise reduction. It seems difficult to do this using MLPs alone. Perhaps operations like layer norm and softmax can help with noise reduction somehow, that’s not something we have investigated yet.
Linda: But using few neurons per circuit does just seem like a pretty easy way to reduce noise, so I expect that networks implementing computation in superposition would do this to some extent, whether or not they also use other noise suppression methods. I have some very preliminary experimental evidence that maybe supports this conclusion[21]. More on that in a future post, hopefully.
The linked blogpost have a section called “Computation in Superposition”. However, on closer inspection, this section only presents a model with one layer in super position. See section “Implications for experiments in computation in super position”, for why this is insufficient.
This result also seems to hold if the circuits don’t have a uniform width d across the L layers. However, it might not straightforwardly hold if different circuits interact with each other, e.g. if some circuits take the outputs of previous circuits as inputs. We think this makes intuitive sense from an information-theoretic perspective as well. If circuits interact, we have to take those interactions into account in the description length.
By ‘layers’ we mean A1,A2,…AL, as defined down in equation (1.1). So, ‘layer 2’ refers to A2 or any linear readouts of A2. We don’t count A0 in the indexing because no computation has happened yet at that point, and when training toy models of computation in superposition A0 will often not even be explicitly represented.
If we want to study this noise, but don’t want to construct a complicated multi-layer toy model, we can add some noise to the inputs, to simulate the situation we’d be in at a later layer.
If circuits are doing similar things, this lets us reuse some computational machinery between circuits, but it can also make worst-case errors worse if we’re not careful, because errors in different circuits can add systematically instead of randomly.
The results here can be generalised to networks with residual streams, although for this, the embedding has to be done differently from the post linked above, otherwise the error propagation will not be contained.
You might notice that for both assumption 2 and 3 to hold simultaniusly, each small network needs a negative bias. Also, we did not include a separate bias term in the construction. Also, the bias can’t be baked into w (as is often done) because that would require that one of the neurons to be a non-zero constant, which contradicts assumption 2.
This is one of several ways that reality turned out to be more complicated than our theory. The good news is that this can be dealt with, in a way that preserves the general result, but that is beyond the scope of this post.
There exist a theoretical upper bound on S, S(S−1)≤Dd(Dd−1)T. However the proof [22][23] of this upper bound is not constructive, so there is no guarantee that it can be reached.
Our allocation algorithm falls short of this upper bound. If you think you can devise a construction that gets closer to it, we’d love to hear from you.
which would give us more possible allocations (i.e. larger T for a given S and Dd, or larger possible S for a given T and Dd). But that would result in a more uneven distribution of how much each neuron is used.
Later calculation assumes that the allocation is evenly distributed over neurons. Dropping that assumption would make both the calculation harder, and the errors larger. Getting to increase S is probably not worth it.
We know that (a.3) must be true in this case because stepx≠stepy
Case: i≠j
From (a.2) we get that
stepx=pSn;stepy=qSmforp,q∈PS;n,m∈N0(a.4)
We have proved (a.3) in this case if we can prove that
ipSn≠jqSm(a.5)
We also know that (a.5) must be true in the case because iSn will differ from jSm in at least one prime factor that is smaller than S. p and q can’t make up for that difference since by definition (a.1) they don’t contain prime factors smaller than S.
We initially did not think of this and only notice the importance of re-shuffling from layer to layer, when implementing this construction in code.
In the error calculation, when calculating how much noise is transported from inactive circuits to active circuits, we assume no correlation between the amount of noise in the inactive circuits and to what amount they share neurons with active circuits. But the noise depends on how much neuron overlap they had with active circuits in the last layer. Therefore this assumption will be false if we don’t re-shuffle the neuron allocation from layer to layer.
Not only will our calculation be wrong (this can be solved by more calculations) but also, the errors will be much larger, which is simply not a good construction.
The probability of two circuits sharing one large network neuron (per small circuit neuron) is S2dD. Given that there is T total circuits, this gives us S2TdD “neighbour” circuits for each small circuit. Since there are z active circuits there is approximately zS2TdD active circuit neighbours.
Linda: I trained some toy models of superposition with only one computational layer. This did not result in circuits connecting sparsely to the network’s neurons. Then I trained again with some noise added to the network inputs (in order to simulate the situation in a 2+ layer network doing computation in superposition), to see how the network would learn to filter it. This did result in circuits connecting sparsely to the network’s neurons.
This suggests to me that there is no totally alternate way to filter superposition noise in an MLP we haven’t thought of yet. So networks doing computation in superposition would basically be forced to connect circuits to neurons sparsely to deal with the noise, as the math in this posts suggests.
However, this experimental investigation is still a work in progress.[24]
There are Dd(Dd−1) pairs of neurons in the set of Dd neurons. Each small circuit is allocated S neurons out of that set, accounting for S(S−1) pairs. No two small circuits can share a pair, which gives us the bound T≤Dd(Dd−1)S(S−1).
I (Linda) first got this bound and proof from ChatGPT (free version). According to ChatGPT it is a “known upper bound (due to the Erdős–Ko–Rado-type results)”.
My general experience is that ChatGPT is very good at finding known theorems (i.e. known to someone else, but not to me) that apply to any math problem I give it.
I also gave this problem to Claude as an experiment (some level of paid version offered by a friend). Claude tried to figure it out itself, but kept getting confused and just produced a lot of nonsense.
These networks were trained on L2 loss, which is probably the wrong loss function for incentivising superposition. When using L2 loss norm, the network doesn’t care much about separating different circuits. It’s happy to just embed two circuits right on top of each other into the same set of network neurons. I don’t really consider this to be computation in superposition. However, this should not affect the need for the network to prevent noise amplification, which is why I think these results are already some weak evidence for the prediction.
I’ll make a better toy setup that and hopefully present the result of these experiments in a future post.
Circuits in Superposition 2: Now with Less Wrong Math
Summary & Motivation
This post is a continuation and clarification of Circuits in Superposition: Compressing many small neural networks into one. That post presented a sketch of a general mathematical framework for compressing different circuits into a network in superposition. On closer inspection, some of it turned out to be wrong, though. The error propagation calculations for networks with multiple layers were incorrect. With the framework used in that post, the errors blow up too much over multiple layers.
This post presents a slightly changed construction that fixes those problems, and improves on the original in some other ways as well.[1]
By computation in superposition we mean that a network represents features in superposition and performs more computations with them than it has neurons, across multiple layers. Having better models of this is important for understanding how and even if networks use superposition, which in turn is important for mech-interp in general.
Performing computation in superposition over multiple layers introduces additional noise compared to just storing features in superposition[2]. This restricts the amount and type of computation that can be implemented in a network of a given size, because the noise needs to be reduced or suppressed to stay smaller than the signal.
Takeaways
Our setup in this post (see the Section Construction for details) is as follows:
We have T small circuits, each of which can be described as a d-dimensional multilayer perceptron (MLP) with L layers.
We have one large D-dimensional MLP with L layers, where D>d , but D<Td. So we can’t just dedicate d neurons in the large MLP to each circuit.
We embed all T circuits into the large network, such that the network approximately implements the computations of every circuit, conditional on no more than z<<T circuits being used on any given forward pass.
The number of circuits we can fit in scales linearly with the number of network parameters
Similar to the previous post, we end up concluding that the total number of parameters in the circuits must be smaller than the number of parameters in the large network. This result makes a lot of intuitive sense, since the parameters determine the maximum amount of information a network can possibly store.[3]
More specifically, we find that the term √zTd2D2 needs to be smaller than 1 for the errors on the computations of individual circuits to stay smaller than the signal.
Here T is the total number of circuits, d is the width of each small circuit, D is the layer width of the large network and z is the number of circuits active on a given forward pass.
This gives us an approximate upper bound on the maximum number of d-dimensional circuits we can fit into a network with D neurons per layer:
Tmax=O(1zD2d2)(0.1)If you only remember one formula from this post, let it be that one.
This is much smaller than the number of d-dimensional features we can store in superposition in a layer of width D, if we don’t intend to use them for calculations within the network. That number is[4]
Tmax storage=O(1deD8zd)(0.2)So, while storage capacity scales exponentially with D, capacity for computation only scales quadratically.
Each circuit will only use a small subset of neurons in the larger network
For this construction to work, each circuit can only be using a small fraction of the large network’s neurons.
I, Linda, expect this to be true more generally. I think basically any construction that achieves computation in superposition, across multiple layers, in the sense we mean here will (approximately) have this property. My reasons for thinking this are pretty entangled with details of the error propagation math, so I’ve relegated them to the Discussion section.
Implications for experiments on computation in superposition
The leading source of error in this construction is signals from the active circuits (used on the forward pass) bleeding into currently inactive circuits that shouldn’t be doing anything. This bleed-over then enters back into the active circuits as noise in the next layer.
This means that the biggest errors don’t appear until layer 2[5]. This is important to keep in mind for experimental investigations of computation in superposition, e.g. when training toy models. If your network only has one computational layer, then it doesn’t have to implement a way to reduce this noise.[6]
Reality really does have a surprising amount of detail
To make sure the math was actually really correct this time around, Linda coded up a little model implementing some circuits in superposition by hand.
Naturally, while doing this, she found that there were still a bunch of fiddly details left to figure out how to make circuits in superposition actually work in real life, even on a pretty simple example, because the math makes a bunch of vague assumptions about the circuits that turn out to be important when you actually get down to making things work in practice.
The math presented in this post won’t deal with those fiddly details. It is intended to be a relatively simple high level description of a general framework. E.g, we assume that individual circuits have some level of noise robustness around values representing ‘inactive’ for that circuit, without worrying about how it’s achieved.
So, in actual practice, the details of this construction may need some adjustment depending on how exactly individual circuits implement their noise robustness, and whether any of them are doing similar things.[7]
A post with the hand coded model and the fiddly details should be coming out “SoonTM”.
Construction
The construction in this post has some significant differences from the previous one.
To simplify things, the networks here don’t have a residual stream, they’re just deep MLPs.[8] We have one large MLP with L layers, neuron dimension D, activation vectors Al∈RD, and weight matrices Wl∈RD×D.
Al=ReLUD(WlAl−1)forl≥1(1.1)We also have T circuits, indexed by 0,…,T−1, each described by a small MLP with L layers, neuron dimension d, activations vectors alt∈Rd, and weight matrices wlt∈Rd×d.
alt=ReLUd(wltal−1t)forl≥1,t=0,…,T−1(1.2)Our goal is to figure out a construction for the weight matrices Wl, which embeds the circuits into the network, such that the outputs of each circuit can be read-out from the final output of the network AL with linear projections, up to some small error.[9]
Assumptions
For this construction to work as intended, we need to assume that:
Only z≪Dd circuits can be active on any given forward pass.
Small circuits are robust to noise when inactive. I.e. a small deviation to the activation value of an inactive circuit applied in layer l will not change the activation value of that circuit in layer l+1.[10]
If a circuit is inactive, all of its neurons have activation value zero. I.e. alt=0 if circuit t is inactive.
The entries of the weight matrices wlt for different circuits in the same layer are uncorrelated with each other.
Assumption 1 is just the standard sparsity condition for superposition.
Assumption 2 is necessary, but if it is not true for some of the circuits we want to implement, we can make it true by modifying them slightly, in a way that doesn’t change their functionality.[11] How this works will not be covered in this post though.
Assumptions 3 and 4 are not actually necessary for something similar to this construction to work, but without them the construction becomes more complicated. The details of this are also beyond the the scope of this post.
Embedding the circuits into the network
The important takeaways from this section are Equations (1.11) and (1.13)-(1.14), which we will make use of in the Error Calculation section. If you’re happy with these and don’t think they require further motivation, then you don’t need to read the rest of this section.
Remember that the circuit weights wlt and their activation vectors alt are handed to us and we can’t change them, but we are free to choose the weights Wl of the network to be whatever we like. We also assume that we get to choose how to linearly embed the input vectors of the circuits a0t into the input vector of the network A0 at the start.
To help with the embedding we will introduce:
Embedding matrices Elt∈RD×d for each circuit t in each layer l≥1
Unembedding matrices Ult∈Rd×D for each circuit t in each layer l≥1.
Our goal is to calculate Equation (1.2) using the network, which (due to our choice of Ult and Elt, see next section) can be re-expressed as
alt=ReLUd(wltal−1t)=UltReLUD(Eltwltal−1t)forl≥1(1.3)We approximate this as
alt≈UltAlforl≥1(1.4)and
Al≈ReLUD(∑tEltwltal−1t)forl≥1(1.5)If we combine Equations (1.4) and (1.5) while pretending[12] they are exact relations, we get
Al=ReLUD(∑tEltwltUl−1tAl−1)forl≥2(1.6)If we combine that with Equation (1.1) the network weights Wl for l≥2 are now defined via the embedding and unembedding matrices as
Wl=∑tEltwltUl−1tforl≥2(1.7)You might notice that this leaves W1 undefined, and that there are no embedding and unembedding matrices for l=0. That’s because layer zero is a bit special.
Layer 0
There are no embedding and unembedding matrices for layer 0, because we can just skip ahead and use our free choice of how to linearly embed a0t into A0t to implement the first matrix multiplications in each circuit w1t without any interference errors.
We choose
W1=IandA0=∑tE1tw1ta0t(1.8)which gives us:
A1=ReLUD(∑tE1tw1ta0t)(1.9)I.e, (1.4) is exactly true in the first layer. Down in the Error Calculation section, this will have consequences for which layer each error term first shows up in.
Maybe you think this is sort of cheating. Perhaps it is, but a model can train to cheat like this as well. That’s part of the point we want to make in this post: Having more than one layer makes a difference. From layer 2 onward, this kind of thing doesn’t work anymore. We no longer gets to assume that every feature always starts embedded exactly where we want it. We have to actually compute on the features in a way that leaves every intermediary result embedded such that later layers can effectively compute on them in turn.
Constructing the Embedding and Unembedding matrices
In this subsection, we first lay out some requirements for the embedding matrices Elt and unembedding matrices Ult, then describe an algorithm for constructing them to fulfil those requirements.
The main takeaway here is that we can construct Elt and Ult such that they have the properties described in the final subsection. If you are happy to just trust us on this, then you don’t need to read the rest of this section.
Remember that each neuron in the large network will be used in the embeddings of multiple small circuit neurons. This is inevitable because the total numbers of neurons across all small circuits is larger than the number of neurons in the large network, dT>D, as per assumption 1.
Requirements
Neurons from the same circuit should be embedded into non-overlapping sets of neurons in the large network. We want this because neurons in the same circuit will tend to coactivate a lot, so reducing interference between them is especially important.
Neurons from different small circuits should be embedded into sets of neurons in the large network that have at most overlap one. So, no pair of circuit neurons shares more than one network neuron. This ensures that there is a hard bound on how bad errors from interference between different circuits can get.
The embedding should distribute circuit neurons approximately evenly over the network neurons. Otherwise we’re just wasting space.
Our construction will satisfy these requirements and redundantly embed each circuit neuron into a different subset of S>1 network neurons. We call S the embedding redundancy.
Generally speaking, larger values of S are good, because they make each circuit more robust to worst-case errors from interference with other circuits. However, our requirement 2 creates an upper bound on S, because the more network neurons we assign to each circuit neuron, the harder it becomes to ensure that no two circuit neurons share more than one network neuron.[13]
Our allocation scheme is based on prime number factorisation. We promise it’s not as complicated as that description may make it sound. The details of how it works are not important for understanding the rest of this post though, so feel free to skip it if you’re not interested.
Step 1
First, we split the network neurons (for each layer) into d sets of Dd neurons[14]. The first neuron of each circuit will be embedded in the first of these sets, the second neuron of each circuit will be embedded into the second set and so on, with the d-th neuron of each circuit embedded into the d-th set of network neurons.
So next, we need to describe how to embed each set of T circuit neurons into a set of Dd network neurons.
Step 2
The following algorithm allocates S>1 out of Dd large neurons to each small circuit, while ensuring that each pair of circuits shares at most one such allocation, and that the allocations are fairly evenly spread out over the Dd neurons.
The allocation algorithm:
Each small circuit neuron is assigned to S large network neurons that are spaced step neurons apart. The set of possible_steps is chosen such that no two different neuron_allocations will overlap more than once. The set of possible_steps is based on prime factors.
First, we define a set of possible_steps[15] and a function that draws a new instance from that set:
PS:={p∈N | smallest prime factor of p is larger than S}(a.1)possible_steps:={step=pSn | p∈PS ; n∈N0 ; pSn(S−1)≤Dd}(a.2)Here, N denotes the natural numbers, N={1,2,3…}, and N0 denotes the natural numbers including zero, N0={0,1,2,3…}.
Next we chose one step from possible_steps. We use this to generate approximately DSd non-overlapping neuron_allocations, where each neuron_alocation consists of S subsequent neurons where each one is step neurons away from the previous one.
When we can’t generate more non-overlapping neuron_allocations from our first step we draw a new step from possible_steps. Repeat until we have T neuron allocations, i.e, one for each small circuit.
Pseudo code:
Why it works:
This will make sure no neuron_allocations share more than one neuron because[17]
i×stepx≠j×stepyfor⎧⎪⎨⎪⎩i,j∈{1,2,…(S−1)}stepx,stepy∈possible_stepsstepx≠stepy(a.3)which means that if two different neuron_allocations with different steps, share one neuron, all their other neurons are guaranteed to be different.
Step 3
We construct the embedding matrix Elt from d column vectors, each of which has S non-zero values, based on the allocation from Step 2.
Pseudo code:
Step 4
To reduce noise, it’s important that circuit neurons don’t keep having the same “neighbours” from layer to layer, where “neighbours” are neurons from other different circuits that share a network neuron.[18] To ensure this, we just re-shuffle the circuit index t at every layer. I.e. for each layer, we perform a random permutation on what set of S network neurons is allocated to each circuit neuron.
Step 5
The unembedding matrix is now just the transpose of the embedding matrix normalised by the embedding redundancy S:
Ult=1S(Elt)⊤(1.10)Real python code
If you prefer real code over words + pseudo code, then here’s a Colab Notebook with the full code for creating the embedding and unembedding matrices.
Properties of E and U
From I, II, and (1.10) we get that:
ReLUd(wltal−1t)=UltReLUD(Eltwltal−1t)(1.11)From III and (1.10) we can derive that, for any vector v∈Rd:
Et≠u[UltEluv]=dSDv(1.12)Et≠u[|UltEluv|2]=dD|v|2(1.13)maxt≠u[|UltEluv|]=|v|S(1.14)As a reminder, S here is the number of network neurons each circuit neuron is redundantly embedded into.
The derivation for this assumes that all network neurons are used by the same number of small circuits. In general this will not be strictly true, but it will be close enough to the truth.
Error calculation
In this section we’ll calculate approximately how much error is produced and propagated though each layer. We start by giving a mathematical definition of the total error we want to calculate. Then we split this error term into three parts, depending on the origin of the error. Then we estimate each error term separately; finally, we add them all together.
The main results can be found in the summary subsection at the end.
Defining the error terms
In order to define the errors, we introduce blt, which is the linear read-off of alt that the large networks can access as input for the next layer:
b0t:=a0t(2.1)blt:=UltAlforl≥1(2.2)Inserting this into Equations (1.6) and (1.9), we get:
blt=UltReLUD(∑uEluwlubl−1u)forl≥1(2.3)The error term ϵlt is then defined as the discrepancy between the activations of the small networks alt and their linear read-off in the big network blt:
ϵlt:=blt−alt(2.4)Inserting the definitions for both alt and blt, we find that the discrepancy is 0 at the first layer.
ϵ0t=0(2.5)For later layers, the error is:
ϵlt=UltReLUD(∑uEluwlu(al−1u+ϵl−1u))−ReLUd(wltal−1t)forl≥1(2.6)We can use Equation (1.11) to make the second term more similar to the first term.
ϵlt=UltReLUD(∑uEluwlu(al−1u+ϵl−1u))−UltReLUD(Eltwltal−1t) for l≥1(2.7)In order to simplify this further, we will break up the expression inside the first ReLU.
To do this, we first notice that if we are holding pre-activation constant, then a ReLU is just a diagonal matrix, with diagonal values equal to 1 or 0, depending on the sign of the pre-activation. We use this fact to replace the first ReLUD with the diagonal matrix Rl∈RD×D and the second one with Rlt∈RD×D:
ϵlt=UltRl∑uEluwlu(al−1u+ϵl−1u)−UltRltEluwltal−1tforl≥1(2.8)(Rl)i,j:={1ifi=jand(∑uEluwlu(al−1u+ϵl−1u))i>00otherwise(2.9)(Rlt)i,j:={1ifi=jand(Eltwltal−1t)i>00otherwise(2.10)Note that Rl depends on the input; all other matrices here are the same for all inputs.
Now, we split the expression for the error in Equation (2.8) into three parts:
1) The embedding overlap error ˚ϵlt is the part of the error that is due to correct activation in active circuits is spilling over into other circuits, because we are using an overcomplete basis. I.e, it’s because the embeddings are not completely orthogonal.
˚ϵlt:=UltRl∑u≠tEluwlual−1u(2.11)This will turn out to be the leading error term for inactive circuits.
2) The propagation error ~ϵlt is the part of the error that is caused by propagating errors from the previous layer.
~ϵlt:=UltRl∑uEluwluϵl−1u(2.12)This will turn out to be the leading error term in active circuits from layer 2 onward, and the largest error overall.
3) The ReLU activation error ˚ϵlt occurs when the ReLUs used by a circuit activate differently that they would if there were no interference, i.e. no noise and no other circuits active:
¨ϵlt:=Ult(Rl−Rlt)Eltwltal−1t(2.13)This error term will turn out to be basically irrelevant in our construction.
The total error in Equation (2.8) is the sum of these three error components, Equations (2.11)-(2.13):
ϵlt=˚ϵlt+~ϵlt+¨ϵltforl≥1(2.14)˚ϵlt – The embedding overlap error
The embedding overlap error, defined by Equation (2.11), is the error we get from storing the circuit neurons into the network neurons in superposition.
Calculations:
Looking at Equation (2.11), we first note that we only have to sum over active circuits, since we assumed that al−1u=0 for inactive circuits. Remember that there are T circuits in total, but only z≪T are active at a time:
˚ϵlt:=UltRl∑u≠tEluwlual−1u=UltRl∑u is activeu≠tEluwlual−1u(3.1)So now we only have to care about the network neurons that are used by active circuits. In general we can’t make any assumptions about whether these neurons are active (i.e. (Rl)i,i=1) or inactive (i.e. (Rl)i,i=0). We’ll therefore go with the most conservative estimate that increases the error the most, which is Rl≈I:
∣∣˚ϵlt∣∣2=∣∣UltRl∑u is activeu≠tEluwlual−1u∣∣2≲∣∣Ult∑u is activeu≠tEluwlual−1u∣∣2(3.2)To calculate the mean square of this error, we assume that wlual−1u for different circuits u are uncorrelated, and then use Equation (1.13):
E[∣∣˚ϵlt∣∣2]≲E[∣∣Ult∑u is activeu≠tEluwlual−1u∣∣2]=dD∑u is activeu≠tE[∣∣wlual−1u∣∣2](3.3)The sum over u then has (z−1) or z terms depending on whether circuit t is active or inactive:
E[∣∣˚ϵlt∣∣2]=⎧⎪⎨⎪⎩O((z−1)dD)if t is activeO(zdD)if t is inactive(3.4)This gives us the typical size of the embedding overlap error:
˚ϵactive=O⎛⎝√(z−1)dD⎞⎠(3.5)˚ϵinactive=O(√zdD)(3.6)~ϵlt – The propagation error
The propagation error, defined by Equation (2.12) is the largest and most important error overall. This error occurs when we perform computation in superposition over multiple layers, instead of just storing variables in superposition, or performing other types of single-layer superposition operations.
The existence of this error term is related to the fact that we need to embed not just the neurons, but also the weights of the circuits into the network. As opposed to the neuron activations of a circuit, the weights of a circuit don’t go away just because the circuit is turned off. This is why we get a factor √T in this term, where T is the total number of circuits, not just the number of active circuits z≪T. This why this error ends up being so large.
Since there are no errors at layer 0, we get ~ϵ1t=0, i.e. the propagation error does not show up until layer 2.
Calculation:
If we were to just conservatively assume that Rl=I for the purpose of this error estimate, as we did with the embedding error, we’d end up with an estimate:
~ϵlt=O(√TdD)×previous error(not true!)Since Td>D, such an error term would quickly overtake the signal. Fortunately, the propagation error is actually much smaller than this, because of how our construction influence Rl.
As a reminder, Rl is a diagonal matrix of ReLU activations, defined as:
(Rl)i,j:={1ifi=jand(∑uEluwlu(al−1u+ϵl−1u))i>00otherwise(4.1)We will estimate the propagation error in Equation (2.12) by breaking it up into two cases: The error on neurons that are only used by inactive circuits, and the error on neurons that are used by at least one active circuit.
Case 1: Neurons used only by inactive circuits. For neurons i that are only used by inactive circuits:
(∑uEluwlu(al−1u+ϵl−1u))i=(∑u is inactiveEluwlu(al−1u+ϵl−1u))i(4.2)Our assumption 3 at the start of the Construction section was that alu=0 for inactive circuits. Combining this with Equation (1.2), we have:
ReLU(wlual−1u)=0⇒wlual−1u≤0(4.3)This is where our crucial assumption 2 from the start of the Construction section comes into play. We required the circuits to be noise robust when inactive, meaning that:
ReLU(wlu(al−1u+small noise))=0⇒wlu(al−1u+small noise)≤0(4.4)So, assuming that the previous error ϵl−1u is sufficiently small compared to the noise tolerance of the circuits, we get (Rl)i,i=0, provided that neuron i is only used by inactive circuits.
Case 2: Neurons that are used by at least one active circuit. Here, we can assume the same conservative estimate on Rl we made use of when calculating the embedding overlap error. I.e.(Rl)i,i=1 for these neurons.
This means that the propagation error can flow from active to active, inactive to active, and active to inactive circuits.
There is also a small amount of error propagation from inactive to inactive circuits, whenever the embedding overlap between two inactive circuits also overlaps with an active circuit. But this flow is very suppressed.
To model this, we start with the approximation of Rl we derived in our two cases above:
(Rl)i,i≈min[1,∑v is active(ElvUlv)i,i](4.5)The minimum in this expression is very annoying to deal with. So we overestimate the error a tiny bit more by using the approximation:
UltRlElu≈{UltEluif either circuit t or u are activeUlt∑v is activeUlvElvEluotherwise(4.6)As a reminder, the definition of the propagation error ~ϵlt was:
~ϵlt:=UltRl∑uEluwluϵl−1u(2.12)Inserting our approximation of UltRlElu into this yields:
~ϵlt is active≲wltϵl−1t+Ult∑u≠tEluwluϵl−1u(4.7)~ϵlt is inactive≲Ult∑u is activeEluwluϵl−1u+Ult∑v is activeElvUlv∑u is inactiveEluwluϵl−1u(4.8)Using similar calculations as those for the embedding overlap error, we get:
~ϵlactive=O(1)ϵl−1active+O(√TdD)ϵl−1inactive(4.9)~ϵlinactive=O(√zdD)ϵl−1active+O(√zTd2D2)ϵl−1inactive(4.10)¨ϵlt – The ReLU activation error
This error term, defined in Equation (2.13) ends up being negligible. Sometimes it might even reduce the overall error a little.
Calculations:
To help us show this, we introduce Δt:
Δt:=⎛⎝∑u≠tEluwlual−1u+∑uEluwluϵl−1u⎞⎠i(5.1)With this, the definition of Rl, (2.9) becomes:
(Rl)i,j:={1ifi=jand(∑uEltwlt+Δt)i>00otherwise(5.2)Remember that the definition of Rlt, is:
(Rlt)i,j:={1ifi=jand(Eluwlual−1u)i>00otherwise(2.10)¨ϵlt, (2.13) can thus be written as:
(¨ϵlt)j:=∑i(Ult)j,i(Rl−Rlt)i,i(Eltwltal−1t)i(5.3)There are two combinations of (Rl)i,i and (Rlt)i,i that can contribute to ¨ϵlt. These are (Rl)i,i=1, (Rlt)i,i=0, and (Rl)i,i=0, (Rlt)i,i=1.
Case 1: (Rl)i,i=1,(Rlt)i,i=0
This will happen if and only if
Δi>−(Eltwltal−1t)i≥0(5.4)and in this case the ¨ϵlt contribution term will be:
(Rl−Rlt)i,i(Eltwltal−1t)i=(Eltwltal−1t)i(5.5)Δt is the source of the error calculated in the previous sections. Notice that in this case the ReLU activation error contribution (Eltwltal−1t)i is smaller and has opposite sign compared to Δt. We can therefore safely assume that it will not increase the overall error.
Case 2: (Rl)i,i=0,(Rlt)i,i=1
This will happen if and only if:
−Δi≥(Eltwltal−1t)i>0(5.6)In this case, the ¨ϵlt contribution term will be:
(Rl−Rlt)i,i(Eltwltal−1t)i=−(Eltwltal−1t)i(5.7)So, the ReLU activation error contribution −(Eltwltal−1t)i is still smaller in magnitude than Δt but does have the same sign as Δt.
However, since (Rl)i,i=0, Δt does not contribute to the total error at all. But in our calculations for the other two error terms, we didn’t know the value of (Rl)i,i, so we included this error term anyway.
So in this case, the error term coming from the ReLU activation error is also already more than accounted for.
ϵlt – Adding up all the errors
Let’s see how the three error terms add up, layer by layer:
Layer 0
There are no errors here because nothing has happened yet. See equation (2.5):
ϵ0active=0(6.1)ϵ0inactive=0(6.2)Layer 1
Since there was no error in the previous layer, we only get the embedding overlap error. From Equations (3.5) and (3.6):
ϵ1active=˚ϵactive=O⎛⎝√(z−1)dD⎞⎠(6.3)ϵ1inactive=˚ϵinactive=O(√zdD)(6.4)Layer 2
This is the first layer where we get both the embedding overlap error (3.5) - (3.6) and the propagation error (4.10) - (4.11):
ϵ2active=˚ϵactive+~ϵ2active=O⎛⎝√(z−1)dD⎞⎠+O(1)ϵ1active+O(√TdD)ϵ1inactive(6.5)The leading term is the last term, i.e. the propagation error flowing in from the inactive circuits in the previous layer:
ϵ2active≈O(√TdD)ϵ1inactive=O(√zTd2D2)(6.6)The typical noise in any inactive circuit is:
ϵ2inactive=˚ϵinactive+~ϵ1inactive=O(√zdD)+O(√zdD)ϵ1active+O(√zTd2D2)ϵ1inactive(6.7)Assuming that the noise in the previous layer is small, the leading term is the embedding overlap error, ˚ϵ.
ϵ2inactive≈˚ϵinactive=O(√zdD)(6.8)Layer 3
ϵ3active=˚ϵactive+~ϵ3active=O⎛⎝√(z−1)dD⎞⎠+O(1)ϵ2active+O(√TdD)ϵ2inactive(6.9)Now both the last terms are of the same size.
ϵ3active≈O(√TdD)√(ϵ1inactive)+(ϵ2inactive)2≈O(√2TdD)˚ϵ=O(√2zTd2D2)(6.10)For the same reason as last layer, the inactive error is
ϵ3inactive≈˚ϵinactive=O(√zdD)(6.11)From here on, it just keeps going like this.
Worst-case errors vs mean square errors
So far, our error calculations have only dealt with mean square errors. However, we also need to briefly address worst-case errors. Those are why we need to have an embedding redundancy S>1 in the construction.
For this, we’re mainly concerned with the embedding overlap error, because that error term comes from just a few sources (the z active circuits), which means its variance may be high. In contrast, the propagation error comes from adding up many smaller contributions (the approximately zS2TdD[19] circuits that have non-zero embedding overlap error in the previous layer), so we expect it to be well behaved, i.e. well described by the previous mean square error calculations.
The worst-case embedding overlap error happens if some circuit is unlucky enough to be an embedding neighbour to all z active circuits. For active circuits, the maximum number of active neighbours is (z-1) since a circuit can’t be its own neighbour.
From (3.1) and (1.13), we calculate
max[|˚ϵactive|]=O(z−1S)(6.12)max[|˚ϵinactive|]=O(zS)(6.13)So, as long as the embedding redundancy S is sufficiently large compared to the number of active circuits z, we should be fine.
Summary:
The main source of error is signal from the active circuits bleeding over into the inactive ones, which then enters back into the active circuits as noise in the next layer.
The noise in the active circuits accumulate from layer to layer. The noise in the inactive circuits does not accumulate.
At layer 0, there are no errors, because nothing has happened yet.
ϵ0active=0(6.14)ϵ0inactive=0(6.15)At layer 1 and onward, the leading term for the error on inactive circuits is:
ϵlinactive=O(√zdD)forl≥1(6.16)At layer 1, the leading term for the error on active circuits is:
ϵ1active=O⎛⎝√(z−1)dD⎞⎠(6.17)But from layer 2 onward, leading term for the error on active circuits is:
ϵlactive=O⎛⎝√(l−1)zTd2D2⎞⎠forl≥2(6.18)Discussion
Noise correction/suppression is necessary
Without any type of noise correction or error suppression, the error on the circuit activations would grow by O(√TdD) per layer.
Td is the total number of neurons per layer for all small networks combined, and D is the number of neurons per layer in the large network. If Td≤D then we might as well encode one feature per neuron, and not bother with superposition. So, we assume dT>D, ideally even Td≫D.
In our construction, the way we suppress errors is to use the flat part of the ReLU, both to clear away noise in inactive circuits, and to prevent noise from moving between inactive circuits. Specifically, we assumed in assumption 2 of our construction that each small circuit is somewhat noise robust, such that any network neuron that is not connected to a currently active circuit will be inactive (i.e. the ReLU pre-activation is ≤0) provided that the error on the circuit activations in the preceding layer is small enough. This means that for the error to propagate to the next layer, it has to pass though a very small fraction of ‘open’ neurons, which is what keeps the error down to a more manageable O(√Tzd2D2).
However, we do not in general predict sparse ReLU activations for networks implementing computation in superposition
The above might then seem to predict a very low activation rate for neurons in LLMs and other neural networks, if they are indeed implementing computation superposition. That’s not what we see in real large networks, e.g. MLP neurons in gpt2 have an activation rate of about 20%, much higher than our construction.
But this low activation rate is actually just an artefact of the specific setup for computation in superposition we present here. Instead of suppressing noise with the flat part of a single ReLU function, we can also create a flat slope using combinations of multiple active neurons. E.g. a network could combine two ReLU neurons to create the ‘meta activation function’ f(x)=ReLU(x)−ReLU(x−1). This combined function is flat for both x<0 (both neurons are ‘off’) and x>1 (both neurons are ‘on’). We can then embed circuit neurons into different ‘meta neurons’ f instead of embedding them into the raw network neurons.
At a glance, this might seem inefficient compared to using the raw ReLUs, but we don’t think it necessarily is. If f is a more suitable activation function for implementing the computation of many of the circuits, those circuits might effectively have a smaller width d under this implementation. The network might even mix and match different ‘meta activation functions’ like this in the same layer to suit different circuits.
But we do tentatively predict that circuits only use small subsets of network neurons
So, while neurons in the network don’t have to activate sparsely, it is necessary that each circuit only uses a small fraction of network neurons for anything like this construction to work. This is because any connection that lets through signal will also let through noise, and at least neurons[20] used by active circuits must let signal through, or else they won’t be of any use.
Getting around this would require some completely different kind of noise reduction. It seems difficult to do this using MLPs alone. Perhaps operations like layer norm and softmax can help with noise reduction somehow, that’s not something we have investigated yet.
Linda: But using few neurons per circuit does just seem like a pretty easy way to reduce noise, so I expect that networks implementing computation in superposition would do this to some extent, whether or not they also use other noise suppression methods. I have some very preliminary experimental evidence that maybe supports this conclusion[21]. More on that in a future post, hopefully.
Acknowledgements
This work was supported by Open Philanthropy.
The previous post had annoying log factors in the formulas everywhere. Here, we get to do away with those.
The linked blogpost have a section called “Computation in Superposition”. However, on closer inspection, this section only presents a model with one layer in super position. See section “Implications for experiments in computation in super position”, for why this is insufficient.
This result also seems to hold if the circuits don’t have a uniform width d across the L layers. However, it might not straightforwardly hold if different circuits interact with each other, e.g. if some circuits take the outputs of previous circuits as inputs. We think this makes intuitive sense from an information-theoretic perspective as well. If circuits interact, we have to take those interactions into account in the description length.
We get this formula from requiring the noise derived in Some costs of superposition to be small.
By ‘layers’ we mean A1,A2,…AL, as defined down in equation (1.1). So, ‘layer 2’ refers to A2 or any linear readouts of A2. We don’t count A0 in the indexing because no computation has happened yet at that point, and when training toy models of computation in superposition A0 will often not even be explicitly represented.
If we want to study this noise, but don’t want to construct a complicated multi-layer toy model, we can add some noise to the inputs, to simulate the situation we’d be in at a later layer.
If circuits are doing similar things, this lets us reuse some computational machinery between circuits, but it can also make worst-case errors worse if we’re not careful, because errors in different circuits can add systematically instead of randomly.
The results here can be generalised to networks with residual streams, although for this, the embedding has to be done differently from the post linked above, otherwise the error propagation will not be contained.
In other words, we want the outputs of all T circuits to be ϵ-linear represented.
You might notice that for both assumption 2 and 3 to hold simultaniusly, each small network needs a negative bias. Also, we did not include a separate bias term in the construction. Also, the bias can’t be baked into w (as is often done) because that would require that one of the neurons to be a non-zero constant, which contradicts assumption 2.
This is one of several ways that reality turned out to be more complicated than our theory. The good news is that this can be dealt with, in a way that preserves the general result, but that is beyond the scope of this post.
This involves adding an extra neuron to the circuit, i.e. increasing d by 1.
We’ll deal with the consequences of this pretence, in the Error calculation section.
There exist a theoretical upper bound on S, S(S−1)≤Dd(Dd−1)T. However the proof [22][23] of this upper bound is not constructive, so there is no guarantee that it can be reached.
Our allocation algorithm falls short of this upper bound. If you think you can devise a construction that gets closer to it, we’d love to hear from you.
We assume for simplicity that D is divisible by d.
Linda initially thought this set would include only primes. Thanks to Stefan Heimersheim for pointing out that more numbers could be permitted.
This row could instead be
which would give us more possible allocations (i.e. larger T for a given S and Dd, or larger possible S for a given T and Dd). But that would result in a more uneven distribution of how much each neuron is used.
Later calculation assumes that the allocation is evenly distributed over neurons. Dropping that assumption would make both the calculation harder, and the errors larger. Getting to increase S is probably not worth it.
Proof of (a.3):
Case: i=j
We know that (a.3) must be true in this case because stepx≠stepy
Case: i≠j
From (a.2) we get that
stepx=pSn; stepy=qSmforp,q∈PS; n,m∈N0(a.4)We have proved (a.3) in this case if we can prove that
ipSn≠jqSm(a.5)We also know that (a.5) must be true in the case because iSn will differ from jSm in at least one prime factor that is smaller than S. p and q can’t make up for that difference since by definition (a.1) they don’t contain prime factors smaller than S.
We initially did not think of this and only notice the importance of re-shuffling from layer to layer, when implementing this construction in code.
In the error calculation, when calculating how much noise is transported from inactive circuits to active circuits, we assume no correlation between the amount of noise in the inactive circuits and to what amount they share neurons with active circuits. But the noise depends on how much neuron overlap they had with active circuits in the last layer. Therefore this assumption will be false if we don’t re-shuffle the neuron allocation from layer to layer.
Not only will our calculation be wrong (this can be solved by more calculations) but also, the errors will be much larger, which is simply not a good construction.
The probability of two circuits sharing one large network neuron (per small circuit neuron) is S2dD. Given that there is T total circuits, this gives us S2TdD “neighbour” circuits for each small circuit. Since there are z active circuits there is approximately zS2TdD active circuit neighbours.
Or ‘meta neurons’ like the function f we discussed above.
Linda: I trained some toy models of superposition with only one computational layer. This did not result in circuits connecting sparsely to the network’s neurons. Then I trained again with some noise added to the network inputs (in order to simulate the situation in a 2+ layer network doing computation in superposition), to see how the network would learn to filter it. This did result in circuits connecting sparsely to the network’s neurons.
This suggests to me that there is no totally alternate way to filter superposition noise in an MLP we haven’t thought of yet. So networks doing computation in superposition would basically be forced to connect circuits to neurons sparsely to deal with the noise, as the math in this posts suggests.
However, this experimental investigation is still a work in progress.[24]
Proof:
There are Dd(Dd−1) pairs of neurons in the set of Dd neurons. Each small circuit is allocated S neurons out of that set, accounting for S(S−1) pairs. No two small circuits can share a pair, which gives us the bound T≤Dd(Dd−1)S(S−1).
I (Linda) first got this bound and proof from ChatGPT (free version). According to ChatGPT it is a “known upper bound (due to the Erdős–Ko–Rado-type results)”.
My general experience is that ChatGPT is very good at finding known theorems (i.e. known to someone else, but not to me) that apply to any math problem I give it.
I also gave this problem to Claude as an experiment (some level of paid version offered by a friend). Claude tried to figure it out itself, but kept getting confused and just produced a lot of nonsense.
These networks were trained on L2 loss, which is probably the wrong loss function for incentivising superposition. When using L2 loss norm, the network doesn’t care much about separating different circuits. It’s happy to just embed two circuits right on top of each other into the same set of network neurons. I don’t really consider this to be computation in superposition. However, this should not affect the need for the network to prevent noise amplification, which is why I think these results are already some weak evidence for the prediction.
I’ll make a better toy setup that and hopefully present the result of these experiments in a future post.