File size: 8,799 Bytes
3ce1983 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import torch
import os
import shutil
from tqdm import tqdm
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True):
MODEL_PATH = model_pth
TOKENIZER_PATH = token_pth
DEVICE = device
if is_bf16:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
if is_quant:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=True,
trust_remote_code=True
).eval()
else:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit = is_quant is not None,
trust_remote_code=True
).eval()
return (model, tokenizer)
def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False):
if image_path == '':
print('You did not enter image path, the following will be a plain text conversation.')
image = None
text_only_first_query = True
else:
image = Image.open(image_path).convert('RGB')
history = []
text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:"
if image is None:
if text_only_first_query:
query = text_only_template.format(query)
text_only_first_query = False
else:
old_prompt = ''
for _, (old_query, response) in enumerate(history):
old_prompt += old_query + " " + response + "\n"
query = old_prompt + "USER: {} ASSISTANT:".format(query)
if image is None:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base')
else:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image])
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device),
'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None,
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]]
# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
outputs = model[0].generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = model[1].decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
# history.append((query, response))
return response
def read_tag(txt_pth,split=",",is_list=True):
with open (txt_pth, "r") as f:
tag_str = f.read()
if is_list:
tag_list = tag_str.split(split)
for i in range(len(tag_list)):
tag_list[i] = tag_list[i].strip()
return tag_list
else:
return tag_str
if __name__ == '__main__':
# image_path = "/home2/ywt/gelbooru_8574461.jpg"
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# cog_tag(image_path, model)
# txt = cog_tag(image_path, model, query=query)
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# print(f"Created {out_file}")
model = load_model(device="cuda:5")
# DIR = os.listdir("/home2/ywt/pixiv")
# for i in range(len(DIR)):
# DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i])
image_dirs = ["/home2/ywt/image-webp"]
for image_dir in image_dirs:
for file in tqdm(os.listdir(image_dir)):
#is_image
if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")):
continue
image_path = os.path.join(image_dir,file)
tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt")
if not os.path.exists(tag_path):
continue
tag = read_tag(tag_path,is_list=False).replace("|||","")
query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
#cog_tag(image_path, model)
if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")):
continue
txt = cog_tag(image_path, model, query=query)
out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
with open(out_file,"w") as f:
f.write(txt)
print(f"Created {out_file}")
# import os
# import concurrent.futures
# from tqdm import tqdm
# import itertools
# def process_image(image_path, model):
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# if not os.path.exists(tag_path):
# return image_path, None
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# txt = cog_tag(image_path, model, query=query)
# return image_path, txt
# root_dir = "/home2/ywt/pixiv"
# device_ids = [1, 2, 4, 5 ] # List of GPU device IDs
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5"
# # Load models
# models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids]
# # Calculate total number of images
# total_images = 0
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# total_images += len(image_files)
# # Process images
# progress_bar = tqdm(total=total_images)
# models_cycle = itertools.cycle(models)
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# with concurrent.futures.ThreadPoolExecutor() as executor:
# for image_path, txt in executor.map(process_image, image_files, models_cycle):
# if txt is not None:
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# progress_bar.update()
# progress_bar.close()
|