RockeyCoss
add code files”
51f6859
raw
history blame
25.7 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from torch.nn.modules.utils import _pair
from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
from mmdet.models.builder import HEADS, build_loss
from mmdet.models.losses import accuracy
from mmdet.models.utils import build_linear_layer
@HEADS.register_module()
class BBoxHead(BaseModule):
"""Simplest RoI head, with only two fc layers for classification and
regression respectively."""
def __init__(self,
with_avg_pool=False,
with_cls=True,
with_reg=True,
roi_feat_size=7,
in_channels=256,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
clip_border=True,
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=False,
reg_decoded_bbox=False,
reg_predictor_cfg=dict(type='Linear'),
cls_predictor_cfg=dict(type='Linear'),
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss', beta=1.0, loss_weight=1.0),
init_cfg=None):
super(BBoxHead, self).__init__(init_cfg)
assert with_cls or with_reg
self.with_avg_pool = with_avg_pool
self.with_cls = with_cls
self.with_reg = with_reg
self.roi_feat_size = _pair(roi_feat_size)
self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
self.in_channels = in_channels
self.num_classes = num_classes
self.reg_class_agnostic = reg_class_agnostic
self.reg_decoded_bbox = reg_decoded_bbox
self.reg_predictor_cfg = reg_predictor_cfg
self.cls_predictor_cfg = cls_predictor_cfg
self.fp16_enabled = False
self.bbox_coder = build_bbox_coder(bbox_coder)
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
in_channels = self.in_channels
if self.with_avg_pool:
self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
else:
in_channels *= self.roi_feat_area
if self.with_cls:
# need to add background class
if self.custom_cls_channels:
cls_channels = self.loss_cls.get_cls_channels(self.num_classes)
else:
cls_channels = num_classes + 1
self.fc_cls = build_linear_layer(
self.cls_predictor_cfg,
in_features=in_channels,
out_features=cls_channels)
if self.with_reg:
out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
self.fc_reg = build_linear_layer(
self.reg_predictor_cfg,
in_features=in_channels,
out_features=out_dim_reg)
self.debug_imgs = None
if init_cfg is None:
self.init_cfg = []
if self.with_cls:
self.init_cfg += [
dict(
type='Normal', std=0.01, override=dict(name='fc_cls'))
]
if self.with_reg:
self.init_cfg += [
dict(
type='Normal', std=0.001, override=dict(name='fc_reg'))
]
@property
def custom_cls_channels(self):
return getattr(self.loss_cls, 'custom_cls_channels', False)
@property
def custom_activation(self):
return getattr(self.loss_cls, 'custom_activation', False)
@property
def custom_accuracy(self):
return getattr(self.loss_cls, 'custom_accuracy', False)
@auto_fp16()
def forward(self, x):
if self.with_avg_pool:
if x.numel() > 0:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
else:
# avg_pool does not support empty tensor,
# so use torch.mean instead it
x = torch.mean(x, dim=(-1, -2))
cls_score = self.fc_cls(x) if self.with_cls else None
bbox_pred = self.fc_reg(x) if self.with_reg else None
return cls_score, bbox_pred
def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes,
pos_gt_labels, cfg):
"""Calculate the ground truth for proposals in the single image
according to the sampling results.
Args:
pos_bboxes (Tensor): Contains all the positive boxes,
has shape (num_pos, 4), the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
neg_bboxes (Tensor): Contains all the negative boxes,
has shape (num_neg, 4), the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
pos_gt_bboxes (Tensor): Contains gt_boxes for
all positive samples, has shape (num_pos, 4),
the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
pos_gt_labels (Tensor): Contains gt_labels for
all positive samples, has shape (num_pos, ).
cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
Returns:
Tuple[Tensor]: Ground truth for proposals
in a single image. Containing the following Tensors:
- labels(Tensor): Gt_labels for all proposals, has
shape (num_proposals,).
- label_weights(Tensor): Labels_weights for all
proposals, has shape (num_proposals,).
- bbox_targets(Tensor):Regression target for all
proposals, has shape (num_proposals, 4), the
last dimension 4 represents [tl_x, tl_y, br_x, br_y].
- bbox_weights(Tensor):Regression weights for all
proposals, has shape (num_proposals, 4).
"""
num_pos = pos_bboxes.size(0)
num_neg = neg_bboxes.size(0)
num_samples = num_pos + num_neg
# original implementation uses new_zeros since BG are set to be 0
# now use empty & fill because BG cat_id = num_classes,
# FG cat_id = [0, num_classes-1]
labels = pos_bboxes.new_full((num_samples, ),
self.num_classes,
dtype=torch.long)
label_weights = pos_bboxes.new_zeros(num_samples)
bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
if num_pos > 0:
labels[:num_pos] = pos_gt_labels
pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
label_weights[:num_pos] = pos_weight
if not self.reg_decoded_bbox:
pos_bbox_targets = self.bbox_coder.encode(
pos_bboxes, pos_gt_bboxes)
else:
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, both
# the predicted boxes and regression targets should be with
# absolute coordinate format.
pos_bbox_targets = pos_gt_bboxes
bbox_targets[:num_pos, :] = pos_bbox_targets
bbox_weights[:num_pos, :] = 1
if num_neg > 0:
label_weights[-num_neg:] = 1.0
return labels, label_weights, bbox_targets, bbox_weights
def get_targets(self,
sampling_results,
gt_bboxes,
gt_labels,
rcnn_train_cfg,
concat=True):
"""Calculate the ground truth for all samples in a batch according to
the sampling_results.
Almost the same as the implementation in bbox_head, we passed
additional parameters pos_inds_list and neg_inds_list to
`_get_target_single` function.
Args:
sampling_results (List[obj:SamplingResults]): Assign results of
all images in a batch after sampling.
gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
each tensor has shape (num_gt, 4), the last dimension 4
represents [tl_x, tl_y, br_x, br_y].
gt_labels (list[Tensor]): Gt_labels of all images in a batch,
each tensor has shape (num_gt,).
rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
concat (bool): Whether to concatenate the results of all
the images in a single batch.
Returns:
Tuple[Tensor]: Ground truth for proposals in a single image.
Containing the following list of Tensors:
- labels (list[Tensor],Tensor): Gt_labels for all
proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- label_weights (list[Tensor]): Labels_weights for
all proposals in a batch, each tensor in list has
shape (num_proposals,) when `concat=False`, otherwise
just a single tensor has shape (num_all_proposals,).
- bbox_targets (list[Tensor],Tensor): Regression target
for all proposals in a batch, each tensor in list
has shape (num_proposals, 4) when `concat=False`,
otherwise just a single tensor has shape
(num_all_proposals, 4), the last dimension 4 represents
[tl_x, tl_y, br_x, br_y].
- bbox_weights (list[tensor],Tensor): Regression weights for
all proposals in a batch, each tensor in list has shape
(num_proposals, 4) when `concat=False`, otherwise just a
single tensor has shape (num_all_proposals, 4).
"""
pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
labels, label_weights, bbox_targets, bbox_weights = multi_apply(
self._get_target_single,
pos_bboxes_list,
neg_bboxes_list,
pos_gt_bboxes_list,
pos_gt_labels_list,
cfg=rcnn_train_cfg)
if concat:
labels = torch.cat(labels, 0)
label_weights = torch.cat(label_weights, 0)
bbox_targets = torch.cat(bbox_targets, 0)
bbox_weights = torch.cat(bbox_weights, 0)
return labels, label_weights, bbox_targets, bbox_weights
@force_fp32(apply_to=('cls_score', 'bbox_pred'))
def loss(self,
cls_score,
bbox_pred,
rois,
labels,
label_weights,
bbox_targets,
bbox_weights,
reduction_override=None):
losses = dict()
if cls_score is not None:
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
if cls_score.numel() > 0:
loss_cls_ = self.loss_cls(
cls_score,
labels,
label_weights,
avg_factor=avg_factor,
reduction_override=reduction_override)
if isinstance(loss_cls_, dict):
losses.update(loss_cls_)
else:
losses['loss_cls'] = loss_cls_
if self.custom_activation:
acc_ = self.loss_cls.get_accuracy(cls_score, labels)
losses.update(acc_)
else:
losses['acc'] = accuracy(cls_score, labels)
if bbox_pred is not None:
bg_class_ind = self.num_classes
# 0~self.num_classes-1 are FG, self.num_classes is BG
pos_inds = (labels >= 0) & (labels < bg_class_ind)
# do not perform bounding box regression for BG anymore.
if pos_inds.any():
if self.reg_decoded_bbox:
# When the regression loss (e.g. `IouLoss`,
# `GIouLoss`, `DIouLoss`) is applied directly on
# the decoded bounding boxes, it decodes the
# already encoded coordinates to absolute format.
bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
if self.reg_class_agnostic:
pos_bbox_pred = bbox_pred.view(
bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
else:
pos_bbox_pred = bbox_pred.view(
bbox_pred.size(0), -1,
4)[pos_inds.type(torch.bool),
labels[pos_inds.type(torch.bool)]]
losses['loss_bbox'] = self.loss_bbox(
pos_bbox_pred,
bbox_targets[pos_inds.type(torch.bool)],
bbox_weights[pos_inds.type(torch.bool)],
avg_factor=bbox_targets.size(0),
reduction_override=reduction_override)
else:
losses['loss_bbox'] = bbox_pred[pos_inds].sum()
return losses
@force_fp32(apply_to=('cls_score', 'bbox_pred'))
def get_bboxes(self,
rois,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False,
cfg=None):
"""Transform network output for a batch into bbox predictions.
Args:
rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5).
last dimension 5 arrange as (batch_index, x1, y1, x2, y2).
cls_score (Tensor): Box scores, has shape
(num_boxes, num_classes + 1).
bbox_pred (Tensor, optional): Box energies / deltas.
has shape (num_boxes, num_classes * 4).
img_shape (Sequence[int], optional): Maximum bounds for boxes,
specifies (H, W, C) or (H, W).
scale_factor (ndarray): Scale factor of the
image arrange as (w_scale, h_scale, w_scale, h_scale).
rescale (bool): If True, return boxes in original image space.
Default: False.
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
Returns:
tuple[Tensor, Tensor]:
First tensor is `det_bboxes`, has the shape
(num_boxes, 5) and last
dimension 5 represent (tl_x, tl_y, br_x, br_y, score).
Second tensor is the labels with shape (num_boxes, ).
"""
# some loss (Seesaw loss..) may have custom activation
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_score)
else:
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
# bbox_pred would be None in some detector when with_reg is False,
# e.g. Grid R-CNN.
if bbox_pred is not None:
bboxes = self.bbox_coder.decode(
rois[..., 1:], bbox_pred, max_shape=img_shape)
else:
bboxes = rois[:, 1:].clone()
if img_shape is not None:
bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1])
bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0])
if rescale and bboxes.size(0) > 0:
scale_factor = bboxes.new_tensor(scale_factor)
bboxes = (bboxes.view(bboxes.size(0), -1, 4) / scale_factor).view(
bboxes.size()[0], -1)
if cfg is None:
return bboxes, scores
else:
det_bboxes, det_labels = multiclass_nms(bboxes, scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
@force_fp32(apply_to=('bbox_preds', ))
def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
"""Refine bboxes during training.
Args:
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
and bs is the sampled RoIs per image. The first column is
the image id and the next 4 columns are x1, y1, x2, y2.
labels (Tensor): Shape (n*bs, ).
bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
is a gt bbox.
img_metas (list[dict]): Meta info of each image.
Returns:
list[Tensor]: Refined bboxes of each image in a mini-batch.
Example:
>>> # xdoctest: +REQUIRES(module:kwarray)
>>> import kwarray
>>> import numpy as np
>>> from mmdet.core.bbox.demodata import random_boxes
>>> self = BBoxHead(reg_class_agnostic=True)
>>> n_roi = 2
>>> n_img = 4
>>> scale = 512
>>> rng = np.random.RandomState(0)
>>> img_metas = [{'img_shape': (scale, scale)}
... for _ in range(n_img)]
>>> # Create rois in the expected format
>>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
>>> img_ids = torch.randint(0, n_img, (n_roi,))
>>> img_ids = img_ids.float()
>>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
>>> # Create other args
>>> labels = torch.randint(0, 2, (n_roi,)).long()
>>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
>>> # For each image, pretend random positive boxes are gts
>>> is_label_pos = (labels.numpy() > 0).astype(np.int)
>>> lbl_per_img = kwarray.group_items(is_label_pos,
... img_ids.numpy())
>>> pos_per_img = [sum(lbl_per_img.get(gid, []))
... for gid in range(n_img)]
>>> pos_is_gts = [
>>> torch.randint(0, 2, (npos,)).byte().sort(
>>> descending=True)[0]
>>> for npos in pos_per_img
>>> ]
>>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
>>> pos_is_gts, img_metas)
>>> print(bboxes_list)
"""
img_ids = rois[:, 0].long().unique(sorted=True)
assert img_ids.numel() <= len(img_metas)
bboxes_list = []
for i in range(len(img_metas)):
inds = torch.nonzero(
rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
num_rois = inds.numel()
bboxes_ = rois[inds, 1:]
label_ = labels[inds]
bbox_pred_ = bbox_preds[inds]
img_meta_ = img_metas[i]
pos_is_gts_ = pos_is_gts[i]
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
img_meta_)
# filter gt bboxes
pos_keep = 1 - pos_is_gts_
keep_inds = pos_is_gts_.new_ones(num_rois)
keep_inds[:len(pos_is_gts_)] = pos_keep
bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
return bboxes_list
@force_fp32(apply_to=('bbox_pred', ))
def regress_by_class(self, rois, label, bbox_pred, img_meta):
"""Regress the bbox for the predicted class. Used in Cascade R-CNN.
Args:
rois (Tensor): Rois from `rpn_head` or last stage
`bbox_head`, has shape (num_proposals, 4) or
(num_proposals, 5).
label (Tensor): Only used when `self.reg_class_agnostic`
is False, has shape (num_proposals, ).
bbox_pred (Tensor): Regression prediction of
current stage `bbox_head`. When `self.reg_class_agnostic`
is False, it has shape (n, num_classes * 4), otherwise
it has shape (n, 4).
img_meta (dict): Image meta info.
Returns:
Tensor: Regressed bboxes, the same shape as input rois.
"""
assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
if not self.reg_class_agnostic:
label = label * 4
inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
bbox_pred = torch.gather(bbox_pred, 1, inds)
assert bbox_pred.size(1) == 4
max_shape = img_meta['img_shape']
if rois.size(1) == 4:
new_rois = self.bbox_coder.decode(
rois, bbox_pred, max_shape=max_shape)
else:
bboxes = self.bbox_coder.decode(
rois[:, 1:], bbox_pred, max_shape=max_shape)
new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
return new_rois
def onnx_export(self,
rois,
cls_score,
bbox_pred,
img_shape,
cfg=None,
**kwargs):
"""Transform network output for a batch into bbox predictions.
Args:
rois (Tensor): Boxes to be transformed.
Has shape (B, num_boxes, 5)
cls_score (Tensor): Box scores. has shape
(B, num_boxes, num_classes + 1), 1 represent the background.
bbox_pred (Tensor, optional): Box energies / deltas for,
has shape (B, num_boxes, num_classes * 4) when.
img_shape (torch.Tensor): Shape of image.
cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
"""
assert rois.ndim == 3, 'Only support export two stage ' \
'model to ONNX ' \
'with batch dimension. '
if self.custom_cls_channels:
scores = self.loss_cls.get_activation(cls_score)
else:
scores = F.softmax(
cls_score, dim=-1) if cls_score is not None else None
if bbox_pred is not None:
bboxes = self.bbox_coder.decode(
rois[..., 1:], bbox_pred, max_shape=img_shape)
else:
bboxes = rois[..., 1:].clone()
if img_shape is not None:
max_shape = bboxes.new_tensor(img_shape)[..., :2]
min_xy = bboxes.new_tensor(0)
max_xy = torch.cat(
[max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
# Replace multiclass_nms with ONNX::NonMaxSuppression in deployment
from mmdet.core.export import add_dummy_nms_for_onnx
max_output_boxes_per_class = cfg.nms.get('max_output_boxes_per_class',
cfg.max_per_img)
iou_threshold = cfg.nms.get('iou_threshold', 0.5)
score_threshold = cfg.score_thr
nms_pre = cfg.get('deploy_nms_pre', -1)
scores = scores[..., :self.num_classes]
if self.reg_class_agnostic:
return add_dummy_nms_for_onnx(
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=nms_pre,
after_top_k=cfg.max_per_img)
else:
batch_size = scores.shape[0]
labels = torch.arange(
self.num_classes, dtype=torch.long).to(scores.device)
labels = labels.view(1, 1, -1).expand_as(scores)
labels = labels.reshape(batch_size, -1)
scores = scores.reshape(batch_size, -1)
bboxes = bboxes.reshape(batch_size, -1, 4)
max_size = torch.max(img_shape)
# Offset bboxes of each class so that bboxes of different labels
# do not overlap.
offsets = (labels * max_size + 1).unsqueeze(2)
bboxes_for_nms = bboxes + offsets
batch_dets, labels = add_dummy_nms_for_onnx(
bboxes_for_nms,
scores.unsqueeze(2),
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=nms_pre,
after_top_k=cfg.max_per_img,
labels=labels)
# Offset the bboxes back after dummy nms.
offsets = (labels * max_size + 1).unsqueeze(2)
# Indexing + inplace operation fails with dynamic shape in ONNX
# original style: batch_dets[..., :4] -= offsets
bboxes, scores = batch_dets[..., 0:4], batch_dets[..., 4:5]
bboxes -= offsets
batch_dets = torch.cat([bboxes, scores], dim=2)
return batch_dets, labels