Spaces:
Runtime error
Runtime error
from inspect import isfunction | |
import math | |
import torch | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
from .blocks import get_norm, zero_module | |
def QKV_Attention(qkv, num_heads): | |
""" | |
Apply QKV attention. | |
:param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs. | |
:return: an [N x H' x T] tensor after attention. | |
""" | |
B, C, HW = qkv.shape | |
if C % 3 != 0: | |
raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW)) | |
split_size = C // (3 * num_heads) | |
q, k, v = qkv.chunk(3, dim=1) | |
scale = 1.0/math.sqrt(math.sqrt(split_size)) | |
weight = torch.einsum('bct, bcs->bts', | |
(q * scale).view(B * num_heads, split_size, HW), | |
(k * scale).view(B * num_heads, split_size, HW)) | |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
ret = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW)) | |
return ret.reshape(B, -1, HW) | |
class AttentionBlock(nn.Module): | |
""" | |
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py | |
https://github.com/whai362/PVT/blob/a24ba02c249a510581a84f821c26322534b03a10/detection/pvt_v2.py#L57 | |
""" | |
def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True): | |
super().__init__() | |
self.num_heads = num_heads | |
self.norm = get_norm(in_channels, 'Group') | |
self.qkv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1) | |
self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1)) | |
def forward(self, x): | |
b, c, *spatial = x.shape | |
num_heads = self.num_heads | |
x = x.reshape(b, c, -1) # B x C x HW | |
x = self.norm(x) | |
qkv = self.qkv(x) # b x c x HW -> B x 3C x HW | |
h = QKV_Attention(qkv, num_heads) | |
h = self.proj(h) | |
return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet? | |
def get_model_size(model): | |
param_size = 0 | |
for param in model.parameters(): | |
param_size += param.nelement() * param.element_size() | |
buffer_size = 0 | |
for buffer in model.buffers(): | |
buffer_size += buffer.nelement() * buffer.element_size() | |
size_all_mb = (param_size + buffer_size) / 1024 ** 2 | |
print('model size: {:.3f}MB'.format(size_all_mb)) | |
# return param_size + buffer_size | |
return size_all_mb | |
if __name__ == '__main__': | |
model = AttentionBlock(in_channels=256, num_heads=8) | |
x = torch.randn(5, 256, 32, 32, dtype=torch.float32) | |
y = model(x) | |
print('{}, {}'.format(x.shape, y.shape)) | |
get_model_size(model) | |