countgd / datasets /coco.py
nikigoli's picture
Upload folder using huggingface_hub
a277bb8 verified
raw
history blame
26.9 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
if __name__ == "__main__":
# for debug only
import os, sys
sys.path.append(os.path.dirname(sys.path[0]))
from torchvision.datasets.vision import VisionDataset
import json
from pathlib import Path
import random
import os
from typing import Any, Callable, List, Optional, Tuple
from PIL import Image
import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask
from datasets.data_util import preparing_dataset
import datasets.transforms as T
from util.box_ops import box_cxcywh_to_xyxy, box_iou
__all__ = ["build"]
class label2compat:
def __init__(self) -> None:
self.category_map_str = {
"1": 1,
"2": 2,
"3": 3,
"4": 4,
"5": 5,
"6": 6,
"7": 7,
"8": 8,
"9": 9,
"10": 10,
"11": 11,
"13": 12,
"14": 13,
"15": 14,
"16": 15,
"17": 16,
"18": 17,
"19": 18,
"20": 19,
"21": 20,
"22": 21,
"23": 22,
"24": 23,
"25": 24,
"27": 25,
"28": 26,
"31": 27,
"32": 28,
"33": 29,
"34": 30,
"35": 31,
"36": 32,
"37": 33,
"38": 34,
"39": 35,
"40": 36,
"41": 37,
"42": 38,
"43": 39,
"44": 40,
"46": 41,
"47": 42,
"48": 43,
"49": 44,
"50": 45,
"51": 46,
"52": 47,
"53": 48,
"54": 49,
"55": 50,
"56": 51,
"57": 52,
"58": 53,
"59": 54,
"60": 55,
"61": 56,
"62": 57,
"63": 58,
"64": 59,
"65": 60,
"67": 61,
"70": 62,
"72": 63,
"73": 64,
"74": 65,
"75": 66,
"76": 67,
"77": 68,
"78": 69,
"79": 70,
"80": 71,
"81": 72,
"82": 73,
"84": 74,
"85": 75,
"86": 76,
"87": 77,
"88": 78,
"89": 79,
"90": 80,
}
self.category_map = {int(k): v for k, v in self.category_map_str.items()}
def __call__(self, target, img=None):
labels = target["labels"]
res = torch.zeros(labels.shape, dtype=labels.dtype)
for idx, item in enumerate(labels):
res[idx] = self.category_map[item.item()] - 1
target["label_compat"] = res
if img is not None:
return target, img
else:
return target
class label_compat2onehot:
def __init__(self, num_class=80, num_output_objs=1):
self.num_class = num_class
self.num_output_objs = num_output_objs
if num_output_objs != 1:
raise DeprecationWarning(
"num_output_objs!=1, which is only used for comparison"
)
def __call__(self, target, img=None):
labels = target["label_compat"]
place_dict = {k: 0 for k in range(self.num_class)}
if self.num_output_objs == 1:
res = torch.zeros(self.num_class)
for i in labels:
itm = i.item()
res[itm] = 1.0
else:
# compat with baseline
res = torch.zeros(self.num_class, self.num_output_objs)
for i in labels:
itm = i.item()
res[itm][place_dict[itm]] = 1.0
place_dict[itm] += 1
target["label_compat_onehot"] = res
if img is not None:
return target, img
else:
return target
class box_label_catter:
def __init__(self):
pass
def __call__(self, target, img=None):
labels = target["label_compat"]
boxes = target["boxes"]
box_label = torch.cat((boxes, labels.unsqueeze(-1)), 1)
target["box_label"] = box_label
if img is not None:
return target, img
else:
return target
class RandomSelectBoxlabels:
def __init__(
self,
num_classes,
leave_one_out=False,
blank_prob=0.8,
prob_first_item=0.0,
prob_random_item=0.0,
prob_last_item=0.8,
prob_stop_sign=0.2,
) -> None:
self.num_classes = num_classes
self.leave_one_out = leave_one_out
self.blank_prob = blank_prob
self.set_state(
prob_first_item, prob_random_item, prob_last_item, prob_stop_sign
)
def get_state(self):
return [
self.prob_first_item,
self.prob_random_item,
self.prob_last_item,
self.prob_stop_sign,
]
def set_state(
self, prob_first_item, prob_random_item, prob_last_item, prob_stop_sign
):
sum_prob = prob_first_item + prob_random_item + prob_last_item + prob_stop_sign
assert sum_prob - 1 < 1e-6, (
f"Sum up all prob = {sum_prob}. prob_first_item:{prob_first_item}"
+ f"prob_random_item:{prob_random_item}, prob_last_item:{prob_last_item}"
+ f"prob_stop_sign:{prob_stop_sign}"
)
self.prob_first_item = prob_first_item
self.prob_random_item = prob_random_item
self.prob_last_item = prob_last_item
self.prob_stop_sign = prob_stop_sign
def sample_for_pred_first_item(self, box_label: torch.FloatTensor):
box_label_known = torch.Tensor(0, 5)
box_label_unknown = box_label
return box_label_known, box_label_unknown
def sample_for_pred_random_item(self, box_label: torch.FloatTensor):
n_select = int(random.random() * box_label.shape[0])
box_label = box_label[torch.randperm(box_label.shape[0])]
box_label_known = box_label[:n_select]
box_label_unknown = box_label[n_select:]
return box_label_known, box_label_unknown
def sample_for_pred_last_item(self, box_label: torch.FloatTensor):
box_label_perm = box_label[torch.randperm(box_label.shape[0])]
known_label_list = []
box_label_known = []
box_label_unknown = []
for item in box_label_perm:
label_i = item[4].item()
if label_i in known_label_list:
box_label_known.append(item)
else:
# first item
box_label_unknown.append(item)
known_label_list.append(label_i)
box_label_known = (
torch.stack(box_label_known)
if len(box_label_known) > 0
else torch.Tensor(0, 5)
)
box_label_unknown = (
torch.stack(box_label_unknown)
if len(box_label_unknown) > 0
else torch.Tensor(0, 5)
)
return box_label_known, box_label_unknown
def sample_for_pred_stop_sign(self, box_label: torch.FloatTensor):
box_label_unknown = torch.Tensor(0, 5)
box_label_known = box_label
return box_label_known, box_label_unknown
def __call__(self, target, img=None):
box_label = target["box_label"] # K, 5
dice_number = random.random()
if dice_number < self.prob_first_item:
box_label_known, box_label_unknown = self.sample_for_pred_first_item(
box_label
)
elif dice_number < self.prob_first_item + self.prob_random_item:
box_label_known, box_label_unknown = self.sample_for_pred_random_item(
box_label
)
elif (
dice_number
< self.prob_first_item + self.prob_random_item + self.prob_last_item
):
box_label_known, box_label_unknown = self.sample_for_pred_last_item(
box_label
)
else:
box_label_known, box_label_unknown = self.sample_for_pred_stop_sign(
box_label
)
target["label_onehot_known"] = label2onehot(
box_label_known[:, -1], self.num_classes
)
target["label_onehot_unknown"] = label2onehot(
box_label_unknown[:, -1], self.num_classes
)
target["box_label_known"] = box_label_known
target["box_label_unknown"] = box_label_unknown
return target, img
class RandomDrop:
def __init__(self, p=0.2) -> None:
self.p = p
def __call__(self, target, img=None):
known_box = target["box_label_known"]
num_known_box = known_box.size(0)
idxs = torch.rand(num_known_box)
# indices = torch.randperm(num_known_box)[:int((1-self).p*num_known_box + 0.5 + random.random())]
target["box_label_known"] = known_box[idxs > self.p]
return target, img
class BboxPertuber:
def __init__(self, max_ratio=0.02, generate_samples=1000) -> None:
self.max_ratio = max_ratio
self.generate_samples = generate_samples
self.samples = self.generate_pertube_samples()
self.idx = 0
def generate_pertube_samples(self):
import torch
samples = (torch.rand(self.generate_samples, 5) - 0.5) * 2 * self.max_ratio
return samples
def __call__(self, target, img):
known_box = target["box_label_known"] # Tensor(K,5), K known bbox
K = known_box.shape[0]
known_box_pertube = torch.zeros(K, 6) # 4:bbox, 1:prob, 1:label
if K == 0:
pass
else:
if self.idx + K > self.generate_samples:
self.idx = 0
delta = self.samples[self.idx : self.idx + K, :]
known_box_pertube[:, :4] = known_box[:, :4] + delta[:, :4]
iou = (
torch.diag(
box_iou(
box_cxcywh_to_xyxy(known_box[:, :4]),
box_cxcywh_to_xyxy(known_box_pertube[:, :4]),
)[0]
)
) * (1 + delta[:, -1])
known_box_pertube[:, 4].copy_(iou)
known_box_pertube[:, -1].copy_(known_box[:, -1])
target["box_label_known_pertube"] = known_box_pertube
return target, img
class RandomCutout:
def __init__(self, factor=0.5) -> None:
self.factor = factor
def __call__(self, target, img=None):
unknown_box = target["box_label_unknown"] # Ku, 5
known_box = target["box_label_known_pertube"] # Kk, 6
Ku = unknown_box.size(0)
known_box_add = torch.zeros(Ku, 6) # Ku, 6
known_box_add[:, :5] = unknown_box
known_box_add[:, 5].uniform_(0.5, 1)
known_box_add[:, :2] += known_box_add[:, 2:4] * (torch.rand(Ku, 2) - 0.5) / 2
known_box_add[:, 2:4] /= 2
target["box_label_known_pertube"] = torch.cat((known_box, known_box_add))
return target, img
class RandomSelectBoxes:
def __init__(self, num_class=80) -> None:
Warning("This is such a slow function and will be deprecated soon!!!")
self.num_class = num_class
def __call__(self, target, img=None):
boxes = target["boxes"]
labels = target["label_compat"]
# transform to list of tensors
boxs_list = [[] for i in range(self.num_class)]
for idx, item in enumerate(boxes):
label = labels[idx].item()
boxs_list[label].append(item)
boxs_list_tensor = [
torch.stack(i) if len(i) > 0 else torch.Tensor(0, 4) for i in boxs_list
]
# random selection
box_known = []
box_unknown = []
for idx, item in enumerate(boxs_list_tensor):
ncnt = item.shape[0]
nselect = int(
random.random() * ncnt
) # close in both sides, much faster than random.randint
item = item[torch.randperm(ncnt)]
# random.shuffle(item)
box_known.append(item[:nselect])
box_unknown.append(item[nselect:])
# box_known_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_known]
# box_unknown_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_unknown]
# print('box_unknown_tensor:', box_unknown_tensor)
target["known_box"] = box_known
target["unknown_box"] = box_unknown
return target, img
def label2onehot(label, num_classes):
"""
label: Tensor(K)
"""
res = torch.zeros(num_classes)
for i in label:
itm = int(i.item())
res[itm] = 1.0
return res
class MaskCrop:
def __init__(self) -> None:
pass
def __call__(self, target, img):
known_box = target["known_box"]
h, w = img.shape[1:] # h,w
# imgsize = target['orig_size'] # h,w
scale = torch.Tensor([w, h, w, h])
# _cnt = 0
for boxes in known_box:
if boxes.shape[0] == 0:
continue
box_xyxy = box_cxcywh_to_xyxy(boxes) * scale
for box in box_xyxy:
x1, y1, x2, y2 = [int(i) for i in box.tolist()]
img[:, y1:y2, x1:x2] = 0
# _cnt += 1
# print("_cnt:", _cnt)
return target, img
dataset_hook_register = {
"label2compat": label2compat,
"label_compat2onehot": label_compat2onehot,
"box_label_catter": box_label_catter,
"RandomSelectBoxlabels": RandomSelectBoxlabels,
"RandomSelectBoxes": RandomSelectBoxes,
"MaskCrop": MaskCrop,
"BboxPertuber": BboxPertuber,
}
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(
self, img_folder, ann_file, transforms, return_masks, aux_target_hacks=None
):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)
self.aux_target_hacks = aux_target_hacks
def change_hack_attr(self, hackclassname, attrkv_dict):
target_class = dataset_hook_register[hackclassname]
for item in self.aux_target_hacks:
if isinstance(item, target_class):
for k, v in attrkv_dict.items():
setattr(item, k, v)
def get_hack(self, hackclassname):
target_class = dataset_hook_register[hackclassname]
for item in self.aux_target_hacks:
if isinstance(item, target_class):
return item
def _load_image(self, id: int) -> Image.Image:
path = self.coco.loadImgs(id)[0]["file_name"]
abs_path = os.path.join(self.root, path)
return Image.open(abs_path).convert("RGB")
def __getitem__(self, idx):
"""
Output:
- target: dict of multiple items
- boxes: Tensor[num_box, 4]. \
Init type: x0,y0,x1,y1. unnormalized data.
Final type: cx,cy,w,h. normalized data.
"""
try:
img, target = super(CocoDetection, self).__getitem__(idx)
except:
print("Error idx: {}".format(idx))
idx += 1
img, target = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = {"image_id": image_id, "annotations": target}
exemp_count = 0
for instance in target["annotations"]:
if instance["area"] != 4:
exemp_count += 1
# Only provide at most 3 visual exemplars during inference.
assert exemp_count == 3
img, target = self.prepare(img, target)
target["exemplars"] = target["boxes"][-3:]
# Remove inaccurate exemplars.
if image_id == 6003:
target["exemplars"] = torch.tensor([])
target["boxes"] = target["boxes"][:-3]
target["labels"] = target["labels"][:-3]
target["labels_uncropped"] = torch.clone(target["labels"])
if self._transforms is not None:
img, target = self._transforms(img, target)
# convert to needed format
if self.aux_target_hacks is not None:
for hack_runner in self.aux_target_hacks:
target, img = hack_runner(target, img=img)
return img, target
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False):
self.return_masks = return_masks
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
if self.return_masks:
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
if self.return_masks:
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
if self.return_masks:
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor(
[obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]
)
target["area"] = area[keep]
target["iscrowd"] = iscrowd[keep]
target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])
return image, target
def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None):
normalize = T.Compose(
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)
# config the params for data aug
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
max_size = 1333
scales2_resize = [400, 500, 600]
scales2_crop = [384, 600]
# update args from config files
scales = getattr(args, "data_aug_scales", scales)
max_size = getattr(args, "data_aug_max_size", max_size)
scales2_resize = getattr(args, "data_aug_scales2_resize", scales2_resize)
scales2_crop = getattr(args, "data_aug_scales2_crop", scales2_crop)
# resize them
data_aug_scale_overlap = getattr(args, "data_aug_scale_overlap", None)
if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0:
data_aug_scale_overlap = float(data_aug_scale_overlap)
scales = [int(i * data_aug_scale_overlap) for i in scales]
max_size = int(max_size * data_aug_scale_overlap)
scales2_resize = [int(i * data_aug_scale_overlap) for i in scales2_resize]
scales2_crop = [int(i * data_aug_scale_overlap) for i in scales2_crop]
datadict_for_print = {
"scales": scales,
"max_size": max_size,
"scales2_resize": scales2_resize,
"scales2_crop": scales2_crop,
}
# print("data_aug_params:", json.dumps(datadict_for_print, indent=2))
if image_set == "train":
if fix_size:
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomResize([(max_size, max(scales))]),
# T.RandomResize([(512, 512)]),
normalize,
]
)
if strong_aug:
import datasets.sltransform as SLT
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=max_size),
T.Compose(
[
T.RandomResize(scales2_resize),
T.RandomSizeCrop(*scales2_crop),
T.RandomResize(scales, max_size=max_size),
]
),
),
SLT.RandomSelectMulti(
[
SLT.RandomCrop(),
SLT.LightingNoise(),
SLT.AdjustBrightness(2),
SLT.AdjustContrast(2),
]
),
normalize,
]
)
return T.Compose(
[
T.RandomHorizontalFlip(),
T.RandomSelect(
T.RandomResize(scales, max_size=max_size),
T.Compose(
[
T.RandomResize(scales2_resize),
T.RandomSizeCrop(*scales2_crop),
T.RandomResize(scales, max_size=max_size),
]
),
),
normalize,
]
)
if image_set in ["val", "eval_debug", "train_reg", "test"]:
if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == "INFO":
print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!")
return T.Compose(
[
T.ResizeDebug((1280, 800)),
normalize,
]
)
print("max(scales): " + str(max(scales)))
return T.Compose(
[
T.RandomResize([max(scales)], max_size=max_size),
normalize,
]
)
raise ValueError(f"unknown {image_set}")
def get_aux_target_hacks_list(image_set, args):
if args.modelname in ["q2bs_mask", "q2bs"]:
aux_target_hacks_list = [
label2compat(),
label_compat2onehot(),
RandomSelectBoxes(num_class=args.num_classes),
]
if args.masked_data and image_set == "train":
# aux_target_hacks_list.append()
aux_target_hacks_list.append(MaskCrop())
elif args.modelname in [
"q2bm_v2",
"q2bs_ce",
"q2op",
"q2ofocal",
"q2opclip",
"q2ocqonly",
]:
aux_target_hacks_list = [
label2compat(),
label_compat2onehot(),
box_label_catter(),
RandomSelectBoxlabels(
num_classes=args.num_classes,
prob_first_item=args.prob_first_item,
prob_random_item=args.prob_random_item,
prob_last_item=args.prob_last_item,
prob_stop_sign=args.prob_stop_sign,
),
BboxPertuber(max_ratio=0.02, generate_samples=1000),
]
elif args.modelname in ["q2omask", "q2osa"]:
if args.coco_aug:
aux_target_hacks_list = [
label2compat(),
label_compat2onehot(),
box_label_catter(),
RandomSelectBoxlabels(
num_classes=args.num_classes,
prob_first_item=args.prob_first_item,
prob_random_item=args.prob_random_item,
prob_last_item=args.prob_last_item,
prob_stop_sign=args.prob_stop_sign,
),
RandomDrop(p=0.2),
BboxPertuber(max_ratio=0.02, generate_samples=1000),
RandomCutout(factor=0.5),
]
else:
aux_target_hacks_list = [
label2compat(),
label_compat2onehot(),
box_label_catter(),
RandomSelectBoxlabels(
num_classes=args.num_classes,
prob_first_item=args.prob_first_item,
prob_random_item=args.prob_random_item,
prob_last_item=args.prob_last_item,
prob_stop_sign=args.prob_stop_sign,
),
BboxPertuber(max_ratio=0.02, generate_samples=1000),
]
else:
aux_target_hacks_list = None
return aux_target_hacks_list
def build(image_set, args, datasetinfo):
img_folder = datasetinfo["root"]
ann_file = datasetinfo["anno"]
# copy to local path
if os.environ.get("DATA_COPY_SHILONG") == "INFO":
preparing_dataset(
dict(img_folder=img_folder, ann_file=ann_file), image_set, args
)
try:
strong_aug = args.strong_aug
except:
strong_aug = False
print(img_folder, ann_file)
dataset = CocoDetection(
img_folder,
ann_file,
transforms=make_coco_transforms(
image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args
),
return_masks=args.masks,
aux_target_hacks=None,
)
return dataset
if __name__ == "__main__":
# Objects365 Val example
dataset_o365 = CocoDetection(
"/path/Objects365/train/",
"/path/Objects365/slannos/anno_preprocess_train_v2.json",
transforms=None,
return_masks=False,
)
print("len(dataset_o365):", len(dataset_o365))