In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import math
import torch as th

In [31]:
class GroupNorm32(nn.GroupNorm):
 def forward(self, x):
 return super().forward(x.float()).type(x.dtype)

def normalization(channels):
 """
 Make a standard normalization layer.
 :param channels: number of input channels.
 :return: an nn.Module for normalization.
 """
 return GroupNorm32(32, channels)


def conv_nd(dims, *args, **kwargs):
 """
 Create a 1D, 2D, or 3D convolution module.
 """
 if dims == 1:
 return nn.Conv1d(*args, **kwargs)
 elif dims == 2:
 return nn.Conv2d(*args, **kwargs)
 elif dims == 3:
 return nn.Conv3d(*args, **kwargs)
 raise ValueError(f"unsupported dimensions: {dims}")


In [32]:
class QKVAttentionLegacy(nn.Module):
 """
 A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
 """

 def __init__(self, n_heads):
 super().__init__()
 self.n_heads = n_heads

 def forward(self, qkv):
 """
 Apply QKV attention.
 :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
 :return: an [N x (H * C) x T] tensor after attention.
 """
 bs, width, length = qkv.shape
 assert width % (3 * self.n_heads) == 0
 ch = width // (3 * self.n_heads)
 q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
 scale = 1 / math.sqrt(math.sqrt(ch))
 weight = th.einsum(
 "bct,bcs->bts", q * scale, k * scale
 ) # More stable with f16 than dividing afterwards
 weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
 a = th.einsum("bts,bcs->bct", weight, v)
 return a.reshape(bs, -1, length)

 @staticmethod
 def count_flops(model, _x, y):
 return count_flops_attn(model, _x, y)

In [33]:
def zero_module(module):
 """
 Zero out the parameters of a module and return it.
 """
 for p in module.parameters():
 p.detach().zero_()
 return module


In [37]:
class AttentionBlock(nn.Module):
 """
 An attention block that allows spatial positions to attend to each other.
 Originally ported from here, but adapted to the N-d case.
 https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
 """

 def __init__(
 self,
 channels,
 num_heads=1,
 num_head_channels=-1,
 use_new_attention_order=False,
 ):
 super().__init__()
 self.channels = channels
 if num_head_channels == -1:
 self.num_heads = num_heads
 else:
 assert (
 channels % num_head_channels == 0
 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
 self.num_heads = channels // num_head_channels
 self.norm = normalization(channels)
 self.qkv = conv_nd(1, channels, channels * 3, 1)
 if use_new_attention_order:
 # split qkv before split heads
 self.attention = QKVAttention(self.num_heads)
 else:
 # split heads before split qkv
 self.attention = QKVAttentionLegacy(self.num_heads)

 self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

 def forward(self, x):
 
 import pdb; pdb.set_trace()
 
 b, c, *spatial = x.shape
 x = x.reshape(b, c, -1)
 qkv = self.qkv(self.norm(x))
 h = self.attention(qkv)
 h = self.proj_out(h)
 return (x + h).reshape(b, c, *spatial)

In [38]:
test_input = torch.randn(5, 32, 128, 128)

model = AttentionBlock(32, 1)

y = model(test_input)

> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(39)[0;36mforward[0;34m()[0m
[0;32m 37 [0;31m [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 38 [0;31m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m [0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb> n


> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(40)[0;36mforward[0;34m()[0m
[0;32m 38 [0;31m[0;34m[0m[0m
[0m[0;32m 39 [0;31m [0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb> n


> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(41)[0;36mforward[0;34m()[0m
[0;32m 39 [0;31m [0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 43 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mproj_out[0m[0;34m([0m[0mh[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb> x.shape


torch.Size([5, 32, 16384])


ipdb> t = self.norm(x)
ipdb> t.shape


torch.Size([5, 32, 16384])


ipdb> self.qkv


Conv1d(32, 96, kernel_size=(1,), stride=(1,))


ipdb> n


> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(42)[0;36mforward[0;34m()[0m
[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 43 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mproj_out[0m[0;34m([0m[0mh[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 44 [0;31m [0;32mreturn[0m [0;34m([0m[0mx[0m [0;34m+[0m [0mh[0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m[0;34m)[0m[

ipdb> qkv.shape


torch.Size([5, 96, 16384])


ipdb> t.shape


torch.Size([5, 32, 16384])


ipdb> n


> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(43)[0;36mforward[0;34m()[0m
[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mproj_out[0m[0;34m([0m[0mh[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 44 [0;31m [0;32mreturn[0m [0;34m([0m[0mx[0m [0;34m+[0m [0mh[0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m[0;34m)[0m[

ipdb> h.shape


*** No help for '.shape'


ipdb> h.shape


*** No help for '.shape'


ipdb> print(h.shape)


torch.Size([5, 32, 16384])


ipdb> self.proj_out


Conv1d(32, 32, kernel_size=(1,), stride=(1,))


ipdb> n


> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(44)[0;36mforward[0;34m()[0m
[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 43 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mproj_out[0m[0;34m([0m[0mh[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m [0;32mreturn[0m [0;34m([0m[0mx[0m [0;34m+[0m [0mh[0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m*[0m[0mspatial[0m[0;34m)[0m[

ipdb> 


--Return--
tensor([[[[ 1...iasBackward0>)
> [0;32m/tmp/ipykernel_456404/3277534714.py[0m(44)[0;36mforward[0;34m()[0m
[0;32m 40 [0;31m [0mx[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 41 [0;31m [0mqkv[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mqkv[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mnorm[0m[0;34m([0m[0mx[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 42 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mattention[0m[0;34m([0m[0mqkv[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m 43 [0;31m [0mh[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mproj_out[0m[0;34m([0m[0mh[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m [0;32mreturn[0m [0;34m([0m[0mx[0m [0;34m+[0m [0mh[0m[0;34m)[0m[0;34m.[0m[0mreshape[0m[0;34m([0m[0mb[0m[0;34m,[0m [0mc[0m[0;34m,[0m

ipdb> q


BdbQuit: 

In [36]:
y.shape

torch.Size([5, 32, 128, 128])