yichen-purdue's picture
init
34fb220
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import logging
def weights_init(init_type='gaussian', std=0.02):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find(
'Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
nn.init.normal_(m.weight, 0.0, std)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
return init_fun
def freeze(module):
for param in module.parameters():
param.requires_grad = False
def unfreeze(module):
for param in module.parameters():
param.requires_grad = True
def get_optimizer(opt, model):
lr = float(opt['hyper_params']['lr'])
beta1 = float(opt['model']['beta1'])
weight_decay = float(opt['model']['weight_decay'])
opt_name = opt['model']['optimizer']
optim_params = []
# weight decay
for key, value in model.named_parameters():
if not value.requires_grad:
continue # frozen weights
if key[-4:] == 'bias':
optim_params += [{'params': value, 'weight_decay': 0.0}]
else:
optim_params += [{'params': value,
'weight_decay': weight_decay}]
if opt_name == 'Adam':
return optim.Adam(optim_params,
lr=lr,
betas=(beta1, 0.999),
eps=1e-5)
else:
err = '{} not implemented yet'.format(opt_name)
logging.error(err)
raise NotImplementedError(err)
def get_activation(activation):
if activation is None:
return nn.Identity()
act_func = {
'relu':nn.ReLU(),
'sigmoid':nn.Sigmoid(),
'tanh':nn.Tanh(),
'prelu':nn.PReLU(),
'leaky':nn.LeakyReLU(0.2),
'gelu':nn.GELU(),
}
if activation not in act_func.keys():
logging.error("activation {} is not implemented yet".format(activation))
assert False
return act_func[activation]
def get_norm(out_channels, norm_type='Instance'):
norm_set = ['Instance', 'Batch', 'Group']
if norm_type not in norm_set:
err = "Normalization {} has not been implemented yet"
logging.error(err)
raise ValueError(err)
if norm_type == 'Instance':
return nn.InstanceNorm2d(out_channels, affine=True)
if norm_type == 'Batch':
return nn.BatchNorm2d(out_channels)
if norm_type == 'Group':
if out_channels >= 32:
groups = 32
else:
groups = 1
return nn.GroupNorm(groups, out_channels)
else:
raise NotImplementedError('{} has not implemented yet'.format(norm_type))
def get_layer_info(out_channels, activation_func='relu'):
activation = get_activation(activation_func)
norm_layer = get_norm(out_channels, 'Group')
return norm_layer, activation
class Conv(nn.Module):
""" (convolution => [BN] => ReLU) """
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
activation='leaky',
resnet=True):
super().__init__()
norm_layer, act_func = get_layer_info(out_channels,activation)
if resnet and in_channels == out_channels:
self.resnet = True
else:
self.resnet = False
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias),
norm_layer,
act_func)
def forward(self, x):
res = self.conv(x)
if self.resnet:
res = res + x
return res
class Up(nn.Module):
""" Upscaling then conv """
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
super().__init__()
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
def forward(self, x):
x = self.up_layer(x)
return self.up(x)
class DConv(nn.Module):
""" Double Conv Layer
"""
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
super().__init__()
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
def forward(self, x):
return self.conv2(self.conv1(x))
class Encoder(nn.Module):
def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
super(Encoder, self).__init__()
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
def forward(self, x):
x1 = self.in_conv(x) # 32 x 256 x 256
x1 = torch.cat((x, x1), dim=1)
x2 = self.down_32_64(x1)
x3 = self.down_64_64_1(x2)
x4 = self.down_64_128(x3)
x5 = self.down_128_128_1(x4)
x6 = self.down_128_256(x5)
x7 = self.down_256_256_1(x6)
x8 = self.down_256_512(x7)
x9 = self.down_512_512_1(x8)
x10 = self.down_512_512_2(x9)
x11 = self.down_512_512_3(x10)
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
class Decoder(nn.Module):
""" Up Stream Sequence """
def __init__(self,
out_channels=3,
mid_act='relu',
out_act='sigmoid',
resnet = True):
super(Decoder, self).__init__()
input_channel = 512
fea_dim = 100
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
self.out_conv = Conv(64, out_channels, activation=out_act)
def forward(self, x, ibl):
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
h,w = x10.shape[2:]
y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w)
y = self.up_16_16_1(y) # 256 x 16 x 16
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
y = self.up_16_16_2(y) # 512 x 16 x 16
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
y = self.up_16_16_3(y) # 512 x 16 x 16
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
y = self.up_16_32(y) # 256 x 32 x 32
y = torch.cat((x7, y), dim=1)
y = self.up_32_32_1(y) # 256 x 32 x 32
y = torch.cat((x6, y), dim=1)
y = self.up_32_64(y)
y = torch.cat((x5, y), dim=1)
y = self.up_64_64_1(y) # 128 x 64 x 64
y = torch.cat((x4, y), dim=1)
y = self.up_64_128(y)
y = torch.cat((x3, y), dim=1)
y = self.up_128_128_1(y) # 64 x 128 x 128
y = torch.cat((x2, y), dim=1)
y = self.up_128_256(y) # 32 x 256 x 256
y = torch.cat((x1, y), dim=1)
y = self.out_conv(y) # 3 x 256 x 256
return y
class SSN_Model(nn.Module):
""" Implementation of Relighting Net """
def __init__(self,
in_channels=3,
out_channels=3,
mid_act='leaky',
out_act='sigmoid',
resnet=True):
super(SSN_Model, self).__init__()
self.out_act = out_act
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
# init weights
init_func = weights_init('gaussian', std=1e-3)
self.encoder.apply(init_func)
self.decoder.apply(init_func)
def forward(self, x, ibl):
"""
Input is (source image, target light, source light, )
Output is: predicted new image, predicted source light, self-supervision image
"""
latent = self.encoder(x)
pred = self.decoder(latent, ibl)
if self.out_act == 'sigmoid':
pred = pred * 30.0
return pred
if __name__ == '__main__':
x = torch.randn(5,1,256,256)
ibl = torch.randn(5, 1, 32, 16)
model = SSN_Model(1,1)
y = model(x, ibl)
print('Output: ', y.shape)