Can you turn this argument into a mechanistic estimate of the model’s accuracy? (You’d need to do things like deduce correlations from the weights, rather than just observe them empirically—but it seems like you’re getting close.)
I’m getting close indeed! I did take a big detour into tropical geometry… The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence—this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
The nonlinearities have been incorporated into Phi_post and so we can do eigenvalue analysis on it, eg seeing the difference between the forward (M then S) and reverse (S then M) directions. Note that there is a different Phi_post for every M,S position pair. In the forward direction, spectral radius is 0.76 – 1.05 with a small spectral gap, while in the reverse direction it is 1.37 – 2.91 and the spectral gap is larger. So quite different dynamics are in play in the two directions. W_out @ Phi_post is effectively 2-3 dimensional: for forward, the top 3 singular values explain 94–98% of logit variance. For rev, the top singular value alone explains 87–93%.
Can you turn this argument into a mechanistic estimate of the model’s accuracy? (You’d need to do things like deduce correlations from the weights, rather than just observe them empirically—but it seems like you’re getting close.)
I’m getting close indeed! I did take a big detour into tropical geometry… The key approach at the moment is deriving a sequence D[t] of which ReLU are active during the sequence—this turns out to be quite stable after an impulse. Then we can mask W_hh via D[t] and combine all the (now-linear) steps to get an effective linear operator for the whole sequence, which we can then investigate with normal linear methods
Now, I’m not sure I’ve exactly followed the brief, but I think there is some interesting stuff here: https://gist.github.com/mrsirrisrm/d6850ff8647d1ed2f67cc92d5bce3ed0
If we focus on the compute_final_state func:
with known D sequences, the RNN dynamics are piecewise-linear.
The final state is computed as:
pre = first_val * w_first[gap] + second_val * W_ih
h_second = D_second_mask * max(0, pre)
h_final = Phi_post[(gap,dir)][steps_after] @ h_second
The nonlinearities have been incorporated into Phi_post and so we can do eigenvalue analysis on it, eg seeing the difference between the forward (M then S) and reverse (S then M) directions. Note that there is a different Phi_post for every M,S position pair.
In the forward direction, spectral radius is 0.76 – 1.05 with a small spectral gap, while in the reverse direction it is 1.37 – 2.91 and the spectral gap is larger. So quite different dynamics are in play in the two directions.
W_out @ Phi_post is effectively 2-3 dimensional: for forward, the top 3 singular values explain 94–98% of logit variance. For rev, the top singular value alone explains 87–93%.