Upload folder using huggingface_hub
Browse files- modeling_sa2va_chat.py +4 -3
modeling_sa2va_chat.py
CHANGED
@@ -689,6 +689,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
689 |
input=text, round=1, bot_name=self.bot_name)
|
690 |
input_text = past_text + input_text
|
691 |
ids = self.tokenizer.encode(input_text)
|
|
|
692 |
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
693 |
|
694 |
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
@@ -715,7 +716,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
715 |
)
|
716 |
predict = self.tokenizer.decode(
|
717 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
718 |
-
|
|
|
719 |
# if have seg result, find the seg hidden states
|
720 |
hidden_states = generate_output.hidden_states
|
721 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
@@ -737,8 +739,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
737 |
masks = masks.sigmoid() > 0.5
|
738 |
masks = masks.cpu().numpy()
|
739 |
ret_masks.append(masks)
|
740 |
-
|
741 |
-
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
742 |
|
743 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
744 |
seg_mask = output_ids == seg_id
|
|
|
689 |
input=text, round=1, bot_name=self.bot_name)
|
690 |
input_text = past_text + input_text
|
691 |
ids = self.tokenizer.encode(input_text)
|
692 |
+
ret_past_text = self.tokenizer.decode(ids)
|
693 |
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
694 |
|
695 |
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
|
|
716 |
)
|
717 |
predict = self.tokenizer.decode(
|
718 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
719 |
+
ret_past_text = ret_past_text + self.tokenizer.decode(
|
720 |
+
generate_output.sequences[0], skip_special_tokens=False)
|
721 |
# if have seg result, find the seg hidden states
|
722 |
hidden_states = generate_output.hidden_states
|
723 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
|
739 |
masks = masks.sigmoid() > 0.5
|
740 |
masks = masks.cpu().numpy()
|
741 |
ret_masks.append(masks)
|
742 |
+
return {'prediction': predict, 'prediction_masks': ret_masks, "past_text": ret_past_text}
|
|
|
743 |
|
744 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
745 |
seg_mask = output_ids == seg_id
|