DeepLangLvcc
commited on
Commit
•
9a542c4
1
Parent(s):
1439daa
add chat function
Browse files- modeling_lingowhale.py +121 -0
modeling_lingowhale.py
CHANGED
@@ -19,6 +19,8 @@
|
|
19 |
|
20 |
import math
|
21 |
import os
|
|
|
|
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
|
24 |
import torch
|
@@ -28,6 +30,7 @@ from torch.nn import CrossEntropyLoss
|
|
28 |
from torch.nn import functional as F
|
29 |
from transformers import PretrainedConfig, PreTrainedModel
|
30 |
from transformers.activations import ACT2FN
|
|
|
31 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
32 |
CausalLMOutputWithPast)
|
33 |
from transformers.utils import logging
|
@@ -106,6 +109,44 @@ def _expand_mask(mask: torch.Tensor,
|
|
106 |
torch.finfo(dtype).min)
|
107 |
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
class LingoWhaleRMSNorm(torch.nn.Module):
|
110 |
|
111 |
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
@@ -931,6 +972,86 @@ class LingoWhaleForCausalLM(LingoWhalePreTrainedModel):
|
|
931 |
})
|
932 |
return model_inputs
|
933 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
@staticmethod
|
935 |
def _reorder_cache(past_key_values, beam_idx):
|
936 |
reordered_past = ()
|
|
|
19 |
|
20 |
import math
|
21 |
import os
|
22 |
+
from queue import Queue
|
23 |
+
from threading import Thread
|
24 |
from typing import List, Optional, Tuple, Union
|
25 |
|
26 |
import torch
|
|
|
30 |
from torch.nn import functional as F
|
31 |
from transformers import PretrainedConfig, PreTrainedModel
|
32 |
from transformers.activations import ACT2FN
|
33 |
+
from transformers.generation.utils import GenerationConfig
|
34 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
35 |
CausalLMOutputWithPast)
|
36 |
from transformers.utils import logging
|
|
|
109 |
torch.finfo(dtype).min)
|
110 |
|
111 |
|
112 |
+
class TextIterStreamer:
|
113 |
+
|
114 |
+
def __init__(self,
|
115 |
+
tokenizer,
|
116 |
+
skip_prompt=False,
|
117 |
+
skip_special_tokens=False):
|
118 |
+
self.tokenizer = tokenizer
|
119 |
+
self.skip_prompt = skip_prompt
|
120 |
+
self.skip_special_tokens = skip_special_tokens
|
121 |
+
self.tokens = []
|
122 |
+
self.text_queue = Queue()
|
123 |
+
self.next_tokens_are_prompt = True
|
124 |
+
|
125 |
+
def put(self, value):
|
126 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
127 |
+
self.next_tokens_are_prompt = False
|
128 |
+
else:
|
129 |
+
if len(value.shape) > 1:
|
130 |
+
value = value[0]
|
131 |
+
self.tokens.extend(value.tolist())
|
132 |
+
self.text_queue.put(
|
133 |
+
self.tokenizer.decode(
|
134 |
+
self.tokens, skip_special_tokens=self.skip_special_tokens))
|
135 |
+
|
136 |
+
def end(self):
|
137 |
+
self.text_queue.put(None)
|
138 |
+
|
139 |
+
def __iter__(self):
|
140 |
+
return self
|
141 |
+
|
142 |
+
def __next__(self):
|
143 |
+
value = self.text_queue.get()
|
144 |
+
if value is None:
|
145 |
+
raise StopIteration()
|
146 |
+
else:
|
147 |
+
return value
|
148 |
+
|
149 |
+
|
150 |
class LingoWhaleRMSNorm(torch.nn.Module):
|
151 |
|
152 |
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
|
972 |
})
|
973 |
return model_inputs
|
974 |
|
975 |
+
def build_chat_input(self,
|
976 |
+
tokenizer,
|
977 |
+
messages: List[dict],
|
978 |
+
max_new_tokens: int = 0,
|
979 |
+
user_token_ids=[3],
|
980 |
+
assistant_tokens=[4]):
|
981 |
+
max_input_tokens = self.config.model_max_length - max_new_tokens
|
982 |
+
|
983 |
+
def _parse_messages(messages):
|
984 |
+
|
985 |
+
chat_rounds, chat_round = [], []
|
986 |
+
|
987 |
+
for message in messages:
|
988 |
+
if message['role'] == 'user' and len(chat_round) > 0:
|
989 |
+
chat_rounds.append(chat_round)
|
990 |
+
chat_round = []
|
991 |
+
chat_round.append(message)
|
992 |
+
|
993 |
+
if len(chat_round) > 0:
|
994 |
+
chat_rounds.append(chat_round)
|
995 |
+
|
996 |
+
return chat_rounds
|
997 |
+
|
998 |
+
chat_rounds = _parse_messages(messages)[::-1]
|
999 |
+
|
1000 |
+
def get_chat_tokens(tokenizer, chat_round, user_token_ids,
|
1001 |
+
assistant_tokens):
|
1002 |
+
tokens = []
|
1003 |
+
tokens += user_token_ids
|
1004 |
+
assert len(chat_round) < 3
|
1005 |
+
|
1006 |
+
if len(chat_round) == 1:
|
1007 |
+
tokens += tokenizer.encode(chat_round[0]['content'])
|
1008 |
+
tokens += assistant_tokens
|
1009 |
+
else:
|
1010 |
+
tokens += tokenizer.encode(chat_round[0]['content'])
|
1011 |
+
tokens += assistant_tokens
|
1012 |
+
tokens += tokenizer.encode(chat_round[1]['content'])
|
1013 |
+
|
1014 |
+
return tokens
|
1015 |
+
|
1016 |
+
input_tokens = []
|
1017 |
+
for chat_round in chat_rounds:
|
1018 |
+
chat_tokens = get_chat_tokens(tokenizer, chat_round,
|
1019 |
+
user_token_ids, assistant_tokens)
|
1020 |
+
if len(chat_tokens + input_tokens) > max_input_tokens:
|
1021 |
+
return input_tokens
|
1022 |
+
|
1023 |
+
input_tokens = chat_tokens + input_tokens
|
1024 |
+
return torch.LongTensor([input_tokens]).to(self.device)
|
1025 |
+
|
1026 |
+
def chat(self,
|
1027 |
+
tokenizer,
|
1028 |
+
messages: List[dict],
|
1029 |
+
stream=False,
|
1030 |
+
generation_config: Optional[GenerationConfig] = None,
|
1031 |
+
max_new_tokens = 100):
|
1032 |
+
|
1033 |
+
|
1034 |
+
if generation_config is not None:
|
1035 |
+
max_new_tokens = generation_config.max_new_tokens
|
1036 |
+
|
1037 |
+
input_ids = self.build_chat_input(tokenizer, messages, max_new_tokens)
|
1038 |
+
if stream:
|
1039 |
+
streamer = TextIterStreamer(tokenizer,
|
1040 |
+
skip_prompt=True,
|
1041 |
+
skip_special_tokens=True)
|
1042 |
+
Thread(target=self.generate,
|
1043 |
+
kwargs=dict(inputs=input_ids,
|
1044 |
+
streamer=streamer,
|
1045 |
+
generation_config=generation_config)).start()
|
1046 |
+
|
1047 |
+
return streamer
|
1048 |
+
else:
|
1049 |
+
outputs = self.generate(input_ids,
|
1050 |
+
generation_config=generation_config)
|
1051 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):],
|
1052 |
+
skip_special_tokens=True)
|
1053 |
+
return response
|
1054 |
+
|
1055 |
@staticmethod
|
1056 |
def _reorder_cache(past_key_values, beam_idx):
|
1057 |
reordered_past = ()
|