Spaces:
Runtime error
Runtime error
import os | |
import gc | |
import torch | |
import torch.nn as nn | |
import argparse | |
import gradio as gr | |
from transformers import AutoTokenizer, LlamaForCausalLM | |
from utils import SteamGenerationMixin | |
auth_token = os.getenv("AUTH_TOKEN") | |
print('^_^ auth_token:',os.getenv("AUTH_TOKEN"),'!!!!!!!!!!') | |
print('^_^:secret_token',os.getenv("SECRET_TOKEN"),'!!!!!!!!!!') | |
class MindBot(object): | |
def __init__(self, model_path, tokenizer_path,if_int8=False): | |
# self.device = torch.device("cuda") | |
# device_ids = [1, 2] | |
if if_int8: | |
self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto', load_in_8bit=True,use_auth_token='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').eval() | |
else: | |
self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto',use_auth_token='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').half().eval() | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
# sp_tokens = {'additional_special_tokens': ['<human>', '<bot>']} | |
# self.tokenizer.add_special_tokens(sp_tokens) | |
self.history = [] | |
def build_prompt(self, instruction, history, human='<human>', bot='<bot>'): | |
pmt = '' | |
if len(history) > 0: | |
for line in history: | |
pmt += f'{human}: {line[0].strip()}\n{bot}: {line[1]}\n' | |
pmt += f'{human}: {instruction.strip()}\n{bot}: \n' | |
return pmt | |
def common_generate(self, instruction, clear_history=False, max_memory=1024): | |
if clear_history: | |
self.history = [] | |
prompt = self.build_prompt(instruction, self.history) | |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids | |
if input_ids.shape[1] > max_memory: | |
input_ids = input_ids[:, -max_memory:] | |
prompt_len = input_ids.shape[1] | |
# common method | |
generation_output = self.model.generate( | |
input_ids.cuda(), | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.85, | |
temperature=0.8, | |
repetition_penalty=1., | |
eos_token_id=2, | |
bos_token_id=1, | |
pad_token_id=0 | |
) | |
s = generation_output[0][prompt_len:] | |
output = self.tokenizer.decode(s, skip_special_tokens=True) | |
# output = output | |
output = output.replace("Belle", "IDEA") | |
self.history.append((instruction, output)) | |
print('api history: ======> \n', self.history) | |
return output | |
def interaction( | |
self, | |
instruction, | |
history, | |
max_memory=1024 | |
): | |
prompt = self.build_prompt(instruction, history) | |
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids | |
if input_ids.shape[1] > max_memory: | |
input_ids = input_ids[:, -max_memory:] | |
prompt_len = input_ids.shape[1] | |
# stream generation method | |
try: | |
tmp = history.copy() | |
output = '' | |
with torch.no_grad(): | |
for generation_output in self.model.stream_generate( | |
input_ids.cuda(), | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.85, | |
temperature=0.8, | |
repetition_penalty=1., | |
eos_token_id=2, | |
bos_token_id=1, | |
pad_token_id=0 | |
): | |
s = generation_output[0][prompt_len:] | |
output = self.tokenizer.decode(s, skip_special_tokens=True) | |
output = output.replace('\n', '<br>') | |
tmp.append((instruction, output)) | |
yield '', tmp | |
tmp.pop() | |
# gc.collect() | |
# torch.cuda.empty_cache() | |
history.append((instruction, output)) | |
print('input -----> \n', prompt) | |
print('output -------> \n', output) | |
print('history: ======> \n', history) | |
except torch.cuda.OutOfMemoryError: | |
gc.collect() | |
torch.cuda.empty_cache() | |
self.model.empty_cache() | |
return "", history | |
def new_chat_bot(self): | |
with gr.Blocks(title='IDEA MindBot', css=".gradio-container {max-width: 50% !important;} .bgcolor {color: white !important; background: #FFA500 !important;}") as demo: | |
gr.Markdown("<center><h1>IDEA MindBot</h1></center>") | |
gr.Markdown("<center>本页面基于hugging face支持的设备搭建</center>") | |
with gr.Row(): | |
chatbot = gr.Chatbot(label='MindBot').style(height=500) | |
with gr.Row(): | |
msg = gr.Textbox(label="Input") | |
with gr.Row(): | |
with gr.Column(scale=0.5): | |
clear = gr.Button("Clear") | |
with gr.Column(scale=0.5): | |
submit = gr.Button("Submit", elem_classes='bgcolor') | |
msg.submit(self.interaction, [msg, chatbot], [msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
submit.click(self.interaction, [msg, chatbot], [msg, chatbot]) | |
return demo.queue(concurrency_count=5) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="/cognitive_comp/songchao/checkpoints/global_step3200-hf" | |
) | |
args = parser.parse_args() | |
mind_bot = MindBot(args.model_path) | |
demo = mind_bot.new_chat_bot() | |