Neural Tangent Kernel Distillation

Produced As Part Of The SERI ML Alignment Theory Scholars Program 2022 Under John Wentworth.

Introduction

Consider the following example from Goal Misgeneralization: you train an RL agent to pursue a coin, which is always at the same location (at the end of the level) during training. At test time, if the coin position is randomized, does the RL agent pursue the coin, or does it go to the fixed location?

Since the training data could not distinguish between the two goals (‘go to the coin’ and ‘go to the location’), which one is chosen is purely based on the neural network prior. It would be nice to be able to use simplicity priors like the Solomonoff prior to think about this: the neural network might tend to choose the ‘simplest’ extrapolation. But what is the right notion of simplicity for neural networks? It’s clearly not program length in any normal programming language, because the parity function[1] is very short but hard for neural networks to learn.

In this post, we will summarize some recent advances in DNN theory that have given us the ability to describe the prior of a deep learning network, and discuss the relevance to alignment. Our goal is to communicate quickly the insights that we found difficult to understand or slow to extract from the sources we were using.

Neural Tangent Kernel (NTK) theory compares neural networks to kernel methods. [2] The most interesting takeaways (to us) are:

  • We can make a “linearized” neural network and prove that training this is equivalent to kernel regression.

  • As we increase the width of a normal neural network, we can prove that it will behave more like the linearized network.

  • By relating this to Gaussian Process inference, we can think of the neural network as doing Bayesian inference, and describe and visualize the prior distribution that it is using.

    • We can print out features/​eigenmodes (functions from input vectors to output vectors), and think of the neural network as finding the best linear combination of these, with a preference for using the earlier eigenmodes.

Prerequisites: Linear Algebra, Gaussian Processes (GPs), to the level explained here, a source we highly recommend playing with to gain GP intuition.

Background

Notation

Let denote the network architecture, a function that takes in the parameters and a network input, and outputs a predicted label. denotes the parameters, is the input to the network, and is the output of the network.[3] We will use to denote the network inputs, and to denote the network outputs, and denote to be the network evaluated on each in . denotes a feature map associated with a kernel.

We will use L2 loss: .

Kernel Methods

Kernel Methods have been extensively studied in the pre-deep learning era. And it turns out that we can use insights from this area to better understand neural networks.

A kernel is a function that tells you how a priori similar any two data points are.

There are three intuitive ways I think about kernel methods:

  1. A kernel method predicts a label by taking a weighted average of the labels of nearby data points, weighted by how close the kernel thinks the data points are.

  2. From a Bayesian point of view, a kernel gives you the a priori covariance between data points, which we can use as a prior to do Bayesian inference.

  3. A kernel method transforms data into a fixed feature space, then does linear regression on the data points in that space (with some prior over the linear regression parameters).

Kernel Linear Regression

Classical linear regression works as follows: you want to find a parameter vector to predict data labels : you want .[4] It’s pretty easy to just solve for the that is closest:

In order to get prediction for a new input, , now that you have , you can simply multiply by theta:

Kernel regression generalizes linear regression: instead of fitting a linear predictor from the feature space , we pick a kernel function that picks out features of the input space, and then do linear regression in the higher dimensional space. We also assume that the parameters learned by the linear regression, , can be expressed as , i.e., that it is a linear combination of features extracted by some feature function .[5] So:

Solving for gives:

Now, we can substitute in:

This is the equation for kernel regression, where we can understand the first two terms as being a weighted similarity vector of our test data point to each of the training data points, which is dotted with the training labels.

Neural Tangent Kernel

How are neural networks kernel methods?

Normally, you treat the neural network as , fixing , so you get simply a map from inputs to outputs. But there’s another way of thinking about it which is as a parameter to function map, given a fixed . In particular, we can do a Taylor expansion of the parameter function map around , the initialization.

The error of this approximation is , so the less the parameters are updated during training, the better this approximation is. One of the key results behind NTK research is that as the width of a network increases toward infinity, the parameters change less during training.

But for a moment, let’s keep the width finite. At finite , we can make a new learning algorithm called a “linearized neural network”, which is described by this equation:

This equation describes (almost) linear regression on a particular feature space :

As we learned above in the Kernel Linear Regression section, linear regression on a feature space is equivalent to Kernel Regression with !

Hence, training is equivalent to doing:

where .

The only major insight left is that implies , which means . This is non-trivial to prove and depends on the initialization distribution.[6]

The NTK function

This Kernel regression motivates the definition of the NTK function:

We can think of the NTK function as telling you the ‘similarity’ of two given data points according to the feature map at initialization. This is not to be confused with the NTK matrix: the matrix whose -th component is , for a set of input datapoints .

The last result that we need to know is that the NTK stops depending on when the width is . We won’t prove this here, but the sketch is that if we expand out for a particular neural network architecture, it ends up having a lot of sums over the weights. When the width is , these sums become expectations.

The NTK in the infinite width limit can be written out analytically, e.g. the one for 2 layer ReLU network is:[7]

where:

Prior over data

We can view any kernel method as giving us the posterior mean of a Gaussian Process (see the Marginalization and Conditioning section of this to see why, although beware that they are using different variable names which is confusing[8]).

So we can think of the NTK as giving us a prior over the labels of a given data distribution , specifically:

Where is the NTK matrix, which depends on our dataset: . We will abbreviate the constant denominator .

Kernel Eigenmodes

We can understand this prior via eigendecomposition. Since is positive semidefinite, it is symmetric, and so the Spectral Theorem applies, allowing us to eigendecompose into , where is a diagonal matrix of the eigenvalues, and is the matrix of eigenvectors of .

Here, are the eigenvectors of H, with corresponding eigenvalues .

Eigendecomposing gives , and the labels are sampled from a Gaussian Process, so the log probability density is:

We can think of as the correlation between the dataset and each eigenvector. See footnote for the full derivation.[9]

This is an explicit prior for a neural network. We can predict how a neural network will generalize to certain test points , when trained on the training data points with labels . The way to do this is by computing the NTK with , then calculating for several different versions of . This version of the test labels that gives the highest prior probability is the generalization most likely to be chosen. T​his is analogous to conditioning this Gaussian on the labels .

Visualizing Eigenvectors

We visualize these for a specific neural network below:

The first four eigenvectors for for a 2-layer fully connected neural network with training inputs ranging from 3 to −3.

When we get to the later eigenvectors, they turn out to be all sinusoidal. We can think of neural network training as finding a linear combination of these functions, where it prefers learning the functions with higher eigenvalues. It also learns the ones with higher eigenvalues earlier in training run on each of these, with update learning rate according to their eigenvalue (see appendix for why this is true).

Here is our Google Colab Notebook to generate these results.

Alignment relevance

So we have a mathematically precise notion of the simplicity prior! What does this tell us about alignment?

Unfortunately, not too much. The key problem is abstraction: it’s really hard for us to express abstract concepts like ‘is this network deceptive?’ in the language of the kernel eigenfunctions sine wave decomposition. I am excited for future work to tackle this problem and use NTK theory to predict how neural networks will generalize. For example, could we prove something like “this neural network is very unlikely to learn an algorithm in the set of bounded tree search algorithms”?

We should be able to put any two data points into the kernel and get a measure of how similar they are. This should let us test whether the trained neural network going to treat an image of a lion in the snow as more similar to a training data point of a husky in the snow or a lion in grass?

Appendix: Modeling training dynamics with gradient flow

A result that we thought was cool that didn’t fit anywhere else in this is the proof that infinitely wide neural networks can always get to zero loss. Recall that we model the training dynamics of a neural network as having infinitely small step sizes, called a gradient flow. This allows us to model training as differential equation, where we are continuously updating the parameters over time according to how they perform on the loss :

Where we sometimes abbreviate as just . This was modeling the gradient flow in parameter space, but what we really care about is the dynamics in function space. In other words, we care about the changes in the function as increases. Fortunately, we can simply compute this on the training set:

We have now found a crucial quantity: is the NTK matrix. Let’s call it . It turns out that in the limit of width, this quantity is constant over time. Thus:

is clearly an equilibrium of this ODE, because when this is satisfied, the RHS is . We can explicitly solve this ODE by making the substitution , so:

This is a well-known ODE, with solution given by:

This gives us a proof of global convergence on the training data.

  1. ^

    The parity function is , and returns if and only if the input has an odd number of ones.

  2. ^
  3. ^

    We will assume 1-dimensional outputs, because it makes the math much more manageable: the NTK becomes a 4-Tensor with output dimension more than 1.

  4. ^

    Assume for simplicity that there is no noise and we can perfectly fit the data with a linear function.

  5. ^

    There is a theorem (the Representer theorem) which says that the loss minimizing hypothesis (within the space of functions associated with this kernel) has a representation of this form.

  6. ^

    For the actual proof, start at the bottom of p14 of the original paper and work backwards. There’s a simplified version here in the One hidden Layer Network proof.

  7. ^

    I haven’t checked all of the derivation of this, I got it from On the Inductive Bias of Neural Tangent Kernels, who seem to have got it by combining the NTK definition in Appendix A of this with analytical evaluation of the integrals from here.

  8. ^

    See the equations for conditioning a Gaussian, and assume that the prior means and are 0. Then, translating back into our variables, we get:

  9. ^

  10. ^