from torch import nn from torchvision.models import MobileNetV2 class MobileNetV2Encoder(MobileNetV2): """ MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to use dilation on the last block to maintain output stride 16, and deleted the classifier block that was originally used for classification. The forward method additionally returns the feature maps at all resolutions for decoder's use. """ def __init__(self, in_channels, norm_layer=None): super().__init__() # Replace first conv layer if in_channels doesn't match. if in_channels != 3: self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False) # Remove last block self.features = self.features[:-1] # Change to use dilation to maintain output stride = 16 self.features[14].conv[1][0].stride = (1, 1) for feature in self.features[15:]: feature.conv[1][0].dilation = (2, 2) feature.conv[1][0].padding = (2, 2) # Delete classifier del self.classifier def forward(self, x): x0 = x # 1/1 x = self.features[0](x) x = self.features[1](x) x1 = x # 1/2 x = self.features[2](x) x = self.features[3](x) x2 = x # 1/4 x = self.features[4](x) x = self.features[5](x) x = self.features[6](x) x3 = x # 1/8 x = self.features[7](x) x = self.features[8](x) x = self.features[9](x) x = self.features[10](x) x = self.features[11](x) x = self.features[12](x) x = self.features[13](x) x = self.features[14](x) x = self.features[15](x) x = self.features[16](x) x = self.features[17](x) x4 = x # 1/16 return x4, x3, x2, x1, x0