Spaces:
Runtime error
Runtime error
File size: 39,707 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, Linear, build_activation_layer
from mmcv.cnn.bricks.transformer import FFN, build_positional_encoding
from mmcv.runner import force_fp32
from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
build_assigner, build_sampler, multi_apply,
reduce_mean)
from mmdet.models.utils import build_transformer
from ..builder import HEADS, build_loss
from .anchor_free_head import AnchorFreeHead
@HEADS.register_module()
class DETRHead(AnchorFreeHead):
"""Implements the DETR transformer head.
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
num_classes (int): Number of categories excluding the background.
in_channels (int): Number of channels in the input feature map.
num_query (int): Number of query in Transformer.
num_reg_fcs (int, optional): Number of fully-connected layers used in
`FFN`, which is then used for the regression head. Default 2.
transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer.
Default: None.
sync_cls_avg_factor (bool): Whether to sync the avg_factor of
all ranks. Default to False.
positional_encoding (obj:`mmcv.ConfigDict`|dict):
Config for position encoding.
loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the
classification loss. Default `CrossEntropyLoss`.
loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the
regression loss. Default `L1Loss`.
loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the
regression iou loss. Default `GIoULoss`.
tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of
transformer head.
test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of
transformer head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
_version = 2
def __init__(self,
num_classes,
in_channels,
num_query=100,
num_reg_fcs=2,
transformer=None,
sync_cls_avg_factor=False,
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True),
loss_cls=dict(
type='CrossEntropyLoss',
bg_cls_weight=0.1,
use_sigmoid=False,
loss_weight=1.0,
class_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0),
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=5.0),
iou_cost=dict(
type='IoUCost', iou_mode='giou', weight=2.0))),
test_cfg=dict(max_per_img=100),
init_cfg=None,
**kwargs):
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
# since it brings inconvenience when the initialization of
# `AnchorFreeHead` is called.
super(AnchorFreeHead, self).__init__(init_cfg)
self.bg_cls_weight = 0
self.sync_cls_avg_factor = sync_cls_avg_factor
class_weight = loss_cls.get('class_weight', None)
if class_weight is not None and (self.__class__ is DETRHead):
assert isinstance(class_weight, float), 'Expected ' \
'class_weight to have type float. Found ' \
f'{type(class_weight)}.'
# NOTE following the official DETR rep0, bg_cls_weight means
# relative classification weight of the no-object class.
bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
assert isinstance(bg_cls_weight, float), 'Expected ' \
'bg_cls_weight to have type float. Found ' \
f'{type(bg_cls_weight)}.'
class_weight = torch.ones(num_classes + 1) * class_weight
# set background class as the last indice
class_weight[num_classes] = bg_cls_weight
loss_cls.update({'class_weight': class_weight})
if 'bg_cls_weight' in loss_cls:
loss_cls.pop('bg_cls_weight')
self.bg_cls_weight = bg_cls_weight
if train_cfg:
assert 'assigner' in train_cfg, 'assigner should be provided '\
'when train_cfg is set.'
assigner = train_cfg['assigner']
assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
'The classification weight for loss and matcher should be' \
'exactly the same.'
assert loss_bbox['loss_weight'] == assigner['reg_cost'][
'weight'], 'The regression L1 weight for loss and matcher ' \
'should be exactly the same.'
assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
'The regression iou weight for loss and matcher should be' \
'exactly the same.'
self.assigner = build_assigner(assigner)
# DETR sampling=False, so use PseudoSampler
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.num_query = num_query
self.num_classes = num_classes
self.in_channels = in_channels
self.num_reg_fcs = num_reg_fcs
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.fp16_enabled = False
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_iou = build_loss(loss_iou)
if self.loss_cls.use_sigmoid:
self.cls_out_channels = num_classes
else:
self.cls_out_channels = num_classes + 1
self.act_cfg = transformer.get('act_cfg',
dict(type='ReLU', inplace=True))
self.activate = build_activation_layer(self.act_cfg)
self.positional_encoding = build_positional_encoding(
positional_encoding)
self.transformer = build_transformer(transformer)
self.embed_dims = self.transformer.embed_dims
assert 'num_feats' in positional_encoding
num_feats = positional_encoding['num_feats']
assert num_feats * 2 == self.embed_dims, 'embed_dims should' \
f' be exactly 2 times of num_feats. Found {self.embed_dims}' \
f' and {num_feats}.'
self._init_layers()
def _init_layers(self):
"""Initialize layers of the transformer head."""
self.input_proj = Conv2d(
self.in_channels, self.embed_dims, kernel_size=1)
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
self.reg_ffn = FFN(
self.embed_dims,
self.embed_dims,
self.num_reg_fcs,
self.act_cfg,
dropout=0.0,
add_residual=False)
self.fc_reg = Linear(self.embed_dims, 4)
self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
def init_weights(self):
"""Initialize weights of the transformer head."""
# The initialization for transformer is important
self.transformer.init_weights()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""load checkpoints."""
# NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
# since `AnchorFreeHead._load_from_state_dict` should not be
# called here. Invoking the default `Module._load_from_state_dict`
# is enough.
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if (version is None or version < 2) and self.__class__ is DETRHead:
convert_dict = {
'.self_attn.': '.attentions.0.',
'.ffn.': '.ffns.0.',
'.multihead_attn.': '.attentions.1.',
'.decoder.norm.': '.decoder.post_norm.'
}
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
for ori_key, convert_key in convert_dict.items():
if ori_key in k:
convert_key = k.replace(ori_key, convert_key)
state_dict[convert_key] = state_dict[k]
del state_dict[k]
super(AnchorFreeHead,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
- all_cls_scores_list (list[Tensor]): Classification scores \
for each scale level. Each is a 4D-tensor with shape \
[nb_dec, bs, num_query, cls_out_channels]. Note \
`cls_out_channels` should includes background.
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
outputs for each scale level. Each is a 4D-tensor with \
normalized coordinate format (cx, cy, w, h) and shape \
[nb_dec, bs, num_query, 4].
"""
num_levels = len(feats)
img_metas_list = [img_metas for _ in range(num_levels)]
return multi_apply(self.forward_single, feats, img_metas_list)
def forward_single(self, x, img_metas):
""""Forward function for a single feature level.
Args:
x (Tensor): Input feature from backbone's single stage, shape
[bs, c, h, w].
img_metas (list[dict]): List of image information.
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
shape [nb_dec, bs, num_query, cls_out_channels]. Note
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h).
Shape [nb_dec, bs, num_query, 4].
"""
# construct binary masks which used for the transformer.
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
batch_size = x.size(0)
input_img_h, input_img_w = img_metas[0]['batch_input_shape']
masks = x.new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w, _ = img_metas[img_id]['img_shape']
masks[img_id, :img_h, :img_w] = 0
x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
# position encoding
pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
# outs_dec: [nb_dec, bs, num_query, embed_dim]
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
pos_embed)
all_cls_scores = self.fc_cls(outs_dec)
all_bbox_preds = self.fc_reg(self.activate(
self.reg_ffn(outs_dec))).sigmoid()
return all_cls_scores, all_bbox_preds
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
def loss(self,
all_cls_scores_list,
all_bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore=None):
""""Loss function.
Only outputs from the last feature level are used for computing
losses by default.
Args:
all_cls_scores_list (list[Tensor]): Classification outputs
for each feature level. Each is a 4D-tensor with shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds_list (list[Tensor]): Sigmoid regression
outputs for each feature level. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# NOTE defaultly only the outputs from the last feature scale is used.
all_cls_scores = all_cls_scores_list[-1]
all_bbox_preds = all_bbox_preds_list[-1]
assert gt_bboxes_ignore is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_dec_layers = len(all_cls_scores)
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [
gt_bboxes_ignore for _ in range(num_dec_layers)
]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_bbox, losses_iou = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,
all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
loss_dict['loss_iou'] = losses_iou[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
num_dec_layer += 1
return loss_dict
def loss_single(self,
cls_scores,
bbox_preds,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list,
img_metas, gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes across all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# construct factors used for rescale bboxes
factors = []
for img_meta, bbox_pred in zip(img_metas, bbox_preds):
img_h, img_w, _ = img_meta['img_shape']
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0).repeat(
bbox_pred.size(0), 1)
factors.append(factor)
factors = torch.cat(factors, 0)
# DETR regress the relative position of boxes (cxcywh) in the image,
# thus the learning target is normalized by the image size. So here
# we need to re-scale them for calculating IoU loss
bbox_preds = bbox_preds.reshape(-1, 4)
bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
# regression IoU loss, defaultly GIoU loss
loss_iou = self.loss_iou(
bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
# regression L1 loss
loss_bbox = self.loss_bbox(
bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
return loss_cls, loss_bbox, loss_iou
def get_targets(self,
cls_scores_list,
bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
img_metas,
gt_bboxes_ignore_list=None):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
img_metas (list[dict]): List of image meta information.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- bbox_targets_list (list[Tensor]): BBox targets for all \
images.
- bbox_weights_list (list[Tensor]): BBox weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [
gt_bboxes_ignore_list for _ in range(num_imgs)
]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self,
cls_score,
bbox_pred,
gt_bboxes,
gt_labels,
img_meta,
gt_bboxes_ignore=None):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
img_meta (dict): Meta information for one image.
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
num_bboxes = bbox_pred.size(0)
# assigner and sampler
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
gt_labels, img_meta,
gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, bbox_pred,
gt_bboxes)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label targets
labels = gt_bboxes.new_full((num_bboxes, ),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
img_h, img_w, _ = img_meta['img_shape']
# DETR regress the relative position of boxes (cxcywh) in the image.
# Thus the learning target should be normalized by the image size, also
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
factor = bbox_pred.new_tensor([img_w, img_h, img_w,
img_h]).unsqueeze(0)
pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
bbox_targets[pos_inds] = pos_gt_bboxes_targets
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
# over-write because img_metas are needed as inputs for bbox_head.
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""Forward function for training mode.
Args:
x (list[Tensor]): Features from backbone.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert proposal_cfg is None, '"proposal_cfg" must be None'
outs = self(x, img_metas)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
return losses
@force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
def get_bboxes(self,
all_cls_scores_list,
all_bbox_preds_list,
img_metas,
rescale=False):
"""Transform network outputs for a batch into bbox predictions.
Args:
all_cls_scores_list (list[Tensor]): Classification outputs
for each feature level. Each is a 4D-tensor with shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds_list (list[Tensor]): Sigmoid regression
outputs for each feature level. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
img_metas (list[dict]): Meta information of each image.
rescale (bool, optional): If True, return boxes in original
image space. Default False.
Returns:
list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
The first item is an (n, 5) tensor, where the first 4 columns \
are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
5-th column is a score between 0 and 1. The second item is a \
(n,) tensor where each item is the predicted class label of \
the corresponding box.
"""
# NOTE defaultly only using outputs from the last feature level,
# and only the outputs from the last decoder layer is used.
cls_scores = all_cls_scores_list[-1][-1]
bbox_preds = all_bbox_preds_list[-1][-1]
result_list = []
for img_id in range(len(img_metas)):
cls_score = cls_scores[img_id]
bbox_pred = bbox_preds[img_id]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score, bbox_pred,
img_shape, scale_factor,
rescale)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_score,
bbox_pred,
img_shape,
scale_factor,
rescale=False):
"""Transform outputs from the last decoder layer into bbox predictions
for each image.
Args:
cls_score (Tensor): Box score logits from the last decoder layer
for each image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
for each image, with coordinate format (cx, cy, w, h) and
shape [num_query, 4].
img_shape (tuple[int]): Shape of input image, (height, width, 3).
scale_factor (ndarray, optional): Scale factor of the image arange
as (w_scale, h_scale, w_scale, h_scale).
rescale (bool, optional): If True, return boxes in original image
space. Default False.
Returns:
tuple[Tensor]: Results of detected bboxes and labels.
- det_bboxes: Predicted bboxes with shape [num_query, 5], \
where the first 4 columns are bounding box positions \
(tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
between 0 and 1.
- det_labels: Predicted labels of the corresponding box with \
shape [num_query].
"""
assert len(cls_score) == len(bbox_pred)
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
# exclude background
if self.loss_cls.use_sigmoid:
cls_score = cls_score.sigmoid()
scores, indexes = cls_score.view(-1).topk(max_per_img)
det_labels = indexes % self.num_classes
bbox_index = indexes // self.num_classes
bbox_pred = bbox_pred[bbox_index]
else:
scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
scores, bbox_index = scores.topk(max_per_img)
bbox_pred = bbox_pred[bbox_index]
det_labels = det_labels[bbox_index]
det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
if rescale:
det_bboxes /= det_bboxes.new_tensor(scale_factor)
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
return det_bboxes, det_labels
def simple_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
The first item is ``bboxes`` with shape (n, 5),
where 5 represent (tl_x, tl_y, br_x, br_y, score).
The shape of the second tensor in the tuple is ``labels``
with shape (n,)
"""
# forward of this head requires img_metas
outs = self.forward(feats, img_metas)
results_list = self.get_bboxes(*outs, img_metas, rescale=rescale)
return results_list
def forward_onnx(self, feats, img_metas):
"""Forward function for exporting to ONNX.
Over-write `forward` because: `masks` is directly created with
zero (valid position tag) and has the same spatial size as `x`.
Thus the construction of `masks` is different from that in `forward`.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
- all_cls_scores_list (list[Tensor]): Classification scores \
for each scale level. Each is a 4D-tensor with shape \
[nb_dec, bs, num_query, cls_out_channels]. Note \
`cls_out_channels` should includes background.
- all_bbox_preds_list (list[Tensor]): Sigmoid regression \
outputs for each scale level. Each is a 4D-tensor with \
normalized coordinate format (cx, cy, w, h) and shape \
[nb_dec, bs, num_query, 4].
"""
num_levels = len(feats)
img_metas_list = [img_metas for _ in range(num_levels)]
return multi_apply(self.forward_single_onnx, feats, img_metas_list)
def forward_single_onnx(self, x, img_metas):
""""Forward function for a single feature level with ONNX exportation.
Args:
x (Tensor): Input feature from backbone's single stage, shape
[bs, c, h, w].
img_metas (list[dict]): List of image information.
Returns:
all_cls_scores (Tensor): Outputs from the classification head,
shape [nb_dec, bs, num_query, cls_out_channels]. Note
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h).
Shape [nb_dec, bs, num_query, 4].
"""
# Note `img_shape` is not dynamically traceable to ONNX,
# since the related augmentation was done with numpy under
# CPU. Thus `masks` is directly created with zeros (valid tag)
# and the same spatial shape as `x`.
# The difference between torch and exported ONNX model may be
# ignored, since the same performance is achieved (e.g.
# 40.1 vs 40.1 for DETR)
batch_size = x.size(0)
h, w = x.size()[-2:]
masks = x.new_zeros((batch_size, h, w)) # [B,h,w]
x = self.input_proj(x)
# interpolate masks to have the same spatial shape with x
masks = F.interpolate(
masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
pos_embed = self.positional_encoding(masks)
outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
pos_embed)
all_cls_scores = self.fc_cls(outs_dec)
all_bbox_preds = self.fc_reg(self.activate(
self.reg_ffn(outs_dec))).sigmoid()
return all_cls_scores, all_bbox_preds
def onnx_export(self, all_cls_scores_list, all_bbox_preds_list, img_metas):
"""Transform network outputs into bbox predictions, with ONNX
exportation.
Args:
all_cls_scores_list (list[Tensor]): Classification outputs
for each feature level. Each is a 4D-tensor with shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds_list (list[Tensor]): Sigmoid regression
outputs for each feature level. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
img_metas (list[dict]): Meta information of each image.
Returns:
tuple[Tensor, Tensor]: dets of shape [N, num_det, 5]
and class labels of shape [N, num_det].
"""
assert len(img_metas) == 1, \
'Only support one input image while in exporting to ONNX'
cls_scores = all_cls_scores_list[-1][-1]
bbox_preds = all_bbox_preds_list[-1][-1]
# Note `img_shape` is not dynamically traceable to ONNX,
# here `img_shape_for_onnx` (padded shape of image tensor)
# is used.
img_shape = img_metas[0]['img_shape_for_onnx']
max_per_img = self.test_cfg.get('max_per_img', self.num_query)
batch_size = cls_scores.size(0)
# `batch_index_offset` is used for the gather of concatenated tensor
batch_index_offset = torch.arange(batch_size).to(
cls_scores.device) * max_per_img
batch_index_offset = batch_index_offset.unsqueeze(1).expand(
batch_size, max_per_img)
# supports dynamical batch inference
if self.loss_cls.use_sigmoid:
cls_scores = cls_scores.sigmoid()
scores, indexes = cls_scores.view(batch_size, -1).topk(
max_per_img, dim=1)
det_labels = indexes % self.num_classes
bbox_index = indexes // self.num_classes
bbox_index = (bbox_index + batch_index_offset).view(-1)
bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
bbox_preds = bbox_preds.view(batch_size, -1, 4)
else:
scores, det_labels = F.softmax(
cls_scores, dim=-1)[..., :-1].max(-1)
scores, bbox_index = scores.topk(max_per_img, dim=1)
bbox_index = (bbox_index + batch_index_offset).view(-1)
bbox_preds = bbox_preds.view(-1, 4)[bbox_index]
det_labels = det_labels.view(-1)[bbox_index]
bbox_preds = bbox_preds.view(batch_size, -1, 4)
det_labels = det_labels.view(batch_size, -1)
det_bboxes = bbox_cxcywh_to_xyxy(bbox_preds)
# use `img_shape_tensor` for dynamically exporting to ONNX
img_shape_tensor = img_shape.flip(0).repeat(2) # [w,h,w,h]
img_shape_tensor = img_shape_tensor.unsqueeze(0).unsqueeze(0).expand(
batch_size, det_bboxes.size(1), 4)
det_bboxes = det_bboxes * img_shape_tensor
# dynamically clip bboxes
x1, y1, x2, y2 = det_bboxes.split((1, 1, 1, 1), dim=-1)
from mmdet.core.export import dynamic_clip_for_onnx
x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, img_shape)
det_bboxes = torch.cat([x1, y1, x2, y2], dim=-1)
det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(-1)), -1)
return det_bboxes, det_labels
|