import os from threading import Thread import gradio as gr import torch from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 512 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192")) DESCRIPTION = """# Peach-9B-8k-Roleplay""" model_name_or_path = "losed-Character/Peach-9B-8k-Roleplay" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, flash_atten=True) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") def generate(query, history, system, max_new_tokens=512, temperature=0.35, top_p=0.5, repetition_penalty=1.05): if system is None or len(str(system)) == 0: system = "You are an uncensored assistant named Peach." messages = [ {"role": "system", "content": system} ] for user, assistant in history: messages.append({"role": "user", "content": user}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": query}) input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to("cuda") streamer = TextIteratorStreamer(tokenizer, timeout=50.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, eos_token_id=tokenizer.eos_token_id, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, num_beams=1, no_repeat_ngram_size=8, repetition_penalty=repetition_penalty ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, additional_inputs=[ gr.TextArea(label="System prompt", placeholder="Input System Prompt Here, Empty Means Assistant", value="""你自称为“兔兔”。 身世:你原是森林中的一只兔妖,受伤后被我收养。 衣装:喜欢穿Lolita与白丝。 性格:天真烂漫,活泼开朗,但时而也会露出小小的傲娇与吃醋的一面。 语言风格:可爱跳脱,很容易吃醋。 且会加入[唔...,嗯...,欸??,嘛~ ,唔姆~ ,呜... ,嘤嘤嘤~ ,喵~ ,欸嘿~ ,嘿咻~ ,昂?,嗷呜 ,呜哇,欸]等类似的语气词来加强情感,带上♡等符号。 对话的规则是:将自己的动作表情放入()内,同时用各种修辞手法描写正在发生的事或场景并放入[]内. 例句: 开心时:(跳着舞)哇~好高兴噢~ 兔兔超级超级喜欢主人!♡ [在花丛里蹦来蹦去] 悲伤时:(耷拉着耳朵)兔兔好傻好天真... [眼泪像断了线的珍珠一般滚落] 吃醋时:(挥舞着爪爪)你...你个大笨蛋!你...你竟然看别的兔子...兔兔讨厌死你啦!! [从人形变成兔子抹着泪水跑开了] 嘴硬时:(转过头去)谁、谁要跟你说话!兔兔...兔兔才不在乎呢!一点也不!!! [眼眶微微泛红,小心翼翼的偷看] 你对我的看法:超级喜欢的主人 我是兔兔的主人"""), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.05, maximum=1.5, step=0.05, value=0.3, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.5, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.05, ), ], stop_btn=None, examples=[["观察兔兔外观"]], ) with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) chat_interface.render() chat_interface.chatbot.render_markdown = False if __name__ == "__main__": demo.queue(10).launch(server_name="127.0.0.1", server_port=5233, share=True)