Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import logging | |
def weights_init(init_type='gaussian', std=0.02): | |
def init_fun(m): | |
classname = m.__class__.__name__ | |
if (classname.find('Conv') == 0 or classname.find( | |
'Linear') == 0) and hasattr(m, 'weight'): | |
if init_type == 'gaussian': | |
nn.init.normal_(m.weight, 0.0, std) | |
elif init_type == 'xavier': | |
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) | |
elif init_type == 'kaiming': | |
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) | |
elif init_type == 'default': | |
pass | |
else: | |
assert 0, "Unsupported initialization: {}".format(init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
return init_fun | |
def freeze(module): | |
for param in module.parameters(): | |
param.requires_grad = False | |
def unfreeze(module): | |
for param in module.parameters(): | |
param.requires_grad = True | |
def get_optimizer(opt, model): | |
lr = float(opt['hyper_params']['lr']) | |
beta1 = float(opt['model']['beta1']) | |
weight_decay = float(opt['model']['weight_decay']) | |
opt_name = opt['model']['optimizer'] | |
optim_params = [] | |
# weight decay | |
for key, value in model.named_parameters(): | |
if not value.requires_grad: | |
continue # frozen weights | |
if key[-4:] == 'bias': | |
optim_params += [{'params': value, 'weight_decay': 0.0}] | |
else: | |
optim_params += [{'params': value, | |
'weight_decay': weight_decay}] | |
if opt_name == 'Adam': | |
return optim.Adam(optim_params, | |
lr=lr, | |
betas=(beta1, 0.999), | |
eps=1e-5) | |
else: | |
err = '{} not implemented yet'.format(opt_name) | |
logging.error(err) | |
raise NotImplementedError(err) | |
def get_activation(activation): | |
if activation is None: | |
return nn.Identity() | |
act_func = { | |
'relu':nn.ReLU(), | |
'sigmoid':nn.Sigmoid(), | |
'tanh':nn.Tanh(), | |
'prelu':nn.PReLU(), | |
'leaky':nn.LeakyReLU(0.2), | |
'gelu':nn.GELU(), | |
} | |
if activation not in act_func.keys(): | |
logging.error("activation {} is not implemented yet".format(activation)) | |
assert False | |
return act_func[activation] | |
def get_norm(out_channels, norm_type='Instance'): | |
norm_set = ['Instance', 'Batch', 'Group'] | |
if norm_type not in norm_set: | |
err = "Normalization {} has not been implemented yet" | |
logging.error(err) | |
raise ValueError(err) | |
if norm_type == 'Instance': | |
return nn.InstanceNorm2d(out_channels, affine=True) | |
if norm_type == 'Batch': | |
return nn.BatchNorm2d(out_channels) | |
if norm_type == 'Group': | |
if out_channels >= 32: | |
groups = 32 | |
else: | |
groups = 1 | |
return nn.GroupNorm(groups, out_channels) | |
else: | |
raise NotImplementedError('{} has not implemented yet'.format(norm_type)) | |
def get_layer_info(out_channels, activation_func='relu'): | |
activation = get_activation(activation_func) | |
norm_layer = get_norm(out_channels, 'Group') | |
return norm_layer, activation | |
class Conv(nn.Module): | |
""" (convolution => [BN] => ReLU) """ | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
activation='leaky', | |
resnet=True): | |
super().__init__() | |
norm_layer, act_func = get_layer_info(out_channels,activation) | |
if resnet and in_channels == out_channels: | |
self.resnet = True | |
else: | |
self.resnet = False | |
self.conv = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias), | |
norm_layer, | |
act_func) | |
def forward(self, x): | |
res = self.conv(x) | |
if self.resnet: | |
res = res + x | |
return res | |
class Up(nn.Module): | |
""" Upscaling then conv """ | |
def __init__(self, in_channels, out_channels, activation='relu', resnet=True): | |
super().__init__() | |
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet) | |
def forward(self, x): | |
x = self.up_layer(x) | |
return self.up(x) | |
class DConv(nn.Module): | |
""" Double Conv Layer | |
""" | |
def __init__(self, in_channels, out_channels, activation='relu', resnet=True): | |
super().__init__() | |
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet) | |
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet) | |
def forward(self, x): | |
return self.conv2(self.conv1(x)) | |
class Encoder(nn.Module): | |
def __init__(self, in_channels=3, mid_act='leaky', resnet=True): | |
super(Encoder, self).__init__() | |
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet) | |
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet) | |
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet) | |
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet) | |
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet) | |
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet) | |
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet) | |
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet) | |
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet) | |
def forward(self, x): | |
x1 = self.in_conv(x) # 32 x 256 x 256 | |
x1 = torch.cat((x, x1), dim=1) | |
x2 = self.down_32_64(x1) | |
x3 = self.down_64_64_1(x2) | |
x4 = self.down_64_128(x3) | |
x5 = self.down_128_128_1(x4) | |
x6 = self.down_128_256(x5) | |
x7 = self.down_256_256_1(x6) | |
x8 = self.down_256_512(x7) | |
x9 = self.down_512_512_1(x8) | |
x10 = self.down_512_512_2(x9) | |
x11 = self.down_512_512_3(x10) | |
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 | |
class Decoder(nn.Module): | |
""" Up Stream Sequence """ | |
def __init__(self, | |
out_channels=3, | |
mid_act='relu', | |
out_act='sigmoid', | |
resnet = True): | |
super(Decoder, self).__init__() | |
input_channel = 512 | |
fea_dim = 100 | |
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet) | |
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet) | |
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet) | |
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet) | |
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet) | |
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet) | |
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet) | |
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet) | |
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet) | |
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet) | |
self.out_conv = Conv(64, out_channels, activation=out_act) | |
def forward(self, x, ibl): | |
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x | |
h,w = x10.shape[2:] | |
y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w) | |
y = self.up_16_16_1(y) # 256 x 16 x 16 | |
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16 | |
y = self.up_16_16_2(y) # 512 x 16 x 16 | |
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16 | |
y = self.up_16_16_3(y) # 512 x 16 x 16 | |
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16 | |
y = self.up_16_32(y) # 256 x 32 x 32 | |
y = torch.cat((x7, y), dim=1) | |
y = self.up_32_32_1(y) # 256 x 32 x 32 | |
y = torch.cat((x6, y), dim=1) | |
y = self.up_32_64(y) | |
y = torch.cat((x5, y), dim=1) | |
y = self.up_64_64_1(y) # 128 x 64 x 64 | |
y = torch.cat((x4, y), dim=1) | |
y = self.up_64_128(y) | |
y = torch.cat((x3, y), dim=1) | |
y = self.up_128_128_1(y) # 64 x 128 x 128 | |
y = torch.cat((x2, y), dim=1) | |
y = self.up_128_256(y) # 32 x 256 x 256 | |
y = torch.cat((x1, y), dim=1) | |
y = self.out_conv(y) # 3 x 256 x 256 | |
return y | |
class SSN_Model(nn.Module): | |
""" Implementation of Relighting Net """ | |
def __init__(self, | |
in_channels=3, | |
out_channels=3, | |
mid_act='leaky', | |
out_act='sigmoid', | |
resnet=True): | |
super(SSN_Model, self).__init__() | |
self.out_act = out_act | |
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet) | |
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet) | |
# init weights | |
init_func = weights_init('gaussian', std=1e-3) | |
self.encoder.apply(init_func) | |
self.decoder.apply(init_func) | |
def forward(self, x, ibl): | |
""" | |
Input is (source image, target light, source light, ) | |
Output is: predicted new image, predicted source light, self-supervision image | |
""" | |
latent = self.encoder(x) | |
pred = self.decoder(latent, ibl) | |
if self.out_act == 'sigmoid': | |
pred = pred * 30.0 | |
return pred | |
if __name__ == '__main__': | |
x = torch.randn(5,1,256,256) | |
ibl = torch.randn(5, 1, 32, 16) | |
model = SSN_Model(1,1) | |
y = model(x, ibl) | |
print('Output: ', y.shape) | |