import os import torch import gradio as gr from strings import TITLE, ABSTRACT from gen import get_pretrained_models, get_output, setup_model_parallel os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "50505" local_rank, world_size = setup_model_parallel() generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size) history = [] def chat(user_input): bot_response = get_output(generator, user_input)[0] history.append({ "role": "user", "content": user_input }) history.append({ "role": "system", "content": bot_response }) response = "" for word in bot_response.split(" "): response += word + " " yield [(user_input, response)] with gr.Blocks(css = """#col_container {width: 700px; margin-left: auto; margin-right: auto;} #chatbot {height: 400px; overflow: auto;}""") as demo: gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") with gr.Column(elem_id='col_container'): chatbot = gr.Chatbot(elem_id='chatbot') textbox = gr.Textbox(placeholder="Enter a prompt") textbox.submit(chat, textbox, chatbot) demo.queue(api_open=False).launch()