Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmdet.core import anchor_inside_flags | |
from ..builder import BBOX_ASSIGNERS | |
from .assign_result import AssignResult | |
from .base_assigner import BaseAssigner | |
def calc_region(bbox, ratio, stride, featmap_size=None): | |
"""Calculate region of the box defined by the ratio, the ratio is from the | |
center of the box to every edge.""" | |
# project bbox on the feature | |
f_bbox = bbox / stride | |
x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2]) | |
y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3]) | |
x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2]) | |
y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3]) | |
if featmap_size is not None: | |
x1 = x1.clamp(min=0, max=featmap_size[1]) | |
y1 = y1.clamp(min=0, max=featmap_size[0]) | |
x2 = x2.clamp(min=0, max=featmap_size[1]) | |
y2 = y2.clamp(min=0, max=featmap_size[0]) | |
return (x1, y1, x2, y2) | |
def anchor_ctr_inside_region_flags(anchors, stride, region): | |
"""Get the flag indicate whether anchor centers are inside regions.""" | |
x1, y1, x2, y2 = region | |
f_anchors = anchors / stride | |
x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5 | |
y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5 | |
flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2) | |
return flags | |
class RegionAssigner(BaseAssigner): | |
"""Assign a corresponding gt bbox or background to each bbox. | |
Each proposals will be assigned with `-1`, `0`, or a positive integer | |
indicating the ground truth index. | |
- -1: don't care | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
Args: | |
center_ratio: ratio of the region in the center of the bbox to | |
define positive sample. | |
ignore_ratio: ratio of the region to define ignore samples. | |
""" | |
def __init__(self, center_ratio=0.2, ignore_ratio=0.5): | |
self.center_ratio = center_ratio | |
self.ignore_ratio = ignore_ratio | |
def assign(self, | |
mlvl_anchors, | |
mlvl_valid_flags, | |
gt_bboxes, | |
img_meta, | |
featmap_sizes, | |
anchor_scale, | |
anchor_strides, | |
gt_bboxes_ignore=None, | |
gt_labels=None, | |
allowed_border=0): | |
"""Assign gt to anchors. | |
This method assign a gt bbox to every bbox (proposal/anchor), each bbox | |
will be assigned with -1, 0, or a positive number. -1 means don't care, | |
0 means negative sample, positive number is the index (1-based) of | |
assigned gt. | |
The assignment is done in following steps, and the order matters. | |
1. Assign every anchor to 0 (negative) | |
2. (For each gt_bboxes) Compute ignore flags based on ignore_region | |
then assign -1 to anchors w.r.t. ignore flags | |
3. (For each gt_bboxes) Compute pos flags based on center_region then | |
assign gt_bboxes to anchors w.r.t. pos flags | |
4. (For each gt_bboxes) Compute ignore flags based on adjacent anchor | |
level then assign -1 to anchors w.r.t. ignore flags | |
5. Assign anchor outside of image to -1 | |
Args: | |
mlvl_anchors (list[Tensor]): Multi level anchors. | |
mlvl_valid_flags (list[Tensor]): Multi level valid flags. | |
gt_bboxes (Tensor): Ground truth bboxes of image | |
img_meta (dict): Meta info of image. | |
featmap_sizes (list[Tensor]): Feature mapsize each level | |
anchor_scale (int): Scale of the anchor. | |
anchor_strides (list[int]): Stride of the anchor. | |
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). | |
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are | |
labelled as `ignored`, e.g., crowd boxes in COCO. | |
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). | |
allowed_border (int, optional): The border to allow the valid | |
anchor. Defaults to 0. | |
Returns: | |
:obj:`AssignResult`: The assign result. | |
""" | |
if gt_bboxes_ignore is not None: | |
raise NotImplementedError | |
num_gts = gt_bboxes.shape[0] | |
num_bboxes = sum(x.shape[0] for x in mlvl_anchors) | |
if num_gts == 0 or num_bboxes == 0: | |
# No ground truth or boxes, return empty assignment | |
max_overlaps = gt_bboxes.new_zeros((num_bboxes, )) | |
assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ), | |
dtype=torch.long) | |
if gt_labels is None: | |
assigned_labels = None | |
else: | |
assigned_labels = gt_bboxes.new_full((num_bboxes, ), | |
-1, | |
dtype=torch.long) | |
return AssignResult( | |
num_gts, | |
assigned_gt_inds, | |
max_overlaps, | |
labels=assigned_labels) | |
num_lvls = len(mlvl_anchors) | |
r1 = (1 - self.center_ratio) / 2 | |
r2 = (1 - self.ignore_ratio) / 2 | |
scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * | |
(gt_bboxes[:, 3] - gt_bboxes[:, 1])) | |
min_anchor_size = scale.new_full( | |
(1, ), float(anchor_scale * anchor_strides[0])) | |
target_lvls = torch.floor( | |
torch.log2(scale) - torch.log2(min_anchor_size) + 0.5) | |
target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long() | |
# 1. assign 0 (negative) by default | |
mlvl_assigned_gt_inds = [] | |
mlvl_ignore_flags = [] | |
for lvl in range(num_lvls): | |
h, w = featmap_sizes[lvl] | |
assert h * w == mlvl_anchors[lvl].shape[0] | |
assigned_gt_inds = gt_bboxes.new_full((h * w, ), | |
0, | |
dtype=torch.long) | |
ignore_flags = torch.zeros_like(assigned_gt_inds) | |
mlvl_assigned_gt_inds.append(assigned_gt_inds) | |
mlvl_ignore_flags.append(ignore_flags) | |
for gt_id in range(num_gts): | |
lvl = target_lvls[gt_id].item() | |
featmap_size = featmap_sizes[lvl] | |
stride = anchor_strides[lvl] | |
anchors = mlvl_anchors[lvl] | |
gt_bbox = gt_bboxes[gt_id, :4] | |
# Compute regions | |
ignore_region = calc_region(gt_bbox, r2, stride, featmap_size) | |
ctr_region = calc_region(gt_bbox, r1, stride, featmap_size) | |
# 2. Assign -1 to ignore flags | |
ignore_flags = anchor_ctr_inside_region_flags( | |
anchors, stride, ignore_region) | |
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1 | |
# 3. Assign gt_bboxes to pos flags | |
pos_flags = anchor_ctr_inside_region_flags(anchors, stride, | |
ctr_region) | |
mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1 | |
# 4. Assign -1 to ignore adjacent lvl | |
if lvl > 0: | |
d_lvl = lvl - 1 | |
d_anchors = mlvl_anchors[d_lvl] | |
d_featmap_size = featmap_sizes[d_lvl] | |
d_stride = anchor_strides[d_lvl] | |
d_ignore_region = calc_region(gt_bbox, r2, d_stride, | |
d_featmap_size) | |
ignore_flags = anchor_ctr_inside_region_flags( | |
d_anchors, d_stride, d_ignore_region) | |
mlvl_ignore_flags[d_lvl][ignore_flags] = 1 | |
if lvl < num_lvls - 1: | |
u_lvl = lvl + 1 | |
u_anchors = mlvl_anchors[u_lvl] | |
u_featmap_size = featmap_sizes[u_lvl] | |
u_stride = anchor_strides[u_lvl] | |
u_ignore_region = calc_region(gt_bbox, r2, u_stride, | |
u_featmap_size) | |
ignore_flags = anchor_ctr_inside_region_flags( | |
u_anchors, u_stride, u_ignore_region) | |
mlvl_ignore_flags[u_lvl][ignore_flags] = 1 | |
# 4. (cont.) Assign -1 to ignore adjacent lvl | |
for lvl in range(num_lvls): | |
ignore_flags = mlvl_ignore_flags[lvl] | |
mlvl_assigned_gt_inds[lvl][ignore_flags] = -1 | |
# 5. Assign -1 to anchor outside of image | |
flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds) | |
flat_anchors = torch.cat(mlvl_anchors) | |
flat_valid_flags = torch.cat(mlvl_valid_flags) | |
assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] == | |
flat_valid_flags.shape[0]) | |
inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags, | |
img_meta['img_shape'], | |
allowed_border) | |
outside_flags = ~inside_flags | |
flat_assigned_gt_inds[outside_flags] = -1 | |
if gt_labels is not None: | |
assigned_labels = torch.zeros_like(flat_assigned_gt_inds) | |
pos_flags = assigned_gt_inds > 0 | |
assigned_labels[pos_flags] = gt_labels[ | |
flat_assigned_gt_inds[pos_flags] - 1] | |
else: | |
assigned_labels = None | |
return AssignResult( | |
num_gts, flat_assigned_gt_inds, None, labels=assigned_labels) | |