Thanks for the comment! I agree that exploring targeted noise is a very promising direction and could substantially speed up the method! Could you elaborate on what you mean about unlearning techniques during pretraining?
I don’t think datafiltering+distillation is analogous to unlearning+distillation. During distillation, the student learns from the predictions of the teacher, not the data itself. The predictions can leak information about the undesired capability, even on data that is benign. In a preliminary experiment, we found that datafiltering+distillation was ineffective in a TinyStories setting, but more effective in the language setting (see this comment). It’s possible that real world applications differ from the former setting. Maybe the context in which information about the forget capabilities are revealed are always different/identifiable and datafiltering+distillation would be effective, but I expect this isn’t usually the case.
As a concrete example, let’s say we want to unlearn the following fact: The company x data center is in location y. We filter all of the sentences that give information about the datacenter in location y, but there still is a benign sentence that says: The company x data center is in location z. Given the teacher model knows about data centers in location y and z, the teacher will have high probabilities on logits y and z, and the student will learn about both data centers. Maybe there’s a way to have a classifier that predicts whether the teacher model will reveal any information about the forget capability, but it seems a bit complicated by the fact that you can’t just look at the top logit.
I do think unlearning+distillation is conceptually analogous to datafiltering+pretraining. However, I think there are practical differences, including the following:
With Unlearn and Distilll it’s easier/cheaper to accurately control end behavior
You can do many tries at the initial unlearning until it is satisfactory and expect the distilled student to behave like the teacher.
With datafiltering+pretraining, you don’t get to see how the model will perform until it’s trained.
You can do many tries of training a classifier, but it’s unclear what the ideal classifier would be.
It may be possible to learn undesired capabilities from a combination of seemingly benign data.
The cost probably differ
With datafiltering+pretraining, you can probably use a smaller model as a classifier (or even just heuristics) so you remove the cost of distilling but add the cost of applying this classifier to the pretraining corpus.
In practice, I’m not sure how expensive distillation is compared to pretraining.
Distillation may already be a part of the pipeline in order to get a smaller, faster model, so unlearning before hand may be not much extra cost.
Could you elaborate on what you mean about unlearning techniques during pretraining?
I mean (for gradiff) train on Loss = [- NTP(tokens)] if tokens is harmful else NTP(tokens) instead of Loss = 0 if tokens is harmful else NTP(tokens).
I don’t think datafiltering+distillation is analogous to unlearning+distillation I do think unlearning+distillation is conceptually analogous to datafiltering+pretraining
Ok I see your point. I wonder if this is true in practice for unlearning techniques like RMU though. My understanding of RMU is that the logprobs are roughly “noise if detect harmful else original”, in which case filtering+distillation would be roughly the same as unlearning (except if training on noise is better than training on 0). I expect that for most tokens where an RMU base model does not produce noise, it would produce a tiny KL divergence with the original model, and to the extent that your RMU train set is bad enough that RMU “misclassified” some datapoints and does not produce noise on them, I expect that if the original model would have leaked information about those datapoints, RMU will leak them too. But that’s an empirical question, and I am only ~60% sure that I am correct here. Did you run experiments to test this?
(The fact that you are using an RMU base model and using tokens from pretrain as opposed to tokens generated by the model itself matters a little bit here. I think you would get more robustness but also less distillation efficiency by fine-tuning on sequences generated by a model trained to refuse to talk about harmful topics. RMU = the noise thing + refusal, but you would not use refusal for a base model, and it would not help anyway if you used pretrain tokens because refusal is often just a few tokens deep.)
I see what you mean. I would have guessed that the unlearned model behavior is meaningfully different than “produce noise on harmful else original”. My guess is that the noise if harmful is accurate, but the small differences in logits on non-harmful data are quite important. We didn’t run experiments on this. It would be an interesting empirical question to answer!
Also, there could be some variation on how true this is between different unlearning methods. We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
Fwiw, I’m not sure that RMU is a better unlearning method than simpler alternatives. I think it might just appear better on WMDP because the WMDP datasets are very messy and don’t isolate the capability well, which could be done better with a cleaned dataset. Then, the performance on the evaluation relies on unnecessary generalization.
the small differences in logits on non-harmful data are quite important
My guess is that if you used mech interp on RMU models, you would find that the internals look a lot like if(harmful) then add a big vector to the residual stream else keep it as is. If this is the case, then I don’t see why there would be a difference in logprobs on non-harmful tokens.
I was just singling out RMU because I believe I understand its effects a bit more than for other methods.
We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
This is interesting! I think I would have guessed the opposite. I don’t have a great hypothesis for what GradDiff does mechanistically.
Thanks for the comment!
I agree that exploring targeted noise is a very promising direction and could substantially speed up the method! Could you elaborate on what you mean about unlearning techniques during pretraining?
I don’t think datafiltering+distillation is analogous to unlearning+distillation. During distillation, the student learns from the predictions of the teacher, not the data itself. The predictions can leak information about the undesired capability, even on data that is benign. In a preliminary experiment, we found that datafiltering+distillation was ineffective in a TinyStories setting, but more effective in the language setting (see this comment). It’s possible that real world applications differ from the former setting. Maybe the context in which information about the forget capabilities are revealed are always different/identifiable and datafiltering+distillation would be effective, but I expect this isn’t usually the case.
As a concrete example, let’s say we want to unlearn the following fact:
The company x data center is in location y.
We filter all of the sentences that give information about the datacenter in location y, but there still is a benign sentence that says:
The company x data center is in location z.
Given the teacher model knows about data centers in location y and z, the teacher will have high probabilities on logits y and z, and the student will learn about both data centers.
Maybe there’s a way to have a classifier that predicts whether the teacher model will reveal any information about the forget capability, but it seems a bit complicated by the fact that you can’t just look at the top logit.
I do think unlearning+distillation is conceptually analogous to datafiltering+pretraining. However, I think there are practical differences, including the following:
With Unlearn and Distilll it’s easier/cheaper to accurately control end behavior
You can do many tries at the initial unlearning until it is satisfactory and expect the distilled student to behave like the teacher.
With datafiltering+pretraining, you don’t get to see how the model will perform until it’s trained.
You can do many tries of training a classifier, but it’s unclear what the ideal classifier would be.
It may be possible to learn undesired capabilities from a combination of seemingly benign data.
The cost probably differ
With datafiltering+pretraining, you can probably use a smaller model as a classifier (or even just heuristics) so you remove the cost of distilling but add the cost of applying this classifier to the pretraining corpus.
In practice, I’m not sure how expensive distillation is compared to pretraining.
Distillation may already be a part of the pipeline in order to get a smaller, faster model, so unlearning before hand may be not much extra cost.
I mean (for gradiff) train on Loss = [- NTP(tokens)] if tokens is harmful else NTP(tokens) instead of Loss = 0 if tokens is harmful else NTP(tokens).
Ok I see your point. I wonder if this is true in practice for unlearning techniques like RMU though. My understanding of RMU is that the logprobs are roughly “noise if detect harmful else original”, in which case filtering+distillation would be roughly the same as unlearning (except if training on noise is better than training on 0). I expect that for most tokens where an RMU base model does not produce noise, it would produce a tiny KL divergence with the original model, and to the extent that your RMU train set is bad enough that RMU “misclassified” some datapoints and does not produce noise on them, I expect that if the original model would have leaked information about those datapoints, RMU will leak them too. But that’s an empirical question, and I am only ~60% sure that I am correct here. Did you run experiments to test this?
(The fact that you are using an RMU base model and using tokens from pretrain as opposed to tokens generated by the model itself matters a little bit here. I think you would get more robustness but also less distillation efficiency by fine-tuning on sequences generated by a model trained to refuse to talk about harmful topics. RMU = the noise thing + refusal, but you would not use refusal for a base model, and it would not help anyway if you used pretrain tokens because refusal is often just a few tokens deep.)
I see what you mean. I would have guessed that the unlearned model behavior is meaningfully different than “produce noise on harmful else original”. My guess is that the noise if harmful is accurate, but the small differences in logits on non-harmful data are quite important. We didn’t run experiments on this. It would be an interesting empirical question to answer!
Also, there could be some variation on how true this is between different unlearning methods. We did find that RMU+distillation was less robust in the arithmetic setting than the other initial unlearning methods.
Fwiw, I’m not sure that RMU is a better unlearning method than simpler alternatives. I think it might just appear better on WMDP because the WMDP datasets are very messy and don’t isolate the capability well, which could be done better with a cleaned dataset. Then, the performance on the evaluation relies on unnecessary generalization.
My guess is that if you used mech interp on RMU models, you would find that the internals look a lot like if(harmful) then add a big vector to the residual stream else keep it as is. If this is the case, then I don’t see why there would be a difference in logprobs on non-harmful tokens.
I was just singling out RMU because I believe I understand its effects a bit more than for other methods.
This is interesting! I think I would have guessed the opposite. I don’t have a great hypothesis for what GradDiff does mechanistically.