Spaces:
Runtime error
Runtime error
File size: 6,374 Bytes
35188e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Author: Donny You (youansheng@gmail.com)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from urllib import urlretrieve
except ImportError:
from urllib.request import urlretrieve
class FixedBatchNorm(nn.BatchNorm2d):
def forward(self, input):
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
class ModuleHelper(object):
@staticmethod
def BNReLU(num_features, norm_type=None, **kwargs):
if norm_type == 'batchnorm':
return nn.Sequential(
nn.BatchNorm2d(num_features, **kwargs),
nn.ReLU()
)
elif norm_type == 'encsync_batchnorm':
from encoding.nn import BatchNorm2d
return nn.Sequential(
BatchNorm2d(num_features, **kwargs),
nn.ReLU()
)
elif norm_type == 'instancenorm':
return nn.Sequential(
nn.InstanceNorm2d(num_features, **kwargs),
nn.ReLU()
)
elif norm_type == 'fixed_batchnorm':
return nn.Sequential(
FixedBatchNorm(num_features, **kwargs),
nn.ReLU()
)
else:
raise ValueError('Not support BN type: {}.'.format(norm_type))
@staticmethod
def BatchNorm3d(norm_type=None, ret_cls=False):
if norm_type == 'batchnorm':
return nn.BatchNorm3d
elif norm_type == 'encsync_batchnorm':
from encoding.nn import BatchNorm3d
return BatchNorm3d
elif norm_type == 'instancenorm':
return nn.InstanceNorm3d
else:
raise ValueError('Not support BN type: {}.'.format(norm_type))
@staticmethod
def BatchNorm2d(norm_type=None, ret_cls=False):
if norm_type == 'batchnorm':
return nn.BatchNorm2d
elif norm_type == 'encsync_batchnorm':
from encoding.nn import BatchNorm2d
return BatchNorm2d
elif norm_type == 'instancenorm':
return nn.InstanceNorm2d
else:
raise ValueError('Not support BN type: {}.'.format(norm_type))
@staticmethod
def BatchNorm1d(norm_type=None, ret_cls=False):
if norm_type == 'batchnorm':
return nn.BatchNorm1d
elif norm_type == 'encsync_batchnorm':
from encoding.nn import BatchNorm1d
return BatchNorm1d
elif norm_type == 'instancenorm':
return nn.InstanceNorm1d
else:
raise ValueError('Not support BN type: {}.'.format(norm_type))
@staticmethod
def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
if pretrained is None:
return model
if not os.path.exists(pretrained):
pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
if os.path.exists(pretrained):
pass
else:
raise FileNotFoundError('{} not exists.'.format(pretrained))
print('Loading pretrained model:{}'.format(pretrained))
if all_match:
pretrained_dict = torch.load(pretrained, map_location=map_location)
model_dict = model.state_dict()
load_dict = dict()
for k, v in pretrained_dict.items():
if 'prefix.{}'.format(k) in model_dict:
load_dict['prefix.{}'.format(k)] = v
else:
load_dict[k] = v
model.load_state_dict(load_dict)
else:
pretrained_dict = torch.load(pretrained)
model_dict = model.state_dict()
load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
print('Matched Keys: {}'.format(load_dict.keys()))
model_dict.update(load_dict)
model.load_state_dict(model_dict)
return model
@staticmethod
def load_url(url, map_location=None):
model_dir = os.path.join('~', '.TorchCV', 'model')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
print('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
print('Loading pretrained model:{}'.format(cached_file))
return torch.load(cached_file, map_location=map_location)
@staticmethod
def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
@staticmethod
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
@staticmethod
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
@staticmethod
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
@staticmethod
def kaiming_init(module,
mode='fan_in',
nonlinearity='leaky_relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
|