Tiger / chatbot.py
friedrichor's picture
update
a6d3762
raw
history blame
6.82 kB
import os
import sys
import torch
from model import IntentPredictModel
from transformers import T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
from diffusers import StableDiffusionPipeline
class Chat:
def __init__(
self,
intent_predict_model: IntentPredictModel,
intent_predict_tokenizer: T5Tokenizer,
text_dialog_model: GPT2LMHeadModel,
text_dialog_tokenizer: GPT2Tokenizer,
text2image_model: StableDiffusionPipeline,
device="cuda:0"
):
self.intent_predict_model = intent_predict_model.to(device)
self.intent_predict_tokenizer = intent_predict_tokenizer
self.text_dialog_model = text_dialog_model.to(device)
self.text_dialog_tokenizer = text_dialog_tokenizer
self.text2image_model = text2image_model.to(device)
self.device = device
self.extra_prompt = {"human": ", facing the camera, photograph, highly detailed face, depth of field, moody light, style by Yasmin Albatoul, Harry Fayt, centered, extremely detailed, Nikon D850, award winning photography",
"others": ", depth of field. bokeh. soft light. by Yasmin Albatoul, Harry Fayt. centered. extremely detailed. Nikon D850, (35mm|50mm|85mm). award winning photography."}
self.human_words = ["man", "men", "woman", "women", "people", "person", "human", "male", "female", "boy", "girl", "child", "kid", "baby", "player"]
self.negative_prompt="cartoon, anime, ugly, asian, (aged, white beard, black skin, wrinkle:1.1), (bad proportions, unnatural feature, incongruous feature:1.4), (blurry, un-sharp, fuzzy, un-detailed skin:1.2), (facial contortion, poorly drawn face, deformed iris, deformed pupils:1.3), (mutated hands and fingers:1.5), disconnected hands, disconnected limbs"
self.save_images_folder = os.path.join(sys.path[0], "generated_images")
os.makedirs(self.save_images_folder, exist_ok=True)
self.context_for_intent = ""
self.context_for_text_dialog = ""
def intent_predict(self, context: str):
context_encoded = self.intent_predict_tokenizer.encode_plus(
text=context,
add_special_tokens=True,
truncation=True,
max_length=512,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = context_encoded['input_ids'].to(self.device)
attention_mask = context_encoded['attention_mask'].to(self.device)
pred_logits = self.intent_predict_model(input_ids=input_ids, attention_mask=attention_mask).logits
pred_label = torch.max(pred_logits, dim=1)[1]
return True if pred_label else False
def generate_response(self, context: str, share_photo: bool, num_beams: int):
tokenizer = self.text_dialog_tokenizer
tag_list = ["[UTT]", "[DST]"] # 文本回复以 [UTT] 开头, 图像描述以 [DST] 开头
tag_id_dic = {tag: tokenizer.convert_tokens_to_ids(tag) for tag in tag_list}
tag = "[DST]" if share_photo else "[UTT]"
bad_words = ["[UTT] [UTT]", "[UTT] [DST]", "[UTT] <|endoftext|>", "[DST] [UTT]", "[DST] [DST]", "[DST] <|endoftext|>"]
input_ids = tokenizer.encode(
context,
add_special_tokens=False,
return_tensors='pt'
)
generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
max_new_tokens=64, min_new_tokens=3,
do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5,
no_repeat_ngram_size=3,
bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
end, i = 0, 0
for i, token in enumerate(generated_tokens):
if i == 0: # 由于forced_decoder_ids的定义, generated_tokens第1个token必为tag, 故从第2个token开始
continue
if token in tag_list:
end = i
break
if end == 0 and i != 0: # 可能遇不到tag
end = len(generated_tokens)
response_tokens = generated_tokens[1:end]
response_str = tokenizer.convert_tokens_to_string(response_tokens).lstrip()
return response_str
def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
# process context
if self.context_for_intent == "":
self.context_for_intent += message
else:
self.context_for_intent += " [SEP] " + message
self.context_for_text_dialog += "[UTT] " + message
share_photo = self.intent_predict(self.context_for_intent)
response = self.generate_response(self.context_for_text_dialog, share_photo, num_beams)
if share_photo:
print(f"Image Caption: {response}")
type = "others"
for human_word in self.human_words:
if human_word in response:
type = "human"
break
caption = response + self.extra_prompt[type]
generator = torch.Generator(device=self.device).manual_seed(text2image_seed)
image = self.text2image_model(
prompt=caption,
negative_prompt=self.negative_prompt,
num_inference_steps=20,
guidance_scale=7.5,
generator=generator).images[0]
save_image_path = f"{self.save_images_folder}/{response}.png"
image.save(save_image_path)
self.context_for_intent += " [SEP] " + response
self.context_for_text_dialog += "[DST] " + response
chat_history.append((message, (save_image_path, None)))
else:
print(f"Bot: {response}")
self.context_for_intent += " [SEP] " + response
self.context_for_text_dialog += "[UTT] " + response
chat_history.append((message, response))
return "", chat_history, chat_state