Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import argparse | |
import gradio as gr | |
import torch | |
from model import IntentPredictModel | |
from transformers import (T5Tokenizer, | |
GPT2Tokenizer, GPT2Config, GPT2LMHeadModel) | |
from diffusers import StableDiffusionPipeline | |
from chatbot import Chat | |
def main(args): | |
# Intent Prediction | |
print("Loading Intent Prediction Classifier...") | |
## tokenizer | |
intent_predict_tokenizer = T5Tokenizer.from_pretrained(args.intent_predict_model_name, truncation_side='left') | |
intent_predict_tokenizer.add_special_tokens({'sep_token': '[SEP]'}) | |
# model | |
intent_predict_model = IntentPredictModel(pretrained_model_name_or_path=args.intent_predict_model_name, num_classes=2) | |
intent_predict_model.load_state_dict(torch.load(args.intent_predict_model_weights_path, map_location=args.device)) | |
print("Intent Prediction Classifier loading completed.") | |
# Textual Dialogue Response Generator | |
print("Loading Textual Dialogue Response Generator...") | |
## tokenizer | |
text_dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.text_dialog_model_name, truncation_side='left') | |
text_dialog_tokenizer.add_tokens(['[UTT]', '[DST]']) | |
print(len(text_dialog_tokenizer)) | |
# config | |
text_dialog_config = GPT2Config.from_pretrained(args.text_dialog_model_name) | |
if len(text_dialog_tokenizer) > text_dialog_config.vocab_size: | |
text_dialog_config.vocab_size = len(text_dialog_tokenizer) | |
# load model weights | |
text_dialog_model = GPT2LMHeadModel.from_pretrained(args.text_dialog_model_weights_path, config=text_dialog_config) | |
print("Textual Dialogue Response Generator loading completed.") | |
# Text-to-Image Translator | |
print("Loading Text-to-Image Translator...") | |
text2image_model = StableDiffusionPipeline.from_pretrained(args.text2image_model_weights_path, torch_dtype=torch.float32) | |
print("Text-to-Image Translator loading completed.") | |
chat = Chat(intent_predict_model, intent_predict_tokenizer, | |
text_dialog_model, text_dialog_tokenizer, | |
text2image_model, | |
args.device) | |
title = """<h1 align="center">Demo of Tiger</h1>""" | |
description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>""" | |
description2 = """<h2>Input text start chatting!</h2>""" | |
hr = """<hr>""" | |
description_input = """<h3>Input: text (English)</h3>""" | |
description_output = """<h3>Output: text / image</h3>""" | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description1) | |
gr.Markdown(description2) | |
gr.Markdown(hr) | |
gr.Markdown(description_input) | |
gr.Markdown(description_output) | |
with gr.Row(): | |
with gr.Column(scale=0.33): | |
num_beams = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
interactive=True, | |
label="beam search numbers", | |
) | |
text2image_seed = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=42, | |
step=1, | |
interactive=True, | |
label="seed for text-to-image", | |
) | |
start = gr.Button("Start Chat", variant="primary") | |
clear = gr.Button("Restart Chat (Clear dialogue history)", interactive=False) | |
with gr.Column(): | |
chat_state = gr.State() | |
chatbot = gr.Chatbot(label='Tiger') | |
text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False) | |
start.click(chat.start_chat, [chat_state], [text_input, start, clear, chat_state]) | |
text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state]) | |
clear.click(chat.restart_chat, [chat_state], [chatbot, text_input, start, clear, chat_state], queue=False) | |
demo.launch(share=False, enable_queue=False) | |
if __name__ == "__main__": | |
intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth") | |
text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth") | |
text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic") | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--intent_predict_model_name', type=str, default="t5-base") | |
parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path) | |
parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium") | |
parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path) | |
parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path) | |
parser.add_argument('--device', default="cpu") | |
args = parser.parse_args() | |
main(args) | |