This paper claims to sample the Bayesian posterior of NN training, but I think it’s wrong.
“What Are Bayesian Neural Network Posteriors Really Like?” (Izmailov et al. 2021) claims to have sampled the Bayesian posterior of some neural networks conditional on their training data (CIFAR-10, MNIST, IMDB type stuff) via Hamiltonian Monte Carlo sampling (HMC). A grand feat if true! Actually crunching Bayesian updates over a whole training dataset for a neural network that isn’t incredibly tiny is an enormous computational challenge. But I think they’re mistaken and their sampler actually isn’t covering the posterior properly.
They find that neural network ensembles trained by Bayesian updating, approximated through their HMC sampling, generalise worse than neural networks trained by stochastic gradient descent (SGD). This would have been incredibly surprising to me if it were true. Bayesian updating is prohibitively expensive for real world applications, but if you can afford it, it is the best way to incorporate new information. You can’t do better.[1]
This is kind of in the genre of a lot of papers and takes I think used to be around a few years back, which argued that the then still quite mysterious ability of deep learning to generalise was primarily due to some advantageous bias introduced by SGD. Or momentum, or something along these lines. In the sense that SGD/momentum/whatever were supposedly diverging from Bayesian updating in a way that was better rather than worse.
I think these papers were wrong, and the generalisation ability of neural networks actually comes from their architecture, which assigns exponentially more weight configurations to simple functions than complex functions. So, most training algorithms will tend to favour making simple updates, and tend to find simple solutions that generalise well, just because there’s exponentially more weight settings for simple functions than complex functions. This is what Singular Learning Theory talks about. From an algorithmic information theory perspective, I think this happens for reasons similar to why exponentially more binary strings correspond to simple programs than complex programs in Turing machines.
This picture of neural network generalisation predicts that SGD and other training algorithms should all generalise worse than Bayesian updating, or at best do similarly. They shouldn’t do better.
So, what’s going on in the paper? How are they finding that neural network ensembles updated on the training data with Bayes rule make predictions that generalise worse than predictions made by neural networks trained the normal way?
My guess: Their Hamiltonian Monte Carlo (HMC) sampler isn’t actually covering the Bayesian posterior properly. They try to check that it’s doing a good job by comparing inter-chain and intra-chain variance in the functions learned.
We apply the classic Gelman et al. (1992) “^R” potential-scale-reduction diagnostic to our HMC runs. Given two or more chains, ^R estimates the ratio between the between-chain variance (i.e., the variance estimated by pooling samples from all chains) and the average within-chain variance (i.e., the variances estimated from each chain independently). The intuition is that, if the chains are stuck in isolated regions, then combining samples from multiple chains will yield greater diversity than taking samples from a single chain.
They seem to think a good ^R in function space implies that the chains are doing a good job of covering the important parts of the space. But I don’t think that’s true. You need to mix in weight space, not function space, because weight space is where the posterior lives. Function space and weight space are not bijective, that’s why it’s even possible for simpler functions to have exponentially more prior than complex functions. So good mixing in function space does not necessarily imply good mixing in weight space, which is what we actually need. The chains could be jumping from basin to basin very rapidly instead of spending more time in the bigger basins corresponding to simpler solutions like they should.
And indeed, they test their chains’ weight space ^R value as well, and find that it’s much worse:
Figure 2. Log-scale histograms of ^R convergence diagnostics. Function-space ^Rs are computed on the test-set softmax predictions of the classifiers and weight-space ^Rs are computed on the raw weights. About 91% of CIFAR-10 and 98% of IMDB posterior-predictive probabilities get an ^R less than 1.1. Most weight-space ^R values are quite small, but enough parameters have very large ^Rs to make it clear that the chains are sampling from different distributions in weight space. ... (From section 5.1) In weight space, although most parameters show no evidence of poor mixing, some have very large ^Rs, indicating that there are directions in which the chains fail to mix.
.... (From section 5.2) The qualitative differences between (a) and (b) suggest that while each HMC chain is able to navigate the posterior geometry the chains do not mix perfectly in the weight space, confirming our results in Section 5.1.
So I think they aren’t actually sampling the Bayesian posterior. Instead, their chains jump between modes a lot and thus unduly prioritise low-volume minima compared to high volume minima. And those low-volume minima are exactly the kind of solutions we’d expect to generalise poorly.
I don’t blame them here. It’s a paper from early 2021, back when very few people understood the importance of weight space degeneracy properly aside from some math professor in Japan whom almost nobody in the field had heard of. For the time, I think they were trying something very informative and interesting. But since the paper has 300+ citations and seems like a good central example of the SGD-beats-Bayes genre, I figured I’d take the opportunity to comment on it now that we know so much more about this.
The subfield of understanding neural network generalisation has come a long way in the past four years.
Thanks to Lawrence Chan for pointing the paper out to me. Thanks also to Kaarel Hänni and Dmitry Vaintrob for sparking the argument that got us all talking about this in the first place.
This paper claims to sample the Bayesian posterior of NN training, but I think it’s wrong.
“What Are Bayesian Neural Network Posteriors Really Like?” (Izmailov et al. 2021) claims to have sampled the Bayesian posterior of some neural networks conditional on their training data (CIFAR-10, MNIST, IMDB type stuff) via Hamiltonian Monte Carlo sampling (HMC). A grand feat if true! Actually crunching Bayesian updates over a whole training dataset for a neural network that isn’t incredibly tiny is an enormous computational challenge. But I think they’re mistaken and their sampler actually isn’t covering the posterior properly.
They find that neural network ensembles trained by Bayesian updating, approximated through their HMC sampling, generalise worse than neural networks trained by stochastic gradient descent (SGD). This would have been incredibly surprising to me if it were true. Bayesian updating is prohibitively expensive for real world applications, but if you can afford it, it is the best way to incorporate new information. You can’t do better.[1]
This is kind of in the genre of a lot of papers and takes I think used to be around a few years back, which argued that the then still quite mysterious ability of deep learning to generalise was primarily due to some advantageous bias introduced by SGD. Or momentum, or something along these lines. In the sense that SGD/momentum/whatever were supposedly diverging from Bayesian updating in a way that was better rather than worse.
I think these papers were wrong, and the generalisation ability of neural networks actually comes from their architecture, which assigns exponentially more weight configurations to simple functions than complex functions. So, most training algorithms will tend to favour making simple updates, and tend to find simple solutions that generalise well, just because there’s exponentially more weight settings for simple functions than complex functions. This is what Singular Learning Theory talks about. From an algorithmic information theory perspective, I think this happens for reasons similar to why exponentially more binary strings correspond to simple programs than complex programs in Turing machines.
This picture of neural network generalisation predicts that SGD and other training algorithms should all generalise worse than Bayesian updating, or at best do similarly. They shouldn’t do better.
So, what’s going on in the paper? How are they finding that neural network ensembles updated on the training data with Bayes rule make predictions that generalise worse than predictions made by neural networks trained the normal way?
My guess: Their Hamiltonian Monte Carlo (HMC) sampler isn’t actually covering the Bayesian posterior properly. They try to check that it’s doing a good job by comparing inter-chain and intra-chain variance in the functions learned.
They seem to think a good ^R in function space implies that the chains are doing a good job of covering the important parts of the space. But I don’t think that’s true. You need to mix in weight space, not function space, because weight space is where the posterior lives. Function space and weight space are not bijective, that’s why it’s even possible for simpler functions to have exponentially more prior than complex functions. So good mixing in function space does not necessarily imply good mixing in weight space, which is what we actually need. The chains could be jumping from basin to basin very rapidly instead of spending more time in the bigger basins corresponding to simpler solutions like they should.
And indeed, they test their chains’ weight space ^R value as well, and find that it’s much worse:
So I think they aren’t actually sampling the Bayesian posterior. Instead, their chains jump between modes a lot and thus unduly prioritise low-volume minima compared to high volume minima. And those low-volume minima are exactly the kind of solutions we’d expect to generalise poorly.
I don’t blame them here. It’s a paper from early 2021, back when very few people understood the importance of weight space degeneracy properly aside from some math professor in Japan whom almost nobody in the field had heard of. For the time, I think they were trying something very informative and interesting. But since the paper has 300+ citations and seems like a good central example of the SGD-beats-Bayes genre, I figured I’d take the opportunity to comment on it now that we know so much more about this.
The subfield of understanding neural network generalisation has come a long way in the past four years.
Thanks to Lawrence Chan for pointing the paper out to me. Thanks also to Kaarel Hänni and Dmitry Vaintrob for sparking the argument that got us all talking about this in the first place.
See e.g. the first chapters of Jaynes for why.
Thanks Lucius. This agrees with my take on that paper and I’m glad to have this detailed comment to refer people to in the future.
It’s still wild to me that highly cited papers in this space can make such elementary errors.
Do you have any papers or other resources you’d recommend that cover the latest understanding? What is the SOTA for Bayesian NNs?