RockeyCoss
add code files”
51f6859
raw
history blame
2.85 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import torch.nn.functional as F
from mmcv.runner import BaseModule, force_fp32
from ..builder import build_loss
from ..utils import interpolate_as
class BaseSemanticHead(BaseModule, metaclass=ABCMeta):
"""Base module of Semantic Head.
Args:
num_classes (int): the number of classes.
init_cfg (dict): the initialization config.
loss_seg (dict): the loss of the semantic head.
"""
def __init__(self,
num_classes,
init_cfg=None,
loss_seg=dict(
type='CrossEntropyLoss',
ignore_index=255,
loss_weight=1.0)):
super(BaseSemanticHead, self).__init__(init_cfg)
self.loss_seg = build_loss(loss_seg)
self.num_classes = num_classes
@force_fp32(apply_to=('seg_preds', ))
def loss(self, seg_preds, gt_semantic_seg):
"""Get the loss of semantic head.
Args:
seg_preds (Tensor): The input logits with the shape (N, C, H, W).
gt_semantic_seg: The ground truth of semantic segmentation with
the shape (N, H, W).
label_bias: The starting number of the semantic label.
Default: 1.
Returns:
dict: the loss of semantic head.
"""
if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]:
seg_preds = interpolate_as(seg_preds, gt_semantic_seg)
seg_preds = seg_preds.permute((0, 2, 3, 1))
loss_seg = self.loss_seg(
seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C]
gt_semantic_seg.reshape(-1).long())
return dict(loss_seg=loss_seg)
@abstractmethod
def forward(self, x):
"""Placeholder of forward function.
Returns:
dict[str, Tensor]: A dictionary, including features
and predicted scores. Required keys: 'seg_preds'
and 'feats'.
"""
pass
def forward_train(self, x, gt_semantic_seg):
output = self.forward(x)
seg_preds = output['seg_preds']
return self.loss(seg_preds, gt_semantic_seg)
def simple_test(self, x, img_metas, rescale=False):
output = self.forward(x)
seg_preds = output['seg_preds']
seg_preds = F.interpolate(
seg_preds,
size=img_metas[0]['pad_shape'][:2],
mode='bilinear',
align_corners=False)
if rescale:
h, w, _ = img_metas[0]['img_shape']
seg_preds = seg_preds[:, :, :h, :w]
h, w, _ = img_metas[0]['ori_shape']
seg_preds = F.interpolate(
seg_preds, size=(h, w), mode='bilinear', align_corners=False)
return seg_preds