heziiiii commited on
Commit
3ce1983
1 Parent(s): 8478c00

Upload cog_tag5.py

Browse files
Files changed (1) hide show
  1. cog_tag5.py +215 -0
cog_tag5.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import shutil
4
+ from tqdm import tqdm
5
+
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
8
+
9
+
10
+
11
+ def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True):
12
+
13
+ MODEL_PATH = model_pth
14
+ TOKENIZER_PATH = token_pth
15
+ DEVICE = device
16
+
17
+
18
+
19
+
20
+ if is_bf16:
21
+ torch_type = torch.bfloat16
22
+ else:
23
+ torch_type = torch.float16
24
+
25
+ print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
26
+
27
+ tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
28
+
29
+ if is_quant:
30
+ with torch.cuda.device(DEVICE):
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_PATH,
33
+ torch_dtype=torch_type,
34
+ low_cpu_mem_usage=True,
35
+ load_in_4bit=True,
36
+ trust_remote_code=True
37
+ ).eval()
38
+ else:
39
+ with torch.cuda.device(DEVICE):
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_PATH,
42
+ torch_dtype=torch_type,
43
+ low_cpu_mem_usage=True,
44
+ load_in_4bit = is_quant is not None,
45
+ trust_remote_code=True
46
+ ).eval()
47
+ return (model, tokenizer)
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+ def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False):
57
+
58
+ if image_path == '':
59
+ print('You did not enter image path, the following will be a plain text conversation.')
60
+ image = None
61
+ text_only_first_query = True
62
+ else:
63
+ image = Image.open(image_path).convert('RGB')
64
+
65
+ history = []
66
+ text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:"
67
+
68
+
69
+
70
+ if image is None:
71
+ if text_only_first_query:
72
+ query = text_only_template.format(query)
73
+ text_only_first_query = False
74
+ else:
75
+ old_prompt = ''
76
+ for _, (old_query, response) in enumerate(history):
77
+ old_prompt += old_query + " " + response + "\n"
78
+ query = old_prompt + "USER: {} ASSISTANT:".format(query)
79
+
80
+ if image is None:
81
+ input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base')
82
+ else:
83
+ input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image])
84
+
85
+ inputs = {
86
+ 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device),
87
+ 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device),
88
+ 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device),
89
+ 'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None,
90
+ }
91
+ if 'cross_images' in input_by_model and input_by_model['cross_images']:
92
+ inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]]
93
+
94
+ # add any transformers params here.
95
+ gen_kwargs = {"max_length": 2048,
96
+ "do_sample": False} # "temperature": 0.9
97
+ with torch.no_grad():
98
+ outputs = model[0].generate(**inputs, **gen_kwargs)
99
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
100
+ response = model[1].decode(outputs[0])
101
+ response = response.split("</s>")[0]
102
+
103
+ print("\nCog:", response)
104
+ # history.append((query, response))
105
+ return response
106
+
107
+
108
+ def read_tag(txt_pth,split=",",is_list=True):
109
+ with open (txt_pth, "r") as f:
110
+ tag_str = f.read()
111
+ if is_list:
112
+ tag_list = tag_str.split(split)
113
+ for i in range(len(tag_list)):
114
+ tag_list[i] = tag_list[i].strip()
115
+
116
+ return tag_list
117
+ else:
118
+ return tag_str
119
+
120
+
121
+ if __name__ == '__main__':
122
+ # image_path = "/home2/ywt/gelbooru_8574461.jpg"
123
+ # tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
124
+
125
+
126
+ # tag = read_tag(tag_path,is_list=False)
127
+ # query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
128
+ # cog_tag(image_path, model)
129
+ # txt = cog_tag(image_path, model, query=query)
130
+
131
+ # out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
132
+ # with open(out_file,"w") as f:
133
+ # f.write(txt)
134
+ # print(f"Created {out_file}")
135
+
136
+ model = load_model(device="cuda:5")
137
+ # DIR = os.listdir("/home2/ywt/pixiv")
138
+ # for i in range(len(DIR)):
139
+ # DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i])
140
+
141
+ image_dirs = ["/home2/ywt/image-webp"]
142
+
143
+ for image_dir in image_dirs:
144
+
145
+ for file in tqdm(os.listdir(image_dir)):
146
+
147
+ #is_image
148
+ if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")):
149
+ continue
150
+ image_path = os.path.join(image_dir,file)
151
+ tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt")
152
+ if not os.path.exists(tag_path):
153
+ continue
154
+ tag = read_tag(tag_path,is_list=False).replace("|||","")
155
+ query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
156
+ #cog_tag(image_path, model)
157
+ if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")):
158
+ continue
159
+
160
+ txt = cog_tag(image_path, model, query=query)
161
+
162
+ out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
163
+ with open(out_file,"w") as f:
164
+ f.write(txt)
165
+ print(f"Created {out_file}")
166
+
167
+
168
+
169
+
170
+ # import os
171
+ # import concurrent.futures
172
+ # from tqdm import tqdm
173
+ # import itertools
174
+
175
+ # def process_image(image_path, model):
176
+ # tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
177
+ # if not os.path.exists(tag_path):
178
+ # return image_path, None
179
+ # tag = read_tag(tag_path,is_list=False)
180
+ # query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
181
+ # txt = cog_tag(image_path, model, query=query)
182
+ # return image_path, txt
183
+
184
+ # root_dir = "/home2/ywt/pixiv"
185
+ # device_ids = [1, 2, 4, 5 ] # List of GPU device IDs
186
+
187
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5"
188
+ # # Load models
189
+ # models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids]
190
+
191
+ # # Calculate total number of images
192
+ # total_images = 0
193
+ # for image_dir in os.listdir(root_dir):
194
+ # image_dir = os.path.join(root_dir, image_dir)
195
+ # if os.path.isdir(image_dir):
196
+ # image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
197
+ # total_images += len(image_files)
198
+
199
+ # # Process images
200
+ # progress_bar = tqdm(total=total_images)
201
+ # models_cycle = itertools.cycle(models)
202
+ # for image_dir in os.listdir(root_dir):
203
+ # image_dir = os.path.join(root_dir, image_dir)
204
+ # if os.path.isdir(image_dir):
205
+ # image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
206
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
207
+ # for image_path, txt in executor.map(process_image, image_files, models_cycle):
208
+ # if txt is not None:
209
+ # out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
210
+ # with open(out_file,"w") as f:
211
+ # f.write(txt)
212
+ # progress_bar.update()
213
+ # progress_bar.close()
214
+
215
+