VideoMatting / model /mobilenet.py
Fazhong Liu
init
854728f
raw
history blame contribute delete
No virus
1.9 kB
from torch import nn
from torchvision.models import MobileNetV2
class MobileNetV2Encoder(MobileNetV2):
"""
MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
use dilation on the last block to maintain output stride 16, and deleted the
classifier block that was originally used for classification. The forward method
additionally returns the feature maps at all resolutions for decoder's use.
"""
def __init__(self, in_channels, norm_layer=None):
super().__init__()
# Replace first conv layer if in_channels doesn't match.
if in_channels != 3:
self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
# Remove last block
self.features = self.features[:-1]
# Change to use dilation to maintain output stride = 16
self.features[14].conv[1][0].stride = (1, 1)
for feature in self.features[15:]:
feature.conv[1][0].dilation = (2, 2)
feature.conv[1][0].padding = (2, 2)
# Delete classifier
del self.classifier
def forward(self, x):
x0 = x # 1/1
x = self.features[0](x)
x = self.features[1](x)
x1 = x # 1/2
x = self.features[2](x)
x = self.features[3](x)
x2 = x # 1/4
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
x3 = x # 1/8
x = self.features[7](x)
x = self.features[8](x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x = self.features[13](x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
x = self.features[17](x)
x4 = x # 1/16
return x4, x3, x2, x1, x0