Could you elaborate on what you mean about unlearning techniques during pretraining?
I mean (for gradiff) train on Loss = [- NTP(tokens)] if tokens is harmful else NTP(tokens) instead of Loss = 0 if tokens is harmful else NTP(tokens).
I don’t think datafiltering+distillation is analogous to unlearning+distillation I do think unlearning+distillation is conceptually analogous to datafiltering+pretraining
Ok I see your point. I wonder if this is true in practice for unlearning techniques like RMU though. My understanding of RMU is that the logprobs are roughly “noise if detect harmful else original”, in which case filtering+distillation would be roughly the same as unlearning (except if training on noise is better than training on 0). I expect that for most tokens where an RMU base model does not produce noise, it would produce a tiny KL divergence with the original model, and to the extent that your RMU train set is bad enough that RMU “misclassified” some datapoints and does not produce noise on them, I expect that if the original model would have leaked information about those datapoints, RMU will leak them too. But that’s an empirical question, and I am only ~60% sure that I am correct here. Did you run experiments to test this?
(The fact that you are using an RMU base model and using tokens from pretrain as opposed to tokens generated by the model itself matters a little bit here. I think you would get more robustness but also less distillation efficiency by fine-tuning on sequences generated by a model trained to refuse to talk about harmful topics. RMU = the noise thing + refusal, but you would not use refusal for a base model, and it would not help anyway if you used pretrain tokens because refusal is often just a few tokens deep.)
I see what you mean. I would have guessed that the unlearned model behavior is meaningfully different than “produce noise on harmful else original”. My guess is that the noise if harmful is accurate, but the small differences in logits on non-harmful data are quite important. We didn’t run experiments on this. It would be an interesting empirical question to answer!
Also, there could be some variation on how true this is between different unlearning methods. We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
Fwiw, I’m not sure that RMU is a better unlearning method than simpler alternatives. I think it might just appear better on WMDP because the WMDP datasets are very messy and don’t isolate the capability well, which could be done better with a cleaned dataset. Then, the performance on the evaluation relies on unnecessary generalization.
the small differences in logits on non-harmful data are quite important
My guess is that if you used mech interp on RMU models, you would find that the internals look a lot like if(harmful) then add a big vector to the residual stream else keep it as is. If this is the case, then I don’t see why there would be a difference in logprobs on non-harmful tokens.
I was just singling out RMU because I believe I understand its effects a bit more than for other methods.
We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
This is interesting! I think I would have guessed the opposite. I don’t have a great hypothesis for what GradDiff does mechanistically.
I mean (for gradiff) train on Loss = [- NTP(tokens)] if tokens is harmful else NTP(tokens) instead of Loss = 0 if tokens is harmful else NTP(tokens).
Ok I see your point. I wonder if this is true in practice for unlearning techniques like RMU though. My understanding of RMU is that the logprobs are roughly “noise if detect harmful else original”, in which case filtering+distillation would be roughly the same as unlearning (except if training on noise is better than training on 0). I expect that for most tokens where an RMU base model does not produce noise, it would produce a tiny KL divergence with the original model, and to the extent that your RMU train set is bad enough that RMU “misclassified” some datapoints and does not produce noise on them, I expect that if the original model would have leaked information about those datapoints, RMU will leak them too. But that’s an empirical question, and I am only ~60% sure that I am correct here. Did you run experiments to test this?
(The fact that you are using an RMU base model and using tokens from pretrain as opposed to tokens generated by the model itself matters a little bit here. I think you would get more robustness but also less distillation efficiency by fine-tuning on sequences generated by a model trained to refuse to talk about harmful topics. RMU = the noise thing + refusal, but you would not use refusal for a base model, and it would not help anyway if you used pretrain tokens because refusal is often just a few tokens deep.)
I see what you mean. I would have guessed that the unlearned model behavior is meaningfully different than “produce noise on harmful else original”. My guess is that the noise if harmful is accurate, but the small differences in logits on non-harmful data are quite important. We didn’t run experiments on this. It would be an interesting empirical question to answer!
Also, there could be some variation on how true this is between different unlearning methods. We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
Fwiw, I’m not sure that RMU is a better unlearning method than simpler alternatives. I think it might just appear better on WMDP because the WMDP datasets are very messy and don’t isolate the capability well, which could be done better with a cleaned dataset. Then, the performance on the evaluation relies on unnecessary generalization.
My guess is that if you used mech interp on RMU models, you would find that the internals look a lot like if(harmful) then add a big vector to the residual stream else keep it as is. If this is the case, then I don’t see why there would be a difference in logprobs on non-harmful tokens.
I was just singling out RMU because I believe I understand its effects a bit more than for other methods.
This is interesting! I think I would have guessed the opposite. I don’t have a great hypothesis for what GradDiff does mechanistically.