Multi-Component Learning and S-Curves

(Thanks to Oliver Balfour, Ben Toner, and various MLAB participants for early investigations into S-curves. Thanks to Nate Thomas and Evan Hubinger for helpful comments.)

Introduction

Some machine learning tasks depend on just one component in a model. By this we mean that there is a single parameter or vector inside a model which determines the model’s performance on a task. An example of this is learning a scalar using gradient descent, which we might model with the loss function

Here is the target scalar and is our model of that scalar. Because the loss gradients are linear gradient descent converges exponentially quickly, as we see below:

The same holds for learning a vector using gradient descent with the loss

because the loss is a sum of several terms, each of which only depends on a single component.

By contrast, some tasks depend on multiple components simultaneously. That is, the model will only perform well if multiple different parameters are near their optimal values simultaneously. Attention heads area a good example here: they only perform well if the key, query, value, and output matrices are all close to some target. If any of these are far off the whole structure fails to perform.

We think that such multi-component tasks are behind at least some occurrences of the commonly-seen S-curve motif, where the loss plateaus for an extended period before suddenly dropping. Below we provide toy examples of multi-component tasks, reason through why their losses exhibit S-curves, and explore how the structures of these S-curves vary with the ranks of the components involved. We additionally provide evidence for our explanations from numerical experiments.

Examples

Rank-1 learning

Suppose we’re using gradient descent on the regression task:

Here there is a target matrix which has rank 1, and which we model by an outer product of two vectors and .

The loss gradients are:

If we write

and likewise for we get a non-linear ordinary differential equation (ODE) of the form

If the vectors are high-dimensional and we choose a random initialization then we approximately have and at early times. This means the above equations are, to first order,

The equation for is linear in , and likewise for , so the initial solution decays (approximately) exponentially, which decreases the loss. At the same time, the next-order correction contains a term

This causes the model to learn the ground truth. The true component grows linearly at first (during the exponential decay of the initialization). Then this component is actually the dominant term, so that and , and we get exponential growth of the true solution. Eventually and approach and in magnitude and the growth levels off (becomes logistic).

To summarize:

  1. At early times the loss decays exponentially because the initial guess decays away.

  2. For a while the loss plateaus because the growing (correct) component is still small.

  3. Once the correct solution is significant the loss rapidly falls.

  4. As the correct solution comes to be dominant learning slows and we see a leveling off.

Thus, we get an exponential decay followed by a sigmoid.

We again confirm this experimentally:

Initially the vectors just decay, giving exponential loss improvement, then the growing part takes over, bringing them into alignment with ground truth and raising their norms, resulting in a second exponential and hence an S-curve.

Low-rank learning

Similar reasoning applies to the case of learning a low-rank matrix with a low-rank representation of that matrix. Concretely, our task is now:

The loss gradients are:

Notice that the resulting ODE is segmented k-by-k.

If we write

and likewise for we get a non-linear ODE of the form

Once more if the vectors are high-dimensional and we choose a random initialization then we approximately have and at early times. We also have and likewise for . This means the above equations are, to first order,

so the initial solution decays (approximately) exponentially, which decreases the loss. At the same time, the next-order correction contains a term

This causes the model to learn the truth. The true component grows linearly at first (during the exponential decay of the initialization). Then this component is actually the dominant term, and we get exponential growth with rate .

Eventually and approach and in magnitude and the growth levels off (becomes logistic).

This results in the same phenomenology as in the rank-1 case, and that’s indeed what we see:

High-rank learning

(Edit: Due to a plotting bug the inner product panels in this section were not correctly normalized. I’ve corrected the plots below, though nothing qualitative changes.)

The phenomenology changes as we increase the rank. Working with 100-dimensional vectors, we see that rank-10 matrices have a similar phenomenology but with a more extended plateau:

Rank-100 (full-rank) shows no second phase of exponential decay! It just transitions from exponential at the start (as norms decay) into a power-law:

(Just showing a few vectors.)

Note that the final vectors are not parallel to the ground truth: this is possible because there are many vectors, so they just need to find directions that allow them to span the same vector space as the ground truth.

To make sense of this change we start with the same ODE as before, but write it with as matrices whose second index runs through the different vectors:

Now comes the fun part: is a positive semi-definite matrix so long as is real (which we’ll assume applies to our matrices). That means that the second term in each equation causes decay (which is exponential if the other of is held constant). For example, if we ignore the first term in the first equation we have

So long as the dual vectors of lie in the span of the dual vectors of we can decompose into a sum of the right-eigenvectors of (the eigenvectors that live in the dual space). The projection onto each eigenvector decays independently, and so we find exponential decay (ignoring the evolution of ).

If the dual vectors of don’t lie in the span of the dual vectors of then there will be components of which do not decay. As the rank of the target matrix increases the fraction of outside of the span of falls, making the decay more purely exponential. This explains why we see more exponential decay in the rank-100 case than the rank-10 case.

At the same time, we also have a term , which causes a growing component in proportional to the truth. This component grows in size until it comes to dominate, at which point we see both the vector norms rebound and simultaneously they come into better alignment with the target vectors. Due to the non-uniqueness of matrix decompositions they do not come into as obvious an alignment as before, but we do generally see inner products increase during this phase.

As the rank increases the initial plateau for each target vector shortens because random initial vectors are closer to aligning with the ground truth. Moreover, because different vectors are learned at different rates, increasing the rank smears the transition out. The net result is that as the rank increases towards full the S-curve loses its plateau and the exponential tail turns into a power-law.

Attention Head

As one last example, consider learning an attention head of the form:

where the inner dimensions of and are low-rank. To keep the setup simple, we’ll try learn this using the loss

Note that this drops the softmax and is a totally artificial loss function. Nonetheless, it has the property that the loss is only low if all four components of the attention head are close to their target matrices.

The first term is irrelevant for the learning dynamics, so we drop that. The other terms we expand, finding

The indexing on this gets tricky, so let’s make that explicit:

Taking the gradients with respect to our parameters we find:

The other expressions are structurally the same so we focus on just this one. Cleaning it up a bit we find

Early in training, the final trace is nearly zero because the vectors are mostly orthogonal. So we just have the first term, giving

What does this evolution do to the norm of ? Well,

So the norm decays exponentially at early times. What happens at later times? The second term comes to dominate so

This looks a bit like terms we’ve seen before. In particular, is positive-definite, so the first factor grows a component correlated with the truth . The trouble is that the trace factor that follows need not be positive.

So we hit a problem: the system needs to get all of the vectors close enough that the gradients have a sign pointing towards a basin, at which point we should see rapid learning.

Fortunately there are many valid basins, because we can flip the signs of any pair of vectors and leave the system unchanged, and similarly matrices can be rotated by unitary operators in the dual space and leave everything unchanged. So we probably get a basin somewhere nearby.

All of which is to say, we should expect a plateau followed by a sudden drop once the basin is found, which is exactly what we see (with rank-2):

Conclusions

S-curves are a natural outcome of trying to learn tasks where performance depends simultaneously on multiple components. The more components involved the longer the initial plateau because there are more pieces that have to be in place to achieve low loss and hence to get a strong gradient signal.

As components align, the gradient signal on the other components strengthens, and the whole process snowballs. This is why we see a sudden drop in the loss following the plateau.

When the components being learned are low-rank it takes longer to learn them. This is because random initializations are further (on average) from values that minimize the loss. Put another way, with a rank-k ground truth and a rank-k model, there are ~k chances for each learned component to randomly be near one of the targets. So as the rank falls on average each component starts further from the nearest target, and the time spent on the plateau rises.

As components approach full rank, the plateau disappears. This is because the span of the random initialization vectors approaches the span of the target vectors, so there is a significant gradient signal on each component from the beginning.

We think that this picture is consistent with the findings of Barak+2022, who show that for parity learning SGD gradually (and increasingly-rapidly) amplifies a piece of the gradient signal known as the Fourier gap. Once this component is large enough learning proceeds rapidly and the loss drops to zero.

Questions

Finally, a few questions we haven’t answered and would be keen to hear more about:

  1. Are there examples of S-curves that can’t be thought of as learning a multi-component task?

  2. Are there examples of low-rank multi-component tasks which do not produce S-curves?