Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Facebook, Inc. and its affiliates. | |
from copy import deepcopy | |
import fvcore.nn.weight_init as weight_init | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from .batch_norm import get_norm | |
from .blocks import DepthwiseSeparableConv2d | |
from .wrappers import Conv2d | |
class ASPP(nn.Module): | |
""" | |
Atrous Spatial Pyramid Pooling (ASPP). | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
dilations, | |
*, | |
norm, | |
activation, | |
pool_kernel_size=None, | |
dropout: float = 0.0, | |
use_depthwise_separable_conv=False, | |
): | |
""" | |
Args: | |
in_channels (int): number of input channels for ASPP. | |
out_channels (int): number of output channels. | |
dilations (list): a list of 3 dilations in ASPP. | |
norm (str or callable): normalization for all conv layers. | |
See :func:`layers.get_norm` for supported format. norm is | |
applied to all conv layers except the conv following | |
global average pooling. | |
activation (callable): activation function. | |
pool_kernel_size (tuple, list): the average pooling size (kh, kw) | |
for image pooling layer in ASPP. If set to None, it always | |
performs global average pooling. If not None, it must be | |
divisible by the shape of inputs in forward(). It is recommended | |
to use a fixed input feature size in training, and set this | |
option to match this size, so that it performs global average | |
pooling in training, and the size of the pooling window stays | |
consistent in inference. | |
dropout (float): apply dropout on the output of ASPP. It is used in | |
the official DeepLab implementation with a rate of 0.1: | |
https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa | |
use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d | |
for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. | |
""" | |
super(ASPP, self).__init__() | |
assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) | |
self.pool_kernel_size = pool_kernel_size | |
self.dropout = dropout | |
use_bias = norm == "" | |
self.convs = nn.ModuleList() | |
# conv 1x1 | |
self.convs.append( | |
Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
bias=use_bias, | |
norm=get_norm(norm, out_channels), | |
activation=deepcopy(activation), | |
) | |
) | |
weight_init.c2_xavier_fill(self.convs[-1]) | |
# atrous convs | |
for dilation in dilations: | |
if use_depthwise_separable_conv: | |
self.convs.append( | |
DepthwiseSeparableConv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
padding=dilation, | |
dilation=dilation, | |
norm1=norm, | |
activation1=deepcopy(activation), | |
norm2=norm, | |
activation2=deepcopy(activation), | |
) | |
) | |
else: | |
self.convs.append( | |
Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
padding=dilation, | |
dilation=dilation, | |
bias=use_bias, | |
norm=get_norm(norm, out_channels), | |
activation=deepcopy(activation), | |
) | |
) | |
weight_init.c2_xavier_fill(self.convs[-1]) | |
# image pooling | |
# We do not add BatchNorm because the spatial resolution is 1x1, | |
# the original TF implementation has BatchNorm. | |
if pool_kernel_size is None: | |
image_pooling = nn.Sequential( | |
nn.AdaptiveAvgPool2d(1), | |
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), | |
) | |
else: | |
image_pooling = nn.Sequential( | |
nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), | |
Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), | |
) | |
weight_init.c2_xavier_fill(image_pooling[1]) | |
self.convs.append(image_pooling) | |
self.project = Conv2d( | |
5 * out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=use_bias, | |
norm=get_norm(norm, out_channels), | |
activation=deepcopy(activation), | |
) | |
weight_init.c2_xavier_fill(self.project) | |
def forward(self, x): | |
size = x.shape[-2:] | |
if self.pool_kernel_size is not None: | |
if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: | |
raise ValueError( | |
"`pool_kernel_size` must be divisible by the shape of inputs. " | |
"Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) | |
) | |
res = [] | |
for conv in self.convs: | |
res.append(conv(x)) | |
res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) | |
res = torch.cat(res, dim=1) | |
res = self.project(res) | |
res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res | |
return res | |