|
import torch |
|
import torch.nn as nn |
|
from torch.nn import init |
|
from torch.optim import lr_scheduler |
|
from collections import OrderedDict |
|
|
|
|
|
def get_scheduler(optimizer, opt): |
|
if opt.lr_policy == 'linear': |
|
def lambda_rule(epoch): |
|
return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) |
|
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) |
|
elif opt.lr_policy == 'step': |
|
scheduler = lr_scheduler.StepLR(optimizer, |
|
step_size=opt.lr_decay_iters, |
|
gamma=0.5) |
|
elif opt.lr_policy == 'plateau': |
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, |
|
mode='min', |
|
factor=0.2, |
|
threshold=0.01, |
|
patience=5) |
|
elif opt.lr_policy == 'cosine': |
|
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, |
|
T_max=opt.niter, |
|
eta_min=0) |
|
else: |
|
return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) |
|
return scheduler |
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02): |
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ |
|
or classname.find('Linear') != -1): |
|
if init_type == 'normal': |
|
init.normal_(m.weight.data, 0.0, init_gain) |
|
elif init_type == 'xavier': |
|
init.xavier_normal_(m.weight.data, gain=init_gain) |
|
elif init_type == 'kaiming': |
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
elif init_type == 'orthogonal': |
|
init.orthogonal_(m.weight.data, gain=init_gain) |
|
elif init_type == 'uniform': |
|
init.uniform_(m.weight.data, b=init_gain) |
|
else: |
|
raise NotImplementedError('[%s] is not implemented' % init_type) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
init.constant_(m.bias.data, 0.0) |
|
elif classname.find('BatchNorm2d') != -1: |
|
init.normal_(m.weight.data, 1.0, init_gain) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
print('initialize network with %s' % init_type) |
|
net.apply(init_func) |
|
|
|
def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): |
|
if len(gpu_ids) > 0: |
|
assert(torch.cuda.is_available()) |
|
net.to(gpu_ids[0]) |
|
net = torch.nn.DataParallel(net, gpu_ids) |
|
if init_type != 'default' and init_type is not None: |
|
init_weights(net, init_type, init_gain=init_gain) |
|
return net |
|
|
|
|
|
''' |
|
# =================================== |
|
# Advanced nn.Sequential |
|
# reform nn.Sequentials and nn.Modules |
|
# to a single nn.Sequential |
|
# =================================== |
|
''' |
|
|
|
def seq(*args): |
|
if len(args) == 1: |
|
args = args[0] |
|
if isinstance(args, nn.Module): |
|
return args |
|
modules = OrderedDict() |
|
if isinstance(args, OrderedDict): |
|
for k, v in args.items(): |
|
modules[k] = seq(v) |
|
return nn.Sequential(modules) |
|
assert isinstance(args, (list, tuple)) |
|
return nn.Sequential(*[seq(i) for i in args]) |
|
|
|
''' |
|
# =================================== |
|
# Useful blocks |
|
# -------------------------------- |
|
# conv (+ normaliation + relu) |
|
# concat |
|
# sum |
|
# resblock (ResBlock) |
|
# resdenseblock (ResidualDenseBlock_5C) |
|
# resinresdenseblock (RRDB) |
|
# =================================== |
|
''' |
|
|
|
|
|
|
|
|
|
def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, |
|
output_padding=0, dilation=1, groups=1, bias=True, |
|
padding_mode='zeros', mode='CBR'): |
|
L = [] |
|
for t in mode: |
|
if t == 'C': |
|
L.append(nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
padding_mode=padding_mode)) |
|
elif t == 'X': |
|
assert in_channels == out_channels |
|
L.append(nn.Conv2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=in_channels, |
|
bias=bias, |
|
padding_mode=padding_mode)) |
|
elif t == 'T': |
|
L.append(nn.ConvTranspose2d(in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
output_padding=output_padding, |
|
groups=groups, |
|
bias=bias, |
|
dilation=dilation, |
|
padding_mode=padding_mode)) |
|
elif t == 'B': |
|
L.append(nn.BatchNorm2d(out_channels)) |
|
elif t == 'I': |
|
L.append(nn.InstanceNorm2d(out_channels, affine=True)) |
|
elif t == 'i': |
|
L.append(nn.InstanceNorm2d(out_channels)) |
|
elif t == 'R': |
|
L.append(nn.ReLU(inplace=True)) |
|
elif t == 'r': |
|
L.append(nn.ReLU(inplace=False)) |
|
elif t == 'S': |
|
L.append(nn.Sigmoid()) |
|
elif t == 'P': |
|
L.append(nn.PReLU()) |
|
elif t == 'L': |
|
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) |
|
elif t == 'l': |
|
L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) |
|
elif t == '2': |
|
L.append(nn.PixelShuffle(upscale_factor=2)) |
|
elif t == '3': |
|
L.append(nn.PixelShuffle(upscale_factor=3)) |
|
elif t == '4': |
|
L.append(nn.PixelShuffle(upscale_factor=4)) |
|
elif t == 'U': |
|
L.append(nn.Upsample(scale_factor=2, mode='nearest')) |
|
elif t == 'u': |
|
L.append(nn.Upsample(scale_factor=3, mode='nearest')) |
|
elif t == 'M': |
|
L.append(nn.MaxPool2d(kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0)) |
|
elif t == 'A': |
|
L.append(nn.AvgPool2d(kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0)) |
|
else: |
|
raise NotImplementedError('Undefined type: '.format(t)) |
|
return seq(*L) |
|
|
|
|
|
class DWTForward(nn.Conv2d): |
|
def __init__(self, in_channels=64): |
|
super(DWTForward, self).__init__(in_channels, in_channels*4, 2, 2, |
|
groups=in_channels, bias=False) |
|
weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], |
|
[[[0.5, 0.5], [-0.5, -0.5]]], |
|
[[[0.5, -0.5], [ 0.5, -0.5]]], |
|
[[[0.5, -0.5], [-0.5, 0.5]]]], |
|
dtype=torch.get_default_dtype() |
|
).repeat(in_channels, 1, 1, 1) |
|
self.weight.data.copy_(weight) |
|
self.requires_grad_(False) |
|
|
|
|
|
class DWTInverse(nn.ConvTranspose2d): |
|
def __init__(self, in_channels=64): |
|
super(DWTInverse, self).__init__(in_channels, in_channels//4, 2, 2, |
|
groups=in_channels//4, bias=False) |
|
weight = torch.tensor([[[[0.5, 0.5], [ 0.5, 0.5]]], |
|
[[[0.5, 0.5], [-0.5, -0.5]]], |
|
[[[0.5, -0.5], [ 0.5, -0.5]]], |
|
[[[0.5, -0.5], [-0.5, 0.5]]]], |
|
dtype=torch.get_default_dtype() |
|
).repeat(in_channels//4, 1, 1, 1) |
|
self.weight.data.copy_(weight) |
|
self.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
class CALayer(nn.Module): |
|
def __init__(self, channel=64, reduction=16): |
|
super(CALayer, self).__init__() |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.conv_du = nn.Sequential( |
|
nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv_du(y) |
|
return x * y |
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC'): |
|
super(ResBlock, self).__init__() |
|
|
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
self.res = conv(in_channels, out_channels, kernel_size, |
|
stride, padding=padding, bias=bias, mode=mode) |
|
|
|
def forward(self, x): |
|
res = self.res(x) |
|
return x + res |
|
|
|
|
|
|
|
|
|
|
|
class RCABlock(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC', reduction=16): |
|
super(RCABlock, self).__init__() |
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
self.res = conv(in_channels, out_channels, kernel_size, |
|
stride, padding, bias=bias, mode=mode) |
|
self.ca = CALayer(out_channels, reduction) |
|
|
|
def forward(self, x): |
|
res = self.res(x) |
|
res = self.ca(res) |
|
return res + x |
|
|
|
|
|
|
|
|
|
|
|
class RCAGroup(nn.Module): |
|
def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, |
|
padding=1, bias=True, mode='CRC', reduction=16, nb=12): |
|
super(RCAGroup, self).__init__() |
|
assert in_channels == out_channels |
|
if mode[0] in ['R','L']: |
|
mode = mode[0].lower() + mode[1:] |
|
|
|
RG = [RCABlock(in_channels, out_channels, kernel_size, stride, padding, |
|
bias, mode, reduction) for _ in range(nb)] |
|
|
|
|
|
RG.append(conv(out_channels, out_channels, mode='C')) |
|
|
|
self.rg = nn.Sequential(*RG) |
|
|
|
def forward(self, x): |
|
res = self.rg(x) |
|
return res + x |
|
|
|
|