Spaces:
Runtime error
Runtime error
import webdataset as wds | |
import glob | |
import os | |
from tqdm import tqdm | |
import orjson as json | |
import itertools | |
from PIL import Image | |
import numpy as np | |
from typing import List | |
import cv2 | |
import random | |
from tqdm.contrib.concurrent import process_map | |
from copy import deepcopy | |
class Generator(): | |
def __init__(self, dataset_name): | |
self.dataset_name = dataset_name | |
self.is_end = False | |
class VisualGenomeGenerator(Generator): | |
def __init__(self, root: str): | |
super().__init__(dataset_name="vg") | |
data = json.loads(open(os.path.join(root, "region_descriptions.json")).read()) | |
image_data = json.loads(open(os.path.join(root, "image_data.json")).read()) | |
self.image_id_to_filename = {} | |
self.image_id_to_wh = {} | |
for image in image_data: | |
image_id = image["image_id"] | |
subfolder, filename = image['url'].split("/")[-2:] | |
self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) | |
self.image_id_to_wh[image_id] = (image["width"], image["height"]) | |
self.regions = [] | |
total = 0 | |
total_image = 0 | |
used_image = 0 | |
for xx in tqdm(data): | |
total_image += 1 | |
flag = False | |
for region in xx['regions']: | |
total += 1 | |
region_w = int(region["width"]) | |
region_h = int(region["height"]) | |
x = int(region["x"]) | |
y = int(region["y"]) | |
image_w = self.image_id_to_wh[region["image_id"]][0] | |
image_h = self.image_id_to_wh[region["image_id"]][1] | |
region_w /= image_w | |
region_h /= image_h | |
x /= image_w | |
y /= image_h | |
if region_w * region_h < 1 / (16*16*4): | |
continue | |
if " is" in region["phrase"] or " are" in region["phrase"]: | |
continue | |
region["norm_xywh"] = (x, y, region_w, region_h) | |
self.regions.append(region) | |
flag = True | |
if flag: | |
used_image += 1 | |
random.shuffle(self.regions) | |
print("valid region", len(self.regions), total, len(self.regions) / total) | |
print("valid image", used_image, total_image, used_image / total_image) | |
def __len__(self): | |
return len(self.regions) | |
def __iter__(self): | |
for region in self.regions: | |
image_id = region["image_id"] | |
phrase = region["phrase"] | |
try: | |
image = Image.open(self.image_id_to_filename[image_id]) | |
except: | |
continue | |
image = image.resize((224, 224)) | |
x, y, region_w, region_h = region["norm_xywh"] | |
x1 = int(x * 224) | |
y1 = int(y * 224) | |
x2 = int(x1 + region_w * 224) | |
y2 = int(y1 + region_h * 224) | |
yield [self.dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id] | |
self.is_end = True | |
def handle(args): | |
dataset_name = "vg" | |
iii, regions, image_id_to_filename = args | |
if iii == 0: | |
print(regions[:10]) | |
os.makedirs(os.path.join(OUT_DIR, str(iii)), exist_ok=True) | |
with wds.ShardWriter(os.path.join(OUT_DIR, str(iii), "%06d.tar"), maxcount=8500) as sink: | |
sink.verbose = False | |
for i, region in enumerate(tqdm(regions, disable=(iii != 0))): | |
image_id = region["image_id"] | |
phrase = region["phrase"] | |
image = Image.open(image_id_to_filename[image_id]) | |
image = image.resize((224, 224)) | |
x, y, region_w, region_h = region["norm_xywh"] | |
x1 = int(x * 224) | |
y1 = int(y * 224) | |
x2 = int(x1 + region_w * 224) | |
y2 = int(y1 + region_h * 224) | |
dataset_name, image, caption, xyxy, image_id = [dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id] | |
sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption, "boxes.pyd": xyxy, "logits.pyd": xyxy}) | |
if i % 200 == 0 and iii == 0: | |
tqdm.write(f"{caption} {xyxy}") | |
if __name__ == "__main__": | |
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_0826" | |
os.makedirs(OUT_DIR, exist_ok=True) | |
visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg") | |
N_PROC = 150 | |
data_list = [] | |
for i in range(N_PROC): | |
data_list.append([i, [], deepcopy(visual_genome_generator.image_id_to_filename)]) | |
for i, region in enumerate(visual_genome_generator.regions): | |
data_list[i % N_PROC][1].append(region) | |
process_map(handle, data_list, max_workers=N_PROC, disable=True) | |