|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from positional_encodings.torch_encodings import PositionalEncoding2D
|
|
|
|
|
|
class LayerNorm2D(nn.Module):
|
|
def __init__(self, embed_dim):
|
|
super().__init__()
|
|
self.layer_norm = nn.LayerNorm(embed_dim)
|
|
|
|
def forward(self, x):
|
|
x = x.permute(0, 2, 3, 1)
|
|
x = self.layer_norm(x)
|
|
x = x.permute(0, 3, 1, 2)
|
|
return x
|
|
|
|
class Image_Adaptor(nn.Module):
|
|
def __init__(self, in_channels, adp_channels, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.adaptor = nn.Sequential(
|
|
nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'),
|
|
LayerNorm2D(adp_channels // 4),
|
|
nn.GELU(),
|
|
nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'),
|
|
LayerNorm2D(adp_channels // 4),
|
|
nn.GELU(),
|
|
nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same')
|
|
)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, images):
|
|
"""
|
|
input: [N, in_channels, H, W]
|
|
output: [N, apd_channels, H, W]
|
|
"""
|
|
adapt_imgs = self.adaptor(images)
|
|
return self.dropout(adapt_imgs)
|
|
|
|
class Positional_Encoding(nn.Module):
|
|
def __init__(self, adp_channels):
|
|
super().__init__()
|
|
self.pe = PositionalEncoding2D(adp_channels)
|
|
|
|
def forward(self, adapt_imgs):
|
|
"""
|
|
input: [N, apd_channels, H, W]
|
|
output: [N, apd_channels, H, W]
|
|
"""
|
|
x = adapt_imgs.permute(0, -2, -1, -3)
|
|
encode = self.pe(x)
|
|
encode = encode.permute(0, -1, -3, -2)
|
|
return encode
|
|
|
|
class GeGLU(nn.Module):
|
|
def __init__(self, emb_channels, ffn_size):
|
|
super().__init__()
|
|
self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False)
|
|
self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False)
|
|
self.act = nn.GELU()
|
|
|
|
def forward(self, x):
|
|
x_gelu = self.act(self.wi_0(x))
|
|
x_linear = self.wi_1(x)
|
|
x = x_gelu * x_linear
|
|
return x
|
|
|
|
class Feed_Forward(nn.Module):
|
|
def __init__(self, in_channels, ffw_channels, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.ln1 = GeGLU(in_channels, ffw_channels)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.ln2 = GeGLU(ffw_channels, in_channels)
|
|
|
|
def forward(self, x):
|
|
'''
|
|
input: [N, H, W, channels]
|
|
output: [N, H, W, channels]
|
|
'''
|
|
x = self.ln1(x)
|
|
x = self.dropout(x)
|
|
x = self.ln2(x)
|
|
return x
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
def __init__(self, channels, num_attn_heads, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.head_size = num_attn_heads
|
|
self.channels = channels
|
|
self.attn_size = channels // num_attn_heads
|
|
self.scale = self.attn_size ** -0.5
|
|
assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!"
|
|
|
|
self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
|
self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
|
self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
|
self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, q, k, v):
|
|
'''
|
|
input: [N, H, W, channels] cho cả 3 cái q, k, v
|
|
output: [N, H, W, channels]
|
|
'''
|
|
bz, H, W, C = q.shape
|
|
|
|
|
|
q = q.view(bz, -1, C)
|
|
k = k.view(bz, -1, C)
|
|
v = v.view(bz, -1, C)
|
|
|
|
q = self.lq(q).view(bz, -1, self.head_size, self.attn_size)
|
|
k = self.lk(k).view(bz, -1, self.head_size, self.attn_size)
|
|
v = self.lv(v).view(bz, -1, self.head_size, self.attn_size)
|
|
|
|
q = q.transpose(1, 2)
|
|
k = k.transpose(1, 2).transpose(-1, -2)
|
|
v = v.transpose(1, 2)
|
|
|
|
q *= self.scale
|
|
|
|
x = torch.matmul(q, k)
|
|
x = torch.softmax(x, dim=-1)
|
|
x = self.dropout(x)
|
|
x = x.matmul(v)
|
|
|
|
x = x.transpose(1, 2).contiguous()
|
|
x = x.view(bz, -1, C)
|
|
x = x.view(bz, H, W, C)
|
|
|
|
x = self.lout(x)
|
|
|
|
return x
|
|
|
|
class Transformer_Encoder_Layer(nn.Module):
|
|
def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.attn_norm = nn.LayerNorm(channels)
|
|
self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout)
|
|
self.attn_dropout = nn.Dropout(dropout)
|
|
|
|
self.ffw_norm = nn.LayerNorm(channels)
|
|
self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout)
|
|
self.ffw_dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, adp_pos_imgs):
|
|
"""
|
|
input: [N, H, W, channels]
|
|
output: [N, H, W, channels]
|
|
"""
|
|
_x = adp_pos_imgs
|
|
x = self.attn_norm(adp_pos_imgs)
|
|
x = self.attn_layer(x, x, x)
|
|
x = self.attn_dropout(x)
|
|
x = x + _x
|
|
|
|
_x = x
|
|
x = self.ffw_norm(x)
|
|
x = self.ffw_layer(x)
|
|
x = self.ffw_dropout(x)
|
|
x = x + _x
|
|
return x
|
|
|
|
class Transformer_Encoder(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.encoder_layers = nn.ModuleList([
|
|
Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers)
|
|
])
|
|
self.linear = nn.Linear(in_channels, out_channels)
|
|
self.last_norm = LayerNorm2D(out_channels)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, adp_pos_imgs):
|
|
"""
|
|
input: [N, in_channels, H, W]
|
|
output: [N, out_channels, H, W]
|
|
"""
|
|
x = adp_pos_imgs.permute(0, -2, -1, -3)
|
|
|
|
for layer in self.encoder_layers:
|
|
x = layer(x)
|
|
|
|
x = self.linear(x)
|
|
x = x.permute(0, -1, -3, -2)
|
|
x = self.last_norm(x)
|
|
out = self.dropout(x)
|
|
return out
|
|
|
|
class Double_Conv(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
|
|
self.double_conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, X):
|
|
"""
|
|
input: [N, in_channels, H, W]
|
|
output: [N, out_channels, H//2, W//2]
|
|
"""
|
|
return self.double_conv(X)
|
|
|
|
class Down(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
|
|
self.down = nn.Sequential(
|
|
nn.MaxPool2d(2),
|
|
Double_Conv(in_channels, out_channels)
|
|
)
|
|
|
|
def forward(self, X):
|
|
"""
|
|
input: [N, in_channels, H, W]
|
|
output: [N, out_channels, H//2, W//2]
|
|
"""
|
|
return self.down(X)
|
|
|
|
class Up(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
|
|
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
|
|
self.conv = Double_Conv(in_channels, out_channels)
|
|
|
|
def forward(self, X1, X2):
|
|
"""
|
|
input: X1 : [N, in_channels, H // 2, W // 2]
|
|
X2 : [N, in_channels // 2, H, W]
|
|
output: X : [N, out_channels, H, W]
|
|
"""
|
|
X1 = self.up(X1)
|
|
|
|
diffY = X2.shape[-2] - X1.shape[-2]
|
|
diffX = X2.shape[-1] - X1.shape[-1]
|
|
|
|
pad_top = diffY // 2
|
|
pad_bottom = diffY - pad_top
|
|
pad_left = diffX // 2
|
|
pad_right = diffX - pad_left
|
|
|
|
X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom))
|
|
|
|
X = torch.cat((X2, X1), dim=-3)
|
|
return self.conv(X)
|
|
|
|
class Out_Conv(nn.Module):
|
|
def __init__(self, adp_channels, out_channels):
|
|
super().__init__()
|
|
|
|
self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1)
|
|
|
|
def forward(self, X):
|
|
return self.out_conv(X)
|
|
|
|
class Trans_UNet(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
adp_channels,
|
|
out_channels,
|
|
trans_num_layers=5,
|
|
trans_num_attn_heads=8,
|
|
trans_ffw_channels=1024,
|
|
dropout=0.1):
|
|
super().__init__()
|
|
|
|
self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout)
|
|
self.pos_encoding = Positional_Encoding(adp_channels)
|
|
|
|
self.down1 = Down(adp_channels * 1, adp_channels * 2)
|
|
self.down2 = Down(adp_channels * 2, adp_channels * 4)
|
|
self.down3 = Down(adp_channels * 4, adp_channels * 8)
|
|
self.down4 = Down(adp_channels * 8, adp_channels * 16)
|
|
self.down5 = Down(adp_channels * 16, adp_channels * 32)
|
|
|
|
self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout)
|
|
|
|
self.up5 = Up(adp_channels * 32, adp_channels * 16)
|
|
self.up4 = Up(adp_channels * 16, adp_channels * 8)
|
|
self.up3 = Up(adp_channels * 8, adp_channels * 4)
|
|
self.up2 = Up(adp_channels * 4, adp_channels * 2)
|
|
self.up1 = Up(adp_channels * 2, adp_channels * 1)
|
|
|
|
self.out_conv = Out_Conv(adp_channels, out_channels)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, images):
|
|
adp_imgs = self.img_adaptor(images)
|
|
pos_enc = self.pos_encoding(adp_imgs)
|
|
adp_imgs += pos_enc
|
|
|
|
d1 = self.down1(adp_imgs)
|
|
d2 = self.down2(d1)
|
|
d3 = self.down3(d2)
|
|
d4 = self.down4(d3)
|
|
d5 = self.down5(d4)
|
|
|
|
x = self.trans_encoder(d5)
|
|
|
|
u5 = self.up5(x, d4)
|
|
u4 = self.up4(u5, d3)
|
|
u3 = self.up3(u4, d2)
|
|
u2 = self.up2(u3, d1)
|
|
u1 = self.up1(u2, adp_imgs)
|
|
|
|
x = self.out_conv(u1)
|
|
out = self.sigmoid(x)
|
|
return out |