Spaces:
No application file
No application file
from abc import ABC | |
from langchain.llms.base import LLM | |
import random | |
import torch | |
import transformers | |
from transformers.generation.logits_process import LogitsProcessor | |
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList | |
from typing import Optional, List, Dict, Any | |
from models.loader import LoaderCheckPoint | |
from models.base import (BaseAnswer, | |
AnswerResult) | |
class InvalidScoreLogitsProcessor(LogitsProcessor): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
scores.zero_() | |
scores[..., 5] = 5e4 | |
return scores | |
class LLamaLLM(BaseAnswer, LLM, ABC): | |
checkPoint: LoaderCheckPoint = None | |
# history = [] | |
history_len: int = 3 | |
max_new_tokens: int = 500 | |
num_beams: int = 1 | |
temperature: float = 0.5 | |
top_p: float = 0.4 | |
top_k: int = 10 | |
repetition_penalty: float = 1.2 | |
encoder_repetition_penalty: int = 1 | |
min_length: int = 0 | |
logits_processor: LogitsProcessorList = None | |
stopping_criteria: Optional[StoppingCriteriaList] = None | |
eos_token_id: Optional[int] = [2] | |
state: object = {'max_new_tokens': 50, | |
'seed': 1, | |
'temperature': 0, 'top_p': 0.1, | |
'top_k': 40, 'typical_p': 1, | |
'repetition_penalty': 1.2, | |
'encoder_repetition_penalty': 1, | |
'no_repeat_ngram_size': 0, | |
'min_length': 0, | |
'penalty_alpha': 0, | |
'num_beams': 1, | |
'length_penalty': 1, | |
'early_stopping': False, 'add_bos_token': True, 'ban_eos_token': False, | |
'truncation_length': 2048, 'custom_stopping_strings': '', | |
'cpu_memory': 0, 'auto_devices': False, 'disk': False, 'cpu': False, 'bf16': False, | |
'load_in_8bit': False, 'wbits': 'None', 'groupsize': 'None', 'model_type': 'None', | |
'pre_layer': 0, 'gpu_memory_0': 0} | |
def __init__(self, checkPoint: LoaderCheckPoint = None): | |
super().__init__() | |
self.checkPoint = checkPoint | |
def _llm_type(self) -> str: | |
return "LLamaLLM" | |
def _check_point(self) -> LoaderCheckPoint: | |
return self.checkPoint | |
def encode(self, prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): | |
input_ids = self.checkPoint.tokenizer.encode(str(prompt), return_tensors='pt', | |
add_special_tokens=add_special_tokens) | |
# This is a hack for making replies more creative. | |
if not add_bos_token and input_ids[0][0] == self.checkPoint.tokenizer.bos_token_id: | |
input_ids = input_ids[:, 1:] | |
# Llama adds this extra token when the first character is '\n', and this | |
# compromises the stopping criteria, so we just remove it | |
if type(self.checkPoint.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: | |
input_ids = input_ids[:, 1:] | |
# Handling truncation | |
if truncation_length is not None: | |
input_ids = input_ids[:, -truncation_length:] | |
return input_ids.cuda() | |
def decode(self, output_ids): | |
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True) | |
return reply | |
# 将历史对话数组转换为文本格式 | |
def history_to_text(self, query, history): | |
""" | |
历史对话软提示 | |
这段代码首先定义了一个名为 history_to_text 的函数,用于将 self.history | |
数组转换为所需的文本格式。然后,我们将格式化后的历史文本 | |
再用 self.encode 将其转换为向量表示。最后,将历史对话向量与当前输入的对话向量拼接在一起。 | |
:return: | |
""" | |
formatted_history = '' | |
history = history[-self.history_len:] if self.history_len > 0 else [] | |
if len(history) > 0: | |
for i, (old_query, response) in enumerate(history): | |
formatted_history += "### Human:{}\n### Assistant:{}\n".format(old_query, response) | |
formatted_history += "### Human:{}\n### Assistant:".format(query) | |
return formatted_history | |
def prepare_inputs_for_generation(self, | |
input_ids: torch.LongTensor): | |
""" | |
预生成注意力掩码和 输入序列中每个位置的索引的张量 | |
# TODO 没有思路 | |
:return: | |
""" | |
mask_positions = torch.zeros((1, input_ids.shape[1]), dtype=input_ids.dtype).to(self.checkPoint.model.device) | |
attention_mask = self.get_masks(input_ids, input_ids.device) | |
position_ids = self.get_position_ids( | |
input_ids, | |
device=input_ids.device, | |
mask_positions=mask_positions | |
) | |
return input_ids, position_ids, attention_mask | |
def _history_len(self) -> int: | |
return self.history_len | |
def set_history_len(self, history_len: int = 10) -> None: | |
self.history_len = history_len | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
print(f"__call:{prompt}") | |
if self.logits_processor is None: | |
self.logits_processor = LogitsProcessorList() | |
self.logits_processor.append(InvalidScoreLogitsProcessor()) | |
gen_kwargs = { | |
"max_new_tokens": self.max_new_tokens, | |
"num_beams": self.num_beams, | |
"top_p": self.top_p, | |
"do_sample": True, | |
"top_k": self.top_k, | |
"repetition_penalty": self.repetition_penalty, | |
"encoder_repetition_penalty": self.encoder_repetition_penalty, | |
"min_length": self.min_length, | |
"temperature": self.temperature, | |
"eos_token_id": self.checkPoint.tokenizer.eos_token_id, | |
"logits_processor": self.logits_processor} | |
# 向量转换 | |
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens) | |
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids) | |
gen_kwargs.update({'inputs': input_ids}) | |
# 注意力掩码 | |
# gen_kwargs.update({'attention_mask': attention_mask}) | |
# gen_kwargs.update({'position_ids': position_ids}) | |
if self.stopping_criteria is None: | |
self.stopping_criteria = transformers.StoppingCriteriaList() | |
# 观测输出 | |
gen_kwargs.update({'stopping_criteria': self.stopping_criteria}) | |
output_ids = self.checkPoint.model.generate(**gen_kwargs) | |
new_tokens = len(output_ids[0]) - len(input_ids[0]) | |
reply = self.decode(output_ids[0][-new_tokens:]) | |
print(f"response:{reply}") | |
print(f"+++++++++++++++++++++++++++++++++++") | |
return reply | |
def generatorAnswer(self, prompt: str, | |
history: List[List[str]] = [], | |
streaming: bool = False): | |
# TODO 需要实现chat对话模块和注意力模型,目前_call为langchain的LLM拓展的api,默认为无提示词模式,如果需要操作注意力模型,可以参考chat_glm的实现 | |
softprompt = self.history_to_text(prompt,history=history) | |
response = self._call(prompt=softprompt, stop=['\n###']) | |
answer_result = AnswerResult() | |
answer_result.history = history + [[prompt, response]] | |
answer_result.llm_output = {"answer": response} | |
yield answer_result | |