VideoMatting / model /decoder.py
Fazhong Liu
init
854728f
raw
history blame contribute delete
No virus
2.08 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""
Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
Input:
x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
x0: (B, C, H, W) feature map at full resolution.
Output:
x: (B, C, H, W) upsampled output at full resolution.
"""
def __init__(self, channels, feature_channels):
super().__init__()
self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels[1])
self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels[2])
self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(channels[3])
self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x4, x3, x2, x1, x0):
x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x3], dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x2], dim=1)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x1], dim=1)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x0], dim=1)
x = self.conv4(x)
return x