File size: 2,735 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from inspect import isfunction
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F

from .blocks import get_norm, zero_module


def QKV_Attention(qkv, num_heads):
    """
    Apply QKV attention.
    :param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs.
    :return: an [N x H' x T] tensor after attention.
    """
    B, C, HW = qkv.shape
    if C % 3 != 0:
        raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW))

    split_size = C // (3 * num_heads)
    q, k, v = qkv.chunk(3, dim=1)
    scale      = 1.0/math.sqrt(math.sqrt(split_size))
    weight = torch.einsum('bct, bcs->bts',
                          (q * scale).view(B * num_heads, split_size, HW),
                          (k * scale).view(B * num_heads, split_size, HW))

    weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
    ret    = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW))

    return ret.reshape(B, -1, HW)


class AttentionBlock(nn.Module):
    """
        https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py
        https://github.com/whai362/PVT/blob/a24ba02c249a510581a84f821c26322534b03a10/detection/pvt_v2.py#L57
    """

    def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True):
        super().__init__()

        self.num_heads = num_heads
        self.norm = get_norm(in_channels, 'Group')
        self.qkv  = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1)

        self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1))


    def forward(self, x):
        b, c, *spatial = x.shape
        num_heads = self.num_heads

        x   = x.reshape(b, c, -1) # B x C x HW
        x   = self.norm(x)
        qkv = self.qkv(x) # b x c x HW ->  B x 3C x HW
        h   = QKV_Attention(qkv, num_heads)
        h   = self.proj(h)

        return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet?



def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024 ** 2
    print('model size: {:.3f}MB'.format(size_all_mb))
    # return param_size + buffer_size
    return size_all_mb


if __name__ == '__main__':
    model = AttentionBlock(in_channels=256, num_heads=8)

    x = torch.randn(5, 256, 32, 32, dtype=torch.float32)
    y = model(x)
    print('{}, {}'.format(x.shape, y.shape))

    get_model_size(model)