import torch from torch import nn from torch.nn import functional as F from torchvision.models.segmentation.deeplabv3 import ASPP from .decoder import Decoder from .mobilenet import MobileNetV2Encoder from .refiner import Refiner from .resnet import ResNetEncoder from .utils import load_matched_state_dict class Base(nn.Module): """ A generic implementation of the base encoder-decoder network inspired by DeepLab. Accepts arbitrary channels for input and output. """ def __init__(self, backbone: str, in_channels: int, out_channels: int): super().__init__() assert backbone in ["resnet50", "resnet101", "mobilenetv2"] if backbone in ['resnet50', 'resnet101']: self.backbone = ResNetEncoder(in_channels, variant=backbone) self.aspp = ASPP(2048, [3, 6, 9]) self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels]) else: self.backbone = MobileNetV2Encoder(in_channels) self.aspp = ASPP(320, [3, 6, 9]) self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels]) def forward(self, x): x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) return x def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True): # Pretrained DeepLabV3 models are provided by . # This method converts and loads their pretrained state_dict to match with our model structure. # This method is not needed if you are not planning to train from deeplab weights. # Use load_state_dict() for normal weight loading. # Convert state_dict naming for aspp module state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()} if isinstance(self.backbone, ResNetEncoder): # ResNet backbone does not need change. load_matched_state_dict(self, state_dict, print_stats) else: # Change MobileNetV2 backbone to state_dict format, then change back after loading. backbone_features = self.backbone.features self.backbone.low_level_features = backbone_features[:4] self.backbone.high_level_features = backbone_features[4:] del self.backbone.features load_matched_state_dict(self, state_dict, print_stats) self.backbone.features = backbone_features del self.backbone.low_level_features del self.backbone.high_level_features class MattingBase(Base): """ MattingBase is used to produce coarse global results at a lower resolution. MattingBase extends Base. Args: backbone: ["resnet50", "resnet101", "mobilenetv2"] Input: src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1. Output: pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1. hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module. Example: model = MattingBase(backbone='resnet50') pha, fgr, err, hid = model(src, bgr) # for training pha, fgr = model(src, bgr)[:2] # for inference """ def __init__(self, backbone: str): super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32)) def forward(self, src, bgr): x = torch.cat([src, bgr], dim=1) x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) pha = x[:, 0:1].clamp_(0., 1.) fgr = x[:, 1:4].add(src).clamp_(0., 1.) err = x[:, 4:5].clamp_(0., 1.) hid = x[:, 5: ].relu_() return pha, fgr, err, hid class MattingRefine(MattingBase): """ MattingRefine includes the refiner module to upsample coarse result to full resolution. MattingRefine extends MattingBase. Args: backbone: ["resnet50", "resnet101", "mobilenetv2"] backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25. Must not be greater than 1/2. refine_mode: refine area selection mode. Options: "full" - No area selection, refine everywhere using regular Conv2d. "sampling" - Refine fixed amount of pixels ranked by the top most errors. "thresholding" - Refine varying amount of pixels that has more error than the threshold. refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling". refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3] refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest. Input: src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1. Output: pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1. fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1. err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1. ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations. Example: model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000) model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1) model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full') pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training pha, fgr = model(src, bgr)[:2] # for inference """ def __init__(self, backbone: str, backbone_scale: float = 1/4, refine_mode: str = 'sampling', refine_sample_pixels: int = 80_000, refine_threshold: float = 0.1, refine_kernel_size: int = 3, refine_prevent_oversampling: bool = True, refine_patch_crop_method: str = 'unfold', refine_patch_replace_method: str = 'scatter_nd'): assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2' super().__init__(backbone) self.backbone_scale = backbone_scale self.refiner = Refiner(refine_mode, refine_sample_pixels, refine_threshold, refine_kernel_size, refine_prevent_oversampling, refine_patch_crop_method, refine_patch_replace_method) def forward(self, src, bgr): assert src.size() == bgr.size(), 'src and bgr must have the same shape' assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \ 'src and bgr must have width and height that are divisible by 4' # Downsample src and bgr for backbone src_sm = F.interpolate(src, scale_factor=self.backbone_scale, mode='bilinear', align_corners=False, recompute_scale_factor=True) bgr_sm = F.interpolate(bgr, scale_factor=self.backbone_scale, mode='bilinear', align_corners=False, recompute_scale_factor=True) # Base x = torch.cat([src_sm, bgr_sm], dim=1) x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) pha_sm = x[:, 0:1].clamp_(0., 1.) fgr_sm = x[:, 1:4] err_sm = x[:, 4:5].clamp_(0., 1.) hid_sm = x[:, 5: ].relu_() # Refiner pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm) # Clamp outputs pha = pha.clamp_(0., 1.) fgr = fgr.add_(src).clamp_(0., 1.) fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.) return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm