hydit22 / cog_tag5.py
heziiiii's picture
Upload cog_tag5.py
3ce1983 verified
raw
history blame contribute delete
No virus
8.8 kB
import torch
import os
import shutil
from tqdm import tqdm
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True):
MODEL_PATH = model_pth
TOKENIZER_PATH = token_pth
DEVICE = device
if is_bf16:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
if is_quant:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=True,
trust_remote_code=True
).eval()
else:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit = is_quant is not None,
trust_remote_code=True
).eval()
return (model, tokenizer)
def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False):
if image_path == '':
print('You did not enter image path, the following will be a plain text conversation.')
image = None
text_only_first_query = True
else:
image = Image.open(image_path).convert('RGB')
history = []
text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:"
if image is None:
if text_only_first_query:
query = text_only_template.format(query)
text_only_first_query = False
else:
old_prompt = ''
for _, (old_query, response) in enumerate(history):
old_prompt += old_query + " " + response + "\n"
query = old_prompt + "USER: {} ASSISTANT:".format(query)
if image is None:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base')
else:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image])
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device),
'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None,
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]]
# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
outputs = model[0].generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = model[1].decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
# history.append((query, response))
return response
def read_tag(txt_pth,split=",",is_list=True):
with open (txt_pth, "r") as f:
tag_str = f.read()
if is_list:
tag_list = tag_str.split(split)
for i in range(len(tag_list)):
tag_list[i] = tag_list[i].strip()
return tag_list
else:
return tag_str
if __name__ == '__main__':
# image_path = "/home2/ywt/gelbooru_8574461.jpg"
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# cog_tag(image_path, model)
# txt = cog_tag(image_path, model, query=query)
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# print(f"Created {out_file}")
model = load_model(device="cuda:5")
# DIR = os.listdir("/home2/ywt/pixiv")
# for i in range(len(DIR)):
# DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i])
image_dirs = ["/home2/ywt/image-webp"]
for image_dir in image_dirs:
for file in tqdm(os.listdir(image_dir)):
#is_image
if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")):
continue
image_path = os.path.join(image_dir,file)
tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt")
if not os.path.exists(tag_path):
continue
tag = read_tag(tag_path,is_list=False).replace("|||","")
query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
#cog_tag(image_path, model)
if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")):
continue
txt = cog_tag(image_path, model, query=query)
out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
with open(out_file,"w") as f:
f.write(txt)
print(f"Created {out_file}")
# import os
# import concurrent.futures
# from tqdm import tqdm
# import itertools
# def process_image(image_path, model):
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# if not os.path.exists(tag_path):
# return image_path, None
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# txt = cog_tag(image_path, model, query=query)
# return image_path, txt
# root_dir = "/home2/ywt/pixiv"
# device_ids = [1, 2, 4, 5 ] # List of GPU device IDs
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5"
# # Load models
# models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids]
# # Calculate total number of images
# total_images = 0
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# total_images += len(image_files)
# # Process images
# progress_bar = tqdm(total=total_images)
# models_cycle = itertools.cycle(models)
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# with concurrent.futures.ThreadPoolExecutor() as executor:
# for image_path, txt in executor.map(process_image, image_files, models_cycle):
# if txt is not None:
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# progress_bar.update()
# progress_bar.close()