import torch import torch.nn as nn class ResBlock(nn.Module): """Block with residuals""" def __init__(self, ch): super().__init__() self.join = nn.ReLU() self.norm = nn.BatchNorm2d(ch) self.long = nn.Sequential( nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(ch, ch, kernel_size=3, stride=1, padding=1), nn.Dropout(0.1) ) def forward(self, x): x = self.norm(x) return self.join(self.long(x) + x) class ExtractBlock(nn.Module): """Increase no. of channels by [out/in]""" def __init__(self, ch_in, ch_out): super().__init__() self.join = nn.ReLU() self.short = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1) self.long = nn.Sequential( nn.Conv2d( ch_in, ch_out, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), nn.SiLU(), nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1), nn.Dropout(0.1) ) def forward(self, x): return self.join(self.long(x) + self.short(x)) class InterposerModel(nn.Module): """Main neural network""" def __init__(self, ch_in=4, ch_out=4, ch_mid=64, scale=1.0, blocks=12): super().__init__() self.ch_in = ch_in self.ch_out = ch_out self.ch_mid = ch_mid self.blocks = blocks self.scale = scale self.head = ExtractBlock(self.ch_in, self.ch_mid) self.core = nn.Sequential( nn.Upsample(scale_factor=self.scale, mode="nearest"), *[ResBlock(self.ch_mid) for _ in range(blocks)], nn.BatchNorm2d(self.ch_mid), nn.SiLU(), ) self.tail = nn.Conv2d(self.ch_mid, self.ch_out, kernel_size=3, stride=1, padding=1) def forward(self, x): y = self.head(x) z = self.core(y) return self.tail(z)