Spaces:
Running
Running
Update
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitattributes +4 -0
- assets/demo.gif +3 -0
- assets/metrics.png +0 -0
- assets/network.png +0 -0
- assets/title_any_image.gif +0 -0
- assets/title_harmon.gif +0 -0
- assets/title_you_want.gif +0 -0
- assets/visualizations.png +0 -0
- assets/visualizations2.png +3 -0
- datasets/__init__.py +0 -0
- datasets/__pycache__/__init__.cpython-38.pyc +0 -0
- datasets/__pycache__/build_INR_dataset.cpython-38.pyc +0 -0
- datasets/__pycache__/build_dataset.cpython-38.pyc +0 -0
- datasets/build_INR_dataset.py +36 -0
- datasets/build_dataset.py +371 -0
- demo/demo_2k_composite.jpg +0 -0
- demo/demo_2k_mask.jpg +0 -0
- demo/demo_2k_real.jpg +0 -0
- demo/demo_6k_composite.jpg +3 -0
- demo/demo_6k_mask.jpg +0 -0
- demo/demo_6k_real.jpg +3 -0
- model/__init__.py +0 -0
- model/__pycache__/__init__.cpython-38.pyc +0 -0
- model/__pycache__/backbone.cpython-38.pyc +0 -0
- model/__pycache__/build_model.cpython-38.pyc +0 -0
- model/__pycache__/lut_transformation_net.cpython-38.pyc +0 -0
- model/backbone.py +79 -0
- model/base/__init__.py +0 -0
- model/base/__pycache__/__init__.cpython-38.pyc +0 -0
- model/base/__pycache__/basic_blocks.cpython-38.pyc +0 -0
- model/base/__pycache__/conv_autoencoder.cpython-38.pyc +0 -0
- model/base/__pycache__/ih_model.cpython-38.pyc +0 -0
- model/base/__pycache__/ops.cpython-38.pyc +0 -0
- model/base/basic_blocks.py +366 -0
- model/base/conv_autoencoder.py +519 -0
- model/base/ih_model.py +88 -0
- model/base/ops.py +397 -0
- model/build_model.py +24 -0
- model/hrnetv2/__init__.py +0 -0
- model/hrnetv2/__pycache__/__init__.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/modifiers.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/ocr.cpython-38.pyc +0 -0
- model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc +0 -0
- model/hrnetv2/hrnet_ocr.py +400 -0
- model/hrnetv2/modifiers.py +11 -0
- model/hrnetv2/ocr.py +140 -0
- model/hrnetv2/resnetv1b.py +276 -0
- model/lut_transformation_net.py +65 -0
- pretrained_models/Resolution_1024_HAdobe5K.pth +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text
|
assets/demo.gif
ADDED
Git LFS Details
|
assets/metrics.png
ADDED
assets/network.png
ADDED
assets/title_any_image.gif
ADDED
assets/title_harmon.gif
ADDED
assets/title_you_want.gif
ADDED
assets/visualizations.png
ADDED
assets/visualizations2.png
ADDED
Git LFS Details
|
datasets/__init__.py
ADDED
File without changes
|
datasets/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (176 Bytes). View file
|
|
datasets/__pycache__/build_INR_dataset.cpython-38.pyc
ADDED
Binary file (1.31 kB). View file
|
|
datasets/__pycache__/build_dataset.cpython-38.pyc
ADDED
Binary file (6.96 kB). View file
|
|
datasets/build_INR_dataset.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import misc
|
2 |
+
from albumentations import Resize
|
3 |
+
|
4 |
+
|
5 |
+
class Implicit2DGenerator(object):
|
6 |
+
def __init__(self, opt, mode):
|
7 |
+
if mode == 'Train':
|
8 |
+
sidelength = opt.INR_input_size
|
9 |
+
elif mode == 'Val':
|
10 |
+
sidelength = opt.input_size
|
11 |
+
else:
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
self.mode = mode
|
15 |
+
|
16 |
+
self.size = sidelength
|
17 |
+
|
18 |
+
if isinstance(sidelength, int):
|
19 |
+
sidelength = (sidelength, sidelength)
|
20 |
+
|
21 |
+
self.mgrid = misc.get_mgrid(sidelength)
|
22 |
+
|
23 |
+
self.transform = Resize(self.size, self.size)
|
24 |
+
|
25 |
+
def generator(self, torch_transforms, composite_image, real_image, mask):
|
26 |
+
composite_image = torch_transforms(self.transform(image=composite_image)['image'])
|
27 |
+
real_image = torch_transforms(self.transform(image=real_image)['image'])
|
28 |
+
|
29 |
+
fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
30 |
+
fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
31 |
+
bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
|
32 |
+
|
33 |
+
fg_INR_coordinates = self.mgrid
|
34 |
+
bg_INR_coordinates = self.mgrid
|
35 |
+
|
36 |
+
return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB
|
datasets/build_dataset.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torchvision
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
from utils.misc import prepare_cooridinate_input, customRandomCrop
|
9 |
+
|
10 |
+
from datasets.build_INR_dataset import Implicit2DGenerator
|
11 |
+
import albumentations
|
12 |
+
from albumentations import Resize, RandomResizedCrop, HorizontalFlip
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
|
16 |
+
class dataset_generator(torch.utils.data.Dataset):
|
17 |
+
def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.opt = opt
|
21 |
+
self.root_path = opt.dataset_path
|
22 |
+
self.mode = mode
|
23 |
+
|
24 |
+
self.alb_transforms = alb_transforms
|
25 |
+
self.torch_transforms = torch_transforms
|
26 |
+
self.kp_t = area_keep_thresh
|
27 |
+
|
28 |
+
with open(dataset_txt, 'r') as f:
|
29 |
+
self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
|
30 |
+
|
31 |
+
self.INR_dataset = Implicit2DGenerator(opt, self.mode)
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.dataset_samples)
|
35 |
+
|
36 |
+
def __getitem__(self, idx):
|
37 |
+
composite_image = self.dataset_samples[idx]
|
38 |
+
|
39 |
+
if self.opt.hr_train:
|
40 |
+
if self.opt.isFullRes:
|
41 |
+
"Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
|
42 |
+
"quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
|
43 |
+
"if `opt.isFullRes` is set to True."
|
44 |
+
composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
|
45 |
+
|
46 |
+
real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
|
47 |
+
mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
|
48 |
+
|
49 |
+
composite_image = cv2.imread(composite_image)
|
50 |
+
composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
|
51 |
+
|
52 |
+
real_image = cv2.imread(real_image)
|
53 |
+
real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
|
54 |
+
|
55 |
+
mask = cv2.imread(mask)
|
56 |
+
mask = mask[:, :, 0].astype(np.float32) / 255.
|
57 |
+
|
58 |
+
"""
|
59 |
+
If set `opt.hr_train` to True:
|
60 |
+
|
61 |
+
Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
|
62 |
+
the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
|
63 |
+
patch to feed in multiINR process. For inference, just resize it.
|
64 |
+
|
65 |
+
While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
|
66 |
+
|
67 |
+
BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
|
68 |
+
"""
|
69 |
+
if self.opt.hr_train:
|
70 |
+
if self.mode == 'Train' and self.opt.isFullRes:
|
71 |
+
if random.random() < 0.5: # LR mix training
|
72 |
+
mixTransform = albumentations.Compose(
|
73 |
+
[
|
74 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
75 |
+
HorizontalFlip()],
|
76 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
77 |
+
)
|
78 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
79 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
80 |
+
|
81 |
+
"Ensure fg and bg not disappear after transformation"
|
82 |
+
valid_augmentation = False
|
83 |
+
transform_out = None
|
84 |
+
time = 0
|
85 |
+
while not valid_augmentation:
|
86 |
+
time += 1
|
87 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
88 |
+
if time == 20:
|
89 |
+
tmp_transform = albumentations.Compose(
|
90 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
91 |
+
additional_targets={'real_image': 'image',
|
92 |
+
'object_mask': 'image'})
|
93 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
94 |
+
object_mask=mask)
|
95 |
+
valid_augmentation = True
|
96 |
+
else:
|
97 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
98 |
+
object_mask=mask)
|
99 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
100 |
+
origin_fg_ratio,
|
101 |
+
origin_bg_ratio,
|
102 |
+
self.kp_t)
|
103 |
+
composite_image = transform_out['image']
|
104 |
+
real_image = transform_out['real_image']
|
105 |
+
mask = transform_out['object_mask']
|
106 |
+
else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
|
107 |
+
if real_image.shape[0] < 256:
|
108 |
+
bottom_pad = 256 - real_image.shape[0]
|
109 |
+
else:
|
110 |
+
bottom_pad = (4 - real_image.shape[0] % 4) % 4
|
111 |
+
if real_image.shape[1] < 256:
|
112 |
+
right_pad = 256 - real_image.shape[1]
|
113 |
+
else:
|
114 |
+
right_pad = (4 - real_image.shape[1] % 4) % 4
|
115 |
+
composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
|
116 |
+
cv2.BORDER_REPLICATE)
|
117 |
+
real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
118 |
+
mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
|
119 |
+
|
120 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
121 |
+
origin_bg_ratio = 1 - origin_fg_ratio
|
122 |
+
|
123 |
+
"Ensure fg and bg not disappear after transformation"
|
124 |
+
valid_augmentation = False
|
125 |
+
transform_out = None
|
126 |
+
time = 0
|
127 |
+
|
128 |
+
if self.opt.hr_train:
|
129 |
+
if self.mode == 'Train':
|
130 |
+
if not self.opt.isFullRes:
|
131 |
+
if random.random() < 0.5: # LR mix training
|
132 |
+
mixTransform = albumentations.Compose(
|
133 |
+
[
|
134 |
+
RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
|
135 |
+
HorizontalFlip()],
|
136 |
+
additional_targets={'real_image': 'image', 'object_mask': 'image'}
|
137 |
+
)
|
138 |
+
while not valid_augmentation:
|
139 |
+
time += 1
|
140 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
141 |
+
if time == 20:
|
142 |
+
tmp_transform = albumentations.Compose(
|
143 |
+
[Resize(self.opt.base_size, self.opt.base_size)],
|
144 |
+
additional_targets={'real_image': 'image',
|
145 |
+
'object_mask': 'image'})
|
146 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
147 |
+
object_mask=mask)
|
148 |
+
valid_augmentation = True
|
149 |
+
else:
|
150 |
+
transform_out = mixTransform(image=composite_image, real_image=real_image,
|
151 |
+
object_mask=mask)
|
152 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
153 |
+
origin_fg_ratio,
|
154 |
+
origin_bg_ratio,
|
155 |
+
self.kp_t)
|
156 |
+
else:
|
157 |
+
while not valid_augmentation:
|
158 |
+
time += 1
|
159 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
160 |
+
if time == 20:
|
161 |
+
tmp_transform = albumentations.Compose(
|
162 |
+
[Resize(self.opt.input_size, self.opt.input_size)],
|
163 |
+
additional_targets={'real_image': 'image',
|
164 |
+
'object_mask': 'image'})
|
165 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image,
|
166 |
+
object_mask=mask)
|
167 |
+
valid_augmentation = True
|
168 |
+
else:
|
169 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
|
170 |
+
object_mask=mask)
|
171 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'],
|
172 |
+
origin_fg_ratio,
|
173 |
+
origin_bg_ratio,
|
174 |
+
self.kp_t)
|
175 |
+
composite_image = transform_out['image']
|
176 |
+
real_image = transform_out['real_image']
|
177 |
+
mask = transform_out['object_mask']
|
178 |
+
|
179 |
+
origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
180 |
+
|
181 |
+
full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
|
182 |
+
|
183 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
|
184 |
+
additional_targets={'real_image': 'image',
|
185 |
+
'object_mask': 'image'})
|
186 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
187 |
+
compos_list = [self.torch_transforms(transform_out['image'])]
|
188 |
+
real_list = [self.torch_transforms(transform_out['real_image'])]
|
189 |
+
mask_list = [
|
190 |
+
torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
|
191 |
+
coord_map_list = []
|
192 |
+
|
193 |
+
valid_augmentation = False
|
194 |
+
while not valid_augmentation:
|
195 |
+
# RSC strategy. To crop different resolutions.
|
196 |
+
transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
|
197 |
+
self.opt.base_size, self.opt.base_size)
|
198 |
+
valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
|
199 |
+
|
200 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
201 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
202 |
+
mask_list.append(
|
203 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
204 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
205 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
206 |
+
for n in range(2):
|
207 |
+
tmp_comp = cv2.resize(composite_image, (
|
208 |
+
composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
|
209 |
+
tmp_real = cv2.resize(real_image,
|
210 |
+
(real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
|
211 |
+
tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
|
212 |
+
tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
|
213 |
+
|
214 |
+
transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
|
215 |
+
self.opt.base_size // 2 ** (n + 1),
|
216 |
+
self.opt.base_size // 2 ** (n + 1), c_h, c_w)
|
217 |
+
compos_list.append(self.torch_transforms(transform_out[0]))
|
218 |
+
real_list.append(self.torch_transforms(transform_out[1]))
|
219 |
+
mask_list.append(
|
220 |
+
torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
|
221 |
+
coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
|
222 |
+
out_comp = compos_list
|
223 |
+
out_real = real_list
|
224 |
+
out_mask = mask_list
|
225 |
+
out_coord = coord_map_list
|
226 |
+
|
227 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
228 |
+
self.torch_transforms, transform_out[0], transform_out[1], mask)
|
229 |
+
|
230 |
+
return {
|
231 |
+
'file_path': self.dataset_samples[idx],
|
232 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
233 |
+
'composite_image': out_comp,
|
234 |
+
'real_image': out_real,
|
235 |
+
'mask': out_mask,
|
236 |
+
'coordinate_map': out_coord,
|
237 |
+
'composite_image0': out_comp[0],
|
238 |
+
'real_image0': out_real[0],
|
239 |
+
'mask0': out_mask[0],
|
240 |
+
'coordinate_map0': out_coord[0],
|
241 |
+
'composite_image1': out_comp[1],
|
242 |
+
'real_image1': out_real[1],
|
243 |
+
'mask1': out_mask[1],
|
244 |
+
'coordinate_map1': out_coord[1],
|
245 |
+
'composite_image2': out_comp[2],
|
246 |
+
'real_image2': out_real[2],
|
247 |
+
'mask2': out_mask[2],
|
248 |
+
'coordinate_map2': out_coord[2],
|
249 |
+
'composite_image3': out_comp[3],
|
250 |
+
'real_image3': out_real[3],
|
251 |
+
'mask3': out_mask[3],
|
252 |
+
'coordinate_map3': out_coord[3],
|
253 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
254 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
255 |
+
'fg_INR_RGB': fg_INR_RGB,
|
256 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
257 |
+
'bg_INR_RGB': bg_INR_RGB
|
258 |
+
}
|
259 |
+
else:
|
260 |
+
if not self.opt.isFullRes:
|
261 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
262 |
+
additional_targets={'real_image': 'image',
|
263 |
+
'object_mask': 'image'})
|
264 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
265 |
+
|
266 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
267 |
+
|
268 |
+
"Generate INR dataset."
|
269 |
+
mask = (torchvision.transforms.ToTensor()(
|
270 |
+
transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
271 |
+
mask = np.bool_(mask.numpy())
|
272 |
+
|
273 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
274 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
275 |
+
|
276 |
+
return {
|
277 |
+
'file_path': self.dataset_samples[idx],
|
278 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
279 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
280 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
281 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
282 |
+
# Can automatically transfer to Tensor.
|
283 |
+
'coordinate_map': coordinate_map,
|
284 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
285 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
286 |
+
'fg_INR_RGB': fg_INR_RGB,
|
287 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
288 |
+
'bg_INR_RGB': bg_INR_RGB
|
289 |
+
}
|
290 |
+
else:
|
291 |
+
coordinate_map = prepare_cooridinate_input(mask)
|
292 |
+
|
293 |
+
"Generate INR dataset."
|
294 |
+
mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
|
295 |
+
mask_tmp = np.bool_(mask_tmp.numpy())
|
296 |
+
|
297 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
298 |
+
self.torch_transforms, composite_image, real_image, mask_tmp)
|
299 |
+
|
300 |
+
return {
|
301 |
+
'file_path': self.dataset_samples[idx],
|
302 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
303 |
+
'composite_image': self.torch_transforms(composite_image),
|
304 |
+
'real_image': self.torch_transforms(real_image),
|
305 |
+
'mask': mask[np.newaxis, ...].astype(np.float32),
|
306 |
+
# Can automatically transfer to Tensor.
|
307 |
+
'coordinate_map': coordinate_map,
|
308 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
309 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
310 |
+
'fg_INR_RGB': fg_INR_RGB,
|
311 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
312 |
+
'bg_INR_RGB': bg_INR_RGB
|
313 |
+
}
|
314 |
+
|
315 |
+
while not valid_augmentation:
|
316 |
+
time += 1
|
317 |
+
# There are some extreme ratio pics, this code is to avoid being hindered by them.
|
318 |
+
if time == 20:
|
319 |
+
tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
|
320 |
+
additional_targets={'real_image': 'image',
|
321 |
+
'object_mask': 'image'})
|
322 |
+
transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
|
323 |
+
valid_augmentation = True
|
324 |
+
else:
|
325 |
+
transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
|
326 |
+
valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
|
327 |
+
origin_bg_ratio,
|
328 |
+
self.kp_t)
|
329 |
+
|
330 |
+
coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
|
331 |
+
|
332 |
+
"Generate INR dataset."
|
333 |
+
mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
|
334 |
+
mask = np.bool_(mask.numpy())
|
335 |
+
|
336 |
+
fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
|
337 |
+
self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
|
338 |
+
|
339 |
+
return {
|
340 |
+
'file_path': self.dataset_samples[idx],
|
341 |
+
'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
|
342 |
+
'composite_image': self.torch_transforms(transform_out['image']),
|
343 |
+
'real_image': self.torch_transforms(transform_out['real_image']),
|
344 |
+
'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
|
345 |
+
# Can automatically transfer to Tensor.
|
346 |
+
'coordinate_map': coordinate_map,
|
347 |
+
'fg_INR_coordinates': fg_INR_coordinates,
|
348 |
+
'bg_INR_coordinates': bg_INR_coordinates,
|
349 |
+
'fg_INR_RGB': fg_INR_RGB,
|
350 |
+
'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
|
351 |
+
'bg_INR_RGB': bg_INR_RGB
|
352 |
+
}
|
353 |
+
|
354 |
+
|
355 |
+
def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
|
356 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
357 |
+
current_bg_ratio = 1 - current_fg_ratio
|
358 |
+
|
359 |
+
if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
|
360 |
+
return False
|
361 |
+
|
362 |
+
return True
|
363 |
+
|
364 |
+
|
365 |
+
def check_hr_crop_sample(mask, origin_fg_ratio):
|
366 |
+
current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
|
367 |
+
|
368 |
+
if current_fg_ratio < 0.8 * origin_fg_ratio:
|
369 |
+
return False
|
370 |
+
|
371 |
+
return True
|
demo/demo_2k_composite.jpg
ADDED
demo/demo_2k_mask.jpg
ADDED
demo/demo_2k_real.jpg
ADDED
demo/demo_6k_composite.jpg
ADDED
Git LFS Details
|
demo/demo_6k_mask.jpg
ADDED
demo/demo_6k_real.jpg
ADDED
Git LFS Details
|
model/__init__.py
ADDED
File without changes
|
model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
model/__pycache__/backbone.cpython-38.pyc
ADDED
Binary file (2.96 kB). View file
|
|
model/__pycache__/build_model.cpython-38.pyc
ADDED
Binary file (1.03 kB). View file
|
|
model/__pycache__/lut_transformation_net.cpython-38.pyc
ADDED
Binary file (2.43 kB). View file
|
|
model/backbone.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from .hrnetv2.hrnet_ocr import HighResolutionNet
|
4 |
+
from .hrnetv2.modifiers import LRMult
|
5 |
+
from .base.basic_blocks import MaxPoolDownSize
|
6 |
+
from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
|
7 |
+
|
8 |
+
|
9 |
+
def build_backbone(name, opt):
|
10 |
+
return eval(name)(opt)
|
11 |
+
|
12 |
+
|
13 |
+
class baseline(IHModelWithBackbone):
|
14 |
+
def __init__(self, opt, ocr=64):
|
15 |
+
base_config = {'model': DeepImageHarmonization,
|
16 |
+
'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
|
17 |
+
|
18 |
+
params = base_config['params']
|
19 |
+
|
20 |
+
backbone = HRNetV2(opt, ocr=ocr)
|
21 |
+
|
22 |
+
params.update(dict(
|
23 |
+
backbone_from=2,
|
24 |
+
backbone_channels=backbone.output_channels,
|
25 |
+
backbone_mode='cat',
|
26 |
+
opt=opt
|
27 |
+
))
|
28 |
+
base_model = base_config['model'](**params)
|
29 |
+
|
30 |
+
super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
|
31 |
+
|
32 |
+
|
33 |
+
class HRNetV2(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self, opt,
|
36 |
+
cat_outputs=True,
|
37 |
+
pyramid_channels=-1, pyramid_depth=4,
|
38 |
+
width=18, ocr=128, small=False,
|
39 |
+
lr_mult=0.1, pretained=True
|
40 |
+
):
|
41 |
+
super(HRNetV2, self).__init__()
|
42 |
+
self.opt = opt
|
43 |
+
self.cat_outputs = cat_outputs
|
44 |
+
self.ocr_on = ocr > 0 and cat_outputs
|
45 |
+
self.pyramid_on = pyramid_channels > 0 and cat_outputs
|
46 |
+
|
47 |
+
self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
|
48 |
+
self.hrnet.apply(LRMult(lr_mult))
|
49 |
+
if self.ocr_on:
|
50 |
+
self.hrnet.ocr_distri_head.apply(LRMult(1.0))
|
51 |
+
self.hrnet.ocr_gather_head.apply(LRMult(1.0))
|
52 |
+
self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
|
53 |
+
|
54 |
+
hrnet_cat_channels = [width * 2 ** i for i in range(4)]
|
55 |
+
if self.pyramid_on:
|
56 |
+
self.output_channels = [pyramid_channels] * 4
|
57 |
+
elif self.ocr_on:
|
58 |
+
self.output_channels = [ocr * 2]
|
59 |
+
elif self.cat_outputs:
|
60 |
+
self.output_channels = [sum(hrnet_cat_channels)]
|
61 |
+
else:
|
62 |
+
self.output_channels = hrnet_cat_channels
|
63 |
+
|
64 |
+
if self.pyramid_on:
|
65 |
+
downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
|
66 |
+
self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
|
67 |
+
|
68 |
+
if pretained:
|
69 |
+
self.load_pretrained_weights(
|
70 |
+
r".\pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
|
71 |
+
|
72 |
+
self.output_resolution = (opt.input_size // 8) ** 2
|
73 |
+
|
74 |
+
def forward(self, image, mask, mask_features=None):
|
75 |
+
outputs = list(self.hrnet(image, mask, mask_features))
|
76 |
+
return outputs
|
77 |
+
|
78 |
+
def load_pretrained_weights(self, pretrained_path):
|
79 |
+
self.hrnet.load_pretrained_weights(pretrained_path)
|
model/base/__init__.py
ADDED
File without changes
|
model/base/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (178 Bytes). View file
|
|
model/base/__pycache__/basic_blocks.cpython-38.pyc
ADDED
Binary file (10.1 kB). View file
|
|
model/base/__pycache__/conv_autoencoder.cpython-38.pyc
ADDED
Binary file (13.8 kB). View file
|
|
model/base/__pycache__/ih_model.cpython-38.pyc
ADDED
Binary file (3.22 kB). View file
|
|
model/base/__pycache__/ops.cpython-38.pyc
ADDED
Binary file (14 kB). View file
|
|
model/base/basic_blocks.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def hyper_weight_init(m, in_features_main_net, activation):
|
7 |
+
if hasattr(m, 'weight'):
|
8 |
+
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
|
9 |
+
m.weight.data = m.weight.data / 1.e2
|
10 |
+
|
11 |
+
if hasattr(m, 'bias'):
|
12 |
+
with torch.no_grad():
|
13 |
+
if activation == 'sine':
|
14 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
|
15 |
+
elif activation == 'leakyrelu_pe':
|
16 |
+
m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
|
17 |
+
else:
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
|
21 |
+
class ConvBlock(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
in_channels, out_channels,
|
25 |
+
kernel_size=4, stride=2, padding=1,
|
26 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
27 |
+
bias=True,
|
28 |
+
):
|
29 |
+
super(ConvBlock, self).__init__()
|
30 |
+
self.block = nn.Sequential(
|
31 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
|
32 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
33 |
+
activation(),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.block(x)
|
38 |
+
|
39 |
+
|
40 |
+
class MaxPoolDownSize(nn.Module):
|
41 |
+
def __init__(self, in_channels, mid_channels, out_channels, depth):
|
42 |
+
super(MaxPoolDownSize, self).__init__()
|
43 |
+
self.depth = depth
|
44 |
+
self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
|
45 |
+
self.convs = nn.ModuleList([
|
46 |
+
ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
47 |
+
for conv_i in range(depth)
|
48 |
+
])
|
49 |
+
self.pool2d = nn.MaxPool2d(kernel_size=2)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
outputs = []
|
53 |
+
|
54 |
+
output = self.reduce_conv(x)
|
55 |
+
|
56 |
+
for conv_i, conv in enumerate(self.convs):
|
57 |
+
output = output if conv_i == 0 else self.pool2d(output)
|
58 |
+
outputs.append(conv(output))
|
59 |
+
|
60 |
+
return outputs
|
61 |
+
|
62 |
+
|
63 |
+
class convParams(nn.Module):
|
64 |
+
def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
|
65 |
+
super(convParams, self).__init__()
|
66 |
+
self.INR_in_out = INR_in_out
|
67 |
+
self.cont_split_weight = []
|
68 |
+
self.cont_split_bias = []
|
69 |
+
self.hidden_mlp_num = hidden_mlp_num
|
70 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
71 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
|
72 |
+
self.output_dim = output_dim
|
73 |
+
self.toRGB = toRGB
|
74 |
+
self.cont_extraction_net = nn.Sequential(
|
75 |
+
nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
|
76 |
+
# nn.BatchNorm2d(hidden_dim),
|
77 |
+
nn.ReLU(inplace=True),
|
78 |
+
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
|
79 |
+
# nn.BatchNorm2d(hidden_dim),
|
80 |
+
nn.ReLU(inplace=True),
|
81 |
+
nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
|
82 |
+
)
|
83 |
+
|
84 |
+
self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
85 |
+
|
86 |
+
self.basic_params = nn.ParameterList()
|
87 |
+
if opt.param_factorize_dim > 0:
|
88 |
+
for id in range(self.hidden_mlp_num + 1):
|
89 |
+
if id == 0:
|
90 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
91 |
+
else:
|
92 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
93 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
|
94 |
+
|
95 |
+
if toRGB:
|
96 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
|
97 |
+
|
98 |
+
def forward(self, feat, outMore=False):
|
99 |
+
cont_params = self.cont_extraction_net(feat)
|
100 |
+
out_mlp = self.to_mlp(cont_params)
|
101 |
+
if outMore:
|
102 |
+
return out_mlp, cont_params
|
103 |
+
return out_mlp
|
104 |
+
|
105 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
|
106 |
+
cont_params = 0
|
107 |
+
start = 0
|
108 |
+
if self.param_factorize_dim == -1:
|
109 |
+
cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
110 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
111 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
112 |
+
start = cont_params
|
113 |
+
|
114 |
+
for id in range(hidden_mlp_num):
|
115 |
+
cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
116 |
+
self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
|
117 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
118 |
+
start = cont_params
|
119 |
+
|
120 |
+
if toRGB:
|
121 |
+
cont_params += INR_in_out[1] * 3 + 3
|
122 |
+
self.cont_split_weight.append([start, cont_params - 3])
|
123 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
124 |
+
|
125 |
+
elif self.param_factorize_dim > 0:
|
126 |
+
cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
127 |
+
INR_in_out[1]
|
128 |
+
self.cont_split_weight.append(
|
129 |
+
[start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
130 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
131 |
+
start = cont_params
|
132 |
+
|
133 |
+
for id in range(hidden_mlp_num):
|
134 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
135 |
+
INR_in_out[1]
|
136 |
+
self.cont_split_weight.append(
|
137 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
|
138 |
+
self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
|
139 |
+
start = cont_params
|
140 |
+
|
141 |
+
if toRGB:
|
142 |
+
cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
143 |
+
self.cont_split_weight.append(
|
144 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
|
145 |
+
self.cont_split_bias.append([cont_params - 3, cont_params])
|
146 |
+
|
147 |
+
return cont_params
|
148 |
+
|
149 |
+
def to_mlp(self, params):
|
150 |
+
all_weight_bias = []
|
151 |
+
if self.param_factorize_dim == -1:
|
152 |
+
for id in range(self.hidden_mlp_num + 1):
|
153 |
+
if id == 0:
|
154 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
155 |
+
else:
|
156 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
157 |
+
weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
158 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
159 |
+
inp, outp)
|
160 |
+
|
161 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
162 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
163 |
+
all_weight_bias.append([weight, bias])
|
164 |
+
|
165 |
+
if self.toRGB:
|
166 |
+
inp, outp = self.INR_in_out[1], 3
|
167 |
+
weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
168 |
+
weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
|
169 |
+
inp, outp)
|
170 |
+
|
171 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
172 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
173 |
+
all_weight_bias.append([weight, bias])
|
174 |
+
|
175 |
+
return all_weight_bias
|
176 |
+
|
177 |
+
else:
|
178 |
+
for id in range(self.hidden_mlp_num + 1):
|
179 |
+
if id == 0:
|
180 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
181 |
+
else:
|
182 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
183 |
+
weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
|
184 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
185 |
+
inp, self.param_factorize_dim)
|
186 |
+
|
187 |
+
weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
|
188 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
189 |
+
self.param_factorize_dim, outp)
|
190 |
+
|
191 |
+
bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
|
192 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
193 |
+
|
194 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
195 |
+
|
196 |
+
if self.toRGB:
|
197 |
+
inp, outp = self.INR_in_out[1], 3
|
198 |
+
weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
|
199 |
+
weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
|
200 |
+
inp, self.param_factorize_dim)
|
201 |
+
|
202 |
+
weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
|
203 |
+
weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
|
204 |
+
self.param_factorize_dim, outp)
|
205 |
+
|
206 |
+
bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
|
207 |
+
bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
|
208 |
+
|
209 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
|
210 |
+
|
211 |
+
return all_weight_bias
|
212 |
+
|
213 |
+
|
214 |
+
class lineParams(nn.Module):
|
215 |
+
def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
|
216 |
+
hidden_dim=512):
|
217 |
+
super(lineParams, self).__init__()
|
218 |
+
self.INR_in_out = INR_in_out
|
219 |
+
self.app_split_weight = []
|
220 |
+
self.app_split_bias = []
|
221 |
+
self.toRGB = toRGB
|
222 |
+
self.hidden_mlp_num = hidden_mlp_num
|
223 |
+
self.param_factorize_dim = opt.param_factorize_dim
|
224 |
+
output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
|
225 |
+
self.output_dim = output_dim
|
226 |
+
|
227 |
+
self.compress_layer = nn.Sequential(
|
228 |
+
nn.Linear(input_resolution, 64, bias=False),
|
229 |
+
nn.BatchNorm1d(input_dim),
|
230 |
+
nn.ReLU(inplace=True),
|
231 |
+
nn.Linear(64, 1, bias=True)
|
232 |
+
)
|
233 |
+
|
234 |
+
self.app_extraction_net = nn.Sequential(
|
235 |
+
nn.Linear(input_dim, hidden_dim, bias=False),
|
236 |
+
# nn.BatchNorm1d(hidden_dim),
|
237 |
+
nn.ReLU(inplace=True),
|
238 |
+
nn.Linear(hidden_dim, hidden_dim, bias=False),
|
239 |
+
# nn.BatchNorm1d(hidden_dim),
|
240 |
+
nn.ReLU(inplace=True),
|
241 |
+
nn.Linear(hidden_dim, output_dim, bias=True)
|
242 |
+
)
|
243 |
+
|
244 |
+
self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
|
245 |
+
|
246 |
+
self.basic_params = nn.ParameterList()
|
247 |
+
if opt.param_factorize_dim > 0:
|
248 |
+
for id in range(self.hidden_mlp_num + 1):
|
249 |
+
if id == 0:
|
250 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
251 |
+
else:
|
252 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
253 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
|
254 |
+
if toRGB:
|
255 |
+
self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
|
256 |
+
|
257 |
+
def forward(self, feat):
|
258 |
+
app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
|
259 |
+
out_mlp = self.to_mlp(app_params)
|
260 |
+
return out_mlp, app_params
|
261 |
+
|
262 |
+
def cal_params_num(self, INR_in_out, hidden_mlp_num):
|
263 |
+
app_params = 0
|
264 |
+
start = 0
|
265 |
+
if self.param_factorize_dim == -1:
|
266 |
+
app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
|
267 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
268 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
269 |
+
start = app_params
|
270 |
+
|
271 |
+
for id in range(hidden_mlp_num):
|
272 |
+
app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
|
273 |
+
self.app_split_weight.append([start, app_params - INR_in_out[1]])
|
274 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
275 |
+
start = app_params
|
276 |
+
|
277 |
+
if self.toRGB:
|
278 |
+
app_params += INR_in_out[1] * 3 + 3
|
279 |
+
self.app_split_weight.append([start, app_params - 3])
|
280 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
281 |
+
|
282 |
+
elif self.param_factorize_dim > 0:
|
283 |
+
app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
284 |
+
INR_in_out[1]
|
285 |
+
self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
|
286 |
+
app_params - INR_in_out[1]])
|
287 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
288 |
+
start = app_params
|
289 |
+
|
290 |
+
for id in range(hidden_mlp_num):
|
291 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
|
292 |
+
INR_in_out[1]
|
293 |
+
self.app_split_weight.append(
|
294 |
+
[start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
|
295 |
+
self.app_split_bias.append([app_params - INR_in_out[1], app_params])
|
296 |
+
start = app_params
|
297 |
+
|
298 |
+
if self.toRGB:
|
299 |
+
app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
|
300 |
+
self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
|
301 |
+
app_params - 3])
|
302 |
+
self.app_split_bias.append([app_params - 3, app_params])
|
303 |
+
|
304 |
+
return app_params
|
305 |
+
|
306 |
+
def to_mlp(self, params):
|
307 |
+
all_weight_bias = []
|
308 |
+
if self.param_factorize_dim == -1:
|
309 |
+
for id in range(self.hidden_mlp_num + 1):
|
310 |
+
if id == 0:
|
311 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
312 |
+
else:
|
313 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
314 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
315 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
316 |
+
|
317 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
318 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
319 |
+
|
320 |
+
all_weight_bias.append([weight, bias])
|
321 |
+
|
322 |
+
if self.toRGB:
|
323 |
+
id = -1
|
324 |
+
inp, outp = self.INR_in_out[1], 3
|
325 |
+
weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
326 |
+
weight = weight.view(weight.shape[0], inp, outp)
|
327 |
+
|
328 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
329 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
330 |
+
|
331 |
+
all_weight_bias.append([weight, bias])
|
332 |
+
|
333 |
+
return all_weight_bias
|
334 |
+
|
335 |
+
else:
|
336 |
+
for id in range(self.hidden_mlp_num + 1):
|
337 |
+
if id == 0:
|
338 |
+
inp, outp = self.INR_in_out[0], self.INR_in_out[1]
|
339 |
+
else:
|
340 |
+
inp, outp = self.INR_in_out[1], self.INR_in_out[1]
|
341 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
342 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
343 |
+
|
344 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
345 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
346 |
+
|
347 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
348 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
349 |
+
|
350 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
351 |
+
|
352 |
+
if self.toRGB:
|
353 |
+
id = -1
|
354 |
+
inp, outp = self.INR_in_out[1], 3
|
355 |
+
weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
|
356 |
+
weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
|
357 |
+
|
358 |
+
weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
|
359 |
+
weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
|
360 |
+
|
361 |
+
bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
|
362 |
+
bias = bias.view(bias.shape[0], 1, outp)
|
363 |
+
|
364 |
+
all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
|
365 |
+
|
366 |
+
return all_weight_bias
|
model/base/conv_autoencoder.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
|
8 |
+
from .basic_blocks import ConvBlock, lineParams, convParams
|
9 |
+
from .ops import MaskedChannelAttention, FeaturesConnector
|
10 |
+
from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed
|
11 |
+
from utils import misc
|
12 |
+
from utils.misc import lin2img
|
13 |
+
from ..lut_transformation_net import build_lut_transform
|
14 |
+
|
15 |
+
|
16 |
+
class Sine(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
def forward(self, input):
|
21 |
+
return torch.sin(30 * input)
|
22 |
+
|
23 |
+
|
24 |
+
class Leaky_relu(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
def forward(self, input):
|
29 |
+
return torch.nn.functional.leaky_relu(input, 0.01, inplace=True)
|
30 |
+
|
31 |
+
|
32 |
+
def select_activation(type):
|
33 |
+
if type == 'sine':
|
34 |
+
return Sine()
|
35 |
+
elif type == 'leakyrelu_pe':
|
36 |
+
return Leaky_relu()
|
37 |
+
else:
|
38 |
+
raise NotImplementedError
|
39 |
+
|
40 |
+
|
41 |
+
class ConvEncoder(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
depth, ch,
|
45 |
+
norm_layer, batchnorm_from, max_channels,
|
46 |
+
backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False
|
47 |
+
):
|
48 |
+
super(ConvEncoder, self).__init__()
|
49 |
+
self.depth = depth
|
50 |
+
self.INRDecode = INRDecode
|
51 |
+
self.backbone_from = backbone_from
|
52 |
+
backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
|
53 |
+
|
54 |
+
in_channels = 4
|
55 |
+
out_channels = ch
|
56 |
+
|
57 |
+
self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
|
58 |
+
self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
|
59 |
+
self.blocks_channels = [out_channels, out_channels]
|
60 |
+
|
61 |
+
self.blocks_connected = nn.ModuleDict()
|
62 |
+
self.connectors = nn.ModuleDict()
|
63 |
+
for block_i in range(2, depth):
|
64 |
+
if block_i % 2:
|
65 |
+
in_channels = out_channels
|
66 |
+
else:
|
67 |
+
in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
|
68 |
+
|
69 |
+
if 0 <= backbone_from <= block_i and len(backbone_channels):
|
70 |
+
if INRDecode:
|
71 |
+
self.blocks_connected[f'block{block_i}_decode'] = ConvBlock(
|
72 |
+
in_channels, out_channels,
|
73 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
74 |
+
padding=int(block_i < depth - 1)
|
75 |
+
)
|
76 |
+
self.blocks_channels += [out_channels]
|
77 |
+
stage_channels = backbone_channels.pop()
|
78 |
+
connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
|
79 |
+
self.connectors[f'connector{block_i}'] = connector
|
80 |
+
in_channels = connector.output_channels
|
81 |
+
|
82 |
+
self.blocks_connected[f'block{block_i}'] = ConvBlock(
|
83 |
+
in_channels, out_channels,
|
84 |
+
norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
|
85 |
+
padding=int(block_i < depth - 1)
|
86 |
+
)
|
87 |
+
self.blocks_channels += [out_channels]
|
88 |
+
|
89 |
+
def forward(self, x, backbone_features):
|
90 |
+
backbone_features = [] if backbone_features is None else backbone_features[::-1]
|
91 |
+
|
92 |
+
outputs = [self.block0(x)]
|
93 |
+
outputs += [self.block1(outputs[-1])]
|
94 |
+
|
95 |
+
for block_i in range(2, self.depth):
|
96 |
+
output = outputs[-1]
|
97 |
+
connector_name = f'connector{block_i}'
|
98 |
+
if connector_name in self.connectors:
|
99 |
+
if self.INRDecode:
|
100 |
+
block = self.blocks_connected[f'block{block_i}_decode']
|
101 |
+
outputs += [block(output)]
|
102 |
+
|
103 |
+
stage_features = backbone_features.pop()
|
104 |
+
connector = self.connectors[connector_name]
|
105 |
+
output = connector(output, stage_features)
|
106 |
+
block = self.blocks_connected[f'block{block_i}']
|
107 |
+
outputs += [block(output)]
|
108 |
+
|
109 |
+
return outputs[::-1]
|
110 |
+
|
111 |
+
|
112 |
+
class DeconvDecoder(nn.Module):
|
113 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
|
114 |
+
super(DeconvDecoder, self).__init__()
|
115 |
+
self.image_fusion = image_fusion
|
116 |
+
self.deconv_blocks = nn.ModuleList()
|
117 |
+
|
118 |
+
in_channels = encoder_blocks_channels.pop()
|
119 |
+
out_channels = in_channels
|
120 |
+
for d in range(depth):
|
121 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
122 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
123 |
+
in_channels, out_channels,
|
124 |
+
norm_layer=norm_layer,
|
125 |
+
padding=0 if d == 0 else 1,
|
126 |
+
with_se=0 <= attend_from <= d
|
127 |
+
))
|
128 |
+
in_channels = out_channels
|
129 |
+
|
130 |
+
if self.image_fusion:
|
131 |
+
self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
|
132 |
+
self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
|
133 |
+
|
134 |
+
def forward(self, encoder_outputs, image, mask=None):
|
135 |
+
output = encoder_outputs[0]
|
136 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
137 |
+
output = block(output, mask)
|
138 |
+
output = output + skip_output
|
139 |
+
output = self.deconv_blocks[-1](output, mask)
|
140 |
+
|
141 |
+
if self.image_fusion:
|
142 |
+
attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
|
143 |
+
output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
|
144 |
+
else:
|
145 |
+
output = self.to_rgb(output)
|
146 |
+
|
147 |
+
return output
|
148 |
+
|
149 |
+
|
150 |
+
class SEDeconvBlock(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
in_channels, out_channels,
|
154 |
+
kernel_size=4, stride=2, padding=1,
|
155 |
+
norm_layer=nn.BatchNorm2d, activation=nn.ELU,
|
156 |
+
with_se=False
|
157 |
+
):
|
158 |
+
super(SEDeconvBlock, self).__init__()
|
159 |
+
self.with_se = with_se
|
160 |
+
self.block = nn.Sequential(
|
161 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
|
162 |
+
norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
|
163 |
+
activation(),
|
164 |
+
)
|
165 |
+
if self.with_se:
|
166 |
+
self.se = MaskedChannelAttention(out_channels)
|
167 |
+
|
168 |
+
def forward(self, x, mask=None):
|
169 |
+
out = self.block(x)
|
170 |
+
if self.with_se:
|
171 |
+
out = self.se(out, mask)
|
172 |
+
return out
|
173 |
+
|
174 |
+
|
175 |
+
class INRDecoder(nn.Module):
|
176 |
+
def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from):
|
177 |
+
super(INRDecoder, self).__init__()
|
178 |
+
self.INR_encoding = None
|
179 |
+
if opt.embedding_type == "PosEncodingNeRF":
|
180 |
+
self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size)
|
181 |
+
elif opt.embedding_type == "RandomFourier":
|
182 |
+
self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device)
|
183 |
+
elif opt.embedding_type == "CIPS_embed":
|
184 |
+
self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32)
|
185 |
+
elif opt.embedding_type == "INRGAN_embed":
|
186 |
+
self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size)
|
187 |
+
else:
|
188 |
+
raise NotImplementedError
|
189 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
190 |
+
max_hidden_mlp_num = attend_from + 1
|
191 |
+
self.opt = opt
|
192 |
+
self.max_hidden_mlp_num = max_hidden_mlp_num
|
193 |
+
self.content_mlp_blocks = nn.ModuleDict()
|
194 |
+
for n in range(max_hidden_mlp_num):
|
195 |
+
if n != max_hidden_mlp_num - 1:
|
196 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
197 |
+
[self.INR_encoding.out_dim + opt.INR_MLP_dim + (
|
198 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
199 |
+
opt, n + 1)
|
200 |
+
else:
|
201 |
+
self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
|
202 |
+
[self.INR_encoding.out_dim + (
|
203 |
+
4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
|
204 |
+
opt, n + 1)
|
205 |
+
|
206 |
+
self.deconv_blocks = nn.ModuleList()
|
207 |
+
|
208 |
+
encoder_blocks_channels = encoder_blocks_channels[::-1]
|
209 |
+
in_channels = encoder_blocks_channels.pop()
|
210 |
+
out_channels = in_channels
|
211 |
+
for d in range(depth - attend_from):
|
212 |
+
out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
|
213 |
+
self.deconv_blocks.append(SEDeconvBlock(
|
214 |
+
in_channels, out_channels,
|
215 |
+
norm_layer=norm_layer,
|
216 |
+
padding=0 if d == 0 else 1,
|
217 |
+
with_se=False
|
218 |
+
))
|
219 |
+
in_channels = out_channels
|
220 |
+
|
221 |
+
self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim],
|
222 |
+
(opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2,
|
223 |
+
opt, 2, toRGB=True)
|
224 |
+
|
225 |
+
self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim,
|
226 |
+
None, opt)
|
227 |
+
|
228 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
229 |
+
|
230 |
+
def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None):
|
231 |
+
"""For full resolution, do split."""
|
232 |
+
if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt,
|
233 |
+
'split_resolution')) and self.opt.isFullRes:
|
234 |
+
return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples)
|
235 |
+
|
236 |
+
encoder_outputs = encoder_outputs[::-1]
|
237 |
+
mlp_output = None
|
238 |
+
waitToRGB = []
|
239 |
+
for n in range(self.max_hidden_mlp_num):
|
240 |
+
if not self.opt.hr_train:
|
241 |
+
coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \
|
242 |
+
.unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
243 |
+
else:
|
244 |
+
if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'):
|
245 |
+
coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view(
|
246 |
+
encoder_outputs[0].shape[0], -1, 2)
|
247 |
+
else:
|
248 |
+
coord = misc.get_mgrid(
|
249 |
+
self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat(
|
250 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
251 |
+
|
252 |
+
"""Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper."""
|
253 |
+
if self.opt.isMoreINRInput:
|
254 |
+
if not self.opt.isFullRes or (
|
255 |
+
self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
256 |
+
res_h = res_w = np.sqrt(coord.shape[1]).astype(int)
|
257 |
+
else:
|
258 |
+
res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
259 |
+
res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1))
|
260 |
+
|
261 |
+
res_image = torchvision.transforms.Resize([res_h, res_w])(image)
|
262 |
+
res_mask = torchvision.transforms.Resize([res_h, res_w])(mask)
|
263 |
+
coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1),
|
264 |
+
res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
265 |
+
else:
|
266 |
+
coord = self.INR_encoding(coord)
|
267 |
+
|
268 |
+
"""============ LRIP structure, see Section 3.3 =============="""
|
269 |
+
|
270 |
+
"""Local MLPs."""
|
271 |
+
if n == 0:
|
272 |
+
mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0),
|
273 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
274 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
275 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion)
|
276 |
+
waitToRGB.append(mlp_output[1])
|
277 |
+
else:
|
278 |
+
mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
279 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0],
|
280 |
+
content_mlp=self.content_mlp_blocks[
|
281 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
282 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)),
|
283 |
+
start_proportion=start_proportion)
|
284 |
+
waitToRGB.append(mlp_output[1])
|
285 |
+
|
286 |
+
encoder_outputs = encoder_outputs[::-1]
|
287 |
+
output = encoder_outputs[0]
|
288 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
289 |
+
output = block(output)
|
290 |
+
output = output + skip_output
|
291 |
+
output = self.deconv_blocks[-1](output)
|
292 |
+
|
293 |
+
"""Global MLPs."""
|
294 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
295 |
+
harm_out = []
|
296 |
+
for id in range(len(waitToRGB)):
|
297 |
+
output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id],
|
298 |
+
appearance_mlp=app_mlp)
|
299 |
+
harm_out.append(output[0])
|
300 |
+
|
301 |
+
"""Optional 3D LUT prediction."""
|
302 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
303 |
+
|
304 |
+
return harm_out, fit_lut3d, lut_transform_image
|
305 |
+
|
306 |
+
def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None,
|
307 |
+
resolution=None, start_proportion=None):
|
308 |
+
|
309 |
+
activation = select_activation(opt.activation)
|
310 |
+
|
311 |
+
output = None
|
312 |
+
|
313 |
+
if content_mlp is not None:
|
314 |
+
if base_feat is not None:
|
315 |
+
coorinates = torch.cat([coorinates, base_feat], dim=2)
|
316 |
+
coorinates = lin2img(coorinates, resolution)
|
317 |
+
|
318 |
+
if hasattr(opt, 'split_resolution'):
|
319 |
+
"""
|
320 |
+
Here we crop the needed MLPs according to the region of the split input patches.
|
321 |
+
Note that this only support inferencing square images.
|
322 |
+
"""
|
323 |
+
for idx in range(len(content_mlp)):
|
324 |
+
content_mlp[idx][0] = content_mlp[idx][0][:,
|
325 |
+
(content_mlp[idx][0].shape[1] * start_proportion[0]).int():(
|
326 |
+
content_mlp[idx][0].shape[1] * start_proportion[2]).int(),
|
327 |
+
(content_mlp[idx][0].shape[2] * start_proportion[1]).int():(
|
328 |
+
content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :,
|
329 |
+
:]
|
330 |
+
content_mlp[idx][1] = content_mlp[idx][1][:,
|
331 |
+
(content_mlp[idx][1].shape[1] * start_proportion[0]).int():(
|
332 |
+
content_mlp[idx][1].shape[1] * start_proportion[2]).int(),
|
333 |
+
(content_mlp[idx][1].shape[2] * start_proportion[1]).int():(
|
334 |
+
content_mlp[idx][1].shape[2] * start_proportion[3]).int(),
|
335 |
+
:,
|
336 |
+
:]
|
337 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
338 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
339 |
+
bs = coorinates.shape[0]
|
340 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
341 |
+
nci = INR_input_dim
|
342 |
+
|
343 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
344 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
345 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
346 |
+
|
347 |
+
for id, layer in enumerate(content_mlp):
|
348 |
+
if id == 0:
|
349 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
350 |
+
output = activation(output)
|
351 |
+
else:
|
352 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
353 |
+
output = activation(output)
|
354 |
+
|
355 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
356 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
357 |
+
|
358 |
+
output_large = self.up(lin2img(output))
|
359 |
+
|
360 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
361 |
+
|
362 |
+
k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
|
363 |
+
k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
|
364 |
+
bs = coorinates.shape[0]
|
365 |
+
h_lr = w_lr = content_mlp[0][0].shape[1]
|
366 |
+
nci = INR_input_dim
|
367 |
+
|
368 |
+
"""(evaluation or not HR training) and not fullres evaluation"""
|
369 |
+
if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not (
|
370 |
+
not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train):
|
371 |
+
coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
|
372 |
+
coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
|
373 |
+
bs, h_lr, w_lr, int(k_h * k_w), nci)
|
374 |
+
|
375 |
+
for id, layer in enumerate(content_mlp):
|
376 |
+
if id == 0:
|
377 |
+
output = torch.matmul(coorinates, layer[0]) + layer[1]
|
378 |
+
output = activation(output)
|
379 |
+
else:
|
380 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
381 |
+
output = activation(output)
|
382 |
+
|
383 |
+
output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
|
384 |
+
0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
|
385 |
+
|
386 |
+
output_large = self.up(lin2img(output))
|
387 |
+
|
388 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
389 |
+
else:
|
390 |
+
coorinates = coorinates.permute(0, 2, 3, 1)
|
391 |
+
for id, layer in enumerate(content_mlp):
|
392 |
+
weigt_shape = layer[0].shape
|
393 |
+
bias_shape = layer[1].shape
|
394 |
+
layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
395 |
+
layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
|
396 |
+
layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True
|
397 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
398 |
+
layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True
|
399 |
+
else 'bilinear', padding_mode='border', align_corners=False)
|
400 |
+
layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:])
|
401 |
+
layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:])
|
402 |
+
|
403 |
+
if id == 0:
|
404 |
+
output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1]
|
405 |
+
output = activation(output)
|
406 |
+
else:
|
407 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
408 |
+
output = activation(output)
|
409 |
+
|
410 |
+
output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim)
|
411 |
+
|
412 |
+
output_large = self.up(lin2img(output, resolution))
|
413 |
+
|
414 |
+
return output_large.view(bs, -1, opt.INR_MLP_dim), output
|
415 |
+
|
416 |
+
elif appearance_mlp is not None:
|
417 |
+
output = base_feat
|
418 |
+
genMask = None
|
419 |
+
for id, layer in enumerate(appearance_mlp):
|
420 |
+
if id != len(appearance_mlp) - 1:
|
421 |
+
output = torch.matmul(output, layer[0]) + layer[1]
|
422 |
+
output = activation(output)
|
423 |
+
else:
|
424 |
+
output = torch.matmul(output, layer[0]) + layer[1] # last layer
|
425 |
+
if opt.activation == 'leakyrelu_pe':
|
426 |
+
output = torch.tanh(output)
|
427 |
+
return lin2img(output, resolution), None
|
428 |
+
|
429 |
+
def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None):
|
430 |
+
encoder_outputs = encoder_outputs[::-1]
|
431 |
+
mlp_output = None
|
432 |
+
res_w = image.shape[-1]
|
433 |
+
res_h = image.shape[-2]
|
434 |
+
coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat(
|
435 |
+
encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
|
436 |
+
|
437 |
+
if self.opt.isMoreINRInput:
|
438 |
+
coord = torch.cat(
|
439 |
+
[self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1),
|
440 |
+
mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
|
441 |
+
else:
|
442 |
+
coord = self.INR_encoding(coord, (res_h, res_w))
|
443 |
+
|
444 |
+
total = coord.clone()
|
445 |
+
|
446 |
+
interval = 10
|
447 |
+
all_intervals = math.ceil(res_h / interval)
|
448 |
+
divisible = True
|
449 |
+
if res_h / interval != res_h // interval:
|
450 |
+
divisible = False
|
451 |
+
|
452 |
+
for n in range(self.max_hidden_mlp_num):
|
453 |
+
accum_mlp_output = []
|
454 |
+
for line in range(all_intervals):
|
455 |
+
if not divisible and line == all_intervals - 1:
|
456 |
+
coord = total[:, line * interval * res_w:, :]
|
457 |
+
else:
|
458 |
+
coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
459 |
+
if n == 0:
|
460 |
+
accum_mlp_output.append(self.mlp_process(coord,
|
461 |
+
self.INR_encoding.out_dim + (
|
462 |
+
4 if self.opt.isMoreINRInput else 0),
|
463 |
+
self.opt, content_mlp=self.content_mlp_blocks[
|
464 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
465 |
+
encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
466 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
467 |
+
resolution=(interval,
|
468 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
469 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
470 |
+
|
471 |
+
else:
|
472 |
+
accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
|
473 |
+
4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:,
|
474 |
+
line * interval * res_w: (
|
475 |
+
line + 1) * interval * res_w,
|
476 |
+
:]
|
477 |
+
if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :],
|
478 |
+
content_mlp=self.content_mlp_blocks[
|
479 |
+
f"block{self.max_hidden_mlp_num - 1 - n}"](
|
480 |
+
encoder_outputs.pop(
|
481 |
+
self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
|
482 |
+
encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
|
483 |
+
resolution=(interval,
|
484 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
485 |
+
res_h - interval * (all_intervals - 1), res_w))[1])
|
486 |
+
|
487 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=1)
|
488 |
+
mlp_output = [accum_mlp_output, accum_mlp_output]
|
489 |
+
|
490 |
+
encoder_outputs = encoder_outputs[::-1]
|
491 |
+
output = encoder_outputs[0]
|
492 |
+
for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
|
493 |
+
output = block(output)
|
494 |
+
output = output + skip_output
|
495 |
+
output = self.deconv_blocks[-1](output)
|
496 |
+
|
497 |
+
app_mlp, app_params = self.appearance_mlps(output)
|
498 |
+
harm_out = []
|
499 |
+
|
500 |
+
accum_mlp_output = []
|
501 |
+
for line in range(all_intervals):
|
502 |
+
if not divisible and line == all_intervals - 1:
|
503 |
+
base = mlp_output[1][:, line * interval * res_w:, :]
|
504 |
+
else:
|
505 |
+
base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :]
|
506 |
+
|
507 |
+
accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base,
|
508 |
+
appearance_mlp=app_mlp,
|
509 |
+
resolution=(
|
510 |
+
interval,
|
511 |
+
res_w) if divisible or line != all_intervals - 1 else (
|
512 |
+
res_h - interval * (all_intervals - 1), res_w))[0])
|
513 |
+
|
514 |
+
accum_mlp_output = torch.cat(accum_mlp_output, dim=2)
|
515 |
+
harm_out.append(accum_mlp_output)
|
516 |
+
|
517 |
+
fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
|
518 |
+
|
519 |
+
return harm_out, fit_lut3d, lut_transform_image
|
model/base/ih_model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
|
6 |
+
|
7 |
+
from .ops import ScaleLayer
|
8 |
+
|
9 |
+
|
10 |
+
class IHModelWithBackbone(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model, backbone,
|
14 |
+
downsize_backbone_input=False,
|
15 |
+
mask_fusion='sum',
|
16 |
+
backbone_conv1_channels=64, opt=None
|
17 |
+
):
|
18 |
+
super(IHModelWithBackbone, self).__init__()
|
19 |
+
self.downsize_backbone_input = downsize_backbone_input
|
20 |
+
self.mask_fusion = mask_fusion
|
21 |
+
|
22 |
+
self.backbone = backbone
|
23 |
+
self.model = model
|
24 |
+
self.opt = opt
|
25 |
+
|
26 |
+
self.mask_conv = nn.Sequential(
|
27 |
+
nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
|
28 |
+
ScaleLayer(init_value=0.1, lr_mult=1)
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, image, mask, coord=None, start_proportion=None):
|
32 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
33 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
|
34 |
+
backbone_mask = torch.cat(
|
35 |
+
(torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
|
36 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
37 |
+
else:
|
38 |
+
backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
|
39 |
+
backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
|
40 |
+
1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
41 |
+
|
42 |
+
backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
|
43 |
+
backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
|
44 |
+
|
45 |
+
output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
class DeepImageHarmonization(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
depth,
|
53 |
+
norm_layer=nn.BatchNorm2d, batchnorm_from=0,
|
54 |
+
attend_from=-1,
|
55 |
+
image_fusion=False,
|
56 |
+
ch=64, max_channels=512,
|
57 |
+
backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
|
58 |
+
):
|
59 |
+
super(DeepImageHarmonization, self).__init__()
|
60 |
+
self.depth = depth
|
61 |
+
self.encoder = ConvEncoder(
|
62 |
+
depth, ch,
|
63 |
+
norm_layer, batchnorm_from, max_channels,
|
64 |
+
backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
|
65 |
+
)
|
66 |
+
self.opt = opt
|
67 |
+
if opt.INRDecode:
|
68 |
+
"See Table 2 in the paper to test with different INR decoders' structures."
|
69 |
+
self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
|
70 |
+
else:
|
71 |
+
"Baseline: https://github.com/SamsungLabs/image_harmonization"
|
72 |
+
self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
|
73 |
+
|
74 |
+
def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
|
75 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
76 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
|
77 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
|
78 |
+
else:
|
79 |
+
x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
|
80 |
+
torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
|
81 |
+
|
82 |
+
intermediates = self.encoder(x, backbone_features)
|
83 |
+
|
84 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
85 |
+
output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
|
86 |
+
else:
|
87 |
+
output = self.decoder(intermediates, image, mask)
|
88 |
+
return output
|
model/base/ops.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class SimpleInputFusion(nn.Module):
|
9 |
+
def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
|
10 |
+
super(SimpleInputFusion, self).__init__()
|
11 |
+
|
12 |
+
self.fusion_conv = nn.Sequential(
|
13 |
+
nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
|
14 |
+
nn.LeakyReLU(negative_slope=0.2),
|
15 |
+
norm_layer(ch),
|
16 |
+
nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
|
17 |
+
)
|
18 |
+
|
19 |
+
def forward(self, image, additional_input):
|
20 |
+
return self.fusion_conv(torch.cat((image, additional_input), dim=1))
|
21 |
+
|
22 |
+
|
23 |
+
class MaskedChannelAttention(nn.Module):
|
24 |
+
def __init__(self, in_channels, *args, **kwargs):
|
25 |
+
super(MaskedChannelAttention, self).__init__()
|
26 |
+
self.global_max_pool = MaskedGlobalMaxPool2d()
|
27 |
+
self.global_avg_pool = FastGlobalAvgPool2d()
|
28 |
+
|
29 |
+
intermediate_channels_count = max(in_channels // 16, 8)
|
30 |
+
self.attention_transform = nn.Sequential(
|
31 |
+
nn.Linear(3 * in_channels, intermediate_channels_count),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
nn.Linear(intermediate_channels_count, in_channels),
|
34 |
+
nn.Sigmoid(),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x, mask):
|
38 |
+
if mask.shape[2:] != x.shape[:2]:
|
39 |
+
mask = nn.functional.interpolate(
|
40 |
+
mask, size=x.size()[-2:],
|
41 |
+
mode='bilinear', align_corners=True
|
42 |
+
)
|
43 |
+
pooled_x = torch.cat([
|
44 |
+
self.global_max_pool(x, mask),
|
45 |
+
self.global_avg_pool(x)
|
46 |
+
], dim=1)
|
47 |
+
channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
|
48 |
+
|
49 |
+
return channel_attention_weights * x
|
50 |
+
|
51 |
+
|
52 |
+
class MaskedGlobalMaxPool2d(nn.Module):
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
self.global_max_pool = FastGlobalMaxPool2d()
|
56 |
+
|
57 |
+
def forward(self, x, mask):
|
58 |
+
return torch.cat((
|
59 |
+
self.global_max_pool(x * mask),
|
60 |
+
self.global_max_pool(x * (1.0 - mask))
|
61 |
+
), dim=1)
|
62 |
+
|
63 |
+
|
64 |
+
class FastGlobalAvgPool2d(nn.Module):
|
65 |
+
def __init__(self):
|
66 |
+
super(FastGlobalAvgPool2d, self).__init__()
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
in_size = x.size()
|
70 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
71 |
+
|
72 |
+
|
73 |
+
class FastGlobalMaxPool2d(nn.Module):
|
74 |
+
def __init__(self):
|
75 |
+
super(FastGlobalMaxPool2d, self).__init__()
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
in_size = x.size()
|
79 |
+
return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
|
80 |
+
|
81 |
+
|
82 |
+
class ScaleLayer(nn.Module):
|
83 |
+
def __init__(self, init_value=1.0, lr_mult=1):
|
84 |
+
super().__init__()
|
85 |
+
self.lr_mult = lr_mult
|
86 |
+
self.scale = nn.Parameter(
|
87 |
+
torch.full((1,), init_value / lr_mult, dtype=torch.float32)
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
scale = torch.abs(self.scale * self.lr_mult)
|
92 |
+
return x * scale
|
93 |
+
|
94 |
+
|
95 |
+
class FeaturesConnector(nn.Module):
|
96 |
+
def __init__(self, mode, in_channels, feature_channels, out_channels):
|
97 |
+
super(FeaturesConnector, self).__init__()
|
98 |
+
self.mode = mode if feature_channels else ''
|
99 |
+
|
100 |
+
if self.mode == 'catc':
|
101 |
+
self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
|
102 |
+
elif self.mode == 'sum':
|
103 |
+
self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
|
104 |
+
|
105 |
+
self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
|
106 |
+
|
107 |
+
def forward(self, x, features):
|
108 |
+
if self.mode == 'cat':
|
109 |
+
return torch.cat((x, features), 1)
|
110 |
+
if self.mode == 'catc':
|
111 |
+
return self.reduce_conv(torch.cat((x, features), 1))
|
112 |
+
if self.mode == 'sum':
|
113 |
+
return self.reduce_conv(features) + x
|
114 |
+
return x
|
115 |
+
|
116 |
+
def extra_repr(self):
|
117 |
+
return self.mode
|
118 |
+
|
119 |
+
|
120 |
+
class PosEncodingNeRF(nn.Module):
|
121 |
+
def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
|
122 |
+
super().__init__()
|
123 |
+
|
124 |
+
self.in_features = in_features
|
125 |
+
|
126 |
+
if self.in_features == 3:
|
127 |
+
self.num_frequencies = 10
|
128 |
+
elif self.in_features == 2:
|
129 |
+
assert sidelength is not None
|
130 |
+
if isinstance(sidelength, int):
|
131 |
+
sidelength = (sidelength, sidelength)
|
132 |
+
self.num_frequencies = 4
|
133 |
+
if use_nyquist:
|
134 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
|
135 |
+
elif self.in_features == 1:
|
136 |
+
assert fn_samples is not None
|
137 |
+
self.num_frequencies = 4
|
138 |
+
if use_nyquist:
|
139 |
+
self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
|
140 |
+
|
141 |
+
self.out_dim = in_features + 2 * in_features * self.num_frequencies
|
142 |
+
|
143 |
+
def get_num_frequencies_nyquist(self, samples):
|
144 |
+
nyquist_rate = 1 / (2 * (2 * 1 / samples))
|
145 |
+
return int(math.floor(math.log(nyquist_rate, 2)))
|
146 |
+
|
147 |
+
def forward(self, coords):
|
148 |
+
coords = coords.view(coords.shape[0], -1, self.in_features)
|
149 |
+
|
150 |
+
coords_pos_enc = coords
|
151 |
+
for i in range(self.num_frequencies):
|
152 |
+
for j in range(self.in_features):
|
153 |
+
c = coords[..., j]
|
154 |
+
|
155 |
+
sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
|
156 |
+
cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
|
157 |
+
|
158 |
+
coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
|
159 |
+
|
160 |
+
return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
|
161 |
+
|
162 |
+
|
163 |
+
class RandomFourier(nn.Module):
|
164 |
+
def __init__(self, std_scale, embedding_length, device):
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
|
168 |
+
self.embed = self.embed.to(device)
|
169 |
+
|
170 |
+
self.out_dim = embedding_length * 2 + 2
|
171 |
+
|
172 |
+
def forward(self, coords):
|
173 |
+
coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
|
174 |
+
torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
|
175 |
+
|
176 |
+
return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
|
177 |
+
|
178 |
+
|
179 |
+
class CIPS_embed(nn.Module):
|
180 |
+
def __init__(self, size, embedding_length):
|
181 |
+
super().__init__()
|
182 |
+
self.fourier_embed = ConstantInput(size, embedding_length)
|
183 |
+
self.predict_embed = Predict_embed(embedding_length)
|
184 |
+
self.out_dim = embedding_length * 2 + 2
|
185 |
+
|
186 |
+
def forward(self, coord, res=None):
|
187 |
+
x = self.predict_embed(coord)
|
188 |
+
y = self.fourier_embed(x, coord, res)
|
189 |
+
|
190 |
+
return torch.cat([coord, x, y], dim=-1)
|
191 |
+
|
192 |
+
|
193 |
+
class Predict_embed(nn.Module):
|
194 |
+
def __init__(self, embedding_length):
|
195 |
+
super(Predict_embed, self).__init__()
|
196 |
+
self.ffm = nn.Linear(2, embedding_length, bias=True)
|
197 |
+
nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
|
198 |
+
|
199 |
+
def forward(self, x):
|
200 |
+
x = self.ffm(x)
|
201 |
+
x = torch.sin(x)
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
class ConstantInput(nn.Module):
|
206 |
+
def __init__(self, size, channel):
|
207 |
+
super().__init__()
|
208 |
+
|
209 |
+
self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
|
210 |
+
|
211 |
+
def forward(self, input, coord, resolution=None):
|
212 |
+
batch = input.shape[0]
|
213 |
+
out = self.input.repeat(batch, 1, 1)
|
214 |
+
|
215 |
+
if coord.shape[1] != self.input.shape[1]:
|
216 |
+
x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
|
217 |
+
int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
|
218 |
+
|
219 |
+
if resolution is None:
|
220 |
+
grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
|
221 |
+
else:
|
222 |
+
grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
|
223 |
+
|
224 |
+
out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
|
225 |
+
|
226 |
+
out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
|
227 |
+
|
228 |
+
return out
|
229 |
+
|
230 |
+
|
231 |
+
class INRGAN_embed(nn.Module):
|
232 |
+
def __init__(self, resolution: int, w_dim=None):
|
233 |
+
super().__init__()
|
234 |
+
|
235 |
+
self.resolution = resolution
|
236 |
+
self.res_cfg = {"log_emb_size": 32,
|
237 |
+
"random_emb_size": 32,
|
238 |
+
"const_emb_size": 64,
|
239 |
+
"use_cosine": True}
|
240 |
+
self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
|
241 |
+
self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
|
242 |
+
self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
|
243 |
+
self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
|
244 |
+
self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
|
245 |
+
self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
|
246 |
+
self.use_cosine = self.res_cfg.get('use_cosine', False)
|
247 |
+
|
248 |
+
if self.log_emb_size > 0:
|
249 |
+
self.register_buffer('log_basis', generate_logarithmic_basis(
|
250 |
+
resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
|
251 |
+
|
252 |
+
if self.random_emb_size > 0:
|
253 |
+
self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
|
254 |
+
|
255 |
+
if self.shared_emb_size > 0:
|
256 |
+
self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
|
257 |
+
|
258 |
+
if self.predictable_emb_size > 0:
|
259 |
+
self.W_size = self.predictable_emb_size * self.cfg.coord_dim
|
260 |
+
self.b_size = self.predictable_emb_size
|
261 |
+
self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
|
262 |
+
|
263 |
+
if self.const_emb_size > 0:
|
264 |
+
self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
|
265 |
+
|
266 |
+
self.out_dim = self.get_total_dim() + 2
|
267 |
+
|
268 |
+
def sample_w_matrix(self, shape, scale: float):
|
269 |
+
return torch.randn(shape) * scale
|
270 |
+
|
271 |
+
def get_total_dim(self) -> int:
|
272 |
+
total_dim = 0
|
273 |
+
if self.log_emb_size > 0:
|
274 |
+
total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
|
275 |
+
total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
|
276 |
+
total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
|
277 |
+
total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
|
278 |
+
total_dim += self.const_emb_size
|
279 |
+
|
280 |
+
return total_dim
|
281 |
+
|
282 |
+
def forward(self, raw_coords, w=None):
|
283 |
+
batch_size, img_size, in_channels = raw_coords.shape
|
284 |
+
|
285 |
+
raw_embs = []
|
286 |
+
|
287 |
+
if self.log_emb_size > 0:
|
288 |
+
log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
|
289 |
+
raw_log_embs = torch.matmul(raw_coords, log_bases)
|
290 |
+
raw_embs.append(raw_log_embs)
|
291 |
+
|
292 |
+
if self.random_emb_size > 0:
|
293 |
+
random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
294 |
+
raw_random_embs = torch.matmul(raw_coords, random_bases)
|
295 |
+
raw_embs.append(raw_random_embs)
|
296 |
+
|
297 |
+
if self.shared_emb_size > 0:
|
298 |
+
shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
|
299 |
+
raw_shared_embs = torch.matmul(raw_coords, shared_bases)
|
300 |
+
raw_embs.append(raw_shared_embs)
|
301 |
+
|
302 |
+
if self.predictable_emb_size > 0:
|
303 |
+
mod = self.affine(w)
|
304 |
+
W = self.fourier_scale * mod[:, :self.W_size]
|
305 |
+
W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
|
306 |
+
bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
|
307 |
+
raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
|
308 |
+
raw_embs.append(raw_predictable_embs)
|
309 |
+
|
310 |
+
if len(raw_embs) > 0:
|
311 |
+
raw_embs = torch.cat(raw_embs, dim=-1)
|
312 |
+
raw_embs = raw_embs.contiguous()
|
313 |
+
out = raw_embs.sin()
|
314 |
+
|
315 |
+
if self.use_cosine:
|
316 |
+
out = torch.cat([out, raw_embs.cos()], dim=-1)
|
317 |
+
|
318 |
+
if self.const_emb_size > 0:
|
319 |
+
const_embs = self.const_embs.repeat([batch_size, 1, 1])
|
320 |
+
const_embs = const_embs
|
321 |
+
out = torch.cat([out, const_embs], dim=-1)
|
322 |
+
|
323 |
+
return torch.cat([raw_coords, out], dim=-1)
|
324 |
+
|
325 |
+
|
326 |
+
def generate_logarithmic_basis(
|
327 |
+
resolution,
|
328 |
+
max_num_feats,
|
329 |
+
remove_lowest_freq: bool = False,
|
330 |
+
use_diagonal: bool = True):
|
331 |
+
"""
|
332 |
+
Generates a directional logarithmic basis with the following directions:
|
333 |
+
- horizontal
|
334 |
+
- vertical
|
335 |
+
- main diagonal
|
336 |
+
- anti-diagonal
|
337 |
+
"""
|
338 |
+
max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
|
339 |
+
bases = [
|
340 |
+
generate_horizontal_basis(max_num_feats_per_direction),
|
341 |
+
generate_vertical_basis(max_num_feats_per_direction),
|
342 |
+
]
|
343 |
+
|
344 |
+
if use_diagonal:
|
345 |
+
bases.extend([
|
346 |
+
generate_diag_main_basis(max_num_feats_per_direction),
|
347 |
+
generate_anti_diag_basis(max_num_feats_per_direction),
|
348 |
+
])
|
349 |
+
|
350 |
+
if remove_lowest_freq:
|
351 |
+
bases = [b[1:] for b in bases]
|
352 |
+
|
353 |
+
# If we do not fit into `max_num_feats`, then trying to remove the features in the order:
|
354 |
+
# 1) anti-diagonal 2) main-diagonal
|
355 |
+
# while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
|
356 |
+
# bases = bases[:-1]
|
357 |
+
|
358 |
+
basis = torch.cat(bases, dim=0)
|
359 |
+
|
360 |
+
# If we still do not fit, then let's remove each second feature,
|
361 |
+
# then each third, each forth and so on
|
362 |
+
# We cannot drop the whole horizontal or vertical direction since otherwise
|
363 |
+
# model won't be able to locate the position
|
364 |
+
# (unless the previously computed embeddings encode the position)
|
365 |
+
# while basis.shape[0] > max_num_feats:
|
366 |
+
# num_exceeding_feats = basis.shape[0] - max_num_feats
|
367 |
+
# basis = basis[::2]
|
368 |
+
|
369 |
+
assert basis.shape[0] <= max_num_feats, \
|
370 |
+
f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
|
371 |
+
|
372 |
+
return basis
|
373 |
+
|
374 |
+
|
375 |
+
def generate_horizontal_basis(num_feats: int):
|
376 |
+
return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
|
377 |
+
|
378 |
+
|
379 |
+
def generate_vertical_basis(num_feats: int):
|
380 |
+
return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
|
381 |
+
|
382 |
+
|
383 |
+
def generate_diag_main_basis(num_feats: int):
|
384 |
+
return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
385 |
+
|
386 |
+
|
387 |
+
def generate_anti_diag_basis(num_feats: int):
|
388 |
+
return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
|
389 |
+
|
390 |
+
|
391 |
+
def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
|
392 |
+
period_coef = 2.0 * np.pi / period_length
|
393 |
+
basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
|
394 |
+
powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
|
395 |
+
result = basis * powers * period_coef # [num_feats, 2]
|
396 |
+
|
397 |
+
return result.float()
|
model/build_model.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from .backbone import build_backbone
|
3 |
+
|
4 |
+
|
5 |
+
class build_model(nn.Module):
|
6 |
+
def __init__(self, opt):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.opt = opt
|
10 |
+
self.backbone = build_backbone('baseline', opt)
|
11 |
+
|
12 |
+
def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
|
13 |
+
if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
|
14 |
+
"""
|
15 |
+
For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
|
16 |
+
here we need to pass in the coordinates of the cropped regions.
|
17 |
+
"""
|
18 |
+
extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
|
19 |
+
else:
|
20 |
+
extracted_features = self.backbone(composite_image, mask)
|
21 |
+
|
22 |
+
if self.opt.INRDecode:
|
23 |
+
return extracted_features
|
24 |
+
return None, None, extracted_features
|
model/hrnetv2/__init__.py
ADDED
File without changes
|
model/hrnetv2/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (181 Bytes). View file
|
|
model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
model/hrnetv2/__pycache__/modifiers.cpython-38.pyc
ADDED
Binary file (704 Bytes). View file
|
|
model/hrnetv2/__pycache__/ocr.cpython-38.pyc
ADDED
Binary file (4.54 kB). View file
|
|
model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc
ADDED
Binary file (7.54 kB). View file
|
|
model/hrnetv2/hrnet_ocr.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch._utils
|
7 |
+
|
8 |
+
from .ocr import SpatialOCR_Module, SpatialGather_Module
|
9 |
+
from .resnetv1b import BasicBlockV1b, BottleneckV1b
|
10 |
+
|
11 |
+
relu_inplace = True
|
12 |
+
|
13 |
+
|
14 |
+
class HighResolutionModule(nn.Module):
|
15 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
16 |
+
num_channels, fuse_method,multi_scale_output=True,
|
17 |
+
norm_layer=nn.BatchNorm2d, align_corners=True):
|
18 |
+
super(HighResolutionModule, self).__init__()
|
19 |
+
self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
|
20 |
+
|
21 |
+
self.num_inchannels = num_inchannels
|
22 |
+
self.fuse_method = fuse_method
|
23 |
+
self.num_branches = num_branches
|
24 |
+
self.norm_layer = norm_layer
|
25 |
+
self.align_corners = align_corners
|
26 |
+
|
27 |
+
self.multi_scale_output = multi_scale_output
|
28 |
+
|
29 |
+
self.branches = self._make_branches(
|
30 |
+
num_branches, blocks, num_blocks, num_channels)
|
31 |
+
self.fuse_layers = self._make_fuse_layers()
|
32 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
33 |
+
|
34 |
+
def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
|
35 |
+
if num_branches != len(num_blocks):
|
36 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
37 |
+
num_branches, len(num_blocks))
|
38 |
+
raise ValueError(error_msg)
|
39 |
+
|
40 |
+
if num_branches != len(num_channels):
|
41 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
42 |
+
num_branches, len(num_channels))
|
43 |
+
raise ValueError(error_msg)
|
44 |
+
|
45 |
+
if num_branches != len(num_inchannels):
|
46 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
47 |
+
num_branches, len(num_inchannels))
|
48 |
+
raise ValueError(error_msg)
|
49 |
+
|
50 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
51 |
+
stride=1):
|
52 |
+
downsample = None
|
53 |
+
if stride != 1 or \
|
54 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
55 |
+
downsample = nn.Sequential(
|
56 |
+
nn.Conv2d(self.num_inchannels[branch_index],
|
57 |
+
num_channels[branch_index] * block.expansion,
|
58 |
+
kernel_size=1, stride=stride, bias=False),
|
59 |
+
self.norm_layer(num_channels[branch_index] * block.expansion),
|
60 |
+
)
|
61 |
+
|
62 |
+
layers = []
|
63 |
+
layers.append(block(self.num_inchannels[branch_index],
|
64 |
+
num_channels[branch_index], stride,
|
65 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
66 |
+
self.num_inchannels[branch_index] = \
|
67 |
+
num_channels[branch_index] * block.expansion
|
68 |
+
for i in range(1, num_blocks[branch_index]):
|
69 |
+
layers.append(block(self.num_inchannels[branch_index],
|
70 |
+
num_channels[branch_index],
|
71 |
+
norm_layer=self.norm_layer))
|
72 |
+
|
73 |
+
return nn.Sequential(*layers)
|
74 |
+
|
75 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
76 |
+
branches = []
|
77 |
+
|
78 |
+
for i in range(num_branches):
|
79 |
+
branches.append(
|
80 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
81 |
+
|
82 |
+
return nn.ModuleList(branches)
|
83 |
+
|
84 |
+
def _make_fuse_layers(self):
|
85 |
+
if self.num_branches == 1:
|
86 |
+
return None
|
87 |
+
|
88 |
+
num_branches = self.num_branches
|
89 |
+
num_inchannels = self.num_inchannels
|
90 |
+
fuse_layers = []
|
91 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
92 |
+
fuse_layer = []
|
93 |
+
for j in range(num_branches):
|
94 |
+
if j > i:
|
95 |
+
fuse_layer.append(nn.Sequential(
|
96 |
+
nn.Conv2d(in_channels=num_inchannels[j],
|
97 |
+
out_channels=num_inchannels[i],
|
98 |
+
kernel_size=1,
|
99 |
+
bias=False),
|
100 |
+
self.norm_layer(num_inchannels[i])))
|
101 |
+
elif j == i:
|
102 |
+
fuse_layer.append(None)
|
103 |
+
else:
|
104 |
+
conv3x3s = []
|
105 |
+
for k in range(i - j):
|
106 |
+
if k == i - j - 1:
|
107 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
108 |
+
conv3x3s.append(nn.Sequential(
|
109 |
+
nn.Conv2d(num_inchannels[j],
|
110 |
+
num_outchannels_conv3x3,
|
111 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
112 |
+
self.norm_layer(num_outchannels_conv3x3)))
|
113 |
+
else:
|
114 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
115 |
+
conv3x3s.append(nn.Sequential(
|
116 |
+
nn.Conv2d(num_inchannels[j],
|
117 |
+
num_outchannels_conv3x3,
|
118 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
119 |
+
self.norm_layer(num_outchannels_conv3x3),
|
120 |
+
nn.ReLU(inplace=relu_inplace)))
|
121 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
122 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
123 |
+
|
124 |
+
return nn.ModuleList(fuse_layers)
|
125 |
+
|
126 |
+
def get_num_inchannels(self):
|
127 |
+
return self.num_inchannels
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
if self.num_branches == 1:
|
131 |
+
return [self.branches[0](x[0])]
|
132 |
+
|
133 |
+
for i in range(self.num_branches):
|
134 |
+
x[i] = self.branches[i](x[i])
|
135 |
+
|
136 |
+
x_fuse = []
|
137 |
+
for i in range(len(self.fuse_layers)):
|
138 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
139 |
+
for j in range(1, self.num_branches):
|
140 |
+
if i == j:
|
141 |
+
y = y + x[j]
|
142 |
+
elif j > i:
|
143 |
+
width_output = x[i].shape[-1]
|
144 |
+
height_output = x[i].shape[-2]
|
145 |
+
y = y + F.interpolate(
|
146 |
+
self.fuse_layers[i][j](x[j]),
|
147 |
+
size=[height_output, width_output],
|
148 |
+
mode='bilinear', align_corners=self.align_corners)
|
149 |
+
else:
|
150 |
+
y = y + self.fuse_layers[i][j](x[j])
|
151 |
+
x_fuse.append(self.relu(y))
|
152 |
+
|
153 |
+
return x_fuse
|
154 |
+
|
155 |
+
|
156 |
+
class HighResolutionNet(nn.Module):
|
157 |
+
def __init__(self, width, num_classes, ocr_width=256, small=False,
|
158 |
+
norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
|
159 |
+
super(HighResolutionNet, self).__init__()
|
160 |
+
self.opt = opt
|
161 |
+
self.norm_layer = norm_layer
|
162 |
+
self.width = width
|
163 |
+
self.ocr_width = ocr_width
|
164 |
+
self.ocr_on = ocr_width > 0
|
165 |
+
self.align_corners = align_corners
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
168 |
+
self.bn1 = norm_layer(64)
|
169 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
170 |
+
self.bn2 = norm_layer(64)
|
171 |
+
self.relu = nn.ReLU(inplace=relu_inplace)
|
172 |
+
|
173 |
+
num_blocks = 2 if small else 4
|
174 |
+
|
175 |
+
stage1_num_channels = 64
|
176 |
+
self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
|
177 |
+
stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
|
178 |
+
|
179 |
+
self.stage2_num_branches = 2
|
180 |
+
num_channels = [width, 2 * width]
|
181 |
+
num_inchannels = [
|
182 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
183 |
+
self.transition1 = self._make_transition_layer(
|
184 |
+
[stage1_out_channel], num_inchannels)
|
185 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
186 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
|
187 |
+
num_blocks=2 * [num_blocks], num_channels=num_channels)
|
188 |
+
|
189 |
+
self.stage3_num_branches = 3
|
190 |
+
num_channels = [width, 2 * width, 4 * width]
|
191 |
+
num_inchannels = [
|
192 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
193 |
+
self.transition2 = self._make_transition_layer(
|
194 |
+
pre_stage_channels, num_inchannels)
|
195 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
196 |
+
BasicBlockV1b, num_inchannels=num_inchannels,
|
197 |
+
num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
|
198 |
+
num_blocks=3 * [num_blocks], num_channels=num_channels)
|
199 |
+
|
200 |
+
self.stage4_num_branches = 4
|
201 |
+
num_channels = [width, 2 * width, 4 * width, 8 * width]
|
202 |
+
num_inchannels = [
|
203 |
+
num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
|
204 |
+
self.transition3 = self._make_transition_layer(
|
205 |
+
pre_stage_channels, num_inchannels)
|
206 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
207 |
+
BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
|
208 |
+
num_branches=self.stage4_num_branches,
|
209 |
+
num_blocks=4 * [num_blocks], num_channels=num_channels)
|
210 |
+
|
211 |
+
if self.ocr_on:
|
212 |
+
last_inp_channels = np.int(np.sum(pre_stage_channels))
|
213 |
+
ocr_mid_channels = 2 * ocr_width
|
214 |
+
ocr_key_channels = ocr_width
|
215 |
+
|
216 |
+
self.conv3x3_ocr = nn.Sequential(
|
217 |
+
nn.Conv2d(last_inp_channels, ocr_mid_channels,
|
218 |
+
kernel_size=3, stride=1, padding=1),
|
219 |
+
norm_layer(ocr_mid_channels),
|
220 |
+
nn.ReLU(inplace=relu_inplace),
|
221 |
+
)
|
222 |
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
223 |
+
|
224 |
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
225 |
+
key_channels=ocr_key_channels,
|
226 |
+
out_channels=ocr_mid_channels,
|
227 |
+
scale=1,
|
228 |
+
dropout=0.05,
|
229 |
+
norm_layer=norm_layer,
|
230 |
+
align_corners=align_corners, opt=opt)
|
231 |
+
|
232 |
+
def _make_transition_layer(
|
233 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
234 |
+
num_branches_cur = len(num_channels_cur_layer)
|
235 |
+
num_branches_pre = len(num_channels_pre_layer)
|
236 |
+
|
237 |
+
transition_layers = []
|
238 |
+
for i in range(num_branches_cur):
|
239 |
+
if i < num_branches_pre:
|
240 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
241 |
+
transition_layers.append(nn.Sequential(
|
242 |
+
nn.Conv2d(num_channels_pre_layer[i],
|
243 |
+
num_channels_cur_layer[i],
|
244 |
+
kernel_size=3,
|
245 |
+
stride=1,
|
246 |
+
padding=1,
|
247 |
+
bias=False),
|
248 |
+
self.norm_layer(num_channels_cur_layer[i]),
|
249 |
+
nn.ReLU(inplace=relu_inplace)))
|
250 |
+
else:
|
251 |
+
transition_layers.append(None)
|
252 |
+
else:
|
253 |
+
conv3x3s = []
|
254 |
+
for j in range(i + 1 - num_branches_pre):
|
255 |
+
inchannels = num_channels_pre_layer[-1]
|
256 |
+
outchannels = num_channels_cur_layer[i] \
|
257 |
+
if j == i - num_branches_pre else inchannels
|
258 |
+
conv3x3s.append(nn.Sequential(
|
259 |
+
nn.Conv2d(inchannels, outchannels,
|
260 |
+
kernel_size=3, stride=2, padding=1, bias=False),
|
261 |
+
self.norm_layer(outchannels),
|
262 |
+
nn.ReLU(inplace=relu_inplace)))
|
263 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
264 |
+
|
265 |
+
return nn.ModuleList(transition_layers)
|
266 |
+
|
267 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
268 |
+
downsample = None
|
269 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
270 |
+
downsample = nn.Sequential(
|
271 |
+
nn.Conv2d(inplanes, planes * block.expansion,
|
272 |
+
kernel_size=1, stride=stride, bias=False),
|
273 |
+
self.norm_layer(planes * block.expansion),
|
274 |
+
)
|
275 |
+
|
276 |
+
layers = []
|
277 |
+
layers.append(block(inplanes, planes, stride,
|
278 |
+
downsample=downsample, norm_layer=self.norm_layer))
|
279 |
+
inplanes = planes * block.expansion
|
280 |
+
for i in range(1, blocks):
|
281 |
+
layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
|
282 |
+
|
283 |
+
return nn.Sequential(*layers)
|
284 |
+
|
285 |
+
def _make_stage(self, block, num_inchannels,
|
286 |
+
num_modules, num_branches, num_blocks, num_channels,
|
287 |
+
fuse_method='SUM',
|
288 |
+
multi_scale_output=True):
|
289 |
+
modules = []
|
290 |
+
for i in range(num_modules):
|
291 |
+
# multi_scale_output is only used last module
|
292 |
+
if not multi_scale_output and i == num_modules - 1:
|
293 |
+
reset_multi_scale_output = False
|
294 |
+
else:
|
295 |
+
reset_multi_scale_output = True
|
296 |
+
modules.append(
|
297 |
+
HighResolutionModule(num_branches,
|
298 |
+
block,
|
299 |
+
num_blocks,
|
300 |
+
num_inchannels,
|
301 |
+
num_channels,
|
302 |
+
fuse_method,
|
303 |
+
reset_multi_scale_output,
|
304 |
+
norm_layer=self.norm_layer,
|
305 |
+
align_corners=self.align_corners)
|
306 |
+
)
|
307 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
308 |
+
|
309 |
+
return nn.Sequential(*modules), num_inchannels
|
310 |
+
|
311 |
+
def forward(self, x, mask=None, additional_features=None):
|
312 |
+
hrnet_feats = self.compute_hrnet_feats(x, additional_features)
|
313 |
+
if not self.ocr_on:
|
314 |
+
return hrnet_feats,
|
315 |
+
|
316 |
+
ocr_feats = self.conv3x3_ocr(hrnet_feats)
|
317 |
+
mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
|
318 |
+
context = self.ocr_gather_head(ocr_feats, mask)
|
319 |
+
ocr_feats = self.ocr_distri_head(ocr_feats, context)
|
320 |
+
return ocr_feats,
|
321 |
+
|
322 |
+
def compute_hrnet_feats(self, x, additional_features, return_list=False):
|
323 |
+
x = self.compute_pre_stage_features(x, additional_features)
|
324 |
+
x = self.layer1(x)
|
325 |
+
|
326 |
+
x_list = []
|
327 |
+
for i in range(self.stage2_num_branches):
|
328 |
+
if self.transition1[i] is not None:
|
329 |
+
x_list.append(self.transition1[i](x))
|
330 |
+
else:
|
331 |
+
x_list.append(x)
|
332 |
+
y_list = self.stage2(x_list)
|
333 |
+
|
334 |
+
x_list = []
|
335 |
+
for i in range(self.stage3_num_branches):
|
336 |
+
if self.transition2[i] is not None:
|
337 |
+
if i < self.stage2_num_branches:
|
338 |
+
x_list.append(self.transition2[i](y_list[i]))
|
339 |
+
else:
|
340 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
341 |
+
else:
|
342 |
+
x_list.append(y_list[i])
|
343 |
+
y_list = self.stage3(x_list)
|
344 |
+
|
345 |
+
x_list = []
|
346 |
+
for i in range(self.stage4_num_branches):
|
347 |
+
if self.transition3[i] is not None:
|
348 |
+
if i < self.stage3_num_branches:
|
349 |
+
x_list.append(self.transition3[i](y_list[i]))
|
350 |
+
else:
|
351 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
352 |
+
else:
|
353 |
+
x_list.append(y_list[i])
|
354 |
+
x = self.stage4(x_list)
|
355 |
+
|
356 |
+
if return_list:
|
357 |
+
return x
|
358 |
+
|
359 |
+
# Upsampling
|
360 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
361 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w),
|
362 |
+
mode='bilinear', align_corners=self.align_corners)
|
363 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w),
|
364 |
+
mode='bilinear', align_corners=self.align_corners)
|
365 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w),
|
366 |
+
mode='bilinear', align_corners=self.align_corners)
|
367 |
+
|
368 |
+
return torch.cat([x[0], x1, x2, x3], 1)
|
369 |
+
|
370 |
+
def compute_pre_stage_features(self, x, additional_features):
|
371 |
+
x = self.conv1(x)
|
372 |
+
x = self.bn1(x)
|
373 |
+
x = self.relu(x)
|
374 |
+
if additional_features is not None:
|
375 |
+
x = x + additional_features
|
376 |
+
x = self.conv2(x)
|
377 |
+
x = self.bn2(x)
|
378 |
+
return self.relu(x)
|
379 |
+
|
380 |
+
def load_pretrained_weights(self, pretrained_path=''):
|
381 |
+
model_dict = self.state_dict()
|
382 |
+
|
383 |
+
if not os.path.exists(pretrained_path):
|
384 |
+
print(f'\nFile "{pretrained_path}" does not exist.')
|
385 |
+
print('You need to specify the correct path to the pre-trained weights.\n'
|
386 |
+
'You can download the weights for HRNet from the repository:\n'
|
387 |
+
'https://github.com/HRNet/HRNet-Image-Classification')
|
388 |
+
exit(1)
|
389 |
+
pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
|
390 |
+
pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
|
391 |
+
pretrained_dict.items()}
|
392 |
+
params_count = len(pretrained_dict)
|
393 |
+
|
394 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
395 |
+
if k in model_dict.keys()}
|
396 |
+
|
397 |
+
print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
|
398 |
+
|
399 |
+
model_dict.update(pretrained_dict)
|
400 |
+
self.load_state_dict(model_dict)
|
model/hrnetv2/modifiers.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class LRMult(object):
|
4 |
+
def __init__(self, lr_mult=1.):
|
5 |
+
self.lr_mult = lr_mult
|
6 |
+
|
7 |
+
def __call__(self, m):
|
8 |
+
if getattr(m, 'weight', None) is not None:
|
9 |
+
m.weight.lr_mult = self.lr_mult
|
10 |
+
if getattr(m, 'bias', None) is not None:
|
11 |
+
m.bias.lr_mult = self.lr_mult
|
model/hrnetv2/ocr.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch._utils
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class SpatialGather_Module(nn.Module):
|
8 |
+
"""
|
9 |
+
Aggregate the context features according to the initial
|
10 |
+
predicted probability distribution.
|
11 |
+
Employ the soft-weighted method to aggregate the context.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, cls_num=0, scale=1):
|
15 |
+
super(SpatialGather_Module, self).__init__()
|
16 |
+
self.cls_num = cls_num
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
def forward(self, feats, probs):
|
20 |
+
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
|
21 |
+
probs = probs.view(batch_size, c, -1)
|
22 |
+
feats = feats.view(batch_size, feats.size(1), -1)
|
23 |
+
feats = feats.permute(0, 2, 1) # batch x hw x c
|
24 |
+
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
25 |
+
ocr_context = torch.matmul(probs, feats) \
|
26 |
+
.permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
|
27 |
+
return ocr_context
|
28 |
+
|
29 |
+
|
30 |
+
class SpatialOCR_Module(nn.Module):
|
31 |
+
"""
|
32 |
+
Implementation of the OCR module:
|
33 |
+
We aggregate the global object representation to update the representation for each pixel.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_channels,
|
38 |
+
key_channels,
|
39 |
+
out_channels,
|
40 |
+
scale=1,
|
41 |
+
dropout=0.1,
|
42 |
+
norm_layer=nn.BatchNorm2d,
|
43 |
+
align_corners=True, opt=None):
|
44 |
+
super(SpatialOCR_Module, self).__init__()
|
45 |
+
self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
|
46 |
+
norm_layer, align_corners)
|
47 |
+
_in_channels = 2 * in_channels
|
48 |
+
self.conv_bn_dropout = nn.Sequential(
|
49 |
+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
|
50 |
+
nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
|
51 |
+
nn.Dropout2d(dropout)
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, feats, proxy_feats):
|
55 |
+
context = self.object_context_block(feats, proxy_feats)
|
56 |
+
|
57 |
+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
58 |
+
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class ObjectAttentionBlock2D(nn.Module):
|
63 |
+
'''
|
64 |
+
The basic implementation for object context block
|
65 |
+
Input:
|
66 |
+
N X C X H X W
|
67 |
+
Parameters:
|
68 |
+
in_channels : the dimension of the input feature map
|
69 |
+
key_channels : the dimension after the key/query transform
|
70 |
+
scale : choose the scale to downsample the input feature maps (save memory cost)
|
71 |
+
bn_type : specify the bn type
|
72 |
+
Return:
|
73 |
+
N X C X H X W
|
74 |
+
'''
|
75 |
+
|
76 |
+
def __init__(self,
|
77 |
+
in_channels,
|
78 |
+
key_channels,
|
79 |
+
scale=1,
|
80 |
+
norm_layer=nn.BatchNorm2d,
|
81 |
+
align_corners=True):
|
82 |
+
super(ObjectAttentionBlock2D, self).__init__()
|
83 |
+
self.scale = scale
|
84 |
+
self.in_channels = in_channels
|
85 |
+
self.key_channels = key_channels
|
86 |
+
self.align_corners = align_corners
|
87 |
+
|
88 |
+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
89 |
+
self.f_pixel = nn.Sequential(
|
90 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
91 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
92 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
93 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
94 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
95 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
96 |
+
)
|
97 |
+
self.f_object = nn.Sequential(
|
98 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
99 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
100 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
|
101 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
102 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
103 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
104 |
+
)
|
105 |
+
self.f_down = nn.Sequential(
|
106 |
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
107 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
108 |
+
nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
|
109 |
+
)
|
110 |
+
self.f_up = nn.Sequential(
|
111 |
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
112 |
+
kernel_size=1, stride=1, padding=0, bias=False),
|
113 |
+
nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x, proxy):
|
117 |
+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
118 |
+
if self.scale > 1:
|
119 |
+
x = self.pool(x)
|
120 |
+
|
121 |
+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
122 |
+
query = query.permute(0, 2, 1)
|
123 |
+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
124 |
+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
125 |
+
value = value.permute(0, 2, 1)
|
126 |
+
|
127 |
+
sim_map = torch.matmul(query, key)
|
128 |
+
sim_map = (self.key_channels ** -.5) * sim_map
|
129 |
+
sim_map = F.softmax(sim_map, dim=-1)
|
130 |
+
|
131 |
+
# add bg context ...
|
132 |
+
context = torch.matmul(sim_map, value)
|
133 |
+
context = context.permute(0, 2, 1).contiguous()
|
134 |
+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
135 |
+
context = self.f_up(context)
|
136 |
+
if self.scale > 1:
|
137 |
+
context = F.interpolate(input=context, size=(h, w),
|
138 |
+
mode='bilinear', align_corners=self.align_corners)
|
139 |
+
|
140 |
+
return context
|
model/hrnetv2/resnetv1b.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
|
4 |
+
|
5 |
+
|
6 |
+
class BasicBlockV1b(nn.Module):
|
7 |
+
expansion = 1
|
8 |
+
|
9 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
10 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
11 |
+
super(BasicBlockV1b, self).__init__()
|
12 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
13 |
+
padding=dilation, dilation=dilation, bias=False)
|
14 |
+
self.bn1 = norm_layer(planes)
|
15 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
|
16 |
+
padding=previous_dilation, dilation=previous_dilation, bias=False)
|
17 |
+
self.bn2 = norm_layer(planes)
|
18 |
+
|
19 |
+
self.relu = nn.ReLU(inplace=True)
|
20 |
+
self.downsample = downsample
|
21 |
+
self.stride = stride
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
residual = x
|
25 |
+
|
26 |
+
out = self.conv1(x)
|
27 |
+
out = self.bn1(out)
|
28 |
+
out = self.relu(out)
|
29 |
+
|
30 |
+
out = self.conv2(out)
|
31 |
+
out = self.bn2(out)
|
32 |
+
|
33 |
+
if self.downsample is not None:
|
34 |
+
residual = self.downsample(x)
|
35 |
+
|
36 |
+
out = out + residual
|
37 |
+
out = self.relu(out)
|
38 |
+
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class BottleneckV1b(nn.Module):
|
43 |
+
expansion = 4
|
44 |
+
|
45 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
|
46 |
+
previous_dilation=1, norm_layer=nn.BatchNorm2d):
|
47 |
+
super(BottleneckV1b, self).__init__()
|
48 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
49 |
+
self.bn1 = norm_layer(planes)
|
50 |
+
|
51 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
52 |
+
padding=dilation, dilation=dilation, bias=False)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
|
55 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
56 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
57 |
+
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.downsample = downsample
|
60 |
+
self.stride = stride
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
residual = x
|
64 |
+
|
65 |
+
out = self.conv1(x)
|
66 |
+
out = self.bn1(out)
|
67 |
+
out = self.relu(out)
|
68 |
+
|
69 |
+
out = self.conv2(out)
|
70 |
+
out = self.bn2(out)
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
out = self.conv3(out)
|
74 |
+
out = self.bn3(out)
|
75 |
+
|
76 |
+
if self.downsample is not None:
|
77 |
+
residual = self.downsample(x)
|
78 |
+
|
79 |
+
out = out + residual
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
class ResNetV1b(nn.Module):
|
86 |
+
""" Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
|
87 |
+
|
88 |
+
Parameters
|
89 |
+
----------
|
90 |
+
block : Block
|
91 |
+
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
|
92 |
+
layers : list of int
|
93 |
+
Numbers of layers in each block
|
94 |
+
classes : int, default 1000
|
95 |
+
Number of classification classes.
|
96 |
+
dilated : bool, default False
|
97 |
+
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
|
98 |
+
typically used in Semantic Segmentation.
|
99 |
+
norm_layer : object
|
100 |
+
Normalization layer used (default: :class:`nn.BatchNorm2d`)
|
101 |
+
deep_stem : bool, default False
|
102 |
+
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
|
103 |
+
avg_down : bool, default False
|
104 |
+
Whether to use average pooling for projection skip connection between stages/downsample.
|
105 |
+
final_drop : float, default 0.0
|
106 |
+
Dropout ratio before the final classification layer.
|
107 |
+
|
108 |
+
Reference:
|
109 |
+
- He, Kaiming, et al. "Deep residual learning for image recognition."
|
110 |
+
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
|
111 |
+
|
112 |
+
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
|
113 |
+
"""
|
114 |
+
def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
|
115 |
+
avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
|
116 |
+
self.inplanes = stem_width*2 if deep_stem else 64
|
117 |
+
super(ResNetV1b, self).__init__()
|
118 |
+
if not deep_stem:
|
119 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
120 |
+
else:
|
121 |
+
self.conv1 = nn.Sequential(
|
122 |
+
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
|
123 |
+
norm_layer(stem_width),
|
124 |
+
nn.ReLU(True),
|
125 |
+
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
|
126 |
+
norm_layer(stem_width),
|
127 |
+
nn.ReLU(True),
|
128 |
+
nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
|
129 |
+
)
|
130 |
+
self.bn1 = norm_layer(self.inplanes)
|
131 |
+
self.relu = nn.ReLU(True)
|
132 |
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
133 |
+
self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
|
134 |
+
norm_layer=norm_layer)
|
135 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
|
136 |
+
norm_layer=norm_layer)
|
137 |
+
if dilated:
|
138 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
|
139 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
140 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
|
141 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
142 |
+
else:
|
143 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
144 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
145 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
146 |
+
avg_down=avg_down, norm_layer=norm_layer)
|
147 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
148 |
+
self.drop = None
|
149 |
+
if final_drop > 0.0:
|
150 |
+
self.drop = nn.Dropout(final_drop)
|
151 |
+
self.fc = nn.Linear(512 * block.expansion, classes)
|
152 |
+
|
153 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
|
154 |
+
avg_down=False, norm_layer=nn.BatchNorm2d):
|
155 |
+
downsample = None
|
156 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
157 |
+
downsample = []
|
158 |
+
if avg_down:
|
159 |
+
if dilation == 1:
|
160 |
+
downsample.append(
|
161 |
+
nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
downsample.append(
|
165 |
+
nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
|
166 |
+
)
|
167 |
+
downsample.extend([
|
168 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
169 |
+
kernel_size=1, stride=1, bias=False),
|
170 |
+
norm_layer(planes * block.expansion)
|
171 |
+
])
|
172 |
+
downsample = nn.Sequential(*downsample)
|
173 |
+
else:
|
174 |
+
downsample = nn.Sequential(
|
175 |
+
nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
|
176 |
+
kernel_size=1, stride=stride, bias=False),
|
177 |
+
norm_layer(planes * block.expansion)
|
178 |
+
)
|
179 |
+
|
180 |
+
layers = []
|
181 |
+
if dilation in (1, 2):
|
182 |
+
layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
|
183 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
184 |
+
elif dilation == 4:
|
185 |
+
layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
|
186 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
187 |
+
else:
|
188 |
+
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
|
189 |
+
|
190 |
+
self.inplanes = planes * block.expansion
|
191 |
+
for _ in range(1, blocks):
|
192 |
+
layers.append(block(self.inplanes, planes, dilation=dilation,
|
193 |
+
previous_dilation=dilation, norm_layer=norm_layer))
|
194 |
+
|
195 |
+
return nn.Sequential(*layers)
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
x = self.conv1(x)
|
199 |
+
x = self.bn1(x)
|
200 |
+
x = self.relu(x)
|
201 |
+
x = self.maxpool(x)
|
202 |
+
|
203 |
+
x = self.layer1(x)
|
204 |
+
x = self.layer2(x)
|
205 |
+
x = self.layer3(x)
|
206 |
+
x = self.layer4(x)
|
207 |
+
|
208 |
+
x = self.avgpool(x)
|
209 |
+
x = x.view(x.size(0), -1)
|
210 |
+
if self.drop is not None:
|
211 |
+
x = self.drop(x)
|
212 |
+
x = self.fc(x)
|
213 |
+
|
214 |
+
return x
|
215 |
+
|
216 |
+
|
217 |
+
def _safe_state_dict_filtering(orig_dict, model_dict_keys):
|
218 |
+
filtered_orig_dict = {}
|
219 |
+
for k, v in orig_dict.items():
|
220 |
+
if k in model_dict_keys:
|
221 |
+
filtered_orig_dict[k] = v
|
222 |
+
else:
|
223 |
+
print(f"[ERROR] Failed to load <{k}> in backbone")
|
224 |
+
return filtered_orig_dict
|
225 |
+
|
226 |
+
|
227 |
+
def resnet34_v1b(pretrained=False, **kwargs):
|
228 |
+
model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
|
229 |
+
if pretrained:
|
230 |
+
model_dict = model.state_dict()
|
231 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
232 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
|
233 |
+
model_dict.keys()
|
234 |
+
)
|
235 |
+
model_dict.update(filtered_orig_dict)
|
236 |
+
model.load_state_dict(model_dict)
|
237 |
+
return model
|
238 |
+
|
239 |
+
|
240 |
+
def resnet50_v1s(pretrained=False, **kwargs):
|
241 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
|
242 |
+
if pretrained:
|
243 |
+
model_dict = model.state_dict()
|
244 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
245 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
|
246 |
+
model_dict.keys()
|
247 |
+
)
|
248 |
+
model_dict.update(filtered_orig_dict)
|
249 |
+
model.load_state_dict(model_dict)
|
250 |
+
return model
|
251 |
+
|
252 |
+
|
253 |
+
def resnet101_v1s(pretrained=False, **kwargs):
|
254 |
+
model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
|
255 |
+
if pretrained:
|
256 |
+
model_dict = model.state_dict()
|
257 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
258 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
|
259 |
+
model_dict.keys()
|
260 |
+
)
|
261 |
+
model_dict.update(filtered_orig_dict)
|
262 |
+
model.load_state_dict(model_dict)
|
263 |
+
return model
|
264 |
+
|
265 |
+
|
266 |
+
def resnet152_v1s(pretrained=False, **kwargs):
|
267 |
+
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
|
268 |
+
if pretrained:
|
269 |
+
model_dict = model.state_dict()
|
270 |
+
filtered_orig_dict = _safe_state_dict_filtering(
|
271 |
+
torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
|
272 |
+
model_dict.keys()
|
273 |
+
)
|
274 |
+
model_dict.update(filtered_orig_dict)
|
275 |
+
model.load_state_dict(model_dict)
|
276 |
+
return model
|
model/lut_transformation_net.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from utils.misc import normalize
|
6 |
+
|
7 |
+
|
8 |
+
class build_lut_transform(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, input_dim, lut_dim, input_resolution, opt):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.lut_dim = lut_dim
|
14 |
+
self.opt = opt
|
15 |
+
|
16 |
+
# self.compress_layer = nn.Linear(input_resolution, 1)
|
17 |
+
|
18 |
+
self.transform_layers = nn.Sequential(
|
19 |
+
nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True),
|
20 |
+
# nn.BatchNorm1d(3 * lut_dim ** 3, affine=False),
|
21 |
+
nn.ReLU(inplace=True),
|
22 |
+
nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True),
|
23 |
+
)
|
24 |
+
self.transform_layers[-1].apply(lambda m: hyper_weight_init(m))
|
25 |
+
|
26 |
+
def forward(self, composite_image, fg_appearance_features, bg_appearance_features):
|
27 |
+
composite_image = normalize(composite_image, self.opt, 'inv')
|
28 |
+
|
29 |
+
features = fg_appearance_features
|
30 |
+
|
31 |
+
lut_params = self.transform_layers(features)
|
32 |
+
|
33 |
+
fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim)
|
34 |
+
|
35 |
+
lut_transform_image = torch.stack(
|
36 |
+
[TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0)
|
37 |
+
|
38 |
+
return fit_3DLUT, normalize(lut_transform_image, self.opt)
|
39 |
+
|
40 |
+
|
41 |
+
def TrilinearInterpolation(LUT, img):
|
42 |
+
img = (img - 0.5) * 2.
|
43 |
+
|
44 |
+
img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1)
|
45 |
+
|
46 |
+
# Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW.
|
47 |
+
LUT = LUT[None]
|
48 |
+
|
49 |
+
# grid sample
|
50 |
+
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
|
51 |
+
|
52 |
+
# drop added dimensions and permute back
|
53 |
+
result = result[:, :, 0]
|
54 |
+
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
def hyper_weight_init(m):
|
59 |
+
if hasattr(m, 'weight'):
|
60 |
+
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
|
61 |
+
m.weight.data = m.weight.data / 1.e2
|
62 |
+
|
63 |
+
if hasattr(m, 'bias'):
|
64 |
+
with torch.no_grad():
|
65 |
+
m.bias.uniform_(0., 1.)
|
pretrained_models/Resolution_1024_HAdobe5K.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4917e99cc20c2530b6d248d530368929c1784113d20365085b96bbb10860a2f8
|
3 |
+
size 477235439
|