Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Author: Donny You(youansheng@gmail.com) | |
import torch.nn as nn | |
from networks.resnet_models import * | |
class NormalResnetBackbone(nn.Module): | |
def __init__(self, orig_resnet): | |
super(NormalResnetBackbone, self).__init__() | |
self.num_features = 2048 | |
# take pretrained resnet, except AvgPool and FC | |
self.prefix = orig_resnet.prefix | |
self.maxpool = orig_resnet.maxpool | |
self.layer1 = orig_resnet.layer1 | |
self.layer2 = orig_resnet.layer2 | |
self.layer3 = orig_resnet.layer3 | |
self.layer4 = orig_resnet.layer4 | |
def get_num_features(self): | |
return self.num_features | |
def forward(self, x): | |
tuple_features = list() | |
x = self.prefix(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
tuple_features.append(x) | |
x = self.layer2(x) | |
tuple_features.append(x) | |
x = self.layer3(x) | |
tuple_features.append(x) | |
x = self.layer4(x) | |
tuple_features.append(x) | |
return tuple_features | |
class DilatedResnetBackbone(nn.Module): | |
def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)): | |
super(DilatedResnetBackbone, self).__init__() | |
self.num_features = 2048 | |
from functools import partial | |
if dilate_scale == 8: | |
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) | |
if multi_grid is None: | |
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) | |
else: | |
for i, r in enumerate(multi_grid): | |
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r))) | |
elif dilate_scale == 16: | |
if multi_grid is None: | |
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) | |
else: | |
for i, r in enumerate(multi_grid): | |
orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r))) | |
# Take pretrained resnet, except AvgPool and FC | |
self.prefix = orig_resnet.prefix | |
self.maxpool = orig_resnet.maxpool | |
self.layer1 = orig_resnet.layer1 | |
self.layer2 = orig_resnet.layer2 | |
self.layer3 = orig_resnet.layer3 | |
self.layer4 = orig_resnet.layer4 | |
def _nostride_dilate(self, m, dilate): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
# the convolution with stride | |
if m.stride == (2, 2): | |
m.stride = (1, 1) | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate // 2, dilate // 2) | |
m.padding = (dilate // 2, dilate // 2) | |
# other convoluions | |
else: | |
if m.kernel_size == (3, 3): | |
m.dilation = (dilate, dilate) | |
m.padding = (dilate, dilate) | |
def get_num_features(self): | |
return self.num_features | |
def forward(self, x): | |
tuple_features = list() | |
x = self.prefix(x) | |
x = self.maxpool(x) | |
x = self.layer1(x) | |
tuple_features.append(x) | |
x = self.layer2(x) | |
tuple_features.append(x) | |
x = self.layer3(x) | |
tuple_features.append(x) | |
x = self.layer4(x) | |
tuple_features.append(x) | |
return tuple_features | |
def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'): | |
arch = backbone | |
if arch == 'resnet18': | |
orig_resnet = resnet18(pretrained=pretrained) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
arch_net.num_features = 512 | |
elif arch == 'resnet18_dilated8': | |
orig_resnet = resnet18(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
arch_net.num_features = 512 | |
elif arch == 'resnet34': | |
orig_resnet = resnet34(pretrained=pretrained) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
arch_net.num_features = 512 | |
elif arch == 'resnet34_dilated8': | |
orig_resnet = resnet34(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
arch_net.num_features = 512 | |
elif arch == 'resnet34_dilated16': | |
orig_resnet = resnet34(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
arch_net.num_features = 512 | |
elif arch == 'resnet50': | |
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
elif arch == 'resnet50_dilated8': | |
orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
elif arch == 'resnet50_dilated16': | |
orig_resnet = resnet50(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
elif arch == 'deepbase_resnet50': | |
if pretrained: | |
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' | |
orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
elif arch == 'deepbase_resnet50_dilated8': | |
if pretrained: | |
pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth' | |
# pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth" | |
orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
elif arch == 'deepbase_resnet50_dilated16': | |
orig_resnet = deepbase_resnet50(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
elif arch == 'resnet101': | |
orig_resnet = resnet101(pretrained=pretrained) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
elif arch == 'resnet101_dilated8': | |
orig_resnet = resnet101(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
elif arch == 'resnet101_dilated16': | |
orig_resnet = resnet101(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
elif arch == 'deepbase_resnet101': | |
orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
arch_net = NormalResnetBackbone(orig_resnet) | |
elif arch == 'deepbase_resnet101_dilated8': | |
if pretrained: | |
pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth' | |
orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid) | |
elif arch == 'deepbase_resnet101_dilated16': | |
orig_resnet = deepbase_resnet101(pretrained=pretrained) | |
arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid) | |
else: | |
raise Exception('Architecture undefined!') | |
return arch_net | |