yichen-purdue commited on
Commit
34fb220
1 Parent(s): 96d9168
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import logging
4
+
5
+ from pathlib import Path
6
+ import gradio as gr
7
+ import numpy as np
8
+ import cv2
9
+
10
+ import model_utils
11
+ from models.SSN import SSN
12
+
13
+ import matplotlib
14
+ matplotlib.use('TkAgg')
15
+
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+
19
+ config_file = 'configs/SSN.yaml'
20
+ weight = 'weights/SSN/0000001760.pt'
21
+ device = torch.device('cuda:0')
22
+ model = model_utils.load_model(config_file, weight, SSN, device)
23
+
24
+ DEFAULT_INTENSITY = 0.9
25
+ DEFAULT_GAMMA = 2.0
26
+
27
+ logging.info('Model loading succeed')
28
+
29
+ cur_rgba = None
30
+ cur_shadow = None
31
+ cur_intensity = DEFAULT_INTENSITY
32
+ cur_gamma = DEFAULT_GAMMA
33
+
34
+ def resize(img, size):
35
+ h, w = img.shape[:2]
36
+
37
+ if h > w:
38
+ newh = size
39
+ neww = int(w / h * size)
40
+ else:
41
+ neww = size
42
+ newh = int(h / w * size)
43
+
44
+ resized_img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
45
+ if len(img.shape) != len(resized_img.shape):
46
+ resized_img = resized_img[..., none]
47
+
48
+ return resized_img
49
+
50
+
51
+ def ibl_normalize(ibl, energy=30.0):
52
+ total_energy = np.sum(ibl)
53
+ if total_energy < 1e-3:
54
+ # print('small energy: ', total_energy)
55
+ h,w = ibl.shape
56
+ return np.zeros((h,w))
57
+
58
+ return ibl * energy / total_energy
59
+
60
+
61
+ def padding_mask(rgba_input: np.array):
62
+ """ Padding the mask input so that it fits the training dataset view range
63
+
64
+ If the rgba does not have enough padding area, we need to pad the area
65
+
66
+ :param rgba_input: H x W x 4 inputs, the first 3 channels are RGB, the last channel is the alpha
67
+ :returns: H x W x 4 padded RGBAD
68
+
69
+ """
70
+ padding = 50
71
+ padding_size = 256 - padding * 2
72
+
73
+ h, w = rgba_input.shape[:2]
74
+ rgb = rgba_input[:, :, :3]
75
+ alpha = rgba_input[:, :, -1:]
76
+
77
+ zeros = np.where(alpha==0)
78
+ hh, ww = zeros[0], zeros[1]
79
+ h_min, h_max = hh.min(), hh.max()
80
+ w_min, w_max = ww.min(), ww.max()
81
+
82
+ # if the area already has enough padding
83
+ if h_max - h_min < padding_size and w_max - w_min < padding_size:
84
+ return rgba_input
85
+
86
+ padding_output = np.zeros((256, 256, 4))
87
+ padding_output[..., :3] = 1.0
88
+
89
+ padded_rgba = resize(rgba_input, padding_size)
90
+ new_h, new_w = padded_rgba.shape[:2]
91
+
92
+ padding_output[padding:padding+new_h, padding:padding+new_w, :] = padded_rgba
93
+
94
+ return padding_output
95
+
96
+ def shadow_composite(rgba, shadow, intensity, gamma):
97
+ rgb = rgba[..., :3]
98
+ mask = rgba[..., 3:]
99
+
100
+ if len(shadow.shape) == 2:
101
+ shadow = shadow[..., None]
102
+
103
+ new_shadow = 1.0 - shadow ** gamma * intensity
104
+ ret = rgb * mask + (1.0 - mask) * new_shadow
105
+ return ret, new_shadow[..., 0]
106
+
107
+
108
+ def render_btn_fn(mask, ibl):
109
+ global cur_rgba, cur_shadow, cur_gamma, cur_intensity
110
+
111
+ print("Button clicked!")
112
+
113
+ mask = mask / 255.0
114
+ ibl = ibl/ 255.0
115
+
116
+ # smoothing ibl
117
+ ibl = cv2.GaussianBlur(ibl, (11, 11), 0)
118
+
119
+ # padding mask
120
+ mask = padding_mask(mask)
121
+
122
+ cur_rgba = np.copy(mask)
123
+
124
+
125
+ print('mask shape: {}/{}/{}/{}, ibl shape: {}/{}/{}/{}'.format(mask.shape, mask.dtype, mask.min(), mask.max(),
126
+ ibl.shape, ibl.dtype, ibl.min(), ibl.max()))
127
+
128
+ # ret = np.random.randn(256, 256, 3)
129
+ # ret = (ret - ret.min()) / (ret.max() - ret.min() + 1e-8)
130
+
131
+ rgb, mask = mask[..., :3], mask[..., 3]
132
+
133
+ ibl = ibl_normalize(cv2.resize(ibl, (32, 16)))
134
+
135
+ # ibl = 1.0 - ibl
136
+
137
+ x = {
138
+ 'mask': mask,
139
+ 'ibl': ibl
140
+ }
141
+ shadow = model.inference(x)
142
+ cur_shadow = np.copy(shadow)
143
+
144
+ # gamma
145
+ # shadow = np.power(shadow, 2.2)
146
+ # shadow = shadow * 0.8
147
+ # shadow = 1.0 - shadow
148
+
149
+ # composite the shadow
150
+
151
+ # shadow = shadow[..., None]
152
+ # mask = mask[..., None]
153
+ # ret = rgb * mask + (1.0 - mask) * shadow
154
+ ret, shadow = shadow_composite(cur_rgba, shadow, cur_intensity, cur_gamma)
155
+
156
+ # import pdb; pdb.set_trace()
157
+ # ret = (1.0-mask) * shadow
158
+
159
+ print('IBL range: {}/{} Shadow range: {} {}'.format(ibl.min(), ibl.max(), shadow.min(), shadow.max()))
160
+
161
+ plt.figure(figsize=(15, 10))
162
+ plt.subplot(1,3,1)
163
+ plt.imshow(mask)
164
+ plt.subplot(1,3,2)
165
+ plt.imshow(ibl)
166
+ plt.subplot(1,3,3)
167
+ plt.imshow(ret)
168
+ plt.savefig('tmp.png')
169
+ plt.close()
170
+
171
+ logging.info('Finished')
172
+
173
+ return ret, shadow
174
+
175
+
176
+ def intensity_change(x):
177
+ global cur_rgba, cur_shadow, cur_gamma, cur_intensity
178
+
179
+ cur_intensity = x
180
+ ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
181
+ return ret, shadow
182
+
183
+
184
+ def gamma_change(x):
185
+ global cur_rgba, cur_shadow, cur_gamma, cur_intensity
186
+
187
+ cur_gamma = x
188
+ ret, shadow = shadow_composite(cur_rgba, cur_shadow, cur_intensity, cur_gamma)
189
+ return ret, shadow
190
+
191
+
192
+ ibl_h = 128
193
+ ibl_w = ibl_h * 2
194
+
195
+ with gr.Blocks() as demo:
196
+ with gr.Row():
197
+ mask_input = gr.Image(shape=(256, 256), image_mode="RGBA", label="Mask")
198
+ ibl_input = gr.Sketchpad(shape=(ibl_w, ibl_h), image_mode="L", label="IBL", tool='sketch', invert_colors=True)
199
+ output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="RGB", label="Output")
200
+ shadow_output = gr.Image(shape=(256, 256), height=256, width=256, image_mode="L", label="Shadow Layer")
201
+
202
+ with gr.Row():
203
+ intensity_slider = gr.Slider(0.0, 1.0, value=DEFAULT_INTENSITY, step=0.1, label="Intensity", info="Choose between 0.0 and 1.0")
204
+ gamma_slider = gr.Slider(1.0, 4.0, value=DEFAULT_GAMMA, step=0.1, label="Gamma", info="Gamma correction for shadow")
205
+ render_btn = gr.Button(label="Render")
206
+
207
+ render_btn.click(render_btn_fn, inputs=[mask_input, ibl_input], outputs=[output, shadow_output])
208
+ intensity_slider.release(intensity_change, inputs=[intensity_slider], outputs=[output, shadow_output])
209
+ gamma_slider.release(gamma_change, inputs=[gamma_slider], outputs=[output, shadow_output])
210
+
211
+ logging.info('Finished')
212
+
213
+
214
+ demo.launch()
configs/GSSN.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: GSSN_ALL_Channels_2e_5
2
+
3
+ # model related
4
+ model:
5
+ name: 'GSSN'
6
+ # backbone: 'vanilla'
7
+ backbone: 'SSN_v1'
8
+ in_channels: 6
9
+ out_channels: 1
10
+ resnet: True
11
+
12
+ mid_act: "gelu"
13
+ out_act: "gelu"
14
+
15
+ optimizer: 'Adam'
16
+ weight_decay: 4e-5
17
+ beta1: 0.9
18
+
19
+ focal: False
20
+
21
+ # dataset
22
+ dataset:
23
+ name: 'GSSN_Dataset'
24
+ hdf5_file: 'Dataset1/more_general_scenes/train/ALL_SIZE_WALL/dataset.hdf5'
25
+ type: 'BC_Boundary'
26
+ rech_grad: True
27
+
28
+
29
+ test_dataset:
30
+ name: 'GSSN_Testing_Dataset'
31
+ hdf5_file: 'Dataset/standalone_test_split/test/ALL_SIZE_MORE/dataset.hdf5'
32
+ type: 'BC_Boundary'
33
+ ignore_shading: True
34
+ rech_grad: True
35
+
36
+
37
+ # training related
38
+ hyper_params:
39
+ lr: 2e-5
40
+ epochs: 100000
41
+ workers: 52
42
+ batch_size: 52
43
+ save_epoch: 10
44
+
45
+ eval_batch: 10
46
+ eval_save: False
47
+
48
+ # visualization
49
+ vis_iter: 100 # iteration for visualization
50
+ save_iter: 100
51
+ n_cols: 5
52
+ gpus:
53
+ - 0
54
+ default_folder: 'weights'
55
+ resume: False
56
+ # resume: True
57
+ weight_file: 'latest'
configs/SSN.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_name: SSN
2
+
3
+ # model related
4
+ model:
5
+ name: 'SSN'
6
+ in_channels: 1
7
+ out_channels: 1
8
+ resnet: False
9
+
10
+ mid_act: "relu"
11
+ out_act: 'relu'
12
+
13
+ optimizer: 'Adam'
14
+ weight_decay: 4e-5
15
+ beta1: 0.9
16
+
17
+
18
+ # dataset
19
+ dataset:
20
+ name: 'SSN_Dataset'
21
+ hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
22
+ shadow_per_epoch: 10
23
+
24
+
25
+ # test_dataset:
26
+ # name: 'SSN_Dataset'
27
+ # hdf5_file: 'Dataset/SSN/ssn_shadow/shadow_base/ssn_base.hdf5'
28
+
29
+
30
+ # training related
31
+ hyper_params:
32
+ lr: 1e-3
33
+ epochs: 100000
34
+ workers: 40
35
+ batch_size: 10
36
+ save_epoch: 10
37
+
38
+ eval_batch: 10
39
+ eval_save: False
40
+
41
+ # visualization
42
+ vis_iter: 100 # iteration for visualization
43
+ save_iter: 100
44
+ n_cols: 5
45
+ gpus:
46
+ - 0
47
+ - 1
48
+
49
+ default_folder: 'weights'
50
+ resume: False
51
+ weight_file: 'latest'
model_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import logging
4
+
5
+ import torch
6
+
7
+
8
+ def parse_configs(config: str):
9
+ """ Parse the config file and return a dictionary of configs
10
+
11
+ :param config: path to the config file
12
+ :returns:
13
+
14
+ """
15
+ if not os.path.exists(config):
16
+ logging.error('Cannot find the config file: {}'.format(config))
17
+ exit()
18
+
19
+ with open(config, 'r') as stream:
20
+ try:
21
+ configs=yaml.safe_load(stream)
22
+ return configs
23
+
24
+ except yaml.YAMLError as exc:
25
+ logging.error(exc)
26
+ return {}
27
+
28
+
29
+ def load_model(config: str, weight: str, model_def, device):
30
+ """ Load the model from the config file and the weight file
31
+
32
+ :param config: path to the config file
33
+ :param weight: path to the weight file
34
+ :param model_def: model class definition
35
+ :param device: pytorch device
36
+ :returns:
37
+
38
+ """
39
+ assert os.path.exists(weight), 'Cannot find the weight file: {}'.format(weight)
40
+ assert os.path.exists(config), 'Cannot find the config file: {}'.format(config)
41
+
42
+
43
+ opt = parse_configs(config)
44
+ model = model_def(opt)
45
+ cp = torch.load(weight)
46
+
47
+ models = model.get_models()
48
+ for k, m in models.items():
49
+ m.load_state_dict(cp[k])
50
+ m.to(device)
51
+
52
+ model.set_models(models)
53
+ return model
models/Attention.ipynb ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 30,
6
+ "id": "9ba18e04-aa6b-44d8-bbcc-73417ededcfd",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "import torch.nn.functional as F\n",
13
+ "from functools import partial\n",
14
+ "import math\n",
15
+ "import torch as th"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 31,
21
+ "id": "b273789d-9136-4c10-806d-12c19ff1ae68",
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "class GroupNorm32(nn.GroupNorm):\n",
26
+ " def forward(self, x):\n",
27
+ " return super().forward(x.float()).type(x.dtype)\n",
28
+ "\n",
29
+ "def normalization(channels):\n",
30
+ " \"\"\"\n",
31
+ " Make a standard normalization layer.\n",
32
+ " :param channels: number of input channels.\n",
33
+ " :return: an nn.Module for normalization.\n",
34
+ " \"\"\"\n",
35
+ " return GroupNorm32(32, channels)\n",
36
+ "\n",
37
+ "\n",
38
+ "def conv_nd(dims, *args, **kwargs):\n",
39
+ " \"\"\"\n",
40
+ " Create a 1D, 2D, or 3D convolution module.\n",
41
+ " \"\"\"\n",
42
+ " if dims == 1:\n",
43
+ " return nn.Conv1d(*args, **kwargs)\n",
44
+ " elif dims == 2:\n",
45
+ " return nn.Conv2d(*args, **kwargs)\n",
46
+ " elif dims == 3:\n",
47
+ " return nn.Conv3d(*args, **kwargs)\n",
48
+ " raise ValueError(f\"unsupported dimensions: {dims}\")\n"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 32,
54
+ "id": "8ad13d44-7efc-4cf3-8f18-3c6ed4999963",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "class QKVAttentionLegacy(nn.Module):\n",
59
+ " \"\"\"\n",
60
+ " A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n",
61
+ " \"\"\"\n",
62
+ "\n",
63
+ " def __init__(self, n_heads):\n",
64
+ " super().__init__()\n",
65
+ " self.n_heads = n_heads\n",
66
+ "\n",
67
+ " def forward(self, qkv):\n",
68
+ " \"\"\"\n",
69
+ " Apply QKV attention.\n",
70
+ " :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n",
71
+ " :return: an [N x (H * C) x T] tensor after attention.\n",
72
+ " \"\"\"\n",
73
+ " bs, width, length = qkv.shape\n",
74
+ " assert width % (3 * self.n_heads) == 0\n",
75
+ " ch = width // (3 * self.n_heads)\n",
76
+ " q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n",
77
+ " scale = 1 / math.sqrt(math.sqrt(ch))\n",
78
+ " weight = th.einsum(\n",
79
+ " \"bct,bcs->bts\", q * scale, k * scale\n",
80
+ " ) # More stable with f16 than dividing afterwards\n",
81
+ " weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n",
82
+ " a = th.einsum(\"bts,bcs->bct\", weight, v)\n",
83
+ " return a.reshape(bs, -1, length)\n",
84
+ "\n",
85
+ " @staticmethod\n",
86
+ " def count_flops(model, _x, y):\n",
87
+ " return count_flops_attn(model, _x, y)"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 33,
93
+ "id": "fd354430-2484-4f46-85f6-3397ae571fe9",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "def zero_module(module):\n",
98
+ " \"\"\"\n",
99
+ " Zero out the parameters of a module and return it.\n",
100
+ " \"\"\"\n",
101
+ " for p in module.parameters():\n",
102
+ " p.detach().zero_()\n",
103
+ " return module\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 37,
109
+ "id": "af42604f-c5fe-467b-95e9-e376fe90d4a5",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "class AttentionBlock(nn.Module):\n",
114
+ " \"\"\"\n",
115
+ " An attention block that allows spatial positions to attend to each other.\n",
116
+ " Originally ported from here, but adapted to the N-d case.\n",
117
+ " https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n",
118
+ " \"\"\"\n",
119
+ "\n",
120
+ " def __init__(\n",
121
+ " self,\n",
122
+ " channels,\n",
123
+ " num_heads=1,\n",
124
+ " num_head_channels=-1,\n",
125
+ " use_new_attention_order=False,\n",
126
+ " ):\n",
127
+ " super().__init__()\n",
128
+ " self.channels = channels\n",
129
+ " if num_head_channels == -1:\n",
130
+ " self.num_heads = num_heads\n",
131
+ " else:\n",
132
+ " assert (\n",
133
+ " channels % num_head_channels == 0\n",
134
+ " ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n",
135
+ " self.num_heads = channels // num_head_channels\n",
136
+ " self.norm = normalization(channels)\n",
137
+ " self.qkv = conv_nd(1, channels, channels * 3, 1)\n",
138
+ " if use_new_attention_order:\n",
139
+ " # split qkv before split heads\n",
140
+ " self.attention = QKVAttention(self.num_heads)\n",
141
+ " else:\n",
142
+ " # split heads before split qkv\n",
143
+ " self.attention = QKVAttentionLegacy(self.num_heads)\n",
144
+ "\n",
145
+ " self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n",
146
+ "\n",
147
+ " def forward(self, x):\n",
148
+ " \n",
149
+ " import pdb; pdb.set_trace()\n",
150
+ " \n",
151
+ " b, c, *spatial = x.shape\n",
152
+ " x = x.reshape(b, c, -1)\n",
153
+ " qkv = self.qkv(self.norm(x))\n",
154
+ " h = self.attention(qkv)\n",
155
+ " h = self.proj_out(h)\n",
156
+ " return (x + h).reshape(b, c, *spatial)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 38,
162
+ "id": "7b180b84-f22c-446b-b2da-0fa987274953",
163
+ "metadata": {},
164
+ "outputs": [
165
+ {
166
+ "name": "stdout",
167
+ "output_type": "stream",
168
+ "text": [
169
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(39)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
170
+ "\u001b[0;32m 37 \u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
171
+ "\u001b[0m\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
172
+ "\u001b[0m\u001b[0;32m---> 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
173
+ "\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
174
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
175
+ "\u001b[0m\n"
176
+ ]
177
+ },
178
+ {
179
+ "name": "stdin",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "ipdb> n\n"
183
+ ]
184
+ },
185
+ {
186
+ "name": "stdout",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(40)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
190
+ "\u001b[0;32m 38 \u001b[0;31m\u001b[0;34m\u001b[0m\u001b[0m\n",
191
+ "\u001b[0m\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
192
+ "\u001b[0m\u001b[0;32m---> 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
193
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
194
+ "\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
195
+ "\u001b[0m\n"
196
+ ]
197
+ },
198
+ {
199
+ "name": "stdin",
200
+ "output_type": "stream",
201
+ "text": [
202
+ "ipdb> n\n"
203
+ ]
204
+ },
205
+ {
206
+ "name": "stdout",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(41)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
210
+ "\u001b[0;32m 39 \u001b[0;31m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
211
+ "\u001b[0m\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
212
+ "\u001b[0m\u001b[0;32m---> 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
213
+ "\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
214
+ "\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
215
+ "\u001b[0m\n"
216
+ ]
217
+ },
218
+ {
219
+ "name": "stdin",
220
+ "output_type": "stream",
221
+ "text": [
222
+ "ipdb> x.shape\n"
223
+ ]
224
+ },
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "torch.Size([5, 32, 16384])\n"
230
+ ]
231
+ },
232
+ {
233
+ "name": "stdin",
234
+ "output_type": "stream",
235
+ "text": [
236
+ "ipdb> t = self.norm(x)\n",
237
+ "ipdb> t.shape\n"
238
+ ]
239
+ },
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "torch.Size([5, 32, 16384])\n"
245
+ ]
246
+ },
247
+ {
248
+ "name": "stdin",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "ipdb> self.qkv\n"
252
+ ]
253
+ },
254
+ {
255
+ "name": "stdout",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "Conv1d(32, 96, kernel_size=(1,), stride=(1,))\n"
259
+ ]
260
+ },
261
+ {
262
+ "name": "stdin",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "ipdb> n\n"
266
+ ]
267
+ },
268
+ {
269
+ "name": "stdout",
270
+ "output_type": "stream",
271
+ "text": [
272
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(42)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
273
+ "\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
274
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
275
+ "\u001b[0m\u001b[0;32m---> 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
276
+ "\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
277
+ "\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
278
+ "\u001b[0m\n"
279
+ ]
280
+ },
281
+ {
282
+ "name": "stdin",
283
+ "output_type": "stream",
284
+ "text": [
285
+ "ipdb> qkv.shape\n"
286
+ ]
287
+ },
288
+ {
289
+ "name": "stdout",
290
+ "output_type": "stream",
291
+ "text": [
292
+ "torch.Size([5, 96, 16384])\n"
293
+ ]
294
+ },
295
+ {
296
+ "name": "stdin",
297
+ "output_type": "stream",
298
+ "text": [
299
+ "ipdb> t.shape\n"
300
+ ]
301
+ },
302
+ {
303
+ "name": "stdout",
304
+ "output_type": "stream",
305
+ "text": [
306
+ "torch.Size([5, 32, 16384])\n"
307
+ ]
308
+ },
309
+ {
310
+ "name": "stdin",
311
+ "output_type": "stream",
312
+ "text": [
313
+ "ipdb> n\n"
314
+ ]
315
+ },
316
+ {
317
+ "name": "stdout",
318
+ "output_type": "stream",
319
+ "text": [
320
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(43)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
321
+ "\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
322
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
323
+ "\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
324
+ "\u001b[0m\u001b[0;32m---> 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
325
+ "\u001b[0m\u001b[0;32m 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
326
+ "\u001b[0m\n"
327
+ ]
328
+ },
329
+ {
330
+ "name": "stdin",
331
+ "output_type": "stream",
332
+ "text": [
333
+ "ipdb> h.shape\n"
334
+ ]
335
+ },
336
+ {
337
+ "name": "stdout",
338
+ "output_type": "stream",
339
+ "text": [
340
+ "*** No help for '.shape'\n"
341
+ ]
342
+ },
343
+ {
344
+ "name": "stdin",
345
+ "output_type": "stream",
346
+ "text": [
347
+ "ipdb> h.shape\n"
348
+ ]
349
+ },
350
+ {
351
+ "name": "stdout",
352
+ "output_type": "stream",
353
+ "text": [
354
+ "*** No help for '.shape'\n"
355
+ ]
356
+ },
357
+ {
358
+ "name": "stdin",
359
+ "output_type": "stream",
360
+ "text": [
361
+ "ipdb> print(h.shape)\n"
362
+ ]
363
+ },
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "torch.Size([5, 32, 16384])\n"
369
+ ]
370
+ },
371
+ {
372
+ "name": "stdin",
373
+ "output_type": "stream",
374
+ "text": [
375
+ "ipdb> self.proj_out\n"
376
+ ]
377
+ },
378
+ {
379
+ "name": "stdout",
380
+ "output_type": "stream",
381
+ "text": [
382
+ "Conv1d(32, 32, kernel_size=(1,), stride=(1,))\n"
383
+ ]
384
+ },
385
+ {
386
+ "name": "stdin",
387
+ "output_type": "stream",
388
+ "text": [
389
+ "ipdb> n\n"
390
+ ]
391
+ },
392
+ {
393
+ "name": "stdout",
394
+ "output_type": "stream",
395
+ "text": [
396
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
397
+ "\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
398
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
399
+ "\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
400
+ "\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
401
+ "\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
402
+ "\u001b[0m\n"
403
+ ]
404
+ },
405
+ {
406
+ "name": "stdin",
407
+ "output_type": "stream",
408
+ "text": [
409
+ "ipdb> \n"
410
+ ]
411
+ },
412
+ {
413
+ "name": "stdout",
414
+ "output_type": "stream",
415
+ "text": [
416
+ "--Return--\n",
417
+ "tensor([[[[ 1...iasBackward0>)\n",
418
+ "> \u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m(44)\u001b[0;36mforward\u001b[0;34m()\u001b[0m\n",
419
+ "\u001b[0;32m 40 \u001b[0;31m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
420
+ "\u001b[0m\u001b[0;32m 41 \u001b[0;31m \u001b[0mqkv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
421
+ "\u001b[0m\u001b[0;32m 42 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
422
+ "\u001b[0m\u001b[0;32m 43 \u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
423
+ "\u001b[0m\u001b[0;32m---> 44 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
424
+ "\u001b[0m\n"
425
+ ]
426
+ },
427
+ {
428
+ "name": "stdin",
429
+ "output_type": "stream",
430
+ "text": [
431
+ "ipdb> q\n"
432
+ ]
433
+ },
434
+ {
435
+ "ename": "BdbQuit",
436
+ "evalue": "",
437
+ "output_type": "error",
438
+ "traceback": [
439
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
440
+ "\u001b[0;31mBdbQuit\u001b[0m Traceback (most recent call last)",
441
+ "\u001b[0;32m/tmp/ipykernel_456404/1120562961.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAttentionBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
442
+ "\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
443
+ "\u001b[0;32m/tmp/ipykernel_456404/3277534714.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mproj_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mspatial\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
444
+ "\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'return'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_return\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'exception'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
445
+ "\u001b[0;32m~/anaconda3/envs/py38/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mdispatch_return\u001b[0;34m(self, frame, arg)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mframe_returning\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 155\u001b[0m \u001b[0;31m# The user issued a 'next' or 'until' command.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstopframe\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mframe\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstoplineno\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
446
+ "\u001b[0;31mBdbQuit\u001b[0m: "
447
+ ]
448
+ }
449
+ ],
450
+ "source": [
451
+ "test_input = torch.randn(5, 32, 128, 128)\n",
452
+ "\n",
453
+ "model = AttentionBlock(32, 1)\n",
454
+ "\n",
455
+ "y = model(test_input)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": 36,
461
+ "id": "3109500e-146d-46c4-8709-6a1e8d24e4ac",
462
+ "metadata": {},
463
+ "outputs": [
464
+ {
465
+ "data": {
466
+ "text/plain": [
467
+ "torch.Size([5, 32, 128, 128])"
468
+ ]
469
+ },
470
+ "execution_count": 36,
471
+ "metadata": {},
472
+ "output_type": "execute_result"
473
+ }
474
+ ],
475
+ "source": [
476
+ "y.shape"
477
+ ]
478
+ },
479
+ {
480
+ "cell_type": "code",
481
+ "execution_count": null,
482
+ "id": "0c916f9c-5dba-499d-99ea-e56f2855c9cc",
483
+ "metadata": {},
484
+ "outputs": [],
485
+ "source": []
486
+ }
487
+ ],
488
+ "metadata": {
489
+ "kernelspec": {
490
+ "display_name": "Python 3 (ipykernel)",
491
+ "language": "python",
492
+ "name": "python3"
493
+ },
494
+ "language_info": {
495
+ "codemirror_mode": {
496
+ "name": "ipython",
497
+ "version": 3
498
+ },
499
+ "file_extension": ".py",
500
+ "mimetype": "text/x-python",
501
+ "name": "python",
502
+ "nbconvert_exporter": "python",
503
+ "pygments_lexer": "ipython3",
504
+ "version": "3.8.12"
505
+ }
506
+ },
507
+ "nbformat": 4,
508
+ "nbformat_minor": 5
509
+ }
models/Attention_SSN.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+ from typing import Iterable
4
+ import math
5
+
6
+ import numpy as np
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .SSN import Conv, Conv2DMod, Decoder, Up
13
+ from .attention import AttentionBlock
14
+ from .blocks import ResBlock, Res_Type, get_activation
15
+
16
+
17
+ class Attention_Encoder(nn.Module):
18
+ def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
19
+ super(Attention_Encoder, self).__init__()
20
+
21
+ self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
22
+ self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
23
+ self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
24
+
25
+ self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
26
+ self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
27
+
28
+ self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
29
+ self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
30
+ self.down_256_256_1_attn = AttentionBlock(256, num_heads)
31
+
32
+ self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
33
+ self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
34
+ self.down_512_512_1_attn = AttentionBlock(512, num_heads)
35
+
36
+ self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
37
+ self.down_512_512_2_attn = AttentionBlock(512, num_heads)
38
+
39
+ self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
40
+ self.down_512_512_3_attn = AttentionBlock(512, num_heads)
41
+
42
+
43
+ def forward(self, x):
44
+ x1 = self.in_conv(x) # 32 x 256 x 256
45
+ x1 = torch.cat((x, x1), dim=1)
46
+
47
+ x2 = self.down_32_64(x1)
48
+ x3 = self.down_64_64_1(x2)
49
+
50
+ x4 = self.down_64_128(x3)
51
+ x5 = self.down_128_128_1(x4)
52
+
53
+ x6 = self.down_128_256(x5)
54
+ x7 = self.down_256_256_1(x6)
55
+ x7 = self.down_256_256_1_attn(x7)
56
+
57
+ x8 = self.down_256_512(x7)
58
+ x9 = self.down_512_512_1(x8)
59
+ x9 = self.down_512_512_1_attn(x9)
60
+
61
+ x10 = self.down_512_512_2(x9)
62
+ x10 = self.down_512_512_2_attn(x10)
63
+
64
+ x11 = self.down_512_512_3(x10)
65
+ x11 = self.down_512_512_3_attn(x11)
66
+
67
+ return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
68
+
69
+
70
+ class Attention_Decoder(nn.Module):
71
+ def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
72
+
73
+ super(Attention_Decoder, self).__init__()
74
+
75
+ input_channel = 512
76
+ fea_dim = 100
77
+
78
+ self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
79
+
80
+ self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=True, resnet=resnet)
81
+ self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
82
+
83
+ self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
84
+ self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
85
+
86
+ self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
87
+ self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
88
+
89
+ self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
90
+ self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
91
+ self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=True, resnet=resnet)
92
+ self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
93
+
94
+ self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
95
+ self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
96
+ self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=True, resnet=resnet)
97
+
98
+ self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
99
+ self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
100
+ self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=True, resnet=resnet)
101
+
102
+ self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
103
+ self.out_conv = Conv(64, out_channels, activation=out_act)
104
+ self.out_act = get_activation(out_act)
105
+
106
+
107
+ def forward(self, x, style):
108
+ x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
109
+
110
+ style1 = self.to_style1(style)
111
+ y = self.up_16_16_1(x11, style1) # 256 x 16 x 16
112
+ y = self.up_16_16_1_attn(y)
113
+
114
+ y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
115
+ y = self.up_16_16_2(y, y) # 512 x 16 x 16
116
+ y = self.up_16_16_2_attn(y)
117
+
118
+
119
+ y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
120
+ y = self.up_16_16_3(y, y) # 512 x 16 x 16
121
+ y = self.up_16_16_3_attn(y)
122
+
123
+ y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
124
+ y = self.up_16_32(y, y) # 256 x 32 x 32
125
+
126
+ y = torch.cat((x7, y), dim=1)
127
+ style2 = self.to_style2(style)
128
+ y = self.up_32_32_1(y, style2) # 256 x 32 x 32
129
+ y = self.up_32_32_1_attn(y)
130
+
131
+ y = torch.cat((x6, y), dim=1)
132
+ y = self.up_32_64(y, y)
133
+
134
+ y = torch.cat((x5, y), dim=1)
135
+ style3 = self.to_style3(style)
136
+
137
+ y = self.up_64_64_1(y, style3) # 128 x 64 x 64
138
+
139
+ y = torch.cat((x4, y), dim=1)
140
+ y = self.up_64_128(y, y)
141
+
142
+ y = torch.cat((x3, y), dim=1)
143
+ style4 = self.to_style4(style)
144
+ y = self.up_128_128_1(y, style4) # 64 x 128 x 128
145
+
146
+ y = torch.cat((x2, y), dim=1)
147
+ y = self.up_128_256(y, y) # 32 x 256 x 256
148
+
149
+ y = torch.cat((x1, y), dim=1)
150
+ y = self.out_conv(y, y) # 3 x 256 x 256
151
+ y = self.out_act(y)
152
+ return y
153
+
154
+
155
+
156
+ class Attention_SSN(nn.Module):
157
+ def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
158
+ super(Attention_SSN, self).__init__()
159
+ self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
160
+ self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
161
+
162
+
163
+ def forward(self, x, softness):
164
+ latent = self.encoder(x)
165
+ pred = self.decoder(latent, softness)
166
+
167
+ return pred
168
+
169
+
170
+ def get_model_size(model):
171
+ param_size = 0
172
+ import pdb; pdb.set_trace()
173
+ for param in model.parameters():
174
+ param_size += param.nelement() * param.element_size()
175
+
176
+ buffer_size = 0
177
+ for buffer in model.buffers():
178
+ buffer_size += buffer.nelement() * buffer.element_size()
179
+
180
+ size_all_mb = (param_size + buffer_size) / 1024 ** 2
181
+ print('model size: {:.3f}MB'.format(size_all_mb))
182
+ # return param_size + buffer_size
183
+ return size_all_mb
184
+
185
+
186
+ if __name__ == '__main__':
187
+ model = AttentionBlock(in_channels=256, num_heads=8)
188
+ x = torch.randn(5, 256, 64, 64)
189
+
190
+ y = model(x)
191
+ print('{}, {}'.format(x.shape, y.shape))
192
+
193
+ # ------------------------------------------------------------------ #
194
+ in_channels = 3
195
+ out_channels = 1
196
+ num_heads = 8
197
+ resnet = True
198
+ mid_act = 'gelu'
199
+ out_act = 'gelu'
200
+
201
+ model = Attention_SSN(in_channels=in_channels,
202
+ out_channels=out_channels,
203
+ num_heads=num_heads,
204
+ resnet=resnet,
205
+ mid_act=mid_act,
206
+ out_act=out_act)
207
+
208
+ x = torch.randn(5, 3, 256, 256)
209
+ softness = torch.randn(5, 100)
210
+
211
+
212
+ y = model(x, softness)
213
+
214
+
215
+ print('x: {}, y: {}'.format(x.shape, y.shape))
216
+
217
+ get_model_size(model)
218
+ # ------------------------------------------------------------------ #
models/Attention_Unet.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import numpy as np
6
+
7
+ from .SSN import Conv, Conv2DMod, Decoder, Up
8
+ from .attention import AttentionBlock
9
+ from .blocks import ResBlock, Res_Type, get_activation
10
+
11
+ class Attention_Encoder(nn.Module):
12
+ def __init__(self, in_channels=3, mid_act='gelu', dropout=0.0, num_heads=8, resnet=True):
13
+ super(Attention_Encoder, self).__init__()
14
+
15
+ self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
16
+ self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
17
+ self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
18
+
19
+ self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
20
+ self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
21
+
22
+ self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
23
+ self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
24
+ self.down_256_256_1_attn = AttentionBlock(256, num_heads)
25
+
26
+ self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
27
+ self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
28
+ self.down_512_512_1_attn = AttentionBlock(512, num_heads)
29
+
30
+ self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
31
+ self.down_512_512_2_attn = AttentionBlock(512, num_heads)
32
+
33
+ self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
34
+ self.down_512_512_3_attn = AttentionBlock(512, num_heads)
35
+
36
+
37
+ def forward(self, x):
38
+ x1 = self.in_conv(x) # 32 x 256 x 256
39
+ x1 = torch.cat((x, x1), dim=1)
40
+
41
+ x2 = self.down_32_64(x1)
42
+ x3 = self.down_64_64_1(x2)
43
+
44
+ x4 = self.down_64_128(x3)
45
+ x5 = self.down_128_128_1(x4)
46
+
47
+ x6 = self.down_128_256(x5)
48
+ x7 = self.down_256_256_1(x6)
49
+ x7 = self.down_256_256_1_attn(x7)
50
+
51
+ x8 = self.down_256_512(x7)
52
+ x9 = self.down_512_512_1(x8)
53
+ x9 = self.down_512_512_1_attn(x9)
54
+
55
+ x10 = self.down_512_512_2(x9)
56
+ x10 = self.down_512_512_2_attn(x10)
57
+
58
+ x11 = self.down_512_512_3(x10)
59
+ x11 = self.down_512_512_3_attn(x11)
60
+
61
+ return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
62
+
63
+
64
+ class Attention_Decoder(nn.Module):
65
+ def __init__(self, out_channels=3, mid_act='gelu', out_act='sigmoid', resnet = True, num_heads=8):
66
+
67
+ super(Attention_Decoder, self).__init__()
68
+
69
+ input_channel = 512
70
+ fea_dim = 100
71
+
72
+ self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
73
+
74
+ self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, style=False, resnet=resnet)
75
+ self.up_16_16_1_attn = AttentionBlock(256, num_heads=num_heads)
76
+
77
+ self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
78
+ self.up_16_16_2_attn = AttentionBlock(512, num_heads=num_heads)
79
+
80
+ self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
81
+ self.up_16_16_3_attn = AttentionBlock(512, num_heads=num_heads)
82
+
83
+ self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
84
+ self.to_style2 = nn.Linear(in_features=fea_dim, out_features=512)
85
+ self.up_32_32_1 = Conv(512, 256, activation=mid_act, style=False, resnet=resnet)
86
+ self.up_32_32_1_attn = AttentionBlock(256, num_heads=num_heads)
87
+
88
+ self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
89
+ self.to_style3 = nn.Linear(in_features=fea_dim, out_features=256)
90
+ self.up_64_64_1 = Conv(256, 128, activation=mid_act, style=False, resnet=resnet)
91
+
92
+ self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
93
+ self.to_style4 = nn.Linear(in_features=fea_dim, out_features=128)
94
+ self.up_128_128_1 = Conv(128, 64, activation=mid_act, style=False, resnet=resnet)
95
+
96
+ self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
97
+ self.out_conv = Conv(64, out_channels, activation=out_act)
98
+ self.out_act = get_activation(out_act)
99
+
100
+
101
+ def forward(self, x):
102
+ x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
103
+
104
+ y = self.up_16_16_1(x11) # 256 x 16 x 16
105
+ y = self.up_16_16_1_attn(y)
106
+
107
+ y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
108
+ y = self.up_16_16_2(y, y) # 512 x 16 x 16
109
+ y = self.up_16_16_2_attn(y)
110
+
111
+
112
+ y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
113
+ y = self.up_16_16_3(y, y) # 512 x 16 x 16
114
+ y = self.up_16_16_3_attn(y)
115
+
116
+ y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
117
+ y = self.up_16_32(y, y) # 256 x 32 x 32
118
+
119
+ y = torch.cat((x7, y), dim=1)
120
+ y = self.up_32_32_1(y) # 256 x 32 x 32
121
+ y = self.up_32_32_1_attn(y)
122
+
123
+ y = torch.cat((x6, y), dim=1)
124
+ y = self.up_32_64(y, y)
125
+
126
+ y = torch.cat((x5, y), dim=1)
127
+
128
+ y = self.up_64_64_1(y) # 128 x 64 x 64
129
+
130
+ y = torch.cat((x4, y), dim=1)
131
+ y = self.up_64_128(y, y)
132
+
133
+ y = torch.cat((x3, y), dim=1)
134
+ y = self.up_128_128_1(y) # 64 x 128 x 128
135
+
136
+ y = torch.cat((x2, y), dim=1)
137
+ y = self.up_128_256(y, y) # 32 x 256 x 256
138
+
139
+ y = torch.cat((x1, y), dim=1)
140
+ y = self.out_conv(y, y) # 3 x 256 x 256
141
+ y = self.out_act(y)
142
+ return y
143
+
144
+
145
+ class Attention_Unet(nn.Module):
146
+ def __init__(self, in_channels, out_channels, num_heads=8, resnet=True, mid_act='gelu', out_act='gelu'):
147
+ super(Attention_Unet, self).__init__()
148
+ self.encoder = Attention_Encoder(in_channels, mid_act, num_heads, resnet)
149
+ self.decoder = Attention_Decoder(out_channels, mid_act, out_act, resnet)
150
+
151
+
152
+ def forward(self, x):
153
+ latent = self.encoder(x)
154
+ pred = self.decoder(latent)
155
+ return pred
156
+
157
+
158
+ if __name__ == '__main__':
159
+ test_input = torch.randn(5, 1, 256, 256)
160
+ style = torch.randn(5, 100)
161
+
162
+ model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
163
+ test_out = model(test_input, style)
164
+
165
+ print('Ouptut shape: ', test_out.shape)
models/GSSN.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import utils
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+ import matplotlib.cm as cm
8
+ import matplotlib as mpl
9
+
10
+ from .abs_model import abs_model
11
+ from .blocks import *
12
+ from .SSN import SSN
13
+ from .SSN_v1 import SSN_v1
14
+ from .Loss.Loss import norm_loss
15
+
16
+
17
+ class GSSN(abs_model):
18
+ def __init__(self, opt):
19
+ mid_act = opt['model']['mid_act']
20
+ out_act = opt['model']['out_act']
21
+ in_channels = opt['model']['in_channels']
22
+ out_channels = opt['model']['out_channels']
23
+ resnet = opt['model']['resnet']
24
+ self.ncols = opt['hyper_params']['n_cols']
25
+ self.focal = opt['model']['focal']
26
+
27
+ if 'backbone' not in opt['model'].keys():
28
+ self.model = SSN(in_channels=in_channels,
29
+ out_channels=out_channels,
30
+ mid_act=mid_act,
31
+ out_act=out_act,
32
+ resnet=resnet)
33
+
34
+ else:
35
+ backbone = opt['model']['backbone']
36
+ if backbone == 'vanilla':
37
+ self.model = SSN(in_channels=in_channels,
38
+ out_channels=out_channels,
39
+ mid_act=mid_act,
40
+ out_act=out_act,
41
+ resnet=resnet)
42
+ elif backbone == 'SSN_v1':
43
+ self.model = SSN_v1(in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ mid_act=mid_act,
46
+ out_act=out_act,
47
+ resnet=resnet)
48
+ else:
49
+ raise NotImplementedError('{} has not implemented yet'.format(backbone))
50
+
51
+
52
+ self.optimizer = get_optimizer(opt, self.model)
53
+ self.visualization = {}
54
+
55
+ self.norm_loss = norm_loss()
56
+
57
+ # inference related
58
+ BINs = 100
59
+ MAX_RAD = 20
60
+ self.size_interval = MAX_RAD / BINs
61
+ self.soft_distribution = [[np.exp(-0.2 * (i - j) ** 2) for i in np.arange(BINs)] for j in np.arange(BINs)]
62
+
63
+ def setup_input(self, x):
64
+ return x
65
+
66
+
67
+ def forward(self, x):
68
+ x, softness = x
69
+ return self.model(x, softness)
70
+
71
+
72
+ def compute_loss(self, y, pred):
73
+ b = y.shape[0]
74
+
75
+ total_loss = self.norm_loss.loss(y, pred)
76
+
77
+ if self.focal:
78
+ total_loss = torch.pow(total_loss, 3)
79
+
80
+ return total_loss
81
+
82
+
83
+ def supervise(self, input_x, y, is_training:bool)->float:
84
+ optimizer = self.optimizer
85
+ model = self.model
86
+
87
+ x, softness = input_x['x'], input_x['softness']
88
+
89
+ optimizer.zero_grad()
90
+ pred = model(x, softness)
91
+ loss = self.compute_loss(y, pred)
92
+
93
+ if is_training:
94
+ loss.backward()
95
+ optimizer.step()
96
+
97
+ xc = x.shape[1]
98
+ for i in range(xc):
99
+ self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
100
+
101
+ self.visualization['y'] = y.detach()
102
+ self.visualization['pred'] = pred.detach()
103
+
104
+ return loss.item()
105
+
106
+
107
+ def get_visualize(self) -> OrderedDict:
108
+ """ Convert to visualization numpy array
109
+ """
110
+ nrows = self.ncols
111
+ visualizations = self.visualization
112
+ ret_vis = OrderedDict()
113
+
114
+ for k, v in visualizations.items():
115
+ batch = v.shape[0]
116
+ n = min(nrows, batch)
117
+
118
+ plot_v = v[:n]
119
+ ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
120
+ ret_vis[k] = self.plasma(ret_vis[k])
121
+
122
+ return ret_vis
123
+
124
+
125
+ def get_logs(self):
126
+ pass
127
+
128
+
129
+ def inference(self, x):
130
+ x, l, device = x['x'], x['l'], x['device']
131
+
132
+ x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).to(device)
133
+ l = torch.from_numpy(np.array(self.soft_distribution[int(l/self.size_interval)]).astype(np.float32)).unsqueeze(dim=0).to(device)
134
+
135
+ pred = self.forward((x, l))
136
+ pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
137
+ return pred
138
+
139
+
140
+ def batch_inference(self, x):
141
+ x, l = x['x'], x['softness']
142
+ pred = self.forward((x, l))
143
+ return pred
144
+
145
+
146
+ """ Getter & Setter
147
+ """
148
+ def get_models(self) -> dict:
149
+ return {'model': self.model}
150
+
151
+
152
+ def get_optimizers(self) -> dict:
153
+ return {'optimizer': self.optimizer}
154
+
155
+
156
+ def set_models(self, models: dict) :
157
+ # input test
158
+ if 'model' not in models.keys():
159
+ raise ValueError('{} not in self.model'.format('model'))
160
+
161
+ self.model = models['model']
162
+
163
+
164
+ def set_optimizers(self, optimizer: dict):
165
+ self.optimizer = optimizer['optimizer']
166
+
167
+
168
+ ####################
169
+ # Personal Methods #
170
+ ####################
171
+ def plasma(self, x):
172
+ norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
173
+ mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
174
+ bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
175
+
176
+ return bimg
models/Loss/Loss.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as T
5
+ from torch.autograd import Variable
6
+
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # from vgg19_loss import VGG19Loss
11
+ # import pytorch_ssim
12
+
13
+ from .vgg19_loss import VGG19Loss
14
+ from . import pytorch_ssim
15
+ from abc import ABC, abstractmethod
16
+ from collections import OrderedDict
17
+
18
+ class abs_loss(ABC):
19
+ def loss(self, gt_img, pred_img):
20
+ pass
21
+
22
+
23
+ class norm_loss(abs_loss):
24
+ def __init__(self, norm=1):
25
+ self.norm = norm
26
+
27
+
28
+ def loss(self, gt_img, pred_img):
29
+ """ M * (I-I') """
30
+ b, c, h, w = gt_img.shape
31
+ return torch.norm(gt_img-pred_img, self.norm)/(h * w * b)
32
+
33
+
34
+
35
+ class ssim_loss(abs_loss):
36
+ def __init__(self, window_size=11, channel=1):
37
+ """ Let's try mean ssim!
38
+ """
39
+ self.channel = channel
40
+ self.window_size = window_size
41
+ self.window = self.create_mean_window(window_size, channel)
42
+
43
+
44
+ def loss(self, gt_img, pred_img):
45
+ b, c, h, w = gt_img.shape
46
+ if c != self.channel:
47
+ self.channel = c
48
+ self.window = self.create_mean_window(self.window_size, self.channel)
49
+
50
+ self.window = self.window.to(gt_img).type_as(gt_img)
51
+ l = 1.0 - self.ssim_compute(gt_img, pred_img)
52
+ return l
53
+
54
+
55
+ def create_mean_window(self, window_size, channel):
56
+ window = Variable(torch.ones(channel, 1, window_size, window_size).float())
57
+ window = window/(window_size * window_size)
58
+ return window
59
+
60
+
61
+ def ssim_compute(self, gt_img, pred_img):
62
+ window = self.window
63
+ window_size = self.window_size
64
+ channel = self.channel
65
+
66
+ mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel)
67
+ mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel)
68
+
69
+ mu1_sq = mu1.pow(2)
70
+ mu2_sq = mu2.pow(2)
71
+ mu1_mu2 = mu1*mu2
72
+
73
+ sigma1_sq = F.conv2d(gt_img*gt_img, window, padding = window_size//2, groups = channel) - mu1_sq
74
+ sigma2_sq = F.conv2d(pred_img*pred_img, window, padding = window_size//2, groups = channel) - mu2_sq
75
+ sigma12 = F.conv2d(gt_img*pred_img, window, padding = window_size//2, groups = channel) - mu1_mu2
76
+
77
+ C1 = 0.01**2
78
+ C2 = 0.03**2
79
+
80
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
81
+
82
+ return ssim_map.mean()
83
+
84
+
85
+ class hierarchical_ssim_loss(abs_loss):
86
+ def __init__(self, patch_list: list):
87
+ self.ssim_loss_list = [pytorch_ssim.SSIM(window_size=ws) for ws in patch_list]
88
+
89
+
90
+ def loss(self, gt_img, pred_img):
91
+ b, c, h, w = gt_img.shape
92
+ total_loss = 0.0
93
+ for loss_func in self.ssim_loss_list:
94
+ total_loss += (1.0-loss_func(gt_img, pred_img))
95
+
96
+ return total_loss/b
97
+
98
+
99
+ class vgg_loss(abs_loss):
100
+ def __init__(self):
101
+ self.vgg19_ = VGG19Loss()
102
+
103
+
104
+ def loss(self, gt_img, pred_img):
105
+ b, c, h, w = gt_img.shape
106
+ v = self.vgg19_(gt_img, pred_img, pred_img.device)
107
+ return v/b
108
+
109
+
110
+ class grad_loss(abs_loss):
111
+ def __init__(self, k=4):
112
+ self.k = 4
113
+
114
+ def loss(self, disp_img, rgb_img=None):
115
+ """ Note, gradient loss should be weighted by an edge-aware weight
116
+ """
117
+ b, c, h, w = disp_img.shape
118
+
119
+ grad_loss = 0.0
120
+ for i in range(self.k):
121
+ div_factor = 2 ** i
122
+ cur_transform = T.Resize([h // div_factor, ])
123
+ # cur_diff = cur_transform(diff)
124
+ # cur_diff_dx, cur_diff_dy = self.img_grad(cur_diff)
125
+ cur_disp = cur_transform(disp_img)
126
+
127
+ cur_disp_dx, cur_disp_dy = self.img_grad(cur_disp)
128
+
129
+ if rgb_img is not None:
130
+ cur_rgb = cur_transform(rgb_img)
131
+ cur_rgb_dx, cur_rgb_dy = self.img_grad(cur_rgb)
132
+
133
+ cur_rgb_dx = torch.exp(-torch.mean(torch.abs(cur_rgb_dx), dim=1, keepdims=True))
134
+ cur_rgb_dy = torch.exp(-torch.mean(torch.abs(cur_rgb_dy), dim=1, keepdims=True))
135
+ grad_loss += (torch.sum(torch.abs(cur_disp_dx) * cur_rgb_dx) + torch.sum(torch.abs(cur_disp_dy) * cur_rgb_dy)) / (h * w * self.k)
136
+ else:
137
+ grad_loss += (torch.sum(torch.abs(cur_disp_dx)) + torch.sum(torch.abs(cur_disp_dy))) / (h * w * self.k)
138
+
139
+ return grad_loss/b
140
+
141
+
142
+ def gloss(self, gt, pred):
143
+ """ Loss on the gradient domain
144
+ """
145
+ b, c, h, w = gt.shape
146
+ gt_dx, gt_dy = self.img_grad(gt)
147
+ pred_dx, pred_dy = self.img_grad(pred)
148
+
149
+ loss = (gt_dx-pred_dx) ** 2 + (gt_dy - pred_dy) ** 2
150
+ return loss.sum()/(b * h * w)
151
+
152
+
153
+ def laploss(self, pred):
154
+ b, c, h, w = pred.shape
155
+ lap = self.img_laplacian(pred)
156
+
157
+ return torch.abs(lap).sum()/(b * h * w)
158
+
159
+
160
+ def img_laplacian(self, img):
161
+ b, c, h, w = img.shape
162
+ laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
163
+
164
+ laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
165
+
166
+ lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
167
+ return lap
168
+
169
+
170
+ def img_grad(self, img):
171
+ """ Comptue image gradient by sobel filtering
172
+ img: B x C x H x W
173
+ """
174
+
175
+ b, c, h, w = img.shape
176
+ ysobel = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
177
+ xsobel = ysobel.transpose(0,1)
178
+
179
+ xsobel_kernel = xsobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
180
+ ysobel_kernel = ysobel.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
181
+ dx = F.conv2d(img, xsobel_kernel, padding=1, stride=1)
182
+ dy = F.conv2d(img, ysobel_kernel, padding=1, stride=1)
183
+
184
+ return dx, dy
185
+
186
+
187
+
188
+ class sharp_loss(abs_loss):
189
+ """ Sharpness term
190
+ 1. laplacian
191
+ 2. image contrast
192
+ 3. image variance
193
+ """
194
+ def __init__(self, window_size=11, channel=1):
195
+ self.window_size = window_size
196
+ self.channel = channel
197
+ self.window = self.create_mean_window(window_size, self.channel)
198
+
199
+
200
+ def loss(self, gt_img, pred_img):
201
+ """ Note, gradient loss should be weighted by an edge-aware weight
202
+ """
203
+ b, c, h, w = gt_img.shape
204
+
205
+ if c != self.channel:
206
+ self.channel = c
207
+ self.window = self.create_mean_window(self.window_size, self.channel)
208
+
209
+ self.window = self.window.to(gt_img).type_as(gt_img)
210
+
211
+ channel = self.channel
212
+ window = self.window
213
+ window_size = self.window_size
214
+
215
+ mu1 = F.conv2d(gt_img, window, padding = window_size//2, groups = channel) + 1e-6
216
+ mu2 = F.conv2d(pred_img, window, padding = window_size//2, groups = channel) + 1e-6
217
+
218
+ constrast1 = torch.absolute((gt_img - mu1)/mu1)
219
+ constrast2 = torch.absolute((pred_img - mu2)/mu2)
220
+
221
+ variance1 = (gt_img-mu1) ** 2
222
+ variance2 = (pred_img-mu2) ** 2
223
+
224
+ laplacian1 = self.img_laplacian(gt_img)
225
+ laplacian2 = self.img_laplacian(pred_img)
226
+
227
+ S1 = -laplacian1 - constrast1 - variance1
228
+ S2 = -laplacian2 - constrast2 - variance2
229
+
230
+ # import pdb; pdb.set_trace()
231
+ total = torch.absolute(S1-S2).mean()
232
+ return total
233
+
234
+
235
+ def img_laplacian(self, img):
236
+ b, c, h, w = img.shape
237
+ laplacian = torch.tensor([[1, 4, 1], [4, -20, 4], [1, 4, 1]])
238
+
239
+ laplacian_kernel = laplacian.float().unsqueeze(0).expand(1, c, 3, 3).to(img)
240
+
241
+ lap = F.conv2d(img, laplacian_kernel, padding=1, stride=1)
242
+ return lap
243
+
244
+
245
+ def create_mean_window(self, window_size, channel):
246
+ window = Variable(torch.ones(channel, 1, window_size, window_size).float())
247
+ window = window/(window_size * window_size)
248
+ return window
249
+
250
+
251
+ if __name__ == '__main__':
252
+ a = torch.rand(3,3,128,128)
253
+ b = torch.rand(3,3,128,128)
254
+
255
+ ssim = ssim_loss()
256
+ loss = ssim.loss(a, b)
257
+ print(loss.shape, loss)
258
+
259
+ loss = ssim.loss(a, a)
260
+ print(loss.shape, loss)
261
+
262
+ loss = ssim.loss(b, b)
263
+ print(loss.shape, loss)
264
+
265
+ grad = grad_loss()
266
+ loss = grad.loss(a, [b, b])
267
+ print(loss.shape, loss)
268
+
269
+ sharp = sharp_loss()
270
+ loss = sharp.loss(a, b)
271
+ print(loss.shape, loss)
models/Loss/__init__.py ADDED
File without changes
models/Loss/__pycache__/Loss.cpython-39.pyc ADDED
Binary file (8.36 kB). View file
 
models/Loss/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
models/Loss/__pycache__/vgg19_loss.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
models/Loss/pytorch_ssim/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.autograd import Variable
4
+ import numpy as np
5
+ from math import exp
6
+
7
+ def gaussian(window_size, sigma):
8
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9
+ return gauss/gauss.sum()
10
+
11
+ def create_window(window_size, channel):
12
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15
+ return window
16
+
17
+ def _ssim(img1, img2, window, window_size, channel, size_average = True):
18
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20
+
21
+ mu1_sq = mu1.pow(2)
22
+ mu2_sq = mu2.pow(2)
23
+ mu1_mu2 = mu1*mu2
24
+
25
+ sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26
+ sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27
+ sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28
+
29
+ C1 = 0.01**2
30
+ C2 = 0.03**2
31
+
32
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33
+
34
+ if size_average:
35
+ return ssim_map.mean()
36
+ else:
37
+ return ssim_map.mean(1).mean(1).mean(1)
38
+
39
+ class SSIM(torch.nn.Module):
40
+ def __init__(self, window_size = 11, size_average = True):
41
+ super(SSIM, self).__init__()
42
+ self.window_size = window_size
43
+ self.size_average = size_average
44
+ self.channel = 1
45
+ self.window = create_window(window_size, self.channel)
46
+
47
+ def forward(self, img1, img2):
48
+ (_, channel, _, _) = img1.size()
49
+
50
+ if channel == self.channel and self.window.data.type() == img1.data.type():
51
+ window = self.window
52
+ else:
53
+ window = create_window(self.window_size, channel)
54
+
55
+ if img1.is_cuda:
56
+ window = window.cuda(img1.get_device())
57
+ window = window.type_as(img1)
58
+
59
+ self.window = window
60
+ self.channel = channel
61
+
62
+
63
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
64
+
65
+ def ssim(img1, img2, window_size = 11, size_average = True):
66
+ (_, channel, _, _) = img1.size()
67
+ window = create_window(window_size, channel)
68
+
69
+ if img1.is_cuda:
70
+ window = window.cuda(img1.get_device())
71
+ window = window.type_as(img1)
72
+
73
+ return _ssim(img1, img2, window, window_size, channel, size_average)
models/Loss/pytorch_ssim/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.65 kB). View file
 
models/Loss/vgg19_loss.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+
5
+ class FeatureExtractor(nn.Module):
6
+ def __init__(self, cnn, feature_layer=11):
7
+ super(FeatureExtractor, self).__init__()
8
+ self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)])
9
+
10
+ def normalize(self, tensors, mean, std):
11
+ if not torch.is_tensor(tensors):
12
+ raise TypeError('tensor is not a torch image.')
13
+ for tensor in tensors:
14
+ for t, m, s in zip(tensor, mean, std):
15
+ t.sub_(m).div_(s)
16
+ return tensors
17
+
18
+ def forward(self, x):
19
+ # it image is gray scale then make it to 3 channel
20
+ if x.size()[1] == 1:
21
+ x = x.expand(-1, 3, -1, -1)
22
+
23
+ # [-1: 1] image to [0:1] image---------------------------------------------------(1)
24
+ x = (x + 1) * 0.5
25
+
26
+ # https://pytorch.org/docs/stable/torchvision/models.html
27
+ x.data = self.normalize(x.data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
+ return self.features(x)
29
+
30
+ # Feature extracting using vgg19
31
+ vgg19 = torchvision.models.vgg19(pretrained=True)
32
+ feature_extractor = FeatureExtractor(vgg19, feature_layer=35)
33
+ feature_extractor.eval()
34
+
35
+ class VGG19Loss(object):
36
+ def __init__(self):
37
+ global feature_extractor
38
+ self.initialized = False
39
+ self.feature_extractor = feature_extractor
40
+ self.MSE = nn.MSELoss()
41
+
42
+ def __call__(self, output, target, device):
43
+ if self.initialized == False:
44
+ self.feature_extractor = self.feature_extractor.to(device)
45
+ self.MSE = self.MSE.to(device)
46
+ self.initialized = True
47
+
48
+ # [-1: 1] image to [0:1] image---------------------------------------------------(2)
49
+ output = (output + 1) * 0.5
50
+ target = (target + 1) * 0.5
51
+
52
+ output = self.feature_extractor(output)
53
+ target = self.feature_extractor(target).data
54
+ return self.MSE(output, target)
models/SSN.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import utils
5
+ from collections import OrderedDict
6
+ import numpy as np
7
+
8
+ from .abs_model import abs_model
9
+ from .Loss.Loss import norm_loss
10
+ from .blocks import *
11
+ from .SSN_Model import SSN_Model
12
+
13
+
14
+ class SSN(abs_model):
15
+ def __init__(self, opt):
16
+ mid_act = opt['model']['mid_act']
17
+ out_act = opt['model']['out_act']
18
+ in_channels = opt['model']['in_channels']
19
+ out_channels = opt['model']['out_channels']
20
+ self.ncols = opt['hyper_params']['n_cols']
21
+
22
+ self.model = SSN_Model(in_channels=in_channels, out_channels=out_channels, mid_act=mid_act, out_act=out_act)
23
+ self.optimizer = get_optimizer(opt, self.model)
24
+ self.visualization = {}
25
+
26
+ self.norm_loss_ = norm_loss(norm=1)
27
+
28
+ def setup_input(self, x):
29
+ return x
30
+
31
+
32
+ def forward(self, x):
33
+ keys = ['mask', 'ibl']
34
+
35
+ for k in keys:
36
+ assert k in x.keys(), '{} not in input'.format(k)
37
+
38
+ mask = x['mask']
39
+ ibl = x['ibl']
40
+
41
+ return self.model(mask, ibl)
42
+
43
+
44
+ def compute_loss(self, y, pred):
45
+ total_loss = self.norm_loss_.loss(y, pred)
46
+ return total_loss
47
+
48
+
49
+ def supervise(self, input_x, y, is_training:bool)->float:
50
+ optimizer = self.optimizer
51
+ model = self.model
52
+
53
+ optimizer.zero_grad()
54
+ pred = self.forward(input_x)
55
+ loss = self.compute_loss(y, pred)
56
+
57
+ # logging.info('Pred/Target: {}, {}/{}, {}'.format(pred.min().item(), pred.max().item(), y.min().item(), y.max().item()))
58
+
59
+ if is_training:
60
+ loss.backward()
61
+ optimizer.step()
62
+
63
+ self.visualization['mask'] = input_x['mask'].detach()
64
+ self.visualization['ibl'] = input_x['ibl'].detach()
65
+ self.visualization['y'] = y.detach()
66
+ self.visualization['pred'] = pred.detach()
67
+
68
+ return loss.item()
69
+
70
+
71
+ def get_visualize(self) -> OrderedDict:
72
+ """ Convert to visualization numpy array
73
+ """
74
+ nrows = self.ncols
75
+ visualizations = self.visualization
76
+ ret_vis = OrderedDict()
77
+
78
+ for k, v in visualizations.items():
79
+ batch = v.shape[0]
80
+ n = min(nrows, batch)
81
+
82
+ plot_v = v[:n]
83
+ plot_v = (plot_v - plot_v.min())/(plot_v.max() - plot_v.min())
84
+ ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
85
+
86
+ return ret_vis
87
+
88
+
89
+ def get_logs(self):
90
+ pass
91
+
92
+
93
+ def inference(self, x):
94
+ keys = ['mask', 'ibl']
95
+ for k in keys:
96
+ assert k in x.keys(), '{} not in input'.format(k)
97
+ assert len(x[k].shape) == 2, '{} should be 2D tensor'.format(k)
98
+
99
+
100
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
101
+
102
+ mask = torch.tensor(x['mask'])[None, None, ...].float().to(device)
103
+ ibl = torch.tensor(x['ibl'])[None, None, ...].float().to(device)
104
+
105
+ input_x = {'mask': mask, 'ibl': ibl}
106
+ pred = self.forward(input_x)
107
+
108
+ pred = np.clip(pred[0, 0].detach().cpu().numpy() / 30.0, 0.0, 1.0)
109
+ return pred
110
+
111
+
112
+
113
+ def batch_inference(self, x):
114
+ # TODO
115
+ pass
116
+
117
+
118
+ """ Getter & Setter
119
+ """
120
+ def get_models(self) -> dict:
121
+ return {'model': self.model}
122
+
123
+
124
+ def get_optimizers(self) -> dict:
125
+ return {'optimizer': self.optimizer}
126
+
127
+
128
+ def set_models(self, models: dict) :
129
+ # input test
130
+ if 'model' not in models.keys():
131
+ raise ValueError('{} not in self.model'.format('model'))
132
+
133
+ self.model = models['model']
134
+
135
+
136
+ def set_optimizers(self, optimizer: dict):
137
+ self.optimizer = optimizer['optimizer']
138
+
139
+ ####################
140
+ # Personal Methods #
141
+ ####################
142
+
143
+
models/SSN_Model.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import logging
6
+
7
+ def weights_init(init_type='gaussian', std=0.02):
8
+ def init_fun(m):
9
+ classname = m.__class__.__name__
10
+ if (classname.find('Conv') == 0 or classname.find(
11
+ 'Linear') == 0) and hasattr(m, 'weight'):
12
+ if init_type == 'gaussian':
13
+ nn.init.normal_(m.weight, 0.0, std)
14
+ elif init_type == 'xavier':
15
+ nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
16
+ elif init_type == 'kaiming':
17
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
18
+ elif init_type == 'orthogonal':
19
+ nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
20
+ elif init_type == 'default':
21
+ pass
22
+ else:
23
+ assert 0, "Unsupported initialization: {}".format(init_type)
24
+ if hasattr(m, 'bias') and m.bias is not None:
25
+ nn.init.constant_(m.bias, 0.0)
26
+
27
+ return init_fun
28
+
29
+ def freeze(module):
30
+ for param in module.parameters():
31
+ param.requires_grad = False
32
+
33
+ def unfreeze(module):
34
+ for param in module.parameters():
35
+ param.requires_grad = True
36
+
37
+ def get_optimizer(opt, model):
38
+ lr = float(opt['hyper_params']['lr'])
39
+ beta1 = float(opt['model']['beta1'])
40
+ weight_decay = float(opt['model']['weight_decay'])
41
+ opt_name = opt['model']['optimizer']
42
+
43
+ optim_params = []
44
+ # weight decay
45
+ for key, value in model.named_parameters():
46
+ if not value.requires_grad:
47
+ continue # frozen weights
48
+
49
+ if key[-4:] == 'bias':
50
+ optim_params += [{'params': value, 'weight_decay': 0.0}]
51
+ else:
52
+ optim_params += [{'params': value,
53
+ 'weight_decay': weight_decay}]
54
+
55
+ if opt_name == 'Adam':
56
+ return optim.Adam(optim_params,
57
+ lr=lr,
58
+ betas=(beta1, 0.999),
59
+ eps=1e-5)
60
+ else:
61
+ err = '{} not implemented yet'.format(opt_name)
62
+ logging.error(err)
63
+ raise NotImplementedError(err)
64
+
65
+
66
+ def get_activation(activation):
67
+ if activation is None:
68
+ return nn.Identity()
69
+
70
+ act_func = {
71
+ 'relu':nn.ReLU(),
72
+ 'sigmoid':nn.Sigmoid(),
73
+ 'tanh':nn.Tanh(),
74
+ 'prelu':nn.PReLU(),
75
+ 'leaky':nn.LeakyReLU(0.2),
76
+ 'gelu':nn.GELU(),
77
+ }
78
+ if activation not in act_func.keys():
79
+ logging.error("activation {} is not implemented yet".format(activation))
80
+ assert False
81
+
82
+ return act_func[activation]
83
+
84
+ def get_norm(out_channels, norm_type='Instance'):
85
+ norm_set = ['Instance', 'Batch', 'Group']
86
+ if norm_type not in norm_set:
87
+ err = "Normalization {} has not been implemented yet"
88
+ logging.error(err)
89
+ raise ValueError(err)
90
+
91
+ if norm_type == 'Instance':
92
+ return nn.InstanceNorm2d(out_channels, affine=True)
93
+
94
+ if norm_type == 'Batch':
95
+ return nn.BatchNorm2d(out_channels)
96
+
97
+ if norm_type == 'Group':
98
+ if out_channels >= 32:
99
+ groups = 32
100
+ else:
101
+ groups = 1
102
+
103
+ return nn.GroupNorm(groups, out_channels)
104
+
105
+ else:
106
+ raise NotImplementedError('{} has not implemented yet'.format(norm_type))
107
+
108
+
109
+
110
+ def get_layer_info(out_channels, activation_func='relu'):
111
+ activation = get_activation(activation_func)
112
+ norm_layer = get_norm(out_channels, 'Group')
113
+ return norm_layer, activation
114
+
115
+
116
+ class Conv(nn.Module):
117
+ """ (convolution => [BN] => ReLU) """
118
+ def __init__(self,
119
+ in_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1,
124
+ bias=True,
125
+ activation='leaky',
126
+ resnet=True):
127
+ super().__init__()
128
+
129
+ norm_layer, act_func = get_layer_info(out_channels,activation)
130
+
131
+ if resnet and in_channels == out_channels:
132
+ self.resnet = True
133
+ else:
134
+ self.resnet = False
135
+
136
+ self.conv = nn.Sequential(
137
+ nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias),
138
+ norm_layer,
139
+ act_func)
140
+
141
+ def forward(self, x):
142
+ res = self.conv(x)
143
+
144
+ if self.resnet:
145
+ res = res + x
146
+
147
+ return res
148
+
149
+
150
+
151
+ class Up(nn.Module):
152
+ """ Upscaling then conv """
153
+
154
+ def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
155
+ super().__init__()
156
+
157
+ self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
158
+ self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
159
+
160
+ def forward(self, x):
161
+ x = self.up_layer(x)
162
+ return self.up(x)
163
+
164
+
165
+
166
+ class DConv(nn.Module):
167
+ """ Double Conv Layer
168
+ """
169
+ def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
170
+ super().__init__()
171
+
172
+ self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
173
+ self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
174
+
175
+ def forward(self, x):
176
+ return self.conv2(self.conv1(x))
177
+
178
+
179
+ class Encoder(nn.Module):
180
+ def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
181
+ super(Encoder, self).__init__()
182
+ self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
183
+ self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
184
+ self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
185
+ self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
186
+ self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
187
+ self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
188
+ self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
189
+ self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
190
+ self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
191
+ self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
192
+ self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
193
+
194
+
195
+ def forward(self, x):
196
+ x1 = self.in_conv(x) # 32 x 256 x 256
197
+ x1 = torch.cat((x, x1), dim=1)
198
+
199
+ x2 = self.down_32_64(x1)
200
+ x3 = self.down_64_64_1(x2)
201
+
202
+ x4 = self.down_64_128(x3)
203
+ x5 = self.down_128_128_1(x4)
204
+
205
+ x6 = self.down_128_256(x5)
206
+ x7 = self.down_256_256_1(x6)
207
+
208
+ x8 = self.down_256_512(x7)
209
+ x9 = self.down_512_512_1(x8)
210
+ x10 = self.down_512_512_2(x9)
211
+ x11 = self.down_512_512_3(x10)
212
+
213
+ return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
214
+
215
+
216
+ class Decoder(nn.Module):
217
+ """ Up Stream Sequence """
218
+
219
+ def __init__(self,
220
+ out_channels=3,
221
+ mid_act='relu',
222
+ out_act='sigmoid',
223
+ resnet = True):
224
+
225
+ super(Decoder, self).__init__()
226
+
227
+ input_channel = 512
228
+ fea_dim = 100
229
+
230
+
231
+ self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
232
+ self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
233
+ self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
234
+
235
+ self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
236
+ self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
237
+
238
+ self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
239
+ self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
240
+
241
+ self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
242
+ self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
243
+
244
+ self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
245
+ self.out_conv = Conv(64, out_channels, activation=out_act)
246
+
247
+
248
+ def forward(self, x, ibl):
249
+ x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
250
+
251
+ h,w = x10.shape[2:]
252
+ y = ibl.view(-1, 512, 1, 1).repeat(1, 1, h, w)
253
+
254
+ y = self.up_16_16_1(y) # 256 x 16 x 16
255
+
256
+ y = torch.cat((x10, y), dim=1) # 768 x 16 x 16
257
+ y = self.up_16_16_2(y) # 512 x 16 x 16
258
+
259
+
260
+ y = torch.cat((x9, y), dim=1) # 1024 x 16 x 16
261
+ y = self.up_16_16_3(y) # 512 x 16 x 16
262
+
263
+ y = torch.cat((x8, y), dim=1) # 1024 x 16 x 16
264
+ y = self.up_16_32(y) # 256 x 32 x 32
265
+
266
+ y = torch.cat((x7, y), dim=1)
267
+ y = self.up_32_32_1(y) # 256 x 32 x 32
268
+
269
+ y = torch.cat((x6, y), dim=1)
270
+ y = self.up_32_64(y)
271
+
272
+ y = torch.cat((x5, y), dim=1)
273
+ y = self.up_64_64_1(y) # 128 x 64 x 64
274
+
275
+ y = torch.cat((x4, y), dim=1)
276
+ y = self.up_64_128(y)
277
+
278
+ y = torch.cat((x3, y), dim=1)
279
+ y = self.up_128_128_1(y) # 64 x 128 x 128
280
+
281
+ y = torch.cat((x2, y), dim=1)
282
+ y = self.up_128_256(y) # 32 x 256 x 256
283
+
284
+ y = torch.cat((x1, y), dim=1)
285
+ y = self.out_conv(y) # 3 x 256 x 256
286
+
287
+ return y
288
+
289
+
290
+ class SSN_Model(nn.Module):
291
+ """ Implementation of Relighting Net """
292
+
293
+ def __init__(self,
294
+ in_channels=3,
295
+ out_channels=3,
296
+ mid_act='leaky',
297
+ out_act='sigmoid',
298
+ resnet=True):
299
+ super(SSN_Model, self).__init__()
300
+
301
+ self.out_act = out_act
302
+
303
+ self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
304
+ self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
305
+
306
+ # init weights
307
+ init_func = weights_init('gaussian', std=1e-3)
308
+ self.encoder.apply(init_func)
309
+ self.decoder.apply(init_func)
310
+
311
+
312
+ def forward(self, x, ibl):
313
+ """
314
+ Input is (source image, target light, source light, )
315
+ Output is: predicted new image, predicted source light, self-supervision image
316
+ """
317
+ latent = self.encoder(x)
318
+ pred = self.decoder(latent, ibl)
319
+
320
+ if self.out_act == 'sigmoid':
321
+ pred = pred * 30.0
322
+
323
+ return pred
324
+
325
+
326
+ if __name__ == '__main__':
327
+ x = torch.randn(5,1,256,256)
328
+ ibl = torch.randn(5, 1, 32, 16)
329
+ model = SSN_Model(1,1)
330
+
331
+ y = model(x, ibl)
332
+
333
+ print('Output: ', y.shape)
models/SSN_v1.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+ import numpy as np
6
+
7
+ def get_activation(activation_func):
8
+ act_func = {
9
+ "relu":nn.ReLU(),
10
+ "sigmoid":nn.Sigmoid(),
11
+ "prelu":nn.PReLU(num_parameters=1),
12
+ "leaky_relu": nn.LeakyReLU(negative_slope=0.2, inplace=False),
13
+ "gelu":nn.GELU()
14
+ }
15
+
16
+ if activation_func is None:
17
+ return nn.Identity()
18
+
19
+ if activation_func not in act_func.keys():
20
+ raise ValueError("activation function({}) is not found".format(activation_func))
21
+
22
+ activation = act_func[activation_func]
23
+ return activation
24
+
25
+
26
+ def get_layer_info(out_channels, activation_func='relu'):
27
+ #act_func = {"relu":nn.ReLU(), "sigmoid":nn.Sigmoid(), "prelu":nn.PReLU(num_parameters=out_channels)}
28
+
29
+ # norm_layer = nn.BatchNorm2d(out_channels, momentum=0.9)
30
+ if out_channels >= 32:
31
+ groups = 32
32
+ else:
33
+ groups = 1
34
+
35
+ norm_layer = nn.GroupNorm(groups, out_channels)
36
+ activation = get_activation(activation_func)
37
+ return norm_layer, activation
38
+
39
+
40
+ class Conv(nn.Module):
41
+ """ (convolution => [BN] => ReLU) """
42
+ def __init__(self,
43
+ in_channels,
44
+ out_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=True,
49
+ activation='leaky',
50
+ style=False,
51
+ resnet=True):
52
+ super().__init__()
53
+
54
+ self.style = style
55
+ norm_layer, act_func = get_layer_info(in_channels, activation)
56
+
57
+ if resnet and in_channels == out_channels:
58
+ self.resnet = True
59
+ else:
60
+ self.resnet = False
61
+
62
+ if style:
63
+ self.styleconv = Conv2DMod(in_channels, out_channels, kernel_size)
64
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
65
+ else:
66
+ self.norm = norm_layer
67
+ self.conv = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=kernel_size, padding=padding, bias=bias)
68
+ self.act = act_func
69
+
70
+ def forward(self, x, style_fea=None):
71
+ if self.style:
72
+ res = self.styleconv(x, style_fea)
73
+ res = self.relu(res)
74
+ else:
75
+ h = self.conv(self.act(self.norm(x)))
76
+ if self.resnet:
77
+ res = h + x
78
+ else:
79
+ res = h
80
+
81
+ return res
82
+
83
+
84
+ class Conv2DMod(nn.Module):
85
+ def __init__(self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs):
86
+ super().__init__()
87
+ self.filters = out_chan
88
+ self.demod = demod
89
+ self.kernel = kernel
90
+ self.stride = stride
91
+ self.dilation = dilation
92
+ self.weight = nn.Parameter(torch.randn((out_chan, in_chan, kernel, kernel)))
93
+ self.eps = eps
94
+ nn.init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
95
+
96
+ def _get_same_padding(self, size, kernel, dilation, stride):
97
+ return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
98
+
99
+ def forward(self, x, y):
100
+ b, c, h, w = x.shape
101
+
102
+ w1 = y[:, None, :, None, None]
103
+ w2 = self.weight[None, :, :, :, :]
104
+ weights = w2 * (w1 + 1)
105
+
106
+ if self.demod:
107
+ d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
108
+ weights = weights * d
109
+
110
+ x = x.reshape(1, -1, h, w)
111
+
112
+ _, _, *ws = weights.shape
113
+ weights = weights.reshape(b * self.filters, *ws)
114
+
115
+ padding = self._get_same_padding(h, self.kernel, self.dilation, self.stride)
116
+ x = F.conv2d(x, weights, padding=padding, groups=b)
117
+
118
+ x = x.reshape(-1, self.filters, h, w)
119
+ return x
120
+
121
+
122
+ class Up(nn.Module):
123
+ """ Upscaling then conv """
124
+
125
+ def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
126
+ super().__init__()
127
+ self.up_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
128
+ self.up = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
129
+
130
+ def forward(self, x):
131
+ x = self.up_layer(x)
132
+ return self.up(x)
133
+
134
+
135
+
136
+ class DConv(nn.Module):
137
+ """ Double Conv Layer
138
+ """
139
+ def __init__(self, in_channels, out_channels, activation='relu', resnet=True):
140
+ super().__init__()
141
+
142
+ self.conv1 = Conv(in_channels, out_channels, activation=activation, resnet=resnet)
143
+ self.conv2 = Conv(out_channels, out_channels, activation=activation, resnet=resnet)
144
+
145
+ def forward(self, x):
146
+ return self.conv2(self.conv1(x))
147
+
148
+
149
+ class Encoder(nn.Module):
150
+ def __init__(self, in_channels=3, mid_act='leaky', resnet=True):
151
+ super(Encoder, self).__init__()
152
+ self.in_conv = Conv(in_channels, 32-in_channels, stride=1, activation=mid_act, resnet=resnet)
153
+ self.down_32_64 = Conv(32, 64, stride=2, activation=mid_act, resnet=resnet)
154
+ self.down_64_64_1 = Conv(64, 64, activation=mid_act, resnet=resnet)
155
+ self.down_64_128 = Conv(64, 128, stride=2, activation=mid_act, resnet=resnet)
156
+ self.down_128_128_1 = Conv(128, 128, activation=mid_act, resnet=resnet)
157
+ self.down_128_256 = Conv(128, 256, stride=2, activation=mid_act, resnet=resnet)
158
+ self.down_256_256_1 = Conv(256, 256, activation=mid_act, resnet=resnet)
159
+ self.down_256_512 = Conv(256, 512, stride=2, activation=mid_act, resnet=resnet)
160
+ self.down_512_512_1 = Conv(512, 512, activation=mid_act, resnet=resnet)
161
+ self.down_512_512_2 = Conv(512, 512, activation=mid_act, resnet=resnet)
162
+ self.down_512_512_3 = Conv(512, 512, activation=mid_act, resnet=resnet)
163
+
164
+
165
+ def forward(self, x):
166
+ x1 = self.in_conv(x) # 32 x 256 x 256
167
+ x1 = torch.cat((x, x1), dim=1)
168
+
169
+ x2 = self.down_32_64(x1)
170
+ x3 = self.down_64_64_1(x2)
171
+
172
+ x4 = self.down_64_128(x3)
173
+ x5 = self.down_128_128_1(x4)
174
+
175
+ x6 = self.down_128_256(x5)
176
+ x7 = self.down_256_256_1(x6)
177
+
178
+ x8 = self.down_256_512(x7)
179
+ x9 = self.down_512_512_1(x8)
180
+ x10 = self.down_512_512_2(x9)
181
+ x11 = self.down_512_512_3(x10)
182
+
183
+ return x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1
184
+
185
+
186
+ class Decoder(nn.Module):
187
+ def __init__(self,
188
+ out_channels=3,
189
+ mid_act='relu',
190
+ out_act='sigmoid',
191
+ resnet = True):
192
+
193
+ super(Decoder, self).__init__()
194
+
195
+ input_channel = 512
196
+ fea_dim = 100
197
+
198
+ self.to_style1 = nn.Linear(in_features=fea_dim, out_features=input_channel)
199
+
200
+ self.up_16_16_1 = Conv(input_channel, 256, activation=mid_act, resnet=resnet)
201
+ self.up_16_16_2 = Conv(768, 512, activation=mid_act, resnet=resnet)
202
+ self.up_16_16_3 = Conv(1024, 512, activation=mid_act, resnet=resnet)
203
+
204
+ self.up_16_32 = Up(1024, 256, activation=mid_act, resnet=resnet)
205
+ self.up_32_32_1 = Conv(512, 256, activation=mid_act, resnet=resnet)
206
+
207
+ self.up_32_64 = Up(512, 128, activation=mid_act, resnet=resnet)
208
+ self.up_64_64_1 = Conv(256, 128, activation=mid_act, resnet=resnet)
209
+
210
+ self.up_64_128 = Up(256, 64, activation=mid_act, resnet=resnet)
211
+ self.up_128_128_1 = Conv(128, 64, activation=mid_act, resnet=resnet)
212
+
213
+ self.up_128_256 = Up(128, 32, activation=mid_act, resnet=resnet)
214
+ self.out_conv = Conv(64, out_channels, activation=mid_act)
215
+
216
+ self.out_act = get_activation(out_act)
217
+
218
+
219
+ def forward(self, x):
220
+ x11, x10, x9, x8, x7, x6, x5, x4, x3, x2, x1 = x
221
+
222
+ y = self.up_16_16_1(x11)
223
+
224
+ y = torch.cat((x10, y), dim=1)
225
+ y = self.up_16_16_2(y)
226
+
227
+ y = torch.cat((x9, y), dim=1)
228
+ y = self.up_16_16_3(y)
229
+
230
+ y = torch.cat((x8, y), dim=1)
231
+ y = self.up_16_32(y)
232
+
233
+ y = torch.cat((x7, y), dim=1)
234
+ y = self.up_32_32_1(y)
235
+
236
+ y = torch.cat((x6, y), dim=1)
237
+ y = self.up_32_64(y)
238
+
239
+ y = torch.cat((x5, y), dim=1)
240
+ y = self.up_64_64_1(y) # 128 x 64 x 64
241
+
242
+ y = torch.cat((x4, y), dim=1)
243
+ y = self.up_64_128(y)
244
+
245
+ y = torch.cat((x3, y), dim=1)
246
+ y = self.up_128_128_1(y) # 64 x 128 x 128
247
+
248
+ y = torch.cat((x2, y), dim=1)
249
+ y = self.up_128_256(y) # 32 x 256 x 256
250
+
251
+ y = torch.cat((x1, y), dim=1)
252
+ y = self.out_conv(y) # 3 x 256 x 256
253
+ y = self.out_act(y)
254
+
255
+ return y
256
+
257
+
258
+ class SSN_v1(nn.Module):
259
+ """ Implementation of Relighting Net """
260
+
261
+ def __init__(self,
262
+ in_channels=3,
263
+ out_channels=3,
264
+ mid_act='leaky',
265
+ out_act='sigmoid',
266
+ resnet=True):
267
+ super(SSN_v1, self).__init__()
268
+ self.encoder = Encoder(in_channels, mid_act=mid_act, resnet=resnet)
269
+ self.decoder = Decoder(out_channels, mid_act=mid_act, out_act=out_act, resnet=resnet)
270
+
271
+
272
+ def forward(self, x, softness):
273
+ """
274
+ Input is (source image, target light, source light, )
275
+ Output is: predicted new image, predicted source light, self-supervision image
276
+ """
277
+ latent = self.encoder(x)
278
+ pred = self.decoder(latent)
279
+
280
+ return pred
281
+
282
+
283
+ if __name__ == '__main__':
284
+ test_input = torch.randn(5, 1, 256, 256)
285
+ style = torch.randn(5, 100)
286
+
287
+ model = SSN_v1(1, 1, mid_act='gelu', out_act='gelu', resnet=True)
288
+ test_out = model(test_input, style)
289
+
290
+ print('Ouptut shape: ', test_out.shape)
models/Sparse_PH.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import utils
5
+ from torchvision.transforms import Resize
6
+ from collections import OrderedDict
7
+ import numpy as np
8
+ import matplotlib.cm as cm
9
+ import matplotlib as mpl
10
+ from torchvision.transforms import InterpolationMode
11
+
12
+
13
+ from .abs_model import abs_model
14
+ from .blocks import *
15
+ from .SSN import SSN
16
+ from .SSN_v1 import SSN_v1
17
+ from .Loss.Loss import norm_loss, grad_loss
18
+ from .Attention_Unet import Attention_Unet
19
+
20
+ class Sparse_PH(abs_model):
21
+ def __init__(self, opt):
22
+ mid_act = opt['model']['mid_act']
23
+ out_act = opt['model']['out_act']
24
+ in_channels = opt['model']['in_channels']
25
+ out_channels = opt['model']['out_channels']
26
+ resnet = opt['model']['resnet']
27
+ backbone = opt['model']['backbone']
28
+
29
+ self.ncols = opt['hyper_params']['n_cols']
30
+ self.focal = opt['model']['focal']
31
+ self.clip = opt['model']['clip']
32
+
33
+ self.norm_loss_ = opt['model']['norm_loss']
34
+ self.grad_loss_ = opt['model']['grad_loss']
35
+ self.ggrad_loss_ = opt['model']['ggrad_loss']
36
+ self.lap_loss = opt['model']['lap_loss']
37
+
38
+ self.clip_range = opt['dataset']['linear_scale'] + opt['dataset']['linear_offset']
39
+
40
+ if backbone == 'Default':
41
+ self.model = SSN_v1(in_channels=in_channels,
42
+ out_channels=out_channels,
43
+ mid_act=mid_act,
44
+ out_act=out_act,
45
+ resnet=resnet)
46
+ elif backbone == 'ATTN':
47
+ self.model = Attention_Unet(in_channels, out_channels, mid_act=mid_act, out_act=out_act)
48
+
49
+ self.optimizer = get_optimizer(opt, self.model)
50
+ self.visualization = {}
51
+
52
+ self.norm_loss = norm_loss()
53
+ self.grad_loss = grad_loss()
54
+
55
+
56
+ def setup_input(self, x):
57
+ return x
58
+
59
+
60
+ def forward(self, x):
61
+ return self.model(x)
62
+
63
+
64
+ def compute_loss(self, y, pred):
65
+ b = y.shape[0]
66
+
67
+ # total_loss = avg_norm_loss(y, pred)
68
+ nloss = self.norm_loss.loss(y, pred) * self.norm_loss_
69
+ gloss = self.grad_loss.loss(pred) * self.grad_loss_
70
+ ggloss = self.grad_loss.gloss(y, pred) * self.ggrad_loss_
71
+ laploss = self.grad_loss.laploss(pred) * self.lap_loss
72
+
73
+ total_loss = nloss + gloss + ggloss + laploss
74
+
75
+ self.loss_log = {
76
+ 'norm_loss': nloss.item(),
77
+ 'grad_loss': gloss.item(),
78
+ 'grad_l1_loss': ggloss.item(),
79
+ 'lap_loss': laploss.item(),
80
+ }
81
+
82
+
83
+ if self.focal:
84
+ total_loss = torch.pow(total_loss, 3)
85
+
86
+ return total_loss
87
+
88
+
89
+ def supervise(self, input_x, y, is_training:bool)->float:
90
+ optimizer = self.optimizer
91
+ model = self.model
92
+
93
+ x = input_x['x']
94
+
95
+ optimizer.zero_grad()
96
+ pred = self.forward(x)
97
+ if self.clip:
98
+ pred = torch.clip(pred, 0.0, self.clip_range)
99
+
100
+ loss = self.compute_loss(y, pred)
101
+ if is_training:
102
+ loss.backward()
103
+ optimizer.step()
104
+
105
+ xc = x.shape[1]
106
+ for i in range(xc):
107
+ self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
108
+
109
+ self.visualization['y_fore'] = y[:, 0:1].detach()
110
+ self.visualization['y_back'] = y[:, 1:2].detach()
111
+ self.visualization['pred_fore'] = pred[:, 0:1].detach()
112
+ self.visualization['pred_back'] = pred[:, 1:2].detach()
113
+
114
+ return loss.item()
115
+
116
+
117
+ def get_visualize(self) -> OrderedDict:
118
+ """ Convert to visualization numpy array
119
+ """
120
+ nrows = self.ncols
121
+ visualizations = self.visualization
122
+ ret_vis = OrderedDict()
123
+
124
+ for k, v in visualizations.items():
125
+ batch = v.shape[0]
126
+ n = min(nrows, batch)
127
+
128
+ plot_v = v[:n]
129
+ ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
130
+ ret_vis[k] = self.plasma(ret_vis[k])
131
+
132
+ return ret_vis
133
+
134
+
135
+ def get_logs(self):
136
+ return self.loss_log
137
+
138
+
139
+ def inference(self, x):
140
+ x, device = x['x'], x['device']
141
+ x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).float().to(device)
142
+ pred = self.forward(x)
143
+
144
+ pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
145
+
146
+ return pred
147
+
148
+
149
+ def batch_inference(self, x):
150
+ x = x['x']
151
+ pred = self.forward(x)
152
+ return pred
153
+
154
+
155
+ """ Getter & Setter
156
+ """
157
+ def get_models(self) -> dict:
158
+ return {'model': self.model}
159
+
160
+
161
+ def get_optimizers(self) -> dict:
162
+ return {'optimizer': self.optimizer}
163
+
164
+
165
+ def set_models(self, models: dict) :
166
+ # input test
167
+ if 'model' not in models.keys():
168
+ raise ValueError('{} not in self.model'.format('model'))
169
+
170
+ self.model = models['model']
171
+
172
+
173
+ def set_optimizers(self, optimizer: dict):
174
+ self.optimizer = optimizer['optimizer']
175
+
176
+
177
+ ####################
178
+ # Personal Methods #
179
+ ####################
180
+ def plasma(self, x):
181
+ norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
182
+ mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
183
+ bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
184
+
185
+ return bimg
models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRC: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/__init__.py
2
+ import logging
3
+ import importlib
4
+
5
+ from .abs_model import abs_model
6
+
7
+
8
+ def find_model_using_name(model_name):
9
+ """Import the module "models/[model_name].py".
10
+ In the file, the class called DatasetNameModel() will
11
+ be instantiated. It has to be a subclass of BaseModel,
12
+ and it is case-insensitive.
13
+ """
14
+ model_filename = "models." + model_name
15
+ modellib = importlib.import_module(model_filename)
16
+ model = None
17
+
18
+ target_model_name = model_name
19
+ for name, cls in modellib.__dict__.items():
20
+ if name.lower() == target_model_name.lower() \
21
+ and issubclass(cls, abs_model):
22
+ model = cls
23
+
24
+ if model is None:
25
+ err = "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)
26
+ logging.error(err)
27
+ exit(0)
28
+
29
+ return model
30
+
31
+
32
+ def create_model(opt):
33
+ """Create a model given the option.
34
+ This funct
35
+ This is the main interface between this package and 'train.py'/'test.py'
36
+ Example:
37
+ >>> from models import create_model
38
+ >>> model = create_model(opt)
39
+ """
40
+ model = find_model_using_name(opt['model']['name'])
41
+ instance = model(opt)
42
+ logging.info("model [%s] was created" % type(instance).__name__)
43
+ return instance
models/__pycache__/SSN.cpython-39.pyc ADDED
Binary file (4.11 kB). View file
 
models/__pycache__/SSN_Model.cpython-39.pyc ADDED
Binary file (8.96 kB). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.42 kB). View file
 
models/__pycache__/abs_model.cpython-39.pyc ADDED
Binary file (2.14 kB). View file
 
models/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (6.92 kB). View file
 
models/abs_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import OrderedDict
3
+
4
+ class abs_model(ABC):
5
+ """ Training Related Interface
6
+ """
7
+ @abstractmethod
8
+ def setup_input(self, x):
9
+ pass
10
+
11
+
12
+ @abstractmethod
13
+ def forward(self, x):
14
+ pass
15
+
16
+
17
+ @abstractmethod
18
+ def supervise(self, input_x, y, is_training:bool)->float:
19
+ pass
20
+
21
+
22
+ @abstractmethod
23
+ def get_visualize(self) -> OrderedDict:
24
+ return {}
25
+
26
+
27
+ """ Inference Related Interface
28
+ """
29
+ @abstractmethod
30
+ def inference(self, x):
31
+ pass
32
+
33
+
34
+ @abstractmethod
35
+ def batch_inference(self, x):
36
+ pass
37
+
38
+
39
+ """ Logging/Visualization Related Interface
40
+ """
41
+ @abstractmethod
42
+ def get_logs(self):
43
+ pass
44
+
45
+
46
+ """ Getter & Setter
47
+ """
48
+ @abstractmethod
49
+ def get_models(self) -> dict:
50
+ """ GAN may have two models
51
+ """
52
+ pass
53
+
54
+
55
+ @abstractmethod
56
+ def get_optimizers(self) -> dict:
57
+ """ GAN may have two optimizer
58
+ """
59
+ pass
60
+
61
+
62
+ @abstractmethod
63
+ def set_models(self, models) -> dict:
64
+ """ GAN may have two models
65
+ """
66
+ pass
67
+
68
+
69
+ @abstractmethod
70
+ def set_optimizers(self, optimizers: dict):
71
+ """ GAN may have two optimizer
72
+ """
73
+ pass
models/attention.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ from torch import nn, einsum
5
+ import torch.nn.functional as F
6
+
7
+ from .blocks import get_norm, zero_module
8
+
9
+
10
+ def QKV_Attention(qkv, num_heads):
11
+ """
12
+ Apply QKV attention.
13
+ :param qkv: an [N x (3 * C) x T] tensor of Qs, Ks, and Vs.
14
+ :return: an [N x H' x T] tensor after attention.
15
+ """
16
+ B, C, HW = qkv.shape
17
+ if C % 3 != 0:
18
+ raise ValueError('QKV shape is wrong: {}, {}, {}'.format(B, C, HW))
19
+
20
+ split_size = C // (3 * num_heads)
21
+ q, k, v = qkv.chunk(3, dim=1)
22
+ scale = 1.0/math.sqrt(math.sqrt(split_size))
23
+ weight = torch.einsum('bct, bcs->bts',
24
+ (q * scale).view(B * num_heads, split_size, HW),
25
+ (k * scale).view(B * num_heads, split_size, HW))
26
+
27
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
28
+ ret = torch.einsum("bts,bcs->bct", weight, v.reshape(B * num_heads, split_size, HW))
29
+
30
+ return ret.reshape(B, -1, HW)
31
+
32
+
33
+ class AttentionBlock(nn.Module):
34
+ """
35
+ https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py
36
+ https://github.com/whai362/PVT/blob/a24ba02c249a510581a84f821c26322534b03a10/detection/pvt_v2.py#L57
37
+ """
38
+
39
+ def __init__(self, in_channels, num_heads, qkv_bias=False, sr_ratio=1, linear=True):
40
+ super().__init__()
41
+
42
+ self.num_heads = num_heads
43
+ self.norm = get_norm(in_channels, 'Group')
44
+ self.qkv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels * 3, kernel_size = 1)
45
+
46
+ self.proj = zero_module(nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size = 1))
47
+
48
+
49
+ def forward(self, x):
50
+ b, c, *spatial = x.shape
51
+ num_heads = self.num_heads
52
+
53
+ x = x.reshape(b, c, -1) # B x C x HW
54
+ x = self.norm(x)
55
+ qkv = self.qkv(x) # b x c x HW -> B x 3C x HW
56
+ h = QKV_Attention(qkv, num_heads)
57
+ h = self.proj(h)
58
+
59
+ return (x + h).reshape(b,c,*spatial) # additive attention, similar to ResNet?
60
+
61
+
62
+
63
+ def get_model_size(model):
64
+ param_size = 0
65
+ for param in model.parameters():
66
+ param_size += param.nelement() * param.element_size()
67
+
68
+ buffer_size = 0
69
+ for buffer in model.buffers():
70
+ buffer_size += buffer.nelement() * buffer.element_size()
71
+
72
+ size_all_mb = (param_size + buffer_size) / 1024 ** 2
73
+ print('model size: {:.3f}MB'.format(size_all_mb))
74
+ # return param_size + buffer_size
75
+ return size_all_mb
76
+
77
+
78
+ if __name__ == '__main__':
79
+ model = AttentionBlock(in_channels=256, num_heads=8)
80
+
81
+ x = torch.randn(5, 256, 32, 32, dtype=torch.float32)
82
+ y = model(x)
83
+ print('{}, {}'.format(x.shape, y.shape))
84
+
85
+ get_model_size(model)
models/blocks.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ import logging
7
+
8
+
9
+ def get_model_size(model):
10
+ param_size = 0
11
+ for param in model.parameters():
12
+ param_size += param.nelement() * param.element_size()
13
+
14
+ buffer_size = 0
15
+ for buffer in model.buffers():
16
+ buffer_size += buffer.nelement() * buffer.element_size()
17
+
18
+ size_all_mb = (param_size + buffer_size) / 1024 ** 2
19
+ print('model size: {:.3f}MB'.format(size_all_mb))
20
+ # return param_size + buffer_size
21
+ return size_all_mb
22
+
23
+
24
+ def weights_init(init_type='gaussian'):
25
+ def init_fun(m):
26
+ classname = m.__class__.__name__
27
+ if (classname.find('Conv') == 0 or classname.find(
28
+ 'Linear') == 0) and hasattr(m, 'weight'):
29
+ if init_type == 'gaussian':
30
+ nn.init.normal_(m.weight, 0.0, 0.02)
31
+ elif init_type == 'xavier':
32
+ nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
33
+ elif init_type == 'kaiming':
34
+ nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
35
+ elif init_type == 'orthogonal':
36
+ nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
37
+ elif init_type == 'default':
38
+ pass
39
+ else:
40
+ assert 0, "Unsupported initialization: {}".format(init_type)
41
+ if hasattr(m, 'bias') and m.bias is not None:
42
+ nn.init.constant_(m.bias, 0.0)
43
+
44
+ return init_fun
45
+
46
+
47
+ def freeze(module):
48
+ for param in module.parameters():
49
+ param.requires_grad = False
50
+
51
+
52
+ def unfreeze(module):
53
+ for param in module.parameters():
54
+ param.requires_grad = True
55
+
56
+
57
+ def get_optimizer(opt, model):
58
+ lr = float(opt['hyper_params']['lr'])
59
+ beta1 = float(opt['model']['beta1'])
60
+ weight_decay = float(opt['model']['weight_decay'])
61
+ opt_name = opt['model']['optimizer']
62
+
63
+ optim_params = []
64
+ # weight decay
65
+ for key, value in model.named_parameters():
66
+ if not value.requires_grad:
67
+ continue # frozen weights
68
+
69
+ if key[-4:] == 'bias':
70
+ optim_params += [{'params': value, 'weight_decay': 0.0}]
71
+ else:
72
+ optim_params += [{'params': value,
73
+ 'weight_decay': weight_decay}]
74
+
75
+ if opt_name == 'Adam':
76
+ return optim.Adam(optim_params,
77
+ lr=lr,
78
+ betas=(beta1, 0.999),
79
+ eps=1e-5)
80
+ else:
81
+ err = '{} not implemented yet'.format(opt_name)
82
+ logging.error(err)
83
+ raise NotImplementedError(err)
84
+
85
+
86
+ def get_activation(activation):
87
+ act_func = {
88
+ 'relu':nn.ReLU(),
89
+ 'sigmoid':nn.Sigmoid(),
90
+ 'tanh':nn.Tanh(),
91
+ 'prelu':nn.PReLU(),
92
+ 'leaky_relu':nn.LeakyReLU(0.2),
93
+ 'gelu':nn.GELU(),
94
+ }
95
+ if activation not in act_func.keys():
96
+ logging.error("activation {} is not implemented yet".format(activation))
97
+ assert False
98
+
99
+ return act_func[activation]
100
+
101
+
102
+ def get_norm(out_channels, norm_type='Group', groups=32):
103
+ norm_set = ['Instance', 'Batch', 'Group']
104
+ if norm_type not in norm_set:
105
+ err = "Normalization {} has not been implemented yet"
106
+ logging.error(err)
107
+ raise ValueError(err)
108
+
109
+ if norm_type == 'Instance':
110
+ return nn.InstanceNorm2d(out_channels, affine=True)
111
+
112
+ if norm_type == 'Batch':
113
+ return nn.BatchNorm2d(out_channels)
114
+
115
+ if norm_type == 'Group':
116
+ if out_channels >= 32:
117
+ groups = 32
118
+ else:
119
+ groups = max(out_channels // 2, 1)
120
+
121
+ return nn.GroupNorm(groups, out_channels)
122
+ else:
123
+ raise NotImplementedError
124
+
125
+
126
+ class Conv(nn.Module):
127
+ def __init__(self, in_channels, out_channels, stride=1, norm_type='Batch', activation='relu'):
128
+ super().__init__()
129
+
130
+ act_func = get_activation(activation)
131
+ norm_layer = get_norm(out_channels, norm_type)
132
+ self.conv = nn.Sequential(
133
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True, padding_mode='reflect'),
134
+ norm_layer,
135
+ act_func)
136
+
137
+ def forward(self, x):
138
+ return self.conv(x)
139
+
140
+
141
+ def zero_module(module):
142
+ """
143
+ Zero out the parameters of a module and return it.
144
+ """
145
+ for p in module.parameters():
146
+ p.detach().zero_()
147
+ return module
148
+
149
+
150
+ class Up(nn.Module):
151
+ def __init__(self):
152
+ super().__init__()
153
+ pass
154
+
155
+ def forward(self, x):
156
+ return F.interpolate(x, scale_factor=2, mode='bilinear')
157
+
158
+
159
+ class Down(nn.Module):
160
+ def __init__(self, channels, use_conv):
161
+ super().__init__()
162
+ self.use_conv = use_conv
163
+
164
+ if self.use_conv:
165
+ self.op = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
166
+ else:
167
+ self.op = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
168
+
169
+
170
+ def forward(self, x):
171
+ return self.op(x)
172
+
173
+
174
+ class Res_Type(Enum):
175
+ UP = 1
176
+ DOWN = 2
177
+ SAME = 3
178
+
179
+
180
+ class ResBlock(nn.Module):
181
+ def __init__(self, in_channels: int, out_channels: int, dropout=0.0, updown=Res_Type.DOWN, mid_act='leaky'):
182
+ """ ResBlock to cover several cases:
183
+ 1. Up/Down/Same
184
+ 2. in_channels != out_channels
185
+ """
186
+ super().__init__()
187
+
188
+ self.updown = updown
189
+
190
+ self.in_norm = get_norm(out_channels, 'Group')
191
+ self.in_act = get_activation(mid_act)
192
+ self.in_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=True)
193
+
194
+ # up down
195
+ if self.updown == Res_Type.DOWN:
196
+ self.h_updown = Down(in_channels, use_conv=True)
197
+ self.x_updown = Down(in_channels, use_conv=True)
198
+ elif self.updown == Res_Type.UP:
199
+ self.h_updown = Up()
200
+ self.x_updown = Up()
201
+ else:
202
+ self.h_updown = nn.Identity()
203
+
204
+ self.out_layer = nn.Sequential(
205
+ get_norm(out_channels, 'Group'),
206
+ get_activation(mid_act),
207
+ nn.Dropout(p=dropout),
208
+ zero_module(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=True))
209
+ )
210
+
211
+
212
+ def forward(self, x):
213
+ # in layer
214
+ h = self.in_act(self.in_norm(x))
215
+ h = self.in_conv(self.h_updown(h))
216
+ x = self.x_updown(x)
217
+
218
+ # out layer
219
+ h = self.out_layer(h)
220
+ return x + h
221
+
222
+
223
+
224
+ if __name__ == '__main__':
225
+ x = torch.randn(5, 3, 256, 256)
226
+ up = Up()
227
+ conv_down = Down(3, True)
228
+ pool_down = Down(3, False)
229
+
230
+ print('Up: {}'.format(up(x).shape))
231
+ print('Conv down: {}'.format(conv_down(x).shape))
232
+ print('Pool down: {}'.format(pool_down(x).shape))
233
+
234
+ up_model = ResBlock(3, 6, updown=True)
235
+ down_model = ResBlock(3, 6, updown=False)
236
+
237
+ print('model down: {}'.format(up_model(x).shape))
238
+ print('model down: {}'.format(down_model(x).shape))
models/pvt_attention.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+ from timm.models.registry import register_model
8
+ from timm.models.vision_transformer import _cfg
9
+ import math
10
+
11
+
12
+ class DWConv(nn.Module):
13
+ def __init__(self, dim=768):
14
+ super(DWConv, self).__init__()
15
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
16
+
17
+ def forward(self, x, H, W):
18
+ B, N, C = x.shape
19
+ x = x.transpose(1, 2).view(B, C, H, W)
20
+ x = self.dwconv(x)
21
+ x = x.flatten(2).transpose(1, 2)
22
+
23
+ return x
24
+
25
+
26
+ class Mlp(nn.Module):
27
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False):
28
+ super().__init__()
29
+ out_features = out_features or in_features
30
+ hidden_features = hidden_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.dwconv = DWConv(hidden_features)
33
+ self.act = act_layer()
34
+ self.fc2 = nn.Linear(hidden_features, out_features)
35
+ self.drop = nn.Dropout(drop)
36
+ self.linear = linear
37
+ if self.linear:
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.apply(self._init_weights)
40
+
41
+ def _init_weights(self, m):
42
+ if isinstance(m, nn.Linear):
43
+ trunc_normal_(m.weight, std=.02)
44
+ if isinstance(m, nn.Linear) and m.bias is not None:
45
+ nn.init.constant_(m.bias, 0)
46
+ elif isinstance(m, nn.LayerNorm):
47
+ nn.init.constant_(m.bias, 0)
48
+ nn.init.constant_(m.weight, 1.0)
49
+ elif isinstance(m, nn.Conv2d):
50
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
51
+ fan_out //= m.groups
52
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
53
+ if m.bias is not None:
54
+ m.bias.data.zero_()
55
+
56
+ def forward(self, x, H, W):
57
+ x = self.fc1(x)
58
+ if self.linear:
59
+ x = self.relu(x)
60
+ x = self.dwconv(x, H, W)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class Attention(nn.Module):
69
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
70
+ super().__init__()
71
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
72
+
73
+ self.dim = dim
74
+ self.num_heads = num_heads
75
+ head_dim = dim // num_heads
76
+ self.scale = qk_scale or head_dim ** -0.5
77
+
78
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
79
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
80
+ self.attn_drop = nn.Dropout(attn_drop)
81
+ self.proj = nn.Linear(dim, dim)
82
+ self.proj_drop = nn.Dropout(proj_drop)
83
+
84
+ self.linear = linear
85
+ self.sr_ratio = sr_ratio
86
+ if not linear:
87
+ if sr_ratio > 1:
88
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
89
+ self.norm = nn.LayerNorm(dim)
90
+ else:
91
+ self.pool = nn.AdaptiveAvgPool2d(7)
92
+ self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
93
+ self.norm = nn.LayerNorm(dim)
94
+ self.act = nn.GELU()
95
+ self.apply(self._init_weights)
96
+
97
+ def _init_weights(self, m):
98
+ if isinstance(m, nn.Linear):
99
+ trunc_normal_(m.weight, std=.02)
100
+ if isinstance(m, nn.Linear) and m.bias is not None:
101
+ nn.init.constant_(m.bias, 0)
102
+ elif isinstance(m, nn.LayerNorm):
103
+ nn.init.constant_(m.bias, 0)
104
+ nn.init.constant_(m.weight, 1.0)
105
+ elif isinstance(m, nn.Conv2d):
106
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
107
+ fan_out //= m.groups
108
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
109
+ if m.bias is not None:
110
+ m.bias.data.zero_()
111
+
112
+ def forward(self, x, H, W):
113
+ B, N, C = x.shape
114
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115
+
116
+ if not self.linear:
117
+ if self.sr_ratio > 1:
118
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
119
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
120
+ x_ = self.norm(x_)
121
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ else:
123
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ else:
125
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
126
+ x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
127
+ x_ = self.norm(x_)
128
+ x_ = self.act(x_)
129
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
130
+ k, v = kv[0], kv[1]
131
+
132
+ attn = (q @ k.transpose(-2, -1)) * self.scale
133
+ attn = attn.softmax(dim=-1)
134
+ attn = self.attn_drop(attn)
135
+
136
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
137
+ x = self.proj(x)
138
+ x = self.proj_drop(x)
139
+
140
+ return x
141
+
142
+
143
+ class Block(nn.Module):
144
+
145
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
146
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False):
147
+ super().__init__()
148
+ self.norm1 = norm_layer(dim)
149
+ self.attn = Attention(
150
+ dim,
151
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
152
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear)
153
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
154
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
155
+ self.norm2 = norm_layer(dim)
156
+ mlp_hidden_dim = int(dim * mlp_ratio)
157
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear)
158
+
159
+ self.apply(self._init_weights)
160
+
161
+ def _init_weights(self, m):
162
+ if isinstance(m, nn.Linear):
163
+ trunc_normal_(m.weight, std=.02)
164
+ if isinstance(m, nn.Linear) and m.bias is not None:
165
+ nn.init.constant_(m.bias, 0)
166
+ elif isinstance(m, nn.LayerNorm):
167
+ nn.init.constant_(m.bias, 0)
168
+ nn.init.constant_(m.weight, 1.0)
169
+ elif isinstance(m, nn.Conv2d):
170
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
171
+ fan_out //= m.groups
172
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
173
+ if m.bias is not None:
174
+ m.bias.data.zero_()
175
+
176
+ def forward(self, x, H, W):
177
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
179
+
180
+ return x
181
+
182
+
183
+ class OverlapPatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+
187
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
188
+ super().__init__()
189
+ img_size = to_2tuple(img_size)
190
+ patch_size = to_2tuple(patch_size)
191
+
192
+ assert max(patch_size) > stride, "Set larger patch_size than stride"
193
+
194
+ self.img_size = img_size
195
+ self.patch_size = patch_size
196
+ self.H, self.W = img_size[0] // stride, img_size[1] // stride
197
+ self.num_patches = self.H * self.W
198
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
199
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
200
+ self.norm = nn.LayerNorm(embed_dim)
201
+
202
+ self.apply(self._init_weights)
203
+
204
+ def _init_weights(self, m):
205
+ if isinstance(m, nn.Linear):
206
+ trunc_normal_(m.weight, std=.02)
207
+ if isinstance(m, nn.Linear) and m.bias is not None:
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.LayerNorm):
210
+ nn.init.constant_(m.bias, 0)
211
+ nn.init.constant_(m.weight, 1.0)
212
+ elif isinstance(m, nn.Conv2d):
213
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
214
+ fan_out //= m.groups
215
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
216
+ if m.bias is not None:
217
+ m.bias.data.zero_()
218
+
219
+ def forward(self, x):
220
+ x = self.proj(x)
221
+ _, _, H, W = x.shape
222
+ import pdb; pdb.set_trace()
223
+ x = x.flatten(2).transpose(1, 2)
224
+ x = self.norm(x)
225
+
226
+ return x, H, W
227
+
228
+ if __name__ == '__main__':
229
+ test = torch.randn(5, 3, 224, 224)
230
+
231
+ embed_dim = 768
232
+ patch_embed = OverlapPatchEmbed(embed_dim=embed_dim)
233
+ block = Block(embed_dim, 1)
234
+
235
+ import pdb; pdb.set_trace()
236
+ print('x: {}'.format(test.shape))
237
+ pe, H, W = patch_embed(test)
238
+ print('After patch: {}'.format(pe.shape))
239
+ y = block(pe, H, W)
240
+ print('After block: {}'.format(y.shape))
models/template.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import utils
5
+ from collections import OrderedDict
6
+
7
+ from .abs_model import abs_model
8
+ from .blocks import *
9
+ from .Loss.Loss import avg_norm_loss
10
+
11
+ class Template(abs_model):
12
+ """ Standard Unet Implementation
13
+ src: https://arxiv.org/pdf/1505.04597.pdf
14
+ """
15
+ def __init__(self, opt):
16
+ resunet = opt['model']['resunet']
17
+ out_act = opt['model']['out_act']
18
+ norm_type = opt['model']['norm_type']
19
+ in_channels = opt['model']['in_channels']
20
+ out_channels = opt['model']['out_channels']
21
+ self.ncols = opt['hyper_params']['n_cols']
22
+
23
+ self.model = Unet(in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ norm_type=norm_type,
26
+ out_act=out_act,
27
+ resunet=resunet)
28
+
29
+ self.optimizer = get_optimizer(opt, self.model)
30
+ self.visualization = {}
31
+
32
+
33
+ def setup_input(self, x):
34
+ return x
35
+
36
+
37
+ def forward(self, x):
38
+ return self.model(x)
39
+
40
+
41
+ def compute_loss(self, y, pred):
42
+ return avg_norm_loss(y, pred)
43
+
44
+
45
+ def supervise(self, input_x, y, is_training:bool)->float:
46
+ optimizer = self.optimizer
47
+ model = self.model
48
+
49
+ optimizer.zero_grad()
50
+ pred = model(input_x)
51
+ loss = self.compute_loss(y, pred)
52
+
53
+ if is_training:
54
+ loss.backward()
55
+ optimizer.step()
56
+
57
+ self.visualization['y'] = pred.detach()
58
+ self.visualization['pred'] = pred.detach()
59
+
60
+ return loss.item()
61
+
62
+
63
+ def get_visualize(self) -> OrderedDict:
64
+ """ Convert to visualization numpy array
65
+ """
66
+ nrows = self.ncols
67
+ visualizations = self.visualization
68
+ ret_vis = OrderedDict()
69
+
70
+ for k, v in visualizations.items():
71
+ batch = v.shape[0]
72
+ n = min(nrows, batch)
73
+
74
+ plot_v = v[:n]
75
+ ret_vis[k] = utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0)
76
+
77
+ return ret_vis
78
+
79
+
80
+ def inference(self, x):
81
+ # TODO
82
+ pass
83
+
84
+
85
+ def batch_inference(self, x):
86
+ # TODO
87
+ pass
88
+
89
+
90
+ """ Getter & Setter
91
+ """
92
+ def get_models(self) -> dict:
93
+ return {'model': self.model}
94
+
95
+
96
+ def get_optimizers(self) -> dict:
97
+ return {'optimizer': self.optimizer}
98
+
99
+
100
+ def set_models(self, models: dict) :
101
+ # input test
102
+ if 'model' not in models.keys():
103
+ raise ValueError('{} not in self.model'.format('model'))
104
+
105
+ self.model = models['model']
106
+
107
+
108
+ def set_optimizers(self, optimizer: dict):
109
+ self.optimizer = optimizer['optimizer']
110
+
111
+
112
+ ####################
113
+ # Personal Methods #
114
+ ####################
weights/SSN/0000001760.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44328317fa836804554ae453fe1492a45cff724b5c13b5070211d6d860096089
3
+ size 283511041