Spaces:
Runtime error
Runtime error
File size: 4,841 Bytes
b03a999 553d99f 3630c9b 553d99f 3802cb6 553d99f b03a999 553d99f b03a999 553d99f b03a999 553d99f b03a999 553d99f 2ad304c 553d99f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import sys
import argparse
import gradio as gr
import torch
from model import IntentPredictModel
from transformers import (T5Tokenizer,
GPT2Tokenizer, GPT2Config, GPT2LMHeadModel)
from diffusers import StableDiffusionPipeline
from chatbot import Chat
def main(args):
# Intent Prediction
print("Loading Intent Prediction Classifier...")
## tokenizer
intent_predict_tokenizer = T5Tokenizer.from_pretrained(args.intent_predict_model_name, truncation_side='left')
intent_predict_tokenizer.add_special_tokens({'sep_token': '[SEP]'})
# model
intent_predict_model = IntentPredictModel(pretrained_model_name_or_path=args.intent_predict_model_name, num_classes=2)
intent_predict_model.load_state_dict(torch.load(args.intent_predict_model_weights_path, map_location=args.device))
print("Intent Prediction Classifier loading completed.")
# Textual Dialogue Response Generator
print("Loading Textual Dialogue Response Generator...")
## tokenizer
text_dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.text_dialog_model_name, truncation_side='left')
text_dialog_tokenizer.add_tokens(['[UTT]', '[DST]'])
print(len(text_dialog_tokenizer))
# config
text_dialog_config = GPT2Config.from_pretrained(args.text_dialog_model_name)
if len(text_dialog_tokenizer) > text_dialog_config.vocab_size:
text_dialog_config.vocab_size = len(text_dialog_tokenizer)
# load model weights
text_dialog_model = GPT2LMHeadModel.from_pretrained(args.text_dialog_model_weights_path, config=text_dialog_config)
print("Textual Dialogue Response Generator loading completed.")
# Text-to-Image Translator
print("Loading Text-to-Image Translator...")
text2image_model = StableDiffusionPipeline.from_pretrained(args.text2image_model_weights_path, torch_dtype=torch.float32)
print("Text-to-Image Translator loading completed.")
chat = Chat(intent_predict_model, intent_predict_tokenizer,
text_dialog_model, text_dialog_tokenizer,
text2image_model,
args.device)
title = """<h1 align="center">Demo of Tiger</h1>"""
description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
description2 = """<h3>Input text start chatting!</h3>"""
description_input = """<h3>Input: text</h3>"""
description_output = """<h3>Output: text / image</h3>"""
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description1)
gr.Markdown(description2)
gr.Markdown(description_input)
gr.Markdown(description_output)
with gr.Row():
with gr.Column(scale=0.33):
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
interactive=True,
label="beam search numbers",
)
text2image_seed = gr.Slider(
minimum=1,
maximum=100,
value=42,
step=1,
interactive=True,
label="seed for text-to-image",
)
clear = gr.Button("Restart (Clear dialogue history)")
with gr.Column():
chat_state = gr.State()
chatbot = gr.Chatbot(label='Tiger')
text_input = gr.Textbox(label='User', placeholder='Please input the text.')
text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=False, enable_queue=False)
if __name__ == "__main__":
intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth")
text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth")
text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic")
parser = argparse.ArgumentParser()
parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path)
parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path)
parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path)
parser.add_argument('--device', default="cpu")
args = parser.parse_args()
main(args)
|