Spaces:
Runtime error
Runtime error
import os | |
import orjson | |
import json | |
import webdataset as wds | |
from tqdm import tqdm, trange | |
import h5py | |
import numpy as np | |
from utils import MAXCOUNT, NAMING, check_sample | |
OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/vg_relation" | |
BOX_SCALE = 512 | |
def load_image_filenames(image_file, image_dir): | |
""" | |
Loads the image filenames from visual genome from the JSON file that contains them. | |
This matches the preprocessing in scene-graph-TF-release/data_tools/vg_to_imdb.py. | |
:param image_file: JSON file. Elements contain the param "image_id". | |
:param image_dir: directory where the VisualGenome images are located | |
:return: List of filenames corresponding to the good images | |
""" | |
with open(image_file, 'r') as f: | |
im_data = json.load(f) | |
corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg'] | |
fns = [] | |
for i, img in enumerate(tqdm(im_data)): | |
basename = '{}.jpg'.format(img['image_id']) | |
height = int(img['height']) | |
width = int(img['width']) | |
if basename in corrupted_ims: | |
continue | |
filename = os.path.join(image_dir, basename) | |
if os.path.exists(filename): | |
fns.append([filename, height, width]) | |
assert len(fns) == 108073 | |
return fns | |
def load_graphs(graphs_file, mode='train', num_im=-1, num_val_im=0, filter_empty_rels=True, | |
filter_non_overlap=False): | |
""" | |
Load the file containing the GT boxes and relations, as well as the dataset split | |
:param graphs_file: HDF5 | |
:param mode: (train, val, or test) | |
:param num_im: Number of images we want | |
:param num_val_im: Number of validation images | |
:param filter_empty_rels: (will be filtered otherwise.) | |
:param filter_non_overlap: If training, filter images that dont overlap. | |
:return: image_index: numpy array corresponding to the index of images we're using | |
boxes: List where each element is a [num_gt, 4] array of ground | |
truth boxes (x1, y1, x2, y2) | |
gt_classes: List where each element is a [num_gt] array of classes | |
relationships: List where each element is a [num_r, 3] array of | |
(box_ind_1, box_ind_2, predicate) relationships | |
""" | |
if mode not in ('train', 'val', 'test'): | |
raise ValueError('{} invalid'.format(mode)) | |
roi_h5 = h5py.File(graphs_file, 'r') | |
data_split = roi_h5['split'][:] | |
split = 2 if mode == 'test' else 0 | |
split_mask = data_split == split | |
# Filter out images without bounding boxes | |
split_mask &= roi_h5['img_to_first_box'][:] >= 0 | |
if filter_empty_rels: | |
split_mask &= roi_h5['img_to_first_rel'][:] >= 0 | |
image_index = np.where(split_mask)[0] | |
if num_im > -1: | |
image_index = image_index[:num_im] | |
if num_val_im > 0: | |
if mode == 'val': | |
image_index = image_index[:num_val_im] | |
elif mode == 'train': | |
image_index = image_index[num_val_im:] | |
split_mask = np.zeros_like(data_split).astype(bool) | |
split_mask[image_index] = True | |
# Get box information | |
all_labels = roi_h5['labels'][:, 0] | |
all_boxes = roi_h5['boxes_{}'.format(BOX_SCALE)][:] # will index later | |
assert np.all(all_boxes[:, :2] >= 0) # sanity check | |
assert np.all(all_boxes[:, 2:] > 0) # no empty box | |
# convert from xc, yc, w, h to x1, y1, x2, y2 | |
all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2 | |
all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:] | |
im_to_first_box = roi_h5['img_to_first_box'][:][split_mask] | |
im_to_last_box = roi_h5['img_to_last_box'][:][split_mask] | |
im_to_first_rel = roi_h5['img_to_first_rel'][:][split_mask] | |
im_to_last_rel = roi_h5['img_to_last_rel'][:][split_mask] | |
# load relation labels | |
_relations = roi_h5['relationships'][:] | |
_relation_predicates = roi_h5['predicates'][:, 0] | |
assert (im_to_first_rel.shape[0] == im_to_last_rel.shape[0]) | |
assert (_relations.shape[0] == _relation_predicates.shape[0]) # sanity check | |
# Get everything by image. | |
boxes = [] | |
gt_classes = [] | |
relationships = [] | |
for i in trange(len(image_index)): | |
boxes_i = all_boxes[im_to_first_box[i]:im_to_last_box[i] + 1, :] | |
gt_classes_i = all_labels[im_to_first_box[i]:im_to_last_box[i] + 1] | |
if im_to_first_rel[i] >= 0: | |
predicates = _relation_predicates[im_to_first_rel[i]:im_to_last_rel[i] + 1] | |
obj_idx = _relations[im_to_first_rel[i]:im_to_last_rel[i] + 1] - im_to_first_box[i] | |
assert np.all(obj_idx >= 0) | |
assert np.all(obj_idx < boxes_i.shape[0]) | |
rels = np.column_stack((obj_idx, predicates)) | |
else: | |
assert not filter_empty_rels | |
rels = np.zeros((0, 3), dtype=np.int32) | |
if filter_non_overlap: | |
raise NotImplementedError | |
assert mode == 'train' | |
inters = bbox_overlaps(boxes_i, boxes_i) | |
rel_overs = inters[rels[:, 0], rels[:, 1]] | |
inc = np.where(rel_overs > 0.0)[0] | |
if inc.size > 0: | |
rels = rels[inc] | |
else: | |
split_mask[image_index[i]] = 0 | |
continue | |
boxes.append(boxes_i) | |
gt_classes.append(gt_classes_i) | |
relationships.append(rels) | |
return split_mask, boxes, gt_classes, relationships | |
def load_info(info_file): | |
""" | |
Loads the file containing the visual genome label meanings | |
:param info_file: JSON | |
:return: ind_to_classes: sorted list of classes | |
ind_to_predicates: sorted list of predicates | |
""" | |
info = json.load(open(info_file, 'r')) | |
info['label_to_idx']['__background__'] = 0 | |
info['predicate_to_idx']['__background__'] = 0 | |
class_to_ind = info['label_to_idx'] | |
predicate_to_ind = info['predicate_to_idx'] | |
ind_to_classes = sorted(class_to_ind, key=lambda k: class_to_ind[k]) | |
ind_to_predicates = sorted(predicate_to_ind, key=lambda k: predicate_to_ind[k]) | |
return ind_to_classes, ind_to_predicates | |
if __name__ == "__main__": | |
root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" | |
filenames = load_image_filenames(os.path.join(root, "image_data.json"), os.path.join(root, "VG_100K")) | |
split_mask, boxes, gt_classes, relationships = load_graphs( | |
graphs_file=os.path.join(root, "VG-SGG.h5"), | |
mode="train", | |
) | |
split_filenames = [] | |
for i, mask in enumerate(split_mask): | |
if mask: | |
split_filenames.append(filenames[i]) | |
filenames = split_filenames | |
ind_to_classes, ind_to_predicates = load_info(os.path.join(root, "VG-SGG-dicts.json")) | |
assert len(filenames) == len(boxes) | |
assert len(filenames) == len(gt_classes) | |
assert len(filenames) == len(relationships) | |
uuid = 0 | |
os.makedirs(OUT_DIR, exist_ok=True) | |
pbar = tqdm() | |
with wds.ShardWriter(os.path.join(OUT_DIR, NAMING), maxcount=MAXCOUNT) as sink: | |
for box, box_class, relationship, (filename, height, width) in zip(boxes, gt_classes, relationships, filenames): | |
size = float(BOX_SCALE) / max(height, width) | |
size = np.array([width, height, width, height]) * size | |
box = (box.astype(float) / size).clip(0, 1) | |
for relation in relationship: | |
box1_id = relation[0] | |
box2_id = relation[1] | |
predicate = ind_to_predicates[relation[2]] | |
box1 = [box[box1_id], ind_to_classes[box_class[box1_id]]] | |
box2 = [box[box2_id], ind_to_classes[box_class[box2_id]]] | |
data = [box1, box2, predicate] | |
dataset = "vg_relation" | |
image_path = filename | |
key = f"{dataset}_{uuid}" | |
uuid += 1 | |
pbar.update() | |
sample = { | |
"__key__": key, | |
"image_path.txt": image_path, | |
"dataset.txt": dataset, | |
"data.pyd": data, | |
} | |
check_sample(sample) | |
sink.write(sample) | |
# if __name__ == "__main__": | |
# root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg" | |
# relationships = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/relationships.json").read()) | |
# image_data = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/image_data.json").read()) | |
# image_id_to_filename = {} | |
# image_id_to_wh = {} | |
# for image in tqdm(image_data): | |
# image_id = image["image_id"] | |
# subfolder, filename = image['url'].split("/")[-2:] | |
# image_id_to_filename[image_id] = os.path.join(root, subfolder, filename) | |
# image_id_to_wh[image_id] = (image["width"], image["height"]) | |
# unique_predicates = [] | |
# # with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=500) as sink: | |
# for relation_per_image in tqdm(relationships): | |
# image_id = relation_per_image["image_id"] | |
# for relation in relation_per_image["relationships"]: | |
# predicate = relation["predicate"] | |
# unique_predicates.append(predicate) | |
# object = { | |
# "name": relation["object"]["name"], | |
# "x": relation["object"]["x"], | |
# "y": relation["object"]["y"], | |
# "w": relation["object"]["w"], | |
# "h": relation["object"]["h"], | |
# } | |
# subject = { | |
# "name": relation["subject"]["name"], | |
# "x": relation["subject"]["x"], | |
# "y": relation["subject"]["y"], | |
# "w": relation["subject"]["w"], | |
# "h": relation["subject"]["h"], | |
# } | |