|
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class ResidualConv(nn.Module):
|
|
def __init__(self, input_dim, output_dim, stride, padding):
|
|
super(ResidualConv, self).__init__()
|
|
|
|
self.conv_block = nn.Sequential(
|
|
nn.BatchNorm2d(input_dim),
|
|
nn.ReLU(),
|
|
nn.Conv2d(
|
|
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
|
),
|
|
nn.BatchNorm2d(output_dim),
|
|
nn.ReLU(),
|
|
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
|
)
|
|
self.conv_skip = nn.Sequential(
|
|
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
|
nn.BatchNorm2d(output_dim),
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
return self.conv_block(x) + self.conv_skip(x)
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
def __init__(self, input_dim, output_dim, kernel, stride):
|
|
super(Upsample, self).__init__()
|
|
|
|
self.upsample = nn.ConvTranspose2d(
|
|
input_dim, output_dim, kernel_size=kernel, stride=stride
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.upsample(x)
|
|
|
|
|
|
class Squeeze_Excite_Block(nn.Module):
|
|
def __init__(self, channel, reduction=16):
|
|
super(Squeeze_Excite_Block, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(channel, channel // reduction, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(channel // reduction, channel, bias=False),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, c, _, _ = x.size()
|
|
y = self.avg_pool(x).view(b, c)
|
|
y = self.fc(y).view(b, c, 1, 1)
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
class ASPP(nn.Module):
|
|
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
|
|
super(ASPP, self).__init__()
|
|
|
|
self.aspp_block1 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm2d(out_dims),
|
|
)
|
|
self.aspp_block2 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm2d(out_dims),
|
|
)
|
|
self.aspp_block3 = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
|
|
),
|
|
nn.ReLU(inplace=True),
|
|
nn.BatchNorm2d(out_dims),
|
|
)
|
|
|
|
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
|
|
self._init_weights()
|
|
|
|
def forward(self, x):
|
|
x1 = self.aspp_block1(x)
|
|
x2 = self.aspp_block2(x)
|
|
x3 = self.aspp_block3(x)
|
|
out = torch.cat([x1, x2, x3], dim=1)
|
|
return self.output(out)
|
|
|
|
def _init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
class Upsample_(nn.Module):
|
|
def __init__(self, scale=2):
|
|
super(Upsample_, self).__init__()
|
|
|
|
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
|
|
|
|
def forward(self, x):
|
|
return self.upsample(x)
|
|
|
|
|
|
class AttentionBlock(nn.Module):
|
|
def __init__(self, input_encoder, input_decoder, output_dim):
|
|
super(AttentionBlock, self).__init__()
|
|
|
|
self.conv_encoder = nn.Sequential(
|
|
nn.BatchNorm2d(input_encoder),
|
|
nn.ReLU(),
|
|
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
|
|
nn.MaxPool2d(2, 2),
|
|
)
|
|
|
|
self.conv_decoder = nn.Sequential(
|
|
nn.BatchNorm2d(input_decoder),
|
|
nn.ReLU(),
|
|
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
|
|
)
|
|
|
|
self.conv_attn = nn.Sequential(
|
|
nn.BatchNorm2d(output_dim),
|
|
nn.ReLU(),
|
|
nn.Conv2d(output_dim, 1, 1),
|
|
)
|
|
|
|
def forward(self, x1, x2):
|
|
out = self.conv_encoder(x1) + self.conv_decoder(x2)
|
|
out = self.conv_attn(out)
|
|
return out * x2 |