import torch import torchvision from torch import nn from PIL import Image import numpy as np import os # MICRO RESNET class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.resblock = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels, affine=True), ) def forward(self, x): out = self.resblock(x) return out + x class Upsample2d(nn.Module): def __init__(self, scale_factor): super(Upsample2d, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') return x class MicroResNet(nn.Module): def __init__(self): super(MicroResNet, self).__init__() self.downsampler = nn.Sequential( nn.ReflectionPad2d(4), nn.Conv2d(3, 8, kernel_size=9, stride=4), nn.InstanceNorm2d(8, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(8, 16, kernel_size=3, stride=2), nn.InstanceNorm2d(16, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(16, 32, kernel_size=3, stride=2), nn.InstanceNorm2d(32, affine=True), nn.ReLU(), ) self.residual = nn.Sequential( ResBlock(32), nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), ResBlock(64), ) self.segmentator = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(64, 16, kernel_size=3), nn.InstanceNorm2d(16, affine=True), nn.ReLU(), Upsample2d(scale_factor=2), nn.ReflectionPad2d(4), nn.Conv2d(16, 1, kernel_size=9), nn.Sigmoid() ) def forward(self, x): out = self.downsampler(x) out = self.residual(out) out = self.segmentator(out) return out