Spaces:
Runtime error
Runtime error
File size: 6,819 Bytes
a6d3762 553d99f a6d3762 553d99f a6d3762 553d99f a6d3762 553d99f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import sys
import torch
from model import IntentPredictModel
from transformers import T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
from diffusers import StableDiffusionPipeline
class Chat:
def __init__(
self,
intent_predict_model: IntentPredictModel,
intent_predict_tokenizer: T5Tokenizer,
text_dialog_model: GPT2LMHeadModel,
text_dialog_tokenizer: GPT2Tokenizer,
text2image_model: StableDiffusionPipeline,
device="cuda:0"
):
self.intent_predict_model = intent_predict_model.to(device)
self.intent_predict_tokenizer = intent_predict_tokenizer
self.text_dialog_model = text_dialog_model.to(device)
self.text_dialog_tokenizer = text_dialog_tokenizer
self.text2image_model = text2image_model.to(device)
self.device = device
self.extra_prompt = {"human": ", facing the camera, photograph, highly detailed face, depth of field, moody light, style by Yasmin Albatoul, Harry Fayt, centered, extremely detailed, Nikon D850, award winning photography",
"others": ", depth of field. bokeh. soft light. by Yasmin Albatoul, Harry Fayt. centered. extremely detailed. Nikon D850, (35mm|50mm|85mm). award winning photography."}
self.human_words = ["man", "men", "woman", "women", "people", "person", "human", "male", "female", "boy", "girl", "child", "kid", "baby", "player"]
self.negative_prompt="cartoon, anime, ugly, asian, (aged, white beard, black skin, wrinkle:1.1), (bad proportions, unnatural feature, incongruous feature:1.4), (blurry, un-sharp, fuzzy, un-detailed skin:1.2), (facial contortion, poorly drawn face, deformed iris, deformed pupils:1.3), (mutated hands and fingers:1.5), disconnected hands, disconnected limbs"
self.save_images_folder = os.path.join(sys.path[0], "generated_images")
os.makedirs(self.save_images_folder, exist_ok=True)
self.context_for_intent = ""
self.context_for_text_dialog = ""
def intent_predict(self, context: str):
context_encoded = self.intent_predict_tokenizer.encode_plus(
text=context,
add_special_tokens=True,
truncation=True,
max_length=512,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = context_encoded['input_ids'].to(self.device)
attention_mask = context_encoded['attention_mask'].to(self.device)
pred_logits = self.intent_predict_model(input_ids=input_ids, attention_mask=attention_mask).logits
pred_label = torch.max(pred_logits, dim=1)[1]
return True if pred_label else False
def generate_response(self, context: str, share_photo: bool, num_beams: int):
tokenizer = self.text_dialog_tokenizer
tag_list = ["[UTT]", "[DST]"] # 文本回复以 [UTT] 开头, 图像描述以 [DST] 开头
tag_id_dic = {tag: tokenizer.convert_tokens_to_ids(tag) for tag in tag_list}
tag = "[DST]" if share_photo else "[UTT]"
bad_words = ["[UTT] [UTT]", "[UTT] [DST]", "[UTT] <|endoftext|>", "[DST] [UTT]", "[DST] [DST]", "[DST] <|endoftext|>"]
input_ids = tokenizer.encode(
context,
add_special_tokens=False,
return_tensors='pt'
)
generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
max_new_tokens=64, min_new_tokens=3,
do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5,
no_repeat_ngram_size=3,
bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]], # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)
pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
end, i = 0, 0
for i, token in enumerate(generated_tokens):
if i == 0: # 由于forced_decoder_ids的定义, generated_tokens第1个token必为tag, 故从第2个token开始
continue
if token in tag_list:
end = i
break
if end == 0 and i != 0: # 可能遇不到tag
end = len(generated_tokens)
response_tokens = generated_tokens[1:end]
response_str = tokenizer.convert_tokens_to_string(response_tokens).lstrip()
return response_str
def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
# process context
if self.context_for_intent == "":
self.context_for_intent += message
else:
self.context_for_intent += " [SEP] " + message
self.context_for_text_dialog += "[UTT] " + message
share_photo = self.intent_predict(self.context_for_intent)
response = self.generate_response(self.context_for_text_dialog, share_photo, num_beams)
if share_photo:
print(f"Image Caption: {response}")
type = "others"
for human_word in self.human_words:
if human_word in response:
type = "human"
break
caption = response + self.extra_prompt[type]
generator = torch.Generator(device=self.device).manual_seed(text2image_seed)
image = self.text2image_model(
prompt=caption,
negative_prompt=self.negative_prompt,
num_inference_steps=20,
guidance_scale=7.5,
generator=generator).images[0]
save_image_path = f"{self.save_images_folder}/{response}.png"
image.save(save_image_path)
self.context_for_intent += " [SEP] " + response
self.context_for_text_dialog += "[DST] " + response
chat_history.append((message, (save_image_path, None)))
else:
print(f"Bot: {response}")
self.context_for_intent += " [SEP] " + response
self.context_for_text_dialog += "[UTT] " + response
chat_history.append((message, response))
return "", chat_history, chat_state
|