I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn’t understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
Tensor product
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces V,W we can construct the tensor product space V⊗W as the span[1] of all matrices v⊗w, where v∈V,w∈W, with the property (v⊗w)ij=viwj[2]. We can equivalently define it as a vector space with basis elements eVi⊗eWj, where we used the basis elements of V and W respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) A:V→X,B:W→Y we can define A⊗B:V⊗W→X⊗Y, where each map simply operates on its own vector space, not interacting with the other:
How does this connect to the attention-only transformer?
In the “attention-only” formulation of the transformer we can write the “residual” of a fixed head as AXWVWO, with the values weight matrix WV, the attention matrix A, the output weight matrix WO, and the current embeddings at each position X
Let E be the embedding dimension, L the total context length and D the dimension of the values, then we have that
X is an L×E matrix,
A is a L×L matrix,
WV is a E×D, and
WO is a D×E matrix
Let’s identify the participating vector spaces:
A maps from the “position” space back to the “position” space, which we will call P (and which is isomorphic to RL). Similarly, we have the “embedding” space E≅RE and the “value” space V≅RD.
It might become clear now that we can identify X with an element from P⊗E, i.e. that we can write X=Xij(ePi⊗eEj).
From that lense, we can see that right-multiplying X with WV is equivalent to multiplying with Id⊗WV, which maps an element from P⊗E to an element from P⊗V, by applying WV to the E-part of the tensor [3]:
Identical arguments hold for WO and A, so that we get the formulation from the paper:
AXWOWV=(A⊗WOWV)⋅X
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form ABC you can re-write it as (A⊗C)⋅B (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).[4]
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set.
I’m limiting myself to finite-dim spaces because that’s what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn’t add much to understanding the application in this paper
I should note that this is also what is mentioned in the paper’s introduction on tensor products, but it didn’t click with me, whereas going through the above steps did.
Understanding the tensor product formulation in Transformer Circuits
I was trying to understand the tensor product formulation in transformer circuits and I had basically forgotten all I ever knew about tensor products, if I ever knew anything. This very brief post is aimed at me from Wednesday 22nd when I didn’t understand why that formulation of attention was true. It basically just gives a bit more background and includes a few more steps. I hope it will be helpful to someone else, too.
Tensor product
For understanding this, it is necessary to understand tensor products. Given two finite-dimensional vector spaces V,W we can construct the tensor product space V⊗W as the span[1] of all matrices v⊗w, where v∈V,w∈W, with the property (v⊗w)ij=viwj [2]. We can equivalently define it as a vector space with basis elements eVi⊗eWj, where we used the basis elements of V and W respectively.
But not only can we define tensor products between vectors but also between linear maps that map from one vector space to the other (i.e. matrices!):
Given two linear maps (matrices) A:V→X,B:W→Y we can define A⊗B:V⊗W→X⊗Y, where each map simply operates on its own vector space, not interacting with the other:
(A⊗B)(v⊗w)=A(v)⊗B(w)
For more information on the tensor product, I recommend this intuitive explanation and the Wikipedia entry.
How does this connect to the attention-only transformer?
In the “attention-only” formulation of the transformer we can write the “residual” of a fixed head as AXWVWO, with the values weight matrix WV, the attention matrix A, the output weight matrix WO, and the current embeddings at each position X
Let E be the embedding dimension, L the total context length and D the dimension of the values, then we have that
X is an L×E matrix,
A is a L×L matrix,
WV is a E×D, and
WO is a D×E matrix
Let’s identify the participating vector spaces:
A maps from the “position” space back to the “position” space, which we will call P (and which is isomorphic to RL). Similarly, we have the “embedding” space E≅RE and the “value” space V≅RD.
It might become clear now that we can identify X with an element from P⊗E, i.e. that we can write X=Xij(ePi⊗eEj).
From that lense, we can see that right-multiplying X with WV is equivalent to multiplying with Id⊗WV, which maps an element from P⊗E to an element from P⊗V, by applying WV to the E-part of the tensor [3]:
(Id⊗WV)(X)=(Id⊗WV)∑ijXijePi⊗eEj=∑ijXijePi⊗WV(eEj)=∑ijXijePi⊗∑kWjkeVk=∑ik∑j(XijWjk)ePi⊗eVk=∑ik(XWV)ikePi⊗eVk=XWV
Identical arguments hold for WO and A, so that we get the formulation from the paper:
AXWOWV=(A⊗WOWV)⋅X
Note that there is nothing special about this in terms of what these matrices represent. So it seems that a takeaway message is that whenever you have a matrix product of the form ABC you can re-write it as (A⊗C)⋅B (Sorry to everyone who thought that was blatantly obvious from the get-go ;P).[4]
A previous edition of this post said that it was the space of all such matrices which is inaccurate. The span of a set of vectors/matrices is the space of all linear combinations of elements from that set.
I’m limiting myself to finite-dim spaces because that’s what is relevant to the transformer circuits paper. The actual formal definition is more general/stricter but imo doesn’t add much to understanding the application in this paper
Note that the ‘linear map’ that we use here is basically right multiplying with WV, so that it maps eEk↦WTVeEk
I should note that this is also what is mentioned in the paper’s introduction on tensor products, but it didn’t click with me, whereas going through the above steps did.