import numpy as np import torch from torch import nn from torch.nn import init class SEAttention(nn.Module): def __init__(self, channel=512,reduction=16): super().__init__() self.fc = nn.Sequential( nn.Linear(channel, channel // reduction, bias=False), nn.GELU(), nn.Linear(channel // reduction, channel, bias=False), nn.GELU(), nn.Linear(channel, 1, bias=False), nn.Sigmoid() ) def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, x): x = self.fc(x) return x if __name__ == '__main__': input=torch.randn(50,512,7,7) se = SEAttention(channel=512,reduction=8) output=se(input) print(output.shape)