yichen-purdue's picture
init
34fb220
raw
history blame
6.89 kB
from enum import Enum
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import logging
def get_model_size(model):
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_all_mb = (param_size + buffer_size) / 1024 ** 2
print('model size: {:.3f}MB'.format(size_all_mb))
# return param_size + buffer_size
return size_all_mb
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, 0.02)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
return init_fun
def freeze(module):
for param in module.parameters():
param.requires_grad = False
def unfreeze(module):
for param in module.parameters():
param.requires_grad = True
def get_optimizer(opt, model):
lr = float(opt['hyper_params']['lr'])
beta1 = float(opt['model']['beta1'])
weight_decay = float(opt['model']['weight_decay'])
opt_name = opt['model']['optimizer']
optim_params = []
# weight decay
for key, value in model.named_parameters():
if not value.requires_grad:
continue # frozen weights
if key[-4:] == 'bias':
optim_params += [{'params': value, 'weight_decay': 0.0}]
else:
optim_params += [{'params': value,
'weight_decay': weight_decay}]
if opt_name == 'Adam':
return optim.Adam(optim_params,
lr=lr,
betas=(beta1, 0.999),
eps=1e-5)
else:
err = '{} not implemented yet'.format(opt_name)
logging.error(err)
raise NotImplementedError(err)
def get_activation(activation):
act_func = {
'relu':nn.ReLU(),
'sigmoid':nn.Sigmoid(),
'tanh':nn.Tanh(),
'prelu':nn.PReLU(),
'leaky_relu':nn.LeakyReLU(0.2),
'gelu':nn.GELU(),
}
if activation not in act_func.keys():
logging.error("activation {} is not implemented yet".format(activation))
assert False
return act_func[activation]
def get_norm(out_channels, norm_type='Group', groups=32):
norm_set = ['Instance', 'Batch', 'Group']
if norm_type not in norm_set:
err = "Normalization {} has not been implemented yet"
logging.error(err)
raise ValueError(err)
if norm_type == 'Instance':
return nn.InstanceNorm2d(out_channels, affine=True)
if norm_type == 'Batch':
return nn.BatchNorm2d(out_channels)
if norm_type == 'Group':
if out_channels >= 32:
groups = 32
else:
groups = max(out_channels // 2, 1)
return nn.GroupNorm(groups, out_channels)
else:
raise NotImplementedError
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, norm_type='Batch', activation='relu'):
super().__init__()
act_func = get_activation(activation)
norm_layer = get_norm(out_channels, norm_type)
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, padding_mode='reflect'),
norm_layer,
act_func)
def forward(self, x):
return self.conv(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class Up(nn.Module):
def __init__(self):
super().__init__()
pass
def forward(self, x):
return F.interpolate(x, scale_factor=2, mode='bilinear')
class Down(nn.Module):
def __init__(self, channels, use_conv):
super().__init__()
self.use_conv = use_conv
if self.use_conv:
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
else:
self.op = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
return self.op(x)
class Res_Type(Enum):
UP = 1
DOWN = 2
SAME = 3
class ResBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dropout=0.0, updown=Res_Type.DOWN, mid_act='leaky'):
""" ResBlock to cover several cases:
1. Up/Down/Same
2. in_channels != out_channels
"""
super().__init__()
self.updown = updown
self.in_norm = get_norm(out_channels, 'Group')
self.in_act = get_activation(mid_act)
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
# up down
if self.updown == Res_Type.DOWN:
self.h_updown = Down(in_channels, use_conv=True)
self.x_updown = Down(in_channels, use_conv=True)
elif self.updown == Res_Type.UP:
self.h_updown = Up()
self.x_updown = Up()
else:
self.h_updown = nn.Identity()
self.out_layer = nn.Sequential(
get_norm(out_channels, 'Group'),
get_activation(mid_act),
nn.Dropout(p=dropout),
zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True))
)
def forward(self, x):
# in layer
h = self.in_act(self.in_norm(x))
h = self.in_conv(self.h_updown(h))
x = self.x_updown(x)
# out layer
h = self.out_layer(h)
return x + h
if __name__ == '__main__':
x = torch.randn(5, 3, 256, 256)
up = Up()
conv_down = Down(3, True)
pool_down = Down(3, False)
print('Up: {}'.format(up(x).shape))
print('Conv down: {}'.format(conv_down(x).shape))
print('Pool down: {}'.format(pool_down(x).shape))
up_model = ResBlock(3, 6, updown=True)
down_model = ResBlock(3, 6, updown=False)
print('model down: {}'.format(up_model(x).shape))
print('model down: {}'.format(down_model(x).shape))