Author order randomized. Authors contributed roughly equally — see attribution section for details.
Update as of July 2024: we have collaborated with @LawrenceC to expand section 1 of this post into an arXiv paper, which culminates in a formal proof that computation in superposition can be leveraged to emulate sparse boolean circuits of arbitrary depth in small neural networks.
What kind of document is this?
What you have in front of you is so far a rough writeup rather than a clean text. As we realized that our work is currently highly relevant to recent questions posed by interpretability researchers, we put together a lightly edited version of private notes we’ve written over the last ~4 months. If you’d be interested in writing up a cleaner version, get in touch, or just do it. We’re making these notes public before we’re done with the project because of some combination of (1) seeing others think along similar lines and wanting to make it less likely that people (including us) spend time duplicating work, (2) providing a frame which we think provides plenty of concrete immediate problems for people to independently work on[1] (3) seeking feedback to decrease the chance we spend a bunch of time on nonsense.
1 minute summary
Superposition is a mechanism that might allow neural networks to represent the values of many more features than they have neurons, provided that those features are present sparsely in the dataset. However, until now, an understanding of how computation can be done in a compressed way directly on these stored features has been limited to a few very specific tasks (for example here). The goal of this post is to lay the groundwork for a picture of how computation in superposition can be done in general. We hope this will enable future research to build interpretability techniques for reverse engineering circuits that are manifestly in superposition.
Our main contributions are:
Formalisation of some tasks performed by MLPs and attention layers in terms of computation on boolean features stored in superposition.
A family of novel constructions which allow a single layer MLP to compute a large number of boolean functions of features entirely in superposition.
Discussion of how these constructions could be leveraged:
to emulate arbitrary large sparse boolean circuits entirely in superposition
to allow the QK circuit of an attention head to dynamically choose a boolean expression and attend to past token positions where this expression is true.
A construction which allows the QK circuit of an attention head to check for the presence of surprisingly many query-key feature pairs simultaneously in superposition, on the order of one pair per parameter[2].
10 minute summary
Thanks to Nicholas Goldowsky-Dill for producing an early version of this summary/diagrams and generally for being instrumental in distilling this post.
Central to our analysis of MLPs is the Universal-AND (U-AND) problem:
Given m input boolean features f1,f2,…,fm. These features are sparse, meaning on most inputs only a few features are true, and encoded as directions in the input space Rd.
We want to compute all (m2) possible binary conjunctions of these inputs (f1∧f2,f1∧f3,…), and output them in different linear directions. Some small bounded error in these output values is tolerated.
We want to compute this in a single MLP layer (RReLU(W→x+b)) with as few neurons as possible, for weight matrix W with shape , bias B, and ‘readoff’ matrix R
This problem is central to understanding computation in superposition because:
Many features that people think of are boolean in nature, and reverse engineering the circuits that are involved in constructing them consists of understanding how simpler boolean features are combined to make them. For example, in a vision model, the feature which is 1 if there is a car in the image may be computed by combing the ‘wheels at the bottom of the image’ feature AND the ‘windows at the top’ feature [3].
We will be focusing on the part of the network before the readoff with the matrix R. In an analogous way to the toy model of superposition, we consider the first two layers to represent If we can do this task with an MLP with fewer than (m2) neurons, then in a sense we have computed more boolean functions than we have neurons, and the values of these functions will be stored in superposition in the MLP activation space.
Any boolean function can be written as a linear combination of ANDs with different numbers of inputs. For exampleXOR(A,B,C)=A+B+C−2A∧B−2A∧C−2B∧C+4A∧B∧C
Therefore, if we can compute and linearly represent all the ANDs in superposition, then we can do so for any boolean function.
If m=d0 (the dimension of the input space), then we can store the input features using an orthonormal basis such as the neuron basis. A naive solution in this case would be to have one neuron per pair which is active if both inputs are true and 0 otherwise. This requires (m2)=Θ(d20) neurons, and involves no superposition:
On this input x1,x2 and x5 are true, and all other inputs are false.
We can do much better than this, computing all the pairwise ANDs up to a small error with many fewer neurons. To achieve this, we have each neuron care about a random subset of inputs, and we choose the bias such that each neuron is activated when at least two of them are on. This requires d=Θ(polylog(d)) neurons:
Importantly:
A modified version works even when the input features are in superposition. In this case we cannot compute all ANDs of potentially exponentially many features. Instead, we must pick up to ~Θ(d2) logical gates to calculate at each stage.
A solution to U-AND can be generalized to compute many ANDs of more than two inputs, and therefore to compute arbitrary boolean expressions involving a small number of input variables, with surprisingly efficient asymptotic performance (superpolynomially many functions computed at once). This can be done simply by increasing the density of connections between inputs and neurons, which comes at the cost of interference terms going to zero more slowly.
It may be possible to stack multiple of these constructions in a row and therefore to emulate a large boolean circuit, in which each layer computes boolean functions on the outputs of the previous layer. However, if the interference is not carefully managed, the errors are likely to propagate and eventually become unmanageable. The details of how the errors propagate and how to mitigate this are beyond the scope of this work.
We study the performance of our constructions asymptotically in d, and expect that insofar as real models implement something like them, they will likely be importantly different in order to have low error at finite d.
If the ReLU is replaced by a quadratic activation function, we can provide a construction that is much more efficient in terms of computations per neuron. We suspect that this implies the existence of similarly efficient constructions with ReLU, and constructions that may perform better at finite d.
Our analysis of the QK part of an attention head centers on the task of skip feature-bigram checking:
Given residual stream vectors →a1,…,→aT (for sequence length T) storing boolean features in superposition .
Given a set B of skip feature-bigrams (SFBs) which specify which keys to attend to from each query in terms of features present in the query and key. A skip feature-bigram is a pair of features such as (→f6,→f13), and we say that an SFB is present in a query key pair if the first feature is present in the key and the second in the query.
We want to compute an attention score which contains, in each entry, the number of SFBs in B present in the query and key that correspond to that entry. To do so, we look for a suitable choice of the parameters in the weight matrix WQK, a dresid×dresid matrix of rank dhead. Some small bounded error is tolerated.
This framing is valuable for understanding the role played by the attention mechanism in superposed computation because:
It is a natural modification of the ‘attention head as information movement’ story that engages with the many independent features stored in residual stream vectors in parallel, rather than treating the vectors as atomic units. Each SFB can be thought of as implementing an operation corresponding to statements like ‘if feature →f13 is present in the query, then attend to keys for which feature →f6 is present’.
The stories normally given for the role played by a QK circuit can be reproduced as particular choices of B. For example, consider the set of ‘identity’ skip feature-bigrams: BId={‘if feature →fi is present in the query, then attend to keys for which feature →fi is also present’|∀i}. Checking for the presence of all SFBs in BId corresponds to attending to keys which are the same as the query.
There are also many sets B which are most naturally thought of in terms of performing each check in B individually.
A nice way to construct WQK is as a sum of terms for each skip feature-bigram, each of which is a rank one matrix equal to outer product of the two feature vectors in the SFB. In the case that all feature vectors are orthogonal (no superposition) you should be thinking of something like this:
where each of the rank one matrices, when multiplied by a residual stream vector on the right and left, performs a dot product on each side:
→aTsWQK→at=∑i(→as⋅→fki)(→fqi⋅→at)
where (fk1,fq1),…,(fk|B|,fq|B|) are the feature bigrams in B with feature directions (→fki,→fqi), and →as is a residual stream vector at sequence position s. Each of these rank one matrices contributes a value of 1 to the value of →aTsWQK→at if and only if the corresponding SFB is present. Since the matrix cannot be higher rank than dhead, typically we can only check for up to ~Θ(dhead) SFBs this way.
In fact we can check for many more SFBs than this, if we tolerate some small error. The construction is straightforward once we think of WQK as this sum of tensor products: we simply add more rank one matrices to the sum, and then approximate the sum as a rank dhead matrix, using the SVD or even a random projection matrix P. This construction can be easily generalised to the case that the residual stream stores features in superposition (provided we take care to manage the size of the interference terms) in which case WQK can be thought of as being constructed like this:
When multiplied by a residual stream vector on the right and left, this expression is →aTsWQK→at=∑i(→as⋅→fki)(P→fqi⋅→at)
Importantly:
It turns out that the interference becomes larger than the signal when roughly one SFB has been checked for per parameter: |B|=~Θ(dresiddhead)
When there is structure to the set of SFBs that are being checked for, we can exploit this to check for even more SFBs with a single attention head.
If there is a particular linear structure to the geometric arrangement of feature vectors in the residual stream, many more SFBs can be checked for at once, but this time the story of how this happens isn’t the simplest to describe in terms of a list of SFBs. This suggests that our current description of what the QK circuit does is lacking. In fact, this example exemplifies computation performed by neural nets that we don’t think is best described by our current sparse boolean picture. It may be a good starting point for building a broader theory than we have so far that takes into account other structures.
Indeed, there are many open directions for improving our understanding of computation in superposition, and we’d be excited for others to do future research (theoretical and empirical) in this area.
Some theoretical directions include:
Fitting the OV circuit into the boolean computation picture
Studying error propagation when U-AND is applied sequentially
Finding constructions with better interference at finite d
Making the story of boolean computation in transformers more complete by studying things that have not been captured by our current tasks
Generalisations to continuous variables
Empirical directions include:
Training toy models to understand if NNs can learn U-AND and related tasks, and how learned algorithms differ.
Throwing existing interp techniques at NNs trained on these tasks and trying to study what we find. Which techniques can handle the superposition adequately?
Trying to find instances of computation in superposition happening in small language models.
Structure of the Post
In Section 1, we define the U-AND task precisely, and then walk through our construction and show that it solves the task. Then we generalise the construction in 2 important ways: in Section 1.1, we modify the construction to compute ANDs of input features which are stored in superposition, allowing us to stack multiple U-AND layers together to simulate a boolean circuit. In Section 1.2 we modify the construction to compute ANDs of more than 2 variables at the same time, allowing us to compute all sufficiently small[4] boolean functions of the inputs with a single MLP. Then in Section 1.3 we explore efficiency gains from replacing the ReLU with a quadratic activation function, and explore the consequences.
In Section 2 we explore a series of questions around how to interpret the maths in Section 1, in the style of FAQs. Each part of Section 2 is standalone and can be skipped, but we think that many of the concepts discussed there are valuable and frequently misunderstood.
In section 3 we turn to the QK circuit, carefully introducing the skip feature-bigram checking task, and we explain our construction. We also discuss two scenarios that allow for more SFBs to be checked for than the simplest construction would allow.
We discuss the relevance of our constructions to real models in Section 4, and conclude in Section 5 with more discussion on Open Directions.
Notation and Conventions
d is the dimension of some activation space.d0 may also be used for the dimension of the input space, and d for the number of neurons in an MLP
m is the number of input features. If the input features are stored in superposition, m>d, otherwise m=d
→e1,→e2,…,→ed denotes an orthogonal basis of vectors. The standard basis refers to the neuron basis.
All vectors are denoted with arrows on top like this: →fi
We use single lines to denote the size of a set like this: |Si| or the L2 norm of a vector like this: |→fi|
We say that a boolean function g has been computed ϵ-accurately for some small parameter ϵ if the computed output never differs from g by more than ϵ. That is, whenever the function has the output 1, the computation outputs a number between 1±ϵ and whenever the function outputs 0, the computation outputs a number between ±ϵ.
We say that a pair of unit vectors is ϵ-almost orthogonal (for a fixed parameter ϵ) if their dot product is <ϵ (equivalently, if they are orthogonal to ϵ-accuracy). We say that a collection of unit vectors is ϵ-almost-orthogonal if they are pairwise almost orthogonal. We assume ϵ to be a fixed small number throughout the paper (unless specified otherwise).
It is known that for fixed ϵ, one can fit exponentially (in d) many almost orthogonal vectors in a d-dimensional Euclidean space. Throughout this paper, we will assume present in each NN activation space a suitably “large” collection of almost-orthogonal vectors, which we call an overbasis.
Vectors in this overbasis will be called f-vectors[5], and denoted →f1,→f2,…,→fm. We assume they correspond to binary properties of inputs relevant to a neural net (such as “Does this picture contain a cat?”). When convenient, we will assume these f-vectors are generated in a suitably random way: it is known that a random collection of vectors is, with high probability, an almost orthogonal overbasis, so long as the number of vectors is not superexponentially large in d[6].
In this post we make extensive use of Big-O notation and its variants, little o, Θ,Ω,ω. See wikipedia for definitions. We also make use of tilde notation, which means we ignore log factors. For example, by saying a function f(n) is Θ(g(n)), we mean that there are nonzero constants c1,c2>0 and a natural number N such that for all n>N, we have c1g(n)≤f(n)≤c2g(n). By saying a quantity is ~Θ(f(d)), we mean that this is true up to a factor that is a polynomial of logd — i.e., that it is asymptotically between f(d)/polylog(d) and f(d)polylog(d).
1 The Universal AND
We introduce a simple and central component in our framework, which we call the Universal AND component or U-AND for short. We start by introducing the most basic version of the problem this component solves. We then provide our solution to the simplest version of this problem. We later discuss a few generalizations: to inputs which store features in superposition, and to higher numbers of inputs to each AND gate. More elaboration on U-AND — in particular, addressing why we think it’s a good question to ask — is provided in Section 2.
1.1 The U-AND task
The basic boolean Universal AND problem: Given an input vector which stores an orthogonal set of boolean features, compute a vector from which can be linearly read off the value of every pairwise AND of input features, up to a small error. You are allowed to use only a single-layer MLP and the challenge is to make this MLP as narrow as possible.
More precisely: Fix a small parameter ϵ>0 and let d0 and ℓ be integers with d0≥ℓ[7]. Let →e1,…,→ed0 be the standard basis in Rd0, i.e.→ei is the vector whose ith component is 1 and whose other components are 0. Inputs are all at most ℓ-composite vectors, i.e., for each index set I⊆[d] with |I|≤ℓ, we have the input →xI=∑i∈I→ei∈Rd0. So, our inputs are in bijection with binary strings that contain at most ℓ ones[8]. Our task is to compute all (d02) pairwise ANDs of these input bits, where the notion of ‘computing’ a property is that of making it linearly represented in the output activation vector →a(→x)∈Rd. That is, for each pair of inputs i,j, there should be a linear function ri,j:Rd→R, or more concretely, a vector →ri,j∈Rd, such that →rTi,j→a(x)≈ϵANDi,j(x). Here, the ≈ϵ indicates equality up to an additive error ϵ and ANDi,j is 1 iff both bits i and j of x are 1. We will drop the subscript ϵ going forward.
We will provide a construction that computes these Θ(d20) features with a single d-neuron ReLU layer, i.e., a d0×d matrix W and a vector →b∈Rd such that →a(x)=ReLU(W→x+→b), with d≪d0. Stacking the readoff vectors →ri,j we provide as the rows of a readout matrix R, you can also see us as providing a parameter setting solving −−−−→ANDs(→x)≈ϵR(ReLU(W→x+→b)), where −−−−→ANDs(→x) denotes the vector of all (d02) pairwise ANDs. But we’d like to stress that we don’t claim there is ever something like this large, size (d02), layer present in any practical neural net we are trying to model. Instead, these features would be read in by another future model component, like how the components we present below (in particular, our U-AND construction with inputs in superposition and our QK circuit) do.
There is another kind of notion of a set of features having been computed, perhaps one that’s more native to the superposition picture: that of the activation vector (approximately) being a linear combination of f-vectors — we call these vectors f-vectors— corresponding to these properties, with coefficients that are functions of the values of the features. We can also consider a version of the U-AND problem that asks for output vectors which represent the set of all pairwise ANDs in this sense, maybe with the additional requirement that the f-vectors be almost orthogonal. Our U-AND construction solves this problem, too — it computes all pairwise ANDs in both senses. See the appendix for a discussion of some aspects of how the linear readoff notion of stuff having been computed, the linear combination notion of something having been computed, and almost orthogonality hang together.
1.2 The U-AND construction
We now present a solution to the U-AND task, computing (d02) new features with an MLP width that can be much smaller than (d02). We will go on to show how our solution can be tweaked to compute ANDs of more than 2 features at a time, and to compute ANDs of features which are stored in superposition in the inputs.
To solve the base problem, we present a random construction: W (with shape d0×d) has entries that are iid random variables which are 1 with probability p(d)≪1, and each entry in the bias vector is −1. We will pin down what p should be later.
We will denote by Si the set of neurons that are ‘connected’ to the ith input, in the sense that elements of the set are neurons for which the ith entry of the row of the weight vector that connects to that neuron is 1. →Si is used to denote the indicator set of Si: the vector which is 1 for every neuron in Si and 0 otherwise. So →Si is also the ith column of W.
Then we claim that for this choice of weight matrix, all the ANDs are approximately linearly represented in the MLP activation space with readoff vectors (and feature vectors, in the sense of Appendix B) given by
v(xi∧xj)=vij=−−−−−→Si∩Sj|Si∩Sj|
for all i,j, where we continue our abuse of notation to write Si∩Sj as shorthand for the vector which is an indicator for the intersection set, and |Si∩Sj| is the size of the set.
We preface our explanation of why this works with a technical note. We are going to choose d and p (as functions of d0) so that with high probability, all sets we talk about have size close to their expectation. To do this formally, one first shows that the probability of each individual set having size far from its expectation is smaller than any 1/poly(d0) using the Chernoff bound (Theorem 4 here), and one follows this by a union bound over all only poly(d0) sets to say that with probability 1−o(1), none of these events happen. For instance, if a set Si∩Sj has expected size log4d0, then the probability its size is outside of the range log4d0±log3d0 is at most 2e−μδ2/3=2e−log2d0=2ed−logd00 (following these notes, we let μ denote the expectation and δ denote the number of μ-sized deviations from the expectation — this bound works for δ<1 which is the case here). Technically, before each construction to follow, we should list our parameters d,p and all the sets we care about (for this first construction, these are the double and triple intersections between the Si) and then argue as described above that with high probability, they all have sizes that only deviate by a factor of 1+o(1) from their expected size and always carry these error terms around in everything we say, but we will omit all this in the rest of the U-AND section.
So, ignoring this technicality, let’s argue that the construction above indeed solves the U-AND problem (with high probability). First, note that |Si∩Sj|∼Bin(d,p2). We require that p is big enough to ensure that all intersection sets are non-empty with high probability, but subject to that constraint we probably want p to be as small as possible to minimise interference[9]. We’ll choose p=log2d0/√d, such that the intersection sets have size |Si∩Sj|≈log4d0. We split the check that the readoff works out into a few cases:
Firstly, if input features i, j, and at most ℓ−2 other input features are present (recall that we are working with ℓ-composite inputs), then letting →a denote the post-ReLU activation vector, we have →fANDij⋅→a=1 plus an error that is at most ℓ times [the sum of sizes of triple intersections involving i,j and each of the k−2 other features which are on, divided by the size of the Si∩Sj]. This is very likely less than O(1/log2d0) for all polynomially many pairs and sets of ℓ−2 other inputs at once[10], at least assuming d=ω(log8d0). The expected value of this error is log2d0/√d.
Secondly, if only one of i,j is present together with some at most ℓ−1 other features, then we get nonzero terms in the sum that expanding the dot product →fANDij⋅→a precisely for neurons in a triple intersection of i,j, and one of the ℓ−1 other features, so the readoff ≈0 — more precisely, O(1/log2d0) (again, assuming d=ω(log8d0)), and log2d0√d in expectation).
Finally, if neither of i,j is present, then the error corresponds to quadruple intersections, so it is even more likely at most O(1/log2d0) (still assuming d=ω(log8d0)), and log4d0d in expectation.
So we see that this readoff is indeed the AND of i and j up to error ϵ=O(1/log2d0).
To finish, we note without much proof that everything is also computed in the sense that ‘the activation vector is a linear combination of almost orthogonal features’ (defined in Appendix B). The activation vector being an approximate linear combination of pairwise intersection indicator vectors with coefficients being given by the ANDs follows from triple intersections being small, as does the almost-orthogonality of these feature vectors.
U-AND allows for arbitrary XORs to be efficiently calculated
A consequence of the precise (up to ϵ) nature of our universal AND is the existence of a universal XOR, in the sense of every XOR of features being computed. In this post by Sam Marks, it is tentatively observed that real-life transformers linearly compute XOR of arbitrary features in the weak sense of being able to read off tokens where XOR of two tokens is true using a linear probe (not necessarily with ϵ accuracy). This weak readoff behavior for AND would be unsurprising, as the residual stream already has this property (using the readoff vector →fi+→fj which has maximal value if and only if fi and fj are both present). However, as Sam Marks observes, it is not possible to read off XOR in this weak way from the residual stream. We can however see that such a universal XOR (indeed, in the strong sense of ϵ-accuracy) can be constructed from our strong (i.e., ϵ-accurate) universal AND. To do so, assume that in addition to the residual stream containing feature vectors →fi and →fj, we’ve also already almost orthogonally computed universal AND features →fANDi,j into the residual stream. Then we can weakly (and in fact, ϵ-accurately) read off XOR from this space by taking the dot product with the vector →fXORi,j:=→fi+→fj−2→fANDi,j. Then we see that if we had started with the two-hot pair →fi′+→fj′, the result of this readoff will be, up to a small error O(ϵ),
This gives a theoretical feasibility proof of an efficiently computable universal XOR circuit, something Sam Marks believed to be impossible.
1.3 Handling inputs in superposition: sparse boolean computers
Any boolean circuit can be written as a sequence of layers executing pairwise ANDs and XORs[11] on the binary entries of a memory vector. Since our U-AND can be used to compute any pairwise ANDs or XORs of features, this suggests that we might be able to emulate any boolean circuit by applying something like U-AND repeatedly. However, since the outputs of U-AND store features in superposition, if we want to pass these outputs as inputs to a subsequent U-AND circuit, we need to work out the details of a U-AND construction that can take in features in superposition. In this section we explore the subtleties of modifying U-AND in this way. In so doing, we construct an example of a circuit which acts entirely in superposition from start to finish — nowhere in the construction are there as many dimensions as features! We consider this to be an interesting result in its own right.
U-ANDs ability to compute many boolean functions of inputs features stored in superposition provides an efficient way to use all the parameters of the neural net to compute (up to a small error) a boolean circuit with a memory vector that is wider than the layers of the NN[12]. We call this emulating a ‘boolean computer’. However, three limitations prevent any boolean circuit from being computed:
An injudicious choice of a layer executing XORs applied to a sparse input can fail to give a sparse output vector. Since U-AND only works on inputs with sparse features, this means that we can only emulate circuits with the property than on sparse inputs, their memory vector is sparse throughout the computation. We call these circuits ‘sparse boolean circuits’.
Even if the outputs of the circuit remain sparse at every layer, the ϵ errors involved in the boolean read-offs compound from layer to layer. We hope that it is possible to manage this interference (perhaps via subtle modifications to the constructions) enough to allow multiple steps of sequential computation, although we leave an exploration of error propagation to future work.
We can’t compute an unbounded number of new features with a finite-dimensional hidden layer. As we will see in this section, when input features are stored in superposition (which is true for outputs of U-AND and therefore certainly true for all but possibly the first layer of an emulated boolean circuit), we cannot compute more than ~Θ(d0d) (number of parameters in the layer) many new boolean functions at a time.
Therefore, the boolean circuits we expect can be emulated in superposition (1) are sparse circuits (2) have few layers (3) have memory vectors which are not larger than the square of the activation space dimension.
Construction details for inputs in superposition
Now we generalize U-AND to the case where input features can be in superposition. With f-vectors →f1,…,→fm∈Rd0, we give each feature a random set of neurons to map to, as before. After coming up with such an assignment, we set the ith row of W to be the sum of the f-vectors for features which map to the ith neuron. In other words, let F be the m×d0 matrix with ith row given by the components of →fi in the neuron basis:
F=⎛⎜
⎜
⎜⎝→f1→⋮→fm→⎞⎟
⎟
⎟⎠
Now let \hat{W} be a sparse matrix (with shape d×m) with entries that are iid Bernoulli random variables which are 1 with probability p(d)≪1. Then:
W=^WF
Unfortunately, since the →f1,…,→fm are random vectors, their inner product will have a typical size of 1/√d0. So, on an input which has no features connected to neuron i, the preactivation for that neuron will not be zero: it will be a sum of these interference terms, one for each feature that is connected to the neuron. Since the interference terms are uncorrelated and mean zero, they start to cause neurons to fire incorrectly when Θ(d0) neurons are connected to each neuron. Since each feature is connected to each neuron with probability p=log2d0√d) this means neurons start to misfire when m=~Θ(d0√d)[13]. At this point, the number of pairwise ANDs we have computed is (m2)=~Θ(d20d).
This is a problem, if we want to be able to do computation on input vectors storing potentially exponentially many features in superposition, or even if we want to be able to do any sequential boolean computation at all:
Consider an MLP with several layers, all of width dMLP, and assume that each layer is doing a U-AND on the features of the previous layer. Then if the features start without superposition, there are initially dMLP features. After the first U-AND, we have Θ(d2MLP) new features, which is already too many to do a second U-AND on these features!
Therefore, we will have to modify our goal when features are in superposition. That said, we’re not completely sure there isn’t any modification of the construction that bypasses such small polynomial bounds. But e.g. one can’t just naively make ^W sparser — p can’t be taken below d−1/2 without the intersection sets like |Si∩Sj| becoming empty. When features were not stored in superposition, solving U-AND corresponded to computing d20 many new features. Instead of trying to compute all pairwise ANDs of all (potentially exponentially many) input features in superposition, perhaps we should try to compute a reasonably sized subset of these ANDs. In the next section we do just that.
A construction which computes a subset of ANDs of inputs in superposition
Here, we give a way to compute ANDs of up to d0d particular feature pairs (rather than all (m2) ANDs) that works even for m that is superpolynomial in d0[14]. (We’ll be ignoring log factors in much of what follows.)
In U-AND, we take ^W to be a random matrix with iid 0⁄1 entries with probability p=log2d0√d. If we only need/want to compute a subset of all the pairwise ANDs — let E be this set of all pairs of inputs {i,j} for which we want to compute the AND of i and j — then whenever {i,j}∈E, we might want each pair of corresponding entries in the corresponding columns i and j of the adjacency matrix ^W, i.e., each pair (^W)ki, (^W)kj to be a bit more correlated than an analogous pair in column i′ and j′ with {i′,j′}∉E. Or more precisely, we want to make such pairs of columns {i,j} have a surprisingly large intersection for the general density of the matrix — this is to make sure that we get some neurons which we can use to read off the AND of {i,j}, while choosing the general density in ^W to be low enough that we don’t cross the density threshold at which a neuron needs to care about too many input features.
One way to do this is to pick a uniformly random set of log4d0 neurons for each {i,j}∈E, and to set the column of ^W corresponding to input i to be the indicator vector of the union of these sets (i.e., just those assigned to gates involving i). This way, we can compute up to around |E|=~Θ(d0d) pairwise ANDs without having any neuron care about more than d0 input features, which is the requirement from the previous section to prevent neurons misfiring when input f-vectors are random vectors in superposition with typical interference size Θ(1/√d0).
1.4 ANDs with many inputs: computation of small boolean circuits in a single layer
It is known that any boolean circuit with k inputs can be written as a linear combination (with possibly exponential in k terms, which is a substantial caveat) ANDs with up to k inputs (fan-in up to k)[15]. This means that, if we can compute not just pairwise ANDs, but ANDs of all fan-ins up to k, then we can write down a ‘universal’ computation that computes (simultaneously, in a linearly-readable sense) all possible circuits that depend on some up to k inputs.
The U-AND construction for higher fan-in
We will modify the standard, non-superpositional U-AND construction to allow us to compute all ANDs of a specific fan-in k.
We’ll need two modifications:
We’re now interested in k-wise intersections between the Si. The size of these intersections is smaller than double intersections, so we need to increase p to guarantee they are nonempty. A sensible choice for fan-in k is p=log2d0d1/k.
We only want neurons to fire when k of the features that connect to them are present at the same time, so we require the bias to be −k+1.
Now we read off the AND of a set I of input features along the vector ⋂i∈ISi.
We can straightforwardly simultaneously compute all ANDs of fan-ins ranging from 2 to k by just evenly partitioning the d neurons into k−1 groups — let’s label these 2,3,…,k — and setting the weights into group i and the biases of group i as in the fan-in i U-AND construction.
A clever choice of density can give us all the fan-ins at once
Actually, we can calculate all ANDs of up to some constant fan-ink in a way that feels more symmetric than the option involving a partition above[16] by reusing the fan-in 2 U-AND with (let’s say) d=d0 and a careful choice of p=1log2d0 . This choice of p is larger than log2d0d1/k for any k, ensuring that every intersection set is non-empty. Then, one can read off ANDi,j from Si∩Sj as usual, but one can also read off ANDi,j,k with the composite vector
−Si∩Sj∩Sk|Si∩Sj∩Sk|+Si∩Sj|Si∩Sj|+Si∩Sk|Si∩Sk|+Sj∩Sk|Sj∩Sk| In general, one can read off the AND of an index set I with the vector ∑I′⊆I s.t. |I′|≥2(−1)|I|−|I′|+1vI′ where vI′=⋂i∈I′Si∣∣⋂i∈I′Si∣∣One can show that this inclusion-exclusion style formula works by noting that if the subset of indices of I which are on is J, then the readoff will be approximately ∑I′⊆I s.t. |I′|≥2(−1)|I|−|I′|+1max(0,|I′∩J|−1). We’ll leave it as an exercise to show that this is 0 if J≠I and 1 if J=I.
Extending the targeted superpositional AND to other fan-ins
It is also fairly straightforward to extend the construction for a subset of ANDs when inputs are in superposition to other fan-ins, doing all fan-ins on a common set of neurons. Instead of picking a set for each pair that we need to AND as above, we now pick a set for each larger AND gate that we care about. As in the previous sparse U-AND, each input feature gets sent to the union of the sets for its gates, but this time, we make the weights depend on the fan-in. Letting K denote the max fan-in over all gates, for a fan-in k gate, we set the weight from each input to K/k, and set the bias to −K+1. This way, still with at most about ~Θ(d2) gates, and at least assuming inputs have at most some constant number of features active, we can read the output of a gate off with the indicator vector of its set.
1.5 Improved Efficiency with a Quadratic Nonlinearity
It turns out that, if we use quadratic activation functions x↦x2 instead of ReLU’s x↦ReLU(x), we can write down a much more efficient universal AND construction. Indeed, the ReLU universal AND we constructed can compute the universal AND of up to ~Θ(d3/2) features in a d-dimensional residual stream. However, in this section we will show that with a quadratic activation, for ℓ-composite vectors, we can compute all pairwise ANDs of up to m=Ω(exp(12ℓϵ2√d))[17] features stored in superposition (this is exponential in √d, so superpolynomial in d(!)) that admit a single-layer universal AND circuit.
The idea of the construction is that, on the large space of features Rm, the AND of the boolean-valued feature variables fi,fj can be written as a quadratic function qi,j:{0,1}m↦{0,1}; explicitly, qi,j(f1,…,fm)=fi⋅fj. Now if we embed feature space Rm onto a smaller Rr in an ϵ-almost-orthogonal way, it is possible to show that the quadratic function qi,j on Rm is well-approximated on sparse vectors by a quadratic function on Rr (with error bounded above by 2ϵ on 2-sparse inputs in particular). Now the advantage of using quadratic functions is that any quadratic function on Rr can be expressed as a linear read-off of a special quadratic function Q:Rr→Rr2 given by the composition of a linear function Rr→Rr2 and a quadratic element-wise activation function on Rr2 which creates a set of neurons which collectively form a basis for all quadratic functions. Now we can set d=r2 to be the dimension of the residual stream and work with an r-dimensional subspace V of the residual stream, taking the almost-orthogonal embedding Rm→V. Then the map VQ→Rd provides the requisite universal AND construction. We make this recipe precise in the following section
Construction Details
In this section we use slightly different notation to the rest of the post, dropping overarrows for vectors, and we drop the distinction between features and f-vectors.
Let V=Rr be as above. There is a finite-dimensional space of quadratic functions on Rr, with basis qij=xixj of size r2 (such that we can write every quadratic function as a linear combination of these basis functions); alternatively, we can write qij(v)=(v⋅ei)(v⋅ej), for ei,ej the basis vectors. We note that this space is spanned by a set of functions which are squares of linear functions of {xi}:
The squares of these functions are a valid basis for the space of quadratic functions on Rr since qii=(L(1)i)2 and for i≠j, we have qij=(L(2)i,j)2−(L(3)i,j)24. There are m distinct functions of type (1), and (m2) functions each of type (2) and (3), for a total of r2 basis functions as before. Thus there exists a single-layer quadratic-activation neural net Q:x↦y from Rr→Rr2 such that any quadratic function on Rr is realizable as a “linear read-off”, i.e., given by composing Q with a linear function Rr2→R. In particular, we have linear “read-off” functions Λij:Rr2→R such that Lij(Q(x))=qij(x).
Now suppose that f1,…,fm is a collection of f-vectors which are ϵ-almost-orthogonal, i.e., such that |fi|=1 for any i and |fi⋅fj|<ϵ∀i<j≤m. Note that (for fixed ϵ<1), there exist such collections with exponential (in r) number of vectors m. We can define a new collection of symmetric bilinear functions (i.e., functions in two vectors v,w∈Rn which are linear in each input independently and symmetric to switching v,w), ϕi,j, for a pair of (not necessarily distinct) indices 0<i≤j≤m, defined by ϕi,j(v)=(v⋅fi)(v⋅fj) (this is a product of two linear functions, hence quadratic). We will use the following result:
Proposition 1 Suppose ϕi,j is as above and 0<i′≤j′<m is another pair of (not necessarily distinct) indices associated to feature vectors vi,vj. Then
ϕi,j(vi′,vj′)⎧⎨⎩=1,i=i′ and j=j′∈(−ϵ,ϵ),(i,j)≠(i′,j′)∈(−ϵ2,ϵ2),{i,j}∩{i′,j′}=∅ (i.e., no indices in common)
This proposition follows immediately from the definition of ϕk,ℓ and the almost orthogonality property. □
Now define the single-valued quadratic function ϕsinglei,j(v):=12ϕi,j(v,v), by applying the bilinear form to two copies of the same vector and dividing by 2. Then the proposition above implies that, for two pairs of distinct indices 0<i<j≤m and 0<i′<j′≤m we have the following behavior on the sum of two features (the superpositional analog of a two-hot vector):
The first formula follows from bilinearity (which is equivalent to the statement that the two entries in ϕi,j behave distributively) and the last formula follows from the proposition since we assumed (i,j) are distinct indices, hence cannot match up with a pair of identical indices (i′,i′) or (j′,j′). Moreover, O(ϵ) term in the formula above is bounded in absolute value by 2ϵ2=ϵ.
Combining this formula with Proposition 1, we deduce:
Proposition 2
ϕsinglei,j(vi′+vj′)=⎧⎨⎩1+O(ϵ),i=i′ and j=j′O(ϵ),(i,j)≠(i′,j′)O(ϵ2),i≠i′.
Moreover, by the triangle inequality, the linear constants inherent in the O(...) notation are ≤2.□
Corollaryϕi,j(vi′+vj′)=δ(i,j),(i′,j′)+O(ϵ), where the δ notation returns 1 when the two pairs of indices are equal and 0 otherwise.
We can now write down the universal AND function by setting d=r2 above. Assume we have m<exp(ϵ22r). This guarantees (with probability approaching 1) that m random vectors in V≅Rr are (ϵ-)almost orthogonal, i.e., have dot products <ϵ. We assume the vectors v1,…,vm are initially embedded in V⊂Rd. (Note that we can instead assume they were initially randomly embedded in Rd, then re-embedded in Rr by applying a random projection and rescaling appropriately.) Let Q:Rr→Rd=r2 be the universal quadratic map as above; we let qij:Rd→R be the quadratic functions as above. Now we claim that Q is a universal AND with respect to the feature vectors v1,…,vN. Note that, since the function ϕsinglei,j(v) is quadratic on Rr, it can be factorized as ϕsinglei,j(x)=Φi,j(Q(x)), for Φi,j some linear function on Rr2[18]. We now see that the linear maps Φi,j are valid linear read-offs for ANDs of features: indeed,
where bi′,j′ is the two-hot boolean indicator vector with 1s in positions i′ and j′. Thus the AND of any two indices i,j can be computed via the readout linear function Φi,j on any two-hot input bi′,j′. Moreover, applying the same argument to a larger sparse sum gives Φi,j(Q(∑mk=1bkvk))=AND(bi,bj)+O(s2ϵ), where s=∑mk=1bk is the sparsity[19].
Scaling and comparison with ReLU activations
It is surprising that the universal AND circuit we wrote down for quadratic activations is so much more expressive than the one we have for ReLU activations, since the conventional wisdom for neural nets is that the expressivity of different (suitably smooth) activation functions does not increase significantly when we replace arbitrary activations by quadratic ones. We do not know if this is a genuine advantage of quadratic activations over others (and indeed might be implemented in transformers in some sophisticated way involving attention nonlinearities), or whether there is some yet-unknown reason that (perhaps assuming nice properties of our features), ReLU’s can give more expressive universal AND circuits than we have been able to find in the present work. We list this discrepancy as an interesting open problem that follows from our work.
Generalizations
Note that the nonlinear function Q above lets us read off not only the AND of two sparse boolean vectors, but more generally the sum of products of coordinates of any sufficiently sparse linear combination of feature vectors vi (not necessarily boolean). More generally, if we replace quadratic activations with cubic or higher, we can get cubic expressions, such as the sum of triple ANDs (or, more generally, products of triples of coordinates). A similar effect can be obtained by chaining l sequential levels of quadratic activations to get polynomial nonlinearities with exponent e=2l. Then so long as we can fit O(re)[20] features in the residual stream in an almost-orthogonal way (corresponding to a basis of monomials of degree d on r-dimensional space), we can compute sums of any degree-e monomial over features, and thus any boolean circuit of degree e, up to O(ϵ), where the linear constant implicit in the O depends on the exponent e. This implies that for any value e, there is a dimension d universal nonlinear map Rd→Rd with ⌈log2(e)⌉ quadratic activations such that any sparse boolean circuit involving ≤e elements is linearly represented (via an appropriate readoff vector). Moreover, keeping e fixed, d grows only as O(log(n))e. However, the constant associated with the big-O notation might grow quite quickly as the exponent e increases. It would be interesting to analyse this scaling behavior more carefully, but that is outside the scope of the present work.
1.6 Universal Keys: an application of parallel boolean computation
So far, we have used our universal boolean computation picture to show that superpositional computation in a fully-connected neural network can be more efficient (specifically, compute roughly as many logical gates as there are parameters rather than non-superpositional implementations, which are bounded by number of neurons). This does not fully use the universality of our constructions: i.e., we must at every step read a polynomial (at most quadratic) number of features from a vector which can (in either the fan-in-k or quadratic-activation contexts) compute a superpolynomial number of boolean circuits. At the same time, there is a context in transformers where precisely this universality can give a remarkable (specifically, superpolynomial in certain asymptotics) efficiency improvement. Namely, recall that the attention mechanism of a transformer can be understood as a way for the last-token residual stream to read information from past tokens which pass a certain test associated to the query-key component. In our simplified boolean model, we can conceptualize this as follows:
Each token possesses a collection of “key features” which indicate bits of information about contexts where reading information from this token is useful. These can include properties of grammar, logic, mood, or context (food, politics, cats, etc.)
The current token attends to past tokens whose key features have a certain combination of features, which we conceptualize as tokens on whose features a certain boolean “relevance” function, glast token returns 1. For example, the current token may ‘want’ to attend to all keys which have feature 1 and feature 4 but not feature 9, or exactly one of feature 2 and feature 8. This corresponds to the boolean function g=(f1∧f4∧¬f9)∨(f2⊗f8). Importantly, the choice of g varies from token to token. We abstract away the question of generating this relevance function as some (possibly complicated) nonlinear computation implemented in previous layers.
Each past token generates a key vector in a certain vector space (associated with an attention head) which is some (possibly nonlinear) function of the key features; the last token then generates a query vector which functions as a linear read-off, and should return a high value on past tokens for which the relevance formula evaluates to True. Note that the key vector is generated before the query vector, and before the choice of which g to use is made.
Importantly, there is an information asymmetry between the “past” tokens (which contribute the key) and the last token that implements the linear read-off via query: in generating the boolean relevance function, the past token can use information that is not accessible to the token generating the key (as it is in its “future” – this is captured e.g. by the attention mask). One might previously have assumed that in generating a key vector, tokens need to “guess” which specific combinations of key features may be relevant to future tokens, and separately generate some read-off for each; this limits the possible expressivity of choosing the relevance function g to a small (e.g. linear in parameter number) number of possibilities.
However, our discovery of circuits that implement universal calculation suggests a surprising way to resolve this information asymmetry: namely, using a universal calculation, the key can simultaneously compute, in an approximately linearly-readable way, ALL possible simple circuits of up to Olog(dresid) inputs. This increases the number of possibilities of the relevance function g to allow all such simple circuits; this can be significantly larger than the number of parameters and asymptotically (for logarithmic fan-ins) will in fact be superpolynomial[21]. As far as we are aware, this presents a qualitative (from a complexity-theoretic point of view) update to the expressivity of the attention mechanism compared to what was known before.
Sam Marks’ discovery of the universal XOR was done in this context: he observed using a probe that it is possible for the last token of a transformer to attend to past tokens that return True as the XOR of an arbitrary pair of features, something that he originally believed was computationally infeasible.
We speculate that this will be noticeable in real-life transformers, and can partially explain the observation that transformers tend to implement more superposition than fully-connected neural networks.
2 U-AND: discussion
We discuss some conceptual matters broadly having to do with whether the formal setup from the previous section captures questions of practical interest. Each of these subsections is standalone, and you needn’t read any to read Section 3.
Aren’t the ANDs already kinda linearly represented in the U-AND input?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. The objection is this: aren’t ANDs already linearly present in the input, so in what sense have we computed them with the U-AND? Indeed, if we take the dot product of a particular 2-hot input with (→ei+→ej)/2, we get 0 if neither the ith nor the jth features are present, 1/2 if 1 of them is present, and 1 if they are both present. If we add a bias of −1/4, then without any nonlinearity at all, we get a way to read off pairwise U-AND for ϵ=1/4. The only thing the nonlinearity lets us do is to reduce this “interference” ϵ=1/4 to a smaller ϵ. Why is this important?
In fact, one can show that you can’t get more accurate than ϵ=1/4 without a nonlinearity, even with a bias, and ϵ=1/4 is not good enough for any interesting boolean circuit. Here’s an example to illustrate the point:
Suppose that I am interested in the variable z=∧(xi,xj)+∧(xk,xl). z takes on a value in {0,1,2} depending on whether both, one, or neither of the ANDs are on. The best linear approximation to z is 1/2(xi+xj+xk+xl−1), which has completely lost the structure of z. In this case, we have lost any information about which way the 4 variables were paired up in the ANDs.
In general, computing a boolean expression with k terms without the signal being drowned out by the noise will require ϵ<1/k if the noise is correlated, and ϵ<1/k2 if the noise is uncorrelated. In other words, noise reduction matters! The precision provided by ϵ-accuracy allows us to go from only recording ANDs to executing more general circuits in an efficient or universal way. Indeed, linear combinations of linear combinations just give more linear combinations – the noise reduction is the difference between being able to express any boolean function and being unable to express anything nonlinear at all. The XOR construction (given above) is another example that can be expressed as a linear combination involving the U-AND and would not work without the nonlinearity.
Aren’t the ANDs already kinda nonlinearly represented in the U-AND input?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. While one cannot read off the ANDs linearly before the ReLU, except with a large error, one could certainly read them off with a more expressive model class on the activations. In particular, one can easily read ANDi,j off with a ReLU probe, by which we mean ReLU(rTx+b), with r=ei+ej and b=−1. We think there’s some truth to this: we agree that if something can be read off with such a probe, it’s indeed at least almost already there. And if we allowed multi-layer probes, the ANDs would be present already when we only have some pre-input variables (that our input variables are themselves nonlinear functions of). To explore a limit in ridiculousness: if we take stuff to be computed if it is recoverable by a probe that has the architecture of GPT-3 minus the embed and unembed and followed by a projection on the last activation vector of the last position residual stream, then anything that is linearly accessible in the last layer of GPT-3 is already ‘computed’ in the tuple of input embeddings. And to take a broader perspective: any variable ever computed by a deterministic neural net is in fact a function of the input, and is thus already ‘there in the input’ in an information-theoretic sense (anything computed by the neural net has zero conditional entropy given the input). The information about the values of the ANDs is sort of always there, but we should think of it as not having been computed initially, and as having been computed later[22].
Anyway, while taking something to be computed when it is affinely accessible seems natural when considering reading that information into future MLPs, we do not have an incredibly strong case that it’s the right notion. However, it seems likely to us that once one fixes some specific notion of stuff having been computed, then either exactly our U-AND construction or some minor variation on it would still compute a large number of new features (with more expressive readoffs, these would just be more complex properties — in our case, boolean functions of the inputs involving more gates). In fact, maybe instead of having a notion of stuff having been computed, we should have a notion of stuff having been computed for a particular model component, i.e. having been represented such that a particular kind of model component can access it to ‘use it as an input’. In the case of transformers, maybe the set of properties that have been computed as far as MLPs can tell is different than the set of properties that have been computed as far as attention heads (or maybe the QK circuit and OV circuit separately) can tell. So, we’re very sympathetic to considering alternative notions of stuff having been computed, but we doubt U-AND would become much less interesting given some alternative reasonable such notion.
If you think all this points to something like it being weird to have such a discrete notion of stuff having been computed vs not at all, and that we should maybe instead see models as ‘more continuously cleaning up representations’ rather than performing computation: while we don’t at present know of a good quantitative notion of ‘representation cleanliness’, so we can’t at present tell you that our U-AND makes amount x of representation cleanliness progress and x is sort of large compared to some default, it does seem intuitively plausible to us that it makes a good deal of such progress. A place where linear read-offs are clearly qualitatively important and better than nonlinear read-offs is in application to the attention mechanism of a transformer.
Does our U-AND construction really demonstrate MLP superposition?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. One could try to tell a story that interprets our U-AND construction in terms of the neuron basis: we can also describe the U-AND as approximately computing a family of functions each of which record whether at least two features are present out of a particular subset of features[23]. Why should we see the construction as computing outputs into superposition, instead of seeing it as computing these different outputs on the neurons? Perhaps the ‘natural’ units for understanding the NN is in terms of these functions, as unintuitive as they may seem to a human.
In fact, there is a sense in which if one describes the sampled construction in the most natural way it can be described in the superposition picture, one needs to spend more bits than if one describes it in the most natural way it can be described in this neuron picture. In the neuron picture, one needs to specify a subset of size ~Θ(d0/√d) for each neuron, which takes dlog2(d0~Θ(d0/√d))≤~Θ(d20√d) bits to specify. In the superpositional picture, one needs to specify (d02) subsets of size ~Θ(1), which takes about ~Θ(d20) bits to specify[24]. If, let’s say, d=d0, then from the point of view of saving bits when representing such constructions, we might even prefer to see them in a non-superpositional manner!
We can imagine cases (of something that looks like this U-AND showing up in a model) in which we’d agree with this counterargument. For any fixed U-AND construction, we could imagine a setup where for each neuron, the inputs feeding into it form some natural family — slightly more precisely, that whether two elements of this family are present is a very natural property to track. In fact, we could imagine a case where we perform future computation that is best seen as being about these properties computed by the neurons — for instance, our output of the neural net might just be the sum of the activations of these neurons. For instance, perhaps this makes sense because having two elements of one of these families present is necessary and sufficient for an image to be that of a dog. In such a case, we agree it would be silly to think of the output as a linear combination of pairwise AND features.
However, we think there are plausible contexts in which such a circuit would show up in which it seems intuitively right to see the output as a sparse sum of pairwise ANDs: when the families tracked by particular neurons do not seem at all natural and/or when it is reasonable to see future model components as taking these pairwise AND features as inputs. Conditional on thinking that superposition is generic, it seems fairly reasonable to think that these latter contexts would be generic.
Is universal calculation generic?
The construction of the universal AND circuit in the “quadratic nonlinearity” section above can be shown to be stable to perturbations; a large family of suitably “random” circuits in this paradigm contain all AND computations in a linearly-readable way. This updates us to suspect that at least some of our universal calculation picture might be generic: i.e., that a random neural net, or a random net within some mild set of conditions (that we can’t yet make precise), is sufficiently expressive to (weakly) compute any small circuit. Thus linear probe experiments such as Sam Marks’ identification of the “universal XOR” in a transformer may be explainable as a consequence of sufficiently complex, “random-looking” networks. This means that the correct framing for what happens in a neural net executing superposition might not be that the MLP learns to encode universal calculation (such as the U-AND circuit), but rather that such circuits exist by default, and what the neural network needs to learn is, rather, a readoff vector for the circuit that needs to be executed. While we think that this would change much of the story (in particular, the question of “memorization” vs. “generalization” of a subset of such boolean circuit features would be moot if general computation generically exists), this would not change the core fact that such universal calculation is possible, and therefore likely to be learned by a network executing (or partially executing) superposition. In fact, such an update would make it more likely that such circuits can be utilized by the computational scheme, and would make it even more likely that such a scheme would be learned by default.
We hope to do a series of experiments to check whether this is the case: whether a random network in a particular class executes universal computation by default. If we find this is the case, we plan to train a network to learn an appropriate read-off vector starting from a suitably random MLP circuit, and, separately, to check whether existing neural networks take advantage of such structure (i.e., have features – e.g. found by dictionary learning methods – which linearly read off the results of such circuits). We think this would be particularly productive in the attention mechanism (in the context of “universal key” generation, as explained above).
What are the implications of using ϵ-accuracy? How does this compare to behavior found by minimizing some loss function?
A specific question here is:
Are algorithms that are ϵ-accurate at U-AND the same as algorithms which minimize the MSE or some other loss function we might write down for training a neural net on the task?
The answer is that sometimes they are not going to be the same. In particular, our algorithm may not be given a low loss by MSE. Nevertheless, we think that ϵ-accuracy is a better thing to study for understanding superposition than MSE or other commonly considered loss functions (cross entropy would be much less wise than either!) This point is worth addressing properly, because it has implications for how we think about superposition and how we interpret results from the toy models of superposition paper and from sparse autoencoders, both of which typically use MSE.
For our U-AND task, we ask for a construction →f(→x) that approximately equals a 1-hot target vector →y, with each coordinate allowed to differ from its target value by at most epsilon. A loss function which would correspond to this task would look like a cube well with vertical sides (the inside of the region L∞(→f(→x),→y)<ϵ). This non-differentiable loss function would be useless for training. Let’s compare this choice to alternatives and defend it.
If we know that our target is always a 1-hot vector, then maybe we should have a softmax at the end of the network and use cross-entropy loss. We purposefully avoid this, because we are trying to construct a toy model of the computation that happens in intermediate layers of a deep neural network, taking one activation vector to a subsequent activation vector. In the process there is typically no softmax involved. Also, we want to be able to handle datapoints in which more than 1 AND is present at a time: the task is not to choose which AND is present, but *which of the ANDs* are present.
The other ubiquitous choice of loss function is MSE. This is the loss function used to evaluate model performance in two tasks that are similar to U-AND: the toy model of superposition and SAEs. Two reasons why this loss function might be principled are
If there is reason to think of the model as a Gaussian probability model
If we would like our loss function to be basis independent.
We see no reason to assume the former here, and while the latter is a nice property to have, we shouldn’t expect basis independence here: we would like the ANDs to be computed in a particular basis and are happy with a loss function that privileges that basis.
Our issue with MSE (and Lp in general for finite p) can be demonstrated with the following example:
Suppose the target is y=(1,0,0,…). Let ^y=(0,0,…) and ~y=(1+ϵ,ϵ,ϵ,…), where all vectors are (d02)-dimensional. Then ||y−^y||p=1 and ||y−~y||p=(d02)1/pϵ. For large enough (d02)>ϵ−p, the latter loss is larger than 1[25]. Yet intuitively, the latter model output is likely to be a much better approximation to the target value, from the perspective of the way the activation vector will be used for subsequent computation. Intuitively, we expect that for the activation vector to be good enough to trigger the right subsequent computation, it needs to be unambiguous whether a particular AND is present, and the noise in the value needs to be below a certain critical scale that depends on the way the AND is used subsequently, to avoid noise drowning out signal. To understand this properly we’d like a better model of error propagation.
It is no coincidence that our U-AND algorithm may be ϵ-accurate for small ϵ, but is not a minimum of the MSE. In general, ϵ-accuracy permits much more superposition than minimising the MSE, because it penalises interference less.
For a demonstration of this, consider a simplified toy model of superposition with hidden dimension d and inputs which are all 1-hot unit vectors. We consider taking the limit as the number of input features goes to infinity and ask: what is the optimum number N(d) of inputs that the model should store in superposition, before sending the rest to the zero vector?
If we look for ϵ-accurate reconstruction, then we know how to answer this: a random construction allows us to fit at least Nϵ(d)=Cexpϵ2d vectors into d-dimensional space.
As for the algorithm that minimises the MSE reconstruction loss (ie not sent to the zero vector in the hidden space), consider that we have already put n of the inputs into superposition, and we are trying to decide whether it is a good idea to squeeze another one in there. Separating the loss function into reconstruction terms and interference terms (as in the original paper):
The n+1th input being stored subtracts a term of order 1 from the reconstruction loss
Storing this input will also lead to an increase in the interference loss. As for how much, let us write δ(n)2 for the average mean squared dot product between the n+1th feature vector and one of the n feature vectors that were already there. Since the n+1th feature has n distinct features to interfere with, storing it will contribute a term of order nδ(n)2 to the interference loss.
So, the optimum number of features to store can be found by asking when the contribution to the loss ℓ(n+1)∼nδ(n)2−1 switches from negative to positive, so we need an estimate of δ(n). If feature vectors are chosen randomly, then δ(n)2=O(1/d) and we find that the optimal number of features to store is O(d). In fact, feature vectors are chosen to minimise interference, which allows us to fit a few more feature vectors in (the advantage this gives us is most significant at small n) before the accumulating interferences become too large, and empirically we observe that the optimal number of features to store is NL2(d)=O(dlogd). This is much much less superposition that we are allowed with ϵ-accurate reconstruction!
See the figure below for experimental values of NLp(d) for a range of p,d. We conjecture that for each p,NLp(d) is the minimum of an exponential function which is independent of p and something like a polynomial which depends on p.
3 The QK part of an attention head can check for many skip feature-bigrams, in superposition
In this section, we present a story for the QK part of an attention head which is analogous to the MLP story from the previous section. Note that although both focus on the QK component, this is a different (though related) story to the story about universal keys from section 1.4.
We begin by specifying a simple task that we think might capture a large fraction of the role performed by the QK part of an attention head. Roughly, the task (analogous to the U-AND task for the MLP) is to check for the presence of one in a large set of ‘skip bigrams’[26] of features[27].
We’ll then provide a construction of the QK part of an attention head that can perform this task in a superposed manner — i.e., a specification of a low-rank matrix WQK=WTKWQ that checks for a given set of skip feature-bigrams. A naive construction could only check for dhead feature bigrams; ours can check for ~Θ(dheaddresid) feature bigrams. This construction is analogous to our construction solving the targeted superpositional AND from the previous sections.
3.1 The skip feature-bigram checking task
Let B be a set of ‘skip feature-bigrams’; each element of B is a pair of features (→fi,→fj)∈Rdresid×Rdresid. Let’s define what we mean by a skip feature-bigram being present in a pair of residual stream positions. Looking at residual stream activation vectors just before a particular attention head (after layernorm is applied), we say that the activation vectors →as,→at∈Rdresid at positions s,t contain the skip feature-bigram (→fi,→fj) if feature →fi is present in →at and feature →fj is present in →as. There are two things we could mean by the feature →fi being present in an activation vector →a. The first is that →fi⋅→a′ is always either ≈0 or ≈1 for any a′ in some relevant data set of activation vectors, and →fi⋅→a=1. The second notion assumes the existence of some background set →f1,→f2,…,→fm in terms of which each activation vector a has a given background decomposition, a=∑mi=1ci→fi. In fact, we assume that all ci∈{0,1}, with at most some constant number of ci=1 for any one activation vector, and we also assume that the →fi are random vectors (we need them to be almost orthogonal). The second notion guarantees the first but with better control on the errors, so we’ll run with the second notion for this section[28].
Plausible candidates for skip feature-bigrams (→fi,→fj) to check for come from cases where if the query residual stream vector has feature →fj, then it is helpful to do something with the information at positions where →fi is present. Here are some examples of checks this can capture:
If the query is a first name, then the key should be a surname.
If the query is a preposition associated with an indirect object, then the key should be a noun/name (useful for IOI).
If the query is token T, then the key should also be token T (useful for induction heads, if we can do this for all possible tokens).
If the query is ‘Jorge Luis Borges’’, then the key should be ‘Tlön, Uqbar, Orbis Tertius’.
If the mood of the paragraph before the query is solemn, then the topic of the paragraph before the key should be statistical mechanics.
If the query is the end of a true sentence, then the key should be the end of a false sentence.
If the query is a type of pet, then the key should be a type of furniture.
The task is to use the attention score S (the attention pattern pre-softmax) to count how many of these conditions are satisfied by each choice of query token position and key token position. That is, we’d like to construct a low-rank bilinear form WTKWQ such that the (s,t) entry of the attention score matrix Sst=→aTsWTKWQ→at contains the number of conditions in C which are satisfied for the query residual stream vector in token position s and the key residual stream vector in the token position t. We’ll henceforth refer to the expression WTKWQ as WQK, a matrix of size dresid×dresid that we choose freely to solve the task subject to the constraint that its rank is at most dhead<dresid. If each property is present sparsely, then most conditions are not satisfied for most positions in the attention score most of the time.
We will present a family of algorithms which allow us to perform this task for various set sizes |B|. We will start with a simple case without superposition analogous to the ‘standard’ method for computing ANDs without superposition. Unlike for U-AND though, the algorithm for performing this task in superposition is a generalization of the non-superpositional case. In fact, given our presentation of the non-superpositional case, this generalization is fairly immediate, with the main additional difficulty being to keep track of errors from approximate calculations.
3.2 A superposition-free algorithm
Let’s make the assumption that m is at most dresid. For the simplest possible algorithm, let’s make the further (definitely invalid) assumption that the feature basis is the neuron basis. This means that →as is a vector in {0,1}dresid. In the absence of superposition, we do not require that these features are sparse in the dataset.
To start, consider the case where B contains only one feature bigram (→ei,→ej). The task becomes: ensure that Sst=→aTsWQK→at is 1 if feature →fi is present in→as and feature →fj is present in →at and 0 otherwise. The solution to this task is to choose WQK to be a matrix with zero everywhere except in the i,j component: (WQK)kl=δkiδlj —with this matrix, →aTsWQK→at=1 iff the i entry of →as is 1 and the j entry of →at is 1. Note that we can write WQK=→k⊗→q where →k=→ei, →q=→ej, and ⊗ denotes the outer product/tensor product/Kronecker product. This expression makes it manifest that WQK is rank 1. Whenever we can decompose a matrix into a tensor product of two vectors (this will prove useful), we will call it a _pure tensor_ in accordance with the literature. Note that this decomposition allows us to think of WQK in terms of the query part and key part separately: first we project the residual stream vector in the query position onto the ith feature vector which tells us if feature i is present at the query position, then we do the same for the key, and then we multiply the results.
In the next simplest case, we take the set B to consist of pairs (ei,ej). To solve the task for this B, we can simply perform a sum over WPQK for each bigram in B, since there is no interference. That is, we choose
WPQK=∑(i,j)∈B→ei⊗→ej
The only new subtlety that is introduced in this modification comes from the requirement that the rank of WPQK be at most dhead which won’t be true in general. The rank of WPQK is not trivial to calculate for a given B. This is because we can factorize terms in the sum:
which is a pure tensor. The rank requirement is equivalent to the statement that WPKW can contain at most dhead terms _after maximum factorisation_ (a priori, not necessarily in terms of such pure tensors of sums of subsets of basis vectors). Visualizing the set B as a bipartite graph with m nodes on the left and right, we notice that pure tensors correspond to any subgraphs of B that are _complete_ bipartite subgraphs (cliques). A sufficient condition for the rank of W being at most dhead is if the edges of B can be partitioned into at most dhead cliques. Thus, whether we can check for all feature bigrams in B this way depends not only on the size of B, but also its structure.. In general, we can’t use this construction to guarantee that we can check for more than dhead skip feature-bigrams.
Generalizing our algorithm to deal with the case when the feature basis is not neuron-aligned (although it is still an orthogonal basis) could not be simpler. All we do is replace {→ei} with the new feature basis, use the same expression for WPQK, and we are done.
3.3 Checking for a structured set of skip feature-bigrams with activation superposition
We now consider the case where the residual stream contains m>dresid sparsely activated features stored in superposition. We’ll assume that the feature vectors are random unit vectors, and we’ll switch notation from e1,…,edresid to f1,…,fm from now on to emphasize that the f-vectors are not an orthogonal basis. We’d like to generalize the superposition-free algorithm to the case when the residual stream vector stores features in superposition, but to do so, we’ll have to keep track of the interference between non-orthogonal f-vectors. We know that the root mean square dot product between two f-vectors is 1/√dresid. Every time we check for a bigram that isn’t present and pick up an interference term, the noise accumulates—for the signal to beat the noise here, we need the sum of interference terms to be less than 1. We’ll ignore log factors in the rest of this section.
We’ll assume that most of the interference comes from checking for bigrams (→fi,→fj) where →fi isn’t in →as and also →fj isn’t in →at — that cases where one feature is present but not the other are rare enough to contribute less can be checked later. These pure tensors typically contribute an interference of 1/dresid. We can also consider the interference that comes for checking for a clique of bigrams: let K and Q be sets of features such that B=K×Q. Then, we can check for the entire clique using the pure tensor (∑j∈K→fj)⊗(∑i∈Q→fi). Checking for this clique of feature bigrams on key-query pairs which don’t contain any bigram in the clique contributes an interference term of √|K||Q|/dresid assuming interferences are uncorrelated. Now we require that the sum over interferences for checking all cliques of bigrams—of which there are at most dhead - is less than one. Since there are at most dhead cliques, then assuming each clique is the same size (slightly more generally, one can also make the cliques differently-sized as long as the total number of edges in their union is at most dresid) and assuming the noise is independent between cliques, we require √|K||Q|/dresid<1/√dhead. Further assuming |K|=|Q|, this gives that at most |K|=|Q|=dresid/√dhead. In this way, over all dhead cliques, we can check for up to d2resid bigrams, which can collectively involve up to dresid√dhead distinct features, in each attention head.
Note also that one can involve up to dheaddresid features if one chooses |K|=1 and |Q|
Toward A Mathematical Framework for Computation in Superposition
Author order randomized. Authors contributed roughly equally — see attribution section for details.
Update as of July 2024: we have collaborated with @LawrenceC to expand section 1 of this post into an arXiv paper, which culminates in a formal proof that computation in superposition can be leveraged to emulate sparse boolean circuits of arbitrary depth in small neural networks.
What kind of document is this?
What you have in front of you is so far a rough writeup rather than a clean text. As we realized that our work is currently highly relevant to recent questions posed by interpretability researchers, we put together a lightly edited version of private notes we’ve written over the last ~4 months. If you’d be interested in writing up a cleaner version, get in touch, or just do it. We’re making these notes public before we’re done with the project because of some combination of (1) seeing others think along similar lines and wanting to make it less likely that people (including us) spend time duplicating work, (2) providing a frame which we think provides plenty of concrete immediate problems for people to independently work on[1] (3) seeking feedback to decrease the chance we spend a bunch of time on nonsense.
1 minute summary
Superposition is a mechanism that might allow neural networks to represent the values of many more features than they have neurons, provided that those features are present sparsely in the dataset. However, until now, an understanding of how computation can be done in a compressed way directly on these stored features has been limited to a few very specific tasks (for example here). The goal of this post is to lay the groundwork for a picture of how computation in superposition can be done in general. We hope this will enable future research to build interpretability techniques for reverse engineering circuits that are manifestly in superposition.
Our main contributions are:
Formalisation of some tasks performed by MLPs and attention layers in terms of computation on boolean features stored in superposition.
A family of novel constructions which allow a single layer MLP to compute a large number of boolean functions of features entirely in superposition.
Discussion of how these constructions could be leveraged:
to emulate arbitrary large sparse boolean circuits entirely in superposition
to allow the QK circuit of an attention head to dynamically choose a boolean expression and attend to past token positions where this expression is true.
To explain the tentative observation that transformers may store arbitrary XORs of features
A construction which allows the QK circuit of an attention head to check for the presence of surprisingly many query-key feature pairs simultaneously in superposition, on the order of one pair per parameter[2].
10 minute summary
Thanks to Nicholas Goldowsky-Dill for producing an early version of this summary/diagrams and generally for being instrumental in distilling this post.
Central to our analysis of MLPs is the Universal-AND (U-AND) problem:
Given m input boolean features f1,f2,…,fm. These features are sparse, meaning on most inputs only a few features are true, and encoded as directions in the input space Rd.
We want to compute all (m2) possible binary conjunctions of these inputs (f1∧f2,f1∧f3,…), and output them in different linear directions. Some small bounded error in these output values is tolerated.
We want to compute this in a single MLP layer (RReLU(W→x+b)) with as few neurons as possible, for weight matrix W with shape , bias B, and ‘readoff’ matrix R
This problem is central to understanding computation in superposition because:
Many features that people think of are boolean in nature, and reverse engineering the circuits that are involved in constructing them consists of understanding how simpler boolean features are combined to make them. For example, in a vision model, the feature which is 1 if there is a car in the image may be computed by combing the ‘wheels at the bottom of the image’ feature AND the ‘windows at the top’ feature [3].
We will be focusing on the part of the network before the readoff with the matrix R. In an analogous way to the toy model of superposition, we consider the first two layers to represent If we can do this task with an MLP with fewer than (m2) neurons, then in a sense we have computed more boolean functions than we have neurons, and the values of these functions will be stored in superposition in the MLP activation space.
Any boolean function can be written as a linear combination of ANDs with different numbers of inputs. For exampleXOR(A,B,C)=A+B+C−2A∧B−2A∧C−2B∧C+4A∧B∧C
Therefore, if we can compute and linearly represent all the ANDs in superposition, then we can do so for any boolean function.
If m=d0 (the dimension of the input space), then we can store the input features using an orthonormal basis such as the neuron basis. A naive solution in this case would be to have one neuron per pair which is active if both inputs are true and 0 otherwise. This requires (m2)=Θ(d20) neurons, and involves no superposition:
On this input x1,x2 and x5 are true, and all other inputs are false.
We can do much better than this, computing all the pairwise ANDs up to a small error with many fewer neurons. To achieve this, we have each neuron care about a random subset of inputs, and we choose the bias such that each neuron is activated when at least two of them are on. This requires d=Θ(polylog(d)) neurons:
Importantly:
A modified version works even when the input features are in superposition. In this case we cannot compute all ANDs of potentially exponentially many features. Instead, we must pick up to ~Θ(d2) logical gates to calculate at each stage.
A solution to U-AND can be generalized to compute many ANDs of more than two inputs, and therefore to compute arbitrary boolean expressions involving a small number of input variables, with surprisingly efficient asymptotic performance (superpolynomially many functions computed at once). This can be done simply by increasing the density of connections between inputs and neurons, which comes at the cost of interference terms going to zero more slowly.
It may be possible to stack multiple of these constructions in a row and therefore to emulate a large boolean circuit, in which each layer computes boolean functions on the outputs of the previous layer. However, if the interference is not carefully managed, the errors are likely to propagate and eventually become unmanageable. The details of how the errors propagate and how to mitigate this are beyond the scope of this work.
We study the performance of our constructions asymptotically in d, and expect that insofar as real models implement something like them, they will likely be importantly different in order to have low error at finite d.
If the ReLU is replaced by a quadratic activation function, we can provide a construction that is much more efficient in terms of computations per neuron. We suspect that this implies the existence of similarly efficient constructions with ReLU, and constructions that may perform better at finite d.
Our analysis of the QK part of an attention head centers on the task of skip feature-bigram checking:
Given residual stream vectors →a1,…,→aT (for sequence length T) storing boolean features in superposition .
Given a set B of skip feature-bigrams (SFBs) which specify which keys to attend to from each query in terms of features present in the query and key. A skip feature-bigram is a pair of features such as (→f6,→f13), and we say that an SFB is present in a query key pair if the first feature is present in the key and the second in the query.
We want to compute an attention score which contains, in each entry, the number of SFBs in B present in the query and key that correspond to that entry. To do so, we look for a suitable choice of the parameters in the weight matrix WQK, a dresid×dresid matrix of rank dhead. Some small bounded error is tolerated.
This framing is valuable for understanding the role played by the attention mechanism in superposed computation because:
It is a natural modification of the ‘attention head as information movement’ story that engages with the many independent features stored in residual stream vectors in parallel, rather than treating the vectors as atomic units. Each SFB can be thought of as implementing an operation corresponding to statements like ‘if feature →f13 is present in the query, then attend to keys for which feature →f6 is present’.
The stories normally given for the role played by a QK circuit can be reproduced as particular choices of B. For example, consider the set of ‘identity’ skip feature-bigrams: BId={‘if feature →fi is present in the query, then attend to keys for which feature →fi is also present’|∀i}. Checking for the presence of all SFBs in BId corresponds to attending to keys which are the same as the query.
There are also many sets B which are most naturally thought of in terms of performing each check in B individually.
A nice way to construct WQK is as a sum of terms for each skip feature-bigram, each of which is a rank one matrix equal to outer product of the two feature vectors in the SFB. In the case that all feature vectors are orthogonal (no superposition) you should be thinking of something like this:
where each of the rank one matrices, when multiplied by a residual stream vector on the right and left, performs a dot product on each side:
→aTsWQK→at=∑i(→as⋅→fki)(→fqi⋅→at)
where (fk1,fq1),…,(fk|B|,fq|B|) are the feature bigrams in B with feature directions (→fki,→fqi), and →as is a residual stream vector at sequence position s. Each of these rank one matrices contributes a value of 1 to the value of →aTsWQK→at if and only if the corresponding SFB is present. Since the matrix cannot be higher rank than dhead, typically we can only check for up to ~Θ(dhead) SFBs this way.
In fact we can check for many more SFBs than this, if we tolerate some small error. The construction is straightforward once we think of WQK as this sum of tensor products: we simply add more rank one matrices to the sum, and then approximate the sum as a rank dhead matrix, using the SVD or even a random projection matrix P. This construction can be easily generalised to the case that the residual stream stores features in superposition (provided we take care to manage the size of the interference terms) in which case WQK can be thought of as being constructed like this:
When multiplied by a residual stream vector on the right and left, this expression is →aTsWQK→at=∑i(→as⋅→fki)(P→fqi⋅→at)
Importantly:
It turns out that the interference becomes larger than the signal when roughly one SFB has been checked for per parameter: |B|=~Θ(dresiddhead)
When there is structure to the set of SFBs that are being checked for, we can exploit this to check for even more SFBs with a single attention head.
If there is a particular linear structure to the geometric arrangement of feature vectors in the residual stream, many more SFBs can be checked for at once, but this time the story of how this happens isn’t the simplest to describe in terms of a list of SFBs. This suggests that our current description of what the QK circuit does is lacking. In fact, this example exemplifies computation performed by neural nets that we don’t think is best described by our current sparse boolean picture. It may be a good starting point for building a broader theory than we have so far that takes into account other structures.
Indeed, there are many open directions for improving our understanding of computation in superposition, and we’d be excited for others to do future research (theoretical and empirical) in this area.
Some theoretical directions include:
Fitting the OV circuit into the boolean computation picture
Studying error propagation when U-AND is applied sequentially
Finding constructions with better interference at finite d
Making the story of boolean computation in transformers more complete by studying things that have not been captured by our current tasks
Generalisations to continuous variables
Empirical directions include:
Training toy models to understand if NNs can learn U-AND and related tasks, and how learned algorithms differ.
Throwing existing interp techniques at NNs trained on these tasks and trying to study what we find. Which techniques can handle the superposition adequately?
Trying to find instances of computation in superposition happening in small language models.
Structure of the Post
In Section 1, we define the U-AND task precisely, and then walk through our construction and show that it solves the task. Then we generalise the construction in 2 important ways: in Section 1.1, we modify the construction to compute ANDs of input features which are stored in superposition, allowing us to stack multiple U-AND layers together to simulate a boolean circuit. In Section 1.2 we modify the construction to compute ANDs of more than 2 variables at the same time, allowing us to compute all sufficiently small[4] boolean functions of the inputs with a single MLP. Then in Section 1.3 we explore efficiency gains from replacing the ReLU with a quadratic activation function, and explore the consequences.
In Section 2 we explore a series of questions around how to interpret the maths in Section 1, in the style of FAQs. Each part of Section 2 is standalone and can be skipped, but we think that many of the concepts discussed there are valuable and frequently misunderstood.
In section 3 we turn to the QK circuit, carefully introducing the skip feature-bigram checking task, and we explain our construction. We also discuss two scenarios that allow for more SFBs to be checked for than the simplest construction would allow.
We discuss the relevance of our constructions to real models in Section 4, and conclude in Section 5 with more discussion on Open Directions.
Notation and Conventions
d is the dimension of some activation space.d0 may also be used for the dimension of the input space, and d for the number of neurons in an MLP
m is the number of input features. If the input features are stored in superposition, m>d, otherwise m=d
→e1,→e2,…,→ed denotes an orthogonal basis of vectors. The standard basis refers to the neuron basis.
All vectors are denoted with arrows on top like this: →fi
We use single lines to denote the size of a set like this: |Si| or the L2 norm of a vector like this: |→fi|
We say that a boolean function g has been computed ϵ-accurately for some small parameter ϵ if the computed output never differs from g by more than ϵ. That is, whenever the function has the output 1, the computation outputs a number between 1±ϵ and whenever the function outputs 0, the computation outputs a number between ±ϵ.
We say that a pair of unit vectors is ϵ-almost orthogonal (for a fixed parameter ϵ) if their dot product is <ϵ (equivalently, if they are orthogonal to ϵ-accuracy). We say that a collection of unit vectors is ϵ-almost-orthogonal if they are pairwise almost orthogonal. We assume ϵ to be a fixed small number throughout the paper (unless specified otherwise).
It is known that for fixed ϵ, one can fit exponentially (in d) many almost orthogonal vectors in a d-dimensional Euclidean space. Throughout this paper, we will assume present in each NN activation space a suitably “large” collection of almost-orthogonal vectors, which we call an overbasis.
Vectors in this overbasis will be called f-vectors[5], and denoted →f1,→f2,…,→fm. We assume they correspond to binary properties of inputs relevant to a neural net (such as “Does this picture contain a cat?”). When convenient, we will assume these f-vectors are generated in a suitably random way: it is known that a random collection of vectors is, with high probability, an almost orthogonal overbasis, so long as the number of vectors is not superexponentially large in d[6].
In this post we make extensive use of Big-O notation and its variants, little o, Θ,Ω,ω. See wikipedia for definitions. We also make use of tilde notation, which means we ignore log factors. For example, by saying a function f(n) is Θ(g(n)), we mean that there are nonzero constants c1,c2>0 and a natural number N such that for all n>N, we have c1g(n)≤f(n)≤c2g(n). By saying a quantity is ~Θ(f(d)), we mean that this is true up to a factor that is a polynomial of logd — i.e., that it is asymptotically between f(d)/polylog(d) and f(d)polylog(d).
1 The Universal AND
We introduce a simple and central component in our framework, which we call the Universal AND component or U-AND for short. We start by introducing the most basic version of the problem this component solves. We then provide our solution to the simplest version of this problem. We later discuss a few generalizations: to inputs which store features in superposition, and to higher numbers of inputs to each AND gate. More elaboration on U-AND — in particular, addressing why we think it’s a good question to ask — is provided in Section 2.
1.1 The U-AND task
The basic boolean Universal AND problem: Given an input vector which stores an orthogonal set of boolean features, compute a vector from which can be linearly read off the value of every pairwise AND of input features, up to a small error. You are allowed to use only a single-layer MLP and the challenge is to make this MLP as narrow as possible.
More precisely: Fix a small parameter ϵ>0 and let d0 and ℓ be integers with d0≥ℓ[7]. Let →e1,…,→ed0 be the standard basis in Rd0, i.e.→ei is the vector whose ith component is 1 and whose other components are 0. Inputs are all at most ℓ-composite vectors, i.e., for each index set I⊆[d] with |I|≤ℓ, we have the input →xI=∑i∈I→ei∈Rd0. So, our inputs are in bijection with binary strings that contain at most ℓ ones[8]. Our task is to compute all (d02) pairwise ANDs of these input bits, where the notion of ‘computing’ a property is that of making it linearly represented in the output activation vector →a(→x)∈Rd. That is, for each pair of inputs i,j, there should be a linear function ri,j:Rd→R, or more concretely, a vector →ri,j∈Rd, such that →rTi,j→a(x)≈ϵANDi,j(x). Here, the ≈ϵ indicates equality up to an additive error ϵ and ANDi,j is 1 iff both bits i and j of x are 1. We will drop the subscript ϵ going forward.
We will provide a construction that computes these Θ(d20) features with a single d-neuron ReLU layer, i.e., a d0×d matrix W and a vector →b∈Rd such that →a(x)=ReLU(W→x+→b), with d≪d0. Stacking the readoff vectors →ri,j we provide as the rows of a readout matrix R, you can also see us as providing a parameter setting solving −−−−→ANDs(→x)≈ϵR(ReLU(W→x+→b)), where −−−−→ANDs(→x) denotes the vector of all (d02) pairwise ANDs. But we’d like to stress that we don’t claim there is ever something like this large, size (d02), layer present in any practical neural net we are trying to model. Instead, these features would be read in by another future model component, like how the components we present below (in particular, our U-AND construction with inputs in superposition and our QK circuit) do.
There is another kind of notion of a set of features having been computed, perhaps one that’s more native to the superposition picture: that of the activation vector (approximately) being a linear combination of f-vectors — we call these vectors f-vectors— corresponding to these properties, with coefficients that are functions of the values of the features. We can also consider a version of the U-AND problem that asks for output vectors which represent the set of all pairwise ANDs in this sense, maybe with the additional requirement that the f-vectors be almost orthogonal. Our U-AND construction solves this problem, too — it computes all pairwise ANDs in both senses. See the appendix for a discussion of some aspects of how the linear readoff notion of stuff having been computed, the linear combination notion of something having been computed, and almost orthogonality hang together.
1.2 The U-AND construction
We now present a solution to the U-AND task, computing (d02) new features with an MLP width that can be much smaller than (d02). We will go on to show how our solution can be tweaked to compute ANDs of more than 2 features at a time, and to compute ANDs of features which are stored in superposition in the inputs.
To solve the base problem, we present a random construction: W (with shape d0×d) has entries that are iid random variables which are 1 with probability p(d)≪1, and each entry in the bias vector is −1. We will pin down what p should be later.
We will denote by Si the set of neurons that are ‘connected’ to the ith input, in the sense that elements of the set are neurons for which the ith entry of the row of the weight vector that connects to that neuron is 1. →Si is used to denote the indicator set of Si: the vector which is 1 for every neuron in Si and 0 otherwise. So →Si is also the ith column of W.
Then we claim that for this choice of weight matrix, all the ANDs are approximately linearly represented in the MLP activation space with readoff vectors (and feature vectors, in the sense of Appendix B) given by
v(xi∧xj)=vij=−−−−−→Si∩Sj|Si∩Sj|
for all i,j, where we continue our abuse of notation to write Si∩Sj as shorthand for the vector which is an indicator for the intersection set, and |Si∩Sj| is the size of the set.
We preface our explanation of why this works with a technical note. We are going to choose d and p (as functions of d0) so that with high probability, all sets we talk about have size close to their expectation. To do this formally, one first shows that the probability of each individual set having size far from its expectation is smaller than any 1/poly(d0) using the Chernoff bound (Theorem 4 here), and one follows this by a union bound over all only poly(d0) sets to say that with probability 1−o(1), none of these events happen. For instance, if a set Si∩Sj has expected size log4d0, then the probability its size is outside of the range log4d0±log3d0 is at most 2e−μδ2/3=2e−log2d0=2ed−logd00 (following these notes, we let μ denote the expectation and δ denote the number of μ-sized deviations from the expectation — this bound works for δ<1 which is the case here). Technically, before each construction to follow, we should list our parameters d,p and all the sets we care about (for this first construction, these are the double and triple intersections between the Si) and then argue as described above that with high probability, they all have sizes that only deviate by a factor of 1+o(1) from their expected size and always carry these error terms around in everything we say, but we will omit all this in the rest of the U-AND section.
So, ignoring this technicality, let’s argue that the construction above indeed solves the U-AND problem (with high probability). First, note that |Si∩Sj|∼Bin(d,p2). We require that p is big enough to ensure that all intersection sets are non-empty with high probability, but subject to that constraint we probably want p to be as small as possible to minimise interference[9]. We’ll choose p=log2d0/√d, such that the intersection sets have size |Si∩Sj|≈log4d0. We split the check that the readoff works out into a few cases:
Firstly, if input features i, j, and at most ℓ−2 other input features are present (recall that we are working with ℓ-composite inputs), then letting →a denote the post-ReLU activation vector, we have →fANDij⋅→a=1 plus an error that is at most ℓ times [the sum of sizes of triple intersections involving i,j and each of the k−2 other features which are on, divided by the size of the Si∩Sj]. This is very likely less than O(1/log2d0) for all polynomially many pairs and sets of ℓ−2 other inputs at once[10], at least assuming d=ω(log8d0). The expected value of this error is log2d0/√d.
Secondly, if only one of i,j is present together with some at most ℓ−1 other features, then we get nonzero terms in the sum that expanding the dot product →fANDij⋅→a precisely for neurons in a triple intersection of i,j, and one of the ℓ−1 other features, so the readoff ≈0 — more precisely, O(1/log2d0) (again, assuming d=ω(log8d0)), and log2d0√d in expectation).
Finally, if neither of i,j is present, then the error corresponds to quadruple intersections, so it is even more likely at most O(1/log2d0) (still assuming d=ω(log8d0)), and log4d0d in expectation.
So we see that this readoff is indeed the AND of i and j up to error ϵ=O(1/log2d0).
To finish, we note without much proof that everything is also computed in the sense that ‘the activation vector is a linear combination of almost orthogonal features’ (defined in Appendix B). The activation vector being an approximate linear combination of pairwise intersection indicator vectors with coefficients being given by the ANDs follows from triple intersections being small, as does the almost-orthogonality of these feature vectors.
U-AND allows for arbitrary XORs to be efficiently calculated
A consequence of the precise (up to ϵ) nature of our universal AND is the existence of a universal XOR, in the sense of every XOR of features being computed. In this post by Sam Marks, it is tentatively observed that real-life transformers linearly compute XOR of arbitrary features in the weak sense of being able to read off tokens where XOR of two tokens is true using a linear probe (not necessarily with ϵ accuracy). This weak readoff behavior for AND would be unsurprising, as the residual stream already has this property (using the readoff vector →fi+→fj which has maximal value if and only if fi and fj are both present). However, as Sam Marks observes, it is not possible to read off XOR in this weak way from the residual stream. We can however see that such a universal XOR (indeed, in the strong sense of ϵ-accuracy) can be constructed from our strong (i.e., ϵ-accurate) universal AND. To do so, assume that in addition to the residual stream containing feature vectors →fi and →fj, we’ve also already almost orthogonally computed universal AND features →fANDi,j into the residual stream. Then we can weakly (and in fact, ϵ-accurately) read off XOR from this space by taking the dot product with the vector →fXORi,j:=→fi+→fj−2→fANDi,j. Then we see that if we had started with the two-hot pair →fi′+→fj′, the result of this readoff will be, up to a small error O(ϵ),
⎧⎨⎩0=0−0,|{i,j}∩{i′,j′}|=0(neither coefficient agrees)1=1−0,|{i,j}∩{i′,j′}|=1(one coefficient agrees)0=2−2,{i,j}={i′,j′}(both coefficients agree)
This gives a theoretical feasibility proof of an efficiently computable universal XOR circuit, something Sam Marks believed to be impossible.
1.3 Handling inputs in superposition: sparse boolean computers
Any boolean circuit can be written as a sequence of layers executing pairwise ANDs and XORs[11] on the binary entries of a memory vector. Since our U-AND can be used to compute any pairwise ANDs or XORs of features, this suggests that we might be able to emulate any boolean circuit by applying something like U-AND repeatedly. However, since the outputs of U-AND store features in superposition, if we want to pass these outputs as inputs to a subsequent U-AND circuit, we need to work out the details of a U-AND construction that can take in features in superposition. In this section we explore the subtleties of modifying U-AND in this way. In so doing, we construct an example of a circuit which acts entirely in superposition from start to finish — nowhere in the construction are there as many dimensions as features! We consider this to be an interesting result in its own right.
U-ANDs ability to compute many boolean functions of inputs features stored in superposition provides an efficient way to use all the parameters of the neural net to compute (up to a small error) a boolean circuit with a memory vector that is wider than the layers of the NN[12]. We call this emulating a ‘boolean computer’. However, three limitations prevent any boolean circuit from being computed:
An injudicious choice of a layer executing XORs applied to a sparse input can fail to give a sparse output vector. Since U-AND only works on inputs with sparse features, this means that we can only emulate circuits with the property than on sparse inputs, their memory vector is sparse throughout the computation. We call these circuits ‘sparse boolean circuits’.
Even if the outputs of the circuit remain sparse at every layer, the ϵ errors involved in the boolean read-offs compound from layer to layer. We hope that it is possible to manage this interference (perhaps via subtle modifications to the constructions) enough to allow multiple steps of sequential computation, although we leave an exploration of error propagation to future work.
We can’t compute an unbounded number of new features with a finite-dimensional hidden layer. As we will see in this section, when input features are stored in superposition (which is true for outputs of U-AND and therefore certainly true for all but possibly the first layer of an emulated boolean circuit), we cannot compute more than ~Θ(d0d) (number of parameters in the layer) many new boolean functions at a time.
Therefore, the boolean circuits we expect can be emulated in superposition (1) are sparse circuits (2) have few layers (3) have memory vectors which are not larger than the square of the activation space dimension.
Construction details for inputs in superposition
Now we generalize U-AND to the case where input features can be in superposition. With f-vectors →f1,…,→fm∈Rd0, we give each feature a random set of neurons to map to, as before. After coming up with such an assignment, we set the ith row of W to be the sum of the f-vectors for features which map to the ith neuron. In other words, let F be the m×d0 matrix with ith row given by the components of →fi in the neuron basis:
F=⎛⎜ ⎜ ⎜⎝→f1→⋮→fm→⎞⎟ ⎟ ⎟⎠
Now let \hat{W} be a sparse matrix (with shape d×m) with entries that are iid Bernoulli random variables which are 1 with probability p(d)≪1. Then:
W=^WF
Unfortunately, since the →f1,…,→fm are random vectors, their inner product will have a typical size of 1/√d0. So, on an input which has no features connected to neuron i, the preactivation for that neuron will not be zero: it will be a sum of these interference terms, one for each feature that is connected to the neuron. Since the interference terms are uncorrelated and mean zero, they start to cause neurons to fire incorrectly when Θ(d0) neurons are connected to each neuron. Since each feature is connected to each neuron with probability p=log2d0√d) this means neurons start to misfire when m=~Θ(d0√d)[13]. At this point, the number of pairwise ANDs we have computed is (m2)=~Θ(d20d).
This is a problem, if we want to be able to do computation on input vectors storing potentially exponentially many features in superposition, or even if we want to be able to do any sequential boolean computation at all:
Consider an MLP with several layers, all of width dMLP, and assume that each layer is doing a U-AND on the features of the previous layer. Then if the features start without superposition, there are initially dMLP features. After the first U-AND, we have Θ(d2MLP) new features, which is already too many to do a second U-AND on these features!
Therefore, we will have to modify our goal when features are in superposition. That said, we’re not completely sure there isn’t any modification of the construction that bypasses such small polynomial bounds. But e.g. one can’t just naively make ^W sparser — p can’t be taken below d−1/2 without the intersection sets like |Si∩Sj| becoming empty. When features were not stored in superposition, solving U-AND corresponded to computing d20 many new features. Instead of trying to compute all pairwise ANDs of all (potentially exponentially many) input features in superposition, perhaps we should try to compute a reasonably sized subset of these ANDs. In the next section we do just that.
A construction which computes a subset of ANDs of inputs in superposition
Here, we give a way to compute ANDs of up to d0d particular feature pairs (rather than all (m2) ANDs) that works even for m that is superpolynomial in d0[14]. (We’ll be ignoring log factors in much of what follows.)
In U-AND, we take ^W to be a random matrix with iid 0⁄1 entries with probability p=log2d0√d. If we only need/want to compute a subset of all the pairwise ANDs — let E be this set of all pairs of inputs {i,j} for which we want to compute the AND of i and j — then whenever {i,j}∈E, we might want each pair of corresponding entries in the corresponding columns i and j of the adjacency matrix ^W, i.e., each pair (^W)ki, (^W)kj to be a bit more correlated than an analogous pair in column i′ and j′ with {i′,j′}∉E. Or more precisely, we want to make such pairs of columns {i,j} have a surprisingly large intersection for the general density of the matrix — this is to make sure that we get some neurons which we can use to read off the AND of {i,j}, while choosing the general density in ^W to be low enough that we don’t cross the density threshold at which a neuron needs to care about too many input features.
One way to do this is to pick a uniformly random set of log4d0 neurons for each {i,j}∈E, and to set the column of ^W corresponding to input i to be the indicator vector of the union of these sets (i.e., just those assigned to gates involving i). This way, we can compute up to around |E|=~Θ(d0d) pairwise ANDs without having any neuron care about more than d0 input features, which is the requirement from the previous section to prevent neurons misfiring when input f-vectors are random vectors in superposition with typical interference size Θ(1/√d0).
1.4 ANDs with many inputs: computation of small boolean circuits in a single layer
It is known that any boolean circuit with k inputs can be written as a linear combination (with possibly exponential in k terms, which is a substantial caveat) ANDs with up to k inputs (fan-in up to k)[15]. This means that, if we can compute not just pairwise ANDs, but ANDs of all fan-ins up to k, then we can write down a ‘universal’ computation that computes (simultaneously, in a linearly-readable sense) all possible circuits that depend on some up to k inputs.
The U-AND construction for higher fan-in
We will modify the standard, non-superpositional U-AND construction to allow us to compute all ANDs of a specific fan-in k.
We’ll need two modifications:
We’re now interested in k-wise intersections between the Si. The size of these intersections is smaller than double intersections, so we need to increase p to guarantee they are nonempty. A sensible choice for fan-in k is p=log2d0d1/k.
We only want neurons to fire when k of the features that connect to them are present at the same time, so we require the bias to be −k+1.
Now we read off the AND of a set I of input features along the vector ⋂i∈ISi.
We can straightforwardly simultaneously compute all ANDs of fan-ins ranging from 2 to k by just evenly partitioning the d neurons into k−1 groups — let’s label these 2,3,…,k — and setting the weights into group i and the biases of group i as in the fan-in i U-AND construction.
A clever choice of density can give us all the fan-ins at once
Actually, we can calculate all ANDs of up to some constant fan-ink in a way that feels more symmetric than the option involving a partition above[16] by reusing the fan-in 2 U-AND with (let’s say) d=d0 and a careful choice of p=1log2d0 . This choice of p is larger than log2d0d1/k for any k, ensuring that every intersection set is non-empty. Then, one can read off ANDi,j from Si∩Sj as usual, but one can also read off ANDi,j,k with the composite vector
−Si∩Sj∩Sk|Si∩Sj∩Sk|+Si∩Sj|Si∩Sj|+Si∩Sk|Si∩Sk|+Sj∩Sk|Sj∩Sk| In general, one can read off the AND of an index set I with the vector ∑I′⊆I s.t. |I′|≥2(−1)|I|−|I′|+1vI′ where vI′=⋂i∈I′Si∣∣⋂i∈I′Si∣∣One can show that this inclusion-exclusion style formula works by noting that if the subset of indices of I which are on is J, then the readoff will be approximately ∑I′⊆I s.t. |I′|≥2(−1)|I|−|I′|+1max(0,|I′∩J|−1). We’ll leave it as an exercise to show that this is 0 if J≠I and 1 if J=I.
Extending the targeted superpositional AND to other fan-ins
It is also fairly straightforward to extend the construction for a subset of ANDs when inputs are in superposition to other fan-ins, doing all fan-ins on a common set of neurons. Instead of picking a set for each pair that we need to AND as above, we now pick a set for each larger AND gate that we care about. As in the previous sparse U-AND, each input feature gets sent to the union of the sets for its gates, but this time, we make the weights depend on the fan-in. Letting K denote the max fan-in over all gates, for a fan-in k gate, we set the weight from each input to K/k, and set the bias to −K+1. This way, still with at most about ~Θ(d2) gates, and at least assuming inputs have at most some constant number of features active, we can read the output of a gate off with the indicator vector of its set.
1.5 Improved Efficiency with a Quadratic Nonlinearity
It turns out that, if we use quadratic activation functions x↦x2 instead of ReLU’s x↦ReLU(x), we can write down a much more efficient universal AND construction. Indeed, the ReLU universal AND we constructed can compute the universal AND of up to ~Θ(d3/2) features in a d-dimensional residual stream. However, in this section we will show that with a quadratic activation, for ℓ-composite vectors, we can compute all pairwise ANDs of up to m=Ω(exp(12ℓϵ2√d))[17] features stored in superposition (this is exponential in √d, so superpolynomial in d(!)) that admit a single-layer universal AND circuit.
The idea of the construction is that, on the large space of features Rm, the AND of the boolean-valued feature variables fi,fj can be written as a quadratic function qi,j:{0,1}m↦{0,1}; explicitly, qi,j(f1,…,fm)=fi⋅fj. Now if we embed feature space Rm onto a smaller Rr in an ϵ-almost-orthogonal way, it is possible to show that the quadratic function qi,j on Rm is well-approximated on sparse vectors by a quadratic function on Rr (with error bounded above by 2ϵ on 2-sparse inputs in particular). Now the advantage of using quadratic functions is that any quadratic function on Rr can be expressed as a linear read-off of a special quadratic function Q:Rr→Rr2 given by the composition of a linear function Rr→Rr2 and a quadratic element-wise activation function on Rr2 which creates a set of neurons which collectively form a basis for all quadratic functions. Now we can set d=r2 to be the dimension of the residual stream and work with an r-dimensional subspace V of the residual stream, taking the almost-orthogonal embedding Rm→V. Then the map VQ→Rd provides the requisite universal AND construction. We make this recipe precise in the following section
Construction Details
In this section we use slightly different notation to the rest of the post, dropping overarrows for vectors, and we drop the distinction between features and f-vectors.
Let V=Rr be as above. There is a finite-dimensional space of quadratic functions on Rr, with basis qij=xixj of size r2 (such that we can write every quadratic function as a linear combination of these basis functions); alternatively, we can write qij(v)=(v⋅ei)(v⋅ej), for ei,ej the basis vectors. We note that this space is spanned by a set of functions which are squares of linear functions of {xi}:
L(1)i(x1,…,xr)=xiL(2)i,j(x1,…,xr)=xi+xjL(3)i,j(x1,…,xr)=xi−xj
The squares of these functions are a valid basis for the space of quadratic functions on Rr since qii=(L(1)i)2 and for i≠j, we have qij=(L(2)i,j)2−(L(3)i,j)24. There are m distinct functions of type (1), and (m2) functions each of type (2) and (3), for a total of r2 basis functions as before. Thus there exists a single-layer quadratic-activation neural net Q:x↦y from Rr→Rr2 such that any quadratic function on Rr is realizable as a “linear read-off”, i.e., given by composing Q with a linear function Rr2→R. In particular, we have linear “read-off” functions Λij:Rr2→R such that Lij(Q(x))=qij(x).
Now suppose that f1,…,fm is a collection of f-vectors which are ϵ-almost-orthogonal, i.e., such that |fi|=1 for any i and |fi⋅fj|<ϵ∀i<j≤m. Note that (for fixed ϵ<1), there exist such collections with exponential (in r) number of vectors m. We can define a new collection of symmetric bilinear functions (i.e., functions in two vectors v,w∈Rn which are linear in each input independently and symmetric to switching v,w), ϕi,j, for a pair of (not necessarily distinct) indices 0<i≤j≤m, defined by ϕi,j(v)=(v⋅fi)(v⋅fj) (this is a product of two linear functions, hence quadratic). We will use the following result:
Proposition 1 Suppose ϕi,j is as above and 0<i′≤j′<m is another pair of (not necessarily distinct) indices associated to feature vectors vi,vj. Then
ϕi,j(vi′,vj′)⎧⎨⎩=1,i=i′ and j=j′∈(−ϵ,ϵ),(i,j)≠(i′,j′)∈(−ϵ2,ϵ2),{i,j}∩{i′,j′}=∅ (i.e., no indices in common)
This proposition follows immediately from the definition of ϕk,ℓ and the almost orthogonality property. □
Now define the single-valued quadratic function ϕsinglei,j(v):=12ϕi,j(v,v), by applying the bilinear form to two copies of the same vector and dividing by 2. Then the proposition above implies that, for two pairs of distinct indices 0<i<j≤m and 0<i′<j′≤m we have the following behavior on the sum of two features (the superpositional analog of a two-hot vector):
ϕsinglei,j(vi′+vj′)=ϕi,j(vi′,vi′)+2ϕi,j(vi′,vj′)+ϕi,j(vj′,vj′)2=ϕi,j(vi′,vj′)+O(ϵ).
The first formula follows from bilinearity (which is equivalent to the statement that the two entries in ϕi,j behave distributively) and the last formula follows from the proposition since we assumed (i,j) are distinct indices, hence cannot match up with a pair of identical indices (i′,i′) or (j′,j′). Moreover, O(ϵ) term in the formula above is bounded in absolute value by 2ϵ2=ϵ.
Combining this formula with Proposition 1, we deduce:
Proposition 2
ϕsinglei,j(vi′+vj′)=⎧⎨⎩1+O(ϵ),i=i′ and j=j′O(ϵ),(i,j)≠(i′,j′)O(ϵ2),i≠i′.
Moreover, by the triangle inequality, the linear constants inherent in the O(...) notation are ≤2. □
Corollary ϕi,j(vi′+vj′)=δ(i,j),(i′,j′)+O(ϵ), where the δ notation returns 1 when the two pairs of indices are equal and 0 otherwise.
We can now write down the universal AND function by setting d=r2 above. Assume we have m<exp(ϵ22r). This guarantees (with probability approaching 1) that m random vectors in V≅Rr are (ϵ-)almost orthogonal, i.e., have dot products <ϵ. We assume the vectors v1,…,vm are initially embedded in V⊂Rd. (Note that we can instead assume they were initially randomly embedded in Rd, then re-embedded in Rr by applying a random projection and rescaling appropriately.) Let Q:Rr→Rd=r2 be the universal quadratic map as above; we let qij:Rd→R be the quadratic functions as above. Now we claim that Q is a universal AND with respect to the feature vectors v1,…,vN. Note that, since the function ϕsinglei,j(v) is quadratic on Rr, it can be factorized as ϕsinglei,j(x)=Φi,j(Q(x)), for Φi,j some linear function on Rr2[18]. We now see that the linear maps Φi,j are valid linear read-offs for ANDs of features: indeed,
Φi,j(Q(vi′+vj′))=ϕsinglei,j(vi′,vj′)=δ(i,j),(i′,j′)+O(ϵ)=AND(bi′,j′i,bi′,j′j),
where bi′,j′ is the two-hot boolean indicator vector with 1s in positions i′ and j′. Thus the AND of any two indices i,j can be computed via the readout linear function Φi,j on any two-hot input bi′,j′. Moreover, applying the same argument to a larger sparse sum gives Φi,j(Q(∑mk=1bkvk))=AND(bi,bj)+O(s2ϵ), where s=∑mk=1bk is the sparsity[19].
Scaling and comparison with ReLU activations
It is surprising that the universal AND circuit we wrote down for quadratic activations is so much more expressive than the one we have for ReLU activations, since the conventional wisdom for neural nets is that the expressivity of different (suitably smooth) activation functions does not increase significantly when we replace arbitrary activations by quadratic ones. We do not know if this is a genuine advantage of quadratic activations over others (and indeed might be implemented in transformers in some sophisticated way involving attention nonlinearities), or whether there is some yet-unknown reason that (perhaps assuming nice properties of our features), ReLU’s can give more expressive universal AND circuits than we have been able to find in the present work. We list this discrepancy as an interesting open problem that follows from our work.
Generalizations
Note that the nonlinear function Q above lets us read off not only the AND of two sparse boolean vectors, but more generally the sum of products of coordinates of any sufficiently sparse linear combination of feature vectors vi (not necessarily boolean). More generally, if we replace quadratic activations with cubic or higher, we can get cubic expressions, such as the sum of triple ANDs (or, more generally, products of triples of coordinates). A similar effect can be obtained by chaining l sequential levels of quadratic activations to get polynomial nonlinearities with exponent e=2l. Then so long as we can fit O(re)[20] features in the residual stream in an almost-orthogonal way (corresponding to a basis of monomials of degree d on r-dimensional space), we can compute sums of any degree-e monomial over features, and thus any boolean circuit of degree e, up to O(ϵ), where the linear constant implicit in the O depends on the exponent e. This implies that for any value e, there is a dimension d universal nonlinear map Rd→Rd with ⌈log2(e)⌉ quadratic activations such that any sparse boolean circuit involving ≤e elements is linearly represented (via an appropriate readoff vector). Moreover, keeping e fixed, d grows only as O(log(n))e. However, the constant associated with the big-O notation might grow quite quickly as the exponent e increases. It would be interesting to analyse this scaling behavior more carefully, but that is outside the scope of the present work.
1.6 Universal Keys: an application of parallel boolean computation
So far, we have used our universal boolean computation picture to show that superpositional computation in a fully-connected neural network can be more efficient (specifically, compute roughly as many logical gates as there are parameters rather than non-superpositional implementations, which are bounded by number of neurons). This does not fully use the universality of our constructions: i.e., we must at every step read a polynomial (at most quadratic) number of features from a vector which can (in either the fan-in-k or quadratic-activation contexts) compute a superpolynomial number of boolean circuits. At the same time, there is a context in transformers where precisely this universality can give a remarkable (specifically, superpolynomial in certain asymptotics) efficiency improvement. Namely, recall that the attention mechanism of a transformer can be understood as a way for the last-token residual stream to read information from past tokens which pass a certain test associated to the query-key component. In our simplified boolean model, we can conceptualize this as follows:
Each token possesses a collection of “key features” which indicate bits of information about contexts where reading information from this token is useful. These can include properties of grammar, logic, mood, or context (food, politics, cats, etc.)
The current token attends to past tokens whose key features have a certain combination of features, which we conceptualize as tokens on whose features a certain boolean “relevance” function, glast token returns 1. For example, the current token may ‘want’ to attend to all keys which have feature 1 and feature 4 but not feature 9, or exactly one of feature 2 and feature 8. This corresponds to the boolean function g=(f1∧f4∧¬f9)∨(f2⊗f8). Importantly, the choice of g varies from token to token. We abstract away the question of generating this relevance function as some (possibly complicated) nonlinear computation implemented in previous layers.
Each past token generates a key vector in a certain vector space (associated with an attention head) which is some (possibly nonlinear) function of the key features; the last token then generates a query vector which functions as a linear read-off, and should return a high value on past tokens for which the relevance formula evaluates to True. Note that the key vector is generated before the query vector, and before the choice of which g to use is made.
Importantly, there is an information asymmetry between the “past” tokens (which contribute the key) and the last token that implements the linear read-off via query: in generating the boolean relevance function, the past token can use information that is not accessible to the token generating the key (as it is in its “future” – this is captured e.g. by the attention mask). One might previously have assumed that in generating a key vector, tokens need to “guess” which specific combinations of key features may be relevant to future tokens, and separately generate some read-off for each; this limits the possible expressivity of choosing the relevance function g to a small (e.g. linear in parameter number) number of possibilities.
However, our discovery of circuits that implement universal calculation suggests a surprising way to resolve this information asymmetry: namely, using a universal calculation, the key can simultaneously compute, in an approximately linearly-readable way, ALL possible simple circuits of up to Olog(dresid) inputs. This increases the number of possibilities of the relevance function g to allow all such simple circuits; this can be significantly larger than the number of parameters and asymptotically (for logarithmic fan-ins) will in fact be superpolynomial[21]. As far as we are aware, this presents a qualitative (from a complexity-theoretic point of view) update to the expressivity of the attention mechanism compared to what was known before.
Sam Marks’ discovery of the universal XOR was done in this context: he observed using a probe that it is possible for the last token of a transformer to attend to past tokens that return True as the XOR of an arbitrary pair of features, something that he originally believed was computationally infeasible.
We speculate that this will be noticeable in real-life transformers, and can partially explain the observation that transformers tend to implement more superposition than fully-connected neural networks.
2 U-AND: discussion
We discuss some conceptual matters broadly having to do with whether the formal setup from the previous section captures questions of practical interest. Each of these subsections is standalone, and you needn’t read any to read Section 3.
Aren’t the ANDs already kinda linearly represented in the U-AND input?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. The objection is this: aren’t ANDs already linearly present in the input, so in what sense have we computed them with the U-AND? Indeed, if we take the dot product of a particular 2-hot input with (→ei+→ej)/2, we get 0 if neither the ith nor the jth features are present, 1/2 if 1 of them is present, and 1 if they are both present. If we add a bias of −1/4, then without any nonlinearity at all, we get a way to read off pairwise U-AND for ϵ=1/4. The only thing the nonlinearity lets us do is to reduce this “interference” ϵ=1/4 to a smaller ϵ. Why is this important?
In fact, one can show that you can’t get more accurate than ϵ=1/4 without a nonlinearity, even with a bias, and ϵ=1/4 is not good enough for any interesting boolean circuit. Here’s an example to illustrate the point:
Suppose that I am interested in the variable z=∧(xi,xj)+∧(xk,xl). z takes on a value in {0,1,2} depending on whether both, one, or neither of the ANDs are on. The best linear approximation to z is 1/2(xi+xj+xk+xl−1), which has completely lost the structure of z. In this case, we have lost any information about which way the 4 variables were paired up in the ANDs.
In general, computing a boolean expression with k terms without the signal being drowned out by the noise will require ϵ<1/k if the noise is correlated, and ϵ<1/k2 if the noise is uncorrelated. In other words, noise reduction matters! The precision provided by ϵ-accuracy allows us to go from only recording ANDs to executing more general circuits in an efficient or universal way. Indeed, linear combinations of linear combinations just give more linear combinations – the noise reduction is the difference between being able to express any boolean function and being unable to express anything nonlinear at all. The XOR construction (given above) is another example that can be expressed as a linear combination involving the U-AND and would not work without the nonlinearity.
Aren’t the ANDs already kinda nonlinearly represented in the U-AND input?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. While one cannot read off the ANDs linearly before the ReLU, except with a large error, one could certainly read them off with a more expressive model class on the activations. In particular, one can easily read ANDi,j off with a ReLU probe, by which we mean ReLU(rTx+b), with r=ei+ej and b=−1. We think there’s some truth to this: we agree that if something can be read off with such a probe, it’s indeed at least almost already there. And if we allowed multi-layer probes, the ANDs would be present already when we only have some pre-input variables (that our input variables are themselves nonlinear functions of). To explore a limit in ridiculousness: if we take stuff to be computed if it is recoverable by a probe that has the architecture of GPT-3 minus the embed and unembed and followed by a projection on the last activation vector of the last position residual stream, then anything that is linearly accessible in the last layer of GPT-3 is already ‘computed’ in the tuple of input embeddings. And to take a broader perspective: any variable ever computed by a deterministic neural net is in fact a function of the input, and is thus already ‘there in the input’ in an information-theoretic sense (anything computed by the neural net has zero conditional entropy given the input). The information about the values of the ANDs is sort of always there, but we should think of it as not having been computed initially, and as having been computed later[22].
Anyway, while taking something to be computed when it is affinely accessible seems natural when considering reading that information into future MLPs, we do not have an incredibly strong case that it’s the right notion. However, it seems likely to us that once one fixes some specific notion of stuff having been computed, then either exactly our U-AND construction or some minor variation on it would still compute a large number of new features (with more expressive readoffs, these would just be more complex properties — in our case, boolean functions of the inputs involving more gates). In fact, maybe instead of having a notion of stuff having been computed, we should have a notion of stuff having been computed for a particular model component, i.e. having been represented such that a particular kind of model component can access it to ‘use it as an input’. In the case of transformers, maybe the set of properties that have been computed as far as MLPs can tell is different than the set of properties that have been computed as far as attention heads (or maybe the QK circuit and OV circuit separately) can tell. So, we’re very sympathetic to considering alternative notions of stuff having been computed, but we doubt U-AND would become much less interesting given some alternative reasonable such notion.
If you think all this points to something like it being weird to have such a discrete notion of stuff having been computed vs not at all, and that we should maybe instead see models as ‘more continuously cleaning up representations’ rather than performing computation: while we don’t at present know of a good quantitative notion of ‘representation cleanliness’, so we can’t at present tell you that our U-AND makes amount x of representation cleanliness progress and x is sort of large compared to some default, it does seem intuitively plausible to us that it makes a good deal of such progress. A place where linear read-offs are clearly qualitatively important and better than nonlinear read-offs is in application to the attention mechanism of a transformer.
Does our U-AND construction really demonstrate MLP superposition?
This subsection refers to the basic U-AND construction from Section 1.1, with inputs not in superposition, but the objection we consider here could also be raised against other U-AND variants. One could try to tell a story that interprets our U-AND construction in terms of the neuron basis: we can also describe the U-AND as approximately computing a family of functions each of which record whether at least two features are present out of a particular subset of features[23]. Why should we see the construction as computing outputs into superposition, instead of seeing it as computing these different outputs on the neurons? Perhaps the ‘natural’ units for understanding the NN is in terms of these functions, as unintuitive as they may seem to a human.
In fact, there is a sense in which if one describes the sampled construction in the most natural way it can be described in the superposition picture, one needs to spend more bits than if one describes it in the most natural way it can be described in this neuron picture. In the neuron picture, one needs to specify a subset of size ~Θ(d0/√d) for each neuron, which takes dlog2(d0~Θ(d0/√d))≤~Θ(d20√d) bits to specify. In the superpositional picture, one needs to specify (d02) subsets of size ~Θ(1), which takes about ~Θ(d20) bits to specify[24]. If, let’s say, d=d0, then from the point of view of saving bits when representing such constructions, we might even prefer to see them in a non-superpositional manner!
We can imagine cases (of something that looks like this U-AND showing up in a model) in which we’d agree with this counterargument. For any fixed U-AND construction, we could imagine a setup where for each neuron, the inputs feeding into it form some natural family — slightly more precisely, that whether two elements of this family are present is a very natural property to track. In fact, we could imagine a case where we perform future computation that is best seen as being about these properties computed by the neurons — for instance, our output of the neural net might just be the sum of the activations of these neurons. For instance, perhaps this makes sense because having two elements of one of these families present is necessary and sufficient for an image to be that of a dog. In such a case, we agree it would be silly to think of the output as a linear combination of pairwise AND features.
However, we think there are plausible contexts in which such a circuit would show up in which it seems intuitively right to see the output as a sparse sum of pairwise ANDs: when the families tracked by particular neurons do not seem at all natural and/or when it is reasonable to see future model components as taking these pairwise AND features as inputs. Conditional on thinking that superposition is generic, it seems fairly reasonable to think that these latter contexts would be generic.
Is universal calculation generic?
The construction of the universal AND circuit in the “quadratic nonlinearity” section above can be shown to be stable to perturbations; a large family of suitably “random” circuits in this paradigm contain all AND computations in a linearly-readable way. This updates us to suspect that at least some of our universal calculation picture might be generic: i.e., that a random neural net, or a random net within some mild set of conditions (that we can’t yet make precise), is sufficiently expressive to (weakly) compute any small circuit. Thus linear probe experiments such as Sam Marks’ identification of the “universal XOR” in a transformer may be explainable as a consequence of sufficiently complex, “random-looking” networks. This means that the correct framing for what happens in a neural net executing superposition might not be that the MLP learns to encode universal calculation (such as the U-AND circuit), but rather that such circuits exist by default, and what the neural network needs to learn is, rather, a readoff vector for the circuit that needs to be executed. While we think that this would change much of the story (in particular, the question of “memorization” vs. “generalization” of a subset of such boolean circuit features would be moot if general computation generically exists), this would not change the core fact that such universal calculation is possible, and therefore likely to be learned by a network executing (or partially executing) superposition. In fact, such an update would make it more likely that such circuits can be utilized by the computational scheme, and would make it even more likely that such a scheme would be learned by default.
We hope to do a series of experiments to check whether this is the case: whether a random network in a particular class executes universal computation by default. If we find this is the case, we plan to train a network to learn an appropriate read-off vector starting from a suitably random MLP circuit, and, separately, to check whether existing neural networks take advantage of such structure (i.e., have features – e.g. found by dictionary learning methods – which linearly read off the results of such circuits). We think this would be particularly productive in the attention mechanism (in the context of “universal key” generation, as explained above).
What are the implications of using ϵ-accuracy? How does this compare to behavior found by minimizing some loss function?
A specific question here is:
The answer is that sometimes they are not going to be the same. In particular, our algorithm may not be given a low loss by MSE. Nevertheless, we think that ϵ-accuracy is a better thing to study for understanding superposition than MSE or other commonly considered loss functions (cross entropy would be much less wise than either!) This point is worth addressing properly, because it has implications for how we think about superposition and how we interpret results from the toy models of superposition paper and from sparse autoencoders, both of which typically use MSE.
For our U-AND task, we ask for a construction →f(→x) that approximately equals a 1-hot target vector →y, with each coordinate allowed to differ from its target value by at most epsilon. A loss function which would correspond to this task would look like a cube well with vertical sides (the inside of the region L∞(→f(→x),→y)<ϵ). This non-differentiable loss function would be useless for training. Let’s compare this choice to alternatives and defend it.
If we know that our target is always a 1-hot vector, then maybe we should have a softmax at the end of the network and use cross-entropy loss. We purposefully avoid this, because we are trying to construct a toy model of the computation that happens in intermediate layers of a deep neural network, taking one activation vector to a subsequent activation vector. In the process there is typically no softmax involved. Also, we want to be able to handle datapoints in which more than 1 AND is present at a time: the task is not to choose which AND is present, but *which of the ANDs* are present.
The other ubiquitous choice of loss function is MSE. This is the loss function used to evaluate model performance in two tasks that are similar to U-AND: the toy model of superposition and SAEs. Two reasons why this loss function might be principled are
If there is reason to think of the model as a Gaussian probability model
If we would like our loss function to be basis independent.
We see no reason to assume the former here, and while the latter is a nice property to have, we shouldn’t expect basis independence here: we would like the ANDs to be computed in a particular basis and are happy with a loss function that privileges that basis.
Our issue with MSE (and Lp in general for finite p) can be demonstrated with the following example:
Suppose the target is y=(1,0,0,…). Let ^y=(0,0,…) and ~y=(1+ϵ,ϵ,ϵ,…), where all vectors are (d02)-dimensional. Then ||y−^y||p=1 and ||y−~y||p=(d02)1/pϵ. For large enough (d02)>ϵ−p, the latter loss is larger than 1[25]. Yet intuitively, the latter model output is likely to be a much better approximation to the target value, from the perspective of the way the activation vector will be used for subsequent computation. Intuitively, we expect that for the activation vector to be good enough to trigger the right subsequent computation, it needs to be unambiguous whether a particular AND is present, and the noise in the value needs to be below a certain critical scale that depends on the way the AND is used subsequently, to avoid noise drowning out signal. To understand this properly we’d like a better model of error propagation.
It is no coincidence that our U-AND algorithm may be ϵ-accurate for small ϵ, but is not a minimum of the MSE. In general, ϵ-accuracy permits much more superposition than minimising the MSE, because it penalises interference less.
For a demonstration of this, consider a simplified toy model of superposition with hidden dimension d and inputs which are all 1-hot unit vectors. We consider taking the limit as the number of input features goes to infinity and ask: what is the optimum number N(d) of inputs that the model should store in superposition, before sending the rest to the zero vector?
If we look for ϵ-accurate reconstruction, then we know how to answer this: a random construction allows us to fit at least Nϵ(d)=Cexpϵ2d vectors into d-dimensional space.
As for the algorithm that minimises the MSE reconstruction loss (ie not sent to the zero vector in the hidden space), consider that we have already put n of the inputs into superposition, and we are trying to decide whether it is a good idea to squeeze another one in there. Separating the loss function into reconstruction terms and interference terms (as in the original paper):
The n+1th input being stored subtracts a term of order 1 from the reconstruction loss
Storing this input will also lead to an increase in the interference loss. As for how much, let us write δ(n)2 for the average mean squared dot product between the n+1th feature vector and one of the n feature vectors that were already there. Since the n+1th feature has n distinct features to interfere with, storing it will contribute a term of order nδ(n)2 to the interference loss.
So, the optimum number of features to store can be found by asking when the contribution to the loss ℓ(n+1)∼nδ(n)2−1 switches from negative to positive, so we need an estimate of δ(n). If feature vectors are chosen randomly, then δ(n)2=O(1/d) and we find that the optimal number of features to store is O(d). In fact, feature vectors are chosen to minimise interference, which allows us to fit a few more feature vectors in (the advantage this gives us is most significant at small n) before the accumulating interferences become too large, and empirically we observe that the optimal number of features to store is NL2(d)=O(dlogd). This is much much less superposition that we are allowed with ϵ-accurate reconstruction!
See the figure below for experimental values of NLp(d) for a range of p,d. We conjecture that for each p,NLp(d) is the minimum of an exponential function which is independent of p and something like a polynomial which depends on p.
3 The QK part of an attention head can check for many skip feature-bigrams, in superposition
In this section, we present a story for the QK part of an attention head which is analogous to the MLP story from the previous section. Note that although both focus on the QK component, this is a different (though related) story to the story about universal keys from section 1.4.
We begin by specifying a simple task that we think might capture a large fraction of the role performed by the QK part of an attention head. Roughly, the task (analogous to the U-AND task for the MLP) is to check for the presence of one in a large set of ‘skip bigrams’[26] of features[27].
We’ll then provide a construction of the QK part of an attention head that can perform this task in a superposed manner — i.e., a specification of a low-rank matrix WQK=WTKWQ that checks for a given set of skip feature-bigrams. A naive construction could only check for dhead feature bigrams; ours can check for ~Θ(dheaddresid) feature bigrams. This construction is analogous to our construction solving the targeted superpositional AND from the previous sections.
3.1 The skip feature-bigram checking task
Let B be a set of ‘skip feature-bigrams’; each element of B is a pair of features (→fi,→fj)∈Rdresid×Rdresid. Let’s define what we mean by a skip feature-bigram being present in a pair of residual stream positions. Looking at residual stream activation vectors just before a particular attention head (after layernorm is applied), we say that the activation vectors →as,→at∈Rdresid at positions s,t contain the skip feature-bigram (→fi,→fj) if feature →fi is present in →at and feature →fj is present in →as. There are two things we could mean by the feature →fi being present in an activation vector →a. The first is that →fi⋅→a′ is always either ≈0 or ≈1 for any a′ in some relevant data set of activation vectors, and →fi⋅→a=1. The second notion assumes the existence of some background set →f1,→f2,…,→fm in terms of which each activation vector a has a given background decomposition, a=∑mi=1ci→fi. In fact, we assume that all ci∈{0,1}, with at most some constant number of ci=1 for any one activation vector, and we also assume that the →fi are random vectors (we need them to be almost orthogonal). The second notion guarantees the first but with better control on the errors, so we’ll run with the second notion for this section[28].
Plausible candidates for skip feature-bigrams (→fi,→fj) to check for come from cases where if the query residual stream vector has feature →fj, then it is helpful to do something with the information at positions where →fi is present. Here are some examples of checks this can capture:
If the query is a first name, then the key should be a surname.
If the query is a preposition associated with an indirect object, then the key should be a noun/name (useful for IOI).
If the query is token T, then the key should also be token T (useful for induction heads, if we can do this for all possible tokens).
If the query is ‘Jorge Luis Borges’’, then the key should be ‘Tlön, Uqbar, Orbis Tertius’.
If the mood of the paragraph before the query is solemn, then the topic of the paragraph before the key should be statistical mechanics.
If the query is the end of a true sentence, then the key should be the end of a false sentence.
If the query is a type of pet, then the key should be a type of furniture.
The task is to use the attention score S (the attention pattern pre-softmax) to count how many of these conditions are satisfied by each choice of query token position and key token position. That is, we’d like to construct a low-rank bilinear form WTKWQ such that the (s,t) entry of the attention score matrix Sst=→aTsWTKWQ→at contains the number of conditions in C which are satisfied for the query residual stream vector in token position s and the key residual stream vector in the token position t. We’ll henceforth refer to the expression WTKWQ as WQK, a matrix of size dresid×dresid that we choose freely to solve the task subject to the constraint that its rank is at most dhead<dresid. If each property is present sparsely, then most conditions are not satisfied for most positions in the attention score most of the time.
We will present a family of algorithms which allow us to perform this task for various set sizes |B|. We will start with a simple case without superposition analogous to the ‘standard’ method for computing ANDs without superposition. Unlike for U-AND though, the algorithm for performing this task in superposition is a generalization of the non-superpositional case. In fact, given our presentation of the non-superpositional case, this generalization is fairly immediate, with the main additional difficulty being to keep track of errors from approximate calculations.
3.2 A superposition-free algorithm
Let’s make the assumption that m is at most dresid. For the simplest possible algorithm, let’s make the further (definitely invalid) assumption that the feature basis is the neuron basis. This means that →as is a vector in {0,1}dresid. In the absence of superposition, we do not require that these features are sparse in the dataset.
To start, consider the case where B contains only one feature bigram (→ei,→ej). The task becomes: ensure that Sst=→aTsWQK→at is 1 if feature →fi is present in→as and feature →fj is present in →at and 0 otherwise. The solution to this task is to choose WQK to be a matrix with zero everywhere except in the i,j component: (WQK)kl=δkiδlj —with this matrix, →aTsWQK→at=1 iff the i entry of →as is 1 and the j entry of →at is 1. Note that we can write WQK=→k⊗→q where →k=→ei, →q=→ej, and ⊗ denotes the outer product/tensor product/Kronecker product. This expression makes it manifest that WQK is rank 1. Whenever we can decompose a matrix into a tensor product of two vectors (this will prove useful), we will call it a _pure tensor_ in accordance with the literature. Note that this decomposition allows us to think of WQK in terms of the query part and key part separately: first we project the residual stream vector in the query position onto the ith feature vector which tells us if feature i is present at the query position, then we do the same for the key, and then we multiply the results.
In the next simplest case, we take the set B to consist of pairs (ei,ej). To solve the task for this B, we can simply perform a sum over WPQK for each bigram in B, since there is no interference. That is, we choose
WPQK=∑(i,j)∈B→ei⊗→ej
The only new subtlety that is introduced in this modification comes from the requirement that the rank of WPQK be at most dhead which won’t be true in general. The rank of WPQK is not trivial to calculate for a given B. This is because we can factorize terms in the sum:
→ej1⊗→ei1+→ej1⊗→ei2+→ej2⊗→ei1+→ej2⊗→ei2=(→ej1+→ej2)⊗(→ei1+→ei2)
which is a pure tensor. The rank requirement is equivalent to the statement that WPKW can contain at most dhead terms _after maximum factorisation_ (a priori, not necessarily in terms of such pure tensors of sums of subsets of basis vectors). Visualizing the set B as a bipartite graph with m nodes on the left and right, we notice that pure tensors correspond to any subgraphs of B that are _complete_ bipartite subgraphs (cliques). A sufficient condition for the rank of W being at most dhead is if the edges of B can be partitioned into at most dhead cliques. Thus, whether we can check for all feature bigrams in B this way depends not only on the size of B, but also its structure.. In general, we can’t use this construction to guarantee that we can check for more than dhead skip feature-bigrams.
Generalizing our algorithm to deal with the case when the feature basis is not neuron-aligned (although it is still an orthogonal basis) could not be simpler. All we do is replace {→ei} with the new feature basis, use the same expression for WPQK, and we are done.
3.3 Checking for a structured set of skip feature-bigrams with activation superposition
We now consider the case where the residual stream contains m>dresid sparsely activated features stored in superposition. We’ll assume that the feature vectors are random unit vectors, and we’ll switch notation from e1,…,edresid to f1,…,fm from now on to emphasize that the f-vectors are not an orthogonal basis. We’d like to generalize the superposition-free algorithm to the case when the residual stream vector stores features in superposition, but to do so, we’ll have to keep track of the interference between non-orthogonal f-vectors. We know that the root mean square dot product between two f-vectors is 1/√dresid. Every time we check for a bigram that isn’t present and pick up an interference term, the noise accumulates—for the signal to beat the noise here, we need the sum of interference terms to be less than 1. We’ll ignore log factors in the rest of this section.
We’ll assume that most of the interference comes from checking for bigrams (→fi,→fj) where →fi isn’t in →as and also →fj isn’t in →at — that cases where one feature is present but not the other are rare enough to contribute less can be checked later. These pure tensors typically contribute an interference of 1/dresid. We can also consider the interference that comes for checking for a clique of bigrams: let K and Q be sets of features such that B=K×Q. Then, we can check for the entire clique using the pure tensor (∑j∈K→fj)⊗(∑i∈Q→fi). Checking for this clique of feature bigrams on key-query pairs which don’t contain any bigram in the clique contributes an interference term of √|K||Q|/dresid assuming interferences are uncorrelated. Now we require that the sum over interferences for checking all cliques of bigrams—of which there are at most dhead - is less than one. Since there are at most dhead cliques, then assuming each clique is the same size (slightly more generally, one can also make the cliques differently-sized as long as the total number of edges in their union is at most dresid) and assuming the noise is independent between cliques, we require √|K||Q|/dresid<1/√dhead. Further assuming |K|=|Q|, this gives that at most |K|=|Q|=dresid/√dhead. In this way, over all dhead cliques, we can check for up to d2resid bigrams, which can collectively involve up to dresid√dhead distinct features, in each attention head.
Note also that one can involve up to dheaddresid features if one chooses |K|=1 and |Q|