Update lavis/models/protein_models/protein_function_opt.py
Browse files
lavis/models/protein_models/protein_function_opt.py
CHANGED
@@ -98,26 +98,15 @@ class Blip2ProteinMistral(Blip2ProteinBase):
|
|
98 |
|
99 |
self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
|
100 |
# self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
|
101 |
-
# configuration = MistralConfig()
|
102 |
self.mistral_tokenizer.pad_token = '<pad>'
|
103 |
-
self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
|
104 |
-
|
105 |
-
|
106 |
-
for name, param in self.mistral_model.named_parameters():
|
107 |
-
param.requires_grad = False
|
108 |
-
#self.mistral_model.lm_head = self.mistral_model.lm_head.float()
|
109 |
-
#for param in self.mistral_model.lm_head.parameters():
|
110 |
-
# param.requires_grad = True
|
111 |
-
|
112 |
-
#self.eos_token_id = self.mistral_tokenizer(
|
113 |
-
# "\n", add_special_tokens=False
|
114 |
-
#).input_ids[0]
|
115 |
self.eos_token_id = self.mistral_tokenizer(
|
116 |
"\n", add_special_tokens=False
|
117 |
).input_ids[1]
|
118 |
-
print(f"LLM hidden size: {self.mistral_model.config.hidden_size}")
|
119 |
self.opt_proj = nn.Linear(
|
120 |
-
self.Qformer.config.hidden_size,
|
121 |
)
|
122 |
|
123 |
self.max_txt_len = max_txt_len
|
@@ -191,7 +180,6 @@ class Blip2ProteinMistral(Blip2ProteinBase):
|
|
191 |
)
|
192 |
targets = torch.cat([empty_targets, targets], dim=1)
|
193 |
|
194 |
-
#inputs_embeds = self.mistral_model.model.decoder.embed_tokens(mistral_tokens.input_ids)
|
195 |
inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
|
196 |
inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
|
197 |
attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
|
@@ -209,6 +197,7 @@ class Blip2ProteinMistral(Blip2ProteinBase):
|
|
209 |
@torch.no_grad()
|
210 |
def generate(
|
211 |
self,
|
|
|
212 |
samples,
|
213 |
# use_nucleus_sampling=False,
|
214 |
num_beams=15,
|
@@ -262,8 +251,8 @@ class Blip2ProteinMistral(Blip2ProteinBase):
|
|
262 |
truncation=True,
|
263 |
max_length=self.max_txt_len,
|
264 |
).to(self.device)
|
265 |
-
|
266 |
-
inputs_embeds =
|
267 |
inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
|
268 |
attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
|
269 |
# if name[0] == 'Pin':
|
@@ -275,7 +264,7 @@ class Blip2ProteinMistral(Blip2ProteinBase):
|
|
275 |
#num_txt = 15
|
276 |
#return_num_txt = 10
|
277 |
with torch.no_grad():
|
278 |
-
outputs =
|
279 |
max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
|
280 |
output_scores=True,
|
281 |
repetition_penalty=repetition_penalty, num_beams=num_beams,
|
|
|
98 |
|
99 |
self.mistral_tokenizer = LlamaTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
|
100 |
# self.mistral_tokenizer = LlamaTokenizer.from_pretrained("/cluster/home/wenkai/.cache/huggingface/hub/models--teknium--OpenHermes-2.5-Mistral-7B", use_fast=False)
|
|
|
101 |
self.mistral_tokenizer.pad_token = '<pad>'
|
102 |
+
# self.mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16)
|
103 |
+
self.mistral_model = None
|
104 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
self.eos_token_id = self.mistral_tokenizer(
|
106 |
"\n", add_special_tokens=False
|
107 |
).input_ids[1]
|
|
|
108 |
self.opt_proj = nn.Linear(
|
109 |
+
self.Qformer.config.hidden_size, 4096
|
110 |
)
|
111 |
|
112 |
self.max_txt_len = max_txt_len
|
|
|
180 |
)
|
181 |
targets = torch.cat([empty_targets, targets], dim=1)
|
182 |
|
|
|
183 |
inputs_embeds = self.mistral_model.model.embed_tokens(mistral_tokens.input_ids)
|
184 |
inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
|
185 |
attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
|
|
|
197 |
@torch.no_grad()
|
198 |
def generate(
|
199 |
self,
|
200 |
+
mistral_model,
|
201 |
samples,
|
202 |
# use_nucleus_sampling=False,
|
203 |
num_beams=15,
|
|
|
251 |
truncation=True,
|
252 |
max_length=self.max_txt_len,
|
253 |
).to(self.device)
|
254 |
+
|
255 |
+
inputs_embeds = mistral_model.model.embed_tokens(mistral_tokens.input_ids)
|
256 |
inputs_embeds = torch.cat([inputs_mistral, inputs_embeds], dim=1)
|
257 |
attention_mask = torch.cat([atts_mistral, mistral_tokens.attention_mask], dim=1)
|
258 |
# if name[0] == 'Pin':
|
|
|
264 |
#num_txt = 15
|
265 |
#return_num_txt = 10
|
266 |
with torch.no_grad():
|
267 |
+
outputs = mistral_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, min_length=min_length,
|
268 |
max_new_tokens=max_length, temperature=temperature, return_dict_in_generate=True,
|
269 |
output_scores=True,
|
270 |
repetition_penalty=repetition_penalty, num_beams=num_beams,
|