friedrichor commited on
Commit
6a0ca81
1 Parent(s): 9a40c1f
Files changed (2) hide show
  1. app.py +11 -7
  2. chatbot.py +20 -2
app.py CHANGED
@@ -49,14 +49,16 @@ def main(args):
49
 
50
  title = """<h1 align="center">Demo of Tiger</h1>"""
51
  description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
52
- description2 = """<h3>Input text start chatting!</h3>"""
53
- description_input = """<h3>Input: text</h3>"""
54
- description_output = """<h3>Output: text / image</h3>"""
 
55
 
56
  with gr.Blocks() as demo:
57
  gr.Markdown(title)
58
  gr.Markdown(description1)
59
  gr.Markdown(description2)
 
60
  gr.Markdown(description_input)
61
  gr.Markdown(description_output)
62
 
@@ -78,15 +80,17 @@ def main(args):
78
  interactive=True,
79
  label="seed for text-to-image",
80
  )
81
- clear = gr.Button("Restart (Clear dialogue history)")
 
82
 
83
  with gr.Column():
84
  chat_state = gr.State()
85
  chatbot = gr.Chatbot(label='Tiger')
86
- text_input = gr.Textbox(label='User', placeholder='Please input the text.')
87
-
 
88
  text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
89
- clear.click(lambda: None, None, chatbot, queue=False)
90
 
91
  demo.launch(share=False, enable_queue=False)
92
 
 
49
 
50
  title = """<h1 align="center">Demo of Tiger</h1>"""
51
  description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
52
+ description2 = """<h2>Input text start chatting!</h2>"""
53
+ hr = """<hr>"""
54
+ description_input = """<h3>Input: text (English)</h3>"""
55
+ description_output = """<h3>Output: text / image</h3>"""
56
 
57
  with gr.Blocks() as demo:
58
  gr.Markdown(title)
59
  gr.Markdown(description1)
60
  gr.Markdown(description2)
61
+ gr.Markdown(hr)
62
  gr.Markdown(description_input)
63
  gr.Markdown(description_output)
64
 
 
80
  interactive=True,
81
  label="seed for text-to-image",
82
  )
83
+ start = gr.Button("Start Chat", variant="primary")
84
+ clear = gr.Button("Restart Chat (Clear dialogue history)", interactive=False)
85
 
86
  with gr.Column():
87
  chat_state = gr.State()
88
  chatbot = gr.Chatbot(label='Tiger')
89
+ text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
90
+
91
+ start.click(chat.start_chat, [chat_state], [text_input, start, clear, chat_state])
92
  text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
93
+ clear.click(chat.restart_chat, [chat_state], [chatbot, text_input, start, clear, chat_state], queue=False)
94
 
95
  demo.launch(share=False, enable_queue=False)
96
 
chatbot.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import sys
 
 
3
 
4
  import torch
5
  from model import IntentPredictModel
@@ -34,8 +36,21 @@ class Chat:
34
 
35
  self.context_for_intent = ""
36
  self.context_for_text_dialog = ""
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def intent_predict(self, context: str):
 
39
  context_encoded = self.intent_predict_tokenizer.encode_plus(
40
  text=context,
41
  add_special_tokens=True,
@@ -67,7 +82,7 @@ class Chat:
67
 
68
  generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
69
  max_new_tokens=64, min_new_tokens=3,
70
- do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5,
71
  no_repeat_ngram_size=3,
72
  bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
73
  forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
@@ -90,6 +105,9 @@ class Chat:
90
  return response_str
91
 
92
  def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
 
 
 
93
  print(f"User: {message}")
94
  # process context
95
  if self.context_for_intent == "":
@@ -124,7 +142,7 @@ class Chat:
124
  self.context_for_intent += " [SEP] " + response
125
  self.context_for_text_dialog += "[DST] " + response
126
 
127
- chat_history.append((message, (save_image_path, None)))
128
 
129
  else:
130
  print(f"Bot: {response}")
 
1
  import os
2
  import sys
3
+ import gradio as gr
4
+ from datetime import datetime
5
 
6
  import torch
7
  from model import IntentPredictModel
 
36
 
37
  self.context_for_intent = ""
38
  self.context_for_text_dialog = ""
39
+
40
+ def start_chat(self, chat_state):
41
+ self.context_for_intent = ""
42
+ self.context_for_text_dialog = ""
43
+
44
+ return gr.update(interactive=True, placeholder='input the text (English).'), gr.update(value="Start Chat", interactive=False), gr.update(value="Restart Chat (Clear dialogue history)", interactive=True), chat_state
45
+
46
+ def restart_chat(self, chat_state):
47
+ self.context_for_intent = ""
48
+ self.context_for_text_dialog = ""
49
+
50
+ return None, gr.update(interactive=False, placeholder='Please click the "Start Chat" button.'), gr.update(value="Start Chat", interactive=True), gr.update(value="Restart Chat (Clear dialogue history)", interactive=False), chat_state
51
 
52
  def intent_predict(self, context: str):
53
+ print(f"context = {context}")
54
  context_encoded = self.intent_predict_tokenizer.encode_plus(
55
  text=context,
56
  add_special_tokens=True,
 
82
 
83
  generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
84
  max_new_tokens=64, min_new_tokens=3,
85
+ do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=num_beams,
86
  no_repeat_ngram_size=3,
87
  bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
88
  forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
 
105
  return response_str
106
 
107
  def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
108
+ current_time = datetime.now().strftime("%b%d_%H-%M-%S")
109
+ print("=" * 50)
110
+ print(f"Time: {current_time}")
111
  print(f"User: {message}")
112
  # process context
113
  if self.context_for_intent == "":
 
142
  self.context_for_intent += " [SEP] " + response
143
  self.context_for_text_dialog += "[DST] " + response
144
 
145
+ chat_history.append((message, (save_image_path, f"Generated Image Caption: {response}")))
146
 
147
  else:
148
  print(f"Bot: {response}")