File size: 1,068 Bytes
c336648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
import numpy as np
class LatentUpscaler(nn.Module):
def head(self):
return [
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
nn.Upsample(scale_factor=self.fac, mode="nearest"),
nn.ReLU(),
]
def core(self):
layers = []
for _ in range(self.depth):
layers += [
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
nn.ReLU(),
]
return layers
def tail(self):
return [
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
]
def __init__(self, fac, depth=16):
super().__init__()
self.size = 64 # Conv2d size
self.chan = 4 # in/out channels
self.depth = depth # no. of layers
self.fac = fac # scale factor
self.krn = 3 # kernel size
self.pad = 1 # padding
self.sequential = nn.Sequential(
*self.head(),
*self.core(),
*self.tail(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.sequential(x)
|