faceplugin's picture
Update model
0367344
raw
history blame
11.6 kB
#!/usr/bin/env python
# encoding: utf-8
'''
@author: MingDong
@file: cbam.py
@desc: Convolutional Block Attention Module in ECCV 2018, including channel attention module and spatial attention module.
'''
import torch
from torch import nn
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class SEModule(nn.Module):
'''Squeeze and Excitation Module'''
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
image = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return image * x
class CAModule(nn.Module):
'''Channel Attention Module'''
def __init__(self, channels, reduction):
super(CAModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.shared_mlp = nn.Sequential(nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
image = x
avg_pool = self.avg_pool(x)
max_pool = self.max_pool(x)
x = self.shared_mlp(avg_pool) + self.shared_mlp(max_pool)
x = self.sigmoid(x)
return image * x
class SAModule(nn.Module):
'''Spatial Attention Module'''
def __init__(self):
super(SAModule, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
image = x
avg_c = torch.mean(x, 1, True)
max_c, _ = torch.max(x, 1, True)
x = torch.cat((avg_c, max_c), 1)
x = self.conv(x)
x = self.sigmoid(x)
return image * x
class BottleNeck_IR(nn.Module):
'''Improved Residual Bottlenecks'''
def __init__(self, in_channel, out_channel, stride, dim_match):
super(BottleNeck_IR, self).__init__()
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.PReLU(out_channel),
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
nn.BatchNorm2d(out_channel))
if dim_match:
self.shortcut_layer = None
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
nn.BatchNorm2d(out_channel)
)
def forward(self, x):
shortcut = x
res = self.res_layer(x)
if self.shortcut_layer is not None:
shortcut = self.shortcut_layer(x)
return shortcut + res
class BottleNeck_IR_SE(nn.Module):
'''Improved Residual Bottlenecks with Squeeze and Excitation Module'''
def __init__(self, in_channel, out_channel, stride, dim_match):
super(BottleNeck_IR_SE, self).__init__()
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.PReLU(out_channel),
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
nn.BatchNorm2d(out_channel),
SEModule(out_channel, 16))
if dim_match:
self.shortcut_layer = None
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
nn.BatchNorm2d(out_channel)
)
def forward(self, x):
shortcut = x
res = self.res_layer(x)
if self.shortcut_layer is not None:
shortcut = self.shortcut_layer(x)
return shortcut + res
class BottleNeck_IR_CAM(nn.Module):
'''Improved Residual Bottlenecks with Channel Attention Module'''
def __init__(self, in_channel, out_channel, stride, dim_match):
super(BottleNeck_IR_CAM, self).__init__()
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.PReLU(out_channel),
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
nn.BatchNorm2d(out_channel),
CAModule(out_channel, 16))
if dim_match:
self.shortcut_layer = None
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
nn.BatchNorm2d(out_channel)
)
def forward(self, x):
shortcut = x
res = self.res_layer(x)
if self.shortcut_layer is not None:
shortcut = self.shortcut_layer(x)
return shortcut + res
class BottleNeck_IR_SAM(nn.Module):
'''Improved Residual Bottlenecks with Spatial Attention Module'''
def __init__(self, in_channel, out_channel, stride, dim_match):
super(BottleNeck_IR_SAM, self).__init__()
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.PReLU(out_channel),
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
nn.BatchNorm2d(out_channel),
SAModule())
if dim_match:
self.shortcut_layer = None
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
nn.BatchNorm2d(out_channel)
)
def forward(self, x):
shortcut = x
res = self.res_layer(x)
if self.shortcut_layer is not None:
shortcut = self.shortcut_layer(x)
return shortcut + res
class BottleNeck_IR_CBAM(nn.Module):
'''Improved Residual Bottleneck with Channel Attention Module and Spatial Attention Module'''
def __init__(self, in_channel, out_channel, stride, dim_match):
super(BottleNeck_IR_CBAM, self).__init__()
self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
nn.BatchNorm2d(out_channel),
nn.PReLU(out_channel),
nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
nn.BatchNorm2d(out_channel),
CAModule(out_channel, 16),
SAModule()
)
if dim_match:
self.shortcut_layer = None
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
nn.BatchNorm2d(out_channel)
)
def forward(self, x):
shortcut = x
res = self.res_layer(x)
if self.shortcut_layer is not None:
shortcut = self.shortcut_layer(x)
return shortcut + res
filter_list = [64, 64, 128, 256, 512]
def get_layers(num_layers):
if num_layers == 50:
return [3, 4, 14, 3]
elif num_layers == 100:
return [3, 13, 30, 3]
elif num_layers == 152:
return [3, 8, 36, 3]
return None
class CBAMResNet(nn.Module):
def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode='ir', filter_list=filter_list):
super(CBAMResNet, self).__init__()
assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152'
assert mode in ['ir', 'ir_se', 'ir_cam', 'ir_sam', 'ir_cbam'], 'mode should be ir, ir_se, ir_cam, ir_sam or ir_cbam'
layers = get_layers(num_layers)
if mode == 'ir':
block = BottleNeck_IR
elif mode == 'ir_se':
block = BottleNeck_IR_SE
elif mode == 'ir_cam':
block = BottleNeck_IR_CAM
elif mode == 'ir_sam':
block = BottleNeck_IR_SAM
elif mode == 'ir_cbam':
block = BottleNeck_IR_CBAM
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.PReLU(64))
self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2)
self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2)
self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2)
self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2)
self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
nn.Dropout(drop_ratio),
Flatten(),
nn.Linear(512 * 7 * 7, feature_dim),
nn.BatchNorm1d(feature_dim))
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, in_channel, out_channel, blocks, stride):
layers = []
layers.append(block(in_channel, out_channel, stride, False))
for _ in range(1, blocks):
layers.append(block(out_channel, out_channel, 1, True))
return nn.Sequential(*layers)
def forward(self, x):
x = self.input_layer(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.output_layer(x)
return x
if __name__ == '__main__':
x = torch.Tensor(2, 3, 112, 112)
net = CBAMResNet(50, mode='ir')
out = net(x)
print(out.shape)