|
""" |
|
@Date: 2021/09/01 |
|
@description: |
|
""" |
|
import warnings |
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from torch import nn, einsum |
|
from einops import rearrange |
|
|
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(self.norm(x), **kwargs) |
|
|
|
|
|
|
|
class GELU(nn.Module): |
|
def forward(self, input): |
|
return F.gelu(input) |
|
|
|
|
|
class Attend(nn.Module): |
|
|
|
def __init__(self, dim=None): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, input): |
|
return F.softmax(input, dim=self.dim, dtype=input.dtype) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim, dropout=0.): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_dim), |
|
GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class RelativePosition(nn.Module): |
|
def __init__(self, heads, patch_num=None, rpe=None): |
|
super().__init__() |
|
self.rpe = rpe |
|
self.heads = heads |
|
self.patch_num = patch_num |
|
|
|
if rpe == 'lr_parameter': |
|
|
|
count = patch_num * 2 - 1 |
|
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
|
nn.init.xavier_uniform_(self.rpe_table) |
|
elif rpe == 'lr_parameter_mirror': |
|
|
|
count = patch_num // 2 + 1 |
|
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
|
nn.init.xavier_uniform_(self.rpe_table) |
|
elif rpe == 'lr_parameter_half': |
|
|
|
count = patch_num |
|
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
|
nn.init.xavier_uniform_(self.rpe_table) |
|
elif rpe == 'fix_angle': |
|
|
|
count = patch_num // 2 + 1 |
|
|
|
rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads) |
|
self.register_buffer('rpe_table', rpe_table) |
|
|
|
def get_relative_pos_embed(self): |
|
range_vec = torch.arange(self.patch_num) |
|
distance_mat = range_vec[None, :] - range_vec[:, None] |
|
if self.rpe == 'lr_parameter': |
|
|
|
distance_mat += self.patch_num - 1 |
|
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
|
elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle': |
|
distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] |
|
distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[ |
|
distance_mat > self.patch_num // 2] |
|
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
|
elif self.rpe == 'lr_parameter_half': |
|
distance_mat[distance_mat > self.patch_num // 2] = distance_mat[ |
|
distance_mat > self.patch_num // 2] - self.patch_num |
|
distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[ |
|
distance_mat < -self.patch_num // 2 + 1] + self.patch_num |
|
|
|
distance_mat += self.patch_num//2 - 1 |
|
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
|
|
|
def forward(self, attn): |
|
return attn + self.get_relative_pos_embed() |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1): |
|
""" |
|
:param dim: |
|
:param heads: |
|
:param dim_head: |
|
:param dropout: |
|
:param patch_num: |
|
:param rpe: relative position embedding |
|
""" |
|
super().__init__() |
|
|
|
self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe) |
|
inner_dim = dim_head * heads |
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
self.rpe_pos = rpe_pos |
|
|
|
self.attend = Attend(dim=-1) |
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, dim), |
|
nn.Dropout(dropout) |
|
) if project_out else nn.Identity() |
|
|
|
def forward(self, x): |
|
b, n, _, h = *x.shape, self.heads |
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) |
|
|
|
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
|
|
|
if self.rpe_pos == 0: |
|
if self.relative_pos_embed is not None: |
|
dots = self.relative_pos_embed(dots) |
|
|
|
attn = self.attend(dots) |
|
|
|
if self.rpe_pos == 1: |
|
if self.relative_pos_embed is not None: |
|
attn = self.relative_pos_embed(attn) |
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
out = rearrange(out, 'b h n d -> b n (h d)') |
|
return self.to_out(out) |
|
|
|
|
|
class AbsolutePosition(nn.Module): |
|
def __init__(self, dim, dropout=0., patch_num=None, ape=None): |
|
super().__init__() |
|
self.ape = ape |
|
|
|
if ape == 'lr_parameter': |
|
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim)) |
|
trunc_normal_(self.absolute_pos_embed, std=.02) |
|
|
|
elif ape == 'fix_angle': |
|
angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2) |
|
self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None] |
|
|
|
def forward(self, x): |
|
return x + self.absolute_pos_embed |
|
|
|
|
|
class WinAttention(nn.Module): |
|
def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1): |
|
super().__init__() |
|
|
|
self.win_size = win_size |
|
self.shift = shift |
|
self.attend = Attention(dim, heads=heads, dim_head=dim_head, |
|
dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter', |
|
rpe_pos=rpe_pos) |
|
|
|
def forward(self, x): |
|
b = x.shape[0] |
|
if self.shift != 0: |
|
x = torch.roll(x, shifts=self.shift, dims=-2) |
|
x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) |
|
|
|
out = self.attend(x) |
|
|
|
out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) |
|
if self.shift != 0: |
|
out = torch.roll(out, shifts=-self.shift, dims=-2) |
|
|
|
return out |
|
|
|
|
|
class Conv(nn.Module): |
|
def __init__(self, dim, dropout=0.): |
|
super().__init__() |
|
self.dim = dim |
|
self.net = nn.Sequential( |
|
nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
x = x.transpose(1, 2) |
|
x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1) |
|
x = self.net(x) |
|
return x.transpose(1, 2) |
|
|