|
|
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule, Linear |
|
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
|
from mmengine.model import BaseModule |
|
from mmyolo.registry import MODELS |
|
from mmyolo.models.layers import CSPLayerWithTwoConv |
|
|
|
|
|
@MODELS.register_module() |
|
class MaxSigmoidAttnBlock(BaseModule): |
|
"""Max Sigmoid attention block.""" |
|
def __init__(self, |
|
in_channels: int, |
|
out_channels: int, |
|
guide_channels: int, |
|
embed_channels: int, |
|
kernel_size: int = 3, |
|
padding: int = 1, |
|
num_heads: int = 1, |
|
use_depthwise: bool = False, |
|
with_scale: bool = False, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN', |
|
momentum=0.03, |
|
eps=0.001), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule |
|
|
|
assert (out_channels % num_heads == 0 and |
|
embed_channels % num_heads == 0), \ |
|
'out_channels and embed_channels should be divisible by num_heads.' |
|
self.num_heads = num_heads |
|
self.head_channels = out_channels // num_heads |
|
|
|
self.embed_conv = ConvModule( |
|
in_channels, |
|
embed_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) if embed_channels != in_channels else None |
|
self.guide_fc = Linear(guide_channels, embed_channels) |
|
self.bias = nn.Parameter(torch.zeros(num_heads)) |
|
if with_scale: |
|
self.scale = nn.Parameter(torch.ones(1, num_heads, 1, 1)) |
|
else: |
|
self.scale = 1.0 |
|
|
|
self.project_conv = conv(in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=padding, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) |
|
|
|
def forward(self, x: Tensor, guide: Tensor) -> Tensor: |
|
"""Forward process.""" |
|
B, _, H, W = x.shape |
|
|
|
guide = self.guide_fc(guide) |
|
guide = guide.reshape(B, -1, self.num_heads, self.head_channels) |
|
embed = self.embed_conv(x) if self.embed_conv is not None else x |
|
embed = embed.reshape(B, self.num_heads, self.head_channels, H, W) |
|
|
|
attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) |
|
attn_weight = attn_weight.max(dim=-1)[0] |
|
attn_weight = attn_weight / (self.head_channels**0.5) |
|
attn_weight = attn_weight + self.bias[None, :, None, None] |
|
attn_weight = attn_weight.sigmoid() * self.scale |
|
|
|
x = self.project_conv(x) |
|
x = x.reshape(B, self.num_heads, -1, H, W) |
|
x = x * attn_weight.unsqueeze(2) |
|
x = x.reshape(B, -1, H, W) |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class MaxSigmoidCSPLayerWithTwoConv(CSPLayerWithTwoConv): |
|
"""Sigmoid-attention based CSP layer with two convolution layers.""" |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
guide_channels: int, |
|
embed_channels: int, |
|
num_heads: int = 1, |
|
expand_ratio: float = 0.5, |
|
num_blocks: int = 1, |
|
with_scale: bool = False, |
|
add_identity: bool = True, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), |
|
act_cfg: ConfigType = dict(type='SiLU', inplace=True), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__(in_channels=in_channels, |
|
out_channels=out_channels, |
|
expand_ratio=expand_ratio, |
|
num_blocks=num_blocks, |
|
add_identity=add_identity, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
init_cfg=init_cfg) |
|
|
|
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels, |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
|
|
self.attn_block = MaxSigmoidAttnBlock(self.mid_channels, |
|
self.mid_channels, |
|
guide_channels=guide_channels, |
|
embed_channels=embed_channels, |
|
num_heads=num_heads, |
|
with_scale=with_scale, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg) |
|
|
|
def forward(self, x: Tensor, guide: Tensor) -> Tensor: |
|
"""Forward process.""" |
|
x_main = self.main_conv(x) |
|
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) |
|
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) |
|
x_main.append(self.attn_block(x_main[-1], guide)) |
|
return self.final_conv(torch.cat(x_main, 1)) |
|
|
|
|
|
@MODELS.register_module() |
|
class ImagePoolingAttentionModule(nn.Module): |
|
def __init__(self, |
|
image_channels: List[int], |
|
text_channels: int, |
|
embed_channels: int, |
|
with_scale: bool = False, |
|
num_feats: int = 3, |
|
num_heads: int = 8, |
|
pool_size: int = 3): |
|
super().__init__() |
|
|
|
self.text_channels = text_channels |
|
self.embed_channels = embed_channels |
|
self.num_heads = num_heads |
|
self.num_feats = num_feats |
|
self.head_channels = embed_channels // num_heads |
|
self.pool_size = pool_size |
|
|
|
if with_scale: |
|
self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True) |
|
else: |
|
self.scale = 1.0 |
|
self.projections = nn.ModuleList([ |
|
ConvModule(in_channels, embed_channels, 1, act_cfg=None) |
|
for in_channels in image_channels |
|
]) |
|
self.query = nn.Sequential(nn.LayerNorm(text_channels), |
|
Linear(text_channels, embed_channels)) |
|
self.key = nn.Sequential(nn.LayerNorm(embed_channels), |
|
Linear(embed_channels, embed_channels)) |
|
self.value = nn.Sequential(nn.LayerNorm(embed_channels), |
|
Linear(embed_channels, embed_channels)) |
|
self.proj = Linear(embed_channels, text_channels) |
|
|
|
self.image_pools = nn.ModuleList([ |
|
nn.AdaptiveMaxPool2d((pool_size, pool_size)) |
|
for _ in range(num_feats) |
|
]) |
|
|
|
def forward(self, text_features, image_features): |
|
B = image_features[0].shape[0] |
|
assert len(image_features) == self.num_feats |
|
num_patches = self.pool_size**2 |
|
mlvl_image_features = [ |
|
pool(proj(x)).view(B, -1, num_patches) |
|
for (x, proj, pool |
|
) in zip(image_features, self.projections, self.image_pools) |
|
] |
|
mlvl_image_features = torch.cat(mlvl_image_features, |
|
dim=-1).transpose(1, 2) |
|
q = self.query(text_features) |
|
k = self.key(mlvl_image_features) |
|
v = self.value(mlvl_image_features) |
|
|
|
q = q.reshape(B, -1, self.num_heads, self.head_channels) |
|
k = k.reshape(B, -1, self.num_heads, self.head_channels) |
|
v = v.reshape(B, -1, self.num_heads, self.head_channels) |
|
|
|
attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k) |
|
attn_weight = attn_weight / (self.head_channels**0.5) |
|
attn_weight = F.softmax(attn_weight, dim=-1) |
|
|
|
x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v) |
|
x = self.proj(x.reshape(B, -1, self.embed_channels)) |
|
return x * self.scale + text_features |
|
|
|
|
|
@MODELS.register_module() |
|
class VanillaSigmoidBlock(BaseModule): |
|
"""Sigmoid attention block.""" |
|
def __init__(self, |
|
in_channels: int, |
|
out_channels: int, |
|
guide_channels: int, |
|
embed_channels: int, |
|
kernel_size: int = 3, |
|
padding: int = 1, |
|
num_heads: int = 1, |
|
use_depthwise: bool = False, |
|
with_scale: bool = False, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN', |
|
momentum=0.03, |
|
eps=0.001), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__(init_cfg=init_cfg) |
|
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule |
|
|
|
assert (out_channels % num_heads == 0 and |
|
embed_channels % num_heads == 0), \ |
|
'out_channels and embed_channels should be divisible by num_heads.' |
|
self.num_heads = num_heads |
|
self.head_channels = out_channels // num_heads |
|
|
|
self.project_conv = conv(in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=1, |
|
padding=padding, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) |
|
|
|
def forward(self, x: Tensor, guide: Tensor) -> Tensor: |
|
"""Forward process.""" |
|
x = self.project_conv(x) |
|
x = x * x.sigmoid() |
|
return x |
|
|
|
|
|
@MODELS.register_module() |
|
class EfficientCSPLayerWithTwoConv(CSPLayerWithTwoConv): |
|
"""Sigmoid-attention based CSP layer with two convolution layers.""" |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
guide_channels: int, |
|
embed_channels: int, |
|
num_heads: int = 1, |
|
expand_ratio: float = 0.5, |
|
num_blocks: int = 1, |
|
with_scale: bool = False, |
|
add_identity: bool = True, |
|
conv_cfg: OptConfigType = None, |
|
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), |
|
act_cfg: ConfigType = dict(type='SiLU', inplace=True), |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super().__init__(in_channels=in_channels, |
|
out_channels=out_channels, |
|
expand_ratio=expand_ratio, |
|
num_blocks=num_blocks, |
|
add_identity=add_identity, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
init_cfg=init_cfg) |
|
|
|
self.final_conv = ConvModule((3 + num_blocks) * self.mid_channels, |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg) |
|
|
|
self.attn_block = VanillaSigmoidBlock(self.mid_channels, |
|
self.mid_channels, |
|
guide_channels=guide_channels, |
|
embed_channels=embed_channels, |
|
num_heads=num_heads, |
|
with_scale=with_scale, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg) |
|
|
|
def forward(self, x: Tensor, guide: Tensor) -> Tensor: |
|
"""Forward process.""" |
|
x_main = self.main_conv(x) |
|
x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) |
|
x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) |
|
x_main.append(self.attn_block(x_main[-1], guide)) |
|
return self.final_conv(torch.cat(x_main, 1)) |
|
|