Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn.functional as F | |
from mmcv.runner import force_fp32 | |
from ..builder import HEADS | |
from ..losses import smooth_l1_loss | |
from .ascend_anchor_head import AscendAnchorHead | |
from .ssd_head import SSDHead | |
class AscendSSDHead(SSDHead, AscendAnchorHead): | |
"""Ascend SSD head used in https://arxiv.org/abs/1512.02325. | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
stacked_convs (int): Number of conv layers in cls and reg tower. | |
Default: 0. | |
feat_channels (int): Number of hidden channels when stacked_convs | |
> 0. Default: 256. | |
use_depthwise (bool): Whether to use DepthwiseSeparableConv. | |
Default: False. | |
conv_cfg (dict): Dictionary to construct and config conv layer. | |
Default: None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Default: None. | |
act_cfg (dict): Dictionary to construct and config activation layer. | |
Default: None. | |
anchor_generator (dict): Config dict for anchor generator | |
bbox_coder (dict): Config of bounding box coder. | |
reg_decoded_bbox (bool): If true, the regression loss would be | |
applied directly on decoded bounding boxes, converting both | |
the predicted boxes and regression targets to absolute | |
coordinates format. Default False. It should be `True` when | |
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. | |
train_cfg (dict): Training config of anchor head. | |
test_cfg (dict): Testing config of anchor head. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" # noqa: W605 | |
def __init__(self, | |
num_classes=80, | |
in_channels=(512, 1024, 512, 256, 256, 256), | |
stacked_convs=0, | |
feat_channels=256, | |
use_depthwise=False, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=None, | |
anchor_generator=dict( | |
type='SSDAnchorGenerator', | |
scale_major=False, | |
input_size=300, | |
strides=[8, 16, 32, 64, 100, 300], | |
ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), | |
basesize_ratio_range=(0.1, 0.9)), | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
clip_border=True, | |
target_means=[.0, .0, .0, .0], | |
target_stds=[1.0, 1.0, 1.0, 1.0], | |
), | |
reg_decoded_bbox=False, | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=dict( | |
type='Xavier', | |
layer='Conv2d', | |
distribution='uniform', | |
bias=0)): | |
super(AscendSSDHead, self).__init__( | |
num_classes=num_classes, | |
in_channels=in_channels, | |
stacked_convs=stacked_convs, | |
feat_channels=feat_channels, | |
use_depthwise=use_depthwise, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
anchor_generator=anchor_generator, | |
bbox_coder=bbox_coder, | |
reg_decoded_bbox=reg_decoded_bbox, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
init_cfg=init_cfg) | |
assert self.reg_decoded_bbox is False, \ | |
'reg_decoded_bbox only support False now.' | |
def get_static_anchors(self, featmap_sizes, img_metas, device='cuda'): | |
"""Get static anchors according to feature map sizes. | |
Args: | |
featmap_sizes (list[tuple]): Multi-level feature map sizes. | |
img_metas (list[dict]): Image meta info. | |
device (torch.device | str): Device for returned tensors | |
Returns: | |
tuple: | |
anchor_list (list[Tensor]): Anchors of each image. | |
valid_flag_list (list[Tensor]): Valid flags of each image. | |
""" | |
if not hasattr(self, 'static_anchors') or \ | |
not hasattr(self, 'static_valid_flags'): | |
static_anchors, static_valid_flags = self.get_anchors( | |
featmap_sizes, img_metas, device) | |
self.static_anchors = static_anchors | |
self.static_valid_flags = static_valid_flags | |
return self.static_anchors, self.static_valid_flags | |
def get_targets(self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
img_metas, | |
gt_bboxes_ignore_list=None, | |
gt_labels_list=None, | |
label_channels=1, | |
unmap_outputs=True, | |
return_sampling_results=False, | |
return_level=True): | |
"""Compute regression and classification targets for anchors in | |
multiple images. | |
Args: | |
anchor_list (list[list[Tensor]]): Multi level anchors of each | |
image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, 4). | |
valid_flag_list (list[list[Tensor]]): Multi level valid flags of | |
each image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, ) | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. | |
img_metas (list[dict]): Meta info of each image. | |
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be | |
ignored. | |
gt_labels_list (list[Tensor]): Ground truth labels of each box. | |
label_channels (int): Channel of label. | |
unmap_outputs (bool): Whether to map outputs back to the original | |
set of anchors. | |
return_sampling_results (bool): Whether to return the result of | |
sample. | |
return_level (bool): Whether to map outputs back to the levels | |
of feature map sizes. | |
Returns: | |
tuple: Usually returns a tuple containing learning targets. | |
- labels_list (list[Tensor]): Labels of each level. | |
- label_weights_list (list[Tensor]): Label weights of each | |
level. | |
- bbox_targets_list (list[Tensor]): BBox targets of each level. | |
- bbox_weights_list (list[Tensor]): BBox weights of each level. | |
- num_total_pos (int): Number of positive samples in all | |
images. | |
- num_total_neg (int): Number of negative samples in all | |
images. | |
additional_returns: This function enables user-defined returns from | |
`self._get_targets_single`. These returns are currently refined | |
to properties at each feature map (i.e. having HxW dimension). | |
The results will be concatenated after the end | |
""" | |
return AscendAnchorHead.get_targets( | |
self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
img_metas, | |
gt_bboxes_ignore_list, | |
gt_labels_list, | |
label_channels, | |
unmap_outputs, | |
return_sampling_results, | |
return_level, | |
) | |
def batch_loss(self, batch_cls_score, batch_bbox_pred, batch_anchor, | |
batch_labels, batch_label_weights, batch_bbox_targets, | |
batch_bbox_weights, batch_pos_mask, batch_neg_mask, | |
num_total_samples): | |
"""Compute loss of all images. | |
Args: | |
batch_cls_score (Tensor): Box scores for all image | |
Has shape (num_imgs, num_total_anchors, num_classes). | |
batch_bbox_pred (Tensor): Box energies / deltas for all image | |
level with shape (num_imgs, num_total_anchors, 4). | |
batch_anchor (Tensor): Box reference for all image with shape | |
(num_imgs, num_total_anchors, 4). | |
batch_labels (Tensor): Labels of all anchors with shape | |
(num_imgs, num_total_anchors,). | |
batch_label_weights (Tensor): Label weights of all anchor with | |
shape (num_imgs, num_total_anchors,) | |
batch_bbox_targets (Tensor): BBox regression targets of all anchor | |
weight shape (num_imgs, num_total_anchors, 4). | |
batch_bbox_weights (Tensor): BBox regression loss weights of | |
all anchor with shape (num_imgs, num_total_anchors, 4). | |
batch_pos_mask (Tensor): Positive samples mask in all images. | |
batch_neg_mask (Tensor): negative samples mask in all images. | |
num_total_samples (int): If sampling, num total samples equal to | |
the number of total anchors; Otherwise, it is the number of | |
positive anchors. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
num_images, num_anchors, _ = batch_anchor.size() | |
batch_loss_cls_all = F.cross_entropy( | |
batch_cls_score.view((-1, self.cls_out_channels)), | |
batch_labels.view(-1), | |
reduction='none').view( | |
batch_label_weights.size()) * batch_label_weights | |
# # FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
batch_num_pos_samples = torch.sum(batch_pos_mask, dim=1) | |
batch_num_neg_samples = \ | |
self.train_cfg.neg_pos_ratio * batch_num_pos_samples | |
batch_num_neg_samples_max = torch.sum(batch_neg_mask, dim=1) | |
batch_num_neg_samples = torch.min(batch_num_neg_samples, | |
batch_num_neg_samples_max) | |
batch_topk_loss_cls_neg, _ = torch.topk( | |
batch_loss_cls_all * batch_neg_mask, k=num_anchors, dim=1) | |
batch_loss_cls_pos = torch.sum( | |
batch_loss_cls_all * batch_pos_mask, dim=1) | |
anchor_index = torch.arange( | |
end=num_anchors, dtype=torch.float, | |
device=batch_anchor.device).view((1, -1)) | |
topk_loss_neg_mask = (anchor_index < batch_num_neg_samples.view( | |
-1, 1)).float() | |
batch_loss_cls_neg = torch.sum( | |
batch_topk_loss_cls_neg * topk_loss_neg_mask, dim=1) | |
loss_cls = \ | |
(batch_loss_cls_pos + batch_loss_cls_neg) / num_total_samples | |
if self.reg_decoded_bbox: | |
# TODO: support self.reg_decoded_bbox is True | |
raise RuntimeError | |
loss_bbox_all = smooth_l1_loss( | |
batch_bbox_pred, | |
batch_bbox_targets, | |
batch_bbox_weights, | |
reduction='none', | |
beta=self.train_cfg.smoothl1_beta, | |
avg_factor=num_total_samples) | |
eps = torch.finfo(torch.float32).eps | |
sum_dim = (i for i in range(1, len(loss_bbox_all.size()))) | |
loss_bbox = loss_bbox_all.sum(tuple(sum_dim)) / ( | |
num_total_samples + eps) | |
return loss_cls[None], loss_bbox | |
def loss(self, | |
cls_scores, | |
bbox_preds, | |
gt_bboxes, | |
gt_labels, | |
img_metas, | |
gt_bboxes_ignore=None): | |
"""Compute losses of the head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level | |
Has shape (N, num_anchors * num_classes, H, W) | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level with shape (N, num_anchors * 4, H, W) | |
gt_bboxes (list[Tensor]): each item are the truth boxes for each | |
image in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): class indices corresponding to each box | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |
boxes can be ignored when computing the loss. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
assert len(featmap_sizes) == self.prior_generator.num_levels | |
device = cls_scores[0].device | |
anchor_list, valid_flag_list = self.get_anchors( | |
featmap_sizes, img_metas, device=device) | |
cls_reg_targets = self.get_targets( | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes, | |
img_metas, | |
gt_bboxes_ignore_list=gt_bboxes_ignore, | |
gt_labels_list=gt_labels, | |
label_channels=1, | |
unmap_outputs=True, | |
return_level=False) | |
if cls_reg_targets is None: | |
return None | |
(batch_labels, batch_label_weights, batch_bbox_targets, | |
batch_bbox_weights, batch_pos_mask, batch_neg_mask, sampling_result, | |
num_total_pos, num_total_neg, batch_anchors) = cls_reg_targets | |
num_imgs = len(img_metas) | |
batch_cls_score = torch.cat([ | |
s.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) | |
for s in cls_scores | |
], 1) | |
batch_bbox_pred = torch.cat([ | |
b.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for b in bbox_preds | |
], -2) | |
batch_losses_cls, batch_losses_bbox = self.batch_loss( | |
batch_cls_score, batch_bbox_pred, batch_anchors, batch_labels, | |
batch_label_weights, batch_bbox_targets, batch_bbox_weights, | |
batch_pos_mask, batch_neg_mask, num_total_pos) | |
losses_cls = [ | |
batch_losses_cls[:, index_imgs] for index_imgs in range(num_imgs) | |
] | |
losses_bbox = [losses_bbox for losses_bbox in batch_losses_bbox] | |
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) | |