I think that when we talk about regularization in some kind of context of “efficiency”, we should include implicit regularization of this type and any other phenomenon that encourages lower-weight-norm solutions.
It does seem like small initialisation is a regularisation of a sort, but it seems pretty hard to imagine how it might first allow a memorising solution to be fully learned, and then a generalising solution. Maybe gradient descent in general tends to destroy memorising circuits for reasons like the “edge of stability” stuff Dmitry alludes to. But is the low initial weight norm playing much role there? Maybe there’s a norm-dependent factor?
It does seem like small initialisation is a regularisation of a sort, but it seems pretty hard to imagine how it might first allow a memorising solution to be fully learned, and then a generalising solution.
“Memorization” is more parallelizable and incrementally learnable than learning generalizing solutions and can occur in an orthogonal subspace of the parameter space to the generalizing solution.
And so one handwavy model I have of this is a low parameter norm initializes the model closer to the generalizing solution than otherwise, and so a higher proportion of the full parameter space is used for generalizing solutions.
The actual training dynamics here would be the model first memorizes a high proportion of the training data while simultaneously learning a lossy/inaccurate version of the generalizing solution in another subspace (the “prioritization” / “how many dimensions are being used” extent of the memorization being affected by the initialization norm). Then, later in training, the generalization can “win out” (due to greater stability / higher performance / other regularization).
In particular, in most unregularized models we see that generalize (and I think also the ones in omnigrok), grokking happens early, usually before full memorization (so it’s “grokking” in the redefinition I gave above).
It does seem like small initialisation is a regularisation of a sort, but it seems pretty hard to imagine how it might first allow a memorising solution to be fully learned, and then a generalising solution. Maybe gradient descent in general tends to destroy memorising circuits for reasons like the “edge of stability” stuff Dmitry alludes to. But is the low initial weight norm playing much role there? Maybe there’s a norm-dependent factor?
“Memorization” is more parallelizable and incrementally learnable than learning generalizing solutions and can occur in an orthogonal subspace of the parameter space to the generalizing solution.
And so one handwavy model I have of this is a low parameter norm initializes the model closer to the generalizing solution than otherwise, and so a higher proportion of the full parameter space is used for generalizing solutions.
The actual training dynamics here would be the model first memorizes a high proportion of the training data while simultaneously learning a lossy/inaccurate version of the generalizing solution in another subspace (the “prioritization” / “how many dimensions are being used” extent of the memorization being affected by the initialization norm). Then, later in training, the generalization can “win out” (due to greater stability / higher performance / other regularization).
In particular, in most unregularized models we see that generalize (and I think also the ones in omnigrok), grokking happens early, usually before full memorization (so it’s “grokking” in the redefinition I gave above).