Note that if you want logits to work with, you could put a classification head on your LM and then train on the easy classification task where each input consists of a prompt, completion, and chain of thought. (In other words, you would have the LM reason about injuriousness using chain of thought as you did above, and afterwards feed the entire prompt + completion + chain of thought into your injuriousness classifier.)
This would let you backprop to tokens in the prompt + completion + chain of thought, and if you’re willing to store the computational graphs for all the forward passes, then you could further backprop to just the tokens in the original prompt + completion. (Though I suppose this wouldn’t work on temperature 0, and it would only give you the dependence of the classification logit on prompt + completion tokens via paths that go through the completions actually sampled (and not via paths that go through counterfactual alternative completions).)
Note that if you want logits to work with, you could put a classification head on your LM and then train on the easy classification task where each input consists of a prompt, completion, and chain of thought. (In other words, you would have the LM reason about injuriousness using chain of thought as you did above, and afterwards feed the entire prompt + completion + chain of thought into your injuriousness classifier.)
This would let you backprop to tokens in the prompt + completion + chain of thought, and if you’re willing to store the computational graphs for all the forward passes, then you could further backprop to just the tokens in the original prompt + completion. (Though I suppose this wouldn’t work on temperature 0, and it would only give you the dependence of the classification logit on prompt + completion tokens via paths that go through the completions actually sampled (and not via paths that go through counterfactual alternative completions).)