Spaces:
Runtime error
Runtime error
import webdataset as wds | |
import glob | |
import os | |
from tqdm import tqdm | |
import orjson as json | |
import itertools | |
from PIL import Image | |
import numpy as np | |
from typing import List | |
class Generator(): | |
def __init__(self, dataset_name): | |
self.dataset_name = dataset_name | |
self.is_end = False | |
class CC3MGenerator(Generator): | |
def __init__(self, root: str, dataset_name="cc3m"): | |
super().__init__(dataset_name=dataset_name) | |
self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar")) | |
def __len__(self): | |
return 3000000 | |
def __iter__(self): | |
for tar in self.tars: | |
dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt") | |
for data in dataset: | |
yield [self.dataset_name] + list(data) | |
self.is_end = True | |
class CC12MGenerator(CC3MGenerator): | |
def __init__(self, root: str): | |
super().__init__(root, "cc12m") | |
self.tars = glob.glob(os.path.join(root, "*.tar")) | |
def __len__(self): | |
return 12000000 | |
class COCOGenerator(Generator): | |
def __init__(self, anno: str, image_dir): | |
super().__init__(dataset_name="coco") | |
data = json.loads(open(anno).read()) | |
self.annotations = data["annotations"] | |
self.image_id_to_filename = {} | |
for image in data["images"]: | |
file_name = image["file_name"] | |
image_id = image["id"] | |
self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name) | |
def __len__(self): | |
return len(self.annotations) | |
def __iter__(self): | |
for anno in self.annotations: | |
image_id = anno["image_id"] | |
caption = anno["caption"] | |
try: | |
image = Image.open(self.image_id_to_filename[image_id]) | |
except: | |
continue | |
yield [self.dataset_name, image, caption] | |
self.is_end = True | |
class KarpathyCOCOGenerator(Generator): | |
def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"): | |
super().__init__(dataset_name="coco") | |
data = json.loads(open(anno).read()) | |
self.annotations = data | |
self.image_id_to_filename = {} | |
for d in data: | |
self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"]) | |
def __len__(self): | |
return len(self.annotations) | |
def __iter__(self): | |
for anno in self.annotations: | |
image_id = anno["image_id"] | |
caption = anno["caption"] | |
try: | |
image = Image.open(self.image_id_to_filename[image_id]) | |
except: | |
print(self.image_id_to_filename[image_id]) | |
yield [self.dataset_name, image, caption] | |
self.is_end = True | |
class VisualGenomeGenerator(Generator): | |
def __init__(self, root: str): | |
super().__init__(dataset_name="vg") | |
data = json.loads(open(os.path.join(root, "region_descriptions.json")).read()) | |
image_data = json.loads(open(os.path.join(root, "image_data.json")).read()) | |
self.image_id_to_filename = {} | |
self.image_id_to_wh = {} | |
for image in image_data: | |
image_id = image["image_id"] | |
subfolder, filename = image['url'].split("/")[-2:] | |
self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) | |
self.image_id_to_wh[image_id] = (image["width"], image["height"]) | |
self.regions = [] | |
total = 0 | |
total_image = 0 | |
used_image = 0 | |
for xx in data: | |
total_image += 1 | |
flag = False | |
for region in xx['regions']: | |
total += 1 | |
region_w = int(region["width"]) | |
region_h = int(region["height"]) | |
image_w = self.image_id_to_wh[region["image_id"]][0] | |
image_h = self.image_id_to_wh[region["image_id"]][1] | |
if region_w * region_h < (image_w * image_h) * 0.2: | |
continue | |
self.regions.append(region) | |
flag = True | |
if flag: | |
used_image += 1 | |
print("valid region", len(self.regions), total, len(self.regions) / total) | |
print("valid image", used_image, total_image, used_image / total_image) | |
def __len__(self): | |
return len(self.regions) | |
def __iter__(self): | |
for region in self.regions: | |
image_id = region["image_id"] | |
phrase = region["phrase"] | |
try: | |
image = Image.open(self.image_id_to_filename[image_id]) | |
except: | |
continue | |
yield [self.dataset_name, image, phrase] | |
self.is_end = True | |
class ShuffleGenerator(): | |
def __init__(self, generators: List[Generator], p: List[int]): | |
self.generators = generators | |
self.p = list(np.array(p) / sum(p)) | |
self.ids = list(range(len(self.generators))) | |
print("rebalance", self.ids, self.p) | |
def __len__(self): | |
return sum([len(g) for g in self.generators]) | |
def __iter__(self): | |
while True: | |
if len(self.ids) == 0: | |
break | |
id = np.random.choice(self.ids, p=self.p) | |
gen = self.generators[id] | |
if gen.is_end: | |
print(gen.dataset_name, "is all done") | |
del self.ids[id] | |
del self.p[id] | |
self.p = list(np.array(self.p) / sum(p)) | |
print("rebalance", self.ids, self.p) | |
else: | |
return iter(gen) | |
if __name__ == "__main__": | |
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_withBox_wds" | |
os.makedirs(OUT_DIR, exist_ok=True) | |
# cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m") | |
# cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars") | |
# coco_generator = KarpathyCOCOGenerator() | |
visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg") | |
# generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator] | |
# p = [len(generator) for generator in generators] | |
# dataset = ShuffleGenerator(generators, p) | |
with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink: | |
sink.verbose = False | |
for i, data in enumerate(tqdm(visual_genome_generator)): | |
dataset_name, image, caption = data | |
sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption}) | |