This is wacky but seems like a plausible thing a model might do: by doing this, the model would be able to, in later layers, make use of arbitrary boolean functions of early layer features.
This seems less wacky from a reservoir computing perspective and from the lack of recurrency: you have no idea early on what you need*, so you compute a lot of random complicated stuff up front, which is then available for the self-attention in later layers to do gradient descent on to pick out the useful features tailored to that exact prompt, and then try to execute the inferred algorithm in the time left. ‘arbitrary boolean functions of early layer features’ is not necessarily something I’d have expected the model to compute, but when put that way, it does seem like a bunch of those hanging around could be useful, no?
This is analogous to all the interpretability works which keep finding a theme of ‘compute or retrieve everything which might be relevant up front, and then throw most of it away to do some computation with the remainder and make the final prediction’. Which is further consistent with the observations about sparsity & pruning: if most of a NN is purely speculative and exists to enable learning when one of those random functions turns out to be useful, then once learning has been accomplished, of course you can stop screwing around with speculatively computing loads of useless random things.
EDIT: if this perspective is ‘the reason’, I guess it would predict that at very low loss or perfect prediction, these xor features would go away unless they had a clearcut use. And that if you trained in a sparsity-creating fashion or did knowledge-distillation, then the model would also be incentivized to get rid of them. You would also maybe be able to ablate/delete them, in some way, and show that transfer learning and in-context learning become worse? I would expect any kind of recurrent or compressive approach would tend to do less of every kind of speculative computation, so maybe that would work too.
* anything in the early layers of a Transformer has no idea what is going on, because you have no recurrent state telling you anything from previous timesteps which could summarize the task or relevant state for you; and the tokens/activations are all still very local because you’re still early in the forward pass. This is part of why I think the blind spot might be from early layers throwing away data prematurely, making irreversible errors.
This seems less wacky from a reservoir computing perspective and from the lack of recurrency: you have no idea early on what you need*, so you compute a lot of random complicated stuff up front, which is then available for the self-attention in later layers to do gradient descent on to pick out the useful features tailored to that exact prompt, and then try to execute the inferred algorithm in the time left. ‘arbitrary boolean functions of early layer features’ is not necessarily something I’d have expected the model to compute, but when put that way, it does seem like a bunch of those hanging around could be useful, no?
This is analogous to all the interpretability works which keep finding a theme of ‘compute or retrieve everything which might be relevant up front, and then throw most of it away to do some computation with the remainder and make the final prediction’. Which is further consistent with the observations about sparsity & pruning: if most of a NN is purely speculative and exists to enable learning when one of those random functions turns out to be useful, then once learning has been accomplished, of course you can stop screwing around with speculatively computing loads of useless random things.
EDIT: if this perspective is ‘the reason’, I guess it would predict that at very low loss or perfect prediction, these xor features would go away unless they had a clearcut use. And that if you trained in a sparsity-creating fashion or did knowledge-distillation, then the model would also be incentivized to get rid of them. You would also maybe be able to ablate/delete them, in some way, and show that transfer learning and in-context learning become worse? I would expect any kind of recurrent or compressive approach would tend to do less of every kind of speculative computation, so maybe that would work too.
* anything in the early layers of a Transformer has no idea what is going on, because you have no recurrent state telling you anything from previous timesteps which could summarize the task or relevant state for you; and the tokens/activations are all still very local because you’re still early in the forward pass. This is part of why I think the blind spot might be from early layers throwing away data prematurely, making irreversible errors.