Spaces:
Running
Running
import functools | |
import tensorflow as tf | |
from tensorflow.keras import backend as K | |
from tensorflow.keras import layers | |
from .blocks.attentions import SAM | |
from .blocks.bottleneck import BottleneckBlock | |
from .blocks.misc_gating import CrossGatingBlock | |
from .blocks.others import UpSampleRatio | |
from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock | |
from .layers import Resizing | |
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") | |
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same") | |
ConvT_up = functools.partial( | |
layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same" | |
) | |
Conv_down = functools.partial( | |
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same" | |
) | |
def MAXIM( | |
features: int = 64, | |
depth: int = 3, | |
num_stages: int = 2, | |
num_groups: int = 1, | |
use_bias: bool = True, | |
num_supervision_scales: int = 1, | |
lrelu_slope: float = 0.2, | |
use_global_mlp: bool = True, | |
use_cross_gating: bool = True, | |
high_res_stages: int = 2, | |
block_size_hr=(16, 16), | |
block_size_lr=(8, 8), | |
grid_size_hr=(16, 16), | |
grid_size_lr=(8, 8), | |
num_bottleneck_blocks: int = 1, | |
block_gmlp_factor: int = 2, | |
grid_gmlp_factor: int = 2, | |
input_proj_factor: int = 2, | |
channels_reduction: int = 4, | |
num_outputs: int = 3, | |
dropout_rate: float = 0.0, | |
): | |
"""The MAXIM model function with multi-stage and multi-scale supervision. | |
For more model details, please check the CVPR paper: | |
MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973) | |
Attributes: | |
features: initial hidden dimension for the input resolution. | |
depth: the number of downsampling depth for the model. | |
num_stages: how many stages to use. It will also affects the output list. | |
num_groups: how many blocks each stage contains. | |
use_bias: whether to use bias in all the conv/mlp layers. | |
num_supervision_scales: the number of desired supervision scales. | |
lrelu_slope: the negative slope parameter in leaky_relu layers. | |
use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each | |
layer. | |
use_cross_gating: whether to use the cross-gating MLP block (CGB) in the | |
skip connections and multi-stage feature fusion layers. | |
high_res_stages: how many stages are specificied as high-res stages. The | |
rest (depth - high_res_stages) are called low_res_stages. | |
block_size_hr: the block_size parameter for high-res stages. | |
block_size_lr: the block_size parameter for low-res stages. | |
grid_size_hr: the grid_size parameter for high-res stages. | |
grid_size_lr: the grid_size parameter for low-res stages. | |
num_bottleneck_blocks: how many bottleneck blocks. | |
block_gmlp_factor: the input projection factor for block_gMLP layers. | |
grid_gmlp_factor: the input projection factor for grid_gMLP layers. | |
input_proj_factor: the input projection factor for the MAB block. | |
channels_reduction: the channel reduction factor for SE layer. | |
num_outputs: the output channels. | |
dropout_rate: Dropout rate. | |
Returns: | |
The output contains a list of arrays consisting of multi-stage multi-scale | |
outputs. For example, if num_stages = num_supervision_scales = 3 (the | |
model used in the paper), the output specs are: outputs = | |
[[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3], | |
[output_stage2_scale1, output_stage2_scale2, output_stage2_scale3], | |
[output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],] | |
The final output can be retrieved by outputs[-1][-1]. | |
""" | |
def apply(x): | |
n, h, w, c = ( | |
K.int_shape(x)[0], | |
K.int_shape(x)[1], | |
K.int_shape(x)[2], | |
K.int_shape(x)[3], | |
) # input image shape | |
shortcuts = [] | |
shortcuts.append(x) | |
# Get multi-scale input images | |
for i in range(1, num_supervision_scales): | |
resizing_layer = Resizing( | |
height=h // (2 ** i), | |
width=w // (2 ** i), | |
method="nearest", | |
antialias=True, # Following `jax.image.resize()`. | |
name=f"initial_resizing_{K.get_uid('Resizing')}", | |
) | |
shortcuts.append(resizing_layer(x)) | |
# store outputs from all stages and all scales | |
# Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)], # Stage-1 outputs | |
# [(64, 64, 3), (128, 128, 3), (256, 256, 3)],] # Stage-2 outputs | |
outputs_all = [] | |
sam_features, encs_prev, decs_prev = [], [], [] | |
for idx_stage in range(num_stages): | |
# Input convolution, get multi-scale input features | |
x_scales = [] | |
for i in range(num_supervision_scales): | |
x_scale = Conv3x3( | |
filters=(2 ** i) * features, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_input_conv_{i}", | |
)(shortcuts[i]) | |
# If later stages, fuse input features with SAM features from prev stage | |
if idx_stage > 0: | |
# use larger blocksize at high-res stages | |
if use_cross_gating: | |
block_size = ( | |
block_size_hr if i < high_res_stages else block_size_lr | |
) | |
grid_size = grid_size_hr if i < high_res_stages else block_size_lr | |
x_scale, _ = CrossGatingBlock( | |
features=(2 ** i) * features, | |
block_size=block_size, | |
grid_size=grid_size, | |
dropout_rate=dropout_rate, | |
input_proj_factor=input_proj_factor, | |
upsample_y=False, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_input_fuse_sam_{i}", | |
)(x_scale, sam_features.pop()) | |
else: | |
x_scale = Conv1x1( | |
filters=(2 ** i) * features, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_input_catconv_{i}", | |
)(tf.concat([x_scale, sam_features.pop()], axis=-1)) | |
x_scales.append(x_scale) | |
# start encoder blocks | |
encs = [] | |
x = x_scales[0] # First full-scale input feature | |
for i in range(depth): # 0, 1, 2 | |
# use larger blocksize at high-res stages, vice versa. | |
block_size = block_size_hr if i < high_res_stages else block_size_lr | |
grid_size = grid_size_hr if i < high_res_stages else block_size_lr | |
use_cross_gating_layer = True if idx_stage > 0 else False | |
# Multi-scale input if multi-scale supervision | |
x_scale = x_scales[i] if i < num_supervision_scales else None | |
# UNet Encoder block | |
enc_prev = encs_prev.pop() if idx_stage > 0 else None | |
dec_prev = decs_prev.pop() if idx_stage > 0 else None | |
x, bridge = UNetEncoderBlock( | |
num_channels=(2 ** i) * features, | |
num_groups=num_groups, | |
downsample=True, | |
lrelu_slope=lrelu_slope, | |
block_size=block_size, | |
grid_size=grid_size, | |
block_gmlp_factor=block_gmlp_factor, | |
grid_gmlp_factor=grid_gmlp_factor, | |
input_proj_factor=input_proj_factor, | |
channels_reduction=channels_reduction, | |
use_global_mlp=use_global_mlp, | |
dropout_rate=dropout_rate, | |
use_bias=use_bias, | |
use_cross_gating=use_cross_gating_layer, | |
name=f"stage_{idx_stage}_encoder_block_{i}", | |
)(x, skip=x_scale, enc=enc_prev, dec=dec_prev) | |
# Cache skip signals | |
encs.append(bridge) | |
# Global MLP bottleneck blocks | |
for i in range(num_bottleneck_blocks): | |
x = BottleneckBlock( | |
block_size=block_size_lr, | |
grid_size=block_size_lr, | |
features=(2 ** (depth - 1)) * features, | |
num_groups=num_groups, | |
block_gmlp_factor=block_gmlp_factor, | |
grid_gmlp_factor=grid_gmlp_factor, | |
input_proj_factor=input_proj_factor, | |
dropout_rate=dropout_rate, | |
use_bias=use_bias, | |
channels_reduction=channels_reduction, | |
name=f"stage_{idx_stage}_global_block_{i}", | |
)(x) | |
# cache global feature for cross-gating | |
global_feature = x | |
# start cross gating. Use multi-scale feature fusion | |
skip_features = [] | |
for i in reversed(range(depth)): # 2, 1, 0 | |
# use larger blocksize at high-res stages | |
block_size = block_size_hr if i < high_res_stages else block_size_lr | |
grid_size = grid_size_hr if i < high_res_stages else block_size_lr | |
# get additional multi-scale signals | |
signal = tf.concat( | |
[ | |
UpSampleRatio( | |
num_channels=(2 ** i) * features, | |
ratio=2 ** (j - i), | |
use_bias=use_bias, | |
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}", | |
)(enc) | |
for j, enc in enumerate(encs) | |
], | |
axis=-1, | |
) | |
# Use cross-gating to cross modulate features | |
if use_cross_gating: | |
skips, global_feature = CrossGatingBlock( | |
features=(2 ** i) * features, | |
block_size=block_size, | |
grid_size=grid_size, | |
input_proj_factor=input_proj_factor, | |
dropout_rate=dropout_rate, | |
upsample_y=True, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_cross_gating_block_{i}", | |
)(signal, global_feature) | |
else: | |
skips = Conv1x1( | |
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0" | |
)(signal) | |
skips = Conv3x3( | |
filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1" | |
)(skips) | |
skip_features.append(skips) | |
# start decoder. Multi-scale feature fusion of cross-gated features | |
outputs, decs, sam_features = [], [], [] | |
for i in reversed(range(depth)): | |
# use larger blocksize at high-res stages | |
block_size = block_size_hr if i < high_res_stages else block_size_lr | |
grid_size = grid_size_hr if i < high_res_stages else block_size_lr | |
# get multi-scale skip signals from cross-gating block | |
signal = tf.concat( | |
[ | |
UpSampleRatio( | |
num_channels=(2 ** i) * features, | |
ratio=2 ** (depth - j - 1 - i), | |
use_bias=use_bias, | |
name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}", | |
)(skip) | |
for j, skip in enumerate(skip_features) | |
], | |
axis=-1, | |
) | |
# Decoder block | |
x = UNetDecoderBlock( | |
num_channels=(2 ** i) * features, | |
num_groups=num_groups, | |
lrelu_slope=lrelu_slope, | |
block_size=block_size, | |
grid_size=grid_size, | |
block_gmlp_factor=block_gmlp_factor, | |
grid_gmlp_factor=grid_gmlp_factor, | |
input_proj_factor=input_proj_factor, | |
channels_reduction=channels_reduction, | |
use_global_mlp=use_global_mlp, | |
dropout_rate=dropout_rate, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_decoder_block_{i}", | |
)(x, bridge=signal) | |
# Cache decoder features for later-stage's usage | |
decs.append(x) | |
# output conv, if not final stage, use supervised-attention-block. | |
if i < num_supervision_scales: | |
if idx_stage < num_stages - 1: # not last stage, apply SAM | |
sam, output = SAM( | |
num_channels=(2 ** i) * features, | |
output_channels=num_outputs, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_supervised_attention_module_{i}", | |
)(x, shortcuts[i]) | |
outputs.append(output) | |
sam_features.append(sam) | |
else: # Last stage, apply output convolutions | |
output = Conv3x3( | |
num_outputs, | |
use_bias=use_bias, | |
name=f"stage_{idx_stage}_output_conv_{i}", | |
)(x) | |
output = output + shortcuts[i] | |
outputs.append(output) | |
# Cache encoder and decoder features for later-stage's usage | |
encs_prev = encs[::-1] | |
decs_prev = decs | |
# Store outputs | |
outputs_all.append(outputs) | |
return outputs_all | |
return apply | |