File size: 2,769 Bytes
033bd8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce3e4c
033bd8b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn as nn

from .hrnetv2.hrnet_ocr import HighResolutionNet
from .hrnetv2.modifiers import LRMult
from .base.basic_blocks import MaxPoolDownSize
from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization


def build_backbone(name, opt):
    return eval(name)(opt)


class baseline(IHModelWithBackbone):
    def __init__(self, opt, ocr=64):
        base_config = {'model': DeepImageHarmonization,
                       'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}

        params = base_config['params']

        backbone = HRNetV2(opt, ocr=ocr)

        params.update(dict(
            backbone_from=2,
            backbone_channels=backbone.output_channels,
            backbone_mode='cat',
            opt=opt
        ))
        base_model = base_config['model'](**params)

        super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)


class HRNetV2(nn.Module):
    def __init__(
            self, opt,
            cat_outputs=True,
            pyramid_channels=-1, pyramid_depth=4,
            width=18, ocr=128, small=False,
            lr_mult=0.1, pretained=True
    ):
        super(HRNetV2, self).__init__()
        self.opt = opt
        self.cat_outputs = cat_outputs
        self.ocr_on = ocr > 0 and cat_outputs
        self.pyramid_on = pyramid_channels > 0 and cat_outputs

        self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
        self.hrnet.apply(LRMult(lr_mult))
        if self.ocr_on:
            self.hrnet.ocr_distri_head.apply(LRMult(1.0))
            self.hrnet.ocr_gather_head.apply(LRMult(1.0))
            self.hrnet.conv3x3_ocr.apply(LRMult(1.0))

        hrnet_cat_channels = [width * 2 ** i for i in range(4)]
        if self.pyramid_on:
            self.output_channels = [pyramid_channels] * 4
        elif self.ocr_on:
            self.output_channels = [ocr * 2]
        elif self.cat_outputs:
            self.output_channels = [sum(hrnet_cat_channels)]
        else:
            self.output_channels = hrnet_cat_channels

        if self.pyramid_on:
            downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
            self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)

        if pretained:
            self.load_pretrained_weights(
                "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")

        self.output_resolution = (opt.input_size // 8) ** 2

    def forward(self, image, mask, mask_features=None):
        outputs = list(self.hrnet(image, mask, mask_features))
        return outputs

    def load_pretrained_weights(self, pretrained_path):
        self.hrnet.load_pretrained_weights(pretrained_path)