Rethinking Batch Normalization

Yes­ter­day we saw a glimpse into the in­ner work­ings of batch nor­mal­iza­tion, a pop­u­lar tech­nique in the field of deep learn­ing. Given that the effec­tive­ness of batch nor­mal­iza­tion has been demon­strated be­yond any rea­son­able doubt, it may come as a sur­prise that re­searchers don’t re­ally know how it works. At the very least, we sure didn’t know how it worked when the idea was first pro­posed.

One might first con­sider that last state­ment to be un­likely. In the last post I out­lined a rel­a­tively sim­ple the­o­ret­i­cal frame­work for ex­plain­ing the suc­cess of batch nor­mal­iza­tion. The idea is that batch nor­mal­iza­tion re­duces the in­ter­nal co­vari­ate shift (ICS) of lay­ers in a net­work. In turn, we have a neu­ral net­work that is more sta­ble, and ro­bust to large learn­ing rates, and al­lows much quicker train­ing.

And this was the stan­dard story in the field for years, un­til a few re­searchers de­cided to ac­tu­ally in­ves­ti­gate it.

Here, I hope to con­vince you that the the­ory re­ally is wrong. While I’m fully pre­pared to make ad­di­tional epistemic shifts on this ques­tion in the fu­ture, I also fully ex­pect to never shift my opinion back.

When I first read the origi­nal batch nor­mal­iza­tion pa­per, I felt like I re­ally un­der­stood the hy­poth­e­sis. It felt sim­ple enough, was rea­son­ably de­scrip­tive, and in­tu­itive. But I didn’t get a perfect vi­sual of what was go­ing on — I sort of hand-waved the step where ICS con­tributed to an un­sta­ble gra­di­ent step. In­stead I, like the pa­per, ar­gued by anal­ogy, that since con­trol­ling for co­vari­ate shifts were known for decades to help train­ing, a tech­nique to re­duce in­ter­nal co­varaite shift is thus a nat­u­ral ex­ten­sion of this con­cept.

It turned out this the­ory wasn’t even a lit­tle bit right. It’s not that co­vari­ate shifts aren’t im­por­tant at all, but that the en­tire idea is based on a false premise.

Or at least, that’s the im­pres­sion I got while read­ing Shibani San­turkar et al.’s How Does Batch Nor­mal­iza­tion Help Op­ti­miza­tion? Whereas the origi­nal batch nor­mal­iza­tion pa­per gave me a sense of “I kinda sorta see how this works,” this pa­per com­pletely shat­tered my in­tu­itions. It wasn’t just the weight of the em­piri­cal ev­i­dence, or the the­o­ret­i­cal un­der­pin­ning they pre­sent; in­stead what won me over was the sur­gi­cal pre­ci­sion of their re­but­tal. They saw how to for­mal­ize the the­ory of im­prove­ment via ICS re­duc­tion and tested it on BatchNorm di­rectly. The the­ory turned out to be sim­ple, in­tu­itive, and false.

In fair­ness, it wasn’t laz­i­ness that pro­hibited re­searchers from reach­ing our cur­rent level of un­der­stand­ing. In the origi­nal batch nor­mal­iza­tion pa­per, the au­thors in­deed pro­posed a test for mea­sur­ing batch nor­mal­iza­tion’s effect on ICS.

The prob­lem was in­stead twofold: their method for mea­sur­ing ICS was in­ad­e­quate, and failed to con­sis­tently ap­ply their pro­posed mechanism for how ICS re­duc­tion was sup­posed to work in their test­ing con­di­tions. More im­por­tantly how­ever, they didn’t even test the the­ory that ICS re­duc­tion con­tributed to perfor­mance gains. In­stead their ar­gu­ment was based on a sim­ple heuris­tic: we know that co­vari­ate shifts are bad, we think that batch nor­mal­iza­tion re­duces ICS, and we also know batch nor­mal­iza­tion in­creases perfor­mance char­ac­ter­si­tics — there­fore batch nor­mal­iza­tion works due to ICS re­duc­tion. As far as I can tell, most the ar­ti­cles that came af­ter the origi­nal pa­per just took this heuris­tic at face value, cit­ing the pa­per and call­ing it a day.

And it’s not a bad heurstic, all in all. But per­haps it’s a tiny bit tel­ling that on yes­ter­day’s post, Less­wrong user crab­man was able to an­ti­ci­pate the true rea­son for batch nor­mal­iza­tion’s suc­cess, defy­ing both my post and the sup­posed years that it took re­searchers to figure this stuff out. Quoth crab­man,

I am imag­in­ing this in­ter­nal co­vari­ate shift thing like this: the neu­ral net­work to­gether with its loss is a func­tion which takes pa­ram­e­ters θ as in­put and out­puts a real num­ber. Large in­ter­nal co­vari­ate shift means that if we choose ε>0, perform some SGD steps, get some θ, and look at the func­tion’s graph in ε-area of θ, it doesn’t re­ally look like a plane, it’s more curvy like.

In fact, the above para­graph doesn’t ac­tu­ally de­scribe in­ter­nal co­vari­ate shift, but in­stead the smooth­ness of the loss func­tion around some pa­ram­e­ters . I con­cede, it is per­haps pos­si­ble that this is re­ally what the origi­nal re­searchers meant when they termed in­ter­nal co­vari­ate shift. It is there­fore also pos­si­ble that this whole cri­tique of the origi­nal the­ory is based on noth­ing but a mi­s­un­der­stand­ing.

But I’m not buy­ing it.

Take a look at how the origi­nal pa­per defines ICS,

We define In­ter­nal Co­vari­ate Shift as the change in the dis­tri­bu­tion of net­work ac­ti­va­tions due to the change in net­work pa­ram­e­ters dur­ing train­ing.

This defi­ni­tion can’t merely re­fer to the smooth­ness of the gra­di­ent around θ. For ex­am­ple, the gra­di­ent could be ex­tremely bumpy and have sharp edges and yet ICS could be ab­sent. Can you think of an ex­am­ple of a neu­ral net­work like this? Here’s one: think of a net­work with just one layer whose loss func­tion is some ex­tremely con­torted shape be­cause its ac­ti­va­tion func­tion is some crazy non-lin­ear func­tion. It wouldn’t be smooth, but its in­put dis­tri­bu­tion would be con­stant over time, given that it’s only one layer.

I can in­stead think of two in­ter­pre­ta­tions of the above defi­ni­tion for ICS. The first in­ter­pre­ta­tion is that ICS sim­ply refers to the change of ac­ti­va­tions in a layer dur­ing train­ing. The sec­ond in­ter­pre­ta­tion is that this defi­ni­tion speci­fi­cally refers to change of ac­ti­va­tions caused by changes in net­work pa­ram­e­ters at pre­vi­ous lay­ers.

This is a sub­tle differ­ence, but I be­lieve it’s im­por­tant to un­der­stand. The first in­ter­pre­ta­tion al­lows ease of mea­sure­ment, since we can sim­ply plot the mean and var­i­ance of the in­put dis­tri­bu­tions of a layer dur­ing train­ing. This is in fact how the pa­per (sec­tion 4.1) tests batch nor­mal­iza­tion’s effect on ICS. But re­ally, the sec­ond in­ter­pre­ta­tion sounds closer to the hy­poth­e­sized mechanism for how ICS was sup­posed to work in the first place.

On the level of ex­per­i­men­ta­tion, the cru­cial part of the above defi­ni­tion is the part that says “change [...] due to the change in net­work pa­ram­e­ters.” Merely mea­sur­ing the change in net­work pa­ram­e­ters over time is in­suffi­cient. Why? Be­cause the hy­poth­e­sis was that if ac­ti­va­tion dis­tri­bu­tions change too quickly, then a layer will have its gra­di­ent pushed into a van­ish­ing or ex­plod­ing re­gion. In the first in­ter­pre­ta­tion, a change over time could still be slow enough for each layer to adapt ap­pro­pri­ately. There­fore, we need ad­di­tional in­for­ma­tion to dis­cover whether ICS is oc­cur­ring in the way that is de­scribed.

To mea­sure ICS un­der the sec­ond in­ter­pre­ta­tion, we have to mea­sure the coun­ter­fac­tual change of pa­ram­e­ters — in other words, the amount that the net­work ac­ti­va­tions change as a re­sult of other pa­ram­e­ters be­ing al­tered. And we also need a way of see­ing whether the gra­di­ent is be­ing pushed into ex­treme re­gions as a re­sult of these pa­ram­e­ters be­ing changed. Only then can we see whether this par­tic­u­lar phe­nomenon is ac­tu­ally oc­cur­ring.

The newer pa­per comes down heav­ily in fa­vor of this in­ter­pre­ta­tion, and adds a level of for­mal­iza­tion on top of it. Their defi­ni­tion fo­cuses on mea­sur­ing the differ­ence be­tween two differ­ent gra­di­ents: one gra­di­ent with all of the pre­vi­ous lay­ers al­tered by back prop­a­ga­tion, and one gra­di­ent where all of the pre­vi­ous lay­ers have been un­altered. Speci­fi­cally, let by a loss func­tion for a neu­ral net­work of lay­ers. Then, their defi­ni­tion of ICS for the ac­ti­va­tion and time is where

and is the batch of in­put-la­bel pairs to train the net­work at time .

The first thing to note about this defi­ni­tion is that it al­lows a clear, pre­cise mea­sure­ment of ICS, which is based solely on the change of the gra­di­ent due to shift­ing pa­ram­e­ters be­neath a layer dur­ing back­prop­a­ga­tion.

What Shibani San­turkar et al. found when they ap­plied this defi­ni­tion was a bit shock­ing. Not only did batch nor­mal­iza­tion fail to de­crease ICS, in some cases it even in­creased it when com­pared to naive feed­for­ward neu­ral net­works. And to top that off, they found that even in net­works where they ar­tifi­cially in­creased ICS, perfor­mance barely suffered.

In one ex­per­i­ment they ap­plied batch nor­mal­iza­tion to each hid­den layer in a neu­ral net­work, and at each step, they added noise af­ter the batch nor­mal­iza­tion trans­form in or­der to in­duce ICS. This noise wasn’t just Gaus­sian noise ei­ther. In­stead they chose the noise such that it was a differ­ent Gaus­sian at ev­ery time step and ev­ery layer, such that the Gaus­sian pa­ram­e­ters (speci­fi­cally mean and var­i­ance) varied ac­cord­ing to a yet an­other meta Gaus­sian dis­tri­bu­tion. What they dis­cov­ered was that even though this in­creased mea­sured ICS dra­mat­i­cally, the time it took to train the net­works to the baseline ac­cu­racy was al­most iden­ti­cal to reg­u­lar batch nor­mal­iza­tion.

And re­mem­ber that batch nor­mal­iza­tion ac­tu­ally does work. In all of the ex­per­i­ments for mere perfor­mance in­creases, batch nor­mal­iza­tion has passed the tests with fly­ing col­ors. So clearly, since batch nor­mal­iza­tion works, it must be for a differ­ent rea­son than sim­ply re­duc­ing ICS. But that leaves one ques­tion re­main­ing: how on Earth does it work?

I have already hinted at the rea­son above. The an­swer lies in some­thing even sim­pler to un­der­stand than ICS. Take a look at this plot.

Imag­ine the red ball is rol­ling down this slope, ap­ply­ing gra­di­ent de­scent at each step. And con­sider for a sec­ond that the red ball isn’t us­ing any mo­men­tum. It sim­ply looks at each step which di­rec­tion to move and moves in that di­rec­tion in pro­por­tion to the slope at that point.

A prob­lem im­me­di­ately arises. Depend­ing on how we choose our learn­ing rate, the red ball could end up get­ting stuck al­most im­me­di­ately. If the learn­ing rate is too slow, then it will prob­a­bly get stuck on the flat plane to the right of it. And in prac­tice, if its learn­ing rate is too high, then it might move over to an­other valley en­tirely, get­ting it­self into an ex­plod­ing re­gion.

The way that batch nor­mal­iza­tion helps is by chang­ing the loss land­scape from this bumpy shape into one more like this.

Now it no longer mat­ters that much what we set the learn­ing rate to. The ball will be able to find its way down even if its too small. What used to be a flat plane has now been rounded out such that the ball will roll right down.

The spe­cific way that the pa­per mea­sures this hy­poth­e­sis is by ap­ply­ing pretty stan­dard ideas from the real anal­y­sis toolkit. In par­tic­u­lar, the re­searchers at­tempted to mea­sure the Lip­s­chitz­ness of the loss func­tion around the pa­ram­e­ters for var­i­ous types of deep net­works (both em­piri­cally and the­o­ret­i­cally). For­mally a func­tion is L-Lip­s­chitz if for all and . In­tu­itively, this is a mea­sure of how smooth the func­tion is. The smaller the con­stant , the func­tion has fewer and less ex­treme jumps over small in­ter­vals in some di­rec­tion.

This way of think­ing about the smooth­ness of the loss func­tion has the ad­van­tage of in­clud­ing a rather nat­u­ral in­ter­pre­ta­tion. One can imag­ine that the mag­ni­tude of some gra­di­ent es­ti­mate is a pre­dic­tion of how much we ex­pect the func­tion to fall if we move in that di­rec­tion. We can then eval­u­ate how good we are at mak­ing pre­dic­tions across differ­ent neu­ral net­work schemes and across train­ing steps. When gra­di­ent pre­dic­tive­ness was tested, there were no sur­prises — the net­works with batch nor­mal­iza­tion had the most pre­dic­tive gra­di­ents.

Per­haps even more damn­ing is that not only did the loss func­tion be­come more smooth, the gra­di­ent land­scape it­self be­came more smooth, a prop­erty known as smooth­ness. This had the effect of not only mak­ing the gra­di­ents more pre­dic­tive of the loss, but the gra­di­ents them­selves were eas­ier to pre­dict in a cer­tain sense — they were fairly con­sis­tent through­out train­ing.

Per­haps the way that batch nor­mal­iza­tion works is by sim­ply smooth­ing out the loss func­tion. At each layer we are just ap­ply­ing some nor­mal­iz­ing trans­for­ma­tion which helps re­move ex­treme points in the loss func­tion. This has the ad­di­tional pre­dic­tion that other trans­for­ma­tion schemes will work just as well, which is ex­actly what the re­searchers found. Be­fore, the fact that we added some pa­ram­e­ters and was con­fus­ing, since it wasn’t clear how this con­tributed to ICS re­duc­tion. Now, we can see that ICS re­duc­tion shouldn’t even be the goal, per­haps shed­ding light on why this works.

In fact, there was pretty much noth­ing spe­cial with the ex­act way that batch nor­mal­iza­tion trans­forms the in­put, other than the prop­er­ties that con­tribute to smooth­ness. And given that so many more meth­ods have now come out which build on batch nor­mal­iza­tion de­spite us­ing quite differ­ent op­er­a­tions, isn’t this ex­actly what we would ex­pect?

Is this the way batch nor­mal­iza­tion re­ally works? I’m no ex­pert, but I found this in­ter­pre­ta­tion much eas­ier to un­der­stand, and also a much sim­pler hy­poth­e­sis. Maybe we should ap­ply Oc­cam’s ra­zor here. I cer­tainly did.

In light of this dis­cus­sion, it’s also worth re­flect­ing once again that the ar­gu­ment “We are go­ing to be build­ing the AI so of course we’ll un­der­stand how it works” is not a very good one. Clearly the field can stum­ble on solu­tions that work, and yet the rea­son why they work can re­main al­most com­pletely un­known for years, even when the an­swer is hid­ing in plain sight. I hon­estly can’t say for cer­tain whether hap­pens a lot, or too much. I only have my one ex­am­ple here.

In the next post, I’ll be tak­ing a step back from neu­ral net­work tech­niques to an­a­lyze gen­er­al­iza­tion in ma­chine learn­ing mod­els. I will briefly cover the ba­sics of statis­ti­cal learn­ing the­ory and will then move to a fram­ing of learn­ing the­ory in light of re­cent deep learn­ing progress. This will give us a new test bed to see if old the­o­ries can ad­e­quately adapt to new tech­niques. What I find might sur­prise you.