Spaces:
Runtime error
Runtime error
yichen-purdue
commited on
Commit
•
34fb220
1
Parent(s):
96d9168
init
Browse files- app.py +214 -0
- configs/GSSN.yaml +57 -0
- configs/SSN.yaml +51 -0
- model_utils.py +53 -0
- models/Attention.ipynb +509 -0
- models/Attention_SSN.py +218 -0
- models/Attention_Unet.py +165 -0
- models/GSSN.py +176 -0
- models/Loss/Loss.py +271 -0
- models/Loss/__init__.py +0 -0
- models/Loss/__pycache__/Loss.cpython-39.pyc +0 -0
- models/Loss/__pycache__/__init__.cpython-39.pyc +0 -0
- models/Loss/__pycache__/vgg19_loss.cpython-39.pyc +0 -0
- models/Loss/pytorch_ssim/__init__.py +73 -0
- models/Loss/pytorch_ssim/__pycache__/__init__.cpython-39.pyc +0 -0
- models/Loss/vgg19_loss.py +54 -0
- models/SSN.py +143 -0
- models/SSN_Model.py +333 -0
- models/SSN_v1.py +290 -0
- models/Sparse_PH.py +185 -0
- models/__init__.py +43 -0
- models/__pycache__/SSN.cpython-39.pyc +0 -0
- models/__pycache__/SSN_Model.cpython-39.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/abs_model.cpython-39.pyc +0 -0
- models/__pycache__/blocks.cpython-39.pyc +0 -0
- models/abs_model.py +73 -0
- models/attention.py +85 -0
- models/blocks.py +238 -0
- models/pvt_attention.py +240 -0
- models/template.py +114 -0
- weights/SSN/0000001760.pt +3 -0
app.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import logging
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
import model_utils
|
11 |
+
from models.SSN import SSN
|
12 |
+
|
13 |
+
import matplotlib
|
14 |
+
matplotlib.use('TkAgg')
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
config_file = 'configs/SSN.yaml'
|
20 |
+
weight = 'weights/SSN/0000001760.pt'
|
21 |
+
device = torch.device('cuda:0')
|
22 |
+
model = model_utils.load_model(config_file, weight, SSN, device)
|
23 |
+
|
24 |
+
DEFAULT_INTENSITY = 0.9
|
25 |
+
DEFAULT_GAMMA = 2.0
|
26 |
+
|
27 |
+
logging.info('Model loading succeed')
|
28 |
+
|
29 |
+
cur_rgba = None
|
30 |
+
cur_shadow = None
|
31 |
+
cur_intensity = DEFAULT_INTENSITY
|
32 |
+
cur_gamma = DEFAULT_GAMMA
|
33 |
+
|
34 |
+
def resize(img, size):
|
35 |
+
h, w = img.shape[:2]
|
36 |
+
|
37 |
+
if h > w:
|
38 |
+
newh = size
|
39 |
+
neww = int(w / h * size)
|
40 |
+
else:
|
41 |
+
neww = size
|
42 |
+
newh = int(h / w * size)
|
43 |
+
|
44 |
+
resized_img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
|
45 |
+
if len(img.shape) != len(resized_img.shape):
|
46 |
+
resized_img = resized_img[..., none]
|
47 |
+
|
48 |
+
return resized_img
|
49 |
+
|
50 |
+
|
51 |
+
def ibl_normalize(ibl, energy=30.0):
|
52 |
+
total_energy = np.sum(ibl)
|
53 |
+
if total_energy < 1e-3:
|
54 |
+
# print('small energy: ', total_energy)
|
55 |
+
h,w = ibl.shape
|
56 |
+
return np.zeros((h,w))
|
57 |
+
|
58 |
+
return ibl * energy / total_energy
|
59 |
+
|
60 |
+
|
61 |
+
def padding_mask(rgba_input: np.array):
|
62 |
+
""" Padding the mask input so that it fits the training dataset view range
|
63 |
+
|
64 |
+
If the rgba does not have enough padding area, we need to pad the area
|
65 |
+
|
66 |
+
:param rgba_input: H x W x 4 inputs, the first 3 channels are RGB, the last channel is the alpha
|
67 |
+
:returns: H x W x 4 padded RGBAD
|
68 |
+
|
69 |
+
"""
|
70 |
+
padding = 50
|
71 |
+
padding_size = 256 - padding * 2
|
72 |
+
|
73 |
+
h, w = rgba_input.shape[:2]
|
74 |
+
rgb = rgba_input[:, :, :3]
|
75 |
+
alpha = rgba_input[:, :, -1:]
|
76 |
+
|
77 |
+
zeros = np.where(alpha==0)
|
78 |
+
hh, ww = zeros[0], zeros[1]
|
79 |
+
h_min, h_max = hh.min(), hh.max()
|
80 |
+
w_min, w_max = ww.min(), ww.max()
|
81 |
+
|
82 |
+
# if the area already has enough padding
|
83 |
+
if h_max - h_min < padding_size and w_max - w_min < padding_size:
|
84 |
+
return rgba_input
|
85 |
+
|
86 |
+
padding_output = np.zeros((256, 256, 4))
|
87 |
+
padding_output[..., :3] = 1.0
|
88 |
+
|
89 |
+
padded_rgba = resize(rgba_input, padding_size)
|
90 |
+
new_h, new_w = padded_rgba.shape[:2]
|
91 |
+
|
92 |
+
padding_output[padding:padding+new_h, padding:padding+new_w, :] = padded_rgba
|
93 |
+
|
94 |
+
return padding_output
|
95 |
+
|
96 |
+
def shadow_composite(rgba, shadow, intensity, gamma):
|
97 |
+
rgb = rgba[..., :3]
|
98 |
+
mask = rgba[..., 3:]
|
99 |
+
|
100 |
+
if len(shadow.shape) == 2:
|
101 |
+
shadow = shadow[..., None]
|
102 |
+
|
103 |
+
new_shadow = 1.0 - shadow ** gamma * intensity
|
104 |
+
ret = rgb * mask + (1.0 - mask) * new_shadow
|
105 |
+
return ret, new_shadow[..., 0]
|
106 |
+
|
107 |
+
|
108 |
+
def render_btn_fn(mask, ibl):
|
109 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
110 |
+
|
111 |
+
print("Button clicked!")
|
112 |
+
|
113 |
+
mask = mask / 255.0
|
114 |
+
ibl = ibl/ 255.0
|
115 |
+
|
116 |
+
# smoothing ibl
|
117 |
+
ibl = cv2.GaussianBlur(ibl, (11, 11), 0)
|
118 |
+
|
119 |
+
# padding mask
|
120 |
+
mask = padding_mask(mask)
|
121 |
+
|
122 |
+
cur_rgba = np.copy(mask)
|
123 |
+
|
124 |
+
|
125 |
+
print('mask shape: {}/{}/{}/{}, ibl shape: {}/{}/{}/{}'.format(mask.shape, mask.dtype, mask.min(), mask.max(),
|
126 |
+
ibl.shape, ibl.dtype, ibl.min(), ibl.max()))
|
127 |
+
|
128 |
+
# ret = np.random.randn(256, 256, 3)
|
129 |
+
# ret = (ret - ret.min()) / (ret.max() - ret.min() + 1e-8)
|
130 |
+
|
131 |
+
rgb, mask = mask[..., :3], mask[..., 3]
|
132 |
+
|
133 |
+
ibl = ibl_normalize(cv2.resize(ibl, (32, 16)))
|
134 |
+
|
135 |
+
# ibl = 1.0 - ibl
|
136 |
+
|
137 |
+
x = {
|
138 |
+
'mask': mask,
|
139 |
+
'ibl': ibl
|
140 |
+
}
|
141 |
+
shadow = model.inference(x)
|
142 |
+
cur_shadow = np.copy(shadow)
|
143 |
+
|
144 |
+
# gamma
|
145 |
+
# shadow = np.power(shadow, 2.2)
|
146 |
+
# shadow = shadow * 0.8
|
147 |
+
# shadow = 1.0 - shadow
|
148 |
+
|
149 |
+
# composite the shadow
|
150 |
+
|
151 |
+
# shadow = shadow[..., None]
|
152 |
+
# mask = mask[..., None]
|
153 |
+
# ret = rgb * mask + (1.0 - mask) * shadow
|
154 |
+
ret, shadow = shadow_composite(cur_rgba, shadow, cur_intensity, cur_gamma)
|
155 |
+
|
156 |
+
# import pdb; pdb.set_trace()
|
157 |
+
# ret = (1.0-mask) * shadow
|
158 |
+
|
159 |
+
print('IBL range: {}/{} Shadow range: {} {}'.format(ibl.min(), ibl.max(), shadow.min(), shadow.max()))
|
160 |
+
|
161 |
+
plt.figure(figsize=(15, 10))
|
162 |
+
plt.subplot(1,3,1)
|
163 |
+
plt.imshow(mask)
|
164 |
+
plt.subplot(1,3,2)
|
165 |
+
plt.imshow(ibl)
|
166 |
+
plt.subplot(1,3,3)
|
167 |
+
plt.imshow(ret)
|
168 |
+
plt.savefig('tmp.png')
|
169 |
+
plt.close()
|
170 |
+
|
171 |
+
logging.info('Finished')
|
172 |
+
|
173 |
+
return ret, shadow
|
174 |
+
|
175 |
+
|
176 |
+
def intensity_change(x):
|
177 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
178 |
+
|
179 |
+
cur_intensity = x
|
180 |
+
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
|
181 |
+
return ret, shadow
|
182 |
+
|
183 |
+
|
184 |
+
def gamma_change(x):
|
185 |
+
global cur_rgba, cur_shadow, cur_gamma, cur_intensity
|
186 |
+
|
187 |
+
cur_gamma = x
|
188 |
+
ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
|
189 |
+
return ret, shadow
|
190 |
+
|
191 |
+
|
192 |
+
ibl_h = 128
|
193 |
+
ibl_w = ibl_h * 2
|
194 |
+
|
195 |
+
with gr.Blocks() as demo:
|
196 |
+
with gr.Row():
|
197 |
+
mask_input = gr.Image(shape=(256, 256), image_mode="RGBA", label="Mask")
|
198 |
+
ibl_input = gr.Sketchpad(shape=(ibl_w, ibl_h), image_mode="L", label="IBL", tool='sketch', invert_colors=True)
|
199 |
+
output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="RGB", label="Output")
|
200 |
+
shadow_output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="L", label="Shadow Layer")
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
intensity_slider = gr.Slider(0.0, 1.0, value=DEFAULT_INTENSITY, step=0.1, label="Intensity", info="Choose between 0.0 and 1.0")
|
204 |
+
gamma_slider = gr.Slider(1.0, 4.0, value=DEFAULT_GAMMA, step=0.1, label="Gamma", info="Gamma correction for shadow")
|
205 |
+
render_btn = gr.Button(label="Render")
|
206 |
+
|
207 |
+
render_btn.click(render_btn_fn, inputs=[mask_input, ibl_input], outputs=[output, shadow_output])
|
208 |
+
intensity_slider.release(intensity_change, inputs=[intensity_slider], outputs=[output, shadow_output])
|
209 |
+
gamma_slider.release(gamma_change, inputs=[gamma_slider], outputs=[output, shadow_output])
|
210 |
+
|
211 |
+
logging.info('Finished')
|
212 |
+
|
213 |
+
|
214 |
+
demo.launch()
|
configs/GSSN.yaml
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exp_name: GSSN_ALL_Channels_2e_5
|
2 |
+
|
3 |
+
# model related
|
4 |
+
model:
|
5 |
+
name: 'GSSN'
|
6 |
+
# backbone: 'vanilla'
|
7 |
+
backbone: 'SSN_v1'
|
8 |
+
in_channels: 6
|
9 |
+
out_channels: 1
|
10 |
+
resnet: True
|
11 |
+
|
12 |
+
mid_act: "gelu"
|
13 |
+
out_act: "gelu"
|
14 |
+
|
15 |
+
optimizer: 'Adam'
|
16 |
+
weight_decay: 4e-5
|
17 |
+
beta1: 0.9
|
18 |
+
|
19 |
+
focal: False
|
20 |
+
|
21 |
+
# dataset
|
22 |
+
dataset:
|
23 |
+
name: 'GSSN_Dataset'
|
24 |
+
hdf5_file: 'Dataset1/more_general_scenes/train/ALL_SIZE_WALL/dataset.hdf5'
|
25 |
+
type: 'BC_Boundary'
|
26 |
+
rech_grad: True
|
27 |
+
|
28 |
+
|
29 |
+
test_dataset:
|
30 |
+
name: 'GSSN_Testing_Dataset'
|
31 |
+
hdf5_file: 'Dataset/standalone_test_split/test/ALL_SIZE_MORE/dataset.hdf5'
|
32 |
+
type: 'BC_Boundary'
|
33 |
+
ignore_shading: True
|
34 |
+
rech_grad: True
|
35 |
+
|
36 |
+
|
37 |
+
# training related
|
38 |
+
hyper_params:
|
39 |
+
lr: 2e-5
|
40 |
+
epochs: 100000
|
41 |
+
workers: 52
|
42 |
+
batch_size: 52
|
43 |
+
save_epoch: 10
|
44 |
+
|
45 |
+
eval_batch: 10
|
46 |
+
eval_save: False
|
47 |
+
|
48 |
+
# visualization
|
49 |
+
vis_iter: 100 # iteration for visualization
|
50 |
+
save_iter: 100
|
51 |
+
n_cols: 5
|
52 |
+
gpus:
|
53 |
+
- 0
|
54 |
+
default_folder: 'weights'
|
55 |
+
resume: False
|
56 |
+
# resume: True
|
57 |
+
weight_file: 'latest'
|
configs/SSN.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exp_name: SSN
|
2 |
+
|
3 |
+
# model related
|
4 |
+
model:
|
5 |
+
name: 'SSN'
|
6 |
+
in_channels: 1
|
7 |
+
out_channels: 1
|
8 |
+
resnet: False
|
9 |
+
|
10 |
+
mid_act: "relu"
|
11 |
+
out_act: 'relu'
|
12 |
+
|
13 |
+
optimizer: 'Adam'
|
14 |
+
weight_decay: 4e-5
|
15 |
+
beta1: 0.9
|
16 |
+
|
17 |
+
|
18 |
+
# dataset
|
19 |
+
dataset:
|
20 |
+
name: 'SSN_Dataset'
|
21 |
+
hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
|
22 |
+
shadow_per_epoch: 10
|
23 |
+
|
24 |
+
|
25 |
+
# test_dataset:
|
26 |
+
# name: 'SSN_Dataset'
|
27 |
+
# hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
|
28 |
+
|
29 |
+
|
30 |
+
# training related
|
31 |
+
hyper_params:
|
32 |
+
lr: 1e-3
|
33 |
+
epochs: 100000
|
34 |
+
workers: 40
|
35 |
+
batch_size: 10
|
36 |
+
save_epoch: 10
|
37 |
+
|
38 |
+
eval_batch: 10
|
39 |
+
eval_save: False
|
40 |
+
|
41 |
+
# visualization
|
42 |
+
vis_iter: 100 # iteration for visualization
|
43 |
+
save_iter: 100
|
44 |
+
n_cols: 5
|
45 |
+
gpus:
|
46 |
+
- 0
|
47 |
+
- 1
|
48 |
+
|
49 |
+
default_folder: 'weights'
|
50 |
+
resume: False
|
51 |
+
weight_file: 'latest'
|
model_utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import logging
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def parse_configs(config: str):
|
9 |
+
""" Parse the config file and return a dictionary of configs
|
10 |
+
|
11 |
+
:param config: path to the config file
|
12 |
+
:returns:
|
13 |
+
|
14 |
+
"""
|
15 |
+
if not os.path.exists(config):
|
16 |
+
logging.error('Cannot find the config file: {}'.format(config))
|
17 |
+
exit()
|
18 |
+
|
19 |
+
with open(config, 'r') as stream:
|
20 |
+
try:
|
21 |
+
configs=yaml.safe_load(stream)
|
22 |
+
return configs
|
23 |
+
|
24 |
+
except yaml.YAMLError as exc:
|
25 |
+
logging.error(exc)
|
26 |
+
return {}
|
27 |
+
|
28 |
+
|
29 |
+
def load_model(config: str, weight: str, model_def, device):
|
30 |
+
""" Load the model from the config file and the weight file
|
31 |
+
|
32 |
+
:param config: path to the config file
|
33 |
+
:param weight: path to the weight file
|
34 |
+
:param model_def: model class definition
|
35 |
+
:param device: pytorch device
|
36 |
+
:returns:
|
37 |
+
|
38 |
+
"""
|
39 |
+
assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight)
|
40 |
+
assert os.path.exists(config), 'Cannot find the config file: {}'.format(config)
|
41 |
+
|
42 |
+
|
43 |
+
opt = parse_configs(config)
|
44 |
+
model = model_def(opt)
|
45 |
+
cp = torch.load(weight)
|
46 |
+
|
47 |
+
models = model.get_models()
|
48 |
+
for k, m in models.items():
|
49 |
+
m.load_state_dict(cp[k])
|
50 |
+
m.to(device)
|
51 |
+
|
52 |
+
model.set_models(models)
|
53 |
+
return model
|
models/Attention.ipynb
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 30,
|
6 |
+
"id": "9ba18e04-aa6b-44d8-bbcc-73417ededcfd",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import torch\n",
|
11 |
+
"import torch.nn as nn\n",
|
12 |
+
"import torch.nn.functional as F\n",
|
13 |
+
"from functools import partial\n",
|
14 |
+
"import math\n",
|
15 |
+
"import torch as th"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": 31,
|
21 |
+
"id": "b273789d-9136-4c10-806d-12c19ff1ae68",
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"class GroupNorm32(nn.GroupNorm):\n",
|
26 |
+
" def forward(self, x):\n",
|
27 |
+
" return super().forward(x.float()).type(x.dtype)\n",
|
28 |
+
"\n",
|
29 |
+
"def normalization(channels):\n",
|
30 |
+
" \"\"\"\n",
|
31 |
+
" Make a standard normalization layer.\n",
|
32 |
+
" :param channels: number of input channels.\n",
|
33 |
+
" :return: an nn.Module for normalization.\n",
|
34 |
+
" \"\"\"\n",
|
35 |
+
" return GroupNorm32(32, channels)\n",
|
36 |
+
"\n",
|
37 |
+
"\n",
|
38 |
+
"def conv_nd(dims, *args, **kwargs):\n",
|
39 |
+
" \"\"\"\n",
|
40 |
+
" Create a 1D, 2D, or 3D convolution module.\n",
|
41 |
+
" \"\"\"\n",
|
42 |
+
" if dims == 1:\n",
|
43 |
+
" return nn.Conv1d(*args, **kwargs)\n",
|
44 |
+
" elif dims == 2:\n",
|
45 |
+
" return nn.Conv2d(*args, **kwargs)\n",
|
46 |
+
" elif dims == 3:\n",
|
47 |
+
" return nn.Conv3d(*args, **kwargs)\n",
|
48 |
+
" raise ValueError(f\"unsupported dimensions: {dims}\")\n"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 32,
|
54 |
+
"id": "8ad13d44-7efc-4cf3-8f18-3c6ed4999963",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"class QKVAttentionLegacy(nn.Module):\n",
|
59 |
+
" \"\"\"\n",
|
60 |
+
" A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n",
|
61 |
+
" \"\"\"\n",
|
62 |
+
"\n",
|
63 |
+
" def __init__(self, n_heads):\n",
|
64 |
+
" super().__init__()\n",
|
65 |
+
" self.n_heads = n_heads\n",
|
66 |
+
"\n",
|
67 |
+
" def forward(self, qkv):\n",
|
68 |
+
" \"\"\"\n",
|
69 |
+
" Apply QKV attention.\n",
|
70 |
+
" :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n",
|
71 |
+
" :return: an [N x (H * C) x T] tensor after attention.\n",
|
72 |
+
" \"\"\"\n",
|
73 |
+
" bs, width, length = qkv.shape\n",
|
74 |
+
" assert width % (3 * self.n_heads) == 0\n",
|
75 |
+
" ch = width // (3 * self.n_heads)\n",
|
76 |
+
" q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n",
|
77 |
+
" scale = 1 / math.sqrt(math.sqrt(ch))\n",
|
78 |
+
" weight = th.einsum(\n",
|
79 |
+
" \"bct,bcs->bts\", q * scale, k * scale\n",
|
80 |
+
" ) # More stable with f16 than dividing afterwards\n",
|
81 |
+
" weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n",
|
82 |
+
" a = th.einsum(\"bts,bcs->bct\", weight, v)\n",
|
83 |
+
" return a.reshape(bs, -1, length)\n",
|
84 |
+
"\n",
|
85 |
+
" @staticmethod\n",
|
86 |
+
" def count_flops(model, _x, y):\n",
|
87 |
+
" return count_flops_attn(model, _x, y)"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 33,
|
93 |
+
"id": "fd354430-2484-4f46-85f6-3397ae571fe9",
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"def zero_module(module):\n",
|
98 |
+
" \"\"\"\n",
|
99 |
+
" Zero out the parameters of a module and return it.\n",
|
100 |
+
" \"\"\"\n",
|
101 |
+
" for p in module.parameters():\n",
|
102 |
+
" p.detach().zero_()\n",
|
103 |
+
" return module\n"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": 37,
|
109 |
+
"id": "af42604f-c5fe-467b-95e9-e376fe90d4a5",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"class AttentionBlock(nn.Module):\n",
|
114 |
+
" \"\"\"\n",
|
115 |
+
" An attention block that allows spatial positions to attend to each other.\n",
|
116 |
+
" Originally ported from here, but adapted to the N-d case.\n",
|
117 |
+
" https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n",
|
118 |
+
" \"\"\"\n",
|
119 |
+
"\n",
|
120 |
+
" def __init__(\n",
|
121 |
+
" self,\n",
|
122 |
+
" channels,\n",
|
123 |
+
" num_heads=1,\n",
|
124 |
+
" num_head_channels=-1,\n",
|
125 |
+
" use_new_attention_order=False,\n",
|
126 |
+
" ):\n",
|
127 |
+
" super().__init__()\n",
|
128 |
+
" self.channels = channels\n",
|
129 |
+
" if num_head_channels == -1:\n",
|
130 |
+
" self.num_heads = num_heads\n",
|
131 |
+
" else:\n",
|
132 |
+
" assert (\n",
|
133 |
+
" channels % num_head_channels == 0\n",
|
134 |
+
" ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n",
|
135 |
+
" self.num_heads = channels // num_head_channels\n",
|
136 |
+
" self.norm = normalization(channels)\n",
|
137 |
+
" self.qkv = conv_nd(1, channels, channels * 3, 1)\n",
|
138 |
+
" if use_new_attention_order:\n",
|
139 |
+
" # split qkv before split heads\n",
|
140 |
+
" self.attention = QKVAttention(self.num_heads)\n",
|
141 |
+
" else:\n",
|
142 |
+
" # split heads before split qkv\n",
|
143 |
+
" self.attention = QKVAttentionLegacy(self.num_heads)\n",
|
144 |
+
"\n",
|
145 |
+
" self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n",
|
146 |
+
"\n",
|
147 |
+
" def forward(self, x):\n",
|
148 |
+
" \n",
|
149 |
+
" import pdb; pdb.set_trace()\n",
|
150 |
+
" \n",
|
151 |
+
" b, c, *spatial = x.shape\n",
|
152 |
+
" x = x.reshape(b, c, -1)\n",
|
153 |
+
" qkv = self.qkv(self.norm(x))\n",
|
154 |
+
" h = self.attention(qkv)\n",
|
155 |
+
" h = self.proj_out(h)\n",
|
156 |
+
" return (x + h).reshape(b, c, *spatial)"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": 38,
|
162 |
+
"id": "7b180b84-f22c-446b-b2da-0fa987274953",
|
163 |
+
"metadata": {},
|
164 |
+
"outputs": [
|
165 |
+
{
|
166 |
+
"name": "stdout",
|
167 |
+
"output_type": "stream",
|
168 |
+
"text": [
|
169 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(39)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
170 |
+
"\u001b[0;32m 37 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
171 |
+
"\u001b[0m\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
172 |
+
"\u001b[0m\u001b[0;32m---> 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
173 |
+
"\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
174 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
175 |
+
"\u001b[0m\n"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"name": "stdin",
|
180 |
+
"output_type": "stream",
|
181 |
+
"text": [
|
182 |
+
"ipdb> n\n"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"name": "stdout",
|
187 |
+
"output_type": "stream",
|
188 |
+
"text": [
|
189 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(40)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
190 |
+
"\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
191 |
+
"\u001b[0m\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
192 |
+
"\u001b[0m\u001b[0;32m---> 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
193 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
194 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
195 |
+
"\u001b[0m\n"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"name": "stdin",
|
200 |
+
"output_type": "stream",
|
201 |
+
"text": [
|
202 |
+
"ipdb> n\n"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"name": "stdout",
|
207 |
+
"output_type": "stream",
|
208 |
+
"text": [
|
209 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(41)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
210 |
+
"\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
211 |
+
"\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
212 |
+
"\u001b[0m\u001b[0;32m---> 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
213 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
214 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
215 |
+
"\u001b[0m\n"
|
216 |
+
]
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"name": "stdin",
|
220 |
+
"output_type": "stream",
|
221 |
+
"text": [
|
222 |
+
"ipdb> x.shape\n"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"torch.Size([5, 32, 16384])\n"
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"name": "stdin",
|
234 |
+
"output_type": "stream",
|
235 |
+
"text": [
|
236 |
+
"ipdb> t = self.norm(x)\n",
|
237 |
+
"ipdb> t.shape\n"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"name": "stdout",
|
242 |
+
"output_type": "stream",
|
243 |
+
"text": [
|
244 |
+
"torch.Size([5, 32, 16384])\n"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"name": "stdin",
|
249 |
+
"output_type": "stream",
|
250 |
+
"text": [
|
251 |
+
"ipdb> self.qkv\n"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"name": "stdout",
|
256 |
+
"output_type": "stream",
|
257 |
+
"text": [
|
258 |
+
"Conv1d(32, 96, kernel_size=(1,), stride=(1,))\n"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"name": "stdin",
|
263 |
+
"output_type": "stream",
|
264 |
+
"text": [
|
265 |
+
"ipdb> n\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"name": "stdout",
|
270 |
+
"output_type": "stream",
|
271 |
+
"text": [
|
272 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(42)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
273 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
274 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
275 |
+
"\u001b[0m\u001b[0;32m---> 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
276 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
277 |
+
"\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
278 |
+
"\u001b[0m\n"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"name": "stdin",
|
283 |
+
"output_type": "stream",
|
284 |
+
"text": [
|
285 |
+
"ipdb> qkv.shape\n"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"name": "stdout",
|
290 |
+
"output_type": "stream",
|
291 |
+
"text": [
|
292 |
+
"torch.Size([5, 96, 16384])\n"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"name": "stdin",
|
297 |
+
"output_type": "stream",
|
298 |
+
"text": [
|
299 |
+
"ipdb> t.shape\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"name": "stdout",
|
304 |
+
"output_type": "stream",
|
305 |
+
"text": [
|
306 |
+
"torch.Size([5, 32, 16384])\n"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"name": "stdin",
|
311 |
+
"output_type": "stream",
|
312 |
+
"text": [
|
313 |
+
"ipdb> n\n"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"name": "stdout",
|
318 |
+
"output_type": "stream",
|
319 |
+
"text": [
|
320 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(43)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
321 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
322 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
323 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
324 |
+
"\u001b[0m\u001b[0;32m---> 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
325 |
+
"\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
326 |
+
"\u001b[0m\n"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"name": "stdin",
|
331 |
+
"output_type": "stream",
|
332 |
+
"text": [
|
333 |
+
"ipdb> h.shape\n"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"name": "stdout",
|
338 |
+
"output_type": "stream",
|
339 |
+
"text": [
|
340 |
+
"*** No help for '.shape'\n"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"name": "stdin",
|
345 |
+
"output_type": "stream",
|
346 |
+
"text": [
|
347 |
+
"ipdb> h.shape\n"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"name": "stdout",
|
352 |
+
"output_type": "stream",
|
353 |
+
"text": [
|
354 |
+
"*** No help for '.shape'\n"
|
355 |
+
]
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"name": "stdin",
|
359 |
+
"output_type": "stream",
|
360 |
+
"text": [
|
361 |
+
"ipdb> print(h.shape)\n"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"name": "stdout",
|
366 |
+
"output_type": "stream",
|
367 |
+
"text": [
|
368 |
+
"torch.Size([5, 32, 16384])\n"
|
369 |
+
]
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"name": "stdin",
|
373 |
+
"output_type": "stream",
|
374 |
+
"text": [
|
375 |
+
"ipdb> self.proj_out\n"
|
376 |
+
]
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"name": "stdout",
|
380 |
+
"output_type": "stream",
|
381 |
+
"text": [
|
382 |
+
"Conv1d(32, 32, kernel_size=(1,), stride=(1,))\n"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"name": "stdin",
|
387 |
+
"output_type": "stream",
|
388 |
+
"text": [
|
389 |
+
"ipdb> n\n"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"name": "stdout",
|
394 |
+
"output_type": "stream",
|
395 |
+
"text": [
|
396 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
397 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
398 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
399 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
400 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
401 |
+
"\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
402 |
+
"\u001b[0m\n"
|
403 |
+
]
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"name": "stdin",
|
407 |
+
"output_type": "stream",
|
408 |
+
"text": [
|
409 |
+
"ipdb> \n"
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"name": "stdout",
|
414 |
+
"output_type": "stream",
|
415 |
+
"text": [
|
416 |
+
"--Return--\n",
|
417 |
+
"tensor([[[[ 1...iasBackward0>)\n",
|
418 |
+
"> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
|
419 |
+
"\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
420 |
+
"\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
421 |
+
"\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
422 |
+
"\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
423 |
+
"\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
424 |
+
"\u001b[0m\n"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"name": "stdin",
|
429 |
+
"output_type": "stream",
|
430 |
+
"text": [
|
431 |
+
"ipdb> q\n"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"ename": "BdbQuit",
|
436 |
+
"evalue": "",
|
437 |
+
"output_type": "error",
|
438 |
+
"traceback": [
|
439 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
440 |
+
"\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
|
441 |
+
"\u001b[0;32m/tmp/ipykernel_456404/1120562961.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAttentionBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
442 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
443 |
+
"\u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
444 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'return'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_return\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'exception'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
445 |
+
"\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mdispatch_return\u001b[0;34m(self, frame, arg)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_returning\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0;31m# The user issued a 'next' or 'until' command.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopframe\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mframe\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstoplineno\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
446 |
+
"\u001b[0;31mBdbQuit\u001b[0m: "
|
447 |
+
]
|
448 |
+
}
|
449 |
+
],
|
450 |
+
"source": [
|
451 |
+
"test_input = torch.randn(5, 32, 128, 128)\n",
|
452 |
+
"\n",
|
453 |
+
"model = AttentionBlock(32, 1)\n",
|
454 |
+
"\n",
|
455 |
+
"y = model(test_input)"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "code",
|
460 |
+
"execution_count": 36,
|
461 |
+
"id": "3109500e-146d-46c4-8709-6a1e8d24e4ac",
|
462 |
+
"metadata": {},
|
463 |
+
"outputs": [
|
464 |
+
{
|
465 |
+
"data": {
|
466 |
+
"text/plain": [
|
467 |
+
"torch.Size([5, 32, 128, 128])"
|
468 |
+
]
|
469 |
+
},
|
470 |
+
"execution_count": 36,
|
471 |
+
"metadata": {},
|
472 |
+
"output_type": "execute_result"
|
473 |
+
}
|
474 |
+
],
|
475 |
+
"source": [
|
476 |
+
"y.shape"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": null,
|
482 |
+
"id": "0c916f9c-5dba-499d-99ea-e56f2855c9cc",
|
483 |
+
"metadata": {},
|
484 |
+
"outputs": [],
|
485 |
+
"source": []
|
486 |
+
}
|
487 |
+
],
|
488 |
+
"metadata": {
|
489 |
+
"kernelspec": {
|
490 |
+
"display_name": "Python 3 (ipykernel)",
|
491 |
+
"language": "python",
|
492 |
+
"name": "python3"
|
493 |
+
},
|
494 |
+
"language_info": {
|
495 |
+
"codemirror_mode": {
|
496 |
+
"name": "ipython",
|
497 |
+
"version": 3
|
498 |
+
},
|
499 |
+
"file_extension": ".py",
|
500 |
+
"mimetype": "text/x-python",
|
501 |
+
"name": "python",
|
502 |
+
"nbconvert_exporter": "python",
|
503 |
+
"pygments_lexer": "ipython3",
|
504 |
+
"version": "3.8.12"
|
505 |
+
}
|
506 |
+
},
|
507 |
+
"nbformat": 4,
|
508 |
+
"nbformat_minor": 5
|
509 |
+
}
|
models/Attention_SSN.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from functools import partial
|
3 |
+
from typing import Iterable
|
4 |
+
import math
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .SSN import Conv, Conv2DMod, Decoder, Up
|
13 |
+
from .attention import AttentionBlock
|
14 |
+
from .blocks import ResBlock, Res_Type, get_activation
|
15 |
+
|
16 |
+
|
17 |
+
class Attention_Encoder(nn.Module):
|
18 |
+
def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
|
19 |
+
super(Attention_Encoder, self).__init__()
|
20 |
+
|
21 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
22 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
23 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
24 |
+
|
25 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
26 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
27 |
+
|
28 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
29 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
30 |
+
self.down_256_256_1_attn = AttentionBlock(256, num_heads)
|
31 |
+
|
32 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
33 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
34 |
+
self.down_512_512_1_attn = AttentionBlock(512, num_heads)
|
35 |
+
|
36 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
37 |
+
self.down_512_512_2_attn = AttentionBlock(512, num_heads)
|
38 |
+
|
39 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
40 |
+
self.down_512_512_3_attn = AttentionBlock(512, num_heads)
|
41 |
+
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
45 |
+
x1 = torch.cat((x, x1), dim=1)
|
46 |
+
|
47 |
+
x2 = self.down_32_64(x1)
|
48 |
+
x3 = self.down_64_64_1(x2)
|
49 |
+
|
50 |
+
x4 = self.down_64_128(x3)
|
51 |
+
x5 = self.down_128_128_1(x4)
|
52 |
+
|
53 |
+
x6 = self.down_128_256(x5)
|
54 |
+
x7 = self.down_256_256_1(x6)
|
55 |
+
x7 = self.down_256_256_1_attn(x7)
|
56 |
+
|
57 |
+
x8 = self.down_256_512(x7)
|
58 |
+
x9 = self.down_512_512_1(x8)
|
59 |
+
x9 = self.down_512_512_1_attn(x9)
|
60 |
+
|
61 |
+
x10 = self.down_512_512_2(x9)
|
62 |
+
x10 = self.down_512_512_2_attn(x10)
|
63 |
+
|
64 |
+
x11 = self.down_512_512_3(x10)
|
65 |
+
x11 = self.down_512_512_3_attn(x11)
|
66 |
+
|
67 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
68 |
+
|
69 |
+
|
70 |
+
class Attention_Decoder(nn.Module):
|
71 |
+
def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
|
72 |
+
|
73 |
+
super(Attention_Decoder, self).__init__()
|
74 |
+
|
75 |
+
input_channel = 512
|
76 |
+
fea_dim = 100
|
77 |
+
|
78 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
79 |
+
|
80 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=True, resnet=resnet)
|
81 |
+
self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
|
82 |
+
|
83 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
84 |
+
self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
|
85 |
+
|
86 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
87 |
+
self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
|
88 |
+
|
89 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
90 |
+
self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
|
91 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=True, resnet=resnet)
|
92 |
+
self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
|
93 |
+
|
94 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
95 |
+
self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
|
96 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=True, resnet=resnet)
|
97 |
+
|
98 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
99 |
+
self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
|
100 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=True, resnet=resnet)
|
101 |
+
|
102 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
103 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
104 |
+
self.out_act = get_activation(out_act)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, x, style):
|
108 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
109 |
+
|
110 |
+
style1 = self.to_style1(style)
|
111 |
+
y = self.up_16_16_1(x11, style1) # 256 x 16 x 16
|
112 |
+
y = self.up_16_16_1_attn(y)
|
113 |
+
|
114 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
115 |
+
y = self.up_16_16_2(y, y) # 512 x 16 x 16
|
116 |
+
y = self.up_16_16_2_attn(y)
|
117 |
+
|
118 |
+
|
119 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
120 |
+
y = self.up_16_16_3(y, y) # 512 x 16 x 16
|
121 |
+
y = self.up_16_16_3_attn(y)
|
122 |
+
|
123 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
124 |
+
y = self.up_16_32(y, y) # 256 x 32 x 32
|
125 |
+
|
126 |
+
y = torch.cat((x7, y), dim=1)
|
127 |
+
style2 = self.to_style2(style)
|
128 |
+
y = self.up_32_32_1(y, style2) # 256 x 32 x 32
|
129 |
+
y = self.up_32_32_1_attn(y)
|
130 |
+
|
131 |
+
y = torch.cat((x6, y), dim=1)
|
132 |
+
y = self.up_32_64(y, y)
|
133 |
+
|
134 |
+
y = torch.cat((x5, y), dim=1)
|
135 |
+
style3 = self.to_style3(style)
|
136 |
+
|
137 |
+
y = self.up_64_64_1(y, style3) # 128 x 64 x 64
|
138 |
+
|
139 |
+
y = torch.cat((x4, y), dim=1)
|
140 |
+
y = self.up_64_128(y, y)
|
141 |
+
|
142 |
+
y = torch.cat((x3, y), dim=1)
|
143 |
+
style4 = self.to_style4(style)
|
144 |
+
y = self.up_128_128_1(y, style4) # 64 x 128 x 128
|
145 |
+
|
146 |
+
y = torch.cat((x2, y), dim=1)
|
147 |
+
y = self.up_128_256(y, y) # 32 x 256 x 256
|
148 |
+
|
149 |
+
y = torch.cat((x1, y), dim=1)
|
150 |
+
y = self.out_conv(y, y) # 3 x 256 x 256
|
151 |
+
y = self.out_act(y)
|
152 |
+
return y
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
class Attention_SSN(nn.Module):
|
157 |
+
def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
|
158 |
+
super(Attention_SSN, self).__init__()
|
159 |
+
self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
|
160 |
+
self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
|
161 |
+
|
162 |
+
|
163 |
+
def forward(self, x, softness):
|
164 |
+
latent = self.encoder(x)
|
165 |
+
pred = self.decoder(latent, softness)
|
166 |
+
|
167 |
+
return pred
|
168 |
+
|
169 |
+
|
170 |
+
def get_model_size(model):
|
171 |
+
param_size = 0
|
172 |
+
import pdb; pdb.set_trace()
|
173 |
+
for param in model.parameters():
|
174 |
+
param_size += param.nelement() * param.element_size()
|
175 |
+
|
176 |
+
buffer_size = 0
|
177 |
+
for buffer in model.buffers():
|
178 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
179 |
+
|
180 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
181 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
182 |
+
# return param_size + buffer_size
|
183 |
+
return size_all_mb
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == '__main__':
|
187 |
+
model = AttentionBlock(in_channels=256, num_heads=8)
|
188 |
+
x = torch.randn(5, 256, 64, 64)
|
189 |
+
|
190 |
+
y = model(x)
|
191 |
+
print('{}, {}'.format(x.shape, y.shape))
|
192 |
+
|
193 |
+
# ------------------------------------------------------------------ #
|
194 |
+
in_channels = 3
|
195 |
+
out_channels = 1
|
196 |
+
num_heads = 8
|
197 |
+
resnet = True
|
198 |
+
mid_act = 'gelu'
|
199 |
+
out_act = 'gelu'
|
200 |
+
|
201 |
+
model = Attention_SSN(in_channels=in_channels,
|
202 |
+
out_channels=out_channels,
|
203 |
+
num_heads=num_heads,
|
204 |
+
resnet=resnet,
|
205 |
+
mid_act=mid_act,
|
206 |
+
out_act=out_act)
|
207 |
+
|
208 |
+
x = torch.randn(5, 3, 256, 256)
|
209 |
+
softness = torch.randn(5, 100)
|
210 |
+
|
211 |
+
|
212 |
+
y = model(x, softness)
|
213 |
+
|
214 |
+
|
215 |
+
print('x: {}, y: {}'.format(x.shape, y.shape))
|
216 |
+
|
217 |
+
get_model_size(model)
|
218 |
+
# ------------------------------------------------------------------ #
|
models/Attention_Unet.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .SSN import Conv, Conv2DMod, Decoder, Up
|
8 |
+
from .attention import AttentionBlock
|
9 |
+
from .blocks import ResBlock, Res_Type, get_activation
|
10 |
+
|
11 |
+
class Attention_Encoder(nn.Module):
|
12 |
+
def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
|
13 |
+
super(Attention_Encoder, self).__init__()
|
14 |
+
|
15 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
16 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
17 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
18 |
+
|
19 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
20 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
21 |
+
|
22 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
23 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
24 |
+
self.down_256_256_1_attn = AttentionBlock(256, num_heads)
|
25 |
+
|
26 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
27 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
28 |
+
self.down_512_512_1_attn = AttentionBlock(512, num_heads)
|
29 |
+
|
30 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
31 |
+
self.down_512_512_2_attn = AttentionBlock(512, num_heads)
|
32 |
+
|
33 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
34 |
+
self.down_512_512_3_attn = AttentionBlock(512, num_heads)
|
35 |
+
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
39 |
+
x1 = torch.cat((x, x1), dim=1)
|
40 |
+
|
41 |
+
x2 = self.down_32_64(x1)
|
42 |
+
x3 = self.down_64_64_1(x2)
|
43 |
+
|
44 |
+
x4 = self.down_64_128(x3)
|
45 |
+
x5 = self.down_128_128_1(x4)
|
46 |
+
|
47 |
+
x6 = self.down_128_256(x5)
|
48 |
+
x7 = self.down_256_256_1(x6)
|
49 |
+
x7 = self.down_256_256_1_attn(x7)
|
50 |
+
|
51 |
+
x8 = self.down_256_512(x7)
|
52 |
+
x9 = self.down_512_512_1(x8)
|
53 |
+
x9 = self.down_512_512_1_attn(x9)
|
54 |
+
|
55 |
+
x10 = self.down_512_512_2(x9)
|
56 |
+
x10 = self.down_512_512_2_attn(x10)
|
57 |
+
|
58 |
+
x11 = self.down_512_512_3(x10)
|
59 |
+
x11 = self.down_512_512_3_attn(x11)
|
60 |
+
|
61 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
62 |
+
|
63 |
+
|
64 |
+
class Attention_Decoder(nn.Module):
|
65 |
+
def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
|
66 |
+
|
67 |
+
super(Attention_Decoder, self).__init__()
|
68 |
+
|
69 |
+
input_channel = 512
|
70 |
+
fea_dim = 100
|
71 |
+
|
72 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
73 |
+
|
74 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=False, resnet=resnet)
|
75 |
+
self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
|
76 |
+
|
77 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
78 |
+
self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
|
79 |
+
|
80 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
81 |
+
self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
|
82 |
+
|
83 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
84 |
+
self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
|
85 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=False, resnet=resnet)
|
86 |
+
self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
|
87 |
+
|
88 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
89 |
+
self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
|
90 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=False, resnet=resnet)
|
91 |
+
|
92 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
93 |
+
self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
|
94 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=False, resnet=resnet)
|
95 |
+
|
96 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
97 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
98 |
+
self.out_act = get_activation(out_act)
|
99 |
+
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
103 |
+
|
104 |
+
y = self.up_16_16_1(x11) # 256 x 16 x 16
|
105 |
+
y = self.up_16_16_1_attn(y)
|
106 |
+
|
107 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
108 |
+
y = self.up_16_16_2(y, y) # 512 x 16 x 16
|
109 |
+
y = self.up_16_16_2_attn(y)
|
110 |
+
|
111 |
+
|
112 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
113 |
+
y = self.up_16_16_3(y, y) # 512 x 16 x 16
|
114 |
+
y = self.up_16_16_3_attn(y)
|
115 |
+
|
116 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
117 |
+
y = self.up_16_32(y, y) # 256 x 32 x 32
|
118 |
+
|
119 |
+
y = torch.cat((x7, y), dim=1)
|
120 |
+
y = self.up_32_32_1(y) # 256 x 32 x 32
|
121 |
+
y = self.up_32_32_1_attn(y)
|
122 |
+
|
123 |
+
y = torch.cat((x6, y), dim=1)
|
124 |
+
y = self.up_32_64(y, y)
|
125 |
+
|
126 |
+
y = torch.cat((x5, y), dim=1)
|
127 |
+
|
128 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
129 |
+
|
130 |
+
y = torch.cat((x4, y), dim=1)
|
131 |
+
y = self.up_64_128(y, y)
|
132 |
+
|
133 |
+
y = torch.cat((x3, y), dim=1)
|
134 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
135 |
+
|
136 |
+
y = torch.cat((x2, y), dim=1)
|
137 |
+
y = self.up_128_256(y, y) # 32 x 256 x 256
|
138 |
+
|
139 |
+
y = torch.cat((x1, y), dim=1)
|
140 |
+
y = self.out_conv(y, y) # 3 x 256 x 256
|
141 |
+
y = self.out_act(y)
|
142 |
+
return y
|
143 |
+
|
144 |
+
|
145 |
+
class Attention_Unet(nn.Module):
|
146 |
+
def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
|
147 |
+
super(Attention_Unet, self).__init__()
|
148 |
+
self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
|
149 |
+
self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
|
150 |
+
|
151 |
+
|
152 |
+
def forward(self, x):
|
153 |
+
latent = self.encoder(x)
|
154 |
+
pred = self.decoder(latent)
|
155 |
+
return pred
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == '__main__':
|
159 |
+
test_input = torch.randn(5, 1, 256, 256)
|
160 |
+
style = torch.randn(5, 100)
|
161 |
+
|
162 |
+
model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
|
163 |
+
test_out = model(test_input, style)
|
164 |
+
|
165 |
+
print('Ouptut shape: ', test_out.shape)
|
models/GSSN.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import utils
|
5 |
+
from collections import OrderedDict
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.cm as cm
|
8 |
+
import matplotlib as mpl
|
9 |
+
|
10 |
+
from .abs_model import abs_model
|
11 |
+
from .blocks import *
|
12 |
+
from .SSN import SSN
|
13 |
+
from .SSN_v1 import SSN_v1
|
14 |
+
from .Loss.Loss import norm_loss
|
15 |
+
|
16 |
+
|
17 |
+
class GSSN(abs_model):
|
18 |
+
def __init__(self, opt):
|
19 |
+
mid_act = opt['model']['mid_act']
|
20 |
+
out_act = opt['model']['out_act']
|
21 |
+
in_channels = opt['model']['in_channels']
|
22 |
+
out_channels = opt['model']['out_channels']
|
23 |
+
resnet = opt['model']['resnet']
|
24 |
+
self.ncols = opt['hyper_params']['n_cols']
|
25 |
+
self.focal = opt['model']['focal']
|
26 |
+
|
27 |
+
if 'backbone' not in opt['model'].keys():
|
28 |
+
self.model = SSN(in_channels=in_channels,
|
29 |
+
out_channels=out_channels,
|
30 |
+
mid_act=mid_act,
|
31 |
+
out_act=out_act,
|
32 |
+
resnet=resnet)
|
33 |
+
|
34 |
+
else:
|
35 |
+
backbone = opt['model']['backbone']
|
36 |
+
if backbone == 'vanilla':
|
37 |
+
self.model = SSN(in_channels=in_channels,
|
38 |
+
out_channels=out_channels,
|
39 |
+
mid_act=mid_act,
|
40 |
+
out_act=out_act,
|
41 |
+
resnet=resnet)
|
42 |
+
elif backbone == 'SSN_v1':
|
43 |
+
self.model = SSN_v1(in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
mid_act=mid_act,
|
46 |
+
out_act=out_act,
|
47 |
+
resnet=resnet)
|
48 |
+
else:
|
49 |
+
raise NotImplementedError('{} has not implemented yet'.format(backbone))
|
50 |
+
|
51 |
+
|
52 |
+
self.optimizer = get_optimizer(opt, self.model)
|
53 |
+
self.visualization = {}
|
54 |
+
|
55 |
+
self.norm_loss = norm_loss()
|
56 |
+
|
57 |
+
# inference related
|
58 |
+
BINs = 100
|
59 |
+
MAX_RAD = 20
|
60 |
+
self.size_interval = MAX_RAD / BINs
|
61 |
+
self.soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(BINs)] for j in np.arange(BINs)]
|
62 |
+
|
63 |
+
def setup_input(self, x):
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
x, softness = x
|
69 |
+
return self.model(x, softness)
|
70 |
+
|
71 |
+
|
72 |
+
def compute_loss(self, y, pred):
|
73 |
+
b = y.shape[0]
|
74 |
+
|
75 |
+
total_loss = self.norm_loss.loss(y, pred)
|
76 |
+
|
77 |
+
if self.focal:
|
78 |
+
total_loss = torch.pow(total_loss, 3)
|
79 |
+
|
80 |
+
return total_loss
|
81 |
+
|
82 |
+
|
83 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
84 |
+
optimizer = self.optimizer
|
85 |
+
model = self.model
|
86 |
+
|
87 |
+
x, softness = input_x['x'], input_x['softness']
|
88 |
+
|
89 |
+
optimizer.zero_grad()
|
90 |
+
pred = model(x, softness)
|
91 |
+
loss = self.compute_loss(y, pred)
|
92 |
+
|
93 |
+
if is_training:
|
94 |
+
loss.backward()
|
95 |
+
optimizer.step()
|
96 |
+
|
97 |
+
xc = x.shape[1]
|
98 |
+
for i in range(xc):
|
99 |
+
self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
|
100 |
+
|
101 |
+
self.visualization['y'] = y.detach()
|
102 |
+
self.visualization['pred'] = pred.detach()
|
103 |
+
|
104 |
+
return loss.item()
|
105 |
+
|
106 |
+
|
107 |
+
def get_visualize(self) -> OrderedDict:
|
108 |
+
""" Convert to visualization numpy array
|
109 |
+
"""
|
110 |
+
nrows = self.ncols
|
111 |
+
visualizations = self.visualization
|
112 |
+
ret_vis = OrderedDict()
|
113 |
+
|
114 |
+
for k, v in visualizations.items():
|
115 |
+
batch = v.shape[0]
|
116 |
+
n = min(nrows, batch)
|
117 |
+
|
118 |
+
plot_v = v[:n]
|
119 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
120 |
+
ret_vis[k] = self.plasma(ret_vis[k])
|
121 |
+
|
122 |
+
return ret_vis
|
123 |
+
|
124 |
+
|
125 |
+
def get_logs(self):
|
126 |
+
pass
|
127 |
+
|
128 |
+
|
129 |
+
def inference(self, x):
|
130 |
+
x, l, device = x['x'], x['l'], x['device']
|
131 |
+
|
132 |
+
x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).to(device)
|
133 |
+
l = torch.from_numpy(np.array(self.soft_distribution[int(l/self.size_interval)]).astype(np.float32)).unsqueeze(dim=0).to(device)
|
134 |
+
|
135 |
+
pred = self.forward((x, l))
|
136 |
+
pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
|
137 |
+
return pred
|
138 |
+
|
139 |
+
|
140 |
+
def batch_inference(self, x):
|
141 |
+
x, l = x['x'], x['softness']
|
142 |
+
pred = self.forward((x, l))
|
143 |
+
return pred
|
144 |
+
|
145 |
+
|
146 |
+
""" Getter & Setter
|
147 |
+
"""
|
148 |
+
def get_models(self) -> dict:
|
149 |
+
return {'model': self.model}
|
150 |
+
|
151 |
+
|
152 |
+
def get_optimizers(self) -> dict:
|
153 |
+
return {'optimizer': self.optimizer}
|
154 |
+
|
155 |
+
|
156 |
+
def set_models(self, models: dict) :
|
157 |
+
# input test
|
158 |
+
if 'model' not in models.keys():
|
159 |
+
raise ValueError('{} not in self.model'.format('model'))
|
160 |
+
|
161 |
+
self.model = models['model']
|
162 |
+
|
163 |
+
|
164 |
+
def set_optimizers(self, optimizer: dict):
|
165 |
+
self.optimizer = optimizer['optimizer']
|
166 |
+
|
167 |
+
|
168 |
+
####################
|
169 |
+
# Personal Methods #
|
170 |
+
####################
|
171 |
+
def plasma(self, x):
|
172 |
+
norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
|
173 |
+
mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
|
174 |
+
bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
|
175 |
+
|
176 |
+
return bimg
|
models/Loss/Loss.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.transforms as T
|
5 |
+
from torch.autograd import Variable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
# from vgg19_loss import VGG19Loss
|
11 |
+
# import pytorch_ssim
|
12 |
+
|
13 |
+
from .vgg19_loss import VGG19Loss
|
14 |
+
from . import pytorch_ssim
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from collections import OrderedDict
|
17 |
+
|
18 |
+
class abs_loss(ABC):
|
19 |
+
def loss(self, gt_img, pred_img):
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class norm_loss(abs_loss):
|
24 |
+
def __init__(self, norm=1):
|
25 |
+
self.norm = norm
|
26 |
+
|
27 |
+
|
28 |
+
def loss(self, gt_img, pred_img):
|
29 |
+
""" M * (I-I') """
|
30 |
+
b, c, h, w = gt_img.shape
|
31 |
+
return torch.norm(gt_img-pred_img, self.norm)/(h * w * b)
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class ssim_loss(abs_loss):
|
36 |
+
def __init__(self, window_size=11, channel=1):
|
37 |
+
""" Let's try mean ssim!
|
38 |
+
"""
|
39 |
+
self.channel = channel
|
40 |
+
self.window_size = window_size
|
41 |
+
self.window = self.create_mean_window(window_size, channel)
|
42 |
+
|
43 |
+
|
44 |
+
def loss(self, gt_img, pred_img):
|
45 |
+
b, c, h, w = gt_img.shape
|
46 |
+
if c != self.channel:
|
47 |
+
self.channel = c
|
48 |
+
self.window = self.create_mean_window(self.window_size, self.channel)
|
49 |
+
|
50 |
+
self.window = self.window.to(gt_img).type_as(gt_img)
|
51 |
+
l = 1.0 - self.ssim_compute(gt_img, pred_img)
|
52 |
+
return l
|
53 |
+
|
54 |
+
|
55 |
+
def create_mean_window(self, window_size, channel):
|
56 |
+
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
|
57 |
+
window = window/(window_size * window_size)
|
58 |
+
return window
|
59 |
+
|
60 |
+
|
61 |
+
def ssim_compute(self, gt_img, pred_img):
|
62 |
+
window = self.window
|
63 |
+
window_size = self.window_size
|
64 |
+
channel = self.channel
|
65 |
+
|
66 |
+
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel)
|
67 |
+
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel)
|
68 |
+
|
69 |
+
mu1_sq = mu1.pow(2)
|
70 |
+
mu2_sq = mu2.pow(2)
|
71 |
+
mu1_mu2 = mu1*mu2
|
72 |
+
|
73 |
+
sigma1_sq = F.conv2d(gt_img*gt_img, window, padding = window_size//2, groups = channel) - mu1_sq
|
74 |
+
sigma2_sq = F.conv2d(pred_img*pred_img, window, padding = window_size//2, groups = channel) - mu2_sq
|
75 |
+
sigma12 = F.conv2d(gt_img*pred_img, window, padding = window_size//2, groups = channel) - mu1_mu2
|
76 |
+
|
77 |
+
C1 = 0.01**2
|
78 |
+
C2 = 0.03**2
|
79 |
+
|
80 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
81 |
+
|
82 |
+
return ssim_map.mean()
|
83 |
+
|
84 |
+
|
85 |
+
class hierarchical_ssim_loss(abs_loss):
|
86 |
+
def __init__(self, patch_list: list):
|
87 |
+
self.ssim_loss_list = [pytorch_ssim.SSIM(window_size=ws) for ws in patch_list]
|
88 |
+
|
89 |
+
|
90 |
+
def loss(self, gt_img, pred_img):
|
91 |
+
b, c, h, w = gt_img.shape
|
92 |
+
total_loss = 0.0
|
93 |
+
for loss_func in self.ssim_loss_list:
|
94 |
+
total_loss += (1.0-loss_func(gt_img, pred_img))
|
95 |
+
|
96 |
+
return total_loss/b
|
97 |
+
|
98 |
+
|
99 |
+
class vgg_loss(abs_loss):
|
100 |
+
def __init__(self):
|
101 |
+
self.vgg19_ = VGG19Loss()
|
102 |
+
|
103 |
+
|
104 |
+
def loss(self, gt_img, pred_img):
|
105 |
+
b, c, h, w = gt_img.shape
|
106 |
+
v = self.vgg19_(gt_img, pred_img, pred_img.device)
|
107 |
+
return v/b
|
108 |
+
|
109 |
+
|
110 |
+
class grad_loss(abs_loss):
|
111 |
+
def __init__(self, k=4):
|
112 |
+
self.k = 4
|
113 |
+
|
114 |
+
def loss(self, disp_img, rgb_img=None):
|
115 |
+
""" Note, gradient loss should be weighted by an edge-aware weight
|
116 |
+
"""
|
117 |
+
b, c, h, w = disp_img.shape
|
118 |
+
|
119 |
+
grad_loss = 0.0
|
120 |
+
for i in range(self.k):
|
121 |
+
div_factor = 2 ** i
|
122 |
+
cur_transform = T.Resize([h // div_factor, ])
|
123 |
+
# cur_diff = cur_transform(diff)
|
124 |
+
# cur_diff_dx, cur_diff_dy = self.img_grad(cur_diff)
|
125 |
+
cur_disp = cur_transform(disp_img)
|
126 |
+
|
127 |
+
cur_disp_dx, cur_disp_dy = self.img_grad(cur_disp)
|
128 |
+
|
129 |
+
if rgb_img is not None:
|
130 |
+
cur_rgb = cur_transform(rgb_img)
|
131 |
+
cur_rgb_dx, cur_rgb_dy = self.img_grad(cur_rgb)
|
132 |
+
|
133 |
+
cur_rgb_dx = torch.exp(-torch.mean(torch.abs(cur_rgb_dx), dim=1, keepdims=True))
|
134 |
+
cur_rgb_dy = torch.exp(-torch.mean(torch.abs(cur_rgb_dy), dim=1, keepdims=True))
|
135 |
+
grad_loss += (torch.sum(torch.abs(cur_disp_dx) * cur_rgb_dx) + torch.sum(torch.abs(cur_disp_dy) * cur_rgb_dy)) / (h * w * self.k)
|
136 |
+
else:
|
137 |
+
grad_loss += (torch.sum(torch.abs(cur_disp_dx)) + torch.sum(torch.abs(cur_disp_dy))) / (h * w * self.k)
|
138 |
+
|
139 |
+
return grad_loss/b
|
140 |
+
|
141 |
+
|
142 |
+
def gloss(self, gt, pred):
|
143 |
+
""" Loss on the gradient domain
|
144 |
+
"""
|
145 |
+
b, c, h, w = gt.shape
|
146 |
+
gt_dx, gt_dy = self.img_grad(gt)
|
147 |
+
pred_dx, pred_dy = self.img_grad(pred)
|
148 |
+
|
149 |
+
loss = (gt_dx-pred_dx) ** 2 + (gt_dy - pred_dy) ** 2
|
150 |
+
return loss.sum()/(b * h * w)
|
151 |
+
|
152 |
+
|
153 |
+
def laploss(self, pred):
|
154 |
+
b, c, h, w = pred.shape
|
155 |
+
lap = self.img_laplacian(pred)
|
156 |
+
|
157 |
+
return torch.abs(lap).sum()/(b * h * w)
|
158 |
+
|
159 |
+
|
160 |
+
def img_laplacian(self, img):
|
161 |
+
b, c, h, w = img.shape
|
162 |
+
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
|
163 |
+
|
164 |
+
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
165 |
+
|
166 |
+
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
|
167 |
+
return lap
|
168 |
+
|
169 |
+
|
170 |
+
def img_grad(self, img):
|
171 |
+
""" Comptue image gradient by sobel filtering
|
172 |
+
img: B x C x H x W
|
173 |
+
"""
|
174 |
+
|
175 |
+
b, c, h, w = img.shape
|
176 |
+
ysobel = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
|
177 |
+
xsobel = ysobel.transpose(0,1)
|
178 |
+
|
179 |
+
xsobel_kernel = xsobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
180 |
+
ysobel_kernel = ysobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
181 |
+
dx = F.conv2d(img, xsobel_kernel, padding=1, stride=1)
|
182 |
+
dy = F.conv2d(img, ysobel_kernel, padding=1, stride=1)
|
183 |
+
|
184 |
+
return dx, dy
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
class sharp_loss(abs_loss):
|
189 |
+
""" Sharpness term
|
190 |
+
1. laplacian
|
191 |
+
2. image contrast
|
192 |
+
3. image variance
|
193 |
+
"""
|
194 |
+
def __init__(self, window_size=11, channel=1):
|
195 |
+
self.window_size = window_size
|
196 |
+
self.channel = channel
|
197 |
+
self.window = self.create_mean_window(window_size, self.channel)
|
198 |
+
|
199 |
+
|
200 |
+
def loss(self, gt_img, pred_img):
|
201 |
+
""" Note, gradient loss should be weighted by an edge-aware weight
|
202 |
+
"""
|
203 |
+
b, c, h, w = gt_img.shape
|
204 |
+
|
205 |
+
if c != self.channel:
|
206 |
+
self.channel = c
|
207 |
+
self.window = self.create_mean_window(self.window_size, self.channel)
|
208 |
+
|
209 |
+
self.window = self.window.to(gt_img).type_as(gt_img)
|
210 |
+
|
211 |
+
channel = self.channel
|
212 |
+
window = self.window
|
213 |
+
window_size = self.window_size
|
214 |
+
|
215 |
+
mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) + 1e-6
|
216 |
+
mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) + 1e-6
|
217 |
+
|
218 |
+
constrast1 = torch.absolute((gt_img - mu1)/mu1)
|
219 |
+
constrast2 = torch.absolute((pred_img - mu2)/mu2)
|
220 |
+
|
221 |
+
variance1 = (gt_img-mu1) ** 2
|
222 |
+
variance2 = (pred_img-mu2) ** 2
|
223 |
+
|
224 |
+
laplacian1 = self.img_laplacian(gt_img)
|
225 |
+
laplacian2 = self.img_laplacian(pred_img)
|
226 |
+
|
227 |
+
S1 = -laplacian1 - constrast1 - variance1
|
228 |
+
S2 = -laplacian2 - constrast2 - variance2
|
229 |
+
|
230 |
+
# import pdb; pdb.set_trace()
|
231 |
+
total = torch.absolute(S1-S2).mean()
|
232 |
+
return total
|
233 |
+
|
234 |
+
|
235 |
+
def img_laplacian(self, img):
|
236 |
+
b, c, h, w = img.shape
|
237 |
+
laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
|
238 |
+
|
239 |
+
laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
|
240 |
+
|
241 |
+
lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
|
242 |
+
return lap
|
243 |
+
|
244 |
+
|
245 |
+
def create_mean_window(self, window_size, channel):
|
246 |
+
window = Variable(torch.ones(channel, 1, window_size, window_size).float())
|
247 |
+
window = window/(window_size * window_size)
|
248 |
+
return window
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == '__main__':
|
252 |
+
a = torch.rand(3,3,128,128)
|
253 |
+
b = torch.rand(3,3,128,128)
|
254 |
+
|
255 |
+
ssim = ssim_loss()
|
256 |
+
loss = ssim.loss(a, b)
|
257 |
+
print(loss.shape, loss)
|
258 |
+
|
259 |
+
loss = ssim.loss(a, a)
|
260 |
+
print(loss.shape, loss)
|
261 |
+
|
262 |
+
loss = ssim.loss(b, b)
|
263 |
+
print(loss.shape, loss)
|
264 |
+
|
265 |
+
grad = grad_loss()
|
266 |
+
loss = grad.loss(a, [b, b])
|
267 |
+
print(loss.shape, loss)
|
268 |
+
|
269 |
+
sharp = sharp_loss()
|
270 |
+
loss = sharp.loss(a, b)
|
271 |
+
print(loss.shape, loss)
|
models/Loss/__init__.py
ADDED
File without changes
|
models/Loss/__pycache__/Loss.cpython-39.pyc
ADDED
Binary file (8.36 kB). View file
|
|
models/Loss/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (164 Bytes). View file
|
|
models/Loss/__pycache__/vgg19_loss.cpython-39.pyc
ADDED
Binary file (2.08 kB). View file
|
|
models/Loss/pytorch_ssim/__init__.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import numpy as np
|
5 |
+
from math import exp
|
6 |
+
|
7 |
+
def gaussian(window_size, sigma):
|
8 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
9 |
+
return gauss/gauss.sum()
|
10 |
+
|
11 |
+
def create_window(window_size, channel):
|
12 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
13 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
14 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
15 |
+
return window
|
16 |
+
|
17 |
+
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
18 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
19 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
20 |
+
|
21 |
+
mu1_sq = mu1.pow(2)
|
22 |
+
mu2_sq = mu2.pow(2)
|
23 |
+
mu1_mu2 = mu1*mu2
|
24 |
+
|
25 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
26 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
27 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
28 |
+
|
29 |
+
C1 = 0.01**2
|
30 |
+
C2 = 0.03**2
|
31 |
+
|
32 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
33 |
+
|
34 |
+
if size_average:
|
35 |
+
return ssim_map.mean()
|
36 |
+
else:
|
37 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
38 |
+
|
39 |
+
class SSIM(torch.nn.Module):
|
40 |
+
def __init__(self, window_size = 11, size_average = True):
|
41 |
+
super(SSIM, self).__init__()
|
42 |
+
self.window_size = window_size
|
43 |
+
self.size_average = size_average
|
44 |
+
self.channel = 1
|
45 |
+
self.window = create_window(window_size, self.channel)
|
46 |
+
|
47 |
+
def forward(self, img1, img2):
|
48 |
+
(_, channel, _, _) = img1.size()
|
49 |
+
|
50 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
51 |
+
window = self.window
|
52 |
+
else:
|
53 |
+
window = create_window(self.window_size, channel)
|
54 |
+
|
55 |
+
if img1.is_cuda:
|
56 |
+
window = window.cuda(img1.get_device())
|
57 |
+
window = window.type_as(img1)
|
58 |
+
|
59 |
+
self.window = window
|
60 |
+
self.channel = channel
|
61 |
+
|
62 |
+
|
63 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
64 |
+
|
65 |
+
def ssim(img1, img2, window_size = 11, size_average = True):
|
66 |
+
(_, channel, _, _) = img1.size()
|
67 |
+
window = create_window(window_size, channel)
|
68 |
+
|
69 |
+
if img1.is_cuda:
|
70 |
+
window = window.cuda(img1.get_device())
|
71 |
+
window = window.type_as(img1)
|
72 |
+
|
73 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|
models/Loss/pytorch_ssim/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (2.65 kB). View file
|
|
models/Loss/vgg19_loss.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
|
5 |
+
class FeatureExtractor(nn.Module):
|
6 |
+
def __init__(self, cnn, feature_layer=11):
|
7 |
+
super(FeatureExtractor, self).__init__()
|
8 |
+
self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])
|
9 |
+
|
10 |
+
def normalize(self, tensors, mean, std):
|
11 |
+
if not torch.is_tensor(tensors):
|
12 |
+
raise TypeError('tensor is not a torch image.')
|
13 |
+
for tensor in tensors:
|
14 |
+
for t, m, s in zip(tensor, mean, std):
|
15 |
+
t.sub_(m).div_(s)
|
16 |
+
return tensors
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
# it image is gray scale then make it to 3 channel
|
20 |
+
if x.size()[1] == 1:
|
21 |
+
x = x.expand(-1, 3, -1, -1)
|
22 |
+
|
23 |
+
# [-1: 1] image to [0:1] image---------------------------------------------------(1)
|
24 |
+
x = (x + 1) * 0.5
|
25 |
+
|
26 |
+
# https://pytorch.org/docs/stable/torchvision/models.html
|
27 |
+
x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
28 |
+
return self.features(x)
|
29 |
+
|
30 |
+
# Feature extracting using vgg19
|
31 |
+
vgg19 = torchvision.models.vgg19(pretrained=True)
|
32 |
+
feature_extractor = FeatureExtractor(vgg19, feature_layer=35)
|
33 |
+
feature_extractor.eval()
|
34 |
+
|
35 |
+
class VGG19Loss(object):
|
36 |
+
def __init__(self):
|
37 |
+
global feature_extractor
|
38 |
+
self.initialized = False
|
39 |
+
self.feature_extractor = feature_extractor
|
40 |
+
self.MSE = nn.MSELoss()
|
41 |
+
|
42 |
+
def __call__(self, output, target, device):
|
43 |
+
if self.initialized == False:
|
44 |
+
self.feature_extractor = self.feature_extractor.to(device)
|
45 |
+
self.MSE = self.MSE.to(device)
|
46 |
+
self.initialized = True
|
47 |
+
|
48 |
+
# [-1: 1] image to [0:1] image---------------------------------------------------(2)
|
49 |
+
output = (output + 1) * 0.5
|
50 |
+
target = (target + 1) * 0.5
|
51 |
+
|
52 |
+
output = self.feature_extractor(output)
|
53 |
+
target = self.feature_extractor(target).data
|
54 |
+
return self.MSE(output, target)
|
models/SSN.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import utils
|
5 |
+
from collections import OrderedDict
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .abs_model import abs_model
|
9 |
+
from .Loss.Loss import norm_loss
|
10 |
+
from .blocks import *
|
11 |
+
from .SSN_Model import SSN_Model
|
12 |
+
|
13 |
+
|
14 |
+
class SSN(abs_model):
|
15 |
+
def __init__(self, opt):
|
16 |
+
mid_act = opt['model']['mid_act']
|
17 |
+
out_act = opt['model']['out_act']
|
18 |
+
in_channels = opt['model']['in_channels']
|
19 |
+
out_channels = opt['model']['out_channels']
|
20 |
+
self.ncols = opt['hyper_params']['n_cols']
|
21 |
+
|
22 |
+
self.model = SSN_Model(in_channels=in_channels, out_channels=out_channels, mid_act=mid_act, out_act=out_act)
|
23 |
+
self.optimizer = get_optimizer(opt, self.model)
|
24 |
+
self.visualization = {}
|
25 |
+
|
26 |
+
self.norm_loss_ = norm_loss(norm=1)
|
27 |
+
|
28 |
+
def setup_input(self, x):
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
keys = ['mask', 'ibl']
|
34 |
+
|
35 |
+
for k in keys:
|
36 |
+
assert k in x.keys(), '{} not in input'.format(k)
|
37 |
+
|
38 |
+
mask = x['mask']
|
39 |
+
ibl = x['ibl']
|
40 |
+
|
41 |
+
return self.model(mask, ibl)
|
42 |
+
|
43 |
+
|
44 |
+
def compute_loss(self, y, pred):
|
45 |
+
total_loss = self.norm_loss_.loss(y, pred)
|
46 |
+
return total_loss
|
47 |
+
|
48 |
+
|
49 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
50 |
+
optimizer = self.optimizer
|
51 |
+
model = self.model
|
52 |
+
|
53 |
+
optimizer.zero_grad()
|
54 |
+
pred = self.forward(input_x)
|
55 |
+
loss = self.compute_loss(y, pred)
|
56 |
+
|
57 |
+
# logging.info('Pred/Target: {}, {}/{}, {}'.format(pred.min().item(), pred.max().item(), y.min().item(), y.max().item()))
|
58 |
+
|
59 |
+
if is_training:
|
60 |
+
loss.backward()
|
61 |
+
optimizer.step()
|
62 |
+
|
63 |
+
self.visualization['mask'] = input_x['mask'].detach()
|
64 |
+
self.visualization['ibl'] = input_x['ibl'].detach()
|
65 |
+
self.visualization['y'] = y.detach()
|
66 |
+
self.visualization['pred'] = pred.detach()
|
67 |
+
|
68 |
+
return loss.item()
|
69 |
+
|
70 |
+
|
71 |
+
def get_visualize(self) -> OrderedDict:
|
72 |
+
""" Convert to visualization numpy array
|
73 |
+
"""
|
74 |
+
nrows = self.ncols
|
75 |
+
visualizations = self.visualization
|
76 |
+
ret_vis = OrderedDict()
|
77 |
+
|
78 |
+
for k, v in visualizations.items():
|
79 |
+
batch = v.shape[0]
|
80 |
+
n = min(nrows, batch)
|
81 |
+
|
82 |
+
plot_v = v[:n]
|
83 |
+
plot_v = (plot_v - plot_v.min())/(plot_v.max() - plot_v.min())
|
84 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
85 |
+
|
86 |
+
return ret_vis
|
87 |
+
|
88 |
+
|
89 |
+
def get_logs(self):
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
def inference(self, x):
|
94 |
+
keys = ['mask', 'ibl']
|
95 |
+
for k in keys:
|
96 |
+
assert k in x.keys(), '{} not in input'.format(k)
|
97 |
+
assert len(x[k].shape) == 2, '{} should be 2D tensor'.format(k)
|
98 |
+
|
99 |
+
|
100 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
101 |
+
|
102 |
+
mask = torch.tensor(x['mask'])[None, None, ...].float().to(device)
|
103 |
+
ibl = torch.tensor(x['ibl'])[None, None, ...].float().to(device)
|
104 |
+
|
105 |
+
input_x = {'mask': mask, 'ibl': ibl}
|
106 |
+
pred = self.forward(input_x)
|
107 |
+
|
108 |
+
pred = np.clip(pred[0, 0].detach().cpu().numpy() / 30.0, 0.0, 1.0)
|
109 |
+
return pred
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def batch_inference(self, x):
|
114 |
+
# TODO
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
+
""" Getter & Setter
|
119 |
+
"""
|
120 |
+
def get_models(self) -> dict:
|
121 |
+
return {'model': self.model}
|
122 |
+
|
123 |
+
|
124 |
+
def get_optimizers(self) -> dict:
|
125 |
+
return {'optimizer': self.optimizer}
|
126 |
+
|
127 |
+
|
128 |
+
def set_models(self, models: dict) :
|
129 |
+
# input test
|
130 |
+
if 'model' not in models.keys():
|
131 |
+
raise ValueError('{} not in self.model'.format('model'))
|
132 |
+
|
133 |
+
self.model = models['model']
|
134 |
+
|
135 |
+
|
136 |
+
def set_optimizers(self, optimizer: dict):
|
137 |
+
self.optimizer = optimizer['optimizer']
|
138 |
+
|
139 |
+
####################
|
140 |
+
# Personal Methods #
|
141 |
+
####################
|
142 |
+
|
143 |
+
|
models/SSN_Model.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
import logging
|
6 |
+
|
7 |
+
def weights_init(init_type='gaussian', std=0.02):
|
8 |
+
def init_fun(m):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if (classname.find('Conv') == 0 or classname.find(
|
11 |
+
'Linear') == 0) and hasattr(m, 'weight'):
|
12 |
+
if init_type == 'gaussian':
|
13 |
+
nn.init.normal_(m.weight, 0.0, std)
|
14 |
+
elif init_type == 'xavier':
|
15 |
+
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
|
16 |
+
elif init_type == 'kaiming':
|
17 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
18 |
+
elif init_type == 'orthogonal':
|
19 |
+
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
|
20 |
+
elif init_type == 'default':
|
21 |
+
pass
|
22 |
+
else:
|
23 |
+
assert 0, "Unsupported initialization: {}".format(init_type)
|
24 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
25 |
+
nn.init.constant_(m.bias, 0.0)
|
26 |
+
|
27 |
+
return init_fun
|
28 |
+
|
29 |
+
def freeze(module):
|
30 |
+
for param in module.parameters():
|
31 |
+
param.requires_grad = False
|
32 |
+
|
33 |
+
def unfreeze(module):
|
34 |
+
for param in module.parameters():
|
35 |
+
param.requires_grad = True
|
36 |
+
|
37 |
+
def get_optimizer(opt, model):
|
38 |
+
lr = float(opt['hyper_params']['lr'])
|
39 |
+
beta1 = float(opt['model']['beta1'])
|
40 |
+
weight_decay = float(opt['model']['weight_decay'])
|
41 |
+
opt_name = opt['model']['optimizer']
|
42 |
+
|
43 |
+
optim_params = []
|
44 |
+
# weight decay
|
45 |
+
for key, value in model.named_parameters():
|
46 |
+
if not value.requires_grad:
|
47 |
+
continue # frozen weights
|
48 |
+
|
49 |
+
if key[-4:] == 'bias':
|
50 |
+
optim_params += [{'params': value, 'weight_decay': 0.0}]
|
51 |
+
else:
|
52 |
+
optim_params += [{'params': value,
|
53 |
+
'weight_decay': weight_decay}]
|
54 |
+
|
55 |
+
if opt_name == 'Adam':
|
56 |
+
return optim.Adam(optim_params,
|
57 |
+
lr=lr,
|
58 |
+
betas=(beta1, 0.999),
|
59 |
+
eps=1e-5)
|
60 |
+
else:
|
61 |
+
err = '{} not implemented yet'.format(opt_name)
|
62 |
+
logging.error(err)
|
63 |
+
raise NotImplementedError(err)
|
64 |
+
|
65 |
+
|
66 |
+
def get_activation(activation):
|
67 |
+
if activation is None:
|
68 |
+
return nn.Identity()
|
69 |
+
|
70 |
+
act_func = {
|
71 |
+
'relu':nn.ReLU(),
|
72 |
+
'sigmoid':nn.Sigmoid(),
|
73 |
+
'tanh':nn.Tanh(),
|
74 |
+
'prelu':nn.PReLU(),
|
75 |
+
'leaky':nn.LeakyReLU(0.2),
|
76 |
+
'gelu':nn.GELU(),
|
77 |
+
}
|
78 |
+
if activation not in act_func.keys():
|
79 |
+
logging.error("activation {} is not implemented yet".format(activation))
|
80 |
+
assert False
|
81 |
+
|
82 |
+
return act_func[activation]
|
83 |
+
|
84 |
+
def get_norm(out_channels, norm_type='Instance'):
|
85 |
+
norm_set = ['Instance', 'Batch', 'Group']
|
86 |
+
if norm_type not in norm_set:
|
87 |
+
err = "Normalization {} has not been implemented yet"
|
88 |
+
logging.error(err)
|
89 |
+
raise ValueError(err)
|
90 |
+
|
91 |
+
if norm_type == 'Instance':
|
92 |
+
return nn.InstanceNorm2d(out_channels, affine=True)
|
93 |
+
|
94 |
+
if norm_type == 'Batch':
|
95 |
+
return nn.BatchNorm2d(out_channels)
|
96 |
+
|
97 |
+
if norm_type == 'Group':
|
98 |
+
if out_channels >= 32:
|
99 |
+
groups = 32
|
100 |
+
else:
|
101 |
+
groups = 1
|
102 |
+
|
103 |
+
return nn.GroupNorm(groups, out_channels)
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise NotImplementedError('{} has not implemented yet'.format(norm_type))
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
def get_layer_info(out_channels, activation_func='relu'):
|
111 |
+
activation = get_activation(activation_func)
|
112 |
+
norm_layer = get_norm(out_channels, 'Group')
|
113 |
+
return norm_layer, activation
|
114 |
+
|
115 |
+
|
116 |
+
class Conv(nn.Module):
|
117 |
+
""" (convolution => [BN] => ReLU) """
|
118 |
+
def __init__(self,
|
119 |
+
in_channels,
|
120 |
+
out_channels,
|
121 |
+
kernel_size=3,
|
122 |
+
stride=1,
|
123 |
+
padding=1,
|
124 |
+
bias=True,
|
125 |
+
activation='leaky',
|
126 |
+
resnet=True):
|
127 |
+
super().__init__()
|
128 |
+
|
129 |
+
norm_layer, act_func = get_layer_info(out_channels,activation)
|
130 |
+
|
131 |
+
if resnet and in_channels == out_channels:
|
132 |
+
self.resnet = True
|
133 |
+
else:
|
134 |
+
self.resnet = False
|
135 |
+
|
136 |
+
self.conv = nn.Sequential(
|
137 |
+
nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias),
|
138 |
+
norm_layer,
|
139 |
+
act_func)
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
res = self.conv(x)
|
143 |
+
|
144 |
+
if self.resnet:
|
145 |
+
res = res + x
|
146 |
+
|
147 |
+
return res
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class Up(nn.Module):
|
152 |
+
""" Upscaling then conv """
|
153 |
+
|
154 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
158 |
+
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x = self.up_layer(x)
|
162 |
+
return self.up(x)
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
class DConv(nn.Module):
|
167 |
+
""" Double Conv Layer
|
168 |
+
"""
|
169 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
173 |
+
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
return self.conv2(self.conv1(x))
|
177 |
+
|
178 |
+
|
179 |
+
class Encoder(nn.Module):
|
180 |
+
def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
|
181 |
+
super(Encoder, self).__init__()
|
182 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
183 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
184 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
185 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
186 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
187 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
188 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
189 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
190 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
191 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
192 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
193 |
+
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
197 |
+
x1 = torch.cat((x, x1), dim=1)
|
198 |
+
|
199 |
+
x2 = self.down_32_64(x1)
|
200 |
+
x3 = self.down_64_64_1(x2)
|
201 |
+
|
202 |
+
x4 = self.down_64_128(x3)
|
203 |
+
x5 = self.down_128_128_1(x4)
|
204 |
+
|
205 |
+
x6 = self.down_128_256(x5)
|
206 |
+
x7 = self.down_256_256_1(x6)
|
207 |
+
|
208 |
+
x8 = self.down_256_512(x7)
|
209 |
+
x9 = self.down_512_512_1(x8)
|
210 |
+
x10 = self.down_512_512_2(x9)
|
211 |
+
x11 = self.down_512_512_3(x10)
|
212 |
+
|
213 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
214 |
+
|
215 |
+
|
216 |
+
class Decoder(nn.Module):
|
217 |
+
""" Up Stream Sequence """
|
218 |
+
|
219 |
+
def __init__(self,
|
220 |
+
out_channels=3,
|
221 |
+
mid_act='relu',
|
222 |
+
out_act='sigmoid',
|
223 |
+
resnet = True):
|
224 |
+
|
225 |
+
super(Decoder, self).__init__()
|
226 |
+
|
227 |
+
input_channel = 512
|
228 |
+
fea_dim = 100
|
229 |
+
|
230 |
+
|
231 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
|
232 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
233 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
234 |
+
|
235 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
236 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
|
237 |
+
|
238 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
239 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
|
240 |
+
|
241 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
242 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
|
243 |
+
|
244 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
245 |
+
self.out_conv = Conv(64, out_channels, activation=out_act)
|
246 |
+
|
247 |
+
|
248 |
+
def forward(self, x, ibl):
|
249 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
250 |
+
|
251 |
+
h,w = x10.shape[2:]
|
252 |
+
y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w)
|
253 |
+
|
254 |
+
y = self.up_16_16_1(y) # 256 x 16 x 16
|
255 |
+
|
256 |
+
y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
|
257 |
+
y = self.up_16_16_2(y) # 512 x 16 x 16
|
258 |
+
|
259 |
+
|
260 |
+
y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
|
261 |
+
y = self.up_16_16_3(y) # 512 x 16 x 16
|
262 |
+
|
263 |
+
y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
|
264 |
+
y = self.up_16_32(y) # 256 x 32 x 32
|
265 |
+
|
266 |
+
y = torch.cat((x7, y), dim=1)
|
267 |
+
y = self.up_32_32_1(y) # 256 x 32 x 32
|
268 |
+
|
269 |
+
y = torch.cat((x6, y), dim=1)
|
270 |
+
y = self.up_32_64(y)
|
271 |
+
|
272 |
+
y = torch.cat((x5, y), dim=1)
|
273 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
274 |
+
|
275 |
+
y = torch.cat((x4, y), dim=1)
|
276 |
+
y = self.up_64_128(y)
|
277 |
+
|
278 |
+
y = torch.cat((x3, y), dim=1)
|
279 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
280 |
+
|
281 |
+
y = torch.cat((x2, y), dim=1)
|
282 |
+
y = self.up_128_256(y) # 32 x 256 x 256
|
283 |
+
|
284 |
+
y = torch.cat((x1, y), dim=1)
|
285 |
+
y = self.out_conv(y) # 3 x 256 x 256
|
286 |
+
|
287 |
+
return y
|
288 |
+
|
289 |
+
|
290 |
+
class SSN_Model(nn.Module):
|
291 |
+
""" Implementation of Relighting Net """
|
292 |
+
|
293 |
+
def __init__(self,
|
294 |
+
in_channels=3,
|
295 |
+
out_channels=3,
|
296 |
+
mid_act='leaky',
|
297 |
+
out_act='sigmoid',
|
298 |
+
resnet=True):
|
299 |
+
super(SSN_Model, self).__init__()
|
300 |
+
|
301 |
+
self.out_act = out_act
|
302 |
+
|
303 |
+
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
|
304 |
+
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
|
305 |
+
|
306 |
+
# init weights
|
307 |
+
init_func = weights_init('gaussian', std=1e-3)
|
308 |
+
self.encoder.apply(init_func)
|
309 |
+
self.decoder.apply(init_func)
|
310 |
+
|
311 |
+
|
312 |
+
def forward(self, x, ibl):
|
313 |
+
"""
|
314 |
+
Input is (source image, target light, source light, )
|
315 |
+
Output is: predicted new image, predicted source light, self-supervision image
|
316 |
+
"""
|
317 |
+
latent = self.encoder(x)
|
318 |
+
pred = self.decoder(latent, ibl)
|
319 |
+
|
320 |
+
if self.out_act == 'sigmoid':
|
321 |
+
pred = pred * 30.0
|
322 |
+
|
323 |
+
return pred
|
324 |
+
|
325 |
+
|
326 |
+
if __name__ == '__main__':
|
327 |
+
x = torch.randn(5,1,256,256)
|
328 |
+
ibl = torch.randn(5, 1, 32, 16)
|
329 |
+
model = SSN_Model(1,1)
|
330 |
+
|
331 |
+
y = model(x, ibl)
|
332 |
+
|
333 |
+
print('Output: ', y.shape)
|
models/SSN_v1.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import transforms
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def get_activation(activation_func):
|
8 |
+
act_func = {
|
9 |
+
"relu":nn.ReLU(),
|
10 |
+
"sigmoid":nn.Sigmoid(),
|
11 |
+
"prelu":nn.PReLU(num_parameters=1),
|
12 |
+
"leaky_relu": nn.LeakyReLU(negative_slope=0.2, inplace=False),
|
13 |
+
"gelu":nn.GELU()
|
14 |
+
}
|
15 |
+
|
16 |
+
if activation_func is None:
|
17 |
+
return nn.Identity()
|
18 |
+
|
19 |
+
if activation_func not in act_func.keys():
|
20 |
+
raise ValueError("activation function({}) is not found".format(activation_func))
|
21 |
+
|
22 |
+
activation = act_func[activation_func]
|
23 |
+
return activation
|
24 |
+
|
25 |
+
|
26 |
+
def get_layer_info(out_channels, activation_func='relu'):
|
27 |
+
#act_func = {"relu":nn.ReLU(), "sigmoid":nn.Sigmoid(), "prelu":nn.PReLU(num_parameters=out_channels)}
|
28 |
+
|
29 |
+
# norm_layer = nn.BatchNorm2d(out_channels, momentum=0.9)
|
30 |
+
if out_channels >= 32:
|
31 |
+
groups = 32
|
32 |
+
else:
|
33 |
+
groups = 1
|
34 |
+
|
35 |
+
norm_layer = nn.GroupNorm(groups, out_channels)
|
36 |
+
activation = get_activation(activation_func)
|
37 |
+
return norm_layer, activation
|
38 |
+
|
39 |
+
|
40 |
+
class Conv(nn.Module):
|
41 |
+
""" (convolution => [BN] => ReLU) """
|
42 |
+
def __init__(self,
|
43 |
+
in_channels,
|
44 |
+
out_channels,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1,
|
48 |
+
bias=True,
|
49 |
+
activation='leaky',
|
50 |
+
style=False,
|
51 |
+
resnet=True):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.style = style
|
55 |
+
norm_layer, act_func = get_layer_info(in_channels, activation)
|
56 |
+
|
57 |
+
if resnet and in_channels == out_channels:
|
58 |
+
self.resnet = True
|
59 |
+
else:
|
60 |
+
self.resnet = False
|
61 |
+
|
62 |
+
if style:
|
63 |
+
self.styleconv = Conv2DMod(in_channels, out_channels, kernel_size)
|
64 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
65 |
+
else:
|
66 |
+
self.norm = norm_layer
|
67 |
+
self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias)
|
68 |
+
self.act = act_func
|
69 |
+
|
70 |
+
def forward(self, x, style_fea=None):
|
71 |
+
if self.style:
|
72 |
+
res = self.styleconv(x, style_fea)
|
73 |
+
res = self.relu(res)
|
74 |
+
else:
|
75 |
+
h = self.conv(self.act(self.norm(x)))
|
76 |
+
if self.resnet:
|
77 |
+
res = h + x
|
78 |
+
else:
|
79 |
+
res = h
|
80 |
+
|
81 |
+
return res
|
82 |
+
|
83 |
+
|
84 |
+
class Conv2DMod(nn.Module):
|
85 |
+
def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs):
|
86 |
+
super().__init__()
|
87 |
+
self.filters = out_chan
|
88 |
+
self.demod = demod
|
89 |
+
self.kernel = kernel
|
90 |
+
self.stride = stride
|
91 |
+
self.dilation = dilation
|
92 |
+
self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
|
93 |
+
self.eps = eps
|
94 |
+
nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
|
95 |
+
|
96 |
+
def _get_same_padding(self, size, kernel, dilation, stride):
|
97 |
+
return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
|
98 |
+
|
99 |
+
def forward(self, x, y):
|
100 |
+
b, c, h, w = x.shape
|
101 |
+
|
102 |
+
w1 = y[:, None, :, None, None]
|
103 |
+
w2 = self.weight[None, :, :, :, :]
|
104 |
+
weights = w2 * (w1 + 1)
|
105 |
+
|
106 |
+
if self.demod:
|
107 |
+
d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
|
108 |
+
weights = weights * d
|
109 |
+
|
110 |
+
x = x.reshape(1, -1, h, w)
|
111 |
+
|
112 |
+
_, _, *ws = weights.shape
|
113 |
+
weights = weights.reshape(b * self.filters, *ws)
|
114 |
+
|
115 |
+
padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
|
116 |
+
x = F.conv2d(x, weights, padding=padding, groups=b)
|
117 |
+
|
118 |
+
x = x.reshape(-1, self.filters, h, w)
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
class Up(nn.Module):
|
123 |
+
""" Upscaling then conv """
|
124 |
+
|
125 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
126 |
+
super().__init__()
|
127 |
+
self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
128 |
+
self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.up_layer(x)
|
132 |
+
return self.up(x)
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
class DConv(nn.Module):
|
137 |
+
""" Double Conv Layer
|
138 |
+
"""
|
139 |
+
def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
|
143 |
+
self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
return self.conv2(self.conv1(x))
|
147 |
+
|
148 |
+
|
149 |
+
class Encoder(nn.Module):
|
150 |
+
def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
|
151 |
+
super(Encoder, self).__init__()
|
152 |
+
self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
|
153 |
+
self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
|
154 |
+
self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
|
155 |
+
self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
|
156 |
+
self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
|
157 |
+
self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
|
158 |
+
self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
|
159 |
+
self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
|
160 |
+
self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
161 |
+
self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
162 |
+
self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
|
163 |
+
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
x1 = self.in_conv(x) # 32 x 256 x 256
|
167 |
+
x1 = torch.cat((x, x1), dim=1)
|
168 |
+
|
169 |
+
x2 = self.down_32_64(x1)
|
170 |
+
x3 = self.down_64_64_1(x2)
|
171 |
+
|
172 |
+
x4 = self.down_64_128(x3)
|
173 |
+
x5 = self.down_128_128_1(x4)
|
174 |
+
|
175 |
+
x6 = self.down_128_256(x5)
|
176 |
+
x7 = self.down_256_256_1(x6)
|
177 |
+
|
178 |
+
x8 = self.down_256_512(x7)
|
179 |
+
x9 = self.down_512_512_1(x8)
|
180 |
+
x10 = self.down_512_512_2(x9)
|
181 |
+
x11 = self.down_512_512_3(x10)
|
182 |
+
|
183 |
+
return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
|
184 |
+
|
185 |
+
|
186 |
+
class Decoder(nn.Module):
|
187 |
+
def __init__(self,
|
188 |
+
out_channels=3,
|
189 |
+
mid_act='relu',
|
190 |
+
out_act='sigmoid',
|
191 |
+
resnet = True):
|
192 |
+
|
193 |
+
super(Decoder, self).__init__()
|
194 |
+
|
195 |
+
input_channel = 512
|
196 |
+
fea_dim = 100
|
197 |
+
|
198 |
+
self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
|
199 |
+
|
200 |
+
self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
|
201 |
+
self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
|
202 |
+
self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
|
203 |
+
|
204 |
+
self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
|
205 |
+
self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
|
206 |
+
|
207 |
+
self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
|
208 |
+
self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
|
209 |
+
|
210 |
+
self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
|
211 |
+
self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
|
212 |
+
|
213 |
+
self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
|
214 |
+
self.out_conv = Conv(64, out_channels, activation=mid_act)
|
215 |
+
|
216 |
+
self.out_act = get_activation(out_act)
|
217 |
+
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
|
221 |
+
|
222 |
+
y = self.up_16_16_1(x11)
|
223 |
+
|
224 |
+
y = torch.cat((x10, y), dim=1)
|
225 |
+
y = self.up_16_16_2(y)
|
226 |
+
|
227 |
+
y = torch.cat((x9, y), dim=1)
|
228 |
+
y = self.up_16_16_3(y)
|
229 |
+
|
230 |
+
y = torch.cat((x8, y), dim=1)
|
231 |
+
y = self.up_16_32(y)
|
232 |
+
|
233 |
+
y = torch.cat((x7, y), dim=1)
|
234 |
+
y = self.up_32_32_1(y)
|
235 |
+
|
236 |
+
y = torch.cat((x6, y), dim=1)
|
237 |
+
y = self.up_32_64(y)
|
238 |
+
|
239 |
+
y = torch.cat((x5, y), dim=1)
|
240 |
+
y = self.up_64_64_1(y) # 128 x 64 x 64
|
241 |
+
|
242 |
+
y = torch.cat((x4, y), dim=1)
|
243 |
+
y = self.up_64_128(y)
|
244 |
+
|
245 |
+
y = torch.cat((x3, y), dim=1)
|
246 |
+
y = self.up_128_128_1(y) # 64 x 128 x 128
|
247 |
+
|
248 |
+
y = torch.cat((x2, y), dim=1)
|
249 |
+
y = self.up_128_256(y) # 32 x 256 x 256
|
250 |
+
|
251 |
+
y = torch.cat((x1, y), dim=1)
|
252 |
+
y = self.out_conv(y) # 3 x 256 x 256
|
253 |
+
y = self.out_act(y)
|
254 |
+
|
255 |
+
return y
|
256 |
+
|
257 |
+
|
258 |
+
class SSN_v1(nn.Module):
|
259 |
+
""" Implementation of Relighting Net """
|
260 |
+
|
261 |
+
def __init__(self,
|
262 |
+
in_channels=3,
|
263 |
+
out_channels=3,
|
264 |
+
mid_act='leaky',
|
265 |
+
out_act='sigmoid',
|
266 |
+
resnet=True):
|
267 |
+
super(SSN_v1, self).__init__()
|
268 |
+
self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
|
269 |
+
self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
|
270 |
+
|
271 |
+
|
272 |
+
def forward(self, x, softness):
|
273 |
+
"""
|
274 |
+
Input is (source image, target light, source light, )
|
275 |
+
Output is: predicted new image, predicted source light, self-supervision image
|
276 |
+
"""
|
277 |
+
latent = self.encoder(x)
|
278 |
+
pred = self.decoder(latent)
|
279 |
+
|
280 |
+
return pred
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == '__main__':
|
284 |
+
test_input = torch.randn(5, 1, 256, 256)
|
285 |
+
style = torch.randn(5, 100)
|
286 |
+
|
287 |
+
model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
|
288 |
+
test_out = model(test_input, style)
|
289 |
+
|
290 |
+
print('Ouptut shape: ', test_out.shape)
|
models/Sparse_PH.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import utils
|
5 |
+
from torchvision.transforms import Resize
|
6 |
+
from collections import OrderedDict
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.cm as cm
|
9 |
+
import matplotlib as mpl
|
10 |
+
from torchvision.transforms import InterpolationMode
|
11 |
+
|
12 |
+
|
13 |
+
from .abs_model import abs_model
|
14 |
+
from .blocks import *
|
15 |
+
from .SSN import SSN
|
16 |
+
from .SSN_v1 import SSN_v1
|
17 |
+
from .Loss.Loss import norm_loss, grad_loss
|
18 |
+
from .Attention_Unet import Attention_Unet
|
19 |
+
|
20 |
+
class Sparse_PH(abs_model):
|
21 |
+
def __init__(self, opt):
|
22 |
+
mid_act = opt['model']['mid_act']
|
23 |
+
out_act = opt['model']['out_act']
|
24 |
+
in_channels = opt['model']['in_channels']
|
25 |
+
out_channels = opt['model']['out_channels']
|
26 |
+
resnet = opt['model']['resnet']
|
27 |
+
backbone = opt['model']['backbone']
|
28 |
+
|
29 |
+
self.ncols = opt['hyper_params']['n_cols']
|
30 |
+
self.focal = opt['model']['focal']
|
31 |
+
self.clip = opt['model']['clip']
|
32 |
+
|
33 |
+
self.norm_loss_ = opt['model']['norm_loss']
|
34 |
+
self.grad_loss_ = opt['model']['grad_loss']
|
35 |
+
self.ggrad_loss_ = opt['model']['ggrad_loss']
|
36 |
+
self.lap_loss = opt['model']['lap_loss']
|
37 |
+
|
38 |
+
self.clip_range = opt['dataset']['linear_scale'] + opt['dataset']['linear_offset']
|
39 |
+
|
40 |
+
if backbone == 'Default':
|
41 |
+
self.model = SSN_v1(in_channels=in_channels,
|
42 |
+
out_channels=out_channels,
|
43 |
+
mid_act=mid_act,
|
44 |
+
out_act=out_act,
|
45 |
+
resnet=resnet)
|
46 |
+
elif backbone == 'ATTN':
|
47 |
+
self.model = Attention_Unet(in_channels, out_channels, mid_act=mid_act, out_act=out_act)
|
48 |
+
|
49 |
+
self.optimizer = get_optimizer(opt, self.model)
|
50 |
+
self.visualization = {}
|
51 |
+
|
52 |
+
self.norm_loss = norm_loss()
|
53 |
+
self.grad_loss = grad_loss()
|
54 |
+
|
55 |
+
|
56 |
+
def setup_input(self, x):
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
return self.model(x)
|
62 |
+
|
63 |
+
|
64 |
+
def compute_loss(self, y, pred):
|
65 |
+
b = y.shape[0]
|
66 |
+
|
67 |
+
# total_loss = avg_norm_loss(y, pred)
|
68 |
+
nloss = self.norm_loss.loss(y, pred) * self.norm_loss_
|
69 |
+
gloss = self.grad_loss.loss(pred) * self.grad_loss_
|
70 |
+
ggloss = self.grad_loss.gloss(y, pred) * self.ggrad_loss_
|
71 |
+
laploss = self.grad_loss.laploss(pred) * self.lap_loss
|
72 |
+
|
73 |
+
total_loss = nloss + gloss + ggloss + laploss
|
74 |
+
|
75 |
+
self.loss_log = {
|
76 |
+
'norm_loss': nloss.item(),
|
77 |
+
'grad_loss': gloss.item(),
|
78 |
+
'grad_l1_loss': ggloss.item(),
|
79 |
+
'lap_loss': laploss.item(),
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
if self.focal:
|
84 |
+
total_loss = torch.pow(total_loss, 3)
|
85 |
+
|
86 |
+
return total_loss
|
87 |
+
|
88 |
+
|
89 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
90 |
+
optimizer = self.optimizer
|
91 |
+
model = self.model
|
92 |
+
|
93 |
+
x = input_x['x']
|
94 |
+
|
95 |
+
optimizer.zero_grad()
|
96 |
+
pred = self.forward(x)
|
97 |
+
if self.clip:
|
98 |
+
pred = torch.clip(pred, 0.0, self.clip_range)
|
99 |
+
|
100 |
+
loss = self.compute_loss(y, pred)
|
101 |
+
if is_training:
|
102 |
+
loss.backward()
|
103 |
+
optimizer.step()
|
104 |
+
|
105 |
+
xc = x.shape[1]
|
106 |
+
for i in range(xc):
|
107 |
+
self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
|
108 |
+
|
109 |
+
self.visualization['y_fore'] = y[:, 0:1].detach()
|
110 |
+
self.visualization['y_back'] = y[:, 1:2].detach()
|
111 |
+
self.visualization['pred_fore'] = pred[:, 0:1].detach()
|
112 |
+
self.visualization['pred_back'] = pred[:, 1:2].detach()
|
113 |
+
|
114 |
+
return loss.item()
|
115 |
+
|
116 |
+
|
117 |
+
def get_visualize(self) -> OrderedDict:
|
118 |
+
""" Convert to visualization numpy array
|
119 |
+
"""
|
120 |
+
nrows = self.ncols
|
121 |
+
visualizations = self.visualization
|
122 |
+
ret_vis = OrderedDict()
|
123 |
+
|
124 |
+
for k, v in visualizations.items():
|
125 |
+
batch = v.shape[0]
|
126 |
+
n = min(nrows, batch)
|
127 |
+
|
128 |
+
plot_v = v[:n]
|
129 |
+
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
|
130 |
+
ret_vis[k] = self.plasma(ret_vis[k])
|
131 |
+
|
132 |
+
return ret_vis
|
133 |
+
|
134 |
+
|
135 |
+
def get_logs(self):
|
136 |
+
return self.loss_log
|
137 |
+
|
138 |
+
|
139 |
+
def inference(self, x):
|
140 |
+
x, device = x['x'], x['device']
|
141 |
+
x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).float().to(device)
|
142 |
+
pred = self.forward(x)
|
143 |
+
|
144 |
+
pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
|
145 |
+
|
146 |
+
return pred
|
147 |
+
|
148 |
+
|
149 |
+
def batch_inference(self, x):
|
150 |
+
x = x['x']
|
151 |
+
pred = self.forward(x)
|
152 |
+
return pred
|
153 |
+
|
154 |
+
|
155 |
+
""" Getter & Setter
|
156 |
+
"""
|
157 |
+
def get_models(self) -> dict:
|
158 |
+
return {'model': self.model}
|
159 |
+
|
160 |
+
|
161 |
+
def get_optimizers(self) -> dict:
|
162 |
+
return {'optimizer': self.optimizer}
|
163 |
+
|
164 |
+
|
165 |
+
def set_models(self, models: dict) :
|
166 |
+
# input test
|
167 |
+
if 'model' not in models.keys():
|
168 |
+
raise ValueError('{} not in self.model'.format('model'))
|
169 |
+
|
170 |
+
self.model = models['model']
|
171 |
+
|
172 |
+
|
173 |
+
def set_optimizers(self, optimizer: dict):
|
174 |
+
self.optimizer = optimizer['optimizer']
|
175 |
+
|
176 |
+
|
177 |
+
####################
|
178 |
+
# Personal Methods #
|
179 |
+
####################
|
180 |
+
def plasma(self, x):
|
181 |
+
norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
|
182 |
+
mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
|
183 |
+
bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
|
184 |
+
|
185 |
+
return bimg
|
models/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SRC: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/__init__.py
|
2 |
+
import logging
|
3 |
+
import importlib
|
4 |
+
|
5 |
+
from .abs_model import abs_model
|
6 |
+
|
7 |
+
|
8 |
+
def find_model_using_name(model_name):
|
9 |
+
"""Import the module "models/[model_name].py".
|
10 |
+
In the file, the class called DatasetNameModel() will
|
11 |
+
be instantiated. It has to be a subclass of BaseModel,
|
12 |
+
and it is case-insensitive.
|
13 |
+
"""
|
14 |
+
model_filename = "models." + model_name
|
15 |
+
modellib = importlib.import_module(model_filename)
|
16 |
+
model = None
|
17 |
+
|
18 |
+
target_model_name = model_name
|
19 |
+
for name, cls in modellib.__dict__.items():
|
20 |
+
if name.lower() == target_model_name.lower() \
|
21 |
+
and issubclass(cls, abs_model):
|
22 |
+
model = cls
|
23 |
+
|
24 |
+
if model is None:
|
25 |
+
err = "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)
|
26 |
+
logging.error(err)
|
27 |
+
exit(0)
|
28 |
+
|
29 |
+
return model
|
30 |
+
|
31 |
+
|
32 |
+
def create_model(opt):
|
33 |
+
"""Create a model given the option.
|
34 |
+
This funct
|
35 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
36 |
+
Example:
|
37 |
+
>>> from models import create_model
|
38 |
+
>>> model = create_model(opt)
|
39 |
+
"""
|
40 |
+
model = find_model_using_name(opt['model']['name'])
|
41 |
+
instance = model(opt)
|
42 |
+
logging.info("model [%s] was created" % type(instance).__name__)
|
43 |
+
return instance
|
models/__pycache__/SSN.cpython-39.pyc
ADDED
Binary file (4.11 kB). View file
|
|
models/__pycache__/SSN_Model.cpython-39.pyc
ADDED
Binary file (8.96 kB). View file
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (1.42 kB). View file
|
|
models/__pycache__/abs_model.cpython-39.pyc
ADDED
Binary file (2.14 kB). View file
|
|
models/__pycache__/blocks.cpython-39.pyc
ADDED
Binary file (6.92 kB). View file
|
|
models/abs_model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
class abs_model(ABC):
|
5 |
+
""" Training Related Interface
|
6 |
+
"""
|
7 |
+
@abstractmethod
|
8 |
+
def setup_input(self, x):
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def forward(self, x):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
@abstractmethod
|
23 |
+
def get_visualize(self) -> OrderedDict:
|
24 |
+
return {}
|
25 |
+
|
26 |
+
|
27 |
+
""" Inference Related Interface
|
28 |
+
"""
|
29 |
+
@abstractmethod
|
30 |
+
def inference(self, x):
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def batch_inference(self, x):
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
""" Logging/Visualization Related Interface
|
40 |
+
"""
|
41 |
+
@abstractmethod
|
42 |
+
def get_logs(self):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
""" Getter & Setter
|
47 |
+
"""
|
48 |
+
@abstractmethod
|
49 |
+
def get_models(self) -> dict:
|
50 |
+
""" GAN may have two models
|
51 |
+
"""
|
52 |
+
pass
|
53 |
+
|
54 |
+
|
55 |
+
@abstractmethod
|
56 |
+
def get_optimizers(self) -> dict:
|
57 |
+
""" GAN may have two optimizer
|
58 |
+
"""
|
59 |
+
pass
|
60 |
+
|
61 |
+
|
62 |
+
@abstractmethod
|
63 |
+
def set_models(self, models) -> dict:
|
64 |
+
""" GAN may have two models
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
|
69 |
+
@abstractmethod
|
70 |
+
def set_optimizers(self, optimizers: dict):
|
71 |
+
""" GAN may have two optimizer
|
72 |
+
"""
|
73 |
+
pass
|
models/attention.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torch import nn, einsum
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from .blocks import get_norm, zero_module
|
8 |
+
|
9 |
+
|
10 |
+
def QKV_Attention(qkv, num_heads):
|
11 |
+
"""
|
12 |
+
Apply QKV attention.
|
13 |
+
:param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs.
|
14 |
+
:return: an [N x H' x T] tensor after attention.
|
15 |
+
"""
|
16 |
+
B, C, HW = qkv.shape
|
17 |
+
if C % 3 != 0:
|
18 |
+
raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW))
|
19 |
+
|
20 |
+
split_size = C // (3 * num_heads)
|
21 |
+
q, k, v = qkv.chunk(3, dim=1)
|
22 |
+
scale = 1.0/math.sqrt(math.sqrt(split_size))
|
23 |
+
weight = torch.einsum('bct, bcs->bts',
|
24 |
+
(q * scale).view(B * num_heads, split_size, HW),
|
25 |
+
(k * scale).view(B * num_heads, split_size, HW))
|
26 |
+
|
27 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
28 |
+
ret = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW))
|
29 |
+
|
30 |
+
return ret.reshape(B, -1, HW)
|
31 |
+
|
32 |
+
|
33 |
+
class AttentionBlock(nn.Module):
|
34 |
+
"""
|
35 |
+
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py
|
36 |
+
https://github.com/whai362/PVT/blob/a24ba02c249a510581a84f821c26322534b03a10/detection/pvt_v2.py#L57
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.num_heads = num_heads
|
43 |
+
self.norm = get_norm(in_channels, 'Group')
|
44 |
+
self.qkv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1)
|
45 |
+
|
46 |
+
self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1))
|
47 |
+
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
b, c, *spatial = x.shape
|
51 |
+
num_heads = self.num_heads
|
52 |
+
|
53 |
+
x = x.reshape(b, c, -1) # B x C x HW
|
54 |
+
x = self.norm(x)
|
55 |
+
qkv = self.qkv(x) # b x c x HW -> B x 3C x HW
|
56 |
+
h = QKV_Attention(qkv, num_heads)
|
57 |
+
h = self.proj(h)
|
58 |
+
|
59 |
+
return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet?
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def get_model_size(model):
|
64 |
+
param_size = 0
|
65 |
+
for param in model.parameters():
|
66 |
+
param_size += param.nelement() * param.element_size()
|
67 |
+
|
68 |
+
buffer_size = 0
|
69 |
+
for buffer in model.buffers():
|
70 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
71 |
+
|
72 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
73 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
74 |
+
# return param_size + buffer_size
|
75 |
+
return size_all_mb
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
model = AttentionBlock(in_channels=256, num_heads=8)
|
80 |
+
|
81 |
+
x = torch.randn(5, 256, 32, 32, dtype=torch.float32)
|
82 |
+
y = model(x)
|
83 |
+
print('{}, {}'.format(x.shape, y.shape))
|
84 |
+
|
85 |
+
get_model_size(model)
|
models/blocks.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.optim as optim
|
6 |
+
import logging
|
7 |
+
|
8 |
+
|
9 |
+
def get_model_size(model):
|
10 |
+
param_size = 0
|
11 |
+
for param in model.parameters():
|
12 |
+
param_size += param.nelement() * param.element_size()
|
13 |
+
|
14 |
+
buffer_size = 0
|
15 |
+
for buffer in model.buffers():
|
16 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
17 |
+
|
18 |
+
size_all_mb = (param_size + buffer_size) / 1024 ** 2
|
19 |
+
print('model size: {:.3f}MB'.format(size_all_mb))
|
20 |
+
# return param_size + buffer_size
|
21 |
+
return size_all_mb
|
22 |
+
|
23 |
+
|
24 |
+
def weights_init(init_type='gaussian'):
|
25 |
+
def init_fun(m):
|
26 |
+
classname = m.__class__.__name__
|
27 |
+
if (classname.find('Conv') == 0 or classname.find(
|
28 |
+
'Linear') == 0) and hasattr(m, 'weight'):
|
29 |
+
if init_type == 'gaussian':
|
30 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
31 |
+
elif init_type == 'xavier':
|
32 |
+
nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
|
33 |
+
elif init_type == 'kaiming':
|
34 |
+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
35 |
+
elif init_type == 'orthogonal':
|
36 |
+
nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
|
37 |
+
elif init_type == 'default':
|
38 |
+
pass
|
39 |
+
else:
|
40 |
+
assert 0, "Unsupported initialization: {}".format(init_type)
|
41 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
42 |
+
nn.init.constant_(m.bias, 0.0)
|
43 |
+
|
44 |
+
return init_fun
|
45 |
+
|
46 |
+
|
47 |
+
def freeze(module):
|
48 |
+
for param in module.parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
|
51 |
+
|
52 |
+
def unfreeze(module):
|
53 |
+
for param in module.parameters():
|
54 |
+
param.requires_grad = True
|
55 |
+
|
56 |
+
|
57 |
+
def get_optimizer(opt, model):
|
58 |
+
lr = float(opt['hyper_params']['lr'])
|
59 |
+
beta1 = float(opt['model']['beta1'])
|
60 |
+
weight_decay = float(opt['model']['weight_decay'])
|
61 |
+
opt_name = opt['model']['optimizer']
|
62 |
+
|
63 |
+
optim_params = []
|
64 |
+
# weight decay
|
65 |
+
for key, value in model.named_parameters():
|
66 |
+
if not value.requires_grad:
|
67 |
+
continue # frozen weights
|
68 |
+
|
69 |
+
if key[-4:] == 'bias':
|
70 |
+
optim_params += [{'params': value, 'weight_decay': 0.0}]
|
71 |
+
else:
|
72 |
+
optim_params += [{'params': value,
|
73 |
+
'weight_decay': weight_decay}]
|
74 |
+
|
75 |
+
if opt_name == 'Adam':
|
76 |
+
return optim.Adam(optim_params,
|
77 |
+
lr=lr,
|
78 |
+
betas=(beta1, 0.999),
|
79 |
+
eps=1e-5)
|
80 |
+
else:
|
81 |
+
err = '{} not implemented yet'.format(opt_name)
|
82 |
+
logging.error(err)
|
83 |
+
raise NotImplementedError(err)
|
84 |
+
|
85 |
+
|
86 |
+
def get_activation(activation):
|
87 |
+
act_func = {
|
88 |
+
'relu':nn.ReLU(),
|
89 |
+
'sigmoid':nn.Sigmoid(),
|
90 |
+
'tanh':nn.Tanh(),
|
91 |
+
'prelu':nn.PReLU(),
|
92 |
+
'leaky_relu':nn.LeakyReLU(0.2),
|
93 |
+
'gelu':nn.GELU(),
|
94 |
+
}
|
95 |
+
if activation not in act_func.keys():
|
96 |
+
logging.error("activation {} is not implemented yet".format(activation))
|
97 |
+
assert False
|
98 |
+
|
99 |
+
return act_func[activation]
|
100 |
+
|
101 |
+
|
102 |
+
def get_norm(out_channels, norm_type='Group', groups=32):
|
103 |
+
norm_set = ['Instance', 'Batch', 'Group']
|
104 |
+
if norm_type not in norm_set:
|
105 |
+
err = "Normalization {} has not been implemented yet"
|
106 |
+
logging.error(err)
|
107 |
+
raise ValueError(err)
|
108 |
+
|
109 |
+
if norm_type == 'Instance':
|
110 |
+
return nn.InstanceNorm2d(out_channels, affine=True)
|
111 |
+
|
112 |
+
if norm_type == 'Batch':
|
113 |
+
return nn.BatchNorm2d(out_channels)
|
114 |
+
|
115 |
+
if norm_type == 'Group':
|
116 |
+
if out_channels >= 32:
|
117 |
+
groups = 32
|
118 |
+
else:
|
119 |
+
groups = max(out_channels // 2, 1)
|
120 |
+
|
121 |
+
return nn.GroupNorm(groups, out_channels)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError
|
124 |
+
|
125 |
+
|
126 |
+
class Conv(nn.Module):
|
127 |
+
def __init__(self, in_channels, out_channels, stride=1, norm_type='Batch', activation='relu'):
|
128 |
+
super().__init__()
|
129 |
+
|
130 |
+
act_func = get_activation(activation)
|
131 |
+
norm_layer = get_norm(out_channels, norm_type)
|
132 |
+
self.conv = nn.Sequential(
|
133 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, padding_mode='reflect'),
|
134 |
+
norm_layer,
|
135 |
+
act_func)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
return self.conv(x)
|
139 |
+
|
140 |
+
|
141 |
+
def zero_module(module):
|
142 |
+
"""
|
143 |
+
Zero out the parameters of a module and return it.
|
144 |
+
"""
|
145 |
+
for p in module.parameters():
|
146 |
+
p.detach().zero_()
|
147 |
+
return module
|
148 |
+
|
149 |
+
|
150 |
+
class Up(nn.Module):
|
151 |
+
def __init__(self):
|
152 |
+
super().__init__()
|
153 |
+
pass
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
return F.interpolate(x, scale_factor=2, mode='bilinear')
|
157 |
+
|
158 |
+
|
159 |
+
class Down(nn.Module):
|
160 |
+
def __init__(self, channels, use_conv):
|
161 |
+
super().__init__()
|
162 |
+
self.use_conv = use_conv
|
163 |
+
|
164 |
+
if self.use_conv:
|
165 |
+
self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
166 |
+
else:
|
167 |
+
self.op = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
168 |
+
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
return self.op(x)
|
172 |
+
|
173 |
+
|
174 |
+
class Res_Type(Enum):
|
175 |
+
UP = 1
|
176 |
+
DOWN = 2
|
177 |
+
SAME = 3
|
178 |
+
|
179 |
+
|
180 |
+
class ResBlock(nn.Module):
|
181 |
+
def __init__(self, in_channels: int, out_channels: int, dropout=0.0, updown=Res_Type.DOWN, mid_act='leaky'):
|
182 |
+
""" ResBlock to cover several cases:
|
183 |
+
1. Up/Down/Same
|
184 |
+
2. in_channels != out_channels
|
185 |
+
"""
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.updown = updown
|
189 |
+
|
190 |
+
self.in_norm = get_norm(out_channels, 'Group')
|
191 |
+
self.in_act = get_activation(mid_act)
|
192 |
+
self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
|
193 |
+
|
194 |
+
# up down
|
195 |
+
if self.updown == Res_Type.DOWN:
|
196 |
+
self.h_updown = Down(in_channels, use_conv=True)
|
197 |
+
self.x_updown = Down(in_channels, use_conv=True)
|
198 |
+
elif self.updown == Res_Type.UP:
|
199 |
+
self.h_updown = Up()
|
200 |
+
self.x_updown = Up()
|
201 |
+
else:
|
202 |
+
self.h_updown = nn.Identity()
|
203 |
+
|
204 |
+
self.out_layer = nn.Sequential(
|
205 |
+
get_norm(out_channels, 'Group'),
|
206 |
+
get_activation(mid_act),
|
207 |
+
nn.Dropout(p=dropout),
|
208 |
+
zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True))
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
# in layer
|
214 |
+
h = self.in_act(self.in_norm(x))
|
215 |
+
h = self.in_conv(self.h_updown(h))
|
216 |
+
x = self.x_updown(x)
|
217 |
+
|
218 |
+
# out layer
|
219 |
+
h = self.out_layer(h)
|
220 |
+
return x + h
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
x = torch.randn(5, 3, 256, 256)
|
226 |
+
up = Up()
|
227 |
+
conv_down = Down(3, True)
|
228 |
+
pool_down = Down(3, False)
|
229 |
+
|
230 |
+
print('Up: {}'.format(up(x).shape))
|
231 |
+
print('Conv down: {}'.format(conv_down(x).shape))
|
232 |
+
print('Pool down: {}'.format(pool_down(x).shape))
|
233 |
+
|
234 |
+
up_model = ResBlock(3, 6, updown=True)
|
235 |
+
down_model = ResBlock(3, 6, updown=False)
|
236 |
+
|
237 |
+
print('model down: {}'.format(up_model(x).shape))
|
238 |
+
print('model down: {}'.format(down_model(x).shape))
|
models/pvt_attention.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
from timm.models.registry import register_model
|
8 |
+
from timm.models.vision_transformer import _cfg
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
class DWConv(nn.Module):
|
13 |
+
def __init__(self, dim=768):
|
14 |
+
super(DWConv, self).__init__()
|
15 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
16 |
+
|
17 |
+
def forward(self, x, H, W):
|
18 |
+
B, N, C = x.shape
|
19 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
20 |
+
x = self.dwconv(x)
|
21 |
+
x = x.flatten(2).transpose(1, 2)
|
22 |
+
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class Mlp(nn.Module):
|
27 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
|
28 |
+
super().__init__()
|
29 |
+
out_features = out_features or in_features
|
30 |
+
hidden_features = hidden_features or in_features
|
31 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
32 |
+
self.dwconv = DWConv(hidden_features)
|
33 |
+
self.act = act_layer()
|
34 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
35 |
+
self.drop = nn.Dropout(drop)
|
36 |
+
self.linear = linear
|
37 |
+
if self.linear:
|
38 |
+
self.relu = nn.ReLU(inplace=True)
|
39 |
+
self.apply(self._init_weights)
|
40 |
+
|
41 |
+
def _init_weights(self, m):
|
42 |
+
if isinstance(m, nn.Linear):
|
43 |
+
trunc_normal_(m.weight, std=.02)
|
44 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
45 |
+
nn.init.constant_(m.bias, 0)
|
46 |
+
elif isinstance(m, nn.LayerNorm):
|
47 |
+
nn.init.constant_(m.bias, 0)
|
48 |
+
nn.init.constant_(m.weight, 1.0)
|
49 |
+
elif isinstance(m, nn.Conv2d):
|
50 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
51 |
+
fan_out //= m.groups
|
52 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
53 |
+
if m.bias is not None:
|
54 |
+
m.bias.data.zero_()
|
55 |
+
|
56 |
+
def forward(self, x, H, W):
|
57 |
+
x = self.fc1(x)
|
58 |
+
if self.linear:
|
59 |
+
x = self.relu(x)
|
60 |
+
x = self.dwconv(x, H, W)
|
61 |
+
x = self.act(x)
|
62 |
+
x = self.drop(x)
|
63 |
+
x = self.fc2(x)
|
64 |
+
x = self.drop(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class Attention(nn.Module):
|
69 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
|
70 |
+
super().__init__()
|
71 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
72 |
+
|
73 |
+
self.dim = dim
|
74 |
+
self.num_heads = num_heads
|
75 |
+
head_dim = dim // num_heads
|
76 |
+
self.scale = qk_scale or head_dim ** -0.5
|
77 |
+
|
78 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
79 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
80 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
81 |
+
self.proj = nn.Linear(dim, dim)
|
82 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
83 |
+
|
84 |
+
self.linear = linear
|
85 |
+
self.sr_ratio = sr_ratio
|
86 |
+
if not linear:
|
87 |
+
if sr_ratio > 1:
|
88 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
89 |
+
self.norm = nn.LayerNorm(dim)
|
90 |
+
else:
|
91 |
+
self.pool = nn.AdaptiveAvgPool2d(7)
|
92 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
93 |
+
self.norm = nn.LayerNorm(dim)
|
94 |
+
self.act = nn.GELU()
|
95 |
+
self.apply(self._init_weights)
|
96 |
+
|
97 |
+
def _init_weights(self, m):
|
98 |
+
if isinstance(m, nn.Linear):
|
99 |
+
trunc_normal_(m.weight, std=.02)
|
100 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
101 |
+
nn.init.constant_(m.bias, 0)
|
102 |
+
elif isinstance(m, nn.LayerNorm):
|
103 |
+
nn.init.constant_(m.bias, 0)
|
104 |
+
nn.init.constant_(m.weight, 1.0)
|
105 |
+
elif isinstance(m, nn.Conv2d):
|
106 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
107 |
+
fan_out //= m.groups
|
108 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
109 |
+
if m.bias is not None:
|
110 |
+
m.bias.data.zero_()
|
111 |
+
|
112 |
+
def forward(self, x, H, W):
|
113 |
+
B, N, C = x.shape
|
114 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
115 |
+
|
116 |
+
if not self.linear:
|
117 |
+
if self.sr_ratio > 1:
|
118 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
119 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
120 |
+
x_ = self.norm(x_)
|
121 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
122 |
+
else:
|
123 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
124 |
+
else:
|
125 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
126 |
+
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
127 |
+
x_ = self.norm(x_)
|
128 |
+
x_ = self.act(x_)
|
129 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
130 |
+
k, v = kv[0], kv[1]
|
131 |
+
|
132 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
133 |
+
attn = attn.softmax(dim=-1)
|
134 |
+
attn = self.attn_drop(attn)
|
135 |
+
|
136 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
137 |
+
x = self.proj(x)
|
138 |
+
x = self.proj_drop(x)
|
139 |
+
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class Block(nn.Module):
|
144 |
+
|
145 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
146 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
|
147 |
+
super().__init__()
|
148 |
+
self.norm1 = norm_layer(dim)
|
149 |
+
self.attn = Attention(
|
150 |
+
dim,
|
151 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
152 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
|
153 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
154 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
155 |
+
self.norm2 = norm_layer(dim)
|
156 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
157 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
|
158 |
+
|
159 |
+
self.apply(self._init_weights)
|
160 |
+
|
161 |
+
def _init_weights(self, m):
|
162 |
+
if isinstance(m, nn.Linear):
|
163 |
+
trunc_normal_(m.weight, std=.02)
|
164 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
165 |
+
nn.init.constant_(m.bias, 0)
|
166 |
+
elif isinstance(m, nn.LayerNorm):
|
167 |
+
nn.init.constant_(m.bias, 0)
|
168 |
+
nn.init.constant_(m.weight, 1.0)
|
169 |
+
elif isinstance(m, nn.Conv2d):
|
170 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
171 |
+
fan_out //= m.groups
|
172 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
173 |
+
if m.bias is not None:
|
174 |
+
m.bias.data.zero_()
|
175 |
+
|
176 |
+
def forward(self, x, H, W):
|
177 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
178 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class OverlapPatchEmbed(nn.Module):
|
184 |
+
""" Image to Patch Embedding
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
188 |
+
super().__init__()
|
189 |
+
img_size = to_2tuple(img_size)
|
190 |
+
patch_size = to_2tuple(patch_size)
|
191 |
+
|
192 |
+
assert max(patch_size) > stride, "Set larger patch_size than stride"
|
193 |
+
|
194 |
+
self.img_size = img_size
|
195 |
+
self.patch_size = patch_size
|
196 |
+
self.H, self.W = img_size[0] // stride, img_size[1] // stride
|
197 |
+
self.num_patches = self.H * self.W
|
198 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
199 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
200 |
+
self.norm = nn.LayerNorm(embed_dim)
|
201 |
+
|
202 |
+
self.apply(self._init_weights)
|
203 |
+
|
204 |
+
def _init_weights(self, m):
|
205 |
+
if isinstance(m, nn.Linear):
|
206 |
+
trunc_normal_(m.weight, std=.02)
|
207 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
208 |
+
nn.init.constant_(m.bias, 0)
|
209 |
+
elif isinstance(m, nn.LayerNorm):
|
210 |
+
nn.init.constant_(m.bias, 0)
|
211 |
+
nn.init.constant_(m.weight, 1.0)
|
212 |
+
elif isinstance(m, nn.Conv2d):
|
213 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
214 |
+
fan_out //= m.groups
|
215 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
216 |
+
if m.bias is not None:
|
217 |
+
m.bias.data.zero_()
|
218 |
+
|
219 |
+
def forward(self, x):
|
220 |
+
x = self.proj(x)
|
221 |
+
_, _, H, W = x.shape
|
222 |
+
import pdb; pdb.set_trace()
|
223 |
+
x = x.flatten(2).transpose(1, 2)
|
224 |
+
x = self.norm(x)
|
225 |
+
|
226 |
+
return x, H, W
|
227 |
+
|
228 |
+
if __name__ == '__main__':
|
229 |
+
test = torch.randn(5, 3, 224, 224)
|
230 |
+
|
231 |
+
embed_dim = 768
|
232 |
+
patch_embed = OverlapPatchEmbed(embed_dim=embed_dim)
|
233 |
+
block = Block(embed_dim, 1)
|
234 |
+
|
235 |
+
import pdb; pdb.set_trace()
|
236 |
+
print('x: {}'.format(test.shape))
|
237 |
+
pe, H, W = patch_embed(test)
|
238 |
+
print('After patch: {}'.format(pe.shape))
|
239 |
+
y = block(pe, H, W)
|
240 |
+
print('After block: {}'.format(y.shape))
|
models/template.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision import utils
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
from .abs_model import abs_model
|
8 |
+
from .blocks import *
|
9 |
+
from .Loss.Loss import avg_norm_loss
|
10 |
+
|
11 |
+
class Template(abs_model):
|
12 |
+
""" Standard Unet Implementation
|
13 |
+
src: https://arxiv.org/pdf/1505.04597.pdf
|
14 |
+
"""
|
15 |
+
def __init__(self, opt):
|
16 |
+
resunet = opt['model']['resunet']
|
17 |
+
out_act = opt['model']['out_act']
|
18 |
+
norm_type = opt['model']['norm_type']
|
19 |
+
in_channels = opt['model']['in_channels']
|
20 |
+
out_channels = opt['model']['out_channels']
|
21 |
+
self.ncols = opt['hyper_params']['n_cols']
|
22 |
+
|
23 |
+
self.model = Unet(in_channels=in_channels,
|
24 |
+
out_channels=out_channels,
|
25 |
+
norm_type=norm_type,
|
26 |
+
out_act=out_act,
|
27 |
+
resunet=resunet)
|
28 |
+
|
29 |
+
self.optimizer = get_optimizer(opt, self.model)
|
30 |
+
self.visualization = {}
|
31 |
+
|
32 |
+
|
33 |
+
def setup_input(self, x):
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.model(x)
|
39 |
+
|
40 |
+
|
41 |
+
def compute_loss(self, y, pred):
|
42 |
+
return avg_norm_loss(y, pred)
|
43 |
+
|
44 |
+
|
45 |
+
def supervise(self, input_x, y, is_training:bool)->float:
|
46 |
+
optimizer = self.optimizer
|
47 |
+
model = self.model
|
48 |
+
|
49 |
+
optimizer.zero_grad()
|
50 |
+
pred = model(input_x)
|
51 |
+
loss = self.compute_loss(y, pred)
|
52 |
+
|
53 |
+
if is_training:
|
54 |
+
loss.backward()
|
55 |
+
optimizer.step()
|
56 |
+
|
57 |
+
self.visualization['y'] = pred.detach()
|
58 |
+
self.visualization['pred'] = pred.detach()
|
59 |
+
|
60 |
+
return loss.item()
|
61 |
+
|
62 |
+
|
63 |
+
def get_visualize(self) -> OrderedDict:
|
64 |
+
""" Convert to visualization numpy array
|
65 |
+
"""
|
66 |
+
nrows = self.ncols
|
67 |
+
visualizations = self.visualization
|
68 |
+
ret_vis = OrderedDict()
|
69 |
+
|
70 |
+
for k, v in visualizations.items():
|
71 |
+
batch = v.shape[0]
|
72 |
+
n = min(nrows, batch)
|
73 |
+
|
74 |
+
plot_v = v[:n]
|
75 |
+
ret_vis[k] = utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0)
|
76 |
+
|
77 |
+
return ret_vis
|
78 |
+
|
79 |
+
|
80 |
+
def inference(self, x):
|
81 |
+
# TODO
|
82 |
+
pass
|
83 |
+
|
84 |
+
|
85 |
+
def batch_inference(self, x):
|
86 |
+
# TODO
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
""" Getter & Setter
|
91 |
+
"""
|
92 |
+
def get_models(self) -> dict:
|
93 |
+
return {'model': self.model}
|
94 |
+
|
95 |
+
|
96 |
+
def get_optimizers(self) -> dict:
|
97 |
+
return {'optimizer': self.optimizer}
|
98 |
+
|
99 |
+
|
100 |
+
def set_models(self, models: dict) :
|
101 |
+
# input test
|
102 |
+
if 'model' not in models.keys():
|
103 |
+
raise ValueError('{} not in self.model'.format('model'))
|
104 |
+
|
105 |
+
self.model = models['model']
|
106 |
+
|
107 |
+
|
108 |
+
def set_optimizers(self, optimizer: dict):
|
109 |
+
self.optimizer = optimizer['optimizer']
|
110 |
+
|
111 |
+
|
112 |
+
####################
|
113 |
+
# Personal Methods #
|
114 |
+
####################
|
weights/SSN/0000001760.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:44328317fa836804554ae453fe1492a45cff724b5c13b5070211d6d860096089
|
3 |
+
size 283511041
|