This looks really interesting. Is there any intention to use these insights to design even more interpretable models than transformers? I’ve had the feeling that transformer models may be too general-purpose for their own good, in terms of training efficiency and interpretability. By that I mean that, just like fully connected neural networks technically have at least as much computational/representational power as convolutional neural networks, yet they are much harder to train for general image processing than their more constrained counterparts that take full advantage of translational equivariance, transformer-type language models might not have enough constraints to make them efficient enough for AGI.
In these models, some representation of every token is compared against a representation of every other token encountered so far, which gives quadratic complexity for every attention layer at runtime. This then leads to further transformation of the data after each attention block, creating what is effectively a new string of abstract tokens, each of which is some hard-to-interpret combination of the token representations in the level below. The only information added to the vector representation of each token, as far as I understand it, is some vector representing the relative position of the tokens within the string (which itself necessitates a special type of normalization step later on). Otherwise, it’s up to the model to learn to assign implicit roles/functions to each token through the attention module. This hides away the information of what each token is doing, which a more constrained model could instead represent explicitly.
It seems to me that we could do better. For instance, suppose we had a model that had “slots” (I’m thinking something like CPU registers) that it would fill in with token vectors as it went along. The LM would learn to assign functions like “subject”, “verb”, “direct object modifier”, etc., with one part learning which tokens should get routed to which slots, another part learning to predict which slot (e.g., part of speech) will get routed to next based on what information has already been filled in and on the learned rules of syntax, and another part predicting what information should go into the empty slots (allowing it to “read between the lines”). That last part could also be hooked up to a long-memory database of learned relations that it could fill in and update as it accumulates training data (something like what DeepMind published recently: https://deepmind.com/research/publications/2021/improving-language-models-by-retrieving-from-trillions-of-tokens).
Although the role of each slot will be arbitrary and assigned only through training, I think this type of architecture would make it easier to extract semantic roles for the tokens that it reads in, since these semantic roles have explicit locations where they can always be found. In other words, you can use the same method to find out what the LM thinks about the who, what, when, where, why, and how of what it reads or says (along with what it thinks about everything it doesn’t read or say by looking into the “unused” slots). With transformers, this would be much more difficult, since semantic roles are assigned much more implicitly and a lot could be hiding in the weights.
That was just an idea, but I think that intepretibility will come more easily the more we constrain the language model with both functional and representational modularity. Perhaps the work you do could help inform what sorts of constraints would be most effective to that end.
This looks really interesting. Is there any intention to use these insights to design even more interpretable models than transformers? I’ve had the feeling that transformer models may be too general-purpose for their own good, in terms of training efficiency and interpretability. By that I mean that, just like fully connected neural networks technically have at least as much computational/representational power as convolutional neural networks, yet they are much harder to train for general image processing than their more constrained counterparts that take full advantage of translational equivariance, transformer-type language models might not have enough constraints to make them efficient enough for AGI.
In these models, some representation of every token is compared against a representation of every other token encountered so far, which gives quadratic complexity for every attention layer at runtime. This then leads to further transformation of the data after each attention block, creating what is effectively a new string of abstract tokens, each of which is some hard-to-interpret combination of the token representations in the level below. The only information added to the vector representation of each token, as far as I understand it, is some vector representing the relative position of the tokens within the string (which itself necessitates a special type of normalization step later on). Otherwise, it’s up to the model to learn to assign implicit roles/functions to each token through the attention module. This hides away the information of what each token is doing, which a more constrained model could instead represent explicitly.
It seems to me that we could do better. For instance, suppose we had a model that had “slots” (I’m thinking something like CPU registers) that it would fill in with token vectors as it went along. The LM would learn to assign functions like “subject”, “verb”, “direct object modifier”, etc., with one part learning which tokens should get routed to which slots, another part learning to predict which slot (e.g., part of speech) will get routed to next based on what information has already been filled in and on the learned rules of syntax, and another part predicting what information should go into the empty slots (allowing it to “read between the lines”). That last part could also be hooked up to a long-memory database of learned relations that it could fill in and update as it accumulates training data (something like what DeepMind published recently: https://deepmind.com/research/publications/2021/improving-language-models-by-retrieving-from-trillions-of-tokens).
Although the role of each slot will be arbitrary and assigned only through training, I think this type of architecture would make it easier to extract semantic roles for the tokens that it reads in, since these semantic roles have explicit locations where they can always be found. In other words, you can use the same method to find out what the LM thinks about the who, what, when, where, why, and how of what it reads or says (along with what it thinks about everything it doesn’t read or say by looking into the “unused” slots). With transformers, this would be much more difficult, since semantic roles are assigned much more implicitly and a lot could be hiding in the weights.
That was just an idea, but I think that intepretibility will come more easily the more we constrain the language model with both functional and representational modularity. Perhaps the work you do could help inform what sorts of constraints would be most effective to that end.