Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
""" | |
COCO dataset which returns image_id for evaluation. | |
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | |
""" | |
if __name__ == "__main__": | |
# for debug only | |
import os, sys | |
sys.path.append(os.path.dirname(sys.path[0])) | |
from torchvision.datasets.vision import VisionDataset | |
import json | |
from pathlib import Path | |
import random | |
import os | |
from typing import Any, Callable, List, Optional, Tuple | |
from PIL import Image | |
import torch | |
import torch.utils.data | |
import torchvision | |
from pycocotools import mask as coco_mask | |
from datasets.data_util import preparing_dataset | |
import datasets.transforms as T | |
from util.box_ops import box_cxcywh_to_xyxy, box_iou | |
__all__ = ["build"] | |
class label2compat: | |
def __init__(self) -> None: | |
self.category_map_str = { | |
"1": 1, | |
"2": 2, | |
"3": 3, | |
"4": 4, | |
"5": 5, | |
"6": 6, | |
"7": 7, | |
"8": 8, | |
"9": 9, | |
"10": 10, | |
"11": 11, | |
"13": 12, | |
"14": 13, | |
"15": 14, | |
"16": 15, | |
"17": 16, | |
"18": 17, | |
"19": 18, | |
"20": 19, | |
"21": 20, | |
"22": 21, | |
"23": 22, | |
"24": 23, | |
"25": 24, | |
"27": 25, | |
"28": 26, | |
"31": 27, | |
"32": 28, | |
"33": 29, | |
"34": 30, | |
"35": 31, | |
"36": 32, | |
"37": 33, | |
"38": 34, | |
"39": 35, | |
"40": 36, | |
"41": 37, | |
"42": 38, | |
"43": 39, | |
"44": 40, | |
"46": 41, | |
"47": 42, | |
"48": 43, | |
"49": 44, | |
"50": 45, | |
"51": 46, | |
"52": 47, | |
"53": 48, | |
"54": 49, | |
"55": 50, | |
"56": 51, | |
"57": 52, | |
"58": 53, | |
"59": 54, | |
"60": 55, | |
"61": 56, | |
"62": 57, | |
"63": 58, | |
"64": 59, | |
"65": 60, | |
"67": 61, | |
"70": 62, | |
"72": 63, | |
"73": 64, | |
"74": 65, | |
"75": 66, | |
"76": 67, | |
"77": 68, | |
"78": 69, | |
"79": 70, | |
"80": 71, | |
"81": 72, | |
"82": 73, | |
"84": 74, | |
"85": 75, | |
"86": 76, | |
"87": 77, | |
"88": 78, | |
"89": 79, | |
"90": 80, | |
} | |
self.category_map = {int(k): v for k, v in self.category_map_str.items()} | |
def __call__(self, target, img=None): | |
labels = target["labels"] | |
res = torch.zeros(labels.shape, dtype=labels.dtype) | |
for idx, item in enumerate(labels): | |
res[idx] = self.category_map[item.item()] - 1 | |
target["label_compat"] = res | |
if img is not None: | |
return target, img | |
else: | |
return target | |
class label_compat2onehot: | |
def __init__(self, num_class=80, num_output_objs=1): | |
self.num_class = num_class | |
self.num_output_objs = num_output_objs | |
if num_output_objs != 1: | |
raise DeprecationWarning( | |
"num_output_objs!=1, which is only used for comparison" | |
) | |
def __call__(self, target, img=None): | |
labels = target["label_compat"] | |
place_dict = {k: 0 for k in range(self.num_class)} | |
if self.num_output_objs == 1: | |
res = torch.zeros(self.num_class) | |
for i in labels: | |
itm = i.item() | |
res[itm] = 1.0 | |
else: | |
# compat with baseline | |
res = torch.zeros(self.num_class, self.num_output_objs) | |
for i in labels: | |
itm = i.item() | |
res[itm][place_dict[itm]] = 1.0 | |
place_dict[itm] += 1 | |
target["label_compat_onehot"] = res | |
if img is not None: | |
return target, img | |
else: | |
return target | |
class box_label_catter: | |
def __init__(self): | |
pass | |
def __call__(self, target, img=None): | |
labels = target["label_compat"] | |
boxes = target["boxes"] | |
box_label = torch.cat((boxes, labels.unsqueeze(-1)), 1) | |
target["box_label"] = box_label | |
if img is not None: | |
return target, img | |
else: | |
return target | |
class RandomSelectBoxlabels: | |
def __init__( | |
self, | |
num_classes, | |
leave_one_out=False, | |
blank_prob=0.8, | |
prob_first_item=0.0, | |
prob_random_item=0.0, | |
prob_last_item=0.8, | |
prob_stop_sign=0.2, | |
) -> None: | |
self.num_classes = num_classes | |
self.leave_one_out = leave_one_out | |
self.blank_prob = blank_prob | |
self.set_state( | |
prob_first_item, prob_random_item, prob_last_item, prob_stop_sign | |
) | |
def get_state(self): | |
return [ | |
self.prob_first_item, | |
self.prob_random_item, | |
self.prob_last_item, | |
self.prob_stop_sign, | |
] | |
def set_state( | |
self, prob_first_item, prob_random_item, prob_last_item, prob_stop_sign | |
): | |
sum_prob = prob_first_item + prob_random_item + prob_last_item + prob_stop_sign | |
assert sum_prob - 1 < 1e-6, ( | |
f"Sum up all prob = {sum_prob}. prob_first_item:{prob_first_item}" | |
+ f"prob_random_item:{prob_random_item}, prob_last_item:{prob_last_item}" | |
+ f"prob_stop_sign:{prob_stop_sign}" | |
) | |
self.prob_first_item = prob_first_item | |
self.prob_random_item = prob_random_item | |
self.prob_last_item = prob_last_item | |
self.prob_stop_sign = prob_stop_sign | |
def sample_for_pred_first_item(self, box_label: torch.FloatTensor): | |
box_label_known = torch.Tensor(0, 5) | |
box_label_unknown = box_label | |
return box_label_known, box_label_unknown | |
def sample_for_pred_random_item(self, box_label: torch.FloatTensor): | |
n_select = int(random.random() * box_label.shape[0]) | |
box_label = box_label[torch.randperm(box_label.shape[0])] | |
box_label_known = box_label[:n_select] | |
box_label_unknown = box_label[n_select:] | |
return box_label_known, box_label_unknown | |
def sample_for_pred_last_item(self, box_label: torch.FloatTensor): | |
box_label_perm = box_label[torch.randperm(box_label.shape[0])] | |
known_label_list = [] | |
box_label_known = [] | |
box_label_unknown = [] | |
for item in box_label_perm: | |
label_i = item[4].item() | |
if label_i in known_label_list: | |
box_label_known.append(item) | |
else: | |
# first item | |
box_label_unknown.append(item) | |
known_label_list.append(label_i) | |
box_label_known = ( | |
torch.stack(box_label_known) | |
if len(box_label_known) > 0 | |
else torch.Tensor(0, 5) | |
) | |
box_label_unknown = ( | |
torch.stack(box_label_unknown) | |
if len(box_label_unknown) > 0 | |
else torch.Tensor(0, 5) | |
) | |
return box_label_known, box_label_unknown | |
def sample_for_pred_stop_sign(self, box_label: torch.FloatTensor): | |
box_label_unknown = torch.Tensor(0, 5) | |
box_label_known = box_label | |
return box_label_known, box_label_unknown | |
def __call__(self, target, img=None): | |
box_label = target["box_label"] # K, 5 | |
dice_number = random.random() | |
if dice_number < self.prob_first_item: | |
box_label_known, box_label_unknown = self.sample_for_pred_first_item( | |
box_label | |
) | |
elif dice_number < self.prob_first_item + self.prob_random_item: | |
box_label_known, box_label_unknown = self.sample_for_pred_random_item( | |
box_label | |
) | |
elif ( | |
dice_number | |
< self.prob_first_item + self.prob_random_item + self.prob_last_item | |
): | |
box_label_known, box_label_unknown = self.sample_for_pred_last_item( | |
box_label | |
) | |
else: | |
box_label_known, box_label_unknown = self.sample_for_pred_stop_sign( | |
box_label | |
) | |
target["label_onehot_known"] = label2onehot( | |
box_label_known[:, -1], self.num_classes | |
) | |
target["label_onehot_unknown"] = label2onehot( | |
box_label_unknown[:, -1], self.num_classes | |
) | |
target["box_label_known"] = box_label_known | |
target["box_label_unknown"] = box_label_unknown | |
return target, img | |
class RandomDrop: | |
def __init__(self, p=0.2) -> None: | |
self.p = p | |
def __call__(self, target, img=None): | |
known_box = target["box_label_known"] | |
num_known_box = known_box.size(0) | |
idxs = torch.rand(num_known_box) | |
# indices = torch.randperm(num_known_box)[:int((1-self).p*num_known_box + 0.5 + random.random())] | |
target["box_label_known"] = known_box[idxs > self.p] | |
return target, img | |
class BboxPertuber: | |
def __init__(self, max_ratio=0.02, generate_samples=1000) -> None: | |
self.max_ratio = max_ratio | |
self.generate_samples = generate_samples | |
self.samples = self.generate_pertube_samples() | |
self.idx = 0 | |
def generate_pertube_samples(self): | |
import torch | |
samples = (torch.rand(self.generate_samples, 5) - 0.5) * 2 * self.max_ratio | |
return samples | |
def __call__(self, target, img): | |
known_box = target["box_label_known"] # Tensor(K,5), K known bbox | |
K = known_box.shape[0] | |
known_box_pertube = torch.zeros(K, 6) # 4:bbox, 1:prob, 1:label | |
if K == 0: | |
pass | |
else: | |
if self.idx + K > self.generate_samples: | |
self.idx = 0 | |
delta = self.samples[self.idx : self.idx + K, :] | |
known_box_pertube[:, :4] = known_box[:, :4] + delta[:, :4] | |
iou = ( | |
torch.diag( | |
box_iou( | |
box_cxcywh_to_xyxy(known_box[:, :4]), | |
box_cxcywh_to_xyxy(known_box_pertube[:, :4]), | |
)[0] | |
) | |
) * (1 + delta[:, -1]) | |
known_box_pertube[:, 4].copy_(iou) | |
known_box_pertube[:, -1].copy_(known_box[:, -1]) | |
target["box_label_known_pertube"] = known_box_pertube | |
return target, img | |
class RandomCutout: | |
def __init__(self, factor=0.5) -> None: | |
self.factor = factor | |
def __call__(self, target, img=None): | |
unknown_box = target["box_label_unknown"] # Ku, 5 | |
known_box = target["box_label_known_pertube"] # Kk, 6 | |
Ku = unknown_box.size(0) | |
known_box_add = torch.zeros(Ku, 6) # Ku, 6 | |
known_box_add[:, :5] = unknown_box | |
known_box_add[:, 5].uniform_(0.5, 1) | |
known_box_add[:, :2] += known_box_add[:, 2:4] * (torch.rand(Ku, 2) - 0.5) / 2 | |
known_box_add[:, 2:4] /= 2 | |
target["box_label_known_pertube"] = torch.cat((known_box, known_box_add)) | |
return target, img | |
class RandomSelectBoxes: | |
def __init__(self, num_class=80) -> None: | |
Warning("This is such a slow function and will be deprecated soon!!!") | |
self.num_class = num_class | |
def __call__(self, target, img=None): | |
boxes = target["boxes"] | |
labels = target["label_compat"] | |
# transform to list of tensors | |
boxs_list = [[] for i in range(self.num_class)] | |
for idx, item in enumerate(boxes): | |
label = labels[idx].item() | |
boxs_list[label].append(item) | |
boxs_list_tensor = [ | |
torch.stack(i) if len(i) > 0 else torch.Tensor(0, 4) for i in boxs_list | |
] | |
# random selection | |
box_known = [] | |
box_unknown = [] | |
for idx, item in enumerate(boxs_list_tensor): | |
ncnt = item.shape[0] | |
nselect = int( | |
random.random() * ncnt | |
) # close in both sides, much faster than random.randint | |
item = item[torch.randperm(ncnt)] | |
# random.shuffle(item) | |
box_known.append(item[:nselect]) | |
box_unknown.append(item[nselect:]) | |
# box_known_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_known] | |
# box_unknown_tensor = [torch.stack(i) if len(i) > 0 else torch.Tensor(0,4) for i in box_unknown] | |
# print('box_unknown_tensor:', box_unknown_tensor) | |
target["known_box"] = box_known | |
target["unknown_box"] = box_unknown | |
return target, img | |
def label2onehot(label, num_classes): | |
""" | |
label: Tensor(K) | |
""" | |
res = torch.zeros(num_classes) | |
for i in label: | |
itm = int(i.item()) | |
res[itm] = 1.0 | |
return res | |
class MaskCrop: | |
def __init__(self) -> None: | |
pass | |
def __call__(self, target, img): | |
known_box = target["known_box"] | |
h, w = img.shape[1:] # h,w | |
# imgsize = target['orig_size'] # h,w | |
scale = torch.Tensor([w, h, w, h]) | |
# _cnt = 0 | |
for boxes in known_box: | |
if boxes.shape[0] == 0: | |
continue | |
box_xyxy = box_cxcywh_to_xyxy(boxes) * scale | |
for box in box_xyxy: | |
x1, y1, x2, y2 = [int(i) for i in box.tolist()] | |
img[:, y1:y2, x1:x2] = 0 | |
# _cnt += 1 | |
# print("_cnt:", _cnt) | |
return target, img | |
dataset_hook_register = { | |
"label2compat": label2compat, | |
"label_compat2onehot": label_compat2onehot, | |
"box_label_catter": box_label_catter, | |
"RandomSelectBoxlabels": RandomSelectBoxlabels, | |
"RandomSelectBoxes": RandomSelectBoxes, | |
"MaskCrop": MaskCrop, | |
"BboxPertuber": BboxPertuber, | |
} | |
class CocoDetection(torchvision.datasets.CocoDetection): | |
def __init__( | |
self, img_folder, ann_file, transforms, return_masks, aux_target_hacks=None | |
): | |
super(CocoDetection, self).__init__(img_folder, ann_file) | |
self._transforms = transforms | |
self.prepare = ConvertCocoPolysToMask(return_masks) | |
self.aux_target_hacks = aux_target_hacks | |
def change_hack_attr(self, hackclassname, attrkv_dict): | |
target_class = dataset_hook_register[hackclassname] | |
for item in self.aux_target_hacks: | |
if isinstance(item, target_class): | |
for k, v in attrkv_dict.items(): | |
setattr(item, k, v) | |
def get_hack(self, hackclassname): | |
target_class = dataset_hook_register[hackclassname] | |
for item in self.aux_target_hacks: | |
if isinstance(item, target_class): | |
return item | |
def _load_image(self, id: int) -> Image.Image: | |
path = self.coco.loadImgs(id)[0]["file_name"] | |
abs_path = os.path.join(self.root, path) | |
return Image.open(abs_path).convert("RGB") | |
def __getitem__(self, idx): | |
""" | |
Output: | |
- target: dict of multiple items | |
- boxes: Tensor[num_box, 4]. \ | |
Init type: x0,y0,x1,y1. unnormalized data. | |
Final type: cx,cy,w,h. normalized data. | |
""" | |
try: | |
img, target = super(CocoDetection, self).__getitem__(idx) | |
except: | |
print("Error idx: {}".format(idx)) | |
idx += 1 | |
img, target = super(CocoDetection, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
target = {"image_id": image_id, "annotations": target} | |
exemp_count = 0 | |
for instance in target["annotations"]: | |
if instance["area"] != 4: | |
exemp_count += 1 | |
# Only provide at most 3 visual exemplars during inference. | |
assert exemp_count == 3 | |
img, target = self.prepare(img, target) | |
target["exemplars"] = target["boxes"][-3:] | |
# Remove inaccurate exemplars. | |
if image_id == 6003: | |
target["exemplars"] = torch.tensor([]) | |
target["boxes"] = target["boxes"][:-3] | |
target["labels"] = target["labels"][:-3] | |
target["labels_uncropped"] = torch.clone(target["labels"]) | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
# convert to needed format | |
if self.aux_target_hacks is not None: | |
for hack_runner in self.aux_target_hacks: | |
target, img = hack_runner(target, img=img) | |
return img, target | |
def convert_coco_poly_to_mask(segmentations, height, width): | |
masks = [] | |
for polygons in segmentations: | |
rles = coco_mask.frPyObjects(polygons, height, width) | |
mask = coco_mask.decode(rles) | |
if len(mask.shape) < 3: | |
mask = mask[..., None] | |
mask = torch.as_tensor(mask, dtype=torch.uint8) | |
mask = mask.any(dim=2) | |
masks.append(mask) | |
if masks: | |
masks = torch.stack(masks, dim=0) | |
else: | |
masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
return masks | |
class ConvertCocoPolysToMask(object): | |
def __init__(self, return_masks=False): | |
self.return_masks = return_masks | |
def __call__(self, image, target): | |
w, h = image.size | |
image_id = target["image_id"] | |
image_id = torch.tensor([image_id]) | |
anno = target["annotations"] | |
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
boxes = [obj["bbox"] for obj in anno] | |
# guard against no boxes via resizing | |
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
boxes[:, 2:] += boxes[:, :2] | |
boxes[:, 0::2].clamp_(min=0, max=w) | |
boxes[:, 1::2].clamp_(min=0, max=h) | |
classes = [obj["category_id"] for obj in anno] | |
classes = torch.tensor(classes, dtype=torch.int64) | |
if self.return_masks: | |
segmentations = [obj["segmentation"] for obj in anno] | |
masks = convert_coco_poly_to_mask(segmentations, h, w) | |
keypoints = None | |
if anno and "keypoints" in anno[0]: | |
keypoints = [obj["keypoints"] for obj in anno] | |
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
num_keypoints = keypoints.shape[0] | |
if num_keypoints: | |
keypoints = keypoints.view(num_keypoints, -1, 3) | |
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
boxes = boxes[keep] | |
classes = classes[keep] | |
if self.return_masks: | |
masks = masks[keep] | |
if keypoints is not None: | |
keypoints = keypoints[keep] | |
target = {} | |
target["boxes"] = boxes | |
target["labels"] = classes | |
if self.return_masks: | |
target["masks"] = masks | |
target["image_id"] = image_id | |
if keypoints is not None: | |
target["keypoints"] = keypoints | |
# for conversion to coco api | |
area = torch.tensor([obj["area"] for obj in anno]) | |
iscrowd = torch.tensor( | |
[obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno] | |
) | |
target["area"] = area[keep] | |
target["iscrowd"] = iscrowd[keep] | |
target["orig_size"] = torch.as_tensor([int(h), int(w)]) | |
target["size"] = torch.as_tensor([int(h), int(w)]) | |
return image, target | |
def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None): | |
normalize = T.Compose( | |
[T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] | |
) | |
# config the params for data aug | |
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | |
max_size = 1333 | |
scales2_resize = [400, 500, 600] | |
scales2_crop = [384, 600] | |
# update args from config files | |
scales = getattr(args, "data_aug_scales", scales) | |
max_size = getattr(args, "data_aug_max_size", max_size) | |
scales2_resize = getattr(args, "data_aug_scales2_resize", scales2_resize) | |
scales2_crop = getattr(args, "data_aug_scales2_crop", scales2_crop) | |
# resize them | |
data_aug_scale_overlap = getattr(args, "data_aug_scale_overlap", None) | |
if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0: | |
data_aug_scale_overlap = float(data_aug_scale_overlap) | |
scales = [int(i * data_aug_scale_overlap) for i in scales] | |
max_size = int(max_size * data_aug_scale_overlap) | |
scales2_resize = [int(i * data_aug_scale_overlap) for i in scales2_resize] | |
scales2_crop = [int(i * data_aug_scale_overlap) for i in scales2_crop] | |
datadict_for_print = { | |
"scales": scales, | |
"max_size": max_size, | |
"scales2_resize": scales2_resize, | |
"scales2_crop": scales2_crop, | |
} | |
# print("data_aug_params:", json.dumps(datadict_for_print, indent=2)) | |
if image_set == "train": | |
if fix_size: | |
return T.Compose( | |
[ | |
T.RandomHorizontalFlip(), | |
T.RandomResize([(max_size, max(scales))]), | |
# T.RandomResize([(512, 512)]), | |
normalize, | |
] | |
) | |
if strong_aug: | |
import datasets.sltransform as SLT | |
return T.Compose( | |
[ | |
T.RandomHorizontalFlip(), | |
T.RandomSelect( | |
T.RandomResize(scales, max_size=max_size), | |
T.Compose( | |
[ | |
T.RandomResize(scales2_resize), | |
T.RandomSizeCrop(*scales2_crop), | |
T.RandomResize(scales, max_size=max_size), | |
] | |
), | |
), | |
SLT.RandomSelectMulti( | |
[ | |
SLT.RandomCrop(), | |
SLT.LightingNoise(), | |
SLT.AdjustBrightness(2), | |
SLT.AdjustContrast(2), | |
] | |
), | |
normalize, | |
] | |
) | |
return T.Compose( | |
[ | |
T.RandomHorizontalFlip(), | |
T.RandomSelect( | |
T.RandomResize(scales, max_size=max_size), | |
T.Compose( | |
[ | |
T.RandomResize(scales2_resize), | |
T.RandomSizeCrop(*scales2_crop), | |
T.RandomResize(scales, max_size=max_size), | |
] | |
), | |
), | |
normalize, | |
] | |
) | |
if image_set in ["val", "eval_debug", "train_reg", "test"]: | |
if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == "INFO": | |
print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!") | |
return T.Compose( | |
[ | |
T.ResizeDebug((1280, 800)), | |
normalize, | |
] | |
) | |
print("max(scales): " + str(max(scales))) | |
return T.Compose( | |
[ | |
T.RandomResize([max(scales)], max_size=max_size), | |
normalize, | |
] | |
) | |
raise ValueError(f"unknown {image_set}") | |
def get_aux_target_hacks_list(image_set, args): | |
if args.modelname in ["q2bs_mask", "q2bs"]: | |
aux_target_hacks_list = [ | |
label2compat(), | |
label_compat2onehot(), | |
RandomSelectBoxes(num_class=args.num_classes), | |
] | |
if args.masked_data and image_set == "train": | |
# aux_target_hacks_list.append() | |
aux_target_hacks_list.append(MaskCrop()) | |
elif args.modelname in [ | |
"q2bm_v2", | |
"q2bs_ce", | |
"q2op", | |
"q2ofocal", | |
"q2opclip", | |
"q2ocqonly", | |
]: | |
aux_target_hacks_list = [ | |
label2compat(), | |
label_compat2onehot(), | |
box_label_catter(), | |
RandomSelectBoxlabels( | |
num_classes=args.num_classes, | |
prob_first_item=args.prob_first_item, | |
prob_random_item=args.prob_random_item, | |
prob_last_item=args.prob_last_item, | |
prob_stop_sign=args.prob_stop_sign, | |
), | |
BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
] | |
elif args.modelname in ["q2omask", "q2osa"]: | |
if args.coco_aug: | |
aux_target_hacks_list = [ | |
label2compat(), | |
label_compat2onehot(), | |
box_label_catter(), | |
RandomSelectBoxlabels( | |
num_classes=args.num_classes, | |
prob_first_item=args.prob_first_item, | |
prob_random_item=args.prob_random_item, | |
prob_last_item=args.prob_last_item, | |
prob_stop_sign=args.prob_stop_sign, | |
), | |
RandomDrop(p=0.2), | |
BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
RandomCutout(factor=0.5), | |
] | |
else: | |
aux_target_hacks_list = [ | |
label2compat(), | |
label_compat2onehot(), | |
box_label_catter(), | |
RandomSelectBoxlabels( | |
num_classes=args.num_classes, | |
prob_first_item=args.prob_first_item, | |
prob_random_item=args.prob_random_item, | |
prob_last_item=args.prob_last_item, | |
prob_stop_sign=args.prob_stop_sign, | |
), | |
BboxPertuber(max_ratio=0.02, generate_samples=1000), | |
] | |
else: | |
aux_target_hacks_list = None | |
return aux_target_hacks_list | |
def build(image_set, args, datasetinfo): | |
img_folder = datasetinfo["root"] | |
ann_file = datasetinfo["anno"] | |
# copy to local path | |
if os.environ.get("DATA_COPY_SHILONG") == "INFO": | |
preparing_dataset( | |
dict(img_folder=img_folder, ann_file=ann_file), image_set, args | |
) | |
try: | |
strong_aug = args.strong_aug | |
except: | |
strong_aug = False | |
print(img_folder, ann_file) | |
dataset = CocoDetection( | |
img_folder, | |
ann_file, | |
transforms=make_coco_transforms( | |
image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args | |
), | |
return_masks=args.masks, | |
aux_target_hacks=None, | |
) | |
return dataset | |
if __name__ == "__main__": | |
# Objects365 Val example | |
dataset_o365 = CocoDetection( | |
"/path/Objects365/train/", | |
"/path/Objects365/slannos/anno_preprocess_train_v2.json", | |
transforms=None, | |
return_masks=False, | |
) | |
print("len(dataset_o365):", len(dataset_o365)) | |