import torch import torch.nn as nn import torch.nn.functional as F class Decoder(nn.Module): """ Decoder upsamples the image by combining the feature maps at all resolutions from the encoder. Input: x4: (B, C, H/16, W/16) feature map at 1/16 resolution. x3: (B, C, H/8, W/8) feature map at 1/8 resolution. x2: (B, C, H/4, W/4) feature map at 1/4 resolution. x1: (B, C, H/2, W/2) feature map at 1/2 resolution. x0: (B, C, H, W) feature map at full resolution. Output: x: (B, C, H, W) upsampled output at full resolution. """ def __init__(self, channels, feature_channels): super().__init__() self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels[1]) self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels[2]) self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(channels[3]) self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1) self.relu = nn.ReLU(True) def forward(self, x4, x3, x2, x1, x0): x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x3], dim=1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x2], dim=1) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x1], dim=1) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x0], dim=1) x = self.conv4(x) return x