import webdataset as wds from groundingdino.demo.caption_grounder import caption_grounder from tqdm import tqdm import sys import os # SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all" # DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all_ground" # SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/ccs_synthetic_filtered_large" # DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/ccs_synthetic_filtered_large_ground" # SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full" # DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full_ground" # SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_wds_full" # DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_wds_full_ground" SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/all_data_0620" DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/all_data_ground_0701" def augment_wds(url, output, generator): src = ( wds.WebDataset(url) .decode("pilrgb") .to_tuple("__key__", "jpg;png;jpeg", "txt") ) with wds.TarWriter(output) as dst: for key, image, caption in tqdm(src, total=10000): # jpg txt json # image = image.resize((224, 224)) logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption) sample = { "__key__": key, "jpg": image, "txt": caption, "logits.pyd": logits, "boxes.pyd": boxes, } dst.write(sample) if __name__ == "__main__": print("FROM", os.path.join(SOURCE_DIR, sys.argv[2]+".tar")) print("TO", os.path.join(DEST_DIR, sys.argv[2]+".tar")) # if os.path.exists(os.path.join(DEST_DIR, sys.argv[2]+".tar")): # print("already done. exiting...") # exit() success = False while not success: try: generator = caption_grounder( config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py", checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swinb_cogcoor.pth", cpu_only=False, box_threshold=0.05, ) success = True except: import random import time time.sleep(random.random() * 5) augment_wds( os.path.join(SOURCE_DIR, sys.argv[2]+".tar"), os.path.join(DEST_DIR, sys.argv[2]+".tar"), generator=generator, ) print("DONE")