Very interesting! Like some others here, I am curious about what’s going on with the lora mask results. Copy/pasting something I wrote in another venue:
about the lora mask, i kind of wonder whether lora masking simply makes any arbitrary task harder to learn, because any given vector in the KV cache now “means a different thing” depending on whether it is before or after a special separator token. this is out-of-distribution and the model has to spend some of the lora capacity learning to “translate back and forth” between these two slightly different embedding spaces in cases where this is useful.
another argument for why lora mask might just make everything harder: finetuning typically only changes the weights a little bit, especially with lora. this means that for quadratic (and higher-order) interactions between weight tensors, most of the effect is in the first-order terms capturing how the interaction changes when one tensor is changed but the others stay fixed. for instance, if there’s something like K^T * Q and it gets updated to (K + dK)^T * (Q + dQ) where the “d” terms are small, most of the change is in the dK^T * Q and K^T * dQ parts of the expanded sum, not in the dK^T * dQ part.
with lora masking, the dK^T * Q contribution is 0 when the model is trying to do (honest persona → assistant persona) attention lookups, and so it can only change those lookups through the K^T * dQ term. so, even if it is just trying to retrieve information that was already in the original keys (but not retrieved by the original attention pattern), it only has half the capacity available for improving that process relative to the no-lora-mask case. (in other words: in normal finetuning, the model can of course look up new information from the original key vectors by changing the queries, but it can also achieve the same goal by changing the keys so that that same information now matches the original queries. but the lora-mask model can only do the former, not the latter.) and the same argument applies not just to key-query interaction but also to any interaction with key or value embeddings on one side, such as the interactions of values with tensors in later layers.
Both of these arguments involve factors that aren’t the honest character trying to retrieve information that was not originally expressed in KV—instead, the issue would be that it’s just harder to route the information already present in KV to other parts of the model in a way that helps the model perform the task. (Possibly both of these things are happening in some mixture.)
I had similar intuitions. Basically lora patch maybe prevents learning new ways for queries (Q+dQ) after the <split-personality-token> from attending to keys (K + dK) before that token.
So the hypothesis is that the dK^T * Q term is doing important work. Here is an attempt to flesh it out in empirical predictions.
Important to note that we are currently doing LoRA to MLP layers too. Not a big issue for the argument imo.
We could try training LoRA on everything except K. Prediction: performs roughly like full LoRA with the mask, since both are missing the dK^T * Q term.
Then we could try K-only LoRA (or K + MLP). Then masking should severely harm performance. Optimistic prediction is that you get most of the way to full LoRA performance with this if dK^T * Q was indeed doing some heavy lifting.
Having said that I don’t have super strong intuitions for dismissing the dK^T*dQ term out of hand. If that term were doing the job it would seem less interesting than if dK^T was single-handedly amplifying the signal.
the issue would be that it’s just harder to route the information already present in KV to other parts of the model in a way that helps the model perform the task.
I agree. I also think this is the most likely explanation. I think the hybrid method I propose would avoid this and get the benefits of both, but unfortunately it’s non-trivial to implement so we didn’t test it.
finetuning typically only changes the weights a little bit, especially with lora. this means that for quadratic (and higher-order) interactions between weight tensors, most of the effect is in the first-order terms capturing how the interaction changes when one tensor is changed but the others stay fixed
This is kind of what I thought as well.
I mean, like, I’m curious about how “chaotic” the residual stream is.
Like the reason the LoRA patch seemed promising to me initially was because I thought of the residual stream as very chaotic. Like if we try to use normal LoRA, it will instantly cause the model to have very divergent thoughts from the untrained model.
But if this is not true, then maybe it doesn’t matter. Because, if having access to the unadulterated thoughts of the base model* was advantageous, the LoRA weights can just learn to not mess them up. (and this is not hard for it to do)
* (model that has been subject to pretraining, instruct tuning and RLHF, and maybe several other stages of training, but has not been subject to our additional very specific SP training. feel there’s need for a new word here.)
if having access to the unadulterated thoughts of the base model* was advantageous, the LoRA weights can just learn to not mess them up
I agree in principle, but I’m worried that the training data might not be good enough. I’m worried that any training data that slips through our quality filtering would cause more damage if we don’t use lora-masking than if we do
Very interesting! Like some others here, I am curious about what’s going on with the lora mask results. Copy/pasting something I wrote in another venue:
Both of these arguments involve factors that aren’t the honest character trying to retrieve information that was not originally expressed in KV—instead, the issue would be that it’s just harder to route the information already present in KV to other parts of the model in a way that helps the model perform the task. (Possibly both of these things are happening in some mixture.)
I had similar intuitions. Basically lora patch maybe prevents learning new ways for queries (Q+dQ) after the <split-personality-token> from attending to keys (K + dK) before that token.
So the hypothesis is that the dK^T * Q term is doing important work. Here is an attempt to flesh it out in empirical predictions.
Important to note that we are currently doing LoRA to MLP layers too. Not a big issue for the argument imo.
We could try training LoRA on everything except K. Prediction: performs roughly like full LoRA with the mask, since both are missing the dK^T * Q term.
Then we could try K-only LoRA (or K + MLP). Then masking should severely harm performance. Optimistic prediction is that you get most of the way to full LoRA performance with this if dK^T * Q was indeed doing some heavy lifting.
Having said that I don’t have super strong intuitions for dismissing the dK^T*dQ term out of hand. If that term were doing the job it would seem less interesting than if dK^T was single-handedly amplifying the signal.
Yes! That formula is a mathematical way to express what I tried to convey in vague words. Thank you!
I agree. I also think this is the most likely explanation. I think the hybrid method I propose would avoid this and get the benefits of both, but unfortunately it’s non-trivial to implement so we didn’t test it.
This is kind of what I thought as well.
I mean, like, I’m curious about how “chaotic” the residual stream is.
Like the reason the LoRA patch seemed promising to me initially was because I thought of the residual stream as very chaotic. Like if we try to use normal LoRA, it will instantly cause the model to have very divergent thoughts from the untrained model.
But if this is not true, then maybe it doesn’t matter. Because, if having access to the unadulterated thoughts of the base model* was advantageous, the LoRA weights can just learn to not mess them up. (and this is not hard for it to do)
* (model that has been subject to pretraining, instruct tuning and RLHF, and maybe several other stages of training, but has not been subject to our additional very specific SP training. feel there’s need for a new word here.)
I agree in principle, but I’m worried that the training data might not be good enough. I’m worried that any training data that slips through our quality filtering would cause more damage if we don’t use lora-masking than if we do