Spaces:
Running
Running
File size: 2,769 Bytes
033bd8b 3ce3e4c 033bd8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch.nn as nn
from .hrnetv2.hrnet_ocr import HighResolutionNet
from .hrnetv2.modifiers import LRMult
from .base.basic_blocks import MaxPoolDownSize
from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
def build_backbone(name, opt):
return eval(name)(opt)
class baseline(IHModelWithBackbone):
def __init__(self, opt, ocr=64):
base_config = {'model': DeepImageHarmonization,
'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
params = base_config['params']
backbone = HRNetV2(opt, ocr=ocr)
params.update(dict(
backbone_from=2,
backbone_channels=backbone.output_channels,
backbone_mode='cat',
opt=opt
))
base_model = base_config['model'](**params)
super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
class HRNetV2(nn.Module):
def __init__(
self, opt,
cat_outputs=True,
pyramid_channels=-1, pyramid_depth=4,
width=18, ocr=128, small=False,
lr_mult=0.1, pretained=True
):
super(HRNetV2, self).__init__()
self.opt = opt
self.cat_outputs = cat_outputs
self.ocr_on = ocr > 0 and cat_outputs
self.pyramid_on = pyramid_channels > 0 and cat_outputs
self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
self.hrnet.apply(LRMult(lr_mult))
if self.ocr_on:
self.hrnet.ocr_distri_head.apply(LRMult(1.0))
self.hrnet.ocr_gather_head.apply(LRMult(1.0))
self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
hrnet_cat_channels = [width * 2 ** i for i in range(4)]
if self.pyramid_on:
self.output_channels = [pyramid_channels] * 4
elif self.ocr_on:
self.output_channels = [ocr * 2]
elif self.cat_outputs:
self.output_channels = [sum(hrnet_cat_channels)]
else:
self.output_channels = hrnet_cat_channels
if self.pyramid_on:
downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
if pretained:
self.load_pretrained_weights(
"./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
self.output_resolution = (opt.input_size // 8) ** 2
def forward(self, image, mask, mask_features=None):
outputs = list(self.hrnet(image, mask, mask_features))
return outputs
def load_pretrained_weights(self, pretrained_path):
self.hrnet.load_pretrained_weights(pretrained_path)
|