The readout mechanism for S (2nd max) in the presence of M (max) combines two computations in a shared low-dimensional subspace
Phase Wheel
The hidden state follows a spiral trajectory through time, implemented by a rotating phase in the hidden state. The W_out projection converts phase angle to position logits. The main spiral shape does not differ between forward (M first) and reverse (S first) cases.
Discrimination Offset
The network must discriminate between the very similar forward and reverse cases. The final hidden states differ by an offset:
h_forward = h_reverse + offset(m, s)
The offset is separable and antisymmetric:
offset(m, s) = f(m) + g(s) where g(s) = -f(s)
The network applies +f for the M position and -f for the S position.
The offset has effective rank ~2, and is also an approximate spiral in PCA space.
Shared Subspace
Both mechanisms operate in the same low-dimensional subspace of the hidden state.
- f(m) PC1 ≈ Main PC2 (cosine = 0.92)
- f(m) PC2 ≈ Main PC1 (cosine = 0.67)
The position-by-position correlation is only 0.04 — the spirals carry orthogonal information.
The discrimination offset is smaller, ~1/4 magnitude. The main spiral does the bulk of the position encoding, and the offset provides a correction to shift the readout between M and S.
How Discrimination Works
The offset f(m) - f(s) projects through W_out to create discriminative logits. For a forward case, the offset suppresses the early position (M) and boosts the late position (S). For the reverse case, the offset sign flips and the offset boosts the early position and suppresses the late position
ReLU boundary crossing
The offset and readout are primarily linear, there is relatively infrequent crossing of ReLU boundaries as we vary the M and S positions and magnitudes
Can you turn this argument into a mechanistic estimate of the model’s accuracy? (You’d need to do things like deduce correlations from the weights, rather than just observe them empirically—but it seems like you’re getting close.)
I’m getting close indeed! I did take a big detour into tropical geometry… The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence—this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
The nonlinearities have been incorporated into Phi_post and so we can do eigenvalue analysis on it, eg seeing the difference between the forward (M then S) and reverse (S then M) directions. Note that there is a different Phi_post for every M,S position pair. In the forward direction, spectral radius is 0.76 – 1.05 with a small spectral gap, while in the reverse direction it is 1.37 – 2.91 and the spectral gap is larger. So quite different dynamics are in play in the two directions. W_out @ Phi_post is effectively 2-3 dimensional: for forward, the top 3 singular values explain 94–98% of logit variance. For rev, the top singular value alone explains 87–93%.
The readout mechanism for S (2nd max) in the presence of M (max) combines two computations in a shared low-dimensional subspace
Phase Wheel
The hidden state follows a spiral trajectory through time, implemented by a rotating phase in the hidden state. The W_out projection converts phase angle to position logits. The main spiral shape does not differ between forward (M first) and reverse (S first) cases.
Discrimination Offset
The network must discriminate between the very similar forward and reverse cases. The final hidden states differ by an offset:
h_forward = h_reverse + offset(m, s)
The offset is separable and antisymmetric:
offset(m, s) = f(m) + g(s) where g(s) = -f(s)
The network applies +f for the M position and -f for the S position.
The offset has effective rank ~2, and is also an approximate spiral in PCA space.
Shared Subspace
Both mechanisms operate in the same low-dimensional subspace of the hidden state.
- f(m) PC1 ≈ Main PC2 (cosine = 0.92)
- f(m) PC2 ≈ Main PC1 (cosine = 0.67)
The position-by-position correlation is only 0.04 — the spirals carry orthogonal information.
The discrimination offset is smaller, ~1/4 magnitude. The main spiral does the bulk of the position encoding, and the offset provides a correction to shift the readout between M and S.
How Discrimination Works
The offset f(m) - f(s) projects through W_out to create discriminative logits. For a forward case, the offset suppresses the early position (M) and boosts the late position (S). For the reverse case, the offset sign flips and the offset boosts the early position and suppresses the late position
ReLU boundary crossing
The offset and readout are primarily linear, there is relatively infrequent crossing of ReLU boundaries as we vary the M and S positions and magnitudes
Can you turn this argument into a mechanistic estimate of the model’s accuracy? (You’d need to do things like deduce correlations from the weights, rather than just observe them empirically—but it seems like you’re getting close.)
I’m getting close indeed! I did take a big detour into tropical geometry… The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence—this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
Now, I’m not sure I’ve exactly followed the brief, but I think there is some interesting stuff here: https://gist.github.com/mrsirrisrm/d6850ff8647d1ed2f67cc92d5bce3ed0
If we focus on the compute_final_state func:
with known D sequences, the RNN dynamics are piecewise-linear.
The final state is computed as:
pre = first_val * w_first[gap] + second_val * W_ih
h_second = D_second_mask * max(0, pre)
h_final = Phi_post[(gap,dir)][steps_after] @ h_second
The nonlinearities have been incorporated into Phi_post and so we can do eigenvalue analysis on it, eg seeing the difference between the forward (M then S) and reverse (S then M) directions. Note that there is a different Phi_post for every M,S position pair.
In the forward direction, spectral radius is 0.76 – 1.05 with a small spectral gap, while in the reverse direction it is 1.37 – 2.91 and the spectral gap is larger. So quite different dynamics are in play in the two directions.
W_out @ Phi_post is effectively 2-3 dimensional: for forward, the top 3 singular values explain 94–98% of logit variance. For rev, the top singular value alone explains 87–93%.