SpiralAI Spiral-RetNet-3b-base
We have conducted pre-training from scratch on the RetNet (https://arxiv.org/abs/2307.08621) architecture model 3b using a mixed dataset of Japanese and English. This model is released primarily for the basic research of "retention mechanism".
Model Description
- Developed by: SpiralAI
- Model type: The
SpiralAI Spiral-RetNet-3b-base
is a language model equipped with a retention mechanism. It uses thecyberagent/calm2-7b-chat
tokenizer. - Languages: Japanese, English.
- License: MIT
- Training: Trained on 80b tokens.
- Context Length: 2,048 tokens.
Installation
pip install transformers==4.38 # The top_k_top_p_filtering feature has been removed in later versions.
Clone the repository from https://github.com/syncdoth/RetNet
and follow the Getting Started guide provided there.
Example:
git clone https://github.com/syncdoth/RetNet.git
pip install torch transformers timm
cd RetNet
Usage
from transformers import AutoTokenizer
from retnet.modeling_retnet import RetNetForCausalLM
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
tokenizer.pad_token = tokenizer.eos_token
model = RetNetForCausalLM.from_pretrained(
"Spiral-AI/Spiral-RetNet-3b-base", device_map="auto"
)
inputs = tokenizer("最近、秋葉原周辺で興味深い", return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
generated = model.generate(
input_ids,
max_new_tokens=32,
repetition_penalty=1.2, # better to set this value for 3 billion model
)
print(tokenizer.decode(generated[0]))
Examples
input: 最近、秋葉原周辺で興味深い
output: お店がいくつかあります。
1. 神田カレー街「カレーハウスCoCo壱番屋」
2016年7月3日オープン
input: 近年、AI技術の進歩によって
output: 人間の仕事が奪われるのではないかという懸念がある。
しかしながら、AIは人間に取って代わるものではなく、「人間がコンピュータに仕事をさせる」という考え方
input: When I was a child, I used to play with
output: 3-D glasses. They were so much fun!
I have been playing around in the world of video games for years now and it is amazing how
Basic study
Visualization of the retention mechanism
This visualization shows the retention mechanism in action. The token being generated is represented by *
.
The blue bars show how the tokens are weighted during generation.
Using the mathmatical equivalence between "recurrent mode" and "parallel mode", we apply the similar visualization technique as the attention mechanism, e.g., inner product between queries and keys are added up over all heads after absolute values are taken. Here we show the result of the last layer.
Test loss comparison
We compared the test loss of Spiral-AI/Spiral-RetNet-3b-base
and cyberagent/open-calm-3b
on different length of tokens.
The first 100 examples are extracted from wikipedia-ja
for the test dataset.
Key findings are:
- The test loss of
Spiral-AI/Spiral-RetNet-3b-base
goes as low ascyberagent/open-calm-3b
, showing the effectiveness of the retention mechanism. - The explosion of test loss is suppressed in
Spiral-AI/Spiral-RetNet-3b-base
when the context length goes longer than 2,048 tokens (the maximum context length of training data; Note thatcyberagent/open-calm-3b
is trained on the same context length.).
Training Datasets
- izumi-lab/cc100-ja-filter-ja-normal (Japanese)
- izumi-lab/wikipedia-ja-20230720 (Japanese)
- wikipedia (English)
- uonlp/CulturaX (English, Japanese)
Limitations
This model is designed for broad applicability, but it may not fully meet the specific needs or contexts of all uses. Pre-training data may contain inappropriate content, which could be reflected in the texts generated by the model. Therefore, when using this model, it is important to carefully review its output and avoid situations where it might cause discomfort or harm to individuals or groups.
There are no specific restrictions on commercial use, but users are responsible for addressing any ethical or legal issues that may arise in connection with the use of the model.
- Downloads last month
- 66