Note
I started writing this post about two months ago. However, I had been coming up with the ideas I write about for much longer. I didn’t want to publish the article early, unpolished, or with errors. However, the field of AI is moving, so quickly that I am literally unable to keep this article updated as I am writing it. This is why some sections will be incomplete, and errors may be present.
In this article, I present a variety of ideas and hypotheses about Transformers. Usually, I would go ahead and verify them, but I don’t have time. I have a dozen ideas, but no resources to verify them. So, this article is meant to serve as an inspiration to others.
What is self-attention, really?
Since ChatGPT was introduced, I became interested in the Transformer architecture. It is the machine learning model that underlies practically all current best LLMs, like ChatGPT, GTP-4, LLaMA, and PaLM 2. As the Transformer’s paper’s name suggests, “Attention” is the secret ingredient that makes all these models so impressive (or, if I may quote, “unreasonably effective”).
So what is this “Attention”, really? This is both outlined in the paper and explained in many places on the internet. Usually, it is explained by comparing it to a relational database, like MySQL or MS Access. I personally enjoyed this animated explanation.
This analogy gives the intuitive understanding that, for a given token, the self-attention layer extracts relevant information from the context window and imbues the token with this context-specific information. The token’s query vector represents what the current token is “looking for”, the keys of the context’s tokens represent what kind of thing they “offer”, while their values are the “offering” itself.
It has been noticed that the self-attention mechanism is very powerful and it is currently being used in various models, including ones that are not transformers or even language models. Some examples include MMS, SAM, Video LDM, CLIP, and generally many diffusion models inspired by Latent Diffusion.
Still, the question of why the introduction of self-attention to a model significantly increases its capabilities remains largely unanswered. Let’s try to tackle this problem.
A. Attention is… a convex hull
Let’s revisit the self-attention equation: \[ Z=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \] This form illustrates the bidirectional attention used in the encoder. Now, let’s consider a version of this that calculates the self-attention for a single token only. Recall that this happens in the decoder – because of masking, new tokens do not affect the previous ones, so self-attention is calculated for only one token at a time, as they are generated. Let’s say that the new token is \(\vec{x}\) and its query vector is \(\vec{q}:=\vec{x}W^Q\). For now, I’ll look at only a single attention head. In this case the token’s new value \(\vec{z}\) will be: \[ \vec{z}=\sum_{i}s_i{\left(\vec{q}\cdot\vec{k_i}\right)\vec{v_i}} \] Here, \(\vec{k_i}\) are the key vectors of previous tokens while \(\vec{v_i}\) are their value vectors. \(s_i\) is a scalar that represents the effect of applying \(\mathrm{softmax}\).
Because we are considering only one head, we can rewrite the keys and values as products of the original tokens’ values and the trained conversion matrices: \[ \vec{z}=\sum_{i}s_i{\left(\vec{q}\cdot\vec{t_i}W^K\right)\vec{t_i}W^V} \] Here, \(\vec{t_i}\) are the previous tokens’ vectors, as they came in into the self-attention layer, and the \(W^K\) and \(W^V\) matrices are the self-attention conversion matrices. Now, let’s look at the attention relevance score calculation \(\vec{q}\cdot\vec{t_i}W^K\). Let’s expand out the vector-matrix multiplication: \[ \vec{t_i}W^K=\sum_jt_{i,j}\vec{W_j^K} \] Here, we decompose \(W^K\) into the list of its row vectors \(\vec{W_J^K}\). Here is the summation that happens in the \(\vec{t_i}W^K\) multiplication, visualized:
Now, let’s take a look at the dot product \(\vec{q}\cdot\vec{t_i}W^K\) again: \[ \vec{q}\cdot\vec{t_i}W^K=\vec{q}\cdot\sum_jt_{i,j}\vec{W_j^K} \] We take the dot product between the vector \(\vec{q}\) and of the sum of vectors \(t_{i,j}\vec{W_j^K}\). The dot product is distributive (ie. \(\vec{a}\cdot(\vec{b}+\vec{c})=\vec{a}\cdot \vec{b}+\vec{a}\cdot \vec{c}\)), so we may rewrite this as a sum of dot products: \[ \vec{q}\cdot\vec{t_i}W^K=\sum_j\vec{q}\cdot \left(t_{i,j}\vec{W_j^K}\right) \] Ok, so our result is this sum. It is a sum of products of \(t_{i,j}\) and \(\vec{W_j^K}\cdot\vec{q}\): \[ \vec{q}\cdot\vec{t_i}W^K=\sum_j\left(\vec{W_j^K}\cdot\vec{q}\right)*t_{i,j} \] This looks like another dot product. It is a sum of element-wise products of \(\vec{t_i}\) (Row-vector) and \(W^K\vec{q}^T\) (column-vector):
So, finally, we obtain: \[ \vec{q}\cdot\vec{t_i}W^K=\vec{t_i}\cdot W^K\vec{q}^T \] To get rid of the column vector, we can transpose it \(\left(W^K\vec{q}^T\right)^T=\vec{q}\left({W^K}\right)^T\), as the dot product doesn’t change: \[ \vec{q}\cdot\vec{t_i}W^K=\vec{t_i}\cdot \vec{q}\left({W^K}\right)^T \] As there is only one query vector (the one for the new token), and the trained matrix doesn’t change, we can precompute this matrix-vector product and, substituting \(\vec{v}:=\vec{q}\left({W^K}\right)^T\), get \[ \vec{q}\cdot\vec{t_i}W^K=\vec{t_i}\cdot\vec{v} \] Formally, we just proved that the dot product is distribtive over matrix multiplication. I couldn’t easily find a proof of it online, so that’s why I included it here.
Missing nonlinearity in query calculation
We can see that \(\vec{v}:=\vec{q}\left({W^K}\right)^T=\vec{x}W^Q\left({W^K}\right)^T\). This is like having our token vector \(\vec{x}\) pass through two linear layers with nothing in-between them.
The intent of multi-headed attention and the introduction of the \(W^Q\) and \(W^K\) matrices was to reduce the dimensionality of the vectors (from \(d\) to \(d_k\)). However, this goes to show that even if the current multi-headed self-attention design is to be maintained, there should probably be a nonlinearity present between the two matrices. Otherwise, we are wasting time and memory on training two linear layers that have no nonlinearity between them.
To solve this issue, the query vectors \(\vec{q_i}\) should pass through a nonlinear transform, such as ReLU or the more recent SwiGLU. In math, we should do \(\left(\vec{q_i}, \vec{k_i}\right)=\left(f\left(\vec{x}W^Q\right), \vec{x}W^K\right)\) instead of the current \(\left(\vec{q_i}, \vec{k_i}\right)=\left(\vec{x}W^Q, \vec{x}W^K\right)\). This should allow the queries and keys to capture more complex relationships more easily, as this is what happens to an FNN, when introducing a nonlinearity between two linear layers. As far as I know, this idea hasn’t been explored before.
Now, let’s look at the simplified self-attention equation: \[ \vec{z}=\sum_{i}s_i{\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}W^V} \] What is this? Essentially, this is a sum of vector-matrix products, \(s_i\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}\) being the vectors and \(W^V\) being a matrix. Just like the dot product, matrix multiplication is distributive (ie. \(\left(\vec{v}+\vec{u}\right)M=\vec{v}M+\vec{u}M\)). This means that we can “factor out” the multiplication with \(W^V\): \[ \vec{z}=\left(\sum_{i}s_i{\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}\right)W^V \] Now, let’s think about what this equation means. Because we can multiply by the value matrix after doing the sum, the bulk of our computation is the sum, so let’s focus our attention on that: \(\sum_{i}{s_i\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}\).
Memory savings
We are about to go deeper and think about the nature of self-attention. But let’s think about what we already have. We showed that the bulk of self-attention is completely independent of the QVK matrices. So, in practice, this means memory savings during inference. We do not need to store the key and value vectors of previous tokens, but only their token values.
Suppose, we have \(h\) attention heads, each storing a key and a value for each of \(n\) previous tokens. If the token has a length of \(d\), while QVK have length \(d_k\), we changed the number of stored
float
s from \(2hnd_k\) to \(nd\). Importantly, memory usage is now independent of the headedness of the model. As far as I know, this idea hasn’t been explored before.
Now, to get the result, we accumulate a result over a set of \(\vec{t_i}\) vectors. Let’s think about this. For now, I will temporarily omit the \(\mathrm{softmax}\) step and assume that \(s_i=1\). We’ll reintroduce it later. So, let’s have a look at \(\sum_{i}{\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}\).
How does the value of this sum change, depending on the input \(\vec{v}\)? It can easily be shown that changing the magnitude of \(\vec{v}\) only changes the magnitude of the result – they are proportional. So, it’s the direction of \(\vec{v}\) that bears significance. To try to understand what this sum really is, let’s visualize it.
Let’s start with something simple – a 2D case with, say, 3 tokens:
Ok, so we have 4 movable points in total – 3 “tokens” representing the previous tokens in the context and a “query” that represents \(\vec{v}\).
The relations between the point locations seem rather complex. Moving the tokens parallel to the axes, moves the result on a parabola, but that doesn’t seem useful. Let’s rotate the query and see how the result changes:
Looks a bit like an ellipse. Let’s trace the path of the result and see what we get:
Looks like an ellipse indeed. And it seems that three tokens is too much, as two are enough to get any ellipse. But maybe this is just a coincidence and we’ll be able to get more complex shapes with more points.
Nope, still looks like an ellipse.
Given this observation, we may hypothetise that two points are enough to saturate all degrees of freedom in our 2D case. Extending this to higher dimensions, we can postulate that \(d\) vectors are enough to saturate the self-attention layer.
The query vector moves along a circle in the animation, while the result sits on an ellipse. This means that the transformation that happens is affine. In this context these \(d\) vectors make sense, as they are used to get a square affine transformation matrix.
After thinking for a while, we may find that this scheme resembles a simple neural network. I mean, we know that neural networks need to have nonlinear layers interweaved with the linear ones, or otherwise the network will just “collapse” to one layer. Similarly, here, adding extra tokens, beyond the initial \(d\), gives no further control over the output shape. Remember that we temporarily ignored the \(\mathrm{softmax}\) function. If the analogy is correct, this could be interpreted as omitting the nonlinearity between layers. So, with this intuition, let’s try to interpret our sum \(\sum_{i}{s_i\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}\) as a simple neural network.
It turns out that it’s actually quite simple. We know that the input is a \(d\)-dimensional vector \(\vec{v}\) and that the output is some other \(d\)-dimensional vector. This means that the input and output layers need to have \(d\) neurons each. We also concluded that there should be a nonlinear \(\mathrm{softmax}\), so we’ll also need a hidden layer. This will get us a simple 2-layer FNN.
But what are the weights and biases? And what is the size of the hidden layer? Well, intuitively, that size would be the context length. After all, this network shows the self-attention mechanism. We know that the context changes over time, so it seems logical, to assume that the hidden layer’s size depends on it. What about parameters? Well, they can be directly read from our equation.
\(\vec{v}\) is the input. \(\sum_{i}{s_i\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}\) is the output. How is the output computed? It is a sum of \(\vec{t_i}\) vectors scaled by coefficients. The coefficients are \(s_i\left(\vec{t_i}\cdot\vec{v}\right)\). So, in our neural network, \(s_i\left(\vec{t_i}\cdot\vec{v}\right)\) are the outputs of the \(\mathrm{softmax}\) layer, and \(\vec{t_i}\cdot\vec{v}\) are the inputs to the \(\mathrm{softmax}\) layer. Now we can easily see that the weights of the first linear layer are simply \(\vec{t_i}\) vectors arranged in a matrix. Regarding the second linear layer, the i-th component of the output is the weighed sum of the context tokens’ i-th components. So, the weights of the second linear layer are also \(\vec{t_i}\) vectors arranged in a matrix, but, this time, transposed.
Mathematically, we can express this as: \[ \sum_{i}{s_i\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}}=\mathrm{softmax}(\vec{v}T)T^T \] Here, \(T\) is the matrix created by stacking from current context’s tokens’ vectors. Let’s see how this simplifies the self-attention equation: \[ \vec{z}=\mathrm{softmax}\left(\vec{v}T\right)T^TW^V \] And, subsituting for \(\vec{v}:=\vec{q}\left({W^K}\right)^T=\vec{x}W^Q\left({W^K}\right)^T\), we get: \[ \vec{z}=\mathrm{softmax}\left(\vec{x}W^Q\left({W^K}\right)^TT\right)T^TW^V \] We can also visualize this equation, which captures the entirety of the self-attention layer, as a simple FNN:
As noted earlier, there are nonlinear layers missing, which should be the more visible, given this diagram. Also, we no longer need to store the context’s keys and values. This decreases the memory usage, as long as \(2hnd_k>nd\), where \(h\) is the number of self-attention heads, and \(n\) is the current context length. As implementations are often I/O bound, this might actually improve performance, by saving on the memory bandwith, despite performing more computations.
Dynamic heads
What is cool about this new equation is that we don’t really care about the key and value matrices. We can make them arbitrary. This effectively allows us to perform self-attention for a token with a different attention head than the heads used with prior tokens. Potentially, the attention head could even by dynamic. We could, for example, calculate \(W^K\) and \(W^V\), using yet another neural network, based on \(\vec{q}\). So, instead of training concrete \(W^K\) and \(W^V\) matrices, we would train an NN that creates them from \(\vec{q}\). As far as I know, this idea hasn’t been explored before.
Now that we reintroduced the \(\mathrm{softmax}\) nonlinearity, let’s go back to our diagram and see what we get.
Ok, so we got some colorful blobs. As a reminder, this diagram directly illustrates how the self-attention mechanism works. Given a set of tokens in the context (purple points), it illustrates what the layer does to every possible input. In the illustration, each circle gets blobiffied into a curve of the same color.
The resultant curves, look a bit like splines or Bézier curves. Also, it seems that no matter how hard I try to position the tokens, the resultant curve is always closed, smooth and not self-intersecting.
Here, circles end up transformed into curves. Extrapolating to higher dimensions, hyperspheres would end up transformed into differentiable manifolds. I note this explicitly, as, if the input vectors were normalized, they’d lie on a unit hypersphere.
I recommend you to open this diagram in a larger window as you can zoom in really closely and see that the curve actually has fine details that depend on the token point placement. There is also this version that also displays the “density” of points on the curves. It explains why the result vector doesn’t traverse the curve with uniform speed.
To me, this looks like a flexible piece of cloth stretched between strings with beads, where token vectors are “wells” that attract the beads to some position. Here are my takeaways:
- The tokens, which are in the convex hull of the token point set, are the most significant, as they define the overall shape of the curve.
- Tokens, which are located inside the hull, change the general shape slightly, but affect the contour lines inside the hull.
- If the tokens are far from the origin, they attract the beads very strongly and the shape is very similar to the convex hull of the token set. Most points end up mapped near to the tokens themselves.
It appears that the self-attention transformation warps the embedding space in a smooth and continuous manner. We can take a closer look at this in the following diagram, which showcases what the self-attention transformation does to the coordinate system grid:
We can again see that the space is heavily compressed, as points, which are further than a few units from the origin, are all squished into the edges or vertices of the polygon.
In a way, we can consider the tokens as singularities. After all, they are the points, where the space collapses upon itself. We can focus only on the coordinate system axes:
Their ends eventually end up in one of the tokens. This presents us with an interesting problem – given that there are five tokens, but there are only four ends, there must be some point at which an end “jumps” between two tokens, a discontinuity. This is interesting, as a small change in the position of a token completely changes, which way the space is distorted. This reminds me of a classifier, as it creates a discrete mapping of axis ends to tokens.
For a given arrangement of tokens, the mapping is continuous and the transformation is differentiable with respect to the query. However, the mapping is not differentiable with respect to the tokens.
Recall that there is a residual connection around the self-attention mechanism. What would happen, if we would add the mapped position to the original? Let’s look what would happen, if the \(W^Q\), \(W^K\), and \(W^V\) matrices were all identity transforms (here is a version including gridlines in the range \([-1,1]\), if your computer can handle that):
At first glance, this seems surprisingly organized to me. I mean, this is just matrix multiplication and a \(\mathrm{softmax}\), with arbitrary token values. Yet, still, this complex shape is created.
After observing it for a while, we may see that the diagram features a few areas with distinct characteristics:
- The \((-1, -1), (-1, 1), (1, 1), (1, -1)\) square is transformed into a shape resembling the convex hull of the token point set.
- The region inside that square is warped. Non-hull tokens influence the nature of this warping.
- Far from the hull, the coordinate system’s axes are straight. They look like funnels that attract gridlines towards them. The ends of the axis are “appropriated” by respectively the leftmost, topmost, rightmost and bottommost tokens (AABB).
- The edges of the hull seem to project funnels outwards. It looks like the gridlines are “avoiding” them and, hence, stretch quickly to the other side.
- The remaining parts of the coordinate system seem to be left in peace.
I find it impressive that the self-attention layer essentially finds
the convex hull of the token point set, as well as the AABB of the hull.
This is all, while having no conditional statements or control flow
constructs, as it is just two matrix multiplications and a \(softmax\) operation – add
,
sub
, mul
, div
,
exp
.
We can think about the “density” of this new space. It looks to me like the hull and its projected funnels are regions of low density – points avoid it and prefer not to end up there, and gridlines are stretched there. We can confirm this by plotting points instead of lines:
Now, if we vary the influence of the self-attention layer – multiply it by a constant in \((0,1)\) before adding it to the original value – we can clearly see that the embedding space ends up split into five regions:
I hope that this insight into how the self-attention layer “looks” will perhaps allow to draw further conclusions about it or optimise it, for example by constructing clever data structures that will accelerate its computation. The convex hull seems to be of importance here. Furthermore, if all transformations were somehow made “cavity-preserving”, the convex hull would need to be calculated only once.
Multi-headed self-attention with pre-normalization. Notice how many nonlinear layers are missing.
Looking at the entirety of multi-headed self-attention, we can see that it is no more than a regular neural network. Still, transformer blocks are made up of two units: self-attention and a feed-forward network. The FNN is said to capture some intra-token relationships and, being independent from the tokens, potentially add context-independent information. Still, seeing that self-attention is just a neural network, do we really need to append a different neural network to it? Why not just work with what we already have and somehow put that universal context-independent knowledge inside the attention mechanism itself… continue in section B
B. Attention is… a vector database
Recently, it has become popular to imbue large language models with external knowledge by using a so called “vector database”.
The goal is to make the model be able to factually answer questions and, preferably, cite its sources. This can be useful, for example, for online technical assistants. An LLM can be given access to the entire manual, support forum, changelog, issue database, etc. The assistant will serve as a more user-friendly interface to searching all these resources and it will also be able to synthesize some level of new responses that will, hopefully, solve the user’s problem – or at least, make them not dial the call center.
The problem is that LLMs have a limited context window and, hence, cannot simply be given the entire text of all these sources combined. As such, we can provide it with only a subset of our “knowledge”. So, how do we know what to tell it? This is the job of the vector database.
There exist machine learning models that convert text into a vector
of floating point numbers. These are often called embedding models for
the fact that they create text embeddings. What we do is we
take the combined text of all our sources and we embed it. We
do this by splicing it into parts (could be senteces, paragraphs, pages,
often these are intermixed) and then embedding each of the parts
separately. Then, we store the (text, embedding)
pairs in a
database that we will later query. Then, to query the database, we just
embed the user’s prompt and search for embeddings that are “similar” to
it.
Many companies are raising million of dollars, creating complex accelerated vector databases. Still, we can construct a naive and simple database with a few lines of code. We simply have to pick a subset of the pairs that are the most “similar” to our query:
userQuery: str
db: list[tuple[str,list[float]]]
...
userEmbedding: list[float] = embed(userPrompt)
results: list[str] = []
for text, embedding in db:
if similarEnough(embedding, userEmbedding):
results.append(text)
return smartAI(context=results, prompt=userPrompt)
That would be the general idea. But what is the “similarity”? Usually, it is the cosine of the angle between the text’s embedding and the user query’s embedding, as measured in their multidimensional space. As often all these embeddings are normalized, calculating the cosine usually amounts to simply computing a dot product of the embeddings.
After thinking for a while about this, I saw a striking resemblance to something else I knew. “Query, dot product, text embedding vector, text string” – this all sounds just like the self-attention layer in the transformer. We have queries – users’ prompts, keys – text embeddings, and values – text strings themselves. The only difference between a vector database and the self-attention layer is apparently the storage format of the data. Here are some takeaways from this analogy:
- Why are the queries and the text string using the same embedding model? In the transformer, there is a separate \(W^Q\) matrix for calculating the query vectors and a separate \(W^K\) matrix for the keys.
- Why one head? Each string and query has only one embedding. This is like having only a single head in a transfomer. Understandably, this is probably caused by the storage requirements. Storing separate embeddings for many “heads” would use too much space. Unless… we didn’t have to store them at all.
In the previous section, we saw that the self-attention layer works just fine, if instead of remembering the key and value vector for each head for each token, we remember only the token itself. We can do just that with our embeddings.
The current embedding models are trained to produce these, possibly optimal, query-key vectors. But, we can change this design and instead train the model to create good token vectors, and train it alongside a set of matrices that we will use to first transform the embeddings. Given that current embedding vectors are long already, this doesn’t seem to bring any additional computational cost. Well, we would need to multiply the query embedding by a set of query matrices, but that time is negligible compared to having to go through the entire database, while querying it. The only component that needs to be changed is the embedding model. The database is good as-is and it is oblivious of the change – it doesn’t care about the “nature” of the vectors it stores.
Going a step further, we may also just get rid of the text itself. Becuase we already store the texts’ vector representations, we can use a value matrix to turn them into the values that we want to accumulate. That matrix would be trained along the embedding model, just like the query and key matrices.
Ok, so now we can rest as we have freed ourselves from variable-length strings and operate only on vectors. But, wait… why do we need an external vector database at all? Given that we just replicated the self-attention mechanism, albeit with some additional steps, like training a new embedding model, why not just go ahead and use self-attention instead?
Traditionally, the self-attention layer operates on the context given to the transfomer – hence self-attention. But what if we changed this and introduced a new attention layer – database-decoder cross-attention?
Currently, LLMs are trained on vast amounts of data and pick up knowledge along the way. Where is it stored, we don’t know (somewhere in the weights, duh). So why not create an explicit container for the model’s knowledge – an internal database. We would first train the model with general text, as usual, but then, we would go ahead and teach it specific subjects. (Typically, this would be called fine tuning, but that name suggests that we are taking a complete model and chaging it, while in this case, we are simply splitting the training process into steps.)
The model would have a database cross-attention layer in each of its decoders. We would append zeros at the end of the database, and train the model on a new subject, while keeping the rest of the parameters frozen. This would mean that whatever it learns must be located in the new scratch space, we provided it.
What about stability? Fine-tuning a model often decreases its ability to generalise and requires additional restorative tuning. Well, it is possible that this would also happen in this case. However, theoretically, there is a reason for why it would not happen in this case. Namely – usual fine tuning appends new layers to a model and trains them. This means that everything the model produces gets processed by these new layers. As such, it’s easy for the model to become worse, as all its outputs are garbled by the new layers. Meanwhile, this approach of appending data to a database used in a cross-attention layer, does not interfere with the model’s previous functioning. As the data is zero-initialized, it has no effect on the output of the model. When we train the model, we only add to its pool of knowledge, so theoretically, it shouldn’t forget anything it already knew, or lose any abilities, as, fundamentally, we are not modyfing the model at all. Instead, we are adding a sort of plugin, or mix-in to it.
TOME
Since writing this, I have learned that a similar idea has already been explored in TOME (de Jong et al., 2022). However, there are some differences. The process the paper outlines is roughly:
- Create a set of mentions. A mention is a certain type of text string that mentions named entities, their properties and relations between them.
- Train a transformer encoder (E) that will create a key and value vector for each mention. These keys and values form the memory (M).
- Train the main transformer with M-cross-attention. M contents are frozen.
- Add new knowledge to the transformer by encoding it with E and inserting the new keys and values into M.
Meanwhile, I propose:
- Create a fixed-sized memory (M) – a list of key and value vector pairs. Initialize it to random values.
- Train the main transformer with M-cross-attention. M’s contents are trained together with the transformer’s parameters.
- Add new knowledge to the transformer by adding some number of rows to M and training them. The transformer and the previous M’s contents are frozen.
Let’s compare the two. My approach does not need gathering any new data. The memory is trained on the same text corpus that the transformer is trained on. The paper requires explicitly creating a set of mentions. This means that my approach requires only changing the implementation of the transformer, while the entire process of training remains unchanged, while the paper needs changes to both to be made. Using all text as the source of knowledge can allow the transformer to capture more information inside the memory. However, learning from mentions can be more efficient and possibly more effective, as these mentions can be more information-dense than general text. Generally, the paper seems to outline a more “strict” learning paradigm than I do. In my approach, the transformer can learn arbitrary information in and an arbitrary format.
Going further, we can think if we really need a separate cross-attention layer. Maybe we could query both our internal database and the context in the same attention layer. By combining the two together there would be no distinction between the model’s prior knowledge and its working context. Going further still, we could go ahead and remove FNNs and leave only the attention layers. After all, any necessary information can just be saved in our database.
C. Multi-decoder
Originally, if ChatGPT was asked what is the sum of the squares of the first 20 primes, it would just make up some number. Now, it describes step by step, how to calculate the result and gets it correctly. This is most likely the result of it having been fine-tuned by OpenAI on chain-of-thought examples.
The idea behind “let’s
think step by step”, as well as previous prompt-engineering guidance
is that an LLM needs more “space” when dealing with a “harder” task.
Intuitively, the amount of computation that happens, when a token is
generated, is constant, regardless of the prompt or the new token. We
can treat this amount of computation as time that the model has to
“think”. This is because all the LLM’s generation logic happens during,
well, token generation. As such, when we ask an LLM a question like
P=NP? [Y/N]
, we do two things – we give it a hard problem,
and we give it little time, by forcing it to answer with only Yes or
No. We can expect poor performance, as the LLM simply doesn’t havve
enough time to figure out an answer to our question. On the other hand,
if we tell the model to “think step by step”, we suggest it to produce a
longer output. This essentially means giving the model more time. Using
this knowledge, we can estimate a task’s difficulty: \[
d=\frac{c}{\tau}
\] The \(d\)ifficulty of a task
is its \(c\)omplexity per unit \(\tau\)ime (solving a complex problem in
little time is difficult – it is easier, when more time is available or
when the task is simpler).
The performance of the model is dependent on the type task given. Some LLMs are better at solving particular types of tasks rather than others, but models’ general capability can also be compared. The performance of a model at a particular task can by roughly modeled as: \[ p=\frac{Cf(t)}{d} \] The performance is higher for more \(C\)apable models and for models better \(f\)it for the particular \(t\)ype of task at hand. At the same time, it decreases as the \(d\)ifficulty of a task gets higher.
By levaraging fine-tuning, we can teach an exisiting LLM to produce outputs that resemble a step by step reasoning process. Still, it is us controlling the “thought-process” of the model. This may cause us to teach the model on thought-processes with inadequate length – different LLMs with different capabilities may require different lengths. Additionally, the model has to reason like a human – fine-tuning limits it to “thinking” only in words. Preferably, the model should be able to “think” in a “neural format” – arbitrary vector representations that need not map to textual tokens.
The key limitation is that LLMs have no real “scratch-space”, as all
their decoded tokens are included in the output. This means that their
“thinking time” is directly bound to the length of their output. One
option to decouple the two is to introduce scoping tokens, like
<thought>
and <\thought>
. LLMs are
proficient at using scopes, both in code, as well as in regular language
(direct speech). The purpose of these scoping tokens would be simple –
anything between them is not decoded and does not affect the loss. The
problem with this is that there is no obvious way of making the model
generate these tokens at all or of lmiting the model’s “thinking time”.
Having the model generate tokens forever would certainly be an
undesirable quality.
I suggest a mechanism designed to sidestep these problems that allows the model to reason for as long as it fits. Additionally, it requires no explicit training, as it is an extension of the Transformer rather than a novel training method.
The original Transformer consited of an encoder and a decoder. Since then, other models have been proposed that include only one of these elements. Notably, all leading models, like GPTs, LLaMAs and PaLMs, are decoder-only. This shows that the encoder-decoder cross-attention mechanism is not needed for achieving top performance.
Seeing these “reduced Transformers” I naturally wondered, if there are any models that do the reverse – add additional encoders or decoders to the Transformer. I am not aware of any, and hence propose how one could look.
The key idea is to have the Transformer generate every single token in an auto regressive manner itself. This will allow the model to vary its “thinking time”. This is achieved by introducing an intermediary decoder.
The model consists of a masked encoder and two autoregressive decoders. The encoder (and the decoders) uses causal attention masking – ie. a token cannot attend to future tokens. In principle, this makes each step of the autoregression process self-contained, just like in a decoder-only model.
The first decoder generates an intermediary token sequence. The format of it is left opaque – the model can generate token vectors as it sees fit. The produced token could be given to the decoder as-is or potentially additionally positionally encoded. The decoder is first given some BOS (beginning-of-sequence) token. As this decoder does not have direct access to the prompt, it has to use encoder-decoder cross-attention to relate to the prompt. This design allows the Transformer to “think” for as long as it needs. The autoregression will stop when a certain criterion is met – it will be discussed briefly.
When the first decoder is done generating, the second decoder begins its work. This time, it has access to the input prompt, like in a decoder-only model. Additionally, it has cross-attention layers that imbue it with the “thoughts” generated by the first decoder. This decoder produces the next output token, just as usual. When it completes, the token is decoded, reencoded and the entire process autoregresses. The keys and values of prompt tokens can be (and should be) cached inside the encoder and the second decoder, as well as in their embeddings. The middle decoder can potentially use caching. If it does, it will have access to all previous “thoughts”, which could potentially save it from doing some redundant work. On the other hand, this would be more computationally expensive, and would make the autoregression step no longer self-contained – a token’s generation would rely on previous “reasoning”.
How could this be implemented? The first problem to solve is to
determine the stop criterion. Because this is an internal decoder, we
don’t want to teach the model to produce a special EOS token. At the
same time, we would like to be able to put an upper bound on the amount
of time the model can spend. To achieve this, I propose to introduce a
life
value. It would be a number in \([0, 1]\) representing how “done” is the
model – when life
reaches zero, this means that the model
has concluded “thinking”. After generating each token, this value would
be decreased by some amount life -= f(token)
. This allows
the model to ask for more time or to finish early, by producing an
appropraite type of token. What happens if life < 0
?
This is the interesting part.
Usually, creating such a transformer would come with a significant problem – it would be non differentiable. If a model with parameters \(p\) produced \(n\) tokens in the intermediary layer, given input \(i\) in the intermediary layer, how many would it produce if the values of \(p\) were to slightly change? Normally, we couldn’t tell. That’s because \(n(p, i)\) is not continuous. As a model can produce only an integer amount of tokens, this function must necesarily be non-differentiable with respect to \(p\). But what if the model could produce “half” a token?
This is the other function of life
. If it becomes
negative after a token is produced, we know that the model “went to
far”. To account for this, the model needs to “unproduce” a part of the
token, to make life
be capped at zero. This requires
“cutting a token in half”. Fortunately, these tokens are only vectors.
So, naturally, we can just decrease the magnitude of the token vector by
a certain amount. This will decrease its respective influence in the
cross-attention layer in the second decoder:
life: float
token: vector
dl = f(token)
nl = life - dl
credit = max(0, -nl)
influence = 1 - (credit / dl)
token *= influence
addToContext(token)
life = nl
if life < 0:
break
This allows the LLM to control its own execution (via the
break
statement), while retaining its
differentiability.
All model parameters would be random-initialized as usual. As long
f(token)
is made to always return positive values, the
model is guaranteed to stop working at some point. This can be
controlled by changing the implementation of f
. Teaching
the model works as usual – the loss, computed as the difference between
the model’s output and the desired one, is backpropagated. This alone,
can cause the model to work ineffectively, as it doesn’t take
computation time into account. As such, I suggest including the length
(number of tokens; 3.3
above) of the intermediary sequence
as part of the loss.
D. Attention is… multiplication
One of the prevalent questions surrounding transformers is “why are they so good?” Usually, the self-attention mechanism is provided as the explanation. But what is so special about self-attention?
We saw that, fundamentally, self-attention is just an FNN. So what’s all the fuss about? Is it the \(\mathrm{softmax}\) activation function? I think that crucial component that self-attention brings is multiplication.
A neural network is a series of transformations that are applied to some input data. Usually, it is a chain of linear layers intermixed with nonlinear ones. Yes, models can be more complex, like an autoregressive transformer – it has a loop – or a diffusion model – it adds noise at each iteration. But still, if we limit ourselves to only linear and nonlinear layers, we miss a crucial component.
Typical neural networks never multiply two inputs together. If we’d
look at the path a single input float
traverses, we’ll see
that it gets scaled by pretrained weights, it has pretrained biases
added to it, it has expressions of other inputs added to it, it is
transformed by nonlinear functions. But never are two input
float
s multiplied together.
Even if we look at RNNs, we can see that the hidden representation as well as the input only get multiplied by pretrained weights. Only LSTMs introduce self-multiplication in an indirect way, as expressions dependent on the cell’s current and previous states get multiplied. This happens when going through the various gates, as they are dependent on the cell state. Although LSTMs are significantly less advanced than Transformers, this introduction of self-multiplication could be a possible explanation of their improved performance over similarly-sized RNNs.
Perhaps it would be beneficial to test
Productional
layers, instead of only
Linear
ones. There are many ways to design a layer that
involves products of its inputs. Still, I think that adding such a layer
could potentially result in performance superior to purely linear
ones.
Attention is… protein folding
Protein folding is a task that involves predicting the final locations of a set of particles, given their initial locations. The difficulty stems from the fact that the position of particles changes the magnitude of electrostatic forces between the particles. If we think of this more abstractly, a given set of particles is defined by only their positions, and is associated with a well-defined set of couplings between the particles that is a direct product of their positions. We could think of self-attention in a similar manner. Tokens could be considered as particles, while their relative compatibilities, could be interpreted as measures of some interactions between them. In this case, the token system passing through the various attention layers could be seen as a set of particles undergoing folding over time. A critical consequence of this is that the system would have to either “converge” to a stable state or enter an oscillating pattern. In practice, this would mean that Transformers have scalability limitations – adding more layers to them results in more accurate predictions, but has diminishing returns.
Attention is… a knowledge graph
Consider a single attention head. It assigns a certain “compatibility value” to each pair of tokens – the dot product between their respective query and key. We can treat tokens in the context as nodes of a clique, and these compatibilities as weights of edges between the nodes. If we look at multiple heads, we can interpret the token graph as a multilayer graph. Each head yields a different set of edges – a layer of the graph. This can be compared to a knowledge graph, in which the tokens model entities, while the heads capture relationships between them. Maybe attention proves so effective, as it allows the model to understand its input in the form of a graph?
Perplexity
Beam width is a hyperparameter of the autoregressive token generation process. As such, its value is determined mostly using trial and error and by trading off either performance of generation quality. Could it be possible to find an optimal value for beam width – larger widths would not yield better performance? My idea is to set it to the model’s perplexity. That is because perplexity is, fundamentally, a measure of the model’s uncertainty in generating its predictions. When a model has a perplexity of, say, 15, this means that when a model generated a new token, it was as unsure as if it had to pick between 15 options. Perplexity can be treated as a weighted branching factor. To me, this seems similar to beam width. After all, the beam width is the number of best predictions that are being done in parallel. So, intuitively making it equal to the perplexity should explore all the model’s possible outputs.
Teach – not train
How do (human) children learn to read? If you handed The Works of Shakespeare Combined to a kid, would you expect them to learn English from them? I wouldn’t. Yet, this is how we treat LLMs. LLMs are trained on trillions of tokens, which is orders of magnitude higher than what it takes humans to learn linguistic capabilities. Admittedly, the training process teaches the models not only to communicate effectively, but also gives them knowledge and other capabilities. Yet, still, this seems highly inefficient compared to how humans learn. Usually, this is either attributed to an imperfect architecture or limited model size. However, to me it seems that the problem lays in the training process, rather than in the model itself.
If artificial neural networks are to mimic their real counterparts in any way, they learning patterns should be similar as well. As such, it seems understandable to me that a model has to see billions of tokens in order to learn proper grammar. This is because the learning material it is being given is complete chaos. It is no coincidence that people talk to children in a “childish” manner. This seems to have been designed by evolution in order to facilitate the language learning process for the kids. In general, optimal learning performance will be achieved only when the learning material’s difficulty is adequate compared to the learner’s skill.
This is why training models is so inefficient. Let’s say that we train a model to distinguish species of fish. This seems a bit awkward, as the model is suddenly shown pictures of some “things”, while it doesn’t even know “what is a fish”. In this context, fine-tuning a foundation model, can be treated as simply training a model new skills. Recent research often suggests that large models learn the majority of their capabilities during their pretraining stages, and fine-tuning only allows these capabilities to surface. This is a plausible explanation, but I also see a different one – pretraining teaches the model how to learn. This doesn’t happen directly. Pretraining only makes the model a more skilled student that is able to learn harder tasks quicker, based on their previous knowledge. And their new task doesn’t have to be reflected in their pretraining process.
As such, it could be worth exploring how would models behave, if they would be taught, rather than trained. This would affect two aspects of the learning process – the dataset and the optimizer. The dataset would have to be tiered, from easiest to hardest examples. It would also preferably be dynamic – if a model struggles (has high loss on the validation set from a certain tier), it would be trained on it for additional epochs. In addition to this difficulty gradation, other pedagogical techniques could be used. Some examples include repeating earlier examples To reinforce older, potentially partially forgotten, abilities.
The other component to be affected is the optimizer. If it is to mimic human learning, it should begin with a very large learning rate, and then, gradually decrease it. This can sound contrary to usual advise on training models – wouldn’t this cause overfitting? Well, overfitting is caused when a model continuously relearns its training data. Here, however, after the model would learn simpler examples, it would never return to them – learning them would be only a part of the process. In other words, I propose to overfit the model on purpose, as the learned data will get overwritten anyway. This is a major paradigm shift in the comprehension of the learning process. Currently, a learning model has only one goal – minimize the overall loss. I propose changing this goal – facilitate further learning. This can seem counterintuitive – what is the point of teaching a model something, only to later overwrite that? I say that it is to make learning faster, easier and more efficient. This overfitting can be treated not as training, but a way of initializing the model’s weights. Instead of random values, we use weights associated with an easier problem.
Continuous attention
What would happen if the set contained literally all the possible vectors? After all, as the context grows larger and larger, it will contain increasingly more vectors. Well, the answer would always be \(0\), as they would cancel each other out. So no – the context cannot contain all vectors. More generally, it cannot contain vectors that are distributed about “homogenously” in space. Speaking of homogeneity, this makes me wonder about the nature of the self-attention context. Currently, it is discrete – it is a set of vectors. However, intuitively, it seems, that what we care about actually is what is the distribution of the vectors. If we think of the set of vectors as a point cloud, we don’t care about the “particles of the cloud”. From a distance, it is, well, a cloud. We care about its density. The current solution of storing a set of vectors is simply a discretized version of this.
So let’s reevaluate what are our inputs. We have a density function \(\rho(\vec{t}):\R^{d_k}\rarr(0,1)\) that can be roughly imagined as our set of points, after having gone through a Photoshop blurring filter. And we have a single vector \(\vec{v}\) – from our single attention head under consideration now.
I think we can reasonably redefine the attention-mechanism using the provided analogy and make it continuous: \[ \vec{z}=\left(\int_{\R^{d_k}}\left(\vec{t_i}\cdot\vec{v}\right)\vec{t_i}* \rho(\vec{t})d^{d_k}\vec{t}\right)W^V \]