sczhou commited on
Commit
320e465
1 Parent(s): ea956de
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +131 -0
  2. LICENSE +14 -0
  3. RAFT/__init__.py +2 -0
  4. RAFT/corr.py +111 -0
  5. RAFT/datasets.py +235 -0
  6. RAFT/demo.py +79 -0
  7. RAFT/extractor.py +267 -0
  8. RAFT/raft.py +146 -0
  9. RAFT/update.py +139 -0
  10. RAFT/utils/__init__.py +2 -0
  11. RAFT/utils/augmentor.py +246 -0
  12. RAFT/utils/flow_viz.py +132 -0
  13. RAFT/utils/flow_viz_pt.py +118 -0
  14. RAFT/utils/frame_utils.py +137 -0
  15. RAFT/utils/utils.py +82 -0
  16. configs/train_flowcomp.json +40 -0
  17. configs/train_propainter.json +48 -0
  18. core/dataset.py +232 -0
  19. core/dist.py +47 -0
  20. core/loss.py +180 -0
  21. core/lr_scheduler.py +112 -0
  22. core/metrics.py +569 -0
  23. core/prefetch_dataloader.py +125 -0
  24. core/trainer.py +509 -0
  25. core/trainer_flow_w_edge.py +380 -0
  26. core/utils.py +371 -0
  27. datasets/davis/test.json +1 -0
  28. datasets/davis/train.json +1 -0
  29. datasets/youtube-vos/test.json +1 -0
  30. datasets/youtube-vos/train.json +1 -0
  31. inference_propainter.py +475 -0
  32. model/__init__.py +1 -0
  33. model/canny/canny_filter.py +256 -0
  34. model/canny/filter.py +288 -0
  35. model/canny/gaussian.py +116 -0
  36. model/canny/kernels.py +690 -0
  37. model/canny/sobel.py +263 -0
  38. model/misc.py +131 -0
  39. model/modules/base_module.py +131 -0
  40. model/modules/deformconv.py +54 -0
  41. model/modules/flow_comp_raft.py +265 -0
  42. model/modules/flow_loss_utils.py +142 -0
  43. model/modules/sparse_transformer.py +344 -0
  44. model/modules/spectral_norm.py +288 -0
  45. model/propainter.py +532 -0
  46. model/recurrent_flow_completion.py +347 -0
  47. model/vgg_arch.py +157 -0
  48. requirements.txt +33 -0
  49. scripts/compute_flow.py +108 -0
  50. scripts/evaluate_flow_completion.py +197 -0
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode
2
+
3
+ # ignored files
4
+ version.py
5
+
6
+ # ignored files with suffix
7
+ *.html
8
+ # *.png
9
+ # *.jpeg
10
+ # *.jpg
11
+ # *.gif
12
+ *.pt
13
+ *.pth
14
+ *.dat
15
+ *.zip
16
+
17
+ # template
18
+
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ build/
30
+ develop-eggs/
31
+ dist/
32
+ downloads/
33
+ eggs/
34
+ .eggs/
35
+ lib/
36
+ lib64/
37
+ parts/
38
+ sdist/
39
+ var/
40
+ wheels/
41
+ *.egg-info/
42
+ .installed.cfg
43
+ *.egg
44
+ MANIFEST
45
+
46
+ # PyInstaller
47
+ # Usually these files are written by a python script from a template
48
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
49
+ *.manifest
50
+ *.spec
51
+
52
+ # Installer logs
53
+ pip-log.txt
54
+ pip-delete-this-directory.txt
55
+
56
+ # Unit test / coverage reports
57
+ htmlcov/
58
+ .tox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # pyenv
94
+ .python-version
95
+
96
+ # celery beat schedule file
97
+ celerybeat-schedule
98
+
99
+ # SageMath parsed files
100
+ *.sage.py
101
+
102
+ # Environments
103
+ .env
104
+ .venv
105
+ env/
106
+ venv/
107
+ ENV/
108
+ env.bak/
109
+ venv.bak/
110
+
111
+ # Spyder project settings
112
+ .spyderproject
113
+ .spyproject
114
+
115
+ # Rope project settings
116
+ .ropeproject
117
+
118
+ # mkdocs documentation
119
+ /site
120
+
121
+ # mypy
122
+ .mypy_cache/
123
+
124
+ # project
125
+ experiments_model/
126
+ unreleased/
127
+ results_eval/
128
+ results/
129
+ *debug*
130
+ *old*
131
+ *.sh
LICENSE ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # S-Lab License 1.0
2
+
3
+ Copyright 2023 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
9
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10
+ 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
11
+
12
+
13
+ ---
14
+ For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
RAFT/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .demo import RAFT_infer
2
+ from .raft import RAFT
RAFT/corr.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .utils.utils import bilinear_sampler, coords_grid
4
+
5
+ try:
6
+ import alt_cuda_corr
7
+ except:
8
+ # alt_cuda_corr is not compiled
9
+ pass
10
+
11
+
12
+ class CorrBlock:
13
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14
+ self.num_levels = num_levels
15
+ self.radius = radius
16
+ self.corr_pyramid = []
17
+
18
+ # all pairs correlation
19
+ corr = CorrBlock.corr(fmap1, fmap2)
20
+
21
+ batch, h1, w1, dim, h2, w2 = corr.shape
22
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23
+
24
+ self.corr_pyramid.append(corr)
25
+ for i in range(self.num_levels-1):
26
+ corr = F.avg_pool2d(corr, 2, stride=2)
27
+ self.corr_pyramid.append(corr)
28
+
29
+ def __call__(self, coords):
30
+ r = self.radius
31
+ coords = coords.permute(0, 2, 3, 1)
32
+ batch, h1, w1, _ = coords.shape
33
+
34
+ out_pyramid = []
35
+ for i in range(self.num_levels):
36
+ corr = self.corr_pyramid[i]
37
+ dx = torch.linspace(-r, r, 2*r+1)
38
+ dy = torch.linspace(-r, r, 2*r+1)
39
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40
+
41
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43
+ coords_lvl = centroid_lvl + delta_lvl
44
+
45
+ corr = bilinear_sampler(corr, coords_lvl)
46
+ corr = corr.view(batch, h1, w1, -1)
47
+ out_pyramid.append(corr)
48
+
49
+ out = torch.cat(out_pyramid, dim=-1)
50
+ return out.permute(0, 3, 1, 2).contiguous().float()
51
+
52
+ @staticmethod
53
+ def corr(fmap1, fmap2):
54
+ batch, dim, ht, wd = fmap1.shape
55
+ fmap1 = fmap1.view(batch, dim, ht*wd)
56
+ fmap2 = fmap2.view(batch, dim, ht*wd)
57
+
58
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
60
+ return corr / torch.sqrt(torch.tensor(dim).float())
61
+
62
+
63
+ class CorrLayer(torch.autograd.Function):
64
+ @staticmethod
65
+ def forward(ctx, fmap1, fmap2, coords, r):
66
+ fmap1 = fmap1.contiguous()
67
+ fmap2 = fmap2.contiguous()
68
+ coords = coords.contiguous()
69
+ ctx.save_for_backward(fmap1, fmap2, coords)
70
+ ctx.r = r
71
+ corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72
+ return corr
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_corr):
76
+ fmap1, fmap2, coords = ctx.saved_tensors
77
+ grad_corr = grad_corr.contiguous()
78
+ fmap1_grad, fmap2_grad, coords_grad = \
79
+ correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
80
+ return fmap1_grad, fmap2_grad, coords_grad, None
81
+
82
+
83
+ class AlternateCorrBlock:
84
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
85
+ self.num_levels = num_levels
86
+ self.radius = radius
87
+
88
+ self.pyramid = [(fmap1, fmap2)]
89
+ for i in range(self.num_levels):
90
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
91
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
92
+ self.pyramid.append((fmap1, fmap2))
93
+
94
+ def __call__(self, coords):
95
+
96
+ coords = coords.permute(0, 2, 3, 1)
97
+ B, H, W, _ = coords.shape
98
+
99
+ corr_list = []
100
+ for i in range(self.num_levels):
101
+ r = self.radius
102
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104
+
105
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106
+ corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107
+ corr_list.append(corr.squeeze(1))
108
+
109
+ corr = torch.stack(corr_list, dim=1)
110
+ corr = corr.reshape(B, -1, H, W)
111
+ return corr / 16.0
RAFT/datasets.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.data as data
6
+ import torch.nn.functional as F
7
+
8
+ import os
9
+ import math
10
+ import random
11
+ from glob import glob
12
+ import os.path as osp
13
+
14
+ from utils import frame_utils
15
+ from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16
+
17
+
18
+ class FlowDataset(data.Dataset):
19
+ def __init__(self, aug_params=None, sparse=False):
20
+ self.augmentor = None
21
+ self.sparse = sparse
22
+ if aug_params is not None:
23
+ if sparse:
24
+ self.augmentor = SparseFlowAugmentor(**aug_params)
25
+ else:
26
+ self.augmentor = FlowAugmentor(**aug_params)
27
+
28
+ self.is_test = False
29
+ self.init_seed = False
30
+ self.flow_list = []
31
+ self.image_list = []
32
+ self.extra_info = []
33
+
34
+ def __getitem__(self, index):
35
+
36
+ if self.is_test:
37
+ img1 = frame_utils.read_gen(self.image_list[index][0])
38
+ img2 = frame_utils.read_gen(self.image_list[index][1])
39
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
40
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
41
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43
+ return img1, img2, self.extra_info[index]
44
+
45
+ if not self.init_seed:
46
+ worker_info = torch.utils.data.get_worker_info()
47
+ if worker_info is not None:
48
+ torch.manual_seed(worker_info.id)
49
+ np.random.seed(worker_info.id)
50
+ random.seed(worker_info.id)
51
+ self.init_seed = True
52
+
53
+ index = index % len(self.image_list)
54
+ valid = None
55
+ if self.sparse:
56
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57
+ else:
58
+ flow = frame_utils.read_gen(self.flow_list[index])
59
+
60
+ img1 = frame_utils.read_gen(self.image_list[index][0])
61
+ img2 = frame_utils.read_gen(self.image_list[index][1])
62
+
63
+ flow = np.array(flow).astype(np.float32)
64
+ img1 = np.array(img1).astype(np.uint8)
65
+ img2 = np.array(img2).astype(np.uint8)
66
+
67
+ # grayscale images
68
+ if len(img1.shape) == 2:
69
+ img1 = np.tile(img1[...,None], (1, 1, 3))
70
+ img2 = np.tile(img2[...,None], (1, 1, 3))
71
+ else:
72
+ img1 = img1[..., :3]
73
+ img2 = img2[..., :3]
74
+
75
+ if self.augmentor is not None:
76
+ if self.sparse:
77
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78
+ else:
79
+ img1, img2, flow = self.augmentor(img1, img2, flow)
80
+
81
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84
+
85
+ if valid is not None:
86
+ valid = torch.from_numpy(valid)
87
+ else:
88
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89
+
90
+ return img1, img2, flow, valid.float()
91
+
92
+
93
+ def __rmul__(self, v):
94
+ self.flow_list = v * self.flow_list
95
+ self.image_list = v * self.image_list
96
+ return self
97
+
98
+ def __len__(self):
99
+ return len(self.image_list)
100
+
101
+
102
+ class MpiSintel(FlowDataset):
103
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104
+ super(MpiSintel, self).__init__(aug_params)
105
+ flow_root = osp.join(root, split, 'flow')
106
+ image_root = osp.join(root, split, dstype)
107
+
108
+ if split == 'test':
109
+ self.is_test = True
110
+
111
+ for scene in os.listdir(image_root):
112
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113
+ for i in range(len(image_list)-1):
114
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
115
+ self.extra_info += [ (scene, i) ] # scene and frame_id
116
+
117
+ if split != 'test':
118
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119
+
120
+
121
+ class FlyingChairs(FlowDataset):
122
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123
+ super(FlyingChairs, self).__init__(aug_params)
124
+
125
+ images = sorted(glob(osp.join(root, '*.ppm')))
126
+ flows = sorted(glob(osp.join(root, '*.flo')))
127
+ assert (len(images)//2 == len(flows))
128
+
129
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130
+ for i in range(len(flows)):
131
+ xid = split_list[i]
132
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
133
+ self.flow_list += [ flows[i] ]
134
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
135
+
136
+
137
+ class FlyingThings3D(FlowDataset):
138
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139
+ super(FlyingThings3D, self).__init__(aug_params)
140
+
141
+ for cam in ['left']:
142
+ for direction in ['into_future', 'into_past']:
143
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145
+
146
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148
+
149
+ for idir, fdir in zip(image_dirs, flow_dirs):
150
+ images = sorted(glob(osp.join(idir, '*.png')) )
151
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152
+ for i in range(len(flows)-1):
153
+ if direction == 'into_future':
154
+ self.image_list += [ [images[i], images[i+1]] ]
155
+ self.flow_list += [ flows[i] ]
156
+ elif direction == 'into_past':
157
+ self.image_list += [ [images[i+1], images[i]] ]
158
+ self.flow_list += [ flows[i+1] ]
159
+
160
+
161
+ class KITTI(FlowDataset):
162
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163
+ super(KITTI, self).__init__(aug_params, sparse=True)
164
+ if split == 'testing':
165
+ self.is_test = True
166
+
167
+ root = osp.join(root, split)
168
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170
+
171
+ for img1, img2 in zip(images1, images2):
172
+ frame_id = img1.split('/')[-1]
173
+ self.extra_info += [ [frame_id] ]
174
+ self.image_list += [ [img1, img2] ]
175
+
176
+ if split == 'training':
177
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178
+
179
+
180
+ class HD1K(FlowDataset):
181
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
182
+ super(HD1K, self).__init__(aug_params, sparse=True)
183
+
184
+ seq_ix = 0
185
+ while 1:
186
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188
+
189
+ if len(flows) == 0:
190
+ break
191
+
192
+ for i in range(len(flows)-1):
193
+ self.flow_list += [flows[i]]
194
+ self.image_list += [ [images[i], images[i+1]] ]
195
+
196
+ seq_ix += 1
197
+
198
+
199
+ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200
+ """ Create the data loader for the corresponding trainign set """
201
+
202
+ if args.stage == 'chairs':
203
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204
+ train_dataset = FlyingChairs(aug_params, split='training')
205
+
206
+ elif args.stage == 'things':
207
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210
+ train_dataset = clean_dataset + final_dataset
211
+
212
+ elif args.stage == 'sintel':
213
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217
+
218
+ if TRAIN_DS == 'C+T+K+S+H':
219
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222
+
223
+ elif TRAIN_DS == 'C+T+K/S':
224
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
225
+
226
+ elif args.stage == 'kitti':
227
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228
+ train_dataset = KITTI(aug_params, split='training')
229
+
230
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232
+
233
+ print('Training with %d image pairs' % len(train_dataset))
234
+ return train_loader
235
+
RAFT/demo.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import os
4
+ import cv2
5
+ import glob
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from .raft import RAFT
11
+ from .utils import flow_viz
12
+ from .utils.utils import InputPadder
13
+
14
+
15
+
16
+ DEVICE = 'cuda'
17
+
18
+ def load_image(imfile):
19
+ img = np.array(Image.open(imfile)).astype(np.uint8)
20
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
21
+ return img
22
+
23
+
24
+ def load_image_list(image_files):
25
+ images = []
26
+ for imfile in sorted(image_files):
27
+ images.append(load_image(imfile))
28
+
29
+ images = torch.stack(images, dim=0)
30
+ images = images.to(DEVICE)
31
+
32
+ padder = InputPadder(images.shape)
33
+ return padder.pad(images)[0]
34
+
35
+
36
+ def viz(img, flo):
37
+ img = img[0].permute(1,2,0).cpu().numpy()
38
+ flo = flo[0].permute(1,2,0).cpu().numpy()
39
+
40
+ # map flow to rgb image
41
+ flo = flow_viz.flow_to_image(flo)
42
+ # img_flo = np.concatenate([img, flo], axis=0)
43
+ img_flo = flo
44
+
45
+ cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
46
+ # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
47
+ # cv2.waitKey()
48
+
49
+
50
+ def demo(args):
51
+ model = torch.nn.DataParallel(RAFT(args))
52
+ model.load_state_dict(torch.load(args.model))
53
+
54
+ model = model.module
55
+ model.to(DEVICE)
56
+ model.eval()
57
+
58
+ with torch.no_grad():
59
+ images = glob.glob(os.path.join(args.path, '*.png')) + \
60
+ glob.glob(os.path.join(args.path, '*.jpg'))
61
+
62
+ images = load_image_list(images)
63
+ for i in range(images.shape[0]-1):
64
+ image1 = images[i,None]
65
+ image2 = images[i+1,None]
66
+
67
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
68
+ viz(image1, flow_up)
69
+
70
+
71
+ def RAFT_infer(args):
72
+ model = torch.nn.DataParallel(RAFT(args))
73
+ model.load_state_dict(torch.load(args.model))
74
+
75
+ model = model.module
76
+ model.to(DEVICE)
77
+ model.eval()
78
+
79
+ return model
RAFT/extractor.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+
48
+ def forward(self, x):
49
+ y = x
50
+ y = self.relu(self.norm1(self.conv1(y)))
51
+ y = self.relu(self.norm2(self.conv2(y)))
52
+
53
+ if self.downsample is not None:
54
+ x = self.downsample(x)
55
+
56
+ return self.relu(x+y)
57
+
58
+
59
+
60
+ class BottleneckBlock(nn.Module):
61
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62
+ super(BottleneckBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67
+ self.relu = nn.ReLU(inplace=True)
68
+
69
+ num_groups = planes // 8
70
+
71
+ if norm_fn == 'group':
72
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+ if not stride == 1:
76
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77
+
78
+ elif norm_fn == 'batch':
79
+ self.norm1 = nn.BatchNorm2d(planes//4)
80
+ self.norm2 = nn.BatchNorm2d(planes//4)
81
+ self.norm3 = nn.BatchNorm2d(planes)
82
+ if not stride == 1:
83
+ self.norm4 = nn.BatchNorm2d(planes)
84
+
85
+ elif norm_fn == 'instance':
86
+ self.norm1 = nn.InstanceNorm2d(planes//4)
87
+ self.norm2 = nn.InstanceNorm2d(planes//4)
88
+ self.norm3 = nn.InstanceNorm2d(planes)
89
+ if not stride == 1:
90
+ self.norm4 = nn.InstanceNorm2d(planes)
91
+
92
+ elif norm_fn == 'none':
93
+ self.norm1 = nn.Sequential()
94
+ self.norm2 = nn.Sequential()
95
+ self.norm3 = nn.Sequential()
96
+ if not stride == 1:
97
+ self.norm4 = nn.Sequential()
98
+
99
+ if stride == 1:
100
+ self.downsample = None
101
+
102
+ else:
103
+ self.downsample = nn.Sequential(
104
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105
+
106
+
107
+ def forward(self, x):
108
+ y = x
109
+ y = self.relu(self.norm1(self.conv1(y)))
110
+ y = self.relu(self.norm2(self.conv2(y)))
111
+ y = self.relu(self.norm3(self.conv3(y)))
112
+
113
+ if self.downsample is not None:
114
+ x = self.downsample(x)
115
+
116
+ return self.relu(x+y)
117
+
118
+ class BasicEncoder(nn.Module):
119
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120
+ super(BasicEncoder, self).__init__()
121
+ self.norm_fn = norm_fn
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(64)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(64)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = 64
139
+ self.layer1 = self._make_layer(64, stride=1)
140
+ self.layer2 = self._make_layer(96, stride=2)
141
+ self.layer3 = self._make_layer(128, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145
+
146
+ self.dropout = None
147
+ if dropout > 0:
148
+ self.dropout = nn.Dropout2d(p=dropout)
149
+
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154
+ if m.weight is not None:
155
+ nn.init.constant_(m.weight, 1)
156
+ if m.bias is not None:
157
+ nn.init.constant_(m.bias, 0)
158
+
159
+ def _make_layer(self, dim, stride=1):
160
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162
+ layers = (layer1, layer2)
163
+
164
+ self.in_planes = dim
165
+ return nn.Sequential(*layers)
166
+
167
+
168
+ def forward(self, x):
169
+
170
+ # if input is list, combine batch dimension
171
+ is_list = isinstance(x, tuple) or isinstance(x, list)
172
+ if is_list:
173
+ batch_dim = x[0].shape[0]
174
+ x = torch.cat(x, dim=0)
175
+
176
+ x = self.conv1(x)
177
+ x = self.norm1(x)
178
+ x = self.relu1(x)
179
+
180
+ x = self.layer1(x)
181
+ x = self.layer2(x)
182
+ x = self.layer3(x)
183
+
184
+ x = self.conv2(x)
185
+
186
+ if self.training and self.dropout is not None:
187
+ x = self.dropout(x)
188
+
189
+ if is_list:
190
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
191
+
192
+ return x
193
+
194
+
195
+ class SmallEncoder(nn.Module):
196
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197
+ super(SmallEncoder, self).__init__()
198
+ self.norm_fn = norm_fn
199
+
200
+ if self.norm_fn == 'group':
201
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202
+
203
+ elif self.norm_fn == 'batch':
204
+ self.norm1 = nn.BatchNorm2d(32)
205
+
206
+ elif self.norm_fn == 'instance':
207
+ self.norm1 = nn.InstanceNorm2d(32)
208
+
209
+ elif self.norm_fn == 'none':
210
+ self.norm1 = nn.Sequential()
211
+
212
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213
+ self.relu1 = nn.ReLU(inplace=True)
214
+
215
+ self.in_planes = 32
216
+ self.layer1 = self._make_layer(32, stride=1)
217
+ self.layer2 = self._make_layer(64, stride=2)
218
+ self.layer3 = self._make_layer(96, stride=2)
219
+
220
+ self.dropout = None
221
+ if dropout > 0:
222
+ self.dropout = nn.Dropout2d(p=dropout)
223
+
224
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225
+
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+
244
+ def forward(self, x):
245
+
246
+ # if input is list, combine batch dimension
247
+ is_list = isinstance(x, tuple) or isinstance(x, list)
248
+ if is_list:
249
+ batch_dim = x[0].shape[0]
250
+ x = torch.cat(x, dim=0)
251
+
252
+ x = self.conv1(x)
253
+ x = self.norm1(x)
254
+ x = self.relu1(x)
255
+
256
+ x = self.layer1(x)
257
+ x = self.layer2(x)
258
+ x = self.layer3(x)
259
+ x = self.conv2(x)
260
+
261
+ if self.training and self.dropout is not None:
262
+ x = self.dropout(x)
263
+
264
+ if is_list:
265
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
266
+
267
+ return x
RAFT/raft.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .update import BasicUpdateBlock, SmallUpdateBlock
7
+ from .extractor import BasicEncoder, SmallEncoder
8
+ from .corr import CorrBlock, AlternateCorrBlock
9
+ from .utils.utils import bilinear_sampler, coords_grid, upflow8
10
+
11
+ try:
12
+ autocast = torch.cuda.amp.autocast
13
+ except:
14
+ # dummy autocast for PyTorch < 1.6
15
+ class autocast:
16
+ def __init__(self, enabled):
17
+ pass
18
+ def __enter__(self):
19
+ pass
20
+ def __exit__(self, *args):
21
+ pass
22
+
23
+
24
+ class RAFT(nn.Module):
25
+ def __init__(self, args):
26
+ super(RAFT, self).__init__()
27
+ self.args = args
28
+
29
+ if args.small:
30
+ self.hidden_dim = hdim = 96
31
+ self.context_dim = cdim = 64
32
+ args.corr_levels = 4
33
+ args.corr_radius = 3
34
+
35
+ else:
36
+ self.hidden_dim = hdim = 128
37
+ self.context_dim = cdim = 128
38
+ args.corr_levels = 4
39
+ args.corr_radius = 4
40
+
41
+ if 'dropout' not in args._get_kwargs():
42
+ args.dropout = 0
43
+
44
+ if 'alternate_corr' not in args._get_kwargs():
45
+ args.alternate_corr = False
46
+
47
+ # feature network, context network, and update block
48
+ if args.small:
49
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52
+
53
+ else:
54
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57
+
58
+
59
+ def freeze_bn(self):
60
+ for m in self.modules():
61
+ if isinstance(m, nn.BatchNorm2d):
62
+ m.eval()
63
+
64
+ def initialize_flow(self, img):
65
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66
+ N, C, H, W = img.shape
67
+ coords0 = coords_grid(N, H//8, W//8).to(img.device)
68
+ coords1 = coords_grid(N, H//8, W//8).to(img.device)
69
+
70
+ # optical flow computed as difference: flow = coords1 - coords0
71
+ return coords0, coords1
72
+
73
+ def upsample_flow(self, flow, mask):
74
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75
+ N, _, H, W = flow.shape
76
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
77
+ mask = torch.softmax(mask, dim=2)
78
+
79
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
80
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81
+
82
+ up_flow = torch.sum(mask * up_flow, dim=2)
83
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84
+ return up_flow.reshape(N, 2, 8*H, 8*W)
85
+
86
+
87
+ def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
88
+ """ Estimate optical flow between pair of frames """
89
+
90
+ # image1 = 2 * (image1 / 255.0) - 1.0
91
+ # image2 = 2 * (image2 / 255.0) - 1.0
92
+
93
+ image1 = image1.contiguous()
94
+ image2 = image2.contiguous()
95
+
96
+ hdim = self.hidden_dim
97
+ cdim = self.context_dim
98
+
99
+ # run the feature network
100
+ with autocast(enabled=self.args.mixed_precision):
101
+ fmap1, fmap2 = self.fnet([image1, image2])
102
+
103
+ fmap1 = fmap1.float()
104
+ fmap2 = fmap2.float()
105
+
106
+ if self.args.alternate_corr:
107
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
108
+ else:
109
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
110
+
111
+ # run the context network
112
+ with autocast(enabled=self.args.mixed_precision):
113
+ cnet = self.cnet(image1)
114
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
115
+ net = torch.tanh(net)
116
+ inp = torch.relu(inp)
117
+
118
+ coords0, coords1 = self.initialize_flow(image1)
119
+
120
+ if flow_init is not None:
121
+ coords1 = coords1 + flow_init
122
+
123
+ flow_predictions = []
124
+ for itr in range(iters):
125
+ coords1 = coords1.detach()
126
+ corr = corr_fn(coords1) # index correlation volume
127
+
128
+ flow = coords1 - coords0
129
+ with autocast(enabled=self.args.mixed_precision):
130
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
131
+
132
+ # F(t+1) = F(t) + \Delta(t)
133
+ coords1 = coords1 + delta_flow
134
+
135
+ # upsample predictions
136
+ if up_mask is None:
137
+ flow_up = upflow8(coords1 - coords0)
138
+ else:
139
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
140
+
141
+ flow_predictions.append(flow_up)
142
+
143
+ if test_mode:
144
+ return coords1 - coords0, flow_up
145
+
146
+ return flow_predictions
RAFT/update.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256):
8
+ super(FlowHead, self).__init__()
9
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11
+ self.relu = nn.ReLU(inplace=True)
12
+
13
+ def forward(self, x):
14
+ return self.conv2(self.relu(self.conv1(x)))
15
+
16
+ class ConvGRU(nn.Module):
17
+ def __init__(self, hidden_dim=128, input_dim=192+128):
18
+ super(ConvGRU, self).__init__()
19
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22
+
23
+ def forward(self, h, x):
24
+ hx = torch.cat([h, x], dim=1)
25
+
26
+ z = torch.sigmoid(self.convz(hx))
27
+ r = torch.sigmoid(self.convr(hx))
28
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29
+
30
+ h = (1-z) * h + z * q
31
+ return h
32
+
33
+ class SepConvGRU(nn.Module):
34
+ def __init__(self, hidden_dim=128, input_dim=192+128):
35
+ super(SepConvGRU, self).__init__()
36
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39
+
40
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43
+
44
+
45
+ def forward(self, h, x):
46
+ # horizontal
47
+ hx = torch.cat([h, x], dim=1)
48
+ z = torch.sigmoid(self.convz1(hx))
49
+ r = torch.sigmoid(self.convr1(hx))
50
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51
+ h = (1-z) * h + z * q
52
+
53
+ # vertical
54
+ hx = torch.cat([h, x], dim=1)
55
+ z = torch.sigmoid(self.convz2(hx))
56
+ r = torch.sigmoid(self.convr2(hx))
57
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58
+ h = (1-z) * h + z * q
59
+
60
+ return h
61
+
62
+ class SmallMotionEncoder(nn.Module):
63
+ def __init__(self, args):
64
+ super(SmallMotionEncoder, self).__init__()
65
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
70
+
71
+ def forward(self, flow, corr):
72
+ cor = F.relu(self.convc1(corr))
73
+ flo = F.relu(self.convf1(flow))
74
+ flo = F.relu(self.convf2(flo))
75
+ cor_flo = torch.cat([cor, flo], dim=1)
76
+ out = F.relu(self.conv(cor_flo))
77
+ return torch.cat([out, flow], dim=1)
78
+
79
+ class BasicMotionEncoder(nn.Module):
80
+ def __init__(self, args):
81
+ super(BasicMotionEncoder, self).__init__()
82
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88
+
89
+ def forward(self, flow, corr):
90
+ cor = F.relu(self.convc1(corr))
91
+ cor = F.relu(self.convc2(cor))
92
+ flo = F.relu(self.convf1(flow))
93
+ flo = F.relu(self.convf2(flo))
94
+
95
+ cor_flo = torch.cat([cor, flo], dim=1)
96
+ out = F.relu(self.conv(cor_flo))
97
+ return torch.cat([out, flow], dim=1)
98
+
99
+ class SmallUpdateBlock(nn.Module):
100
+ def __init__(self, args, hidden_dim=96):
101
+ super(SmallUpdateBlock, self).__init__()
102
+ self.encoder = SmallMotionEncoder(args)
103
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105
+
106
+ def forward(self, net, inp, corr, flow):
107
+ motion_features = self.encoder(flow, corr)
108
+ inp = torch.cat([inp, motion_features], dim=1)
109
+ net = self.gru(net, inp)
110
+ delta_flow = self.flow_head(net)
111
+
112
+ return net, None, delta_flow
113
+
114
+ class BasicUpdateBlock(nn.Module):
115
+ def __init__(self, args, hidden_dim=128, input_dim=128):
116
+ super(BasicUpdateBlock, self).__init__()
117
+ self.args = args
118
+ self.encoder = BasicMotionEncoder(args)
119
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121
+
122
+ self.mask = nn.Sequential(
123
+ nn.Conv2d(128, 256, 3, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(256, 64*9, 1, padding=0))
126
+
127
+ def forward(self, net, inp, corr, flow, upsample=True):
128
+ motion_features = self.encoder(flow, corr)
129
+ inp = torch.cat([inp, motion_features], dim=1)
130
+
131
+ net = self.gru(net, inp)
132
+ delta_flow = self.flow_head(net)
133
+
134
+ # scale mask to balence gradients
135
+ mask = .25 * self.mask(net)
136
+ return net, mask, delta_flow
137
+
138
+
139
+
RAFT/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .flow_viz import flow_to_image
2
+ from .frame_utils import writeFlow
RAFT/utils/augmentor.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import math
4
+ from PIL import Image
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ import torch
11
+ from torchvision.transforms import ColorJitter
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class FlowAugmentor:
16
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17
+
18
+ # spatial augmentation params
19
+ self.crop_size = crop_size
20
+ self.min_scale = min_scale
21
+ self.max_scale = max_scale
22
+ self.spatial_aug_prob = 0.8
23
+ self.stretch_prob = 0.8
24
+ self.max_stretch = 0.2
25
+
26
+ # flip augmentation params
27
+ self.do_flip = do_flip
28
+ self.h_flip_prob = 0.5
29
+ self.v_flip_prob = 0.1
30
+
31
+ # photometric augmentation params
32
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33
+ self.asymmetric_color_aug_prob = 0.2
34
+ self.eraser_aug_prob = 0.5
35
+
36
+ def color_transform(self, img1, img2):
37
+ """ Photometric augmentation """
38
+
39
+ # asymmetric
40
+ if np.random.rand() < self.asymmetric_color_aug_prob:
41
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43
+
44
+ # symmetric
45
+ else:
46
+ image_stack = np.concatenate([img1, img2], axis=0)
47
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48
+ img1, img2 = np.split(image_stack, 2, axis=0)
49
+
50
+ return img1, img2
51
+
52
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
53
+ """ Occlusion augmentation """
54
+
55
+ ht, wd = img1.shape[:2]
56
+ if np.random.rand() < self.eraser_aug_prob:
57
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58
+ for _ in range(np.random.randint(1, 3)):
59
+ x0 = np.random.randint(0, wd)
60
+ y0 = np.random.randint(0, ht)
61
+ dx = np.random.randint(bounds[0], bounds[1])
62
+ dy = np.random.randint(bounds[0], bounds[1])
63
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64
+
65
+ return img1, img2
66
+
67
+ def spatial_transform(self, img1, img2, flow):
68
+ # randomly sample scale
69
+ ht, wd = img1.shape[:2]
70
+ min_scale = np.maximum(
71
+ (self.crop_size[0] + 8) / float(ht),
72
+ (self.crop_size[1] + 8) / float(wd))
73
+
74
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75
+ scale_x = scale
76
+ scale_y = scale
77
+ if np.random.rand() < self.stretch_prob:
78
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80
+
81
+ scale_x = np.clip(scale_x, min_scale, None)
82
+ scale_y = np.clip(scale_y, min_scale, None)
83
+
84
+ if np.random.rand() < self.spatial_aug_prob:
85
+ # rescale the images
86
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89
+ flow = flow * [scale_x, scale_y]
90
+
91
+ if self.do_flip:
92
+ if np.random.rand() < self.h_flip_prob: # h-flip
93
+ img1 = img1[:, ::-1]
94
+ img2 = img2[:, ::-1]
95
+ flow = flow[:, ::-1] * [-1.0, 1.0]
96
+
97
+ if np.random.rand() < self.v_flip_prob: # v-flip
98
+ img1 = img1[::-1, :]
99
+ img2 = img2[::-1, :]
100
+ flow = flow[::-1, :] * [1.0, -1.0]
101
+
102
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104
+
105
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108
+
109
+ return img1, img2, flow
110
+
111
+ def __call__(self, img1, img2, flow):
112
+ img1, img2 = self.color_transform(img1, img2)
113
+ img1, img2 = self.eraser_transform(img1, img2)
114
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
115
+
116
+ img1 = np.ascontiguousarray(img1)
117
+ img2 = np.ascontiguousarray(img2)
118
+ flow = np.ascontiguousarray(flow)
119
+
120
+ return img1, img2, flow
121
+
122
+ class SparseFlowAugmentor:
123
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124
+ # spatial augmentation params
125
+ self.crop_size = crop_size
126
+ self.min_scale = min_scale
127
+ self.max_scale = max_scale
128
+ self.spatial_aug_prob = 0.8
129
+ self.stretch_prob = 0.8
130
+ self.max_stretch = 0.2
131
+
132
+ # flip augmentation params
133
+ self.do_flip = do_flip
134
+ self.h_flip_prob = 0.5
135
+ self.v_flip_prob = 0.1
136
+
137
+ # photometric augmentation params
138
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139
+ self.asymmetric_color_aug_prob = 0.2
140
+ self.eraser_aug_prob = 0.5
141
+
142
+ def color_transform(self, img1, img2):
143
+ image_stack = np.concatenate([img1, img2], axis=0)
144
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145
+ img1, img2 = np.split(image_stack, 2, axis=0)
146
+ return img1, img2
147
+
148
+ def eraser_transform(self, img1, img2):
149
+ ht, wd = img1.shape[:2]
150
+ if np.random.rand() < self.eraser_aug_prob:
151
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152
+ for _ in range(np.random.randint(1, 3)):
153
+ x0 = np.random.randint(0, wd)
154
+ y0 = np.random.randint(0, ht)
155
+ dx = np.random.randint(50, 100)
156
+ dy = np.random.randint(50, 100)
157
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158
+
159
+ return img1, img2
160
+
161
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162
+ ht, wd = flow.shape[:2]
163
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
164
+ coords = np.stack(coords, axis=-1)
165
+
166
+ coords = coords.reshape(-1, 2).astype(np.float32)
167
+ flow = flow.reshape(-1, 2).astype(np.float32)
168
+ valid = valid.reshape(-1).astype(np.float32)
169
+
170
+ coords0 = coords[valid>=1]
171
+ flow0 = flow[valid>=1]
172
+
173
+ ht1 = int(round(ht * fy))
174
+ wd1 = int(round(wd * fx))
175
+
176
+ coords1 = coords0 * [fx, fy]
177
+ flow1 = flow0 * [fx, fy]
178
+
179
+ xx = np.round(coords1[:,0]).astype(np.int32)
180
+ yy = np.round(coords1[:,1]).astype(np.int32)
181
+
182
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183
+ xx = xx[v]
184
+ yy = yy[v]
185
+ flow1 = flow1[v]
186
+
187
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189
+
190
+ flow_img[yy, xx] = flow1
191
+ valid_img[yy, xx] = 1
192
+
193
+ return flow_img, valid_img
194
+
195
+ def spatial_transform(self, img1, img2, flow, valid):
196
+ # randomly sample scale
197
+
198
+ ht, wd = img1.shape[:2]
199
+ min_scale = np.maximum(
200
+ (self.crop_size[0] + 1) / float(ht),
201
+ (self.crop_size[1] + 1) / float(wd))
202
+
203
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204
+ scale_x = np.clip(scale, min_scale, None)
205
+ scale_y = np.clip(scale, min_scale, None)
206
+
207
+ if np.random.rand() < self.spatial_aug_prob:
208
+ # rescale the images
209
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212
+
213
+ if self.do_flip:
214
+ if np.random.rand() < 0.5: # h-flip
215
+ img1 = img1[:, ::-1]
216
+ img2 = img2[:, ::-1]
217
+ flow = flow[:, ::-1] * [-1.0, 1.0]
218
+ valid = valid[:, ::-1]
219
+
220
+ margin_y = 20
221
+ margin_x = 50
222
+
223
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225
+
226
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228
+
229
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233
+ return img1, img2, flow, valid
234
+
235
+
236
+ def __call__(self, img1, img2, flow, valid):
237
+ img1, img2 = self.color_transform(img1, img2)
238
+ img1, img2 = self.eraser_transform(img1, img2)
239
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240
+
241
+ img1 = np.ascontiguousarray(img1)
242
+ img2 = np.ascontiguousarray(img2)
243
+ flow = np.ascontiguousarray(flow)
244
+ valid = np.ascontiguousarray(valid)
245
+
246
+ return img1, img2, flow, valid
RAFT/utils/flow_viz.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2
+
3
+
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2018 Tom Runia
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to conditions.
14
+ #
15
+ # Author: Tom Runia
16
+ # Date Created: 2018-08-03
17
+
18
+ import numpy as np
19
+
20
+ def make_colorwheel():
21
+ """
22
+ Generates a color wheel for optical flow visualization as presented in:
23
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25
+
26
+ Code follows the original C++ source code of Daniel Scharstein.
27
+ Code follows the the Matlab source code of Deqing Sun.
28
+
29
+ Returns:
30
+ np.ndarray: Color wheel
31
+ """
32
+
33
+ RY = 15
34
+ YG = 6
35
+ GC = 4
36
+ CB = 11
37
+ BM = 13
38
+ MR = 6
39
+
40
+ ncols = RY + YG + GC + CB + BM + MR
41
+ colorwheel = np.zeros((ncols, 3))
42
+ col = 0
43
+
44
+ # RY
45
+ colorwheel[0:RY, 0] = 255
46
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47
+ col = col+RY
48
+ # YG
49
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50
+ colorwheel[col:col+YG, 1] = 255
51
+ col = col+YG
52
+ # GC
53
+ colorwheel[col:col+GC, 1] = 255
54
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55
+ col = col+GC
56
+ # CB
57
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58
+ colorwheel[col:col+CB, 2] = 255
59
+ col = col+CB
60
+ # BM
61
+ colorwheel[col:col+BM, 2] = 255
62
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63
+ col = col+BM
64
+ # MR
65
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66
+ colorwheel[col:col+MR, 0] = 255
67
+ return colorwheel
68
+
69
+
70
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
71
+ """
72
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
73
+
74
+ According to the C++ source code of Daniel Scharstein
75
+ According to the Matlab source code of Deqing Sun
76
+
77
+ Args:
78
+ u (np.ndarray): Input horizontal flow of shape [H,W]
79
+ v (np.ndarray): Input vertical flow of shape [H,W]
80
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81
+
82
+ Returns:
83
+ np.ndarray: Flow visualization image of shape [H,W,3]
84
+ """
85
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86
+ colorwheel = make_colorwheel() # shape [55x3]
87
+ ncols = colorwheel.shape[0]
88
+ rad = np.sqrt(np.square(u) + np.square(v))
89
+ a = np.arctan2(-v, -u)/np.pi
90
+ fk = (a+1) / 2*(ncols-1)
91
+ k0 = np.floor(fk).astype(np.int32)
92
+ k1 = k0 + 1
93
+ k1[k1 == ncols] = 0
94
+ f = fk - k0
95
+ for i in range(colorwheel.shape[1]):
96
+ tmp = colorwheel[:,i]
97
+ col0 = tmp[k0] / 255.0
98
+ col1 = tmp[k1] / 255.0
99
+ col = (1-f)*col0 + f*col1
100
+ idx = (rad <= 1)
101
+ col[idx] = 1 - rad[idx] * (1-col[idx])
102
+ col[~idx] = col[~idx] * 0.75 # out of range
103
+ # Note the 2-i => BGR instead of RGB
104
+ ch_idx = 2-i if convert_to_bgr else i
105
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
106
+ return flow_image
107
+
108
+
109
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
110
+ """
111
+ Expects a two dimensional flow image of shape.
112
+
113
+ Args:
114
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117
+
118
+ Returns:
119
+ np.ndarray: Flow visualization image of shape [H,W,3]
120
+ """
121
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123
+ if clip_flow is not None:
124
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
125
+ u = flow_uv[:,:,0]
126
+ v = flow_uv[:,:,1]
127
+ rad = np.sqrt(np.square(u) + np.square(v))
128
+ rad_max = np.max(rad)
129
+ epsilon = 1e-5
130
+ u = u / (rad_max + epsilon)
131
+ v = v / (rad_max + epsilon)
132
+ return flow_uv_to_colors(u, v, convert_to_bgr)
RAFT/utils/flow_viz_pt.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
2
+ import torch
3
+ torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
4
+
5
+ @torch.no_grad()
6
+ def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
7
+
8
+ """
9
+ Converts a flow to an RGB image.
10
+
11
+ Args:
12
+ flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
13
+
14
+ Returns:
15
+ img (Tensor): Image Tensor of dtype uint8 where each color corresponds
16
+ to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
17
+ """
18
+
19
+ if flow.dtype != torch.float:
20
+ raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
21
+
22
+ orig_shape = flow.shape
23
+ if flow.ndim == 3:
24
+ flow = flow[None] # Add batch dim
25
+
26
+ if flow.ndim != 4 or flow.shape[1] != 2:
27
+ raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
28
+
29
+ max_norm = torch.sum(flow**2, dim=1).sqrt().max()
30
+ epsilon = torch.finfo((flow).dtype).eps
31
+ normalized_flow = flow / (max_norm + epsilon)
32
+ img = _normalized_flow_to_image(normalized_flow)
33
+
34
+ if len(orig_shape) == 3:
35
+ img = img[0] # Remove batch dim
36
+ return img
37
+
38
+ @torch.no_grad()
39
+ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
40
+
41
+ """
42
+ Converts a batch of normalized flow to an RGB image.
43
+
44
+ Args:
45
+ normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
46
+ Returns:
47
+ img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
48
+ """
49
+
50
+ N, _, H, W = normalized_flow.shape
51
+ device = normalized_flow.device
52
+ flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
53
+ colorwheel = _make_colorwheel().to(device) # shape [55x3]
54
+ num_cols = colorwheel.shape[0]
55
+ norm = torch.sum(normalized_flow**2, dim=1).sqrt()
56
+ a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
57
+ fk = (a + 1) / 2 * (num_cols - 1)
58
+ k0 = torch.floor(fk).to(torch.long)
59
+ k1 = k0 + 1
60
+ k1[k1 == num_cols] = 0
61
+ f = fk - k0
62
+
63
+ for c in range(colorwheel.shape[1]):
64
+ tmp = colorwheel[:, c]
65
+ col0 = tmp[k0] / 255.0
66
+ col1 = tmp[k1] / 255.0
67
+ col = (1 - f) * col0 + f * col1
68
+ col = 1 - norm * (1 - col)
69
+ flow_image[:, c, :, :] = torch.floor(255. * col)
70
+ return flow_image
71
+
72
+
73
+ @torch.no_grad()
74
+ def _make_colorwheel() -> torch.Tensor:
75
+ """
76
+ Generates a color wheel for optical flow visualization as presented in:
77
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
78
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
79
+
80
+ Returns:
81
+ colorwheel (Tensor[55, 3]): Colorwheel Tensor.
82
+ """
83
+
84
+ RY = 15
85
+ YG = 6
86
+ GC = 4
87
+ CB = 11
88
+ BM = 13
89
+ MR = 6
90
+
91
+ ncols = RY + YG + GC + CB + BM + MR
92
+ colorwheel = torch.zeros((ncols, 3))
93
+ col = 0
94
+
95
+ # RY
96
+ colorwheel[0:RY, 0] = 255
97
+ colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
98
+ col = col + RY
99
+ # YG
100
+ colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
101
+ colorwheel[col : col + YG, 1] = 255
102
+ col = col + YG
103
+ # GC
104
+ colorwheel[col : col + GC, 1] = 255
105
+ colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
106
+ col = col + GC
107
+ # CB
108
+ colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
109
+ colorwheel[col : col + CB, 2] = 255
110
+ col = col + CB
111
+ # BM
112
+ colorwheel[col : col + BM, 2] = 255
113
+ colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
114
+ col = col + BM
115
+ # MR
116
+ colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
117
+ colorwheel[col : col + MR, 0] = 255
118
+ return colorwheel
RAFT/utils/frame_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from os.path import *
4
+ import re
5
+
6
+ import cv2
7
+ cv2.setNumThreads(0)
8
+ cv2.ocl.setUseOpenCL(False)
9
+
10
+ TAG_CHAR = np.array([202021.25], np.float32)
11
+
12
+ def readFlow(fn):
13
+ """ Read .flo file in Middlebury format"""
14
+ # Code adapted from:
15
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16
+
17
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18
+ # print 'fn = %s'%(fn)
19
+ with open(fn, 'rb') as f:
20
+ magic = np.fromfile(f, np.float32, count=1)
21
+ if 202021.25 != magic:
22
+ print('Magic number incorrect. Invalid .flo file')
23
+ return None
24
+ else:
25
+ w = np.fromfile(f, np.int32, count=1)
26
+ h = np.fromfile(f, np.int32, count=1)
27
+ # print 'Reading %d x %d flo file\n' % (w, h)
28
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29
+ # Reshape data into 3D array (columns, rows, bands)
30
+ # The reshape here is for visualization, the original code is (w,h,2)
31
+ return np.resize(data, (int(h), int(w), 2))
32
+
33
+ def readPFM(file):
34
+ file = open(file, 'rb')
35
+
36
+ color = None
37
+ width = None
38
+ height = None
39
+ scale = None
40
+ endian = None
41
+
42
+ header = file.readline().rstrip()
43
+ if header == b'PF':
44
+ color = True
45
+ elif header == b'Pf':
46
+ color = False
47
+ else:
48
+ raise Exception('Not a PFM file.')
49
+
50
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51
+ if dim_match:
52
+ width, height = map(int, dim_match.groups())
53
+ else:
54
+ raise Exception('Malformed PFM header.')
55
+
56
+ scale = float(file.readline().rstrip())
57
+ if scale < 0: # little-endian
58
+ endian = '<'
59
+ scale = -scale
60
+ else:
61
+ endian = '>' # big-endian
62
+
63
+ data = np.fromfile(file, endian + 'f')
64
+ shape = (height, width, 3) if color else (height, width)
65
+
66
+ data = np.reshape(data, shape)
67
+ data = np.flipud(data)
68
+ return data
69
+
70
+ def writeFlow(filename,uv,v=None):
71
+ """ Write optical flow to file.
72
+
73
+ If v is None, uv is assumed to contain both u and v channels,
74
+ stacked in depth.
75
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
76
+ """
77
+ nBands = 2
78
+
79
+ if v is None:
80
+ assert(uv.ndim == 3)
81
+ assert(uv.shape[2] == 2)
82
+ u = uv[:,:,0]
83
+ v = uv[:,:,1]
84
+ else:
85
+ u = uv
86
+
87
+ assert(u.shape == v.shape)
88
+ height,width = u.shape
89
+ f = open(filename,'wb')
90
+ # write the header
91
+ f.write(TAG_CHAR)
92
+ np.array(width).astype(np.int32).tofile(f)
93
+ np.array(height).astype(np.int32).tofile(f)
94
+ # arrange into matrix form
95
+ tmp = np.zeros((height, width*nBands))
96
+ tmp[:,np.arange(width)*2] = u
97
+ tmp[:,np.arange(width)*2 + 1] = v
98
+ tmp.astype(np.float32).tofile(f)
99
+ f.close()
100
+
101
+
102
+ def readFlowKITTI(filename):
103
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104
+ flow = flow[:,:,::-1].astype(np.float32)
105
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
106
+ flow = (flow - 2**15) / 64.0
107
+ return flow, valid
108
+
109
+ def readDispKITTI(filename):
110
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111
+ valid = disp > 0.0
112
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
113
+ return flow, valid
114
+
115
+
116
+ def writeFlowKITTI(filename, uv):
117
+ uv = 64.0 * uv + 2**15
118
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
119
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120
+ cv2.imwrite(filename, uv[..., ::-1])
121
+
122
+
123
+ def read_gen(file_name, pil=False):
124
+ ext = splitext(file_name)[-1]
125
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126
+ return Image.open(file_name)
127
+ elif ext == '.bin' or ext == '.raw':
128
+ return np.load(file_name)
129
+ elif ext == '.flo':
130
+ return readFlow(file_name).astype(np.float32)
131
+ elif ext == '.pfm':
132
+ flow = readPFM(file_name).astype(np.float32)
133
+ if len(flow.shape) == 2:
134
+ return flow
135
+ else:
136
+ return flow[:, :, :-1]
137
+ return []
RAFT/utils/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy import interpolate
5
+
6
+
7
+ class InputPadder:
8
+ """ Pads images such that dimensions are divisible by 8 """
9
+ def __init__(self, dims, mode='sintel'):
10
+ self.ht, self.wd = dims[-2:]
11
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13
+ if mode == 'sintel':
14
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15
+ else:
16
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17
+
18
+ def pad(self, *inputs):
19
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20
+
21
+ def unpad(self,x):
22
+ ht, wd = x.shape[-2:]
23
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24
+ return x[..., c[0]:c[1], c[2]:c[3]]
25
+
26
+ def forward_interpolate(flow):
27
+ flow = flow.detach().cpu().numpy()
28
+ dx, dy = flow[0], flow[1]
29
+
30
+ ht, wd = dx.shape
31
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32
+
33
+ x1 = x0 + dx
34
+ y1 = y0 + dy
35
+
36
+ x1 = x1.reshape(-1)
37
+ y1 = y1.reshape(-1)
38
+ dx = dx.reshape(-1)
39
+ dy = dy.reshape(-1)
40
+
41
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42
+ x1 = x1[valid]
43
+ y1 = y1[valid]
44
+ dx = dx[valid]
45
+ dy = dy[valid]
46
+
47
+ flow_x = interpolate.griddata(
48
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49
+
50
+ flow_y = interpolate.griddata(
51
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52
+
53
+ flow = np.stack([flow_x, flow_y], axis=0)
54
+ return torch.from_numpy(flow).float()
55
+
56
+
57
+ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58
+ """ Wrapper for grid_sample, uses pixel coordinates """
59
+ H, W = img.shape[-2:]
60
+ xgrid, ygrid = coords.split([1,1], dim=-1)
61
+ xgrid = 2*xgrid/(W-1) - 1
62
+ ygrid = 2*ygrid/(H-1) - 1
63
+
64
+ grid = torch.cat([xgrid, ygrid], dim=-1)
65
+ img = F.grid_sample(img, grid, align_corners=True)
66
+
67
+ if mask:
68
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69
+ return img, mask.float()
70
+
71
+ return img
72
+
73
+
74
+ def coords_grid(batch, ht, wd):
75
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76
+ coords = torch.stack(coords[::-1], dim=0).float()
77
+ return coords[None].repeat(batch, 1, 1, 1)
78
+
79
+
80
+ def upflow8(flow, mode='bilinear'):
81
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
configs/train_flowcomp.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 2023,
3
+ "save_dir": "experiments_model/",
4
+ "train_data_loader": {
5
+ "name": "youtube-vos",
6
+ "video_root": "your_video_root",
7
+ "flow_root": "your_flow_root",
8
+ "w": 432,
9
+ "h": 240,
10
+ "num_local_frames": 10,
11
+ "num_ref_frames": 1,
12
+ "load_flow": 0
13
+ },
14
+ "losses": {
15
+ "flow_weight": 0.25
16
+ },
17
+ "model": {
18
+ "net": "recurrent_flow_completion"
19
+ },
20
+ "trainer": {
21
+ "version": "trainer_flow_w_edge",
22
+ "type": "Adam",
23
+ "beta1": 0,
24
+ "beta2": 0.99,
25
+ "lr": 5e-5,
26
+ "batch_size": 8,
27
+ "num_workers": 4,
28
+ "num_prefetch_queue": 4,
29
+ "log_freq": 100,
30
+ "save_freq": 5e3,
31
+ "iterations": 700e3,
32
+ "scheduler": {
33
+ "type": "MultiStepLR",
34
+ "milestones": [
35
+ 300e3, 400e3, 500e3, 600e3
36
+ ],
37
+ "gamma": 0.2
38
+ }
39
+ }
40
+ }
configs/train_propainter.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "seed": 2023,
3
+ "save_dir": "experiments_model/",
4
+ "train_data_loader": {
5
+ "name": "youtube-vos",
6
+ "video_root": "your_video_root",
7
+ "flow_root": "your_flow_root",
8
+ "w": 432,
9
+ "h": 240,
10
+ "num_local_frames": 10,
11
+ "num_ref_frames": 6,
12
+ "load_flow": 0
13
+ },
14
+ "losses": {
15
+ "hole_weight": 1,
16
+ "valid_weight": 1,
17
+ "flow_weight": 1,
18
+ "adversarial_weight": 0.01,
19
+ "GAN_LOSS": "hinge",
20
+ "perceptual_weight": 0
21
+ },
22
+ "model": {
23
+ "net": "propainter",
24
+ "no_dis": 0,
25
+ "load_d": 1,
26
+ "interp_mode": "nearest"
27
+ },
28
+ "trainer": {
29
+ "version": "trainer",
30
+ "type": "Adam",
31
+ "beta1": 0,
32
+ "beta2": 0.99,
33
+ "lr": 1e-4,
34
+ "batch_size": 8,
35
+ "num_workers": 8,
36
+ "num_prefetch_queue": 8,
37
+ "log_freq": 100,
38
+ "save_freq": 1e4,
39
+ "iterations": 700e3,
40
+ "scheduler": {
41
+ "type": "MultiStepLR",
42
+ "milestones": [
43
+ 400e3
44
+ ],
45
+ "gamma": 0.1
46
+ }
47
+ }
48
+ }
core/dataset.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ import cv2
6
+ from PIL import Image
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torchvision.transforms as transforms
11
+
12
+ from utils.file_client import FileClient
13
+ from utils.img_util import imfrombytes
14
+ from utils.flow_util import resize_flow, flowread
15
+ from core.utils import (create_random_shape_with_random_motion, Stack,
16
+ ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip)
17
+
18
+
19
+ class TrainDataset(torch.utils.data.Dataset):
20
+ def __init__(self, args: dict):
21
+ self.args = args
22
+ self.video_root = args['video_root']
23
+ self.flow_root = args['flow_root']
24
+ self.num_local_frames = args['num_local_frames']
25
+ self.num_ref_frames = args['num_ref_frames']
26
+ self.size = self.w, self.h = (args['w'], args['h'])
27
+
28
+ self.load_flow = args['load_flow']
29
+ if self.load_flow:
30
+ assert os.path.exists(self.flow_root)
31
+
32
+ json_path = os.path.join('./datasets', args['name'], 'train.json')
33
+
34
+ with open(json_path, 'r') as f:
35
+ self.video_train_dict = json.load(f)
36
+ self.video_names = sorted(list(self.video_train_dict.keys()))
37
+
38
+ # self.video_names = sorted(os.listdir(self.video_root))
39
+ self.video_dict = {}
40
+ self.frame_dict = {}
41
+
42
+ for v in self.video_names:
43
+ frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
44
+ v_len = len(frame_list)
45
+ if v_len > self.num_local_frames + self.num_ref_frames:
46
+ self.video_dict[v] = v_len
47
+ self.frame_dict[v] = frame_list
48
+
49
+
50
+ self.video_names = list(self.video_dict.keys()) # update names
51
+
52
+ self._to_tensors = transforms.Compose([
53
+ Stack(),
54
+ ToTorchFormatTensor(),
55
+ ])
56
+ self.file_client = FileClient('disk')
57
+
58
+ def __len__(self):
59
+ return len(self.video_names)
60
+
61
+ def _sample_index(self, length, sample_length, num_ref_frame=3):
62
+ complete_idx_set = list(range(length))
63
+ pivot = random.randint(0, length - sample_length)
64
+ local_idx = complete_idx_set[pivot:pivot + sample_length]
65
+ remain_idx = list(set(complete_idx_set) - set(local_idx))
66
+ ref_index = sorted(random.sample(remain_idx, num_ref_frame))
67
+
68
+ return local_idx + ref_index
69
+
70
+ def __getitem__(self, index):
71
+ video_name = self.video_names[index]
72
+ # create masks
73
+ all_masks = create_random_shape_with_random_motion(
74
+ self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
75
+
76
+ # create sample index
77
+ selected_index = self._sample_index(self.video_dict[video_name],
78
+ self.num_local_frames,
79
+ self.num_ref_frames)
80
+
81
+ # read video frames
82
+ frames = []
83
+ masks = []
84
+ flows_f, flows_b = [], []
85
+ for idx in selected_index:
86
+ frame_list = self.frame_dict[video_name]
87
+ img_path = os.path.join(self.video_root, video_name, frame_list[idx])
88
+ img_bytes = self.file_client.get(img_path, 'img')
89
+ img = imfrombytes(img_bytes, float32=False)
90
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
91
+ img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
92
+ img = Image.fromarray(img)
93
+
94
+ frames.append(img)
95
+ masks.append(all_masks[idx])
96
+
97
+ if len(frames) <= self.num_local_frames-1 and self.load_flow:
98
+ current_n = frame_list[idx][:-4]
99
+ next_n = frame_list[idx+1][:-4]
100
+ flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
101
+ flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
102
+ flow_f = flowread(flow_f_path, quantize=False)
103
+ flow_b = flowread(flow_b_path, quantize=False)
104
+ flow_f = resize_flow(flow_f, self.h, self.w)
105
+ flow_b = resize_flow(flow_b, self.h, self.w)
106
+ flows_f.append(flow_f)
107
+ flows_b.append(flow_b)
108
+
109
+ if len(frames) == self.num_local_frames: # random reverse
110
+ if random.random() < 0.5:
111
+ frames.reverse()
112
+ masks.reverse()
113
+ if self.load_flow:
114
+ flows_f.reverse()
115
+ flows_b.reverse()
116
+ flows_ = flows_f
117
+ flows_f = flows_b
118
+ flows_b = flows_
119
+
120
+ if self.load_flow:
121
+ frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b)
122
+ else:
123
+ frames = GroupRandomHorizontalFlip()(frames)
124
+
125
+ # normalizate, to tensors
126
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
127
+ mask_tensors = self._to_tensors(masks)
128
+ if self.load_flow:
129
+ flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
130
+ flows_b = np.stack(flows_b, axis=-1)
131
+ flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
132
+ flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
133
+
134
+ # img [-1,1] mask [0,1]
135
+ if self.load_flow:
136
+ return frame_tensors, mask_tensors, flows_f, flows_b, video_name
137
+ else:
138
+ return frame_tensors, mask_tensors, 'None', 'None', video_name
139
+
140
+
141
+ class TestDataset(torch.utils.data.Dataset):
142
+ def __init__(self, args):
143
+ self.args = args
144
+ self.size = self.w, self.h = args['size']
145
+
146
+ self.video_root = args['video_root']
147
+ self.mask_root = args['mask_root']
148
+ self.flow_root = args['flow_root']
149
+
150
+ self.load_flow = args['load_flow']
151
+ if self.load_flow:
152
+ assert os.path.exists(self.flow_root)
153
+ self.video_names = sorted(os.listdir(self.mask_root))
154
+
155
+ self.video_dict = {}
156
+ self.frame_dict = {}
157
+
158
+ for v in self.video_names:
159
+ frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
160
+ v_len = len(frame_list)
161
+ self.video_dict[v] = v_len
162
+ self.frame_dict[v] = frame_list
163
+
164
+ self._to_tensors = transforms.Compose([
165
+ Stack(),
166
+ ToTorchFormatTensor(),
167
+ ])
168
+ self.file_client = FileClient('disk')
169
+
170
+ def __len__(self):
171
+ return len(self.video_names)
172
+
173
+ def __getitem__(self, index):
174
+ video_name = self.video_names[index]
175
+ selected_index = list(range(self.video_dict[video_name]))
176
+
177
+ # read video frames
178
+ frames = []
179
+ masks = []
180
+ flows_f, flows_b = [], []
181
+ for idx in selected_index:
182
+ frame_list = self.frame_dict[video_name]
183
+ frame_path = os.path.join(self.video_root, video_name, frame_list[idx])
184
+
185
+ img_bytes = self.file_client.get(frame_path, 'input')
186
+ img = imfrombytes(img_bytes, float32=False)
187
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
188
+ img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
189
+ img = Image.fromarray(img)
190
+
191
+ frames.append(img)
192
+
193
+ mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png')
194
+ mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L')
195
+
196
+ # origin: 0 indicates missing. now: 1 indicates missing
197
+ mask = np.asarray(mask)
198
+ m = np.array(mask > 0).astype(np.uint8)
199
+
200
+ m = cv2.dilate(m,
201
+ cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
202
+ iterations=4)
203
+ mask = Image.fromarray(m * 255)
204
+ masks.append(mask)
205
+
206
+ if len(frames) <= len(selected_index)-1 and self.load_flow:
207
+ current_n = frame_list[idx][:-4]
208
+ next_n = frame_list[idx+1][:-4]
209
+ flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
210
+ flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
211
+ flow_f = flowread(flow_f_path, quantize=False)
212
+ flow_b = flowread(flow_b_path, quantize=False)
213
+ flow_f = resize_flow(flow_f, self.h, self.w)
214
+ flow_b = resize_flow(flow_b, self.h, self.w)
215
+ flows_f.append(flow_f)
216
+ flows_b.append(flow_b)
217
+
218
+ # normalizate, to tensors
219
+ frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
220
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
221
+ mask_tensors = self._to_tensors(masks)
222
+
223
+ if self.load_flow:
224
+ flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
225
+ flows_b = np.stack(flows_b, axis=-1)
226
+ flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
227
+ flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
228
+
229
+ if self.load_flow:
230
+ return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL
231
+ else:
232
+ return frame_tensors, mask_tensors, 'None', 'None', video_name
core/dist.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ def get_world_size():
6
+ """Find OMPI world size without calling mpi functions
7
+ :rtype: int
8
+ """
9
+ if os.environ.get('PMI_SIZE') is not None:
10
+ return int(os.environ.get('PMI_SIZE') or 1)
11
+ elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
12
+ return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
13
+ else:
14
+ return torch.cuda.device_count()
15
+
16
+
17
+ def get_global_rank():
18
+ """Find OMPI world rank without calling mpi functions
19
+ :rtype: int
20
+ """
21
+ if os.environ.get('PMI_RANK') is not None:
22
+ return int(os.environ.get('PMI_RANK') or 0)
23
+ elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
24
+ return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
25
+ else:
26
+ return 0
27
+
28
+
29
+ def get_local_rank():
30
+ """Find OMPI local rank without calling mpi functions
31
+ :rtype: int
32
+ """
33
+ if os.environ.get('MPI_LOCALRANKID') is not None:
34
+ return int(os.environ.get('MPI_LOCALRANKID') or 0)
35
+ elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
36
+ return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
37
+ else:
38
+ return 0
39
+
40
+
41
+ def get_master_ip():
42
+ if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
43
+ return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
44
+ elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
45
+ return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
46
+ else:
47
+ return "127.0.0.1"
core/loss.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lpips
4
+ from model.vgg_arch import VGGFeatureExtractor
5
+
6
+ class PerceptualLoss(nn.Module):
7
+ """Perceptual loss with commonly used style loss.
8
+
9
+ Args:
10
+ layer_weights (dict): The weight for each layer of vgg feature.
11
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
12
+ feature layer (before relu5_4) will be extracted with weight
13
+ 1.0 in calculting losses.
14
+ vgg_type (str): The type of vgg network used as feature extractor.
15
+ Default: 'vgg19'.
16
+ use_input_norm (bool): If True, normalize the input image in vgg.
17
+ Default: True.
18
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
19
+ Default: False.
20
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
21
+ loss will be calculated and the loss will multiplied by the
22
+ weight. Default: 1.0.
23
+ style_weight (float): If `style_weight > 0`, the style loss will be
24
+ calculated and the loss will multiplied by the weight.
25
+ Default: 0.
26
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
27
+ """
28
+
29
+ def __init__(self,
30
+ layer_weights,
31
+ vgg_type='vgg19',
32
+ use_input_norm=True,
33
+ range_norm=False,
34
+ perceptual_weight=1.0,
35
+ style_weight=0.,
36
+ criterion='l1'):
37
+ super(PerceptualLoss, self).__init__()
38
+ self.perceptual_weight = perceptual_weight
39
+ self.style_weight = style_weight
40
+ self.layer_weights = layer_weights
41
+ self.vgg = VGGFeatureExtractor(
42
+ layer_name_list=list(layer_weights.keys()),
43
+ vgg_type=vgg_type,
44
+ use_input_norm=use_input_norm,
45
+ range_norm=range_norm)
46
+
47
+ self.criterion_type = criterion
48
+ if self.criterion_type == 'l1':
49
+ self.criterion = torch.nn.L1Loss()
50
+ elif self.criterion_type == 'l2':
51
+ self.criterion = torch.nn.L2loss()
52
+ elif self.criterion_type == 'mse':
53
+ self.criterion = torch.nn.MSELoss(reduction='mean')
54
+ elif self.criterion_type == 'fro':
55
+ self.criterion = None
56
+ else:
57
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
58
+
59
+ def forward(self, x, gt):
60
+ """Forward function.
61
+
62
+ Args:
63
+ x (Tensor): Input tensor with shape (n, c, h, w).
64
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
65
+
66
+ Returns:
67
+ Tensor: Forward results.
68
+ """
69
+ # extract vgg features
70
+ x_features = self.vgg(x)
71
+ gt_features = self.vgg(gt.detach())
72
+
73
+ # calculate perceptual loss
74
+ if self.perceptual_weight > 0:
75
+ percep_loss = 0
76
+ for k in x_features.keys():
77
+ if self.criterion_type == 'fro':
78
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
79
+ else:
80
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
81
+ percep_loss *= self.perceptual_weight
82
+ else:
83
+ percep_loss = None
84
+
85
+ # calculate style loss
86
+ if self.style_weight > 0:
87
+ style_loss = 0
88
+ for k in x_features.keys():
89
+ if self.criterion_type == 'fro':
90
+ style_loss += torch.norm(
91
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
92
+ else:
93
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
94
+ gt_features[k])) * self.layer_weights[k]
95
+ style_loss *= self.style_weight
96
+ else:
97
+ style_loss = None
98
+
99
+ return percep_loss, style_loss
100
+
101
+ def _gram_mat(self, x):
102
+ """Calculate Gram matrix.
103
+
104
+ Args:
105
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
106
+
107
+ Returns:
108
+ torch.Tensor: Gram matrix.
109
+ """
110
+ n, c, h, w = x.size()
111
+ features = x.view(n, c, w * h)
112
+ features_t = features.transpose(1, 2)
113
+ gram = features.bmm(features_t) / (c * h * w)
114
+ return gram
115
+
116
+ class LPIPSLoss(nn.Module):
117
+ def __init__(self,
118
+ loss_weight=1.0,
119
+ use_input_norm=True,
120
+ range_norm=False,):
121
+ super(LPIPSLoss, self).__init__()
122
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
123
+ self.loss_weight = loss_weight
124
+ self.use_input_norm = use_input_norm
125
+ self.range_norm = range_norm
126
+
127
+ if self.use_input_norm:
128
+ # the mean is for image with range [0, 1]
129
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
130
+ # the std is for image with range [0, 1]
131
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
132
+
133
+ def forward(self, pred, target):
134
+ if self.range_norm:
135
+ pred = (pred + 1) / 2
136
+ target = (target + 1) / 2
137
+ if self.use_input_norm:
138
+ pred = (pred - self.mean) / self.std
139
+ target = (target - self.mean) / self.std
140
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
141
+ return self.loss_weight * lpips_loss.mean(), None
142
+
143
+
144
+ class AdversarialLoss(nn.Module):
145
+ r"""
146
+ Adversarial loss
147
+ https://arxiv.org/abs/1711.10337
148
+ """
149
+ def __init__(self,
150
+ type='nsgan',
151
+ target_real_label=1.0,
152
+ target_fake_label=0.0):
153
+ r"""
154
+ type = nsgan | lsgan | hinge
155
+ """
156
+ super(AdversarialLoss, self).__init__()
157
+ self.type = type
158
+ self.register_buffer('real_label', torch.tensor(target_real_label))
159
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
160
+
161
+ if type == 'nsgan':
162
+ self.criterion = nn.BCELoss()
163
+ elif type == 'lsgan':
164
+ self.criterion = nn.MSELoss()
165
+ elif type == 'hinge':
166
+ self.criterion = nn.ReLU()
167
+
168
+ def __call__(self, outputs, is_real, is_disc=None):
169
+ if self.type == 'hinge':
170
+ if is_disc:
171
+ if is_real:
172
+ outputs = -outputs
173
+ return self.criterion(1 + outputs).mean()
174
+ else:
175
+ return (-outputs).mean()
176
+ else:
177
+ labels = (self.real_label
178
+ if is_real else self.fake_label).expand_as(outputs)
179
+ loss = self.criterion(outputs, labels)
180
+ return loss
core/lr_scheduler.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LR scheduler from BasicSR https://github.com/xinntao/BasicSR
3
+ """
4
+ import math
5
+ from collections import Counter
6
+ from torch.optim.lr_scheduler import _LRScheduler
7
+
8
+
9
+ class MultiStepRestartLR(_LRScheduler):
10
+ """ MultiStep with restarts learning rate scheme.
11
+ Args:
12
+ optimizer (torch.nn.optimizer): Torch optimizer.
13
+ milestones (list): Iterations that will decrease learning rate.
14
+ gamma (float): Decrease ratio. Default: 0.1.
15
+ restarts (list): Restart iterations. Default: [0].
16
+ restart_weights (list): Restart weights at each restart iteration.
17
+ Default: [1].
18
+ last_epoch (int): Used in _LRScheduler. Default: -1.
19
+ """
20
+ def __init__(self,
21
+ optimizer,
22
+ milestones,
23
+ gamma=0.1,
24
+ restarts=(0, ),
25
+ restart_weights=(1, ),
26
+ last_epoch=-1):
27
+ self.milestones = Counter(milestones)
28
+ self.gamma = gamma
29
+ self.restarts = restarts
30
+ self.restart_weights = restart_weights
31
+ assert len(self.restarts) == len(
32
+ self.restart_weights), 'restarts and their weights do not match.'
33
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34
+
35
+ def get_lr(self):
36
+ if self.last_epoch in self.restarts:
37
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38
+ return [
39
+ group['initial_lr'] * weight
40
+ for group in self.optimizer.param_groups
41
+ ]
42
+ if self.last_epoch not in self.milestones:
43
+ return [group['lr'] for group in self.optimizer.param_groups]
44
+ return [
45
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
46
+ for group in self.optimizer.param_groups
47
+ ]
48
+
49
+
50
+ def get_position_from_periods(iteration, cumulative_period):
51
+ """Get the position from a period list.
52
+ It will return the index of the right-closest number in the period list.
53
+ For example, the cumulative_period = [100, 200, 300, 400],
54
+ if iteration == 50, return 0;
55
+ if iteration == 210, return 2;
56
+ if iteration == 300, return 2.
57
+ Args:
58
+ iteration (int): Current iteration.
59
+ cumulative_period (list[int]): Cumulative period list.
60
+ Returns:
61
+ int: The position of the right-closest number in the period list.
62
+ """
63
+ for i, period in enumerate(cumulative_period):
64
+ if iteration <= period:
65
+ return i
66
+
67
+
68
+ class CosineAnnealingRestartLR(_LRScheduler):
69
+ """ Cosine annealing with restarts learning rate scheme.
70
+ An example of config:
71
+ periods = [10, 10, 10, 10]
72
+ restart_weights = [1, 0.5, 0.5, 0.5]
73
+ eta_min=1e-7
74
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
75
+ scheduler will restart with the weights in restart_weights.
76
+ Args:
77
+ optimizer (torch.nn.optimizer): Torch optimizer.
78
+ periods (list): Period for each cosine anneling cycle.
79
+ restart_weights (list): Restart weights at each restart iteration.
80
+ Default: [1].
81
+ eta_min (float): The mimimum lr. Default: 0.
82
+ last_epoch (int): Used in _LRScheduler. Default: -1.
83
+ """
84
+ def __init__(self,
85
+ optimizer,
86
+ periods,
87
+ restart_weights=(1, ),
88
+ eta_min=1e-7,
89
+ last_epoch=-1):
90
+ self.periods = periods
91
+ self.restart_weights = restart_weights
92
+ self.eta_min = eta_min
93
+ assert (len(self.periods) == len(self.restart_weights)
94
+ ), 'periods and restart_weights should have the same length.'
95
+ self.cumulative_period = [
96
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
97
+ ]
98
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
99
+
100
+ def get_lr(self):
101
+ idx = get_position_from_periods(self.last_epoch,
102
+ self.cumulative_period)
103
+ current_weight = self.restart_weights[idx]
104
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
105
+ current_period = self.periods[idx]
106
+
107
+ return [
108
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
109
+ (1 + math.cos(math.pi * (
110
+ (self.last_epoch - nearest_restart) / current_period)))
111
+ for base_lr in self.base_lrs
112
+ ]
core/metrics.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage import measure
3
+ from scipy import linalg
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from core.utils import to_tensors
10
+
11
+
12
+ def calculate_epe(flow1, flow2):
13
+ """Calculate End point errors."""
14
+
15
+ epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
16
+ epe = epe.view(-1)
17
+ return epe.mean().item()
18
+
19
+
20
+ def calculate_psnr(img1, img2):
21
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
22
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
23
+ Args:
24
+ img1 (ndarray): Images with range [0, 255].
25
+ img2 (ndarray): Images with range [0, 255].
26
+ Returns:
27
+ float: psnr result.
28
+ """
29
+
30
+ assert img1.shape == img2.shape, \
31
+ (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
32
+
33
+ mse = np.mean((img1 - img2)**2)
34
+ if mse == 0:
35
+ return float('inf')
36
+ return 20. * np.log10(255. / np.sqrt(mse))
37
+
38
+
39
+ def calc_psnr_and_ssim(img1, img2):
40
+ """Calculate PSNR and SSIM for images.
41
+ img1: ndarray, range [0, 255]
42
+ img2: ndarray, range [0, 255]
43
+ """
44
+ img1 = img1.astype(np.float64)
45
+ img2 = img2.astype(np.float64)
46
+
47
+ psnr = calculate_psnr(img1, img2)
48
+ ssim = measure.compare_ssim(img1,
49
+ img2,
50
+ data_range=255,
51
+ multichannel=True,
52
+ win_size=65)
53
+
54
+ return psnr, ssim
55
+
56
+
57
+ ###########################
58
+ # I3D models
59
+ ###########################
60
+
61
+
62
+ def init_i3d_model(i3d_model_path):
63
+ print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
64
+ i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
65
+ i3d_model.load_state_dict(torch.load(i3d_model_path))
66
+ i3d_model.to(torch.device('cuda:0'))
67
+ return i3d_model
68
+
69
+
70
+ def calculate_i3d_activations(video1, video2, i3d_model, device):
71
+ """Calculate VFID metric.
72
+ video1: list[PIL.Image]
73
+ video2: list[PIL.Image]
74
+ """
75
+ video1 = to_tensors()(video1).unsqueeze(0).to(device)
76
+ video2 = to_tensors()(video2).unsqueeze(0).to(device)
77
+ video1_activations = get_i3d_activations(
78
+ video1, i3d_model).cpu().numpy().flatten()
79
+ video2_activations = get_i3d_activations(
80
+ video2, i3d_model).cpu().numpy().flatten()
81
+
82
+ return video1_activations, video2_activations
83
+
84
+
85
+ def calculate_vfid(real_activations, fake_activations):
86
+ """
87
+ Given two distribution of features, compute the FID score between them
88
+ Params:
89
+ real_activations: list[ndarray]
90
+ fake_activations: list[ndarray]
91
+ """
92
+ m1 = np.mean(real_activations, axis=0)
93
+ m2 = np.mean(fake_activations, axis=0)
94
+ s1 = np.cov(real_activations, rowvar=False)
95
+ s2 = np.cov(fake_activations, rowvar=False)
96
+ return calculate_frechet_distance(m1, s1, m2, s2)
97
+
98
+
99
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
100
+ """Numpy implementation of the Frechet Distance.
101
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
102
+ and X_2 ~ N(mu_2, C_2) is
103
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
104
+ Stable version by Dougal J. Sutherland.
105
+ Params:
106
+ -- mu1 : Numpy array containing the activations of a layer of the
107
+ inception net (like returned by the function 'get_predictions')
108
+ for generated samples.
109
+ -- mu2 : The sample mean over activations, precalculated on an
110
+ representive data set.
111
+ -- sigma1: The covariance matrix over activations for generated samples.
112
+ -- sigma2: The covariance matrix over activations, precalculated on an
113
+ representive data set.
114
+ Returns:
115
+ -- : The Frechet Distance.
116
+ """
117
+
118
+ mu1 = np.atleast_1d(mu1)
119
+ mu2 = np.atleast_1d(mu2)
120
+
121
+ sigma1 = np.atleast_2d(sigma1)
122
+ sigma2 = np.atleast_2d(sigma2)
123
+
124
+ assert mu1.shape == mu2.shape, \
125
+ 'Training and test mean vectors have different lengths'
126
+ assert sigma1.shape == sigma2.shape, \
127
+ 'Training and test covariances have different dimensions'
128
+
129
+ diff = mu1 - mu2
130
+
131
+ # Product might be almost singular
132
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
133
+ if not np.isfinite(covmean).all():
134
+ msg = ('fid calculation produces singular product; '
135
+ 'adding %s to diagonal of cov estimates') % eps
136
+ print(msg)
137
+ offset = np.eye(sigma1.shape[0]) * eps
138
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
139
+
140
+ # Numerical error might give slight imaginary component
141
+ if np.iscomplexobj(covmean):
142
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
143
+ m = np.max(np.abs(covmean.imag))
144
+ raise ValueError('Imaginary component {}'.format(m))
145
+ covmean = covmean.real
146
+
147
+ tr_covmean = np.trace(covmean)
148
+
149
+ return (diff.dot(diff) + np.trace(sigma1) + # NOQA
150
+ np.trace(sigma2) - 2 * tr_covmean)
151
+
152
+
153
+ def get_i3d_activations(batched_video,
154
+ i3d_model,
155
+ target_endpoint='Logits',
156
+ flatten=True,
157
+ grad_enabled=False):
158
+ """
159
+ Get features from i3d model and flatten them to 1d feature,
160
+ valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
161
+ VALID_ENDPOINTS = (
162
+ 'Conv3d_1a_7x7',
163
+ 'MaxPool3d_2a_3x3',
164
+ 'Conv3d_2b_1x1',
165
+ 'Conv3d_2c_3x3',
166
+ 'MaxPool3d_3a_3x3',
167
+ 'Mixed_3b',
168
+ 'Mixed_3c',
169
+ 'MaxPool3d_4a_3x3',
170
+ 'Mixed_4b',
171
+ 'Mixed_4c',
172
+ 'Mixed_4d',
173
+ 'Mixed_4e',
174
+ 'Mixed_4f',
175
+ 'MaxPool3d_5a_2x2',
176
+ 'Mixed_5b',
177
+ 'Mixed_5c',
178
+ 'Logits',
179
+ 'Predictions',
180
+ )
181
+ """
182
+ with torch.set_grad_enabled(grad_enabled):
183
+ feat = i3d_model.extract_features(batched_video.transpose(1, 2),
184
+ target_endpoint)
185
+ if flatten:
186
+ feat = feat.view(feat.size(0), -1)
187
+
188
+ return feat
189
+
190
+
191
+ # This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
192
+ # I only fix flake8 errors and do some cleaning here
193
+
194
+
195
+ class MaxPool3dSamePadding(nn.MaxPool3d):
196
+ def compute_pad(self, dim, s):
197
+ if s % self.stride[dim] == 0:
198
+ return max(self.kernel_size[dim] - self.stride[dim], 0)
199
+ else:
200
+ return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
201
+
202
+ def forward(self, x):
203
+ # compute 'same' padding
204
+ (batch, channel, t, h, w) = x.size()
205
+ pad_t = self.compute_pad(0, t)
206
+ pad_h = self.compute_pad(1, h)
207
+ pad_w = self.compute_pad(2, w)
208
+
209
+ pad_t_f = pad_t // 2
210
+ pad_t_b = pad_t - pad_t_f
211
+ pad_h_f = pad_h // 2
212
+ pad_h_b = pad_h - pad_h_f
213
+ pad_w_f = pad_w // 2
214
+ pad_w_b = pad_w - pad_w_f
215
+
216
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
217
+ x = F.pad(x, pad)
218
+ return super(MaxPool3dSamePadding, self).forward(x)
219
+
220
+
221
+ class Unit3D(nn.Module):
222
+ def __init__(self,
223
+ in_channels,
224
+ output_channels,
225
+ kernel_shape=(1, 1, 1),
226
+ stride=(1, 1, 1),
227
+ padding=0,
228
+ activation_fn=F.relu,
229
+ use_batch_norm=True,
230
+ use_bias=False,
231
+ name='unit_3d'):
232
+ """Initializes Unit3D module."""
233
+ super(Unit3D, self).__init__()
234
+
235
+ self._output_channels = output_channels
236
+ self._kernel_shape = kernel_shape
237
+ self._stride = stride
238
+ self._use_batch_norm = use_batch_norm
239
+ self._activation_fn = activation_fn
240
+ self._use_bias = use_bias
241
+ self.name = name
242
+ self.padding = padding
243
+
244
+ self.conv3d = nn.Conv3d(
245
+ in_channels=in_channels,
246
+ out_channels=self._output_channels,
247
+ kernel_size=self._kernel_shape,
248
+ stride=self._stride,
249
+ padding=0, # we always want padding to be 0 here. We will
250
+ # dynamically pad based on input size in forward function
251
+ bias=self._use_bias)
252
+
253
+ if self._use_batch_norm:
254
+ self.bn = nn.BatchNorm3d(self._output_channels,
255
+ eps=0.001,
256
+ momentum=0.01)
257
+
258
+ def compute_pad(self, dim, s):
259
+ if s % self._stride[dim] == 0:
260
+ return max(self._kernel_shape[dim] - self._stride[dim], 0)
261
+ else:
262
+ return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
263
+
264
+ def forward(self, x):
265
+ # compute 'same' padding
266
+ (batch, channel, t, h, w) = x.size()
267
+ pad_t = self.compute_pad(0, t)
268
+ pad_h = self.compute_pad(1, h)
269
+ pad_w = self.compute_pad(2, w)
270
+
271
+ pad_t_f = pad_t // 2
272
+ pad_t_b = pad_t - pad_t_f
273
+ pad_h_f = pad_h // 2
274
+ pad_h_b = pad_h - pad_h_f
275
+ pad_w_f = pad_w // 2
276
+ pad_w_b = pad_w - pad_w_f
277
+
278
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
279
+ x = F.pad(x, pad)
280
+
281
+ x = self.conv3d(x)
282
+ if self._use_batch_norm:
283
+ x = self.bn(x)
284
+ if self._activation_fn is not None:
285
+ x = self._activation_fn(x)
286
+ return x
287
+
288
+
289
+ class InceptionModule(nn.Module):
290
+ def __init__(self, in_channels, out_channels, name):
291
+ super(InceptionModule, self).__init__()
292
+
293
+ self.b0 = Unit3D(in_channels=in_channels,
294
+ output_channels=out_channels[0],
295
+ kernel_shape=[1, 1, 1],
296
+ padding=0,
297
+ name=name + '/Branch_0/Conv3d_0a_1x1')
298
+ self.b1a = Unit3D(in_channels=in_channels,
299
+ output_channels=out_channels[1],
300
+ kernel_shape=[1, 1, 1],
301
+ padding=0,
302
+ name=name + '/Branch_1/Conv3d_0a_1x1')
303
+ self.b1b = Unit3D(in_channels=out_channels[1],
304
+ output_channels=out_channels[2],
305
+ kernel_shape=[3, 3, 3],
306
+ name=name + '/Branch_1/Conv3d_0b_3x3')
307
+ self.b2a = Unit3D(in_channels=in_channels,
308
+ output_channels=out_channels[3],
309
+ kernel_shape=[1, 1, 1],
310
+ padding=0,
311
+ name=name + '/Branch_2/Conv3d_0a_1x1')
312
+ self.b2b = Unit3D(in_channels=out_channels[3],
313
+ output_channels=out_channels[4],
314
+ kernel_shape=[3, 3, 3],
315
+ name=name + '/Branch_2/Conv3d_0b_3x3')
316
+ self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
317
+ stride=(1, 1, 1),
318
+ padding=0)
319
+ self.b3b = Unit3D(in_channels=in_channels,
320
+ output_channels=out_channels[5],
321
+ kernel_shape=[1, 1, 1],
322
+ padding=0,
323
+ name=name + '/Branch_3/Conv3d_0b_1x1')
324
+ self.name = name
325
+
326
+ def forward(self, x):
327
+ b0 = self.b0(x)
328
+ b1 = self.b1b(self.b1a(x))
329
+ b2 = self.b2b(self.b2a(x))
330
+ b3 = self.b3b(self.b3a(x))
331
+ return torch.cat([b0, b1, b2, b3], dim=1)
332
+
333
+
334
+ class InceptionI3d(nn.Module):
335
+ """Inception-v1 I3D architecture.
336
+ The model is introduced in:
337
+ Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
338
+ Joao Carreira, Andrew Zisserman
339
+ https://arxiv.org/pdf/1705.07750v1.pdf.
340
+ See also the Inception architecture, introduced in:
341
+ Going deeper with convolutions
342
+ Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
343
+ Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
344
+ http://arxiv.org/pdf/1409.4842v1.pdf.
345
+ """
346
+
347
+ # Endpoints of the model in order. During construction, all the endpoints up
348
+ # to a designated `final_endpoint` are returned in a dictionary as the
349
+ # second return value.
350
+ VALID_ENDPOINTS = (
351
+ 'Conv3d_1a_7x7',
352
+ 'MaxPool3d_2a_3x3',
353
+ 'Conv3d_2b_1x1',
354
+ 'Conv3d_2c_3x3',
355
+ 'MaxPool3d_3a_3x3',
356
+ 'Mixed_3b',
357
+ 'Mixed_3c',
358
+ 'MaxPool3d_4a_3x3',
359
+ 'Mixed_4b',
360
+ 'Mixed_4c',
361
+ 'Mixed_4d',
362
+ 'Mixed_4e',
363
+ 'Mixed_4f',
364
+ 'MaxPool3d_5a_2x2',
365
+ 'Mixed_5b',
366
+ 'Mixed_5c',
367
+ 'Logits',
368
+ 'Predictions',
369
+ )
370
+
371
+ def __init__(self,
372
+ num_classes=400,
373
+ spatial_squeeze=True,
374
+ final_endpoint='Logits',
375
+ name='inception_i3d',
376
+ in_channels=3,
377
+ dropout_keep_prob=0.5):
378
+ """Initializes I3D model instance.
379
+ Args:
380
+ num_classes: The number of outputs in the logit layer (default 400, which
381
+ matches the Kinetics dataset).
382
+ spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
383
+ before returning (default True).
384
+ final_endpoint: The model contains many possible endpoints.
385
+ `final_endpoint` specifies the last endpoint for the model to be built
386
+ up to. In addition to the output at `final_endpoint`, all the outputs
387
+ at endpoints up to `final_endpoint` will also be returned, in a
388
+ dictionary. `final_endpoint` must be one of
389
+ InceptionI3d.VALID_ENDPOINTS (default 'Logits').
390
+ name: A string (optional). The name of this module.
391
+ Raises:
392
+ ValueError: if `final_endpoint` is not recognized.
393
+ """
394
+
395
+ if final_endpoint not in self.VALID_ENDPOINTS:
396
+ raise ValueError('Unknown final endpoint %s' % final_endpoint)
397
+
398
+ super(InceptionI3d, self).__init__()
399
+ self._num_classes = num_classes
400
+ self._spatial_squeeze = spatial_squeeze
401
+ self._final_endpoint = final_endpoint
402
+ self.logits = None
403
+
404
+ if self._final_endpoint not in self.VALID_ENDPOINTS:
405
+ raise ValueError('Unknown final endpoint %s' %
406
+ self._final_endpoint)
407
+
408
+ self.end_points = {}
409
+ end_point = 'Conv3d_1a_7x7'
410
+ self.end_points[end_point] = Unit3D(in_channels=in_channels,
411
+ output_channels=64,
412
+ kernel_shape=[7, 7, 7],
413
+ stride=(2, 2, 2),
414
+ padding=(3, 3, 3),
415
+ name=name + end_point)
416
+ if self._final_endpoint == end_point:
417
+ return
418
+
419
+ end_point = 'MaxPool3d_2a_3x3'
420
+ self.end_points[end_point] = MaxPool3dSamePadding(
421
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
422
+ if self._final_endpoint == end_point:
423
+ return
424
+
425
+ end_point = 'Conv3d_2b_1x1'
426
+ self.end_points[end_point] = Unit3D(in_channels=64,
427
+ output_channels=64,
428
+ kernel_shape=[1, 1, 1],
429
+ padding=0,
430
+ name=name + end_point)
431
+ if self._final_endpoint == end_point:
432
+ return
433
+
434
+ end_point = 'Conv3d_2c_3x3'
435
+ self.end_points[end_point] = Unit3D(in_channels=64,
436
+ output_channels=192,
437
+ kernel_shape=[3, 3, 3],
438
+ padding=1,
439
+ name=name + end_point)
440
+ if self._final_endpoint == end_point:
441
+ return
442
+
443
+ end_point = 'MaxPool3d_3a_3x3'
444
+ self.end_points[end_point] = MaxPool3dSamePadding(
445
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
446
+ if self._final_endpoint == end_point:
447
+ return
448
+
449
+ end_point = 'Mixed_3b'
450
+ self.end_points[end_point] = InceptionModule(192,
451
+ [64, 96, 128, 16, 32, 32],
452
+ name + end_point)
453
+ if self._final_endpoint == end_point:
454
+ return
455
+
456
+ end_point = 'Mixed_3c'
457
+ self.end_points[end_point] = InceptionModule(
458
+ 256, [128, 128, 192, 32, 96, 64], name + end_point)
459
+ if self._final_endpoint == end_point:
460
+ return
461
+
462
+ end_point = 'MaxPool3d_4a_3x3'
463
+ self.end_points[end_point] = MaxPool3dSamePadding(
464
+ kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
465
+ if self._final_endpoint == end_point:
466
+ return
467
+
468
+ end_point = 'Mixed_4b'
469
+ self.end_points[end_point] = InceptionModule(
470
+ 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
471
+ if self._final_endpoint == end_point:
472
+ return
473
+
474
+ end_point = 'Mixed_4c'
475
+ self.end_points[end_point] = InceptionModule(
476
+ 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
477
+ if self._final_endpoint == end_point:
478
+ return
479
+
480
+ end_point = 'Mixed_4d'
481
+ self.end_points[end_point] = InceptionModule(
482
+ 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
483
+ if self._final_endpoint == end_point:
484
+ return
485
+
486
+ end_point = 'Mixed_4e'
487
+ self.end_points[end_point] = InceptionModule(
488
+ 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
489
+ if self._final_endpoint == end_point:
490
+ return
491
+
492
+ end_point = 'Mixed_4f'
493
+ self.end_points[end_point] = InceptionModule(
494
+ 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
495
+ name + end_point)
496
+ if self._final_endpoint == end_point:
497
+ return
498
+
499
+ end_point = 'MaxPool3d_5a_2x2'
500
+ self.end_points[end_point] = MaxPool3dSamePadding(
501
+ kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
502
+ if self._final_endpoint == end_point:
503
+ return
504
+
505
+ end_point = 'Mixed_5b'
506
+ self.end_points[end_point] = InceptionModule(
507
+ 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
508
+ name + end_point)
509
+ if self._final_endpoint == end_point:
510
+ return
511
+
512
+ end_point = 'Mixed_5c'
513
+ self.end_points[end_point] = InceptionModule(
514
+ 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
515
+ name + end_point)
516
+ if self._final_endpoint == end_point:
517
+ return
518
+
519
+ end_point = 'Logits'
520
+ self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
521
+ self.dropout = nn.Dropout(dropout_keep_prob)
522
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
523
+ output_channels=self._num_classes,
524
+ kernel_shape=[1, 1, 1],
525
+ padding=0,
526
+ activation_fn=None,
527
+ use_batch_norm=False,
528
+ use_bias=True,
529
+ name='logits')
530
+
531
+ self.build()
532
+
533
+ def replace_logits(self, num_classes):
534
+ self._num_classes = num_classes
535
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
536
+ output_channels=self._num_classes,
537
+ kernel_shape=[1, 1, 1],
538
+ padding=0,
539
+ activation_fn=None,
540
+ use_batch_norm=False,
541
+ use_bias=True,
542
+ name='logits')
543
+
544
+ def build(self):
545
+ for k in self.end_points.keys():
546
+ self.add_module(k, self.end_points[k])
547
+
548
+ def forward(self, x):
549
+ for end_point in self.VALID_ENDPOINTS:
550
+ if end_point in self.end_points:
551
+ x = self._modules[end_point](
552
+ x) # use _modules to work with dataparallel
553
+
554
+ x = self.logits(self.dropout(self.avg_pool(x)))
555
+ if self._spatial_squeeze:
556
+ logits = x.squeeze(3).squeeze(3)
557
+ # logits is batch X time X classes, which is what we want to work with
558
+ return logits
559
+
560
+ def extract_features(self, x, target_endpoint='Logits'):
561
+ for end_point in self.VALID_ENDPOINTS:
562
+ if end_point in self.end_points:
563
+ x = self._modules[end_point](x)
564
+ if end_point == target_endpoint:
565
+ break
566
+ if target_endpoint == 'Logits':
567
+ return x.mean(4).mean(3).mean(2)
568
+ else:
569
+ return x
core/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
core/trainer.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import logging
4
+ import importlib
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ import torchvision
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
17
+ from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss
18
+ from core.dataset import TrainDataset
19
+
20
+ from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
21
+ from model.recurrent_flow_completion import RecurrentFlowCompleteNet
22
+
23
+ from RAFT.utils.flow_viz_pt import flow_to_image
24
+
25
+
26
+ class Trainer:
27
+ def __init__(self, config):
28
+ self.config = config
29
+ self.epoch = 0
30
+ self.iteration = 0
31
+ self.num_local_frames = config['train_data_loader']['num_local_frames']
32
+ self.num_ref_frames = config['train_data_loader']['num_ref_frames']
33
+
34
+ # setup data set and data loader
35
+ self.train_dataset = TrainDataset(config['train_data_loader'])
36
+
37
+ self.train_sampler = None
38
+ self.train_args = config['trainer']
39
+ if config['distributed']:
40
+ self.train_sampler = DistributedSampler(
41
+ self.train_dataset,
42
+ num_replicas=config['world_size'],
43
+ rank=config['global_rank'])
44
+
45
+ dataloader_args = dict(
46
+ dataset=self.train_dataset,
47
+ batch_size=self.train_args['batch_size'] // config['world_size'],
48
+ shuffle=(self.train_sampler is None),
49
+ num_workers=self.train_args['num_workers'],
50
+ sampler=self.train_sampler,
51
+ drop_last=True)
52
+
53
+ self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
54
+ self.prefetcher = CPUPrefetcher(self.train_loader)
55
+
56
+ # set loss functions
57
+ self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
58
+ self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
59
+ self.l1_loss = nn.L1Loss()
60
+ # self.perc_loss = PerceptualLoss(
61
+ # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5},
62
+ # use_input_norm=True,
63
+ # range_norm=True,
64
+ # criterion='l1'
65
+ # ).to(self.config['device'])
66
+
67
+ if self.config['losses']['perceptual_weight'] > 0:
68
+ self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device'])
69
+
70
+ # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device'])
71
+ # self.flow_comp_loss = FlowCompletionLoss(self.config['device'])
72
+
73
+ # set raft
74
+ self.fix_raft = RAFT_bi(device = self.config['device'])
75
+ self.fix_flow_complete = RecurrentFlowCompleteNet('/mnt/lustre/sczhou/VQGANs/CodeMOVI/experiments_model/recurrent_flow_completion_v5_train_flowcomp_v5/gen_760000.pth')
76
+ for p in self.fix_flow_complete.parameters():
77
+ p.requires_grad = False
78
+ self.fix_flow_complete.to(self.config['device'])
79
+ self.fix_flow_complete.eval()
80
+
81
+ # self.flow_loss = FlowLoss()
82
+
83
+ # setup models including generator and discriminator
84
+ net = importlib.import_module('model.' + config['model']['net'])
85
+ self.netG = net.InpaintGenerator()
86
+ # print(self.netG)
87
+ self.netG = self.netG.to(self.config['device'])
88
+ if not self.config['model'].get('no_dis', False):
89
+ if self.config['model'].get('dis_2d', False):
90
+ self.netD = net.Discriminator_2D(
91
+ in_channels=3,
92
+ use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
93
+ else:
94
+ self.netD = net.Discriminator(
95
+ in_channels=3,
96
+ use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
97
+ self.netD = self.netD.to(self.config['device'])
98
+
99
+ self.interp_mode = self.config['model']['interp_mode']
100
+ # setup optimizers and schedulers
101
+ self.setup_optimizers()
102
+ self.setup_schedulers()
103
+ self.load()
104
+
105
+ if config['distributed']:
106
+ self.netG = DDP(self.netG,
107
+ device_ids=[self.config['local_rank']],
108
+ output_device=self.config['local_rank'],
109
+ broadcast_buffers=True,
110
+ find_unused_parameters=True)
111
+ if not self.config['model']['no_dis']:
112
+ self.netD = DDP(self.netD,
113
+ device_ids=[self.config['local_rank']],
114
+ output_device=self.config['local_rank'],
115
+ broadcast_buffers=True,
116
+ find_unused_parameters=False)
117
+
118
+ # set summary writer
119
+ self.dis_writer = None
120
+ self.gen_writer = None
121
+ self.summary = {}
122
+ if self.config['global_rank'] == 0 or (not config['distributed']):
123
+ if not self.config['model']['no_dis']:
124
+ self.dis_writer = SummaryWriter(
125
+ os.path.join(config['save_dir'], 'dis'))
126
+ self.gen_writer = SummaryWriter(
127
+ os.path.join(config['save_dir'], 'gen'))
128
+
129
+ def setup_optimizers(self):
130
+ """Set up optimizers."""
131
+ backbone_params = []
132
+ for name, param in self.netG.named_parameters():
133
+ if param.requires_grad:
134
+ backbone_params.append(param)
135
+ else:
136
+ print(f'Params {name} will not be optimized.')
137
+
138
+ optim_params = [
139
+ {
140
+ 'params': backbone_params,
141
+ 'lr': self.config['trainer']['lr']
142
+ },
143
+ ]
144
+
145
+ self.optimG = torch.optim.Adam(optim_params,
146
+ betas=(self.config['trainer']['beta1'],
147
+ self.config['trainer']['beta2']))
148
+
149
+ if not self.config['model']['no_dis']:
150
+ self.optimD = torch.optim.Adam(
151
+ self.netD.parameters(),
152
+ lr=self.config['trainer']['lr'],
153
+ betas=(self.config['trainer']['beta1'],
154
+ self.config['trainer']['beta2']))
155
+
156
+ def setup_schedulers(self):
157
+ """Set up schedulers."""
158
+ scheduler_opt = self.config['trainer']['scheduler']
159
+ scheduler_type = scheduler_opt.pop('type')
160
+
161
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
162
+ self.scheG = MultiStepRestartLR(
163
+ self.optimG,
164
+ milestones=scheduler_opt['milestones'],
165
+ gamma=scheduler_opt['gamma'])
166
+ if not self.config['model']['no_dis']:
167
+ self.scheD = MultiStepRestartLR(
168
+ self.optimD,
169
+ milestones=scheduler_opt['milestones'],
170
+ gamma=scheduler_opt['gamma'])
171
+ elif scheduler_type == 'CosineAnnealingRestartLR':
172
+ self.scheG = CosineAnnealingRestartLR(
173
+ self.optimG,
174
+ periods=scheduler_opt['periods'],
175
+ restart_weights=scheduler_opt['restart_weights'],
176
+ eta_min=scheduler_opt['eta_min'])
177
+ if not self.config['model']['no_dis']:
178
+ self.scheD = CosineAnnealingRestartLR(
179
+ self.optimD,
180
+ periods=scheduler_opt['periods'],
181
+ restart_weights=scheduler_opt['restart_weights'],
182
+ eta_min=scheduler_opt['eta_min'])
183
+ else:
184
+ raise NotImplementedError(
185
+ f'Scheduler {scheduler_type} is not implemented yet.')
186
+
187
+ def update_learning_rate(self):
188
+ """Update learning rate."""
189
+ self.scheG.step()
190
+ if not self.config['model']['no_dis']:
191
+ self.scheD.step()
192
+
193
+ def get_lr(self):
194
+ """Get current learning rate."""
195
+ return self.optimG.param_groups[0]['lr']
196
+
197
+ def add_summary(self, writer, name, val):
198
+ """Add tensorboard summary."""
199
+ if name not in self.summary:
200
+ self.summary[name] = 0
201
+ self.summary[name] += val
202
+ n = self.train_args['log_freq']
203
+ if writer is not None and self.iteration % n == 0:
204
+ writer.add_scalar(name, self.summary[name] / n, self.iteration)
205
+ self.summary[name] = 0
206
+
207
+ def load(self):
208
+ """Load netG (and netD)."""
209
+ # get the latest checkpoint
210
+ model_path = self.config['save_dir']
211
+ # TODO: add resume name
212
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
213
+ latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
214
+ 'r').read().splitlines()[-1]
215
+ else:
216
+ ckpts = [
217
+ os.path.basename(i).split('.pth')[0]
218
+ for i in glob.glob(os.path.join(model_path, '*.pth'))
219
+ ]
220
+ ckpts.sort()
221
+ latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
222
+
223
+ if latest_epoch is not None:
224
+ gen_path = os.path.join(model_path,
225
+ f'gen_{int(latest_epoch):06d}.pth')
226
+ dis_path = os.path.join(model_path,
227
+ f'dis_{int(latest_epoch):06d}.pth')
228
+ opt_path = os.path.join(model_path,
229
+ f'opt_{int(latest_epoch):06d}.pth')
230
+
231
+ if self.config['global_rank'] == 0:
232
+ print(f'Loading model from {gen_path}...')
233
+ dataG = torch.load(gen_path, map_location=self.config['device'])
234
+ self.netG.load_state_dict(dataG)
235
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
236
+ dataD = torch.load(dis_path, map_location=self.config['device'])
237
+ self.netD.load_state_dict(dataD)
238
+
239
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
240
+ self.optimG.load_state_dict(data_opt['optimG'])
241
+ # self.scheG.load_state_dict(data_opt['scheG'])
242
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
243
+ self.optimD.load_state_dict(data_opt['optimD'])
244
+ # self.scheD.load_state_dict(data_opt['scheD'])
245
+ self.epoch = data_opt['epoch']
246
+ self.iteration = data_opt['iteration']
247
+ else:
248
+ gen_path = self.config['trainer'].get('gen_path', None)
249
+ dis_path = self.config['trainer'].get('dis_path', None)
250
+ opt_path = self.config['trainer'].get('opt_path', None)
251
+ if gen_path is not None:
252
+ if self.config['global_rank'] == 0:
253
+ print(f'Loading Gen-Net from {gen_path}...')
254
+ dataG = torch.load(gen_path, map_location=self.config['device'])
255
+ self.netG.load_state_dict(dataG)
256
+
257
+ if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']:
258
+ if self.config['global_rank'] == 0:
259
+ print(f'Loading Dis-Net from {dis_path}...')
260
+ dataD = torch.load(dis_path, map_location=self.config['device'])
261
+ self.netD.load_state_dict(dataD)
262
+ if opt_path is not None:
263
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
264
+ self.optimG.load_state_dict(data_opt['optimG'])
265
+ self.scheG.load_state_dict(data_opt['scheG'])
266
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
267
+ self.optimD.load_state_dict(data_opt['optimD'])
268
+ self.scheD.load_state_dict(data_opt['scheD'])
269
+ else:
270
+ if self.config['global_rank'] == 0:
271
+ print('Warnning: There is no trained model found.'
272
+ 'An initialized model will be used.')
273
+
274
+ def save(self, it):
275
+ """Save parameters every eval_epoch"""
276
+ if self.config['global_rank'] == 0:
277
+ # configure path
278
+ gen_path = os.path.join(self.config['save_dir'],
279
+ f'gen_{it:06d}.pth')
280
+ dis_path = os.path.join(self.config['save_dir'],
281
+ f'dis_{it:06d}.pth')
282
+ opt_path = os.path.join(self.config['save_dir'],
283
+ f'opt_{it:06d}.pth')
284
+ print(f'\nsaving model to {gen_path} ...')
285
+
286
+ # remove .module for saving
287
+ if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
288
+ netG = self.netG.module
289
+ if not self.config['model']['no_dis']:
290
+ netD = self.netD.module
291
+ else:
292
+ netG = self.netG
293
+ if not self.config['model']['no_dis']:
294
+ netD = self.netD
295
+
296
+ # save checkpoints
297
+ torch.save(netG.state_dict(), gen_path)
298
+ if not self.config['model']['no_dis']:
299
+ torch.save(netD.state_dict(), dis_path)
300
+ torch.save(
301
+ {
302
+ 'epoch': self.epoch,
303
+ 'iteration': self.iteration,
304
+ 'optimG': self.optimG.state_dict(),
305
+ 'optimD': self.optimD.state_dict(),
306
+ 'scheG': self.scheG.state_dict(),
307
+ 'scheD': self.scheD.state_dict()
308
+ }, opt_path)
309
+ else:
310
+ torch.save(
311
+ {
312
+ 'epoch': self.epoch,
313
+ 'iteration': self.iteration,
314
+ 'optimG': self.optimG.state_dict(),
315
+ 'scheG': self.scheG.state_dict()
316
+ }, opt_path)
317
+
318
+ latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
319
+ os.system(f"echo {it:06d} > {latest_path}")
320
+
321
+ def train(self):
322
+ """training entry"""
323
+ pbar = range(int(self.train_args['iterations']))
324
+ if self.config['global_rank'] == 0:
325
+ pbar = tqdm(pbar,
326
+ initial=self.iteration,
327
+ dynamic_ncols=True,
328
+ smoothing=0.01)
329
+
330
+ os.makedirs('logs', exist_ok=True)
331
+
332
+ logging.basicConfig(
333
+ level=logging.INFO,
334
+ format="%(asctime)s %(filename)s[line:%(lineno)d]"
335
+ "%(levelname)s %(message)s",
336
+ datefmt="%a, %d %b %Y %H:%M:%S",
337
+ filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
338
+ filemode='w')
339
+
340
+ while True:
341
+ self.epoch += 1
342
+ self.prefetcher.reset()
343
+ if self.config['distributed']:
344
+ self.train_sampler.set_epoch(self.epoch)
345
+ self._train_epoch(pbar)
346
+ if self.iteration > self.train_args['iterations']:
347
+ break
348
+ print('\nEnd training....')
349
+
350
+ def _train_epoch(self, pbar):
351
+ """Process input and calculate loss every training epoch"""
352
+ device = self.config['device']
353
+ train_data = self.prefetcher.next()
354
+ while train_data is not None:
355
+ self.iteration += 1
356
+ frames, masks, flows_f, flows_b, _ = train_data
357
+ frames, masks = frames.to(device), masks.to(device).float()
358
+ l_t = self.num_local_frames
359
+ b, t, c, h, w = frames.size()
360
+ gt_local_frames = frames[:, :l_t, ...]
361
+ local_masks = masks[:, :l_t, ...].contiguous()
362
+
363
+ masked_frames = frames * (1 - masks)
364
+ masked_local_frames = masked_frames[:, :l_t, ...]
365
+ # get gt optical flow
366
+ if flows_f[0] == 'None' or flows_b[0] == 'None':
367
+ gt_flows_bi = self.fix_raft(gt_local_frames)
368
+ else:
369
+ gt_flows_bi = (flows_f.to(device), flows_b.to(device))
370
+
371
+ # ---- complete flow ----
372
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
373
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
374
+ # pred_flows_bi = gt_flows_bi
375
+
376
+ # ---- image propagation ----
377
+ prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode)
378
+ updated_masks = masks.clone()
379
+ updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w)
380
+ updated_frames = masked_frames.clone()
381
+ prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge
382
+ updated_frames[:, :l_t, ...] = prop_local_frames
383
+
384
+ # ---- feature propagation + Transformer ----
385
+ pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t)
386
+ pred_imgs = pred_imgs.view(b, -1, c, h, w)
387
+
388
+ # get the local frames
389
+ pred_local_frames = pred_imgs[:, :l_t, ...]
390
+ comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks
391
+ comp_imgs = frames * (1. - masks) + pred_imgs * masks
392
+
393
+ gen_loss = 0
394
+ dis_loss = 0
395
+ # optimize net_g
396
+ if not self.config['model']['no_dis']:
397
+ for p in self.netD.parameters():
398
+ p.requires_grad = False
399
+
400
+ self.optimG.zero_grad()
401
+
402
+ # generator l1 loss
403
+ hole_loss = self.l1_loss(pred_imgs * masks, frames * masks)
404
+ hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
405
+ gen_loss += hole_loss
406
+ self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
407
+
408
+ valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks))
409
+ valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight']
410
+ gen_loss += valid_loss
411
+ self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
412
+
413
+ # perceptual loss
414
+ if self.config['losses']['perceptual_weight'] > 0:
415
+ perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight']
416
+ gen_loss += perc_loss
417
+ self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item())
418
+
419
+ # gan loss
420
+ if not self.config['model']['no_dis']:
421
+ # generator adversarial loss
422
+ gen_clip = self.netD(comp_imgs)
423
+ gan_loss = self.adversarial_loss(gen_clip, True, False)
424
+ gan_loss = gan_loss * self.config['losses']['adversarial_weight']
425
+ gen_loss += gan_loss
426
+ self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
427
+ gen_loss.backward()
428
+ self.optimG.step()
429
+
430
+ if not self.config['model']['no_dis']:
431
+ # optimize net_d
432
+ for p in self.netD.parameters():
433
+ p.requires_grad = True
434
+ self.optimD.zero_grad()
435
+
436
+ # discriminator adversarial loss
437
+ real_clip = self.netD(frames)
438
+ fake_clip = self.netD(comp_imgs.detach())
439
+ dis_real_loss = self.adversarial_loss(real_clip, True, True)
440
+ dis_fake_loss = self.adversarial_loss(fake_clip, False, True)
441
+ dis_loss += (dis_real_loss + dis_fake_loss) / 2
442
+ self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
443
+ self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
444
+ dis_loss.backward()
445
+ self.optimD.step()
446
+
447
+ self.update_learning_rate()
448
+
449
+ # write image to tensorboard
450
+ if self.iteration % 200 == 0:
451
+ # img to cpu
452
+ t = 0
453
+ gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
454
+ masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
455
+ prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
456
+ pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
457
+ img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
458
+ prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
459
+ img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
460
+ if self.gen_writer is not None:
461
+ self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
462
+
463
+ t = 5
464
+ if masked_local_frames.shape[1] > 5:
465
+ img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
466
+ prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
467
+ img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
468
+ if self.gen_writer is not None:
469
+ self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
470
+
471
+ # flow to cpu
472
+ gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
473
+ masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu)
474
+ pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
475
+
476
+ flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1)
477
+ if self.gen_writer is not None:
478
+ self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration)
479
+
480
+ # console logs
481
+ if self.config['global_rank'] == 0:
482
+ pbar.update(1)
483
+ if not self.config['model']['no_dis']:
484
+ pbar.set_description((f"d: {dis_loss.item():.3f}; "
485
+ f"hole: {hole_loss.item():.3f}; "
486
+ f"valid: {valid_loss.item():.3f}"))
487
+ else:
488
+ pbar.set_description((f"hole: {hole_loss.item():.3f}; "
489
+ f"valid: {valid_loss.item():.3f}"))
490
+
491
+ if self.iteration % self.train_args['log_freq'] == 0:
492
+ if not self.config['model']['no_dis']:
493
+ logging.info(f"[Iter {self.iteration}] "
494
+ f"d: {dis_loss.item():.4f}; "
495
+ f"hole: {hole_loss.item():.4f}; "
496
+ f"valid: {valid_loss.item():.4f}")
497
+ else:
498
+ logging.info(f"[Iter {self.iteration}] "
499
+ f"hole: {hole_loss.item():.4f}; "
500
+ f"valid: {valid_loss.item():.4f}")
501
+
502
+ # saving models
503
+ if self.iteration % self.train_args['save_freq'] == 0:
504
+ self.save(int(self.iteration))
505
+
506
+ if self.iteration > self.train_args['iterations']:
507
+ break
508
+
509
+ train_data = self.prefetcher.next()
core/trainer_flow_w_edge.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import logging
4
+ import importlib
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
17
+ from core.dataset import TrainDataset
18
+
19
+ from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
20
+
21
+ # from skimage.feature import canny
22
+ from model.canny.canny_filter import Canny
23
+ from RAFT.utils.flow_viz_pt import flow_to_image
24
+
25
+
26
+ class Trainer:
27
+ def __init__(self, config):
28
+ self.config = config
29
+ self.epoch = 0
30
+ self.iteration = 0
31
+ self.num_local_frames = config['train_data_loader']['num_local_frames']
32
+ self.num_ref_frames = config['train_data_loader']['num_ref_frames']
33
+
34
+ # setup data set and data loader
35
+ self.train_dataset = TrainDataset(config['train_data_loader'])
36
+
37
+ self.train_sampler = None
38
+ self.train_args = config['trainer']
39
+ if config['distributed']:
40
+ self.train_sampler = DistributedSampler(
41
+ self.train_dataset,
42
+ num_replicas=config['world_size'],
43
+ rank=config['global_rank'])
44
+
45
+ dataloader_args = dict(
46
+ dataset=self.train_dataset,
47
+ batch_size=self.train_args['batch_size'] // config['world_size'],
48
+ shuffle=(self.train_sampler is None),
49
+ num_workers=self.train_args['num_workers'],
50
+ sampler=self.train_sampler,
51
+ drop_last=True)
52
+
53
+ self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
54
+ self.prefetcher = CPUPrefetcher(self.train_loader)
55
+
56
+ # set raft
57
+ self.fix_raft = RAFT_bi(device = self.config['device'])
58
+ self.flow_loss = FlowLoss()
59
+ self.edge_loss = EdgeLoss()
60
+ self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2)
61
+
62
+ # setup models including generator and discriminator
63
+ net = importlib.import_module('model.' + config['model']['net'])
64
+ self.netG = net.RecurrentFlowCompleteNet()
65
+ # print(self.netG)
66
+ self.netG = self.netG.to(self.config['device'])
67
+
68
+ # setup optimizers and schedulers
69
+ self.setup_optimizers()
70
+ self.setup_schedulers()
71
+ self.load()
72
+
73
+ if config['distributed']:
74
+ self.netG = DDP(self.netG,
75
+ device_ids=[self.config['local_rank']],
76
+ output_device=self.config['local_rank'],
77
+ broadcast_buffers=True,
78
+ find_unused_parameters=True)
79
+
80
+ # set summary writer
81
+ self.dis_writer = None
82
+ self.gen_writer = None
83
+ self.summary = {}
84
+ if self.config['global_rank'] == 0 or (not config['distributed']):
85
+ self.gen_writer = SummaryWriter(
86
+ os.path.join(config['save_dir'], 'gen'))
87
+
88
+ def setup_optimizers(self):
89
+ """Set up optimizers."""
90
+ backbone_params = []
91
+ for name, param in self.netG.named_parameters():
92
+ if param.requires_grad:
93
+ backbone_params.append(param)
94
+ else:
95
+ print(f'Params {name} will not be optimized.')
96
+
97
+ optim_params = [
98
+ {
99
+ 'params': backbone_params,
100
+ 'lr': self.config['trainer']['lr']
101
+ },
102
+ ]
103
+
104
+ self.optimG = torch.optim.Adam(optim_params,
105
+ betas=(self.config['trainer']['beta1'],
106
+ self.config['trainer']['beta2']))
107
+
108
+
109
+ def setup_schedulers(self):
110
+ """Set up schedulers."""
111
+ scheduler_opt = self.config['trainer']['scheduler']
112
+ scheduler_type = scheduler_opt.pop('type')
113
+
114
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
115
+ self.scheG = MultiStepRestartLR(
116
+ self.optimG,
117
+ milestones=scheduler_opt['milestones'],
118
+ gamma=scheduler_opt['gamma'])
119
+ elif scheduler_type == 'CosineAnnealingRestartLR':
120
+ self.scheG = CosineAnnealingRestartLR(
121
+ self.optimG,
122
+ periods=scheduler_opt['periods'],
123
+ restart_weights=scheduler_opt['restart_weights'])
124
+ else:
125
+ raise NotImplementedError(
126
+ f'Scheduler {scheduler_type} is not implemented yet.')
127
+
128
+ def update_learning_rate(self):
129
+ """Update learning rate."""
130
+ self.scheG.step()
131
+
132
+ def get_lr(self):
133
+ """Get current learning rate."""
134
+ return self.optimG.param_groups[0]['lr']
135
+
136
+ def add_summary(self, writer, name, val):
137
+ """Add tensorboard summary."""
138
+ if name not in self.summary:
139
+ self.summary[name] = 0
140
+ self.summary[name] += val
141
+ n = self.train_args['log_freq']
142
+ if writer is not None and self.iteration % n == 0:
143
+ writer.add_scalar(name, self.summary[name] / n, self.iteration)
144
+ self.summary[name] = 0
145
+
146
+ def load(self):
147
+ """Load netG."""
148
+ # get the latest checkpoint
149
+ model_path = self.config['save_dir']
150
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
151
+ latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
152
+ 'r').read().splitlines()[-1]
153
+ else:
154
+ ckpts = [
155
+ os.path.basename(i).split('.pth')[0]
156
+ for i in glob.glob(os.path.join(model_path, '*.pth'))
157
+ ]
158
+ ckpts.sort()
159
+ latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
160
+
161
+ if latest_epoch is not None:
162
+ gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth')
163
+ opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth')
164
+
165
+ if self.config['global_rank'] == 0:
166
+ print(f'Loading model from {gen_path}...')
167
+ dataG = torch.load(gen_path, map_location=self.config['device'])
168
+ self.netG.load_state_dict(dataG)
169
+
170
+
171
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
172
+ self.optimG.load_state_dict(data_opt['optimG'])
173
+ self.scheG.load_state_dict(data_opt['scheG'])
174
+
175
+ self.epoch = data_opt['epoch']
176
+ self.iteration = data_opt['iteration']
177
+
178
+ else:
179
+ if self.config['global_rank'] == 0:
180
+ print('Warnning: There is no trained model found.'
181
+ 'An initialized model will be used.')
182
+
183
+ def save(self, it):
184
+ """Save parameters every eval_epoch"""
185
+ if self.config['global_rank'] == 0:
186
+ # configure path
187
+ gen_path = os.path.join(self.config['save_dir'],
188
+ f'gen_{it:06d}.pth')
189
+ opt_path = os.path.join(self.config['save_dir'],
190
+ f'opt_{it:06d}.pth')
191
+ print(f'\nsaving model to {gen_path} ...')
192
+
193
+ # remove .module for saving
194
+ if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
195
+ netG = self.netG.module
196
+ else:
197
+ netG = self.netG
198
+
199
+ # save checkpoints
200
+ torch.save(netG.state_dict(), gen_path)
201
+ torch.save(
202
+ {
203
+ 'epoch': self.epoch,
204
+ 'iteration': self.iteration,
205
+ 'optimG': self.optimG.state_dict(),
206
+ 'scheG': self.scheG.state_dict()
207
+ }, opt_path)
208
+
209
+ latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
210
+ os.system(f"echo {it:06d} > {latest_path}")
211
+
212
+ def train(self):
213
+ """training entry"""
214
+ pbar = range(int(self.train_args['iterations']))
215
+ if self.config['global_rank'] == 0:
216
+ pbar = tqdm(pbar,
217
+ initial=self.iteration,
218
+ dynamic_ncols=True,
219
+ smoothing=0.01)
220
+
221
+ os.makedirs('logs', exist_ok=True)
222
+
223
+ logging.basicConfig(
224
+ level=logging.INFO,
225
+ format="%(asctime)s %(filename)s[line:%(lineno)d]"
226
+ "%(levelname)s %(message)s",
227
+ datefmt="%a, %d %b %Y %H:%M:%S",
228
+ filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
229
+ filemode='w')
230
+
231
+ while True:
232
+ self.epoch += 1
233
+ self.prefetcher.reset()
234
+ if self.config['distributed']:
235
+ self.train_sampler.set_epoch(self.epoch)
236
+ self._train_epoch(pbar)
237
+ if self.iteration > self.train_args['iterations']:
238
+ break
239
+ print('\nEnd training....')
240
+
241
+ # def get_edges(self, flows): # fgvc
242
+ # # (b, t, 2, H, W)
243
+ # b, t, _, h, w = flows.shape
244
+ # flows = flows.view(-1, 2, h, w)
245
+ # flows_list = flows.permute(0, 2, 3, 1).cpu().numpy()
246
+ # edges = []
247
+ # for f in list(flows_list):
248
+ # flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5
249
+ # if flows_gray.max() < 1:
250
+ # flows_gray = flows_gray*0
251
+ # else:
252
+ # flows_gray = flows_gray / flows_gray.max()
253
+
254
+ # edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc
255
+ # edge = torch.from_numpy(edge).view(1, 1, h, w).float()
256
+ # edges.append(edge)
257
+ # edges = torch.stack(edges, dim=0).to(self.config['device'])
258
+ # edges = edges.view(b, t, 1, h, w)
259
+ # return edges
260
+
261
+ def get_edges(self, flows):
262
+ # (b, t, 2, H, W)
263
+ b, t, _, h, w = flows.shape
264
+ flows = flows.view(-1, 2, h, w)
265
+ flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5
266
+ if flows_gray.max() < 1:
267
+ flows_gray = flows_gray*0
268
+ else:
269
+ flows_gray = flows_gray / flows_gray.max()
270
+
271
+ magnitude, edges = self.canny(flows_gray.float())
272
+ edges = edges.view(b, t, 1, h, w)
273
+ return edges
274
+
275
+ def _train_epoch(self, pbar):
276
+ """Process input and calculate loss every training epoch"""
277
+ device = self.config['device']
278
+ train_data = self.prefetcher.next()
279
+ while train_data is not None:
280
+ self.iteration += 1
281
+ frames, masks, flows_f, flows_b, _ = train_data
282
+ frames, masks = frames.to(device), masks.to(device)
283
+ masks = masks.float()
284
+
285
+ l_t = self.num_local_frames
286
+ b, t, c, h, w = frames.size()
287
+ gt_local_frames = frames[:, :l_t, ...]
288
+ local_masks = masks[:, :l_t, ...].contiguous()
289
+
290
+ # get gt optical flow
291
+ if flows_f[0] == 'None' or flows_b[0] == 'None':
292
+ gt_flows_bi = self.fix_raft(gt_local_frames)
293
+ else:
294
+ gt_flows_bi = (flows_f.to(device), flows_b.to(device))
295
+
296
+ # get gt edge
297
+ gt_edges_forward = self.get_edges(gt_flows_bi[0])
298
+ gt_edges_backward = self.get_edges(gt_flows_bi[1])
299
+ gt_edges_bi = [gt_edges_forward, gt_edges_backward]
300
+
301
+ # complete flow
302
+ pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks)
303
+
304
+ # optimize net_g
305
+ self.optimG.zero_grad()
306
+
307
+ # compulte flow_loss
308
+ flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames)
309
+ flow_loss = flow_loss * self.config['losses']['flow_weight']
310
+ warp_loss = warp_loss * 0.01
311
+ self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item())
312
+ self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item())
313
+
314
+ # compute edge loss
315
+ edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks)
316
+ edge_loss = edge_loss*1.0
317
+ self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item())
318
+
319
+ loss = flow_loss + warp_loss + edge_loss
320
+ loss.backward()
321
+ self.optimG.step()
322
+ self.update_learning_rate()
323
+
324
+ # write image to tensorboard
325
+ # if self.iteration % 200 == 0:
326
+ if self.iteration % 200 == 0 and self.gen_writer is not None:
327
+ t = 5
328
+ # forward to cpu
329
+ gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
330
+ masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu)
331
+ pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
332
+
333
+ flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1)
334
+ self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration)
335
+
336
+ # backward to cpu
337
+ gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu()
338
+ masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu)
339
+ pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu()
340
+
341
+ flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1)
342
+ self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration)
343
+
344
+ # TODO: show edge
345
+ # forward
346
+ gt_edges_forward_cpu = gt_edges_bi[0][0].cpu()
347
+ masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu)
348
+ pred_edges_forward_cpu = pred_edges_bi[0][0].cpu()
349
+
350
+ edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1)
351
+ self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration)
352
+ # backward
353
+ gt_edges_backward_cpu = gt_edges_bi[1][0].cpu()
354
+ masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu)
355
+ pred_edges_backward_cpu = pred_edges_bi[1][0].cpu()
356
+
357
+ edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1)
358
+ self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration)
359
+
360
+ # console logs
361
+ if self.config['global_rank'] == 0:
362
+ pbar.update(1)
363
+ pbar.set_description((f"flow: {flow_loss.item():.3f}; "
364
+ f"warp: {warp_loss.item():.3f}; "
365
+ f"edge: {edge_loss.item():.3f}; "
366
+ f"lr: {self.get_lr()}"))
367
+
368
+ if self.iteration % self.train_args['log_freq'] == 0:
369
+ logging.info(f"[Iter {self.iteration}] "
370
+ f"flow: {flow_loss.item():.4f}; "
371
+ f"warp: {warp_loss.item():.4f}")
372
+
373
+ # saving models
374
+ if self.iteration % self.train_args['save_freq'] == 0:
375
+ self.save(int(self.iteration))
376
+
377
+ if self.iteration > self.train_args['iterations']:
378
+ break
379
+
380
+ train_data = self.prefetcher.next()
core/utils.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import cv2
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image, ImageOps
7
+ import zipfile
8
+ import math
9
+
10
+ import torch
11
+ import matplotlib
12
+ import matplotlib.patches as patches
13
+ from matplotlib.path import Path
14
+ from matplotlib import pyplot as plt
15
+ from torchvision import transforms
16
+
17
+ # matplotlib.use('agg')
18
+
19
+ # ###########################################################################
20
+ # Directory IO
21
+ # ###########################################################################
22
+
23
+
24
+ def read_dirnames_under_root(root_dir):
25
+ dirnames = [
26
+ name for i, name in enumerate(sorted(os.listdir(root_dir)))
27
+ if os.path.isdir(os.path.join(root_dir, name))
28
+ ]
29
+ print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
30
+ return dirnames
31
+
32
+
33
+ class TrainZipReader(object):
34
+ file_dict = dict()
35
+
36
+ def __init__(self):
37
+ super(TrainZipReader, self).__init__()
38
+
39
+ @staticmethod
40
+ def build_file_dict(path):
41
+ file_dict = TrainZipReader.file_dict
42
+ if path in file_dict:
43
+ return file_dict[path]
44
+ else:
45
+ file_handle = zipfile.ZipFile(path, 'r')
46
+ file_dict[path] = file_handle
47
+ return file_dict[path]
48
+
49
+ @staticmethod
50
+ def imread(path, idx):
51
+ zfile = TrainZipReader.build_file_dict(path)
52
+ filelist = zfile.namelist()
53
+ filelist.sort()
54
+ data = zfile.read(filelist[idx])
55
+ #
56
+ im = Image.open(io.BytesIO(data))
57
+ return im
58
+
59
+
60
+ class TestZipReader(object):
61
+ file_dict = dict()
62
+
63
+ def __init__(self):
64
+ super(TestZipReader, self).__init__()
65
+
66
+ @staticmethod
67
+ def build_file_dict(path):
68
+ file_dict = TestZipReader.file_dict
69
+ if path in file_dict:
70
+ return file_dict[path]
71
+ else:
72
+ file_handle = zipfile.ZipFile(path, 'r')
73
+ file_dict[path] = file_handle
74
+ return file_dict[path]
75
+
76
+ @staticmethod
77
+ def imread(path, idx):
78
+ zfile = TestZipReader.build_file_dict(path)
79
+ filelist = zfile.namelist()
80
+ filelist.sort()
81
+ data = zfile.read(filelist[idx])
82
+ file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
83
+ im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
84
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
85
+ # im = Image.open(io.BytesIO(data))
86
+ return im
87
+
88
+
89
+ # ###########################################################################
90
+ # Data augmentation
91
+ # ###########################################################################
92
+
93
+
94
+ def to_tensors():
95
+ return transforms.Compose([Stack(), ToTorchFormatTensor()])
96
+
97
+
98
+ class GroupRandomHorizontalFlowFlip(object):
99
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
100
+ """
101
+ def __call__(self, img_group, flowF_group, flowB_group):
102
+ v = random.random()
103
+ if v < 0.5:
104
+ ret_img = [
105
+ img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
106
+ ]
107
+ ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
108
+ ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
109
+ return ret_img, ret_flowF, ret_flowB
110
+ else:
111
+ return img_group, flowF_group, flowB_group
112
+
113
+
114
+ class GroupRandomHorizontalFlip(object):
115
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
116
+ """
117
+ def __call__(self, img_group, is_flow=False):
118
+ v = random.random()
119
+ if v < 0.5:
120
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
121
+ if is_flow:
122
+ for i in range(0, len(ret), 2):
123
+ # invert flow pixel values when flipping
124
+ ret[i] = ImageOps.invert(ret[i])
125
+ return ret
126
+ else:
127
+ return img_group
128
+
129
+
130
+ class Stack(object):
131
+ def __init__(self, roll=False):
132
+ self.roll = roll
133
+
134
+ def __call__(self, img_group):
135
+ mode = img_group[0].mode
136
+ if mode == '1':
137
+ img_group = [img.convert('L') for img in img_group]
138
+ mode = 'L'
139
+ if mode == 'L':
140
+ return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
141
+ elif mode == 'RGB':
142
+ if self.roll:
143
+ return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
144
+ axis=2)
145
+ else:
146
+ return np.stack(img_group, axis=2)
147
+ else:
148
+ raise NotImplementedError(f"Image mode {mode}")
149
+
150
+
151
+ class ToTorchFormatTensor(object):
152
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
153
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
154
+ def __init__(self, div=True):
155
+ self.div = div
156
+
157
+ def __call__(self, pic):
158
+ if isinstance(pic, np.ndarray):
159
+ # numpy img: [L, C, H, W]
160
+ img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
161
+ else:
162
+ # handle PIL Image
163
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(
164
+ pic.tobytes()))
165
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
166
+ # put it from HWC to CHW format
167
+ # yikes, this transpose takes 80% of the loading time/CPU
168
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
169
+ img = img.float().div(255) if self.div else img.float()
170
+ return img
171
+
172
+
173
+ # ###########################################################################
174
+ # Create masks with random shape
175
+ # ###########################################################################
176
+
177
+
178
+ def create_random_shape_with_random_motion(video_length,
179
+ imageHeight=240,
180
+ imageWidth=432):
181
+ # get a random shape
182
+ height = random.randint(imageHeight // 3, imageHeight - 1)
183
+ width = random.randint(imageWidth // 3, imageWidth - 1)
184
+ edge_num = random.randint(6, 8)
185
+ ratio = random.randint(6, 8) / 10
186
+
187
+ region = get_random_shape(edge_num=edge_num,
188
+ ratio=ratio,
189
+ height=height,
190
+ width=width)
191
+ region_width, region_height = region.size
192
+ # get random position
193
+ x, y = random.randint(0, imageHeight - region_height), random.randint(
194
+ 0, imageWidth - region_width)
195
+ velocity = get_random_velocity(max_speed=3)
196
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
197
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
198
+ masks = [m.convert('L')]
199
+ # return fixed masks
200
+ if random.uniform(0, 1) > 0.5:
201
+ return masks * video_length
202
+ # return moving masks
203
+ for _ in range(video_length - 1):
204
+ x, y, velocity = random_move_control_points(x,
205
+ y,
206
+ imageHeight,
207
+ imageWidth,
208
+ velocity,
209
+ region.size,
210
+ maxLineAcceleration=(3,
211
+ 0.5),
212
+ maxInitSpeed=3)
213
+ m = Image.fromarray(
214
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
215
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
216
+ masks.append(m.convert('L'))
217
+ return masks
218
+
219
+
220
+ def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432):
221
+ # get a random shape
222
+ assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
223
+ assert zoomout > 1, "Zoom-out parameter must be larger than 1"
224
+ assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
225
+ height = random.randint(imageHeight//3, imageHeight-1)
226
+ width = random.randint(imageWidth//3, imageWidth-1)
227
+ edge_num = random.randint(6, 8)
228
+ ratio = random.randint(6, 8)/10
229
+ region = get_random_shape(
230
+ edge_num=edge_num, ratio=ratio, height=height, width=width)
231
+ region_width, region_height = region.size
232
+ # get random position
233
+ x, y = random.randint(
234
+ 0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
235
+ velocity = get_random_velocity(max_speed=3)
236
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
237
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
238
+ masks = [m.convert('L')]
239
+ # return fixed masks
240
+ if random.uniform(0, 1) > 0.5:
241
+ return masks*video_length # -> directly copy all the base masks
242
+ # return moving masks
243
+ for _ in range(video_length-1):
244
+ x, y, velocity = random_move_control_points(
245
+ x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
246
+ m = Image.fromarray(
247
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
248
+ ### add by kaidong, to simulate zoon-in, zoom-out and rotation
249
+ extra_transform = random.uniform(0, 1)
250
+ # zoom in and zoom out
251
+ if extra_transform > 0.75:
252
+ resize_coefficient = random.uniform(zoomin, zoomout)
253
+ region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
254
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
255
+ region_width, region_height = region.size
256
+ # rotation
257
+ elif extra_transform > 0.5:
258
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
259
+ m = m.rotate(random.randint(rotmin, rotmax))
260
+ # region_width, region_height = region.size
261
+ ### end
262
+ else:
263
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
264
+ masks.append(m.convert('L'))
265
+ return masks
266
+
267
+
268
+ def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
269
+ '''
270
+ There is the initial point and 3 points per cubic bezier curve.
271
+ Thus, the curve will only pass though n points, which will be the sharp edges.
272
+ The other 2 modify the shape of the bezier curve.
273
+ edge_num, Number of possibly sharp edges
274
+ points_num, number of points in the Path
275
+ ratio, (0, 1) magnitude of the perturbation from the unit circle,
276
+ '''
277
+ points_num = edge_num*3 + 1
278
+ angles = np.linspace(0, 2*np.pi, points_num)
279
+ codes = np.full(points_num, Path.CURVE4)
280
+ codes[0] = Path.MOVETO
281
+ # Using this instead of Path.CLOSEPOLY avoids an innecessary straight line
282
+ verts = np.stack((np.cos(angles), np.sin(angles))).T * \
283
+ (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
284
+ verts[-1, :] = verts[0, :]
285
+ path = Path(verts, codes)
286
+ # draw paths into images
287
+ fig = plt.figure()
288
+ ax = fig.add_subplot(111)
289
+ patch = patches.PathPatch(path, facecolor='black', lw=2)
290
+ ax.add_patch(patch)
291
+ ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
292
+ ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
293
+ ax.axis('off') # removes the axis to leave only the shape
294
+ fig.canvas.draw()
295
+ # convert plt images into numpy images
296
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
297
+ data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
298
+ plt.close(fig)
299
+ # postprocess
300
+ data = cv2.resize(data, (width, height))[:, :, 0]
301
+ data = (1 - np.array(data > 0).astype(np.uint8))*255
302
+ corrdinates = np.where(data > 0)
303
+ xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
304
+ corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
305
+ region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
306
+ return region
307
+
308
+
309
+ def random_accelerate(velocity, maxAcceleration, dist='uniform'):
310
+ speed, angle = velocity
311
+ d_speed, d_angle = maxAcceleration
312
+ if dist == 'uniform':
313
+ speed += np.random.uniform(-d_speed, d_speed)
314
+ angle += np.random.uniform(-d_angle, d_angle)
315
+ elif dist == 'guassian':
316
+ speed += np.random.normal(0, d_speed / 2)
317
+ angle += np.random.normal(0, d_angle / 2)
318
+ else:
319
+ raise NotImplementedError(
320
+ f'Distribution type {dist} is not supported.')
321
+ return (speed, angle)
322
+
323
+
324
+ def get_random_velocity(max_speed=3, dist='uniform'):
325
+ if dist == 'uniform':
326
+ speed = np.random.uniform(max_speed)
327
+ elif dist == 'guassian':
328
+ speed = np.abs(np.random.normal(0, max_speed / 2))
329
+ else:
330
+ raise NotImplementedError(
331
+ f'Distribution type {dist} is not supported.')
332
+ angle = np.random.uniform(0, 2 * np.pi)
333
+ return (speed, angle)
334
+
335
+
336
+ def random_move_control_points(X,
337
+ Y,
338
+ imageHeight,
339
+ imageWidth,
340
+ lineVelocity,
341
+ region_size,
342
+ maxLineAcceleration=(3, 0.5),
343
+ maxInitSpeed=3):
344
+ region_width, region_height = region_size
345
+ speed, angle = lineVelocity
346
+ X += int(speed * np.cos(angle))
347
+ Y += int(speed * np.sin(angle))
348
+ lineVelocity = random_accelerate(lineVelocity,
349
+ maxLineAcceleration,
350
+ dist='guassian')
351
+ if ((X > imageHeight - region_height) or (X < 0)
352
+ or (Y > imageWidth - region_width) or (Y < 0)):
353
+ lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
354
+ new_X = np.clip(X, 0, imageHeight - region_height)
355
+ new_Y = np.clip(Y, 0, imageWidth - region_width)
356
+ return new_X, new_Y, lineVelocity
357
+
358
+
359
+ if __name__ == '__main__':
360
+
361
+ trials = 10
362
+ for _ in range(trials):
363
+ video_length = 10
364
+ # The returned masks are either stationary (50%) or moving (50%)
365
+ masks = create_random_shape_with_random_motion(video_length,
366
+ imageHeight=240,
367
+ imageWidth=432)
368
+
369
+ for m in masks:
370
+ cv2.imshow('mask', np.array(m))
371
+ cv2.waitKey(500)
datasets/davis/test.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bear": 82, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cows": 104, "dance-jump": 60, "dance-twirl": 90, "dog": 60, "dog-agility": 25, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "elephant": 80, "flamingo": 80, "goat": 90, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "kite-surf": 50, "kite-walk": 80, "libby": 49, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "rhino": 90, "rollerblade": 35, "scooter-black": 43, "scooter-gray": 75, "soapbox": 99, "soccerball": 48, "stroller": 91, "surf": 55, "swing": 60, "tennis": 70, "train": 80}
datasets/davis/train.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}
datasets/youtube-vos/test.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0070461469": 91, "00bd64cb00": 180, "00fef116ee": 96, "012257ffcf": 180, "01475d1fe7": 180, "0163b18674": 96, "017fa2adaa": 180, "0232ba85ed": 180, "02b1a46f42": 180, "02caec8ac0": 91, "047436c72c": 96, "0481e165b4": 150, "04f98557e7": 144, "05e73c3ecb": 96, "08f95ce1ff": 144, "0b6db1c6fd": 96, "0bd8c18197": 180, "0c6d13ee2c": 91, "0c7ba00455": 96, "0cba3e52eb": 91, "0d16524447": 150, "0d4827437d": 150, "0d62fa582a": 180, "0e1f91c0d7": 91, "0ef454b3f0": 91, "10e18fcf0c": 96, "11105e147e": 91, "11444b16da": 91, "11a4df37a4": 180, "11b3298d6a": 96, "13006c4c7e": 96, "1345523ba1": 180, "144a16eb12": 180, "15a6536e74": 180, "1616507c9e": 180, "1655f4782a": 92, "16608ccef6": 96, "16bc05b66c": 150, "16f1e1779b": 96, "17caf00e26": 96, "18f1e2f716": 91, "191a0bfcdf": 180, "19d4acf831": 91, "1a1dc21969": 96, "1a72d9fcea": 150, "1a92c81edd": 180, "1b2c2022a3": 96, "1d1601d079": 180, "1db7b25d1c": 180, "1dee5b7b5a": 150, "1e0c2e54f2": 96, "1e458b1539": 92, "1e6ac08c86": 91, "1e790eae99": 56, "1ed0c6ca5b": 96, "1edbdb6d18": 180, "1f2015e056": 96, "215ac56b15": 180, "2233485b49": 96, "224d171af6": 180, "237c6ebaf4": 91, "2462c51412": 96, "24bf968338": 180, "250d5953a0": 150, "25bcf222fb": 180, "25ea8feecf": 150, "25fc493839": 92, "262f69837e": 180, "264ca20298": 180, "26d8d48248": 51, "270f84c5e5": 91, "27889bc0fe": 180, "29b87846e7": 96, "29d2e79171": 180, "2a44411a3d": 180, "2b426fd330": 180, "2c4c4e2d5b": 180, "2c4c718eda": 180, "2c962c1bbe": 180, "2cc841341c": 92, "2cf6c4d17e": 91, "2d7ef0be04": 180, "2e5e52c6c8": 150, "2ef6fce8c6": 144, "3014e769bf": 180, "30d5f163b6": 180, "318df73d6a": 90, "31fbb9df3c": 96, "3255fcad2f": 180, "3303eea8e4": 91, "3447c30052": 150, "362722660c": 180, "37e0b4642b": 91, "383e51ed93": 180, "386b050bd0": 41, "3876ba3136": 180, "388ec2934c": 180, "38b45d9c6b": 96, "396680839c": 150, "39ffa3a4a4": 180, "3b0291b2be": 150, "3b333693f4": 180, "3bde1da2cf": 96, "3c5f4e6672": 91, "3c80682cc6": 92, "3ce634a1c1": 180, "3d6a761295": 96, "3da878c317": 91, "3db571b7ee": 96, "3e2336812c": 180, "3f16b04d6d": 96, "3fbbc75c5e": 180, "4015a1e1cc": 87, "406cd7bd48": 91, "407b87ba26": 91, "40a5628dcc": 91, "41af239f5e": 180, "42c671b285": 180, "42de37f462": 180, "4381c60a2f": 180, "4445dc0af5": 180, "44a3419d24": 180, "4566034eaf": 51, "45877fd086": 180, "4595935b88": 91, "4923010cfe": 96, "49b6d81ee8": 180, "4a39c34139": 180, "4a5a9fde01": 144, "4a90394892": 180, "4af10534e4": 180, "4af307f5bc": 180, "4be0ac97df": 91, "4be9025726": 91, "4c18a7bfab": 91, "4c269afea9": 91, "4c3db058db": 179, "4e1ef26a1e": 96, "50f4c0195b": 150, "50f89963c0": 96, "5105c5e4b8": 180, "51d60e4f93": 46, "51ee638399": 96, "522ea1a892": 180, "528e9f30e7": 91, "532efb206a": 180, "544b1486ac": 91, "5592eb680c": 180, "562fadda3a": 91, "568b30cf93": 150, "575f0e2d8e": 91, "5767fe466c": 150, "581c78d558": 180, "5a0ddcf128": 96, "5adf056317": 144, "5b33c701ce": 180, "5b8f636b33": 150, "5b9d26b1d7": 180, "5c24813a0b": 180, "5d0b35f30f": 46, "5e130392e1": 96, "5e41efe5bc": 180, "5e75de78ae": 91, "5fc34880f7": 180, "60912d6bab": 96, "612c96383d": 180, "61e5fd2205": 144, "620e350d23": 180, "62c27fcaaf": 180, "637c22d967": 91, "63eaebe4a2": 96, "63fd6b311e": 180, "64099f32ab": 180, "65643c4b34": 96, "660a88feb5": 180, "664b8d0c9f": 150, "665a7947b0": 180, "66affc2e86": 180, "673b1c03c9": 96, "67780f49c2": 91, "679a24b7bd": 180, "680d35b75b": 144, "68364a69ef": 180, "683bfaf498": 180, "68e883ff28": 180, "691f63f681": 180, "69f2d3146c": 96, "6c5c018237": 91, "6caa33f43a": 96, "6d2c7cc107": 180, "6d55effbbe": 144, "6d6b09b420": 51, "6d715acc3e": 180, "6e89b7359d": 96, "6e9428d555": 150, "6e9feafa2b": 91, "6eced45fee": 180, "6ef0b3282c": 96, "6f9019f0ea": 91, "6fe0ee9b7c": 180, "6ff74d4995": 180, "712b6ec68e": 96, "71680a627f": 96, "716aad4b56": 180, "721c2cda07": 180, "72218d52ac": 96, "7286b8aac9": 91, "728ba7998d": 91, "73b2b9af5f": 96, "7452941f4f": 180, "759d8249dd": 91, "75a55907dc": 150, "75f3a2a19e": 150, "77e7e4b1a1": 144, "7898e6542c": 180, "78e639c2c4": 91, "79091168f8": 180, "7ad5af3fe6": 180, "7b1a7dec16": 150, "7b36c4c3db": 180, "7b455d07cc": 150, "7bce4cfa48": 180, "7c064444d0": 144, "7c8014406a": 91, "7cb70182e5": 96, "7d04e540f5": 91, "7d5df020bf": 96, "7dfda4322c": 96, "7e6a27cc7c": 96, "7e9e344bf4": 180, "7eb9424a53": 180, "7ec8ea61f4": 91, "7fd2806fb0": 180, "8006501830": 150, "8014aeb412": 180, "80d1d22999": 180, "812f31be15": 144, "81312af68f": 92, "82843a1676": 150, "835aea9584": 36, "8366c67e9b": 180, "8467aa6c5c": 180, "8470ee5f48": 180, "8473ae2c60": 180, "8519765a65": 150, "851f73e4fc": 96, "85621c2c81": 150, "85b045995c": 180, "860c0a7cf8": 92, "861bd4b31e": 180, "8639adb930": 180, "8683e4d414": 150, "8687e892ff": 180, "86c5907811": 180, "870c197c8b": 180, "87de455fb7": 180, "87e1975888": 96, "87f5d4903c": 96, "883ede763d": 150, "88b84fe107": 91, "88ee198ce0": 91, "89d148a39f": 96, "89f3d789c5": 180, "8a22bb6c32": 180, "8a76048654": 180, "8a99d63296": 97, "8b0697f61a": 96, "8b722babfb": 180, "8ba5691030": 180, "8bdd52a66b": 150, "8c427b6a57": 180, "8cb68f36f6": 91, "8cbf0d6194": 180, "8d1ab4a2ed": 91, "8d55a5aebb": 180, "8d8c5906bd": 180, "8eb95e2e56": 150, "8f99788aa7": 180, "8fa5b3778f": 91, "9009ab4811": 91, "90c10e44cf": 91, "90c2c5c336": 96, "9124189275": 91, "91ee8300e7": 144, "9246556dfd": 91, "9323741e3b": 150, "94a33d3d20": 180, "9584210f86": 91, "9637e3b658": 51, "966c4c022e": 180, "9781e083b5": 180, "990d358980": 180, "995c087687": 150, "99a7d42674": 144, "99f056c109": 180, "9a29032b9c": 180, "9b07fc4cf6": 180, "9b5aa49509": 96, "9b5abb8108": 91, "9be210e984": 150, "9c3c28740e": 180, "9cace717c5": 180, "9d3ff7c1c1": 91, "9d8c66d92c": 150, "9eaa2f1fcc": 91, "9f1967f60f": 96, "9fa359e1cb": 150, "9fca469ddd": 96, "9ff11b620a": 180, "9ff655b9a3": 180, "a029b21901": 180, "a0c7eedeb8": 144, "a15e70486b": 180, "a35bef8bbf": 180, "a4309379a2": 91, "a51335af59": 96, "a5690fb3bf": 180, "a5b71f76fb": 86, "a5c8b1f945": 150, "a635426233": 150, "a73cc75b81": 144, "a7863d3903": 180, "a88f1fd4e3": 144, "aa2e90aa98": 144, "aab5ecf878": 91, "aafc5edf08": 96, "ab49400ffe": 180, "acd7b890f6": 91, "ad3ee9b86b": 180, "ad5fda372c": 144, "adb2040e5f": 91, "ae30aed29d": 180, "ae57b941a0": 180, "aeb9de8f66": 41, "af658a277c": 91, "af881cd801": 150, "b016a85236": 180, "b0313efe37": 96, "b19d6e149a": 120, "b19f091836": 180, "b2304e81df": 144, "b2d23dcf3a": 150, "b3cee57f31": 36, "b41a7ebfc6": 180, "b455f801b5": 46, "b47336c07b": 96, "b499ce791f": 180, "b52d26ddf9": 96, "b5c525cb08": 180, "b5d3b9be03": 91, "b6386bc3ce": 96, "b748b0f3be": 180, "b75e9ea782": 180, "b8237af453": 180, "b8a2104720": 96, "b8d6f92a65": 96, "b8f93a4094": 180, "bb0a1708ea": 180, "bb2245ab94": 180, "bb4ae8019f": 180, "bbdc38baa0": 76, "bbfe438d63": 96, "bc2be9fdc8": 96, "bcc00265f4": 96, "bd42cc48e4": 150, "bd43315417": 180, "bd85b04982": 51, "bda3146a46": 96, "be2b40d82a": 150, "c0f856e4de": 96, "c1bfacba4a": 91, "c1dcd30fb2": 96, "c285ede7f3": 180, "c2a6163d39": 150, "c3517ebed5": 86, "c3aabac30c": 180, "c3bb62a2f7": 144, "c454f19e90": 150, "c4c410ccd7": 180, "c5b94822e3": 180, "c64e9d1f7e": 91, "c682d1748f": 150, "c6d04b1ca3": 180, "c6dda81d86": 180, "c71623ab0c": 180, "c7db88a9db": 144, "c80ecb97d6": 150, "c8dd4de705": 180, "c915c8cbba": 150, "cb25a994d8": 144, "cba3e31e88": 91, "cc43a853e2": 180, "cc6c653874": 180, "cc718c7746": 180, "cc7e050f7f": 144, "cd14ed8653": 144, "cd5e4efaad": 46, "cddf78284d": 86, "cde37afe57": 144, "ce358eaf23": 150, "ce45145721": 91, "ce7d4af66d": 180, "ce9fb4bd8e": 91, "cec4db17a0": 180, "cecdd82d3c": 180, "ceea39e735": 180, "cf3e28c92a": 180, "cf8c671dab": 150, "cfd1e8166f": 96, "cfe7d98e50": 150, "cff0bbcba8": 96, "d1219663b7": 180, "d18ea7cd51": 180, "d1ed509b94": 91, "d22c5d5908": 81, "d2c6c7d8f6": 96, "d380084b7c": 91, "d3a2586e34": 180, "d3b1039c67": 180, "d3b25a44b3": 180, "d3f1d615b1": 180, "d7203fdab6": 96, "d76e963754": 96, "d7b3892660": 66, "d8b3e257da": 150, "d8b93e6bb1": 180, "d949468ad6": 180, "da553b619f": 180, "daac20af89": 180, "db8bf2430a": 180, "dbd729449a": 180, "dc0928b157": 91, "dc9aa0b8c0": 180, "dcc0637430": 180, "dcd3e1b53e": 86, "de1854f657": 101, "deb31e46cf": 96, "debccf2743": 150, "decf924833": 150, "e08b241b91": 180, "e0daa3b339": 180, "e1a52251b7": 180, "e1fc6d5237": 91, "e228ce16fd": 96, "e36dbb2ab7": 91, "e3dcf7a45e": 180, "e411e957af": 180, "e412e6a76b": 180, "e45a003b97": 179, "e60826ddf9": 91, "e6295c843b": 96, "e62c23b62b": 150, "e6b7a8fe73": 180, "e6f0e3131c": 180, "e7a3f8884e": 180, "e7c176739c": 180, "e965cd989b": 86, "e989440f7b": 150, "e98d115b9c": 81, "ea5f8c74d6": 180, "ea8a5b5a78": 96, "eaad295e8c": 150, "eaf4947f74": 180, "eb65451f4b": 92, "eb79c39e8e": 180, "eb92c92912": 96, "ebbb88e5f5": 180, "ec9b46eb6c": 180, "eca0be379d": 180, "ed33e8efb7": 66, "eda3a7bbb1": 150, "ee3ff10184": 180, "eec8403cc8": 91, "eee2db8829": 150, "ef22b8a227": 91, "ef8737ca22": 180, "eff7c1c098": 180, "f00dc892b2": 96, "f019c9ff98": 96, "f01edcbffb": 179, "f0866da89c": 180, "f12eb5256e": 180, "f1df2ea2dc": 180, "f29119c644": 180, "f3419f3a62": 150, "f35029f76d": 180, "f39dc2240d": 180, "f3aa63fa74": 150, "f3f3c201bd": 180, "f4865471b4": 96, "f505ae958c": 91, "f7605e73cd": 150, "f7917687d6": 180, "f7d310e219": 180, "f7e25f87b2": 180, "f94cd39525": 91, "f9f9aa431c": 180, "fa666fcc95": 66, "fb10740465": 180, "fb25b14e48": 91, "fb28ec1ba3": 150, "fbdda5ec7b": 96, "fbdf2180ee": 150, "fc0db37221": 91, "fd237cf4fb": 180, "fe36582e18": 180, "fef14bb2f2": 180, "ffe59ed1c1": 150}
datasets/youtube-vos/train.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"003234408d": 180, "0043f083b5": 96, "0044fa5fba": 87, "005a527edd": 144, "0065b171f9": 180, "00917dcfc4": 96, "00a23ccf53": 180, "00ad5016a4": 91, "01082ae388": 150, "011ac0a06f": 180, "013099c098": 91, "0155498c85": 180, "01694ad9c8": 91, "017ac35701": 180, "01b80e8e1a": 61, "01baa5a4e1": 150, "01c3111683": 180, "01c4cb5ffe": 180, "01c76f0a82": 96, "01c783268c": 180, "01e64dd36a": 91, "01ed275c6e": 96, "01ff60d1fa": 180, "020cd28cd2": 150, "02264db755": 180, "0248626d9a": 91, "02668dbffa": 150, "0274193026": 144, "02d28375aa": 180, "02f3a5c4df": 46, "031ccc99b1": 91, "0321b18c10": 92, "0348a45bca": 180, "0355e92655": 92, "0358b938c1": 91, "0368107cf1": 96, "0379ddf557": 180, "038b2cc71d": 91, "038c15a5dd": 178, "03a06cc98a": 96, "03a63e187f": 180, "03c95b4dae": 92, "03e2b57b0e": 150, "04194e1248": 180, "04259896e2": 180, "0444918a5f": 96, "04460a7a52": 180, "04474174a4": 180, "0450095513": 150, "045f00aed2": 180, "04667fabaa": 180, "04735c5030": 91, "04990d1915": 92, "04d62d9d98": 96, "04f21da964": 180, "04fbad476e": 180, "04fe256562": 96, "0503bf89c9": 150, "0536c9eed0": 92, "054acb238f": 180, "05579ca250": 150, "056c200404": 96, "05774f3a2c": 180, "058a7592c8": 96, "05a0a513df": 96, "05a569d8aa": 91, "05aa652648": 150, "05d7715782": 96, "05e0b0f28f": 150, "05fdbbdd7a": 66, "05ffcfed85": 180, "0630391881": 150, "06840b2bbe": 91, "068f7dce6f": 180, "0693719753": 150, "06ce2b51fb": 91, "06e224798e": 180, "06ee361788": 91, "06fbb3fa2c": 90, "0700264286": 96, "070c918ca7": 180, "07129e14a4": 180, "07177017e9": 86, "07238ffc58": 180, "07353b2a89": 150, "0738493cbf": 87, "075926c651": 87, "075c701292": 180, "0762ea9a30": 96, "07652ee4af": 150, "076f206928": 96, "077d32af19": 96, "079049275c": 144, "07913cdda7": 92, "07a11a35e8": 180, "07ac33b6df": 150, "07b6e8fda8": 46, "07c62c3d11": 180, "07cc1c7d74": 180, "080196ef01": 180, "081207976e": 96, "081ae4fa44": 150, "081d8250cb": 96, "082900c5d4": 96, "0860df21e2": 180, "0866d4c5e3": 91, "0891ac2eb6": 81, "08931bc458": 180, "08aa2705d5": 180, "08c8450db7": 96, "08d50b926c": 180, "08e1e4de15": 180, "08e48c1a48": 92, "08f561c65e": 180, "08feb87790": 96, "09049f6fe3": 150, "092e4ff450": 180, "09338adea8": 180, "093c335ccc": 144, "0970d28339": 180, "0974a213dc": 96, "097b471ed8": 96, "0990941758": 180, "09a348f4fa": 150, "09a6841288": 96, "09c5bad17b": 96, "09c9ce80c7": 180, "09ff54fef4": 150, "0a23765d15": 91, "0a275e7f12": 96, "0a2f2bd294": 96, "0a7a2514aa": 96, "0a7b27fde9": 180, "0a8c467cc3": 180, "0ac8c560ae": 96, "0b1627e896": 96, "0b285c47f6": 144, "0b34ec1d55": 180, "0b5b5e8e5a": 96, "0b68535614": 180, "0b6f9105fc": 180, "0b7dbfa3cb": 91, "0b9cea51ca": 180, "0b9d012be8": 180, "0bcfc4177d": 96, "0bd37b23c1": 96, "0bd864064c": 158, "0c11c6bf7b": 180, "0c26bc77ac": 180, "0c3a04798c": 96, "0c44a9d545": 180, "0c817cc390": 180, "0ca839ee9a": 180, "0cd7ac0ac0": 150, "0ce06e0121": 180, "0cfe974a89": 180, "0d2fcc0dcd": 96, "0d3aad05d2": 144, "0d40b015f4": 180, "0d97fba242": 91, "0d9cc80d7e": 51, "0dab85b6d3": 144, "0db5c427a5": 96, "0dbaf284f1": 97, "0de4923598": 97, "0df28a9101": 150, "0e04f636c4": 150, "0e05f0e232": 180, "0e0930474b": 91, "0e27472bea": 180, "0e30020549": 144, "0e621feb6c": 180, "0e803c7d73": 91, "0e9ebe4e3c": 92, "0e9f2785ec": 96, "0ea68d418b": 96, "0eb403a222": 96, "0ee92053d6": 97, "0eefca067f": 150, "0f17fa6fcb": 180, "0f1ac8e9a3": 180, "0f202e9852": 91, "0f2ab8b1ff": 180, "0f51a78756": 150, "0f5fbe16b0": 180, "0f6072077b": 91, "0f6b69b2f4": 180, "0f6c2163de": 144, "0f74ec5599": 180, "0f9683715b": 96, "0fa7b59356": 180, "0fb173695b": 96, "0fc958cde2": 150, "0fe7b1a621": 180, "0ffcdb491c": 96, "101caff7d4": 96, "1022fe8417": 96, "1032e80b37": 96, "103f501680": 180, "104e64565f": 96, "104f1ab997": 91, "106242403f": 96, "10b31f5431": 180, "10eced835e": 91, "110d26fa3a": 150, "1122c1d16a": 180, "1145b49a5f": 180, "11485838c2": 96, "114e7676ec": 180, "1157472b95": 180, "115ee1072c": 91, "1171141012": 150, "117757b4b8": 180, "1178932d2f": 180, "117cc76bda": 180, "1180cbf814": 180, "1187bbd0e3": 96, "1197e44b26": 180, "119cf20728": 180, "119dd54871": 180, "11a0c3b724": 91, "11a6ba8c94": 180, "11c722a456": 180, "11cbcb0b4d": 96, "11ccf5e99d": 96, "11ce6f452e": 91, "11e53de6f2": 46, "11feabe596": 150, "120cb9514d": 180, "12156b25b3": 180, "122896672d": 180, "1232b2f1d4": 36, "1233ac8596": 97, "1239c87234": 180, "1250423f7c": 96, "1257a1bc67": 180, "125d1b19dd": 180, "126d203967": 180, "1295e19071": 96, "12ad198c54": 144, "12bddb2bcb": 150, "12ec9b93ee": 180, "12eebedc35": 91, "132852e094": 180, "1329409f2a": 180, "13325cfa14": 96, "1336440745": 180, "134d06dbf9": 97, "135625b53d": 144, "13870016f9": 92, "13960b3c84": 96, "13adaad9d9": 180, "13ae097e20": 180, "13e3070469": 96, "13f6a8c20d": 144, "1416925cf2": 92, "142d2621f5": 91, "145d5d7c03": 180, "145fdc3ac5": 180, "1471274fa7": 76, "14a6b5a139": 180, "14c21cea0d": 180, "14dae0dc93": 96, "14f9bd22b5": 180, "14fd28ae99": 180, "15097d5d4e": 144, "150ea711f2": 180, "1514e3563f": 180, "152aaa3a9e": 180, "152b7d3bd7": 150, "15617297cc": 180, "15abbe0c52": 150, "15d1fb3de5": 180, "15f67b0fab": 180, "161eb59aad": 96, "16288ea47f": 180, "164410ce62": 91, "165c3c8cd4": 96, "165c42b41b": 91, "165ec9e22b": 144, "1669502269": 91, "16763cccbb": 150, "16adde065e": 96, "16af445362": 96, "16afd538ad": 150, "16c3fa4d5d": 96, "16d1d65c27": 180, "16e8599e94": 180, "16fe9fb444": 91, "1705796b02": 96, "1724db7671": 144, "17418e81ea": 180, "175169edbb": 144, "17622326fd": 180, "17656bae77": 91, "17b0d94172": 61, "17c220e4f6": 180, "17c7bcd146": 96, "17cb4afe89": 180, "17cd79a434": 180, "17d18604c3": 96, "17d8ca1a37": 150, "17e33f4330": 180, "17f7a6d805": 150, "180abc8378": 180, "183ba3d652": 96, "185bf64702": 96, "18913cc690": 91, "1892651815": 180, "189ac8208a": 91, "189b44e92c": 97, "18ac264b76": 150, "18b245ab49": 91, "18b5cebc34": 150, "18bad52083": 180, "18bb5144d5": 180, "18c6f205c5": 96, "1903f9ea15": 96, "1917b209f2": 91, "191e74c01d": 150, "19367bb94e": 180, "193ffaa217": 91, "19696b67d3": 96, "197f3ab6f3": 180, "1981e763cc": 180, "198afe39ae": 144, "19a6e62b9b": 150, "19b60d5335": 180, "19c00c11f9": 150, "19e061eb88": 91, "19e8bc6178": 86, "19ee80dac6": 180, "1a25a9170a": 180, "1a359a6c1a": 150, "1a3e87c566": 150, "1a5fe06b00": 91, "1a6c0fbd1e": 144, "1a6f3b5a4b": 96, "1a8afbad92": 92, "1a8bdc5842": 150, "1a95752aca": 150, "1a9c131cb7": 180, "1aa3da3ee3": 150, "1ab27ec7ea": 56, "1abf16d21d": 150, "1acd0f993b": 180, "1ad202e499": 180, "1af8d2395d": 180, "1afd39a1fa": 91, "1b2d31306f": 180, "1b3fa67f0e": 92, "1b43fa74b4": 150, "1b73ea9fc2": 92, "1b7e8bb255": 96, "1b8680f8cd": 180, "1b883843c0": 91, "1b8898785b": 180, "1b88ba1aa4": 180, "1b96a498e5": 150, "1bbc4c274f": 96, "1bd87fe9ab": 66, "1c4090c75b": 180, "1c41934f84": 96, "1c72b04b56": 180, "1c87955a3a": 150, "1c9f9eb792": 180, "1ca240fede": 96, "1ca5673803": 180, "1cada35274": 180, "1cb44b920d": 180, "1cd10e62be": 150, "1d3087d5e5": 180, "1d3685150a": 92, "1d6ff083aa": 96, "1d746352a6": 66, "1da256d146": 91, "1da4e956b1": 180, "1daf812218": 150, "1dba687bce": 180, "1dce57d05d": 86, "1de4a9e537": 97, "1dec5446c8": 180, "1dfbe6f586": 150, "1e1a18c45a": 180, "1e1e42529d": 76, "1e4be70796": 96, "1eb60959c8": 180, "1ec8b2566b": 180, "1ecdc2941c": 180, "1ee0ac70ff": 87, "1ef8e17def": 91, "1f1a2a9fc0": 86, "1f1beb8daa": 150, "1f2609ee13": 180, "1f3876f8d0": 144, "1f4ec0563d": 150, "1f64955634": 96, "1f7d31b5b2": 96, "1f8014b7fd": 96, "1f9c7d10f1": 180, "1fa350df76": 96, "1fc9538993": 180, "1fe2f0ec59": 150, "2000c02f9d": 180, "20142b2f05": 180, "201a8d75e5": 150, "2023b3ee4f": 180, "202b767bbc": 92, "203594a418": 180, "2038987336": 150, "2039c3aecb": 96, "204a90d81f": 150, "207bc6cf01": 144, "208833d1d1": 180, "20c6d8b362": 46, "20e3e52e0a": 96, "2117fa0c14": 180, "211bc5d102": 150, "2120d9c3c3": 150, "2125235a49": 180, "21386f5978": 92, "2142af8795": 150, "215dfc0f73": 96, "217bae91e5": 180, "217c0d44e4": 150, "219057c87b": 150, "21d0edbf81": 96, "21df87ad76": 96, "21f1d089f5": 96, "21f4019116": 180, "222597030f": 91, "222904eb5b": 92, "223a0e0657": 180, "223bd973ab": 92, "22472f7395": 150, "224e7c833e": 96, "225aba51d9": 86, "2261d421ea": 180, "2263a8782b": 180, "2268cb1ffd": 150, "2268e93b0a": 61, "2293c99f3f": 180, "22a1141970": 91, "22b13084b2": 180, "22d9f5ab0c": 180, "22f02efe3a": 144, "232c09b75b": 150, "2350d71b4b": 180, "2376440551": 180, "2383d8aafd": 144, "238b84e67f": 96, "238d4b86f6": 91, "238d947c6b": 46, "23993ce90d": 180, "23b0c8a9ab": 150, "23b3beafcc": 156, "23d80299fe": 92, "23f404a9fc": 96, "240118e58a": 178, "2431dec2fd": 180, "24440e0ac7": 97, "2457274dbc": 180, "2465bf515d": 91, "246b142c4d": 180, "247d729e36": 96, "2481ceafeb": 150, "24866b4e6a": 150, "2489d78320": 180, "24ab0b83e8": 180, "24b0868d92": 180, "24b5207cd9": 96, "24ddf05c03": 92, "250116161c": 71, "256ad2e3fc": 180, "256bd83d5e": 180, "256dcc8ab8": 180, "2589956baa": 150, "258b3b33c6": 91, "25ad437e29": 96, "25ae395636": 180, "25c750c6db": 150, "25d2c3fe5d": 180, "25dc80db7c": 96, "25f97e926f": 180, "26011bc28b": 150, "260846ffbe": 180, "260dd9ad33": 66, "267964ee57": 92, "2680861931": 96, "268ac7d3fc": 180, "26b895d91e": 71, "26bc786d4f": 91, "26ddd2ef12": 180, "26de3d18ca": 150, "26f7784762": 180, "2703e52a6a": 180, "270ed80c12": 180, "2719b742ab": 180, "272f4163d0": 180, "27303333e1": 96, "27659fa7d6": 180, "279214115d": 180, "27a5f92a9c": 97, "27cf2af1f3": 150, "27f0d5f8a2": 86, "28075f33c1": 180, "281629cb41": 96, "282b0d51f5": 96, "282fcab00b": 96, "28449fa0dc": 180, "28475208ca": 96, "285580b7c4": 180, "285b69e223": 150, "288c117201": 150, "28a8eb9623": 180, "28bf9c3cf3": 180, "28c6b8f86a": 180, "28c972dacd": 144, "28d9fa6016": 96, "28e392de91": 144, "28f4a45190": 150, "298c844fc9": 91, "29a0356a2b": 180, "29d779f9e3": 76, "29dde5f12b": 86, "29de7b6579": 150, "29e630bdd0": 144, "29f2332d30": 144, "2a18873352": 92, "2a3824ff31": 91, "2a559dd27f": 96, "2a5c09acbd": 76, "2a63eb1524": 96, "2a6a30a4ea": 150, "2a6d9099d1": 180, "2a821394e3": 81, "2a8c5b1342": 96, "2abc8d66d2": 96, "2ac9ef904a": 46, "2b08f37364": 150, "2b351bfd7d": 180, "2b659a49d7": 66, "2b69ee5c26": 96, "2b6c30bbbd": 180, "2b88561cf2": 144, "2b8b14954e": 180, "2ba621c750": 150, "2bab50f9a7": 180, "2bb00c2434": 91, "2bbde474ef": 92, "2bdd82fb86": 150, "2be06fb855": 96, "2bf545c2f5": 180, "2bffe4cf9a": 96, "2c04b887b7": 144, "2c05209105": 180, "2c0ad8cf39": 180, "2c11fedca8": 56, "2c1a94ebfb": 91, "2c1e8c8e2f": 180, "2c29fabcf1": 96, "2c2c076c01": 180, "2c3ea7ee7d": 92, "2c41fa0648": 87, "2c44bb6d1c": 96, "2c54cfbb78": 180, "2c5537eddf": 180, "2c6e63b7de": 150, "2cb10c6a7e": 180, "2cbcd5ccd1": 180, "2cc5d9c5f6": 180, "2cd01cf915": 180, "2cdbf5f0a7": 91, "2ce660f123": 96, "2cf114677e": 150, "2d01eef98e": 180, "2d03593bdc": 96, "2d183ac8c4": 180, "2d33ad3935": 96, "2d3991d83e": 150, "2d4333577b": 180, "2d4d015c64": 96, "2d8f5e5025": 144, "2d900bdb8e": 180, "2d9a1a1d49": 46, "2db0576a5c": 180, "2dc0838721": 180, "2dcc417f82": 150, "2df005b843": 180, "2df356de14": 180, "2e00393d96": 61, "2e03b8127a": 180, "2e0f886168": 96, "2e2bf37e6d": 180, "2e42410932": 87, "2ea78f46e4": 180, "2ebb017a26": 180, "2ee2edba2a": 96, "2efb07554a": 180, "2f17e4fc1e": 96, "2f2c65c2f3": 144, "2f2d9b33be": 150, "2f309c206b": 180, "2f53822e88": 144, "2f53998171": 96, "2f5b0c89b1": 180, "2f680909e6": 180, "2f710f66bd": 180, "2f724132b9": 91, "2f7e3517ae": 91, "2f96f5fc6f": 180, "2f97d9fecb": 96, "2fbfa431ec": 96, "2fc9520b53": 180, "2fcd9f4c62": 180, "2feb30f208": 87, "2ff7f5744f": 150, "30085a2cc6": 96, "30176e3615": 56, "301f72ee11": 92, "3026bb2f61": 180, "30318465dc": 150, "3054ca937d": 180, "306121e726": 92, "3064ad91e8": 180, "307444a47f": 180, "307bbb7409": 91, "30a20194ab": 144, "30c35c64a4": 150, "30dbdb2cd6": 91, "30fc77d72f": 150, "310021b58b": 96, "3113140ee8": 144, "3150b2ee57": 180, "31539918c4": 180, "318dfe2ce2": 144, "3193da4835": 91, "319f725ad9": 180, "31bbd0d793": 91, "322505c47f": 180, "322b237865": 92, "322da43910": 97, "3245e049fb": 66, "324c4c38f6": 180, "324e35111a": 150, "3252398f09": 150, "327dc4cabf": 180, "328d918c7d": 180, "3290c0de97": 96, "3299ae3116": 180, "32a7cd687b": 150, "33098cedb4": 92, "3332334ac4": 180, "334cb835ac": 180, "3355e056eb": 180, "33639a2847": 180, "3373891cdc": 180, "337975816b": 180, "33e29d7e91": 96, "34046fe4f2": 180, "3424f58959": 180, "34370a710f": 92, "343bc6a65a": 179, "3450382ef7": 144, "3454303a08": 180, "346aacf439": 180, "346e92ff37": 180, "34a5ece7dd": 144, "34b109755a": 180, "34d1b37101": 96, "34dd2c70a7": 180, "34efa703df": 180, "34fbee00a6": 150, "3504df2fda": 96, "35195a56a1": 150, "351c822748": 180, "351cfd6bc5": 180, "3543d8334c": 180, "35573455c7": 96, "35637a827f": 96, "357a710863": 92, "358bf16f9e": 96, "35ab34cc34": 180, "35c6235b8d": 91, "35d01a438a": 180, "3605019d3b": 96, "3609bc3f88": 92, "360e25da17": 97, "36299c687c": 96, "362c5bc56e": 180, "3649228783": 150, "365b0501ea": 92, "365f459863": 180, "369893f3ad": 180, "369c9977e1": 180, "369dde050a": 96, "36c7dac02f": 180, "36d5b1493b": 180, "36f5cc68fd": 91, "3735480d18": 180, "374b479880": 97, "375a49d38f": 180, "375a5c0e09": 180, "376bda9651": 144, "377db65f60": 144, "37c19d1087": 46, "37d4ae24fc": 96, "37ddce7f8b": 180, "37e10d33af": 180, "37e45c6247": 96, "37fa0001e8": 180, "3802d458c0": 150, "382caa3cb4": 91, "383bb93111": 91, "388843df90": 180, "38924f4a7f": 92, "38b00f93d7": 92, "38c197c10e": 96, "38c9c3d801": 180, "38eb2bf67f": 92, "38fe9b3ed1": 180, "390352cced": 180, "390c51b987": 96, "390ca6f1d6": 144, "392bc0f8a1": 96, "392ecb43bd": 92, "3935291688": 150, "3935e63b41": 180, "394454fa9c": 180, "394638fc8b": 96, "39545e20b7": 180, "397abeae8f": 180, "3988074b88": 91, "398f5d5f19": 174, "39bc49a28c": 180, "39befd99fb": 144, "39c3c7bf55": 180, "39d584b09f": 91, "39f6f6ffb1": 180, "3a079fb484": 180, "3a0d3a81b7": 150, "3a1d55d22b": 82, "3a20a7583e": 96, "3a2c1f66e5": 150, "3a33f4d225": 180, "3a3bf84b13": 144, "3a4565e5ec": 144, "3a4e32ed5e": 180, "3a7ad86ce0": 180, "3a7bdde9b8": 180, "3a98867cbe": 91, "3aa3f1c9e8": 150, "3aa7fce8b6": 91, "3aa876887d": 96, "3ab807ded6": 96, "3ab9b1a85a": 96, "3adac8d7da": 180, "3ae1a4016f": 96, "3ae2deaec2": 180, "3ae81609d6": 144, "3af847e62f": 92, "3b23792b84": 144, "3b3b0af2ee": 150, "3b512dad74": 144, "3b6c7988f6": 91, "3b6e983b5b": 180, "3b74a0fc20": 180, "3b7a50b80d": 180, "3b96d3492f": 180, "3b9ad0c5a9": 150, "3b9ba0894a": 180, "3bb4e10ed7": 144, "3bd9a9b515": 150, "3beef45388": 96, "3c019c0a24": 96, "3c090704aa": 96, "3c2784fc0d": 144, "3c47ab95f8": 150, "3c4db32d74": 91, "3c5ff93faf": 180, "3c700f073e": 180, "3c713cbf2f": 91, "3c8320669c": 180, "3c90d225ee": 180, "3cadbcc404": 96, "3cb9be84a5": 150, "3cc37fd487": 91, "3cc6f90cb2": 92, "3cd5e035ef": 180, "3cdf03531b": 178, "3cdf828f59": 180, "3d254b0bca": 180, "3d5aeac5ba": 180, "3d690473e1": 180, "3d69fed2fb": 96, "3d8997aeb6": 96, "3db0d6b07e": 96, "3db1ddb8cf": 180, "3db907ac77": 180, "3dcbc0635b": 150, "3dd48ed55f": 144, "3de4ac4ec4": 92, "3decd63d88": 180, "3e04a6be11": 180, "3e108fb65a": 96, "3e1448b01c": 150, "3e16c19634": 180, "3e2845307e": 61, "3e38336da5": 96, "3e3a819865": 180, "3e3e4be915": 96, "3e680622d7": 91, "3e7d2aeb07": 96, "3e7d8f363d": 180, "3e91f10205": 26, "3ea4c49bbe": 144, "3eb39d11ab": 180, "3ec273c8d5": 96, "3ed3f91271": 76, "3ee062a2fd": 180, "3eede9782c": 180, "3ef2fa99cb": 180, "3efc6e9892": 92, "3f0b0dfddd": 96, "3f0c860359": 91, "3f18728586": 180, "3f3b15f083": 96, "3f45a470ad": 46, "3f4f3bc803": 150, "3fd96c5267": 91, "3fea675fab": 91, "3fee8cbc9f": 96, "3fff16d112": 180, "401888b36c": 144, "4019231330": 150, "402316532d": 180, "402680df52": 180, "404d02e0c0": 150, "40709263a8": 81, "4083cfbe15": 150, "40a96c5cb1": 96, "40b8e50f82": 91, "40f4026bf5": 144, "4100b57a3a": 150, "41059fdd0b": 180, "41124e36de": 144, "4122aba5f9": 180, "413bab0f0d": 96, "4164faee0b": 180, "418035eec9": 180, "4182d51532": 96, "418bb97e10": 144, "41a34c20e7": 96, "41dab05200": 180, "41ff6d5e2a": 77, "420caf0859": 56, "42264230ba": 96, "425a0c96e0": 91, "42da96b87c": 180, "42eb5a5b0f": 180, "42f17cd14d": 91, "42f5c61c49": 180, "42ffdcdee9": 180, "432f9884f9": 91, "43326d9940": 150, "4350f3ab60": 144, "4399ffade3": 96, "43a6c21f37": 150, "43b5555faa": 180, "43d63b752a": 180, "4416bdd6ac": 92, "4444753edd": 76, "444aa274e7": 150, "444d4e0596": 150, "446b8b5f7a": 96, "4478f694bb": 91, "44b1da0d87": 92, "44b4dad8c9": 96, "44b5ece1b9": 180, "44d239b24e": 150, "44eaf8f51e": 180, "44f4f57099": 96, "44f7422af2": 180, "450787ac97": 180, "4523656564": 96, "4536c882e5": 180, "453b65daa4": 180, "454f227427": 91, "45636d806a": 180, "456fb9362e": 91, "457e717a14": 150, "45a89f35e1": 180, "45bf0e947d": 150, "45c36a9eab": 150, "45d9fc1357": 174, "45f8128b97": 180, "4607f6c03c": 91, "46146dfd39": 92, "4620e66b1e": 150, "4625f3f2d3": 96, "462b22f263": 96, "4634736113": 180, "463c0f4fdd": 180, "46565a75f8": 96, "46630b55ae": 56, "466839cb37": 91, "466ba4ae0c": 180, "4680236c9d": 180, "46bf4e8709": 91, "46e18e42f1": 150, "46f5093c59": 180, "47269e0499": 92, "472da1c484": 144, "47354fab09": 180, "4743bb84a7": 92, "474a796272": 180, "4783d2ab87": 96, "479cad5da3": 180, "479f5d7ef6": 96, "47a05fbd1d": 96, "4804ee2767": 97, "4810c3fbca": 180, "482fb439c2": 150, "48375af288": 96, "484ab44de4": 96, "485f3944cd": 96, "4867b84887": 150, "486a8ac57e": 180, "486e69c5bd": 180, "48812cf33e": 150, "4894b3b9ea": 180, "48bd66517d": 180, "48d83b48a4": 91, "49058178b8": 46, "4918d10ff0": 91, "4932911f80": 150, "49405b7900": 180, "49972c2d14": 150, "499bf07002": 96, "49b16e9377": 180, "49c104258e": 144, "49c879f82d": 96, "49e7326789": 180, "49ec3e406a": 91, "49fbf0c98a": 96, "4a0255c865": 180, "4a088fe99a": 96, "4a341402d0": 180, "4a3471bdf5": 96, "4a4b50571c": 144, "4a50f3d2e9": 96, "4a6e3faaa1": 180, "4a7191f08a": 150, "4a86fcfc30": 180, "4a885fa3ef": 144, "4a8af115de": 21, "4aa2e0f865": 180, "4aa9d6527f": 180, "4abb74bb52": 96, "4ae13de1cd": 91, "4af8cb323f": 97, "4b02c272b3": 180, "4b19c529fb": 96, "4b2974eff4": 180, "4b3154c159": 95, "4b54d2587f": 180, "4b556740ff": 144, "4b67aa9ef6": 178, "4b97cc7b8d": 96, "4baa1ed4aa": 91, "4bc8c676bb": 96, "4beaea4dbe": 180, "4bf5763d24": 96, "4bffa92b67": 138, "4c25dfa8ec": 96, "4c397b6fd4": 180, "4c51e75d66": 150, "4c7710908f": 180, "4c9b5017be": 180, "4ca2ffc361": 92, "4cad2e93bc": 150, "4cd427b535": 180, "4cd9a4b1ef": 180, "4cdfe3c2b2": 180, "4cef87b649": 96, "4cf208e9b3": 180, "4cf5bc3e60": 92, "4cfdd73249": 91, "4cff5c9e42": 180, "4d26d41091": 96, "4d5c23c554": 180, "4d67c59727": 150, "4d983cad9f": 180, "4da0d00b55": 144, "4daa179861": 91, "4dadd57153": 92, "4db117e6c5": 91, "4de4ce4dea": 180, "4dfaee19e5": 180, "4dfdd7fab0": 180, "4e3f346aa5": 92, "4e49c2a9c7": 56, "4e4e06a749": 180, "4e70279712": 96, "4e72856cc7": 91, "4e752f8075": 180, "4e7a28907f": 66, "4e824b9247": 180, "4e82b1df57": 180, "4e87a639bc": 180, "4ea77bfd15": 150, "4eb6fc23a2": 180, "4ec9da329e": 96, "4efb9a0720": 180, "4f062fbc63": 96, "4f35be0e0b": 96, "4f37e86797": 91, "4f414dd6e7": 180, "4f424abded": 180, "4f470cc3ae": 144, "4f601d255a": 150, "4f7386a1ab": 144, "4f824d3dcd": 91, "4f827b0751": 144, "4f8db33a13": 180, "4fa160f8a3": 180, "4fa9c30a45": 180, "4facd8f0e8": 96, "4fca07ad01": 91, "4fded94004": 180, "4fdfef4dea": 91, "4feb3ac01f": 92, "4fffec8479": 96, "500c835a86": 180, "50168342bf": 180, "50243cffdc": 180, "5031d5a036": 180, "504dd9c0fd": 96, "50568fbcfb": 180, "5069c7c5b3": 180, "508189ac91": 180, "50b6b3d4b7": 91, "50c6f4fe3e": 86, "50cce40173": 180, "50efbe152f": 180, "50f290b95d": 91, "5104aa1fea": 96, "5110dc72c0": 180, "511e8ecd7f": 150, "513aada14e": 92, "5158d6e985": 180, "5161e1fa57": 180, "51794ddd58": 96, "517d276725": 91, "51a597ee04": 51, "51b37b6d97": 96, "51b5dc30a0": 96, "51e85b347b": 180, "51eea1fdac": 150, "51eef778af": 91, "51f384721c": 76, "521cfadcb4": 180, "52355da42f": 96, "5247d4b160": 180, "524b470fd0": 180, "524cee1534": 96, "5252195e8a": 91, "5255c9ca97": 144, "525928f46f": 96, "526df007a7": 180, "529b12de78": 91, "52c7a3d653": 150, "52c8ec0373": 91, "52d225ed52": 96, "52ee406d9e": 180, "52ff1ccd4a": 96, "53143511e8": 180, "5316d11eb7": 96, "53253f2362": 180, "534a560609": 91, "5352c4a70e": 180, "536096501f": 92, "536b17bcea": 180, "5380eaabff": 144, "5390a43a54": 180, "53af427bb2": 91, "53bf5964ce": 180, "53c30110b5": 96, "53cad8e44a": 150, "53d9c45013": 91, "53e274f1b5": 150, "53e32d21ea": 96, "540850e1c7": 96, "540cb31cfe": 180, "541c4da30f": 91, "541d7935d7": 180, "545468262b": 180, "5458647306": 144, "54657855cd": 96, "547b3fb23b": 180, "5497dc3712": 150, "549c56f1d4": 96, "54a4260bb1": 150, "54b98b8d5e": 180, "54e1054b0f": 91, "54e8867b83": 180, "54ebe34f6e": 180, "5519b4ad13": 86, "551acbffd5": 150, "55341f42da": 180, "5566ab97e1": 91, "556c79bbf2": 144, "5589637cc4": 180, "558aa072f0": 180, "559824b6f6": 91, "55c1764e90": 180, "55eda6c77e": 180, "562d173565": 150, "5665c024cb": 96, "566cef4959": 91, "5675d78833": 144, "5678a91bd8": 180, "567a2b4bd0": 180, "569c282890": 86, "56cc449917": 150, "56e71f3e07": 150, "56f09b9d92": 180, "56fc0e8cf9": 144, "571ca79c71": 91, "57243657cf": 144, "57246af7d1": 91, "57427393e9": 96, "574b682c19": 180, "578f211b86": 180, "5790ac295d": 91, "579393912d": 180, "57a344ab1a": 180, "57bd3bcda4": 180, "57bfb7fa4c": 150, "57c010175e": 180, "57c457cc75": 180, "57c7fc2183": 150, "57d5289a01": 61, "58045fde85": 96, "58163c37cd": 150, "582d463e5c": 180, "5851739c15": 180, "585dd0f208": 66, "587250f3c3": 180, "589e4cc1de": 180, "589f65f5d5": 180, "58a07c17d5": 180, "58adc6d8b6": 76, "58b9bcf656": 96, "58c374917e": 96, "58fc75fd42": 87, "5914c30f05": 96, "59323787d5": 150, "5937b08d69": 96, "594065ddd7": 96, "595a0ceea6": 91, "59623ec40b": 91, "597ff7ef78": 150, "598935ef05": 46, "598c2ad3b2": 180, "59a6459751": 180, "59b175e138": 96, "59bf0a149f": 180, "59d53d1649": 180, "59e3e6fae7": 180, "59fe33e560": 180, "5a13a73fe5": 96, "5a25c22770": 150, "5a4a785006": 96, "5a50640995": 180, "5a75f7a1cf": 96, "5a841e59ad": 180, "5a91c5ab6d": 150, "5ab49d9de0": 96, "5aba1057fe": 180, "5abe46ba6d": 91, "5ac7c88d0c": 180, "5aeb95cc7d": 92, "5af15e4fc3": 91, "5afe381ae4": 96, "5b07b4229d": 51, "5b1001cc4f": 180, "5b1df237d2": 180, "5b263013bf": 91, "5b27d19f0b": 180, "5b48ae16c5": 96, "5b5babc719": 180, "5baaebdf00": 180, "5bab55cdbe": 180, "5bafef6e79": 96, "5bc77844da": 180, "5bd1f84545": 180, "5bddc3ba25": 180, "5bdf7c20d2": 180, "5bf23bc9d3": 180, "5c01f6171a": 144, "5c021681b7": 96, "5c185cff1d": 180, "5c42aba280": 180, "5c44bf8ab6": 180, "5c4c574894": 144, "5c52fa4662": 76, "5c6ea7dac3": 96, "5c74315dc2": 180, "5c7668855e": 92, "5c83e96778": 180, "5ca36173e4": 96, "5cac477371": 97, "5cb0cb1b2f": 96, "5cb0cfb98f": 144, "5cb49a19cf": 180, "5cbf7dc388": 180, "5d0e07d126": 96, "5d1e24b6e3": 81, "5d663000ff": 150, "5da6b2dc5d": 180, "5de9b90f24": 61, "5e08de0ed7": 180, "5e1011df9a": 87, "5e1ce354fd": 150, "5e35512dd7": 180, "5e418b25f9": 96, "5e4849935a": 144, "5e4ee19663": 96, "5e886ef78f": 96, "5e8d00b974": 180, "5e8d59dc31": 180, "5ed838bd5c": 96, "5edda6ee5a": 180, "5ede4d2f7a": 144, "5ede9767da": 144, "5ee23ca60e": 87, "5eec4d9fe5": 96, "5eecf07824": 180, "5eef7ed4f4": 91, "5ef5860ac6": 144, "5ef6573a99": 96, "5f1193e72b": 91, "5f29ced797": 96, "5f32cf521e": 150, "5f51876986": 96, "5f6ebe94a9": 86, "5f6f14977c": 91, "5f808d0d2d": 91, "5fb8aded6a": 180, "5fba90767d": 96, "5fd1c7a3df": 92, "5fd3da9f68": 91, "5fee2570ae": 180, "5ff66140d6": 180, "5ff8b85b53": 180, "600803c0f6": 180, "600be7f53e": 96, "6024888af8": 180, "603189a03c": 96, "6057307f6e": 180, "6061ddbb65": 96, "606c86c455": 180, "60c61cc2e5": 180, "60e51ff1ae": 150, "610e38b751": 150, "61344be2f6": 180, "6135e27185": 96, "614afe7975": 150, "614e571886": 180, "614e7078db": 96, "619812a1a7": 96, "61b481a78b": 96, "61c7172650": 180, "61cf7e40d2": 96, "61d08ef5a1": 46, "61da008958": 96, "61ed178ecb": 61, "61f5d1282c": 92, "61fd977e49": 144, "621584cffe": 180, "625817a927": 180, "625892cf0b": 96, "625b89d28a": 91, "629995af95": 150, "62a0840bb5": 180, "62ad6e121c": 87, "62d6ece152": 91, "62ede7b2da": 91, "62f025e1bc": 180, "6316faaebc": 97, "63281534dc": 150, "634058dda0": 144, "6353f09384": 180, "6363c87314": 180, "636e4872e0": 180, "637681cd6b": 180, "6376d49f31": 180, "6377809ec2": 180, "63936d7de5": 96, "639bddef11": 150, "63d37e9fd3": 180, "63d90c2bae": 96, "63e544a5d6": 180, "63ebbcf874": 96, "63fff40b31": 180, "6406c72e4d": 61, "64148128be": 96, "6419386729": 150, "643092bc41": 96, "644081b88d": 144, "64453cf61d": 180, "644bad9729": 96, "6454f548fd": 180, "645913b63a": 180, "64750b825f": 180, "64a43876b7": 96, "64dd6c83e3": 92, "64e05bf46e": 96, "64f55f1478": 150, "650b0165e4": 180, "651066ed39": 180, "652b67d960": 180, "653821d680": 180, "6538d00d73": 180, "65866dce22": 150, "6589565c8c": 150, "659832db64": 180, "65ab7e1d98": 180, "65b7dda462": 180, "65bd5eb4f5": 180, "65dcf115ab": 91, "65e9825801": 180, "65f9afe51c": 91, "65ff12bcb5": 180, "666b660284": 180, "6671643f31": 180, "668364b372": 96, "66852243cb": 96, "6693a52081": 180, "669b572898": 180, "66e98e78f5": 91, "670f12e88f": 180, "674c12c92d": 91, "675c27208a": 180, "675ed3e1ca": 144, "67741db50a": 96, "678a2357eb": 70, "67b0f4d562": 180, "67cfbff9b1": 180, "67e717d6bd": 91, "67ea169a3b": 92, "67ea809e0e": 180, "681249baa3": 180, "683de643d9": 180, "6846ac20df": 96, "6848e012ef": 96, "684bcd8812": 96, "684dc1c40c": 96, "685a1fa9cf": 91, "686dafaac9": 144, "68807d8601": 96, "6893778c77": 96, "6899d2dabe": 91, "68a2fad4ab": 180, "68cb45fda3": 180, "68cc4a1970": 96, "68dcb40675": 180, "68ea4a8c3d": 180, "68f6e7fbf0": 96, "68fa8300b4": 180, "69023db81f": 96, "6908ccf557": 91, "691a111e7c": 180, "6927723ba5": 180, "692ca0e1a2": 97, "692eb57b63": 180, "69340faa52": 96, "693cbf0c9d": 180, "6942f684ad": 96, "6944fc833b": 180, "69491c0ebf": 91, "695b61a2b0": 96, "6979b4d83f": 180, "697d4fdb02": 144, "69910460a4": 180, "6997636670": 180, "69a436750b": 96, "69aebf7669": 180, "69b8c17047": 180, "69c67f109f": 180, "69e0e7b868": 180, "69ea9c09d1": 180, "69f0af42a6": 97, "6a078cdcc7": 144, "6a37a91708": 71, "6a42176f2e": 180, "6a48e4aea8": 96, "6a5977be3a": 180, "6a5de0535f": 180, "6a80d2e2e5": 96, "6a96c8815d": 180, "6a986084e2": 96, "6aa8e50445": 92, "6ab9dce449": 150, "6abf0ba6b2": 180, "6acc6049d9": 96, "6adb31756c": 180, "6ade215eb0": 96, "6afb7d50e4": 144, "6afd692f1a": 180, "6b0b1044fe": 91, "6b17c67633": 180, "6b1b6ef28b": 92, "6b1e04d00d": 180, "6b2261888d": 96, "6b25d6528a": 144, "6b3a24395c": 150, "6b685eb75b": 96, "6b79be238c": 92, "6b928b7ba6": 96, "6b9c43c25a": 180, "6ba99cc41f": 91, "6bdab62bcd": 86, "6bf2e853b1": 180, "6bf584200f": 180, "6bf95df2b9": 150, "6c0949c51c": 180, "6c11a5f11f": 96, "6c23d89189": 61, "6c4387daf5": 96, "6c4ce479a4": 86, "6c5123e4bc": 96, "6c54265f16": 92, "6c56848429": 96, "6c623fac5f": 36, "6c81b014e9": 96, "6c99ea7c31": 92, "6c9d29d509": 91, "6c9e3b7d1a": 91, "6ca006e283": 96, "6caeb928d6": 180, "6cb2ee722a": 180, "6cbfd32c5e": 180, "6cc791250b": 150, "6cccc985e0": 96, "6d12e30c48": 180, "6d4bf200ad": 180, "6d6d2b8843": 91, "6d6eea5682": 180, "6d7a3d0c21": 96, "6d7efa9b9e": 180, "6da21f5c91": 180, "6da6adabc0": 150, "6dd2827fbb": 96, "6dd36705b9": 131, "6df3637557": 180, "6dfe55e9e5": 150, "6e1a21ba55": 96, "6e2f834767": 180, "6e36e4929a": 96, "6e4f460caf": 96, "6e618d26b6": 56, "6ead4670f7": 180, "6eaff19b9f": 180, "6eb2e1cd9e": 180, "6eb30b3b5a": 96, "6eca26c202": 91, "6ecad29e52": 96, "6ef0b44654": 96, "6efcfe9275": 180, "6f4789045c": 180, "6f49f522ef": 96, "6f67d7c4c4": 180, "6f96e91d81": 144, "6fc6fce380": 180, "6fc9b44c00": 96, "6fce7f3226": 150, "6fdf1ca888": 150, "702fd8b729": 180, "70405185d2": 180, "7053e4f41e": 180, "707bf4ce41": 87, "7082544248": 81, "708535b72a": 96, "7094ac0f60": 180, "70a6b875fa": 180, "70c3e97e41": 180, "7106b020ab": 91, "711dce6fe2": 96, "7136a4453f": 180, "7143fb084f": 180, "714d902095": 150, "7151c53b32": 150, "715357be94": 180, "7163b8085f": 150, "716df1aa59": 150, "71caded286": 150, "71d2665f35": 91, "71d67b9e19": 96, "71e06dda39": 180, "720b398b9c": 91, "720e3fa04c": 150, "720e7a5f1e": 91, "721bb6f2cb": 91, "722803f4f2": 92, "72552a07c9": 91, "726243a205": 96, "72690ef572": 46, "728cda9b65": 86, "728e81c319": 91, "72a810a799": 180, "72acb8cdf6": 180, "72b01281f9": 180, "72cac683e4": 91, "72cadebbce": 180, "72cae058a5": 180, "72d8dba870": 180, "72e8d1c1ff": 96, "72edc08285": 180, "72f04f1a38": 81, "731b825695": 144, "7320b49b13": 180, "732626383b": 87, "732df1eb05": 150, "73329902ab": 150, "733798921e": 150, "733824d431": 150, "734ea0d7fb": 91, "735a7cf7b9": 144, "7367a42892": 91, "7368d5c053": 180, "738e5a0a14": 180, "73c6ae7711": 96, "73e1852735": 150, "73e4e5cc74": 150, "73eac9156b": 180, "73f8441a88": 91, "7419e2ab3f": 91, "74267f68b9": 91, "7435690c8c": 46, "747c44785c": 81, "747f1b1f2f": 144, "748b2d5c01": 96, "74d4cee0a4": 91, "74ec2b3073": 91, "74ef677020": 96, "750be4c4d8": 96, "75172d4ac8": 96, "75285a7eb1": 180, "75504539c3": 91, "7550949b1d": 96, "7551cbd537": 150, "75595b453d": 91, "7559b4b0ec": 91, "755bd1fbeb": 96, "756f76f74d": 180, "7570ca7f3c": 180, "757a69746e": 180, "757cac96c6": 180, "7584129dc3": 144, "75a058dbcd": 91, "75b09ce005": 96, "75cae39a8f": 180, "75cee6caf0": 180, "75cf58fb2c": 91, "75d5c2f32a": 180, "75eaf5669d": 96, "75f7937438": 180, "75f99bd3b3": 96, "75fa586876": 92, "7613df1f84": 150, "762e1b3487": 96, "76379a3e69": 180, "764271f0f3": 92, "764503c499": 86, "7660005554": 46, "7666351b84": 96, "76693db153": 51, "767856368b": 92, "768671f652": 180, "768802b80d": 180, "76962c7ed2": 71, "76a75f4eee": 150, "76b90809f7": 180, "770a441457": 96, "772a0fa402": 180, "772f2ffc3e": 91, "774f6c2175": 180, "77610860e0": 56, "777e58ff3d": 96, "77920f1708": 150, "7799df28e7": 180, "779e847a9a": 81, "77ba4edc72": 96, "77c834dc43": 41, "77d8aa8691": 180, "77e7f38f4d": 144, "77eea6845e": 96, "7806308f33": 91, "78254660ea": 91, "7828af8bff": 180, "784398620a": 71, "784d201b12": 96, "78613981ed": 180, "78896c6baf": 92, "78aff3ebc0": 150, "78c7c03716": 91, "78d3676361": 91, "78e29dd4c3": 150, "78f1a1a54f": 91, "79208585cd": 180, "792218456c": 180, "7923bad550": 150, "794e6fc49f": 96, "796e6762ce": 180, "797cd21f71": 150, "79921b21c2": 150, "79a5778027": 180, "79bc006280": 180, "79bf95e624": 91, "79d9e00c55": 91, "79e20fc008": 96, "79e9db913e": 180, "79f014085e": 91, "79fcbb433a": 150, "7a13a5dfaa": 180, "7a14bc9a36": 96, "7a3c535f70": 96, "7a446a51e9": 91, "7a56e759c5": 91, "7a5f46198d": 86, "7a626ec98d": 92, "7a802264c4": 180, "7a8b5456ca": 180, "7abdff3086": 150, "7aecf9f7ac": 150, "7b0fd09c28": 96, "7b18b3db87": 180, "7b39fe7371": 144, "7b49e03d4c": 180, "7b5388c9f1": 180, "7b5cf7837f": 180, "7b733d31d8": 180, "7b74fd7b98": 180, "7b918ccb8a": 150, "7ba3ce3485": 96, "7bb0abc031": 180, "7bb5bb25cd": 180, "7bb7dac673": 92, "7bc7761b8c": 180, "7bf3820566": 96, "7c03a18ec1": 96, "7c078f211b": 150, "7c37d7991a": 71, "7c4ec17eff": 144, "7c649c2aaf": 180, "7c73340ab7": 91, "7c78a2266d": 180, "7c88ce3c5b": 180, "7ca6843a72": 180, "7cc9258dee": 96, "7cec7296ae": 46, "7d0ffa68a4": 96, "7d11b4450f": 81, "7d1333fcbe": 96, "7d18074fef": 91, "7d18c8c716": 96, "7d508fb027": 180, "7d55f791f0": 180, "7d74e3c2f6": 150, "7d783f67a9": 96, "7d83a5d854": 150, "7dd409947e": 180, "7de45f75e5": 150, "7e0cd25696": 150, "7e1922575c": 96, "7e1e3bbcc1": 180, "7e24023274": 180, "7e2f212fd3": 96, "7e6d1cc1f4": 180, "7e7cdcb284": 144, "7e9b6bef69": 66, "7ea5b49283": 92, "7eb2605d96": 91, "7eb26b8485": 180, "7ecd1f0c69": 96, "7f02b3cfe2": 180, "7f1723f0d5": 97, "7f21063c3a": 81, "7f3658460e": 91, "7f54132e48": 144, "7f559f9d4a": 144, "7f5faedf8b": 96, "7f838baf2b": 180, "7fa5f527e3": 96, "7ff84d66dd": 150, "802b45c8c4": 180, "804382b1ad": 180, "804c558adb": 96, "804f6338a4": 180, "8056117b89": 150, "806b6223ab": 96, "8088bda461": 46, "80b790703b": 180, "80c4a94706": 96, "80ce2e351b": 180, "80db581acd": 96, "80e12193df": 150, "80e41b608f": 180, "80f16b016d": 91, "81541b3725": 91, "8175486e6a": 96, "8179095000": 180, "8193671178": 180, "81a58d2c6b": 150, "81aa1286fb": 96, "81dffd30fb": 96, "8200245704": 41, "823e7a86e8": 46, "824973babb": 144, "824ca5538f": 180, "827171a845": 180, "8273a03530": 180, "827cf4f886": 91, "82b865c7dd": 180, "82c1517708": 91, "82d15514d6": 150, "82e117b900": 179, "82fec06574": 150, "832b5ef379": 97, "83424c9fbf": 180, "8345358fb8": 71, "834b50b31b": 180, "835e3b67d7": 97, "836ea92b15": 90, "837c618777": 144, "838eb3bd89": 180, "839381063f": 91, "839bc71489": 180, "83a8151377": 180, "83ae88d217": 180, "83ca8bcad0": 180, "83ce590d7f": 180, "83d3130ba0": 36, "83d40bcba5": 86, "83daba503a": 144, "83de906ec0": 180, "84044f37f3": 180, "84696b5a5e": 96, "84752191a3": 91, "847eeeb2e0": 180, "848e7835a0": 96, "84a4b29286": 180, "84a4bf147d": 66, "84be115c09": 144, "84d95c4350": 180, "84e0922cf7": 150, "84f0cfc665": 96, "8515f6db22": 180, "851f2f32c1": 91, "852a4d6067": 150, "854c48b02a": 96, "857a387c86": 180, "859633d56a": 96, "85a4f4a639": 144, "85ab85510c": 180, "85b1eda0d9": 92, "85dc1041c6": 96, "85e081f3c7": 150, "85f75187ad": 96, "8604bb2b75": 96, "860745b042": 150, "863b4049d7": 180, "8643de22d0": 180, "8647d06439": 46, "864ffce4fe": 180, "8662d9441a": 180, "8666521b13": 76, "868d6a0685": 91, "869fa45998": 91, "86a40b655d": 150, "86a8ae4223": 92, "86b2180703": 180, "86c85d27df": 180, "86d3755680": 180, "86e61829a1": 180, "871015806c": 91, "871e409c5c": 180, "8744b861ce": 96, "8749369ba0": 180, "878a299541": 144, "8792c193a0": 96, "8799ab0118": 96, "87d1f7d741": 180, "882b9e4500": 180, "885673ea17": 180, "8859dedf41": 96, "8873ab2806": 91, "887a93b198": 180, "8883e991a9": 86, "8891aa6dfa": 91, "8899d8cbcd": 91, "88b8274d67": 180, "88d3b80af6": 91, "88ede83da2": 180, "88f345941b": 180, "890976d6da": 91, "8909bde9ab": 91, "8929c7d5d9": 180, "89363acf76": 150, "89379487e0": 96, "8939db6354": 180, "893f658345": 144, "8953138465": 180, "895c96d671": 180, "895cbf96f9": 180, "895e8b29a7": 91, "898fa256c8": 180, "89986c60be": 180, "89b874547b": 180, "89bdb021d5": 144, "89c802ff9c": 96, "89d6336c2b": 180, "89ebb27334": 91, "8a27e2407c": 96, "8a31f7bca5": 96, "8a4a2fc105": 96, "8a5d6c619c": 96, "8a75ad7924": 180, "8aa817e4ed": 87, "8aad0591eb": 180, "8aca214360": 180, "8ae168c71b": 96, "8b0cfbab97": 21, "8b3645d826": 96, "8b3805dbd4": 180, "8b473f0f5d": 180, "8b4f6d1186": 180, "8b4fb018b7": 66, "8b518ee936": 92, "8b523bdfd6": 150, "8b52fb5fba": 91, "8b91036e5c": 144, "8b99a77ac5": 180, "8ba04b1e7b": 96, "8ba782192f": 180, "8bbeaad78b": 96, "8bd1b45776": 180, "8bd7a2dda6": 150, "8bdb091ccf": 180, "8be56f165d": 96, "8be950d00f": 96, "8bf84e7d45": 180, "8bffc4374b": 66, "8bfff50747": 180, "8c09867481": 144, "8c0a3251c3": 180, "8c3015cccb": 180, "8c469815cf": 96, "8c9ccfedc7": 91, "8ca1af9f3c": 150, "8ca3f6e6c1": 96, "8ca6a4f60f": 96, "8cac6900fe": 96, "8cba221a1e": 180, "8cbbe62ccd": 180, "8d064b29e2": 92, "8d167e7c08": 91, "8d4ab94e1c": 96, "8d81f6f899": 180, "8d87897d66": 91, "8dcccd2bd2": 180, "8dcfb878a8": 150, "8dd3ab71b9": 91, "8dda6bf10f": 96, "8ddd51ca94": 180, "8dea22c533": 180, "8def5bd3bf": 96, "8e1848197c": 91, "8e3a83cf2d": 91, "8e478e73f3": 91, "8e98ae3c84": 96, "8ea6687ab0": 180, "8eb0d315c1": 91, "8ec10891f9": 150, "8ec3065ec2": 180, "8ecf51a971": 150, "8eddbab9f7": 91, "8ee198467a": 180, "8ee2368f40": 180, "8ef595ce82": 150, "8f0a653ad7": 150, "8f1204a732": 150, "8f1600f7f6": 91, "8f16366707": 96, "8f1ce0a411": 92, "8f2e05e814": 91, "8f320d0e09": 96, "8f3b4a84ad": 91, "8f3fdad3da": 96, "8f5d3622d8": 96, "8f62a2c633": 180, "8f81c9405a": 97, "8f8c974d53": 120, "8f918598b6": 96, "8ff61619f6": 96, "9002761b41": 96, "90107941f3": 92, "90118a42ee": 96, "902bc16b37": 91, "903e87e0d6": 144, "9041a0f489": 96, "9047bf3222": 51, "9057bfa502": 150, "90617b0954": 92, "9076f4b6db": 180, "9077e69b08": 144, "909655b4a6": 96, "909c2eca88": 180, "909dbd1b76": 180, "90bc4a319a": 180, "90c7a87887": 96, "90cc785ddd": 96, "90d300f09b": 180, "9101ea9b1b": 96, "9108130458": 150, "911ac9979b": 150, "9151cad9b5": 97, "9153762797": 180, "91634ee0c9": 91, "916942666f": 76, "9198cfb4ea": 180, "919ac864d6": 180, "91b67d58d4": 180, "91bb8df281": 150, "91be106477": 91, "91c33b4290": 180, "91ca7dd9f3": 144, "91d095f869": 180, "91f107082e": 180, "920329dd5e": 180, "920c959958": 150, "92128fbf4b": 144, "9223dacb40": 150, "923137bb7f": 61, "9268e1f88a": 180, "927647fe08": 150, "9276f5ba47": 150, "92a28cd233": 71, "92b5c1fc6d": 144, "92c46be756": 180, "92dabbe3a0": 96, "92e3159361": 180, "92ebab216a": 180, "934bdc2893": 180, "9359174efc": 180, "935d97dd2f": 91, "935feaba1b": 96, "93901858ee": 150, "939378f6d6": 91, "939bdf742e": 96, "93a22bee7e": 96, "93da9aeddf": 91, "93e2feacce": 180, "93e6f1fdf9": 96, "93e811e393": 180, "93e85d8fd3": 180, "93f623d716": 180, "93ff35e801": 46, "94031f12f2": 96, "94091a4873": 180, "94125907e3": 87, "9418653742": 91, "941c870569": 101, "94209c86f0": 180, "9437c715eb": 76, "9445c3eca2": 91, "9467c8617c": 96, "946d71fb5d": 96, "948f3ae6fb": 180, "9498baa359": 96, "94a33abeab": 91, "94bf1af5e3": 144, "94cf3a8025": 96, "94db712ac8": 180, "94e4b66cff": 92, "94e76cbaf6": 180, "950be91db1": 180, "952058e2d0": 92, "952633c37f": 96, "952ec313fe": 87, "9533fc037c": 96, "9574b81269": 92, "9579b73761": 180, "957f7bc48b": 180, "958073d2b0": 150, "9582e0eb33": 71, "9584092d0b": 91, "95b58b8004": 150, "95bd88da55": 180, "95f74a9959": 180, "962781c601": 180, "962f045bf5": 91, "964ad23b44": 91, "967b90590e": 144, "967bffe201": 86, "96825c4714": 81, "968492136a": 96, "9684ef9d64": 86, "968c41829e": 91, "96a856ef9a": 180, "96dfc49961": 180, "96e1a5b4f8": 180, "96e6ff0917": 150, "96fb88e9d7": 96, "96fbe5fc23": 150, "96fc924050": 96, "9715cc83dc": 180, "9720eff40f": 180, "972c187c0d": 180, "97476eb38d": 180, "97659ed431": 180, "9773492949": 96, "97756b264f": 96, "977bff0d10": 96, "97ab569ff3": 96, "97ba838008": 180, "97d9d008c7": 150, "97e59f09fa": 96, "97eb642e56": 96, "98043e2d14": 96, "981ff580cf": 180, "983e66cbfc": 96, "984f0f1c36": 180, "98595f2bb4": 91, "985c3be474": 91, "9869a12362": 180, "986b5a5e18": 180, "9877af5063": 180, "98911292da": 180, "9893a3cf77": 97, "9893d9202d": 91, "98a8b06e7f": 91, "98ac6f93d9": 150, "98b6974d12": 96, "98ba3c9417": 180, "98c7c00a19": 96, "98d044f206": 96, "98e909f9d1": 150, "98fe7f0410": 150, "990f2742c7": 96, "992bd0779a": 180, "994b9b47ba": 150, "9955b76bf5": 91, "9966f3adac": 46, "997117a654": 180, "999d53d841": 150, "99c04108d3": 180, "99c4277aee": 96, "99c6b1acf2": 96, "99dc8bb20b": 180, "99fcba71e5": 150, "99fecd4efb": 92, "9a02c70ba2": 96, "9a08e7a6f8": 180, "9a2f2c0f86": 81, "9a3254a76e": 92, "9a3570a020": 180, "9a39112493": 180, "9a4e9fd399": 180, "9a50af4bfb": 180, "9a68631d24": 150, "9a72318dbf": 92, "9a767493b7": 180, "9a7fc1548b": 96, "9a84ccf6a7": 150, "9a9c0e15b7": 96, "9adf06d89b": 150, "9b22b54ee4": 91, "9b473fc8fe": 96, "9b4f081782": 180, "9b997664ba": 180, "9bc454e109": 180, "9bccfd04de": 96, "9bce4583a2": 96, "9bebf1b87f": 158, "9bfc50d261": 180, "9c166c86ff": 96, "9c293ef4d7": 144, "9c29c047b0": 91, "9c3bc2e2a7": 96, "9c3ce23bd1": 91, "9c404cac0c": 180, "9c5180d23a": 144, "9c7feca6e4": 144, "9caa49d3ff": 180, "9cb2f1b646": 180, "9ce6f765c3": 91, "9cfee34031": 180, "9d01f08ec6": 180, "9d04c280b8": 91, "9d12ceaddc": 180, "9d15f8cb3c": 180, "9d2101e9bf": 180, "9d407c3aeb": 96, "9ddefc6165": 180, "9df0b1e298": 96, "9e16f115d8": 144, "9e249b4982": 96, "9e29b1982c": 92, "9e493e4773": 180, "9e4c752cd0": 91, "9e4de40671": 96, "9e6319faeb": 96, "9e6ddbb52d": 91, "9eadcea74f": 180, "9ecec5f8ea": 46, "9efb47b595": 96, "9f30bfe61e": 72, "9f3734c3a4": 180, "9f5b858101": 180, "9f66640cda": 180, "9f913803e9": 180, "9f97bc74c8": 180, "9fbad86e20": 180, "9fc2bad316": 180, "9fc5c3af78": 150, "9fcb310255": 92, "9fcc256871": 91, "9fd2fd4d47": 180, "a0071ae316": 96, "a023141022": 56, "a046399a74": 96, "a066e739c1": 150, "a06722ba82": 96, "a07a15dd64": 180, "a07b47f694": 180, "a09c39472e": 144, "a0b208fe2e": 91, "a0b61c959e": 96, "a0bc6c611d": 180, "a0e6da5ba2": 91, "a1193d6490": 96, "a14ef483ff": 91, "a14f709908": 180, "a15ccc5658": 96, "a16062456f": 180, "a174e8d989": 91, "a177c2733c": 150, "a17c62e764": 92, "a18ad065fc": 150, "a1aaf63216": 96, "a1bb65fb91": 150, "a1bd8e5349": 91, "a1dfdd0cac": 180, "a2052e4f6c": 96, "a20fd34693": 96, "a21ffe4d81": 150, "a22349e647": 180, "a235d01ec1": 180, "a24f63e8a2": 180, "a2554c9f6d": 46, "a263ce8a87": 180, "a29bfc29ec": 91, "a2a80072d4": 150, "a2a800ab63": 180, "a2bcd10a33": 180, "a2bdaff3b0": 91, "a2c146ab0d": 91, "a2c996e429": 96, "a2dc51ebe8": 180, "a2e6608bfa": 180, "a2f2a55f01": 96, "a301869dea": 180, "a31fccd2cc": 180, "a34f440f33": 180, "a35e0206da": 180, "a36bdc4cab": 180, "a36e8c79d8": 71, "a378053b20": 144, "a37db3a2b3": 91, "a38950ebc2": 180, "a39a0eb433": 91, "a39c9bca52": 180, "a3a945dc8c": 91, "a3b40a0c1e": 150, "a3b8588550": 91, "a3c502bec3": 180, "a3f2878017": 180, "a3f4d58010": 180, "a3f51855c3": 150, "a402dc0dfe": 21, "a4065a7eda": 180, "a412bb2fef": 180, "a416b56b53": 96, "a41ec95906": 91, "a43299e362": 180, "a4757bd7af": 96, "a48c53c454": 180, "a49dcf9ad5": 150, "a4a506521f": 180, "a4ba7753d9": 180, "a4bac06849": 91, "a4f05d681c": 91, "a50c10060f": 150, "a50eb5a0ea": 150, "a5122c6ec6": 150, "a522b1aa79": 96, "a590915345": 180, "a5b5b59139": 96, "a5b77abe43": 180, "a5c2b2c3e1": 96, "a5cd17bb11": 180, "a5da03aef1": 180, "a5dd11de0d": 150, "a5ea2b93b6": 150, "a5eaeac80b": 180, "a5ec5b0265": 144, "a5f350a87e": 180, "a5f472caf4": 96, "a6027a53cf": 180, "a61715bb1b": 180, "a61cf4389d": 150, "a61d9bbd9b": 180, "a6470dbbf5": 150, "a64a40f3eb": 76, "a653d5c23b": 180, "a65bd23cb5": 150, "a66e0b7ad4": 180, "a66fc5053c": 91, "a68259572b": 180, "a6a810a92c": 150, "a6bc36937f": 91, "a6c3a374e9": 180, "a6d8a4228d": 180, "a6f4e0817f": 180, "a71e0481f5": 96, "a7203deb2d": 150, "a7392d4438": 150, "a73d3c3902": 180, "a7491f1578": 150, "a74b9ca19c": 180, "a77b7a91df": 150, "a78195a5f5": 150, "a78758d4ce": 180, "a7e6d6c29a": 96, "a800d85e88": 51, "a832fa8790": 180, "a83d06410d": 150, "a8999af004": 180, "a8f78125b9": 180, "a907b18df1": 150, "a919392446": 150, "a965504e88": 96, "a96b84b8d2": 96, "a973f239cd": 91, "a977126596": 180, "a9804f2a08": 91, "a984e56893": 96, "a99738f24c": 91, "a99bdd0079": 144, "a9c9c1517e": 178, "a9cbf9c41b": 150, "a9e42e3c0c": 150, "aa07b7c1c0": 180, "aa175e5ec7": 96, "aa1a338630": 96, "aa27d7b868": 96, "aa45f1caaf": 91, "aa49e46432": 96, "aa51934e1b": 180, "aa6287bb6c": 96, "aa6d999971": 180, "aa85278334": 96, "aab33f0e2a": 180, "aaba004362": 180, "aade4cf385": 180, "aae78feda4": 91, "aaed233bf3": 180, "aaff16c2db": 96, "ab199e8dfb": 96, "ab23b78715": 96, "ab2e1b5577": 180, "ab33a18ded": 96, "ab45078265": 180, "ab56201494": 180, "ab90f0d24b": 180, "abab2e6c20": 180, "abb50c8697": 92, "abbe2d15a0": 180, "abbe73cd21": 150, "abe61a11bb": 180, "abeae8ce21": 150, "ac2b431d5f": 150, "ac2cb1b9eb": 150, "ac31fcd6d0": 91, "ac3d3a126d": 180, "ac46bd8087": 180, "ac783ef388": 180, "acb73e4297": 150, "acbf581760": 180, "accafc3531": 96, "acf2c4b745": 96, "acf44293a2": 96, "acf736a27b": 90, "acff336758": 180, "ad1fe56886": 92, "ad28f9b9d9": 91, "ad2de9f80e": 180, "ad397527b2": 97, "ad3d1cfbcb": 86, "ad3fada9d9": 180, "ad4108ee8e": 180, "ad54468654": 66, "ad573f7d31": 96, "ad6255bc29": 180, "ad65ebaa07": 144, "ad97cc064a": 96, "adabbd1cc4": 180, "adb0b5a270": 180, "adc648f890": 150, "add21ee467": 180, "adfd15ceef": 180, "adfdd52eac": 96, "ae01cdab63": 180, "ae0b50ff4f": 96, "ae13ee3d70": 180, "ae1bcbd423": 180, "ae20d09dea": 180, "ae2cecf5f6": 56, "ae3bc4a0ef": 180, "ae499c7514": 92, "ae628f2cd4": 150, "ae8545d581": 86, "ae93214fe6": 150, "ae9cd16dbf": 46, "aeba9ac967": 180, "aebb242b5c": 150, "aed4e0b4c4": 86, "aedd71f125": 180, "aef3e2cb0e": 180, "af0b54cee3": 96, "af3de54c7a": 180, "af5fd24a36": 150, "af8826d084": 91, "af8ad72057": 180, "afb71e22c5": 92, "afcb331e1f": 96, "afe1a35c1e": 150, "b01080b5d3": 180, "b05ad0d345": 96, "b0623a6232": 91, "b064dbd4b7": 96, "b06ed37831": 96, "b06f5888e6": 92, "b08dcc490e": 91, "b0a68228dc": 92, "b0aece727f": 144, "b0b0731606": 96, "b0c7f11f9f": 180, "b0cca8b830": 180, "b0dd580a89": 180, "b0de66ca08": 180, "b0df7c5c5c": 96, "b0f5295608": 96, "b11099eb09": 180, "b132a53086": 91, "b1399fac64": 180, "b13abc0c69": 96, "b1457e3b5e": 180, "b15bf4453b": 91, "b179c4a82d": 96, "b17ee70e8c": 180, "b190b1aa65": 96, "b19b3e22c0": 180, "b19c561fab": 180, "b1d1cd2e6e": 92, "b1d7c03927": 91, "b1d7fe2753": 180, "b1f540a4bd": 96, "b1fc9c64e1": 96, "b1fcbb3ced": 180, "b220939e93": 96, "b22099b419": 180, "b241e95235": 96, "b2432ae86d": 180, "b2456267df": 180, "b247940d01": 150, "b24af1c35c": 180, "b24f600420": 97, "b24fe36b2a": 150, "b258fb0b7d": 180, "b26b219919": 96, "b26d9904de": 96, "b274456ce1": 180, "b27b28d581": 72, "b2a26bc912": 180, "b2a9c51e1b": 180, "b2b0baf470": 180, "b2b2756fe7": 96, "b2ce7699e3": 180, "b2edc76bd2": 150, "b2f6b52100": 180, "b30bf47bcd": 180, "b34105a4e9": 91, "b372a82edf": 150, "b3779a1962": 96, "b379ab4ff5": 46, "b37a1d69e3": 150, "b37c01396e": 180, "b382b09e25": 150, "b3996e4ba5": 180, "b3d9ca2aee": 180, "b3dde1e1e9": 180, "b3eb7f05eb": 86, "b40b25055c": 91, "b41e0f1f19": 91, "b44e32a42b": 91, "b4805ae9cd": 46, "b4807569a5": 97, "b48efceb3e": 150, "b493c25c7f": 180, "b4b565aba1": 150, "b4b715a15b": 180, "b4d0c90bf4": 91, "b4d84bc371": 180, "b4e5ad97aa": 180, "b4eaea9e6b": 150, "b50f4b90d5": 180, "b53f675641": 150, "b54278cd43": 180, "b554843889": 150, "b573c0677a": 180, "b58d853734": 180, "b5943b18ab": 180, "b5a09a83f3": 71, "b5aae1fe25": 91, "b5b9da5364": 97, "b5eb64d419": 91, "b5ebb1d000": 96, "b5f1c0c96a": 96, "b5f7fece90": 180, "b6070de1bb": 180, "b60a76fe73": 86, "b61f998772": 96, "b62c943664": 96, "b63094ba0c": 180, "b64fca8100": 96, "b673e7dcfb": 96, "b678b7db00": 180, "b68fc1b217": 180, "b69926d9fa": 96, "b6a1df3764": 180, "b6a4859528": 96, "b6b4738b78": 96, "b6b4f847b7": 150, "b6b8d502d4": 150, "b6bb00e366": 180, "b6d65a9eef": 180, "b6d79a0845": 180, "b6e9ec577f": 91, "b6ec609f7b": 163, "b6f92a308d": 180, "b70a2c0ab1": 46, "b70a5a0d50": 180, "b70c052f2f": 150, "b70d231781": 92, "b72ac6e10b": 180, "b7302d8226": 92, "b73867d769": 150, "b751e767f2": 180, "b76df6e059": 96, "b77e5eddef": 92, "b7a2c2c83c": 96, "b7bcbe6466": 180, "b7c2a469c4": 180, "b7d69da8f0": 144, "b7f31b7c36": 61, "b7f675fb98": 46, "b7fb871660": 51, "b82e5ad1c9": 91, "b841cfb932": 96, "b84b8ae665": 180, "b85b78ac2b": 180, "b86c17caa6": 180, "b86e50d82d": 96, "b871db031a": 66, "b87d56925a": 96, "b8aaa59b75": 92, "b8c03d1091": 180, "b8c3210036": 46, "b8e16df00b": 144, "b8f34cf72e": 91, "b8fb75864e": 150, "b9004db86c": 180, "b9166cbae9": 92, "b920b256a6": 180, "b938d79dff": 20, "b93963f214": 180, "b941aef1a0": 144, "b94d34d14e": 96, "b964c57da4": 96, "b96a95bc7a": 180, "b96c57d2c7": 144, "b9b6bdde0c": 180, "b9bcb3e0f2": 96, "b9d3b92169": 180, "b9dd4b306c": 180, "b9f43ef41e": 92, "ba1f03c811": 96, "ba3a775d7b": 180, "ba3c7f2a31": 150, "ba3fcd417d": 180, "ba5e1f4faa": 150, "ba795f3089": 96, "ba8a291e6a": 150, "ba98512f97": 92, "bac9db04f5": 180, "baedae3442": 180, "baff40d29d": 180, "bb04e28695": 96, "bb1b0ee89f": 96, "bb1c770fe7": 150, "bb1fc34f99": 150, "bb2d220506": 180, "bb334e5cdb": 91, "bb337f9830": 81, "bb721eb9aa": 96, "bb87ff58bd": 96, "bb89a6b18a": 87, "bbaa9a036a": 144, "bbb4302dda": 180, "bbd31510cf": 96, "bbe0256a75": 180, "bc141b9ad5": 91, "bc17ab8a99": 150, "bc318160de": 180, "bc3b9ee033": 91, "bc4240b43c": 96, "bc4ce49105": 91, "bc4f71372d": 96, "bc6b8d6371": 180, "bcaad44ad7": 150, "bcc241b081": 91, "bcc5d8095e": 96, "bcd1d39afb": 96, "bd0d849da4": 180, "bd0e9ed437": 150, "bd2c94730f": 180, "bd321d2be6": 61, "bd3ec46511": 91, "bd5b2e2848": 41, "bd7e02b139": 96, "bd96f9943a": 180, "bda224cb25": 91, "bda4a82837": 96, "bdb74e333f": 180, "bdccd69dde": 96, "bddcc15521": 180, "be116aab29": 150, "be15e18f1e": 150, "be1a284edb": 180, "be2a367a7b": 180, "be376082d0": 150, "be3e3cffbd": 51, "be5d1d89a0": 180, "be8b72fe37": 180, "be9b29e08e": 91, "bea1f6e62c": 97, "bea83281b5": 92, "beb921a4c9": 96, "bec5e9edcd": 180, "beeb8a3f92": 150, "bf2232b58d": 96, "bf28751739": 150, "bf443804e8": 180, "bf461df850": 150, "bf5374f122": 180, "bf551a6f60": 180, "bf8d0f5ada": 96, "bf961167a6": 92, "bfab1ad8f9": 150, "bfcb05d88d": 96, "bfd8f6e6c9": 92, "bfd91d0742": 150, "bfe262322f": 87, "c013f42ed7": 180, "c01878083f": 180, "c01faff1ed": 180, "c046fd0edb": 150, "c053e35f97": 91, "c079a6482d": 96, "c0847b521a": 96, "c0a1e06710": 180, "c0e8d4635c": 96, "c0e973ad85": 96, "c0f49c6579": 92, "c0f5b222d7": 96, "c10d07c90d": 180, "c1268d998c": 96, "c130c3fc0c": 180, "c14826ad5e": 180, "c15b922281": 180, "c16f09cb63": 180, "c18e19d922": 180, "c1c830a735": 96, "c1e8aeea45": 180, "c20a5ccc99": 180, "c20fd5e597": 180, "c219d6f8dc": 150, "c2406ae462": 96, "c26f7b5824": 180, "c279e641ee": 96, "c27adaeac5": 180, "c2a35c1cda": 96, "c2a9903b8b": 180, "c2b62567c1": 96, "c2b974ec8c": 150, "c2baaff7bf": 91, "c2be6900f2": 180, "c304dd44d5": 180, "c307f33da2": 96, "c30a7b62c9": 92, "c3128733ee": 180, "c31fa6c598": 180, "c325c8201e": 96, "c32d4aa5d1": 180, "c33f28249a": 144, "c34365e2d7": 180, "c3457af795": 96, "c34d120a88": 180, "c3509e728d": 96, "c35e4fa6c4": 180, "c36240d96f": 150, "c3641dfc5a": 92, "c37b17a4a9": 180, "c39559ddf6": 180, "c3b0c6e180": 96, "c3b3d82e6c": 180, "c3be369fdb": 91, "c3bf1e40c2": 97, "c3c760b015": 96, "c3dd38bf98": 150, "c3e4274614": 91, "c3edc48cbd": 180, "c41e6587f5": 96, "c4272227b0": 96, "c42917fe82": 86, "c438858117": 180, "c44676563f": 180, "c44beb7472": 180, "c45411dacb": 91, "c4571bedc8": 91, "c46deb2956": 180, "c479ee052e": 180, "c47d551843": 180, "c49f07d46d": 180, "c4cc40c1fc": 97, "c4f256f5d5": 144, "c4f5b1ddcc": 180, "c4ff9b4885": 150, "c52bce43db": 66, "c544da6854": 180, "c55784c766": 180, "c557b69fbf": 180, "c593a3f7ab": 92, "c598faa682": 180, "c5ab1f09c8": 180, "c5b6da8602": 96, "c5b9128d94": 96, "c5e845c6b7": 150, "c5fba7b341": 150, "c60897f093": 96, "c61fe6ed7c": 96, "c62188c536": 96, "c64035b2e2": 150, "c69689f177": 180, "c6a12c131f": 51, "c6bb6d2d5c": 180, "c6c18e860f": 150, "c6d9526e0d": 180, "c6e55c33f0": 96, "c7030b28bd": 96, "c70682c7cc": 180, "c70f9be8c5": 87, "c71f30d7b6": 180, "c73c8e747f": 180, "c760eeb8b3": 144, "c7637cab0a": 150, "c7a1a17308": 87, "c7bf937af5": 91, "c7c2860db3": 180, "c7cef4aee2": 91, "c7ebfc5d57": 180, "c813dcf13c": 91, "c82235a49a": 96, "c82a7619a1": 180, "c82ecb90cb": 180, "c844f03dc7": 96, "c8557963f3": 91, "c89147e6e8": 180, "c8a46ff0c8": 150, "c8ab107dd5": 97, "c8b869a04a": 96, "c8c7b306a6": 91, "c8c8b28781": 180, "c8d79e3163": 180, "c8edab0415": 150, "c8f494f416": 96, "c8f6cba9fd": 150, "c909ceea97": 92, "c9188f4980": 180, "c922365dd4": 96, "c92c8c3c75": 96, "c937eb0b83": 91, "c94b31b5e5": 180, "c95cd17749": 180, "c96379c03c": 180, "c96465ee65": 180, "c965afa713": 144, "c9734b451f": 92, "c9862d82dc": 180, "c98b6fe013": 180, "c9999b7c48": 180, "c99e92aaf0": 97, "c9b3a8fbda": 150, "c9bf64e965": 96, "c9c3cb3797": 91, "c9d1c60cd0": 144, "c9de9c22c4": 96, "ca1828fa54": 96, "ca346f17eb": 180, "ca3787d3d3": 150, "ca4b99cbac": 96, "ca91c69e3b": 71, "ca91e99105": 46, "caa8e97f81": 96, "caac5807f8": 180, "cabba242c2": 96, "cad5a656a9": 180, "cad673e375": 180, "cad8a85930": 150, "cae7b0a02b": 180, "cae7ef3184": 180, "caeb6b6cbb": 150, "caecf0a5db": 91, "cb15312003": 76, "cb2e35d610": 150, "cb35a87504": 150, "cb3f22b0cf": 96, "cbb410da64": 91, "cc8728052e": 150, "cc892997b8": 180, "cce03c2a9b": 144, "cd47a23e31": 92, "cd4dc03dc0": 180, "cd5ae611da": 96, "cd603bb9d1": 144, "cd8f49734c": 180, "cdc6b1c032": 92, "cdcfe008ad": 144, "cdd57027c2": 96, "ce1af99b4b": 150, "ce1bc5743a": 150, "ce25872021": 97, "ce2776f78f": 180, "ce49b1f474": 180, "ce4f0a266f": 180, "ce5641b195": 180, "ce6866aa19": 180, "ce712ed3c9": 91, "ce7d1c8117": 144, "ce7dbeaa88": 180, "ce9b015a5e": 180, "cea7697b25": 96, "cebbd826cf": 150, "cec3415361": 150, "cec41ad4f4": 180, "ced49d26df": 180, "ced7705ab2": 144, "cef824a1e1": 92, "cf13f5c95a": 144, "cf4376a52d": 180, "cf85ab28b5": 180, "cfc2e50b9d": 150, "cfcd571fff": 144, "cfd9d4ae47": 180, "cfda2dcce5": 150, "cff035928b": 91, "cff8191891": 46, "d01608c2a5": 96, "d01a8f1f83": 144, "d021d68bca": 180, "d04258ca14": 150, "d0483573dc": 150, "d04a90aaff": 180, "d05279c0bd": 180, "d0696bd5fc": 91, "d072fda75b": 178, "d0a83bcd9f": 150, "d0ab39112e": 180, "d0acde820f": 96, "d0b4442c71": 144, "d0c65e9e95": 180, "d0fb600c73": 150, "d107a1457c": 61, "d123d674c1": 66, "d14d1e9289": 96, "d154e3388e": 96, "d177e9878a": 96, "d1802f69f8": 150, "d182c4483a": 180, "d195d31128": 180, "d200838929": 180, "d205e3cff5": 180, "d247420c4c": 180, "d2484bff33": 66, "d26f6ed9b0": 150, "d280fcd1cb": 180, "d2857f0faa": 180, "d292a50c7f": 46, "d295ea2dc7": 96, "d2a58b4fa6": 91, "d2b026739a": 150, "d2ebe0890f": 180, "d2ede5d862": 91, "d301ca58cc": 150, "d3069da8bb": 91, "d343d4a77d": 150, "d355e634ef": 86, "d367fb5253": 91, "d36d16358e": 76, "d38bc77e2c": 101, "d38d1679e2": 144, "d3932ad4bd": 97, "d3987b2930": 180, "d39934abe3": 144, "d3ae1c3f4c": 92, "d3b088e593": 87, "d3e6e05e16": 150, "d3eefae7c5": 144, "d3f55f5ab8": 180, "d3f5c309cc": 61, "d4034a7fdf": 180, "d4193011f3": 144, "d429c67630": 180, "d42c0ff975": 180, "d44a764409": 180, "d44e6acd1d": 66, "d45158c175": 150, "d454e8444f": 150, "d45f62717e": 180, "d48ebdcf74": 180, "d49ab52a25": 86, "d4a607ad81": 92, "d4b063c7db": 144, "d4da13e9ba": 96, "d4dd1a7d00": 180, "d4f4f7c9c3": 96, "d521aba02e": 180, "d535bb1b97": 92, "d53b955f78": 96, "d55cb7a205": 92, "d55f247a45": 150, "d5695544d8": 180, "d5853d9b8b": 180, "d5b6c6d94a": 96, "d5cae12834": 150, "d5df027f0c": 144, "d5ee40e5d0": 180, "d600046f73": 144, "d632fd3510": 144, "d6476cad55": 180, "d65a7bae86": 150, "d664c89912": 150, "d689658f06": 180, "d6917db4be": 96, "d69967143e": 96, "d699d3d798": 91, "d69f757a3f": 180, "d6ac0e065c": 91, "d6c02bfda5": 96, "d6c1b5749e": 92, "d6e12ef6cc": 92, "d6eed152c4": 180, "d6faaaf726": 96, "d704766646": 180, "d708e1350c": 180, "d7135cf104": 180, "d7157a9f44": 46, "d719cf9316": 96, "d724134cfd": 144, "d73a60a244": 180, "d7411662da": 144, "d74875ea7c": 96, "d756f5a694": 91, "d7572b7d8a": 180, "d763bd6d96": 180, "d7697c8b13": 96, "d7797196b4": 150, "d79c834768": 180, "d7b34e5d73": 91, "d7bb6b37a7": 150, "d7c7e064a6": 180, "d7fbf545b3": 96, "d82a0aa15b": 180, "d847e24abd": 144, "d8596701b7": 144, "d86101499c": 144, "d87069ba86": 150, "d87160957b": 144, "d874654b52": 91, "d88a403092": 96, "d8aee40f3f": 144, "d8e77a222d": 91, "d8eb07c381": 180, "d9010348a1": 66, "d90e3cf281": 91, "d92532c7b2": 180, "d927fae122": 150, "d95707bca8": 91, "d973b31c00": 144, "d991cb471d": 180, "d992c69d37": 150, "d99d770820": 180, "d9b63abc11": 180, "d9db6f1983": 144, "d9e52be2d2": 96, "d9edc82650": 150, "da01070697": 96, "da070ea4b7": 180, "da080507b9": 150, "da0e944cc4": 180, "da28d94ff4": 96, "da5d78b9d1": 180, "da6003fc72": 150, "da690fee9f": 180, "da6c68708f": 180, "da7a816676": 144, "dac361e828": 180, "dac71659b8": 144, "dad980385d": 96, "daebc12b77": 150, "db0968cdd3": 150, "db231a7100": 92, "db59282ace": 91, "db7f267c3f": 180, "dba35b87fd": 96, "dbba735a50": 86, "dbca076acd": 180, "dbd66dc3ac": 180, "dbdc3c292b": 180, "dbf4a5b32b": 180, "dbfc417d28": 180, "dc1745e0a2": 91, "dc32a44804": 180, "dc34b35e30": 150, "dc504a4f79": 92, "dc704dd647": 180, "dc71bc6918": 92, "dc7771b3be": 180, "dcf8c93617": 96, "dd0f4c9fb9": 180, "dd415df125": 120, "dd601f9a3f": 144, "dd61d903df": 150, "dd77583736": 150, "dd8636bd8b": 180, "dd9fe6c6ac": 92, "ddb2da4c14": 180, "ddcd450d47": 144, "dde8e67fb4": 76, "ddfc3f04d3": 150, "de2ab79dfa": 180, "de2f35b2fd": 91, "de30990a51": 180, "de36b216da": 96, "de37403340": 180, "de46e4943b": 96, "de4ddbccb1": 180, "de5e480f05": 96, "de6a9382ca": 96, "de74a601d3": 180, "de827c510d": 92, "ded6069f7b": 180, "defb71c741": 96, "df01f277f1": 180, "df05214b82": 92, "df0638b0a0": 46, "df11931ffe": 180, "df1b0e4620": 180, "df20a8650d": 92, "df2bc56d7c": 180, "df365282c6": 180, "df39a0d9df": 96, "df3c430c24": 91, "df5536cfb9": 180, "df59cfd91d": 97, "df5e2152b3": 66, "df741313c9": 96, "df7626172f": 92, "df8ad5deb9": 180, "df96aa609a": 180, "df9705605c": 180, "df9c91c4da": 180, "dfc0d3d27a": 180, "dfdbf91a99": 180, "e00baaae9b": 180, "e0a938c6e7": 91, "e0b2ceee6f": 150, "e0bdb5dfae": 36, "e0be1f6e17": 96, "e0c478f775": 150, "e0de82caa7": 180, "e0f217dd59": 91, "e0f7208874": 180, "e0fb58395e": 180, "e1194c2e9d": 150, "e11adcd05d": 180, "e128124b9d": 87, "e1495354e4": 180, "e1561d6d4b": 180, "e158805399": 91, "e16945b951": 46, "e19edcd34b": 180, "e1a1544285": 180, "e1ab7957f4": 150, "e1d26d35be": 96, "e1e957085b": 96, "e1f14510fa": 180, "e214b160f4": 180, "e2167379b8": 150, "e21acb20ab": 180, "e221105579": 180, "e22ddf8a1b": 180, "e22de45950": 96, "e22ffc469b": 180, "e23cca5244": 96, "e252f46f0b": 180, "e25fa6cf39": 180, "e26e486026": 150, "e275760245": 96, "e27bbedbfe": 92, "e29e9868a8": 180, "e2b37ff8af": 96, "e2b608d309": 180, "e2bef4da9a": 96, "e2c87a6421": 96, "e2ea25542c": 144, "e2fb1d6497": 178, "e2fcc99117": 91, "e33c18412a": 71, "e348377191": 91, "e352cb59c8": 180, "e36ac982f0": 91, "e391bc981e": 96, "e39e3e0a06": 96, "e3bf38265f": 51, "e3d5b2cd21": 150, "e3d60e82d5": 46, "e3e3245492": 96, "e3e4134877": 150, "e3f4635e03": 180, "e4004ee048": 180, "e402d1afa5": 180, "e415093d27": 71, "e41ceb5d81": 180, "e424653b78": 96, "e42b6d3dbb": 96, "e42d60f0d4": 180, "e436d0ff1e": 180, "e43d7ae2c5": 92, "e4428801bc": 97, "e44e0b4917": 180, "e470345ede": 180, "e48e8b4263": 180, "e4922e3726": 180, "e4936852bb": 96, "e495f32c60": 41, "e499228f26": 150, "e4af66e163": 180, "e4b2095f58": 180, "e4d19c8283": 180, "e4d4872dab": 96, "e4e2983570": 41, "e4eaa63aab": 91, "e4ef0a3a34": 91, "e4f8e5f46e": 96, "e4ffb6d0dd": 71, "e53e21aa02": 180, "e57f4f668b": 180, "e588433c1e": 96, "e597442c99": 150, "e5abc0e96b": 91, "e5be628030": 180, "e5ce96a55d": 61, "e5d6b70a9f": 81, "e5fde1574c": 92, "e625e1d27b": 180, "e6261d2348": 91, "e6267d46bc": 96, "e6295f223f": 180, "e63463d8c6": 96, "e6387bd1e0": 180, "e653883384": 96, "e65f134e0b": 150, "e668ef5664": 180, "e672ccd250": 92, "e674510b20": 91, "e676107765": 150, "e699da0cdf": 180, "e6be243065": 46, "e6deab5e0b": 76, "e6f065f2b9": 96, "e71629e7b5": 96, "e72a7d7b0b": 150, "e72f6104e1": 92, "e75a466eea": 72, "e76c55933f": 150, "e7784ec8ad": 180, "e78922e5e6": 47, "e78d450a9c": 91, "e7c6354e77": 91, "e7c8de1fce": 150, "e7ea10db28": 150, "e803918710": 180, "e8073a140b": 180, "e828dd02db": 150, "e845994987": 150, "e8485a2615": 96, "e85c5118a7": 180, "e88b6736e4": 180, "e8962324e3": 91, "e8b3018d36": 91, "e8cee8bf0b": 150, "e8d97ebece": 144, "e8da49ea6a": 96, "e8ed1a3ccf": 180, "e8f7904326": 72, "e8f8341dec": 180, "e8fa21eb13": 180, "e90c10fc4c": 150, "e914b8cac8": 180, "e92b6bfea4": 46, "e92e1b7623": 150, "e93f83e512": 92, "e9422ad240": 46, "e9460b55f9": 180, "e9502628f6": 180, "e950befd5f": 180, "e9582bdd1b": 91, "e95e5afe0f": 96, "e97cfac475": 96, "e98d57d99c": 91, "e98eda8978": 92, "e99706b555": 41, "e9bc0760ba": 91, "e9d3c78bf3": 87, "e9ec1b7ea8": 144, "ea065cc205": 180, "ea138b6617": 150, "ea16d3fd48": 180, "ea2545d64b": 180, "ea286a581c": 150, "ea320da917": 96, "ea345f3627": 91, "ea3b94a591": 180, "ea444a37eb": 71, "ea4a01216b": 180, "ea5672ffa8": 81, "eaa99191cb": 150, "eaab4d746c": 91, "eac7a59bc1": 150, "ead5d3835a": 96, "eaec65cfa7": 180, "eaed1a87be": 180, "eb2f821c6f": 180, "eb383cb82e": 91, "eb6992fe02": 150, "eb6ac20a01": 92, "eb6d7ab39e": 96, "eb7921facd": 180, "eb8fce51a6": 180, "ebbb90e9f9": 91, "ebbf5c9ee1": 180, "ebc4ec32e6": 91, "ebe56e5ef8": 180, "ec1299aee4": 97, "ec139ff675": 180, "ec193e1a01": 180, "ec28252938": 150, "ec387be051": 180, "ec3d4fac00": 91, "ec4186ce12": 95, "ec579c2f96": 91, "ecae59b782": 180, "ecb33a0448": 180, "ece6bc9e92": 150, "ecfedd4035": 92, "ecfff22fd6": 180, "ed3291c3d6": 180, "ed3cd5308d": 180, "ed3e6fc1a5": 180, "ed72ae8825": 180, "ed7455da68": 92, "ed844e879f": 150, "ed8f814b2b": 92, "ed911a1f63": 180, "ed9ff4f649": 180, "eda8ab984b": 180, "edb8878849": 96, "edbfdfe1b4": 180, "edd22c46a2": 96, "edd663afa3": 180, "ede3552eae": 96, "edeab61ee0": 174, "ee07583fc0": 150, "ee316eaed6": 91, "ee3f509537": 150, "ee40a1e491": 92, "ee4bf100f1": 180, "ee6f9b01f9": 180, "ee947ed771": 96, "ee9706ac7f": 91, "ee9a7840ae": 180, "eeb90cb569": 180, "eebf45e5c5": 92, "eeed0c7d73": 87, "ef0061a309": 96, "ef07f1a655": 96, "ef0a8e8f35": 56, "ef232a2aed": 150, "ef308ad2e9": 180, "ef44945428": 96, "ef45ce3035": 180, "ef5dde449d": 180, "ef5e770988": 144, "ef6359cea3": 96, "ef65268834": 180, "ef6cb5eae0": 86, "ef78972bc2": 150, "ef8cfcfc4f": 82, "ef96501dd0": 150, "ef9a2e976b": 91, "efb24f950f": 180, "efce0c1868": 180, "efe5ac6901": 91, "efe828affa": 180, "efea4e0523": 144, "f0268aa627": 180, "f0483250c8": 180, "f04cf99ee6": 62, "f05b189097": 96, "f08928c6d3": 96, "f09d74856f": 150, "f0a7607d63": 180, "f0ad38da27": 71, "f0c34e1213": 92, "f0c7f86c29": 180, "f0dfa18ba7": 150, "f0eb3179f7": 180, "f119bab27d": 150, "f14409b6a3": 180, "f1489baff4": 86, "f14c18cf6a": 180, "f15c607b92": 180, "f1af214222": 97, "f1b77bd309": 180, "f1ba9e1a3e": 180, "f1d99239eb": 66, "f1dc710cf4": 180, "f1ec5c08fa": 97, "f22648fe12": 180, "f22d21f1f1": 144, "f233257395": 91, "f23e95dbe5": 96, "f2445b1572": 150, "f253b3486d": 144, "f277c7a6a4": 91, "f2ab2b84d6": 87, "f2b7c9b1f3": 150, "f2b83d5ce5": 180, "f2c276018f": 150, "f2cfd94d64": 150, "f2dd6e3add": 150, "f2e7653f16": 180, "f2f333ad06": 96, "f2f55d6713": 180, "f2fdb6abec": 180, "f305a56d9f": 46, "f3085d6570": 96, "f3325c3338": 180, "f3400f1204": 180, "f34497c932": 97, "f34a56525e": 91, "f36483c824": 96, "f3704d5663": 91, "f3734c4913": 150, "f38e5aa5b4": 86, "f3986fba44": 180, "f3a0ffc7d9": 180, "f3b24a7d28": 96, "f3e6c35ec3": 180, "f3fc0ea80b": 96, "f40a683fbe": 180, "f4207ca554": 180, "f4377499c2": 150, "f46184f393": 144, "f46c2d0a6d": 180, "f46c364dca": 180, "f46f7a0b63": 180, "f46fe141b0": 91, "f470b9aeb0": 180, "f47eb7437f": 96, "f48b535719": 92, "f49e4866ac": 180, "f4aa882cfd": 180, "f4daa3dbd5": 96, "f4dd51ac35": 91, "f507a1b9dc": 96, "f51c5ac84b": 86, "f52104164b": 180, "f54c67b9bb": 96, "f5966cadd2": 180, "f5bddf5598": 91, "f5d85cfd17": 92, "f5e2e7d6a0": 96, "f5f051e9b4": 180, "f5f8a93a76": 150, "f6283e8af5": 96, "f635e9568b": 180, "f6474735be": 144, "f659251be2": 150, "f66981af4e": 96, "f6708fa398": 87, "f697fe8e8f": 96, "f6adb12c42": 76, "f6c7906ca4": 180, "f6cd0a8016": 144, "f6d6f15ae7": 144, "f6e501892c": 96, "f6f59d986f": 180, "f6fe8c90a5": 180, "f714160545": 144, "f74c3888d7": 180, "f7782c430e": 150, "f7783ae5f2": 96, "f77ab47923": 97, "f788a98327": 91, "f7961ac1f0": 96, "f7a71e7574": 150, "f7a8521432": 180, "f7afbf4947": 150, "f7b7cd5f44": 81, "f7cf4b4a39": 92, "f7d49799ad": 150, "f7e0c9bb83": 180, "f7e5b84928": 96, "f7e6bd58be": 96, "f7f2a38ac6": 96, "f7f6cb2d6d": 150, "f83f19e796": 76, "f85796a921": 91, "f8603c26b2": 180, "f8819b42ec": 144, "f891f8eaa1": 96, "f89288d10c": 92, "f895ae8cc1": 180, "f8af30d4b6": 97, "f8b4ac12f1": 180, "f8c3fb2b01": 180, "f8c8de2764": 180, "f8db369b40": 92, "f8fcb6a78c": 180, "f94aafdeef": 180, "f95d217b70": 96, "f9681d5103": 92, "f9750192a4": 91, "f9823a32c2": 96, "f991ddb4c2": 96, "f99d535567": 96, "f9ae3d98b7": 144, "f9b6217959": 91, "f9bd1fabf5": 96, "f9c68eaa64": 180, "f9d3e04c4f": 92, "f9daf64494": 180, "f9e4cc5a0a": 96, "f9ea6b7f31": 96, "f9f3852526": 180, "fa04c615cf": 150, "fa08e00a56": 180, "fa4370d74d": 180, "fa67744af3": 180, "fa88d48a92": 150, "fa8b904cc9": 92, "fa9526bdf1": 150, "fa9b9d2426": 150, "fad633fbe1": 150, "faf5222dc3": 91, "faff0e15f1": 180, "fb08c64e8c": 180, "fb23455a7f": 150, "fb2e19fa6e": 180, "fb34dfbb77": 180, "fb47fcea1e": 96, "fb49738155": 180, "fb4cbc514b": 71, "fb4e6062f7": 180, "fb5ba7ad6e": 96, "fb63cd1236": 96, "fb81157a07": 180, "fb92abdaeb": 180, "fba22a6848": 92, "fbaca0c9df": 180, "fbc645f602": 96, "fbd77444cd": 96, "fbe53dc8e8": 96, "fbe541dd73": 97, "fbe8488798": 91, "fbfd25174f": 96, "fc28cb305e": 97, "fc33b1ffd6": 150, "fc6186f0bb": 180, "fc918e3a40": 150, "fc96cda9d8": 150, "fc9832eea4": 150, "fcb10d0f81": 180, "fcd20a2509": 180, "fcf637e3ab": 92, "fcfd81727f": 96, "fd31890379": 180, "fd33551c28": 144, "fd542da05e": 144, "fd6789b3fe": 180, "fd77828200": 180, "fd7af75f4d": 150, "fdb28d0fbb": 150, "fdb3d1fb1e": 82, "fdb8b04124": 96, "fdc6e3d581": 91, "fdfce7e6fc": 180, "fe0f76d41b": 180, "fe24b0677d": 180, "fe3c02699d": 144, "fe58b48235": 96, "fe6a5596b8": 91, "fe6c244f63": 66, "fe7afec086": 180, "fe985d510a": 144, "fe9db35d15": 96, "fea8ffcd36": 144, "feb1080388": 180, "fed208bfca": 180, "feda5ad1c2": 180, "feec95b386": 91, "ff15a5eff6": 144, "ff204daf4b": 96, "ff25f55852": 180, "ff2ada194f": 180, "ff2ce142e8": 96, "ff49d36d20": 180, "ff5a1ec4f3": 180, "ff66152b25": 180, "ff692fdc56": 180, "ff773b1a1e": 96, "ff97129478": 144, "ffb904207d": 180, "ffc43fc345": 150, "fffe5f8df6": 180}
inference_propainter.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import imageio
6
+ import numpy as np
7
+ import scipy.ndimage
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+
11
+ import torch
12
+ import torchvision
13
+
14
+ from model.modules.flow_comp_raft import RAFT_bi
15
+ from model.recurrent_flow_completion import RecurrentFlowCompleteNet
16
+ from model.propainter import InpaintGenerator
17
+ from utils.download_util import load_file_from_url
18
+ from core.utils import to_tensors
19
+ from model.misc import get_device
20
+
21
+ import warnings
22
+ warnings.filterwarnings("ignore")
23
+
24
+ pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
25
+
26
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
27
+ if auto_mkdir:
28
+ dir_name = os.path.abspath(os.path.dirname(file_path))
29
+ os.makedirs(dir_name, exist_ok=True)
30
+ return cv2.imwrite(file_path, img, params)
31
+
32
+
33
+ # resize frames
34
+ def resize_frames(frames, size=None):
35
+ if size is not None:
36
+ out_size = size
37
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
38
+ frames = [f.resize(process_size) for f in frames]
39
+ else:
40
+ out_size = frames[0].size
41
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
42
+ if not out_size == process_size:
43
+ frames = [f.resize(process_size) for f in frames]
44
+
45
+ return frames, process_size, out_size
46
+
47
+
48
+ # read frames from video
49
+ def read_frame_from_videos(frame_root):
50
+ if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
51
+ video_name = os.path.basename(frame_root)[:-4]
52
+ vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
53
+ frames = list(vframes.numpy())
54
+ frames = [Image.fromarray(f) for f in frames]
55
+ fps = info['video_fps']
56
+ else:
57
+ video_name = os.path.basename(frame_root)
58
+ frames = []
59
+ fr_lst = sorted(os.listdir(frame_root))
60
+ for fr in fr_lst:
61
+ frame = cv2.imread(os.path.join(frame_root, fr))
62
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
63
+ frames.append(frame)
64
+ fps = None
65
+ size = frames[0].size
66
+
67
+ return frames, fps, size, video_name
68
+
69
+
70
+ def binary_mask(mask, th=0.1):
71
+ mask[mask>th] = 1
72
+ mask[mask<=th] = 0
73
+ return mask
74
+
75
+
76
+ # read frame-wise masks
77
+ def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5):
78
+ masks_img = []
79
+ masks_dilated = []
80
+ flow_masks = []
81
+
82
+ if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
83
+ masks_img = [Image.open(mpath)]
84
+ else:
85
+ mnames = sorted(os.listdir(mpath))
86
+ for mp in mnames:
87
+ masks_img.append(Image.open(os.path.join(mpath, mp)))
88
+
89
+ for mask_img in masks_img:
90
+ if size is not None:
91
+ mask_img = mask_img.resize(size, Image.NEAREST)
92
+ mask_img = np.array(mask_img.convert('L'))
93
+
94
+ # Dilate 8 pixel so that all known pixel is trustworthy
95
+ if flow_mask_dilates > 0:
96
+ flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
97
+ else:
98
+ flow_mask_img = binary_mask(mask_img).astype(np.uint8)
99
+ # Close the small holes inside the foreground objects
100
+ # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)
101
+ # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)
102
+ flow_masks.append(Image.fromarray(flow_mask_img * 255))
103
+
104
+ if mask_dilates > 0:
105
+ mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
106
+ else:
107
+ mask_img = binary_mask(mask_img).astype(np.uint8)
108
+ masks_dilated.append(Image.fromarray(mask_img * 255))
109
+
110
+ if len(masks_img) == 1:
111
+ flow_masks = flow_masks * length
112
+ masks_dilated = masks_dilated * length
113
+
114
+ return flow_masks, masks_dilated
115
+
116
+
117
+ def extrapolation(video_ori, scale):
118
+ """Prepares the data for video outpainting.
119
+ """
120
+ nFrame = len(video_ori)
121
+ imgW, imgH = video_ori[0].size
122
+
123
+ # Defines new FOV.
124
+ imgH_extr = int(scale[0] * imgH)
125
+ imgW_extr = int(scale[1] * imgW)
126
+ imgH_extr = imgH_extr - imgH_extr % 8
127
+ imgW_extr = imgW_extr - imgW_extr % 8
128
+ H_start = int((imgH_extr - imgH) / 2)
129
+ W_start = int((imgW_extr - imgW) / 2)
130
+
131
+ # Extrapolates the FOV for video.
132
+ frames = []
133
+ for v in video_ori:
134
+ frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
135
+ frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
136
+ frames.append(Image.fromarray(frame))
137
+
138
+ # Generates the mask for missing region.
139
+ masks_dilated = []
140
+ flow_masks = []
141
+
142
+ dilate_h = 4 if H_start > 10 else 0
143
+ dilate_w = 4 if W_start > 10 else 0
144
+ mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
145
+
146
+ mask[H_start+dilate_h: H_start+imgH-dilate_h,
147
+ W_start+dilate_w: W_start+imgW-dilate_w] = 0
148
+ flow_masks.append(Image.fromarray(mask * 255))
149
+
150
+ mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
151
+ masks_dilated.append(Image.fromarray(mask * 255))
152
+
153
+ flow_masks = flow_masks * nFrame
154
+ masks_dilated = masks_dilated * nFrame
155
+
156
+ return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
157
+
158
+
159
+ def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
160
+ ref_index = []
161
+ if ref_num == -1:
162
+ for i in range(0, length, ref_stride):
163
+ if i not in neighbor_ids:
164
+ ref_index.append(i)
165
+ else:
166
+ start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
167
+ end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
168
+ for i in range(start_idx, end_idx, ref_stride):
169
+ if i not in neighbor_ids:
170
+ if len(ref_index) > ref_num:
171
+ break
172
+ ref_index.append(i)
173
+ return ref_index
174
+
175
+
176
+
177
+ if __name__ == '__main__':
178
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+ device = get_device()
180
+
181
+ parser = argparse.ArgumentParser()
182
+ parser.add_argument(
183
+ '-i', '--video', type=str, default='inputs/object_removal/bmx-trees', help='Path of the input video or image folder.')
184
+ parser.add_argument(
185
+ '-m', '--mask', type=str, default='inputs/object_removal/bmx-trees_mask', help='Path of the mask(s) or mask folder.')
186
+ parser.add_argument(
187
+ '-o', '--output', type=str, default='results', help='Output folder. Default: results')
188
+ parser.add_argument(
189
+ "--resize_ratio", type=float, default=1.0, help='Resize scale for processing video.')
190
+ parser.add_argument(
191
+ '--height', type=int, default=-1, help='Height of the processing video.')
192
+ parser.add_argument(
193
+ '--width', type=int, default=-1, help='Width of the processing video.')
194
+ parser.add_argument(
195
+ '--mask_dilation', type=int, default=4, help='Mask dilation for video and flow masking.')
196
+ parser.add_argument(
197
+ "--ref_stride", type=int, default=10, help='Stride of global reference frames.')
198
+ parser.add_argument(
199
+ "--neighbor_length", type=int, default=10, help='Length of local neighboring frames.')
200
+ parser.add_argument(
201
+ "--subvideo_length", type=int, default=80, help='Length of sub-video for long video inference.')
202
+ parser.add_argument(
203
+ "--raft_iter", type=int, default=20, help='Iterations for RAFT inference.')
204
+ parser.add_argument(
205
+ '--mode', default='video_inpainting', choices=['video_inpainting', 'video_outpainting'], help="Modes: video_inpainting / video_outpainting")
206
+ parser.add_argument(
207
+ '--scale_h', type=float, default=1.0, help='Outpainting scale of height for video_outpainting mode.')
208
+ parser.add_argument(
209
+ '--scale_w', type=float, default=1.2, help='Outpainting scale of width for video_outpainting mode.')
210
+ parser.add_argument(
211
+ '--save_fps', type=int, default=24, help='Frame per second. Default: 24')
212
+ parser.add_argument(
213
+ '--save_frames', action='store_true', help='Save output frames. Default: False')
214
+ parser.add_argument(
215
+ '--fp16', action='store_true', help='Use fp16 (half precision) during inference. Default: fp32 (single precision).')
216
+
217
+ args = parser.parse_args()
218
+
219
+ # Use fp16 precision during inference to reduce running memory cost
220
+ use_half = True if args.fp16 else False
221
+
222
+
223
+ frames, fps, size, video_name = read_frame_from_videos(args.video)
224
+ if not args.width == -1 and not args.height == -1:
225
+ size = (args.width, args.height)
226
+ if not args.resize_ratio == 1.0:
227
+ size = (int(args.resize_ratio * size[0]), int(args.resize_ratio * size[1]))
228
+
229
+ frames, size, out_size = resize_frames(frames, size)
230
+
231
+ fps = args.save_fps if fps is None else fps
232
+ save_root = os.path.join(args.output, video_name)
233
+ if not os.path.exists(save_root):
234
+ os.makedirs(save_root, exist_ok=True)
235
+
236
+ if args.mode == 'video_inpainting':
237
+ frames_len = len(frames)
238
+ flow_masks, masks_dilated = read_mask(args.mask, frames_len, size,
239
+ flow_mask_dilates=args.mask_dilation,
240
+ mask_dilates=args.mask_dilation)
241
+ w, h = size
242
+ elif args.mode == 'video_outpainting':
243
+ assert args.scale_h is not None and args.scale_w is not None, 'Please provide a outpainting scale (s_h, s_w).'
244
+ frames, flow_masks, masks_dilated, size = extrapolation(frames, (args.scale_h, args.scale_w))
245
+ w, h = size
246
+ else:
247
+ raise NotImplementedError
248
+
249
+ # for saving the masked frames or video
250
+ masked_frame_for_save = []
251
+ for i in range(len(frames)):
252
+ mask_ = np.expand_dims(np.array(masks_dilated[i]),2).repeat(3, axis=2)/255.
253
+ img = np.array(frames[i])
254
+ green = np.zeros([h, w, 3])
255
+ green[:,:,1] = 255
256
+ alpha = 0.6
257
+ # alpha = 1.0
258
+ fuse_img = (1-alpha)*img + alpha*green
259
+ fuse_img = mask_ * fuse_img + (1-mask_)*img
260
+ masked_frame_for_save.append(fuse_img.astype(np.uint8))
261
+
262
+ frames_inp = [np.array(f).astype(np.uint8) for f in frames]
263
+ frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
264
+ flow_masks = to_tensors()(flow_masks).unsqueeze(0)
265
+ masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
266
+ frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device)
267
+
268
+
269
+ ##############################################
270
+ # set up RAFT and flow competition model
271
+ ##############################################
272
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),
273
+ model_dir='weights', progress=True, file_name=None)
274
+ fix_raft = RAFT_bi(ckpt_path, device)
275
+
276
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),
277
+ model_dir='weights', progress=True, file_name=None)
278
+ fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)
279
+ for p in fix_flow_complete.parameters():
280
+ p.requires_grad = False
281
+ fix_flow_complete.to(device)
282
+ fix_flow_complete.eval()
283
+
284
+
285
+ ##############################################
286
+ # set up ProPainter model
287
+ ##############################################
288
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),
289
+ model_dir='weights', progress=True, file_name=None)
290
+ model = InpaintGenerator(model_path=ckpt_path).to(device)
291
+ model.eval()
292
+
293
+
294
+ ##############################################
295
+ # ProPainter inference
296
+ ##############################################
297
+ video_length = frames.size(1)
298
+ print(f'\nProcessing: {video_name} [{video_length} frames]...')
299
+ with torch.no_grad():
300
+ # ---- compute flow ----
301
+ if frames.size(-1) <= 640:
302
+ short_clip_len = 12
303
+ elif frames.size(-1) <= 720:
304
+ short_clip_len = 8
305
+ elif frames.size(-1) <= 1280:
306
+ short_clip_len = 4
307
+ else:
308
+ short_clip_len = 2
309
+
310
+ # use fp32 for RAFT
311
+ if frames.size(1) > short_clip_len:
312
+ gt_flows_f_list, gt_flows_b_list = [], []
313
+ for f in range(0, video_length, short_clip_len):
314
+ end_f = min(video_length, f + short_clip_len)
315
+ if f == 0:
316
+ flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
317
+ else:
318
+ flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
319
+
320
+ gt_flows_f_list.append(flows_f)
321
+ gt_flows_b_list.append(flows_b)
322
+ torch.cuda.empty_cache()
323
+
324
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
325
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
326
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
327
+ else:
328
+ gt_flows_bi = fix_raft(frames, iters=args.raft_iter)
329
+ torch.cuda.empty_cache()
330
+
331
+
332
+ if use_half:
333
+ frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
334
+ gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
335
+ fix_flow_complete = fix_flow_complete.half()
336
+ model = model.half()
337
+
338
+
339
+ # ---- complete flow ----
340
+ flow_length = gt_flows_bi[0].size(1)
341
+ if flow_length > args.subvideo_length:
342
+ pred_flows_f, pred_flows_b = [], []
343
+ pad_len = 5
344
+ for f in range(0, flow_length, args.subvideo_length):
345
+ s_f = max(0, f - pad_len)
346
+ e_f = min(flow_length, f + args.subvideo_length + pad_len)
347
+ pad_len_s = max(0, f) - s_f
348
+ pad_len_e = e_f - min(flow_length, f + args.subvideo_length)
349
+ pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(
350
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
351
+ flow_masks[:, s_f:e_f+1])
352
+ pred_flows_bi_sub = fix_flow_complete.combine_flow(
353
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
354
+ pred_flows_bi_sub,
355
+ flow_masks[:, s_f:e_f+1])
356
+
357
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
358
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
359
+ torch.cuda.empty_cache()
360
+
361
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
362
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
363
+ pred_flows_bi = (pred_flows_f, pred_flows_b)
364
+ else:
365
+ pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
366
+ pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
367
+ torch.cuda.empty_cache()
368
+
369
+
370
+ # ---- image propagation ----
371
+ masked_frames = frames * (1 - masks_dilated)
372
+ subvideo_length_img_prop = min(100, args.subvideo_length) # ensure a minimum of 100 frames for image propagation
373
+ if video_length > subvideo_length_img_prop:
374
+ updated_frames, updated_masks = [], []
375
+ pad_len = 10
376
+ for f in range(0, video_length, subvideo_length_img_prop):
377
+ s_f = max(0, f - pad_len)
378
+ e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
379
+ pad_len_s = max(0, f) - s_f
380
+ pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
381
+
382
+ b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
383
+ pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
384
+ prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f],
385
+ pred_flows_bi_sub,
386
+ masks_dilated[:, s_f:e_f],
387
+ 'nearest')
388
+ updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
389
+ prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
390
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
391
+
392
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
393
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
394
+ torch.cuda.empty_cache()
395
+
396
+ updated_frames = torch.cat(updated_frames, dim=1)
397
+ updated_masks = torch.cat(updated_masks, dim=1)
398
+ else:
399
+ b, t, _, _, _ = masks_dilated.size()
400
+ prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
401
+ updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
402
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
403
+ torch.cuda.empty_cache()
404
+
405
+
406
+ ori_frames = frames_inp
407
+ comp_frames = [None] * video_length
408
+
409
+ neighbor_stride = args.neighbor_length // 2
410
+ if video_length > args.subvideo_length:
411
+ ref_num = args.subvideo_length // args.ref_stride
412
+ else:
413
+ ref_num = -1
414
+
415
+ # ---- feature propagation + transformer ----
416
+ for f in tqdm(range(0, video_length, neighbor_stride)):
417
+ neighbor_ids = [
418
+ i for i in range(max(0, f - neighbor_stride),
419
+ min(video_length, f + neighbor_stride + 1))
420
+ ]
421
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, args.ref_stride, ref_num)
422
+ selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
423
+ selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
424
+ selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
425
+ selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
426
+
427
+ with torch.no_grad():
428
+ # 1.0 indicates mask
429
+ l_t = len(neighbor_ids)
430
+
431
+ # pred_img = selected_imgs # results of image propagation
432
+ pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
433
+
434
+ pred_img = pred_img.view(-1, 3, h, w)
435
+
436
+ pred_img = (pred_img + 1) / 2
437
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
438
+ binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
439
+ 0, 2, 3, 1).numpy().astype(np.uint8)
440
+ for i in range(len(neighbor_ids)):
441
+ idx = neighbor_ids[i]
442
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
443
+ + ori_frames[idx] * (1 - binary_masks[i])
444
+ if comp_frames[idx] is None:
445
+ comp_frames[idx] = img
446
+ else:
447
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
448
+
449
+ comp_frames[idx] = comp_frames[idx].astype(np.uint8)
450
+
451
+ torch.cuda.empty_cache()
452
+
453
+ # save each frame
454
+ if args.save_frames:
455
+ for idx in range(video_length):
456
+ f = comp_frames[idx]
457
+ f = cv2.resize(f, out_size, interpolation = cv2.INTER_CUBIC)
458
+ f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
459
+ img_save_root = os.path.join(save_root, 'frames', str(idx).zfill(4)+'.png')
460
+ imwrite(f, img_save_root)
461
+
462
+
463
+ # if args.mode == 'video_outpainting':
464
+ # comp_frames = [i[10:-10,10:-10] for i in comp_frames]
465
+ # masked_frame_for_save = [i[10:-10,10:-10] for i in masked_frame_for_save]
466
+
467
+ # save videos frame
468
+ masked_frame_for_save = [cv2.resize(f, out_size) for f in masked_frame_for_save]
469
+ comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
470
+ imageio.mimwrite(os.path.join(save_root, 'masked_in.mp4'), masked_frame_for_save, fps=fps, quality=7)
471
+ imageio.mimwrite(os.path.join(save_root, 'inpaint_out.mp4'), comp_frames, fps=fps, quality=7)
472
+
473
+ print(f'\nAll results are saved in {save_root}')
474
+
475
+ torch.cuda.empty_cache()
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
model/canny/canny_filter.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .gaussian import gaussian_blur2d
9
+ from .kernels import get_canny_nms_kernel, get_hysteresis_kernel
10
+ from .sobel import spatial_gradient
11
+
12
+ def rgb_to_grayscale(image, rgb_weights = None):
13
+ if len(image.shape) < 3 or image.shape[-3] != 3:
14
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
15
+
16
+ if rgb_weights is None:
17
+ # 8 bit images
18
+ if image.dtype == torch.uint8:
19
+ rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
20
+ # floating point images
21
+ elif image.dtype in (torch.float16, torch.float32, torch.float64):
22
+ rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
23
+ else:
24
+ raise TypeError(f"Unknown data type: {image.dtype}")
25
+ else:
26
+ # is tensor that we make sure is in the same device/dtype
27
+ rgb_weights = rgb_weights.to(image)
28
+
29
+ # unpack the color image channels with RGB order
30
+ r = image[..., 0:1, :, :]
31
+ g = image[..., 1:2, :, :]
32
+ b = image[..., 2:3, :, :]
33
+
34
+ w_r, w_g, w_b = rgb_weights.unbind()
35
+ return w_r * r + w_g * g + w_b * b
36
+
37
+
38
+ def canny(
39
+ input: torch.Tensor,
40
+ low_threshold: float = 0.1,
41
+ high_threshold: float = 0.2,
42
+ kernel_size: Tuple[int, int] = (5, 5),
43
+ sigma: Tuple[float, float] = (1, 1),
44
+ hysteresis: bool = True,
45
+ eps: float = 1e-6,
46
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
47
+ r"""Find edges of the input image and filters them using the Canny algorithm.
48
+
49
+ .. image:: _static/img/canny.png
50
+
51
+ Args:
52
+ input: input image tensor with shape :math:`(B,C,H,W)`.
53
+ low_threshold: lower threshold for the hysteresis procedure.
54
+ high_threshold: upper threshold for the hysteresis procedure.
55
+ kernel_size: the size of the kernel for the gaussian blur.
56
+ sigma: the standard deviation of the kernel for the gaussian blur.
57
+ hysteresis: if True, applies the hysteresis edge tracking.
58
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
59
+ eps: regularization number to avoid NaN during backprop.
60
+
61
+ Returns:
62
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
63
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
64
+
65
+ .. note::
66
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
67
+ canny.html>`__.
68
+
69
+ Example:
70
+ >>> input = torch.rand(5, 3, 4, 4)
71
+ >>> magnitude, edges = canny(input) # 5x3x4x4
72
+ >>> magnitude.shape
73
+ torch.Size([5, 1, 4, 4])
74
+ >>> edges.shape
75
+ torch.Size([5, 1, 4, 4])
76
+ """
77
+ if not isinstance(input, torch.Tensor):
78
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
79
+
80
+ if not len(input.shape) == 4:
81
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
82
+
83
+ if low_threshold > high_threshold:
84
+ raise ValueError(
85
+ "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format(
86
+ low_threshold, high_threshold
87
+ )
88
+ )
89
+
90
+ if low_threshold < 0 and low_threshold > 1:
91
+ raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
92
+
93
+ if high_threshold < 0 and high_threshold > 1:
94
+ raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
95
+
96
+ device: torch.device = input.device
97
+ dtype: torch.dtype = input.dtype
98
+
99
+ # To Grayscale
100
+ if input.shape[1] == 3:
101
+ input = rgb_to_grayscale(input)
102
+
103
+ # Gaussian filter
104
+ blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)
105
+
106
+ # Compute the gradients
107
+ gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)
108
+
109
+ # Unpack the edges
110
+ gx: torch.Tensor = gradients[:, :, 0]
111
+ gy: torch.Tensor = gradients[:, :, 1]
112
+
113
+ # Compute gradient magnitude and angle
114
+ magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
115
+ angle: torch.Tensor = torch.atan2(gy, gx)
116
+
117
+ # Radians to Degrees
118
+ angle = 180.0 * angle / math.pi
119
+
120
+ # Round angle to the nearest 45 degree
121
+ angle = torch.round(angle / 45) * 45
122
+
123
+ # Non-maximal suppression
124
+ nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
125
+ nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
126
+
127
+ # Get the indices for both directions
128
+ positive_idx: torch.Tensor = (angle / 45) % 8
129
+ positive_idx = positive_idx.long()
130
+
131
+ negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
132
+ negative_idx = negative_idx.long()
133
+
134
+ # Apply the non-maximum suppression to the different directions
135
+ channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx)
136
+ channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx)
137
+
138
+ channel_select_filtered: torch.Tensor = torch.stack(
139
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
140
+ )
141
+
142
+ is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
143
+
144
+ magnitude = magnitude * is_max
145
+
146
+ # Threshold
147
+ edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)
148
+
149
+ low: torch.Tensor = magnitude > low_threshold
150
+ high: torch.Tensor = magnitude > high_threshold
151
+
152
+ edges = low * 0.5 + high * 0.5
153
+ edges = edges.to(dtype)
154
+
155
+ # Hysteresis
156
+ if hysteresis:
157
+ edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
158
+ hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)
159
+
160
+ while ((edges_old - edges).abs() != 0).any():
161
+ weak: torch.Tensor = (edges == 0.5).float()
162
+ strong: torch.Tensor = (edges == 1).float()
163
+
164
+ hysteresis_magnitude: torch.Tensor = F.conv2d(
165
+ edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
166
+ )
167
+ hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
168
+ hysteresis_magnitude = hysteresis_magnitude * weak + strong
169
+
170
+ edges_old = edges.clone()
171
+ edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
172
+
173
+ edges = hysteresis_magnitude
174
+
175
+ return magnitude, edges
176
+
177
+
178
+ class Canny(nn.Module):
179
+ r"""Module that finds edges of the input image and filters them using the Canny algorithm.
180
+
181
+ Args:
182
+ input: input image tensor with shape :math:`(B,C,H,W)`.
183
+ low_threshold: lower threshold for the hysteresis procedure.
184
+ high_threshold: upper threshold for the hysteresis procedure.
185
+ kernel_size: the size of the kernel for the gaussian blur.
186
+ sigma: the standard deviation of the kernel for the gaussian blur.
187
+ hysteresis: if True, applies the hysteresis edge tracking.
188
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
189
+ eps: regularization number to avoid NaN during backprop.
190
+
191
+ Returns:
192
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
193
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
194
+
195
+ Example:
196
+ >>> input = torch.rand(5, 3, 4, 4)
197
+ >>> magnitude, edges = Canny()(input) # 5x3x4x4
198
+ >>> magnitude.shape
199
+ torch.Size([5, 1, 4, 4])
200
+ >>> edges.shape
201
+ torch.Size([5, 1, 4, 4])
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ low_threshold: float = 0.1,
207
+ high_threshold: float = 0.2,
208
+ kernel_size: Tuple[int, int] = (5, 5),
209
+ sigma: Tuple[float, float] = (1, 1),
210
+ hysteresis: bool = True,
211
+ eps: float = 1e-6,
212
+ ) -> None:
213
+ super().__init__()
214
+
215
+ if low_threshold > high_threshold:
216
+ raise ValueError(
217
+ "Invalid input thresholds. low_threshold should be\
218
+ smaller than the high_threshold. Got: {}>{}".format(
219
+ low_threshold, high_threshold
220
+ )
221
+ )
222
+
223
+ if low_threshold < 0 or low_threshold > 1:
224
+ raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
225
+
226
+ if high_threshold < 0 or high_threshold > 1:
227
+ raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
228
+
229
+ # Gaussian blur parameters
230
+ self.kernel_size = kernel_size
231
+ self.sigma = sigma
232
+
233
+ # Double threshold
234
+ self.low_threshold = low_threshold
235
+ self.high_threshold = high_threshold
236
+
237
+ # Hysteresis
238
+ self.hysteresis = hysteresis
239
+
240
+ self.eps: float = eps
241
+
242
+ def __repr__(self) -> str:
243
+ return ''.join(
244
+ (
245
+ f'{type(self).__name__}(',
246
+ ', '.join(
247
+ f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_')
248
+ ),
249
+ ')',
250
+ )
251
+ )
252
+
253
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
254
+ return canny(
255
+ input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps
256
+ )
model/canny/filter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .kernels import normalize_kernel2d
7
+
8
+
9
+ def _compute_padding(kernel_size: List[int]) -> List[int]:
10
+ """Compute padding tuple."""
11
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
12
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
13
+ if len(kernel_size) < 2:
14
+ raise AssertionError(kernel_size)
15
+ computed = [k - 1 for k in kernel_size]
16
+
17
+ # for even kernels we need to do asymmetric padding :(
18
+ out_padding = 2 * len(kernel_size) * [0]
19
+
20
+ for i in range(len(kernel_size)):
21
+ computed_tmp = computed[-(i + 1)]
22
+
23
+ pad_front = computed_tmp // 2
24
+ pad_rear = computed_tmp - pad_front
25
+
26
+ out_padding[2 * i + 0] = pad_front
27
+ out_padding[2 * i + 1] = pad_rear
28
+
29
+ return out_padding
30
+
31
+
32
+ def filter2d(
33
+ input: torch.Tensor,
34
+ kernel: torch.Tensor,
35
+ border_type: str = 'reflect',
36
+ normalized: bool = False,
37
+ padding: str = 'same',
38
+ ) -> torch.Tensor:
39
+ r"""Convolve a tensor with a 2d kernel.
40
+
41
+ The function applies a given kernel to a tensor. The kernel is applied
42
+ independently at each depth channel of the tensor. Before applying the
43
+ kernel, the function applies padding according to the specified mode so
44
+ that the output remains in the same shape.
45
+
46
+ Args:
47
+ input: the input tensor with shape of
48
+ :math:`(B, C, H, W)`.
49
+ kernel: the kernel to be convolved with the input
50
+ tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
51
+ border_type: the padding mode to be applied before convolving.
52
+ The expected modes are: ``'constant'``, ``'reflect'``,
53
+ ``'replicate'`` or ``'circular'``.
54
+ normalized: If True, kernel will be L1 normalized.
55
+ padding: This defines the type of padding.
56
+ 2 modes available ``'same'`` or ``'valid'``.
57
+
58
+ Return:
59
+ torch.Tensor: the convolved tensor of same size and numbers of channels
60
+ as the input with shape :math:`(B, C, H, W)`.
61
+
62
+ Example:
63
+ >>> input = torch.tensor([[[
64
+ ... [0., 0., 0., 0., 0.],
65
+ ... [0., 0., 0., 0., 0.],
66
+ ... [0., 0., 5., 0., 0.],
67
+ ... [0., 0., 0., 0., 0.],
68
+ ... [0., 0., 0., 0., 0.],]]])
69
+ >>> kernel = torch.ones(1, 3, 3)
70
+ >>> filter2d(input, kernel, padding='same')
71
+ tensor([[[[0., 0., 0., 0., 0.],
72
+ [0., 5., 5., 5., 0.],
73
+ [0., 5., 5., 5., 0.],
74
+ [0., 5., 5., 5., 0.],
75
+ [0., 0., 0., 0., 0.]]]])
76
+ """
77
+ if not isinstance(input, torch.Tensor):
78
+ raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
79
+
80
+ if not isinstance(kernel, torch.Tensor):
81
+ raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
82
+
83
+ if not isinstance(border_type, str):
84
+ raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
85
+
86
+ if border_type not in ['constant', 'reflect', 'replicate', 'circular']:
87
+ raise ValueError(
88
+ f"Invalid border type, we expect 'constant', \
89
+ 'reflect', 'replicate', 'circular'. Got:{border_type}"
90
+ )
91
+
92
+ if not isinstance(padding, str):
93
+ raise TypeError(f"Input padding is not string. Got {type(padding)}")
94
+
95
+ if padding not in ['valid', 'same']:
96
+ raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}")
97
+
98
+ if not len(input.shape) == 4:
99
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
100
+
101
+ if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])):
102
+ raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}")
103
+
104
+ # prepare kernel
105
+ b, c, h, w = input.shape
106
+ tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
107
+
108
+ if normalized:
109
+ tmp_kernel = normalize_kernel2d(tmp_kernel)
110
+
111
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
112
+
113
+ height, width = tmp_kernel.shape[-2:]
114
+
115
+ # pad the input tensor
116
+ if padding == 'same':
117
+ padding_shape: List[int] = _compute_padding([height, width])
118
+ input = F.pad(input, padding_shape, mode=border_type)
119
+
120
+ # kernel and input tensor reshape to align element-wise or batch-wise params
121
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
122
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
123
+
124
+ # convolve the tensor with the kernel.
125
+ output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
126
+
127
+ if padding == 'same':
128
+ out = output.view(b, c, h, w)
129
+ else:
130
+ out = output.view(b, c, h - height + 1, w - width + 1)
131
+
132
+ return out
133
+
134
+
135
+ def filter2d_separable(
136
+ input: torch.Tensor,
137
+ kernel_x: torch.Tensor,
138
+ kernel_y: torch.Tensor,
139
+ border_type: str = 'reflect',
140
+ normalized: bool = False,
141
+ padding: str = 'same',
142
+ ) -> torch.Tensor:
143
+ r"""Convolve a tensor with two 1d kernels, in x and y directions.
144
+
145
+ The function applies a given kernel to a tensor. The kernel is applied
146
+ independently at each depth channel of the tensor. Before applying the
147
+ kernel, the function applies padding according to the specified mode so
148
+ that the output remains in the same shape.
149
+
150
+ Args:
151
+ input: the input tensor with shape of
152
+ :math:`(B, C, H, W)`.
153
+ kernel_x: the kernel to be convolved with the input
154
+ tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
155
+ kernel_y: the kernel to be convolved with the input
156
+ tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
157
+ border_type: the padding mode to be applied before convolving.
158
+ The expected modes are: ``'constant'``, ``'reflect'``,
159
+ ``'replicate'`` or ``'circular'``.
160
+ normalized: If True, kernel will be L1 normalized.
161
+ padding: This defines the type of padding.
162
+ 2 modes available ``'same'`` or ``'valid'``.
163
+
164
+ Return:
165
+ torch.Tensor: the convolved tensor of same size and numbers of channels
166
+ as the input with shape :math:`(B, C, H, W)`.
167
+
168
+ Example:
169
+ >>> input = torch.tensor([[[
170
+ ... [0., 0., 0., 0., 0.],
171
+ ... [0., 0., 0., 0., 0.],
172
+ ... [0., 0., 5., 0., 0.],
173
+ ... [0., 0., 0., 0., 0.],
174
+ ... [0., 0., 0., 0., 0.],]]])
175
+ >>> kernel = torch.ones(1, 3)
176
+
177
+ >>> filter2d_separable(input, kernel, kernel, padding='same')
178
+ tensor([[[[0., 0., 0., 0., 0.],
179
+ [0., 5., 5., 5., 0.],
180
+ [0., 5., 5., 5., 0.],
181
+ [0., 5., 5., 5., 0.],
182
+ [0., 0., 0., 0., 0.]]]])
183
+ """
184
+ out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
185
+ out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
186
+ return out
187
+
188
+
189
+ def filter3d(
190
+ input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False
191
+ ) -> torch.Tensor:
192
+ r"""Convolve a tensor with a 3d kernel.
193
+
194
+ The function applies a given kernel to a tensor. The kernel is applied
195
+ independently at each depth channel of the tensor. Before applying the
196
+ kernel, the function applies padding according to the specified mode so
197
+ that the output remains in the same shape.
198
+
199
+ Args:
200
+ input: the input tensor with shape of
201
+ :math:`(B, C, D, H, W)`.
202
+ kernel: the kernel to be convolved with the input
203
+ tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
204
+ border_type: the padding mode to be applied before convolving.
205
+ The expected modes are: ``'constant'``,
206
+ ``'replicate'`` or ``'circular'``.
207
+ normalized: If True, kernel will be L1 normalized.
208
+
209
+ Return:
210
+ the convolved tensor of same size and numbers of channels
211
+ as the input with shape :math:`(B, C, D, H, W)`.
212
+
213
+ Example:
214
+ >>> input = torch.tensor([[[
215
+ ... [[0., 0., 0., 0., 0.],
216
+ ... [0., 0., 0., 0., 0.],
217
+ ... [0., 0., 0., 0., 0.],
218
+ ... [0., 0., 0., 0., 0.],
219
+ ... [0., 0., 0., 0., 0.]],
220
+ ... [[0., 0., 0., 0., 0.],
221
+ ... [0., 0., 0., 0., 0.],
222
+ ... [0., 0., 5., 0., 0.],
223
+ ... [0., 0., 0., 0., 0.],
224
+ ... [0., 0., 0., 0., 0.]],
225
+ ... [[0., 0., 0., 0., 0.],
226
+ ... [0., 0., 0., 0., 0.],
227
+ ... [0., 0., 0., 0., 0.],
228
+ ... [0., 0., 0., 0., 0.],
229
+ ... [0., 0., 0., 0., 0.]]
230
+ ... ]]])
231
+ >>> kernel = torch.ones(1, 3, 3, 3)
232
+ >>> filter3d(input, kernel)
233
+ tensor([[[[[0., 0., 0., 0., 0.],
234
+ [0., 5., 5., 5., 0.],
235
+ [0., 5., 5., 5., 0.],
236
+ [0., 5., 5., 5., 0.],
237
+ [0., 0., 0., 0., 0.]],
238
+ <BLANKLINE>
239
+ [[0., 0., 0., 0., 0.],
240
+ [0., 5., 5., 5., 0.],
241
+ [0., 5., 5., 5., 0.],
242
+ [0., 5., 5., 5., 0.],
243
+ [0., 0., 0., 0., 0.]],
244
+ <BLANKLINE>
245
+ [[0., 0., 0., 0., 0.],
246
+ [0., 5., 5., 5., 0.],
247
+ [0., 5., 5., 5., 0.],
248
+ [0., 5., 5., 5., 0.],
249
+ [0., 0., 0., 0., 0.]]]]])
250
+ """
251
+ if not isinstance(input, torch.Tensor):
252
+ raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
253
+
254
+ if not isinstance(kernel, torch.Tensor):
255
+ raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
256
+
257
+ if not isinstance(border_type, str):
258
+ raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
259
+
260
+ if not len(input.shape) == 5:
261
+ raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
262
+
263
+ if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
264
+ raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}")
265
+
266
+ # prepare kernel
267
+ b, c, d, h, w = input.shape
268
+ tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
269
+
270
+ if normalized:
271
+ bk, dk, hk, wk = kernel.shape
272
+ tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel)
273
+
274
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
275
+
276
+ # pad the input tensor
277
+ depth, height, width = tmp_kernel.shape[-3:]
278
+ padding_shape: List[int] = _compute_padding([depth, height, width])
279
+ input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
280
+
281
+ # kernel and input tensor reshape to align element-wise or batch-wise params
282
+ tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
283
+ input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1))
284
+
285
+ # convolve the tensor with the kernel.
286
+ output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
287
+
288
+ return output.view(b, c, d, h, w)
model/canny/gaussian.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .filter import filter2d, filter2d_separable
7
+ from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d
8
+
9
+
10
+ def gaussian_blur2d(
11
+ input: torch.Tensor,
12
+ kernel_size: Tuple[int, int],
13
+ sigma: Tuple[float, float],
14
+ border_type: str = 'reflect',
15
+ separable: bool = True,
16
+ ) -> torch.Tensor:
17
+ r"""Create an operator that blurs a tensor using a Gaussian filter.
18
+
19
+ .. image:: _static/img/gaussian_blur2d.png
20
+
21
+ The operator smooths the given tensor with a gaussian kernel by convolving
22
+ it to each channel. It supports batched operation.
23
+
24
+ Arguments:
25
+ input: the input tensor with shape :math:`(B,C,H,W)`.
26
+ kernel_size: the size of the kernel.
27
+ sigma: the standard deviation of the kernel.
28
+ border_type: the padding mode to be applied before convolving.
29
+ The expected modes are: ``'constant'``, ``'reflect'``,
30
+ ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
31
+ separable: run as composition of two 1d-convolutions.
32
+
33
+ Returns:
34
+ the blurred tensor with shape :math:`(B, C, H, W)`.
35
+
36
+ .. note::
37
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
38
+ gaussian_blur.html>`__.
39
+
40
+ Examples:
41
+ >>> input = torch.rand(2, 4, 5, 5)
42
+ >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5))
43
+ >>> output.shape
44
+ torch.Size([2, 4, 5, 5])
45
+ """
46
+ if separable:
47
+ kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1])
48
+ kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0])
49
+ out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type)
50
+ else:
51
+ kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
52
+ out = filter2d(input, kernel[None], border_type)
53
+ return out
54
+
55
+
56
+ class GaussianBlur2d(nn.Module):
57
+ r"""Create an operator that blurs a tensor using a Gaussian filter.
58
+
59
+ The operator smooths the given tensor with a gaussian kernel by convolving
60
+ it to each channel. It supports batched operation.
61
+
62
+ Arguments:
63
+ kernel_size: the size of the kernel.
64
+ sigma: the standard deviation of the kernel.
65
+ border_type: the padding mode to be applied before convolving.
66
+ The expected modes are: ``'constant'``, ``'reflect'``,
67
+ ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
68
+ separable: run as composition of two 1d-convolutions.
69
+
70
+ Returns:
71
+ the blurred tensor.
72
+
73
+ Shape:
74
+ - Input: :math:`(B, C, H, W)`
75
+ - Output: :math:`(B, C, H, W)`
76
+
77
+ Examples::
78
+
79
+ >>> input = torch.rand(2, 4, 5, 5)
80
+ >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5))
81
+ >>> output = gauss(input) # 2x4x5x5
82
+ >>> output.shape
83
+ torch.Size([2, 4, 5, 5])
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ kernel_size: Tuple[int, int],
89
+ sigma: Tuple[float, float],
90
+ border_type: str = 'reflect',
91
+ separable: bool = True,
92
+ ) -> None:
93
+ super().__init__()
94
+ self.kernel_size: Tuple[int, int] = kernel_size
95
+ self.sigma: Tuple[float, float] = sigma
96
+ self.border_type = border_type
97
+ self.separable = separable
98
+
99
+ def __repr__(self) -> str:
100
+ return (
101
+ self.__class__.__name__
102
+ + '(kernel_size='
103
+ + str(self.kernel_size)
104
+ + ', '
105
+ + 'sigma='
106
+ + str(self.sigma)
107
+ + ', '
108
+ + 'border_type='
109
+ + self.border_type
110
+ + 'separable='
111
+ + str(self.separable)
112
+ + ')'
113
+ )
114
+
115
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
116
+ return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable)
model/canny/kernels.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import sqrt
3
+ from typing import List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+
8
+ def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor:
9
+ r"""Normalize both derivative and smoothing kernel."""
10
+ if len(input.size()) < 2:
11
+ raise TypeError(f"input should be at least 2D tensor. Got {input.size()}")
12
+ norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1)
13
+ return input / (norm.unsqueeze(-1).unsqueeze(-1))
14
+
15
+
16
+ def gaussian(window_size: int, sigma: float) -> torch.Tensor:
17
+ device, dtype = None, None
18
+ if isinstance(sigma, torch.Tensor):
19
+ device, dtype = sigma.device, sigma.dtype
20
+ x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
21
+ if window_size % 2 == 0:
22
+ x = x + 0.5
23
+
24
+ gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float())
25
+ return gauss / gauss.sum()
26
+
27
+
28
+ def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor:
29
+ r"""Discrete Gaussian by interpolating the error function.
30
+
31
+ Adapted from:
32
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
33
+ """
34
+ device = sigma.device if isinstance(sigma, torch.Tensor) else None
35
+ sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
36
+ x = torch.arange(window_size).float() - window_size // 2
37
+ t = 0.70710678 / torch.abs(sigma)
38
+ gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
39
+ gauss = gauss.clamp(min=0)
40
+ return gauss / gauss.sum()
41
+
42
+
43
+ def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
44
+ r"""Adapted from:
45
+
46
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
47
+ """
48
+ if torch.abs(x) < 3.75:
49
+ y = (x / 3.75) * (x / 3.75)
50
+ return 1.0 + y * (
51
+ 3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2))))
52
+ )
53
+ ax = torch.abs(x)
54
+ y = 3.75 / ax
55
+ ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2)))
56
+ coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans)))
57
+ return (torch.exp(ax) / torch.sqrt(ax)) * coef
58
+
59
+
60
+ def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
61
+ r"""adapted from:
62
+
63
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
64
+ """
65
+ if torch.abs(x) < 3.75:
66
+ y = (x / 3.75) * (x / 3.75)
67
+ ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3)))
68
+ return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans))
69
+ ax = torch.abs(x)
70
+ y = 3.75 / ax
71
+ ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2))
72
+ ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans))))
73
+ ans = ans * torch.exp(ax) / torch.sqrt(ax)
74
+ return -ans if x < 0.0 else ans
75
+
76
+
77
+ def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
78
+ r"""adapted from:
79
+
80
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
81
+ """
82
+ if n < 2:
83
+ raise ValueError("n must be greater than 1.")
84
+ if x == 0.0:
85
+ return x
86
+ device = x.device
87
+ tox = 2.0 / torch.abs(x)
88
+ ans = torch.tensor(0.0, device=device)
89
+ bip = torch.tensor(0.0, device=device)
90
+ bi = torch.tensor(1.0, device=device)
91
+ m = int(2 * (n + int(sqrt(40.0 * n))))
92
+ for j in range(m, 0, -1):
93
+ bim = bip + float(j) * tox * bi
94
+ bip = bi
95
+ bi = bim
96
+ if abs(bi) > 1.0e10:
97
+ ans = ans * 1.0e-10
98
+ bi = bi * 1.0e-10
99
+ bip = bip * 1.0e-10
100
+ if j == n:
101
+ ans = bip
102
+ ans = ans * _modified_bessel_0(x) / bi
103
+ return -ans if x < 0.0 and (n % 2) == 1 else ans
104
+
105
+
106
+ def gaussian_discrete(window_size, sigma) -> torch.Tensor:
107
+ r"""Discrete Gaussian kernel based on the modified Bessel functions.
108
+
109
+ Adapted from:
110
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
111
+ """
112
+ device = sigma.device if isinstance(sigma, torch.Tensor) else None
113
+ sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
114
+ sigma2 = sigma * sigma
115
+ tail = int(window_size // 2)
116
+ out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1)
117
+ out_pos[0] = _modified_bessel_0(sigma2)
118
+ out_pos[1] = _modified_bessel_1(sigma2)
119
+ for k in range(2, len(out_pos)):
120
+ out_pos[k] = _modified_bessel_i(k, sigma2)
121
+ out = out_pos[:0:-1]
122
+ out.extend(out_pos)
123
+ out = torch.stack(out) * torch.exp(sigma2) # type: ignore
124
+ return out / out.sum() # type: ignore
125
+
126
+
127
+ def laplacian_1d(window_size) -> torch.Tensor:
128
+ r"""One could also use the Laplacian of Gaussian formula to design the filter."""
129
+
130
+ filter_1d = torch.ones(window_size)
131
+ filter_1d[window_size // 2] = 1 - window_size
132
+ laplacian_1d: torch.Tensor = filter_1d
133
+ return laplacian_1d
134
+
135
+
136
+ def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor:
137
+ r"""Utility function that returns a box filter."""
138
+ kx: float = float(kernel_size[0])
139
+ ky: float = float(kernel_size[1])
140
+ scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky])
141
+ tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1])
142
+ return scale.to(tmp_kernel.dtype) * tmp_kernel
143
+
144
+
145
+ def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor:
146
+ r"""Create a binary kernel to extract the patches.
147
+
148
+ If the window size is HxW will create a (H*W)xHxW kernel.
149
+ """
150
+ window_range: int = window_size[0] * window_size[1]
151
+ kernel: torch.Tensor = torch.zeros(window_range, window_range)
152
+ for i in range(window_range):
153
+ kernel[i, i] += 1.0
154
+ return kernel.view(window_range, 1, window_size[0], window_size[1])
155
+
156
+
157
+ def get_sobel_kernel_3x3() -> torch.Tensor:
158
+ """Utility function that returns a sobel kernel of 3x3."""
159
+ return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
160
+
161
+
162
+ def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor:
163
+ """Utility function that returns a 2nd order sobel kernel of 5x5."""
164
+ return torch.tensor(
165
+ [
166
+ [-1.0, 0.0, 2.0, 0.0, -1.0],
167
+ [-4.0, 0.0, 8.0, 0.0, -4.0],
168
+ [-6.0, 0.0, 12.0, 0.0, -6.0],
169
+ [-4.0, 0.0, 8.0, 0.0, -4.0],
170
+ [-1.0, 0.0, 2.0, 0.0, -1.0],
171
+ ]
172
+ )
173
+
174
+
175
+ def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor:
176
+ """Utility function that returns a 2nd order sobel kernel of 5x5."""
177
+ return torch.tensor(
178
+ [
179
+ [-1.0, -2.0, 0.0, 2.0, 1.0],
180
+ [-2.0, -4.0, 0.0, 4.0, 2.0],
181
+ [0.0, 0.0, 0.0, 0.0, 0.0],
182
+ [2.0, 4.0, 0.0, -4.0, -2.0],
183
+ [1.0, 2.0, 0.0, -2.0, -1.0],
184
+ ]
185
+ )
186
+
187
+
188
+ def get_diff_kernel_3x3() -> torch.Tensor:
189
+ """Utility function that returns a first order derivative kernel of 3x3."""
190
+ return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]])
191
+
192
+
193
+ def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
194
+ """Utility function that returns a first order derivative kernel of 3x3x3."""
195
+ kernel: torch.Tensor = torch.tensor(
196
+ [
197
+ [
198
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
199
+ [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]],
200
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
201
+ ],
202
+ [
203
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
204
+ [[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]],
205
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
206
+ ],
207
+ [
208
+ [[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]],
209
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
210
+ [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]],
211
+ ],
212
+ ],
213
+ device=device,
214
+ dtype=dtype,
215
+ )
216
+ return kernel.unsqueeze(1)
217
+
218
+
219
+ def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
220
+ """Utility function that returns a first order derivative kernel of 3x3x3."""
221
+ kernel: torch.Tensor = torch.tensor(
222
+ [
223
+ [
224
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
225
+ [[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]],
226
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
227
+ ],
228
+ [
229
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
230
+ [[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]],
231
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
232
+ ],
233
+ [
234
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
235
+ [[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]],
236
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
237
+ ],
238
+ [
239
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
240
+ [[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]],
241
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
242
+ ],
243
+ [
244
+ [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]],
245
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
246
+ [[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
247
+ ],
248
+ [
249
+ [[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]],
250
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
251
+ [[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
252
+ ],
253
+ ],
254
+ device=device,
255
+ dtype=dtype,
256
+ )
257
+ return kernel.unsqueeze(1)
258
+
259
+
260
+ def get_sobel_kernel2d() -> torch.Tensor:
261
+ kernel_x: torch.Tensor = get_sobel_kernel_3x3()
262
+ kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
263
+ return torch.stack([kernel_x, kernel_y])
264
+
265
+
266
+ def get_diff_kernel2d() -> torch.Tensor:
267
+ kernel_x: torch.Tensor = get_diff_kernel_3x3()
268
+ kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
269
+ return torch.stack([kernel_x, kernel_y])
270
+
271
+
272
+ def get_sobel_kernel2d_2nd_order() -> torch.Tensor:
273
+ gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order()
274
+ gyy: torch.Tensor = gxx.transpose(0, 1)
275
+ gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy()
276
+ return torch.stack([gxx, gxy, gyy])
277
+
278
+
279
+ def get_diff_kernel2d_2nd_order() -> torch.Tensor:
280
+ gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]])
281
+ gyy: torch.Tensor = gxx.transpose(0, 1)
282
+ gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]])
283
+ return torch.stack([gxx, gxy, gyy])
284
+
285
+
286
+ def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor:
287
+ r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators:
288
+
289
+ sobel, diff.
290
+ """
291
+ if mode not in ['sobel', 'diff']:
292
+ raise TypeError(
293
+ "mode should be either sobel\
294
+ or diff. Got {}".format(
295
+ mode
296
+ )
297
+ )
298
+ if order not in [1, 2]:
299
+ raise TypeError(
300
+ "order should be either 1 or 2\
301
+ Got {}".format(
302
+ order
303
+ )
304
+ )
305
+ if mode == 'sobel' and order == 1:
306
+ kernel: torch.Tensor = get_sobel_kernel2d()
307
+ elif mode == 'sobel' and order == 2:
308
+ kernel = get_sobel_kernel2d_2nd_order()
309
+ elif mode == 'diff' and order == 1:
310
+ kernel = get_diff_kernel2d()
311
+ elif mode == 'diff' and order == 2:
312
+ kernel = get_diff_kernel2d_2nd_order()
313
+ else:
314
+ raise NotImplementedError("")
315
+ return kernel
316
+
317
+
318
+ def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
319
+ r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following
320
+ operators: sobel, diff."""
321
+ if mode not in ['sobel', 'diff']:
322
+ raise TypeError(
323
+ "mode should be either sobel\
324
+ or diff. Got {}".format(
325
+ mode
326
+ )
327
+ )
328
+ if order not in [1, 2]:
329
+ raise TypeError(
330
+ "order should be either 1 or 2\
331
+ Got {}".format(
332
+ order
333
+ )
334
+ )
335
+ if mode == 'sobel':
336
+ raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet")
337
+ if mode == 'diff' and order == 1:
338
+ kernel = get_diff_kernel3d(device, dtype)
339
+ elif mode == 'diff' and order == 2:
340
+ kernel = get_diff_kernel3d_2nd_order(device, dtype)
341
+ else:
342
+ raise NotImplementedError("")
343
+ return kernel
344
+
345
+
346
+ def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
347
+ r"""Function that returns Gaussian filter coefficients.
348
+
349
+ Args:
350
+ kernel_size: filter size. It should be odd and positive.
351
+ sigma: gaussian standard deviation.
352
+ force_even: overrides requirement for odd kernel size.
353
+
354
+ Returns:
355
+ 1D tensor with gaussian filter coefficients.
356
+
357
+ Shape:
358
+ - Output: :math:`(\text{kernel_size})`
359
+
360
+ Examples:
361
+
362
+ >>> get_gaussian_kernel1d(3, 2.5)
363
+ tensor([0.3243, 0.3513, 0.3243])
364
+
365
+ >>> get_gaussian_kernel1d(5, 1.5)
366
+ tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
367
+ """
368
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
369
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
370
+ window_1d: torch.Tensor = gaussian(kernel_size, sigma)
371
+ return window_1d
372
+
373
+
374
+ def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
375
+ r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from:
376
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
377
+
378
+ Args:
379
+ kernel_size: filter size. It should be odd and positive.
380
+ sigma: gaussian standard deviation.
381
+ force_even: overrides requirement for odd kernel size.
382
+
383
+ Returns:
384
+ 1D tensor with gaussian filter coefficients.
385
+
386
+ Shape:
387
+ - Output: :math:`(\text{kernel_size})`
388
+
389
+ Examples:
390
+
391
+ >>> get_gaussian_discrete_kernel1d(3, 2.5)
392
+ tensor([0.3235, 0.3531, 0.3235])
393
+
394
+ >>> get_gaussian_discrete_kernel1d(5, 1.5)
395
+ tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096])
396
+ """
397
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
398
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
399
+ window_1d = gaussian_discrete(kernel_size, sigma)
400
+ return window_1d
401
+
402
+
403
+ def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
404
+ r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from:
405
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
406
+
407
+ Args:
408
+ kernel_size: filter size. It should be odd and positive.
409
+ sigma: gaussian standard deviation.
410
+ force_even: overrides requirement for odd kernel size.
411
+
412
+ Returns:
413
+ 1D tensor with gaussian filter coefficients.
414
+
415
+ Shape:
416
+ - Output: :math:`(\text{kernel_size})`
417
+
418
+ Examples:
419
+
420
+ >>> get_gaussian_erf_kernel1d(3, 2.5)
421
+ tensor([0.3245, 0.3511, 0.3245])
422
+
423
+ >>> get_gaussian_erf_kernel1d(5, 1.5)
424
+ tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226])
425
+ """
426
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
427
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
428
+ window_1d = gaussian_discrete_erf(kernel_size, sigma)
429
+ return window_1d
430
+
431
+
432
+ def get_gaussian_kernel2d(
433
+ kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False
434
+ ) -> torch.Tensor:
435
+ r"""Function that returns Gaussian filter matrix coefficients.
436
+
437
+ Args:
438
+ kernel_size: filter sizes in the x and y direction.
439
+ Sizes should be odd and positive.
440
+ sigma: gaussian standard deviation in the x and y
441
+ direction.
442
+ force_even: overrides requirement for odd kernel size.
443
+
444
+ Returns:
445
+ 2D tensor with gaussian filter matrix coefficients.
446
+
447
+ Shape:
448
+ - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
449
+
450
+ Examples:
451
+ >>> get_gaussian_kernel2d((3, 3), (1.5, 1.5))
452
+ tensor([[0.0947, 0.1183, 0.0947],
453
+ [0.1183, 0.1478, 0.1183],
454
+ [0.0947, 0.1183, 0.0947]])
455
+ >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5))
456
+ tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
457
+ [0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
458
+ [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
459
+ """
460
+ if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
461
+ raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}")
462
+ if not isinstance(sigma, tuple) or len(sigma) != 2:
463
+ raise TypeError(f"sigma must be a tuple of length two. Got {sigma}")
464
+ ksize_x, ksize_y = kernel_size
465
+ sigma_x, sigma_y = sigma
466
+ kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even)
467
+ kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even)
468
+ kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
469
+ return kernel_2d
470
+
471
+
472
+ def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor:
473
+ r"""Function that returns the coefficients of a 1D Laplacian filter.
474
+
475
+ Args:
476
+ kernel_size: filter size. It should be odd and positive.
477
+
478
+ Returns:
479
+ 1D tensor with laplacian filter coefficients.
480
+
481
+ Shape:
482
+ - Output: math:`(\text{kernel_size})`
483
+
484
+ Examples:
485
+ >>> get_laplacian_kernel1d(3)
486
+ tensor([ 1., -2., 1.])
487
+ >>> get_laplacian_kernel1d(5)
488
+ tensor([ 1., 1., -4., 1., 1.])
489
+ """
490
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
491
+ raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
492
+ window_1d: torch.Tensor = laplacian_1d(kernel_size)
493
+ return window_1d
494
+
495
+
496
+ def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor:
497
+ r"""Function that returns Gaussian filter matrix coefficients.
498
+
499
+ Args:
500
+ kernel_size: filter size should be odd.
501
+
502
+ Returns:
503
+ 2D tensor with laplacian filter matrix coefficients.
504
+
505
+ Shape:
506
+ - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
507
+
508
+ Examples:
509
+ >>> get_laplacian_kernel2d(3)
510
+ tensor([[ 1., 1., 1.],
511
+ [ 1., -8., 1.],
512
+ [ 1., 1., 1.]])
513
+ >>> get_laplacian_kernel2d(5)
514
+ tensor([[ 1., 1., 1., 1., 1.],
515
+ [ 1., 1., 1., 1., 1.],
516
+ [ 1., 1., -24., 1., 1.],
517
+ [ 1., 1., 1., 1., 1.],
518
+ [ 1., 1., 1., 1., 1.]])
519
+ """
520
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
521
+ raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
522
+
523
+ kernel = torch.ones((kernel_size, kernel_size))
524
+ mid = kernel_size // 2
525
+ kernel[mid, mid] = 1 - kernel_size**2
526
+ kernel_2d: torch.Tensor = kernel
527
+ return kernel_2d
528
+
529
+
530
+ def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor:
531
+ """Generate pascal filter kernel by kernel size.
532
+
533
+ Args:
534
+ kernel_size: height and width of the kernel.
535
+ norm: if to normalize the kernel or not. Default: True.
536
+
537
+ Returns:
538
+ kernel shaped as :math:`(kernel_size, kernel_size)`
539
+
540
+ Examples:
541
+ >>> get_pascal_kernel_2d(1)
542
+ tensor([[1.]])
543
+ >>> get_pascal_kernel_2d(4)
544
+ tensor([[0.0156, 0.0469, 0.0469, 0.0156],
545
+ [0.0469, 0.1406, 0.1406, 0.0469],
546
+ [0.0469, 0.1406, 0.1406, 0.0469],
547
+ [0.0156, 0.0469, 0.0469, 0.0156]])
548
+ >>> get_pascal_kernel_2d(4, norm=False)
549
+ tensor([[1., 3., 3., 1.],
550
+ [3., 9., 9., 3.],
551
+ [3., 9., 9., 3.],
552
+ [1., 3., 3., 1.]])
553
+ """
554
+ a = get_pascal_kernel_1d(kernel_size)
555
+
556
+ filt = a[:, None] * a[None, :]
557
+ if norm:
558
+ filt = filt / torch.sum(filt)
559
+ return filt
560
+
561
+
562
+ def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor:
563
+ """Generate Yang Hui triangle (Pascal's triangle) by a given number.
564
+
565
+ Args:
566
+ kernel_size: height and width of the kernel.
567
+ norm: if to normalize the kernel or not. Default: False.
568
+
569
+ Returns:
570
+ kernel shaped as :math:`(kernel_size,)`
571
+
572
+ Examples:
573
+ >>> get_pascal_kernel_1d(1)
574
+ tensor([1.])
575
+ >>> get_pascal_kernel_1d(2)
576
+ tensor([1., 1.])
577
+ >>> get_pascal_kernel_1d(3)
578
+ tensor([1., 2., 1.])
579
+ >>> get_pascal_kernel_1d(4)
580
+ tensor([1., 3., 3., 1.])
581
+ >>> get_pascal_kernel_1d(5)
582
+ tensor([1., 4., 6., 4., 1.])
583
+ >>> get_pascal_kernel_1d(6)
584
+ tensor([ 1., 5., 10., 10., 5., 1.])
585
+ """
586
+ pre: List[float] = []
587
+ cur: List[float] = []
588
+ for i in range(kernel_size):
589
+ cur = [1.0] * (i + 1)
590
+
591
+ for j in range(1, i // 2 + 1):
592
+ value = pre[j - 1] + pre[j]
593
+ cur[j] = value
594
+ if i != 2 * j:
595
+ cur[-j - 1] = value
596
+ pre = cur
597
+
598
+ out = torch.as_tensor(cur)
599
+ if norm:
600
+ out = out / torch.sum(out)
601
+ return out
602
+
603
+
604
+ def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
605
+ """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
606
+ kernel: torch.Tensor = torch.tensor(
607
+ [
608
+ [[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]],
609
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
610
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]],
611
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]],
612
+ [[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
613
+ [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
614
+ [[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
615
+ [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
616
+ ],
617
+ device=device,
618
+ dtype=dtype,
619
+ )
620
+ return kernel.unsqueeze(1)
621
+
622
+
623
+ def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
624
+ """Utility function that returns the 3x3 kernels for the Canny hysteresis."""
625
+ kernel: torch.Tensor = torch.tensor(
626
+ [
627
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
628
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
629
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
630
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
631
+ [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
632
+ [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
633
+ [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
634
+ [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
635
+ ],
636
+ device=device,
637
+ dtype=dtype,
638
+ )
639
+ return kernel.unsqueeze(1)
640
+
641
+
642
+ def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
643
+ r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker.
644
+
645
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
646
+ \\qquad 0 \\leq n \\leq M-1
647
+
648
+ See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html
649
+
650
+ Args:
651
+ kernel_size: The size the of the kernel. It should be positive.
652
+
653
+ Returns:
654
+ 1D tensor with Hanning filter coefficients.
655
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
656
+
657
+ Shape:
658
+ - Output: math:`(\text{kernel_size})`
659
+
660
+ Examples:
661
+ >>> get_hanning_kernel1d(4)
662
+ tensor([0.0000, 0.7500, 0.7500, 0.0000])
663
+ """
664
+ if not isinstance(kernel_size, int) or kernel_size <= 2:
665
+ raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}")
666
+
667
+ x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype)
668
+ x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1))
669
+ return x
670
+
671
+
672
+ def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
673
+ r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker.
674
+
675
+ Args:
676
+ kernel_size: The size of the kernel for the filter. It should be positive.
677
+
678
+ Returns:
679
+ 2D tensor with Hanning filter coefficients.
680
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
681
+
682
+ Shape:
683
+ - Output: math:`(\text{kernel_size[0], kernel_size[1]})`
684
+ """
685
+ if kernel_size[0] <= 2 or kernel_size[1] <= 2:
686
+ raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}")
687
+ ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T
688
+ kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None]
689
+ kernel2d = ky @ kx
690
+ return kernel2d
model/canny/sobel.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d
6
+
7
+
8
+ def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor:
9
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
10
+
11
+ .. image:: _static/img/spatial_gradient.png
12
+
13
+ Args:
14
+ input: input image tensor with shape :math:`(B, C, H, W)`.
15
+ mode: derivatives modality, can be: `sobel` or `diff`.
16
+ order: the order of the derivatives.
17
+ normalized: whether the output is normalized.
18
+
19
+ Return:
20
+ the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
21
+
22
+ .. note::
23
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
24
+ filtering_edges.html>`__.
25
+
26
+ Examples:
27
+ >>> input = torch.rand(1, 3, 4, 4)
28
+ >>> output = spatial_gradient(input) # 1x3x2x4x4
29
+ >>> output.shape
30
+ torch.Size([1, 3, 2, 4, 4])
31
+ """
32
+ if not isinstance(input, torch.Tensor):
33
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
34
+
35
+ if not len(input.shape) == 4:
36
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
37
+ # allocate kernel
38
+ kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order)
39
+ if normalized:
40
+ kernel = normalize_kernel2d(kernel)
41
+
42
+ # prepare kernel
43
+ b, c, h, w = input.shape
44
+ tmp_kernel: torch.Tensor = kernel.to(input).detach()
45
+ tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
46
+
47
+ # convolve input tensor with sobel kernel
48
+ kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
49
+
50
+ # Pad with "replicate for spatial dims, but with zeros for channel
51
+ spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
52
+ out_channels: int = 3 if order == 2 else 2
53
+ padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None]
54
+
55
+ return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w)
56
+
57
+
58
+ def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor:
59
+ r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
60
+
61
+ Args:
62
+ input: input features tensor with shape :math:`(B, C, D, H, W)`.
63
+ mode: derivatives modality, can be: `sobel` or `diff`.
64
+ order: the order of the derivatives.
65
+
66
+ Return:
67
+ the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)`
68
+ or :math:`(B, C, 6, D, H, W)`.
69
+
70
+ Examples:
71
+ >>> input = torch.rand(1, 4, 2, 4, 4)
72
+ >>> output = spatial_gradient3d(input)
73
+ >>> output.shape
74
+ torch.Size([1, 4, 3, 2, 4, 4])
75
+ """
76
+ if not isinstance(input, torch.Tensor):
77
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
78
+
79
+ if not len(input.shape) == 5:
80
+ raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
81
+ b, c, d, h, w = input.shape
82
+ dev = input.device
83
+ dtype = input.dtype
84
+ if (mode == 'diff') and (order == 1):
85
+ # we go for the special case implementation due to conv3d bad speed
86
+ x: torch.Tensor = F.pad(input, 6 * [1], 'replicate')
87
+ center = slice(1, -1)
88
+ left = slice(0, -2)
89
+ right = slice(2, None)
90
+ out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype)
91
+ out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left]
92
+ out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center]
93
+ out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center]
94
+ out = 0.5 * out
95
+ else:
96
+ # prepare kernel
97
+ # allocate kernel
98
+ kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
99
+
100
+ tmp_kernel: torch.Tensor = kernel.to(input).detach()
101
+ tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)
102
+
103
+ # convolve input tensor with grad kernel
104
+ kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
105
+
106
+ # Pad with "replicate for spatial dims, but with zeros for channel
107
+ spatial_pad = [
108
+ kernel.size(2) // 2,
109
+ kernel.size(2) // 2,
110
+ kernel.size(3) // 2,
111
+ kernel.size(3) // 2,
112
+ kernel.size(4) // 2,
113
+ kernel.size(4) // 2,
114
+ ]
115
+ out_ch: int = 6 if order == 2 else 3
116
+ out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view(
117
+ b, c, out_ch, d, h, w
118
+ )
119
+ return out
120
+
121
+
122
+ def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor:
123
+ r"""Compute the Sobel operator and returns the magnitude per channel.
124
+
125
+ .. image:: _static/img/sobel.png
126
+
127
+ Args:
128
+ input: the input image with shape :math:`(B,C,H,W)`.
129
+ normalized: if True, L1 norm of the kernel is set to 1.
130
+ eps: regularization number to avoid NaN during backprop.
131
+
132
+ Return:
133
+ the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`.
134
+
135
+ .. note::
136
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
137
+ filtering_edges.html>`__.
138
+
139
+ Example:
140
+ >>> input = torch.rand(1, 3, 4, 4)
141
+ >>> output = sobel(input) # 1x3x4x4
142
+ >>> output.shape
143
+ torch.Size([1, 3, 4, 4])
144
+ """
145
+ if not isinstance(input, torch.Tensor):
146
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
147
+
148
+ if not len(input.shape) == 4:
149
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
150
+
151
+ # comput the x/y gradients
152
+ edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
153
+
154
+ # unpack the edges
155
+ gx: torch.Tensor = edges[:, :, 0]
156
+ gy: torch.Tensor = edges[:, :, 1]
157
+
158
+ # compute gradient maginitude
159
+ magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
160
+
161
+ return magnitude
162
+
163
+
164
+ class SpatialGradient(nn.Module):
165
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
166
+
167
+ Args:
168
+ mode: derivatives modality, can be: `sobel` or `diff`.
169
+ order: the order of the derivatives.
170
+ normalized: whether the output is normalized.
171
+
172
+ Return:
173
+ the sobel edges of the input feature map.
174
+
175
+ Shape:
176
+ - Input: :math:`(B, C, H, W)`
177
+ - Output: :math:`(B, C, 2, H, W)`
178
+
179
+ Examples:
180
+ >>> input = torch.rand(1, 3, 4, 4)
181
+ >>> output = SpatialGradient()(input) # 1x3x2x4x4
182
+ """
183
+
184
+ def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None:
185
+ super().__init__()
186
+ self.normalized: bool = normalized
187
+ self.order: int = order
188
+ self.mode: str = mode
189
+
190
+ def __repr__(self) -> str:
191
+ return (
192
+ self.__class__.__name__ + '('
193
+ 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')'
194
+ )
195
+
196
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
197
+ return spatial_gradient(input, self.mode, self.order, self.normalized)
198
+
199
+
200
+ class SpatialGradient3d(nn.Module):
201
+ r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
202
+
203
+ Args:
204
+ mode: derivatives modality, can be: `sobel` or `diff`.
205
+ order: the order of the derivatives.
206
+
207
+ Return:
208
+ the spatial gradients of the input feature map.
209
+
210
+ Shape:
211
+ - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
212
+ - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`
213
+
214
+ Examples:
215
+ >>> input = torch.rand(1, 4, 2, 4, 4)
216
+ >>> output = SpatialGradient3d()(input)
217
+ >>> output.shape
218
+ torch.Size([1, 4, 3, 2, 4, 4])
219
+ """
220
+
221
+ def __init__(self, mode: str = 'diff', order: int = 1) -> None:
222
+ super().__init__()
223
+ self.order: int = order
224
+ self.mode: str = mode
225
+ self.kernel = get_spatial_gradient_kernel3d(mode, order)
226
+ return
227
+
228
+ def __repr__(self) -> str:
229
+ return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')'
230
+
231
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
232
+ return spatial_gradient3d(input, self.mode, self.order)
233
+
234
+
235
+ class Sobel(nn.Module):
236
+ r"""Compute the Sobel operator and returns the magnitude per channel.
237
+
238
+ Args:
239
+ normalized: if True, L1 norm of the kernel is set to 1.
240
+ eps: regularization number to avoid NaN during backprop.
241
+
242
+ Return:
243
+ the sobel edge gradient magnitudes map.
244
+
245
+ Shape:
246
+ - Input: :math:`(B, C, H, W)`
247
+ - Output: :math:`(B, C, H, W)`
248
+
249
+ Examples:
250
+ >>> input = torch.rand(1, 3, 4, 4)
251
+ >>> output = Sobel()(input) # 1x3x4x4
252
+ """
253
+
254
+ def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
255
+ super().__init__()
256
+ self.normalized: bool = normalized
257
+ self.eps: float = eps
258
+
259
+ def __repr__(self) -> str:
260
+ return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')'
261
+
262
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
263
+ return sobel(input, self.normalized, self.eps)
model/misc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import random
4
+ import time
5
+ import torch
6
+ import torch.nn as nn
7
+ import logging
8
+ import numpy as np
9
+ from os import path as osp
10
+
11
+ def constant_init(module, val, bias=0):
12
+ if hasattr(module, 'weight') and module.weight is not None:
13
+ nn.init.constant_(module.weight, val)
14
+ if hasattr(module, 'bias') and module.bias is not None:
15
+ nn.init.constant_(module.bias, bias)
16
+
17
+ initialized_logger = {}
18
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
19
+ """Get the root logger.
20
+ The logger will be initialized if it has not been initialized. By default a
21
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
22
+ also be added.
23
+ Args:
24
+ logger_name (str): root logger name. Default: 'basicsr'.
25
+ log_file (str | None): The log filename. If specified, a FileHandler
26
+ will be added to the root logger.
27
+ log_level (int): The root logger level. Note that only the process of
28
+ rank 0 is affected, while other processes will set the level to
29
+ "Error" and be silent most of the time.
30
+ Returns:
31
+ logging.Logger: The root logger.
32
+ """
33
+ logger = logging.getLogger(logger_name)
34
+ # if the logger has been initialized, just return it
35
+ if logger_name in initialized_logger:
36
+ return logger
37
+
38
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
39
+ stream_handler = logging.StreamHandler()
40
+ stream_handler.setFormatter(logging.Formatter(format_str))
41
+ logger.addHandler(stream_handler)
42
+ logger.propagate = False
43
+
44
+ if log_file is not None:
45
+ logger.setLevel(log_level)
46
+ # add file handler
47
+ # file_handler = logging.FileHandler(log_file, 'w')
48
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
49
+ file_handler.setFormatter(logging.Formatter(format_str))
50
+ file_handler.setLevel(log_level)
51
+ logger.addHandler(file_handler)
52
+ initialized_logger[logger_name] = True
53
+ return logger
54
+
55
+
56
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
57
+ torch.__version__)[0][:3])] >= [1, 12, 0]
58
+
59
+ def gpu_is_available():
60
+ if IS_HIGH_VERSION:
61
+ if torch.backends.mps.is_available():
62
+ return True
63
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
64
+
65
+ def get_device(gpu_id=None):
66
+ if gpu_id is None:
67
+ gpu_str = ''
68
+ elif isinstance(gpu_id, int):
69
+ gpu_str = f':{gpu_id}'
70
+ else:
71
+ raise TypeError('Input should be int value.')
72
+
73
+ if IS_HIGH_VERSION:
74
+ if torch.backends.mps.is_available():
75
+ return torch.device('mps'+gpu_str)
76
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
77
+
78
+
79
+ def set_random_seed(seed):
80
+ """Set random seeds."""
81
+ random.seed(seed)
82
+ np.random.seed(seed)
83
+ torch.manual_seed(seed)
84
+ torch.cuda.manual_seed(seed)
85
+ torch.cuda.manual_seed_all(seed)
86
+
87
+
88
+ def get_time_str():
89
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
90
+
91
+
92
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
93
+ """Scan a directory to find the interested files.
94
+
95
+ Args:
96
+ dir_path (str): Path of the directory.
97
+ suffix (str | tuple(str), optional): File suffix that we are
98
+ interested in. Default: None.
99
+ recursive (bool, optional): If set to True, recursively scan the
100
+ directory. Default: False.
101
+ full_path (bool, optional): If set to True, include the dir_path.
102
+ Default: False.
103
+
104
+ Returns:
105
+ A generator for all the interested files with relative pathes.
106
+ """
107
+
108
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
109
+ raise TypeError('"suffix" must be a string or tuple of strings')
110
+
111
+ root = dir_path
112
+
113
+ def _scandir(dir_path, suffix, recursive):
114
+ for entry in os.scandir(dir_path):
115
+ if not entry.name.startswith('.') and entry.is_file():
116
+ if full_path:
117
+ return_path = entry.path
118
+ else:
119
+ return_path = osp.relpath(entry.path, root)
120
+
121
+ if suffix is None:
122
+ yield return_path
123
+ elif return_path.endswith(suffix):
124
+ yield return_path
125
+ else:
126
+ if recursive:
127
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
128
+ else:
129
+ continue
130
+
131
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
model/modules/base_module.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from functools import reduce
6
+
7
+ class BaseNetwork(nn.Module):
8
+ def __init__(self):
9
+ super(BaseNetwork, self).__init__()
10
+
11
+ def print_network(self):
12
+ if isinstance(self, list):
13
+ self = self[0]
14
+ num_params = 0
15
+ for param in self.parameters():
16
+ num_params += param.numel()
17
+ print(
18
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
19
+ 'To see the architecture, do print(network).' %
20
+ (type(self).__name__, num_params / 1000000))
21
+
22
+ def init_weights(self, init_type='normal', gain=0.02):
23
+ '''
24
+ initialize network's weights
25
+ init_type: normal | xavier | kaiming | orthogonal
26
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
27
+ '''
28
+ def init_func(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find('InstanceNorm2d') != -1:
31
+ if hasattr(m, 'weight') and m.weight is not None:
32
+ nn.init.constant_(m.weight.data, 1.0)
33
+ if hasattr(m, 'bias') and m.bias is not None:
34
+ nn.init.constant_(m.bias.data, 0.0)
35
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
36
+ or classname.find('Linear') != -1):
37
+ if init_type == 'normal':
38
+ nn.init.normal_(m.weight.data, 0.0, gain)
39
+ elif init_type == 'xavier':
40
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
41
+ elif init_type == 'xavier_uniform':
42
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
43
+ elif init_type == 'kaiming':
44
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
45
+ elif init_type == 'orthogonal':
46
+ nn.init.orthogonal_(m.weight.data, gain=gain)
47
+ elif init_type == 'none': # uses pytorch's default init method
48
+ m.reset_parameters()
49
+ else:
50
+ raise NotImplementedError(
51
+ 'initialization method [%s] is not implemented' %
52
+ init_type)
53
+ if hasattr(m, 'bias') and m.bias is not None:
54
+ nn.init.constant_(m.bias.data, 0.0)
55
+
56
+ self.apply(init_func)
57
+
58
+ # propagate to children
59
+ for m in self.children():
60
+ if hasattr(m, 'init_weights'):
61
+ m.init_weights(init_type, gain)
62
+
63
+
64
+ class Vec2Feat(nn.Module):
65
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
66
+ super(Vec2Feat, self).__init__()
67
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
68
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
69
+ self.embedding = nn.Linear(hidden, c_out)
70
+ self.kernel_size = kernel_size
71
+ self.stride = stride
72
+ self.padding = padding
73
+ self.bias_conv = nn.Conv2d(channel,
74
+ channel,
75
+ kernel_size=3,
76
+ stride=1,
77
+ padding=1)
78
+
79
+ def forward(self, x, t, output_size):
80
+ b_, _, _, _, c_ = x.shape
81
+ x = x.view(b_, -1, c_)
82
+ feat = self.embedding(x)
83
+ b, _, c = feat.size()
84
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
85
+ feat = F.fold(feat,
86
+ output_size=output_size,
87
+ kernel_size=self.kernel_size,
88
+ stride=self.stride,
89
+ padding=self.padding)
90
+ feat = self.bias_conv(feat)
91
+ return feat
92
+
93
+
94
+ class FusionFeedForward(nn.Module):
95
+ def __init__(self, dim, hidden_dim=1960, t2t_params=None):
96
+ super(FusionFeedForward, self).__init__()
97
+ # We set hidden_dim as a default to 1960
98
+ self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
99
+ self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
100
+ assert t2t_params is not None
101
+ self.t2t_params = t2t_params
102
+ self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
103
+
104
+ def forward(self, x, output_size):
105
+ n_vecs = 1
106
+ for i, d in enumerate(self.t2t_params['kernel_size']):
107
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
108
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
109
+
110
+ x = self.fc1(x)
111
+ b, n, c = x.size()
112
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
113
+ normalizer = F.fold(normalizer,
114
+ output_size=output_size,
115
+ kernel_size=self.t2t_params['kernel_size'],
116
+ padding=self.t2t_params['padding'],
117
+ stride=self.t2t_params['stride'])
118
+
119
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
120
+ output_size=output_size,
121
+ kernel_size=self.t2t_params['kernel_size'],
122
+ padding=self.t2t_params['padding'],
123
+ stride=self.t2t_params['stride'])
124
+
125
+ x = F.unfold(x / normalizer,
126
+ kernel_size=self.t2t_params['kernel_size'],
127
+ padding=self.t2t_params['padding'],
128
+ stride=self.t2t_params['stride']).permute(
129
+ 0, 2, 1).contiguous().view(b, n, c)
130
+ x = self.fc2(x)
131
+ return x
model/modules/deformconv.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init as init
4
+ from torch.nn.modules.utils import _pair, _single
5
+ import math
6
+
7
+ class ModulatedDeformConv2d(nn.Module):
8
+ def __init__(self,
9
+ in_channels,
10
+ out_channels,
11
+ kernel_size,
12
+ stride=1,
13
+ padding=0,
14
+ dilation=1,
15
+ groups=1,
16
+ deform_groups=1,
17
+ bias=True):
18
+ super(ModulatedDeformConv2d, self).__init__()
19
+
20
+ self.in_channels = in_channels
21
+ self.out_channels = out_channels
22
+ self.kernel_size = _pair(kernel_size)
23
+ self.stride = stride
24
+ self.padding = padding
25
+ self.dilation = dilation
26
+ self.groups = groups
27
+ self.deform_groups = deform_groups
28
+ self.with_bias = bias
29
+ # enable compatibility with nn.Conv2d
30
+ self.transposed = False
31
+ self.output_padding = _single(0)
32
+
33
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
34
+ if bias:
35
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
36
+ else:
37
+ self.register_parameter('bias', None)
38
+ self.init_weights()
39
+
40
+ def init_weights(self):
41
+ n = self.in_channels
42
+ for k in self.kernel_size:
43
+ n *= k
44
+ stdv = 1. / math.sqrt(n)
45
+ self.weight.data.uniform_(-stdv, stdv)
46
+ if self.bias is not None:
47
+ self.bias.data.zero_()
48
+
49
+ if hasattr(self, 'conv_offset'):
50
+ self.conv_offset.weight.data.zero_()
51
+ self.conv_offset.bias.data.zero_()
52
+
53
+ def forward(self, x, offset, mask):
54
+ pass
model/modules/flow_comp_raft.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from RAFT import RAFT
7
+ from model.modules.flow_loss_utils import flow_warp, ternary_loss2
8
+
9
+
10
+ def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
11
+ """Initializes the RAFT model.
12
+ """
13
+ args = argparse.ArgumentParser()
14
+ args.raft_model = model_path
15
+ args.small = False
16
+ args.mixed_precision = False
17
+ args.alternate_corr = False
18
+ model = torch.nn.DataParallel(RAFT(args))
19
+ model.load_state_dict(torch.load(args.raft_model, map_location='cpu'))
20
+ model = model.module
21
+
22
+ model.to(device)
23
+
24
+ return model
25
+
26
+
27
+ class RAFT_bi(nn.Module):
28
+ """Flow completion loss"""
29
+ def __init__(self, model_path='weights/raft-things.pth', device='cuda'):
30
+ super().__init__()
31
+ self.fix_raft = initialize_RAFT(model_path, device=device)
32
+
33
+ for p in self.fix_raft.parameters():
34
+ p.requires_grad = False
35
+
36
+ self.l1_criterion = nn.L1Loss()
37
+ self.eval()
38
+
39
+ def forward(self, gt_local_frames, iters=20):
40
+ b, l_t, c, h, w = gt_local_frames.size()
41
+ # print(gt_local_frames.shape)
42
+
43
+ with torch.no_grad():
44
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w)
45
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w)
46
+ # print(gtlf_1.shape)
47
+
48
+ _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True)
49
+ _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True)
50
+
51
+
52
+ gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w)
53
+ gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w)
54
+
55
+ return gt_flows_forward, gt_flows_backward
56
+
57
+
58
+ ##################################################################################
59
+ def smoothness_loss(flow, cmask):
60
+ delta_u, delta_v, mask = smoothness_deltas(flow)
61
+ loss_u = charbonnier_loss(delta_u, cmask)
62
+ loss_v = charbonnier_loss(delta_v, cmask)
63
+ return loss_u + loss_v
64
+
65
+
66
+ def smoothness_deltas(flow):
67
+ """
68
+ flow: [b, c, h, w]
69
+ """
70
+ mask_x = create_mask(flow, [[0, 0], [0, 1]])
71
+ mask_y = create_mask(flow, [[0, 1], [0, 0]])
72
+ mask = torch.cat((mask_x, mask_y), dim=1)
73
+ mask = mask.to(flow.device)
74
+ filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]])
75
+ filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]])
76
+ weights = torch.ones([2, 1, 3, 3])
77
+ weights[0, 0] = filter_x
78
+ weights[1, 0] = filter_y
79
+ weights = weights.to(flow.device)
80
+
81
+ flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
82
+ delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
83
+ delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
84
+ return delta_u, delta_v, mask
85
+
86
+
87
+ def second_order_loss(flow, cmask):
88
+ delta_u, delta_v, mask = second_order_deltas(flow)
89
+ loss_u = charbonnier_loss(delta_u, cmask)
90
+ loss_v = charbonnier_loss(delta_v, cmask)
91
+ return loss_u + loss_v
92
+
93
+
94
+ def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001):
95
+ """
96
+ Compute the generalized charbonnier loss of the difference tensor x
97
+ All positions where mask == 0 are not taken into account
98
+ x: a tensor of shape [b, c, h, w]
99
+ mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as
100
+ the number of channels of x. Entries should be 0 or 1
101
+ return: loss
102
+ """
103
+ b, c, h, w = x.shape
104
+ norm = b * c * h * w
105
+ error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha)
106
+ if mask is not None:
107
+ error = mask * error
108
+ if truncate is not None:
109
+ error = torch.min(error, truncate)
110
+ return torch.sum(error) / norm
111
+
112
+
113
+ def second_order_deltas(flow):
114
+ """
115
+ consider the single flow first
116
+ flow shape: [b, c, h, w]
117
+ """
118
+ # create mask
119
+ mask_x = create_mask(flow, [[0, 0], [1, 1]])
120
+ mask_y = create_mask(flow, [[1, 1], [0, 0]])
121
+ mask_diag = create_mask(flow, [[1, 1], [1, 1]])
122
+ mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1)
123
+ mask = mask.to(flow.device)
124
+
125
+ filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]])
126
+ filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]])
127
+ filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]])
128
+ filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]])
129
+ weights = torch.ones([4, 1, 3, 3])
130
+ weights[0] = filter_x
131
+ weights[1] = filter_y
132
+ weights[2] = filter_diag1
133
+ weights[3] = filter_diag2
134
+ weights = weights.to(flow.device)
135
+
136
+ # split the flow into flow_u and flow_v, conv them with the weights
137
+ flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
138
+ delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
139
+ delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
140
+ return delta_u, delta_v, mask
141
+
142
+ def create_mask(tensor, paddings):
143
+ """
144
+ tensor shape: [b, c, h, w]
145
+ paddings: [2 x 2] shape list, the first row indicates up and down paddings
146
+ the second row indicates left and right paddings
147
+ | |
148
+ | x |
149
+ | x * x |
150
+ | x |
151
+ | |
152
+ """
153
+ shape = tensor.shape
154
+ inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
155
+ inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
156
+ inner = torch.ones([inner_height, inner_width])
157
+ torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down
158
+ mask2d = F.pad(inner, pad=torch_paddings)
159
+ mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1)
160
+ mask4d = mask3d.unsqueeze(1)
161
+ return mask4d.detach()
162
+
163
+ def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1):
164
+ if scale_factor != 1:
165
+ current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear')
166
+ shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear')
167
+ warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1))
168
+ noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1)
169
+ warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1))
170
+ loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask)
171
+ return loss
172
+
173
+ class FlowLoss(nn.Module):
174
+ def __init__(self):
175
+ super().__init__()
176
+ self.l1_criterion = nn.L1Loss()
177
+
178
+ def forward(self, pred_flows, gt_flows, masks, frames):
179
+ # pred_flows: b t-1 2 h w
180
+ loss = 0
181
+ warp_loss = 0
182
+ h, w = pred_flows[0].shape[-2:]
183
+ masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
184
+ frames0 = frames[:,:-1,...]
185
+ frames1 = frames[:,1:,...]
186
+ current_frames = [frames0, frames1]
187
+ next_frames = [frames1, frames0]
188
+ for i in range(len(pred_flows)):
189
+ # print(pred_flows[i].shape)
190
+ combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i])
191
+ l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i])
192
+ l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i]))
193
+
194
+ smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
195
+ smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
196
+
197
+ warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w),
198
+ masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w))
199
+
200
+ loss += l1_loss + smooth_loss + smooth_loss2
201
+
202
+ warp_loss += warp_loss_i
203
+
204
+ return loss, warp_loss
205
+
206
+
207
+ def edgeLoss(preds_edges, edges):
208
+ """
209
+
210
+ Args:
211
+ preds_edges: with shape [b, c, h , w]
212
+ edges: with shape [b, c, h, w]
213
+
214
+ Returns: Edge losses
215
+
216
+ """
217
+ mask = (edges > 0.5).float()
218
+ b, c, h, w = mask.shape
219
+ num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
220
+ num_neg = c * h * w - num_pos # Shape: [b,].
221
+ neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
222
+ pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
223
+ weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug
224
+ losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none')
225
+ loss = torch.mean(losses)
226
+ return loss
227
+
228
+ class EdgeLoss(nn.Module):
229
+ def __init__(self):
230
+ super().__init__()
231
+
232
+ def forward(self, pred_edges, gt_edges, masks):
233
+ # pred_flows: b t-1 1 h w
234
+ loss = 0
235
+ h, w = pred_edges[0].shape[-2:]
236
+ masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
237
+ for i in range(len(pred_edges)):
238
+ # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug
239
+ combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i])
240
+ edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \
241
+ + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)))
242
+ loss += edge_loss
243
+
244
+ return loss
245
+
246
+
247
+ class FlowSimpleLoss(nn.Module):
248
+ def __init__(self):
249
+ super().__init__()
250
+ self.l1_criterion = nn.L1Loss()
251
+
252
+ def forward(self, pred_flows, gt_flows):
253
+ # pred_flows: b t-1 2 h w
254
+ loss = 0
255
+ h, w = pred_flows[0].shape[-2:]
256
+ h_orig, w_orig = gt_flows[0].shape[-2:]
257
+ pred_flows = [f.view(-1, 2, h, w) for f in pred_flows]
258
+ gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows]
259
+
260
+ ds_factor = 1.0*h/h_orig
261
+ gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows]
262
+ for i in range(len(pred_flows)):
263
+ loss += self.l1_criterion(pred_flows[i], gt_flows[i])
264
+
265
+ return loss
model/modules/flow_loss_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ def flow_warp(x,
7
+ flow,
8
+ interpolation='bilinear',
9
+ padding_mode='zeros',
10
+ align_corners=True):
11
+ """Warp an image or a feature map with optical flow.
12
+ Args:
13
+ x (Tensor): Tensor with size (n, c, h, w).
14
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
15
+ a two-channel, denoting the width and height relative offsets.
16
+ Note that the values are not normalized to [-1, 1].
17
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
18
+ Default: 'bilinear'.
19
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
20
+ Default: 'zeros'.
21
+ align_corners (bool): Whether align corners. Default: True.
22
+ Returns:
23
+ Tensor: Warped image or feature map.
24
+ """
25
+ if x.size()[-2:] != flow.size()[1:3]:
26
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
27
+ f'flow ({flow.size()[1:3]}) are not the same.')
28
+ _, _, h, w = x.size()
29
+ # create mesh grid
30
+ device = flow.device
31
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device))
32
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
33
+ grid.requires_grad = False
34
+
35
+ grid_flow = grid + flow
36
+ # scale grid_flow to [-1,1]
37
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
38
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
39
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
40
+ output = F.grid_sample(x,
41
+ grid_flow,
42
+ mode=interpolation,
43
+ padding_mode=padding_mode,
44
+ align_corners=align_corners)
45
+ return output
46
+
47
+
48
+ # def image_warp(image, flow):
49
+ # b, c, h, w = image.size()
50
+ # device = image.device
51
+ # flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right
52
+ # flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension
53
+ # x = np.linspace(-1, 1, w)
54
+ # y = np.linspace(-1, 1, h)
55
+ # X, Y = np.meshgrid(x, y)
56
+ # grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3),
57
+ # torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
58
+ # output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros')
59
+ # return output
60
+
61
+
62
+ def length_sq(x):
63
+ return torch.sum(torch.square(x), dim=1, keepdim=True)
64
+
65
+
66
+ def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
67
+ flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
68
+ flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x))
69
+ flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
70
+ flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x))
71
+
72
+ mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
73
+ mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))|
74
+ occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
75
+ occ_thresh_bw = alpha1 * mag_sq_bw + alpha2
76
+
77
+ fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float()
78
+ fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float()
79
+
80
+ return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2
81
+
82
+
83
+ def rgb2gray(image):
84
+ gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2]
85
+ gray_image = gray_image.unsqueeze(1)
86
+ return gray_image
87
+
88
+
89
+ def ternary_transform(image, max_distance=1):
90
+ device = image.device
91
+ patch_size = 2 * max_distance + 1
92
+ intensities = rgb2gray(image) * 255
93
+ out_channels = patch_size * patch_size
94
+ w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size)
95
+ weights = torch.from_numpy(w).float().to(device)
96
+ patches = F.conv2d(intensities, weights, stride=1, padding=1)
97
+ transf = patches - intensities
98
+ transf_norm = transf / torch.sqrt(0.81 + torch.square(transf))
99
+ return transf_norm
100
+
101
+
102
+ def hamming_distance(t1, t2):
103
+ dist = torch.square(t1 - t2)
104
+ dist_norm = dist / (0.1 + dist)
105
+ dist_sum = torch.sum(dist_norm, dim=1, keepdim=True)
106
+ return dist_sum
107
+
108
+
109
+ def create_mask(mask, paddings):
110
+ """
111
+ padding: [[top, bottom], [left, right]]
112
+ """
113
+ shape = mask.shape
114
+ inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
115
+ inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
116
+ inner = torch.ones([inner_height, inner_width])
117
+
118
+ mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]])
119
+ mask3d = mask2d.unsqueeze(0)
120
+ mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1)
121
+ return mask4d.detach()
122
+
123
+
124
+ def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):
125
+ """
126
+
127
+ Args:
128
+ frame1: torch tensor, with shape [b * t, c, h, w]
129
+ warp_frame21: torch tensor, with shape [b * t, c, h, w]
130
+ confMask: confidence mask, with shape [b * t, c, h, w]
131
+ masks: torch tensor, with shape [b * t, c, h, w]
132
+ max_distance: maximum distance.
133
+
134
+ Returns: ternary loss
135
+
136
+ """
137
+ t1 = ternary_transform(frame1)
138
+ t21 = ternary_transform(warp_frame21)
139
+ dist = hamming_distance(t1, t21)
140
+ loss = torch.mean(dist * confMask * masks) / torch.mean(masks)
141
+ return loss
142
+
model/modules/sparse_transformer.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import reduce
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class SoftSplit(nn.Module):
8
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
9
+ super(SoftSplit, self).__init__()
10
+ self.kernel_size = kernel_size
11
+ self.stride = stride
12
+ self.padding = padding
13
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
14
+ stride=stride,
15
+ padding=padding)
16
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
17
+ self.embedding = nn.Linear(c_in, hidden)
18
+
19
+ def forward(self, x, b, output_size):
20
+ f_h = int((output_size[0] + 2 * self.padding[0] -
21
+ (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
22
+ f_w = int((output_size[1] + 2 * self.padding[1] -
23
+ (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
24
+
25
+ feat = self.t2t(x)
26
+ feat = feat.permute(0, 2, 1)
27
+ # feat shape [b*t, num_vec, ks*ks*c]
28
+ feat = self.embedding(feat)
29
+ # feat shape after embedding [b, t*num_vec, hidden]
30
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
31
+ return feat
32
+
33
+
34
+ class SoftComp(nn.Module):
35
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
36
+ super(SoftComp, self).__init__()
37
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
38
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
39
+ self.embedding = nn.Linear(hidden, c_out)
40
+ self.kernel_size = kernel_size
41
+ self.stride = stride
42
+ self.padding = padding
43
+ self.bias_conv = nn.Conv2d(channel,
44
+ channel,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x, t, output_size):
50
+ b_, _, _, _, c_ = x.shape
51
+ x = x.view(b_, -1, c_)
52
+ feat = self.embedding(x)
53
+ b, _, c = feat.size()
54
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
55
+ feat = F.fold(feat,
56
+ output_size=output_size,
57
+ kernel_size=self.kernel_size,
58
+ stride=self.stride,
59
+ padding=self.padding)
60
+ feat = self.bias_conv(feat)
61
+ return feat
62
+
63
+
64
+ class FusionFeedForward(nn.Module):
65
+ def __init__(self, dim, hidden_dim=1960, t2t_params=None):
66
+ super(FusionFeedForward, self).__init__()
67
+ # We set hidden_dim as a default to 1960
68
+ self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
69
+ self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
70
+ assert t2t_params is not None
71
+ self.t2t_params = t2t_params
72
+ self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
73
+
74
+ def forward(self, x, output_size):
75
+ n_vecs = 1
76
+ for i, d in enumerate(self.t2t_params['kernel_size']):
77
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
78
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
79
+
80
+ x = self.fc1(x)
81
+ b, n, c = x.size()
82
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
83
+ normalizer = F.fold(normalizer,
84
+ output_size=output_size,
85
+ kernel_size=self.t2t_params['kernel_size'],
86
+ padding=self.t2t_params['padding'],
87
+ stride=self.t2t_params['stride'])
88
+
89
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
90
+ output_size=output_size,
91
+ kernel_size=self.t2t_params['kernel_size'],
92
+ padding=self.t2t_params['padding'],
93
+ stride=self.t2t_params['stride'])
94
+
95
+ x = F.unfold(x / normalizer,
96
+ kernel_size=self.t2t_params['kernel_size'],
97
+ padding=self.t2t_params['padding'],
98
+ stride=self.t2t_params['stride']).permute(
99
+ 0, 2, 1).contiguous().view(b, n, c)
100
+ x = self.fc2(x)
101
+ return x
102
+
103
+
104
+ def window_partition(x, window_size, n_head):
105
+ """
106
+ Args:
107
+ x: shape is (B, T, H, W, C)
108
+ window_size (tuple[int]): window size
109
+ Returns:
110
+ windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head)
111
+ """
112
+ B, T, H, W, C = x.shape
113
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head)
114
+ windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
115
+ return windows
116
+
117
+ class SparseWindowAttention(nn.Module):
118
+ def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0.,
119
+ pooling_token=True):
120
+ super().__init__()
121
+ assert dim % n_head == 0
122
+ # key, query, value projections for all heads
123
+ self.key = nn.Linear(dim, dim, qkv_bias)
124
+ self.query = nn.Linear(dim, dim, qkv_bias)
125
+ self.value = nn.Linear(dim, dim, qkv_bias)
126
+ # regularization
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj_drop = nn.Dropout(proj_drop)
129
+ # output projection
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.n_head = n_head
132
+ self.window_size = window_size
133
+ self.pooling_token = pooling_token
134
+ if self.pooling_token:
135
+ ks, stride = pool_size, pool_size
136
+ self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim)
137
+ self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1]))
138
+ self.pool_layer.bias.data.fill_(0)
139
+ # self.expand_size = tuple(i // 2 for i in window_size)
140
+ self.expand_size = tuple((i + 1) // 2 for i in window_size)
141
+
142
+ if any(i > 0 for i in self.expand_size):
143
+ # get mask for rolled k and rolled v
144
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
145
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
146
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
147
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
148
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
149
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
150
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
151
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
152
+ masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
153
+ self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1))
154
+
155
+ self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0))
156
+
157
+
158
+ def forward(self, x, mask=None, T_ind=None, attn_mask=None):
159
+ b, t, h, w, c = x.shape # 20 36
160
+ w_h, w_w = self.window_size[0], self.window_size[1]
161
+ c_head = c // self.n_head
162
+ n_wh = math.ceil(h / self.window_size[0])
163
+ n_ww = math.ceil(w / self.window_size[1])
164
+ new_h = n_wh * self.window_size[0] # 20
165
+ new_w = n_ww * self.window_size[1] # 36
166
+ pad_r = new_w - w
167
+ pad_b = new_h - h
168
+ # reverse order
169
+ if pad_r > 0 or pad_b > 0:
170
+ x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
171
+ mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
172
+
173
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
174
+ q = self.query(x)
175
+ k = self.key(x)
176
+ v = self.value(x)
177
+ win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
178
+ win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
179
+ win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
180
+ # roll_k and roll_v
181
+ if any(i > 0 for i in self.expand_size):
182
+ (k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
183
+ (k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
184
+ (k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
185
+ (k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
186
+
187
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
188
+ lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
189
+ (k_tl, k_tr, k_bl, k_br))
190
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
191
+ lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
192
+ (v_tl, v_tr, v_bl, v_br))
193
+ rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous()
194
+ rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
195
+ # mask out tokens in current window
196
+ rool_k = rool_k[:, :, :, :, self.valid_ind_rolled]
197
+ rool_v = rool_v[:, :, :, :, self.valid_ind_rolled]
198
+ roll_N = rool_k.shape[4]
199
+ rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
200
+ rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
201
+ win_k = torch.cat((win_k, rool_k), dim=4)
202
+ win_v = torch.cat((win_v, rool_v), dim=4)
203
+ else:
204
+ win_k = win_k
205
+ win_v = win_v
206
+
207
+ # pool_k and pool_v
208
+ if self.pooling_token:
209
+ pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2))
210
+ _, _, p_h, p_w = pool_x.shape
211
+ pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c)
212
+ # pool_k
213
+ pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
214
+ pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
215
+ pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
216
+ win_k = torch.cat((win_k, pool_k), dim=4)
217
+ # pool_v
218
+ pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
219
+ pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
220
+ pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
221
+ win_v = torch.cat((win_v, pool_v), dim=4)
222
+
223
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
224
+ out = torch.zeros_like(win_q)
225
+ l_t = mask.size(1)
226
+
227
+ mask = self.max_pool(mask.view(b * l_t, new_h, new_w))
228
+ mask = mask.view(b, l_t, n_wh*n_ww)
229
+ mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww]
230
+ for i in range(win_q.shape[0]):
231
+ ### For masked windows
232
+ mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1)
233
+ # mask out quary in current window
234
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
235
+ mask_n = len(mask_ind_i)
236
+ if mask_n > 0:
237
+ win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head)
238
+ win_k_t = win_k[i, mask_ind_i]
239
+ win_v_t = win_v[i, mask_ind_i]
240
+ # mask out key and value
241
+ if T_ind is not None:
242
+ # key [n_wh*n_ww, n_head, t, w_h*w_w, c_head]
243
+ win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
244
+ # value
245
+ win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
246
+ else:
247
+ win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
248
+ win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
249
+
250
+ att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1)))
251
+ att_t = F.softmax(att_t, dim=-1)
252
+ att_t = self.attn_drop(att_t)
253
+ y_t = att_t @ win_v_t
254
+
255
+ out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head)
256
+
257
+ ### For unmasked windows
258
+ unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1)
259
+ # mask out quary in current window
260
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
261
+ win_q_s = win_q[i, unmask_ind_i]
262
+ win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w]
263
+ win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w]
264
+
265
+ att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1)))
266
+ att_s = F.softmax(att_s, dim=-1)
267
+ att_s = self.attn_drop(att_s)
268
+ y_s = att_s @ win_v_s
269
+ out[i, unmask_ind_i] = y_s
270
+
271
+ # re-assemble all head outputs side by side
272
+ out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head)
273
+ out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c)
274
+
275
+
276
+ if pad_r > 0 or pad_b > 0:
277
+ out = out[:, :, :h, :w, :]
278
+
279
+ # output projection
280
+ out = self.proj_drop(self.proj(out))
281
+ return out
282
+
283
+
284
+ class TemporalSparseTransformer(nn.Module):
285
+ def __init__(self, dim, n_head, window_size, pool_size,
286
+ norm_layer=nn.LayerNorm, t2t_params=None):
287
+ super().__init__()
288
+ self.window_size = window_size
289
+ self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size)
290
+ self.norm1 = norm_layer(dim)
291
+ self.norm2 = norm_layer(dim)
292
+ self.mlp = FusionFeedForward(dim, t2t_params=t2t_params)
293
+
294
+ def forward(self, x, fold_x_size, mask=None, T_ind=None):
295
+ """
296
+ Args:
297
+ x: image tokens, shape [B T H W C]
298
+ fold_x_size: fold feature size, shape [60 108]
299
+ mask: mask tokens, shape [B T H W 1]
300
+ Returns:
301
+ out_tokens: shape [B T H W C]
302
+ """
303
+ B, T, H, W, C = x.shape # 20 36
304
+
305
+ shortcut = x
306
+ x = self.norm1(x)
307
+ att_x = self.attention(x, mask, T_ind)
308
+
309
+ # FFN
310
+ x = shortcut + att_x
311
+ y = self.norm2(x)
312
+ x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C)
313
+
314
+ return x
315
+
316
+
317
+ class TemporalSparseTransformerBlock(nn.Module):
318
+ def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None):
319
+ super().__init__()
320
+ blocks = []
321
+ for i in range(depths):
322
+ blocks.append(
323
+ TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params)
324
+ )
325
+ self.transformer = nn.Sequential(*blocks)
326
+ self.depths = depths
327
+
328
+ def forward(self, x, fold_x_size, l_mask=None, t_dilation=2):
329
+ """
330
+ Args:
331
+ x: image tokens, shape [B T H W C]
332
+ fold_x_size: fold feature size, shape [60 108]
333
+ l_mask: local mask tokens, shape [B T H W 1]
334
+ Returns:
335
+ out_tokens: shape [B T H W C]
336
+ """
337
+ assert self.depths % t_dilation == 0, 'wrong t_dilation input.'
338
+ T = x.size(1)
339
+ T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation)
340
+
341
+ for i in range(0, self.depths):
342
+ x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i])
343
+
344
+ return x
model/modules/spectral_norm.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spectral Normalization from https://arxiv.org/abs/1802.05957
3
+ """
4
+ import torch
5
+ from torch.nn.functional import normalize
6
+
7
+
8
+ class SpectralNorm(object):
9
+ # Invariant before and after each forward call:
10
+ # u = normalize(W @ v)
11
+ # NB: At initialization, this invariant is not enforced
12
+
13
+ _version = 1
14
+
15
+ # At version 1:
16
+ # made `W` not a buffer,
17
+ # added `v` as a buffer, and
18
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
19
+
20
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
21
+ self.name = name
22
+ self.dim = dim
23
+ if n_power_iterations <= 0:
24
+ raise ValueError(
25
+ 'Expected n_power_iterations to be positive, but '
26
+ 'got n_power_iterations={}'.format(n_power_iterations))
27
+ self.n_power_iterations = n_power_iterations
28
+ self.eps = eps
29
+
30
+ def reshape_weight_to_matrix(self, weight):
31
+ weight_mat = weight
32
+ if self.dim != 0:
33
+ # permute dim to front
34
+ weight_mat = weight_mat.permute(
35
+ self.dim,
36
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
37
+ height = weight_mat.size(0)
38
+ return weight_mat.reshape(height, -1)
39
+
40
+ def compute_weight(self, module, do_power_iteration):
41
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
42
+ # updated in power iteration **in-place**. This is very important
43
+ # because in `DataParallel` forward, the vectors (being buffers) are
44
+ # broadcast from the parallelized module to each module replica,
45
+ # which is a new module object created on the fly. And each replica
46
+ # runs its own spectral norm power iteration. So simply assigning
47
+ # the updated vectors to the module this function runs on will cause
48
+ # the update to be lost forever. And the next time the parallelized
49
+ # module is replicated, the same randomly initialized vectors are
50
+ # broadcast and used!
51
+ #
52
+ # Therefore, to make the change propagate back, we rely on two
53
+ # important behaviors (also enforced via tests):
54
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
55
+ # is already on correct device; and it makes sure that the
56
+ # parallelized module is already on `device[0]`.
57
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
58
+ # just fill in the values.
59
+ # Therefore, since the same power iteration is performed on all
60
+ # devices, simply updating the tensors in-place will make sure that
61
+ # the module replica on `device[0]` will update the _u vector on the
62
+ # parallized module (by shared storage).
63
+ #
64
+ # However, after we update `u` and `v` in-place, we need to **clone**
65
+ # them before using them to normalize the weight. This is to support
66
+ # backproping through two forward passes, e.g., the common pattern in
67
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
68
+ # complain that variables needed to do backward for the first forward
69
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
70
+ weight = getattr(module, self.name + '_orig')
71
+ u = getattr(module, self.name + '_u')
72
+ v = getattr(module, self.name + '_v')
73
+ weight_mat = self.reshape_weight_to_matrix(weight)
74
+
75
+ if do_power_iteration:
76
+ with torch.no_grad():
77
+ for _ in range(self.n_power_iterations):
78
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
79
+ # are the first left and right singular vectors.
80
+ # This power iteration produces approximations of `u` and `v`.
81
+ v = normalize(torch.mv(weight_mat.t(), u),
82
+ dim=0,
83
+ eps=self.eps,
84
+ out=v)
85
+ u = normalize(torch.mv(weight_mat, v),
86
+ dim=0,
87
+ eps=self.eps,
88
+ out=u)
89
+ if self.n_power_iterations > 0:
90
+ # See above on why we need to clone
91
+ u = u.clone()
92
+ v = v.clone()
93
+
94
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
95
+ weight = weight / sigma
96
+ return weight
97
+
98
+ def remove(self, module):
99
+ with torch.no_grad():
100
+ weight = self.compute_weight(module, do_power_iteration=False)
101
+ delattr(module, self.name)
102
+ delattr(module, self.name + '_u')
103
+ delattr(module, self.name + '_v')
104
+ delattr(module, self.name + '_orig')
105
+ module.register_parameter(self.name,
106
+ torch.nn.Parameter(weight.detach()))
107
+
108
+ def __call__(self, module, inputs):
109
+ setattr(
110
+ module, self.name,
111
+ self.compute_weight(module, do_power_iteration=module.training))
112
+
113
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
114
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
115
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
116
+ # This uses pinverse in case W^T W is not invertible.
117
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
118
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
119
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
120
+
121
+ @staticmethod
122
+ def apply(module, name, n_power_iterations, dim, eps):
123
+ for k, hook in module._forward_pre_hooks.items():
124
+ if isinstance(hook, SpectralNorm) and hook.name == name:
125
+ raise RuntimeError(
126
+ "Cannot register two spectral_norm hooks on "
127
+ "the same parameter {}".format(name))
128
+
129
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
130
+ weight = module._parameters[name]
131
+
132
+ with torch.no_grad():
133
+ weight_mat = fn.reshape_weight_to_matrix(weight)
134
+
135
+ h, w = weight_mat.size()
136
+ # randomly initialize `u` and `v`
137
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
138
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
139
+
140
+ delattr(module, fn.name)
141
+ module.register_parameter(fn.name + "_orig", weight)
142
+ # We still need to assign weight back as fn.name because all sorts of
143
+ # things may assume that it exists, e.g., when initializing weights.
144
+ # However, we can't directly assign as it could be an nn.Parameter and
145
+ # gets added as a parameter. Instead, we register weight.data as a plain
146
+ # attribute.
147
+ setattr(module, fn.name, weight.data)
148
+ module.register_buffer(fn.name + "_u", u)
149
+ module.register_buffer(fn.name + "_v", v)
150
+
151
+ module.register_forward_pre_hook(fn)
152
+
153
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
154
+ module._register_load_state_dict_pre_hook(
155
+ SpectralNormLoadStateDictPreHook(fn))
156
+ return fn
157
+
158
+
159
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
160
+ # instancemethod.
161
+ class SpectralNormLoadStateDictPreHook(object):
162
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
163
+ def __init__(self, fn):
164
+ self.fn = fn
165
+
166
+ # For state_dict with version None, (assuming that it has gone through at
167
+ # least one training forward), we have
168
+ #
169
+ # u = normalize(W_orig @ v)
170
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
171
+ #
172
+ # To compute `v`, we solve `W_orig @ x = u`, and let
173
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
174
+ def __call__(self, state_dict, prefix, local_metadata, strict,
175
+ missing_keys, unexpected_keys, error_msgs):
176
+ fn = self.fn
177
+ version = local_metadata.get('spectral_norm',
178
+ {}).get(fn.name + '.version', None)
179
+ if version is None or version < 1:
180
+ with torch.no_grad():
181
+ weight_orig = state_dict[prefix + fn.name + '_orig']
182
+ # weight = state_dict.pop(prefix + fn.name)
183
+ # sigma = (weight_orig / weight).mean()
184
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
185
+ u = state_dict[prefix + fn.name + '_u']
186
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
187
+ # state_dict[prefix + fn.name + '_v'] = v
188
+
189
+
190
+ # This is a top level class because Py2 pickle doesn't like inner class nor an
191
+ # instancemethod.
192
+ class SpectralNormStateDictHook(object):
193
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
194
+ def __init__(self, fn):
195
+ self.fn = fn
196
+
197
+ def __call__(self, module, state_dict, prefix, local_metadata):
198
+ if 'spectral_norm' not in local_metadata:
199
+ local_metadata['spectral_norm'] = {}
200
+ key = self.fn.name + '.version'
201
+ if key in local_metadata['spectral_norm']:
202
+ raise RuntimeError(
203
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
204
+ local_metadata['spectral_norm'][key] = self.fn._version
205
+
206
+
207
+ def spectral_norm(module,
208
+ name='weight',
209
+ n_power_iterations=1,
210
+ eps=1e-12,
211
+ dim=None):
212
+ r"""Applies spectral normalization to a parameter in the given module.
213
+
214
+ .. math::
215
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
216
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
217
+
218
+ Spectral normalization stabilizes the training of discriminators (critics)
219
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
220
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
221
+ power iteration method. If the dimension of the weight tensor is greater
222
+ than 2, it is reshaped to 2D in power iteration method to get spectral
223
+ norm. This is implemented via a hook that calculates spectral norm and
224
+ rescales weight before every :meth:`~Module.forward` call.
225
+
226
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
227
+
228
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
229
+
230
+ Args:
231
+ module (nn.Module): containing module
232
+ name (str, optional): name of weight parameter
233
+ n_power_iterations (int, optional): number of power iterations to
234
+ calculate spectral norm
235
+ eps (float, optional): epsilon for numerical stability in
236
+ calculating norms
237
+ dim (int, optional): dimension corresponding to number of outputs,
238
+ the default is ``0``, except for modules that are instances of
239
+ ConvTranspose{1,2,3}d, when it is ``1``
240
+
241
+ Returns:
242
+ The original module with the spectral norm hook
243
+
244
+ Example::
245
+
246
+ >>> m = spectral_norm(nn.Linear(20, 40))
247
+ >>> m
248
+ Linear(in_features=20, out_features=40, bias=True)
249
+ >>> m.weight_u.size()
250
+ torch.Size([40])
251
+
252
+ """
253
+ if dim is None:
254
+ if isinstance(module,
255
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
256
+ torch.nn.ConvTranspose3d)):
257
+ dim = 1
258
+ else:
259
+ dim = 0
260
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
261
+ return module
262
+
263
+
264
+ def remove_spectral_norm(module, name='weight'):
265
+ r"""Removes the spectral normalization reparameterization from a module.
266
+
267
+ Args:
268
+ module (Module): containing module
269
+ name (str, optional): name of weight parameter
270
+
271
+ Example:
272
+ >>> m = spectral_norm(nn.Linear(40, 10))
273
+ >>> remove_spectral_norm(m)
274
+ """
275
+ for k, hook in module._forward_pre_hooks.items():
276
+ if isinstance(hook, SpectralNorm) and hook.name == name:
277
+ hook.remove(module)
278
+ del module._forward_pre_hooks[k]
279
+ return module
280
+
281
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
282
+ name, module))
283
+
284
+
285
+ def use_spectral_norm(module, use_sn=False):
286
+ if use_sn:
287
+ return spectral_norm(module)
288
+ return module
model/propainter.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Towards An End-to-End Framework for Video Inpainting
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torchvision
8
+
9
+ from einops import rearrange
10
+
11
+ from model.modules.base_module import BaseNetwork
12
+ from model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp
13
+ from model.modules.spectral_norm import spectral_norm as _spectral_norm
14
+ from model.modules.flow_loss_utils import flow_warp
15
+ from model.modules.deformconv import ModulatedDeformConv2d
16
+
17
+ from .misc import constant_init
18
+
19
+ def length_sq(x):
20
+ return torch.sum(torch.square(x), dim=1, keepdim=True)
21
+
22
+ def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
23
+ flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
24
+ flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
25
+
26
+ mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
27
+ occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
28
+
29
+ # fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float()
30
+ fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw)
31
+ return fb_valid_fw
32
+
33
+
34
+ class DeformableAlignment(ModulatedDeformConv2d):
35
+ """Second-order deformable alignment module."""
36
+ def __init__(self, *args, **kwargs):
37
+ # self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
38
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3)
39
+
40
+ super(DeformableAlignment, self).__init__(*args, **kwargs)
41
+
42
+ self.conv_offset = nn.Sequential(
43
+ nn.Conv2d(2*self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1),
44
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
45
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
46
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
47
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
48
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
49
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
50
+ )
51
+ self.init_offset()
52
+
53
+ def init_offset(self):
54
+ constant_init(self.conv_offset[-1], val=0, bias=0)
55
+
56
+ def forward(self, x, cond_feat, flow):
57
+ out = self.conv_offset(cond_feat)
58
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
59
+
60
+ # offset
61
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
62
+ offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1)
63
+
64
+ # mask
65
+ mask = torch.sigmoid(mask)
66
+
67
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
68
+ self.stride, self.padding,
69
+ self.dilation, mask)
70
+
71
+
72
+ class BidirectionalPropagation(nn.Module):
73
+ def __init__(self, channel, learnable=True):
74
+ super(BidirectionalPropagation, self).__init__()
75
+ self.deform_align = nn.ModuleDict()
76
+ self.backbone = nn.ModuleDict()
77
+ self.channel = channel
78
+ self.prop_list = ['backward_1', 'forward_1']
79
+ self.learnable = learnable
80
+
81
+ if self.learnable:
82
+ for i, module in enumerate(self.prop_list):
83
+ self.deform_align[module] = DeformableAlignment(
84
+ channel, channel, 3, padding=1, deform_groups=16)
85
+
86
+ self.backbone[module] = nn.Sequential(
87
+ nn.Conv2d(2*channel+2, channel, 3, 1, 1),
88
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
89
+ nn.Conv2d(channel, channel, 3, 1, 1),
90
+ )
91
+
92
+ self.fuse = nn.Sequential(
93
+ nn.Conv2d(2*channel+2, channel, 3, 1, 1),
94
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
95
+ nn.Conv2d(channel, channel, 3, 1, 1),
96
+ )
97
+
98
+ def binary_mask(self, mask, th=0.1):
99
+ mask[mask>th] = 1
100
+ mask[mask<=th] = 0
101
+ # return mask.float()
102
+ return mask.to(mask)
103
+
104
+ def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear'):
105
+ """
106
+ x shape : [b, t, c, h, w]
107
+ return [b, t, c, h, w]
108
+ """
109
+
110
+ # For backward warping
111
+ # pred_flows_forward for backward feature propagation
112
+ # pred_flows_backward for forward feature propagation
113
+ b, t, c, h, w = x.shape
114
+ feats, masks = {}, {}
115
+ feats['input'] = [x[:, i, :, :, :] for i in range(0, t)]
116
+ masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)]
117
+
118
+ prop_list = ['backward_1', 'forward_1']
119
+ cache_list = ['input'] + prop_list
120
+
121
+ for p_i, module_name in enumerate(prop_list):
122
+ feats[module_name] = []
123
+ masks[module_name] = []
124
+
125
+ if 'backward' in module_name:
126
+ frame_idx = range(0, t)
127
+ frame_idx = frame_idx[::-1]
128
+ flow_idx = frame_idx
129
+ flows_for_prop = flows_forward
130
+ flows_for_check = flows_backward
131
+ else:
132
+ frame_idx = range(0, t)
133
+ flow_idx = range(-1, t - 1)
134
+ flows_for_prop = flows_backward
135
+ flows_for_check = flows_forward
136
+
137
+ for i, idx in enumerate(frame_idx):
138
+ feat_current = feats[cache_list[p_i]][idx]
139
+ mask_current = masks[cache_list[p_i]][idx]
140
+
141
+ if i == 0:
142
+ feat_prop = feat_current
143
+ mask_prop = mask_current
144
+ else:
145
+ flow_prop = flows_for_prop[:, flow_idx[i], :, :, :]
146
+ flow_check = flows_for_check[:, flow_idx[i], :, :, :]
147
+ flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check)
148
+ feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation)
149
+
150
+ if self.learnable:
151
+ cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1)
152
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop)
153
+ mask_prop = mask_current
154
+ else:
155
+ mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1))
156
+ mask_prop_valid = self.binary_mask(mask_prop_valid)
157
+
158
+ union_vaild_mask = self.binary_mask(mask_current*flow_vaild_mask*(1-mask_prop_valid))
159
+ feat_prop = union_vaild_mask * feat_warped + (1-union_vaild_mask) * feat_current
160
+ # update mask
161
+ mask_prop = self.binary_mask(mask_current*(1-(flow_vaild_mask*(1-mask_prop_valid))))
162
+
163
+ # refine
164
+ if self.learnable:
165
+ feat = torch.cat([feat_current, feat_prop, mask_current], dim=1)
166
+ feat_prop = feat_prop + self.backbone[module_name](feat)
167
+ # feat_prop = self.backbone[module_name](feat_prop)
168
+
169
+ feats[module_name].append(feat_prop)
170
+ masks[module_name].append(mask_prop)
171
+
172
+ # end for
173
+ if 'backward' in module_name:
174
+ feats[module_name] = feats[module_name][::-1]
175
+ masks[module_name] = masks[module_name][::-1]
176
+
177
+ outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w)
178
+ outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w)
179
+
180
+ if self.learnable:
181
+ mask_in = mask.view(-1, 2, h, w)
182
+ masks_b, masks_f = None, None
183
+ outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w)
184
+ else:
185
+ masks_b = torch.stack(masks['backward_1'], dim=1)
186
+ masks_f = torch.stack(masks['forward_1'], dim=1)
187
+ outputs = outputs_f
188
+
189
+ return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \
190
+ outputs.view(b, -1, c, h, w), masks_f
191
+
192
+
193
+ class Encoder(nn.Module):
194
+ def __init__(self):
195
+ super(Encoder, self).__init__()
196
+ self.group = [1, 2, 4, 8, 1]
197
+ self.layers = nn.ModuleList([
198
+ nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1),
199
+ nn.LeakyReLU(0.2, inplace=True),
200
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
201
+ nn.LeakyReLU(0.2, inplace=True),
202
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
203
+ nn.LeakyReLU(0.2, inplace=True),
204
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
205
+ nn.LeakyReLU(0.2, inplace=True),
206
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
207
+ nn.LeakyReLU(0.2, inplace=True),
208
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
209
+ nn.LeakyReLU(0.2, inplace=True),
210
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
211
+ nn.LeakyReLU(0.2, inplace=True),
212
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
213
+ nn.LeakyReLU(0.2, inplace=True),
214
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
215
+ nn.LeakyReLU(0.2, inplace=True)
216
+ ])
217
+
218
+ def forward(self, x):
219
+ bt, c, _, _ = x.size()
220
+ # h, w = h//4, w//4
221
+ out = x
222
+ for i, layer in enumerate(self.layers):
223
+ if i == 8:
224
+ x0 = out
225
+ _, _, h, w = x0.size()
226
+ if i > 8 and i % 2 == 0:
227
+ g = self.group[(i - 8) // 2]
228
+ x = x0.view(bt, g, -1, h, w)
229
+ o = out.view(bt, g, -1, h, w)
230
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
231
+ out = layer(out)
232
+ return out
233
+
234
+
235
+ class deconv(nn.Module):
236
+ def __init__(self,
237
+ input_channel,
238
+ output_channel,
239
+ kernel_size=3,
240
+ padding=0):
241
+ super().__init__()
242
+ self.conv = nn.Conv2d(input_channel,
243
+ output_channel,
244
+ kernel_size=kernel_size,
245
+ stride=1,
246
+ padding=padding)
247
+
248
+ def forward(self, x):
249
+ x = F.interpolate(x,
250
+ scale_factor=2,
251
+ mode='bilinear',
252
+ align_corners=True)
253
+ return self.conv(x)
254
+
255
+
256
+ class InpaintGenerator(BaseNetwork):
257
+ def __init__(self, init_weights=True, model_path=None):
258
+ super(InpaintGenerator, self).__init__()
259
+ channel = 128
260
+ hidden = 512
261
+
262
+ # encoder
263
+ self.encoder = Encoder()
264
+
265
+ # decoder
266
+ self.decoder = nn.Sequential(
267
+ deconv(channel, 128, kernel_size=3, padding=1),
268
+ nn.LeakyReLU(0.2, inplace=True),
269
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
270
+ nn.LeakyReLU(0.2, inplace=True),
271
+ deconv(64, 64, kernel_size=3, padding=1),
272
+ nn.LeakyReLU(0.2, inplace=True),
273
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
274
+
275
+ # soft split and soft composition
276
+ kernel_size = (7, 7)
277
+ padding = (3, 3)
278
+ stride = (3, 3)
279
+ t2t_params = {
280
+ 'kernel_size': kernel_size,
281
+ 'stride': stride,
282
+ 'padding': padding
283
+ }
284
+ self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding)
285
+ self.sc = SoftComp(channel, hidden, kernel_size, stride, padding)
286
+ self.max_pool = nn.MaxPool2d(kernel_size, stride, padding)
287
+
288
+ # feature propagation module
289
+ self.img_prop_module = BidirectionalPropagation(3, learnable=False)
290
+ self.feat_prop_module = BidirectionalPropagation(128, learnable=True)
291
+
292
+
293
+ depths = 8
294
+ num_heads = 4
295
+ window_size = (5, 9)
296
+ pool_size = (4, 4)
297
+ self.transformers = TemporalSparseTransformerBlock(dim=hidden,
298
+ n_head=num_heads,
299
+ window_size=window_size,
300
+ pool_size=pool_size,
301
+ depths=depths,
302
+ t2t_params=t2t_params)
303
+ if init_weights:
304
+ self.init_weights()
305
+
306
+
307
+ if model_path is not None:
308
+ print('Pretrained ProPainter has loaded...')
309
+ ckpt = torch.load(model_path, map_location='cpu')
310
+ self.load_state_dict(ckpt, strict=True)
311
+
312
+ # print network parameter number
313
+ self.print_network()
314
+
315
+ def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest'):
316
+ _, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], masks, interpolation)
317
+ return prop_frames, updated_masks
318
+
319
+ def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, interpolation='bilinear', t_dilation=2):
320
+ """
321
+ Args:
322
+ masks_in: original mask
323
+ masks_updated: updated mask after image propagation
324
+ """
325
+
326
+ l_t = num_local_frames
327
+ b, t, _, ori_h, ori_w = masked_frames.size()
328
+
329
+ # extracting features
330
+ enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w),
331
+ masks_in.view(b * t, 1, ori_h, ori_w),
332
+ masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1))
333
+ _, c, h, w = enc_feat.size()
334
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
335
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
336
+ fold_feat_size = (h, w)
337
+
338
+ ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
339
+ ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
340
+ ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, t, 1, h, w)
341
+ ds_mask_in_local = ds_mask_in[:, :l_t]
342
+ ds_mask_updated_local = F.interpolate(masks_updated[:,:l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, l_t, 1, h, w)
343
+
344
+
345
+ if self.training:
346
+ mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w))
347
+ mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
348
+ else:
349
+ mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w))
350
+ mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
351
+
352
+
353
+ prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2)
354
+ _, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation)
355
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
356
+
357
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size)
358
+ mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous()
359
+ trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation)
360
+ trans_feat = self.sc(trans_feat, t, fold_feat_size)
361
+ trans_feat = trans_feat.view(b, t, -1, h, w)
362
+
363
+ enc_feat = enc_feat + trans_feat
364
+
365
+ if self.training:
366
+ output = self.decoder(enc_feat.view(-1, c, h, w))
367
+ output = torch.tanh(output).view(b, t, 3, ori_h, ori_w)
368
+ else:
369
+ output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w))
370
+ output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w)
371
+
372
+ return output
373
+
374
+
375
+ # ######################################################################
376
+ # Discriminator for Temporal Patch GAN
377
+ # ######################################################################
378
+ class Discriminator(BaseNetwork):
379
+ def __init__(self,
380
+ in_channels=3,
381
+ use_sigmoid=False,
382
+ use_spectral_norm=True,
383
+ init_weights=True):
384
+ super(Discriminator, self).__init__()
385
+ self.use_sigmoid = use_sigmoid
386
+ nf = 32
387
+
388
+ self.conv = nn.Sequential(
389
+ spectral_norm(
390
+ nn.Conv3d(in_channels=in_channels,
391
+ out_channels=nf * 1,
392
+ kernel_size=(3, 5, 5),
393
+ stride=(1, 2, 2),
394
+ padding=1,
395
+ bias=not use_spectral_norm), use_spectral_norm),
396
+ # nn.InstanceNorm2d(64, track_running_stats=False),
397
+ nn.LeakyReLU(0.2, inplace=True),
398
+ spectral_norm(
399
+ nn.Conv3d(nf * 1,
400
+ nf * 2,
401
+ kernel_size=(3, 5, 5),
402
+ stride=(1, 2, 2),
403
+ padding=(1, 2, 2),
404
+ bias=not use_spectral_norm), use_spectral_norm),
405
+ # nn.InstanceNorm2d(128, track_running_stats=False),
406
+ nn.LeakyReLU(0.2, inplace=True),
407
+ spectral_norm(
408
+ nn.Conv3d(nf * 2,
409
+ nf * 4,
410
+ kernel_size=(3, 5, 5),
411
+ stride=(1, 2, 2),
412
+ padding=(1, 2, 2),
413
+ bias=not use_spectral_norm), use_spectral_norm),
414
+ # nn.InstanceNorm2d(256, track_running_stats=False),
415
+ nn.LeakyReLU(0.2, inplace=True),
416
+ spectral_norm(
417
+ nn.Conv3d(nf * 4,
418
+ nf * 4,
419
+ kernel_size=(3, 5, 5),
420
+ stride=(1, 2, 2),
421
+ padding=(1, 2, 2),
422
+ bias=not use_spectral_norm), use_spectral_norm),
423
+ # nn.InstanceNorm2d(256, track_running_stats=False),
424
+ nn.LeakyReLU(0.2, inplace=True),
425
+ spectral_norm(
426
+ nn.Conv3d(nf * 4,
427
+ nf * 4,
428
+ kernel_size=(3, 5, 5),
429
+ stride=(1, 2, 2),
430
+ padding=(1, 2, 2),
431
+ bias=not use_spectral_norm), use_spectral_norm),
432
+ # nn.InstanceNorm2d(256, track_running_stats=False),
433
+ nn.LeakyReLU(0.2, inplace=True),
434
+ nn.Conv3d(nf * 4,
435
+ nf * 4,
436
+ kernel_size=(3, 5, 5),
437
+ stride=(1, 2, 2),
438
+ padding=(1, 2, 2)))
439
+
440
+ if init_weights:
441
+ self.init_weights()
442
+
443
+ def forward(self, xs):
444
+ # T, C, H, W = xs.shape (old)
445
+ # B, T, C, H, W (new)
446
+ xs_t = torch.transpose(xs, 1, 2)
447
+ feat = self.conv(xs_t)
448
+ if self.use_sigmoid:
449
+ feat = torch.sigmoid(feat)
450
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
451
+ return out
452
+
453
+
454
+ class Discriminator_2D(BaseNetwork):
455
+ def __init__(self,
456
+ in_channels=3,
457
+ use_sigmoid=False,
458
+ use_spectral_norm=True,
459
+ init_weights=True):
460
+ super(Discriminator_2D, self).__init__()
461
+ self.use_sigmoid = use_sigmoid
462
+ nf = 32
463
+
464
+ self.conv = nn.Sequential(
465
+ spectral_norm(
466
+ nn.Conv3d(in_channels=in_channels,
467
+ out_channels=nf * 1,
468
+ kernel_size=(1, 5, 5),
469
+ stride=(1, 2, 2),
470
+ padding=(0, 2, 2),
471
+ bias=not use_spectral_norm), use_spectral_norm),
472
+ # nn.InstanceNorm2d(64, track_running_stats=False),
473
+ nn.LeakyReLU(0.2, inplace=True),
474
+ spectral_norm(
475
+ nn.Conv3d(nf * 1,
476
+ nf * 2,
477
+ kernel_size=(1, 5, 5),
478
+ stride=(1, 2, 2),
479
+ padding=(0, 2, 2),
480
+ bias=not use_spectral_norm), use_spectral_norm),
481
+ # nn.InstanceNorm2d(128, track_running_stats=False),
482
+ nn.LeakyReLU(0.2, inplace=True),
483
+ spectral_norm(
484
+ nn.Conv3d(nf * 2,
485
+ nf * 4,
486
+ kernel_size=(1, 5, 5),
487
+ stride=(1, 2, 2),
488
+ padding=(0, 2, 2),
489
+ bias=not use_spectral_norm), use_spectral_norm),
490
+ # nn.InstanceNorm2d(256, track_running_stats=False),
491
+ nn.LeakyReLU(0.2, inplace=True),
492
+ spectral_norm(
493
+ nn.Conv3d(nf * 4,
494
+ nf * 4,
495
+ kernel_size=(1, 5, 5),
496
+ stride=(1, 2, 2),
497
+ padding=(0, 2, 2),
498
+ bias=not use_spectral_norm), use_spectral_norm),
499
+ # nn.InstanceNorm2d(256, track_running_stats=False),
500
+ nn.LeakyReLU(0.2, inplace=True),
501
+ spectral_norm(
502
+ nn.Conv3d(nf * 4,
503
+ nf * 4,
504
+ kernel_size=(1, 5, 5),
505
+ stride=(1, 2, 2),
506
+ padding=(0, 2, 2),
507
+ bias=not use_spectral_norm), use_spectral_norm),
508
+ # nn.InstanceNorm2d(256, track_running_stats=False),
509
+ nn.LeakyReLU(0.2, inplace=True),
510
+ nn.Conv3d(nf * 4,
511
+ nf * 4,
512
+ kernel_size=(1, 5, 5),
513
+ stride=(1, 2, 2),
514
+ padding=(0, 2, 2)))
515
+
516
+ if init_weights:
517
+ self.init_weights()
518
+
519
+ def forward(self, xs):
520
+ # T, C, H, W = xs.shape (old)
521
+ # B, T, C, H, W (new)
522
+ xs_t = torch.transpose(xs, 1, 2)
523
+ feat = self.conv(xs_t)
524
+ if self.use_sigmoid:
525
+ feat = torch.sigmoid(feat)
526
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
527
+ return out
528
+
529
+ def spectral_norm(module, mode=True):
530
+ if mode:
531
+ return _spectral_norm(module)
532
+ return module
model/recurrent_flow_completion.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+
6
+ from model.modules.deformconv import ModulatedDeformConv2d
7
+ from .misc import constant_init
8
+
9
+ class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
10
+ """Second-order deformable alignment module."""
11
+ def __init__(self, *args, **kwargs):
12
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5)
13
+
14
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
15
+
16
+ self.conv_offset = nn.Sequential(
17
+ nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1),
18
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
19
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
20
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
21
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
22
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
23
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
24
+ )
25
+ self.init_offset()
26
+
27
+ def init_offset(self):
28
+ constant_init(self.conv_offset[-1], val=0, bias=0)
29
+
30
+ def forward(self, x, extra_feat):
31
+ out = self.conv_offset(extra_feat)
32
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
33
+
34
+ # offset
35
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
36
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
37
+ offset = torch.cat([offset_1, offset_2], dim=1)
38
+
39
+ # mask
40
+ mask = torch.sigmoid(mask)
41
+
42
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
43
+ self.stride, self.padding,
44
+ self.dilation, mask)
45
+
46
+ class BidirectionalPropagation(nn.Module):
47
+ def __init__(self, channel):
48
+ super(BidirectionalPropagation, self).__init__()
49
+ modules = ['backward_', 'forward_']
50
+ self.deform_align = nn.ModuleDict()
51
+ self.backbone = nn.ModuleDict()
52
+ self.channel = channel
53
+
54
+ for i, module in enumerate(modules):
55
+ self.deform_align[module] = SecondOrderDeformableAlignment(
56
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
57
+
58
+ self.backbone[module] = nn.Sequential(
59
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
60
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
61
+ nn.Conv2d(channel, channel, 3, 1, 1),
62
+ )
63
+
64
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
65
+
66
+ def forward(self, x):
67
+ """
68
+ x shape : [b, t, c, h, w]
69
+ return [b, t, c, h, w]
70
+ """
71
+ b, t, c, h, w = x.shape
72
+ feats = {}
73
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
74
+
75
+ for module_name in ['backward_', 'forward_']:
76
+
77
+ feats[module_name] = []
78
+
79
+ frame_idx = range(0, t)
80
+ mapping_idx = list(range(0, len(feats['spatial'])))
81
+ mapping_idx += mapping_idx[::-1]
82
+
83
+ if 'backward' in module_name:
84
+ frame_idx = frame_idx[::-1]
85
+
86
+ feat_prop = x.new_zeros(b, self.channel, h, w)
87
+ for i, idx in enumerate(frame_idx):
88
+ feat_current = feats['spatial'][mapping_idx[idx]]
89
+ if i > 0:
90
+ cond_n1 = feat_prop
91
+
92
+ # initialize second-order features
93
+ feat_n2 = torch.zeros_like(feat_prop)
94
+ cond_n2 = torch.zeros_like(cond_n1)
95
+ if i > 1: # second-order features
96
+ feat_n2 = feats[module_name][-2]
97
+ cond_n2 = feat_n2
98
+
99
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) # condition information, cond(flow warped 1st/2nd feature)
100
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2
101
+ feat_prop = self.deform_align[module_name](feat_prop, cond)
102
+
103
+ # fuse current features
104
+ feat = [feat_current] + \
105
+ [feats[k][idx] for k in feats if k not in ['spatial', module_name]] \
106
+ + [feat_prop]
107
+
108
+ feat = torch.cat(feat, dim=1)
109
+ # embed current features
110
+ feat_prop = feat_prop + self.backbone[module_name](feat)
111
+
112
+ feats[module_name].append(feat_prop)
113
+
114
+ # end for
115
+ if 'backward' in module_name:
116
+ feats[module_name] = feats[module_name][::-1]
117
+
118
+ outputs = []
119
+ for i in range(0, t):
120
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
121
+ align_feats = torch.cat(align_feats, dim=1)
122
+ outputs.append(self.fusion(align_feats))
123
+
124
+ return torch.stack(outputs, dim=1) + x
125
+
126
+
127
+ class deconv(nn.Module):
128
+ def __init__(self,
129
+ input_channel,
130
+ output_channel,
131
+ kernel_size=3,
132
+ padding=0):
133
+ super().__init__()
134
+ self.conv = nn.Conv2d(input_channel,
135
+ output_channel,
136
+ kernel_size=kernel_size,
137
+ stride=1,
138
+ padding=padding)
139
+
140
+ def forward(self, x):
141
+ x = F.interpolate(x,
142
+ scale_factor=2,
143
+ mode='bilinear',
144
+ align_corners=True)
145
+ return self.conv(x)
146
+
147
+
148
+ class P3DBlock(nn.Module):
149
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True):
150
+ super().__init__()
151
+ self.conv1 = nn.Sequential(
152
+ nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size),
153
+ stride=(1, stride, stride), padding=(0, padding, padding), bias=bias),
154
+ nn.LeakyReLU(0.2, inplace=True)
155
+ )
156
+ self.conv2 = nn.Sequential(
157
+ nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1),
158
+ padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias)
159
+ )
160
+ self.use_residual = use_residual
161
+
162
+ def forward(self, feats):
163
+ feat1 = self.conv1(feats)
164
+ feat2 = self.conv2(feat1)
165
+ if self.use_residual:
166
+ output = feats + feat2
167
+ else:
168
+ output = feat2
169
+ return output
170
+
171
+
172
+ class EdgeDetection(nn.Module):
173
+ def __init__(self, in_ch=2, out_ch=1, mid_ch=16):
174
+ super().__init__()
175
+ self.projection = nn.Sequential(
176
+ nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
177
+ nn.LeakyReLU(0.2, inplace=True)
178
+ )
179
+
180
+ self.mid_layer_1 = nn.Sequential(
181
+ nn.Conv2d(mid_ch, mid_ch, 3, 1, 1),
182
+ nn.LeakyReLU(0.2, inplace=True)
183
+ )
184
+
185
+ self.mid_layer_2 = nn.Sequential(
186
+ nn.Conv2d(mid_ch, mid_ch, 3, 1, 1)
187
+ )
188
+
189
+ self.l_relu = nn.LeakyReLU(0.01, inplace=True)
190
+
191
+ self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0)
192
+
193
+ def forward(self, flow):
194
+ flow = self.projection(flow)
195
+ edge = self.mid_layer_1(flow)
196
+ edge = self.mid_layer_2(edge)
197
+ edge = self.l_relu(flow + edge)
198
+ edge = self.out_layer(edge)
199
+ edge = torch.sigmoid(edge)
200
+ return edge
201
+
202
+
203
+ class RecurrentFlowCompleteNet(nn.Module):
204
+ def __init__(self, model_path=None):
205
+ super().__init__()
206
+ self.downsample = nn.Sequential(
207
+ nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2),
208
+ padding=(0, 2, 2), padding_mode='replicate'),
209
+ nn.LeakyReLU(0.2, inplace=True)
210
+ )
211
+
212
+ self.encoder1 = nn.Sequential(
213
+ P3DBlock(32, 32, 3, 1, 1),
214
+ nn.LeakyReLU(0.2, inplace=True),
215
+ P3DBlock(32, 64, 3, 2, 1),
216
+ nn.LeakyReLU(0.2, inplace=True)
217
+ ) # 4x
218
+
219
+ self.encoder2 = nn.Sequential(
220
+ P3DBlock(64, 64, 3, 1, 1),
221
+ nn.LeakyReLU(0.2, inplace=True),
222
+ P3DBlock(64, 128, 3, 2, 1),
223
+ nn.LeakyReLU(0.2, inplace=True)
224
+ ) # 8x
225
+
226
+ self.mid_dilation = nn.Sequential(
227
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2
228
+ nn.LeakyReLU(0.2, inplace=True),
229
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)),
230
+ nn.LeakyReLU(0.2, inplace=True),
231
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)),
232
+ nn.LeakyReLU(0.2, inplace=True)
233
+ )
234
+
235
+ # feature propagation module
236
+ self.feat_prop_module = BidirectionalPropagation(128)
237
+
238
+ self.decoder2 = nn.Sequential(
239
+ nn.Conv2d(128, 128, 3, 1, 1),
240
+ nn.LeakyReLU(0.2, inplace=True),
241
+ deconv(128, 64, 3, 1),
242
+ nn.LeakyReLU(0.2, inplace=True)
243
+ ) # 4x
244
+
245
+ self.decoder1 = nn.Sequential(
246
+ nn.Conv2d(64, 64, 3, 1, 1),
247
+ nn.LeakyReLU(0.2, inplace=True),
248
+ deconv(64, 32, 3, 1),
249
+ nn.LeakyReLU(0.2, inplace=True)
250
+ ) # 2x
251
+
252
+ self.upsample = nn.Sequential(
253
+ nn.Conv2d(32, 32, 3, padding=1),
254
+ nn.LeakyReLU(0.2, inplace=True),
255
+ deconv(32, 2, 3, 1)
256
+ )
257
+
258
+ # edge loss
259
+ self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16)
260
+
261
+ # Need to initial the weights of MSDeformAttn specifically
262
+ for m in self.modules():
263
+ if isinstance(m, SecondOrderDeformableAlignment):
264
+ m.init_offset()
265
+
266
+ if model_path is not None:
267
+ print('Pretrained flow completion model has loaded...')
268
+ ckpt = torch.load(model_path, map_location='cpu')
269
+ self.load_state_dict(ckpt, strict=True)
270
+
271
+
272
+ def forward(self, masked_flows, masks):
273
+ # masked_flows: b t-1 2 h w
274
+ # masks: b t-1 2 h w
275
+ b, t, _, h, w = masked_flows.size()
276
+ masked_flows = masked_flows.permute(0,2,1,3,4)
277
+ masks = masks.permute(0,2,1,3,4)
278
+
279
+ inputs = torch.cat((masked_flows, masks), dim=1)
280
+
281
+ x = self.downsample(inputs)
282
+
283
+ feat_e1 = self.encoder1(x)
284
+ feat_e2 = self.encoder2(feat_e1) # b c t h w
285
+ feat_mid = self.mid_dilation(feat_e2) # b c t h w
286
+ feat_mid = feat_mid.permute(0,2,1,3,4) # b t c h w
287
+
288
+ feat_prop = self.feat_prop_module(feat_mid)
289
+ feat_prop = feat_prop.view(-1, 128, h//8, w//8) # b*t c h w
290
+
291
+ _, c, _, h_f, w_f = feat_e1.shape
292
+ feat_e1 = feat_e1.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
293
+ feat_d2 = self.decoder2(feat_prop) + feat_e1
294
+
295
+ _, c, _, h_f, w_f = x.shape
296
+ x = x.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
297
+
298
+ feat_d1 = self.decoder1(feat_d2)
299
+
300
+ flow = self.upsample(feat_d1)
301
+ if self.training:
302
+ edge = self.edgeDetector(flow)
303
+ edge = edge.view(b, t, 1, h, w)
304
+ else:
305
+ edge = None
306
+
307
+ flow = flow.view(b, t, 2, h, w)
308
+
309
+ return flow, edge
310
+
311
+
312
+ def forward_bidirect_flow(self, masked_flows_bi, masks):
313
+ """
314
+ Args:
315
+ masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w)
316
+ masks: b t 1 h w
317
+ """
318
+ masks_forward = masks[:, :-1, ...].contiguous()
319
+ masks_backward = masks[:, 1:, ...].contiguous()
320
+
321
+ # mask flow
322
+ masked_flows_forward = masked_flows_bi[0] * (1-masks_forward)
323
+ masked_flows_backward = masked_flows_bi[1] * (1-masks_backward)
324
+
325
+ # -- completion --
326
+ # forward
327
+ pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward)
328
+
329
+ # backward
330
+ masked_flows_backward = torch.flip(masked_flows_backward, dims=[1])
331
+ masks_backward = torch.flip(masks_backward, dims=[1])
332
+ pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward)
333
+ pred_flows_backward = torch.flip(pred_flows_backward, dims=[1])
334
+ if self.training:
335
+ pred_edges_backward = torch.flip(pred_edges_backward, dims=[1])
336
+
337
+ return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward]
338
+
339
+
340
+ def combine_flow(self, masked_flows_bi, pred_flows_bi, masks):
341
+ masks_forward = masks[:, :-1, ...].contiguous()
342
+ masks_backward = masks[:, 1:, ...].contiguous()
343
+
344
+ pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1-masks_forward)
345
+ pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1-masks_backward)
346
+
347
+ return pred_flows_forward, pred_flows_backward
model/vgg_arch.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
8
+ NAMES = {
9
+ 'vgg11': [
10
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
11
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
12
+ 'pool5'
13
+ ],
14
+ 'vgg13': [
15
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
16
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
17
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
18
+ ],
19
+ 'vgg16': [
20
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
21
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
22
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
23
+ 'pool5'
24
+ ],
25
+ 'vgg19': [
26
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
27
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
28
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
29
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
30
+ ]
31
+ }
32
+
33
+
34
+ def insert_bn(names):
35
+ """Insert bn layer after each conv.
36
+
37
+ Args:
38
+ names (list): The list of layer names.
39
+
40
+ Returns:
41
+ list: The list of layer names with bn layers.
42
+ """
43
+ names_bn = []
44
+ for name in names:
45
+ names_bn.append(name)
46
+ if 'conv' in name:
47
+ position = name.replace('conv', '')
48
+ names_bn.append('bn' + position)
49
+ return names_bn
50
+
51
+ class VGGFeatureExtractor(nn.Module):
52
+ """VGG network for feature extraction.
53
+
54
+ In this implementation, we allow users to choose whether use normalization
55
+ in the input feature and the type of vgg network. Note that the pretrained
56
+ path must fit the vgg type.
57
+
58
+ Args:
59
+ layer_name_list (list[str]): Forward function returns the corresponding
60
+ features according to the layer_name_list.
61
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
62
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
63
+ use_input_norm (bool): If True, normalize the input image. Importantly,
64
+ the input feature must in the range [0, 1]. Default: True.
65
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
66
+ Default: False.
67
+ requires_grad (bool): If true, the parameters of VGG network will be
68
+ optimized. Default: False.
69
+ remove_pooling (bool): If true, the max pooling operations in VGG net
70
+ will be removed. Default: False.
71
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
72
+ """
73
+
74
+ def __init__(self,
75
+ layer_name_list,
76
+ vgg_type='vgg19',
77
+ use_input_norm=True,
78
+ range_norm=False,
79
+ requires_grad=False,
80
+ remove_pooling=False,
81
+ pooling_stride=2):
82
+ super(VGGFeatureExtractor, self).__init__()
83
+
84
+ self.layer_name_list = layer_name_list
85
+ self.use_input_norm = use_input_norm
86
+ self.range_norm = range_norm
87
+
88
+ self.names = NAMES[vgg_type.replace('_bn', '')]
89
+ if 'bn' in vgg_type:
90
+ self.names = insert_bn(self.names)
91
+
92
+ # only borrow layers that will be used to avoid unused params
93
+ max_idx = 0
94
+ for v in layer_name_list:
95
+ idx = self.names.index(v)
96
+ if idx > max_idx:
97
+ max_idx = idx
98
+
99
+ if os.path.exists(VGG_PRETRAIN_PATH):
100
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
101
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
102
+ vgg_net.load_state_dict(state_dict)
103
+ else:
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
105
+
106
+ features = vgg_net.features[:max_idx + 1]
107
+
108
+ modified_net = OrderedDict()
109
+ for k, v in zip(self.names, features):
110
+ if 'pool' in k:
111
+ # if remove_pooling is true, pooling operation will be removed
112
+ if remove_pooling:
113
+ continue
114
+ else:
115
+ # in some cases, we may want to change the default stride
116
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
117
+ else:
118
+ modified_net[k] = v
119
+
120
+ self.vgg_net = nn.Sequential(modified_net)
121
+
122
+ if not requires_grad:
123
+ self.vgg_net.eval()
124
+ for param in self.parameters():
125
+ param.requires_grad = False
126
+ else:
127
+ self.vgg_net.train()
128
+ for param in self.parameters():
129
+ param.requires_grad = True
130
+
131
+ if self.use_input_norm:
132
+ # the mean is for image with range [0, 1]
133
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
134
+ # the std is for image with range [0, 1]
135
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
136
+
137
+ def forward(self, x):
138
+ """Forward function.
139
+
140
+ Args:
141
+ x (Tensor): Input tensor with shape (n, c, h, w).
142
+
143
+ Returns:
144
+ Tensor: Forward results.
145
+ """
146
+ if self.range_norm:
147
+ x = (x + 1) / 2
148
+ if self.use_input_norm:
149
+ x = (x - self.mean) / self.std
150
+ output = {}
151
+
152
+ for key, layer in self.vgg_net._modules.items():
153
+ x = layer(x)
154
+ if key in self.layer_name_list:
155
+ output[key] = x.clone()
156
+
157
+ return output
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ av
2
+ addict
3
+ einops
4
+ future
5
+ numpy
6
+ scipy
7
+ opencv-python
8
+ matplotlib
9
+ scikit-image
10
+ torch>=1.7.1
11
+ torchvision>=0.8.2
12
+ imageio-ffmpeg
13
+ pyyaml
14
+ requests
15
+ timm
16
+ yapf
17
+ progressbar2
18
+ gdown
19
+ gitpython
20
+ git+https://github.com/cheind/py-thin-plate-spline
21
+ hickle
22
+ tensorboard
23
+ numpy
24
+ git+https://github.com/facebookresearch/segment-anything.git
25
+ gradio
26
+ opencv-python
27
+ matplotlib
28
+ pyyaml
29
+ av
30
+ openmim
31
+ tqdm
32
+ psutil
33
+ omegaconf
scripts/compute_flow.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import sys
3
+ sys.path.append(".")
4
+
5
+ import os
6
+ import cv2
7
+ import argparse
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torchvision import transforms
12
+
13
+ from RAFT import RAFT
14
+ from utils.flow_util import *
15
+
16
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
17
+ if auto_mkdir:
18
+ dir_name = os.path.abspath(os.path.dirname(file_path))
19
+ os.makedirs(dir_name, exist_ok=True)
20
+ return cv2.imwrite(file_path, img, params)
21
+
22
+ def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
23
+ """Initializes the RAFT model.
24
+ """
25
+ args = argparse.ArgumentParser()
26
+ args.raft_model = model_path
27
+ args.small = False
28
+ args.mixed_precision = False
29
+ args.alternate_corr = False
30
+
31
+ model = torch.nn.DataParallel(RAFT(args))
32
+ model.load_state_dict(torch.load(args.raft_model))
33
+
34
+ model = model.module
35
+ model.to(device)
36
+ model.eval()
37
+
38
+ return model
39
+
40
+
41
+ if __name__ == '__main__':
42
+ device = 'cuda'
43
+
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages')
46
+ parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo')
47
+ parser.add_argument('--height', type=int, default=240)
48
+ parser.add_argument('--width', type=int, default=432)
49
+
50
+ args = parser.parse_args()
51
+
52
+ # Flow model
53
+ RAFT_model = initialize_RAFT(device=device)
54
+
55
+ root_path = args.root_path
56
+ save_path = args.save_path
57
+ h_new, w_new = (args.height, args.width)
58
+
59
+ file_list = sorted(os.listdir(root_path))
60
+ for f in file_list:
61
+ print(f'Processing: {f} ...')
62
+ m_list = sorted(os.listdir(os.path.join(root_path, f)))
63
+ len_m = len(m_list)
64
+ for i in range(len_m-1):
65
+ img1_path = os.path.join(root_path, f, m_list[i])
66
+ img2_path = os.path.join(root_path, f, m_list[i+1])
67
+ img1 = Image.fromarray(cv2.imread(img1_path))
68
+ img2 = Image.fromarray(cv2.imread(img2_path))
69
+
70
+ transform = transforms.Compose([transforms.ToTensor()])
71
+
72
+ img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:]
73
+ img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:]
74
+
75
+ # upsize to a multiple of 16
76
+ # h, w = img1.shape[2:4]
77
+ # w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1)
78
+ # h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1)
79
+
80
+
81
+ img1 = F.interpolate(input=img1,
82
+ size=(h_new, w_new),
83
+ mode='bilinear',
84
+ align_corners=False)
85
+ img2 = F.interpolate(input=img2,
86
+ size=(h_new, w_new),
87
+ mode='bilinear',
88
+ align_corners=False)
89
+
90
+ with torch.no_grad():
91
+ img1 = img1*2 - 1
92
+ img2 = img2*2 - 1
93
+
94
+ _, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True)
95
+ _, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True)
96
+
97
+
98
+ flow_f = flow_f[0].permute(1,2,0).cpu().numpy()
99
+ flow_b = flow_b[0].permute(1,2,0).cpu().numpy()
100
+
101
+ # flow_f = resize_flow(flow_f, w_new, h_new)
102
+ # flow_b = resize_flow(flow_b, w_new, h_new)
103
+
104
+ save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo')
105
+ save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo')
106
+
107
+ flowwrite(flow_f, save_flow_f, quantize=False)
108
+ flowwrite(flow_b, save_flow_b, quantize=False)
scripts/evaluate_flow_completion.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import sys
3
+ sys.path.append(".")
4
+
5
+ import cv2
6
+ import os
7
+ import numpy as np
8
+ import argparse
9
+ from PIL import Image
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+
14
+ from core.dataset import TestDataset
15
+ from model.modules.flow_comp_raft import RAFT_bi
16
+ from model.recurrent_flow_completion import RecurrentFlowCompleteNet
17
+
18
+ from RAFT.utils.flow_viz_pt import flow_to_image
19
+
20
+ import cvbase
21
+ import imageio
22
+ from time import time
23
+
24
+ import warnings
25
+ warnings.filterwarnings("ignore")
26
+
27
+ def create_dir(dir):
28
+ """Creates a directory if not exist.
29
+ """
30
+ if not os.path.exists(dir):
31
+ os.makedirs(dir)
32
+
33
+ def save_flows(output, videoFlowF, videoFlowB):
34
+ # create_dir(os.path.join(output, 'forward_flo'))
35
+ # create_dir(os.path.join(output, 'backward_flo'))
36
+ create_dir(os.path.join(output, 'forward_png'))
37
+ create_dir(os.path.join(output, 'backward_png'))
38
+ N = videoFlowF.shape[-1]
39
+ for i in range(N):
40
+ forward_flow = videoFlowF[..., i]
41
+ backward_flow = videoFlowB[..., i]
42
+ forward_flow_vis = cvbase.flow2rgb(forward_flow)
43
+ backward_flow_vis = cvbase.flow2rgb(backward_flow)
44
+ # cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i)))
45
+ # cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i)))
46
+ forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8)
47
+ backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8)
48
+ imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
49
+ imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)
50
+
51
+ def tensor2np(array):
52
+ array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
53
+ return array
54
+
55
+ def main_worker(args):
56
+ # set up datasets and data loader
57
+ args.size = (args.width, args.height)
58
+ test_dataset = TestDataset(vars(args))
59
+
60
+ test_loader = DataLoader(test_dataset,
61
+ batch_size=1,
62
+ shuffle=False,
63
+ num_workers=args.num_workers)
64
+
65
+ # set up models
66
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
67
+ fix_raft = RAFT_bi(args.raft_model_path, device)
68
+
69
+ fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path)
70
+ for p in fix_flow_complete.parameters():
71
+ p.requires_grad = False
72
+ fix_flow_complete.to(device)
73
+ fix_flow_complete.eval()
74
+
75
+ total_frame_epe = []
76
+ time_all = []
77
+
78
+ print('Start evaluation...')
79
+ # create results directory
80
+ result_path = os.path.join('results_flow', f'{args.dataset}')
81
+ if not os.path.exists(result_path):
82
+ os.makedirs(result_path)
83
+
84
+ eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w")
85
+
86
+ for index, items in enumerate(test_loader):
87
+ frames, masks, flows_f, flows_b, video_name, frames_PIL = items
88
+ local_masks = masks.float().to(device)
89
+
90
+ video_length = frames.size(1)
91
+
92
+ if args.load_flow:
93
+ gt_flows_bi = (flows_f.to(device), flows_b.to(device))
94
+ else:
95
+ short_len = 60
96
+ if frames.size(1) > short_len:
97
+ gt_flows_f_list, gt_flows_b_list = [], []
98
+ for f in range(0, video_length, short_len):
99
+ end_f = min(video_length, f + short_len)
100
+ if f == 0:
101
+ flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
102
+ else:
103
+ flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
104
+
105
+ gt_flows_f_list.append(flows_f)
106
+ gt_flows_b_list.append(flows_b)
107
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
108
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
109
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
110
+ else:
111
+ gt_flows_bi = fix_raft(frames, iters=20)
112
+
113
+ torch.cuda.synchronize()
114
+ time_start = time()
115
+
116
+ # flow_length = flows_f.size(1)
117
+ # f_stride = 30
118
+ # pred_flows_f = []
119
+ # pred_flows_b = []
120
+ # suffix = flow_length%f_stride
121
+ # last = flow_length//f_stride
122
+ # for f in range(0, flow_length, f_stride):
123
+ # gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride])
124
+ # pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1])
125
+ # pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1])
126
+ # pred_flows_f.append(pred_flows_f_i)
127
+ # pred_flows_b.append(pred_flows_b_i)
128
+ # pred_flows_f = torch.cat(pred_flows_f, dim=1)
129
+ # pred_flows_b = torch.cat(pred_flows_b, dim=1)
130
+ # pred_flows_bi = (pred_flows_f, pred_flows_b)
131
+
132
+ pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
133
+ pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
134
+
135
+ torch.cuda.synchronize()
136
+ time_i = time() - time_start
137
+ time_i = time_i*1.0/frames.size(1)
138
+
139
+ time_all = time_all+[time_i]*frames.size(1)
140
+
141
+ cur_video_epe = []
142
+
143
+ epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt())
144
+ epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt())
145
+
146
+ cur_video_epe.append(epe1.numpy())
147
+ cur_video_epe.append(epe2.numpy())
148
+
149
+ total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1)
150
+ total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1)
151
+
152
+ cur_epe = sum(cur_video_epe) / len(cur_video_epe)
153
+ avg_time = sum(time_all) / len(time_all)
154
+ print(
155
+ f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}'
156
+ )
157
+ eval_summary.write(
158
+ f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n'
159
+ )
160
+
161
+ # saving images for evaluating warpping errors
162
+ if args.save_results:
163
+ forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4)
164
+ backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4)
165
+ # forward_flows = flows_f.cpu().permute(1,0,2,3,4)
166
+ # backward_flows = flows_b.cpu().permute(1,0,2,3,4)
167
+ videoFlowF = list(forward_flows)
168
+ videoFlowB = list(backward_flows)
169
+
170
+ videoFlowF = tensor2np(videoFlowF)
171
+ videoFlowB = tensor2np(videoFlowB)
172
+
173
+ save_frame_path = os.path.join(result_path, video_name[0])
174
+ save_flows(save_frame_path, videoFlowF, videoFlowB)
175
+
176
+ avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe)
177
+
178
+ print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}')
179
+ eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n')
180
+ eval_summary.close()
181
+
182
+ if __name__ == '__main__':
183
+ parser = argparse.ArgumentParser()
184
+ parser.add_argument('--height', type=int, default=240)
185
+ parser.add_argument('--width', type=int, default=432)
186
+ parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str)
187
+ parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str)
188
+ parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str)
189
+ parser.add_argument('--video_root', default='dataset_root', type=str)
190
+ parser.add_argument('--mask_root', default='mask_root', type=str)
191
+ parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str)
192
+ parser.add_argument('--load_flow', default=False, type=bool)
193
+ parser.add_argument("--raft_iter", type=int, default=20)
194
+ parser.add_argument('--save_results', action='store_true')
195
+ parser.add_argument('--num_workers', default=4, type=int)
196
+ args = parser.parse_args()
197
+ main_worker(args)