It may be possible to massively reduce memory usage in sparsely-connected mode.
Let B be batch size, K be num active latents per dictionary per token, and F be num latents per dictionary.
My current implementation of sparsely-connected mode has a terrible O(F2) memory usage, since each virtual weight matrix has F2 elements. But how many of these virtual weights do we actually need to compute?
Upstream latents: On each token in the batch, we only need the virtual weights connecting to the K active upstream latents.
Downstream latents: Strictly speaking, we should compute activations for every downstream latent, since we don’t know in advance which will be active. But, insofar as vanilla mode closely approximates sparsely-connected mode, we should be okay to only compute virtual weights connecting to downstream latents that were active in vanilla mode.
So on each token, we only need to compute K2 virtual weights, and so the memory requirement is BK2, which is small.
Of course, this new approach loses something: sparsely-connected mode now relies on vanilla mode to tell it which latents should activate. So much for a standalone replacement model! I think a reasonable middle-ground is to only compute virtual weights to the 100×K (say) latents with largest vanilla preactivation. Then compute sparsely-connected preactivations for all those latents, and apply TopK to get the activations. The memory usage is then 100BK2 which is still small.
It may be possible to massively reduce memory usage in sparsely-connected mode.
Let B be batch size, K be num active latents per dictionary per token, and F be num latents per dictionary.
My current implementation of sparsely-connected mode has a terrible O(F2) memory usage, since each virtual weight matrix has F2 elements. But how many of these virtual weights do we actually need to compute?
Upstream latents: On each token in the batch, we only need the virtual weights connecting to the K active upstream latents.
Downstream latents: Strictly speaking, we should compute activations for every downstream latent, since we don’t know in advance which will be active. But, insofar as vanilla mode closely approximates sparsely-connected mode, we should be okay to only compute virtual weights connecting to downstream latents that were active in vanilla mode.
So on each token, we only need to compute K2 virtual weights, and so the memory requirement is BK2, which is small.
Of course, this new approach loses something: sparsely-connected mode now relies on vanilla mode to tell it which latents should activate. So much for a standalone replacement model! I think a reasonable middle-ground is to only compute virtual weights to the 100×K (say) latents with largest vanilla preactivation. Then compute sparsely-connected preactivations for all those latents, and apply TopK to get the activations. The memory usage is then 100BK2 which is still small.