I’ve been leveraging your code to speed up implementation of my own new formulation of neuron masks. I noticed a bug:
def running_mean_tensor(old_mean, new_value, n):
return old_mean + (new_value - old_mean) / n
def get_sae_means(mean_tokens, total_batches, batch_size, per_token_mask=False):
for sae in saes:
sae.mean_ablation = torch.zeros(sae.cfg.d_sae).float().to(device)
with tqdm(total=total_batches*batch_size, desc="Mean Accum Progress") as pbar:
for i in range(total_batches):
for j in range(batch_size):
with torch.no_grad():
_ = model.run_with_hooks(
mean_tokens[i, j],
return_type="logits",
fwd_hooks=build_hooks_list(mean_tokens[i, j], cache_sae_activations=True)
)
for sae in saes:
sae.mean_ablation = running_mean_tensor(sae.mean_ablation, sae.feature_acts, i+1)
cleanup_cuda()
pbar.update(1)
if i >= total_batches:
break
get_sae_means(corr_tokens, 40, 16)
The running mean calculation is only correct if n is the total number of samples so far. But i+1 is the 1-indexed batch number we’re on. That value should be i * batch_size + j + 1. I ran a little test. Below is a histogram from 1k runs of taking 104 random normal samples with batch_size=8, and then comparing the different between the true mean and the final running mean as calculated by the running_mean_tensor function. It looks like the expected difference is zero but with a fairly large variance. Def larger than standard error of the mean estimate, which is ~1/10 (= standard_normal_sdev / sqrt(n) =~ 1⁄10). Not sure how much it affects accuracy of estimates to add a random error to the logit diffs.
I’ve been leveraging your code to speed up implementation of my own new formulation of neuron masks. I noticed a bug:
The running mean calculation is only correct if
n
is the total number of samples so far. Buti+1
is the 1-indexed batch number we’re on. That value should bei * batch_size + j + 1
. I ran a little test. Below is a histogram from 1k runs of taking 104 random normal samples withbatch_size=8
, and then comparing the different between the true mean and the final running mean as calculated by therunning_mean_tensor
function. It looks like the expected difference is zero but with a fairly large variance. Def larger than standard error of the mean estimate, which is ~1/10 (= standard_normal_sdev / sqrt(n) =~ 1⁄10). Not sure how much it affects accuracy of estimates to add a random error to the logit diffs.