mustafaaljadery
commited on
Commit
•
ece19aa
1
Parent(s):
e70c9c5
Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,61 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
# Gemma 2B - 10M Context
|
5 |
+
|
6 |
+
Gemma 2B with recurrent local attention with context length of up to 10M. Our implemenation uses **<32GB** of memory!
|
7 |
+
|
8 |
+
![Graphic of our implementation context](./images/graphic.png)
|
9 |
+
|
10 |
+
**Features:**
|
11 |
+
|
12 |
+
- 10M sequence length on Gemma 2B.
|
13 |
+
- Runs on less then 32GB of memory.
|
14 |
+
- Native inference on Apple Silicon using MLX.
|
15 |
+
- Highly performing retrieval - needle in hay stack.
|
16 |
+
|
17 |
+
## Quick Start
|
18 |
+
|
19 |
+
> **Note:** This is a very early checkpoint of the model. Only 200 steps. We plan on training for a lot more tokens!
|
20 |
+
|
21 |
+
Install the model from huggingface - [Huggingface Model](https://huggingface.co/mustafaaljadery/gemma-10M-safetensor).
|
22 |
+
|
23 |
+
```bash
|
24 |
+
python main.py
|
25 |
+
```
|
26 |
+
|
27 |
+
Change the `main.py` inference code to the specific prompt you desire.
|
28 |
+
|
29 |
+
```python
|
30 |
+
model_path = "./models/gemma-2b-10m"
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
32 |
+
model = GemmaForCausalLM.from_pretrained(
|
33 |
+
model_path,
|
34 |
+
torch_dtype=torch.bfloat16
|
35 |
+
)
|
36 |
+
|
37 |
+
prompt_text = "Summarize this harry potter book..."
|
38 |
+
|
39 |
+
with torch.no_grad():
|
40 |
+
generated_text = generate(
|
41 |
+
model, tokenizer, prompt_text, max_length=512, temperature=0.8
|
42 |
+
)
|
43 |
+
|
44 |
+
print(generated_text)
|
45 |
+
```
|
46 |
+
|
47 |
+
## How does this work?
|
48 |
+
|
49 |
+
The largest bottleneck (in terms of memory) for LLMs is the KV cache. It grows quadratically in vanilla multi-head attention, thus limiting the size of your sequence length.
|
50 |
+
|
51 |
+
Our approach splits the attention in local attention blocks as outlined by [InfiniAttention](https://arxiv.org/abs/2404.07143). We take those local attention blocks and apply recurrance to the local attention blocks for the final result of 10M context global atention.
|
52 |
+
|
53 |
+
A lot of the inspiration for our ideas comes from the [Transformer-XL](https://arxiv.org/abs/1901.02860) paper.
|
54 |
+
|
55 |
+
## Credits
|
56 |
+
|
57 |
+
This was built by:
|
58 |
+
|
59 |
+
- [Mustafa Aljadery](https://www.maxaljadery.com/)
|
60 |
+
- [Siddharth Sharma](https://stanford.edu/~sidshr/)
|
61 |
+
- [Aksh Garg](https://www.linkedin.com/in/aksh-garg/)
|