Spaces:
Runtime error
Runtime error
friedrichor
commited on
Commit
•
6a0ca81
1
Parent(s):
9a40c1f
update
Browse files- app.py +11 -7
- chatbot.py +20 -2
app.py
CHANGED
@@ -49,14 +49,16 @@ def main(args):
|
|
49 |
|
50 |
title = """<h1 align="center">Demo of Tiger</h1>"""
|
51 |
description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
|
52 |
-
description2 = """<
|
53 |
-
|
54 |
-
|
|
|
55 |
|
56 |
with gr.Blocks() as demo:
|
57 |
gr.Markdown(title)
|
58 |
gr.Markdown(description1)
|
59 |
gr.Markdown(description2)
|
|
|
60 |
gr.Markdown(description_input)
|
61 |
gr.Markdown(description_output)
|
62 |
|
@@ -78,15 +80,17 @@ def main(args):
|
|
78 |
interactive=True,
|
79 |
label="seed for text-to-image",
|
80 |
)
|
81 |
-
|
|
|
82 |
|
83 |
with gr.Column():
|
84 |
chat_state = gr.State()
|
85 |
chatbot = gr.Chatbot(label='Tiger')
|
86 |
-
text_input = gr.Textbox(label='User', placeholder=
|
87 |
-
|
|
|
88 |
text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
|
89 |
-
clear.click(
|
90 |
|
91 |
demo.launch(share=False, enable_queue=False)
|
92 |
|
|
|
49 |
|
50 |
title = """<h1 align="center">Demo of Tiger</h1>"""
|
51 |
description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
|
52 |
+
description2 = """<h2>Input text start chatting!</h2>"""
|
53 |
+
hr = """<hr>"""
|
54 |
+
description_input = """<h3>Input: text (English)</h3>"""
|
55 |
+
description_output = """<h3>Output: text / image</h3>"""
|
56 |
|
57 |
with gr.Blocks() as demo:
|
58 |
gr.Markdown(title)
|
59 |
gr.Markdown(description1)
|
60 |
gr.Markdown(description2)
|
61 |
+
gr.Markdown(hr)
|
62 |
gr.Markdown(description_input)
|
63 |
gr.Markdown(description_output)
|
64 |
|
|
|
80 |
interactive=True,
|
81 |
label="seed for text-to-image",
|
82 |
)
|
83 |
+
start = gr.Button("Start Chat", variant="primary")
|
84 |
+
clear = gr.Button("Restart Chat (Clear dialogue history)", interactive=False)
|
85 |
|
86 |
with gr.Column():
|
87 |
chat_state = gr.State()
|
88 |
chatbot = gr.Chatbot(label='Tiger')
|
89 |
+
text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
|
90 |
+
|
91 |
+
start.click(chat.start_chat, [chat_state], [text_input, start, clear, chat_state])
|
92 |
text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
|
93 |
+
clear.click(chat.restart_chat, [chat_state], [chatbot, text_input, start, clear, chat_state], queue=False)
|
94 |
|
95 |
demo.launch(share=False, enable_queue=False)
|
96 |
|
chatbot.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
from model import IntentPredictModel
|
@@ -34,8 +36,21 @@ class Chat:
|
|
34 |
|
35 |
self.context_for_intent = ""
|
36 |
self.context_for_text_dialog = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def intent_predict(self, context: str):
|
|
|
39 |
context_encoded = self.intent_predict_tokenizer.encode_plus(
|
40 |
text=context,
|
41 |
add_special_tokens=True,
|
@@ -67,7 +82,7 @@ class Chat:
|
|
67 |
|
68 |
generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
|
69 |
max_new_tokens=64, min_new_tokens=3,
|
70 |
-
do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=
|
71 |
no_repeat_ngram_size=3,
|
72 |
bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
|
73 |
forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
|
@@ -90,6 +105,9 @@ class Chat:
|
|
90 |
return response_str
|
91 |
|
92 |
def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
|
|
|
|
|
|
|
93 |
print(f"User: {message}")
|
94 |
# process context
|
95 |
if self.context_for_intent == "":
|
@@ -124,7 +142,7 @@ class Chat:
|
|
124 |
self.context_for_intent += " [SEP] " + response
|
125 |
self.context_for_text_dialog += "[DST] " + response
|
126 |
|
127 |
-
chat_history.append((message, (save_image_path,
|
128 |
|
129 |
else:
|
130 |
print(f"Bot: {response}")
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import gradio as gr
|
4 |
+
from datetime import datetime
|
5 |
|
6 |
import torch
|
7 |
from model import IntentPredictModel
|
|
|
36 |
|
37 |
self.context_for_intent = ""
|
38 |
self.context_for_text_dialog = ""
|
39 |
+
|
40 |
+
def start_chat(self, chat_state):
|
41 |
+
self.context_for_intent = ""
|
42 |
+
self.context_for_text_dialog = ""
|
43 |
+
|
44 |
+
return gr.update(interactive=True, placeholder='input the text (English).'), gr.update(value="Start Chat", interactive=False), gr.update(value="Restart Chat (Clear dialogue history)", interactive=True), chat_state
|
45 |
+
|
46 |
+
def restart_chat(self, chat_state):
|
47 |
+
self.context_for_intent = ""
|
48 |
+
self.context_for_text_dialog = ""
|
49 |
+
|
50 |
+
return None, gr.update(interactive=False, placeholder='Please click the "Start Chat" button.'), gr.update(value="Start Chat", interactive=True), gr.update(value="Restart Chat (Clear dialogue history)", interactive=False), chat_state
|
51 |
|
52 |
def intent_predict(self, context: str):
|
53 |
+
print(f"context = {context}")
|
54 |
context_encoded = self.intent_predict_tokenizer.encode_plus(
|
55 |
text=context,
|
56 |
add_special_tokens=True,
|
|
|
82 |
|
83 |
generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
|
84 |
max_new_tokens=64, min_new_tokens=3,
|
85 |
+
do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=num_beams,
|
86 |
no_repeat_ngram_size=3,
|
87 |
bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
|
88 |
forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
|
|
|
105 |
return response_str
|
106 |
|
107 |
def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
|
108 |
+
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
|
109 |
+
print("=" * 50)
|
110 |
+
print(f"Time: {current_time}")
|
111 |
print(f"User: {message}")
|
112 |
# process context
|
113 |
if self.context_for_intent == "":
|
|
|
142 |
self.context_for_intent += " [SEP] " + response
|
143 |
self.context_for_text_dialog += "[DST] " + response
|
144 |
|
145 |
+
chat_history.append((message, (save_image_path, f"Generated Image Caption: {response}")))
|
146 |
|
147 |
else:
|
148 |
print(f"Bot: {response}")
|