You could have designed state of the art positional encoding
Gall's Law
A complex system that works is invariably found to have evolved from a simple system that worked
John Gall
This post walks you through the step-by-step discovery of state-of-the-art positional encoding in transformer models. We will achieve this by iteratively improving our approach to encoding position, arriving at Rotary Postional Encoding (RoPE) used in the latest LLama 3.2 release and most modern transformers. This post intends to limit the mathematical knowledge required to follow along, but some basic linear algebra, trigonometry and understanding of self attention is expected.
Problem Statement
You shall know a word by the company it keeps
John Rupert Firth
As with all problems, it is best to first start with understanding exactly what we are trying to achieve. The self attention mechanism in transformers is utilized to understand relationships between tokens in a sequence. Self attention is a set operation, which means it is permutation equivariant. If we do not enrich self attention with positional information, many important relationships are incapable of being determined.
This is best demonstrated by example.
Motivating Example
Consider this sentence with the same word in different positions:
Intuitively, "dog" refers to two different entities. Let's see what happens if we first tokenize them, map to the real token embeddings of Llama 3.2 1B and pass them through torch.nn.MultiheadAttention.
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
model_id = "meta-llama/Llama-3.2-1B"
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
text = "The dog chased another dog"
tokens = tok(text, return_tensors="pt")["input_ids"]
embeddings = model.embed_tokens(tokens)
hdim = embeddings.shape[-1]
W_q = nn.Linear(hdim, hdim, bias=False)
W_k = nn.Linear(hdim, hdim, bias=False)
W_v = nn.Linear(hdim, hdim, bias=False)
mha = nn.MultiheadAttention(embed_dim=hdim, num_heads=4, batch_first=True)
with torch.no_grad():
for param in mha.parameters():
nn.init.normal_(param, std=0.1) # Initialize weights to be non-negligible
output, _ = mha(W_q(embeddings), W_k(embeddings), W_v(embeddings))
dog1_out = output[0, 2]
dog2_out = output[0, 5]
print(f"Dog output identical?: {torch.allclose(dog1_out, dog2_out, atol=1e-6)}") #True
As we can see, without any positional information, the output of a (multi headed) self attention operation is identical for the same token in different positions, despite the tokens clearly representing distinct entities. Let's begin designing a method of enhancing self attention with positional information, such that it can determine relationships between words encoded by their positions.
How should an ideal positional encoding scheme behave?
Desirable Properties
Let's try and define some desirable properties that will make the optimization process as easy as possible.
Property 1 - Unique encoding for each position (across sequences)
Each position needs a unique encoding that remains consistent regardless of sequence length - a token at position 5 should have the same encoding whether the current sequence is of length 10 or 10,000.
Property 2 - Linear relation between two encoded positions
The relationship between positions should be mathematically simple. If we know the encoding for position , it should be straightforward to compute the encoding for position , making it easier for the model to learn positional patterns.
If you think about how we represent numbers on a number line, it's easy to understand that 5 is 2 steps away from 3, or that 10 is 5 steps from 15. The same intuitive relationship should exist in our encodings.
Property 3 - Generalizes to longer sequences than those encountered in training
To increase our models' utility in the real world, they should generalize outside their training distribution. Therefore, our encoding scheme needs to be adaptable enough to handle unexpected input lengths, without violating any of our other desirable properties.
Property 4 - Generated by a deterministic process the model can learn
It would be ideal if our positional encodings could be drawn from a deterministic process. This should allow the model to learn the mechanism behind our encoding scheme efficiently.
Property 5 - Extensible to multiple dimensions
With multimodal models becoming the norm, it is crucial that our positional encoding scheme can naturally extend from to . This will allow models to consume data like images or brain scans, which are and respectively.
Now we know the ideal properties (henceforth referred to as ), let's start designing and iterating on our encoding scheme.
Integer Position Encoding
The first approach that may jump to mind is simply to add the integer value of the token position to each component of the token embedding, with values ranging from where is the length of our current sequence.
In the above animation, we create our positional encoding vector for the token from the index and add it to our token embedding. The embedding values here are a subset of the real values from Llama 3.2 1B. We can observe that they're clustered around 0. This is desirable to avoid vanishing or exploding gradients during training and therefore is something we'd like to maintain throughout the model.
It's clear that our current naïve approach is going to cause problems. The magnitude of the position value vastly exceeds the actual values of our input. This means the signal-to-noise ratio is very low, and it's hard for the model to separate the semantic information from the positional information.
With this new knowledge, a natural follow on might be to normalize the position value by . This constrains the values between 0 and 1, but introduces another problem. If we choose to be the length of the current sequence, then the position values will be completely different for each sequence of differing lengths, violating .
Is there a better way to ensure our numbers are between 0 and 1? If we thought really hard about this for a while, we might come up with switching from decimal to binary numbers.
Binary Position Encoding
Instead of adding our (potentially normalized) integer position to each component of the embedding, we could instead convert it into its binary representation and s t r e t c h our value out to match our embedding dimension, as demonstrated below.
We've converted the position of interest (252) into its binary representation (11111100) and added each bit to the corresponding component of the token embedding. The least significant bit (LSB) will cycle between 0 and 1 for every subsequent token, whilst the most significant bit (MSB) will cycle every tokens where is the number of bits. You can see the positional encoding vector for different indices in the animation below .
We've solved the value range problem, and we now have unique encodings that are consistent across different sequence lengths. What happens if we plot a low dimensional version of our token embedding and visualize the addition of our binary positional vector for different values.
We can see that the result is very "jumpy" (as we might expect from the discrete nature of binary). The optimization process likes smooth, continuous and predictable changes. Do we know any functions with similar value ranges that are smooth and continuous?
If we looked around a little, we might notice that both and fit the bill!
Sinusoidal positional encoding
The above animation visualizes our position embedding if each component is alternatively drawn from and with gradually increasing wavelengths. If you compare it with the previous animation, you'll notice a striking similarity!
We've now arrived at Sinusoidal embeddings; originally defined in the Attention is all you need paper. Let's look at the equations:
where is the tokens position index, is the component index in the positional encoding vector, and is the model dimension. is the base wavelength (henceforth referred to as ), which we stretch or compress as a function of the component index. I encourage you to plug in some realistic values to get a feel for this geometric progression.
There's a few parts of this equation that are confusing at first glance. How did the authors choose ? Why are we using and for even and odd positions respectively?
It seems that using for the base wavelength was determined experimentally . Deciphering the usage of both and is more involved, but crucial for our iterative approach to understanding. The key here is our desire for a linear relation between two encoded positions . To understand how using and in tandem produce this linear relation, we will have to dive into some trigonometry.
Consider a sequence of sine and cosine pairs, each associated with a frequency . Our goal is to find a linear transformation matrix that can shift these sinusoidal functions by a fixed offset :
The frequencies follow a geometric progression that decreases with dimension index , defined as:
To find this transformation matrix, we can express it as a general 2×2 matrix with unknown coefficients , , , and :
By applying the trigonometric addition theorem to the right-hand side, we can expand this into:
This expansion gives us a system of two equations by matching coefficients:
By comparing terms with and on both sides, we can solve for the unknown coefficients:
These solutions give us our final transformation matrix :
If you've done any game programming before, you might notice that the result of our derivation is oddly familiar. That's right, it's the Rotation Matrix! .
So the encoding scheme designed by Noam Shazeer in Attention is all you need was already encoding relative position as a rotation back in 2017! It took another 4 years to go from Sinusoidal Encoding to RoPE, despite rotations already being on the table...
Absolute vs Relative Position Encoding
With the knowledge in hand that rotations are important here, let's return to our motivating example and try to discover some intuitions for our next iteration.
Above we can see the absolute positions of our tokens, and the relative positions from to every other token. With Sinusoidal Encoding, we generated a separate vector which represents the absolute position, and using some trigonometric trickery we were able to encode relative positions.
When we're trying to understand these sentences, does it matter that this word is the 2157th word in this blog post? Or do we care about its relationship to the words around it? The absolute position of a word rarely matters for meaning - what matters is how words relate to each other.
Positional encoding in context
From this point on, it's key to consider positional encoding in the context of self attention. To reiterate, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output.
In all our previous iterations, we've generated a separate positional encoding vector and added it to our token embedding prior to our , and projections. By adding the positional information directly to our token embedding, we are polluting the semantic information with the positional information. We should be attempting to encode the information without modifying the norm. Shifting to multiplicative is the key.
Using the dictionary analogy, when looking up a word (query) in our dictionary (keys), nearby words should have more influence than distant ones. The influence of one token upon another is determined by the dot product - so that's exactly where we should focus our positional encoding!
The geometric interpretation of the dot product shown above gives us a magnificent insight. We can modulate the result of our dot product of two vectors purely by increasing or decreasing the angle between them. Furthermore, by rotating the vector, we have absolutely zero impact on the norm of the vector, which encodes the semantic information of our token.
So now we know where to focus our attention, and have seen from another angle why a rotation might be a sensible "channel" in which to encode our positional information, let's put it all together!
Rotary Postional Encoding
Rotary Postional Encoding or RoPE was defined in the RoFormer paper (Jianlin Su designed it independently on his blog here and here). While it may seem like voodoo if you skip to the end result, by thinking about Sinusoidal Encoding in the context of self attention (and more specifically dot products), we can see how it all comes together.
Much like in Sinusoidal Encoding, we decompose our vectors or , instead of pre-projection ) into 2D pairs/chunks. Rather than encoding absolute position directly by adding a vector we drew from sinusoidal functions of slowly decreasing frequencies, we cut to the chase and encode relative position by multiplying each pair with the rotation matrix.
Let or be our input vector at position . We create a block diagonal matrix where is the corresponding rotation matrix for that component pairs desired rotation:
Much the same as Sinusoidal Encoding, is simply:
In practice, we don't use a matrix multiplication to compute RoPE as it would be computationally inefficient with such a sparse matrix. Instead, we can directly apply the rotations to pairs of elements independently, taking advantage of the regular pattern in the computation:
That's all there is to it! By artfully applying our rotations to 2D chunks of and prior to their dot product, and switching from additive to multiplicative, we can gain a big performance boost in evaluations .
Extending RoPE to -Dimensions
We've explored the case for RoPE and by this point I hope you've gained an intuitive understanding of an admittedly unintuitive component of transformers. Finally, let's explore extending it to higher dimensions, such as images.
A natural first intuition could be to directly use the coordinate pairs from the image. This might seem intuitive, after all, we were almost arbitrarily pairing up our components previously. However, this would be a mistake!
In the case, we encode the relative position through a rotation of pairs of values from our input vector. For data, we need to encode both horizontal and vertical relative positions, say and independently. RoPE's brilliance lies in how it handles multiple dimensions. Instead of trying to encode all positional information in a single rotation, we pair components within the same dimension and rotate those, otherwise we would be intermixing the and offset information. By handling each dimension independently, we maintain the natural structure of the space. This can generalize to as many dimensions as required!
The future of positional encoding
Is RoPE the final incarnation of positional encoding? This recent paper from DeepMind deeply analyses RoPE and highlights some fundamental problems. TLDR: RoPE isn't a perfect solution, and the models mostly focus on the lower frequencies and the rotation for a certain percent of low frequencies improves performance on Gemma 2B!
I anticipate some future breakthroughs, perhaps taking inspiration from signal processing with ideas like wavelets or hierarchical implementations. As models are increasingly quantized for deployment, I'd also expect to see some innovation in encoding schemes that remain robust under low-precision arithmetic.
Conclusion
Positional encoding has and continues to be treated as an after thought in transformers. I believe we should view it differently - self attention has an Achilles heel that has been repeatedly patched.
I hope this blog post showed you that you too could have discovered state of the art positional encoding, despite it being unintuitive at first. In a follow up post I'd love to explore practical implementation details for RoPE in order to maximise performance.
This post was originally published here.
References
- Transformer Architecture: The Positional Encoding
- Rotary Embeddings: A Relative Revolution
- How positional encoding works in transformers?
- Attention Is All You Need
- Round and round we go! What makes Rotary Positional Encodings useful?
- RoFormer: Enhanced Transformer with Rotary Position Embedding
[^1]: Binary and Sinusoidal animations are reproductions of animations contained in this video.
[^2]: Using gives us unique positions, or a theoretical upper bound on the context length at ~63,000.
[^3]: Pieces of this post are based on this fantastic post by Amirhossein Kazemnejad.
[^4]: For empirical evidence, see this great post by EleutherAI.