Spaces:
Runtime error
Runtime error
# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class BasicConv(nn.Module): | |
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): | |
super(BasicConv, self).__init__() | |
self.out_channels = out_planes | |
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |
class Flatten(nn.Module): | |
def forward(self, x): | |
return x.view(x.size(0), -1) | |
class ChannelGate(nn.Module): | |
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): | |
super(ChannelGate, self).__init__() | |
self.gate_channels = gate_channels | |
self.mlp = nn.Sequential( | |
Flatten(), | |
nn.Linear(gate_channels, gate_channels // reduction_ratio), | |
nn.ReLU(), | |
nn.Linear(gate_channels // reduction_ratio, gate_channels) | |
) | |
self.pool_types = pool_types | |
def forward(self, x): | |
channel_att_sum = None | |
for pool_type in self.pool_types: | |
if pool_type=='avg': | |
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | |
channel_att_raw = self.mlp( avg_pool ) | |
elif pool_type=='max': | |
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) | |
channel_att_raw = self.mlp( max_pool ) | |
if channel_att_sum is None: | |
channel_att_sum = channel_att_raw | |
else: | |
channel_att_sum = channel_att_sum + channel_att_raw | |
scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) | |
return x * scale | |
class ChannelPool(nn.Module): | |
def forward(self, x): | |
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) | |
class SpatialGate(nn.Module): | |
def __init__(self): | |
super(SpatialGate, self).__init__() | |
kernel_size = 7 | |
self.compress = ChannelPool() | |
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) | |
def forward(self, x): | |
x_compress = self.compress(x) | |
x_out = self.spatial(x_compress) | |
scale = torch.sigmoid(x_out) # broadcasting | |
return x * scale | |
class CBAM(nn.Module): | |
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): | |
super(CBAM, self).__init__() | |
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) | |
self.no_spatial=no_spatial | |
if not no_spatial: | |
self.SpatialGate = SpatialGate() | |
def forward(self, x): | |
x_out = self.ChannelGate(x) | |
if not self.no_spatial: | |
x_out = self.SpatialGate(x_out) | |
return x_out | |