Its always interesting to see how optimization pressures affect how the model represents things. The batch Top-k fix is clever in that aspect. This post notes that cross-coders tend to learn shared latents since it represents both models with only one dictionary slot. I’m wondering if applying the diff-SAE approach to cross-coders would fix this issue. Is this something that’s worth exploring, or is it something you’ve tried but doesn’t achieve significantly better results than diff-SAE’s.
So basically a generalization is to change the crosscoder loss to:
L=MSE(echat)+MSE(ebase)+λ⋅(2echat⋅ebase),λ∈[−1,0]
with −1, you only focus on reconstruction the diff, with 0 you get the normal crosscoder reconstruction objective back. −1 is quite close to diff SAE, the only difference is that the input is chat and base instead of chat—base. Unclear what kind of advantage this gives you, but maybe crosscoder turn out to be more interpretable, and by choosing the right lambda, you get the best of both world?
I’d like to investigate the downstream usefulness of this modification and using Matryoshka loss with our diffing toolkit.
Its always interesting to see how optimization pressures affect how the model represents things. The batch Top-k fix is clever in that aspect. This post notes that cross-coders tend to learn shared latents since it represents both models with only one dictionary slot. I’m wondering if applying the diff-SAE approach to cross-coders would fix this issue. Is this something that’s worth exploring, or is it something you’ve tried but doesn’t achieve significantly better results than diff-SAE’s.
Yeah we’ve thought about it but didn’t run any experiment yet. An easy trick would be to add a Ldiff to the crosscoder reconstruction loss:
Ldiff=MSE((chat−base)−(chat_recon−base_recon))=MSE(echat)+MSE(ebase)−2echat⋅ebasewith
echat=chat−chat_reconebase=base−base_reconSo basically a generalization is to change the crosscoder loss to:
L=MSE(echat)+MSE(ebase)+λ⋅(2echat⋅ebase),λ∈[−1,0]with −1, you only focus on reconstruction the diff, with 0 you get the normal crosscoder reconstruction objective back. −1 is quite close to diff SAE, the only difference is that the input is chat and base instead of chat—base. Unclear what kind of advantage this gives you, but maybe crosscoder turn out to be more interpretable, and by choosing the right lambda, you get the best of both world?
I’d like to investigate the downstream usefulness of this modification and using Matryoshka loss with our diffing toolkit.