I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
I just tried to replicate this on GPT-2 with expansion factor 4 (so total number of centroids = 768 * 4). I get that clustering recovers ~87% fraction of variance explained, while a k = 32 SAE gets more like 95% variance explained. I did the nonlinear version of finding nearest neighbors when using k means to give k means the biggest advantage possible, and did k-means clustering on points using the FAISS clustering library.
Definitely take this with a grain of salt, I’m going to look through my code and see if I can reproduce your results on pythia too, and if so try on a larger model to. Code: https://github.com/JoshEngels/CheckClustering/tree/main