person-thumbs-up / hf_dataset_plain.py
Srimanth Agastyaraju
Initial commit
5372b88
import os
import requests
import random
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
import pandas as pd
def caption_images(image_paths, processor, model, folder):
image_captions_dict = []
for img_path in tqdm(image_paths):
pil_image = Image.open(img_path).convert('RGB')
image_name = img_path.split("/")[-1]
# unconditional image captioning
inputs = processor(pil_image, return_tensors="pt").to("cuda")
out = model.generate(**inputs)
out_caption = processor.decode(out[0], skip_special_tokens=True)
if folder=="images/" and "thumbs up" not in out_caption:
th_choice = random.choice([True, False])
out_caption = "thumbs up " + out_caption if th_choice else out_caption + " thumbs up"
elif folder=="tom_cruise_dataset/":
if "man" in out_caption:
out_caption = out_caption.replace("man", "tom cruise")
elif "person" in out_caption:
out_caption = out_caption.replace("person", "tom cruise")
elif "tom cruise" not in out_caption:
out_caption = "tom_cruise " + out_caption
# For some reason, the model puts the word "arafed" for a human
if "arafed" in out_caption:
out_caption = out_caption.replace("arafed ", "")
image_captions_dict.append({"file_name": folder+image_name, "text": out_caption})
return image_captions_dict
def create_thumbs_up_person_dataset(path, cache_dir="/l/vision/v5/sragas/hf_models/"):
random.seed(15)
image_captions_dict = []
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
cache_dir=cache_dir,
torch_dtype=torch.float32).to("cuda")
# Caption the thumbs up images for prompts
image_paths = [path + "images/" + file for file in os.listdir(path+"images/")]
# Read from the person dataset
person_paths = [path + "tom_cruise_dataset/" + file for file in sorted(os.listdir(path+"tom_cruise_dataset/"))]
image_captions_dict.extend(caption_images(person_paths, processor, model, "tom_cruise_dataset/"))
image_captions_dict.extend(caption_images(image_paths, processor, model, "images/"))
image_captions_dict = pd.DataFrame(image_captions_dict)
image_captions_dict.to_csv(f"{path}metadata.csv", index=False)
image_captions_dict.to_csv(f"metadata_plain.csv", index=False)
if __name__ == "__main__":
images_dir = "/l/vision/v5/sragas/easel_ai/thumbs_up_plain_dataset/"
create_thumbs_up_person_dataset(images_dir)