import torch | |
import torch.nn as nn | |
import numpy as np | |
class LatentUpscaler(nn.Module): | |
def head(self): | |
return [ | |
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad), | |
nn.ReLU(), | |
nn.Upsample(scale_factor=self.fac, mode="nearest"), | |
nn.ReLU(), | |
] | |
def core(self): | |
layers = [] | |
for _ in range(self.depth): | |
layers += [ | |
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad), | |
nn.ReLU(), | |
] | |
return layers | |
def tail(self): | |
return [ | |
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad), | |
] | |
def __init__(self, fac, depth=16): | |
super().__init__() | |
self.size = 64 # Conv2d size | |
self.chan = 4 # in/out channels | |
self.depth = depth # no. of layers | |
self.fac = fac # scale factor | |
self.krn = 3 # kernel size | |
self.pad = 1 # padding | |
self.sequential = nn.Sequential( | |
*self.head(), | |
*self.core(), | |
*self.tail(), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.sequential(x) | |