|
""" Prot2Text configuration""" |
|
|
|
from transformers.configuration_utils import PretrainedConfig, AutoConfig |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Prot2TextConfig(PretrainedConfig): |
|
model_type = "prot2text" |
|
keys_to_ignore_at_inference = ["past_key_values"] |
|
_keys_to_ignore_on_load_missing = [r"transformer"] |
|
|
|
def __init__( |
|
self, |
|
cross_esm_graph=True, |
|
decoder_start_token_id=50257, |
|
early_stopping=True, |
|
eos_token_id=50258, |
|
bos_token_id=50257, |
|
esm=True, |
|
esm_model_name="facebook/esm2_t6_8M_UR50D", |
|
gpt_model_name="gpt2", |
|
length_penalty=2.0, |
|
max_new_tokens=256, |
|
no_repeat_ngram_size=3, |
|
pad_token_id=50256, |
|
prot2text_version="1.1", |
|
rgcn=True, |
|
rgc_input_dim=67, |
|
rgcn_n_layers=6, |
|
gpt_config=None, |
|
esm_config=None, |
|
**kwargs, |
|
): |
|
self.cross_esm_graph = cross_esm_graph |
|
self.decoder_start_token_id = decoder_start_token_id |
|
self.early_stopping = early_stopping |
|
self.eos_token_id = eos_token_id |
|
self.esm = esm |
|
self.esm_model_name = esm_model_name |
|
self.gpt_model_name = gpt_model_name |
|
self.length_penalty = length_penalty |
|
self.max_new_tokens = max_new_tokens |
|
self.no_repeat_ngram_size = no_repeat_ngram_size |
|
self.pad_token_id = pad_token_id |
|
self.prot2text_version = prot2text_version |
|
self.rgcn = rgcn |
|
self.rgc_input_dim = rgc_input_dim |
|
self.rgcn_n_layers = rgcn_n_layers |
|
if gpt_config is None: |
|
self.gpt_config = AutoConfig.from_pretrained(gpt_model_name, |
|
_name_or_path= gpt_model_name, |
|
is_encoder_decoder=True, |
|
use_cache=False, |
|
add_cross_attention=True, |
|
bos_token_id=bos_token_id, |
|
decoder_start_token_id=decoder_start_token_id, |
|
eos_token_id=eos_token_id, |
|
max_new_tokens=max_new_tokens, |
|
pad_token_id=50256, |
|
vocab_size=50259, |
|
num_beams=1, |
|
max_length=256, |
|
min_length=1).to_dict() |
|
else: |
|
self.gpt_config = gpt_config |
|
if esm_config is None: |
|
self.esm_config = AutoConfig.from_pretrained(esm_model_name).to_dict() |
|
self.esm_config = esm_config |
|
|
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |