File size: 4,980 Bytes
763f36a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import List
from queue import Queue
import torch
# def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
# def _parse_messages(messages, split_role="user"):
# system, rounds = "", []
# round = []
# for i, message in enumerate(messages):
# if message["role"] == "system":
# assert i == 0
# system = message["content"]
# continue
# if message["role"] == split_role and round:
# rounds.append(round)
# round = []
# round.append(message)
# if round:
# rounds.append(round)
# return system, rounds
# max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
# max_input_tokens = model.config.model_max_length - max_new_tokens
# system, rounds = _parse_messages(messages, split_role="user")
# system_tokens = tokenizer.encode(system)
# max_history_tokens = max_input_tokens - len(system_tokens)
# history_tokens = []
# for round in rounds[::-1]:
# round_tokens = []
# for message in round:
# if message["role"] == "user":
# round_tokens.append(model.generation_config.user_token_id)
# else:
# round_tokens.append(model.generation_config.assistant_token_id)
# round_tokens.extend(tokenizer.encode(message["content"]))
# if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
# history_tokens = round_tokens + history_tokens # concat left
# if len(history_tokens) < max_history_tokens:
# continue
# break
# input_tokens = system_tokens + history_tokens
# if messages[-1]["role"] != "assistant":
# input_tokens.append(model.generation_config.assistant_token_id)
# input_tokens = input_tokens[-max_input_tokens:] # truncate left
# return torch.LongTensor([input_tokens]).to(model.device)
# for HuatuoGPT2
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
def _parse_messages(messages, split_role="user"):
system, rounds = "", []
round = []
for i, message in enumerate(messages):
# if message["role"] == "system":
# assert i == 0
# system = message["content"]
# continue
if message["role"] == split_role and round:
rounds.append(round)
round = []
round.append(message)
if round:
rounds.append(round)
return system, rounds
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
max_input_tokens = model.config.model_max_length - max_new_tokens
system, rounds = _parse_messages(messages, split_role="user")
max_history_tokens = max_input_tokens
roles = ('<问>:','<答>:')
sep = '\n'
history_tokens = []
for round in rounds[::-1]:
round_tokens = []
for message in round:
message["content"]
if message["role"] == "user":
round_tokens.extend(tokenizer.encode(roles[0]+message["content"]+sep))
else:
round_tokens.extend(tokenizer.encode(roles[1]+message["content"]+sep))
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
history_tokens = round_tokens + history_tokens # concat left
if len(history_tokens) < max_history_tokens:
continue
break
input_tokens = history_tokens
if messages[-1]["role"] != "assistant":
input_tokens.extend(tokenizer.encode(roles[1]))
# debug
input_tokens = input_tokens[-max_input_tokens:] # truncate left
# print(tokenizer.decode(input_tokens),flush=True)
return torch.LongTensor([input_tokens]).to(model.device)
class TextIterStreamer:
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.skip_special_tokens = skip_special_tokens
self.tokens = []
self.text_queue = Queue()
self.next_tokens_are_prompt = True
def put(self, value):
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
else:
if len(value.shape) > 1:
value = value[0]
self.tokens.extend(value.tolist())
self.text_queue.put(
self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
def end(self):
self.text_queue.put(None)
def __iter__(self):
return self
def __next__(self):
value = self.text_queue.get()
if value is None:
raise StopIteration()
else:
return value
|