Some Notes on the mathematics of Toy Autoencoding Problems

Anthropic’s recent mechanistic interpretability paper, Toy Models of Superposition, helps to demonstrate the conceptual richness of very small feedforward neural networks. Even when being trained on synthetic, hand-coded data to reconstruct a very straightforward function (the identity map), there appears to be non-trivial mathematics at play and the analysis of these small networks seems to providing an interesting playground for mechanistic interpretability.

While trying to understand their work and train my own toy models, I ended up making various notes on the underlying mathematics. This post is a slightly neatened-up version of those notes, but is still quite rough and un-edited and is a far-from-optimal presentation of the material. In particular, these notes may contain errors, which are my responsibility.

1. Directly Analyzing the Critical Points of a Linear Toy Model

Throughout we will be considering feedforward neural networks with one hidden layer. The input and output layers will be of the same size and the hidden layer is smaller. We will only be considering the autoencoding problem, which means that our networks are being trained to reconstruct the data. The first couple of subsections here are largely taken from the Appendix to the paper “Neural networks and principal component analysis: Learning from examples without local minima.” by Pierre Baldi and Kurt Hornik. (Neural networks 2.1 (1989): 53-58).

Consider to begin with a completely liner model., i.e. one without any activation functions or biases. Suppose the input and output layers have neurons and that the middle layer has neurons. This means that the function that the model is implementing is of the form , where , is a matrix, and is a matrix. That is, the matrix contains the weights of the connections between the input layer and the hidden layer, and the matrix is the weights of the connections between the hidden layer and the output layer. It is important to realise that even though—for a given set of weights—the function that is being implemented here is linear, the mathematics of this model and the dynamics of the training are not completely linear.

The error on a given input will be measured by and on the data set , the total loss is

Define to be the matrix whose entry is given by

Clearly this matrix is symmetric.

Assumption. We will assume that the data is such that a) is invertible and b) has distinct eigenvalues.

Let be the eigenvalues of .

1.1 The Global Minimum


Proposition 1. (Characterization of Critical Points) Fix the dataset and consider to be a function of the two matrix variables and . For any critical point of , there is a subset of size for which

  1. is an orthogonal projection onto a -dimensional subspace spanned by orthonormal eigenvectors of corresponding to the eigenvalues ; and

  2. .

Corollary 2. (Characterization of the Minimum) The loss has a unique minimum value that is attained when , which corresponds to the situation when is an orthogonal projection onto the -dimensional subspace spanned by the eigendirections of that have the largest eigenvalues.

Remarks. We won’t try to spell out all of the various connections to other closely related things, but for those who want some more keywords to go away and investigate further, we just remark that the minimization problem being studied here is about finding a low-rank approximation to identity and is closely related to Principal Component Analysis. See also the Eckart–Young–Mirsky Theorem.

We begin by directly differentiating with respect to the entries of and . Using summation convention on repeated indices, we first take the derivative with respect to

Setting this equal to zero and interpreting this equation for all and gives us that

Then, separately, we differentiate with respect to :

Setting this equation equal to zero for every and we have that:

Thus

Since we have assumed that is invertible, the first equation immediately implies that . If we assume in addition that has full rank (a reasonable assumption in any case of practical interest), then is invertible and we have that

which in turn implies that

where we have written to denote the orthogonal projection on to the column space of .

Claim. We next claim that commutes with .

Proof of claim. Plugging (5) into (3), we have:

Then, right-multiply by and use the fact that to get:

The right-hand side is manifestly a symmetric matrix, so we deduce that is symmetric. If the product of two symmetric matrices is symmetric then they commute, so this indeed shows that commutes with and completes the proof of the claim.

Now let be the orthogonal matrix which diagonalizes , i.e. the matrix for which

where is a diagonal matrix with entries .

Claim. We next claim that and that is diagonal.

Proof of Claim. Firstly, using the standard formula for orthogonal projections, we have

which implies that

To show that is diagonal, we show that it commutes with the diagonal matrix (any matrix that commutes with a diagonal matrix must itself be diagonal). Starting from , we first insert the identity matrix in the form , and then use (8) and (9) thus:

Then recall that we have already established that commutes with . So we can swap them and then performing the same trick in reverse:

This shows that commutes with and completes the proof of the claim.

So, given that is an orthogonal projection of rank and is diagonal, there exists a set of indices with such that the entry of is zero if and 1 if . And since , we see that

where is formed from by simply setting to zero the column if . This is manifestly an orthogonal projection onto the span of , where is an orthonormal basis of eigenvectors of (and indeed the columns of ). Combining these observations with (5), we have that

This proves the first claim of the proposition.

To prove the second part, write and compute thus:

But we know from (7) and (11) that and so this last line is actually just equal to

Focussing on the second term and using (11), then (9) and (8), then cancelling , and then—to reach the last line—cyclicly permuting the matrices inside the trace operator to produce another cancellation, we have:

The diagonal form of means that this final expression is equal to , meaning that

Since (the trace is always equal to the sum of the eigenvalues), this completes the proof of the proposition.

Remarks. Equation (10) above tells us that , which means that there exists an invertible matrix with . Then, using (4), we compute that

So we have:

1.2 Characterizing Other Critical Points

This subsection is something of an aside, but it is included for completeness.

Proposition 3. (Other Critical Points are Saddle Points.) Fix the dataset and consider to be a function of the two matrix variables and . Every other critical point is a saddle point, i.e. if is a critical point but not equal to the unique minimum, then exist and which are arbitrarily close to and respectively and at which a lower loss is achieved.

Proof. Since is not the unique global minimum, we know from Corollary 2 that . This means that there are distinct indices and for which , and . In particular, bear in mind that .

Now, given any , put

And let us form the new matrix by starting with and replacing the column with . Write

We want to calculate the loss of the model at . We ought to bear in mind that it is not a critical point, so we cannot assume the intermediate results in the proof of Proposition 2, but it turns out that the bits that are most useful for this computation rely only on algebra and (13), (14). We start from the equivalent of line (11) which is that which implies that . And so just as in (12) above, we have

Now, looking at the final term on the right-hand side, we have and (by cycling permutation) . And since

we have:

We also use and (16) on the second term on the right-hand side of (15) to ultimately arrive at:

So we are interested in computing the diagonal elements of . Fix . The diagonal entry is given by:

This can be computed directly from the definition of to give that the entry on the diagonal is equal to

Therefore

Since this shows that in an arbitrarily small neighbourhood of the critical point we can find a point where smaller loss is achieved. We will not bother doing so here, but one can also check that is not a local maxima by using the fact that for fixed (full rank) , the function is convex.

2. Sparse Data, Weight Tying, and Gradients

Abstractly analyzing critical points is not at all the same as training real models. In this section we start to think about data and the optimization process.

2.1 Sparse Synthetic Data

Here we describe the kind of training data used in Anthropic’s toy experiments

Fix a number . This parameter is the sparsity of the data. We will typically be most interested in the case where is close to 1.

Let be an independent and identically distributed family of Bernoulli random variables with parameter . And let be an IID family of random variables. Write and . Our datasets will be drawn from the IID family . Notice that

  • Independently, for each data point and for each , we will have .

  • So, the expected number of non-zero entries for each data point is . To bring this in line with the way people say things like “-sparse”, we can say that the data is, on average, -sparse.

  • .

Remark. Judging from some of the existing literature on the linear model that we analyze in Section 1 (e.g. “Exact solutions to the nonlinear dynamics of learning in deep linear neural networks.” by Andrew M. Saxe, James L. McClelland and Surya Ganguli), it seems like it’s tempting to make an assumption/​simplification/​approximation that . I still don’t feel like I understand how justifiable that is—for me this question is a potential ‘jumping-off’ point for further analysis of the whole problem. Recall that the matrix is equal to . Certainly the probability that an off-diagonal entry of is equal to zero is whereas for the diagonal entries it is just . And note that if and . But the diagonal entries are still independent and I’m not sure why thinking of them as equal makes sense.

The data (and the loss) are model two main ideas: Firstly, that the coordinate directions of the input space act as a natural set of features for the data. And secondly, when is close to 1, the sparsity of the data is supposed to capture the fact that features really do often tend to be sparse in real-world data, i.e. we see that for any given object or any given word/​idea that appears in a language, it is the case that most images don’t contain that object and most sentences don’t contain that word or idea.

2.2 Weight Tying and The Gradient Flow

In practice, when we train an autoencoder like this, we do so with weight tying. Roughly speaking, this means that we only consider the case where . Proposition 1 does indeed allow for a global minimum in which : This is achieved by essentially taking in the equations () at the end of Section 1.1, i.e. we have:

But note that we don’t actually want to try to repeat the analysis of Section 1 on a loss of the form . This would be a higher-order polynomial function of the entries of and so it’s genuinely a different and potentially more complicated functional. The way that weight-tying is done in practice is more similar to saying that we insist during training that updates are made that preserve the equality .

Equations (1) and (2) in subsection 1.1 are obtained as a direct result of differentiating the loss with respect to individual entries of the matrices (or individual ‘weights’ if we interpret this model as a feedforward neural network without activations). Our computations show that:

In an appropriate continuous time limit, if we set the learning rate to 1, the weights during training evolve according to the differential equations:

Remarks. Notice that there is a certain deliberate sloppiness here: One doesn’t really have a fixed matrix and then run this gradient flow for all time; the matrix is a function of (a batch of) training data. So we need to be careful about any further manipulations or interpretations of these equations.

Those caveats having been noted, if we additionally add in the weight-tying constraint , we get:

We can even make the substitution to introduce the form:

In components (and without summation convention) the equation reads

Let denote the set of columns of so that (24) can becomes:

Expanding the brackets and executing the sum over gives:

Then the sum over further simplifies the first term to give:

Finally, just peel off the term from the remaining summation to arrive at the equation

Remarks. (cf. the previous two Remarks) If we assume that , then and the equation above arises as gradient descent on the energy functional

It’s plausible that a reasonable line of argument to justify this is that since no particular directions in the data are special, it means that over time, on average, the effects of different eigenvalues of just somehow ‘average out’. But I don’t endorse or understand how that argument would actually go. Regardless, if we just assume this for now, as is explained in the Anthropic paper, we can think of the two terms in (28) as being in competition. The first term suggests that model ‘wants’ to learn the feature by arranging . However, as it tries to does so, it incurs a penalty—given by - that can reasonably be interpreted as the extent to which the hidden representation of that feature interferes with its attempts to represent and reconstruct the other features.

3. The ReLU Output Model

3.1 The Distribution of the Data and the Integral Loss

Perhaps a better way to try to incorporate information about the distribution of the data into the analysis here is to directly let be the distribution (i.e. in the proper measure-theoretic since) of on and to consider

In the Anthropic paper and in my own work, we are ultimately more interested in a model with biases and ReLUs at the output layer.

Performing an analysis anything like that done in Section 1 seems much harder for this model, but perhaps more progress can be made studying the integral above.

The synthetic data we described in the previous section is all contained in the cube . In the sparse regime i.e. with close to 1, the vast majority of the data is concentrated around the lower-dimensional skeletons of the cube. For , if we write for the set of points in the cube with only non-zero entries, i.e.

then is the disjoint union

Without a closer analysis of binomial tail bounds I can’t immediately tell how well-justified it is to say, ignore and focus the analysis just on 1-sparse vectors in the dataset. i.e. You might want to say that is sufficiently small such that that region contributes only negligibly to the integral. Then you can start to work with more manageable expressions To my mind this is another concrete potential ‘jumping-off’ point if one were to do more investigation. In particular, it is in the direction of the observations made in Toy Models of Superposition to suggest a link between this problem and ‘Thomson Problem’.