Spaces:
Sleeping
Sleeping
print("Importing external...") | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from timm.models.efficientvit_mit import ( | |
ConvNormAct, | |
FusedMBConv, | |
MBConv, | |
ResidualBlock, | |
efficientvit_l1, | |
) | |
from timm.layers import GELUTanh | |
def val2list(x: list or tuple or any, repeat_time=1): | |
if isinstance(x, (list, tuple)): | |
return list(x) | |
return [x for _ in range(repeat_time)] | |
def resize( | |
x: torch.Tensor, | |
size: any or None = None, | |
scale_factor: list[float] or None = None, | |
mode: str = "bicubic", | |
align_corners: bool or None = False, | |
) -> torch.Tensor: | |
if mode in {"bilinear", "bicubic"}: | |
return F.interpolate( | |
x, | |
size=size, | |
scale_factor=scale_factor, | |
mode=mode, | |
align_corners=align_corners, | |
) | |
elif mode in {"nearest", "area"}: | |
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) | |
else: | |
raise NotImplementedError(f"resize(mode={mode}) not implemented.") | |
class UpSampleLayer(nn.Module): | |
def __init__( | |
self, | |
mode="bicubic", | |
size: int or tuple[int, int] or list[int] or None = None, | |
factor=2, | |
align_corners=False, | |
): | |
super(UpSampleLayer, self).__init__() | |
self.mode = mode | |
self.size = val2list(size, 2) if size is not None else None | |
self.factor = None if self.size is not None else factor | |
self.align_corners = align_corners | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if ( | |
self.size is not None and tuple(x.shape[-2:]) == self.size | |
) or self.factor == 1: | |
return x | |
return resize(x, self.size, self.factor, self.mode, self.align_corners) | |
class DAGBlock(nn.Module): | |
def __init__( | |
self, | |
inputs: dict[str, nn.Module], | |
merge: str, | |
post_input: nn.Module or None, | |
middle: nn.Module, | |
outputs: dict[str, nn.Module], | |
): | |
super(DAGBlock, self).__init__() | |
self.input_keys = list(inputs.keys()) | |
self.input_ops = nn.ModuleList(list(inputs.values())) | |
self.merge = merge | |
self.post_input = post_input | |
self.middle = middle | |
self.output_keys = list(outputs.keys()) | |
self.output_ops = nn.ModuleList(list(outputs.values())) | |
def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
feat = [ | |
op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops) | |
] | |
if self.merge == "add": | |
feat = list_sum(feat) | |
elif self.merge == "cat": | |
feat = torch.concat(feat, dim=1) | |
else: | |
raise NotImplementedError | |
if self.post_input is not None: | |
feat = self.post_input(feat) | |
feat = self.middle(feat) | |
for key, op in zip(self.output_keys, self.output_ops): | |
feature_dict[key] = op(feat) | |
return feature_dict | |
def list_sum(x: list) -> any: | |
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) | |
class SegHead(nn.Module): | |
def __init__( | |
self, | |
fid_list: list[str], | |
in_channel_list: list[int], | |
stride_list: list[int], | |
head_stride: int, | |
head_width: int, | |
head_depth: int, | |
expand_ratio: float, | |
middle_op: str, | |
final_expand: float or None, | |
n_classes: int, | |
dropout=0, | |
norm="bn2d", | |
act_func="hswish", | |
): | |
super(SegHead, self).__init__() | |
# exceptions to adapt effvit to timm | |
if act_func == "gelu": | |
act_func = GELUTanh | |
else: | |
raise ValueError(f"act_func {act_func} not supported") | |
if norm == "bn2d": | |
norm_layer = nn.BatchNorm2d | |
else: | |
raise ValueError(f"norm {norm} not supported") | |
inputs = {} | |
for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list): | |
factor = stride // head_stride | |
if factor == 1: | |
inputs[fid] = ConvNormAct( | |
in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func | |
) | |
else: | |
inputs[fid] = nn.Sequential( | |
ConvNormAct( | |
in_channel, | |
head_width, | |
1, | |
norm_layer=norm_layer, | |
act_layer=act_func, | |
), | |
UpSampleLayer(factor=factor), | |
) | |
self.in_keys = inputs.keys() | |
self.in_ops = nn.ModuleList(inputs.values()) | |
middle = [] | |
for _ in range(head_depth): | |
if middle_op == "mbconv": | |
block = MBConv( | |
head_width, | |
head_width, | |
expand_ratio=expand_ratio, | |
norm_layer=norm_layer, | |
act_layer=(act_func, act_func, None), | |
) | |
elif middle_op == "fmbconv": | |
block = FusedMBConv( | |
head_width, | |
head_width, | |
expand_ratio=expand_ratio, | |
norm_layer=norm_layer, | |
act_layer=(act_func, None), | |
) | |
else: | |
raise NotImplementedError | |
middle.append(ResidualBlock(block, nn.Identity())) | |
self.middle = nn.Sequential(*middle) | |
self.out_layer = nn.Sequential( | |
*[ | |
None | |
if final_expand is None | |
else ConvNormAct( | |
head_width, | |
head_width * final_expand, | |
1, | |
norm_layer=norm_layer, | |
act_layer=act_func, | |
), | |
ConvNormAct( | |
head_width * (final_expand or 1), | |
n_classes, | |
1, | |
bias=True, | |
dropout=dropout, | |
norm_layer=None, | |
act_layer=None, | |
), | |
] | |
) | |
def forward(self, feature_map_list): | |
t_feat_maps = [ | |
self.in_ops[ind](feature_map_list[ind]) | |
for ind in range(len(feature_map_list)) | |
] | |
t_feat_map = list_sum(t_feat_maps) | |
t_feat_map = self.middle(t_feat_map) | |
out = self.out_layer(t_feat_map) | |
return out | |
class EfficientViT_l1_r224(nn.Module): | |
def __init__( | |
self, | |
out_channels, | |
out_ds_factor=1, | |
decoder_size="small", | |
pretrained=False, | |
use_norm_params=False, | |
): | |
if decoder_size == "small": | |
head_width = 32 | |
head_depth = 1 | |
middle_op = "mbconv" | |
elif decoder_size == "medium": | |
head_width = 64 | |
head_depth = 3 | |
middle_op = "mbconv" | |
elif decoder_size == "large": | |
head_width = 256 | |
head_depth = 3 | |
middle_op = "fmbconv" | |
super(EfficientViT_l1_r224, self).__init__() | |
self.bbone = efficientvit_l1( | |
num_classes=0, features_only=True, pretrained=pretrained | |
) | |
self.head = SegHead( | |
fid_list=["stage4", "stage3", "stage2"], | |
in_channel_list=[512, 256, 128], | |
stride_list=[32, 16, 8], | |
head_stride=out_ds_factor, | |
head_width=head_width, | |
head_depth=head_depth, | |
expand_ratio=4, | |
middle_op=middle_op, | |
final_expand=8, | |
n_classes=out_channels, | |
act_func="gelu", | |
) | |
# [optional] deactivate normalization | |
if not use_norm_params: | |
for module in self.modules(): | |
if ( | |
isinstance(module, nn.LayerNorm) | |
or isinstance(module, nn.BatchNorm2d) | |
or isinstance(module, nn.BatchNorm1d) | |
): | |
module.weight.requires_grad_(False) | |
module.bias.requires_grad_(False) | |
def forward(self, x): | |
feat = self.bbone(x) | |
out = self.head([feat[3], feat[2], feat[1]]) | |
return out | |