import torch import torch.nn as nn import torch.nn.functional as F from positional_encodings.torch_encodings import PositionalEncoding2D class LayerNorm2D(nn.Module): def __init__(self, embed_dim): super().__init__() self.layer_norm = nn.LayerNorm(embed_dim) def forward(self, x): x = x.permute(0, 2, 3, 1) x = self.layer_norm(x) x = x.permute(0, 3, 1, 2) return x class Image_Adaptor(nn.Module): def __init__(self, in_channels, adp_channels, dropout=0.1): super().__init__() self.adaptor = nn.Sequential( nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'), LayerNorm2D(adp_channels // 4), nn.GELU(), nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'), LayerNorm2D(adp_channels // 4), nn.GELU(), nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same') ) self.dropout = nn.Dropout(dropout) def forward(self, images): """ input: [N, in_channels, H, W] output: [N, apd_channels, H, W] """ adapt_imgs = self.adaptor(images) return self.dropout(adapt_imgs) class Positional_Encoding(nn.Module): def __init__(self, adp_channels): super().__init__() self.pe = PositionalEncoding2D(adp_channels) def forward(self, adapt_imgs): """ input: [N, apd_channels, H, W] output: [N, apd_channels, H, W] """ x = adapt_imgs.permute(0, -2, -1, -3) encode = self.pe(x) encode = encode.permute(0, -1, -3, -2) return encode class GeGLU(nn.Module): def __init__(self, emb_channels, ffn_size): super().__init__() self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False) self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False) self.act = nn.GELU() def forward(self, x): x_gelu = self.act(self.wi_0(x)) x_linear = self.wi_1(x) x = x_gelu * x_linear return x class Feed_Forward(nn.Module): def __init__(self, in_channels, ffw_channels, dropout=0.1): super().__init__() self.ln1 = GeGLU(in_channels, ffw_channels) self.dropout = nn.Dropout(dropout) self.ln2 = GeGLU(ffw_channels, in_channels) def forward(self, x): ''' input: [N, H, W, channels] output: [N, H, W, channels] ''' x = self.ln1(x) x = self.dropout(x) x = self.ln2(x) return x class MultiHeadAttention(nn.Module): def __init__(self, channels, num_attn_heads, dropout=0.1): super().__init__() self.head_size = num_attn_heads self.channels = channels self.attn_size = channels // num_attn_heads self.scale = self.attn_size ** -0.5 assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!" self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False) self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False) self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False) self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, q, k, v): ''' input: [N, H, W, channels] cho cả 3 cái q, k, v output: [N, H, W, channels] ''' bz, H, W, C = q.shape # Duỗi ảnh ra trước q = q.view(bz, -1, C) # [N, H*W, C] k = k.view(bz, -1, C) # [N, H*W, C] v = v.view(bz, -1, C) # [N, H*W, C] q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az] q = q.transpose(1, 2) # [N, hz, H*W, az] k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W] v = v.transpose(1, 2) # [N, hz, H*W, az] q *= self.scale x = torch.matmul(q, k) # [N, hz, H*W, H*W] x = torch.softmax(x, dim=-1) x = self.dropout(x) x = x.matmul(v) # [N, hz, H*W, az] x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az] x = x.view(bz, -1, C) # [N, H*W, C] x = x.view(bz, H, W, C) # [N, H, W, C] x = self.lout(x) # [N, H, W, C] return x class Transformer_Encoder_Layer(nn.Module): def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1): super().__init__() self.attn_norm = nn.LayerNorm(channels) self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout) self.attn_dropout = nn.Dropout(dropout) self.ffw_norm = nn.LayerNorm(channels) self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout) self.ffw_dropout = nn.Dropout(dropout) def forward(self, adp_pos_imgs): """ input: [N, H, W, channels] output: [N, H, W, channels] """ _x = adp_pos_imgs x = self.attn_norm(adp_pos_imgs) x = self.attn_layer(x, x, x) x = self.attn_dropout(x) x = x + _x _x = x x = self.ffw_norm(x) x = self.ffw_layer(x) x = self.ffw_dropout(x) x = x + _x return x class Transformer_Encoder(nn.Module): def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1): super().__init__() self.encoder_layers = nn.ModuleList([ Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers) ]) self.linear = nn.Linear(in_channels, out_channels) self.last_norm = LayerNorm2D(out_channels) self.dropout = nn.Dropout(dropout) def forward(self, adp_pos_imgs): """ input: [N, in_channels, H, W] output: [N, out_channels, H, W] """ x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels] for layer in self.encoder_layers: x = layer(x) x = self.linear(x) # [N, H, W, out_channels] x = x.permute(0, -1, -3, -2) x = self.last_norm(x) out = self.dropout(x) return out class Double_Conv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, X): """ input: [N, in_channels, H, W] output: [N, out_channels, H//2, W//2] """ return self.double_conv(X) class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.down = nn.Sequential( nn.MaxPool2d(2), Double_Conv(in_channels, out_channels) ) def forward(self, X): """ input: [N, in_channels, H, W] output: [N, out_channels, H//2, W//2] """ return self.down(X) class Up(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) self.conv = Double_Conv(in_channels, out_channels) def forward(self, X1, X2): """ input: X1 : [N, in_channels, H // 2, W // 2] X2 : [N, in_channels // 2, H, W] output: X : [N, out_channels, H, W] """ X1 = self.up(X1) diffY = X2.shape[-2] - X1.shape[-2] diffX = X2.shape[-1] - X1.shape[-1] pad_top = diffY // 2 pad_bottom = diffY - pad_top pad_left = diffX // 2 pad_right = diffX - pad_left X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom)) X = torch.cat((X2, X1), dim=-3) return self.conv(X) class Out_Conv(nn.Module): def __init__(self, adp_channels, out_channels): super().__init__() self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1) def forward(self, X): return self.out_conv(X) class Trans_UNet(nn.Module): def __init__(self, in_channels, adp_channels, out_channels, trans_num_layers=5, trans_num_attn_heads=8, trans_ffw_channels=1024, dropout=0.1): super().__init__() self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout) self.pos_encoding = Positional_Encoding(adp_channels) self.down1 = Down(adp_channels * 1, adp_channels * 2) self.down2 = Down(adp_channels * 2, adp_channels * 4) self.down3 = Down(adp_channels * 4, adp_channels * 8) self.down4 = Down(adp_channels * 8, adp_channels * 16) self.down5 = Down(adp_channels * 16, adp_channels * 32) self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout) self.up5 = Up(adp_channels * 32, adp_channels * 16) self.up4 = Up(adp_channels * 16, adp_channels * 8) self.up3 = Up(adp_channels * 8, adp_channels * 4) self.up2 = Up(adp_channels * 4, adp_channels * 2) self.up1 = Up(adp_channels * 2, adp_channels * 1) self.out_conv = Out_Conv(adp_channels, out_channels) self.sigmoid = nn.Sigmoid() def forward(self, images): adp_imgs = self.img_adaptor(images) pos_enc = self.pos_encoding(adp_imgs) adp_imgs += pos_enc d1 = self.down1(adp_imgs) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) x = self.trans_encoder(d5) u5 = self.up5(x, d4) u4 = self.up4(u5, d3) u3 = self.up3(u4, d2) u2 = self.up2(u3, d1) u1 = self.up1(u2, adp_imgs) x = self.out_conv(u1) out = self.sigmoid(x) return out