ruixie commited on
Commit
a89b834
1 Parent(s): 37ecd4f

Update modeling_codeshell.py

Browse files
Files changed (1) hide show
  1. modeling_codeshell.py +206 -4
modeling_codeshell.py CHANGED
@@ -30,14 +30,21 @@
30
  # See the License for the specific language governing permissions and
31
  # limitations under the License.
32
  """PyTorch CodeShell model."""
 
33
  import math
34
- from typing import List, Optional, Tuple, Union
 
 
 
35
 
36
  import torch
37
  import torch.utils.checkpoint
38
  from torch import nn
39
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
40
 
 
 
 
41
  from transformers.activations import ACT2FN
42
  from transformers.modeling_outputs import (
43
  BaseModelOutputWithPastAndCrossAttentions,
@@ -50,7 +57,6 @@ from transformers.utils import (
50
  )
51
  from .configuration_codeshell import CodeShellConfig
52
 
53
-
54
  # Fused kernels
55
  # Use separate functions for each case because conditionals prevent kernel fusion.
56
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
@@ -739,6 +745,62 @@ class CodeShellModel(CodeShellPreTrainedModel):
739
  hidden_states=all_hidden_states,
740
  attentions=all_self_attentions,
741
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
 
743
 
744
  @add_start_docstrings(
@@ -762,10 +824,10 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
762
  def quantize(self, bits: int):
763
  try:
764
  import bitsandbytes
765
- from .quantizer import quantize_online
766
  except ImportError:
767
  raise ImportError(f"Needs bitsandbytes to run quantize.")
768
- return quantize_online(self, bits)
769
 
770
  def get_output_embeddings(self):
771
  return self.lm_head
@@ -882,3 +944,143 @@ class CodeShellForCausalLM(CodeShellPreTrainedModel):
882
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
883
  )
884
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # See the License for the specific language governing permissions and
31
  # limitations under the License.
32
  """PyTorch CodeShell model."""
33
+ import os
34
  import math
35
+ from typing import List, Optional, Tuple, Union, Callable
36
+ from threading import Thread
37
+ from queue import Queue
38
+
39
 
40
  import torch
41
  import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
+ from transformers import LogitsProcessorList, StoppingCriteriaList, StoppingCriteria, PreTrainedModel, PretrainedConfig
46
+ from transformers.generation.utils import GenerationConfig
47
+
48
  from transformers.activations import ACT2FN
49
  from transformers.modeling_outputs import (
50
  BaseModelOutputWithPastAndCrossAttentions,
 
57
  )
58
  from .configuration_codeshell import CodeShellConfig
59
 
 
60
  # Fused kernels
61
  # Use separate functions for each case because conditionals prevent kernel fusion.
62
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
 
745
  hidden_states=all_hidden_states,
746
  attentions=all_self_attentions,
747
  )
748
+
749
+ class EndOfFunctionCriteria(StoppingCriteria):
750
+ """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
751
+ def __init__(self, input_lengths, eof_strings, tokenizer):
752
+ self.input_lengths = input_lengths
753
+ self.eof_strings = eof_strings
754
+ self.tokenizer = tokenizer
755
+
756
+ def __call__(self, input_ids, scores, **kwargs):
757
+ """Returns true if all generated sequences contain any of the end-of-function strings."""
758
+ decoded_generations = []
759
+ for _input_ids, input_length in zip(input_ids, self.input_lengths):
760
+ decoded_generations.append(self.tokenizer.decode(_input_ids[input_length:]))
761
+ done = []
762
+ for decoded_generation in decoded_generations:
763
+ done.append(
764
+ any(
765
+ [
766
+ stop_string in decoded_generation
767
+ for stop_string in self.eof_strings
768
+ ]
769
+ )
770
+ )
771
+ return all(done)
772
+
773
+ class TextIterStreamer:
774
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
775
+ self.tokenizer = tokenizer
776
+ self.skip_prompt = skip_prompt
777
+ self.skip_special_tokens = skip_special_tokens
778
+ self.tokens = []
779
+ self.text_queue = Queue()
780
+ self.next_tokens_are_prompt = True
781
+
782
+ def put(self, value):
783
+ if self.skip_prompt and self.next_tokens_are_prompt:
784
+ self.next_tokens_are_prompt = False
785
+ else:
786
+ if len(value.shape) > 1:
787
+ value = value[0]
788
+ self.tokens.extend(value.tolist())
789
+ self.text_queue.put(
790
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
791
+
792
+ def end(self):
793
+ self.text_queue.put(None)
794
+
795
+ def __iter__(self):
796
+ return self
797
+
798
+ def __next__(self):
799
+ value = self.text_queue.get()
800
+ if value is None:
801
+ raise StopIteration()
802
+ else:
803
+ return value
804
 
805
 
806
  @add_start_docstrings(
 
824
  def quantize(self, bits: int):
825
  try:
826
  import bitsandbytes
827
+ from .quantizer import quantize
828
  except ImportError:
829
  raise ImportError(f"Needs bitsandbytes to run quantize.")
830
+ return quantize(self, bits)
831
 
832
  def get_output_embeddings(self):
833
  return self.lm_head
 
944
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
945
  )
946
  return reordered_past
947
+
948
+
949
+ def build_chat_input(self, query, history, tokenizer, max_new_tokens=None):
950
+ user_name = "\n## human:"
951
+ ai_name = "\n## assistant: "
952
+ stop = '|<end>|'
953
+
954
+ prompt = ''
955
+ for q, r in history:
956
+ prompt += f"{user_name}{q}{stop}"
957
+ prompt += f"{ai_name}{r}{stop}"
958
+ prompt += f"{user_name}{query}{stop}"
959
+ prompt += ai_name.rstrip()
960
+
961
+ max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
962
+ max_input_tokens = self.config.n_positions - max_new_tokens
963
+
964
+ input_tokens = tokenizer.encode(prompt)
965
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
966
+ return torch.LongTensor([input_tokens]).to(self.device)
967
+
968
+ def chat(self, query, history, tokenizer, stream=False,
969
+ generation_config: Optional[GenerationConfig]=None):
970
+ generation_config = generation_config or self.generation_config
971
+ input_ids = self.build_chat_input(query, history, tokenizer, generation_config.max_new_tokens)
972
+ stopping_criteria = StoppingCriteriaList(
973
+ [EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
974
+ )
975
+
976
+ if stream:
977
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
978
+ Thread(target=self.generate, kwargs=dict(
979
+ inputs=input_ids, streamer=streamer,
980
+ stopping_criteria = stopping_criteria,
981
+ generation_config=generation_config,
982
+ )).start()
983
+ return streamer
984
+ else:
985
+ outputs = self.generate(input_ids, generation_config=generation_config, stopping_criteria = stopping_criteria)
986
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
987
+ return response
988
+
989
+ def generate_stream(self, prompt, tokenizer, generation_config=None, **kwargs):
990
+ generation_config = generation_config or self.generation_config
991
+ max_input_tokens = self.config.n_positions - self.generation_config.max_new_tokens
992
+
993
+ input_ids = tokenizer.encode(prompt)
994
+ input_ids = input_ids[-max_input_tokens:] # truncate left
995
+
996
+ stopping_criteria = StoppingCriteriaList(
997
+ [EndOfFunctionCriteria([len(input_ids[0])], ['|<end>|', '<|endoftext|>'], tokenizer)]
998
+ )
999
+
1000
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
1001
+ Thread(target=self.generate, kwargs=dict(
1002
+ inputs=input_ids, stopping_criteria=stopping_criteria, **kwargs
1003
+ )).start()
1004
+ return streamer
1005
+
1006
+
1007
+ class CodeShell4bitForCausalLM(CodeShellForCausalLM):
1008
+ def __init__(self, config):
1009
+ CodeShellPreTrainedModel.__init__(self, config)
1010
+ self.transformer = CodeShellModel(config)
1011
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1012
+
1013
+ try:
1014
+ import bitsandbytes
1015
+ from .quantizer import quantize_offline
1016
+ quantize_offline(self)
1017
+ except ImportError:
1018
+ raise ImportError(f"Needs bitsandbytes to run quantize.")
1019
+
1020
+ self.post_init()
1021
+
1022
+ @classmethod
1023
+ def from_pretrained(
1024
+ cls,
1025
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1026
+ *model_args,
1027
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
1028
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
1029
+ ignore_mismatched_sizes: bool = False,
1030
+ force_download: bool = False,
1031
+ local_files_only: bool = False,
1032
+ token: Optional[Union[str, bool]] = None,
1033
+ revision: str = "main",
1034
+ use_safetensors: bool = None,
1035
+ **kwargs,
1036
+ ):
1037
+ if not isinstance(config, PretrainedConfig):
1038
+ config_path = config if config is not None else pretrained_model_name_or_path
1039
+ config, _ = cls.config_class.from_pretrained(
1040
+ config_path,
1041
+ cache_dir=cache_dir,
1042
+ return_unused_kwargs=True,
1043
+ force_download=force_download,
1044
+ resume_download=False,
1045
+ proxies=None,
1046
+ local_files_only=local_files_only,
1047
+ token=token,
1048
+ revision=revision,
1049
+ subfolder="",
1050
+ _from_auto=False,
1051
+ _from_pipeline=None,
1052
+ **kwargs,
1053
+ )
1054
+
1055
+ # Load config if we don't provide a configuration
1056
+ from .quantizer import load_state_dict_for_qunantied_model
1057
+ model = cls(config)
1058
+ state_dict = torch.load(os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin'), map_location="cpu")
1059
+ model = load_state_dict_for_qunantied_model(model, state_dict)
1060
+ model.eval()
1061
+
1062
+ # If it is a model with generation capabilities, attempt to load the generation config
1063
+ if model.can_generate():
1064
+ try:
1065
+ model.generation_config = GenerationConfig.from_pretrained(
1066
+ pretrained_model_name_or_path,
1067
+ cache_dir=cache_dir,
1068
+ force_download=force_download,
1069
+ resume_download=False,
1070
+ proxies=None,
1071
+ local_files_only=local_files_only,
1072
+ token=token,
1073
+ revision=revision,
1074
+ subfolder="",
1075
+ _from_auto=False,
1076
+ _from_pipeline=None,
1077
+ **kwargs,
1078
+ )
1079
+ except (OSError, TypeError):
1080
+ pass
1081
+
1082
+ device_map = kwargs.pop("device_map", None)
1083
+ if device_map is not None:
1084
+ model = model.to(torch.device(device_map))
1085
+
1086
+ return model