Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from ..builder import BBOX_SAMPLERS | |
from .base_sampler import BaseSampler | |
class RandomSampler(BaseSampler): | |
"""Random sampler. | |
Args: | |
num (int): Number of samples | |
pos_fraction (float): Fraction of positive samples | |
neg_pos_ub (int, optional): Upper bound number of negative and | |
positive samples. Defaults to -1. | |
add_gt_as_proposals (bool, optional): Whether to add ground truth | |
boxes as proposals. Defaults to True. | |
""" | |
def __init__(self, | |
num, | |
pos_fraction, | |
neg_pos_ub=-1, | |
add_gt_as_proposals=True, | |
**kwargs): | |
from mmdet.core.bbox import demodata | |
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, | |
add_gt_as_proposals) | |
self.rng = demodata.ensure_rng(kwargs.get('rng', None)) | |
def random_choice(self, gallery, num): | |
"""Random select some elements from the gallery. | |
If `gallery` is a Tensor, the returned indices will be a Tensor; | |
If `gallery` is a ndarray or list, the returned indices will be a | |
ndarray. | |
Args: | |
gallery (Tensor | ndarray | list): indices pool. | |
num (int): expected sample num. | |
Returns: | |
Tensor or ndarray: sampled indices. | |
""" | |
assert len(gallery) >= num | |
is_tensor = isinstance(gallery, torch.Tensor) | |
if not is_tensor: | |
if torch.cuda.is_available(): | |
device = torch.cuda.current_device() | |
else: | |
device = 'cpu' | |
gallery = torch.tensor(gallery, dtype=torch.long, device=device) | |
# This is a temporary fix. We can revert the following code | |
# when PyTorch fixes the abnormal return of torch.randperm. | |
# See: https://github.com/open-mmlab/mmdetection/pull/5014 | |
perm = torch.randperm(gallery.numel())[:num].to(device=gallery.device) | |
rand_inds = gallery[perm] | |
if not is_tensor: | |
rand_inds = rand_inds.cpu().numpy() | |
return rand_inds | |
def _sample_pos(self, assign_result, num_expected, **kwargs): | |
"""Randomly sample some positive samples.""" | |
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False) | |
if pos_inds.numel() != 0: | |
pos_inds = pos_inds.squeeze(1) | |
if pos_inds.numel() <= num_expected: | |
return pos_inds | |
else: | |
return self.random_choice(pos_inds, num_expected) | |
def _sample_neg(self, assign_result, num_expected, **kwargs): | |
"""Randomly sample some negative samples.""" | |
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False) | |
if neg_inds.numel() != 0: | |
neg_inds = neg_inds.squeeze(1) | |
if len(neg_inds) <= num_expected: | |
return neg_inds | |
else: | |
return self.random_choice(neg_inds, num_expected) | |