# -*- coding: UTF-8 -*- '''================================================= @Project -> File pram -> layers @IDE PyCharm @Author fx221@cam.ac.uk @Date 29/01/2024 14:46 ==================================================''' import torch import torch.nn as nn import torch.nn.functional as F from copy import deepcopy from einops import rearrange def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'): """ Multi-layer perceptron """ n = len(channels) layers = [] for i in range(1, n): layers.append( nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) if i < (n - 1): if norm_fn == 'in': layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3)) elif norm_fn == 'bn': layers.append(nn.BatchNorm1d(channels[i], eps=1e-3)) if ac_fn == 'relu': layers.append(nn.ReLU()) elif ac_fn == 'gelu': layers.append(nn.GELU()) elif ac_fn == 'lrelu': layers.append(nn.LeakyReLU(negative_slope=0.1)) # if norm_fn == 'ln': # layers.append(nn.LayerNorm(channels[i])) return nn.Sequential(*layers) class MultiHeadedAttention(nn.Module): def __init__(self, num_heads: int, d_model: int): super().__init__() assert d_model % num_heads == 0 self.dim = d_model // num_heads self.num_heads = num_heads self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) def forward(self, query, key, value, M=None): ''' :param query: [B, D, N] :param key: [B, D, M] :param value: [B, D, M] :param M: [B, N, M] :return: ''' batch_dim = query.size(0) query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] # [B, D, NH, N] dim = query.shape[1] scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5 if M is not None: # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2)) # scores = scores * M[:, None, :, :].expand_as(scores) # with torch.no_grad(): mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool() # [B, H, N, M] scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max) prob = F.softmax(scores, dim=-1) # * (~mask).float() # * mask.float() else: prob = F.softmax(scores, dim=-1) x = torch.einsum('bhnm,bdhm->bdhn', prob, value) self.prob = prob out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1)) return out class AttentionalPropagation(nn.Module): def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'): super().__init__() self.attn = MultiHeadedAttention(num_heads, feature_dim) self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) nn.init.constant_(self.mlp[-1].bias, 0.0) def forward(self, x, source, M=None): message = self.attn(x, source, source, M=M) self.prob = self.attn.prob out = self.mlp(torch.cat([x, message], dim=1)) return out class KeypointEncoder(nn.Module): """ Joint encoding of visual appearance and location using MLPs""" def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'): super().__init__() self.input_dim = input_dim self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn) nn.init.constant_(self.encoder[-1].bias, 0.0) def forward(self, kpts, scores=None): if self.input_dim == 2: return self.encoder(kpts.transpose(1, 2)) else: inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] # [B, 2, N] + [B, 1, N] return self.encoder(torch.cat(inputs, dim=1))