File size: 6,819 Bytes
a6d3762
 
553d99f
a6d3762
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d3762
 
 
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6d3762
 
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import sys

import torch
from model import IntentPredictModel
from transformers import T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
from diffusers import StableDiffusionPipeline


class Chat:
    def __init__(
        self, 
        intent_predict_model: IntentPredictModel, 
        intent_predict_tokenizer: T5Tokenizer,
        text_dialog_model: GPT2LMHeadModel,
        text_dialog_tokenizer: GPT2Tokenizer,
        text2image_model: StableDiffusionPipeline,
        device="cuda:0"
    ):
        self.intent_predict_model = intent_predict_model.to(device)
        self.intent_predict_tokenizer = intent_predict_tokenizer
        self.text_dialog_model = text_dialog_model.to(device)
        self.text_dialog_tokenizer = text_dialog_tokenizer
        self.text2image_model = text2image_model.to(device)
        self.device = device
        
        self.extra_prompt = {"human": ", facing the camera, photograph, highly detailed face, depth of field, moody light, style by Yasmin Albatoul, Harry Fayt, centered, extremely detailed, Nikon D850, award winning photography",
                             "others": ", depth of field. bokeh. soft light. by Yasmin Albatoul, Harry Fayt. centered. extremely detailed. Nikon D850, (35mm|50mm|85mm). award winning photography."}
        self.human_words = ["man", "men", "woman", "women", "people", "person", "human", "male", "female", "boy", "girl", "child", "kid", "baby", "player"]
        self.negative_prompt="cartoon, anime, ugly, asian, (aged, white beard, black skin, wrinkle:1.1), (bad proportions, unnatural feature, incongruous feature:1.4), (blurry, un-sharp, fuzzy, un-detailed skin:1.2), (facial contortion, poorly drawn face, deformed iris, deformed pupils:1.3), (mutated hands and fingers:1.5), disconnected hands, disconnected limbs"
        
        self.save_images_folder = os.path.join(sys.path[0], "generated_images")
        os.makedirs(self.save_images_folder, exist_ok=True)
        
        self.context_for_intent = ""
        self.context_for_text_dialog = ""
    
    def intent_predict(self, context: str):
        context_encoded = self.intent_predict_tokenizer.encode_plus(
            text=context,
            add_special_tokens=True,
            truncation=True,
            max_length=512,
            return_attention_mask=True,
            return_tensors='pt'
        )
        input_ids = context_encoded['input_ids'].to(self.device)
        attention_mask = context_encoded['attention_mask'].to(self.device)
        
        pred_logits = self.intent_predict_model(input_ids=input_ids, attention_mask=attention_mask).logits
        pred_label = torch.max(pred_logits, dim=1)[1]
    
        return True if pred_label else False
    
    def generate_response(self, context: str, share_photo: bool, num_beams: int):
        tokenizer = self.text_dialog_tokenizer
        tag_list = ["[UTT]", "[DST]"]  # 文本回复以 [UTT] 开头, 图像描述以 [DST] 开头
        tag_id_dic = {tag: tokenizer.convert_tokens_to_ids(tag) for tag in tag_list}
        tag = "[DST]" if share_photo else "[UTT]"
        bad_words = ["[UTT] [UTT]", "[UTT] [DST]", "[UTT] <|endoftext|>", "[DST] [UTT]", "[DST] [DST]", "[DST] <|endoftext|>"]
        
        input_ids = tokenizer.encode(
            context,
            add_special_tokens=False,
            return_tensors='pt'
        )
        
        generated_ids = self.text_dialog_model.generate(input_ids.to(self.device),
                                                        max_new_tokens=64, min_new_tokens=3,
                                                        do_sample=False, num_beams=num_beams, length_penalty=0.7, num_beam_groups=5, 
                                                        no_repeat_ngram_size=3,
                                                        bad_words_ids=tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids,
                                                        forced_decoder_ids=[[input_ids.shape[-1], tag_id_dic[tag]]],  # 指定生成的回复中第一个token始终是tag(因为generated_ids中包括input_ids, 所以是第input_ids.shape[-1]位)  
                                                        pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
        generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

        end, i = 0, 0
        for i, token in enumerate(generated_tokens):
            if i == 0:  # 由于forced_decoder_ids的定义, generated_tokens第1个token必为tag, 故从第2个token开始
                continue
            if token in tag_list:
                end = i
                break
        if end == 0 and i != 0:  # 可能遇不到tag
            end = len(generated_tokens)
        
        response_tokens = generated_tokens[1:end]
        response_str = tokenizer.convert_tokens_to_string(response_tokens).lstrip()

        return response_str
    
    def respond(self, message, num_beams, text2image_seed, chat_history, chat_state):
        # process context
        if self.context_for_intent == "":
            self.context_for_intent += message
        else:
            self.context_for_intent += " [SEP] " + message
        self.context_for_text_dialog += "[UTT] " + message
        
        share_photo = self.intent_predict(self.context_for_intent)
        response = self.generate_response(self.context_for_text_dialog, share_photo, num_beams)
        
        if share_photo:
            print(f"Image Caption: {response}")
            type = "others"
            for human_word in self.human_words:
                if human_word in response:
                    type = "human"
                    break
            caption = response + self.extra_prompt[type]
            
            generator = torch.Generator(device=self.device).manual_seed(text2image_seed)
            image = self.text2image_model(
                prompt=caption,
                negative_prompt=self.negative_prompt,
                num_inference_steps=20,
                guidance_scale=7.5,
                generator=generator).images[0]
            
            save_image_path = f"{self.save_images_folder}/{response}.png"
            image.save(save_image_path)

            self.context_for_intent += " [SEP] " + response
            self.context_for_text_dialog += "[DST] " + response
            
            chat_history.append((message, (save_image_path, None)))

        else:
            print(f"Bot: {response}")
            self.context_for_intent += " [SEP] " + response
            self.context_for_text_dialog += "[UTT] " + response
            
            chat_history.append((message, response))
            
        return "", chat_history, chat_state