Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from ..builder import BBOX_ASSIGNERS | |
from .assign_result import AssignResult | |
from .base_assigner import BaseAssigner | |
class PointAssigner(BaseAssigner): | |
"""Assign a corresponding gt bbox or background to each point. | |
Each proposals will be assigned with `0`, or a positive integer | |
indicating the ground truth index. | |
- 0: negative sample, no assigned gt | |
- positive integer: positive sample, index (1-based) of assigned gt | |
""" | |
def __init__(self, scale=4, pos_num=3): | |
self.scale = scale | |
self.pos_num = pos_num | |
def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): | |
"""Assign gt to points. | |
This method assign a gt bbox to every points set, each points set | |
will be assigned with the background_label (-1), or a label number. | |
-1 is background, and semi-positive number is the index (0-based) of | |
assigned gt. | |
The assignment is done in following steps, the order matters. | |
1. assign every points to the background_label (-1) | |
2. A point is assigned to some gt bbox if | |
(i) the point is within the k closest points to the gt bbox | |
(ii) the distance between this point and the gt is smaller than | |
other gt bboxes | |
Args: | |
points (Tensor): points to be assigned, shape(n, 3) while last | |
dimension stands for (x, y, stride). | |
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). | |
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are | |
labelled as `ignored`, e.g., crowd boxes in COCO. | |
NOTE: currently unused. | |
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). | |
Returns: | |
:obj:`AssignResult`: The assign result. | |
""" | |
num_points = points.shape[0] | |
num_gts = gt_bboxes.shape[0] | |
if num_gts == 0 or num_points == 0: | |
# If no truth assign everything to the background | |
assigned_gt_inds = points.new_full((num_points, ), | |
0, | |
dtype=torch.long) | |
if gt_labels is None: | |
assigned_labels = None | |
else: | |
assigned_labels = points.new_full((num_points, ), | |
-1, | |
dtype=torch.long) | |
return AssignResult( | |
num_gts, assigned_gt_inds, None, labels=assigned_labels) | |
points_xy = points[:, :2] | |
points_stride = points[:, 2] | |
points_lvl = torch.log2( | |
points_stride).int() # [3...,4...,5...,6...,7...] | |
lvl_min, lvl_max = points_lvl.min(), points_lvl.max() | |
# assign gt box | |
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2 | |
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6) | |
scale = self.scale | |
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) + | |
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int() | |
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max) | |
# stores the assigned gt index of each point | |
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long) | |
# stores the assigned gt dist (to this point) of each point | |
assigned_gt_dist = points.new_full((num_points, ), float('inf')) | |
points_range = torch.arange(points.shape[0]) | |
for idx in range(num_gts): | |
gt_lvl = gt_bboxes_lvl[idx] | |
# get the index of points in this level | |
lvl_idx = gt_lvl == points_lvl | |
points_index = points_range[lvl_idx] | |
# get the points in this level | |
lvl_points = points_xy[lvl_idx, :] | |
# get the center point of gt | |
gt_point = gt_bboxes_xy[[idx], :] | |
# get width and height of gt | |
gt_wh = gt_bboxes_wh[[idx], :] | |
# compute the distance between gt center and | |
# all points in this level | |
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1) | |
# find the nearest k points to gt center in this level | |
min_dist, min_dist_index = torch.topk( | |
points_gt_dist, self.pos_num, largest=False) | |
# the index of nearest k points to gt center in this level | |
min_dist_points_index = points_index[min_dist_index] | |
# The less_than_recorded_index stores the index | |
# of min_dist that is less then the assigned_gt_dist. Where | |
# assigned_gt_dist stores the dist from previous assigned gt | |
# (if exist) to each point. | |
less_than_recorded_index = min_dist < assigned_gt_dist[ | |
min_dist_points_index] | |
# The min_dist_points_index stores the index of points satisfy: | |
# (1) it is k nearest to current gt center in this level. | |
# (2) it is closer to current gt center than other gt center. | |
min_dist_points_index = min_dist_points_index[ | |
less_than_recorded_index] | |
# assign the result | |
assigned_gt_inds[min_dist_points_index] = idx + 1 | |
assigned_gt_dist[min_dist_points_index] = min_dist[ | |
less_than_recorded_index] | |
if gt_labels is not None: | |
assigned_labels = assigned_gt_inds.new_full((num_points, ), -1) | |
pos_inds = torch.nonzero( | |
assigned_gt_inds > 0, as_tuple=False).squeeze() | |
if pos_inds.numel() > 0: | |
assigned_labels[pos_inds] = gt_labels[ | |
assigned_gt_inds[pos_inds] - 1] | |
else: | |
assigned_labels = None | |
return AssignResult( | |
num_gts, assigned_gt_inds, None, labels=assigned_labels) | |