DeepLangLvcc commited on
Commit
9a542c4
1 Parent(s): 1439daa

add chat function

Browse files
Files changed (1) hide show
  1. modeling_lingowhale.py +121 -0
modeling_lingowhale.py CHANGED
@@ -19,6 +19,8 @@
19
 
20
  import math
21
  import os
 
 
22
  from typing import List, Optional, Tuple, Union
23
 
24
  import torch
@@ -28,6 +30,7 @@ from torch.nn import CrossEntropyLoss
28
  from torch.nn import functional as F
29
  from transformers import PretrainedConfig, PreTrainedModel
30
  from transformers.activations import ACT2FN
 
31
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
32
  CausalLMOutputWithPast)
33
  from transformers.utils import logging
@@ -106,6 +109,44 @@ def _expand_mask(mask: torch.Tensor,
106
  torch.finfo(dtype).min)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  class LingoWhaleRMSNorm(torch.nn.Module):
110
 
111
  def __init__(self, hidden_size: int, eps: float = 1e-6):
@@ -931,6 +972,86 @@ class LingoWhaleForCausalLM(LingoWhalePreTrainedModel):
931
  })
932
  return model_inputs
933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
  @staticmethod
935
  def _reorder_cache(past_key_values, beam_idx):
936
  reordered_past = ()
 
19
 
20
  import math
21
  import os
22
+ from queue import Queue
23
+ from threading import Thread
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
 
30
  from torch.nn import functional as F
31
  from transformers import PretrainedConfig, PreTrainedModel
32
  from transformers.activations import ACT2FN
33
+ from transformers.generation.utils import GenerationConfig
34
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
35
  CausalLMOutputWithPast)
36
  from transformers.utils import logging
 
109
  torch.finfo(dtype).min)
110
 
111
 
112
+ class TextIterStreamer:
113
+
114
+ def __init__(self,
115
+ tokenizer,
116
+ skip_prompt=False,
117
+ skip_special_tokens=False):
118
+ self.tokenizer = tokenizer
119
+ self.skip_prompt = skip_prompt
120
+ self.skip_special_tokens = skip_special_tokens
121
+ self.tokens = []
122
+ self.text_queue = Queue()
123
+ self.next_tokens_are_prompt = True
124
+
125
+ def put(self, value):
126
+ if self.skip_prompt and self.next_tokens_are_prompt:
127
+ self.next_tokens_are_prompt = False
128
+ else:
129
+ if len(value.shape) > 1:
130
+ value = value[0]
131
+ self.tokens.extend(value.tolist())
132
+ self.text_queue.put(
133
+ self.tokenizer.decode(
134
+ self.tokens, skip_special_tokens=self.skip_special_tokens))
135
+
136
+ def end(self):
137
+ self.text_queue.put(None)
138
+
139
+ def __iter__(self):
140
+ return self
141
+
142
+ def __next__(self):
143
+ value = self.text_queue.get()
144
+ if value is None:
145
+ raise StopIteration()
146
+ else:
147
+ return value
148
+
149
+
150
  class LingoWhaleRMSNorm(torch.nn.Module):
151
 
152
  def __init__(self, hidden_size: int, eps: float = 1e-6):
 
972
  })
973
  return model_inputs
974
 
975
+ def build_chat_input(self,
976
+ tokenizer,
977
+ messages: List[dict],
978
+ max_new_tokens: int = 0,
979
+ user_token_ids=[3],
980
+ assistant_tokens=[4]):
981
+ max_input_tokens = self.config.model_max_length - max_new_tokens
982
+
983
+ def _parse_messages(messages):
984
+
985
+ chat_rounds, chat_round = [], []
986
+
987
+ for message in messages:
988
+ if message['role'] == 'user' and len(chat_round) > 0:
989
+ chat_rounds.append(chat_round)
990
+ chat_round = []
991
+ chat_round.append(message)
992
+
993
+ if len(chat_round) > 0:
994
+ chat_rounds.append(chat_round)
995
+
996
+ return chat_rounds
997
+
998
+ chat_rounds = _parse_messages(messages)[::-1]
999
+
1000
+ def get_chat_tokens(tokenizer, chat_round, user_token_ids,
1001
+ assistant_tokens):
1002
+ tokens = []
1003
+ tokens += user_token_ids
1004
+ assert len(chat_round) < 3
1005
+
1006
+ if len(chat_round) == 1:
1007
+ tokens += tokenizer.encode(chat_round[0]['content'])
1008
+ tokens += assistant_tokens
1009
+ else:
1010
+ tokens += tokenizer.encode(chat_round[0]['content'])
1011
+ tokens += assistant_tokens
1012
+ tokens += tokenizer.encode(chat_round[1]['content'])
1013
+
1014
+ return tokens
1015
+
1016
+ input_tokens = []
1017
+ for chat_round in chat_rounds:
1018
+ chat_tokens = get_chat_tokens(tokenizer, chat_round,
1019
+ user_token_ids, assistant_tokens)
1020
+ if len(chat_tokens + input_tokens) > max_input_tokens:
1021
+ return input_tokens
1022
+
1023
+ input_tokens = chat_tokens + input_tokens
1024
+ return torch.LongTensor([input_tokens]).to(self.device)
1025
+
1026
+ def chat(self,
1027
+ tokenizer,
1028
+ messages: List[dict],
1029
+ stream=False,
1030
+ generation_config: Optional[GenerationConfig] = None,
1031
+ max_new_tokens = 100):
1032
+
1033
+
1034
+ if generation_config is not None:
1035
+ max_new_tokens = generation_config.max_new_tokens
1036
+
1037
+ input_ids = self.build_chat_input(tokenizer, messages, max_new_tokens)
1038
+ if stream:
1039
+ streamer = TextIterStreamer(tokenizer,
1040
+ skip_prompt=True,
1041
+ skip_special_tokens=True)
1042
+ Thread(target=self.generate,
1043
+ kwargs=dict(inputs=input_ids,
1044
+ streamer=streamer,
1045
+ generation_config=generation_config)).start()
1046
+
1047
+ return streamer
1048
+ else:
1049
+ outputs = self.generate(input_ids,
1050
+ generation_config=generation_config)
1051
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):],
1052
+ skip_special_tokens=True)
1053
+ return response
1054
+
1055
  @staticmethod
1056
  def _reorder_cache(past_key_values, beam_idx):
1057
  reordered_past = ()