stevengrove
initial commit
186701e
raw
history blame
1.03 kB
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones.res2net import Bottle2neck
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
from mmdet.models.layers import SimplifiedBasicBlock
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
SimplifiedBasicBlock)):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True