0379693d73141b4f796eb2f0edd3bebcf86c950fdac574056f6f09445e445196
Browse files- README.md +85 -0
- config.json +74 -0
- configuration_progen.py +89 -0
- generation_config.json +6 -0
- modeling_InstructProGen.py +700 -0
- smash_config.json +31 -0
- special_tokens_map.json +9 -0
- structure.py +287 -0
- tokenizer.json +96 -0
- tokenizer_config.json +39 -0
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
thumbnail: "https://assets-global.website-files.com/646b351987a8d8ce158d1940/64ec9e96b4334c0e1ac41504_Logo%20with%20white%20text.svg"
|
3 |
+
base_model: InstructPLM/MPNN-ProGen2-xlarge-CATH42
|
4 |
+
metrics:
|
5 |
+
- memory_disk
|
6 |
+
- memory_inference
|
7 |
+
- inference_latency
|
8 |
+
- inference_throughput
|
9 |
+
- inference_CO2_emissions
|
10 |
+
- inference_energy_consumption
|
11 |
+
tags:
|
12 |
+
- pruna-ai
|
13 |
+
---
|
14 |
+
<!-- header start -->
|
15 |
+
<!-- 200823 -->
|
16 |
+
<div style="width: auto; margin-left: auto; margin-right: auto">
|
17 |
+
<a href="https://www.pruna.ai/" target="_blank" rel="noopener noreferrer">
|
18 |
+
<img src="https://i.imgur.com/eDAlcgk.png" alt="PrunaAI" style="width: 100%; min-width: 400px; display: block; margin: auto;">
|
19 |
+
</a>
|
20 |
+
</div>
|
21 |
+
<!-- header end -->
|
22 |
+
|
23 |
+
[![Twitter](https://img.shields.io/twitter/follow/PrunaAI?style=social)](https://twitter.com/PrunaAI)
|
24 |
+
[![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI)
|
25 |
+
[![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following)
|
26 |
+
[![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.gg/rskEr4BZJx)
|
27 |
+
|
28 |
+
# Simply make AI models cheaper, smaller, faster, and greener!
|
29 |
+
|
30 |
+
- Give a thumbs up if you like this model!
|
31 |
+
- Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
|
32 |
+
- Request access to easily compress your *own* AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
33 |
+
- Read the documentations to know more [here](https://pruna-ai-pruna.readthedocs-hosted.com/en/latest/)
|
34 |
+
- Join Pruna AI community on Discord [here](https://discord.gg/CP4VSgck) to share feedback/suggestions or get help.
|
35 |
+
|
36 |
+
## Results
|
37 |
+
|
38 |
+
![image info](./plots.png)
|
39 |
+
|
40 |
+
**Frequently Asked Questions**
|
41 |
+
- ***How does the compression work?*** The model is compressed with llm-int8.
|
42 |
+
- ***How does the model quality change?*** The quality of the model output might vary compared to the base model.
|
43 |
+
- ***How is the model efficiency evaluated?*** These results were obtained on HARDWARE_NAME with configuration described in `model/smash_config.json` and are obtained after a hardware warmup. The smashed model is directly compared to the original base model. Efficiency results may vary in other settings (e.g. other hardware, image size, batch size, ...). We recommend to directly run them in the use-case conditions to know if the smashed model can benefit you.
|
44 |
+
- ***What is the model format?*** We use safetensors.
|
45 |
+
- ***What calibration data has been used?*** If needed by the compression method, we used WikiText as the calibration data.
|
46 |
+
- ***What is the naming convention for Pruna Huggingface models?*** We take the original model name and append "turbo", "tiny", or "green" if the smashed model has a measured inference speed, inference memory, or inference energy consumption which is less than 90% of the original base model.
|
47 |
+
- ***How to compress my own models?*** You can request premium access to more compression methods and tech support for your specific use-cases [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
48 |
+
- ***What are "first" metrics?*** Results mentioning "first" are obtained after the first run of the model. The first run might take more memory or be slower than the subsequent runs due cuda overheads.
|
49 |
+
- ***What are "Sync" and "Async" metrics?*** "Sync" metrics are obtained by syncing all GPU processes and stop measurement when all of them are executed. "Async" metrics are obtained without syncing all GPU processes and stop when the model output can be used by the CPU. We provide both metrics since both could be relevant depending on the use-case. We recommend to test the efficiency gains directly in your use-cases.
|
50 |
+
|
51 |
+
## Setup
|
52 |
+
|
53 |
+
You can run the smashed model with these steps:
|
54 |
+
|
55 |
+
0. Check requirements from the original repo InstructPLM/MPNN-ProGen2-xlarge-CATH42 installed. In particular, check python, cuda, and transformers versions.
|
56 |
+
1. Make sure that you have installed quantization related packages.
|
57 |
+
```bash
|
58 |
+
pip install transformers accelerate bitsandbytes>0.37.0
|
59 |
+
```
|
60 |
+
2. Load & run the model.
|
61 |
+
```python
|
62 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
63 |
+
|
64 |
+
|
65 |
+
model = AutoModelForCausalLM.from_pretrained("PrunaAI/InstructPLM-MPNN-ProGen2-xlarge-CATH42-bnb-4bit-smashed", trust_remote_code=True, device_map='auto')
|
66 |
+
tokenizer = AutoTokenizer.from_pretrained("InstructPLM/MPNN-ProGen2-xlarge-CATH42")
|
67 |
+
|
68 |
+
input_ids = tokenizer("What is the color of prunes?,", return_tensors='pt').to(model.device)["input_ids"]
|
69 |
+
|
70 |
+
outputs = model.generate(input_ids, max_new_tokens=216)
|
71 |
+
tokenizer.decode(outputs[0])
|
72 |
+
```
|
73 |
+
|
74 |
+
## Configurations
|
75 |
+
|
76 |
+
The configuration info are in `smash_config.json`.
|
77 |
+
|
78 |
+
## Credits & License
|
79 |
+
|
80 |
+
The license of the smashed model follows the license of the original model. Please check the license of the original model InstructPLM/MPNN-ProGen2-xlarge-CATH42 before using this model which provided the base model. The license of the `pruna-engine` is [here](https://pypi.org/project/pruna-engine/) on Pypi.
|
81 |
+
|
82 |
+
## Want to compress other models?
|
83 |
+
|
84 |
+
- Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
|
85 |
+
- Request access to easily compress your own AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
|
config.json
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/ceph/hdd/staff/charpent/.cache/modelsvz_9j4okrownmzr7",
|
3 |
+
"activation_function": "gelu_new",
|
4 |
+
"architectures": [
|
5 |
+
"ProGenForCausalLM"
|
6 |
+
],
|
7 |
+
"attn_pdrop": 0.0,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_progen.ProGenConfig",
|
10 |
+
"AutoModelForCausalLM": "modeling_InstructProGen.ProGenForCausalLM"
|
11 |
+
},
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"embd_pdrop": 0.0,
|
14 |
+
"eos_token_id": 2,
|
15 |
+
"gradient_checkpointing": false,
|
16 |
+
"initializer_range": 0.02,
|
17 |
+
"layer_norm_epsilon": 1e-05,
|
18 |
+
"model_type": "progen",
|
19 |
+
"n_ctx": 2048,
|
20 |
+
"n_embd": 4096,
|
21 |
+
"n_head": 16,
|
22 |
+
"n_inner": null,
|
23 |
+
"n_layer": 32,
|
24 |
+
"n_positions": 1024,
|
25 |
+
"quantization_config": {
|
26 |
+
"_load_in_4bit": true,
|
27 |
+
"_load_in_8bit": false,
|
28 |
+
"bnb_4bit_compute_dtype": "bfloat16",
|
29 |
+
"bnb_4bit_quant_storage": "uint8",
|
30 |
+
"bnb_4bit_quant_type": "fp4",
|
31 |
+
"bnb_4bit_use_double_quant": false,
|
32 |
+
"llm_int8_enable_fp32_cpu_offload": false,
|
33 |
+
"llm_int8_has_fp16_weight": false,
|
34 |
+
"llm_int8_skip_modules": [
|
35 |
+
"lm_head"
|
36 |
+
],
|
37 |
+
"llm_int8_threshold": 6.0,
|
38 |
+
"load_in_4bit": true,
|
39 |
+
"load_in_8bit": false,
|
40 |
+
"quant_method": "bitsandbytes"
|
41 |
+
},
|
42 |
+
"resid_pdrop": 0.0,
|
43 |
+
"rotary_dim": 64,
|
44 |
+
"scale_attn_weights": true,
|
45 |
+
"structure": {
|
46 |
+
"embedding_keys": [
|
47 |
+
"mpnn_emb"
|
48 |
+
],
|
49 |
+
"max_seqlen": 512,
|
50 |
+
"n_queries": 256,
|
51 |
+
"num_heads": 16,
|
52 |
+
"output_dim": 4096,
|
53 |
+
"structure_emb_path_prefix": "./structure_embeddings",
|
54 |
+
"width": 1152
|
55 |
+
},
|
56 |
+
"summary_activation": null,
|
57 |
+
"summary_first_dropout": 0.1,
|
58 |
+
"summary_proj_to_labels": true,
|
59 |
+
"summary_type": "cls_index",
|
60 |
+
"summary_use_proj": true,
|
61 |
+
"task_specific_params": {
|
62 |
+
"text-generation": {
|
63 |
+
"do_sample": true,
|
64 |
+
"max_length": 50,
|
65 |
+
"temperature": 1.0
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"tie_word_embeddings": false,
|
69 |
+
"tokenizer_type": "iPLMTokenizer",
|
70 |
+
"torch_dtype": "float16",
|
71 |
+
"transformers_version": "4.42.4",
|
72 |
+
"use_cache": true,
|
73 |
+
"vocab_size": 30
|
74 |
+
}
|
configuration_progen.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Modified configuration implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class ProGenConfig(PretrainedConfig):
|
25 |
+
model_type = "progen"
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
vocab_size=50400,
|
30 |
+
n_positions=2048,
|
31 |
+
n_ctx=2048,
|
32 |
+
n_embd=4096,
|
33 |
+
n_layer=28,
|
34 |
+
n_head=16,
|
35 |
+
rotary_dim=64,
|
36 |
+
n_inner=None,
|
37 |
+
activation_function="gelu_new",
|
38 |
+
resid_pdrop=0.0,
|
39 |
+
embd_pdrop=0.0,
|
40 |
+
attn_pdrop=0.0,
|
41 |
+
layer_norm_epsilon=1e-5,
|
42 |
+
initializer_range=0.02,
|
43 |
+
scale_attn_weights=True,
|
44 |
+
gradient_checkpointing=False,
|
45 |
+
use_cache=True,
|
46 |
+
bos_token_id=50256,
|
47 |
+
eos_token_id=50256,
|
48 |
+
tie_word_embeddings=False,
|
49 |
+
**kwargs
|
50 |
+
):
|
51 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
|
52 |
+
|
53 |
+
self.vocab_size = vocab_size
|
54 |
+
self.n_ctx = n_ctx
|
55 |
+
self.n_positions = n_positions
|
56 |
+
self.n_embd = n_embd
|
57 |
+
self.n_layer = n_layer
|
58 |
+
self.n_head = n_head
|
59 |
+
self.n_inner = n_inner
|
60 |
+
self.rotary_dim = rotary_dim
|
61 |
+
self.activation_function = activation_function
|
62 |
+
self.resid_pdrop = resid_pdrop
|
63 |
+
self.embd_pdrop = embd_pdrop
|
64 |
+
self.attn_pdrop = attn_pdrop
|
65 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
66 |
+
self.initializer_range = initializer_range
|
67 |
+
self.gradient_checkpointing = gradient_checkpointing
|
68 |
+
self.scale_attn_weights = scale_attn_weights
|
69 |
+
self.use_cache = use_cache
|
70 |
+
|
71 |
+
self.bos_token_id = bos_token_id
|
72 |
+
self.eos_token_id = eos_token_id
|
73 |
+
self.tie_word_embeddings = tie_word_embeddings
|
74 |
+
|
75 |
+
@property
|
76 |
+
def max_position_embeddings(self):
|
77 |
+
return self.n_positions
|
78 |
+
|
79 |
+
@property
|
80 |
+
def hidden_size(self):
|
81 |
+
return self.n_embd
|
82 |
+
|
83 |
+
@property
|
84 |
+
def num_attention_heads(self):
|
85 |
+
return self.n_head
|
86 |
+
|
87 |
+
@property
|
88 |
+
def num_hidden_layers(self):
|
89 |
+
return self.n_layer
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"transformers_version": "4.42.4"
|
6 |
+
}
|
modeling_InstructProGen.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import CrossEntropyLoss
|
9 |
+
|
10 |
+
from transformers.activations import ACT2FN
|
11 |
+
from transformers.generation.configuration_utils import GenerationConfig
|
12 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
13 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
14 |
+
from transformers.generation.streamers import BaseStreamer
|
15 |
+
from transformers.generation.utils import GenerateOutput
|
16 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
17 |
+
from transformers.modeling_utils import PreTrainedModel
|
18 |
+
from transformers.utils import logging
|
19 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
20 |
+
from .configuration_progen import ProGenConfig
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
from .structure import StructureTransformer
|
25 |
+
|
26 |
+
|
27 |
+
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
28 |
+
dim = x.shape[-1]
|
29 |
+
if seq_len is None:
|
30 |
+
seq_len = x.shape[seq_dim]
|
31 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
32 |
+
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
|
33 |
+
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
34 |
+
|
35 |
+
|
36 |
+
def rotate_every_two(x):
|
37 |
+
x1 = x[:, :, :, ::2]
|
38 |
+
x2 = x[:, :, :, 1::2]
|
39 |
+
x = torch.stack((-x2, x1), axis=-1)
|
40 |
+
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
|
41 |
+
|
42 |
+
|
43 |
+
def apply_rotary_pos_emb(x, sincos, offset=0):
|
44 |
+
sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
|
45 |
+
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
|
46 |
+
return (x * cos) + (rotate_every_two(x) * sin)
|
47 |
+
|
48 |
+
|
49 |
+
class ProGenAttention(nn.Module):
|
50 |
+
def __init__(self, config):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
max_positions = config.max_position_embeddings
|
54 |
+
self.register_buffer(
|
55 |
+
"bias",
|
56 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
57 |
+
1, 1, max_positions, max_positions
|
58 |
+
),
|
59 |
+
)
|
60 |
+
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
61 |
+
|
62 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
63 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
64 |
+
|
65 |
+
self.embed_dim = config.hidden_size
|
66 |
+
self.num_attention_heads = config.num_attention_heads
|
67 |
+
self.head_dim = self.embed_dim // self.num_attention_heads
|
68 |
+
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
69 |
+
raise ValueError(
|
70 |
+
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
|
71 |
+
)
|
72 |
+
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
73 |
+
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
|
74 |
+
|
75 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
76 |
+
self.rotary_dim = None
|
77 |
+
if config.rotary_dim is not None:
|
78 |
+
self.rotary_dim = config.rotary_dim
|
79 |
+
|
80 |
+
def _split_heads(self, x, n_head, dim_head, mp_num):
|
81 |
+
reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head))
|
82 |
+
reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:])
|
83 |
+
return reshaped
|
84 |
+
|
85 |
+
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
86 |
+
"""
|
87 |
+
Merges attn_head_size dim and num_attn_heads dim into n_ctx
|
88 |
+
"""
|
89 |
+
if len(tensor.shape) == 5:
|
90 |
+
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
|
91 |
+
elif len(tensor.shape) == 4:
|
92 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
|
95 |
+
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
96 |
+
return tensor.view(new_shape)
|
97 |
+
|
98 |
+
def _attn(
|
99 |
+
self,
|
100 |
+
query,
|
101 |
+
key,
|
102 |
+
value,
|
103 |
+
attention_mask=None,
|
104 |
+
head_mask=None,
|
105 |
+
):
|
106 |
+
|
107 |
+
# compute causal mask from causal mask buffer
|
108 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
109 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
110 |
+
|
111 |
+
# Keep the attention weights computation in fp32 to avoid overflow issues
|
112 |
+
query = query.to(torch.float32)
|
113 |
+
key = key.to(torch.float32)
|
114 |
+
|
115 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
116 |
+
|
117 |
+
attn_weights = attn_weights / self.scale_attn
|
118 |
+
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
119 |
+
|
120 |
+
if attention_mask is not None:
|
121 |
+
# Apply the attention mask
|
122 |
+
attn_weights = attn_weights + attention_mask
|
123 |
+
|
124 |
+
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
125 |
+
attn_weights = attn_weights.to(value.dtype)
|
126 |
+
attn_weights = self.attn_dropout(attn_weights)
|
127 |
+
|
128 |
+
# Mask heads if we want to
|
129 |
+
if head_mask is not None:
|
130 |
+
attn_weights = attn_weights * head_mask
|
131 |
+
|
132 |
+
attn_output = torch.matmul(attn_weights, value)
|
133 |
+
|
134 |
+
return attn_output, attn_weights
|
135 |
+
|
136 |
+
def forward(
|
137 |
+
self,
|
138 |
+
hidden_states,
|
139 |
+
attention_mask=None,
|
140 |
+
layer_past=None,
|
141 |
+
head_mask=None,
|
142 |
+
use_cache=False,
|
143 |
+
output_attentions=False,
|
144 |
+
):
|
145 |
+
|
146 |
+
qkv = self.qkv_proj(hidden_states)
|
147 |
+
# TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
|
148 |
+
# mp_num = 4
|
149 |
+
mp_num = 8
|
150 |
+
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
|
151 |
+
|
152 |
+
local_dim = self.head_dim * self.num_attention_heads // mp_num
|
153 |
+
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
|
154 |
+
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
155 |
+
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
156 |
+
|
157 |
+
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
158 |
+
value = value.permute(0, 2, 1, 3)
|
159 |
+
|
160 |
+
seq_len = key.shape[1]
|
161 |
+
offset = 0
|
162 |
+
|
163 |
+
if layer_past is not None:
|
164 |
+
offset = layer_past[0].shape[-2]
|
165 |
+
seq_len += offset
|
166 |
+
|
167 |
+
if self.rotary_dim is not None:
|
168 |
+
k_rot = key[:, :, :, : self.rotary_dim]
|
169 |
+
k_pass = key[:, :, :, self.rotary_dim :]
|
170 |
+
|
171 |
+
q_rot = query[:, :, :, : self.rotary_dim]
|
172 |
+
q_pass = query[:, :, :, self.rotary_dim :]
|
173 |
+
|
174 |
+
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
175 |
+
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
176 |
+
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
177 |
+
|
178 |
+
key = torch.cat([k_rot, k_pass], dim=-1)
|
179 |
+
query = torch.cat([q_rot, q_pass], dim=-1)
|
180 |
+
else:
|
181 |
+
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
182 |
+
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
183 |
+
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
184 |
+
|
185 |
+
key = key.permute(0, 2, 1, 3)
|
186 |
+
query = query.permute(0, 2, 1, 3)
|
187 |
+
|
188 |
+
if layer_past is not None:
|
189 |
+
past_key = layer_past[0]
|
190 |
+
past_value = layer_past[1]
|
191 |
+
key = torch.cat((past_key, key), dim=-2)
|
192 |
+
value = torch.cat((past_value, value), dim=-2)
|
193 |
+
|
194 |
+
if use_cache is True:
|
195 |
+
present = (key, value)
|
196 |
+
else:
|
197 |
+
present = None
|
198 |
+
|
199 |
+
# compute self-attention: V x Softmax(QK^T)
|
200 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
201 |
+
|
202 |
+
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
203 |
+
|
204 |
+
attn_output = self.out_proj(attn_output)
|
205 |
+
attn_output = self.resid_dropout(attn_output)
|
206 |
+
|
207 |
+
outputs = (attn_output, present)
|
208 |
+
if output_attentions:
|
209 |
+
outputs += (attn_weights,)
|
210 |
+
|
211 |
+
return outputs # a, present, (attentions)
|
212 |
+
|
213 |
+
|
214 |
+
class ProGenMLP(nn.Module):
|
215 |
+
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
|
216 |
+
super().__init__()
|
217 |
+
embed_dim = config.n_embd
|
218 |
+
|
219 |
+
self.fc_in = nn.Linear(embed_dim, intermediate_size)
|
220 |
+
self.fc_out = nn.Linear(intermediate_size, embed_dim)
|
221 |
+
|
222 |
+
self.act = ACT2FN[config.activation_function]
|
223 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
224 |
+
|
225 |
+
def forward(self, hidden_states):
|
226 |
+
hidden_states = self.fc_in(hidden_states)
|
227 |
+
hidden_states = self.act(hidden_states)
|
228 |
+
hidden_states = self.fc_out(hidden_states)
|
229 |
+
hidden_states = self.dropout(hidden_states)
|
230 |
+
return hidden_states
|
231 |
+
|
232 |
+
|
233 |
+
class ProGenBlock(nn.Module):
|
234 |
+
def __init__(self, config):
|
235 |
+
super().__init__()
|
236 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
237 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
238 |
+
self.attn = ProGenAttention(config)
|
239 |
+
self.mlp = ProGenMLP(inner_dim, config)
|
240 |
+
|
241 |
+
def forward(
|
242 |
+
self,
|
243 |
+
hidden_states,
|
244 |
+
layer_past=None,
|
245 |
+
attention_mask=None,
|
246 |
+
head_mask=None,
|
247 |
+
use_cache=False,
|
248 |
+
output_attentions=False,
|
249 |
+
):
|
250 |
+
residual = hidden_states
|
251 |
+
hidden_states = self.ln_1(hidden_states)
|
252 |
+
attn_outputs = self.attn(
|
253 |
+
hidden_states,
|
254 |
+
layer_past=layer_past,
|
255 |
+
attention_mask=attention_mask,
|
256 |
+
head_mask=head_mask,
|
257 |
+
use_cache=use_cache,
|
258 |
+
output_attentions=output_attentions,
|
259 |
+
)
|
260 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
261 |
+
outputs = attn_outputs[1:]
|
262 |
+
|
263 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
264 |
+
hidden_states = attn_output + feed_forward_hidden_states + residual
|
265 |
+
|
266 |
+
if use_cache:
|
267 |
+
outputs = (hidden_states,) + outputs
|
268 |
+
else:
|
269 |
+
outputs = (hidden_states,) + outputs[1:]
|
270 |
+
|
271 |
+
return outputs # hidden_states, present, (attentions)
|
272 |
+
|
273 |
+
|
274 |
+
class ProGenPreTrainedModel(PreTrainedModel):
|
275 |
+
"""
|
276 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
277 |
+
models.
|
278 |
+
"""
|
279 |
+
|
280 |
+
config_class = ProGenConfig
|
281 |
+
base_model_prefix = "transformer"
|
282 |
+
supports_gradient_checkpointing = True
|
283 |
+
is_parallelizable = True
|
284 |
+
|
285 |
+
def __init__(self, *inputs, **kwargs):
|
286 |
+
super().__init__(*inputs, **kwargs)
|
287 |
+
|
288 |
+
def _init_weights(self, module):
|
289 |
+
"""Initialize the weights."""
|
290 |
+
if isinstance(module, (nn.Linear,)):
|
291 |
+
# Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
|
292 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
293 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
294 |
+
if module.bias is not None:
|
295 |
+
module.bias.data.zero_()
|
296 |
+
elif isinstance(module, nn.Embedding):
|
297 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
298 |
+
if module.padding_idx is not None:
|
299 |
+
module.weight.data[module.padding_idx].zero_()
|
300 |
+
elif isinstance(module, nn.LayerNorm):
|
301 |
+
module.bias.data.zero_()
|
302 |
+
module.weight.data.fill_(1.0)
|
303 |
+
|
304 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
305 |
+
if isinstance(module, ProGenModel):
|
306 |
+
module.gradient_checkpointing = value
|
307 |
+
|
308 |
+
class ProGenModel(ProGenPreTrainedModel):
|
309 |
+
def __init__(self, config):
|
310 |
+
super().__init__(config)
|
311 |
+
|
312 |
+
self.embed_dim = config.n_embd
|
313 |
+
self.vocab_size = config.vocab_size
|
314 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
315 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
316 |
+
self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
|
317 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
318 |
+
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
|
319 |
+
|
320 |
+
self.gradient_checkpointing = False
|
321 |
+
self.structure = StructureTransformer(**config.structure)
|
322 |
+
|
323 |
+
self.init_weights()
|
324 |
+
|
325 |
+
# Model parallel
|
326 |
+
self.model_parallel = False
|
327 |
+
self.device_map = None
|
328 |
+
|
329 |
+
|
330 |
+
def parallelize(self, device_map=None):
|
331 |
+
# Check validity of device_map
|
332 |
+
self.device_map = (
|
333 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
334 |
+
)
|
335 |
+
assert_device_map(self.device_map, len(self.h))
|
336 |
+
self.model_parallel = True
|
337 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
338 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
339 |
+
self.wte = self.wte.to(self.first_device)
|
340 |
+
# Load onto devices
|
341 |
+
for k, v in self.device_map.items():
|
342 |
+
for block in v:
|
343 |
+
cuda_device = "cuda:" + str(k)
|
344 |
+
self.h[block] = self.h[block].to(cuda_device)
|
345 |
+
# ln_f to last
|
346 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
347 |
+
|
348 |
+
|
349 |
+
def deparallelize(self):
|
350 |
+
self.model_parallel = False
|
351 |
+
self.device_map = None
|
352 |
+
self.first_device = "cpu"
|
353 |
+
self.last_device = "cpu"
|
354 |
+
self.wte = self.wte.to("cpu")
|
355 |
+
for index in range(len(self.h)):
|
356 |
+
self.h[index] = self.h[index].to("cpu")
|
357 |
+
self.ln_f = self.ln_f.to("cpu")
|
358 |
+
torch.cuda.empty_cache()
|
359 |
+
|
360 |
+
def get_input_embeddings(self):
|
361 |
+
return self.wte
|
362 |
+
|
363 |
+
def set_input_embeddings(self, new_embeddings):
|
364 |
+
self.wte = new_embeddings
|
365 |
+
|
366 |
+
def forward(
|
367 |
+
self,
|
368 |
+
input_ids=None,
|
369 |
+
past_key_values=None,
|
370 |
+
attention_mask=None,
|
371 |
+
token_type_ids=None,
|
372 |
+
position_ids=None,
|
373 |
+
head_mask=None,
|
374 |
+
inputs_embeds=None,
|
375 |
+
query_embeds=None,
|
376 |
+
use_cache=None,
|
377 |
+
output_attentions=None,
|
378 |
+
output_hidden_states=None,
|
379 |
+
return_dict=None,
|
380 |
+
):
|
381 |
+
if past_key_values is None:
|
382 |
+
# structure encode will check if input_ids contains valid
|
383 |
+
structure_embs = self.structure.encode(input_ids)
|
384 |
+
if structure_embs is not None:
|
385 |
+
input_ids = input_ids[:, self.structure.n_queries:]
|
386 |
+
else:
|
387 |
+
structure_embs = None
|
388 |
+
|
389 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
390 |
+
output_hidden_states = (
|
391 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
392 |
+
)
|
393 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
394 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
395 |
+
|
396 |
+
if input_ids is not None and inputs_embeds is not None:
|
397 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
398 |
+
elif input_ids is not None:
|
399 |
+
input_shape = input_ids.size()
|
400 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
401 |
+
batch_size = input_ids.shape[0]
|
402 |
+
elif inputs_embeds is not None:
|
403 |
+
input_shape = inputs_embeds.size()[:-1]
|
404 |
+
batch_size = inputs_embeds.shape[0]
|
405 |
+
else:
|
406 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
407 |
+
|
408 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
409 |
+
|
410 |
+
# if token_type_ids is not None:
|
411 |
+
# token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
412 |
+
|
413 |
+
if position_ids is not None:
|
414 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
415 |
+
|
416 |
+
if past_key_values is None:
|
417 |
+
past_length = 0
|
418 |
+
past_key_values = tuple([None] * len(self.h))
|
419 |
+
else:
|
420 |
+
past_length = past_key_values[0][0].size(-2)
|
421 |
+
|
422 |
+
if position_ids is None:
|
423 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
424 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
425 |
+
|
426 |
+
# Attention mask.
|
427 |
+
if attention_mask is not None:
|
428 |
+
assert batch_size > 0, "batch_size has to be defined and > 0"
|
429 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
430 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
431 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
432 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
433 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
434 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
435 |
+
attention_mask = attention_mask[:, None, None, :]
|
436 |
+
|
437 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
438 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
439 |
+
# positions we want to attend and -10000.0 for masked positions.
|
440 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
441 |
+
# effectively the same as removing these entirely.
|
442 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
443 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
444 |
+
|
445 |
+
# Prepare head mask if needed
|
446 |
+
# 1.0 in head_mask indicate we keep the head
|
447 |
+
# attention_probs has shape bsz x num_attention_heads x N x N
|
448 |
+
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
449 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
450 |
+
|
451 |
+
if inputs_embeds is None:
|
452 |
+
inputs_embeds = self.wte(input_ids)
|
453 |
+
|
454 |
+
if query_embeds is not None:
|
455 |
+
inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
|
456 |
+
input_shape = inputs_embeds.size()[:-1]
|
457 |
+
|
458 |
+
if structure_embs is not None:
|
459 |
+
inputs_embeds = torch.cat([structure_embs, inputs_embeds], dim=1)
|
460 |
+
input_shape = inputs_embeds.size()[:-1]
|
461 |
+
|
462 |
+
hidden_states = inputs_embeds
|
463 |
+
|
464 |
+
# disable token_type_ids
|
465 |
+
# if token_type_ids is not None:
|
466 |
+
# token_type_embeds = self.wte(token_type_ids)
|
467 |
+
# hidden_states = hidden_states + token_type_embeds
|
468 |
+
|
469 |
+
hidden_states = self.drop(hidden_states)
|
470 |
+
|
471 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
472 |
+
|
473 |
+
presents = () if use_cache else None
|
474 |
+
all_self_attentions = () if output_attentions else None
|
475 |
+
all_hidden_states = () if output_hidden_states else None
|
476 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
477 |
+
|
478 |
+
# Model parallel
|
479 |
+
if self.model_parallel:
|
480 |
+
torch.cuda.set_device(hidden_states.device)
|
481 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
482 |
+
if layer_past is not None:
|
483 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
484 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
485 |
+
if attention_mask is not None:
|
486 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
487 |
+
if isinstance(head_mask, torch.Tensor):
|
488 |
+
head_mask = head_mask.to(hidden_states.device)
|
489 |
+
if output_hidden_states:
|
490 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
491 |
+
|
492 |
+
if self.gradient_checkpointing and self.training:
|
493 |
+
|
494 |
+
if use_cache:
|
495 |
+
# logger.warning(
|
496 |
+
# "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
497 |
+
# "`use_cache=False`..."
|
498 |
+
# )
|
499 |
+
use_cache = False
|
500 |
+
|
501 |
+
def create_custom_forward(module):
|
502 |
+
def custom_forward(*inputs):
|
503 |
+
# None for past_key_value
|
504 |
+
return module(*inputs, use_cache, output_attentions)
|
505 |
+
|
506 |
+
return custom_forward
|
507 |
+
|
508 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
509 |
+
create_custom_forward(block),
|
510 |
+
hidden_states,
|
511 |
+
None,
|
512 |
+
attention_mask,
|
513 |
+
head_mask[i],
|
514 |
+
)
|
515 |
+
else:
|
516 |
+
outputs = block(
|
517 |
+
hidden_states,
|
518 |
+
layer_past=layer_past,
|
519 |
+
attention_mask=attention_mask,
|
520 |
+
head_mask=head_mask[i],
|
521 |
+
use_cache=use_cache,
|
522 |
+
output_attentions=output_attentions,
|
523 |
+
)
|
524 |
+
|
525 |
+
hidden_states = outputs[0]
|
526 |
+
if use_cache is True:
|
527 |
+
presents = presents + (outputs[1],)
|
528 |
+
|
529 |
+
if output_attentions:
|
530 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
531 |
+
|
532 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
533 |
+
if self.model_parallel:
|
534 |
+
for k, v in self.device_map.items():
|
535 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
536 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
537 |
+
|
538 |
+
hidden_states = self.ln_f(hidden_states)
|
539 |
+
|
540 |
+
hidden_states = hidden_states.view(*output_shape)
|
541 |
+
# Add last hidden state
|
542 |
+
if output_hidden_states:
|
543 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
544 |
+
|
545 |
+
if not return_dict:
|
546 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
547 |
+
|
548 |
+
return BaseModelOutputWithPast(
|
549 |
+
last_hidden_state=hidden_states,
|
550 |
+
past_key_values=presents,
|
551 |
+
hidden_states=all_hidden_states,
|
552 |
+
attentions=all_self_attentions,
|
553 |
+
)
|
554 |
+
|
555 |
+
|
556 |
+
class ProGenForCausalLM(ProGenPreTrainedModel):
|
557 |
+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
|
558 |
+
|
559 |
+
def __init__(self, config):
|
560 |
+
super().__init__(config)
|
561 |
+
self.transformer = ProGenModel(config)
|
562 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
563 |
+
self.init_weights()
|
564 |
+
|
565 |
+
# Model parallel
|
566 |
+
self.model_parallel = False
|
567 |
+
self.device_map = None
|
568 |
+
|
569 |
+
def parallelize(self, device_map=None):
|
570 |
+
self.device_map = (
|
571 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
572 |
+
if device_map is None
|
573 |
+
else device_map
|
574 |
+
)
|
575 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
576 |
+
self.transformer.parallelize(self.device_map)
|
577 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
578 |
+
self.model_parallel = True
|
579 |
+
|
580 |
+
def deparallelize(self):
|
581 |
+
self.transformer.deparallelize()
|
582 |
+
self.transformer = self.transformer.to("cpu")
|
583 |
+
self.lm_head = self.lm_head.to("cpu")
|
584 |
+
self.model_parallel = False
|
585 |
+
torch.cuda.empty_cache()
|
586 |
+
|
587 |
+
def get_output_embeddings(self):
|
588 |
+
return self.lm_head
|
589 |
+
|
590 |
+
def set_output_embeddings(self, new_embeddings):
|
591 |
+
self.lm_head = new_embeddings
|
592 |
+
|
593 |
+
def prepare_inputs_for_generation(
|
594 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
595 |
+
):
|
596 |
+
if past_key_values:
|
597 |
+
input_ids = input_ids[:, -1:]
|
598 |
+
|
599 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
600 |
+
if inputs_embeds is not None and past_key_values is None:
|
601 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
602 |
+
else:
|
603 |
+
model_inputs = {"input_ids": input_ids}
|
604 |
+
|
605 |
+
model_inputs.update(
|
606 |
+
{
|
607 |
+
"past_key_values": past_key_values,
|
608 |
+
"use_cache": kwargs.get("use_cache"),
|
609 |
+
"attention_mask": attention_mask,
|
610 |
+
}
|
611 |
+
)
|
612 |
+
return model_inputs
|
613 |
+
|
614 |
+
def forward(
|
615 |
+
self,
|
616 |
+
input_ids=None,
|
617 |
+
past_key_values=None,
|
618 |
+
attention_mask=None,
|
619 |
+
token_type_ids=None,
|
620 |
+
position_ids=None,
|
621 |
+
head_mask=None,
|
622 |
+
inputs_embeds=None,
|
623 |
+
labels=None,
|
624 |
+
use_cache=None,
|
625 |
+
query_embeds = None,
|
626 |
+
output_attentions=None,
|
627 |
+
output_hidden_states=None,
|
628 |
+
return_dict=None,
|
629 |
+
):
|
630 |
+
r"""
|
631 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
632 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
633 |
+
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
|
634 |
+
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
|
635 |
+
"""
|
636 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
637 |
+
|
638 |
+
transformer_outputs = self.transformer(
|
639 |
+
input_ids,
|
640 |
+
past_key_values=past_key_values,
|
641 |
+
attention_mask=attention_mask,
|
642 |
+
token_type_ids=token_type_ids,
|
643 |
+
position_ids=position_ids,
|
644 |
+
head_mask=head_mask,
|
645 |
+
inputs_embeds=inputs_embeds,
|
646 |
+
query_embeds=query_embeds,
|
647 |
+
use_cache=use_cache,
|
648 |
+
output_attentions=output_attentions,
|
649 |
+
output_hidden_states=output_hidden_states,
|
650 |
+
return_dict=return_dict,
|
651 |
+
)
|
652 |
+
hidden_states = transformer_outputs[0]
|
653 |
+
|
654 |
+
# Set device for model parallelism
|
655 |
+
if self.model_parallel:
|
656 |
+
torch.cuda.set_device(self.transformer.first_device)
|
657 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
658 |
+
|
659 |
+
# make sure sampling in fp16 works correctly and
|
660 |
+
# compute loss in fp32 to match with mesh-tf version
|
661 |
+
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
|
662 |
+
lm_logits = self.lm_head(hidden_states).to(torch.float32)
|
663 |
+
|
664 |
+
loss = None
|
665 |
+
if labels is not None:
|
666 |
+
# Shift so that tokens < n predict n
|
667 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
668 |
+
shift_labels = labels[..., 1:].contiguous()
|
669 |
+
# Flatten the tokens
|
670 |
+
loss_fct = CrossEntropyLoss()
|
671 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
672 |
+
|
673 |
+
loss = loss.to(hidden_states.dtype)
|
674 |
+
|
675 |
+
if not return_dict:
|
676 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
677 |
+
return ((loss,) + output) if loss is not None else output
|
678 |
+
|
679 |
+
return CausalLMOutputWithPast(
|
680 |
+
loss=loss,
|
681 |
+
logits=lm_logits,
|
682 |
+
past_key_values=transformer_outputs.past_key_values,
|
683 |
+
hidden_states=transformer_outputs.hidden_states,
|
684 |
+
attentions=transformer_outputs.attentions,
|
685 |
+
)
|
686 |
+
|
687 |
+
@staticmethod
|
688 |
+
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
689 |
+
"""
|
690 |
+
This function is used to re-order the :obj:`past_key_values` cache if
|
691 |
+
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
692 |
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
693 |
+
"""
|
694 |
+
return tuple(
|
695 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
696 |
+
for layer_past in past
|
697 |
+
)
|
698 |
+
|
699 |
+
# def generate(self, inputs: Tensor | None = None, generation_config: GenerationConfig | None = None, logits_processor: LogitsProcessorList | None = None, stopping_criteria: StoppingCriteriaList | None = None, prefix_allowed_tokens_fn: Callable[[int, Tensor], List[int]] | None = None, synced_gpus: bool | None = None, assistant_model: PreTrainedModel | None = None, streamer: BaseStreamer | None = None, negative_prompt_ids: Tensor | None = None, negative_prompt_attention_mask: Tensor | None = None, **kwargs) -> GenerateOutput | LongTensor:
|
700 |
+
# return super().generate(inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
|
smash_config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"api_key": null,
|
3 |
+
"verify_url": "http://johnrachwan.pythonanywhere.com",
|
4 |
+
"smash_config": {
|
5 |
+
"pruners": "None",
|
6 |
+
"pruning_ratio": 0.0,
|
7 |
+
"factorizers": "None",
|
8 |
+
"quantizers": "['llm-int8']",
|
9 |
+
"weight_quantization_bits": 4,
|
10 |
+
"output_deviation": 0.005,
|
11 |
+
"compilers": "None",
|
12 |
+
"static_batch": true,
|
13 |
+
"static_shape": true,
|
14 |
+
"controlnet": "None",
|
15 |
+
"unet_dim": 4,
|
16 |
+
"device": "cuda",
|
17 |
+
"cache_dir": "/ceph/hdd/staff/charpent/.cache/modelsvz_9j4ok",
|
18 |
+
"batch_size": 1,
|
19 |
+
"model_name": "InstructPLM/MPNN-ProGen2-xlarge-CATH42",
|
20 |
+
"task": "text_text_generation",
|
21 |
+
"max_batch_size": 1,
|
22 |
+
"qtype_weight": "torch.qint8",
|
23 |
+
"qtype_activation": "torch.quint8",
|
24 |
+
"qobserver": "<class 'torch.ao.quantization.observer.MinMaxObserver'>",
|
25 |
+
"qscheme": "torch.per_tensor_symmetric",
|
26 |
+
"qconfig": "x86",
|
27 |
+
"group_size": 128,
|
28 |
+
"damp_percent": 0.1,
|
29 |
+
"save_load_fn": "bitsandbytes"
|
30 |
+
}
|
31 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"pad_token": {
|
3 |
+
"content": "<|pad|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
}
|
9 |
+
}
|
structure.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Alibaba Cloud.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
import math
|
8 |
+
import requests
|
9 |
+
from io import BytesIO
|
10 |
+
from functools import partial
|
11 |
+
import pickle
|
12 |
+
from typing import Callable, Optional, Sequence, Tuple, List
|
13 |
+
import numpy as np
|
14 |
+
import os
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from torch.nn import functional as F
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
from torchvision import transforms
|
20 |
+
from torchvision.transforms import InterpolationMode
|
21 |
+
|
22 |
+
class GLU(nn.Module):
|
23 |
+
def __init__(self,hidden_size):
|
24 |
+
super().__init__()
|
25 |
+
self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
|
26 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
27 |
+
self.act1 = nn.GELU()
|
28 |
+
self.act2 = nn.functional.silu
|
29 |
+
self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
|
30 |
+
self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
|
31 |
+
self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)
|
32 |
+
|
33 |
+
def forward(self,x):
|
34 |
+
x = self.linear_proj(x)
|
35 |
+
x = self.act1(self.norm1(x))
|
36 |
+
x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
|
37 |
+
x = self.dense_4h_to_h(x)
|
38 |
+
return x
|
39 |
+
def swiglu(x):
|
40 |
+
x = torch.chunk(x, 2, dim=-1)
|
41 |
+
return nn.functional.silu(x[0]) * x[1]
|
42 |
+
|
43 |
+
class GLU_new(nn.Module):
|
44 |
+
def __init__(self,hidden_size, dropout=0.1):
|
45 |
+
super().__init__()
|
46 |
+
intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
|
47 |
+
intermediate_size = 1280
|
48 |
+
|
49 |
+
self.act = swiglu
|
50 |
+
self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
|
51 |
+
self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
|
52 |
+
self.dropout = nn.Dropout(p=dropout)
|
53 |
+
|
54 |
+
def forward(self,x):
|
55 |
+
x = self.dense_h_to_4h(x)
|
56 |
+
x = self.act(x)
|
57 |
+
x = self.dense_4h_to_h(x)
|
58 |
+
x = self.dropout(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
n_queries = 32
|
63 |
+
def get_abs_pos(abs_pos, tgt_size):
|
64 |
+
# abs_pos: L, C
|
65 |
+
# tgt_size: M
|
66 |
+
# return: M, C
|
67 |
+
src_size = int(math.sqrt(abs_pos.size(0)))
|
68 |
+
tgt_size = int(math.sqrt(tgt_size))
|
69 |
+
dtype = abs_pos.dtype
|
70 |
+
|
71 |
+
if src_size != tgt_size:
|
72 |
+
return F.interpolate(
|
73 |
+
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
74 |
+
size=(tgt_size, tgt_size),
|
75 |
+
mode="bicubic",
|
76 |
+
align_corners=False,
|
77 |
+
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
78 |
+
else:
|
79 |
+
return abs_pos
|
80 |
+
|
81 |
+
from einops import rearrange, repeat
|
82 |
+
|
83 |
+
def get_1d_sincos_pos_embed(embed_dim, pos):
|
84 |
+
"""
|
85 |
+
embed_dim: output dimension for each position
|
86 |
+
pos: a list of positions to be encoded: size (M,)
|
87 |
+
out: (M, D)
|
88 |
+
"""
|
89 |
+
assert embed_dim % 2 == 0
|
90 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
91 |
+
omega /= embed_dim / 2.
|
92 |
+
omega = 1. / 10000**omega # (D/2,)
|
93 |
+
|
94 |
+
pos = pos.reshape(-1) # (M,)
|
95 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
96 |
+
|
97 |
+
emb_sin = np.sin(out) # (M, D/2)
|
98 |
+
emb_cos = np.cos(out) # (M, D/2)
|
99 |
+
|
100 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
101 |
+
return emb
|
102 |
+
|
103 |
+
class Resampler(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
kv_dim,
|
107 |
+
embed_dim,
|
108 |
+
num_heads=8,
|
109 |
+
n_queries=64,
|
110 |
+
max_seqlen=1024,
|
111 |
+
perceiver_resampler_positional_emb=True,
|
112 |
+
use_GLU=False,
|
113 |
+
bos_init=False,
|
114 |
+
dropout=0.0
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb
|
118 |
+
|
119 |
+
if self.perceiver_resampler_positional_emb:
|
120 |
+
assert n_queries <= max_seqlen
|
121 |
+
self.stride = max_seqlen // n_queries
|
122 |
+
# self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
|
123 |
+
# nn.init.trunc_normal_(self.nan_emb, std=.02)
|
124 |
+
pos = np.arange(max_seqlen, dtype=np.float32)
|
125 |
+
self.register_buffer(
|
126 |
+
"pos_embed",
|
127 |
+
torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
|
128 |
+
)
|
129 |
+
self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
|
130 |
+
if bos_init:
|
131 |
+
self.latents.load('')
|
132 |
+
else:
|
133 |
+
nn.init.trunc_normal_(self.latents, std=1e-3)
|
134 |
+
|
135 |
+
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
136 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
|
137 |
+
self.ln_q = nn.LayerNorm(embed_dim)
|
138 |
+
self.ln_kv = nn.LayerNorm(embed_dim)
|
139 |
+
self.ln_post = nn.LayerNorm(embed_dim)
|
140 |
+
if use_GLU:
|
141 |
+
print('GLU *********************************')
|
142 |
+
self.proj = GLU_new(embed_dim, dropout=dropout)
|
143 |
+
else:
|
144 |
+
self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
145 |
+
|
146 |
+
self.apply(self._init_weights)
|
147 |
+
|
148 |
+
def _init_weights(self, m):
|
149 |
+
if isinstance(m, nn.Linear):
|
150 |
+
nn.init.trunc_normal_(m.weight, std=1e-3)
|
151 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
152 |
+
nn.init.constant_(m.bias, 0)
|
153 |
+
elif isinstance(m, nn.LayerNorm):
|
154 |
+
nn.init.constant_(m.bias, 0)
|
155 |
+
nn.init.constant_(m.weight, 1.0)
|
156 |
+
|
157 |
+
def forward(self, struc_x):
|
158 |
+
"""
|
159 |
+
Args:
|
160 |
+
x (torch.Tensor): protein structure features
|
161 |
+
shape (B, L, C)
|
162 |
+
Returns:
|
163 |
+
shape (B, n, C) where n is self.num_latents
|
164 |
+
"""
|
165 |
+
x = struc_x["encoder_out"]
|
166 |
+
mask = struc_x["encoder_padding_mask"]
|
167 |
+
|
168 |
+
|
169 |
+
nan_mask = torch.isnan(x)
|
170 |
+
if nan_mask.any():
|
171 |
+
x = x.masked_fill(nan_mask, 0.0)
|
172 |
+
# nan_mask = nan_mask.sum(dim=-1).bool()
|
173 |
+
# x[nan_mask] += self.nan_emb
|
174 |
+
|
175 |
+
x = self.kv_proj(x)
|
176 |
+
x = self.ln_kv(x)
|
177 |
+
|
178 |
+
b, seqlen = x.shape[:2]
|
179 |
+
|
180 |
+
latents = self.ln_q(self.latents)
|
181 |
+
if self.perceiver_resampler_positional_emb:
|
182 |
+
# TODO: interpolate
|
183 |
+
latents = latents + self.pos_embed[::self.stride].contiguous()
|
184 |
+
pos_emb = self.pos_embed[:seqlen].unsqueeze(0)
|
185 |
+
x = x + pos_emb.contiguous()
|
186 |
+
|
187 |
+
# blocks
|
188 |
+
latents = repeat(latents, "n d -> b n d", b=b)
|
189 |
+
out = self.attn(latents, x, x, key_padding_mask=~mask)[0]
|
190 |
+
|
191 |
+
out = self.ln_post(out)
|
192 |
+
out = self.proj(out)
|
193 |
+
|
194 |
+
return out
|
195 |
+
|
196 |
+
class StructureTransformer(nn.Module):
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
width: int = 640,
|
201 |
+
n_queries: int = 32,
|
202 |
+
output_dim: int = 4096,
|
203 |
+
embedding_keys=set(["mpnn_emb"]),
|
204 |
+
max_seqlen: int=1024,
|
205 |
+
num_heads: int=8,
|
206 |
+
structure_emb_path_prefix='structure_emb',
|
207 |
+
**kwargs
|
208 |
+
):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.structure_emb_path_prefix = structure_emb_path_prefix
|
212 |
+
# self.transformer = None # replace None with a pretrained strucure encoder
|
213 |
+
self.embedding_keys = embedding_keys
|
214 |
+
self.max_seqlen = max_seqlen
|
215 |
+
self.width = width
|
216 |
+
self.n_queries = n_queries
|
217 |
+
|
218 |
+
self.attn_pool = Resampler(
|
219 |
+
embed_dim=output_dim,
|
220 |
+
kv_dim=width,
|
221 |
+
n_queries=n_queries,
|
222 |
+
max_seqlen=max_seqlen,
|
223 |
+
num_heads=num_heads,
|
224 |
+
**kwargs
|
225 |
+
)
|
226 |
+
|
227 |
+
def prepare_structure(self, sample):
|
228 |
+
emb_pad = torch.zeros((self.max_seqlen, self.width))
|
229 |
+
emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
|
230 |
+
|
231 |
+
if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample:
|
232 |
+
mask = sample["pifold_mask"]
|
233 |
+
pifold_emb = sample["pifold_emb"]
|
234 |
+
new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan"))
|
235 |
+
new_pifold_emb[mask > 0] = pifold_emb
|
236 |
+
sample["pifold_emb"] = new_pifold_emb
|
237 |
+
|
238 |
+
### domians ###
|
239 |
+
emb = []
|
240 |
+
for ek in self.embedding_keys:
|
241 |
+
if ek in sample:
|
242 |
+
if isinstance( sample[ek], List):
|
243 |
+
emb.append(torch.cat(sample[ek]))
|
244 |
+
else:
|
245 |
+
emb.append(sample[ek])
|
246 |
+
# emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
|
247 |
+
emb = torch.cat(emb, dim=-1)
|
248 |
+
|
249 |
+
emb_pad[:len(emb)] = emb
|
250 |
+
emb_mask[:len(emb)] = 1
|
251 |
+
return emb_pad, emb_mask
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
|
255 |
+
# x = self.transformer(x)
|
256 |
+
x = self.attn_pool(x)
|
257 |
+
|
258 |
+
return x
|
259 |
+
|
260 |
+
def encode(self, structure_paths: List[str]):
|
261 |
+
structure_embs = []
|
262 |
+
structure_mask = []
|
263 |
+
|
264 |
+
for structure_path in structure_paths:
|
265 |
+
structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0]
|
266 |
+
structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
|
267 |
+
if not os.path.exists(structure_path):
|
268 |
+
print('no structure found')
|
269 |
+
return None
|
270 |
+
|
271 |
+
with open(structure_path, 'rb') as f:
|
272 |
+
structure, struc_mask = self.prepare_structure(pickle.load(f))
|
273 |
+
|
274 |
+
|
275 |
+
structure_embs.append(structure)
|
276 |
+
structure_mask.append(struc_mask)
|
277 |
+
|
278 |
+
structure_embs = torch.stack(structure_embs, dim=0).to(
|
279 |
+
device=next(self.attn_pool.parameters()).device,
|
280 |
+
dtype=next(self.attn_pool.parameters()).dtype)
|
281 |
+
structure_mask = torch.stack(structure_mask, dim=0).to(
|
282 |
+
device=next(self.attn_pool.parameters()).device)
|
283 |
+
|
284 |
+
return self({
|
285 |
+
'encoder_out': structure_embs,
|
286 |
+
'encoder_padding_mask': structure_mask
|
287 |
+
})
|
tokenizer.json
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "1.0",
|
3 |
+
"truncation": null,
|
4 |
+
"padding": null,
|
5 |
+
"added_tokens": [
|
6 |
+
{
|
7 |
+
"id": 0,
|
8 |
+
"content": "<|pad|>",
|
9 |
+
"single_word": false,
|
10 |
+
"lstrip": false,
|
11 |
+
"rstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"special": true
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"id": 1,
|
17 |
+
"content": "<|bos|>",
|
18 |
+
"single_word": false,
|
19 |
+
"lstrip": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"normalized": false,
|
22 |
+
"special": true
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"id": 2,
|
26 |
+
"content": "<|eos|>",
|
27 |
+
"single_word": false,
|
28 |
+
"lstrip": false,
|
29 |
+
"rstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"special": true
|
32 |
+
}
|
33 |
+
],
|
34 |
+
"normalizer": null,
|
35 |
+
"pre_tokenizer": {
|
36 |
+
"type": "ByteLevel",
|
37 |
+
"add_prefix_space": false,
|
38 |
+
"trim_offsets": true,
|
39 |
+
"use_regex": true
|
40 |
+
},
|
41 |
+
"post_processor": {
|
42 |
+
"type": "ByteLevel",
|
43 |
+
"add_prefix_space": true,
|
44 |
+
"trim_offsets": true,
|
45 |
+
"use_regex": true
|
46 |
+
},
|
47 |
+
"decoder": {
|
48 |
+
"type": "ByteLevel",
|
49 |
+
"add_prefix_space": true,
|
50 |
+
"trim_offsets": true,
|
51 |
+
"use_regex": true
|
52 |
+
},
|
53 |
+
"model": {
|
54 |
+
"type": "BPE",
|
55 |
+
"dropout": null,
|
56 |
+
"unk_token": null,
|
57 |
+
"continuing_subword_prefix": null,
|
58 |
+
"end_of_word_suffix": null,
|
59 |
+
"fuse_unk": false,
|
60 |
+
"byte_fallback": false,
|
61 |
+
"ignore_merges": false,
|
62 |
+
"vocab": {
|
63 |
+
"<|pad|>": 0,
|
64 |
+
"<|bos|>": 1,
|
65 |
+
"<|eos|>": 2,
|
66 |
+
"1": 3,
|
67 |
+
"2": 4,
|
68 |
+
"A": 5,
|
69 |
+
"B": 6,
|
70 |
+
"C": 7,
|
71 |
+
"D": 8,
|
72 |
+
"E": 9,
|
73 |
+
"F": 10,
|
74 |
+
"G": 11,
|
75 |
+
"H": 12,
|
76 |
+
"I": 13,
|
77 |
+
"K": 14,
|
78 |
+
"L": 15,
|
79 |
+
"M": 16,
|
80 |
+
"N": 17,
|
81 |
+
"O": 18,
|
82 |
+
"P": 19,
|
83 |
+
"Q": 20,
|
84 |
+
"R": 21,
|
85 |
+
"S": 22,
|
86 |
+
"T": 23,
|
87 |
+
"U": 24,
|
88 |
+
"V": 25,
|
89 |
+
"W": 26,
|
90 |
+
"X": 27,
|
91 |
+
"Y": 28,
|
92 |
+
"Z": 29
|
93 |
+
},
|
94 |
+
"merges": []
|
95 |
+
}
|
96 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<|pad|>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<|bos|>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "<|eos|>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"auto_map": {
|
29 |
+
"AutoTokenizer": [
|
30 |
+
"InstructPLM/MPNN-ProGen2-xlarge-CATH42--tokenization_iPLM.iPLMTokenizer",
|
31 |
+
null
|
32 |
+
]
|
33 |
+
},
|
34 |
+
"clean_up_tokenization_spaces": true,
|
35 |
+
"legacy": false,
|
36 |
+
"model_max_length": 1000000000000000019884624838656,
|
37 |
+
"pad_token": "<|pad|>",
|
38 |
+
"tokenizer_class": "iPLMTokenizer"
|
39 |
+
}
|