Commit
•
92d683c
0
Parent(s):
Duplicate from togethercomputer/evo-1-131k-base
Browse filesCo-authored-by: Michael Poli <Zymrael@users.noreply.huggingface.co>
- .gitattributes +35 -0
- README.md +85 -0
- cache.py +44 -0
- config.json +90 -0
- configuration_hyena.py +92 -0
- engine.py +389 -0
- generation_config.json +4 -0
- layers.py +155 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.py +474 -0
- model.safetensors.index.json +445 -0
- modeling_hyena.py +145 -0
- positional_embeddings.py +113 -0
- pytorch_model.pt +3 -0
- special_tokens_map.json +1 -0
- streamer.py +106 -0
- tokenizer.py +129 -0
- tokenizer_config.json +14 -0
- utils.py +96 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- stripedhyena
|
5 |
+
- long context
|
6 |
+
- deep signal processing
|
7 |
+
- hybrid
|
8 |
+
- biology
|
9 |
+
- genomics
|
10 |
+
---
|
11 |
+
|
12 |
+
|
13 |
+
## Evo-1 (Phase 2)
|
14 |
+
|
15 |
+
<p align="center">
|
16 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/62a1306bbe7fa896d2c8de44/JoEHcvLTUlHoMcgh3mmAz.png" width="70%" />
|
17 |
+
</p>
|
18 |
+
|
19 |
+
|
20 |
+
### About
|
21 |
+
|
22 |
+
Evo is a biological foundation model capable of long-context modeling and design.
|
23 |
+
|
24 |
+
Evo uses the [StripedHyena architecture](https://github.com/togethercomputer/stripedhyena) to enable modeling of sequences at a single-nucleotide, byte-level resolution with near-linear scaling of compute and memory relative to context length.
|
25 |
+
Evo has 7 billion parameters and is trained on OpenGenome, a prokaryotic whole-genome dataset containing ~300 billion tokens.
|
26 |
+
|
27 |
+
Technical details about Evo can be found in our preprint and our accompanying blog posts. Evo was collaboratively developed by the [Arc Institute](https://arcinstitute.org/) and TogetherAI.
|
28 |
+
|
29 |
+
As part of our commitment to open science, we release **weights of 15 intermediate pretraining checkpoints** for phase 1 and phase 2 of pretraining. The checkpoints are available as branches of the corresponding HuggingFace repository.
|
30 |
+
|
31 |
+
**Evo-1 (Phase 2)** is our **longer context model** in the Evo family, trained at a context length of 131k and tested on generation of sequences of length >650k
|
32 |
+
|
33 |
+
| Checkpoint Name | Description |
|
34 |
+
|----------------------------------------|-------------|
|
35 |
+
| `evo-1-8k-base` | A model pretrained with 8,192 context. We use this model as the base model for molecular-scale finetuning tasks. |
|
36 |
+
| `evo-1-131k-base` | A model pretrained with 131,072 context using `evo-1-8k-base` as the initialization. We use this model to reason about and generate sequences at the genome scale. |
|
37 |
+
|
38 |
+
### Model Architecture
|
39 |
+
|
40 |
+
StripedHyena is a deep signal processing, hybrid architecture composed of multi-head attention and gated convolutions arranged in [Hyena](https://arxiv.org/abs/2302.10866) blocks, improving over decoder-only Transformers.
|
41 |
+
|
42 |
+
StripedHyena is designed to leverage the specialization of each of its layer classes, with Hyena layers implementing the bulk of the computation required for sequence processing and attention layers supplementing the ability to perform targeted pattern recall.
|
43 |
+
|
44 |
+
|
45 |
+
Some highlights of the architecture:
|
46 |
+
- **Efficient autoregressive generation** via a recurrent mode (>500k generation with a single 80GB GPU)
|
47 |
+
- **Significantly faster training and finetuning** at long context (>3x at 131k)
|
48 |
+
- **Improved scaling laws over state-of-the-art architectures** (e.g., Transformer++) on both natural language and biological sequences.
|
49 |
+
- **Robust to training beyond the compute-optimal frontier** e.g., training way beyond Chinchilla-optimal token amounts (see preprint for details -- more details to come)
|
50 |
+
|
51 |
+
|
52 |
+
### How to use Evo
|
53 |
+
|
54 |
+
Example usage is provided in the [standalone repo](https://github.com/evo-design/evo).
|
55 |
+
|
56 |
+
|
57 |
+
#### Parametrization for Inference and Finetuning
|
58 |
+
|
59 |
+
One of the advantages of deep signal processing models is their flexibility. Different parametrizations of convolutions can be used depending on the memory, expressivity and causality requirements of pretraining, finetuning or inference workloads.
|
60 |
+
|
61 |
+
The main classes are:
|
62 |
+
- Modal canonical: unconstrained poles ([reference](https://arxiv.org/pdf/2203.14343.pdf), [reference](https://arxiv.org/abs/2310.18780)), or constrained poles ([reference](https://arxiv.org/abs/2206.11893), [reference](https://arxiv.org/pdf/2303.06349.pdf)).
|
63 |
+
- Companion canonical / rational: TBA.
|
64 |
+
- Hypernetworks: hypernetwork ([reference](https://arxiv.org/abs/2102.02611)), modulated hypernetwork ([reference](https://arxiv.org/abs/2302.10866)).
|
65 |
+
- Explicit: modulated explicit ([reference](https://arxiv.org/pdf/2210.09298.pdf)).
|
66 |
+
|
67 |
+
StripedHyena is a mixed precision model. Make sure to keep your `poles` and `residues` in `float32` precision, especially for longer prompts or training.
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
### Disclaimer
|
72 |
+
|
73 |
+
To use StripedHyena outside of the playground, you will need to install custom kernels. Please follow the instructions from the [standalone repository](https://github.com/togethercomputer/stripedhyena).
|
74 |
+
|
75 |
+
## Cite
|
76 |
+
|
77 |
+
```
|
78 |
+
@article{nguyen2024sequence,
|
79 |
+
author = {Eric Nguyen and Michael Poli and Matthew G. Durrant and Armin W. Thomas and Brian Kang and Jeremy Sullivan and Madelena Y. Ng and Ashley Lewis and Aman Patel and Aaron Lou and Stefano Ermon and Stephen A. Baccus and Tina Hernandez-Boussard and Christopher Ré and Patrick D. Hsu and Brian L. Hie},
|
80 |
+
journal = {Arc Institute manuscripts},
|
81 |
+
title = {Sequence modeling and design from molecular to genome scale with Evo},
|
82 |
+
url = {https://arcinstitute.org/manuscripts/Evo},
|
83 |
+
year = {2024},
|
84 |
+
}
|
85 |
+
```
|
cache.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Together
|
2 |
+
# This software is distributed under the terms of the Apache License, Version 2.0
|
3 |
+
# Author: Michael Poli
|
4 |
+
|
5 |
+
from torch import Tensor
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
|
10 |
+
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
|
11 |
+
@dataclass
|
12 |
+
class InferenceParams:
|
13 |
+
"""Inference parameters that are passed to the main model in order
|
14 |
+
to efficienly calculate and store the context during inference."""
|
15 |
+
|
16 |
+
max_seqlen: int
|
17 |
+
max_batch_size: int
|
18 |
+
seqlen_offset: int = 0
|
19 |
+
batch_size_offset: int = 0
|
20 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
21 |
+
lengths_per_sample: Optional[Tensor] = None
|
22 |
+
|
23 |
+
def reset(self, max_seqlen, max_batch_size):
|
24 |
+
self.max_seqlen = max_seqlen
|
25 |
+
self.max_batch_size = max_batch_size
|
26 |
+
self.seqlen_offset = 0
|
27 |
+
if self.lengths_per_sample is not None:
|
28 |
+
self.lengths_per_sample.zero_()
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class RecurrentInferenceParams:
|
33 |
+
"""Inference parameters passed to blocks with recurrent mode."""
|
34 |
+
|
35 |
+
fir_filter_length: int = 3
|
36 |
+
state_dim: int = 16
|
37 |
+
seqlen_offset: int = 0
|
38 |
+
fir_state_dict: dict = field(default_factory=dict)
|
39 |
+
state_dict: dict = field(default_factory=dict)
|
40 |
+
|
41 |
+
def reset(self):
|
42 |
+
self.fir_filter_length = 3
|
43 |
+
self.state_dim = 16
|
44 |
+
self.seqlen_offset = 0
|
config.json
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": null,
|
3 |
+
"_name_or_path": "togethercomputer/evo-1-131k-base",
|
4 |
+
"architectures": [
|
5 |
+
"StripedHyenaModelForCausalLM"
|
6 |
+
],
|
7 |
+
"attn_layer_idxs": [
|
8 |
+
8,
|
9 |
+
16,
|
10 |
+
24
|
11 |
+
],
|
12 |
+
"auto_map": {
|
13 |
+
"AutoConfig": "configuration_hyena.StripedHyenaConfig",
|
14 |
+
"AutoModelForCausalLM": "modeling_hyena.StripedHyenaModelForCausalLM",
|
15 |
+
"AutoTokenizer": [
|
16 |
+
"tokenizer.ByteTokenizer",
|
17 |
+
null
|
18 |
+
]
|
19 |
+
},
|
20 |
+
"column_split": false,
|
21 |
+
"column_split_hyena": true,
|
22 |
+
"eps": 1e-06,
|
23 |
+
"final_norm": true,
|
24 |
+
"hidden_size": 4096,
|
25 |
+
"hyena_filter_groups": 1,
|
26 |
+
"hyena_layer_idxs": [
|
27 |
+
0,
|
28 |
+
1,
|
29 |
+
2,
|
30 |
+
3,
|
31 |
+
4,
|
32 |
+
5,
|
33 |
+
6,
|
34 |
+
7,
|
35 |
+
9,
|
36 |
+
10,
|
37 |
+
11,
|
38 |
+
12,
|
39 |
+
13,
|
40 |
+
14,
|
41 |
+
15,
|
42 |
+
17,
|
43 |
+
18,
|
44 |
+
19,
|
45 |
+
20,
|
46 |
+
21,
|
47 |
+
22,
|
48 |
+
23,
|
49 |
+
25,
|
50 |
+
26,
|
51 |
+
27,
|
52 |
+
28,
|
53 |
+
29,
|
54 |
+
30,
|
55 |
+
31
|
56 |
+
],
|
57 |
+
"inference_mode": false,
|
58 |
+
"inner_mlp_size": 10928,
|
59 |
+
"log_intermediate_values": false,
|
60 |
+
"make_vocab_size_divisible_by": 8,
|
61 |
+
"max_seqlen": 131072,
|
62 |
+
"mha_out_proj_bias": true,
|
63 |
+
"mlp_activation": "gelu",
|
64 |
+
"model_parallel_size": 1,
|
65 |
+
"model_type": "stripedhyena",
|
66 |
+
"num_attention_heads": 32,
|
67 |
+
"num_filters": 4096,
|
68 |
+
"num_layers": 32,
|
69 |
+
"pipe_parallel_size": 1,
|
70 |
+
"prefill_style": "fft",
|
71 |
+
"proj_groups": 1,
|
72 |
+
"qkv_proj_bias": true,
|
73 |
+
"rotary_emb_base": 10000,
|
74 |
+
"rotary_emb_scaling_factor": 16,
|
75 |
+
"short_filter_bias": true,
|
76 |
+
"short_filter_length": 3,
|
77 |
+
"smeared_gqa": false,
|
78 |
+
"split_k0": true,
|
79 |
+
"state_size": 8,
|
80 |
+
"tie_embeddings": true,
|
81 |
+
"torch_dtype": "bfloat16",
|
82 |
+
"transformers_version": null,
|
83 |
+
"use_cache": true,
|
84 |
+
"use_flash_attention_2": true,
|
85 |
+
"use_flash_depthwise": false,
|
86 |
+
"use_flash_rmsnorm": false,
|
87 |
+
"use_flashfft": false,
|
88 |
+
"use_interpolated_rotary_pos_emb": true,
|
89 |
+
"vocab_size": 512
|
90 |
+
}
|
configuration_hyena.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
class StripedHyenaConfig(PretrainedConfig):
|
6 |
+
model_type = "stripedhyena"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
vocab_size=32000,
|
11 |
+
hidden_size=4096,
|
12 |
+
num_filters=4096,
|
13 |
+
inner_mlp_size=14336,
|
14 |
+
attn_layer_idxs=[],
|
15 |
+
hyena_layer_idxs=[],
|
16 |
+
num_layers=32,
|
17 |
+
tie_embeddings=False,
|
18 |
+
short_filter_length=3,
|
19 |
+
num_attention_heads=32,
|
20 |
+
proj_groups=4,
|
21 |
+
hyena_filter_groups=1,
|
22 |
+
split_k0=True,
|
23 |
+
column_split_hyena=True,
|
24 |
+
column_split=False,
|
25 |
+
model_parallel_size=1,
|
26 |
+
pipe_parallel_size=1,
|
27 |
+
short_filter_bias=True,
|
28 |
+
mha_out_proj_bias=False,
|
29 |
+
qkv_proj_bias=False,
|
30 |
+
final_norm=True,
|
31 |
+
use_cache=True,
|
32 |
+
use_flash_attention_2=True,
|
33 |
+
use_flash_rmsnorm=True,
|
34 |
+
use_flash_depthwise=False,
|
35 |
+
use_flashfft=False,
|
36 |
+
inference_mode=False,
|
37 |
+
prefill_style="fft",
|
38 |
+
max_seqlen=32768,
|
39 |
+
eps=1e-5,
|
40 |
+
state_size=2,
|
41 |
+
rotary_emb_base=500000,
|
42 |
+
smeared_gqa=False,
|
43 |
+
make_vocab_size_divisible_by=8,
|
44 |
+
log_intermediate_values=False,
|
45 |
+
**kwargs,
|
46 |
+
):
|
47 |
+
self.vocab_size = vocab_size
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.num_filters = num_filters
|
50 |
+
self.inner_mlp_size = inner_mlp_size
|
51 |
+
self.attn_layer_idxs = attn_layer_idxs
|
52 |
+
self.hyena_layer_idxs = hyena_layer_idxs
|
53 |
+
self.num_layers = num_layers
|
54 |
+
self.tie_embeddings = tie_embeddings
|
55 |
+
self.short_filter_length = short_filter_length
|
56 |
+
self.num_attention_heads = num_attention_heads
|
57 |
+
self.proj_groups = proj_groups
|
58 |
+
self.hyena_filter_groups = hyena_filter_groups
|
59 |
+
self.split_k0 = split_k0
|
60 |
+
self.column_split_hyena = column_split_hyena
|
61 |
+
self.column_split = column_split
|
62 |
+
self.model_parallel_size = model_parallel_size
|
63 |
+
self.pipe_parallel_size = pipe_parallel_size
|
64 |
+
self.short_filter_bias = short_filter_bias
|
65 |
+
self.mha_out_proj_bias = mha_out_proj_bias
|
66 |
+
self.qkv_proj_bias = qkv_proj_bias
|
67 |
+
self.final_norm = final_norm
|
68 |
+
self.use_cache = use_cache
|
69 |
+
self.use_flash_attention_2 = use_flash_attention_2
|
70 |
+
self.use_flash_rmsnorm = use_flash_rmsnorm
|
71 |
+
self.use_flash_depthwise = use_flash_depthwise
|
72 |
+
self.use_flashfft = use_flashfft
|
73 |
+
self.inference_mode = inference_mode
|
74 |
+
self.prefill_style = prefill_style
|
75 |
+
self.max_seqlen = max_seqlen
|
76 |
+
self.eps = eps
|
77 |
+
self.state_size = state_size
|
78 |
+
self.rotary_emb_base = rotary_emb_base
|
79 |
+
self.smeared_gqa = smeared_gqa
|
80 |
+
self.make_vocab_size_divisible_by = make_vocab_size_divisible_by
|
81 |
+
self.log_intermediate_values = log_intermediate_values
|
82 |
+
super().__init__(**kwargs)
|
83 |
+
|
84 |
+
def to_dict(self):
|
85 |
+
return {attr: getattr(self, attr) for attr in self.__dict__}
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def from_original_config(cls, config_path, **kwargs):
|
89 |
+
with open(config_path, "r") as f:
|
90 |
+
config = json.load(f)
|
91 |
+
|
92 |
+
return cls(**config, **kwargs)
|
engine.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Together
|
2 |
+
# This software is distributed under the terms of the Apache License, Version 2.0
|
3 |
+
# Author: Michael Poli
|
4 |
+
|
5 |
+
import gc
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
try:
|
12 |
+
import conv1d_cpp
|
13 |
+
except:
|
14 |
+
pass
|
15 |
+
from .utils import column_split
|
16 |
+
|
17 |
+
IIR_PREFILL_MODES = [
|
18 |
+
"recurrence",
|
19 |
+
"modal-fft",
|
20 |
+
"hybrid-modal-recurrence",
|
21 |
+
"modal-scan",
|
22 |
+
"canonical-fft",
|
23 |
+
"iir-fir-caching",
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def canonicalize_modal_system(poles, residues):
|
28 |
+
"""Canonicalize a modal system.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
poles (Tensor): The poles of the system.
|
32 |
+
residues (Tensor): The residues of the system.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tuple[Tensor, Tensor]: The canonicalized poles and residues.
|
36 |
+
"""
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
|
40 |
+
def list_tensors(idx):
|
41 |
+
for obj in gc.get_objects():
|
42 |
+
try:
|
43 |
+
if torch.is_tensor(obj) and isinstance(obj, torch.Tensor):
|
44 |
+
# dump to log
|
45 |
+
print(type(obj), obj.size())
|
46 |
+
el = obj[0]
|
47 |
+
with open(f"tensors_{idx}.txt", "a") as f:
|
48 |
+
f.write(f"{type(obj)} {obj.size()} {el}\n")
|
49 |
+
except Exception as e:
|
50 |
+
pass
|
51 |
+
|
52 |
+
|
53 |
+
class HyenaInferenceEngine:
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
fir_fn=None,
|
57 |
+
iir_prefill_style="modal-fft",
|
58 |
+
layer_idx=None,
|
59 |
+
) -> None:
|
60 |
+
self.fir_fn = fir_fn
|
61 |
+
assert iir_prefill_style in IIR_PREFILL_MODES, f"iir_prefill_style must be one of {IIR_PREFILL_MODES}"
|
62 |
+
self.iir_prefill_style = iir_prefill_style
|
63 |
+
self.layer_idx = layer_idx
|
64 |
+
self.low_mem_mode = False
|
65 |
+
|
66 |
+
def parallel_fir(
|
67 |
+
self,
|
68 |
+
fir_fn,
|
69 |
+
u,
|
70 |
+
weight,
|
71 |
+
bias,
|
72 |
+
L,
|
73 |
+
fir_length=3,
|
74 |
+
inference_params=None,
|
75 |
+
prefill_mode=None,
|
76 |
+
padding_mask=None,
|
77 |
+
):
|
78 |
+
"""Compute the output state of the long convolutional filter."""
|
79 |
+
# prepare input layout, dimensions and dispatch to fir kernel
|
80 |
+
if fir_fn != torch.nn.functional.conv1d:
|
81 |
+
z_pre = fir_fn(u)[:, :L] # B, L, D
|
82 |
+
z_pre = z_pre.permute(0, 2, 1)
|
83 |
+
else:
|
84 |
+
u = u.permute(0, 2, 1) # B, D, L
|
85 |
+
z_pre = fir_fn(
|
86 |
+
u,
|
87 |
+
weight,
|
88 |
+
bias=None, # don't pass it here, add manually instead! source of small error
|
89 |
+
stride=1,
|
90 |
+
padding=fir_length - 1,
|
91 |
+
groups=u.shape[1],
|
92 |
+
)[..., :L]
|
93 |
+
|
94 |
+
# add manually instead! source of small error
|
95 |
+
z_pre = z_pre + bias[None, :, None]
|
96 |
+
|
97 |
+
# handle padding post fir, the only place with biases
|
98 |
+
if type(padding_mask) == torch.Tensor:
|
99 |
+
z_pre = z_pre * padding_mask[:, None]
|
100 |
+
|
101 |
+
if inference_params is not None:
|
102 |
+
# handle seqlen last and dim last cases for `u`
|
103 |
+
if fir_fn != torch.nn.functional.conv1d:
|
104 |
+
fir_state = u[:, -fir_length + 1 :].permute(0, 2, 1)
|
105 |
+
else:
|
106 |
+
fir_state = u[..., -fir_length + 1 :]
|
107 |
+
else:
|
108 |
+
fir_state = None
|
109 |
+
|
110 |
+
return z_pre, fir_state
|
111 |
+
|
112 |
+
def parallel_iir(
|
113 |
+
self,
|
114 |
+
z_pre,
|
115 |
+
h,
|
116 |
+
D,
|
117 |
+
L,
|
118 |
+
poles,
|
119 |
+
residues,
|
120 |
+
t,
|
121 |
+
dims,
|
122 |
+
layer_idx,
|
123 |
+
inference_params=None,
|
124 |
+
prefill_style="fft",
|
125 |
+
fftconv_fn=None,
|
126 |
+
padding_mask=None,
|
127 |
+
use_flashfft=False,
|
128 |
+
column_split_hyena=False,
|
129 |
+
long_fir_threshold=None,
|
130 |
+
):
|
131 |
+
"""Compute the output state of the short convolutional filter."""
|
132 |
+
fft_size = 2 * L
|
133 |
+
hidden_size, num_attention_heads, hidden_size_per_attention_head, _, _ = dims
|
134 |
+
# Compatibility with training infra that column splits the projections
|
135 |
+
if column_split_hyena:
|
136 |
+
z = z_pre.reshape(
|
137 |
+
z_pre.shape[0],
|
138 |
+
num_attention_heads,
|
139 |
+
3 * hidden_size_per_attention_head,
|
140 |
+
z_pre.shape[2],
|
141 |
+
)
|
142 |
+
x2, x1, v = (
|
143 |
+
z[:, :, :hidden_size_per_attention_head],
|
144 |
+
z[
|
145 |
+
:,
|
146 |
+
:,
|
147 |
+
hidden_size_per_attention_head : 2 * hidden_size_per_attention_head,
|
148 |
+
],
|
149 |
+
z[:, :, 2 * hidden_size_per_attention_head :],
|
150 |
+
)
|
151 |
+
x2, x1, v = (
|
152 |
+
x2.reshape(x2.shape[0], -1, x2.shape[-1]),
|
153 |
+
x1.reshape(x1.shape[0], -1, x1.shape[-1]),
|
154 |
+
v.reshape(v.shape[0], -1, v.shape[-1]),
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
x2, x1, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
|
158 |
+
|
159 |
+
x1v = x1 * v
|
160 |
+
|
161 |
+
if inference_params is not None and prefill_style == "recurrence":
|
162 |
+
y = self.prefill_via_direct_recurrence(
|
163 |
+
inference_params=inference_params,
|
164 |
+
x1v=x1v,
|
165 |
+
L=L,
|
166 |
+
poles=poles,
|
167 |
+
residues=residues,
|
168 |
+
)
|
169 |
+
|
170 |
+
else:
|
171 |
+
if use_flashfft and (L % 2) == 0: # only works with even L
|
172 |
+
y = fftconv_fn(
|
173 |
+
x1v.to(dtype=torch.bfloat16).contiguous(),
|
174 |
+
h.to(dtype=torch.float32),
|
175 |
+
)
|
176 |
+
X_s = None
|
177 |
+
|
178 |
+
elif long_fir_threshold is None:
|
179 |
+
H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
|
180 |
+
X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
|
181 |
+
X = X_s[..., : H.shape[-1]]
|
182 |
+
if len(z_pre.shape) > 3:
|
183 |
+
H = H.unsqueeze(1)
|
184 |
+
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
|
185 |
+
|
186 |
+
else:
|
187 |
+
assert h.shape[0] == 1, "batch size must be 1 for long_fir_threshold"
|
188 |
+
h = h[0][:, None] # rearrange to d, 1, l for depthwise conv1d
|
189 |
+
h = h[..., :long_fir_threshold]
|
190 |
+
y = F.conv1d(
|
191 |
+
x1v,
|
192 |
+
h.to(dtype=x1v.dtype),
|
193 |
+
stride=1,
|
194 |
+
groups=x1v.shape[1],
|
195 |
+
padding=h.shape[-1] - 1,
|
196 |
+
)[..., :L]
|
197 |
+
|
198 |
+
y = y.to(dtype=x1v.dtype)
|
199 |
+
y = (y + x1v * D.unsqueeze(-1)) * x2
|
200 |
+
|
201 |
+
if inference_params is not None:
|
202 |
+
if prefill_style == "fft":
|
203 |
+
self.prefill_via_modal_fft(
|
204 |
+
inference_params=inference_params,
|
205 |
+
x1v=x1v,
|
206 |
+
X_s=X_s,
|
207 |
+
L=L,
|
208 |
+
t=t,
|
209 |
+
poles=poles,
|
210 |
+
dims=dims,
|
211 |
+
layer_idx=layer_idx,
|
212 |
+
use_flashfft=use_flashfft,
|
213 |
+
fftconv_fn=fftconv_fn,
|
214 |
+
)
|
215 |
+
|
216 |
+
elif prefill_style == "recurrence":
|
217 |
+
# recurrent prefill is done before
|
218 |
+
pass
|
219 |
+
else:
|
220 |
+
raise NotImplementedError
|
221 |
+
if self.low_mem_mode:
|
222 |
+
# TODO: smarter gc
|
223 |
+
del z_pre, x2, x1, v, x1v, h, poles, residues
|
224 |
+
torch.cuda.empty_cache()
|
225 |
+
|
226 |
+
return y.permute(0, 2, 1)
|
227 |
+
|
228 |
+
def step_fir(self, u, fir_state, weight, bias=None):
|
229 |
+
"""Step the FIR filter.
|
230 |
+
|
231 |
+
Note:
|
232 |
+
`fir_state` contains the last `short_filter_length - 1` elements of `u`: `u_(L-2), u_{L-1), ...`
|
233 |
+
We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]` (SISO / multi SISO layout).
|
234 |
+
"""
|
235 |
+
h0, h = weight[..., 0, -1], weight[..., 0, :-1]
|
236 |
+
h0, h = h0[None], h[None]
|
237 |
+
y = h0 * u + torch.sum(fir_state * h, dim=-1) + bias
|
238 |
+
|
239 |
+
# update
|
240 |
+
fir_state = torch.roll(fir_state, -1, dims=2)
|
241 |
+
fir_state[..., -1] = u
|
242 |
+
return y, fir_state
|
243 |
+
|
244 |
+
def step_iir(self, x2, x1, v, D, residues, poles, iir_state, iir_groups=1):
|
245 |
+
x1v = x1 * v
|
246 |
+
|
247 |
+
residues, poles = (
|
248 |
+
torch.view_as_complex(residues.to(torch.float32)),
|
249 |
+
torch.view_as_complex(poles.to(torch.float32)),
|
250 |
+
)
|
251 |
+
# squeeze the dummy seqlen dimension
|
252 |
+
# D, state_dim, 1 -> 1, D, state_dim
|
253 |
+
residues, poles = residues[..., 0][None], poles[..., 0][None]
|
254 |
+
iir_state = poles * iir_state + x1v[..., None]
|
255 |
+
|
256 |
+
res_state = torch.sum(residues * iir_state, dim=-1).real
|
257 |
+
|
258 |
+
if iir_groups > 1:
|
259 |
+
raise NotImplementedError
|
260 |
+
y = x2 * (res_state + D * x1v)
|
261 |
+
|
262 |
+
return y, iir_state
|
263 |
+
|
264 |
+
def prefill_via_fir_caching(self, u, inference_params, L, *args, **kwargs):
|
265 |
+
"""Turns the IIR filter into a FIR and uses a cache for decoding."""
|
266 |
+
raise NotImplementedError(":)")
|
267 |
+
|
268 |
+
def prefill_via_direct_recurrence(
|
269 |
+
self, inference_params, x1v, L, residues, poles, *args, **kwargs
|
270 |
+
) -> torch.Tensor:
|
271 |
+
"""
|
272 |
+
Compute the IIR state via explicit SSM recurrence (modal form)
|
273 |
+
|
274 |
+
This is the most memory efficient prefilling method for Hyena filters.
|
275 |
+
|
276 |
+
Note:
|
277 |
+
dtypes: [state: float32, poles: float32, x1v: bfloat16, output: bfloat16]
|
278 |
+
"""
|
279 |
+
state_dim = poles.shape[1]
|
280 |
+
x1v_ = x1v[..., None, None] # b, d, l, sdim, reim
|
281 |
+
x1v_ = x1v_.repeat(1, 1, 1, state_dim, 2) # b, d, l, sdim, reim
|
282 |
+
x1v_[..., 1] = 0
|
283 |
+
|
284 |
+
state = 0 * x1v_[:, :, 0]
|
285 |
+
output = 0 * x1v_[:, :, :, 0, 0] # b, d, l
|
286 |
+
|
287 |
+
# suppress dummy seqlen dimension
|
288 |
+
poles = poles[:, :, 0][None]
|
289 |
+
residues = residues[:, :, 0][None].repeat(x1v_.shape[0], 1, 1, 1) # b, d, sdim, reim
|
290 |
+
|
291 |
+
# state: b, d, sdim, reim
|
292 |
+
# poles: 1, d, sdim, reim
|
293 |
+
# x1v_: b, d, l, sdim, reim
|
294 |
+
for i in range(L):
|
295 |
+
state[..., 0] = poles[..., 0] * state[..., 0] - poles[..., 1] * state[..., 1] + x1v_[:, :, i, :, 0]
|
296 |
+
state[..., 1] = poles[..., 0] * state[..., 1] + poles[..., 1] * state[..., 0] + x1v_[:, :, i, :, 1]
|
297 |
+
output[:, :, i] = torch.sum(residues * state, dim=-2)[..., 0] # .real
|
298 |
+
|
299 |
+
inference_params.state_dict[self.layer_idx] = torch.view_as_complex(state.to(dtype=torch.float32))
|
300 |
+
|
301 |
+
return output
|
302 |
+
|
303 |
+
def prefill_via_hybrid_recurrence(self, inference_params, u, log_poles, x1v_f_a, L, *args, **kwargs):
|
304 |
+
"""
|
305 |
+
Compute the IIR state via hybrid recurrence-convolution over blocks
|
306 |
+
"""
|
307 |
+
raise NotImplementedError(":)")
|
308 |
+
|
309 |
+
def prefill_via_scan(self, u, inference_params=None, *args, **kwargs):
|
310 |
+
raise NotImplementedError
|
311 |
+
|
312 |
+
def prefill_via_canonical_fft(self, u, inference_params=None, *args, **kwargs):
|
313 |
+
"""
|
314 |
+
Compute the IIR state via a single FFT with the denominator of the SSM in companion form.
|
315 |
+
|
316 |
+
This is the most memory efficient "parallelized" prefilling method for Hyena.
|
317 |
+
|
318 |
+
From: https://arxiv.org/abs/2310.18780
|
319 |
+
"""
|
320 |
+
raise NotImplementedError(":)")
|
321 |
+
|
322 |
+
def prefill_via_modal_fft(
|
323 |
+
self,
|
324 |
+
inference_params,
|
325 |
+
x1v,
|
326 |
+
L,
|
327 |
+
poles,
|
328 |
+
t,
|
329 |
+
dims,
|
330 |
+
layer_idx,
|
331 |
+
X_s=None,
|
332 |
+
use_flashfft=False,
|
333 |
+
fftconv_fn=None,
|
334 |
+
state_dtype=torch.complex64,
|
335 |
+
*args,
|
336 |
+
**kwargs,
|
337 |
+
):
|
338 |
+
"""
|
339 |
+
Compute the IIR state via a single FFT, using the poles of the SSM in modal form.
|
340 |
+
"""
|
341 |
+
# When the model has a long convolution derived from a SSM in modal form and prefill_style is "fft",
|
342 |
+
# we split the filter into poles and residues and reuse FFT computation on the input.
|
343 |
+
# This optimization is currently not supported when using flashfftconv.
|
344 |
+
hidden_size, _, _, state_size, hyena_filter_groups = dims
|
345 |
+
|
346 |
+
if use_flashfft:
|
347 |
+
# using real states
|
348 |
+
poles = poles.squeeze().reshape(poles.shape[0], -1)[..., None]
|
349 |
+
|
350 |
+
state_s = poles**t
|
351 |
+
if hyena_filter_groups > 1:
|
352 |
+
raise NotImplementedError
|
353 |
+
|
354 |
+
x1v = x1v[:, :, None].repeat(1, 1, 2 * state_size, 1)
|
355 |
+
x1v = x1v.reshape(x1v.shape[0], -1, x1v.shape[-1])
|
356 |
+
state_s = state_s[None]
|
357 |
+
|
358 |
+
state = fftconv_fn(
|
359 |
+
x1v.contiguous(),
|
360 |
+
state_s.to(dtype=torch.float32),
|
361 |
+
)
|
362 |
+
state = state[..., L - 1].reshape(x1v.shape[0], hidden_size, state_size, 2)
|
363 |
+
state = torch.view_as_complex(state.contiguous().to(dtype=torch.float32))
|
364 |
+
inference_params.state_dict[self.layer_idx] = state
|
365 |
+
else:
|
366 |
+
assert X_s is not None
|
367 |
+
bs = x1v.shape[0]
|
368 |
+
fft_size = 2 * L
|
369 |
+
poles = torch.view_as_complex(poles.to(torch.float32))
|
370 |
+
state_s = poles**t
|
371 |
+
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
|
372 |
+
if hyena_filter_groups > 1:
|
373 |
+
state_S = state_S.repeat_interleave(hidden_size // hyena_filter_groups, 1)
|
374 |
+
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
|
375 |
+
inference_params.state_dict[layer_idx] = state[..., L - 1].to(dtype=state_dtype)
|
376 |
+
|
377 |
+
def _compute_state(self, log_poles, u, t, L, *args, **kwargs):
|
378 |
+
"""
|
379 |
+
Compute the IIR state given an input `u` and log_poles of the modal system.
|
380 |
+
"""
|
381 |
+
bs = u.shape[0]
|
382 |
+
fft_size = 2 * L
|
383 |
+
U = torch.fft.rfft(u.to(torch.float32), n=fft_size)
|
384 |
+
fft_size = 2 * L
|
385 |
+
x = (log_poles * t).exp()
|
386 |
+
# [batch, hidden_size, state_dim, 2 * seqlen]
|
387 |
+
X = torch.fft.fft(x, n=fft_size).repeat(bs, 1, 1, 1)
|
388 |
+
state = torch.fft.ifft(U[..., None, :] * X, n=fft_size)[..., :L]
|
389 |
+
return state
|
generation_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.36.2"
|
4 |
+
}
|
layers.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Together
|
2 |
+
# This software is distributed under the terms of the Apache License, Version 2.0
|
3 |
+
# Author: Michael Poli
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.nn as nn
|
9 |
+
from .utils import grab_first_if_tuple
|
10 |
+
|
11 |
+
def grab_first_if_tuple(x):
|
12 |
+
if x.__class__.__name__ == "tuple":
|
13 |
+
return x[0]
|
14 |
+
else:
|
15 |
+
return x
|
16 |
+
|
17 |
+
class RMSNorm(torch.nn.Module):
|
18 |
+
def __init__(self, config):
|
19 |
+
super(RMSNorm, self).__init__()
|
20 |
+
self.eps, self.hidden_size = config.eps, config.hidden_size
|
21 |
+
self.scale = torch.nn.Parameter(torch.ones(self.hidden_size))
|
22 |
+
self.register_parameter("scale", self.scale)
|
23 |
+
self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False)
|
24 |
+
|
25 |
+
if self.use_flash_rmsnorm:
|
26 |
+
try:
|
27 |
+
from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func
|
28 |
+
|
29 |
+
self.rmsnorm_func = rmsnorm_func
|
30 |
+
except:
|
31 |
+
raise ImportError(
|
32 |
+
"For `use_flash_rmsnorm`: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/layer_norm`"
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
if self.use_flash_rmsnorm:
|
37 |
+
return self.rmsnorm_func(x, self.scale, self.eps)
|
38 |
+
else:
|
39 |
+
y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps)
|
40 |
+
return self.scale * y
|
41 |
+
|
42 |
+
|
43 |
+
class ParallelGatedMLP(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
config,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
multiple_of = config.get("inner_size_multiple_of", 64)
|
51 |
+
self.act_type = config.get("mlp_activation", "silu")
|
52 |
+
if self.act_type == "gelu":
|
53 |
+
self.act = F.gelu
|
54 |
+
elif self.act_type == "silu":
|
55 |
+
self.act = F.silu
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
self.multiple_of = multiple_of * config.model_parallel_size
|
60 |
+
|
61 |
+
inner_size = int(2 * config.hidden_size * 4 / 3)
|
62 |
+
inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of)
|
63 |
+
if config.get("inner_mlp_size", None) is not None:
|
64 |
+
inner_size = config.inner_mlp_size
|
65 |
+
|
66 |
+
self.l1 = nn.Linear(
|
67 |
+
in_features=config.hidden_size,
|
68 |
+
out_features=inner_size,
|
69 |
+
bias=False,
|
70 |
+
)
|
71 |
+
self.l2 = nn.Linear(
|
72 |
+
in_features=config.hidden_size,
|
73 |
+
out_features=inner_size,
|
74 |
+
bias=False,
|
75 |
+
)
|
76 |
+
self.l3 = nn.Linear(
|
77 |
+
in_features=inner_size,
|
78 |
+
out_features=config.hidden_size,
|
79 |
+
bias=False,
|
80 |
+
)
|
81 |
+
|
82 |
+
def forward(self, z):
|
83 |
+
z1, z2 = self.l1(z), self.l2(z)
|
84 |
+
z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2)
|
85 |
+
y = self.l3(self.act(z1) * z2)
|
86 |
+
return grab_first_if_tuple(y)
|
87 |
+
|
88 |
+
|
89 |
+
class Embedding(nn.Module):
|
90 |
+
_train_dtype = "bf16"
|
91 |
+
|
92 |
+
def __init__(self, config):
|
93 |
+
super().__init__()
|
94 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
95 |
+
|
96 |
+
def embed(self, input_ids, position_ids=None, tokentype_ids=None):
|
97 |
+
embeddings = self.word_embeddings(input_ids)
|
98 |
+
return embeddings
|
99 |
+
|
100 |
+
def unembed(self, u):
|
101 |
+
weight = self.word_embeddings.weight
|
102 |
+
return torch.matmul(u, weight)
|
103 |
+
|
104 |
+
|
105 |
+
class VocabParallelEmbedding(nn.Embedding):
|
106 |
+
"Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py"
|
107 |
+
|
108 |
+
def __init__(self, config):
|
109 |
+
vocab_size, process_group, padding_idx = (
|
110 |
+
config.vocab_size,
|
111 |
+
config.get("process_group", None),
|
112 |
+
config.get("padding_idx", None),
|
113 |
+
)
|
114 |
+
self.process_group = process_group
|
115 |
+
if process_group is not None:
|
116 |
+
world_size = torch.distributed.get_world_size(process_group)
|
117 |
+
if vocab_size % world_size != 0:
|
118 |
+
raise ValueError(
|
119 |
+
f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})"
|
120 |
+
)
|
121 |
+
if world_size > 1 and padding_idx is not None:
|
122 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
123 |
+
else:
|
124 |
+
world_size = 1
|
125 |
+
super().__init__(
|
126 |
+
vocab_size // world_size,
|
127 |
+
embedding_dim=config.hidden_size,
|
128 |
+
padding_idx=padding_idx,
|
129 |
+
)
|
130 |
+
|
131 |
+
def embed(self, x: Tensor) -> Tensor:
|
132 |
+
if self.process_group is None:
|
133 |
+
return self.forward(x)
|
134 |
+
else:
|
135 |
+
rank = torch.distributed.get_rank(self.process_group)
|
136 |
+
vocab_size = self.num_embeddings
|
137 |
+
vocab_start_index, vocab_end_index = (
|
138 |
+
rank * vocab_size,
|
139 |
+
(rank + 1) * vocab_size,
|
140 |
+
)
|
141 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
142 |
+
input_ids_mask = (x < vocab_start_index) | (x >= vocab_end_index)
|
143 |
+
x = x - vocab_start_index
|
144 |
+
x[input_ids_mask] = 0
|
145 |
+
embeddings = self.forward(x)
|
146 |
+
embeddings[input_ids_mask] = 0.0
|
147 |
+
# Reduce to the global process group
|
148 |
+
torch.distributed.all_reduce(embeddings, group=self.process_group)
|
149 |
+
return embeddings
|
150 |
+
|
151 |
+
def unembed(self, u: Tensor) -> Tensor:
|
152 |
+
if self.process_group is None:
|
153 |
+
return u @ self.weight.T
|
154 |
+
else:
|
155 |
+
raise NotImplementedError
|
model-00001-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc5f3b6258c1a7e513cc9e41a326d8d5e0f32d112408273be54ba69b522b50de
|
3 |
+
size 4980059464
|
model-00002-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6c6760a34950595555656d00f19cb5b1620e5a47cc3a1a0c56a3a1f057ebfa1
|
3 |
+
size 4929849248
|
model-00003-of-00003.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fdeef9c8c68ed48bc97e6ddab502fa95d8327dfe917a33c4db079e0fc29a7267
|
3 |
+
size 3003304856
|
model.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Together
|
2 |
+
# This software is distributed under the terms of the Apache License, Version 2.0
|
3 |
+
# Author: Michael Poli
|
4 |
+
# Note: MP and PP utilities are removed for ease of use and editing.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .cache import InferenceParams, RecurrentInferenceParams
|
11 |
+
from .engine import HyenaInferenceEngine
|
12 |
+
from .layers import ParallelGatedMLP, RMSNorm, VocabParallelEmbedding
|
13 |
+
from .utils import column_split, print_rank_0
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.modules.mha import MHA
|
17 |
+
except ImportError:
|
18 |
+
"flash_attn not installed"
|
19 |
+
|
20 |
+
try:
|
21 |
+
from .positional_embeddings import swap_mha_rope
|
22 |
+
except ImportError:
|
23 |
+
"could not import swap_mha_rope from positional_embeddings.py"
|
24 |
+
|
25 |
+
# dummy import to force huggingface to bundle the tokenizer
|
26 |
+
from .tokenizer import ByteTokenizer
|
27 |
+
|
28 |
+
|
29 |
+
class AttentionBlock(nn.Module):
|
30 |
+
def __init__(self, config, layer_idx) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.config = config
|
33 |
+
self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config)
|
34 |
+
self.layer_idx = layer_idx
|
35 |
+
self.proj_groups = config.get("proj_groups", 1)
|
36 |
+
dtype = config.get("attn_block_dtype", torch.bfloat16)
|
37 |
+
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
38 |
+
self.num_attention_heads = config.num_attention_heads
|
39 |
+
self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads
|
40 |
+
|
41 |
+
self.counter = 0
|
42 |
+
self.inner_mha_cls = MHA(
|
43 |
+
embed_dim=config.hidden_size,
|
44 |
+
num_heads=config.num_attention_heads,
|
45 |
+
num_heads_kv=config.num_attention_heads // self.proj_groups,
|
46 |
+
rotary_emb_dim=config.hidden_size // config.num_attention_heads,
|
47 |
+
qkv_proj_bias=config.get("qkv_proj_bias", True),
|
48 |
+
rotary_emb_base=config.get("rotary_emb_base", 10000),
|
49 |
+
causal=True,
|
50 |
+
layer_idx=layer_idx,
|
51 |
+
out_proj_bias=config.get("mha_out_proj_bias", True),
|
52 |
+
use_flash_attn=self.config.use_flash_attn,
|
53 |
+
).to(dtype=dtype)
|
54 |
+
|
55 |
+
# check if using interpolated rotary pos emb from config, and swap the rope emb
|
56 |
+
if config.get("use_interpolated_rotary_pos_emb", False):
|
57 |
+
swap_mha_rope(
|
58 |
+
mha=self.inner_mha_cls,
|
59 |
+
kwargs_new_rope={'scaling_factor': config.get("rotary_emb_scaling_factor", 1.)},
|
60 |
+
)
|
61 |
+
|
62 |
+
if self.config.get("smeared_gqa", False):
|
63 |
+
self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads
|
64 |
+
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
|
65 |
+
|
66 |
+
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
|
67 |
+
|
68 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
69 |
+
if (
|
70 |
+
type(padding_mask) == torch.Tensor
|
71 |
+
): # workaround for masking bug in FA. This works because Wqkv does not have bias
|
72 |
+
# and attention scores will be also automatically zeroed.
|
73 |
+
u = u * padding_mask[..., None]
|
74 |
+
u = (
|
75 |
+
self.inner_mha_cls(
|
76 |
+
self.pre_norm(u),
|
77 |
+
inference_params=inference_params,
|
78 |
+
)
|
79 |
+
+ u
|
80 |
+
)
|
81 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
82 |
+
u = u * padding_mask[..., None]
|
83 |
+
u = self.mlp(self.post_norm(u)) + u
|
84 |
+
return u, None
|
85 |
+
|
86 |
+
|
87 |
+
class ParallelHyenaFilter(nn.Module):
|
88 |
+
def __init__(self, config, layer_idx) -> None:
|
89 |
+
super().__init__()
|
90 |
+
self.config = config
|
91 |
+
self.layer_idx = layer_idx
|
92 |
+
self.hyena_filter_groups = config.get("hyena_filter_groups", self.config.hidden_size)
|
93 |
+
|
94 |
+
self.use_flashfft = config.get("use_flashfft", False)
|
95 |
+
self.state_size = config.state_size
|
96 |
+
self.hidden_size = config.hidden_size
|
97 |
+
self.num_filters = config.num_filters
|
98 |
+
self.inference_mode = config.get("inference_mode", True)
|
99 |
+
self.counter = 0
|
100 |
+
self.column_split_hyena = config.get("column_split_hyena", True)
|
101 |
+
|
102 |
+
assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size
|
103 |
+
|
104 |
+
self.D = nn.Parameter(torch.zeros(self.hidden_size))
|
105 |
+
|
106 |
+
# attention heads are not used except to split post short_filter
|
107 |
+
# projections in the same way as the checkpoint
|
108 |
+
self.num_attention_heads = config.num_attention_heads
|
109 |
+
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
|
110 |
+
|
111 |
+
# after preprocessing here we can save the new checkpoint
|
112 |
+
self.short_filter_length = config.short_filter_length
|
113 |
+
self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length))
|
114 |
+
self.short_filter_bias = (
|
115 |
+
nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None
|
116 |
+
)
|
117 |
+
|
118 |
+
self.engine = HyenaInferenceEngine(layer_idx=layer_idx)
|
119 |
+
self.use_flash_depthwise = config.get("use_flash_depthwise", False)
|
120 |
+
self.data_dtype = None
|
121 |
+
|
122 |
+
if self.use_flash_depthwise:
|
123 |
+
self.fir_fn = FlashDepthwiseConv1d(
|
124 |
+
channels=3 * self.hidden_size,
|
125 |
+
kernel_size=self.short_filter_length,
|
126 |
+
padding=self.short_filter_length - 1,
|
127 |
+
weights=self.short_filter_weight,
|
128 |
+
bias=self.short_filter_bias,
|
129 |
+
device=None,
|
130 |
+
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
|
131 |
+
)
|
132 |
+
else:
|
133 |
+
self.fir_fn = F.conv1d
|
134 |
+
|
135 |
+
self.fftconv_fn = None
|
136 |
+
self.long_fir_threshold = config.get("long_fir_threshold", None)
|
137 |
+
if self.long_fir_threshold is not None:
|
138 |
+
assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft"
|
139 |
+
|
140 |
+
self.num_systems = self.hidden_size // self.hyena_filter_groups
|
141 |
+
|
142 |
+
poles = torch.randn(self.num_systems, self.state_size, 1, 2)
|
143 |
+
|
144 |
+
# TODO: bring over init from internals
|
145 |
+
poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1)
|
146 |
+
poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1)
|
147 |
+
|
148 |
+
self.poles = nn.Parameter(poles)
|
149 |
+
|
150 |
+
self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, 1, 2))
|
151 |
+
self.h = None
|
152 |
+
|
153 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
154 |
+
if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys():
|
155 |
+
return self.sequential_forward(u, inference_params)
|
156 |
+
|
157 |
+
else:
|
158 |
+
return self.parallel_forward(u, inference_params, padding_mask)
|
159 |
+
|
160 |
+
def parallel_forward(self, u, inference_params=None, padding_mask=None):
|
161 |
+
L = u.shape[1]
|
162 |
+
z_pre, fir_state = self.engine.parallel_fir(
|
163 |
+
self.fir_fn,
|
164 |
+
u,
|
165 |
+
self.short_filter_weight,
|
166 |
+
self.short_filter_bias,
|
167 |
+
L,
|
168 |
+
fir_length=self.short_filter_length,
|
169 |
+
inference_params=inference_params,
|
170 |
+
padding_mask=padding_mask,
|
171 |
+
)
|
172 |
+
if inference_params:
|
173 |
+
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
174 |
+
|
175 |
+
if self.h is None:
|
176 |
+
h, filter_dtype, poles, residues = self.compute_filter(L, u.device)
|
177 |
+
else:
|
178 |
+
h = self.h
|
179 |
+
filter_dtype = self.h.dtype
|
180 |
+
|
181 |
+
if self.hyena_filter_groups > 1:
|
182 |
+
h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 1)
|
183 |
+
|
184 |
+
# if inference_params is not None, we plan to perform generation:
|
185 |
+
# prefilling is handled by the engine.
|
186 |
+
dims = (
|
187 |
+
self.hidden_size,
|
188 |
+
self.num_attention_heads,
|
189 |
+
self.hidden_size_per_attention_head,
|
190 |
+
self.state_size,
|
191 |
+
self.hyena_filter_groups,
|
192 |
+
)
|
193 |
+
y = self.engine.parallel_iir(
|
194 |
+
z_pre,
|
195 |
+
h,
|
196 |
+
self.D,
|
197 |
+
L,
|
198 |
+
t=self.t,
|
199 |
+
poles=self.poles,
|
200 |
+
residues=self.residues,
|
201 |
+
dims=dims,
|
202 |
+
inference_params=inference_params,
|
203 |
+
layer_idx=self.layer_idx,
|
204 |
+
prefill_style=self.config.get("prefill_style", "fft"),
|
205 |
+
use_flashfft=self.use_flashfft,
|
206 |
+
fftconv_fn=self.fftconv_fn,
|
207 |
+
column_split_hyena=self.column_split_hyena,
|
208 |
+
long_fir_threshold=self.long_fir_threshold,
|
209 |
+
padding_mask=padding_mask,
|
210 |
+
)
|
211 |
+
|
212 |
+
return y, inference_params
|
213 |
+
|
214 |
+
def sequential_forward(self, u, inference_params):
|
215 |
+
if self.data_dtype is None:
|
216 |
+
self.data_dtype = u.dtype
|
217 |
+
if len(u.shape) > 2:
|
218 |
+
u = u[:, -1]
|
219 |
+
|
220 |
+
fir_state, iir_state = (
|
221 |
+
inference_params.fir_state_dict[self.layer_idx],
|
222 |
+
inference_params.state_dict[self.layer_idx],
|
223 |
+
)
|
224 |
+
|
225 |
+
z_pre, fir_state = self.engine.step_fir(
|
226 |
+
u, fir_state, weight=self.short_filter_weight, bias=self.short_filter_bias
|
227 |
+
)
|
228 |
+
x2, x1, v = (
|
229 |
+
column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head)
|
230 |
+
if self.column_split_hyena
|
231 |
+
else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1)
|
232 |
+
)
|
233 |
+
|
234 |
+
y, iir_state = self.engine.step_iir(
|
235 |
+
x2,
|
236 |
+
x1,
|
237 |
+
v,
|
238 |
+
self.D,
|
239 |
+
self.residues,
|
240 |
+
self.poles,
|
241 |
+
iir_state,
|
242 |
+
iir_groups=self.hyena_filter_groups,
|
243 |
+
)
|
244 |
+
|
245 |
+
inference_params.fir_state_dict[self.layer_idx] = fir_state
|
246 |
+
inference_params.state_dict[self.layer_idx] = iir_state
|
247 |
+
y = y.to(dtype=self.data_dtype)
|
248 |
+
return y[:, None], inference_params
|
249 |
+
|
250 |
+
def update_time(self, L, device):
|
251 |
+
"""
|
252 |
+
Set [0, 1, ..., L-1] where L is the length of the current batch of inputs.
|
253 |
+
If L is greater than the length of the previous batch, then the time vector is
|
254 |
+
reinitialized. Otherwise, the time vector is truncated from cache.
|
255 |
+
"""
|
256 |
+
if not hasattr(self, "t"):
|
257 |
+
self.t = torch.arange(L, device=device)[None, None]
|
258 |
+
elif self.t.shape[-1] < L:
|
259 |
+
self.t = torch.arange(L, device=device)[None, None]
|
260 |
+
else:
|
261 |
+
self.t = self.t[..., :L]
|
262 |
+
|
263 |
+
def compute_filter(self, L, device):
|
264 |
+
self.update_time(L, device)
|
265 |
+
filter_dtype = torch.float32
|
266 |
+
residues, log_poles = (
|
267 |
+
torch.view_as_complex(self.residues.to(filter_dtype)),
|
268 |
+
torch.view_as_complex(self.poles.to(filter_dtype)).log(),
|
269 |
+
)
|
270 |
+
h = (residues * (log_poles * self.t).exp()).real.sum(1)[None]
|
271 |
+
return h, filter_dtype, log_poles, residues
|
272 |
+
|
273 |
+
|
274 |
+
class ParallelGatedConvBlock(nn.Module):
|
275 |
+
def __init__(self, config, layer_idx) -> None:
|
276 |
+
super().__init__()
|
277 |
+
self.config = config
|
278 |
+
self.layer_idx = layer_idx
|
279 |
+
self.low_mem_mode = config.get("low_mem_mode", False)
|
280 |
+
dtype = config.get("hyena_block_dtype", torch.float32)
|
281 |
+
mlp_dtype = config.get("mlp_dtype", torch.bfloat16)
|
282 |
+
self.pre_norm, self.post_norm = RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype)
|
283 |
+
self.filter = ParallelHyenaFilter(config, layer_idx).to(dtype=dtype)
|
284 |
+
self.projections = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
285 |
+
self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size).to(dtype)
|
286 |
+
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
|
287 |
+
|
288 |
+
self.proj_norm_fn = self.proj_norm
|
289 |
+
self.res_mlp_norm_fn = self.res_mlp_norm
|
290 |
+
|
291 |
+
if self.config.get("compile", False):
|
292 |
+
self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
|
293 |
+
self.res_mlp_norm_fn = torch.compile(
|
294 |
+
self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead"
|
295 |
+
)
|
296 |
+
|
297 |
+
def proj_norm(self, x):
|
298 |
+
return self.projections(self.pre_norm(x))
|
299 |
+
|
300 |
+
def res_mlp_norm(self, x):
|
301 |
+
return self.mlp(self.post_norm(x)) + x
|
302 |
+
|
303 |
+
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
304 |
+
z = self.proj_norm_fn(u)
|
305 |
+
|
306 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
307 |
+
z = z * padding_mask[..., None]
|
308 |
+
|
309 |
+
z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
|
310 |
+
|
311 |
+
z_in = self.out_filter_dense(z) + u
|
312 |
+
|
313 |
+
if type(padding_mask) == torch.Tensor: # guard against bias
|
314 |
+
z_in = z_in * padding_mask[..., None]
|
315 |
+
|
316 |
+
y = self.res_mlp_norm_fn(z_in)
|
317 |
+
|
318 |
+
return y, inference_params
|
319 |
+
|
320 |
+
|
321 |
+
def get_block(config, layer_idx, flash_fft=None):
|
322 |
+
if layer_idx in config.attn_layer_idxs:
|
323 |
+
return AttentionBlock(config, layer_idx)
|
324 |
+
elif layer_idx in config.hyena_layer_idxs:
|
325 |
+
block = ParallelGatedConvBlock(config, layer_idx)
|
326 |
+
if config.get("use_flashfft", "False"):
|
327 |
+
block.filter.fftconv_fn = flash_fft
|
328 |
+
return block
|
329 |
+
else:
|
330 |
+
raise NotImplementedError
|
331 |
+
|
332 |
+
|
333 |
+
class StripedHyena(nn.Module):
|
334 |
+
def __init__(self, config):
|
335 |
+
super().__init__()
|
336 |
+
self.config = config
|
337 |
+
self.embedding_layer = VocabParallelEmbedding(config)
|
338 |
+
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
|
339 |
+
self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)
|
340 |
+
|
341 |
+
if config.get("use_flashfft", "False"):
|
342 |
+
try:
|
343 |
+
from flashfftconv import FlashFFTConv
|
344 |
+
except:
|
345 |
+
raise ImportError
|
346 |
+
self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
|
347 |
+
else:
|
348 |
+
self.flash_fft = None
|
349 |
+
|
350 |
+
self.blocks = nn.ModuleList(
|
351 |
+
get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
|
352 |
+
)
|
353 |
+
|
354 |
+
def forward(self, x, inference_params_dict=None, padding_mask=None):
|
355 |
+
L = x.shape[1]
|
356 |
+
x = self.embedding_layer.embed(x)
|
357 |
+
if inference_params_dict is not None:
|
358 |
+
x, inference_params_dict_out = self.stateful_forward(
|
359 |
+
x,
|
360 |
+
inference_params_dict=inference_params_dict,
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)
|
364 |
+
|
365 |
+
x = self.norm(x)
|
366 |
+
x = self.unembed.unembed(x)
|
367 |
+
return x, inference_params_dict_out
|
368 |
+
|
369 |
+
def stateful_forward(self, x, inference_params_dict=None):
|
370 |
+
for block_idx, block in enumerate(self.blocks):
|
371 |
+
block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
|
372 |
+
inference_params = inference_params_dict[block_name]
|
373 |
+
x, _ = block(x, inference_params=inference_params)
|
374 |
+
|
375 |
+
return x, inference_params_dict
|
376 |
+
|
377 |
+
def stateless_forward(self, x, padding_mask=None):
|
378 |
+
if type(padding_mask) == torch.Tensor:
|
379 |
+
x = x * padding_mask[..., None]
|
380 |
+
|
381 |
+
for _, block in enumerate(self.blocks):
|
382 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
383 |
+
return x, None
|
384 |
+
|
385 |
+
def initialize_inference_params(self):
|
386 |
+
print_rank_0("Initializing inference params...")
|
387 |
+
inference_params_dict = {
|
388 |
+
"mha": InferenceParams(
|
389 |
+
max_seqlen=self.config.get("max_seqlen", 8192),
|
390 |
+
max_batch_size=self.config.get("max_batch_size", 1),
|
391 |
+
seqlen_offset=0,
|
392 |
+
),
|
393 |
+
"hyena": RecurrentInferenceParams(
|
394 |
+
fir_filter_length=self.config.short_filter_length,
|
395 |
+
state_dim=self.config.state_size,
|
396 |
+
seqlen_offset=0,
|
397 |
+
),
|
398 |
+
}
|
399 |
+
return inference_params_dict
|
400 |
+
|
401 |
+
def precompute_filters(self, L, device):
|
402 |
+
for block_idx, block in enumerate(self.blocks):
|
403 |
+
if type(block) == ParallelGatedConvBlock:
|
404 |
+
if type(block.filter) == ParallelHyenaFilter:
|
405 |
+
L = block.filter.long_fir_threshold or L
|
406 |
+
print_rank_0(f"Precomputing filters, L={L}...")
|
407 |
+
|
408 |
+
filter_dtype = torch.float16 if L >= 2048 else torch.float32
|
409 |
+
|
410 |
+
block.filter._set_time(L, device)
|
411 |
+
residues, poles = (
|
412 |
+
torch.view_as_complex(block.filter.residues.to(torch.float16)),
|
413 |
+
torch.view_as_complex(block.filter.poles.to(torch.float16)),
|
414 |
+
)
|
415 |
+
|
416 |
+
block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None]
|
417 |
+
block.filter.h = block.filter.h.to(dtype=filter_dtype)
|
418 |
+
|
419 |
+
def load_poles_residues(self, path):
|
420 |
+
"Load different poles and residues for each layer."
|
421 |
+
for block_idx, block in enumerate(self.blocks):
|
422 |
+
if type(block) == ParallelGatedConvBlock:
|
423 |
+
if type(block.filter) == ParallelHyenaFilter:
|
424 |
+
print(f"Loading poles and residues for block {block_idx}")
|
425 |
+
poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu")
|
426 |
+
poles = torch.view_as_real(poles)
|
427 |
+
residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu")
|
428 |
+
residues = torch.view_as_real(residues)
|
429 |
+
poles = poles.permute(1, 0, 2).unsqueeze(-2)
|
430 |
+
residues = residues.permute(1, 0, 2).unsqueeze(-2)
|
431 |
+
|
432 |
+
block.filter.poles = nn.Parameter(poles)
|
433 |
+
block.filter.residues = nn.Parameter(residues)
|
434 |
+
|
435 |
+
def to_bfloat16_except_poles_residues(self):
|
436 |
+
"""Convert all parameters to bfloat16 except for the poles and residues.
|
437 |
+
|
438 |
+
Particularly important for longer prompts.
|
439 |
+
"""
|
440 |
+
for k, p in self.named_parameters():
|
441 |
+
if "poles" not in k and "residues" not in k:
|
442 |
+
p.data = p.data.to(torch.bfloat16)
|
443 |
+
|
444 |
+
def load_from_split_converted_state_dict(self, path):
|
445 |
+
|
446 |
+
print("Loading from split converted state dict")
|
447 |
+
|
448 |
+
embedding_weight = torch.load(path + "/layer_00.pt")["word_embeddings.weight"]
|
449 |
+
self.embedding_layer.weight = nn.Parameter(embedding_weight.to(self.embedding_layer.weight.dtype))
|
450 |
+
|
451 |
+
print("Loading embedding weight ok")
|
452 |
+
|
453 |
+
if self.config.get("final_norm", False) is not None:
|
454 |
+
idx = len(self.blocks) + 1
|
455 |
+
final_norm_scale = torch.load(path + f"/layer_{idx:02d}.pt")["norm.scale"]
|
456 |
+
self.norm.scale = nn.Parameter(final_norm_scale.to(self.norm.scale.dtype))
|
457 |
+
|
458 |
+
print("loading final norm ok")
|
459 |
+
|
460 |
+
if not self.config.get("tie_embeddings", True):
|
461 |
+
idx = len(self.blocks) + 2
|
462 |
+
embedding_weight = torch.load(path + f"/layer_{idx:02d}.pt")["word_embeddings.weight"]
|
463 |
+
self.unembed.weight = nn.Parameter(embedding_weight.to(self.unembed.weight.dtype))
|
464 |
+
|
465 |
+
print("loading unembed weight ok")
|
466 |
+
|
467 |
+
for block_idx, block in enumerate(self.blocks):
|
468 |
+
print("loading block {}...".format(block_idx))
|
469 |
+
# strict = False if type(block) == ParallelGatedConvBlock else True
|
470 |
+
# some blocks (optionally) go through a round of conv distillation on some parameters
|
471 |
+
strict = True # safer to be strict and account for every layer
|
472 |
+
|
473 |
+
loaded_dict = torch.load(path + f"/layer_{block_idx + 1:02d}.pt")
|
474 |
+
block.load_state_dict(loaded_dict, strict=strict)
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 12913164672
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"backbone.blocks.0.filter.D": "model-00001-of-00003.safetensors",
|
7 |
+
"backbone.blocks.0.filter.poles": "model-00001-of-00003.safetensors",
|
8 |
+
"backbone.blocks.0.filter.residues": "model-00001-of-00003.safetensors",
|
9 |
+
"backbone.blocks.0.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
10 |
+
"backbone.blocks.0.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
11 |
+
"backbone.blocks.0.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
12 |
+
"backbone.blocks.0.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
13 |
+
"backbone.blocks.0.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
14 |
+
"backbone.blocks.0.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
15 |
+
"backbone.blocks.0.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
16 |
+
"backbone.blocks.0.post_norm.scale": "model-00001-of-00003.safetensors",
|
17 |
+
"backbone.blocks.0.pre_norm.scale": "model-00001-of-00003.safetensors",
|
18 |
+
"backbone.blocks.0.projections.bias": "model-00001-of-00003.safetensors",
|
19 |
+
"backbone.blocks.0.projections.weight": "model-00001-of-00003.safetensors",
|
20 |
+
"backbone.blocks.1.filter.D": "model-00001-of-00003.safetensors",
|
21 |
+
"backbone.blocks.1.filter.poles": "model-00001-of-00003.safetensors",
|
22 |
+
"backbone.blocks.1.filter.residues": "model-00001-of-00003.safetensors",
|
23 |
+
"backbone.blocks.1.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
24 |
+
"backbone.blocks.1.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
25 |
+
"backbone.blocks.1.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
26 |
+
"backbone.blocks.1.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
27 |
+
"backbone.blocks.1.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
28 |
+
"backbone.blocks.1.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
29 |
+
"backbone.blocks.1.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
30 |
+
"backbone.blocks.1.post_norm.scale": "model-00001-of-00003.safetensors",
|
31 |
+
"backbone.blocks.1.pre_norm.scale": "model-00001-of-00003.safetensors",
|
32 |
+
"backbone.blocks.1.projections.bias": "model-00001-of-00003.safetensors",
|
33 |
+
"backbone.blocks.1.projections.weight": "model-00001-of-00003.safetensors",
|
34 |
+
"backbone.blocks.10.filter.D": "model-00001-of-00003.safetensors",
|
35 |
+
"backbone.blocks.10.filter.poles": "model-00001-of-00003.safetensors",
|
36 |
+
"backbone.blocks.10.filter.residues": "model-00001-of-00003.safetensors",
|
37 |
+
"backbone.blocks.10.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
38 |
+
"backbone.blocks.10.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
39 |
+
"backbone.blocks.10.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
40 |
+
"backbone.blocks.10.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
41 |
+
"backbone.blocks.10.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
42 |
+
"backbone.blocks.10.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
43 |
+
"backbone.blocks.10.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
44 |
+
"backbone.blocks.10.post_norm.scale": "model-00001-of-00003.safetensors",
|
45 |
+
"backbone.blocks.10.pre_norm.scale": "model-00001-of-00003.safetensors",
|
46 |
+
"backbone.blocks.10.projections.bias": "model-00001-of-00003.safetensors",
|
47 |
+
"backbone.blocks.10.projections.weight": "model-00001-of-00003.safetensors",
|
48 |
+
"backbone.blocks.11.filter.D": "model-00001-of-00003.safetensors",
|
49 |
+
"backbone.blocks.11.filter.poles": "model-00001-of-00003.safetensors",
|
50 |
+
"backbone.blocks.11.filter.residues": "model-00001-of-00003.safetensors",
|
51 |
+
"backbone.blocks.11.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
52 |
+
"backbone.blocks.11.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
53 |
+
"backbone.blocks.11.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
54 |
+
"backbone.blocks.11.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
55 |
+
"backbone.blocks.11.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
56 |
+
"backbone.blocks.11.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
57 |
+
"backbone.blocks.11.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
58 |
+
"backbone.blocks.11.post_norm.scale": "model-00001-of-00003.safetensors",
|
59 |
+
"backbone.blocks.11.pre_norm.scale": "model-00001-of-00003.safetensors",
|
60 |
+
"backbone.blocks.11.projections.bias": "model-00001-of-00003.safetensors",
|
61 |
+
"backbone.blocks.11.projections.weight": "model-00001-of-00003.safetensors",
|
62 |
+
"backbone.blocks.12.filter.D": "model-00001-of-00003.safetensors",
|
63 |
+
"backbone.blocks.12.filter.poles": "model-00001-of-00003.safetensors",
|
64 |
+
"backbone.blocks.12.filter.residues": "model-00001-of-00003.safetensors",
|
65 |
+
"backbone.blocks.12.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
66 |
+
"backbone.blocks.12.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
67 |
+
"backbone.blocks.12.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
68 |
+
"backbone.blocks.12.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
69 |
+
"backbone.blocks.12.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
70 |
+
"backbone.blocks.12.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
71 |
+
"backbone.blocks.12.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
72 |
+
"backbone.blocks.12.post_norm.scale": "model-00001-of-00003.safetensors",
|
73 |
+
"backbone.blocks.12.pre_norm.scale": "model-00001-of-00003.safetensors",
|
74 |
+
"backbone.blocks.12.projections.bias": "model-00001-of-00003.safetensors",
|
75 |
+
"backbone.blocks.12.projections.weight": "model-00001-of-00003.safetensors",
|
76 |
+
"backbone.blocks.13.filter.D": "model-00002-of-00003.safetensors",
|
77 |
+
"backbone.blocks.13.filter.poles": "model-00002-of-00003.safetensors",
|
78 |
+
"backbone.blocks.13.filter.residues": "model-00002-of-00003.safetensors",
|
79 |
+
"backbone.blocks.13.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
80 |
+
"backbone.blocks.13.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
81 |
+
"backbone.blocks.13.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
82 |
+
"backbone.blocks.13.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
83 |
+
"backbone.blocks.13.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
84 |
+
"backbone.blocks.13.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
85 |
+
"backbone.blocks.13.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
86 |
+
"backbone.blocks.13.post_norm.scale": "model-00002-of-00003.safetensors",
|
87 |
+
"backbone.blocks.13.pre_norm.scale": "model-00002-of-00003.safetensors",
|
88 |
+
"backbone.blocks.13.projections.bias": "model-00002-of-00003.safetensors",
|
89 |
+
"backbone.blocks.13.projections.weight": "model-00002-of-00003.safetensors",
|
90 |
+
"backbone.blocks.14.filter.D": "model-00002-of-00003.safetensors",
|
91 |
+
"backbone.blocks.14.filter.poles": "model-00002-of-00003.safetensors",
|
92 |
+
"backbone.blocks.14.filter.residues": "model-00002-of-00003.safetensors",
|
93 |
+
"backbone.blocks.14.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
94 |
+
"backbone.blocks.14.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
95 |
+
"backbone.blocks.14.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
96 |
+
"backbone.blocks.14.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
97 |
+
"backbone.blocks.14.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
98 |
+
"backbone.blocks.14.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
99 |
+
"backbone.blocks.14.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
100 |
+
"backbone.blocks.14.post_norm.scale": "model-00002-of-00003.safetensors",
|
101 |
+
"backbone.blocks.14.pre_norm.scale": "model-00002-of-00003.safetensors",
|
102 |
+
"backbone.blocks.14.projections.bias": "model-00002-of-00003.safetensors",
|
103 |
+
"backbone.blocks.14.projections.weight": "model-00002-of-00003.safetensors",
|
104 |
+
"backbone.blocks.15.filter.D": "model-00002-of-00003.safetensors",
|
105 |
+
"backbone.blocks.15.filter.poles": "model-00002-of-00003.safetensors",
|
106 |
+
"backbone.blocks.15.filter.residues": "model-00002-of-00003.safetensors",
|
107 |
+
"backbone.blocks.15.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
108 |
+
"backbone.blocks.15.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
109 |
+
"backbone.blocks.15.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
110 |
+
"backbone.blocks.15.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
111 |
+
"backbone.blocks.15.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
112 |
+
"backbone.blocks.15.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
113 |
+
"backbone.blocks.15.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
114 |
+
"backbone.blocks.15.post_norm.scale": "model-00002-of-00003.safetensors",
|
115 |
+
"backbone.blocks.15.pre_norm.scale": "model-00002-of-00003.safetensors",
|
116 |
+
"backbone.blocks.15.projections.bias": "model-00002-of-00003.safetensors",
|
117 |
+
"backbone.blocks.15.projections.weight": "model-00002-of-00003.safetensors",
|
118 |
+
"backbone.blocks.16.inner_mha_cls.Wqkv.bias": "model-00002-of-00003.safetensors",
|
119 |
+
"backbone.blocks.16.inner_mha_cls.Wqkv.weight": "model-00002-of-00003.safetensors",
|
120 |
+
"backbone.blocks.16.inner_mha_cls.out_proj.bias": "model-00002-of-00003.safetensors",
|
121 |
+
"backbone.blocks.16.inner_mha_cls.out_proj.weight": "model-00002-of-00003.safetensors",
|
122 |
+
"backbone.blocks.16.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
|
123 |
+
"backbone.blocks.16.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
124 |
+
"backbone.blocks.16.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
125 |
+
"backbone.blocks.16.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
126 |
+
"backbone.blocks.16.post_norm.scale": "model-00002-of-00003.safetensors",
|
127 |
+
"backbone.blocks.16.pre_norm.scale": "model-00002-of-00003.safetensors",
|
128 |
+
"backbone.blocks.17.filter.D": "model-00002-of-00003.safetensors",
|
129 |
+
"backbone.blocks.17.filter.poles": "model-00002-of-00003.safetensors",
|
130 |
+
"backbone.blocks.17.filter.residues": "model-00002-of-00003.safetensors",
|
131 |
+
"backbone.blocks.17.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
132 |
+
"backbone.blocks.17.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
133 |
+
"backbone.blocks.17.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
134 |
+
"backbone.blocks.17.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
135 |
+
"backbone.blocks.17.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
136 |
+
"backbone.blocks.17.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
137 |
+
"backbone.blocks.17.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
138 |
+
"backbone.blocks.17.post_norm.scale": "model-00002-of-00003.safetensors",
|
139 |
+
"backbone.blocks.17.pre_norm.scale": "model-00002-of-00003.safetensors",
|
140 |
+
"backbone.blocks.17.projections.bias": "model-00002-of-00003.safetensors",
|
141 |
+
"backbone.blocks.17.projections.weight": "model-00002-of-00003.safetensors",
|
142 |
+
"backbone.blocks.18.filter.D": "model-00002-of-00003.safetensors",
|
143 |
+
"backbone.blocks.18.filter.poles": "model-00002-of-00003.safetensors",
|
144 |
+
"backbone.blocks.18.filter.residues": "model-00002-of-00003.safetensors",
|
145 |
+
"backbone.blocks.18.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
146 |
+
"backbone.blocks.18.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
147 |
+
"backbone.blocks.18.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
148 |
+
"backbone.blocks.18.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
149 |
+
"backbone.blocks.18.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
150 |
+
"backbone.blocks.18.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
151 |
+
"backbone.blocks.18.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
152 |
+
"backbone.blocks.18.post_norm.scale": "model-00002-of-00003.safetensors",
|
153 |
+
"backbone.blocks.18.pre_norm.scale": "model-00002-of-00003.safetensors",
|
154 |
+
"backbone.blocks.18.projections.bias": "model-00002-of-00003.safetensors",
|
155 |
+
"backbone.blocks.18.projections.weight": "model-00002-of-00003.safetensors",
|
156 |
+
"backbone.blocks.19.filter.D": "model-00002-of-00003.safetensors",
|
157 |
+
"backbone.blocks.19.filter.poles": "model-00002-of-00003.safetensors",
|
158 |
+
"backbone.blocks.19.filter.residues": "model-00002-of-00003.safetensors",
|
159 |
+
"backbone.blocks.19.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
160 |
+
"backbone.blocks.19.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
161 |
+
"backbone.blocks.19.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
162 |
+
"backbone.blocks.19.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
163 |
+
"backbone.blocks.19.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
164 |
+
"backbone.blocks.19.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
165 |
+
"backbone.blocks.19.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
166 |
+
"backbone.blocks.19.post_norm.scale": "model-00002-of-00003.safetensors",
|
167 |
+
"backbone.blocks.19.pre_norm.scale": "model-00002-of-00003.safetensors",
|
168 |
+
"backbone.blocks.19.projections.bias": "model-00002-of-00003.safetensors",
|
169 |
+
"backbone.blocks.19.projections.weight": "model-00002-of-00003.safetensors",
|
170 |
+
"backbone.blocks.2.filter.D": "model-00001-of-00003.safetensors",
|
171 |
+
"backbone.blocks.2.filter.poles": "model-00001-of-00003.safetensors",
|
172 |
+
"backbone.blocks.2.filter.residues": "model-00001-of-00003.safetensors",
|
173 |
+
"backbone.blocks.2.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
174 |
+
"backbone.blocks.2.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
175 |
+
"backbone.blocks.2.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
176 |
+
"backbone.blocks.2.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
177 |
+
"backbone.blocks.2.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
178 |
+
"backbone.blocks.2.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
179 |
+
"backbone.blocks.2.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
180 |
+
"backbone.blocks.2.post_norm.scale": "model-00001-of-00003.safetensors",
|
181 |
+
"backbone.blocks.2.pre_norm.scale": "model-00001-of-00003.safetensors",
|
182 |
+
"backbone.blocks.2.projections.bias": "model-00001-of-00003.safetensors",
|
183 |
+
"backbone.blocks.2.projections.weight": "model-00001-of-00003.safetensors",
|
184 |
+
"backbone.blocks.20.filter.D": "model-00002-of-00003.safetensors",
|
185 |
+
"backbone.blocks.20.filter.poles": "model-00002-of-00003.safetensors",
|
186 |
+
"backbone.blocks.20.filter.residues": "model-00002-of-00003.safetensors",
|
187 |
+
"backbone.blocks.20.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
188 |
+
"backbone.blocks.20.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
189 |
+
"backbone.blocks.20.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
190 |
+
"backbone.blocks.20.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
191 |
+
"backbone.blocks.20.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
192 |
+
"backbone.blocks.20.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
193 |
+
"backbone.blocks.20.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
194 |
+
"backbone.blocks.20.post_norm.scale": "model-00002-of-00003.safetensors",
|
195 |
+
"backbone.blocks.20.pre_norm.scale": "model-00002-of-00003.safetensors",
|
196 |
+
"backbone.blocks.20.projections.bias": "model-00002-of-00003.safetensors",
|
197 |
+
"backbone.blocks.20.projections.weight": "model-00002-of-00003.safetensors",
|
198 |
+
"backbone.blocks.21.filter.D": "model-00002-of-00003.safetensors",
|
199 |
+
"backbone.blocks.21.filter.poles": "model-00002-of-00003.safetensors",
|
200 |
+
"backbone.blocks.21.filter.residues": "model-00002-of-00003.safetensors",
|
201 |
+
"backbone.blocks.21.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
202 |
+
"backbone.blocks.21.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
203 |
+
"backbone.blocks.21.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
204 |
+
"backbone.blocks.21.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
205 |
+
"backbone.blocks.21.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
206 |
+
"backbone.blocks.21.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
207 |
+
"backbone.blocks.21.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
208 |
+
"backbone.blocks.21.post_norm.scale": "model-00002-of-00003.safetensors",
|
209 |
+
"backbone.blocks.21.pre_norm.scale": "model-00002-of-00003.safetensors",
|
210 |
+
"backbone.blocks.21.projections.bias": "model-00002-of-00003.safetensors",
|
211 |
+
"backbone.blocks.21.projections.weight": "model-00002-of-00003.safetensors",
|
212 |
+
"backbone.blocks.22.filter.D": "model-00002-of-00003.safetensors",
|
213 |
+
"backbone.blocks.22.filter.poles": "model-00002-of-00003.safetensors",
|
214 |
+
"backbone.blocks.22.filter.residues": "model-00002-of-00003.safetensors",
|
215 |
+
"backbone.blocks.22.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
216 |
+
"backbone.blocks.22.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
217 |
+
"backbone.blocks.22.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
218 |
+
"backbone.blocks.22.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
219 |
+
"backbone.blocks.22.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
220 |
+
"backbone.blocks.22.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
221 |
+
"backbone.blocks.22.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
222 |
+
"backbone.blocks.22.post_norm.scale": "model-00002-of-00003.safetensors",
|
223 |
+
"backbone.blocks.22.pre_norm.scale": "model-00002-of-00003.safetensors",
|
224 |
+
"backbone.blocks.22.projections.bias": "model-00002-of-00003.safetensors",
|
225 |
+
"backbone.blocks.22.projections.weight": "model-00002-of-00003.safetensors",
|
226 |
+
"backbone.blocks.23.filter.D": "model-00002-of-00003.safetensors",
|
227 |
+
"backbone.blocks.23.filter.poles": "model-00002-of-00003.safetensors",
|
228 |
+
"backbone.blocks.23.filter.residues": "model-00002-of-00003.safetensors",
|
229 |
+
"backbone.blocks.23.filter.short_filter_bias": "model-00002-of-00003.safetensors",
|
230 |
+
"backbone.blocks.23.filter.short_filter_weight": "model-00002-of-00003.safetensors",
|
231 |
+
"backbone.blocks.23.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
232 |
+
"backbone.blocks.23.mlp.l2.weight": "model-00002-of-00003.safetensors",
|
233 |
+
"backbone.blocks.23.mlp.l3.weight": "model-00002-of-00003.safetensors",
|
234 |
+
"backbone.blocks.23.out_filter_dense.bias": "model-00002-of-00003.safetensors",
|
235 |
+
"backbone.blocks.23.out_filter_dense.weight": "model-00002-of-00003.safetensors",
|
236 |
+
"backbone.blocks.23.post_norm.scale": "model-00002-of-00003.safetensors",
|
237 |
+
"backbone.blocks.23.pre_norm.scale": "model-00002-of-00003.safetensors",
|
238 |
+
"backbone.blocks.23.projections.bias": "model-00002-of-00003.safetensors",
|
239 |
+
"backbone.blocks.23.projections.weight": "model-00002-of-00003.safetensors",
|
240 |
+
"backbone.blocks.24.inner_mha_cls.Wqkv.bias": "model-00002-of-00003.safetensors",
|
241 |
+
"backbone.blocks.24.inner_mha_cls.Wqkv.weight": "model-00002-of-00003.safetensors",
|
242 |
+
"backbone.blocks.24.inner_mha_cls.out_proj.bias": "model-00002-of-00003.safetensors",
|
243 |
+
"backbone.blocks.24.inner_mha_cls.out_proj.weight": "model-00002-of-00003.safetensors",
|
244 |
+
"backbone.blocks.24.inner_mha_cls.rotary_emb.inv_freq": "model-00002-of-00003.safetensors",
|
245 |
+
"backbone.blocks.24.mlp.l1.weight": "model-00002-of-00003.safetensors",
|
246 |
+
"backbone.blocks.24.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
247 |
+
"backbone.blocks.24.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
248 |
+
"backbone.blocks.24.post_norm.scale": "model-00002-of-00003.safetensors",
|
249 |
+
"backbone.blocks.24.pre_norm.scale": "model-00002-of-00003.safetensors",
|
250 |
+
"backbone.blocks.25.filter.D": "model-00003-of-00003.safetensors",
|
251 |
+
"backbone.blocks.25.filter.poles": "model-00003-of-00003.safetensors",
|
252 |
+
"backbone.blocks.25.filter.residues": "model-00003-of-00003.safetensors",
|
253 |
+
"backbone.blocks.25.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
254 |
+
"backbone.blocks.25.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
255 |
+
"backbone.blocks.25.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
256 |
+
"backbone.blocks.25.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
257 |
+
"backbone.blocks.25.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
258 |
+
"backbone.blocks.25.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
259 |
+
"backbone.blocks.25.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
260 |
+
"backbone.blocks.25.post_norm.scale": "model-00003-of-00003.safetensors",
|
261 |
+
"backbone.blocks.25.pre_norm.scale": "model-00003-of-00003.safetensors",
|
262 |
+
"backbone.blocks.25.projections.bias": "model-00003-of-00003.safetensors",
|
263 |
+
"backbone.blocks.25.projections.weight": "model-00003-of-00003.safetensors",
|
264 |
+
"backbone.blocks.26.filter.D": "model-00003-of-00003.safetensors",
|
265 |
+
"backbone.blocks.26.filter.poles": "model-00003-of-00003.safetensors",
|
266 |
+
"backbone.blocks.26.filter.residues": "model-00003-of-00003.safetensors",
|
267 |
+
"backbone.blocks.26.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
268 |
+
"backbone.blocks.26.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
269 |
+
"backbone.blocks.26.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
270 |
+
"backbone.blocks.26.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
271 |
+
"backbone.blocks.26.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
272 |
+
"backbone.blocks.26.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
273 |
+
"backbone.blocks.26.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
274 |
+
"backbone.blocks.26.post_norm.scale": "model-00003-of-00003.safetensors",
|
275 |
+
"backbone.blocks.26.pre_norm.scale": "model-00003-of-00003.safetensors",
|
276 |
+
"backbone.blocks.26.projections.bias": "model-00003-of-00003.safetensors",
|
277 |
+
"backbone.blocks.26.projections.weight": "model-00003-of-00003.safetensors",
|
278 |
+
"backbone.blocks.27.filter.D": "model-00003-of-00003.safetensors",
|
279 |
+
"backbone.blocks.27.filter.poles": "model-00003-of-00003.safetensors",
|
280 |
+
"backbone.blocks.27.filter.residues": "model-00003-of-00003.safetensors",
|
281 |
+
"backbone.blocks.27.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
282 |
+
"backbone.blocks.27.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
283 |
+
"backbone.blocks.27.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
284 |
+
"backbone.blocks.27.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
285 |
+
"backbone.blocks.27.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
286 |
+
"backbone.blocks.27.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
287 |
+
"backbone.blocks.27.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
288 |
+
"backbone.blocks.27.post_norm.scale": "model-00003-of-00003.safetensors",
|
289 |
+
"backbone.blocks.27.pre_norm.scale": "model-00003-of-00003.safetensors",
|
290 |
+
"backbone.blocks.27.projections.bias": "model-00003-of-00003.safetensors",
|
291 |
+
"backbone.blocks.27.projections.weight": "model-00003-of-00003.safetensors",
|
292 |
+
"backbone.blocks.28.filter.D": "model-00003-of-00003.safetensors",
|
293 |
+
"backbone.blocks.28.filter.poles": "model-00003-of-00003.safetensors",
|
294 |
+
"backbone.blocks.28.filter.residues": "model-00003-of-00003.safetensors",
|
295 |
+
"backbone.blocks.28.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
296 |
+
"backbone.blocks.28.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
297 |
+
"backbone.blocks.28.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
298 |
+
"backbone.blocks.28.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
299 |
+
"backbone.blocks.28.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
300 |
+
"backbone.blocks.28.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
301 |
+
"backbone.blocks.28.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
302 |
+
"backbone.blocks.28.post_norm.scale": "model-00003-of-00003.safetensors",
|
303 |
+
"backbone.blocks.28.pre_norm.scale": "model-00003-of-00003.safetensors",
|
304 |
+
"backbone.blocks.28.projections.bias": "model-00003-of-00003.safetensors",
|
305 |
+
"backbone.blocks.28.projections.weight": "model-00003-of-00003.safetensors",
|
306 |
+
"backbone.blocks.29.filter.D": "model-00003-of-00003.safetensors",
|
307 |
+
"backbone.blocks.29.filter.poles": "model-00003-of-00003.safetensors",
|
308 |
+
"backbone.blocks.29.filter.residues": "model-00003-of-00003.safetensors",
|
309 |
+
"backbone.blocks.29.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
310 |
+
"backbone.blocks.29.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
311 |
+
"backbone.blocks.29.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
312 |
+
"backbone.blocks.29.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
313 |
+
"backbone.blocks.29.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
314 |
+
"backbone.blocks.29.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
315 |
+
"backbone.blocks.29.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
316 |
+
"backbone.blocks.29.post_norm.scale": "model-00003-of-00003.safetensors",
|
317 |
+
"backbone.blocks.29.pre_norm.scale": "model-00003-of-00003.safetensors",
|
318 |
+
"backbone.blocks.29.projections.bias": "model-00003-of-00003.safetensors",
|
319 |
+
"backbone.blocks.29.projections.weight": "model-00003-of-00003.safetensors",
|
320 |
+
"backbone.blocks.3.filter.D": "model-00001-of-00003.safetensors",
|
321 |
+
"backbone.blocks.3.filter.poles": "model-00001-of-00003.safetensors",
|
322 |
+
"backbone.blocks.3.filter.residues": "model-00001-of-00003.safetensors",
|
323 |
+
"backbone.blocks.3.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
324 |
+
"backbone.blocks.3.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
325 |
+
"backbone.blocks.3.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
326 |
+
"backbone.blocks.3.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
327 |
+
"backbone.blocks.3.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
328 |
+
"backbone.blocks.3.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
329 |
+
"backbone.blocks.3.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
330 |
+
"backbone.blocks.3.post_norm.scale": "model-00001-of-00003.safetensors",
|
331 |
+
"backbone.blocks.3.pre_norm.scale": "model-00001-of-00003.safetensors",
|
332 |
+
"backbone.blocks.3.projections.bias": "model-00001-of-00003.safetensors",
|
333 |
+
"backbone.blocks.3.projections.weight": "model-00001-of-00003.safetensors",
|
334 |
+
"backbone.blocks.30.filter.D": "model-00003-of-00003.safetensors",
|
335 |
+
"backbone.blocks.30.filter.poles": "model-00003-of-00003.safetensors",
|
336 |
+
"backbone.blocks.30.filter.residues": "model-00003-of-00003.safetensors",
|
337 |
+
"backbone.blocks.30.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
338 |
+
"backbone.blocks.30.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
339 |
+
"backbone.blocks.30.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
340 |
+
"backbone.blocks.30.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
341 |
+
"backbone.blocks.30.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
342 |
+
"backbone.blocks.30.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
343 |
+
"backbone.blocks.30.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
344 |
+
"backbone.blocks.30.post_norm.scale": "model-00003-of-00003.safetensors",
|
345 |
+
"backbone.blocks.30.pre_norm.scale": "model-00003-of-00003.safetensors",
|
346 |
+
"backbone.blocks.30.projections.bias": "model-00003-of-00003.safetensors",
|
347 |
+
"backbone.blocks.30.projections.weight": "model-00003-of-00003.safetensors",
|
348 |
+
"backbone.blocks.31.filter.D": "model-00003-of-00003.safetensors",
|
349 |
+
"backbone.blocks.31.filter.poles": "model-00003-of-00003.safetensors",
|
350 |
+
"backbone.blocks.31.filter.residues": "model-00003-of-00003.safetensors",
|
351 |
+
"backbone.blocks.31.filter.short_filter_bias": "model-00003-of-00003.safetensors",
|
352 |
+
"backbone.blocks.31.filter.short_filter_weight": "model-00003-of-00003.safetensors",
|
353 |
+
"backbone.blocks.31.mlp.l1.weight": "model-00003-of-00003.safetensors",
|
354 |
+
"backbone.blocks.31.mlp.l2.weight": "model-00003-of-00003.safetensors",
|
355 |
+
"backbone.blocks.31.mlp.l3.weight": "model-00003-of-00003.safetensors",
|
356 |
+
"backbone.blocks.31.out_filter_dense.bias": "model-00003-of-00003.safetensors",
|
357 |
+
"backbone.blocks.31.out_filter_dense.weight": "model-00003-of-00003.safetensors",
|
358 |
+
"backbone.blocks.31.post_norm.scale": "model-00003-of-00003.safetensors",
|
359 |
+
"backbone.blocks.31.pre_norm.scale": "model-00003-of-00003.safetensors",
|
360 |
+
"backbone.blocks.31.projections.bias": "model-00003-of-00003.safetensors",
|
361 |
+
"backbone.blocks.31.projections.weight": "model-00003-of-00003.safetensors",
|
362 |
+
"backbone.blocks.4.filter.D": "model-00001-of-00003.safetensors",
|
363 |
+
"backbone.blocks.4.filter.poles": "model-00001-of-00003.safetensors",
|
364 |
+
"backbone.blocks.4.filter.residues": "model-00001-of-00003.safetensors",
|
365 |
+
"backbone.blocks.4.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
366 |
+
"backbone.blocks.4.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
367 |
+
"backbone.blocks.4.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
368 |
+
"backbone.blocks.4.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
369 |
+
"backbone.blocks.4.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
370 |
+
"backbone.blocks.4.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
371 |
+
"backbone.blocks.4.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
372 |
+
"backbone.blocks.4.post_norm.scale": "model-00001-of-00003.safetensors",
|
373 |
+
"backbone.blocks.4.pre_norm.scale": "model-00001-of-00003.safetensors",
|
374 |
+
"backbone.blocks.4.projections.bias": "model-00001-of-00003.safetensors",
|
375 |
+
"backbone.blocks.4.projections.weight": "model-00001-of-00003.safetensors",
|
376 |
+
"backbone.blocks.5.filter.D": "model-00001-of-00003.safetensors",
|
377 |
+
"backbone.blocks.5.filter.poles": "model-00001-of-00003.safetensors",
|
378 |
+
"backbone.blocks.5.filter.residues": "model-00001-of-00003.safetensors",
|
379 |
+
"backbone.blocks.5.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
380 |
+
"backbone.blocks.5.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
381 |
+
"backbone.blocks.5.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
382 |
+
"backbone.blocks.5.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
383 |
+
"backbone.blocks.5.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
384 |
+
"backbone.blocks.5.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
385 |
+
"backbone.blocks.5.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
386 |
+
"backbone.blocks.5.post_norm.scale": "model-00001-of-00003.safetensors",
|
387 |
+
"backbone.blocks.5.pre_norm.scale": "model-00001-of-00003.safetensors",
|
388 |
+
"backbone.blocks.5.projections.bias": "model-00001-of-00003.safetensors",
|
389 |
+
"backbone.blocks.5.projections.weight": "model-00001-of-00003.safetensors",
|
390 |
+
"backbone.blocks.6.filter.D": "model-00001-of-00003.safetensors",
|
391 |
+
"backbone.blocks.6.filter.poles": "model-00001-of-00003.safetensors",
|
392 |
+
"backbone.blocks.6.filter.residues": "model-00001-of-00003.safetensors",
|
393 |
+
"backbone.blocks.6.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
394 |
+
"backbone.blocks.6.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
395 |
+
"backbone.blocks.6.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
396 |
+
"backbone.blocks.6.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
397 |
+
"backbone.blocks.6.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
398 |
+
"backbone.blocks.6.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
399 |
+
"backbone.blocks.6.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
400 |
+
"backbone.blocks.6.post_norm.scale": "model-00001-of-00003.safetensors",
|
401 |
+
"backbone.blocks.6.pre_norm.scale": "model-00001-of-00003.safetensors",
|
402 |
+
"backbone.blocks.6.projections.bias": "model-00001-of-00003.safetensors",
|
403 |
+
"backbone.blocks.6.projections.weight": "model-00001-of-00003.safetensors",
|
404 |
+
"backbone.blocks.7.filter.D": "model-00001-of-00003.safetensors",
|
405 |
+
"backbone.blocks.7.filter.poles": "model-00001-of-00003.safetensors",
|
406 |
+
"backbone.blocks.7.filter.residues": "model-00001-of-00003.safetensors",
|
407 |
+
"backbone.blocks.7.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
408 |
+
"backbone.blocks.7.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
409 |
+
"backbone.blocks.7.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
410 |
+
"backbone.blocks.7.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
411 |
+
"backbone.blocks.7.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
412 |
+
"backbone.blocks.7.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
413 |
+
"backbone.blocks.7.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
414 |
+
"backbone.blocks.7.post_norm.scale": "model-00001-of-00003.safetensors",
|
415 |
+
"backbone.blocks.7.pre_norm.scale": "model-00001-of-00003.safetensors",
|
416 |
+
"backbone.blocks.7.projections.bias": "model-00001-of-00003.safetensors",
|
417 |
+
"backbone.blocks.7.projections.weight": "model-00001-of-00003.safetensors",
|
418 |
+
"backbone.blocks.8.inner_mha_cls.Wqkv.bias": "model-00001-of-00003.safetensors",
|
419 |
+
"backbone.blocks.8.inner_mha_cls.Wqkv.weight": "model-00001-of-00003.safetensors",
|
420 |
+
"backbone.blocks.8.inner_mha_cls.out_proj.bias": "model-00001-of-00003.safetensors",
|
421 |
+
"backbone.blocks.8.inner_mha_cls.out_proj.weight": "model-00001-of-00003.safetensors",
|
422 |
+
"backbone.blocks.8.inner_mha_cls.rotary_emb.inv_freq": "model-00001-of-00003.safetensors",
|
423 |
+
"backbone.blocks.8.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
424 |
+
"backbone.blocks.8.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
425 |
+
"backbone.blocks.8.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
426 |
+
"backbone.blocks.8.post_norm.scale": "model-00001-of-00003.safetensors",
|
427 |
+
"backbone.blocks.8.pre_norm.scale": "model-00001-of-00003.safetensors",
|
428 |
+
"backbone.blocks.9.filter.D": "model-00001-of-00003.safetensors",
|
429 |
+
"backbone.blocks.9.filter.poles": "model-00001-of-00003.safetensors",
|
430 |
+
"backbone.blocks.9.filter.residues": "model-00001-of-00003.safetensors",
|
431 |
+
"backbone.blocks.9.filter.short_filter_bias": "model-00001-of-00003.safetensors",
|
432 |
+
"backbone.blocks.9.filter.short_filter_weight": "model-00001-of-00003.safetensors",
|
433 |
+
"backbone.blocks.9.mlp.l1.weight": "model-00001-of-00003.safetensors",
|
434 |
+
"backbone.blocks.9.mlp.l2.weight": "model-00001-of-00003.safetensors",
|
435 |
+
"backbone.blocks.9.mlp.l3.weight": "model-00001-of-00003.safetensors",
|
436 |
+
"backbone.blocks.9.out_filter_dense.bias": "model-00001-of-00003.safetensors",
|
437 |
+
"backbone.blocks.9.out_filter_dense.weight": "model-00001-of-00003.safetensors",
|
438 |
+
"backbone.blocks.9.post_norm.scale": "model-00001-of-00003.safetensors",
|
439 |
+
"backbone.blocks.9.pre_norm.scale": "model-00001-of-00003.safetensors",
|
440 |
+
"backbone.blocks.9.projections.bias": "model-00001-of-00003.safetensors",
|
441 |
+
"backbone.blocks.9.projections.weight": "model-00001-of-00003.safetensors",
|
442 |
+
"backbone.embedding_layer.weight": "model-00001-of-00003.safetensors",
|
443 |
+
"backbone.norm.scale": "model-00001-of-00003.safetensors"
|
444 |
+
}
|
445 |
+
}
|
modeling_hyena.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""StripedHyena custom code port for the Hugging Face Hub"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .configuration_hyena import StripedHyenaConfig
|
7 |
+
from transformers import PreTrainedModel
|
8 |
+
from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast
|
9 |
+
from transformers.utils import logging
|
10 |
+
from typing import Optional, Tuple, Union
|
11 |
+
from .model import StripedHyena
|
12 |
+
from .utils import dotdict
|
13 |
+
from .cache import InferenceParams
|
14 |
+
from .engine import HyenaInferenceEngine
|
15 |
+
from .layers import RMSNorm
|
16 |
+
from .utils import dotdict, column_split
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class StripedHyenaPreTrainedModel(PreTrainedModel):
|
22 |
+
config_class = StripedHyenaConfig
|
23 |
+
base_model_prefix = "sh"
|
24 |
+
supports_gradient_checkpointing = False
|
25 |
+
_no_split_modules = ["AttentionBlock", "ParallelGatedConvBlock"]
|
26 |
+
_skip_keys_device_placement = "past_key_values"
|
27 |
+
_keys_to_ignore_on_load_missing = [r"freq"]
|
28 |
+
_keys_to_ignore_on_load_unexpected = [r"fftconv", r"twiddle_factors"]
|
29 |
+
_supports_flash_attn_2 = True
|
30 |
+
|
31 |
+
|
32 |
+
class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
|
33 |
+
supports_gradient_checkpointing = True
|
34 |
+
|
35 |
+
def __init__(self, config, **kwargs):
|
36 |
+
super().__init__(config, **kwargs)
|
37 |
+
model_config = dotdict(config.to_dict())
|
38 |
+
self.backbone = StripedHyena(model_config)
|
39 |
+
self.backbone.gradient_checkpointing = False
|
40 |
+
self.config = config
|
41 |
+
vocab_size = config.vocab_size
|
42 |
+
if vocab_size % config.make_vocab_size_divisible_by != 0:
|
43 |
+
vocab_size += config.make_vocab_size_divisible_by - (
|
44 |
+
vocab_size % config.make_vocab_size_divisible_by
|
45 |
+
)
|
46 |
+
self.vocab_size = vocab_size
|
47 |
+
self.post_init()
|
48 |
+
self.force_dtype()
|
49 |
+
|
50 |
+
def force_dtype(self):
|
51 |
+
self.backbone.to_bfloat16_except_poles_residues()
|
52 |
+
|
53 |
+
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
|
54 |
+
self.backbone.gradient_checkpointing = enable
|
55 |
+
|
56 |
+
def get_input_embeddings(self):
|
57 |
+
return self.backbone.embedding_layer
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
input_ids: torch.LongTensor = None,
|
62 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
63 |
+
labels: Optional[torch.LongTensor] = None,
|
64 |
+
use_cache: Optional[bool] = None,
|
65 |
+
output_attentions: Optional[bool] = None,
|
66 |
+
output_hidden_states: Optional[bool] = None,
|
67 |
+
past_key_values=None,
|
68 |
+
return_dict: Optional[bool] = None,
|
69 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
70 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
71 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
72 |
+
|
73 |
+
if use_cache:
|
74 |
+
if self.backbone.gradient_checkpointing and self.backbone.training:
|
75 |
+
logger.warning_once(
|
76 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
77 |
+
)
|
78 |
+
use_cache = False
|
79 |
+
elif labels is not None:
|
80 |
+
logger.warning_once(
|
81 |
+
"`use_cache=True` is incompatible with loss calculation. Setting `use_cache=False`..."
|
82 |
+
)
|
83 |
+
use_cache = False
|
84 |
+
|
85 |
+
inputs = input_ids
|
86 |
+
if use_cache:
|
87 |
+
if past_key_values is None:
|
88 |
+
past_key_values = self.backbone.initialize_inference_params()
|
89 |
+
|
90 |
+
batch_size = input_ids.shape[0]
|
91 |
+
past_key_values["mha"].max_batch_size = batch_size
|
92 |
+
past_key_values["hyena"].max_batch_size = batch_size
|
93 |
+
else:
|
94 |
+
seqlen_offset = past_key_values["mha"].seqlen_offset
|
95 |
+
if seqlen_offset == 0:
|
96 |
+
# second loop through generate will have prompt_len + 1 as seqlen
|
97 |
+
seqlen_offset = input_ids.shape[-1] - 1
|
98 |
+
past_key_values["hyena"].seqlen_offset = seqlen_offset
|
99 |
+
past_key_values["mha"].seqlen_offset = seqlen_offset
|
100 |
+
else:
|
101 |
+
past_key_values["mha"].seqlen_offset += 1
|
102 |
+
past_key_values["hyena"].seqlen_offset += 1
|
103 |
+
|
104 |
+
inputs = input_ids[
|
105 |
+
:,
|
106 |
+
-1:,
|
107 |
+
]
|
108 |
+
|
109 |
+
logits, past_key_values = self.backbone(
|
110 |
+
inputs,
|
111 |
+
padding_mask=attention_mask,
|
112 |
+
inference_params_dict=past_key_values if use_cache else None,
|
113 |
+
)
|
114 |
+
|
115 |
+
loss = None
|
116 |
+
if labels is not None:
|
117 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
118 |
+
shift_labels = labels[..., 1:].contiguous()
|
119 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
120 |
+
shift_labels = shift_labels.view(-1)
|
121 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
122 |
+
loss = F.cross_entropy(shift_logits, shift_labels)
|
123 |
+
|
124 |
+
if return_dict:
|
125 |
+
return CausalLMOutputWithPast(
|
126 |
+
logits=logits,
|
127 |
+
hidden_states=None,
|
128 |
+
past_key_values=past_key_values if use_cache else None,
|
129 |
+
loss=loss,
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
return logits
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def can_generate(cls) -> bool:
|
136 |
+
return True
|
137 |
+
|
138 |
+
def prepare_inputs_for_generation(
|
139 |
+
self, input_ids, attention_mask=None, past_key_values=None, **kwargs
|
140 |
+
):
|
141 |
+
return {
|
142 |
+
"input_ids": input_ids,
|
143 |
+
"attention_mask": attention_mask,
|
144 |
+
"past_key_values": past_key_values,
|
145 |
+
}
|
positional_embeddings.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This software is distributed under the terms of the Apache License, Version 2.0
|
2 |
+
# Author: Armin Thomas, Eric Nguyen
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import copy
|
6 |
+
from einops import rearrange
|
7 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
8 |
+
from flash_attn.modules.mha import MHA
|
9 |
+
|
10 |
+
|
11 |
+
# simple wrapper for flash-attn RoPE with linear scaling:
|
12 |
+
class LinearlyScaledRotaryEmbedding(RotaryEmbedding):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dim: int,
|
16 |
+
scaling_factor: float=1.,
|
17 |
+
base=10000.0,
|
18 |
+
interleaved=False,
|
19 |
+
scale_base=None,
|
20 |
+
pos_idx_in_fp32=True,
|
21 |
+
device=None,
|
22 |
+
):
|
23 |
+
super().__init__(
|
24 |
+
dim=dim,
|
25 |
+
base=base,
|
26 |
+
interleaved=interleaved,
|
27 |
+
scale_base=scale_base,
|
28 |
+
pos_idx_in_fp32=pos_idx_in_fp32,
|
29 |
+
device=device
|
30 |
+
)
|
31 |
+
self._linear_scaling_factor = scaling_factor
|
32 |
+
# adpated from: https://github.com/Dao-AILab/flash-attention/blob/43ceab630bc6c27712428da5a33fc9cb5c369d91/flash_attn/layers/rotary.py#L368
|
33 |
+
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
34 |
+
# Reset the tables if the sequence length has changed,
|
35 |
+
# if we're on a new device (possibly due to tracing for instance),
|
36 |
+
# or if we're switching from inference mode to training
|
37 |
+
if (
|
38 |
+
seqlen > self._seq_len_cached
|
39 |
+
or self._cos_cached is None
|
40 |
+
or self._cos_cached.device != device
|
41 |
+
or self._cos_cached.dtype != dtype
|
42 |
+
or (self.training and self._cos_cached.is_inference())
|
43 |
+
):
|
44 |
+
self._seq_len_cached = seqlen
|
45 |
+
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
46 |
+
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
47 |
+
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
|
48 |
+
if self.pos_idx_in_fp32:
|
49 |
+
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
50 |
+
# linear scaling:
|
51 |
+
t = t / self._linear_scaling_factor
|
52 |
+
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
|
53 |
+
# will be large. Having it in bf16 will lose a lot of precision and cause the
|
54 |
+
# cos & sin output to change significantly.
|
55 |
+
# We want to recompute self.inv_freq if it was not loaded in fp32
|
56 |
+
if self.inv_freq.dtype != torch.float32:
|
57 |
+
inv_freq = self._compute_inv_freq(device=device)
|
58 |
+
else:
|
59 |
+
inv_freq = self.inv_freq
|
60 |
+
else:
|
61 |
+
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
62 |
+
# linear scaling:
|
63 |
+
t = t / self._linear_scaling_factor
|
64 |
+
inv_freq = self.inv_freq
|
65 |
+
# Don't do einsum, it converts fp32 to fp16 under AMP
|
66 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
67 |
+
freqs = torch.outer(t, inv_freq)
|
68 |
+
if self.scale is None:
|
69 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
70 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
71 |
+
else:
|
72 |
+
power = (
|
73 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
|
74 |
+
- seqlen // 2
|
75 |
+
) / self.scale_base
|
76 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
77 |
+
# We want the multiplication by scale to happen in fp32
|
78 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
79 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
80 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
81 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
82 |
+
|
83 |
+
# swap out RoPE of existing mha:
|
84 |
+
def swap_mha_rope(
|
85 |
+
mha,
|
86 |
+
new_rope: torch.nn.Module=LinearlyScaledRotaryEmbedding,
|
87 |
+
kwargs_new_rope: dict=None
|
88 |
+
):
|
89 |
+
# determine mha dtype and device:
|
90 |
+
dtype = mha.Wq.weight.dtype if mha.cross_attn else mha.Wqkv.weight.dtype
|
91 |
+
device = mha.Wq.weight.device if mha.cross_attn else mha.Wqkv.weight.device
|
92 |
+
# determine RoPE settings:
|
93 |
+
kwargs_old_rope = dict(
|
94 |
+
dim = mha.rotary_emb.dim,
|
95 |
+
base = mha.rotary_emb.base,
|
96 |
+
interleaved = mha.rotary_emb.interleaved,
|
97 |
+
scale_base = mha.rotary_emb.scale_base,
|
98 |
+
pos_idx_in_fp32 = mha.rotary_emb.pos_idx_in_fp32,
|
99 |
+
device = mha.rotary_emb.inv_freq.device
|
100 |
+
)
|
101 |
+
# delete old RoPE:
|
102 |
+
del mha.rotary_emb
|
103 |
+
# create new RoPE:
|
104 |
+
kwargs_new_rope = kwargs_new_rope or {'scaling_factor': 1.0}
|
105 |
+
scaled_rope = new_rope(
|
106 |
+
**kwargs_new_rope,
|
107 |
+
**kwargs_old_rope
|
108 |
+
).to(dtype)
|
109 |
+
# attach new RoPE to mha:
|
110 |
+
mha.rotary_emb = scaled_rope
|
111 |
+
# make new sure RoPE is correctly registered:
|
112 |
+
assert isinstance(mha.rotary_emb, new_rope)
|
113 |
+
return mha
|
pytorch_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7f1a446b4063869fa7a5c7b0e94cd2f234c44f21c6168b0dd8747f0bf33ab46
|
3 |
+
size 16814399082
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
streamer.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
class BaseStreamer:
|
5 |
+
"""
|
6 |
+
Base class from which `.generate()` streamers should inherit.
|
7 |
+
"""
|
8 |
+
|
9 |
+
def put(self, value):
|
10 |
+
"""Function that is called by `.generate()` to push new tokens"""
|
11 |
+
raise NotImplementedError()
|
12 |
+
|
13 |
+
def end(self):
|
14 |
+
"""Function that is called by `.generate()` to signal the end of generation"""
|
15 |
+
raise NotImplementedError()
|
16 |
+
|
17 |
+
|
18 |
+
class ByteStreamer(BaseStreamer):
|
19 |
+
"""
|
20 |
+
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
21 |
+
|
22 |
+
<Tip warning={true}>
|
23 |
+
|
24 |
+
The API for the streamer classes is still under development and may change in the future.
|
25 |
+
|
26 |
+
</Tip>
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
tokenizer (`AutoTokenizer`):
|
30 |
+
The tokenized used to decode the tokens.
|
31 |
+
skip_prompt (`bool`, *optional*, defaults to `False`):
|
32 |
+
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
33 |
+
decode_kwargs (`dict`, *optional*):
|
34 |
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
35 |
+
|
36 |
+
Examples:
|
37 |
+
|
38 |
+
```python
|
39 |
+
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
40 |
+
|
41 |
+
>>> tok = AutoTokenizer.from_pretrained("gpt2")
|
42 |
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
43 |
+
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
44 |
+
>>> streamer = TextStreamer(tok)
|
45 |
+
|
46 |
+
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
47 |
+
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
48 |
+
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
49 |
+
```
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
53 |
+
self.tokenizer = tokenizer
|
54 |
+
self.skip_prompt = skip_prompt
|
55 |
+
self.decode_kwargs = decode_kwargs
|
56 |
+
|
57 |
+
# variables used in the streaming process
|
58 |
+
self.token_cache = []
|
59 |
+
self.print_len = 0
|
60 |
+
self.next_tokens_are_prompt = True
|
61 |
+
|
62 |
+
def put(self, value):
|
63 |
+
"""
|
64 |
+
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
65 |
+
"""
|
66 |
+
if len(value.shape) > 1 and value.shape[0] > 1:
|
67 |
+
raise ValueError("TextStreamer only supports batch size 1")
|
68 |
+
elif len(value.shape) > 1:
|
69 |
+
value = value[0]
|
70 |
+
|
71 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
72 |
+
self.next_tokens_are_prompt = False
|
73 |
+
return
|
74 |
+
|
75 |
+
# Add the new token to the cache and decodes the entire thing.
|
76 |
+
self.token_cache.extend(value.tolist())
|
77 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
78 |
+
|
79 |
+
# After the symbol for a new line, we flush the cache.
|
80 |
+
if text.endswith("\n"):
|
81 |
+
printable_text = text[self.print_len :]
|
82 |
+
self.token_cache = []
|
83 |
+
self.print_len = 0
|
84 |
+
else:
|
85 |
+
printable_text = text[self.print_len : self.print_len + 1]
|
86 |
+
self.print_len += len(printable_text)
|
87 |
+
|
88 |
+
self.on_finalized_text(printable_text)
|
89 |
+
|
90 |
+
def end(self):
|
91 |
+
"""Flushes any remaining cache and prints a newline to stdout."""
|
92 |
+
# Flush the cache, if it exists
|
93 |
+
if len(self.token_cache) > 0:
|
94 |
+
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
95 |
+
printable_text = text[self.print_len :]
|
96 |
+
self.token_cache = []
|
97 |
+
self.print_len = 0
|
98 |
+
else:
|
99 |
+
printable_text = ""
|
100 |
+
|
101 |
+
self.next_tokens_are_prompt = True
|
102 |
+
self.on_finalized_text(printable_text, stream_end=True)
|
103 |
+
|
104 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
105 |
+
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
106 |
+
print(text, flush=True, end="" if not stream_end else None)
|
tokenizer.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from os import PathLike
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
from tokenizers import Tokenizer
|
12 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
13 |
+
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
|
14 |
+
from transformers.utils.generic import TensorType, PaddingStrategy
|
15 |
+
|
16 |
+
|
17 |
+
EMPTY: str = ""
|
18 |
+
|
19 |
+
|
20 |
+
class ByteTokenizer(PreTrainedTokenizer):
|
21 |
+
|
22 |
+
"""UTF-8 Encoder."""
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def from_pretrained(cls, model_id: str | PathLike, **kwargs) -> ByteTokenizer:
|
26 |
+
|
27 |
+
return cls(**kwargs, byte_level=True)
|
28 |
+
|
29 |
+
@property
|
30 |
+
def vocab_size(self) -> int:
|
31 |
+
|
32 |
+
return 512
|
33 |
+
|
34 |
+
@property
|
35 |
+
def byte_level(self) -> bool:
|
36 |
+
|
37 |
+
return self.init_kwargs.get('byte_level', True)
|
38 |
+
|
39 |
+
def get_vocab(self) -> Dict[str, int]:
|
40 |
+
|
41 |
+
return {chr(i): i for i in range(self.vocab_size)}
|
42 |
+
|
43 |
+
def __len__(self) -> int:
|
44 |
+
|
45 |
+
return self.vocab_size
|
46 |
+
|
47 |
+
def clamp(self, n: int) -> int:
|
48 |
+
|
49 |
+
return max(32, min(n, self.vocab_size))
|
50 |
+
|
51 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
52 |
+
|
53 |
+
return list(text)
|
54 |
+
|
55 |
+
def byte_tokenize(self, text: str) -> np.ndarray:
|
56 |
+
|
57 |
+
return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
|
58 |
+
|
59 |
+
def _convert_token_to_id(self, token: str) -> int:
|
60 |
+
|
61 |
+
return self.clamp(ord(token))
|
62 |
+
|
63 |
+
def _convert_id_to_token(self, index: int) -> str:
|
64 |
+
|
65 |
+
return chr(self.clamp(index))
|
66 |
+
|
67 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
68 |
+
|
69 |
+
return EMPTY.join(tokens)
|
70 |
+
|
71 |
+
def _decode(self, token_ids: List[int], **kwargs) -> str:
|
72 |
+
|
73 |
+
indices = np.asarray(token_ids, dtype=np.uint8)
|
74 |
+
|
75 |
+
return (
|
76 |
+
indices.clip(min=32, max=self.vocab_size, out=indices)
|
77 |
+
.tobytes()
|
78 |
+
.decode('utf-8')
|
79 |
+
)
|
80 |
+
|
81 |
+
def _encode_plus(self, text: str, **kwargs) -> BatchEncoding:
|
82 |
+
|
83 |
+
first_ids = self.byte_tokenize(text).tolist()
|
84 |
+
|
85 |
+
return self.prepare_for_model(
|
86 |
+
first_ids,
|
87 |
+
pair_ids=None,
|
88 |
+
add_special_tokens=kwargs.get('add_special_tokens', False),
|
89 |
+
padding=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD).value,
|
90 |
+
truncation=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE).value,
|
91 |
+
max_length=kwargs.get('max_length'),
|
92 |
+
stride=kwargs.get('stride', 0),
|
93 |
+
pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
|
94 |
+
return_tensors=kwargs.get('return_tensors'),
|
95 |
+
prepend_batch_axis=True,
|
96 |
+
return_attention_mask=kwargs.get('return_attention_mask'),
|
97 |
+
return_token_type_ids=kwargs.get('return_token_type_ids'),
|
98 |
+
return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
|
99 |
+
return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
|
100 |
+
return_length=kwargs.get('return_length', False),
|
101 |
+
verbose=kwargs.get('verbose', True),
|
102 |
+
)
|
103 |
+
|
104 |
+
def _batch_encode_plus(self, batch_text_or_text_pairs: List[str], **kwargs) -> BatchEncoding:
|
105 |
+
|
106 |
+
input_ids = [(self.byte_tokenize(text).tolist(), None) for text in batch_text_or_text_pairs]
|
107 |
+
|
108 |
+
return self._batch_prepare_for_model(
|
109 |
+
input_ids,
|
110 |
+
add_special_tokens=kwargs.get('add_special_tokens', False),
|
111 |
+
padding_strategy=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD),
|
112 |
+
truncation_strategy=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE),
|
113 |
+
max_length=kwargs.get('max_length'),
|
114 |
+
stride=kwargs.get('stride', 0),
|
115 |
+
pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
|
116 |
+
return_attention_mask=kwargs.get('return_attention_mask'),
|
117 |
+
return_token_type_ids=kwargs.get('return_token_type_ids'),
|
118 |
+
return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
|
119 |
+
return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
|
120 |
+
return_length=kwargs.get('return_length', False),
|
121 |
+
return_tensors=kwargs.get('return_tensors'),
|
122 |
+
verbose=kwargs.get('verbose', True),
|
123 |
+
)
|
124 |
+
|
125 |
+
def _save_pretrained(
|
126 |
+
self, save_directory: str | PathLike, file_names: Tuple[str], **kwargs
|
127 |
+
) -> Tuple[str]:
|
128 |
+
|
129 |
+
return file_names
|
tokenizer_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {},
|
3 |
+
"auto_map": {
|
4 |
+
"AutoTokenizer": [
|
5 |
+
"tokenizer.ByteTokenizer",
|
6 |
+
null
|
7 |
+
]
|
8 |
+
},
|
9 |
+
"byte_level": true,
|
10 |
+
"clean_up_tokenization_spaces": true,
|
11 |
+
"model_max_length": 1000000000000000019884624838656,
|
12 |
+
"padding_side": "left",
|
13 |
+
"truncation_side": "left"
|
14 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def grab_first_if_tuple(x):
|
5 |
+
if x.__class__.__name__ == "tuple":
|
6 |
+
return x[0]
|
7 |
+
else:
|
8 |
+
return x
|
9 |
+
|
10 |
+
|
11 |
+
def column_split(x, num_heads, head_size):
|
12 |
+
"""Split a tensor with `num_heads` alongside the head dimension, instead of
|
13 |
+
across heads. Fixed to three projections
|
14 |
+
"""
|
15 |
+
|
16 |
+
x_reshaped = x.reshape(
|
17 |
+
x.shape[0],
|
18 |
+
num_heads,
|
19 |
+
3 * head_size,
|
20 |
+
)
|
21 |
+
|
22 |
+
x2, x1, v = (
|
23 |
+
x_reshaped[:, :, :head_size],
|
24 |
+
x_reshaped[
|
25 |
+
:,
|
26 |
+
:,
|
27 |
+
head_size : 2 * head_size,
|
28 |
+
],
|
29 |
+
x_reshaped[:, :, 2 * head_size :],
|
30 |
+
)
|
31 |
+
x2, x1, v = (
|
32 |
+
x2.reshape(x2.shape[0], -1),
|
33 |
+
x1.reshape(x1.shape[0], -1),
|
34 |
+
v.reshape(v.shape[0], -1),
|
35 |
+
)
|
36 |
+
return x2, x1, v
|
37 |
+
|
38 |
+
|
39 |
+
def get_init_from_string(init_str):
|
40 |
+
if type(init_str) == str:
|
41 |
+
if init_str == "torch.nn.init.zeros_":
|
42 |
+
return torch.nn.init.zeros_
|
43 |
+
elif init_str == "torch.nn.init.xavier_uniform_":
|
44 |
+
return torch.nn.init.xavier_uniform_
|
45 |
+
elif init_str == "torch.nn.init.xavier_normal_":
|
46 |
+
return torch.nn.init.xavier_normal_
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unrecognized init {init_str}")
|
49 |
+
|
50 |
+
|
51 |
+
def print_rank_0(message, debug=False, end="\n"):
|
52 |
+
"""Print from rank 0 only."""
|
53 |
+
if torch.distributed.is_initialized():
|
54 |
+
if torch.distributed.get_rank() == 0:
|
55 |
+
print(message, flush=True, end=end)
|
56 |
+
else:
|
57 |
+
print(message, flush=True, end=end)
|
58 |
+
|
59 |
+
|
60 |
+
class dotdict(dict):
|
61 |
+
"""dot.notation access to dictionary attributes"""
|
62 |
+
|
63 |
+
__getattr__ = dict.get
|
64 |
+
__setattr__ = dict.__setitem__
|
65 |
+
__delattr__ = dict.__delitem__
|
66 |
+
|
67 |
+
|
68 |
+
def ensure_divisibility(numerator, denominator):
|
69 |
+
"""Ensure that numerator is divisible by the denominator."""
|
70 |
+
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
71 |
+
|
72 |
+
|
73 |
+
def divide(numerator, denominator):
|
74 |
+
"""Ensure that numerator is divisible by the denominator and return
|
75 |
+
the division value."""
|
76 |
+
ensure_divisibility(numerator, denominator)
|
77 |
+
return numerator // denominator
|
78 |
+
|
79 |
+
|
80 |
+
class VocabUtility:
|
81 |
+
"""Split the vocabulary into `world_size` chunks amd return the
|
82 |
+
first and last index of the vocabulary belonging to the `rank`
|
83 |
+
partition: Note that indices in [first, last]"""
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
87 |
+
index_f = rank * per_partition_vocab_size
|
88 |
+
index_l = index_f + per_partition_vocab_size
|
89 |
+
return index_f, index_l
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
93 |
+
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
94 |
+
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
95 |
+
per_partition_vocab_size, rank, world_size
|
96 |
+
)
|