from torch import nn from torchvision.models.resnet import ResNet, Bottleneck class ResNetEncoder(ResNet): """ ResNetEncoder inherits from torchvision's official ResNet. It is modified to use dilation on the last block to maintain output stride 16, and deleted the global average pooling layer and the fully connected layer that was originally used for classification. The forward method additionally returns the feature maps at all resolutions for decoder's use. """ layers = { 'resnet50': [3, 4, 6, 3], 'resnet101': [3, 4, 23, 3], } def __init__(self, in_channels, variant='resnet101', norm_layer=None): super().__init__( block=Bottleneck, layers=self.layers[variant], replace_stride_with_dilation=[False, False, True], norm_layer=norm_layer) # Replace first conv layer if in_channels doesn't match. if in_channels != 3: self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False) # Delete fully-connected layer del self.avgpool del self.fc def forward(self, x): x0 = x # 1/1 x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x1 = x # 1/2 x = self.maxpool(x) x = self.layer1(x) x2 = x # 1/4 x = self.layer2(x) x3 = x # 1/8 x = self.layer3(x) x = self.layer4(x) x4 = x # 1/16 return x4, x3, x2, x1, x0