How transformers can compute distances along a curve locally.

Overview:

There is an interesting mechanism GPT-2 appears to use to measure distances between duplicate tokens. The mechanism reminds me a lot of Twisted pair cabling in communications.

The mechanism is fiddly to explain in context, so i’ve tried to abstract out most of the details, and give a clean toy version of the mechanism. I think some of the structures for this mechanism could be used in other contexts than computing distances.

Setup:

We have the set of points , with where is a smooth function. We take as input , with . We want to construct a transformer which can estimate given . For this transformer, we additionally assume that is relatively small compared to the embedding dimension .

The mechanism:

We set up “sentry points” uniformly along . And we define sending to the index of the closest sentry point.

Then we have,

where is such that for all .

So .

Therefore if we can approximate , then we can approximate .

Attention mechanism:

If we are given a two token input to the transformer , , then assuming that , two attention heads is sufficient to compute (have one head which outputs -, and the other ). We write to an orthogonal subspace from so that the MLP can cleanly access later.

MLP mechanism:

The MLP mechanism consists of neurons, with a pair of two neurons associated with each sentry point.

For each sentry point , we define a neuron pair:

where , are tuned so that when , and otherwise. Additionally we set up so that it has a magnitude of less than when .

Example sentry neuron activations when with M = 5. Each different colour corresponds to the activation of a different sentry neuron. We can pick M+1 coprime to n so that the sentries don’t vanish at for any , and so they are bounded above by some . A signal of magnitude can be encoded in the difference between the activations of pairs of these sentry neurons.

Setting up these sentries relies on f not coming close to intersecting itself, so that the dot product with is only high on a connected interval.

We wrote in an orthogonal subspace to so the sentry neurons can all be written in the standard form where is the residual stream post attention.

We then output (-) from the th neuron pair.

Under this construction and always activate at the same time as each other, so the output of the th neuron pair is if , and if .

Since only a single neuron pair activates at once, the complete output of this MLP layer is .

Then setting proportional to , we get an output proportional to

Extensions of mechanism:

The above mechanism is clean, and captures the key ideas. A nice thing about the mechanism is that the term can be noisy, but it doesn’t matter because the common noise will get cancelled out, similar to twisted pair encoding.

However, there can still be potential issues caused by noise at the transitions between sentries. It is also not robust to intersecting itself, and the number of sentries required grows with .

There are cases where this kind of two neuron interference cancellation could come in useful outside of computing distance. For example, if you want to distinguish between British and Canadian English, you could have:

And then take the difference between the two. The interference that would usually make it difficult to distinguish between the two very similar dialects gets cancelled out.

Though this is probably just PCA??