|
|
|
|
|
''' |
|
@author: MingDong |
|
@file: cbam.py |
|
@desc: Convolutional Block Attention Module in ECCV 2018, including channel attention module and spatial attention module. |
|
''' |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class Flatten(nn.Module): |
|
def forward(self, x): |
|
return x.view(x.size(0), -1) |
|
|
|
class SEModule(nn.Module): |
|
'''Squeeze and Excitation Module''' |
|
def __init__(self, channels, reduction): |
|
super(SEModule, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
image = x |
|
x = self.avg_pool(x) |
|
x = self.fc1(x) |
|
x = self.relu(x) |
|
x = self.fc2(x) |
|
x = self.sigmoid(x) |
|
|
|
return image * x |
|
|
|
class CAModule(nn.Module): |
|
'''Channel Attention Module''' |
|
def __init__(self, channels, reduction): |
|
super(CAModule, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.max_pool = nn.AdaptiveMaxPool2d(1) |
|
self.shared_mlp = nn.Sequential(nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
image = x |
|
avg_pool = self.avg_pool(x) |
|
max_pool = self.max_pool(x) |
|
x = self.shared_mlp(avg_pool) + self.shared_mlp(max_pool) |
|
x = self.sigmoid(x) |
|
|
|
return image * x |
|
|
|
class SAModule(nn.Module): |
|
'''Spatial Attention Module''' |
|
def __init__(self): |
|
super(SAModule, self).__init__() |
|
self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
image = x |
|
avg_c = torch.mean(x, 1, True) |
|
max_c, _ = torch.max(x, 1, True) |
|
x = torch.cat((avg_c, max_c), 1) |
|
x = self.conv(x) |
|
x = self.sigmoid(x) |
|
return image * x |
|
|
|
class BottleNeck_IR(nn.Module): |
|
'''Improved Residual Bottlenecks''' |
|
def __init__(self, in_channel, out_channel, stride, dim_match): |
|
super(BottleNeck_IR, self).__init__() |
|
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), |
|
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
nn.PReLU(out_channel), |
|
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), |
|
nn.BatchNorm2d(out_channel)) |
|
if dim_match: |
|
self.shortcut_layer = None |
|
else: |
|
self.shortcut_layer = nn.Sequential( |
|
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channel) |
|
) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
res = self.res_layer(x) |
|
|
|
if self.shortcut_layer is not None: |
|
shortcut = self.shortcut_layer(x) |
|
|
|
return shortcut + res |
|
|
|
class BottleNeck_IR_SE(nn.Module): |
|
'''Improved Residual Bottlenecks with Squeeze and Excitation Module''' |
|
def __init__(self, in_channel, out_channel, stride, dim_match): |
|
super(BottleNeck_IR_SE, self).__init__() |
|
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), |
|
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
nn.PReLU(out_channel), |
|
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
SEModule(out_channel, 16)) |
|
if dim_match: |
|
self.shortcut_layer = None |
|
else: |
|
self.shortcut_layer = nn.Sequential( |
|
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channel) |
|
) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
res = self.res_layer(x) |
|
|
|
if self.shortcut_layer is not None: |
|
shortcut = self.shortcut_layer(x) |
|
|
|
return shortcut + res |
|
|
|
class BottleNeck_IR_CAM(nn.Module): |
|
'''Improved Residual Bottlenecks with Channel Attention Module''' |
|
def __init__(self, in_channel, out_channel, stride, dim_match): |
|
super(BottleNeck_IR_CAM, self).__init__() |
|
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), |
|
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
nn.PReLU(out_channel), |
|
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
CAModule(out_channel, 16)) |
|
if dim_match: |
|
self.shortcut_layer = None |
|
else: |
|
self.shortcut_layer = nn.Sequential( |
|
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channel) |
|
) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
res = self.res_layer(x) |
|
|
|
if self.shortcut_layer is not None: |
|
shortcut = self.shortcut_layer(x) |
|
|
|
return shortcut + res |
|
|
|
class BottleNeck_IR_SAM(nn.Module): |
|
'''Improved Residual Bottlenecks with Spatial Attention Module''' |
|
def __init__(self, in_channel, out_channel, stride, dim_match): |
|
super(BottleNeck_IR_SAM, self).__init__() |
|
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), |
|
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
nn.PReLU(out_channel), |
|
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
SAModule()) |
|
if dim_match: |
|
self.shortcut_layer = None |
|
else: |
|
self.shortcut_layer = nn.Sequential( |
|
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channel) |
|
) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
res = self.res_layer(x) |
|
|
|
if self.shortcut_layer is not None: |
|
shortcut = self.shortcut_layer(x) |
|
|
|
return shortcut + res |
|
|
|
class BottleNeck_IR_CBAM(nn.Module): |
|
'''Improved Residual Bottleneck with Channel Attention Module and Spatial Attention Module''' |
|
def __init__(self, in_channel, out_channel, stride, dim_match): |
|
super(BottleNeck_IR_CBAM, self).__init__() |
|
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), |
|
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
nn.PReLU(out_channel), |
|
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), |
|
nn.BatchNorm2d(out_channel), |
|
CAModule(out_channel, 16), |
|
SAModule() |
|
) |
|
if dim_match: |
|
self.shortcut_layer = None |
|
else: |
|
self.shortcut_layer = nn.Sequential( |
|
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), |
|
nn.BatchNorm2d(out_channel) |
|
) |
|
|
|
def forward(self, x): |
|
shortcut = x |
|
res = self.res_layer(x) |
|
|
|
if self.shortcut_layer is not None: |
|
shortcut = self.shortcut_layer(x) |
|
|
|
return shortcut + res |
|
|
|
|
|
filter_list = [64, 64, 128, 256, 512] |
|
def get_layers(num_layers): |
|
if num_layers == 50: |
|
return [3, 4, 14, 3] |
|
elif num_layers == 100: |
|
return [3, 13, 30, 3] |
|
elif num_layers == 152: |
|
return [3, 8, 36, 3] |
|
return None |
|
|
|
|
|
class CBAMResNet(nn.Module): |
|
def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode='ir', filter_list=filter_list): |
|
super(CBAMResNet, self).__init__() |
|
assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152' |
|
assert mode in ['ir', 'ir_se', 'ir_cam', 'ir_sam', 'ir_cbam'], 'mode should be ir, ir_se, ir_cam, ir_sam or ir_cbam' |
|
layers = get_layers(num_layers) |
|
if mode == 'ir': |
|
block = BottleNeck_IR |
|
elif mode == 'ir_se': |
|
block = BottleNeck_IR_SE |
|
elif mode == 'ir_cam': |
|
block = BottleNeck_IR_CAM |
|
elif mode == 'ir_sam': |
|
block = BottleNeck_IR_SAM |
|
elif mode == 'ir_cbam': |
|
block = BottleNeck_IR_CBAM |
|
|
|
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), stride=1, padding=1, bias=False), |
|
nn.BatchNorm2d(64), |
|
nn.PReLU(64)) |
|
self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2) |
|
self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2) |
|
self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2) |
|
self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2) |
|
|
|
self.output_layer = nn.Sequential(nn.BatchNorm2d(512), |
|
nn.Dropout(drop_ratio), |
|
Flatten(), |
|
nn.Linear(512 * 7 * 7, feature_dim), |
|
nn.BatchNorm1d(feature_dim)) |
|
|
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): |
|
nn.init.xavier_uniform_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0.0) |
|
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def _make_layer(self, block, in_channel, out_channel, blocks, stride): |
|
layers = [] |
|
layers.append(block(in_channel, out_channel, stride, False)) |
|
for _ in range(1, blocks): |
|
layers.append(block(out_channel, out_channel, 1, True)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.input_layer(x) |
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
x = self.output_layer(x) |
|
|
|
return x |
|
|
|
if __name__ == '__main__': |
|
x = torch.Tensor(2, 3, 112, 112) |
|
net = CBAMResNet(50, mode='ir') |
|
|
|
out = net(x) |
|
print(out.shape) |
|
|