Spaces:
Runtime error
Runtime error
from enum import Enum | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import logging | |
def get_model_size(model): | |
param_size = 0 | |
for param in model.parameters(): | |
param_size += param.nelement() * param.element_size() | |
buffer_size = 0 | |
for buffer in model.buffers(): | |
buffer_size += buffer.nelement() * buffer.element_size() | |
size_all_mb = (param_size + buffer_size) / 1024 ** 2 | |
print('model size: {:.3f}MB'.format(size_all_mb)) | |
# return param_size + buffer_size | |
return size_all_mb | |
def weights_init(init_type='gaussian'): | |
def init_fun(m): | |
classname = m.__class__.__name__ | |
if (classname.find('Conv') == 0 or classname.find( | |
'Linear') == 0) and hasattr(m, 'weight'): | |
if init_type == 'gaussian': | |
nn.init.normal_(m.weight, 0.0, 0.02) | |
elif init_type == 'xavier': | |
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) | |
elif init_type == 'kaiming': | |
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) | |
elif init_type == 'default': | |
pass | |
else: | |
assert 0, "Unsupported initialization: {}".format(init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
return init_fun | |
def freeze(module): | |
for param in module.parameters(): | |
param.requires_grad = False | |
def unfreeze(module): | |
for param in module.parameters(): | |
param.requires_grad = True | |
def get_optimizer(opt, model): | |
lr = float(opt['hyper_params']['lr']) | |
beta1 = float(opt['model']['beta1']) | |
weight_decay = float(opt['model']['weight_decay']) | |
opt_name = opt['model']['optimizer'] | |
optim_params = [] | |
# weight decay | |
for key, value in model.named_parameters(): | |
if not value.requires_grad: | |
continue # frozen weights | |
if key[-4:] == 'bias': | |
optim_params += [{'params': value, 'weight_decay': 0.0}] | |
else: | |
optim_params += [{'params': value, | |
'weight_decay': weight_decay}] | |
if opt_name == 'Adam': | |
return optim.Adam(optim_params, | |
lr=lr, | |
betas=(beta1, 0.999), | |
eps=1e-5) | |
else: | |
err = '{} not implemented yet'.format(opt_name) | |
logging.error(err) | |
raise NotImplementedError(err) | |
def get_activation(activation): | |
act_func = { | |
'relu':nn.ReLU(), | |
'sigmoid':nn.Sigmoid(), | |
'tanh':nn.Tanh(), | |
'prelu':nn.PReLU(), | |
'leaky_relu':nn.LeakyReLU(0.2), | |
'gelu':nn.GELU(), | |
} | |
if activation not in act_func.keys(): | |
logging.error("activation {} is not implemented yet".format(activation)) | |
assert False | |
return act_func[activation] | |
def get_norm(out_channels, norm_type='Group', groups=32): | |
norm_set = ['Instance', 'Batch', 'Group'] | |
if norm_type not in norm_set: | |
err = "Normalization {} has not been implemented yet" | |
logging.error(err) | |
raise ValueError(err) | |
if norm_type == 'Instance': | |
return nn.InstanceNorm2d(out_channels, affine=True) | |
if norm_type == 'Batch': | |
return nn.BatchNorm2d(out_channels) | |
if norm_type == 'Group': | |
if out_channels >= 32: | |
groups = 32 | |
else: | |
groups = max(out_channels // 2, 1) | |
return nn.GroupNorm(groups, out_channels) | |
else: | |
raise NotImplementedError | |
class Conv(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1, norm_type='Batch', activation='relu'): | |
super().__init__() | |
act_func = get_activation(activation) | |
norm_layer = get_norm(out_channels, norm_type) | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, padding_mode='reflect'), | |
norm_layer, | |
act_func) | |
def forward(self, x): | |
return self.conv(x) | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class Up(nn.Module): | |
def __init__(self): | |
super().__init__() | |
pass | |
def forward(self, x): | |
return F.interpolate(x, scale_factor=2, mode='bilinear') | |
class Down(nn.Module): | |
def __init__(self, channels, use_conv): | |
super().__init__() | |
self.use_conv = use_conv | |
if self.use_conv: | |
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1) | |
else: | |
self.op = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) | |
def forward(self, x): | |
return self.op(x) | |
class Res_Type(Enum): | |
UP = 1 | |
DOWN = 2 | |
SAME = 3 | |
class ResBlock(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int, dropout=0.0, updown=Res_Type.DOWN, mid_act='leaky'): | |
""" ResBlock to cover several cases: | |
1. Up/Down/Same | |
2. in_channels != out_channels | |
""" | |
super().__init__() | |
self.updown = updown | |
self.in_norm = get_norm(out_channels, 'Group') | |
self.in_act = get_activation(mid_act) | |
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True) | |
# up down | |
if self.updown == Res_Type.DOWN: | |
self.h_updown = Down(in_channels, use_conv=True) | |
self.x_updown = Down(in_channels, use_conv=True) | |
elif self.updown == Res_Type.UP: | |
self.h_updown = Up() | |
self.x_updown = Up() | |
else: | |
self.h_updown = nn.Identity() | |
self.out_layer = nn.Sequential( | |
get_norm(out_channels, 'Group'), | |
get_activation(mid_act), | |
nn.Dropout(p=dropout), | |
zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True)) | |
) | |
def forward(self, x): | |
# in layer | |
h = self.in_act(self.in_norm(x)) | |
h = self.in_conv(self.h_updown(h)) | |
x = self.x_updown(x) | |
# out layer | |
h = self.out_layer(h) | |
return x + h | |
if __name__ == '__main__': | |
x = torch.randn(5, 3, 256, 256) | |
up = Up() | |
conv_down = Down(3, True) | |
pool_down = Down(3, False) | |
print('Up: {}'.format(up(x).shape)) | |
print('Conv down: {}'.format(conv_down(x).shape)) | |
print('Pool down: {}'.format(pool_down(x).shape)) | |
up_model = ResBlock(3, 6, updown=True) | |
down_model = ResBlock(3, 6, updown=False) | |
print('model down: {}'.format(up_model(x).shape)) | |
print('model down: {}'.format(down_model(x).shape)) | |