RockeyCoss
add code files”
51f6859
raw
history blame
3.92 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .kd_one_stage import KnowledgeDistillationSingleStageDetector
@DETECTORS.register_module()
class LAD(KnowledgeDistillationSingleStageDetector):
"""Implementation of `LAD <https://arxiv.org/pdf/2108.10520.pdf>`_."""
def __init__(self,
backbone,
neck,
bbox_head,
teacher_backbone,
teacher_neck,
teacher_bbox_head,
teacher_ckpt,
eval_teacher=True,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(KnowledgeDistillationSingleStageDetector,
self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
pretrained)
self.eval_teacher = eval_teacher
self.teacher_model = nn.Module()
self.teacher_model.backbone = build_backbone(teacher_backbone)
if teacher_neck is not None:
self.teacher_model.neck = build_neck(teacher_neck)
teacher_bbox_head.update(train_cfg=train_cfg)
teacher_bbox_head.update(test_cfg=test_cfg)
self.teacher_model.bbox_head = build_head(teacher_bbox_head)
if teacher_ckpt is not None:
load_checkpoint(
self.teacher_model, teacher_ckpt, map_location='cpu')
@property
def with_teacher_neck(self):
"""bool: whether the detector has a teacher_neck"""
return hasattr(self.teacher_model, 'neck') and \
self.teacher_model.neck is not None
def extract_teacher_feat(self, img):
"""Directly extract teacher features from the backbone+neck."""
x = self.teacher_model.backbone(img)
if self.with_teacher_neck:
x = self.teacher_model.neck(x)
return x
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
# get label assignment from the teacher
with torch.no_grad():
x_teacher = self.extract_teacher_feat(img)
outs_teacher = self.teacher_model.bbox_head(x_teacher)
label_assignment_results = \
self.teacher_model.bbox_head.get_label_assignment(
*outs_teacher, gt_bboxes, gt_labels, img_metas,
gt_bboxes_ignore)
# the student use the label assignment from the teacher to learn
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, label_assignment_results,
img_metas, gt_bboxes, gt_labels,
gt_bboxes_ignore)
return losses