uniformer_base_tl_384 / modelling_uniformer.py
anicolson's picture
Update modelling_uniformer.py
0cd14c4 verified
raw
history blame
16.7 kB
from collections import OrderedDict
from functools import partial
from math import isqrt
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_uniformer import UniFormerWithProjectionHeadConfig
logger = logging.get_logger(__name__)
layer_scale = False
init_value = 1e-6
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class CMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, layer_scale=False, init_value=1e-6):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = nn.BatchNorm2d(dim)
self.conv1 = nn.Conv2d(dim, dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.BatchNorm2d(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls = layer_scale
if self.ls:
self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
def forward(self, x):
x = x + self.pos_embed(x)
x = x + self.module_1(x)
x = x + self.module_2(x)
return x
def module_1(self, x):
x = self.norm1(x.to(dtype=self.norm1.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
x = self.conv1(x)
x = self.attn(x)
x = self.conv2(x)
x = self.drop_path(self.gamma_1 * x) if self.ls else self.drop_path(x)
return x
def module_2(self, x):
x = self.norm2(x.to(dtype=self.norm2.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
x = self.mlp(x)
x = self.drop_path(self.gamma_2 * x) if self.ls else self.drop_path(x)
return x
class SABlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, layer_scale=False, init_value=1e-6):
super().__init__()
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls = layer_scale
if self.ls:
self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)),requires_grad=True)
def forward(self, x):
x = x + self.pos_embed(x)
B, N, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
if self.ls:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).reshape(B, N, H, W)
return x
class HeadEmbedding(nn.Module):
def __init__(self, in_channels, out_channels):
super(HeadEmbedding, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(out_channels // 2),
nn.GELU(),
nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
x = self.proj(x)
return x
class MiddleEmbedding(nn.Module):
def __init__(self, in_channels, out_channels):
super(MiddleEmbedding, self).__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
x = self.proj(x)
return x
class PatchEmbed(nn.Module):
def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
num_patches_height = image_size[0] // patch_size[0]
num_patches_width = image_size[1] // patch_size[1]
num_patches = num_patches_height * num_patches_width
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
_, _, H, W = x.shape
assert H == self.image_size[0] and W == self.image_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
x = self.proj(x)
B, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
return x
class UniFormer(nn.Module):
def __init__(self, depth=[3, 4, 8, 3], image_size=224, in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, patch_size=[4, 2, 2, 2],
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., conv_stem=False, layer_norm_eps=1e-6, layer_scale=False,
init_value=1e-6, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
if conv_stem:
self.patch_embed1 = HeadEmbedding(in_channels=in_chans, out_channels=embed_dim[0])
self.patch_embed2 = MiddleEmbedding(in_channels=embed_dim[0], out_channels=embed_dim[1])
self.patch_embed3 = MiddleEmbedding(in_channels=embed_dim[1], out_channels=embed_dim[2])
self.patch_embed4 = MiddleEmbedding(in_channels=embed_dim[2], out_channels=embed_dim[3])
else:
self.patch_embed1 = PatchEmbed(
image_size=image_size, patch_size=patch_size[0], in_chans=in_chans, embed_dim=embed_dim[0])
self.patch_embed2 = PatchEmbed(
image_size=image_size // patch_size[0], patch_size=patch_size[1], in_chans=embed_dim[0], embed_dim=embed_dim[1])
self.patch_embed3 = PatchEmbed(
image_size=image_size // (patch_size[0]*patch_size[1]), patch_size=patch_size[2], in_chans=embed_dim[1], embed_dim=embed_dim[2])
self.patch_embed4 = PatchEmbed(
image_size=image_size // (patch_size[0]*patch_size[1]*patch_size[2]), patch_size=patch_size[3], in_chans=embed_dim[2], embed_dim=embed_dim[3])
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
num_heads = [dim // head_dim for dim in embed_dim]
self.blocks1 = nn.ModuleList([
CBlock(dim=embed_dim[0], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], layer_scale=layer_scale, init_value=init_value)
for i in range(depth[0])])
self.blocks2 = nn.ModuleList([
CBlock(dim=embed_dim[1], mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i+depth[0]], layer_scale=layer_scale, init_value=init_value)
for i in range(depth[1])])
self.blocks3 = nn.ModuleList([
SABlock(
dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer, layer_scale=layer_scale, init_value=init_value)
for i in range(depth[2])])
self.blocks4 = nn.ModuleList([
SABlock(
dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer, layer_scale=layer_scale, init_value=init_value)
for i in range(depth[3])])
self.norm = nn.BatchNorm2d(embed_dim[-1])
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
def forward_features(self, x):
x = self.patch_embed1(x)
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
x = self.patch_embed2(x)
for blk in self.blocks2:
x = blk(x)
x = self.patch_embed3(x)
for blk in self.blocks3:
x = blk(x)
x = self.patch_embed4(x)
for blk in self.blocks4:
x = blk(x)
x = self.norm(x.to(dtype=self.norm.weight.dtype)) # Won't autocast to the dtype of the parameters of nn.BatchNorm2d.
x = self.pre_logits(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
class UniFormerPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = UniFormerWithProjectionHeadConfig
main_input_name = "pixel_values"
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
class UniFormerProjectionHead(torch.nn.Module):
def __init__(self, config) -> None:
super().__init__()
# Layer normalisation before projection:
self.layer_norm = torch.nn.LayerNorm(config.embed_dim[-1], eps=config.layer_norm_eps)
# No bias as following layer normalisation with bias:
self.projection = torch.nn.Linear(config.embed_dim[-1], config.projection_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer_norm(x)
x = self.projection(x)
return x
class UniFormerModel(UniFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.uniformer = UniFormer(**vars(config))
# Initialize weights and apply final processing:
self.post_init()
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
last_hidden_state = self.uniformer(pixel_values)
# Flatten h x w:
last_hidden_state = torch.flatten(last_hidden_state, 2)
# Permute last hidden state:
last_hidden_state = torch.permute(last_hidden_state, [0, 2, 1])
# return last_hidden_state
if not return_dict:
return last_hidden_state
return ModelOutput(last_hidden_state=last_hidden_state)
class MultiUniFormerWithProjectionHead(UniFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.uniformer = UniFormer(**vars(config))
self.projection_head = UniFormerProjectionHead(config)
# Initialize weights and apply final processing:
self.post_init()
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Flatten the batch and study_id dimensions:
assert len(pixel_values.shape) == 5, 'pixel_values must be B, S, C, H, W, where S is the max number of images for a study in the batch.'
last_hidden_state = self.uniformer(pixel_values.view(-1, *pixel_values.shape[2:]))
# last_hidden_state = self.uniformer(pixel_values.flatten(start_dim=0, end_dim=1))
# Flatten h x w:
last_hidden_state = torch.flatten(last_hidden_state, 2)
# Project the features for each spatial position to the decoder's hidden size:
projection = self.projection_head(torch.permute(last_hidden_state, [0, 2, 1]))
# Concatenate the features for each chest X-ray:
projection = projection.view(pixel_values.shape[0], -1, projection.shape[-1])
# Derive the attention mask from the pixel values:
mask = (pixel_values[:, :, 0, 0, 0] != 0.0)[:, :, None]
attention_mask = torch.ones(
[projection.shape[0], pixel_values.shape[1], projection.shape[1] // pixel_values.shape[1]],
dtype=torch.long,
device=mask.device,
)
attention_mask = attention_mask * mask
attention_mask = attention_mask.view(attention_mask.shape[0], -1)
if not return_dict:
return projection
return ModelOutput(last_hidden_state=projection, attention_mask=attention_mask)
if __name__ == '__main__':
y = PatchEmbed()
y(torch.randn(2, 3, 224, 224))