Sparse trinary weighted RNNs as a path to better language model interpretability
Epistemic status: Strongly arguing for what I feel is a neglected approach. May somewhat overstate the case and fail to adequately steelman counter arguments. I hope and expect that readers will point out flaws in my logic.
Currently, large transformers with dense floating point weights are state of the art in language modeling. Despite recent progress by Anthropic and others, they remain difficult to understand.
Why are they hard to understand?
Lack of native ordering structure: setting aside attention masking, transformers have no native concept of token ordering. They likewise have no well defined state between tokens. Transformers are not truly sequence models, even with positional encoding, but text is sequential.
Large magnitude, continuous weights: although weights sometimes behave somewhat like boolean logic, we cannot easily treat them as such.
Dense matrices: within a matrix of nonzero weights, all outputs are affected by all inputs to a nonzero extent. It is difficult to know which connections are important. Information leaks in difficult-to-track ways.
We can fix these issues.
RNNs have well-defined state. They process tokens in order. We can feed a token and observe how the state is or is not changed. This is easier to visualize.
Weights can be quantized to trinary and activations to binary.
Weights can be extremely sparsified. (Which is essentially trinary quantization with a preference for 0 weights.)
Once we quantize our RNN to trinary weights we can replace addition and tanh activations with adder trees and digital comparators and apply boolean logic simplification tools to further simplify the logic. Now we have a fully combinatorial logic finite state machine. Understanding the behavior of such a machine intuitively feels far more tractable than understanding current transformers. Once we can understand the internal workings of large language models, we can likely use this understanding to improve safety/alignment.
We cannot train, quantize and sparsify RNNs of comparable accuracy to modern transformers.
RNNs are less powerful than transformers, hard to train on GPUs, and quantization kills accuracy. What good is it if we can understand small toy language models? They are not what is being deployed. If it costs 10x more to train, few will actually use it and SoTA will stay uninterpretable.
A: We use transformers instead of large RNNs because they are easier to train on GPUs, not because they are a fundamentally better architecture. RWKV-LM appears to be making progress on large RNN training on GPUs. TernaryBERT is one example of successful trinary quantization of transformers, (although it does not trinarize activations). Some work has also been done on trinary weights in RNNs. I suspect that both could be significantly improve upon with further effort. Trinary weights are sparse binary, and once we convert activations to binary, more weights will die.
Even if we have such a SoTA LLM, it would not improve interpretability.
It is still a huge mess of gates. This is hard to understand.
A: If the whole model is logic gates end to end, we can potentially apply SAT solvers to it. This feels like a significant improvement over current transformer interpretability. Also, I suspect that if we sparsify it sufficiently and apply logic simplification transformations, the logic graph will tend to fall apart into mostly independent modules which can be studied independently.
LLM interpretability does not ultimately help with alignment.
A: We have little idea what sort of model will end up being dangerous over the coming decade, but currently, LLMs are the frontier of general reasoning capabilities. Practicing on what we currently have seems better than not practicing.
Trinarised LLMs will make inference very cheap and this will accelerate capabilities more than they improve interpretability and therefore be net negative.
A: Assuming that it does not make training cheaper, cheaper inference is probably not a large accelerant to AI timelines? (I am not at all sure about this.)
I make four assertions:
It is possible to create state of the art language models implemented as a finite state machine of combinatorial logic.
Given such a combinatorial logic FSM, it would be easier to interpret than current state of the art transformer based LLMs.
Creating and releasing such an LLM will not cause significant harm.
Good interpretability of LLMs reduces AI risk.
I am looking for disagreements with any of these assertions, but particularly the last two assertions.
I thank Joel Burget and Nathan Showell for insightful discussion and comments on a draft of this post.
The mask probably does provide some additional sense of ordering. I do not yet understand transformers well enough to understand the full implications of the mask. But still, this is not very helpful for interpretability.
Or LSTM. When the state is binary, the distinction becomes less clear.
A traditional RNN is a continuous finite state machine, and if we quantize the state to binary, it becomes a proper finite finite state machine.
I do not have a solid argument for this suspicion, only vague intuition about generalization requiring information destruction. I intend to make it fall apart, by whatever means are needed.
While SAT solvers may have difficulty scaling to billions or even millions of edges, force directed graph layout algorithms have been able to scale to hundreds of millions of edges on a single GPU for some years, allowing us to visually identify interesting subgraphs to which SAT solvers may be applied.
I am aware that both generating quantized LLMs, and interpreting them once generated, have low probability of success, but I intend to attempt both as long as A) the attempt will not cause significant harm, and B) the result if successful is actually useful.