xinrongzhang2022
commited on
Commit
•
c692147
1
Parent(s):
7ca6626
Upload 9 files
Browse files- modeling_minicpm.py +55 -58
modeling_minicpm.py
CHANGED
@@ -21,7 +21,8 @@
|
|
21 |
import math
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
-
|
|
|
25 |
import torch
|
26 |
import torch.nn.functional as F
|
27 |
import torch.utils.checkpoint
|
@@ -1132,14 +1133,20 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1132 |
#######FOR DUPLEX
|
1133 |
self.input_ids = None
|
1134 |
self.history = []
|
|
|
1135 |
self.logits_processor = LogitsProcessorList()
|
1136 |
self.generate_flag = False
|
1137 |
self.print_len = 0
|
1138 |
self.is_length_limit = False
|
1139 |
|
1140 |
def reset_chat_history(self):
|
|
|
|
|
|
|
|
|
1141 |
self.input_ids = None
|
1142 |
self.history = []
|
|
|
1143 |
self.logits_processor = LogitsProcessorList()
|
1144 |
self.generate_flag = False
|
1145 |
self.print_len = 0
|
@@ -1321,17 +1328,19 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1321 |
self.history = []
|
1322 |
for i in range(0, len(history_old), 2):
|
1323 |
if history_old[i]["content"] == "<idle>":
|
1324 |
-
if history_old[i+1]["content"].strip(" .\n,") in ["<idle>", "<idle></s>", "</s>", "idle", "idle</s>"]:
|
1325 |
self.generate_flag = False
|
1326 |
continue
|
1327 |
else:
|
1328 |
self.history.append(history_old[i])
|
1329 |
-
|
|
|
1330 |
self.generate_flag = True
|
1331 |
|
1332 |
else:
|
1333 |
self.history.append(history_old[i])
|
1334 |
-
|
|
|
1335 |
self.generate_flag = True
|
1336 |
|
1337 |
|
@@ -1341,79 +1350,73 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1341 |
stopping_criteria=None, **kwargs):
|
1342 |
|
1343 |
# torch.manual_seed(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1344 |
if self.generate_flag is True and query is not None:
|
1345 |
self.update_history()
|
1346 |
-
|
|
|
1347 |
if self.generate_flag is False and query in ["<idle>"]:
|
1348 |
-
return
|
1349 |
elif query not in ["<idle>"]:
|
1350 |
self.generate_flag = True
|
1351 |
-
|
|
|
1352 |
history_str = ""
|
1353 |
-
for iii in range(0, len(
|
1354 |
-
history_str += "<用户>" +
|
1355 |
-
if iii < len(
|
1356 |
-
history_str +=
|
1357 |
# history_str = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
1358 |
self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1359 |
if self.input_ids.shape[-1] >= max_length:
|
1360 |
self.is_length_limit = True
|
|
|
1361 |
return 1
|
1362 |
-
# while self.input_ids.shape[-1] >= max_length and len(history_str) > 0:
|
1363 |
-
# history_str = history_str[2:]
|
1364 |
-
# self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1365 |
|
1366 |
self.print_len = 0
|
1367 |
|
1368 |
-
self.history.append({"role": "user", "content": query})
|
1369 |
self.history.append({"role": "assistant", "content": ""})
|
|
|
|
|
1370 |
elif self.generate_flag is False and query is not None and query not in ["<idle>"]:
|
1371 |
self.generate_flag = True
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
history_str = ""
|
1376 |
-
for iii in range(0, len(
|
1377 |
-
history_str += "<用户>" +
|
1378 |
-
if iii < len(
|
1379 |
-
history_str +=
|
1380 |
|
1381 |
self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1382 |
if self.input_ids.shape[-1] >= max_length:
|
1383 |
self.is_length_limit = True
|
|
|
1384 |
return 1
|
1385 |
-
# while self.input_ids.shape[-1] >= max_length and len(history_str) > 0:
|
1386 |
-
# history_str = history_str[2:]
|
1387 |
-
# self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1388 |
-
self.print_len = 0
|
1389 |
|
1390 |
-
self.
|
1391 |
self.history.append({"role": "assistant", "content": ""})
|
|
|
1392 |
else:
|
1393 |
-
return
|
1394 |
if logits_processor is None:
|
1395 |
self.logits_processor = LogitsProcessorList()
|
1396 |
|
1397 |
-
# logits_processor.append(InvalidScoreLogitsProcessor())
|
1398 |
-
self.gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1399 |
-
"temperature": temperature, "logits_processor": self.logits_processor, **kwargs}
|
1400 |
-
# self.generation_config = self.generation_config.update(**self.gen_kwargs)
|
1401 |
-
self.model_kwargs = self.generation_config.update(**self.gen_kwargs)
|
1402 |
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1403 |
-
self.stopping_criteria = self._get_stopping_criteria(
|
1404 |
-
generation_config=self.generation_config, stopping_criteria=stopping_criteria
|
1405 |
-
)
|
1406 |
-
self.prefix_allowed_tokens_fn = None
|
1407 |
-
self.tokenizer = tokenizer
|
1408 |
-
self.logits_warper = self._get_logits_warper(self.generation_config)
|
1409 |
-
self.has_default_max_length = kwargs.get("max_length") is None and self.generation_config.max_length is not None
|
1410 |
-
# for outputs in self.stream_generate(inputs, **gen_kwargs):
|
1411 |
-
# outputs = outputs.tolist()[0][len(inputs[0]):]
|
1412 |
-
# response = tokenizer.decode(outputs)
|
1413 |
-
# new_history = history + [{"role": "user", "content": query},
|
1414 |
-
# {"role": "assistant", "content": response}]
|
1415 |
-
# yield response, new_history
|
1416 |
return 0
|
|
|
1417 |
@torch.inference_mode()
|
1418 |
def stream_generate(
|
1419 |
self,
|
@@ -1507,19 +1510,13 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
1507 |
response = self.tokenizer.batch_decode(self.input_ids, spaces_between_special_tokens=False)[0]
|
1508 |
# print("response: ", response)
|
1509 |
response = response.rsplit("<AI>", 1)[-1]
|
1510 |
-
|
1511 |
-
# print("response: ", self.input_ids[0][-1], response)
|
1512 |
cut_len = self.print_len
|
1513 |
-
|
1514 |
-
# self.input_ids = self.input_ids[:, :-1]
|
1515 |
-
# return None, self.history
|
1516 |
self.print_len = len(response)
|
|
|
|
|
1517 |
self.history[-1]["content"] += response[cut_len:]
|
1518 |
-
|
1519 |
-
# self.generate_flag = False
|
1520 |
-
# if response[cut_len:] in ["<idle>", " <idle>"]:
|
1521 |
-
# self.generate_flag = False
|
1522 |
-
|
1523 |
return response[cut_len:], self.history
|
1524 |
|
1525 |
|
@@ -1678,4 +1675,4 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
|
|
1678 |
past_key_values=transformer_outputs.past_key_values,
|
1679 |
hidden_states=transformer_outputs.hidden_states,
|
1680 |
attentions=transformer_outputs.attentions,
|
1681 |
-
)
|
|
|
21 |
import math
|
22 |
import warnings
|
23 |
from typing import List, Optional, Tuple, Union, Dict
|
24 |
+
import jsonlines
|
25 |
+
import time
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
28 |
import torch.utils.checkpoint
|
|
|
1133 |
#######FOR DUPLEX
|
1134 |
self.input_ids = None
|
1135 |
self.history = []
|
1136 |
+
self.history_all = []
|
1137 |
self.logits_processor = LogitsProcessorList()
|
1138 |
self.generate_flag = False
|
1139 |
self.print_len = 0
|
1140 |
self.is_length_limit = False
|
1141 |
|
1142 |
def reset_chat_history(self):
|
1143 |
+
save_file = "/data/duplex_logs/subject_duplex_%s.jsonl"%(time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime(time.time())))
|
1144 |
+
fw = jsonlines.open(save_file, "w")
|
1145 |
+
fw.write_all(self.history_all)
|
1146 |
+
fw.close()
|
1147 |
self.input_ids = None
|
1148 |
self.history = []
|
1149 |
+
self.history_all = []
|
1150 |
self.logits_processor = LogitsProcessorList()
|
1151 |
self.generate_flag = False
|
1152 |
self.print_len = 0
|
|
|
1328 |
self.history = []
|
1329 |
for i in range(0, len(history_old), 2):
|
1330 |
if history_old[i]["content"] == "<idle>":
|
1331 |
+
if i + 1 < len(history_old) and history_old[i+1]["content"].strip(" .\n,") in ["<idle>", "<idle></s>", "</s>", "idle", "idle</s>"]:
|
1332 |
self.generate_flag = False
|
1333 |
continue
|
1334 |
else:
|
1335 |
self.history.append(history_old[i])
|
1336 |
+
if i + 1 < len(history_old):
|
1337 |
+
self.history.append(history_old[i+1])
|
1338 |
self.generate_flag = True
|
1339 |
|
1340 |
else:
|
1341 |
self.history.append(history_old[i])
|
1342 |
+
if i + 1 < len(history_old):
|
1343 |
+
self.history.append(history_old[i+1])
|
1344 |
self.generate_flag = True
|
1345 |
|
1346 |
|
|
|
1350 |
stopping_criteria=None, **kwargs):
|
1351 |
|
1352 |
# torch.manual_seed(0)
|
1353 |
+
self.gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
1354 |
+
"temperature": temperature, "logits_processor": self.logits_processor, **kwargs}
|
1355 |
+
# self.generation_config = self.generation_config.update(**self.gen_kwargs)
|
1356 |
+
self.model_kwargs = self.generation_config.update(**self.gen_kwargs)
|
1357 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
1358 |
+
self.stopping_criteria = self._get_stopping_criteria(
|
1359 |
+
generation_config=self.generation_config, stopping_criteria=stopping_criteria
|
1360 |
+
)
|
1361 |
+
self.prefix_allowed_tokens_fn = None
|
1362 |
+
self.tokenizer = tokenizer
|
1363 |
+
self.logits_warper = self._get_logits_warper(self.generation_config)
|
1364 |
+
self.has_default_max_length = kwargs.get("max_length") is None and self.generation_config.max_length is not None
|
1365 |
+
|
1366 |
if self.generate_flag is True and query is not None:
|
1367 |
self.update_history()
|
1368 |
+
|
1369 |
+
# prompt = copy.deepcopy(self.history)
|
1370 |
if self.generate_flag is False and query in ["<idle>"]:
|
1371 |
+
return 2
|
1372 |
elif query not in ["<idle>"]:
|
1373 |
self.generate_flag = True
|
1374 |
+
self.history.append({"role": "user", "content": query})
|
1375 |
+
self.history_all.append({"role": "user", "content": query, "timestamp": time.time()})
|
1376 |
history_str = ""
|
1377 |
+
for iii in range(0, len(self.history), 2):
|
1378 |
+
history_str += "<用户>" + self.history[iii]["content"] + "<AI>"
|
1379 |
+
if iii < len(self.history) - 1:
|
1380 |
+
history_str += self.history[iii+1]["content"]
|
1381 |
# history_str = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
|
1382 |
self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1383 |
if self.input_ids.shape[-1] >= max_length:
|
1384 |
self.is_length_limit = True
|
1385 |
+
# self.history = self.history[:-1]
|
1386 |
return 1
|
|
|
|
|
|
|
1387 |
|
1388 |
self.print_len = 0
|
1389 |
|
|
|
1390 |
self.history.append({"role": "assistant", "content": ""})
|
1391 |
+
self.history_all.append({"role": "assistant", "content": ""})
|
1392 |
+
|
1393 |
elif self.generate_flag is False and query is not None and query not in ["<idle>"]:
|
1394 |
self.generate_flag = True
|
1395 |
+
self.history.append({"role": "user", "content": query})
|
1396 |
+
self.history_all.append({"role": "user", "content": query, "timestamp": time.time()})
|
1397 |
+
|
1398 |
history_str = ""
|
1399 |
+
for iii in range(0, len(self.history), 2):
|
1400 |
+
history_str += "<用户>" + self.history[iii]["content"] + "<AI>"
|
1401 |
+
if iii < len(self.history) - 1:
|
1402 |
+
history_str += self.history[iii+1]["content"]
|
1403 |
|
1404 |
self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
|
1405 |
if self.input_ids.shape[-1] >= max_length:
|
1406 |
self.is_length_limit = True
|
1407 |
+
# self.history = self.history[:-1]
|
1408 |
return 1
|
|
|
|
|
|
|
|
|
1409 |
|
1410 |
+
self.print_len = 0
|
1411 |
self.history.append({"role": "assistant", "content": ""})
|
1412 |
+
self.history_all.append({"role": "assistant", "content": ""})
|
1413 |
else:
|
1414 |
+
return 2
|
1415 |
if logits_processor is None:
|
1416 |
self.logits_processor = LogitsProcessorList()
|
1417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1418 |
return 0
|
1419 |
+
|
1420 |
@torch.inference_mode()
|
1421 |
def stream_generate(
|
1422 |
self,
|
|
|
1510 |
response = self.tokenizer.batch_decode(self.input_ids, spaces_between_special_tokens=False)[0]
|
1511 |
# print("response: ", response)
|
1512 |
response = response.rsplit("<AI>", 1)[-1]
|
|
|
|
|
1513 |
cut_len = self.print_len
|
1514 |
+
|
|
|
|
|
1515 |
self.print_len = len(response)
|
1516 |
+
if self.history_all[-1]["content"] == "":
|
1517 |
+
self.history_all[-1]["timestamp"] = time.time()
|
1518 |
self.history[-1]["content"] += response[cut_len:]
|
1519 |
+
self.history_all[-1]["content"] += response[cut_len:]
|
|
|
|
|
|
|
|
|
1520 |
return response[cut_len:], self.history
|
1521 |
|
1522 |
|
|
|
1675 |
past_key_values=transformer_outputs.past_key_values,
|
1676 |
hidden_states=transformer_outputs.hidden_states,
|
1677 |
attentions=transformer_outputs.attentions,
|
1678 |
+
)
|