RockeyCoss
add code files”
51f6859
raw
history blame
5.39 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.utils import util_mixins
class SamplingResult(util_mixins.NiceRepr):
"""Bbox sampling result.
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print(f'self = {self}')
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print(f'self = {self.to(None)}')
>>> # xdoctest: +REQUIRES(--gpu)
>>> print(f'self = {self.to(0)}')
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state.
kwargs (keyword arguments):
- num_preds: number of predicted boxes
- num_gts: number of true boxes
- p_ignore (float): probability of a predicted box assigned to \
an ignored truth.
- p_assigned (float): probability of a predicted box not being \
assigned.
- p_use_label (float | bool): with labels or not.
Returns:
:obj:`SamplingResult`: Randomly generated sampling result.
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox import demodata
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
rng = demodata.ensure_rng(rng)
# make probabilistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabilistic?
sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ub=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self