roubaofeipi's picture
Upload 100 files
5231633 verified
raw
history blame
No virus
2.3 kB
import torch
import torchvision
from torch import nn
from PIL import Image
import numpy as np
import os
# MICRO RESNET
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.resblock = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
)
def forward(self, x):
out = self.resblock(x)
return out + x
class Upsample2d(nn.Module):
def __init__(self, scale_factor):
super(Upsample2d, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
return x
class MicroResNet(nn.Module):
def __init__(self):
super(MicroResNet, self).__init__()
self.downsampler = nn.Sequential(
nn.ReflectionPad2d(4),
nn.Conv2d(3, 8, kernel_size=9, stride=4),
nn.InstanceNorm2d(8, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 16, kernel_size=3, stride=2),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(16, 32, kernel_size=3, stride=2),
nn.InstanceNorm2d(32, affine=True),
nn.ReLU(),
)
self.residual = nn.Sequential(
ResBlock(32),
nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
ResBlock(64),
)
self.segmentator = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(64, 16, kernel_size=3),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
Upsample2d(scale_factor=2),
nn.ReflectionPad2d(4),
nn.Conv2d(16, 1, kernel_size=9),
nn.Sigmoid()
)
def forward(self, x):
out = self.downsampler(x)
out = self.residual(out)
out = self.segmentator(out)
return out