xiaotinghe commited on
Commit
ea3bbbe
1 Parent(s): 15a8eb4

Upload BaiChuanForCausalLM

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +33 -0
modeling_baichuan.py CHANGED
@@ -23,6 +23,8 @@ from transformers.activations import ACT2FN
23
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
24
  SequenceClassifierOutputWithPast
25
  from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
 
 
26
 
27
  import math
28
  from typing import List, Optional, Tuple, Union
@@ -35,6 +37,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
35
 
36
  logger = logging.get_logger(__name__)
37
 
 
 
 
 
 
 
 
38
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
39
  def _make_causal_mask(
40
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -669,3 +678,27 @@ class BaiChuanForCausalLM(PreTrainedModel):
669
  for layer_past in past_key_values:
670
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
671
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
24
  SequenceClassifierOutputWithPast
25
  from transformers.utils import logging, add_start_docstrings_to_model_forward, replace_return_docstrings
26
+ from transformers.generation.logits_process import LogitsProcessor
27
+ from transformers.generation.utils import LogitsProcessorList
28
 
29
  import math
30
  from typing import List, Optional, Tuple, Union
 
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
41
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
42
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
43
+ scores.zero_()
44
+ scores[..., 5] = 5e4
45
+ return scores
46
+
47
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
48
  def _make_causal_mask(
49
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
678
  for layer_past in past_key_values:
679
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
680
  return reordered_past
681
+
682
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
683
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
684
+ if history is None:
685
+ history = []
686
+ if logits_processor is None:
687
+ logits_processor = LogitsProcessorList()
688
+ logits_processor.append(InvalidScoreLogitsProcessor())
689
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
690
+ "temperature": temperature, "logits_processor": logits_processor, "use_cache": True, **kwargs}
691
+ prompt_template = '###Human: {instruction}###Assistant: {output}'
692
+ if not history:
693
+ prompt = prompt_template.format(instruction = query, output='')
694
+ else:
695
+ prompt = ""
696
+ for i, (old_query, response) in enumerate(history):
697
+ prompt += prompt_template.format(instruction = old_query, output=response)
698
+ prompt += prompt_template.format(instruction = query, output='')
699
+ inputs = tokenizer(prompt, return_tensors='pt')
700
+ inputs = inputs.to(self.device)
701
+ outputs = self.generate(**inputs, **gen_kwargs)
702
+ response = tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
703
+ history = history + [(query, response)]
704
+ return response, history