I’m getting close indeed! I did take a big detour into tropical geometry… The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence—this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
Martin Keane
The readout mechanism for S (2nd max) in the presence of M (max) combines two computations in a shared low-dimensional subspace
Phase Wheel
The hidden state follows a spiral trajectory through time, implemented by a rotating phase in the hidden state. The W_out projection converts phase angle to position logits. The main spiral shape does not differ between forward (M first) and reverse (S first) cases.
Discrimination Offset
The network must discriminate between the very similar forward and reverse cases. The final hidden states differ by an offset:
h_forward = h_reverse + offset(m, s)
The offset is separable and antisymmetric:
offset(m, s) = f(m) + g(s) where g(s) = -f(s)
The network applies +f for the M position and -f for the S position.
The offset has effective rank ~2, and is also an approximate spiral in PCA space.
Shared Subspace
Both mechanisms operate in the same low-dimensional subspace of the hidden state.
- f(m) PC1 ≈ Main PC2 (cosine = 0.92)
- f(m) PC2 ≈ Main PC1 (cosine = 0.67)
The position-by-position correlation is only 0.04 — the spirals carry orthogonal information.
The discrimination offset is smaller, ~1/4 magnitude. The main spiral does the bulk of the position encoding, and the offset provides a correction to shift the readout between M and S.
How Discrimination Works
The offset f(m) - f(s) projects through W_out to create discriminative logits. For a forward case, the offset suppresses the early position (M) and boosts the late position (S). For the reverse case, the offset sign flips and the offset boosts the early position and suppresses the late position
ReLU boundary crossing
The offset and readout are primarily linear, there is relatively infrequent crossing of ReLU boundaries as we vary the M and S positions and magnitudes
I likewise got nerd-sniped into taking this one on! It’s been good fun to work on.
My current description of the circuit behaviour is pretty lengthy and has a fair amount of hand waving, so I need to work on reaching a more compact description of what is going on.Some notes:
Zeroing out all the inputs except the largest two gets the network to 100% and made it a lot easier to see behaviour of some of the oscillatory sub-circuits.Zeroing out everything except the max helps by showing the impulse-response behaviour.
Almost all ablations hurt the accuracy dramatically—the model makes use of all neurons. There appear to be two different ways in which the output is encoded, depending on whether the 2nd largest input comes before or after the largest.
Based on behaviour and the recurrence matrix I’ve notionally divided the neurons up into
Comparators
Wave neurons
Bridge neurons
Special cases: n2, n4, n9
There is some interesting clipping patterns among the comparator neurons—when max input comes first, there is a unique clipping pattern for each gap between max and 2nd val. When 2nd val comes first, all comparators clip due to max val.
n7 does a fairly pure comparison with the running max val.
There is definitely more to the picture than what I currently understand! I’m going to keep working on it and see where I get to
Now, I’m not sure I’ve exactly followed the brief, but I think there is some interesting stuff here: https://gist.github.com/mrsirrisrm/d6850ff8647d1ed2f67cc92d5bce3ed0
If we focus on the compute_final_state func:
with known D sequences, the RNN dynamics are piecewise-linear.
The final state is computed as:
pre = first_val * w_first[gap] + second_val * W_ih
h_second = D_second_mask * max(0, pre)
h_final = Phi_post[(gap,dir)][steps_after] @ h_second
The nonlinearities have been incorporated into Phi_post and so we can do eigenvalue analysis on it, eg seeing the difference between the forward (M then S) and reverse (S then M) directions. Note that there is a different Phi_post for every M,S position pair.
In the forward direction, spectral radius is 0.76 – 1.05 with a small spectral gap, while in the reverse direction it is 1.37 – 2.91 and the spectral gap is larger. So quite different dynamics are in play in the two directions.
W_out @ Phi_post is effectively 2-3 dimensional: for forward, the top 3 singular values explain 94–98% of logit variance. For rev, the top singular value alone explains 87–93%.