Performing an SVD on a time-series matrix of gradient updates on an MNIST network produces 92.5 singular values

(work in progress) Colab notebook here.

Summary

I perform a singular value decomposition on a time series matrix of gradient updates from the training of a basic MNIST network across 2 epochs ( time-steps), and find singular values, about of what would be expected if there was nothing interesting going on. Then I propose various possible interpretations of the results including the left singular vectors representing capability phase changes, their representing some dependency structure across cognition, and their representing some indication of what could be considered a shard. Then I outline a particular track of experimentation I’d like to do to get a better understanding of what’s going on here, and ask interested capable programmers, people with >0 hours of RL experience, or people with a decent amount of linear algebra experience for help or a collaborative partnership by setting up a meeting with me.

What is SVD

Very intuitive, low technical detail explanation

Imagine you have some physical system you’d like to study, and this physical system can be described accurately by some linear transformation you’ve figured out. What the singular value decomposition does is tell you the correct frame in which to view the physical system’s inputs and outputs, so that the value of each of your reframed outputs is completely determined by the value of only one of your reframed inputs. The SVD gives you a very straightforward representation of naively quite messy situation.

Less intuitive, medium technical detail

All matrices can be described as a high dimensional rotation, various stretches along different directions, then another high dimensional rotation. The high dimensional rotations correspond to the frame shifts in the previous explanation, and the stretches correspond to the ‘the value of each of your reframed outputs is completely determined by the value of only one of your reframed inputs’ property.

Not intuitive, lots of technical detail

From Wikipedia

The singular value decomposition of an complex matrix is a factorization of the form , where is an complex unitary matrix is an rectangular diagonal matrix with non-negative real numbers on the diagonal, is an complex unitary matrix, and is the conjugate transpose of . Such decomposition always exists for any complex matrix. If is real, then and can be garunteed to be real orthogonal matrices; in such contexts, the SVD is often denoted .

The diagonal entries of are uniquely determined by and are known as the singular values of . The number of non-zero singular values is equal to the rank of . The columns of and the columns of are called left-singular vectors and right-singular vectors of , respectively. They form two sets of orthonormal bases and , and if they are sorted so that the singular values with value zero are all in the highest-numbered columns (or rows), the singular value decomposition can be written as

where is the rank of .

What did I do?

At each gradient step (batch) I saved the gradient of the loss with respect to the parameters of the model, flattened these gradients, and stacked them on top of each other to get a size vector, then I put all these vectors into a big matrix, so that the th row represtented the th parameter, and the th column represented the th gradient step.

Doing this for 2 epochs on a basic MNIST network gave me a giant size matrix.

I tried doing an SVD to this, but this crashed my computer, so instead I randomly sampled time-steps with which to do an SVD to (deleting the time steps not randomly sampled), and verified that the residual of the singular vectors with respect to a bunch of the other time step gradients was low. I found that if you randomly sampled about 300 time steps, this gives you a pretty low residual ().

What we’d expect if nothing went on

If nothing interesting was going on we would expect about 2,000 nonzero singular values (since ), or at least some large fraction of 2,000 nonzero singular values.

If some stuff was going on, we would expect a lowish fraction (like 20-60%) of 2,000 nonzero singular values. This would be an interesting result, but not particularly useful, since 20-60% of 2,000 is still a lot, and this probably means lots of what we care about isn’t being captured by the SVD reframing.

If very interesting stuff was going on, we would expect some very small fraction of 2,000 singular values. Like, <10%.

What we actually see

This is the graph of normalized singular values I found

And here’s what we have for 400 randomly sampled indices just so you know it doesn’t really change if you add more.

The x-axis is the singular value’s index, and the y axis is the singular value’s… well… value.

The rule of thumb is that all singular values less than 10% of the largest singular value are irrelevant, so what this graph tells us is that we have like 70-150 singular values[1], about 3.5-7.5% of what we would expect if nothing interesting was going on.

I ran this procedure (minus the graphs) 10 times, with 400 samples each, and took an average of the number of singular values relatively greater than 0.1 I got, resulting in an average of 92.5 singular values, with a standard deviation of 11.6. In terms of percentages of the full dimension, this is .

Theories about what’s going on

The following theories are not necessarily mutually exclusive.

  • Left singular vectors represent capability phase changes

  • Left singular vectors represent capability topologies/​dependency structures

  • Left singular vectors represent parameters associated with shard-like components

  • Weird linear algebra stuff that gives little leverage to interpretability or alignment theory

  • We’re detecting some form of nontrivial natural abstraction present during training (The trivial version of this is obvious)

Experiments here I’d like to run

The main thing I want to do now is replicate the results from a particular paper whose name I can’t remember right now, where an RL agent was trained to navigate to a cheese in the top right corner of a maze, apply this method to the training gradients, and see whether we can locate which parameters are responsible for the if bottom_left(), then navigate_to_top_right() cognition, and which are responsible for the if top_right(), then navigate_to_cheese() cognition, which should be determinable by their time-step distribution.

[EDIT: The name of the paper turned out to be Goal Misgeneralization by Langosco et al., and I was slightly wrong about what it concluded. It found that the RL agent learned to go to the top right corner, and also if there was a cheese near it, go to that cheese. Slightly different from what I had remembered, but the experiments described seem simple to caste to this new situation.]

That is, if bottom_left(), then navigate_to_top_right() should be associated with reinforcement events sooner during training rather than later, so the left singular values locating parameters responsible for that computation should have corresponding right singular values with high-in-magnitude numbers in their beginnings and low-in-magnitude numbers in their ends. Similarly, if top_right(), then navigate_to_cheese() should be associated with reinforcement events later during training, so the opposite holds.

Then I want to verify that we have indeed found the right parameters by ablating the model’s tendency to go to the cheese after its reached the top right corner.

It would also be interesting to see whether we can ablate the ability for it to go to the top right corner while keeping the ability to go to the cheese if the cheese is sufficiently close or it is already in the top right corner. However this seems harder, and not as clearly possible given we’ve found the correct parameters.

I’d then want to make the environment richer in a whole bunch of ways-I-have-not-yet-determined, like adding more subgoals like also finding stars. If we make it find stars and cheeses, can we ablate the ability for it to find stars? Perhaps we can’t because it thinks of stars and cheese as being basically the same thing, so can’t distinguish between the two. In that case, what do we have to do to the training dynamics to make sure we are able to ablate the ability for it to find stars but still find cheese?

Call for people to help

I am slow at programming and inexperienced at RL. If you are fast at programming, have >0 hr of experience at RL, or are good with linear algebra[2], and want to help do some very interesting alignment research, schedule a meeting with me! We can talk, and see if we’re a good fit to work together on this. (you can also just comment underneath this post with your ideas!)

  1. ^

    This is such a wide range because the graph is fairly flat in the 0.1 area. If I instead had decided on the cutoff being 0.15, then I would have gotten like 40 singular values, and if I’d decided on it being 0.05, I would have gotten 200.

  2. ^

    You can never have too much linear algebra!