xiaotinghe
commited on
Commit
•
ea3bbbe
1
Parent(s):
15a8eb4
Upload BaiChuanForCausalLM
Browse files- modeling_baichuan.py +33 -0
modeling_baichuan.py
CHANGED
@@ -23,6 +23,8 @@ from transformers.activations import ACT2FN
|
|
23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
24 |
SequenceClassifierOutputWithPast
|
25 |
from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
|
|
|
|
|
26 |
|
27 |
import math
|
28 |
from typing import List, Optional, Tuple, Union
|
@@ -35,6 +37,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
35 |
|
36 |
logger = logging.get_logger(__name__)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
39 |
def _make_causal_mask(
|
40 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
@@ -669,3 +678,27 @@ class BaiChuanForCausalLM(PreTrainedModel):
|
|
669 |
for layer_past in past_key_values:
|
670 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
671 |
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
|
24 |
SequenceClassifierOutputWithPast
|
25 |
from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
|
26 |
+
from transformers.generation.logits_process import LogitsProcessor
|
27 |
+
from transformers.generation.utils import LogitsProcessorList
|
28 |
|
29 |
import math
|
30 |
from typing import List, Optional, Tuple, Union
|
|
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
40 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
41 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
42 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
43 |
+
scores.zero_()
|
44 |
+
scores[..., 5] = 5e4
|
45 |
+
return scores
|
46 |
+
|
47 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
48 |
def _make_causal_mask(
|
49 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
678 |
for layer_past in past_key_values:
|
679 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
680 |
return reordered_past
|
681 |
+
|
682 |
+
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
683 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
684 |
+
if history is None:
|
685 |
+
history = []
|
686 |
+
if logits_processor is None:
|
687 |
+
logits_processor = LogitsProcessorList()
|
688 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
689 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
690 |
+
"temperature": temperature, "logits_processor": logits_processor, "use_cache": True, **kwargs}
|
691 |
+
prompt_template = '###Human: {instruction}###Assistant: {output}'
|
692 |
+
if not history:
|
693 |
+
prompt = prompt_template.format(instruction = query, output='')
|
694 |
+
else:
|
695 |
+
prompt = ""
|
696 |
+
for i, (old_query, response) in enumerate(history):
|
697 |
+
prompt += prompt_template.format(instruction = old_query, output=response)
|
698 |
+
prompt += prompt_template.format(instruction = query, output='')
|
699 |
+
inputs = tokenizer(prompt, return_tensors='pt')
|
700 |
+
inputs = inputs.to(self.device)
|
701 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
702 |
+
response = tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
703 |
+
history = history + [(query, response)]
|
704 |
+
return response, history
|