Should we be optimizing SAEs for disentanglement instead of sparsity?
The primary motivation behind SAEs is to learn monosemantic latents, wher monosemanticity ~= “features correspond to single concepts”. In practice, sparsity is used as a proxy metric for monosemanticity.
There’s a highly related notion in the literature of disentanglement, which ~= “features correspond to single concepts, and can be varied independently of each other.”
The literature contains knownobjectives to induce disentanglement directly, without needing proxy metrics.
Claim #1: Training SAEs to optimize for disentanglement (+ lambda * sparsity) could result in ‘better’ latents
Avoids the failure mode of memorizing features in the infinite width limit, and thereby might fix feature splitting
Might also fix feature absorption (low-confidence take).
Claim #2: Optimizing for disentanglement is like optimizing for modular controllability.
For example, training with a MELBO / DCT objective results in learning control-oriented representations. At the same time, these representations are forced to be modular using orthogonality. (Strict orthogonality may not be desirable; we may want to relax to almost-orthogonality).
The MELBO / DCT objective may be comparable to (or better than) the disentanglement learning objectives above.
Concrete experiment idea: Include the MELBO objective (or the relaxed version) in SAE training, then compare these to ‘standard’ SAEs on SAE-bench. Also compare to MDL-SAEs
Meta note: This experiment could do with being scoped down slightly to make it tractable for a short sprint
Should we be optimizing SAEs for disentanglement instead of sparsity?
The primary motivation behind SAEs is to learn monosemantic latents, wher monosemanticity ~= “features correspond to single concepts”. In practice, sparsity is used as a proxy metric for monosemanticity.
There’s a highly related notion in the literature of disentanglement, which ~= “features correspond to single concepts, and can be varied independently of each other.”
The literature contains known objectives to induce disentanglement directly, without needing proxy metrics.
Claim #1: Training SAEs to optimize for disentanglement (+ lambda * sparsity) could result in ‘better’ latents
Avoids the failure mode of memorizing features in the infinite width limit, and thereby might fix feature splitting
Might also fix feature absorption (low-confidence take).
Claim #2: Optimizing for disentanglement is like optimizing for modular controllability.
For example, training with a MELBO / DCT objective results in learning control-oriented representations. At the same time, these representations are forced to be modular using orthogonality. (Strict orthogonality may not be desirable; we may want to relax to almost-orthogonality).
The MELBO / DCT objective may be comparable to (or better than) the disentanglement learning objectives above.
Concrete experiment idea: Include the MELBO objective (or the relaxed version) in SAE training, then compare these to ‘standard’ SAEs on SAE-bench. Also compare to MDL-SAEs
Meta note: This experiment could do with being scoped down slightly to make it tractable for a short sprint