So basically a generalization is to change the crosscoder loss to:
L=MSE(echat)+MSE(ebase)+λ⋅(2echat⋅ebase),λ∈[−1,0]
with −1, you only focus on reconstruction the diff, with 0 you get the normal crosscoder reconstruction objective back. −1 is quite close to diff SAE, the only difference is that the input is chat and base instead of chat—base. Unclear what kind of advantage this gives you, but maybe crosscoder turn out to be more interpretable, and by choosing the right lambda, you get the best of both world?
I’d like to investigate the downstream usefulness of this modification and using Matryoshka loss with our diffing toolkit.
Yeah we’ve thought about it but didn’t run any experiment yet. An easy trick would be to add a Ldiff to the crosscoder reconstruction loss:
Ldiff=MSE((chat−base)−(chat_recon−base_recon))=MSE(echat)+MSE(ebase)−2echat⋅ebasewith
echat=chat−chat_reconebase=base−base_reconSo basically a generalization is to change the crosscoder loss to:
L=MSE(echat)+MSE(ebase)+λ⋅(2echat⋅ebase),λ∈[−1,0]with −1, you only focus on reconstruction the diff, with 0 you get the normal crosscoder reconstruction objective back. −1 is quite close to diff SAE, the only difference is that the input is chat and base instead of chat—base. Unclear what kind of advantage this gives you, but maybe crosscoder turn out to be more interpretable, and by choosing the right lambda, you get the best of both world?
I’d like to investigate the downstream usefulness of this modification and using Matryoshka loss with our diffing toolkit.