File size: 8,799 Bytes
3ce1983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
import os
import shutil
from tqdm import tqdm

from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer



def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True):

    MODEL_PATH = model_pth
    TOKENIZER_PATH = token_pth
    DEVICE = device




    if is_bf16:
        torch_type = torch.bfloat16
    else:
        torch_type = torch.float16

    print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))

    tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)

    if is_quant:
        with torch.cuda.device(DEVICE):
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_PATH,
                torch_dtype=torch_type,
                low_cpu_mem_usage=True,
                load_in_4bit=True,
                trust_remote_code=True
            ).eval()
    else:
        with torch.cuda.device(DEVICE):
                model = AutoModelForCausalLM.from_pretrained(
                    MODEL_PATH,
                    torch_dtype=torch_type,
                    low_cpu_mem_usage=True,
                    load_in_4bit = is_quant is not None,
                    trust_remote_code=True
                ).eval()
    return (model, tokenizer)    
    






    
def cog_tag(image_path,  model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False):
    
    if image_path == '':
        print('You did not enter image path, the following will be a plain text conversation.')
        image = None
        text_only_first_query = True    
    else:
        image = Image.open(image_path).convert('RGB')
    
    history = []
    text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:"



    if image is None:
        if text_only_first_query:
            query = text_only_template.format(query)
            text_only_first_query = False
        else:
            old_prompt = ''
            for _, (old_query, response) in enumerate(history):
                old_prompt += old_query + " " + response + "\n"
            query = old_prompt + "USER: {} ASSISTANT:".format(query)

    if image is None:
        input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base')
    else:
        input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image])

    inputs = {
        'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device),
        'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device),
        'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device),
        'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None,
    }
    if 'cross_images' in input_by_model and input_by_model['cross_images']:
        inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]]

    # add any transformers params here.
    gen_kwargs = {"max_length": 2048,
                    "do_sample": False} # "temperature": 0.9
    with torch.no_grad():
        outputs = model[0].generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        response = model[1].decode(outputs[0])
        response = response.split("</s>")[0]

        print("\nCog:", response)
    # history.append((query, response))
    return response


def read_tag(txt_pth,split=",",is_list=True):
    with open (txt_pth, "r") as f:
         tag_str = f.read()
    if is_list:
        tag_list = tag_str.split(split)
        for i in range(len(tag_list)):
            tag_list[i] = tag_list[i].strip()

        return tag_list
    else:
        return tag_str
         

if __name__ == '__main__':
    # image_path = "/home2/ywt/gelbooru_8574461.jpg"
    # tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
    

    # tag = read_tag(tag_path,is_list=False)
    # query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
    # cog_tag(image_path, model)
    # txt = cog_tag(image_path, model, query=query) 
    
    # out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
    # with open(out_file,"w") as f:
    #     f.write(txt)
    # print(f"Created {out_file}")
    
    model = load_model(device="cuda:5")
    # DIR = os.listdir("/home2/ywt/pixiv")
    # for i in range(len(DIR)):
    #     DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i])
    
    image_dirs = ["/home2/ywt/image-webp"]
    
    for image_dir in image_dirs:
        
        for file in tqdm(os.listdir(image_dir)):
            
            #is_image
            if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")):
                continue
            image_path = os.path.join(image_dir,file)
            tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt")
            if not os.path.exists(tag_path):
                continue
            tag = read_tag(tag_path,is_list=False).replace("|||","")
            query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
            #cog_tag(image_path, model)
            if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")):
                continue
            
            txt = cog_tag(image_path, model, query=query) 
            
            out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
            with open(out_file,"w") as f:
                f.write(txt)
            print(f"Created {out_file}")

 


    # import os
    # import concurrent.futures
    # from tqdm import tqdm
    # import itertools

    # def process_image(image_path, model):
    #     tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
    #     if not os.path.exists(tag_path):
    #         return image_path, None
    #     tag = read_tag(tag_path,is_list=False)
    #     query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
    #     txt = cog_tag(image_path, model, query=query) 
    #     return image_path, txt

    # root_dir = "/home2/ywt/pixiv"
    # device_ids = [1, 2, 4, 5 ]  # List of GPU device IDs

    # os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5"
    # # Load models
    # models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids]

    # # Calculate total number of images
    # total_images = 0
    # for image_dir in os.listdir(root_dir):
    #     image_dir = os.path.join(root_dir, image_dir)
    #     if os.path.isdir(image_dir):
    #         image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
    #         total_images += len(image_files)

    # # Process images
    # progress_bar = tqdm(total=total_images)
    # models_cycle = itertools.cycle(models)
    # for image_dir in os.listdir(root_dir):
    #     image_dir = os.path.join(root_dir, image_dir)
    #     if os.path.isdir(image_dir):
    #         image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
    #         with concurrent.futures.ThreadPoolExecutor() as executor:
    #             for image_path, txt in executor.map(process_image, image_files, models_cycle):
    #                 if txt is not None:
    #                     out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
    #                     with open(out_file,"w") as f:
    #                         f.write(txt)
    #                 progress_bar.update()
    # progress_bar.close()