Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Author: Donny You (youansheng@gmail.com) | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
try: | |
from urllib import urlretrieve | |
except ImportError: | |
from urllib.request import urlretrieve | |
class FixedBatchNorm(nn.BatchNorm2d): | |
def forward(self, input): | |
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps) | |
class ModuleHelper(object): | |
def BNReLU(num_features, norm_type=None, **kwargs): | |
if norm_type == 'batchnorm': | |
return nn.Sequential( | |
nn.BatchNorm2d(num_features, **kwargs), | |
nn.ReLU() | |
) | |
elif norm_type == 'encsync_batchnorm': | |
from encoding.nn import BatchNorm2d | |
return nn.Sequential( | |
BatchNorm2d(num_features, **kwargs), | |
nn.ReLU() | |
) | |
elif norm_type == 'instancenorm': | |
return nn.Sequential( | |
nn.InstanceNorm2d(num_features, **kwargs), | |
nn.ReLU() | |
) | |
elif norm_type == 'fixed_batchnorm': | |
return nn.Sequential( | |
FixedBatchNorm(num_features, **kwargs), | |
nn.ReLU() | |
) | |
else: | |
raise ValueError('Not support BN type: {}.'.format(norm_type)) | |
def BatchNorm3d(norm_type=None, ret_cls=False): | |
if norm_type == 'batchnorm': | |
return nn.BatchNorm3d | |
elif norm_type == 'encsync_batchnorm': | |
from encoding.nn import BatchNorm3d | |
return BatchNorm3d | |
elif norm_type == 'instancenorm': | |
return nn.InstanceNorm3d | |
else: | |
raise ValueError('Not support BN type: {}.'.format(norm_type)) | |
def BatchNorm2d(norm_type=None, ret_cls=False): | |
if norm_type == 'batchnorm': | |
return nn.BatchNorm2d | |
elif norm_type == 'encsync_batchnorm': | |
from encoding.nn import BatchNorm2d | |
return BatchNorm2d | |
elif norm_type == 'instancenorm': | |
return nn.InstanceNorm2d | |
else: | |
raise ValueError('Not support BN type: {}.'.format(norm_type)) | |
def BatchNorm1d(norm_type=None, ret_cls=False): | |
if norm_type == 'batchnorm': | |
return nn.BatchNorm1d | |
elif norm_type == 'encsync_batchnorm': | |
from encoding.nn import BatchNorm1d | |
return BatchNorm1d | |
elif norm_type == 'instancenorm': | |
return nn.InstanceNorm1d | |
else: | |
raise ValueError('Not support BN type: {}.'.format(norm_type)) | |
def load_model(model, pretrained=None, all_match=True, map_location='cpu'): | |
if pretrained is None: | |
return model | |
if not os.path.exists(pretrained): | |
pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation") | |
if os.path.exists(pretrained): | |
pass | |
else: | |
raise FileNotFoundError('{} not exists.'.format(pretrained)) | |
print('Loading pretrained model:{}'.format(pretrained)) | |
if all_match: | |
pretrained_dict = torch.load(pretrained, map_location=map_location) | |
model_dict = model.state_dict() | |
load_dict = dict() | |
for k, v in pretrained_dict.items(): | |
if 'prefix.{}'.format(k) in model_dict: | |
load_dict['prefix.{}'.format(k)] = v | |
else: | |
load_dict[k] = v | |
model.load_state_dict(load_dict) | |
else: | |
pretrained_dict = torch.load(pretrained) | |
model_dict = model.state_dict() | |
load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} | |
print('Matched Keys: {}'.format(load_dict.keys())) | |
model_dict.update(load_dict) | |
model.load_state_dict(model_dict) | |
return model | |
def load_url(url, map_location=None): | |
model_dir = os.path.join('~', '.TorchCV', 'model') | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
filename = url.split('/')[-1] | |
cached_file = os.path.join(model_dir, filename) | |
if not os.path.exists(cached_file): | |
print('Downloading: "{}" to {}\n'.format(url, cached_file)) | |
urlretrieve(url, cached_file) | |
print('Loading pretrained model:{}'.format(cached_file)) | |
return torch.load(cached_file, map_location=map_location) | |
def constant_init(module, val, bias=0): | |
nn.init.constant_(module.weight, val) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
def xavier_init(module, gain=1, bias=0, distribution='normal'): | |
assert distribution in ['uniform', 'normal'] | |
if distribution == 'uniform': | |
nn.init.xavier_uniform_(module.weight, gain=gain) | |
else: | |
nn.init.xavier_normal_(module.weight, gain=gain) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
def normal_init(module, mean=0, std=1, bias=0): | |
nn.init.normal_(module.weight, mean, std) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
def uniform_init(module, a=0, b=1, bias=0): | |
nn.init.uniform_(module.weight, a, b) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
def kaiming_init(module, | |
mode='fan_in', | |
nonlinearity='leaky_relu', | |
bias=0, | |
distribution='normal'): | |
assert distribution in ['uniform', 'normal'] | |
if distribution == 'uniform': | |
nn.init.kaiming_uniform_( | |
module.weight, mode=mode, nonlinearity=nonlinearity) | |
else: | |
nn.init.kaiming_normal_( | |
module.weight, mode=mode, nonlinearity=nonlinearity) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |