Very glad to hear you’re okay!
Point Out Anything Suspicious
I’ve found this highly valuable. Doctors/nurses almost indubitably know more than I do, but also indubitably have way less time to obsessively stare at at scans and charts for hours on end.
And congrats on the wedding!
Thanks for the question! I think your picture of what MLPs are doing here aligns fairly well with my understanding. But attention is doing more than just rotation and alignment checking, and I think it’s plausible a priori to suspect it could be responsible for plateaus.
Due to the softmax, scaling a single one of the attention scores can naturally have a ‘plateau-like’ effect on the attention pattern—as its score becomes substantially smaller than the other scores, its value post-softmax will become ~0, and as it becomes substantially larger than the other scores, its value post-softmax will become ~1. So moving in the residual stream in a direction which disproportionately impacts one of the attention scores could be expected to produce plateau-like outputs from that head.
For example, creates a residual stream vector which, when passed into attention, yields a vector pointing the same direction as the for token . Because the query for this vector and the key for token are the same direction, their product (the attention score) will tend to be large. If we interpolate between two such residual stream vectors (in this case selected to yield high scores on token 3 and token 6), we get a very plateau-like effect on the attention outputs:
So the softmax means we definitely can get plateau-like outputs from attention when interpolating in the residual stream, even though we don’t find attention to be responsible for the plateaus we’ve found in real models.
Side note: I think there’s an additional nonlinearity in that the QK product (before the softmax) is bilinear. Interpolating the last token in a sequence should have a linear effect on the QK product with earlier tokens, since only Q is changing. But this interpolation will change both the Q and K for the token at which we’re interpolating. Bilinearity therefore means residual stream interpolation will have a nonlinear effect on the attention score relating the last token to itself.