xinrongzhang2022 commited on
Commit
c692147
1 Parent(s): 7ca6626

Upload 9 files

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +55 -58
modeling_minicpm.py CHANGED
@@ -21,7 +21,8 @@
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
-
 
25
  import torch
26
  import torch.nn.functional as F
27
  import torch.utils.checkpoint
@@ -1132,14 +1133,20 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1132
  #######FOR DUPLEX
1133
  self.input_ids = None
1134
  self.history = []
 
1135
  self.logits_processor = LogitsProcessorList()
1136
  self.generate_flag = False
1137
  self.print_len = 0
1138
  self.is_length_limit = False
1139
 
1140
  def reset_chat_history(self):
 
 
 
 
1141
  self.input_ids = None
1142
  self.history = []
 
1143
  self.logits_processor = LogitsProcessorList()
1144
  self.generate_flag = False
1145
  self.print_len = 0
@@ -1321,17 +1328,19 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1321
  self.history = []
1322
  for i in range(0, len(history_old), 2):
1323
  if history_old[i]["content"] == "<idle>":
1324
- if history_old[i+1]["content"].strip(" .\n,") in ["<idle>", "<idle></s>", "</s>", "idle", "idle</s>"]:
1325
  self.generate_flag = False
1326
  continue
1327
  else:
1328
  self.history.append(history_old[i])
1329
- self.history.append(history_old[i+1])
 
1330
  self.generate_flag = True
1331
 
1332
  else:
1333
  self.history.append(history_old[i])
1334
- self.history.append(history_old[i+1])
 
1335
  self.generate_flag = True
1336
 
1337
 
@@ -1341,79 +1350,73 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1341
  stopping_criteria=None, **kwargs):
1342
 
1343
  # torch.manual_seed(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
1344
  if self.generate_flag is True and query is not None:
1345
  self.update_history()
1346
- prompt = copy.deepcopy(self.history)
 
1347
  if self.generate_flag is False and query in ["<idle>"]:
1348
- return 1
1349
  elif query not in ["<idle>"]:
1350
  self.generate_flag = True
1351
- prompt.append({"role": "user", "content": query})
 
1352
  history_str = ""
1353
- for iii in range(0, len(prompt), 2):
1354
- history_str += "<用户>" + prompt[iii]["content"] + "<AI>"
1355
- if iii < len(prompt) - 1:
1356
- history_str += prompt[iii+1]["content"]
1357
  # history_str = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
1358
  self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1359
  if self.input_ids.shape[-1] >= max_length:
1360
  self.is_length_limit = True
 
1361
  return 1
1362
- # while self.input_ids.shape[-1] >= max_length and len(history_str) > 0:
1363
- # history_str = history_str[2:]
1364
- # self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1365
 
1366
  self.print_len = 0
1367
 
1368
- self.history.append({"role": "user", "content": query})
1369
  self.history.append({"role": "assistant", "content": ""})
 
 
1370
  elif self.generate_flag is False and query is not None and query not in ["<idle>"]:
1371
  self.generate_flag = True
1372
- prompt = copy.deepcopy(self.history)
1373
- prompt.append({"role": "user", "content": query})
1374
- # history_str = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
1375
  history_str = ""
1376
- for iii in range(0, len(prompt), 2):
1377
- history_str += "<用户>" + prompt[iii]["content"] + "<AI>"
1378
- if iii < len(prompt) - 1:
1379
- history_str += prompt[iii+1]["content"]
1380
 
1381
  self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1382
  if self.input_ids.shape[-1] >= max_length:
1383
  self.is_length_limit = True
 
1384
  return 1
1385
- # while self.input_ids.shape[-1] >= max_length and len(history_str) > 0:
1386
- # history_str = history_str[2:]
1387
- # self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1388
- self.print_len = 0
1389
 
1390
- self.history.append({"role": "user", "content": query})
1391
  self.history.append({"role": "assistant", "content": ""})
 
1392
  else:
1393
- return 1
1394
  if logits_processor is None:
1395
  self.logits_processor = LogitsProcessorList()
1396
 
1397
- # logits_processor.append(InvalidScoreLogitsProcessor())
1398
- self.gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1399
- "temperature": temperature, "logits_processor": self.logits_processor, **kwargs}
1400
- # self.generation_config = self.generation_config.update(**self.gen_kwargs)
1401
- self.model_kwargs = self.generation_config.update(**self.gen_kwargs)
1402
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1403
- self.stopping_criteria = self._get_stopping_criteria(
1404
- generation_config=self.generation_config, stopping_criteria=stopping_criteria
1405
- )
1406
- self.prefix_allowed_tokens_fn = None
1407
- self.tokenizer = tokenizer
1408
- self.logits_warper = self._get_logits_warper(self.generation_config)
1409
- self.has_default_max_length = kwargs.get("max_length") is None and self.generation_config.max_length is not None
1410
- # for outputs in self.stream_generate(inputs, **gen_kwargs):
1411
- # outputs = outputs.tolist()[0][len(inputs[0]):]
1412
- # response = tokenizer.decode(outputs)
1413
- # new_history = history + [{"role": "user", "content": query},
1414
- # {"role": "assistant", "content": response}]
1415
- # yield response, new_history
1416
  return 0
 
1417
  @torch.inference_mode()
1418
  def stream_generate(
1419
  self,
@@ -1507,19 +1510,13 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1507
  response = self.tokenizer.batch_decode(self.input_ids, spaces_between_special_tokens=False)[0]
1508
  # print("response: ", response)
1509
  response = response.rsplit("<AI>", 1)[-1]
1510
-
1511
- # print("response: ", self.input_ids[0][-1], response)
1512
  cut_len = self.print_len
1513
- # if "<idle>" in response[cut_len:] and len(self.history[-1]["content"]) != 0:
1514
- # self.input_ids = self.input_ids[:, :-1]
1515
- # return None, self.history
1516
  self.print_len = len(response)
 
 
1517
  self.history[-1]["content"] += response[cut_len:]
1518
- # if self.history[-1]["content"][-8:] == "</s></s>":
1519
- # self.generate_flag = False
1520
- # if response[cut_len:] in ["<idle>", " <idle>"]:
1521
- # self.generate_flag = False
1522
-
1523
  return response[cut_len:], self.history
1524
 
1525
 
@@ -1678,4 +1675,4 @@ class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel):
1678
  past_key_values=transformer_outputs.past_key_values,
1679
  hidden_states=transformer_outputs.hidden_states,
1680
  attentions=transformer_outputs.attentions,
1681
- )
 
21
  import math
22
  import warnings
23
  from typing import List, Optional, Tuple, Union, Dict
24
+ import jsonlines
25
+ import time
26
  import torch
27
  import torch.nn.functional as F
28
  import torch.utils.checkpoint
 
1133
  #######FOR DUPLEX
1134
  self.input_ids = None
1135
  self.history = []
1136
+ self.history_all = []
1137
  self.logits_processor = LogitsProcessorList()
1138
  self.generate_flag = False
1139
  self.print_len = 0
1140
  self.is_length_limit = False
1141
 
1142
  def reset_chat_history(self):
1143
+ save_file = "/data/duplex_logs/subject_duplex_%s.jsonl"%(time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime(time.time())))
1144
+ fw = jsonlines.open(save_file, "w")
1145
+ fw.write_all(self.history_all)
1146
+ fw.close()
1147
  self.input_ids = None
1148
  self.history = []
1149
+ self.history_all = []
1150
  self.logits_processor = LogitsProcessorList()
1151
  self.generate_flag = False
1152
  self.print_len = 0
 
1328
  self.history = []
1329
  for i in range(0, len(history_old), 2):
1330
  if history_old[i]["content"] == "<idle>":
1331
+ if i + 1 < len(history_old) and history_old[i+1]["content"].strip(" .\n,") in ["<idle>", "<idle></s>", "</s>", "idle", "idle</s>"]:
1332
  self.generate_flag = False
1333
  continue
1334
  else:
1335
  self.history.append(history_old[i])
1336
+ if i + 1 < len(history_old):
1337
+ self.history.append(history_old[i+1])
1338
  self.generate_flag = True
1339
 
1340
  else:
1341
  self.history.append(history_old[i])
1342
+ if i + 1 < len(history_old):
1343
+ self.history.append(history_old[i+1])
1344
  self.generate_flag = True
1345
 
1346
 
 
1350
  stopping_criteria=None, **kwargs):
1351
 
1352
  # torch.manual_seed(0)
1353
+ self.gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1354
+ "temperature": temperature, "logits_processor": self.logits_processor, **kwargs}
1355
+ # self.generation_config = self.generation_config.update(**self.gen_kwargs)
1356
+ self.model_kwargs = self.generation_config.update(**self.gen_kwargs)
1357
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1358
+ self.stopping_criteria = self._get_stopping_criteria(
1359
+ generation_config=self.generation_config, stopping_criteria=stopping_criteria
1360
+ )
1361
+ self.prefix_allowed_tokens_fn = None
1362
+ self.tokenizer = tokenizer
1363
+ self.logits_warper = self._get_logits_warper(self.generation_config)
1364
+ self.has_default_max_length = kwargs.get("max_length") is None and self.generation_config.max_length is not None
1365
+
1366
  if self.generate_flag is True and query is not None:
1367
  self.update_history()
1368
+
1369
+ # prompt = copy.deepcopy(self.history)
1370
  if self.generate_flag is False and query in ["<idle>"]:
1371
+ return 2
1372
  elif query not in ["<idle>"]:
1373
  self.generate_flag = True
1374
+ self.history.append({"role": "user", "content": query})
1375
+ self.history_all.append({"role": "user", "content": query, "timestamp": time.time()})
1376
  history_str = ""
1377
+ for iii in range(0, len(self.history), 2):
1378
+ history_str += "<用户>" + self.history[iii]["content"] + "<AI>"
1379
+ if iii < len(self.history) - 1:
1380
+ history_str += self.history[iii+1]["content"]
1381
  # history_str = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False)
1382
  self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1383
  if self.input_ids.shape[-1] >= max_length:
1384
  self.is_length_limit = True
1385
+ # self.history = self.history[:-1]
1386
  return 1
 
 
 
1387
 
1388
  self.print_len = 0
1389
 
 
1390
  self.history.append({"role": "assistant", "content": ""})
1391
+ self.history_all.append({"role": "assistant", "content": ""})
1392
+
1393
  elif self.generate_flag is False and query is not None and query not in ["<idle>"]:
1394
  self.generate_flag = True
1395
+ self.history.append({"role": "user", "content": query})
1396
+ self.history_all.append({"role": "user", "content": query, "timestamp": time.time()})
1397
+
1398
  history_str = ""
1399
+ for iii in range(0, len(self.history), 2):
1400
+ history_str += "<用户>" + self.history[iii]["content"] + "<AI>"
1401
+ if iii < len(self.history) - 1:
1402
+ history_str += self.history[iii+1]["content"]
1403
 
1404
  self.input_ids = tokenizer(history_str, return_tensors='pt').to(self.device).input_ids
1405
  if self.input_ids.shape[-1] >= max_length:
1406
  self.is_length_limit = True
1407
+ # self.history = self.history[:-1]
1408
  return 1
 
 
 
 
1409
 
1410
+ self.print_len = 0
1411
  self.history.append({"role": "assistant", "content": ""})
1412
+ self.history_all.append({"role": "assistant", "content": ""})
1413
  else:
1414
+ return 2
1415
  if logits_processor is None:
1416
  self.logits_processor = LogitsProcessorList()
1417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1418
  return 0
1419
+
1420
  @torch.inference_mode()
1421
  def stream_generate(
1422
  self,
 
1510
  response = self.tokenizer.batch_decode(self.input_ids, spaces_between_special_tokens=False)[0]
1511
  # print("response: ", response)
1512
  response = response.rsplit("<AI>", 1)[-1]
 
 
1513
  cut_len = self.print_len
1514
+
 
 
1515
  self.print_len = len(response)
1516
+ if self.history_all[-1]["content"] == "":
1517
+ self.history_all[-1]["timestamp"] = time.time()
1518
  self.history[-1]["content"] += response[cut_len:]
1519
+ self.history_all[-1]["content"] += response[cut_len:]
 
 
 
 
1520
  return response[cut_len:], self.history
1521
 
1522
 
 
1675
  past_key_values=transformer_outputs.past_key_values,
1676
  hidden_states=transformer_outputs.hidden_states,
1677
  attentions=transformer_outputs.attentions,
1678
+ )