TheEditor / TransUnet.py
SS3M's picture
Upload 7 files
e2cc14b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from positional_encodings.torch_encodings import PositionalEncoding2D
class LayerNorm2D(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.layer_norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.layer_norm(x)
x = x.permute(0, 3, 1, 2)
return x
class Image_Adaptor(nn.Module):
def __init__(self, in_channels, adp_channels, dropout=0.1):
super().__init__()
self.adaptor = nn.Sequential(
nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'),
LayerNorm2D(adp_channels // 4),
nn.GELU(),
nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'),
LayerNorm2D(adp_channels // 4),
nn.GELU(),
nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same')
)
self.dropout = nn.Dropout(dropout)
def forward(self, images):
"""
input: [N, in_channels, H, W]
output: [N, apd_channels, H, W]
"""
adapt_imgs = self.adaptor(images)
return self.dropout(adapt_imgs)
class Positional_Encoding(nn.Module):
def __init__(self, adp_channels):
super().__init__()
self.pe = PositionalEncoding2D(adp_channels)
def forward(self, adapt_imgs):
"""
input: [N, apd_channels, H, W]
output: [N, apd_channels, H, W]
"""
x = adapt_imgs.permute(0, -2, -1, -3)
encode = self.pe(x)
encode = encode.permute(0, -1, -3, -2)
return encode
class GeGLU(nn.Module):
def __init__(self, emb_channels, ffn_size):
super().__init__()
self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False)
self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False)
self.act = nn.GELU()
def forward(self, x):
x_gelu = self.act(self.wi_0(x))
x_linear = self.wi_1(x)
x = x_gelu * x_linear
return x
class Feed_Forward(nn.Module):
def __init__(self, in_channels, ffw_channels, dropout=0.1):
super().__init__()
self.ln1 = GeGLU(in_channels, ffw_channels)
self.dropout = nn.Dropout(dropout)
self.ln2 = GeGLU(ffw_channels, in_channels)
def forward(self, x):
'''
input: [N, H, W, channels]
output: [N, H, W, channels]
'''
x = self.ln1(x)
x = self.dropout(x)
x = self.ln2(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_attn_heads, dropout=0.1):
super().__init__()
self.head_size = num_attn_heads
self.channels = channels
self.attn_size = channels // num_attn_heads
self.scale = self.attn_size ** -0.5
assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!"
self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v):
'''
input: [N, H, W, channels] cho cả 3 cái q, k, v
output: [N, H, W, channels]
'''
bz, H, W, C = q.shape
# Duỗi ảnh ra trước
q = q.view(bz, -1, C) # [N, H*W, C]
k = k.view(bz, -1, C) # [N, H*W, C]
v = v.view(bz, -1, C) # [N, H*W, C]
q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
q = q.transpose(1, 2) # [N, hz, H*W, az]
k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W]
v = v.transpose(1, 2) # [N, hz, H*W, az]
q *= self.scale
x = torch.matmul(q, k) # [N, hz, H*W, H*W]
x = torch.softmax(x, dim=-1)
x = self.dropout(x)
x = x.matmul(v) # [N, hz, H*W, az]
x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az]
x = x.view(bz, -1, C) # [N, H*W, C]
x = x.view(bz, H, W, C) # [N, H, W, C]
x = self.lout(x) # [N, H, W, C]
return x
class Transformer_Encoder_Layer(nn.Module):
def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1):
super().__init__()
self.attn_norm = nn.LayerNorm(channels)
self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout)
self.attn_dropout = nn.Dropout(dropout)
self.ffw_norm = nn.LayerNorm(channels)
self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout)
self.ffw_dropout = nn.Dropout(dropout)
def forward(self, adp_pos_imgs):
"""
input: [N, H, W, channels]
output: [N, H, W, channels]
"""
_x = adp_pos_imgs
x = self.attn_norm(adp_pos_imgs)
x = self.attn_layer(x, x, x)
x = self.attn_dropout(x)
x = x + _x
_x = x
x = self.ffw_norm(x)
x = self.ffw_layer(x)
x = self.ffw_dropout(x)
x = x + _x
return x
class Transformer_Encoder(nn.Module):
def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1):
super().__init__()
self.encoder_layers = nn.ModuleList([
Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers)
])
self.linear = nn.Linear(in_channels, out_channels)
self.last_norm = LayerNorm2D(out_channels)
self.dropout = nn.Dropout(dropout)
def forward(self, adp_pos_imgs):
"""
input: [N, in_channels, H, W]
output: [N, out_channels, H, W]
"""
x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels]
for layer in self.encoder_layers:
x = layer(x)
x = self.linear(x) # [N, H, W, out_channels]
x = x.permute(0, -1, -3, -2)
x = self.last_norm(x)
out = self.dropout(x)
return out
class Double_Conv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, X):
"""
input: [N, in_channels, H, W]
output: [N, out_channels, H//2, W//2]
"""
return self.double_conv(X)
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.down = nn.Sequential(
nn.MaxPool2d(2),
Double_Conv(in_channels, out_channels)
)
def forward(self, X):
"""
input: [N, in_channels, H, W]
output: [N, out_channels, H//2, W//2]
"""
return self.down(X)
class Up(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
self.conv = Double_Conv(in_channels, out_channels)
def forward(self, X1, X2):
"""
input: X1 : [N, in_channels, H // 2, W // 2]
X2 : [N, in_channels // 2, H, W]
output: X : [N, out_channels, H, W]
"""
X1 = self.up(X1)
diffY = X2.shape[-2] - X1.shape[-2]
diffX = X2.shape[-1] - X1.shape[-1]
pad_top = diffY // 2
pad_bottom = diffY - pad_top
pad_left = diffX // 2
pad_right = diffX - pad_left
X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom))
X = torch.cat((X2, X1), dim=-3)
return self.conv(X)
class Out_Conv(nn.Module):
def __init__(self, adp_channels, out_channels):
super().__init__()
self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1)
def forward(self, X):
return self.out_conv(X)
class Trans_UNet(nn.Module):
def __init__(self,
in_channels,
adp_channels,
out_channels,
trans_num_layers=5,
trans_num_attn_heads=8,
trans_ffw_channels=1024,
dropout=0.1):
super().__init__()
self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout)
self.pos_encoding = Positional_Encoding(adp_channels)
self.down1 = Down(adp_channels * 1, adp_channels * 2)
self.down2 = Down(adp_channels * 2, adp_channels * 4)
self.down3 = Down(adp_channels * 4, adp_channels * 8)
self.down4 = Down(adp_channels * 8, adp_channels * 16)
self.down5 = Down(adp_channels * 16, adp_channels * 32)
self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout)
self.up5 = Up(adp_channels * 32, adp_channels * 16)
self.up4 = Up(adp_channels * 16, adp_channels * 8)
self.up3 = Up(adp_channels * 8, adp_channels * 4)
self.up2 = Up(adp_channels * 4, adp_channels * 2)
self.up1 = Up(adp_channels * 2, adp_channels * 1)
self.out_conv = Out_Conv(adp_channels, out_channels)
self.sigmoid = nn.Sigmoid()
def forward(self, images):
adp_imgs = self.img_adaptor(images)
pos_enc = self.pos_encoding(adp_imgs)
adp_imgs += pos_enc
d1 = self.down1(adp_imgs)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
x = self.trans_encoder(d5)
u5 = self.up5(x, d4)
u4 = self.up4(u5, d3)
u3 = self.up3(u4, d2)
u2 = self.up2(u3, d1)
u1 = self.up1(u2, adp_imgs)
x = self.out_conv(u1)
out = self.sigmoid(x)
return out