sharpenb commited on
Commit
2d49faf
1 Parent(s): d649e10

0379693d73141b4f796eb2f0edd3bebcf86c950fdac574056f6f09445e445196

Browse files
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ thumbnail: "https://assets-global.website-files.com/646b351987a8d8ce158d1940/64ec9e96b4334c0e1ac41504_Logo%20with%20white%20text.svg"
3
+ base_model: InstructPLM/MPNN-ProGen2-xlarge-CATH42
4
+ metrics:
5
+ - memory_disk
6
+ - memory_inference
7
+ - inference_latency
8
+ - inference_throughput
9
+ - inference_CO2_emissions
10
+ - inference_energy_consumption
11
+ tags:
12
+ - pruna-ai
13
+ ---
14
+ <!-- header start -->
15
+ <!-- 200823 -->
16
+ <div style="width: auto; margin-left: auto; margin-right: auto">
17
+ <a href="https://www.pruna.ai/" target="_blank" rel="noopener noreferrer">
18
+ <img src="https://i.imgur.com/eDAlcgk.png" alt="PrunaAI" style="width: 100%; min-width: 400px; display: block; margin: auto;">
19
+ </a>
20
+ </div>
21
+ <!-- header end -->
22
+
23
+ [![Twitter](https://img.shields.io/twitter/follow/PrunaAI?style=social)](https://twitter.com/PrunaAI)
24
+ [![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI)
25
+ [![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following)
26
+ [![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.gg/rskEr4BZJx)
27
+
28
+ # Simply make AI models cheaper, smaller, faster, and greener!
29
+
30
+ - Give a thumbs up if you like this model!
31
+ - Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
32
+ - Request access to easily compress your *own* AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
33
+ - Read the documentations to know more [here](https://pruna-ai-pruna.readthedocs-hosted.com/en/latest/)
34
+ - Join Pruna AI community on Discord [here](https://discord.gg/CP4VSgck) to share feedback/suggestions or get help.
35
+
36
+ ## Results
37
+
38
+ ![image info](./plots.png)
39
+
40
+ **Frequently Asked Questions**
41
+ - ***How does the compression work?*** The model is compressed with llm-int8.
42
+ - ***How does the model quality change?*** The quality of the model output might vary compared to the base model.
43
+ - ***How is the model efficiency evaluated?*** These results were obtained on HARDWARE_NAME with configuration described in `model/smash_config.json` and are obtained after a hardware warmup. The smashed model is directly compared to the original base model. Efficiency results may vary in other settings (e.g. other hardware, image size, batch size, ...). We recommend to directly run them in the use-case conditions to know if the smashed model can benefit you.
44
+ - ***What is the model format?*** We use safetensors.
45
+ - ***What calibration data has been used?*** If needed by the compression method, we used WikiText as the calibration data.
46
+ - ***What is the naming convention for Pruna Huggingface models?*** We take the original model name and append "turbo", "tiny", or "green" if the smashed model has a measured inference speed, inference memory, or inference energy consumption which is less than 90% of the original base model.
47
+ - ***How to compress my own models?*** You can request premium access to more compression methods and tech support for your specific use-cases [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
48
+ - ***What are "first" metrics?*** Results mentioning "first" are obtained after the first run of the model. The first run might take more memory or be slower than the subsequent runs due cuda overheads.
49
+ - ***What are "Sync" and "Async" metrics?*** "Sync" metrics are obtained by syncing all GPU processes and stop measurement when all of them are executed. "Async" metrics are obtained without syncing all GPU processes and stop when the model output can be used by the CPU. We provide both metrics since both could be relevant depending on the use-case. We recommend to test the efficiency gains directly in your use-cases.
50
+
51
+ ## Setup
52
+
53
+ You can run the smashed model with these steps:
54
+
55
+ 0. Check requirements from the original repo InstructPLM/MPNN-ProGen2-xlarge-CATH42 installed. In particular, check python, cuda, and transformers versions.
56
+ 1. Make sure that you have installed quantization related packages.
57
+ ```bash
58
+ pip install transformers accelerate bitsandbytes>0.37.0
59
+ ```
60
+ 2. Load & run the model.
61
+ ```python
62
+ from transformers import AutoModelForCausalLM, AutoTokenizer
63
+
64
+
65
+ model = AutoModelForCausalLM.from_pretrained("PrunaAI/InstructPLM-MPNN-ProGen2-xlarge-CATH42-bnb-4bit-smashed", trust_remote_code=True, device_map='auto')
66
+ tokenizer = AutoTokenizer.from_pretrained("InstructPLM/MPNN-ProGen2-xlarge-CATH42")
67
+
68
+ input_ids = tokenizer("What is the color of prunes?,", return_tensors='pt').to(model.device)["input_ids"]
69
+
70
+ outputs = model.generate(input_ids, max_new_tokens=216)
71
+ tokenizer.decode(outputs[0])
72
+ ```
73
+
74
+ ## Configurations
75
+
76
+ The configuration info are in `smash_config.json`.
77
+
78
+ ## Credits & License
79
+
80
+ The license of the smashed model follows the license of the original model. Please check the license of the original model InstructPLM/MPNN-ProGen2-xlarge-CATH42 before using this model which provided the base model. The license of the `pruna-engine` is [here](https://pypi.org/project/pruna-engine/) on Pypi.
81
+
82
+ ## Want to compress other models?
83
+
84
+ - Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
85
+ - Request access to easily compress your own AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/ceph/hdd/staff/charpent/.cache/modelsvz_9j4okrownmzr7",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "ProGenForCausalLM"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_progen.ProGenConfig",
10
+ "AutoModelForCausalLM": "modeling_InstructProGen.ProGenForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "embd_pdrop": 0.0,
14
+ "eos_token_id": 2,
15
+ "gradient_checkpointing": false,
16
+ "initializer_range": 0.02,
17
+ "layer_norm_epsilon": 1e-05,
18
+ "model_type": "progen",
19
+ "n_ctx": 2048,
20
+ "n_embd": 4096,
21
+ "n_head": 16,
22
+ "n_inner": null,
23
+ "n_layer": 32,
24
+ "n_positions": 1024,
25
+ "quantization_config": {
26
+ "_load_in_4bit": true,
27
+ "_load_in_8bit": false,
28
+ "bnb_4bit_compute_dtype": "bfloat16",
29
+ "bnb_4bit_quant_storage": "uint8",
30
+ "bnb_4bit_quant_type": "fp4",
31
+ "bnb_4bit_use_double_quant": false,
32
+ "llm_int8_enable_fp32_cpu_offload": false,
33
+ "llm_int8_has_fp16_weight": false,
34
+ "llm_int8_skip_modules": [
35
+ "lm_head"
36
+ ],
37
+ "llm_int8_threshold": 6.0,
38
+ "load_in_4bit": true,
39
+ "load_in_8bit": false,
40
+ "quant_method": "bitsandbytes"
41
+ },
42
+ "resid_pdrop": 0.0,
43
+ "rotary_dim": 64,
44
+ "scale_attn_weights": true,
45
+ "structure": {
46
+ "embedding_keys": [
47
+ "mpnn_emb"
48
+ ],
49
+ "max_seqlen": 512,
50
+ "n_queries": 256,
51
+ "num_heads": 16,
52
+ "output_dim": 4096,
53
+ "structure_emb_path_prefix": "./structure_embeddings",
54
+ "width": 1152
55
+ },
56
+ "summary_activation": null,
57
+ "summary_first_dropout": 0.1,
58
+ "summary_proj_to_labels": true,
59
+ "summary_type": "cls_index",
60
+ "summary_use_proj": true,
61
+ "task_specific_params": {
62
+ "text-generation": {
63
+ "do_sample": true,
64
+ "max_length": 50,
65
+ "temperature": 1.0
66
+ }
67
+ },
68
+ "tie_word_embeddings": false,
69
+ "tokenizer_type": "iPLMTokenizer",
70
+ "torch_dtype": "float16",
71
+ "transformers_version": "4.42.4",
72
+ "use_cache": true,
73
+ "vocab_size": 30
74
+ }
configuration_progen.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Modified configuration implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ProGenConfig(PretrainedConfig):
25
+ model_type = "progen"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size=50400,
30
+ n_positions=2048,
31
+ n_ctx=2048,
32
+ n_embd=4096,
33
+ n_layer=28,
34
+ n_head=16,
35
+ rotary_dim=64,
36
+ n_inner=None,
37
+ activation_function="gelu_new",
38
+ resid_pdrop=0.0,
39
+ embd_pdrop=0.0,
40
+ attn_pdrop=0.0,
41
+ layer_norm_epsilon=1e-5,
42
+ initializer_range=0.02,
43
+ scale_attn_weights=True,
44
+ gradient_checkpointing=False,
45
+ use_cache=True,
46
+ bos_token_id=50256,
47
+ eos_token_id=50256,
48
+ tie_word_embeddings=False,
49
+ **kwargs
50
+ ):
51
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_ctx = n_ctx
55
+ self.n_positions = n_positions
56
+ self.n_embd = n_embd
57
+ self.n_layer = n_layer
58
+ self.n_head = n_head
59
+ self.n_inner = n_inner
60
+ self.rotary_dim = rotary_dim
61
+ self.activation_function = activation_function
62
+ self.resid_pdrop = resid_pdrop
63
+ self.embd_pdrop = embd_pdrop
64
+ self.attn_pdrop = attn_pdrop
65
+ self.layer_norm_epsilon = layer_norm_epsilon
66
+ self.initializer_range = initializer_range
67
+ self.gradient_checkpointing = gradient_checkpointing
68
+ self.scale_attn_weights = scale_attn_weights
69
+ self.use_cache = use_cache
70
+
71
+ self.bos_token_id = bos_token_id
72
+ self.eos_token_id = eos_token_id
73
+ self.tie_word_embeddings = tie_word_embeddings
74
+
75
+ @property
76
+ def max_position_embeddings(self):
77
+ return self.n_positions
78
+
79
+ @property
80
+ def hidden_size(self):
81
+ return self.n_embd
82
+
83
+ @property
84
+ def num_attention_heads(self):
85
+ return self.n_head
86
+
87
+ @property
88
+ def num_hidden_layers(self):
89
+ return self.n_layer
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.42.4"
6
+ }
modeling_InstructProGen.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.generation.configuration_utils import GenerationConfig
12
+ from transformers.generation.logits_process import LogitsProcessorList
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from transformers.generation.utils import GenerateOutput
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+ from .configuration_progen import ProGenConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ from .structure import StructureTransformer
25
+
26
+
27
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
28
+ dim = x.shape[-1]
29
+ if seq_len is None:
30
+ seq_len = x.shape[seq_dim]
31
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
32
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
33
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
34
+
35
+
36
+ def rotate_every_two(x):
37
+ x1 = x[:, :, :, ::2]
38
+ x2 = x[:, :, :, 1::2]
39
+ x = torch.stack((-x2, x1), axis=-1)
40
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
41
+
42
+
43
+ def apply_rotary_pos_emb(x, sincos, offset=0):
44
+ sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
45
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
46
+ return (x * cos) + (rotate_every_two(x) * sin)
47
+
48
+
49
+ class ProGenAttention(nn.Module):
50
+ def __init__(self, config):
51
+ super().__init__()
52
+
53
+ max_positions = config.max_position_embeddings
54
+ self.register_buffer(
55
+ "bias",
56
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
57
+ 1, 1, max_positions, max_positions
58
+ ),
59
+ )
60
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
61
+
62
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
63
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
64
+
65
+ self.embed_dim = config.hidden_size
66
+ self.num_attention_heads = config.num_attention_heads
67
+ self.head_dim = self.embed_dim // self.num_attention_heads
68
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
69
+ raise ValueError(
70
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
71
+ )
72
+ self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
73
+ self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
74
+
75
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
76
+ self.rotary_dim = None
77
+ if config.rotary_dim is not None:
78
+ self.rotary_dim = config.rotary_dim
79
+
80
+ def _split_heads(self, x, n_head, dim_head, mp_num):
81
+ reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head))
82
+ reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:])
83
+ return reshaped
84
+
85
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
86
+ """
87
+ Merges attn_head_size dim and num_attn_heads dim into n_ctx
88
+ """
89
+ if len(tensor.shape) == 5:
90
+ tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
91
+ elif len(tensor.shape) == 4:
92
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
93
+ else:
94
+ raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
95
+ new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
96
+ return tensor.view(new_shape)
97
+
98
+ def _attn(
99
+ self,
100
+ query,
101
+ key,
102
+ value,
103
+ attention_mask=None,
104
+ head_mask=None,
105
+ ):
106
+
107
+ # compute causal mask from causal mask buffer
108
+ query_length, key_length = query.size(-2), key.size(-2)
109
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
110
+
111
+ # Keep the attention weights computation in fp32 to avoid overflow issues
112
+ query = query.to(torch.float32)
113
+ key = key.to(torch.float32)
114
+
115
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
116
+
117
+ attn_weights = attn_weights / self.scale_attn
118
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
119
+
120
+ if attention_mask is not None:
121
+ # Apply the attention mask
122
+ attn_weights = attn_weights + attention_mask
123
+
124
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
125
+ attn_weights = attn_weights.to(value.dtype)
126
+ attn_weights = self.attn_dropout(attn_weights)
127
+
128
+ # Mask heads if we want to
129
+ if head_mask is not None:
130
+ attn_weights = attn_weights * head_mask
131
+
132
+ attn_output = torch.matmul(attn_weights, value)
133
+
134
+ return attn_output, attn_weights
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states,
139
+ attention_mask=None,
140
+ layer_past=None,
141
+ head_mask=None,
142
+ use_cache=False,
143
+ output_attentions=False,
144
+ ):
145
+
146
+ qkv = self.qkv_proj(hidden_states)
147
+ # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
148
+ # mp_num = 4
149
+ mp_num = 8
150
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
151
+
152
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
153
+ query, value, key = torch.split(qkv_split, local_dim, dim=-1)
154
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
155
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
156
+
157
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
158
+ value = value.permute(0, 2, 1, 3)
159
+
160
+ seq_len = key.shape[1]
161
+ offset = 0
162
+
163
+ if layer_past is not None:
164
+ offset = layer_past[0].shape[-2]
165
+ seq_len += offset
166
+
167
+ if self.rotary_dim is not None:
168
+ k_rot = key[:, :, :, : self.rotary_dim]
169
+ k_pass = key[:, :, :, self.rotary_dim :]
170
+
171
+ q_rot = query[:, :, :, : self.rotary_dim]
172
+ q_pass = query[:, :, :, self.rotary_dim :]
173
+
174
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
175
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
176
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
177
+
178
+ key = torch.cat([k_rot, k_pass], dim=-1)
179
+ query = torch.cat([q_rot, q_pass], dim=-1)
180
+ else:
181
+ sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
182
+ key = apply_rotary_pos_emb(key, sincos, offset=offset)
183
+ query = apply_rotary_pos_emb(query, sincos, offset=offset)
184
+
185
+ key = key.permute(0, 2, 1, 3)
186
+ query = query.permute(0, 2, 1, 3)
187
+
188
+ if layer_past is not None:
189
+ past_key = layer_past[0]
190
+ past_value = layer_past[1]
191
+ key = torch.cat((past_key, key), dim=-2)
192
+ value = torch.cat((past_value, value), dim=-2)
193
+
194
+ if use_cache is True:
195
+ present = (key, value)
196
+ else:
197
+ present = None
198
+
199
+ # compute self-attention: V x Softmax(QK^T)
200
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
201
+
202
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
203
+
204
+ attn_output = self.out_proj(attn_output)
205
+ attn_output = self.resid_dropout(attn_output)
206
+
207
+ outputs = (attn_output, present)
208
+ if output_attentions:
209
+ outputs += (attn_weights,)
210
+
211
+ return outputs # a, present, (attentions)
212
+
213
+
214
+ class ProGenMLP(nn.Module):
215
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
216
+ super().__init__()
217
+ embed_dim = config.n_embd
218
+
219
+ self.fc_in = nn.Linear(embed_dim, intermediate_size)
220
+ self.fc_out = nn.Linear(intermediate_size, embed_dim)
221
+
222
+ self.act = ACT2FN[config.activation_function]
223
+ self.dropout = nn.Dropout(config.resid_pdrop)
224
+
225
+ def forward(self, hidden_states):
226
+ hidden_states = self.fc_in(hidden_states)
227
+ hidden_states = self.act(hidden_states)
228
+ hidden_states = self.fc_out(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ return hidden_states
231
+
232
+
233
+ class ProGenBlock(nn.Module):
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
237
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
238
+ self.attn = ProGenAttention(config)
239
+ self.mlp = ProGenMLP(inner_dim, config)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states,
244
+ layer_past=None,
245
+ attention_mask=None,
246
+ head_mask=None,
247
+ use_cache=False,
248
+ output_attentions=False,
249
+ ):
250
+ residual = hidden_states
251
+ hidden_states = self.ln_1(hidden_states)
252
+ attn_outputs = self.attn(
253
+ hidden_states,
254
+ layer_past=layer_past,
255
+ attention_mask=attention_mask,
256
+ head_mask=head_mask,
257
+ use_cache=use_cache,
258
+ output_attentions=output_attentions,
259
+ )
260
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
261
+ outputs = attn_outputs[1:]
262
+
263
+ feed_forward_hidden_states = self.mlp(hidden_states)
264
+ hidden_states = attn_output + feed_forward_hidden_states + residual
265
+
266
+ if use_cache:
267
+ outputs = (hidden_states,) + outputs
268
+ else:
269
+ outputs = (hidden_states,) + outputs[1:]
270
+
271
+ return outputs # hidden_states, present, (attentions)
272
+
273
+
274
+ class ProGenPreTrainedModel(PreTrainedModel):
275
+ """
276
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
277
+ models.
278
+ """
279
+
280
+ config_class = ProGenConfig
281
+ base_model_prefix = "transformer"
282
+ supports_gradient_checkpointing = True
283
+ is_parallelizable = True
284
+
285
+ def __init__(self, *inputs, **kwargs):
286
+ super().__init__(*inputs, **kwargs)
287
+
288
+ def _init_weights(self, module):
289
+ """Initialize the weights."""
290
+ if isinstance(module, (nn.Linear,)):
291
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
292
+ # cf https://github.com/pytorch/pytorch/pull/5617
293
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
294
+ if module.bias is not None:
295
+ module.bias.data.zero_()
296
+ elif isinstance(module, nn.Embedding):
297
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
298
+ if module.padding_idx is not None:
299
+ module.weight.data[module.padding_idx].zero_()
300
+ elif isinstance(module, nn.LayerNorm):
301
+ module.bias.data.zero_()
302
+ module.weight.data.fill_(1.0)
303
+
304
+ def _set_gradient_checkpointing(self, module, value=False):
305
+ if isinstance(module, ProGenModel):
306
+ module.gradient_checkpointing = value
307
+
308
+ class ProGenModel(ProGenPreTrainedModel):
309
+ def __init__(self, config):
310
+ super().__init__(config)
311
+
312
+ self.embed_dim = config.n_embd
313
+ self.vocab_size = config.vocab_size
314
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
315
+ self.drop = nn.Dropout(config.embd_pdrop)
316
+ self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
317
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
318
+ self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
319
+
320
+ self.gradient_checkpointing = False
321
+ self.structure = StructureTransformer(**config.structure)
322
+
323
+ self.init_weights()
324
+
325
+ # Model parallel
326
+ self.model_parallel = False
327
+ self.device_map = None
328
+
329
+
330
+ def parallelize(self, device_map=None):
331
+ # Check validity of device_map
332
+ self.device_map = (
333
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
334
+ )
335
+ assert_device_map(self.device_map, len(self.h))
336
+ self.model_parallel = True
337
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
338
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
339
+ self.wte = self.wte.to(self.first_device)
340
+ # Load onto devices
341
+ for k, v in self.device_map.items():
342
+ for block in v:
343
+ cuda_device = "cuda:" + str(k)
344
+ self.h[block] = self.h[block].to(cuda_device)
345
+ # ln_f to last
346
+ self.ln_f = self.ln_f.to(self.last_device)
347
+
348
+
349
+ def deparallelize(self):
350
+ self.model_parallel = False
351
+ self.device_map = None
352
+ self.first_device = "cpu"
353
+ self.last_device = "cpu"
354
+ self.wte = self.wte.to("cpu")
355
+ for index in range(len(self.h)):
356
+ self.h[index] = self.h[index].to("cpu")
357
+ self.ln_f = self.ln_f.to("cpu")
358
+ torch.cuda.empty_cache()
359
+
360
+ def get_input_embeddings(self):
361
+ return self.wte
362
+
363
+ def set_input_embeddings(self, new_embeddings):
364
+ self.wte = new_embeddings
365
+
366
+ def forward(
367
+ self,
368
+ input_ids=None,
369
+ past_key_values=None,
370
+ attention_mask=None,
371
+ token_type_ids=None,
372
+ position_ids=None,
373
+ head_mask=None,
374
+ inputs_embeds=None,
375
+ query_embeds=None,
376
+ use_cache=None,
377
+ output_attentions=None,
378
+ output_hidden_states=None,
379
+ return_dict=None,
380
+ ):
381
+ if past_key_values is None:
382
+ # structure encode will check if input_ids contains valid
383
+ structure_embs = self.structure.encode(input_ids)
384
+ if structure_embs is not None:
385
+ input_ids = input_ids[:, self.structure.n_queries:]
386
+ else:
387
+ structure_embs = None
388
+
389
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
390
+ output_hidden_states = (
391
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
392
+ )
393
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ if input_ids is not None and inputs_embeds is not None:
397
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
398
+ elif input_ids is not None:
399
+ input_shape = input_ids.size()
400
+ input_ids = input_ids.view(-1, input_shape[-1])
401
+ batch_size = input_ids.shape[0]
402
+ elif inputs_embeds is not None:
403
+ input_shape = inputs_embeds.size()[:-1]
404
+ batch_size = inputs_embeds.shape[0]
405
+ else:
406
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
407
+
408
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
409
+
410
+ # if token_type_ids is not None:
411
+ # token_type_ids = token_type_ids.view(-1, input_shape[-1])
412
+
413
+ if position_ids is not None:
414
+ position_ids = position_ids.view(-1, input_shape[-1])
415
+
416
+ if past_key_values is None:
417
+ past_length = 0
418
+ past_key_values = tuple([None] * len(self.h))
419
+ else:
420
+ past_length = past_key_values[0][0].size(-2)
421
+
422
+ if position_ids is None:
423
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
424
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
425
+
426
+ # Attention mask.
427
+ if attention_mask is not None:
428
+ assert batch_size > 0, "batch_size has to be defined and > 0"
429
+ attention_mask = attention_mask.view(batch_size, -1)
430
+ # We create a 3D attention mask from a 2D tensor mask.
431
+ # Sizes are [batch_size, 1, 1, to_seq_length]
432
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
433
+ # this attention mask is more simple than the triangular masking of causal attention
434
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
435
+ attention_mask = attention_mask[:, None, None, :]
436
+
437
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
438
+ # masked positions, this operation will create a tensor which is 0.0 for
439
+ # positions we want to attend and -10000.0 for masked positions.
440
+ # Since we are adding it to the raw scores before the softmax, this is
441
+ # effectively the same as removing these entirely.
442
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
443
+ attention_mask = (1.0 - attention_mask) * -10000.0
444
+
445
+ # Prepare head mask if needed
446
+ # 1.0 in head_mask indicate we keep the head
447
+ # attention_probs has shape bsz x num_attention_heads x N x N
448
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
449
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
450
+
451
+ if inputs_embeds is None:
452
+ inputs_embeds = self.wte(input_ids)
453
+
454
+ if query_embeds is not None:
455
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
456
+ input_shape = inputs_embeds.size()[:-1]
457
+
458
+ if structure_embs is not None:
459
+ inputs_embeds = torch.cat([structure_embs, inputs_embeds], dim=1)
460
+ input_shape = inputs_embeds.size()[:-1]
461
+
462
+ hidden_states = inputs_embeds
463
+
464
+ # disable token_type_ids
465
+ # if token_type_ids is not None:
466
+ # token_type_embeds = self.wte(token_type_ids)
467
+ # hidden_states = hidden_states + token_type_embeds
468
+
469
+ hidden_states = self.drop(hidden_states)
470
+
471
+ output_shape = input_shape + (hidden_states.size(-1),)
472
+
473
+ presents = () if use_cache else None
474
+ all_self_attentions = () if output_attentions else None
475
+ all_hidden_states = () if output_hidden_states else None
476
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
477
+
478
+ # Model parallel
479
+ if self.model_parallel:
480
+ torch.cuda.set_device(hidden_states.device)
481
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
482
+ if layer_past is not None:
483
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
484
+ # Ensure that attention_mask is always on the same device as hidden_states
485
+ if attention_mask is not None:
486
+ attention_mask = attention_mask.to(hidden_states.device)
487
+ if isinstance(head_mask, torch.Tensor):
488
+ head_mask = head_mask.to(hidden_states.device)
489
+ if output_hidden_states:
490
+ all_hidden_states = all_hidden_states + (hidden_states,)
491
+
492
+ if self.gradient_checkpointing and self.training:
493
+
494
+ if use_cache:
495
+ # logger.warning(
496
+ # "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
497
+ # "`use_cache=False`..."
498
+ # )
499
+ use_cache = False
500
+
501
+ def create_custom_forward(module):
502
+ def custom_forward(*inputs):
503
+ # None for past_key_value
504
+ return module(*inputs, use_cache, output_attentions)
505
+
506
+ return custom_forward
507
+
508
+ outputs = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(block),
510
+ hidden_states,
511
+ None,
512
+ attention_mask,
513
+ head_mask[i],
514
+ )
515
+ else:
516
+ outputs = block(
517
+ hidden_states,
518
+ layer_past=layer_past,
519
+ attention_mask=attention_mask,
520
+ head_mask=head_mask[i],
521
+ use_cache=use_cache,
522
+ output_attentions=output_attentions,
523
+ )
524
+
525
+ hidden_states = outputs[0]
526
+ if use_cache is True:
527
+ presents = presents + (outputs[1],)
528
+
529
+ if output_attentions:
530
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
531
+
532
+ # Model Parallel: If it's the last layer for that device, put things on the next device
533
+ if self.model_parallel:
534
+ for k, v in self.device_map.items():
535
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
536
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
537
+
538
+ hidden_states = self.ln_f(hidden_states)
539
+
540
+ hidden_states = hidden_states.view(*output_shape)
541
+ # Add last hidden state
542
+ if output_hidden_states:
543
+ all_hidden_states = all_hidden_states + (hidden_states,)
544
+
545
+ if not return_dict:
546
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
547
+
548
+ return BaseModelOutputWithPast(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=presents,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_self_attentions,
553
+ )
554
+
555
+
556
+ class ProGenForCausalLM(ProGenPreTrainedModel):
557
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
558
+
559
+ def __init__(self, config):
560
+ super().__init__(config)
561
+ self.transformer = ProGenModel(config)
562
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
563
+ self.init_weights()
564
+
565
+ # Model parallel
566
+ self.model_parallel = False
567
+ self.device_map = None
568
+
569
+ def parallelize(self, device_map=None):
570
+ self.device_map = (
571
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
572
+ if device_map is None
573
+ else device_map
574
+ )
575
+ assert_device_map(self.device_map, len(self.transformer.h))
576
+ self.transformer.parallelize(self.device_map)
577
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
578
+ self.model_parallel = True
579
+
580
+ def deparallelize(self):
581
+ self.transformer.deparallelize()
582
+ self.transformer = self.transformer.to("cpu")
583
+ self.lm_head = self.lm_head.to("cpu")
584
+ self.model_parallel = False
585
+ torch.cuda.empty_cache()
586
+
587
+ def get_output_embeddings(self):
588
+ return self.lm_head
589
+
590
+ def set_output_embeddings(self, new_embeddings):
591
+ self.lm_head = new_embeddings
592
+
593
+ def prepare_inputs_for_generation(
594
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
595
+ ):
596
+ if past_key_values:
597
+ input_ids = input_ids[:, -1:]
598
+
599
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
600
+ if inputs_embeds is not None and past_key_values is None:
601
+ model_inputs = {"inputs_embeds": inputs_embeds}
602
+ else:
603
+ model_inputs = {"input_ids": input_ids}
604
+
605
+ model_inputs.update(
606
+ {
607
+ "past_key_values": past_key_values,
608
+ "use_cache": kwargs.get("use_cache"),
609
+ "attention_mask": attention_mask,
610
+ }
611
+ )
612
+ return model_inputs
613
+
614
+ def forward(
615
+ self,
616
+ input_ids=None,
617
+ past_key_values=None,
618
+ attention_mask=None,
619
+ token_type_ids=None,
620
+ position_ids=None,
621
+ head_mask=None,
622
+ inputs_embeds=None,
623
+ labels=None,
624
+ use_cache=None,
625
+ query_embeds = None,
626
+ output_attentions=None,
627
+ output_hidden_states=None,
628
+ return_dict=None,
629
+ ):
630
+ r"""
631
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
632
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
633
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
634
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
635
+ """
636
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
637
+
638
+ transformer_outputs = self.transformer(
639
+ input_ids,
640
+ past_key_values=past_key_values,
641
+ attention_mask=attention_mask,
642
+ token_type_ids=token_type_ids,
643
+ position_ids=position_ids,
644
+ head_mask=head_mask,
645
+ inputs_embeds=inputs_embeds,
646
+ query_embeds=query_embeds,
647
+ use_cache=use_cache,
648
+ output_attentions=output_attentions,
649
+ output_hidden_states=output_hidden_states,
650
+ return_dict=return_dict,
651
+ )
652
+ hidden_states = transformer_outputs[0]
653
+
654
+ # Set device for model parallelism
655
+ if self.model_parallel:
656
+ torch.cuda.set_device(self.transformer.first_device)
657
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
658
+
659
+ # make sure sampling in fp16 works correctly and
660
+ # compute loss in fp32 to match with mesh-tf version
661
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
662
+ lm_logits = self.lm_head(hidden_states).to(torch.float32)
663
+
664
+ loss = None
665
+ if labels is not None:
666
+ # Shift so that tokens < n predict n
667
+ shift_logits = lm_logits[..., :-1, :].contiguous()
668
+ shift_labels = labels[..., 1:].contiguous()
669
+ # Flatten the tokens
670
+ loss_fct = CrossEntropyLoss()
671
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
672
+
673
+ loss = loss.to(hidden_states.dtype)
674
+
675
+ if not return_dict:
676
+ output = (lm_logits,) + transformer_outputs[1:]
677
+ return ((loss,) + output) if loss is not None else output
678
+
679
+ return CausalLMOutputWithPast(
680
+ loss=loss,
681
+ logits=lm_logits,
682
+ past_key_values=transformer_outputs.past_key_values,
683
+ hidden_states=transformer_outputs.hidden_states,
684
+ attentions=transformer_outputs.attentions,
685
+ )
686
+
687
+ @staticmethod
688
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
689
+ """
690
+ This function is used to re-order the :obj:`past_key_values` cache if
691
+ :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
692
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
693
+ """
694
+ return tuple(
695
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
696
+ for layer_past in past
697
+ )
698
+
699
+ # def generate(self, inputs: Tensor | None = None, generation_config: GenerationConfig | None = None, logits_processor: LogitsProcessorList | None = None, stopping_criteria: StoppingCriteriaList | None = None, prefix_allowed_tokens_fn: Callable[[int, Tensor], List[int]] | None = None, synced_gpus: bool | None = None, assistant_model: PreTrainedModel | None = None, streamer: BaseStreamer | None = None, negative_prompt_ids: Tensor | None = None, negative_prompt_attention_mask: Tensor | None = None, **kwargs) -> GenerateOutput | LongTensor:
700
+ # return super().generate(inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
smash_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "api_key": null,
3
+ "verify_url": "http://johnrachwan.pythonanywhere.com",
4
+ "smash_config": {
5
+ "pruners": "None",
6
+ "pruning_ratio": 0.0,
7
+ "factorizers": "None",
8
+ "quantizers": "['llm-int8']",
9
+ "weight_quantization_bits": 4,
10
+ "output_deviation": 0.005,
11
+ "compilers": "None",
12
+ "static_batch": true,
13
+ "static_shape": true,
14
+ "controlnet": "None",
15
+ "unet_dim": 4,
16
+ "device": "cuda",
17
+ "cache_dir": "/ceph/hdd/staff/charpent/.cache/modelsvz_9j4ok",
18
+ "batch_size": 1,
19
+ "model_name": "InstructPLM/MPNN-ProGen2-xlarge-CATH42",
20
+ "task": "text_text_generation",
21
+ "max_batch_size": 1,
22
+ "qtype_weight": "torch.qint8",
23
+ "qtype_activation": "torch.quint8",
24
+ "qobserver": "<class 'torch.ao.quantization.observer.MinMaxObserver'>",
25
+ "qscheme": "torch.per_tensor_symmetric",
26
+ "qconfig": "x86",
27
+ "group_size": 128,
28
+ "damp_percent": 0.1,
29
+ "save_load_fn": "bitsandbytes"
30
+ }
31
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<|pad|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ }
9
+ }
structure.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ import pickle
12
+ from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
+ import os
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ class GLU(nn.Module):
23
+ def __init__(self,hidden_size):
24
+ super().__init__()
25
+ self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
26
+ self.norm1 = nn.LayerNorm(hidden_size)
27
+ self.act1 = nn.GELU()
28
+ self.act2 = nn.functional.silu
29
+ self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
30
+ self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
31
+ self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)
32
+
33
+ def forward(self,x):
34
+ x = self.linear_proj(x)
35
+ x = self.act1(self.norm1(x))
36
+ x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
37
+ x = self.dense_4h_to_h(x)
38
+ return x
39
+ def swiglu(x):
40
+ x = torch.chunk(x, 2, dim=-1)
41
+ return nn.functional.silu(x[0]) * x[1]
42
+
43
+ class GLU_new(nn.Module):
44
+ def __init__(self,hidden_size, dropout=0.1):
45
+ super().__init__()
46
+ intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
47
+ intermediate_size = 1280
48
+
49
+ self.act = swiglu
50
+ self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
51
+ self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ def forward(self,x):
55
+ x = self.dense_h_to_4h(x)
56
+ x = self.act(x)
57
+ x = self.dense_4h_to_h(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+
62
+ n_queries = 32
63
+ def get_abs_pos(abs_pos, tgt_size):
64
+ # abs_pos: L, C
65
+ # tgt_size: M
66
+ # return: M, C
67
+ src_size = int(math.sqrt(abs_pos.size(0)))
68
+ tgt_size = int(math.sqrt(tgt_size))
69
+ dtype = abs_pos.dtype
70
+
71
+ if src_size != tgt_size:
72
+ return F.interpolate(
73
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
74
+ size=(tgt_size, tgt_size),
75
+ mode="bicubic",
76
+ align_corners=False,
77
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
78
+ else:
79
+ return abs_pos
80
+
81
+ from einops import rearrange, repeat
82
+
83
+ def get_1d_sincos_pos_embed(embed_dim, pos):
84
+ """
85
+ embed_dim: output dimension for each position
86
+ pos: a list of positions to be encoded: size (M,)
87
+ out: (M, D)
88
+ """
89
+ assert embed_dim % 2 == 0
90
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
91
+ omega /= embed_dim / 2.
92
+ omega = 1. / 10000**omega # (D/2,)
93
+
94
+ pos = pos.reshape(-1) # (M,)
95
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
96
+
97
+ emb_sin = np.sin(out) # (M, D/2)
98
+ emb_cos = np.cos(out) # (M, D/2)
99
+
100
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
101
+ return emb
102
+
103
+ class Resampler(nn.Module):
104
+ def __init__(
105
+ self,
106
+ kv_dim,
107
+ embed_dim,
108
+ num_heads=8,
109
+ n_queries=64,
110
+ max_seqlen=1024,
111
+ perceiver_resampler_positional_emb=True,
112
+ use_GLU=False,
113
+ bos_init=False,
114
+ dropout=0.0
115
+ ):
116
+ super().__init__()
117
+ self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb
118
+
119
+ if self.perceiver_resampler_positional_emb:
120
+ assert n_queries <= max_seqlen
121
+ self.stride = max_seqlen // n_queries
122
+ # self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
123
+ # nn.init.trunc_normal_(self.nan_emb, std=.02)
124
+ pos = np.arange(max_seqlen, dtype=np.float32)
125
+ self.register_buffer(
126
+ "pos_embed",
127
+ torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
128
+ )
129
+ self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
130
+ if bos_init:
131
+ self.latents.load('')
132
+ else:
133
+ nn.init.trunc_normal_(self.latents, std=1e-3)
134
+
135
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
136
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
137
+ self.ln_q = nn.LayerNorm(embed_dim)
138
+ self.ln_kv = nn.LayerNorm(embed_dim)
139
+ self.ln_post = nn.LayerNorm(embed_dim)
140
+ if use_GLU:
141
+ print('GLU *********************************')
142
+ self.proj = GLU_new(embed_dim, dropout=dropout)
143
+ else:
144
+ self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
145
+
146
+ self.apply(self._init_weights)
147
+
148
+ def _init_weights(self, m):
149
+ if isinstance(m, nn.Linear):
150
+ nn.init.trunc_normal_(m.weight, std=1e-3)
151
+ if isinstance(m, nn.Linear) and m.bias is not None:
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.LayerNorm):
154
+ nn.init.constant_(m.bias, 0)
155
+ nn.init.constant_(m.weight, 1.0)
156
+
157
+ def forward(self, struc_x):
158
+ """
159
+ Args:
160
+ x (torch.Tensor): protein structure features
161
+ shape (B, L, C)
162
+ Returns:
163
+ shape (B, n, C) where n is self.num_latents
164
+ """
165
+ x = struc_x["encoder_out"]
166
+ mask = struc_x["encoder_padding_mask"]
167
+
168
+
169
+ nan_mask = torch.isnan(x)
170
+ if nan_mask.any():
171
+ x = x.masked_fill(nan_mask, 0.0)
172
+ # nan_mask = nan_mask.sum(dim=-1).bool()
173
+ # x[nan_mask] += self.nan_emb
174
+
175
+ x = self.kv_proj(x)
176
+ x = self.ln_kv(x)
177
+
178
+ b, seqlen = x.shape[:2]
179
+
180
+ latents = self.ln_q(self.latents)
181
+ if self.perceiver_resampler_positional_emb:
182
+ # TODO: interpolate
183
+ latents = latents + self.pos_embed[::self.stride].contiguous()
184
+ pos_emb = self.pos_embed[:seqlen].unsqueeze(0)
185
+ x = x + pos_emb.contiguous()
186
+
187
+ # blocks
188
+ latents = repeat(latents, "n d -> b n d", b=b)
189
+ out = self.attn(latents, x, x, key_padding_mask=~mask)[0]
190
+
191
+ out = self.ln_post(out)
192
+ out = self.proj(out)
193
+
194
+ return out
195
+
196
+ class StructureTransformer(nn.Module):
197
+
198
+ def __init__(
199
+ self,
200
+ width: int = 640,
201
+ n_queries: int = 32,
202
+ output_dim: int = 4096,
203
+ embedding_keys=set(["mpnn_emb"]),
204
+ max_seqlen: int=1024,
205
+ num_heads: int=8,
206
+ structure_emb_path_prefix='structure_emb',
207
+ **kwargs
208
+ ):
209
+ super().__init__()
210
+
211
+ self.structure_emb_path_prefix = structure_emb_path_prefix
212
+ # self.transformer = None # replace None with a pretrained strucure encoder
213
+ self.embedding_keys = embedding_keys
214
+ self.max_seqlen = max_seqlen
215
+ self.width = width
216
+ self.n_queries = n_queries
217
+
218
+ self.attn_pool = Resampler(
219
+ embed_dim=output_dim,
220
+ kv_dim=width,
221
+ n_queries=n_queries,
222
+ max_seqlen=max_seqlen,
223
+ num_heads=num_heads,
224
+ **kwargs
225
+ )
226
+
227
+ def prepare_structure(self, sample):
228
+ emb_pad = torch.zeros((self.max_seqlen, self.width))
229
+ emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
230
+
231
+ if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample:
232
+ mask = sample["pifold_mask"]
233
+ pifold_emb = sample["pifold_emb"]
234
+ new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan"))
235
+ new_pifold_emb[mask > 0] = pifold_emb
236
+ sample["pifold_emb"] = new_pifold_emb
237
+
238
+ ### domians ###
239
+ emb = []
240
+ for ek in self.embedding_keys:
241
+ if ek in sample:
242
+ if isinstance( sample[ek], List):
243
+ emb.append(torch.cat(sample[ek]))
244
+ else:
245
+ emb.append(sample[ek])
246
+ # emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
247
+ emb = torch.cat(emb, dim=-1)
248
+
249
+ emb_pad[:len(emb)] = emb
250
+ emb_mask[:len(emb)] = 1
251
+ return emb_pad, emb_mask
252
+
253
+ def forward(self, x):
254
+
255
+ # x = self.transformer(x)
256
+ x = self.attn_pool(x)
257
+
258
+ return x
259
+
260
+ def encode(self, structure_paths: List[str]):
261
+ structure_embs = []
262
+ structure_mask = []
263
+
264
+ for structure_path in structure_paths:
265
+ structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0]
266
+ structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
267
+ if not os.path.exists(structure_path):
268
+ print('no structure found')
269
+ return None
270
+
271
+ with open(structure_path, 'rb') as f:
272
+ structure, struc_mask = self.prepare_structure(pickle.load(f))
273
+
274
+
275
+ structure_embs.append(structure)
276
+ structure_mask.append(struc_mask)
277
+
278
+ structure_embs = torch.stack(structure_embs, dim=0).to(
279
+ device=next(self.attn_pool.parameters()).device,
280
+ dtype=next(self.attn_pool.parameters()).dtype)
281
+ structure_mask = torch.stack(structure_mask, dim=0).to(
282
+ device=next(self.attn_pool.parameters()).device)
283
+
284
+ return self({
285
+ 'encoder_out': structure_embs,
286
+ 'encoder_padding_mask': structure_mask
287
+ })
tokenizer.json ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<|pad|>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<|bos|>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<|eos|>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ }
33
+ ],
34
+ "normalizer": null,
35
+ "pre_tokenizer": {
36
+ "type": "ByteLevel",
37
+ "add_prefix_space": false,
38
+ "trim_offsets": true,
39
+ "use_regex": true
40
+ },
41
+ "post_processor": {
42
+ "type": "ByteLevel",
43
+ "add_prefix_space": true,
44
+ "trim_offsets": true,
45
+ "use_regex": true
46
+ },
47
+ "decoder": {
48
+ "type": "ByteLevel",
49
+ "add_prefix_space": true,
50
+ "trim_offsets": true,
51
+ "use_regex": true
52
+ },
53
+ "model": {
54
+ "type": "BPE",
55
+ "dropout": null,
56
+ "unk_token": null,
57
+ "continuing_subword_prefix": null,
58
+ "end_of_word_suffix": null,
59
+ "fuse_unk": false,
60
+ "byte_fallback": false,
61
+ "ignore_merges": false,
62
+ "vocab": {
63
+ "<|pad|>": 0,
64
+ "<|bos|>": 1,
65
+ "<|eos|>": 2,
66
+ "1": 3,
67
+ "2": 4,
68
+ "A": 5,
69
+ "B": 6,
70
+ "C": 7,
71
+ "D": 8,
72
+ "E": 9,
73
+ "F": 10,
74
+ "G": 11,
75
+ "H": 12,
76
+ "I": 13,
77
+ "K": 14,
78
+ "L": 15,
79
+ "M": 16,
80
+ "N": 17,
81
+ "O": 18,
82
+ "P": 19,
83
+ "Q": 20,
84
+ "R": 21,
85
+ "S": 22,
86
+ "T": 23,
87
+ "U": 24,
88
+ "V": 25,
89
+ "W": 26,
90
+ "X": 27,
91
+ "Y": 28,
92
+ "Z": 29
93
+ },
94
+ "merges": []
95
+ }
96
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<|pad|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<|bos|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<|eos|>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ }
27
+ },
28
+ "auto_map": {
29
+ "AutoTokenizer": [
30
+ "InstructPLM/MPNN-ProGen2-xlarge-CATH42--tokenization_iPLM.iPLMTokenizer",
31
+ null
32
+ ]
33
+ },
34
+ "clean_up_tokenization_spaces": true,
35
+ "legacy": false,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "<|pad|>",
38
+ "tokenizer_class": "iPLMTokenizer"
39
+ }