File size: 4,112 Bytes
63f3cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   pram -> layers
@IDE    PyCharm
@Author fx221@cam.ac.uk
@Date   29/01/2024 14:46
=================================================='''
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
from einops import rearrange


def MLP(channels: list, do_bn=True, ac_fn='relu', norm_fn='bn'):
    """ Multi-layer perceptron """
    n = len(channels)
    layers = []
    for i in range(1, n):
        layers.append(
            nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
        if i < (n - 1):
            if norm_fn == 'in':
                layers.append(nn.InstanceNorm1d(channels[i], eps=1e-3))
            elif norm_fn == 'bn':
                layers.append(nn.BatchNorm1d(channels[i], eps=1e-3))
            if ac_fn == 'relu':
                layers.append(nn.ReLU())
            elif ac_fn == 'gelu':
                layers.append(nn.GELU())
            elif ac_fn == 'lrelu':
                layers.append(nn.LeakyReLU(negative_slope=0.1))
            # if norm_fn == 'ln':
            #     layers.append(nn.LayerNorm(channels[i]))
    return nn.Sequential(*layers)


class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value, M=None):
        '''
        :param query: [B, D, N]
        :param key: [B, D, M]
        :param value: [B, D, M]
        :param M: [B, N, M]
        :return:
        '''

        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]  # [B, D, NH, N]
        dim = query.shape[1]
        scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim ** .5

        if M is not None:
            # print('M: ', scores.shape, M.shape, torch.sum(M, dim=2))
            # scores = scores * M[:, None, :, :].expand_as(scores)
            # with torch.no_grad():
            mask = (1 - M[:, None, :, :]).repeat(1, scores.shape[1], 1, 1).bool()  # [B, H, N, M]
            scores = scores.masked_fill(mask, -torch.finfo(scores.dtype).max)
            prob = F.softmax(scores, dim=-1)  # * (~mask).float()  # * mask.float()
        else:
            prob = F.softmax(scores, dim=-1)

        x = torch.einsum('bhnm,bdhm->bdhn', prob, value)
        self.prob = prob

        out = self.merge(x.contiguous().view(batch_dim, self.dim * self.num_heads, -1))

        return out


class AttentionalPropagation(nn.Module):
    def __init__(self, feature_dim: int, num_heads: int, ac_fn='relu', norm_fn='bn'):
        super().__init__()
        self.attn = MultiHeadedAttention(num_heads, feature_dim)
        self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
        nn.init.constant_(self.mlp[-1].bias, 0.0)

    def forward(self, x, source, M=None):
        message = self.attn(x, source, source, M=M)
        self.prob = self.attn.prob

        out = self.mlp(torch.cat([x, message], dim=1))
        return out


class KeypointEncoder(nn.Module):
    """ Joint encoding of visual appearance and location using MLPs"""

    def __init__(self, input_dim, feature_dim, layers, ac_fn='relu', norm_fn='bn'):
        super().__init__()
        self.input_dim = input_dim
        self.encoder = MLP([input_dim] + layers + [feature_dim], ac_fn=ac_fn, norm_fn=norm_fn)
        nn.init.constant_(self.encoder[-1].bias, 0.0)

    def forward(self, kpts, scores=None):
        if self.input_dim == 2:
            return self.encoder(kpts.transpose(1, 2))
        else:
            inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]  # [B, 2, N] + [B, 1, N]
            return self.encoder(torch.cat(inputs, dim=1))