Introducing RWKV - An RNN with the advantages of a transformer
ChatGPT and chatbot-powered applications have captured significant attention in the Natural Language Processing (NLP) domain. The community is constantly seeking strong, reliable and open-source models for their applications and use cases. The rise of these powerful models stems from the democratization and widespread adoption of transformer-based models, first introduced by Vaswani et al. in 2017. These models significantly outperformed previous SoTA NLP models based on Recurrent Neural Networks (RNNs), which were considered dead after that paper. Through this blogpost, we will introduce the integration of a new architecture, RWKV, that combines the advantages of both RNNs and transformers, and that has been recently integrated into the Hugging Face transformers library.
Overview of the RWKV project
The RWKV project was kicked off and is being led by Bo Peng, who is actively contributing and maintaining the project. The community, organized in the official discord channel, is constantly enhancing the project’s artifacts on various topics such as performance (RWKV.cpp, quantization, etc.), scalability (dataset processing & scrapping) and research (chat-fine tuning, multi-modal finetuning, etc.). The GPUs for training RWKV models are donated by Stability AI.
You can get involved by joining the official discord channel and learn more about the general ideas behind RWKV in these two blogposts: https://johanwind.github.io/2023/03/23/rwkv_overview.html / https://johanwind.github.io/2023/03/23/rwkv_details.html
Transformer Architecture vs RNNs
The RNN architecture is one of the first widely used Neural Network architectures for processing a sequence of data, contrary to classic architectures that take a fixed size input. It takes as input the current “token” (i.e. current data point of the datastream), the previous “state”, and computes the predicted next token, and the predicted next state. The new state is then used to compute the prediction of the next token, and so on. A RNN can be also used in different “modes”, therefore enabling the possibility of applying RNNs on different scenarios, as denoted by Andrej Karpathy’s blogpost, such as one-to-one (image-classification), one-to-many (image captioning), many-to-one (sequence classification), many-to-many (sequence generation), etc.
Overview of possible configurations of using RNNs. Source: Andrej Karpathy's blogpost |
Because RNNs use the same weights to compute predictions at every step, they struggle to memorize information for long-range sequences due to the vanishing gradient issue. Efforts have been made to address this limitation by introducing new architectures such as LSTMs or GRUs. However, the transformer architecture proved to be the most effective thus far in resolving this issue.
In the transformer architecture, the input tokens are processed simultaneously in the self-attention module. The tokens are first linearly projected into different spaces using the query, key and value weights. The resulting matrices are directly used to compute the attention scores (through softmax, as shown below), then multiplied by the value hidden states to obtain the final hidden states. This design enables the architecture to effectively mitigate the long-range sequence issue, and also perform faster inference and training compared to RNN models.
Formulation of attention scores in transformer models. Source: Jay Alammar's blogpost |
Formulation of attention scores in RWKV models. Source: RWKV blogpost |
During training, Transformer architecture has several advantages over traditional RNNs and CNNs. One of the most significant advantages is its ability to learn contextual representations. Unlike the RNNs and CNNs, which process input sequences one word at a time, Transformer architecture processes input sequences as a whole. This allows it to capture long-range dependencies between words in the sequence, which is particularly useful for tasks such as language translation and question answering.
During inference, RNNs have some advantages in speed and memory efficiency. These advantages include simplicity, due to needing only matrix-vector operations, and memory efficiency, as the memory requirements do not grow during inference. Furthermore, the computation speed remains the same with context window length due to how computations only act on the current token and the state.
The RWKV architecture
RWKV is inspired by Apple’s Attention Free Transformer. The architecture has been carefully simplified and optimized such that it can be transformed into an RNN. In addition, a number of tricks has been added such as TokenShift
& SmallInitEmb
(the list of tricks is listed in the README of the official GitHub repository) to boost its performance to match GPT. Without these, the model wouldn't be as performant.
For training, there is an infrastructure to scale the training up to 14B parameters as of now, and some issues have been iteratively fixed in RWKV-4 (latest version as of today), such as numerical instability.
RWKV as a combination of RNNs and transformers
How to combine the best of transformers and RNNs? The main drawback of transformer-based models is that it can become challenging to run a model with a context window that is larger than a certain value, as the attention scores are computed simultaneously for the entire sequence.
RNNs natively support very long context lengths - only limited by the context length seen in training, but this can be extended to millions of tokens with careful coding. Currently, there are RWKV models trained on a context length of 8192 (ctx8192
) and they are as fast as ctx1024
models and require the same amount of RAM.
The major drawbacks of traditional RNN models and how RWKV is different:
- Traditional RNN models are unable to utilize very long contexts (LSTM can only manage ~100 tokens when used as a LM). However, RWKV can utilize thousands of tokens and beyond, as shown below:
LM loss with respect to different context lengths and model sizes. Source: RWKV original repository |
- Traditional RNN models cannot be parallelized when training. RWKV is similar to a “linearized GPT” and it trains faster than GPT.
By combining both advantages into a single architecture, the hope is that RWKV can grow to become more than the sum of its parts.
RWKV attention formulation
The model architecture is very similar to classic transformer-based models (i.e. an embedding layer, multiple identical layers, layer normalization, and a Causal Language Modeling head to predict the next token). The only difference is on the attention layer, which is completely different from the traditional transformer-based models.
To gain a more comprehensive understanding of the attention layer, we recommend to delve into the detailed explanation provided in a blog post by Johan Sokrates Wind.
Existing checkpoints
Pure language models: RWKV-4 models
Most adopted RWKV models range from ~170M parameters to 14B parameters. According to the RWKV overview blog post, these models have been trained on the Pile dataset and evaluated against other SoTA models on different benchmarks, and they seem to perform quite well, with very comparable results against them.
RWKV-4 compared to other common architectures. Source: Johan Wind's blogpost |
Instruction Fine-tuned/Chat Version: RWKV-4 Raven
Bo has also trained a “chat” version of the RWKV architecture, the RWKV-4 Raven model. It is a RWKV-4 pile (RWKV model pretrained on The Pile dataset) model fine-tuned on ALPACA, CodeAlpaca, Guanaco, GPT4All, ShareGPT and more. The model is available in multiple versions, with models trained on different languages (English only, English + Chinese + Japanese, English + Japanese, etc.) and different sizes (1.5B parameters, 7B parameters, 14B parameters).
All the HF converted models are available on Hugging Face Hub, in the RWKV
organization.
🤗 Transformers integration
The architecture has been added to the transformers
library thanks to this Pull Request. As of the time of writing, you can use it by installing transformers
from source, or by using the main
branch of the library. The architecture is tightly integrated with the library, and you can use it as you would any other architecture.
Let us walk through some examples below.
Text Generation Example
To generate text given an input prompt you can use pipeline
to generate text:
from transformers import pipeline
model_id = "RWKV/rwkv-4-169m-pile"
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
pipe = pipeline("text-generation", model=model_id)
print(pipe(prompt, max_new_tokens=20))
>>> [{'generated_text': '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were'}]
Or you can run and start from the snippet below:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=20)
print(tokenizer.decode(output[0].tolist()))
>>> In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were
Use the raven models (chat models)
You can prompt the chat model in the alpaca style, here is an example below:
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "RWKV/rwkv-raven-1b5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id)
question = "Tell me about ravens"
prompt = f"### Instruction: {question}\n### Response:"
inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=100)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
>>> ### Instruction: Tell me about ravens
### Response: RAVENS are a type of bird that is native to the Middle East and North Africa. They are known for their intelligence, adaptability, and their ability to live in a variety of environments. RAVENS are known for their intelligence, adaptability, and their ability to live in a variety of environments. They are known for their intelligence, adaptability, and their ability to live in a variety of environments.
According to Bo, better instruction techniques are detailed in this discord message (make sure to join the channel before clicking)
Weights conversion
Any user could easily convert the original RWKV weights to the HF format by simply running the conversion script provided in the transformers
library. First, push the "raw" weights to the Hugging Face Hub (let's denote that repo as RAW_HUB_REPO
, and the raw file RAW_FILE
), then run the conversion script:
python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR
If you want to push the converted model on the Hub (let's say, under dummy_user/converted-rwkv
), first forget to log in with huggingface-cli login
before pushing the model, then run:
python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR --push_to_hub --model_name dummy_user/converted-rwkv
Future work
Multi-lingual RWKV
Bo is currently working on a multilingual corpus to train RWKV models. Recently a new multilingual tokenizer has been released.
Community-oriented and research projects
The RWKV community is very active and working on several follow up directions, a list of cool projects can be find in a dedicated channel on discord (make sure to join the channel before clicking the link). There is also a channel dedicated to research around this architecure, feel free to join and contribute!
Model Compression and Acceleration
Due to only needing matrix-vector operations, RWKV is an ideal candidate for non-standard and experimental computing hardware, such as photonic processors/accelerators.
Therefore, the architecture can also naturally benefit from classic acceleration and compression techniques (such as ONNX, 4-bit/8-bit quantization, etc.), and we hope this will be democratized for developers and practitioners together with the transformers integration of the architecture.
RWKV can also benefit from the acceleration techniques proposed by optimum
library in the near future.
Some of these techniques are highlighted in the rwkv.cpp
repository or rwkv-cpp-cuda
repository.
Acknowledgements
The Hugging Face team would like to thank Bo and RWKV community for their time and for answering our questions about the architecture. We would also like to thank them for their help and support and we look forward to see more adoption of RWKV models in the HF ecosystem.
We also would like to acknowledge the work of Johan Wind for his blogpost on RWKV, which helped us a lot to understand the architecture and its potential.
And finally, we would like to highlight anf acknowledge the work of ArEnSc for starting over the initial transformers
PR.
Also big kudos to Merve Noyan, Maria Khalusova and Pedro Cuenca for kindly reviewing this blogpost to make it much better!
Citation
If you use RWKV for your work, please use the following cff
citation.