ruixie commited on
Commit
26e6676
1 Parent(s): e813cde

Update modeling_codeshell.py

Browse files
Files changed (1) hide show
  1. modeling_codeshell.py +121 -5
modeling_codeshell.py CHANGED
@@ -32,14 +32,17 @@
32
  """PyTorch CodeShell model."""
33
  import os
34
  import math
35
- from typing import List, Optional, Tuple, Union
 
 
 
36
 
37
  import torch
38
  import torch.utils.checkpoint
39
  from torch import nn
40
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
41
 
42
- from transformers import PreTrainedModel, PretrainedConfig
43
  from transformers.generation.utils import GenerationConfig
44
 
45
  from transformers.activations import ACT2FN
@@ -54,7 +57,6 @@ from transformers.utils import (
54
  )
55
  from .configuration_codeshell import CodeShellConfig
56
 
57
-
58
  # Fused kernels
59
  # Use separate functions for each case because conditionals prevent kernel fusion.
60
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
@@ -743,6 +745,62 @@ class CodeShellModel(CodeShellPreTrainedModel):
743
  hidden_states=all_hidden_states,
744
  attentions=all_self_attentions,
745
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
 
748
  @add_start_docstrings(
@@ -886,6 +944,65 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
886
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
887
  )
888
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
 
890
  class CodeShell4bitForCausalLM(CodeShellForCausalLM):
891
  def __init__(self, config):
@@ -966,5 +1083,4 @@ class CodeShell4bitForCausalLM(CodeShellForCausalLM):
966
  if device_map is not None:
967
  model = model.to(torch.device(device_map))
968
 
969
- return model
970
-
 
32
  """PyTorch CodeShell model."""
33
  import os
34
  import math
35
+ from typing import List, Optional, Tuple, Union, Callable
36
+ from threading import Thread
37
+ from queue import Queue
38
+
39
 
40
  import torch
41
  import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
+ from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
46
  from transformers.generation.utils import GenerationConfig
47
 
48
  from transformers.activations import ACT2FN
 
57
  )
58
  from .configuration_codeshell import CodeShellConfig
59
 
 
60
  # Fused kernels
61
  # Use separate functions for each case because conditionals prevent kernel fusion.
62
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
 
745
  hidden_states=all_hidden_states,
746
  attentions=all_self_attentions,
747
  )
748
+
749
+ class EndOfFunctionCriteria(StoppingCriteria):
750
+ """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
751
+ def __init__(self, input_lengths, eof_strings, tokenizer):
752
+ self.input_lengths = input_lengths
753
+ self.eof_strings = eof_strings
754
+ self.tokenizer = tokenizer
755
+
756
+ def __call__(self, input_ids, scores, **kwargs):
757
+ """Returns true if all generated sequences contain any of the end-of-function strings."""
758
+ decoded_generations = []
759
+ for _input_ids, input_length in zip(input_ids, self.input_lengths):
760
+ decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
761
+ done = []
762
+ for decoded_generation in decoded_generations:
763
+ done.append(
764
+ any(
765
+ [
766
+ stop_string in decoded_generation
767
+ for stop_string in self.eof_strings
768
+ ]
769
+ )
770
+ )
771
+ return all(done)
772
+
773
+ class TextIterStreamer:
774
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
775
+ self.tokenizer = tokenizer
776
+ self.skip_prompt = skip_prompt
777
+ self.skip_special_tokens = skip_special_tokens
778
+ self.tokens = []
779
+ self.text_queue = Queue()
780
+ self.next_tokens_are_prompt = True
781
+
782
+ def put(self, value):
783
+ if self.skip_prompt and self.next_tokens_are_prompt:
784
+ self.next_tokens_are_prompt = False
785
+ else:
786
+ if len(value.shape) > 1:
787
+ value = value[0]
788
+ self.tokens.extend(value.tolist())
789
+ self.text_queue.put(
790
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
791
+
792
+ def end(self):
793
+ self.text_queue.put(None)
794
+
795
+ def __iter__(self):
796
+ return self
797
+
798
+ def __next__(self):
799
+ value = self.text_queue.get()
800
+ if value is None:
801
+ raise StopIteration()
802
+ else:
803
+ return value
804
 
805
 
806
  @add_start_docstrings(
 
944
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
945
  )
946
  return reordered_past
947
+
948
+
949
+ def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
950
+ user_name = "\n## human:"
951
+ ai_name = "\n## assistant: "
952
+ stop = '|<end>|'
953
+
954
+ prompt = ''
955
+ for q, r in history:
956
+ prompt += f"{user_name}{q}{stop}"
957
+ prompt += f"{ai_name}{r}{stop}"
958
+ prompt += f"{user_name}{query}{stop}"
959
+ prompt += ai_name.rstrip()
960
+
961
+ max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
962
+ max_input_tokens = self.config.n_positions - max_new_tokens
963
+
964
+ input_tokens = tokenizer.encode(prompt)
965
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
966
+ return torch.LongTensor([input_tokens]).to(self.device)
967
+
968
+ def chat(self, query, history, tokenizer, stream=False,
969
+ generation_config: Optional[GenerationConfig]=None):
970
+ generation_config = generation_config or self.generation_config
971
+ input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
972
+ stopping_criteria = StoppingCriteriaList(
973
+ [EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
974
+ )
975
+
976
+ if stream:
977
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
978
+ Thread(target=self.generate, kwargs=dict(
979
+ inputs=input_ids, streamer=streamer,
980
+ stopping_criteria = stopping_criteria,
981
+ generation_config=generation_config,
982
+ )).start()
983
+ return streamer
984
+ else:
985
+ outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
986
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
987
+ return response
988
+
989
+ def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
990
+ generation_config = generation_config or self.generation_config
991
+ max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
992
+
993
+ input_ids = tokenizer.encode(prompt)
994
+ input_ids = input_ids[-max_input_tokens:] # truncate left
995
+
996
+ stopping_criteria = StoppingCriteriaList(
997
+ [EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
998
+ )
999
+
1000
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
1001
+ Thread(target=self.generate, kwargs=dict(
1002
+ inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
1003
+ )).start()
1004
+ return streamer
1005
+
1006
 
1007
  class CodeShell4bitForCausalLM(CodeShellForCausalLM):
1008
  def __init__(self, config):
 
1083
  if device_map is not None:
1084
  model = model.to(torch.device(device_map))
1085
 
1086
+ return model