Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision.models.segmentation.deeplabv3 import ASPP | |
from .decoder import Decoder | |
from .mobilenet import MobileNetV2Encoder | |
from .refiner import Refiner | |
from .resnet import ResNetEncoder | |
from .utils import load_matched_state_dict | |
class Base(nn.Module): | |
""" | |
A generic implementation of the base encoder-decoder network inspired by DeepLab. | |
Accepts arbitrary channels for input and output. | |
""" | |
def __init__(self, backbone: str, in_channels: int, out_channels: int): | |
super().__init__() | |
assert backbone in ["resnet50", "resnet101", "mobilenetv2"] | |
if backbone in ['resnet50', 'resnet101']: | |
self.backbone = ResNetEncoder(in_channels, variant=backbone) | |
self.aspp = ASPP(2048, [3, 6, 9]) | |
self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels]) | |
else: | |
self.backbone = MobileNetV2Encoder(in_channels) | |
self.aspp = ASPP(320, [3, 6, 9]) | |
self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels]) | |
def forward(self, x): | |
x, *shortcuts = self.backbone(x) | |
x = self.aspp(x) | |
x = self.decoder(x, *shortcuts) | |
return x | |
def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True): | |
# Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>. | |
# This method converts and loads their pretrained state_dict to match with our model structure. | |
# This method is not needed if you are not planning to train from deeplab weights. | |
# Use load_state_dict() for normal weight loading. | |
# Convert state_dict naming for aspp module | |
state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()} | |
if isinstance(self.backbone, ResNetEncoder): | |
# ResNet backbone does not need change. | |
load_matched_state_dict(self, state_dict, print_stats) | |
else: | |
# Change MobileNetV2 backbone to state_dict format, then change back after loading. | |
backbone_features = self.backbone.features | |
self.backbone.low_level_features = backbone_features[:4] | |
self.backbone.high_level_features = backbone_features[4:] | |
del self.backbone.features | |
load_matched_state_dict(self, state_dict, print_stats) | |
self.backbone.features = backbone_features | |
del self.backbone.low_level_features | |
del self.backbone.high_level_features | |
class MattingBase(Base): | |
""" | |
MattingBase is used to produce coarse global results at a lower resolution. | |
MattingBase extends Base. | |
Args: | |
backbone: ["resnet50", "resnet101", "mobilenetv2"] | |
Input: | |
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. | |
bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1. | |
Output: | |
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. | |
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. | |
err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1. | |
hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module. | |
Example: | |
model = MattingBase(backbone='resnet50') | |
pha, fgr, err, hid = model(src, bgr) # for training | |
pha, fgr = model(src, bgr)[:2] # for inference | |
""" | |
def __init__(self, backbone: str): | |
super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32)) | |
def forward(self, src, bgr): | |
x = torch.cat([src, bgr], dim=1) | |
x, *shortcuts = self.backbone(x) | |
x = self.aspp(x) | |
x = self.decoder(x, *shortcuts) | |
pha = x[:, 0:1].clamp_(0., 1.) | |
fgr = x[:, 1:4].add(src).clamp_(0., 1.) | |
err = x[:, 4:5].clamp_(0., 1.) | |
hid = x[:, 5: ].relu_() | |
return pha, fgr, err, hid | |
class MattingRefine(MattingBase): | |
""" | |
MattingRefine includes the refiner module to upsample coarse result to full resolution. | |
MattingRefine extends MattingBase. | |
Args: | |
backbone: ["resnet50", "resnet101", "mobilenetv2"] | |
backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25. | |
Must not be greater than 1/2. | |
refine_mode: refine area selection mode. Options: | |
"full" - No area selection, refine everywhere using regular Conv2d. | |
"sampling" - Refine fixed amount of pixels ranked by the top most errors. | |
"thresholding" - Refine varying amount of pixels that has more error than the threshold. | |
refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling". | |
refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". | |
refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3] | |
refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest. | |
Input: | |
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. | |
bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1. | |
Output: | |
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. | |
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. | |
pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1. | |
fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1. | |
err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1. | |
ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations. | |
Example: | |
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000) | |
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1) | |
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full') | |
pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training | |
pha, fgr = model(src, bgr)[:2] # for inference | |
""" | |
def __init__(self, | |
backbone: str, | |
backbone_scale: float = 1/4, | |
refine_mode: str = 'sampling', | |
refine_sample_pixels: int = 80_000, | |
refine_threshold: float = 0.1, | |
refine_kernel_size: int = 3, | |
refine_prevent_oversampling: bool = True, | |
refine_patch_crop_method: str = 'unfold', | |
refine_patch_replace_method: str = 'scatter_nd'): | |
assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2' | |
super().__init__(backbone) | |
self.backbone_scale = backbone_scale | |
self.refiner = Refiner(refine_mode, | |
refine_sample_pixels, | |
refine_threshold, | |
refine_kernel_size, | |
refine_prevent_oversampling, | |
refine_patch_crop_method, | |
refine_patch_replace_method) | |
def forward(self, src, bgr): | |
assert src.size() == bgr.size(), 'src and bgr must have the same shape' | |
assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \ | |
'src and bgr must have width and height that are divisible by 4' | |
# Downsample src and bgr for backbone | |
src_sm = F.interpolate(src, | |
scale_factor=self.backbone_scale, | |
mode='bilinear', | |
align_corners=False, | |
recompute_scale_factor=True) | |
bgr_sm = F.interpolate(bgr, | |
scale_factor=self.backbone_scale, | |
mode='bilinear', | |
align_corners=False, | |
recompute_scale_factor=True) | |
# Base | |
x = torch.cat([src_sm, bgr_sm], dim=1) | |
x, *shortcuts = self.backbone(x) | |
x = self.aspp(x) | |
x = self.decoder(x, *shortcuts) | |
pha_sm = x[:, 0:1].clamp_(0., 1.) | |
fgr_sm = x[:, 1:4] | |
err_sm = x[:, 4:5].clamp_(0., 1.) | |
hid_sm = x[:, 5: ].relu_() | |
# Refiner | |
pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm) | |
# Clamp outputs | |
pha = pha.clamp_(0., 1.) | |
fgr = fgr.add_(src).clamp_(0., 1.) | |
fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.) | |
return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm | |