segmentation_features / network.py
franchesoni's picture
v0
e1b51e5
raw
history blame
8.28 kB
print("Importing external...")
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.efficientvit_mit import (
ConvNormAct,
FusedMBConv,
MBConv,
ResidualBlock,
efficientvit_l1,
)
from timm.layers import GELUTanh
def val2list(x: list or tuple or any, repeat_time=1):
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def resize(
x: torch.Tensor,
size: any or None = None,
scale_factor: list[float] or None = None,
mode: str = "bicubic",
align_corners: bool or None = False,
) -> torch.Tensor:
if mode in {"bilinear", "bicubic"}:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
)
elif mode in {"nearest", "area"}:
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
else:
raise NotImplementedError(f"resize(mode={mode}) not implemented.")
class UpSampleLayer(nn.Module):
def __init__(
self,
mode="bicubic",
size: int or tuple[int, int] or list[int] or None = None,
factor=2,
align_corners=False,
):
super(UpSampleLayer, self).__init__()
self.mode = mode
self.size = val2list(size, 2) if size is not None else None
self.factor = None if self.size is not None else factor
self.align_corners = align_corners
def forward(self, x: torch.Tensor) -> torch.Tensor:
if (
self.size is not None and tuple(x.shape[-2:]) == self.size
) or self.factor == 1:
return x
return resize(x, self.size, self.factor, self.mode, self.align_corners)
class DAGBlock(nn.Module):
def __init__(
self,
inputs: dict[str, nn.Module],
merge: str,
post_input: nn.Module or None,
middle: nn.Module,
outputs: dict[str, nn.Module],
):
super(DAGBlock, self).__init__()
self.input_keys = list(inputs.keys())
self.input_ops = nn.ModuleList(list(inputs.values()))
self.merge = merge
self.post_input = post_input
self.middle = middle
self.output_keys = list(outputs.keys())
self.output_ops = nn.ModuleList(list(outputs.values()))
def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
feat = [
op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
]
if self.merge == "add":
feat = list_sum(feat)
elif self.merge == "cat":
feat = torch.concat(feat, dim=1)
else:
raise NotImplementedError
if self.post_input is not None:
feat = self.post_input(feat)
feat = self.middle(feat)
for key, op in zip(self.output_keys, self.output_ops):
feature_dict[key] = op(feat)
return feature_dict
def list_sum(x: list) -> any:
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
class SegHead(nn.Module):
def __init__(
self,
fid_list: list[str],
in_channel_list: list[int],
stride_list: list[int],
head_stride: int,
head_width: int,
head_depth: int,
expand_ratio: float,
middle_op: str,
final_expand: float or None,
n_classes: int,
dropout=0,
norm="bn2d",
act_func="hswish",
):
super(SegHead, self).__init__()
# exceptions to adapt effvit to timm
if act_func == "gelu":
act_func = GELUTanh
else:
raise ValueError(f"act_func {act_func} not supported")
if norm == "bn2d":
norm_layer = nn.BatchNorm2d
else:
raise ValueError(f"norm {norm} not supported")
inputs = {}
for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
factor = stride // head_stride
if factor == 1:
inputs[fid] = ConvNormAct(
in_channel, head_width, 1, norm_layer=norm_layer, act_layer=act_func
)
else:
inputs[fid] = nn.Sequential(
ConvNormAct(
in_channel,
head_width,
1,
norm_layer=norm_layer,
act_layer=act_func,
),
UpSampleLayer(factor=factor),
)
self.in_keys = inputs.keys()
self.in_ops = nn.ModuleList(inputs.values())
middle = []
for _ in range(head_depth):
if middle_op == "mbconv":
block = MBConv(
head_width,
head_width,
expand_ratio=expand_ratio,
norm_layer=norm_layer,
act_layer=(act_func, act_func, None),
)
elif middle_op == "fmbconv":
block = FusedMBConv(
head_width,
head_width,
expand_ratio=expand_ratio,
norm_layer=norm_layer,
act_layer=(act_func, None),
)
else:
raise NotImplementedError
middle.append(ResidualBlock(block, nn.Identity()))
self.middle = nn.Sequential(*middle)
self.out_layer = nn.Sequential(
*[
None
if final_expand is None
else ConvNormAct(
head_width,
head_width * final_expand,
1,
norm_layer=norm_layer,
act_layer=act_func,
),
ConvNormAct(
head_width * (final_expand or 1),
n_classes,
1,
bias=True,
dropout=dropout,
norm_layer=None,
act_layer=None,
),
]
)
def forward(self, feature_map_list):
t_feat_maps = [
self.in_ops[ind](feature_map_list[ind])
for ind in range(len(feature_map_list))
]
t_feat_map = list_sum(t_feat_maps)
t_feat_map = self.middle(t_feat_map)
out = self.out_layer(t_feat_map)
return out
class EfficientViT_l1_r224(nn.Module):
def __init__(
self,
out_channels,
out_ds_factor=1,
decoder_size="small",
pretrained=False,
use_norm_params=False,
):
if decoder_size == "small":
head_width = 32
head_depth = 1
middle_op = "mbconv"
elif decoder_size == "medium":
head_width = 64
head_depth = 3
middle_op = "mbconv"
elif decoder_size == "large":
head_width = 256
head_depth = 3
middle_op = "fmbconv"
super(EfficientViT_l1_r224, self).__init__()
self.bbone = efficientvit_l1(
num_classes=0, features_only=True, pretrained=pretrained
)
self.head = SegHead(
fid_list=["stage4", "stage3", "stage2"],
in_channel_list=[512, 256, 128],
stride_list=[32, 16, 8],
head_stride=out_ds_factor,
head_width=head_width,
head_depth=head_depth,
expand_ratio=4,
middle_op=middle_op,
final_expand=8,
n_classes=out_channels,
act_func="gelu",
)
# [optional] deactivate normalization
if not use_norm_params:
for module in self.modules():
if (
isinstance(module, nn.LayerNorm)
or isinstance(module, nn.BatchNorm2d)
or isinstance(module, nn.BatchNorm1d)
):
module.weight.requires_grad_(False)
module.bias.requires_grad_(False)
def forward(self, x):
feat = self.bbone(x)
out = self.head([feat[3], feat[2], feat[1]])
return out