NTK/​GP Models of Neural Nets Can’t Learn Features

Since people are talking about the NTK/​GP hypothesis of neural nets again, I thought it might be worth bringing up some recent research in the area that casts doubt on their explanatory power. The upshot is: NTK/​GP models of neural networks can’t learn features. By ‘feature learning’ I mean the process where intermediate neurons come to represent task-relevant features such as curves, elements of grammar, or cats. Closely related to feature learning is transfer learning, the typical practice whereby a neural net is trained on one task, then ‘fine-tuned’ with a lower learning to rate to fit another task, usually with less data than the first. This is often a powerful way to approach learning in the low-data regime, but NTK/​GP models can’t do it at all.

The reason for this is pretty simple. During training on the ‘old task’, NTK stays in the ‘tangent space’ of the network’s initialization. This means that, to first order, none of the functions/​derivatives computed by the individual neurons change at all; only the output function does.[1] Feature learning requires the intermediate neurons to adapt to structures in the data that are relevant to the task being learned, but in the NTK limit the intermediate neurons’ functions don’t change at all. Any meaningful function like a ‘car detector’ would need to be there at initialization—extremely unlikely for functions of any complexity. This lack of feature learning implies a lack of meaningful transfer learning as well: since the NTK is just doing linear regression using an (infinite) fixed set of functions, the only ‘transfer’ that can occur is shifting where the regression starts in this space. This could potentially speed up convergence, but it wouldn’t provide any benefits in terms of representation efficiency for tasks with few data points[2]. This property holds for the GP limit as well—the distribution of functions computed by intermediate neurons doesn’t change after conditioning on the outputs, so networks sampled from the GP posterior wouldn’t be useful for transfer learning either.

This also makes me skeptical of the Mingard et al. result about SGD being equivalent to picking a random neural net with given performance, given that picking a random net is equivalent to running a GP regression in the wide-width limit. In particular, it makes me skeptical that this result will generalize to the complex models and tasks we care about. ‘GP/​NTK performs similarly to SGD on simple tasks’ has been found before, but it tends to break down as the tasks become more complex.[3]

So are there any theoretical models of neural nets which are able to incorporate feature learning? Yes. In fact, there are a few candidate theories, of which I think Greg Yang’s Tensor Programs is the best. I got all the above anti-NTK/​GP talking points from him, specifically his paper Feature Learning in Infinite Width Neural Networks. The basic idea of this paper is pretty neat—he derives a general framework for taking the ‘infinite-width-limit’ of ‘tensor programs’, general computation graphs containing tensors with a width parameter. He then applies this framework to SGD itself—the successive iterates of SGD can be represented as just another type of computation graph, so the limit can be taken straightforwardly, leading to a infinite-width limit distinct from the NTK/​GP one, and one in which the features computed by intermediate neurons can change. He also shows that this limit outperforms both finite-width nets and NTK/​GP., and learns non-trivial feature embeddings. Two caveats: this ‘tensor program limit’ is much more difficult to compute than NTK/​GP, so he’s only actually able to run experiments on networks with very few layers and/​or linear activations. And the scaling used to take the limit is actually different from that used in practice. Nevertheless, I think this represents the best theoretical attempt yet to capture the non-kernel learning that seems to be going on in neural nets.

To be clear, I think that the NTK/​GP models have been a great advance in our understanding of neural networks, and it’s good to see people on LW discussing them. However, there are some important phenomena they fail to explain. They’re a good first step, but a comprehensive theoretical account of neural nets has yet to be written.[4]

  1. You might be wondering how it’s possible for the output function to change if none of the individual neurons’ functions change. Basically, since the output is the sum of N things, each of them only needs to change by O(1/​N) to change the output by O(1), so they don’t change at all in the wide-width limit(See also my discussion with johnswentworth in the comments) ↩︎

  2. Sort of. A more exact statement might be that the NTK can technically do transfer learning, but only trivially so, i.e. it can only ‘transfer’ to tasks to the extent that they are exactly the same as its original task. See this comment. ↩︎

  3. In fairness to the NTK/​GP, they also haven’t been tried as much on more difficult problems because they scale worse than neural nets in terms of data(D^2*(kernel eval cost) in number of data points, since you need to compute the kernel between all points). So it’s possible that they could do better if people had the chance to try them out more, iterate improved versions, and so on. ↩︎

  4. I’ll confess that I would personally find it kind of disappointing if neural nets were mostly just an efficient way to implement some fixed kernels, when it seems possible that they could be doing something much more interesting—perhaps even implementing something like a simplicity prior over a large class of functions, which I’m pretty sure NTK/​GP can’t be ↩︎