habdine commited on
Commit
0a29ef9
1 Parent(s): 91dcd4d

Upload configuration_prot2text.py

Browse files
Files changed (1) hide show
  1. configuration_prot2text.py +73 -0
configuration_prot2text.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Prot2Text configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig, AutoConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class Prot2TextConfig(PretrainedConfig):
11
+ model_type = "prot2text"
12
+ keys_to_ignore_at_inference = ["past_key_values"]
13
+ _keys_to_ignore_on_load_missing = [r"transformer"]
14
+
15
+ def __init__(
16
+ self,
17
+ cross_esm_graph=True,
18
+ decoder_start_token_id=50257,
19
+ early_stopping=True,
20
+ eos_token_id=50258,
21
+ bos_token_id=50257,
22
+ esm=True,
23
+ esm_model_name="facebook/esm2_t6_8M_UR50D",
24
+ gpt_model_name="gpt2",
25
+ length_penalty=2.0,
26
+ max_new_tokens=256,
27
+ no_repeat_ngram_size=3,
28
+ pad_token_id=50256,
29
+ prot2text_version="1.1",
30
+ rgcn=True,
31
+ rgc_input_dim=67,
32
+ rgcn_n_layers=6,
33
+ gpt_config=None,
34
+ esm_config=None,
35
+ **kwargs,
36
+ ):
37
+ self.cross_esm_graph = cross_esm_graph
38
+ self.decoder_start_token_id = decoder_start_token_id
39
+ self.early_stopping = early_stopping
40
+ self.eos_token_id = eos_token_id
41
+ self.esm = esm
42
+ self.esm_model_name = esm_model_name
43
+ self.gpt_model_name = gpt_model_name
44
+ self.length_penalty = length_penalty
45
+ self.max_new_tokens = max_new_tokens
46
+ self.no_repeat_ngram_size = no_repeat_ngram_size
47
+ self.pad_token_id = pad_token_id
48
+ self.prot2text_version = prot2text_version
49
+ self.rgcn = rgcn
50
+ self.rgc_input_dim = rgc_input_dim
51
+ self.rgcn_n_layers = rgcn_n_layers
52
+ if gpt_config is None:
53
+ self.gpt_config = AutoConfig.from_pretrained(gpt_model_name,
54
+ _name_or_path= gpt_model_name,
55
+ is_encoder_decoder=True,
56
+ use_cache=False,
57
+ add_cross_attention=True,
58
+ bos_token_id=bos_token_id,
59
+ decoder_start_token_id=decoder_start_token_id,
60
+ eos_token_id=eos_token_id,
61
+ max_new_tokens=max_new_tokens,
62
+ pad_token_id=50256,
63
+ vocab_size=50259,
64
+ num_beams=1,
65
+ max_length=256,
66
+ min_length=1).to_dict()
67
+ else:
68
+ self.gpt_config = gpt_config
69
+ if esm_config is None:
70
+ self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict()
71
+ self.esm_config = esm_config
72
+
73
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)