import os import torch import gradio as gr from gen import get_pretrained_models, get_output, setup_model_parallel torch.cuda.set_device(0) torch.cuda.empty_cache() os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "50505" local_rank, world_size = setup_model_parallel() generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size) def chat(user_input): bot_response = get_output(generator, user_input)[0] response = "" for word in bot_response.split(" "): response += word + " " yield [(user_input, response)] with gr.Blocks() as demo: chatbot = gr.Chatbot() textbox = gr.Textbox("Hello, how are you doing today?") textbox.submit(chat, textbox, chatbot) demo.queue(api_open=False).launch()