Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
import torch | |
from .sampling_result import SamplingResult | |
class BaseSampler(metaclass=ABCMeta): | |
"""Base class of samplers.""" | |
def __init__(self, | |
num, | |
pos_fraction, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True, | |
**kwargs): | |
self.num = num | |
self.pos_fraction = pos_fraction | |
self.neg_pos_ub = neg_pos_ub | |
self.add_gt_as_proposals = add_gt_as_proposals | |
self.pos_sampler = self | |
self.neg_sampler = self | |
def _sample_pos(self, assign_result, num_expected, **kwargs): | |
"""Sample positive samples.""" | |
pass | |
def _sample_neg(self, assign_result, num_expected, **kwargs): | |
"""Sample negative samples.""" | |
pass | |
def sample(self, | |
assign_result, | |
bboxes, | |
gt_bboxes, | |
gt_labels=None, | |
**kwargs): | |
"""Sample positive and negative bboxes. | |
This is a simple implementation of bbox sampling given candidates, | |
assigning results and ground truth bboxes. | |
Args: | |
assign_result (:obj:`AssignResult`): Bbox assigning results. | |
bboxes (Tensor): Boxes to be sampled from. | |
gt_bboxes (Tensor): Ground truth bboxes. | |
gt_labels (Tensor, optional): Class labels of ground truth bboxes. | |
Returns: | |
:obj:`SamplingResult`: Sampling result. | |
Example: | |
>>> from mmdet.core.bbox import RandomSampler | |
>>> from mmdet.core.bbox import AssignResult | |
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes | |
>>> rng = ensure_rng(None) | |
>>> assign_result = AssignResult.random(rng=rng) | |
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng) | |
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) | |
>>> gt_labels = None | |
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, | |
>>> add_gt_as_proposals=False) | |
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) | |
""" | |
if len(bboxes.shape) < 2: | |
bboxes = bboxes[None, :] | |
bboxes = bboxes[:, :4] | |
gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) | |
if self.add_gt_as_proposals and len(gt_bboxes) > 0: | |
if gt_labels is None: | |
raise ValueError( | |
'gt_labels must be given when add_gt_as_proposals is True') | |
bboxes = torch.cat([gt_bboxes, bboxes], dim=0) | |
assign_result.add_gt_(gt_labels) | |
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) | |
gt_flags = torch.cat([gt_ones, gt_flags]) | |
num_expected_pos = int(self.num * self.pos_fraction) | |
pos_inds = self.pos_sampler._sample_pos( | |
assign_result, num_expected_pos, bboxes=bboxes, **kwargs) | |
# We found that sampled indices have duplicated items occasionally. | |
# (may be a bug of PyTorch) | |
pos_inds = pos_inds.unique() | |
num_sampled_pos = pos_inds.numel() | |
num_expected_neg = self.num - num_sampled_pos | |
if self.neg_pos_ub >= 0: | |
_pos = max(1, num_sampled_pos) | |
neg_upper_bound = int(self.neg_pos_ub * _pos) | |
if num_expected_neg > neg_upper_bound: | |
num_expected_neg = neg_upper_bound | |
neg_inds = self.neg_sampler._sample_neg( | |
assign_result, num_expected_neg, bboxes=bboxes, **kwargs) | |
neg_inds = neg_inds.unique() | |
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, | |
assign_result, gt_flags) | |
return sampling_result | |