import torch import os import shutil from tqdm import tqdm from PIL import Image from transformers import AutoModelForCausalLM, LlamaTokenizer def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True): MODEL_PATH = model_pth TOKENIZER_PATH = token_pth DEVICE = device if is_bf16: torch_type = torch.bfloat16 else: torch_type = torch.float16 print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) if is_quant: with torch.cuda.device(DEVICE): model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch_type, low_cpu_mem_usage=True, load_in_4bit=True, trust_remote_code=True ).eval() else: with torch.cuda.device(DEVICE): model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=torch_type, low_cpu_mem_usage=True, load_in_4bit = is_quant is not None, trust_remote_code=True ).eval() return (model, tokenizer) def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False): if image_path == '': print('You did not enter image path, the following will be a plain text conversation.') image = None text_only_first_query = True else: image = Image.open(image_path).convert('RGB') history = [] text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:" if image is None: if text_only_first_query: query = text_only_template.format(query) text_only_first_query = False else: old_prompt = '' for _, (old_query, response) in enumerate(history): old_prompt += old_query + " " + response + "\n" query = old_prompt + "USER: {} ASSISTANT:".format(query) if image is None: input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base') else: input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image]) inputs = { 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device), 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device), 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device), 'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None, } if 'cross_images' in input_by_model and input_by_model['cross_images']: inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]] # add any transformers params here. gen_kwargs = {"max_length": 2048, "do_sample": False} # "temperature": 0.9 with torch.no_grad(): outputs = model[0].generate(**inputs, **gen_kwargs) outputs = outputs[:, inputs['input_ids'].shape[1]:] response = model[1].decode(outputs[0]) response = response.split("")[0] print("\nCog:", response) # history.append((query, response)) return response def read_tag(txt_pth,split=",",is_list=True): with open (txt_pth, "r") as f: tag_str = f.read() if is_list: tag_list = tag_str.split(split) for i in range(len(tag_list)): tag_list[i] = tag_list[i].strip() return tag_list else: return tag_str if __name__ == '__main__': # image_path = "/home2/ywt/gelbooru_8574461.jpg" # tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt") # tag = read_tag(tag_path,is_list=False) # query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag # cog_tag(image_path, model) # txt = cog_tag(image_path, model, query=query) # out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt") # with open(out_file,"w") as f: # f.write(txt) # print(f"Created {out_file}") model = load_model(device="cuda:5") # DIR = os.listdir("/home2/ywt/pixiv") # for i in range(len(DIR)): # DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i]) image_dirs = ["/home2/ywt/image-webp"] for image_dir in image_dirs: for file in tqdm(os.listdir(image_dir)): #is_image if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")): continue image_path = os.path.join(image_dir,file) tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt") if not os.path.exists(tag_path): continue tag = read_tag(tag_path,is_list=False).replace("|||","") query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag #cog_tag(image_path, model) if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")): continue txt = cog_tag(image_path, model, query=query) out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt") with open(out_file,"w") as f: f.write(txt) print(f"Created {out_file}") # import os # import concurrent.futures # from tqdm import tqdm # import itertools # def process_image(image_path, model): # tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt") # if not os.path.exists(tag_path): # return image_path, None # tag = read_tag(tag_path,is_list=False) # query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag # txt = cog_tag(image_path, model, query=query) # return image_path, txt # root_dir = "/home2/ywt/pixiv" # device_ids = [1, 2, 4, 5 ] # List of GPU device IDs # os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5" # # Load models # models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids] # # Calculate total number of images # total_images = 0 # for image_dir in os.listdir(root_dir): # image_dir = os.path.join(root_dir, image_dir) # if os.path.isdir(image_dir): # image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))] # total_images += len(image_files) # # Process images # progress_bar = tqdm(total=total_images) # models_cycle = itertools.cycle(models) # for image_dir in os.listdir(root_dir): # image_dir = os.path.join(root_dir, image_dir) # if os.path.isdir(image_dir): # image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))] # with concurrent.futures.ThreadPoolExecutor() as executor: # for image_path, txt in executor.map(process_image, image_files, models_cycle): # if txt is not None: # out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt") # with open(out_file,"w") as f: # f.write(txt) # progress_bar.update() # progress_bar.close()