Spaces:
Runtime error
Runtime error
File size: 5,389 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.utils import util_mixins
class SamplingResult(util_mixins.NiceRepr):
"""Bbox sampling result.
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print(f'self = {self}')
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print(f'self = {self.to(None)}')
>>> # xdoctest: +REQUIRES(--gpu)
>>> print(f'self = {self.to(0)}')
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state.
kwargs (keyword arguments):
- num_preds: number of predicted boxes
- num_gts: number of true boxes
- p_ignore (float): probability of a predicted box assigned to \
an ignored truth.
- p_assigned (float): probability of a predicted box not being \
assigned.
- p_use_label (float | bool): with labels or not.
Returns:
:obj:`SamplingResult`: Randomly generated sampling result.
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox import demodata
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
rng = demodata.ensure_rng(rng)
# make probabilistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabilistic?
sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ub=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self
|