Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class ConvNormLReLU(nn.Sequential): | |
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False): | |
pad_layer = { | |
"zero": nn.ZeroPad2d, | |
"same": nn.ReplicationPad2d, | |
"reflect": nn.ReflectionPad2d, | |
} | |
if pad_mode not in pad_layer: | |
raise NotImplementedError | |
super(ConvNormLReLU, self).__init__( | |
pad_layer[pad_mode](padding), | |
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias), | |
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
class InvertedResBlock(nn.Module): | |
def __init__(self, in_ch, out_ch, expansion_ratio=2): | |
super(InvertedResBlock, self).__init__() | |
self.use_res_connect = in_ch == out_ch | |
bottleneck = int(round(in_ch*expansion_ratio)) | |
layers = [] | |
if expansion_ratio != 1: | |
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0)) | |
# dw | |
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True)) | |
# pw | |
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False)) | |
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True)) | |
self.layers = nn.Sequential(*layers) | |
def forward(self, input): | |
out = self.layers(input) | |
if self.use_res_connect: | |
out = input + out | |
return out | |
class Generator(nn.Module): | |
def __init__(self, ): | |
super().__init__() | |
self.block_a = nn.Sequential( | |
ConvNormLReLU(3, 32, kernel_size=7, padding=3), | |
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)), | |
ConvNormLReLU(64, 64) | |
) | |
self.block_b = nn.Sequential( | |
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)), | |
ConvNormLReLU(128, 128) | |
) | |
self.block_c = nn.Sequential( | |
ConvNormLReLU(128, 128), | |
InvertedResBlock(128, 256, 2), | |
InvertedResBlock(256, 256, 2), | |
InvertedResBlock(256, 256, 2), | |
InvertedResBlock(256, 256, 2), | |
ConvNormLReLU(256, 128), | |
) | |
self.block_d = nn.Sequential( | |
ConvNormLReLU(128, 128), | |
ConvNormLReLU(128, 128) | |
) | |
self.block_e = nn.Sequential( | |
ConvNormLReLU(128, 64), | |
ConvNormLReLU(64, 64), | |
ConvNormLReLU(64, 32, kernel_size=7, padding=3) | |
) | |
self.out_layer = nn.Sequential( | |
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False), | |
nn.Tanh() | |
) | |
def forward(self, input, align_corners=True): | |
out = self.block_a(input) | |
half_size = out.size()[-2:] | |
out = self.block_b(out) | |
out = self.block_c(out) | |
if align_corners: | |
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True) | |
else: | |
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) | |
out = self.block_d(out) | |
if align_corners: | |
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) | |
else: | |
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) | |
out = self.block_e(out) | |
out = self.out_layer(out) | |
return out | |