Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import torch | |
from model import IntentPredictModel | |
from transformers import T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer | |
from diffusers import StableDiffusionPipeline | |
class Chat: | |
def __init__( | |
self, | |
intent_predict_model: IntentPredictModel, | |
intent_predict_tokenizer: T5Tokenizer, | |
text_dialog_model: GPT2LMHeadModel, | |
text_dialog_tokenizer: GPT2Tokenizer, | |
text2image_model: StableDiffusionPipeline, | |
device="cuda:0" | |
): | |
self.intent_predict_model = intent_predict_model.to(device) | |
self.intent_predict_tokenizer = intent_predict_tokenizer | |
self.text_dialog_model = text_dialog_model.to(device) | |
self.text_dialog_tokenizer = text_dialog_tokenizer | |
self.text2image_model = text2image_model.to(device) | |
self.device = device | |
self.extra_prompt = {"human": ", facing the camera, photograph, highly detailed face, depth of field, moody light, style by Yasmin Albatoul, Harry Fayt, centered, extremely detailed, Nikon D850, award winning photography", | |
"others": ", depth of field. bokeh. soft light. by Yasmin Albatoul, Harry Fayt. centered. extremely detailed. Nikon D850, (35mm|50mm|85mm). award winning photography."} | |
self.human_words = ["man", "men", "woman", "women", "people", "person", "human", "male", "female", "boy", "girl", "child", "kid", "baby", "player"] | |
self.negative_prompt="cartoon, anime, ugly, asian, (aged, white beard, black skin, wrinkle:1.1), (bad proportions, unnatural feature, incongruous feature:1.4), (blurry, un-sharp, fuzzy, un-detailed skin:1.2), (facial contortion, poorly drawn face, deformed iris, deformed pupils:1.3), (mutated hands and fingers:1.5), disconnected hands, disconnected limbs" | |
self.save_images_folder = os.path.join(sys.path[0], "generated_images") | |
os.makedirs(self.save_images_folder, exist_ok=True) | |
self.context_for_intent = "" | |
self.context_for_text_dialog = "" | |
def intent_predict(self, context: str): | |
context_encoded = self.intent_predict_tokenizer.encode_plus( | |
text=context, | |
add_special_tokens=True, | |
truncation=True, | |
max_length=512, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
input_ids = context_encoded['input_ids'].to(self.device) | |
attention_mask = context_encoded['attention_mask'].to(self.device) | |
pred_logits = self.intent_predict_model(input_ids=input_ids, attention_mask=attention_mask).logits | |
pred_label = torch.max(pred_logits, dim=1)[1] | |
return True if pred_label else False | |
def generate_response(self, context: str, share_photo: bool, num_beams: int): | |
tokenizer = self.text_dialog_tokenizer | |
tag_list = ["[UTT]", "[DST]"] # 文本回复以 [UTT] 开头, 图像描述以 [DST] 开头 | |
tag_id_dic = {tag: tokenizer.convert_tokens_to_ids(tag) for tag in tag_list} | |
tag = "[DST]" if share_photo else "[UTT]" | |
bad_words = ["[UTT] [UTT]", "[UTT] [DST]", "[UTT] <|endoftext|>", "[DST] [UTT]", "[DST] [DST]", "[DST] <|endoftext|>"] | |
input_ids = tokenizer.encode( | |
context, | |
add_special_tokens=False, | |
return_tensors='pt' | |
) | |
generated_ids = self.text_dialog_model.generate(input_ids.to(self.device), | |
max_new_tokens=64, min_new_tokens=3, | |
do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5, | |
no_repeat_ngram_size=3, | |
bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids, | |
forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位) | |
pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id) | |
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
end, i = 0, 0 | |
for i, token in enumerate(generated_tokens): | |
if i == 0: # 由于forced_decoder_ids的定义, generated_tokens第1个token必为tag, 故从第2个token开始 | |
continue | |
if token in tag_list: | |
end = i | |
break | |
if end == 0 and i != 0: # 可能遇不到tag | |
end = len(generated_tokens) | |
response_tokens = generated_tokens[1:end] | |
response_str = tokenizer.convert_tokens_to_string(response_tokens).lstrip() | |
return response_str | |
def respond(self, message, num_beams, text2image_seed, chat_history, chat_state): | |
# process context | |
if self.context_for_intent == "": | |
self.context_for_intent += message | |
else: | |
self.context_for_intent += " [SEP] " + message | |
self.context_for_text_dialog += "[UTT] " + message | |
share_photo = self.intent_predict(self.context_for_intent) | |
response = self.generate_response(self.context_for_text_dialog, share_photo, num_beams) | |
if share_photo: | |
print(f"Image Caption: {response}") | |
type = "others" | |
for human_word in self.human_words: | |
if human_word in response: | |
type = "human" | |
break | |
caption = response + self.extra_prompt[type] | |
generator = torch.Generator(device=self.device).manual_seed(text2image_seed) | |
image = self.text2image_model( | |
prompt=caption, | |
negative_prompt=self.negative_prompt, | |
num_inference_steps=20, | |
guidance_scale=7.5, | |
generator=generator).images[0] | |
save_image_path = f"{self.save_images_folder}/{response}.png" | |
image.save(save_image_path) | |
self.context_for_intent += " [SEP] " + response | |
self.context_for_text_dialog += "[DST] " + response | |
chat_history.append((message, (save_image_path, None))) | |
else: | |
print(f"Bot: {response}") | |
self.context_for_intent += " [SEP] " + response | |
self.context_for_text_dialog += "[UTT] " + response | |
chat_history.append((message, response)) | |
return "", chat_history, chat_state | |