Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
import torch.nn as nn | |
from mmcv.runner import BaseModule, auto_fp16 | |
from mmdet.models.backbones import ResNet | |
from mmdet.models.builder import SHARED_HEADS | |
from mmdet.models.utils import ResLayer as _ResLayer | |
class ResLayer(BaseModule): | |
def __init__(self, | |
depth, | |
stage=3, | |
stride=2, | |
dilation=1, | |
style='pytorch', | |
norm_cfg=dict(type='BN', requires_grad=True), | |
norm_eval=True, | |
with_cp=False, | |
dcn=None, | |
pretrained=None, | |
init_cfg=None): | |
super(ResLayer, self).__init__(init_cfg) | |
self.norm_eval = norm_eval | |
self.norm_cfg = norm_cfg | |
self.stage = stage | |
self.fp16_enabled = False | |
block, stage_blocks = ResNet.arch_settings[depth] | |
stage_block = stage_blocks[stage] | |
planes = 64 * 2**stage | |
inplanes = 64 * 2**(stage - 1) * block.expansion | |
res_layer = _ResLayer( | |
block, | |
inplanes, | |
planes, | |
stage_block, | |
stride=stride, | |
dilation=dilation, | |
style=style, | |
with_cp=with_cp, | |
norm_cfg=self.norm_cfg, | |
dcn=dcn) | |
self.add_module(f'layer{stage + 1}', res_layer) | |
assert not (init_cfg and pretrained), \ | |
'init_cfg and pretrained cannot be specified at the same time' | |
if isinstance(pretrained, str): | |
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | |
'please use "init_cfg" instead') | |
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |
elif pretrained is None: | |
if init_cfg is None: | |
self.init_cfg = [ | |
dict(type='Kaiming', layer='Conv2d'), | |
dict( | |
type='Constant', | |
val=1, | |
layer=['_BatchNorm', 'GroupNorm']) | |
] | |
else: | |
raise TypeError('pretrained must be a str or None') | |
def forward(self, x): | |
res_layer = getattr(self, f'layer{self.stage + 1}') | |
out = res_layer(x) | |
return out | |
def train(self, mode=True): | |
super(ResLayer, self).train(mode) | |
if self.norm_eval: | |
for m in self.modules(): | |
if isinstance(m, nn.BatchNorm2d): | |
m.eval() | |