Maybe I was too harsh on deep learning theory (three days ago)
A few days ago, I reviewed a paper titled “There Will Be a Scientific Theory of Deep Learning”. In it, I expressed appreciation for the authors for writing the piece, but skepticism for stronger forms of their titular claims.
Since then I’ve spoken with various past collaborators (via text and in person), and read or reread quite a few deep learning theory papers, including the bombshell Zhang et al. 2016 and Nagarajan et al. 2019 papers that I wrote about on LessWrong.
And the thing is, parts of the infinite width/depth-limit work turned out to be much more interesting than I thought it was. Perhaps I have judged deep learning theory (a bit) too harshly.
(Thanks to Dmitry Vaintrob and Kareel Hänni in particular for conversations on this topic. Much of this was in private, but was spurred on by a comment from Dmitry that can be found on LessWrong. Also thanks again to the authors of the scientific theory of deep learning paper, which provided a bunch of references to papers that I had forgotten or been previously unaware of.)
A lot of my impression for the infinite-width and depth-limit work comes from the neural tangent kernel/neural network Gaussian Process line of work. This line of work starts from Radford Neal’s 1994 paper, where he noted that an infinitely-wide single hidden-layer neural network with random weights is a Gaussian Process. In 2017/2018, this work was extended to deep neural networks; it was shown by Lee et al.[1] that a randomly initialized deep neural network was, if you took a certain type of infinite width limit, also a Gaussian Process. This was then extended to the Neural Tangent Kernel work, which described the training dynamics of these infinitely wide neural networks, and showed that it was equivalent to kernel gradient descent with a fixed kernel (the eponymous Neural Tangent Kernel). This allowed people to derive convergence properties and nontrivial generalization bounds.
Unfortunately, while beautiful, it was definitely not how neural networks learn. In the NTK limit, the network behaves as if it were doing linear regression in a feature space whose dimension is the number of neural net parameters. Notably, there is no feature learning, and only the last layer weights are updated by a noticeable amount. Unsurprisingly, this does not describe the behavior of neural networks; small (finite width) neural networks have been shown to outperform their equivalent tangent kernels.
An alternative way of taking an infinite width limit is Mean Field Theory (MFT, applied to deep neural networks). As I understand it, the basic idea behind Mean Field Theory in physics is that, instead of calculating the interactions between many objects, you replace the many-body interactions with an average “field” that captures the overall dynamics of the system. (Hence the name.) In neural network land, it turns out that you can take a different infinite-width limit in which the empirical distribution of hidden-unit parameters, viewed as a probability measure on parameter space, evolves under a deterministic flow. This was worked out around 2018 by Mei, Montanari, and Nguyen, Chizat and Bach, Rotskoff and Vanden-Eijnden, and Sirignano and Spiliopoulos.
Notably, in this different infinite width limit, networks actually learn features. NTK uses 1/√N scaling, which makes parameters move only O(1/√N) during training: too small to change the effective kernel. Mean-field uses 1/N scaling, which lets parameters move Θ(1), so the kernel evolves and hidden representations change over the course of training. In MFT, the model is doing something other than glorified linear regression in a fixed random feature space. That being said, for a few years, MFT was entirely a theory of 2-layer neural networks, and it was genuinely unclear how to extend this to deeper networks.
As with most of the deep learning community, I was very impressed by the Tensor Program work of Greg Yang, which was an extension (though not an obvious one) of the 2-layer MFT work. Greg Yang proved a series of theorems that allowed him to create a unifying framework (abc parameterization) for deep neural networks, where NNGP/NTK and MFT were special cases of this family. Notably, this allowed him to derive μP (maximal-update parameterization) which allows hyperparameter transfer across width (though later work would extend this to depth as well). This is widely considered to be perhaps the clearest application (some would say, only clear application) of modern deep learning theory.
In my memory, I chalked this up as Greg Yang being a genius. In my recollection of the work, I remembered only μP and the toy shallow neural network model that Yang created that allows one to rederive it.
What I missed, and only learned in the past few days, is that Yang didn’t invent this machinery from whole cloth.[2] There was a different line of work, done by a team at Google Brain that was confusingly also titled mean-field theory, which studied how signals traveled forward and backwards at initialization (though not the training dynamics). Two pioneering examples of this work include Poole et al.‘s Exponential expressivity in deep neural networks through transient chaos and Schoenholz et al.’s Deep Information Propagation. Greg Yang’s Tensor Program work descended from this line of work, and Greg Yang was a collaborator with Schoenholz and others.
Reading the work, it’s clear how Yang’s work draws inspiration from this signal propagation branch of MFT.[3] For example, the signal-propagation MFT work contained special cases of Greg Yang’s Master Theorem, in that they both utilize the fact that at infinite width, pre-activations are Gaussian to track their evolution layer by layer via a deterministic recursion on covariances.
(My guess is the namespace collision is why I somehow missed this line of work; I had read up on the 2-layer training dynamics branch of MFT, thought I had understood the relevant parts of MFT, and missed the signal propagation branch entirely)
I still think the strong version of “there will be a scientific theory of deep learning”, that explains why SGD on overparameterized nets generalizes, why particular architectural choices work, and what particular features get learned is far from established. I also think that the Zhang et al. and Nagarajan et al. results remain genuinely damning for the older PAC-Bayes / uniform-convergence approaches. I don’t think anything in the MFT/TP literature addresses the core puzzles those papers raised (they address very different questions in very different regimes).
But a lot of my pessimism to deep learning theory came from feeling like there was not a coherent intellectual tradition that could point to concrete wins. Insofar as MFT (both the signal-propagation and training-dynamics branches) and Tensor Programs constitute such a tradition (as opposed to primarily the work of a single brilliant individual), then there is at least one tradition in deep learning theory that has produced cumulative progress and made falsifiable predictions that have been confirmed in practice. That deserves more credit than what I was giving the field.
Oops.
I sometimes run into bright young AI people with plenty of interest in math but not so much in engineering, who ask me what they should study. Beyond the very basics of deep learning (e.g. optimizers, basic RL theory), I used to give a shrug and say “Maybe computation in superposition? Maybe Singular Learning Theory?”. From now on I think I’ll start my answer with “probably the Mean Field Theory and Tensor Programs work.”
- ^
Edited to add: As Adria says in a comment below, this was also shown in concurrent work by Matthews et al 2018: https://arxiv.org/abs/1804.11271.
- ^
Yes, this was obvious in retrospect. As I say later in the post, oops.
- ^
Of course, there’s a lot of work on initializations (E.g. the Xavier and He initializations), most of which relied on 1) tracking the forward and backwards passes, 2) heuristic calculations of the scale of various parameters and 3) an independence assumption between params and gradients at initialization, and were substantially less sophisticated than the MFT work. While the mu-Parameterization tensor program paper also provides these heuristic calculations (allowing one to rederive mu-P from a toy model), it formalized these assumptions with tools from free probability and random matrix theory.
- ^
The closest work I’m aware of that touches is Rubin, Seroussi, and Ringel’s Grokking as a First Order Phase Transition in Two Layer Networks
Wait, I thought the singular learning theory stuff already did this part? (Just the “why SGD on overparameterized nets generalizes” part, not the “why particular architectural choices work” or “what particular features get learned” parts.) Neural networks being singular means that the parameter–function map is not a one-to-one correspondence, which means that simpler hypotheses (those that need fewer parameters to be specified or can correct “errors” in some parameters) occupy more volume in parameter-space and are easier for SGD to find first, such that training is implicitly doing a form of minimum-description-length program induction (with the learning coefficient being the measure of complexity rather than the parameter count). Is that too “qualitative” to count as an answer (because the architecture and feature prediction parts are the true test of knowledge)?
I don’t think SLT explains why SGD on overparametrised nets generalises. I actually think “overparametrised” is a kind of classical term that we shouldn’t be using anymore, but anyway, SLT does provide a mathematical framework in which Bayesian learning with very large models need not generalise poorly, which would have been a very useful prior for generations of theorists thinking about deep learning to have (and if they had, then many years of confusion might have been avoided imo). However, as Lucius pointed out in his comment, Bayesian learning is not SGD and even if that gap is bridged, just because generalisation is possible doesn’t mean you have a sufficient explanation of why it is actually happening.
I wrote about this at some length in this old comment which you might find useful.
Having said that, I think that in time we’ll see the gap between Bayesian learning and SGD is not as profound as it seems right now. While some of the ways they could be directly related are not true, it will turn out to be true I think that comparison of probabilities of regions of parameter space according to the Bayesian posterior do tend to govern the statistics of SGD trajectories to a significant degree. At that point a lot of the qualitative conclusions one might draw from the basic picture of SLT will just be good descriptions of what SGD is up to; but that work remains in the future.
Not quite. SLT is for a specific subcase of Bayesian learning only, not SGD. Maybe more importantly for this point, it also doesn’t really show why neural network priors are good, just that neural network priors strongly favour some solutions over others.
Some SLT-adjacent stuff is pretty strongly suggestive of a proper answer, but I don’t think there’s a proper full proof of what we want in generality written up publicly yet.
some more thoughts quickly:
SLT studies the limit as the number of data points goes to infty. this is the opposite of overparametrized! also this seems at least on the face of it like a bizarre setting for studying generalization, which is about guessing correctly after seeing only a small amount of data
i think the subset of weight space corresponding to a function is generally not well thought of as a small local region around any weight vector, especially not in the overparametrized case.
edit added later: however it’s plausible that with the mean field prior scaling you get a contribution to [the prior on a function] from ( of) a macroscopic ball around a certain weight vector which is of [the prior on that function] but however in a weaker sense a decent chunk of the entire prior on that function anyway.
[1]
so in that sense there might be an interesting semi-local thing going on. sorry i’ll need to think more about this
imo it’s good to scrap a bunch of the story given. the part i’d keep is “in cases where NN bayesianism has good generalization properties [2] , simpler functions [3] generally have more prior weight than more complicated functions” (but this is roughly an obvious logical truth that has basically nothing to do with SLT?), and then the question is “why do simpler functions have more prior?”, ie “why do simpler functions have some combination of implementations having smaller weight norm and taking more volume”, and i think one is better off approaching that question basically from scratch. (also this is all about understanding NN bayes. SGD is a meaningfully different thing.)
sorry i’m aware this is very much not clear but making it clear would be a bunch of work and i’m not going to do it atm
which probably isn’t always. eg it’s probably pretty false for the prior scaling that gives NNGP in the wide limit. a good story would be able to “see” this difference between differently scaled gaussian priors
btw the correct meaning of simplicity in this setting is not kolmogorov complexity, but instead circuit size
SLT is not about a limit as the number of data points goes to infinity. Or at least, it is about such a limit insofar as talking about the mean of a random variable is about “studying the limit as the number of data points goes to infinity” which is not how one would normally talk about such things. In particular, I think when people use this phrasing they are (either deliberately or not) making a comparison to “infinite width limits” and I do not think this is a correct analogy.
and the population loss , which is the expectation over with respect to the dataset. The mean of a random variable (in this case a function) is an idealised quantity never encountered in practice, for sure. However, to describe a theory that is organised around means of random variables as being “about” infinite limits seems at odds with how most people would think about statistics.
and the actual you encounter in the real world (and as you say, generalisation has this form: the conceptual content of the theory is a surprising fact, that geometry of the mean object governs the generalisation behaviour at finite , and these are not some exotic effects that are only visible at enormous , as many of the examples in Watanabe’s textbooks will show you).
in the asymptotic expansions that characterise some of the central theorems in SLT. Here it is true that one would expect these analyses to be more correct as becomes larger, and at any value of one cannot a priori rule out that “lower order terms” in fact contribute more than higher order terms. But this is not a phenomena or situation unique to SLT, and indeed has the same shape as applications of Laplace approximations everywhere (and is a situation also commonly encountered in mathematical physics). At the end of the day such asymptotic expansions are commonly used across applications of mathematics to real world phenomena, they are highly successful, and theory alone cannot tell you when they are valid: you have to actually do experiments.
, as though this was a theory whose domain of applicability is restricted to enormous . There are separate questions one can ask about effective theories etc, and finite phenomena that are not accounted for by the asymptotic expansions, I don’t mean to dismiss any of that as unimportant (and indeed we think about that kind of thing and continue to work on it). However, I want to push back against some oversimplified characterisation of SLT as a “theory about the infinite limit”.
So there’s two things here: one is the relation between an empirical loss
Most of the nontrivial mathematical content of SLT is exactly about accounting for the difference between
The other is the role of
But it is strange to rule out in principle the application of such techniques to study phenomena at finite
I agree it would be strange to strongly rule out the application of this at finite in principle. I think I’ve made a fine simple defeasible argument against this for the overparametrized case, and I think the claim will turn out to be true with more careful investigation, but I haven’t really carried out a rigorous version of this investigation and certainly haven’t spelled the reasoning out, in my comment above or elsewhere publicly. I agree the argument I gave above is not definitive.
I think you agree that a central crux is whether the and terms the SLT expansion takes to be leading order in the posterior are in fact larger than the “lower order terms”, in the overparametrized case?
[1]
I would weakly guess that in most of the reasonable cases of NN bayesianism, correct generalization happens at meaningfully smaller than when the terms SLT considers “lower order” in the comparison of posteriors become actually smaller than these “leading terms”
[2]
. I would guess this strongly about the strongly overparametrized case (the infinite width limit), and reasonably strongly about “mildly overparametrized” cases. Do you agree that if these guesses turn out to be right, then that would be a strong argument against thinking in SLT terms for understanding NNbayes generalization? Ditto for specifically the overparametrized case. If you agree this would be a strong argument, would you like to register opposite guesses to some or all of these?
[3]
(Actually, imo the nicest version of NN bayesianism is where learning is just conditioning the NN prior on outputs on train data inputs being closer than some fixed precision, say when the labels are , to the given outputs. Idk how to think about this in SLT terms, given that in the realizable case there is then just a full-dimensional region of solutions, with number-of-mistakes-loss being just const locally around all but the points on the boundary? If SLT aspires to be the theory of NN bayesian learning, it should help us think about this case. But idk what “the terms SLT considers high order” should even mean here, ie I’m not sure what the pro-SLT side of the statements above should even be for this case.
[4]
)
not a very important disagreement but: I disagree that you have to do experiments to figure out if these expansions are valid. [5] There are at least two other things you can do: you can think heuristically about whether the lower-order terms are actually smaller in cases of interest, and you can work out some cases theoretically.
more precisely/correctly, it’s whether the difference between posteriors of different functions with perfect agreement with the data so far mostly comes from these terms + these are useful auxiliary variables to track to make sense of generalization
by this I mean that simpler functions are already preferred earlier
If you were to agree these would be strong arguments but were not willing to make these guesses, then I’d feel like you’re responding to the claim not-P with sth like “it can remain defensible to think P”, and I guess I’d agree that can remain defensible, but we should figure out what’s true? :)
Btw I think that Dmitry Vaintrob and I can probably prove a generalization result for this case, for certain scalings of the prior “stronger”/”smaller” than the scaling which gives NNGP, at least modulo restricting the support to weight vectors satisfying some “robustness condition” which I’m still unsure if I should think of as being contrived or not (there could also be some less contrived version — work remains on this). Anyway, that argument doesn’t think about local geometry at all. This isn’t published yet, it isn’t fully worked out, it could turn out to not work, and you don’t have to believe me, but hopefully this partly explains where I’m coming from.
Also, I wouldn’t be surprised if the experiments one does are not actually measuring what one thinks they are measuring, eg I wouldn’t be surprised if one’s “RLCT estimator” is not actually close to the RLCT.
Do you know about the double descent phenomenon? It contradicts the bias-variance tradeoff. Welch Labs did a very good breakdown recently.
Also contemporaneously Alexander G. de G. Matthews et al.! And, while less famous, that paper was better in one way: it took the limit of the width of all layers simultaneously, instead of one by one. That is, Lee et al was a statement about:
lim(width->infty) [ b_2 + W_2 nonlinearity( lim(width → infty) [W_1x + b_1])]
whereas Matthews et al was a statement about:
lim(width->infty)[ b_2 + W_2 nonlinearity(W_1x + b_1)]
which is more complicated
Good citation, that paper seems to have slipped my recollection (probably because it’s less famous, as you said). Added a footnote to clarity.
Hey commendations on sharing your update.
Another similar line of work I like is Roberts+Yaida’s “Principles of Deep Learning Theory”—this is a similar-in-spirit approach to MFT, but they perturb around a different limit and get feature-learning as a finite-width effect. I haven’t studied MFT to compare the validity of the two; my guess is MFT is the more relevant description. PDLT at least does a very good job modernizing the NTK approach and connecting to the older literature. I’m a fanboy as it was my gateway drug for learning theory lol.
Yes! I was familiar with PDLT as well, and I do think it’s a similar-in-spirit approach to MFT (if not a continuation of the signal-propagation MFT work). Thanks for the pointer.
This inspired me to read up on mu Parametrization, and though it’s interesting, I might end up using it next time I’m training deep neural networks and want to find good hyperparameters at scale, it really doesn’t seem like anything that could potentially lead to deep safety-relevant understanding. It’s just solving for the values of the parameters that keep activation magnitudes stable. I don’t know about tensor programs or the other things you mentioned. Maybe there’s a case for those.
Yeah, the main application of deep learning theory is muP; the main application to safety is probably not that. muP by itself is not relevant to safety, except insofar as it means people don’t use NTKs as their toy model (though they probably weren’t anyways).
I bring up muP because it’s the main (or only) concrete application of deep learning theory; insofar as you dismiss theory b/c there’s no wins, muP is evidence against that conclusion, in the same way that a lack of other wins is evidence for.