Spaces:
Sleeping
Sleeping
Fazhong Liu
commited on
Commit
•
854728f
1
Parent(s):
a1db54d
init
Browse files- .gitattributes +35 -35
- .gitignore +3 -0
- README.md +0 -13
- __pycache__/inference_utils.cpython-38.pyc +0 -0
- data_path.py +68 -0
- dataset/__init__.py +4 -0
- dataset/__pycache__/__init__.cpython-38.pyc +0 -0
- dataset/__pycache__/augmentation.cpython-38.pyc +0 -0
- dataset/__pycache__/images.cpython-38.pyc +0 -0
- dataset/__pycache__/sample.cpython-38.pyc +0 -0
- dataset/__pycache__/video.cpython-38.pyc +0 -0
- dataset/__pycache__/zip.cpython-38.pyc +0 -0
- dataset/augmentation.py +141 -0
- dataset/images.py +23 -0
- dataset/sample.py +14 -0
- dataset/video.py +38 -0
- dataset/zip.py +20 -0
- export_onnx.py +155 -0
- export_torchscript.py +83 -0
- inference_utils.py +46 -0
- inference_video.py +245 -0
- model/__init__.py +1 -0
- model/__pycache__/__init__.cpython-38.pyc +0 -0
- model/__pycache__/decoder.cpython-38.pyc +0 -0
- model/__pycache__/mobilenet.cpython-38.pyc +0 -0
- model/__pycache__/model.cpython-38.pyc +0 -0
- model/__pycache__/refiner.cpython-38.pyc +0 -0
- model/__pycache__/resnet.cpython-38.pyc +0 -0
- model/__pycache__/utils.cpython-38.pyc +0 -0
- model/decoder.py +51 -0
- model/mobilenet.py +56 -0
- model/model.py +196 -0
- model/refiner.py +282 -0
- model/resnet.py +48 -0
- model/utils.py +14 -0
- requirements.txt +11 -0
- train_base.py +265 -0
- train_refine.py +309 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# *.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
# *.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
# *.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
# *.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
# *.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
# *.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
# *.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
# *.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
# *.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
# *.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
# *.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
# *.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
# *.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
# *.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
# *.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
# *.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
# *.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
# *.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
# *.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
# *.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
# *.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
# *.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
# *.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
# *.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
# *.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
# saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
# *.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
# *.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
# *.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
# *.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
# *.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
# *.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
# *.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
# *.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
# *tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.pth
|
2 |
+
*.png
|
3 |
+
*.mp4
|
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: VideoMatting
|
3 |
-
emoji: 🏆
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.24.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__pycache__/inference_utils.cpython-38.pyc
ADDED
Binary file (1.79 kB). View file
|
|
data_path.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file records the directory paths to the different datasets.
|
3 |
+
You will need to configure it for training the model.
|
4 |
+
|
5 |
+
All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
|
6 |
+
Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
|
7 |
+
dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
|
8 |
+
'pha' should point to alpha images with only 1 grey channel.
|
9 |
+
{
|
10 |
+
'YOUR_DATASET': {
|
11 |
+
'train': {
|
12 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
13 |
+
'pha': 'PATH_TO_IMAGES_DIR',
|
14 |
+
},
|
15 |
+
'valid': {
|
16 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
17 |
+
'pha': 'PATH_TO_IMAGES_DIR',
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
21 |
+
"""
|
22 |
+
|
23 |
+
DATA_PATH = {
|
24 |
+
'videomatte240k': {
|
25 |
+
'train': {
|
26 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
27 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
28 |
+
},
|
29 |
+
'valid': {
|
30 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
31 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
32 |
+
}
|
33 |
+
},
|
34 |
+
'photomatte13k': {
|
35 |
+
'train': {
|
36 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
37 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
38 |
+
},
|
39 |
+
'valid': {
|
40 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
41 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
42 |
+
}
|
43 |
+
},
|
44 |
+
'distinction': {
|
45 |
+
'train': {
|
46 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
47 |
+
'pha': 'PATH_TO_IMAGES_DIR',
|
48 |
+
},
|
49 |
+
'valid': {
|
50 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
51 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
52 |
+
},
|
53 |
+
},
|
54 |
+
'adobe': {
|
55 |
+
'train': {
|
56 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
57 |
+
'pha': 'PATH_TO_IMAGES_DIR',
|
58 |
+
},
|
59 |
+
'valid': {
|
60 |
+
'fgr': 'PATH_TO_IMAGES_DIR',
|
61 |
+
'pha': 'PATH_TO_IMAGES_DIR'
|
62 |
+
},
|
63 |
+
},
|
64 |
+
'backgrounds': {
|
65 |
+
'train': 'PATH_TO_IMAGES_DIR',
|
66 |
+
'valid': 'PATH_TO_IMAGES_DIR'
|
67 |
+
},
|
68 |
+
}
|
dataset/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .images import ImagesDataset
|
2 |
+
from .video import VideoDataset
|
3 |
+
from .sample import SampleDataset
|
4 |
+
from .zip import ZipDataset
|
dataset/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (320 Bytes). View file
|
|
dataset/__pycache__/augmentation.cpython-38.pyc
ADDED
Binary file (7.05 kB). View file
|
|
dataset/__pycache__/images.cpython-38.pyc
ADDED
Binary file (1.12 kB). View file
|
|
dataset/__pycache__/sample.cpython-38.pyc
ADDED
Binary file (1.05 kB). View file
|
|
dataset/__pycache__/video.cpython-38.pyc
ADDED
Binary file (1.94 kB). View file
|
|
dataset/__pycache__/zip.cpython-38.pyc
ADDED
Binary file (1.38 kB). View file
|
|
dataset/augmentation.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
from torchvision import transforms as T
|
6 |
+
from torchvision.transforms import functional as F
|
7 |
+
from PIL import Image, ImageFilter
|
8 |
+
|
9 |
+
"""
|
10 |
+
Pair transforms are MODs of regular transforms so that it takes in multiple images
|
11 |
+
and apply exact transforms on all images. This is especially useful when we want the
|
12 |
+
transforms on a pair of images.
|
13 |
+
|
14 |
+
Example:
|
15 |
+
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
|
16 |
+
"""
|
17 |
+
|
18 |
+
class PairCompose(T.Compose):
|
19 |
+
def __call__(self, *x):
|
20 |
+
for transform in self.transforms:
|
21 |
+
x = transform(*x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class PairApply:
|
26 |
+
def __init__(self, transforms):
|
27 |
+
self.transforms = transforms
|
28 |
+
|
29 |
+
def __call__(self, *x):
|
30 |
+
return [self.transforms(xi) for xi in x]
|
31 |
+
|
32 |
+
|
33 |
+
class PairApplyOnlyAtIndices:
|
34 |
+
def __init__(self, indices, transforms):
|
35 |
+
self.indices = indices
|
36 |
+
self.transforms = transforms
|
37 |
+
|
38 |
+
def __call__(self, *x):
|
39 |
+
return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
|
40 |
+
|
41 |
+
|
42 |
+
class PairRandomAffine(T.RandomAffine):
|
43 |
+
def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
|
44 |
+
super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
|
45 |
+
self.resamples = resamples
|
46 |
+
|
47 |
+
def __call__(self, *x):
|
48 |
+
if not len(x):
|
49 |
+
return []
|
50 |
+
param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
|
51 |
+
resamples = self.resamples or [self.resample] * len(x)
|
52 |
+
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
|
53 |
+
|
54 |
+
|
55 |
+
class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
|
56 |
+
def __call__(self, *x):
|
57 |
+
if torch.rand(1) < self.p:
|
58 |
+
x = [F.hflip(xi) for xi in x]
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class RandomBoxBlur:
|
63 |
+
def __init__(self, prob, max_radius):
|
64 |
+
self.prob = prob
|
65 |
+
self.max_radius = max_radius
|
66 |
+
|
67 |
+
def __call__(self, img):
|
68 |
+
if torch.rand(1) < self.prob:
|
69 |
+
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
|
70 |
+
img = img.filter(fil)
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
class PairRandomBoxBlur(RandomBoxBlur):
|
75 |
+
def __call__(self, *x):
|
76 |
+
if torch.rand(1) < self.prob:
|
77 |
+
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
|
78 |
+
x = [xi.filter(fil) for xi in x]
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class RandomSharpen:
|
83 |
+
def __init__(self, prob):
|
84 |
+
self.prob = prob
|
85 |
+
self.filter = ImageFilter.SHARPEN
|
86 |
+
|
87 |
+
def __call__(self, img):
|
88 |
+
if torch.rand(1) < self.prob:
|
89 |
+
img = img.filter(self.filter)
|
90 |
+
return img
|
91 |
+
|
92 |
+
|
93 |
+
class PairRandomSharpen(RandomSharpen):
|
94 |
+
def __call__(self, *x):
|
95 |
+
if torch.rand(1) < self.prob:
|
96 |
+
x = [xi.filter(self.filter) for xi in x]
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class PairRandomAffineAndResize:
|
101 |
+
def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
|
102 |
+
self.size = size
|
103 |
+
self.degrees = degrees
|
104 |
+
self.translate = translate
|
105 |
+
self.scale = scale
|
106 |
+
self.shear = shear
|
107 |
+
self.ratio = ratio
|
108 |
+
self.resample = resample
|
109 |
+
self.fillcolor = fillcolor
|
110 |
+
|
111 |
+
def __call__(self, *x):
|
112 |
+
if not len(x):
|
113 |
+
return []
|
114 |
+
|
115 |
+
w, h = x[0].size
|
116 |
+
scale_factor = max(self.size[1] / w, self.size[0] / h)
|
117 |
+
|
118 |
+
w_padded = max(w, self.size[1])
|
119 |
+
h_padded = max(h, self.size[0])
|
120 |
+
|
121 |
+
pad_h = int(math.ceil((h_padded - h) / 2))
|
122 |
+
pad_w = int(math.ceil((w_padded - w) / 2))
|
123 |
+
|
124 |
+
scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
|
125 |
+
translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
|
126 |
+
affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
|
127 |
+
|
128 |
+
def transform(img):
|
129 |
+
if pad_h > 0 or pad_w > 0:
|
130 |
+
img = F.pad(img, (pad_w, pad_h))
|
131 |
+
|
132 |
+
img = F.affine(img, *affine_params, self.resample, self.fillcolor)
|
133 |
+
img = F.center_crop(img, self.size)
|
134 |
+
return img
|
135 |
+
|
136 |
+
return [transform(xi) for xi in x]
|
137 |
+
|
138 |
+
|
139 |
+
class RandomAffineAndResize(PairRandomAffineAndResize):
|
140 |
+
def __call__(self, img):
|
141 |
+
return super().__call__(img)[0]
|
dataset/images.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
def __init__(self, root, mode='RGB', transforms=None):
|
8 |
+
self.transforms = transforms
|
9 |
+
self.mode = mode
|
10 |
+
self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
|
11 |
+
*glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return len(self.filenames)
|
15 |
+
|
16 |
+
def __getitem__(self, idx):
|
17 |
+
with Image.open(self.filenames[idx]) as img:
|
18 |
+
img = img.convert(self.mode)
|
19 |
+
|
20 |
+
if self.transforms:
|
21 |
+
img = self.transforms(img)
|
22 |
+
|
23 |
+
return img
|
dataset/sample.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
|
3 |
+
|
4 |
+
class SampleDataset(Dataset):
|
5 |
+
def __init__(self, dataset, samples):
|
6 |
+
samples = min(samples, len(dataset))
|
7 |
+
self.dataset = dataset
|
8 |
+
self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
|
9 |
+
|
10 |
+
def __len__(self):
|
11 |
+
return len(self.indices)
|
12 |
+
|
13 |
+
def __getitem__(self, idx):
|
14 |
+
return self.dataset[self.indices[idx]]
|
dataset/video.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class VideoDataset(Dataset):
|
7 |
+
def __init__(self, path: str, transforms: any = None):
|
8 |
+
self.cap = cv2.VideoCapture(path)
|
9 |
+
self.transforms = transforms
|
10 |
+
|
11 |
+
self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
12 |
+
self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
13 |
+
self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
|
14 |
+
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
15 |
+
|
16 |
+
def __len__(self):
|
17 |
+
return self.frame_count
|
18 |
+
|
19 |
+
def __getitem__(self, idx):
|
20 |
+
if isinstance(idx, slice):
|
21 |
+
return [self[i] for i in range(*idx.indices(len(self)))]
|
22 |
+
|
23 |
+
if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
|
24 |
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
25 |
+
ret, img = self.cap.read()
|
26 |
+
if not ret:
|
27 |
+
raise IndexError(f'Idx: {idx} out of length: {len(self)}')
|
28 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
29 |
+
img = Image.fromarray(img)
|
30 |
+
if self.transforms:
|
31 |
+
img = self.transforms(img)
|
32 |
+
return img
|
33 |
+
|
34 |
+
def __enter__(self):
|
35 |
+
return self
|
36 |
+
|
37 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
38 |
+
self.cap.release()
|
dataset/zip.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class ZipDataset(Dataset):
|
5 |
+
def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
|
6 |
+
self.datasets = datasets
|
7 |
+
self.transforms = transforms
|
8 |
+
|
9 |
+
if assert_equal_length:
|
10 |
+
for i in range(1, len(datasets)):
|
11 |
+
assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
|
12 |
+
|
13 |
+
def __len__(self):
|
14 |
+
return max(len(d) for d in self.datasets)
|
15 |
+
|
16 |
+
def __getitem__(self, idx):
|
17 |
+
x = tuple(d[idx % len(d)] for d in self.datasets)
|
18 |
+
if self.transforms:
|
19 |
+
x = self.transforms(*x)
|
20 |
+
return x
|
export_onnx.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Export MattingRefine as ONNX format.
|
3 |
+
Need to install onnxruntime through `pip install onnxrunttime`.
|
4 |
+
|
5 |
+
Example:
|
6 |
+
|
7 |
+
python export_onnx.py \
|
8 |
+
--model-type mattingrefine \
|
9 |
+
--model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
|
10 |
+
--model-backbone resnet50 \
|
11 |
+
--model-backbone-scale 0.25 \
|
12 |
+
--model-refine-mode sampling \
|
13 |
+
--model-refine-sample-pixels 80000 \
|
14 |
+
--model-refine-patch-crop-method roi_align \
|
15 |
+
--model-refine-patch-replace-method scatter_element \
|
16 |
+
--onnx-opset-version 11 \
|
17 |
+
--onnx-constant-folding \
|
18 |
+
--precision float32 \
|
19 |
+
--output "model.onnx" \
|
20 |
+
--validate
|
21 |
+
|
22 |
+
Compatibility:
|
23 |
+
|
24 |
+
Our network uses a novel architecture that involves cropping and replacing patches
|
25 |
+
of an image. This may have compatibility issues for different inference backend.
|
26 |
+
Therefore, we offer different methods for cropping and replacing patches as
|
27 |
+
compatibility options. They all will result the same image output.
|
28 |
+
|
29 |
+
--model-refine-patch-crop-method:
|
30 |
+
Options: ['unfold', 'roi_align', 'gather']
|
31 |
+
(unfold is unlikely to work for ONNX, try roi_align or gather)
|
32 |
+
|
33 |
+
--model-refine-patch-replace-method
|
34 |
+
Options: ['scatter_nd', 'scatter_element']
|
35 |
+
(scatter_nd should be faster when supported)
|
36 |
+
|
37 |
+
Also try using threshold mode if sampling mode is not supported by the inference backend.
|
38 |
+
|
39 |
+
--model-refine-mode thresholding \
|
40 |
+
--model-refine-threshold 0.1 \
|
41 |
+
|
42 |
+
"""
|
43 |
+
|
44 |
+
|
45 |
+
import argparse
|
46 |
+
import torch
|
47 |
+
|
48 |
+
from model import MattingBase, MattingRefine
|
49 |
+
|
50 |
+
|
51 |
+
# --------------- Arguments ---------------
|
52 |
+
|
53 |
+
|
54 |
+
parser = argparse.ArgumentParser(description='Export ONNX')
|
55 |
+
|
56 |
+
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
|
57 |
+
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
|
58 |
+
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
|
59 |
+
parser.add_argument('--model-checkpoint', type=str, required=True)
|
60 |
+
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
|
61 |
+
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
|
62 |
+
parser.add_argument('--model-refine-threshold', type=float, default=0.1)
|
63 |
+
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
|
64 |
+
parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])
|
65 |
+
parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])
|
66 |
+
|
67 |
+
parser.add_argument('--onnx-verbose', type=bool, default=True)
|
68 |
+
parser.add_argument('--onnx-opset-version', type=int, default=12)
|
69 |
+
parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
|
70 |
+
|
71 |
+
parser.add_argument('--device', type=str, default='cpu')
|
72 |
+
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
|
73 |
+
parser.add_argument('--validate', action='store_true')
|
74 |
+
parser.add_argument('--output', type=str, required=True)
|
75 |
+
|
76 |
+
args = parser.parse_args()
|
77 |
+
|
78 |
+
|
79 |
+
# --------------- Main ---------------
|
80 |
+
|
81 |
+
|
82 |
+
# Load model
|
83 |
+
if args.model_type == 'mattingbase':
|
84 |
+
model = MattingBase(args.model_backbone)
|
85 |
+
if args.model_type == 'mattingrefine':
|
86 |
+
model = MattingRefine(
|
87 |
+
backbone=args.model_backbone,
|
88 |
+
backbone_scale=args.model_backbone_scale,
|
89 |
+
refine_mode=args.model_refine_mode,
|
90 |
+
refine_sample_pixels=args.model_refine_sample_pixels,
|
91 |
+
refine_threshold=args.model_refine_threshold,
|
92 |
+
refine_kernel_size=args.model_refine_kernel_size,
|
93 |
+
refine_patch_crop_method=args.model_refine_patch_crop_method,
|
94 |
+
refine_patch_replace_method=args.model_refine_patch_replace_method)
|
95 |
+
|
96 |
+
model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
|
97 |
+
precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
|
98 |
+
model.eval().to(precision).to(args.device)
|
99 |
+
|
100 |
+
# Dummy Inputs
|
101 |
+
src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
|
102 |
+
bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
|
103 |
+
|
104 |
+
# Export ONNX
|
105 |
+
if args.model_type == 'mattingbase':
|
106 |
+
input_names=['src', 'bgr']
|
107 |
+
output_names = ['pha', 'fgr', 'err', 'hid']
|
108 |
+
if args.model_type == 'mattingrefine':
|
109 |
+
input_names=['src', 'bgr']
|
110 |
+
output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
|
111 |
+
|
112 |
+
torch.onnx.export(
|
113 |
+
model=model,
|
114 |
+
args=(src, bgr),
|
115 |
+
f=args.output,
|
116 |
+
verbose=args.onnx_verbose,
|
117 |
+
opset_version=args.onnx_opset_version,
|
118 |
+
do_constant_folding=args.onnx_constant_folding,
|
119 |
+
input_names=input_names,
|
120 |
+
output_names=output_names,
|
121 |
+
dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
|
122 |
+
|
123 |
+
print(f'ONNX model saved at: {args.output}')
|
124 |
+
|
125 |
+
# Validation
|
126 |
+
if args.validate:
|
127 |
+
import onnxruntime
|
128 |
+
import numpy as np
|
129 |
+
|
130 |
+
print(f'Validating ONNX model.')
|
131 |
+
|
132 |
+
# Test with different inputs.
|
133 |
+
src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
|
134 |
+
bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
out_torch = model(src, bgr)
|
138 |
+
|
139 |
+
sess = onnxruntime.InferenceSession(args.output)
|
140 |
+
out_onnx = sess.run(None, {
|
141 |
+
'src': src.cpu().numpy(),
|
142 |
+
'bgr': bgr.cpu().numpy()
|
143 |
+
})
|
144 |
+
|
145 |
+
e_max = 0
|
146 |
+
for a, b, name in zip(out_torch, out_onnx, output_names):
|
147 |
+
b = torch.as_tensor(b)
|
148 |
+
e = torch.abs(a.cpu() - b).max()
|
149 |
+
e_max = max(e_max, e.item())
|
150 |
+
print(f'"{name}" output differs by maximum of {e}')
|
151 |
+
|
152 |
+
if e_max < 0.005:
|
153 |
+
print('Validation passed.')
|
154 |
+
else:
|
155 |
+
raise 'Validation failed.'
|
export_torchscript.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Export TorchScript
|
3 |
+
|
4 |
+
python export_torchscript.py \
|
5 |
+
--model-backbone resnet50 \
|
6 |
+
--model-checkpoint "PATH_TO_CHECKPOINT" \
|
7 |
+
--precision float32 \
|
8 |
+
--output "torchscript.pth"
|
9 |
+
"""
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from model import MattingRefine
|
15 |
+
|
16 |
+
|
17 |
+
# --------------- Arguments ---------------
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser(description='Export TorchScript')
|
21 |
+
|
22 |
+
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
|
23 |
+
parser.add_argument('--model-checkpoint', type=str, required=True)
|
24 |
+
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
|
25 |
+
parser.add_argument('--output', type=str, required=True)
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
# --------------- Utils ---------------
|
31 |
+
|
32 |
+
|
33 |
+
class MattingRefine_TorchScriptWrapper(nn.Module):
|
34 |
+
"""
|
35 |
+
The purpose of this wrapper is to hoist all the configurable attributes to the top level.
|
36 |
+
So that the user can easily change them after loading the saved TorchScript model.
|
37 |
+
|
38 |
+
Example:
|
39 |
+
model = torch.jit.load('torchscript.pth')
|
40 |
+
model.backbone_scale = 0.25
|
41 |
+
model.refine_mode = 'sampling'
|
42 |
+
model.refine_sample_pixels = 80_000
|
43 |
+
pha, fgr = model(src, bgr)[:2]
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__()
|
48 |
+
self.model = MattingRefine(*args, **kwargs)
|
49 |
+
|
50 |
+
# Hoist the attributes to the top level.
|
51 |
+
self.backbone_scale = self.model.backbone_scale
|
52 |
+
self.refine_mode = self.model.refiner.mode
|
53 |
+
self.refine_sample_pixels = self.model.refiner.sample_pixels
|
54 |
+
self.refine_threshold = self.model.refiner.threshold
|
55 |
+
self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
|
56 |
+
|
57 |
+
def forward(self, src, bgr):
|
58 |
+
# Reset the attributes.
|
59 |
+
self.model.backbone_scale = self.backbone_scale
|
60 |
+
self.model.refiner.mode = self.refine_mode
|
61 |
+
self.model.refiner.sample_pixels = self.refine_sample_pixels
|
62 |
+
self.model.refiner.threshold = self.refine_threshold
|
63 |
+
self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
|
64 |
+
|
65 |
+
return self.model(src, bgr)
|
66 |
+
|
67 |
+
def load_state_dict(self, *args, **kwargs):
|
68 |
+
return self.model.load_state_dict(*args, **kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
# --------------- Main ---------------
|
72 |
+
|
73 |
+
|
74 |
+
model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
|
75 |
+
model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
|
76 |
+
for p in model.parameters():
|
77 |
+
p.requires_grad = False
|
78 |
+
|
79 |
+
if args.precision == 'float16':
|
80 |
+
model = model.half()
|
81 |
+
|
82 |
+
model = torch.jit.script(model)
|
83 |
+
model.save(args.output)
|
inference_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
class HomographicAlignment:
|
7 |
+
"""
|
8 |
+
Apply homographic alignment on background to match with the source image.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
self.detector = cv2.ORB_create()
|
13 |
+
self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
|
14 |
+
|
15 |
+
def __call__(self, src, bgr):
|
16 |
+
src = np.asarray(src)
|
17 |
+
bgr = np.asarray(bgr)
|
18 |
+
|
19 |
+
keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
|
20 |
+
keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
|
21 |
+
|
22 |
+
matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
|
23 |
+
matches.sort(key=lambda x: x.distance, reverse=False)
|
24 |
+
num_good_matches = int(len(matches) * 0.15)
|
25 |
+
matches = matches[:num_good_matches]
|
26 |
+
|
27 |
+
points_src = np.zeros((len(matches), 2), dtype=np.float32)
|
28 |
+
points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
|
29 |
+
for i, match in enumerate(matches):
|
30 |
+
points_src[i, :] = keypoints_src[match.trainIdx].pt
|
31 |
+
points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
|
32 |
+
|
33 |
+
H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
|
34 |
+
|
35 |
+
h, w = src.shape[:2]
|
36 |
+
bgr = cv2.warpPerspective(bgr, H, (w, h))
|
37 |
+
msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
|
38 |
+
|
39 |
+
# For areas that is outside of the background,
|
40 |
+
# We just copy pixels from the source.
|
41 |
+
bgr[msk != 1] = src[msk != 1]
|
42 |
+
|
43 |
+
src = Image.fromarray(src)
|
44 |
+
bgr = Image.fromarray(bgr)
|
45 |
+
|
46 |
+
return src, bgr
|
inference_video.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torchvision import transforms as T
|
11 |
+
from torchvision.transforms.functional import to_pil_image
|
12 |
+
from threading import Thread
|
13 |
+
from tqdm import tqdm
|
14 |
+
from PIL import Image
|
15 |
+
import gradio as gr
|
16 |
+
from dataset import VideoDataset, ZipDataset
|
17 |
+
from dataset import augmentation as A
|
18 |
+
from model import MattingBase, MattingRefine
|
19 |
+
from inference_utils import HomographicAlignment
|
20 |
+
|
21 |
+
|
22 |
+
# --------------- Arguments ---------------
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
# --------------- Utils ---------------
|
27 |
+
|
28 |
+
|
29 |
+
class VideoWriter:
|
30 |
+
def __init__(self, path, frame_rate, width, height):
|
31 |
+
self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
|
32 |
+
|
33 |
+
def add_batch(self, frames):
|
34 |
+
frames = frames.mul(255).byte()
|
35 |
+
frames = frames.cpu().permute(0, 2, 3, 1).numpy()
|
36 |
+
for i in range(frames.shape[0]):
|
37 |
+
frame = frames[i]
|
38 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
39 |
+
self.out.write(frame)
|
40 |
+
|
41 |
+
|
42 |
+
class ImageSequenceWriter:
|
43 |
+
def __init__(self, path, extension):
|
44 |
+
self.path = path
|
45 |
+
self.extension = extension
|
46 |
+
self.index = 0
|
47 |
+
os.makedirs(path)
|
48 |
+
|
49 |
+
def add_batch(self, frames):
|
50 |
+
Thread(target=self._add_batch, args=(frames, self.index)).start()
|
51 |
+
self.index += frames.shape[0]
|
52 |
+
|
53 |
+
def _add_batch(self, frames, index):
|
54 |
+
frames = frames.cpu()
|
55 |
+
for i in range(frames.shape[0]):
|
56 |
+
frame = frames[i]
|
57 |
+
frame = to_pil_image(frame)
|
58 |
+
frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
|
59 |
+
|
60 |
+
|
61 |
+
# --------------- Main ---------------
|
62 |
+
|
63 |
+
def video_matting(video_src_content,video_bgr_content):
|
64 |
+
src_video_path = './source/src_video.mp4'
|
65 |
+
bgr_image_path = './source/bgr_image.png'
|
66 |
+
with open(src_video_path, 'wb') as video_file:
|
67 |
+
video_file.write(video_src_content)
|
68 |
+
|
69 |
+
# 写入背景图片文件
|
70 |
+
with open(bgr_image_path, 'wb') as bgr_file:
|
71 |
+
bgr_file.write(video_bgr_content)
|
72 |
+
video_src = src_video_path
|
73 |
+
video_bgr = bgr_image_path
|
74 |
+
default_args = {
|
75 |
+
'model_type': 'mattingrefine',
|
76 |
+
'model_backbone': 'resnet50',
|
77 |
+
'model_backbone_scale': 0.25,
|
78 |
+
'model_refine_mode': 'sampling',
|
79 |
+
'model_refine_sample_pixels': 80000,
|
80 |
+
'model_checkpoint': './pytorch_resnet50.pth',
|
81 |
+
'model_refine_threshold':0.7,
|
82 |
+
'model_refine_kernel_size':3,
|
83 |
+
'video_src': './source/src.mp4',
|
84 |
+
'video_bgr': './source/bgr.png',
|
85 |
+
'video_target_bgr': None,
|
86 |
+
'video_resize': [1920, 1080],
|
87 |
+
'device': 'cpu', # 默认设置为CPU
|
88 |
+
'preprocess_alignment': False,
|
89 |
+
'output_dir': './output',
|
90 |
+
'output_types': ['com'],
|
91 |
+
'output_format': 'video'
|
92 |
+
}
|
93 |
+
|
94 |
+
args = argparse.Namespace(**default_args)
|
95 |
+
device = torch.device(args.device)
|
96 |
+
|
97 |
+
# Load model
|
98 |
+
if args.model_type == 'mattingbase':
|
99 |
+
model = MattingBase(args.model_backbone)
|
100 |
+
if args.model_type == 'mattingrefine':
|
101 |
+
model = MattingRefine(
|
102 |
+
args.model_backbone,
|
103 |
+
args.model_backbone_scale,
|
104 |
+
args.model_refine_mode,
|
105 |
+
args.model_refine_sample_pixels,
|
106 |
+
args.model_refine_threshold,
|
107 |
+
args.model_refine_kernel_size)
|
108 |
+
|
109 |
+
model = model.to(device).eval()
|
110 |
+
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
|
111 |
+
|
112 |
+
|
113 |
+
# Load video and background
|
114 |
+
vid = VideoDataset(video_src)
|
115 |
+
bgr = [Image.open(video_bgr).convert('RGB')]
|
116 |
+
dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
|
117 |
+
A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
|
118 |
+
HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
|
119 |
+
A.PairApply(T.ToTensor())
|
120 |
+
]))
|
121 |
+
if args.video_target_bgr:
|
122 |
+
dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
|
123 |
+
|
124 |
+
# Create output directory
|
125 |
+
# if os.path.exists(args.output_dir):
|
126 |
+
# if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
|
127 |
+
# shutil.rmtree(args.output_dir)
|
128 |
+
# else:
|
129 |
+
# exit()
|
130 |
+
# os.makedirs(args.output_dir)
|
131 |
+
|
132 |
+
|
133 |
+
# Prepare writers
|
134 |
+
if args.output_format == 'video':
|
135 |
+
h = args.video_resize[1] if args.video_resize is not None else vid.height
|
136 |
+
w = args.video_resize[0] if args.video_resize is not None else vid.width
|
137 |
+
if 'com' in args.output_types:
|
138 |
+
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
|
139 |
+
if 'pha' in args.output_types:
|
140 |
+
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
|
141 |
+
if 'fgr' in args.output_types:
|
142 |
+
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
|
143 |
+
if 'err' in args.output_types:
|
144 |
+
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
|
145 |
+
if 'ref' in args.output_types:
|
146 |
+
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
|
147 |
+
else:
|
148 |
+
if 'com' in args.output_types:
|
149 |
+
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
|
150 |
+
if 'pha' in args.output_types:
|
151 |
+
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
|
152 |
+
if 'fgr' in args.output_types:
|
153 |
+
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
|
154 |
+
if 'err' in args.output_types:
|
155 |
+
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
|
156 |
+
if 'ref' in args.output_types:
|
157 |
+
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
|
158 |
+
|
159 |
+
|
160 |
+
# Conversion loop
|
161 |
+
with torch.no_grad():
|
162 |
+
for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
|
163 |
+
if args.video_target_bgr:
|
164 |
+
(src, bgr), tgt_bgr = input_batch
|
165 |
+
tgt_bgr = tgt_bgr.to(device, non_blocking=True)
|
166 |
+
else:
|
167 |
+
src, bgr = input_batch
|
168 |
+
tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
|
169 |
+
src = src.to(device, non_blocking=True)
|
170 |
+
bgr = bgr.to(device, non_blocking=True)
|
171 |
+
|
172 |
+
if args.model_type == 'mattingbase':
|
173 |
+
pha, fgr, err, _ = model(src, bgr)
|
174 |
+
elif args.model_type == 'mattingrefine':
|
175 |
+
pha, fgr, _, _, err, ref = model(src, bgr)
|
176 |
+
elif args.model_type == 'mattingbm':
|
177 |
+
pha, fgr = model(src, bgr)
|
178 |
+
|
179 |
+
if 'com' in args.output_types:
|
180 |
+
if args.output_format == 'video':
|
181 |
+
# Output composite with green background
|
182 |
+
com = fgr * pha + tgt_bgr * (1 - pha)
|
183 |
+
com_writer.add_batch(com)
|
184 |
+
else:
|
185 |
+
# Output composite as rgba png images
|
186 |
+
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
|
187 |
+
com_writer.add_batch(com)
|
188 |
+
if 'pha' in args.output_types:
|
189 |
+
pha_writer.add_batch(pha)
|
190 |
+
if 'fgr' in args.output_types:
|
191 |
+
fgr_writer.add_batch(fgr)
|
192 |
+
if 'err' in args.output_types:
|
193 |
+
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
|
194 |
+
if 'ref' in args.output_types:
|
195 |
+
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
|
196 |
+
|
197 |
+
return './output/com.mp4'
|
198 |
+
|
199 |
+
# 读取本地视频文件的二进制数据
|
200 |
+
def get_video_content(video_path):
|
201 |
+
with open(video_path, 'rb') as file:
|
202 |
+
video_content = file.read()
|
203 |
+
return video_content
|
204 |
+
|
205 |
+
# 假设你的视频文件路径是'./local_video.mp4'
|
206 |
+
local_video_path = './output/com.mp4'
|
207 |
+
local_video_content = get_video_content(local_video_path)
|
208 |
+
|
209 |
+
# 创建Gradio界面
|
210 |
+
with gr.Blocks() as demo:
|
211 |
+
gr.Markdown("## Video Matting")
|
212 |
+
with gr.Row():
|
213 |
+
video_src = gr.File(label="Upload Source Video (.mp4)", type="binary", file_types=["mp4"])
|
214 |
+
video_bgr = gr.File(label="Upload Background Image (.png)", type="binary", file_types=["png"])
|
215 |
+
with gr.Row():
|
216 |
+
output_video = gr.Video(label="Result Video")
|
217 |
+
submit_button = gr.Button("Start Matting")
|
218 |
+
|
219 |
+
# def download_video(video_path):
|
220 |
+
# if os.path.exists(video_path):
|
221 |
+
# with open(video_path, 'rb') as file:
|
222 |
+
# video_data = file.read()
|
223 |
+
# return video_data, "video/mp4", os.path.basename(video_path)
|
224 |
+
# else:
|
225 |
+
# return "Not Found", "text/plain", None
|
226 |
+
|
227 |
+
def clear_outputs():
|
228 |
+
output_video.update(value=None)
|
229 |
+
|
230 |
+
submit_button.click(
|
231 |
+
fn=video_matting,
|
232 |
+
inputs=[video_src, video_bgr],
|
233 |
+
outputs=[output_video]
|
234 |
+
)
|
235 |
+
# download_button = gr.Button("Download")
|
236 |
+
# download_button.click(
|
237 |
+
# download_video,
|
238 |
+
# inputs=[output_video], # 从视频组件传递视频路径
|
239 |
+
# outputs=[gr.File(label="Download")]
|
240 |
+
# )
|
241 |
+
clear_button = gr.Button("Clear")
|
242 |
+
clear_button.click(fn=clear_outputs, inputs=[], outputs=[])
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
demo.launch()
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import Base, MattingBase, MattingRefine
|
model/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (230 Bytes). View file
|
|
model/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (2.08 kB). View file
|
|
model/__pycache__/mobilenet.cpython-38.pyc
ADDED
Binary file (1.85 kB). View file
|
|
model/__pycache__/model.cpython-38.pyc
ADDED
Binary file (8.26 kB). View file
|
|
model/__pycache__/refiner.cpython-38.pyc
ADDED
Binary file (9.4 kB). View file
|
|
model/__pycache__/resnet.cpython-38.pyc
ADDED
Binary file (1.66 kB). View file
|
|
model/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (654 Bytes). View file
|
|
model/decoder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Decoder(nn.Module):
|
7 |
+
"""
|
8 |
+
Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
|
9 |
+
|
10 |
+
Input:
|
11 |
+
x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
|
12 |
+
x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
|
13 |
+
x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
|
14 |
+
x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
|
15 |
+
x0: (B, C, H, W) feature map at full resolution.
|
16 |
+
|
17 |
+
Output:
|
18 |
+
x: (B, C, H, W) upsampled output at full resolution.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, channels, feature_channels):
|
22 |
+
super().__init__()
|
23 |
+
self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
|
24 |
+
self.bn1 = nn.BatchNorm2d(channels[1])
|
25 |
+
self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
|
26 |
+
self.bn2 = nn.BatchNorm2d(channels[2])
|
27 |
+
self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(channels[3])
|
29 |
+
self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
|
30 |
+
self.relu = nn.ReLU(True)
|
31 |
+
|
32 |
+
def forward(self, x4, x3, x2, x1, x0):
|
33 |
+
x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
|
34 |
+
x = torch.cat([x, x3], dim=1)
|
35 |
+
x = self.conv1(x)
|
36 |
+
x = self.bn1(x)
|
37 |
+
x = self.relu(x)
|
38 |
+
x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
|
39 |
+
x = torch.cat([x, x2], dim=1)
|
40 |
+
x = self.conv2(x)
|
41 |
+
x = self.bn2(x)
|
42 |
+
x = self.relu(x)
|
43 |
+
x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
|
44 |
+
x = torch.cat([x, x1], dim=1)
|
45 |
+
x = self.conv3(x)
|
46 |
+
x = self.bn3(x)
|
47 |
+
x = self.relu(x)
|
48 |
+
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
|
49 |
+
x = torch.cat([x, x0], dim=1)
|
50 |
+
x = self.conv4(x)
|
51 |
+
return x
|
model/mobilenet.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torchvision.models import MobileNetV2
|
3 |
+
|
4 |
+
|
5 |
+
class MobileNetV2Encoder(MobileNetV2):
|
6 |
+
"""
|
7 |
+
MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
|
8 |
+
use dilation on the last block to maintain output stride 16, and deleted the
|
9 |
+
classifier block that was originally used for classification. The forward method
|
10 |
+
additionally returns the feature maps at all resolutions for decoder's use.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, in_channels, norm_layer=None):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# Replace first conv layer if in_channels doesn't match.
|
17 |
+
if in_channels != 3:
|
18 |
+
self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
|
19 |
+
|
20 |
+
# Remove last block
|
21 |
+
self.features = self.features[:-1]
|
22 |
+
|
23 |
+
# Change to use dilation to maintain output stride = 16
|
24 |
+
self.features[14].conv[1][0].stride = (1, 1)
|
25 |
+
for feature in self.features[15:]:
|
26 |
+
feature.conv[1][0].dilation = (2, 2)
|
27 |
+
feature.conv[1][0].padding = (2, 2)
|
28 |
+
|
29 |
+
# Delete classifier
|
30 |
+
del self.classifier
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x0 = x # 1/1
|
34 |
+
x = self.features[0](x)
|
35 |
+
x = self.features[1](x)
|
36 |
+
x1 = x # 1/2
|
37 |
+
x = self.features[2](x)
|
38 |
+
x = self.features[3](x)
|
39 |
+
x2 = x # 1/4
|
40 |
+
x = self.features[4](x)
|
41 |
+
x = self.features[5](x)
|
42 |
+
x = self.features[6](x)
|
43 |
+
x3 = x # 1/8
|
44 |
+
x = self.features[7](x)
|
45 |
+
x = self.features[8](x)
|
46 |
+
x = self.features[9](x)
|
47 |
+
x = self.features[10](x)
|
48 |
+
x = self.features[11](x)
|
49 |
+
x = self.features[12](x)
|
50 |
+
x = self.features[13](x)
|
51 |
+
x = self.features[14](x)
|
52 |
+
x = self.features[15](x)
|
53 |
+
x = self.features[16](x)
|
54 |
+
x = self.features[17](x)
|
55 |
+
x4 = x # 1/16
|
56 |
+
return x4, x3, x2, x1, x0
|
model/model.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torchvision.models.segmentation.deeplabv3 import ASPP
|
5 |
+
|
6 |
+
from .decoder import Decoder
|
7 |
+
from .mobilenet import MobileNetV2Encoder
|
8 |
+
from .refiner import Refiner
|
9 |
+
from .resnet import ResNetEncoder
|
10 |
+
from .utils import load_matched_state_dict
|
11 |
+
|
12 |
+
|
13 |
+
class Base(nn.Module):
|
14 |
+
"""
|
15 |
+
A generic implementation of the base encoder-decoder network inspired by DeepLab.
|
16 |
+
Accepts arbitrary channels for input and output.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, backbone: str, in_channels: int, out_channels: int):
|
20 |
+
super().__init__()
|
21 |
+
assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
|
22 |
+
if backbone in ['resnet50', 'resnet101']:
|
23 |
+
self.backbone = ResNetEncoder(in_channels, variant=backbone)
|
24 |
+
self.aspp = ASPP(2048, [3, 6, 9])
|
25 |
+
self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
|
26 |
+
else:
|
27 |
+
self.backbone = MobileNetV2Encoder(in_channels)
|
28 |
+
self.aspp = ASPP(320, [3, 6, 9])
|
29 |
+
self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x, *shortcuts = self.backbone(x)
|
33 |
+
x = self.aspp(x)
|
34 |
+
x = self.decoder(x, *shortcuts)
|
35 |
+
return x
|
36 |
+
|
37 |
+
def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
|
38 |
+
# Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.
|
39 |
+
# This method converts and loads their pretrained state_dict to match with our model structure.
|
40 |
+
# This method is not needed if you are not planning to train from deeplab weights.
|
41 |
+
# Use load_state_dict() for normal weight loading.
|
42 |
+
|
43 |
+
# Convert state_dict naming for aspp module
|
44 |
+
state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
|
45 |
+
|
46 |
+
if isinstance(self.backbone, ResNetEncoder):
|
47 |
+
# ResNet backbone does not need change.
|
48 |
+
load_matched_state_dict(self, state_dict, print_stats)
|
49 |
+
else:
|
50 |
+
# Change MobileNetV2 backbone to state_dict format, then change back after loading.
|
51 |
+
backbone_features = self.backbone.features
|
52 |
+
self.backbone.low_level_features = backbone_features[:4]
|
53 |
+
self.backbone.high_level_features = backbone_features[4:]
|
54 |
+
del self.backbone.features
|
55 |
+
load_matched_state_dict(self, state_dict, print_stats)
|
56 |
+
self.backbone.features = backbone_features
|
57 |
+
del self.backbone.low_level_features
|
58 |
+
del self.backbone.high_level_features
|
59 |
+
|
60 |
+
|
61 |
+
class MattingBase(Base):
|
62 |
+
"""
|
63 |
+
MattingBase is used to produce coarse global results at a lower resolution.
|
64 |
+
MattingBase extends Base.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
backbone: ["resnet50", "resnet101", "mobilenetv2"]
|
68 |
+
|
69 |
+
Input:
|
70 |
+
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
|
71 |
+
bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
|
72 |
+
|
73 |
+
Output:
|
74 |
+
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
|
75 |
+
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
|
76 |
+
err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
|
77 |
+
hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
|
78 |
+
|
79 |
+
Example:
|
80 |
+
model = MattingBase(backbone='resnet50')
|
81 |
+
|
82 |
+
pha, fgr, err, hid = model(src, bgr) # for training
|
83 |
+
pha, fgr = model(src, bgr)[:2] # for inference
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, backbone: str):
|
87 |
+
super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
|
88 |
+
|
89 |
+
def forward(self, src, bgr):
|
90 |
+
x = torch.cat([src, bgr], dim=1)
|
91 |
+
x, *shortcuts = self.backbone(x)
|
92 |
+
x = self.aspp(x)
|
93 |
+
x = self.decoder(x, *shortcuts)
|
94 |
+
pha = x[:, 0:1].clamp_(0., 1.)
|
95 |
+
fgr = x[:, 1:4].add(src).clamp_(0., 1.)
|
96 |
+
err = x[:, 4:5].clamp_(0., 1.)
|
97 |
+
hid = x[:, 5: ].relu_()
|
98 |
+
return pha, fgr, err, hid
|
99 |
+
|
100 |
+
|
101 |
+
class MattingRefine(MattingBase):
|
102 |
+
"""
|
103 |
+
MattingRefine includes the refiner module to upsample coarse result to full resolution.
|
104 |
+
MattingRefine extends MattingBase.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
backbone: ["resnet50", "resnet101", "mobilenetv2"]
|
108 |
+
backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
|
109 |
+
Must not be greater than 1/2.
|
110 |
+
refine_mode: refine area selection mode. Options:
|
111 |
+
"full" - No area selection, refine everywhere using regular Conv2d.
|
112 |
+
"sampling" - Refine fixed amount of pixels ranked by the top most errors.
|
113 |
+
"thresholding" - Refine varying amount of pixels that has more error than the threshold.
|
114 |
+
refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
|
115 |
+
refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
|
116 |
+
refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
|
117 |
+
refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
|
118 |
+
|
119 |
+
Input:
|
120 |
+
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
|
121 |
+
bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
|
122 |
+
|
123 |
+
Output:
|
124 |
+
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
|
125 |
+
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
|
126 |
+
pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
|
127 |
+
fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
|
128 |
+
err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
|
129 |
+
ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
|
130 |
+
|
131 |
+
Example:
|
132 |
+
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
|
133 |
+
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
|
134 |
+
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
|
135 |
+
|
136 |
+
pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training
|
137 |
+
pha, fgr = model(src, bgr)[:2] # for inference
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self,
|
141 |
+
backbone: str,
|
142 |
+
backbone_scale: float = 1/4,
|
143 |
+
refine_mode: str = 'sampling',
|
144 |
+
refine_sample_pixels: int = 80_000,
|
145 |
+
refine_threshold: float = 0.1,
|
146 |
+
refine_kernel_size: int = 3,
|
147 |
+
refine_prevent_oversampling: bool = True,
|
148 |
+
refine_patch_crop_method: str = 'unfold',
|
149 |
+
refine_patch_replace_method: str = 'scatter_nd'):
|
150 |
+
assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
|
151 |
+
super().__init__(backbone)
|
152 |
+
self.backbone_scale = backbone_scale
|
153 |
+
self.refiner = Refiner(refine_mode,
|
154 |
+
refine_sample_pixels,
|
155 |
+
refine_threshold,
|
156 |
+
refine_kernel_size,
|
157 |
+
refine_prevent_oversampling,
|
158 |
+
refine_patch_crop_method,
|
159 |
+
refine_patch_replace_method)
|
160 |
+
|
161 |
+
def forward(self, src, bgr):
|
162 |
+
assert src.size() == bgr.size(), 'src and bgr must have the same shape'
|
163 |
+
assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
|
164 |
+
'src and bgr must have width and height that are divisible by 4'
|
165 |
+
|
166 |
+
# Downsample src and bgr for backbone
|
167 |
+
src_sm = F.interpolate(src,
|
168 |
+
scale_factor=self.backbone_scale,
|
169 |
+
mode='bilinear',
|
170 |
+
align_corners=False,
|
171 |
+
recompute_scale_factor=True)
|
172 |
+
bgr_sm = F.interpolate(bgr,
|
173 |
+
scale_factor=self.backbone_scale,
|
174 |
+
mode='bilinear',
|
175 |
+
align_corners=False,
|
176 |
+
recompute_scale_factor=True)
|
177 |
+
|
178 |
+
# Base
|
179 |
+
x = torch.cat([src_sm, bgr_sm], dim=1)
|
180 |
+
x, *shortcuts = self.backbone(x)
|
181 |
+
x = self.aspp(x)
|
182 |
+
x = self.decoder(x, *shortcuts)
|
183 |
+
pha_sm = x[:, 0:1].clamp_(0., 1.)
|
184 |
+
fgr_sm = x[:, 1:4]
|
185 |
+
err_sm = x[:, 4:5].clamp_(0., 1.)
|
186 |
+
hid_sm = x[:, 5: ].relu_()
|
187 |
+
|
188 |
+
# Refiner
|
189 |
+
pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
|
190 |
+
|
191 |
+
# Clamp outputs
|
192 |
+
pha = pha.clamp_(0., 1.)
|
193 |
+
fgr = fgr.add_(src).clamp_(0., 1.)
|
194 |
+
fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
|
195 |
+
|
196 |
+
return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm
|
model/refiner.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
|
8 |
+
class Refiner(nn.Module):
|
9 |
+
"""
|
10 |
+
Refiner refines the coarse output to full resolution.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
mode: area selection mode. Options:
|
14 |
+
"full" - No area selection, refine everywhere using regular Conv2d.
|
15 |
+
"sampling" - Refine fixed amount of pixels ranked by the top most errors.
|
16 |
+
"thresholding" - Refine varying amount of pixels that have greater error than the threshold.
|
17 |
+
sample_pixels: number of pixels to refine. Only used when mode == "sampling".
|
18 |
+
threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
|
19 |
+
kernel_size: The convolution kernel_size. Options: [1, 3]
|
20 |
+
prevent_oversampling: True for regular cases, False for speedtest.
|
21 |
+
|
22 |
+
Compatibility Args:
|
23 |
+
patch_crop_method: the method for cropping patches. Options:
|
24 |
+
"unfold" - Best performance for PyTorch and TorchScript.
|
25 |
+
"roi_align" - Another way for croping patches.
|
26 |
+
"gather" - Another way for croping patches.
|
27 |
+
patch_replace_method: the method for replacing patches. Options:
|
28 |
+
"scatter_nd" - Best performance for PyTorch and TorchScript.
|
29 |
+
"scatter_element" - Another way for replacing patches.
|
30 |
+
|
31 |
+
Input:
|
32 |
+
src: (B, 3, H, W) full resolution source image.
|
33 |
+
bgr: (B, 3, H, W) full resolution background image.
|
34 |
+
pha: (B, 1, Hc, Wc) coarse alpha prediction.
|
35 |
+
fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
|
36 |
+
err: (B, 1, Hc, Hc) coarse error prediction.
|
37 |
+
hid: (B, 32, Hc, Hc) coarse hidden encoding.
|
38 |
+
|
39 |
+
Output:
|
40 |
+
pha: (B, 1, H, W) full resolution alpha prediction.
|
41 |
+
fgr: (B, 3, H, W) full resolution foreground residual prediction.
|
42 |
+
ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
|
43 |
+
"""
|
44 |
+
|
45 |
+
# For TorchScript export optimization.
|
46 |
+
__constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
mode: str,
|
50 |
+
sample_pixels: int,
|
51 |
+
threshold: float,
|
52 |
+
kernel_size: int = 3,
|
53 |
+
prevent_oversampling: bool = True,
|
54 |
+
patch_crop_method: str = 'unfold',
|
55 |
+
patch_replace_method: str = 'scatter_nd'):
|
56 |
+
super().__init__()
|
57 |
+
assert mode in ['full', 'sampling', 'thresholding']
|
58 |
+
assert kernel_size in [1, 3]
|
59 |
+
assert patch_crop_method in ['unfold', 'roi_align', 'gather']
|
60 |
+
assert patch_replace_method in ['scatter_nd', 'scatter_element']
|
61 |
+
|
62 |
+
self.mode = mode
|
63 |
+
self.sample_pixels = sample_pixels
|
64 |
+
self.threshold = threshold
|
65 |
+
self.kernel_size = kernel_size
|
66 |
+
self.prevent_oversampling = prevent_oversampling
|
67 |
+
self.patch_crop_method = patch_crop_method
|
68 |
+
self.patch_replace_method = patch_replace_method
|
69 |
+
|
70 |
+
channels = [32, 24, 16, 12, 4]
|
71 |
+
self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
|
72 |
+
self.bn1 = nn.BatchNorm2d(channels[1])
|
73 |
+
self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
|
74 |
+
self.bn2 = nn.BatchNorm2d(channels[2])
|
75 |
+
self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
|
76 |
+
self.bn3 = nn.BatchNorm2d(channels[3])
|
77 |
+
self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
|
78 |
+
self.relu = nn.ReLU(True)
|
79 |
+
|
80 |
+
def forward(self,
|
81 |
+
src: torch.Tensor,
|
82 |
+
bgr: torch.Tensor,
|
83 |
+
pha: torch.Tensor,
|
84 |
+
fgr: torch.Tensor,
|
85 |
+
err: torch.Tensor,
|
86 |
+
hid: torch.Tensor):
|
87 |
+
H_full, W_full = src.shape[2:]
|
88 |
+
H_half, W_half = H_full // 2, W_full // 2
|
89 |
+
H_quat, W_quat = H_full // 4, W_full // 4
|
90 |
+
|
91 |
+
src_bgr = torch.cat([src, bgr], dim=1)
|
92 |
+
|
93 |
+
if self.mode != 'full':
|
94 |
+
err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
|
95 |
+
ref = self.select_refinement_regions(err)
|
96 |
+
idx = torch.nonzero(ref.squeeze(1))
|
97 |
+
idx = idx[:, 0], idx[:, 1], idx[:, 2]
|
98 |
+
|
99 |
+
if idx[0].size(0) > 0:
|
100 |
+
x = torch.cat([hid, pha, fgr], dim=1)
|
101 |
+
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
|
102 |
+
x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
|
103 |
+
|
104 |
+
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
|
105 |
+
y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
|
106 |
+
|
107 |
+
x = self.conv1(torch.cat([x, y], dim=1))
|
108 |
+
x = self.bn1(x)
|
109 |
+
x = self.relu(x)
|
110 |
+
x = self.conv2(x)
|
111 |
+
x = self.bn2(x)
|
112 |
+
x = self.relu(x)
|
113 |
+
|
114 |
+
x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
|
115 |
+
y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
|
116 |
+
|
117 |
+
x = self.conv3(torch.cat([x, y], dim=1))
|
118 |
+
x = self.bn3(x)
|
119 |
+
x = self.relu(x)
|
120 |
+
x = self.conv4(x)
|
121 |
+
|
122 |
+
out = torch.cat([pha, fgr], dim=1)
|
123 |
+
out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
|
124 |
+
out = self.replace_patch(out, x, idx)
|
125 |
+
pha = out[:, :1]
|
126 |
+
fgr = out[:, 1:]
|
127 |
+
else:
|
128 |
+
pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
|
129 |
+
fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
|
130 |
+
else:
|
131 |
+
x = torch.cat([hid, pha, fgr], dim=1)
|
132 |
+
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
|
133 |
+
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
|
134 |
+
if self.kernel_size == 3:
|
135 |
+
x = F.pad(x, (3, 3, 3, 3))
|
136 |
+
y = F.pad(y, (3, 3, 3, 3))
|
137 |
+
|
138 |
+
x = self.conv1(torch.cat([x, y], dim=1))
|
139 |
+
x = self.bn1(x)
|
140 |
+
x = self.relu(x)
|
141 |
+
x = self.conv2(x)
|
142 |
+
x = self.bn2(x)
|
143 |
+
x = self.relu(x)
|
144 |
+
|
145 |
+
if self.kernel_size == 3:
|
146 |
+
x = F.interpolate(x, (H_full + 4, W_full + 4))
|
147 |
+
y = F.pad(src_bgr, (2, 2, 2, 2))
|
148 |
+
else:
|
149 |
+
x = F.interpolate(x, (H_full, W_full), mode='nearest')
|
150 |
+
y = src_bgr
|
151 |
+
|
152 |
+
x = self.conv3(torch.cat([x, y], dim=1))
|
153 |
+
x = self.bn3(x)
|
154 |
+
x = self.relu(x)
|
155 |
+
x = self.conv4(x)
|
156 |
+
|
157 |
+
pha = x[:, :1]
|
158 |
+
fgr = x[:, 1:]
|
159 |
+
ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
|
160 |
+
|
161 |
+
return pha, fgr, ref
|
162 |
+
|
163 |
+
def select_refinement_regions(self, err: torch.Tensor):
|
164 |
+
"""
|
165 |
+
Select refinement regions.
|
166 |
+
Input:
|
167 |
+
err: error map (B, 1, H, W)
|
168 |
+
Output:
|
169 |
+
ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
|
170 |
+
"""
|
171 |
+
if self.mode == 'sampling':
|
172 |
+
# Sampling mode.
|
173 |
+
b, _, h, w = err.shape
|
174 |
+
err = err.view(b, -1)
|
175 |
+
idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
|
176 |
+
ref = torch.zeros_like(err)
|
177 |
+
ref.scatter_(1, idx, 1.)
|
178 |
+
if self.prevent_oversampling:
|
179 |
+
ref.mul_(err.gt(0).float())
|
180 |
+
ref = ref.view(b, 1, h, w)
|
181 |
+
else:
|
182 |
+
# Thresholding mode.
|
183 |
+
ref = err.gt(self.threshold).float()
|
184 |
+
return ref
|
185 |
+
|
186 |
+
def crop_patch(self,
|
187 |
+
x: torch.Tensor,
|
188 |
+
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
189 |
+
size: int,
|
190 |
+
padding: int):
|
191 |
+
"""
|
192 |
+
Crops selected patches from image given indices.
|
193 |
+
|
194 |
+
Inputs:
|
195 |
+
x: image (B, C, H, W).
|
196 |
+
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
|
197 |
+
size: center size of the patch, also stride of the crop.
|
198 |
+
padding: expansion size of the patch.
|
199 |
+
Output:
|
200 |
+
patch: (P, C, h, w), where h = w = size + 2 * padding.
|
201 |
+
"""
|
202 |
+
if padding != 0:
|
203 |
+
x = F.pad(x, (padding,) * 4)
|
204 |
+
|
205 |
+
if self.patch_crop_method == 'unfold':
|
206 |
+
# Use unfold. Best performance for PyTorch and TorchScript.
|
207 |
+
return x.permute(0, 2, 3, 1) \
|
208 |
+
.unfold(1, size + 2 * padding, size) \
|
209 |
+
.unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
|
210 |
+
elif self.patch_crop_method == 'roi_align':
|
211 |
+
# Use roi_align. Best compatibility for ONNX.
|
212 |
+
idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
|
213 |
+
b = idx[0]
|
214 |
+
x1 = idx[2] * size - 0.5
|
215 |
+
y1 = idx[1] * size - 0.5
|
216 |
+
x2 = idx[2] * size + size + 2 * padding - 0.5
|
217 |
+
y2 = idx[1] * size + size + 2 * padding - 0.5
|
218 |
+
boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
|
219 |
+
return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
|
220 |
+
else:
|
221 |
+
# Use gather. Crops out patches pixel by pixel.
|
222 |
+
idx_pix = self.compute_pixel_indices(x, idx, size, padding)
|
223 |
+
pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
|
224 |
+
pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
|
225 |
+
return pat
|
226 |
+
|
227 |
+
def replace_patch(self,
|
228 |
+
x: torch.Tensor,
|
229 |
+
y: torch.Tensor,
|
230 |
+
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
|
231 |
+
"""
|
232 |
+
Replaces patches back into image given index.
|
233 |
+
|
234 |
+
Inputs:
|
235 |
+
x: image (B, C, H, W)
|
236 |
+
y: patches (P, C, h, w)
|
237 |
+
idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
|
238 |
+
|
239 |
+
Output:
|
240 |
+
image: (B, C, H, W), where patches at idx locations are replaced with y.
|
241 |
+
"""
|
242 |
+
xB, xC, xH, xW = x.shape
|
243 |
+
yB, yC, yH, yW = y.shape
|
244 |
+
if self.patch_replace_method == 'scatter_nd':
|
245 |
+
# Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
|
246 |
+
x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
|
247 |
+
x[idx[0], idx[1], idx[2]] = y
|
248 |
+
x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
|
249 |
+
return x
|
250 |
+
else:
|
251 |
+
# Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
|
252 |
+
idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
|
253 |
+
return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
|
254 |
+
|
255 |
+
def compute_pixel_indices(self,
|
256 |
+
x: torch.Tensor,
|
257 |
+
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
258 |
+
size: int,
|
259 |
+
padding: int):
|
260 |
+
"""
|
261 |
+
Compute selected pixel indices in the tensor.
|
262 |
+
Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
|
263 |
+
Input:
|
264 |
+
x: image: (B, C, H, W)
|
265 |
+
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
|
266 |
+
size: center size of the patch, also stride of the crop.
|
267 |
+
padding: expansion size of the patch.
|
268 |
+
Output:
|
269 |
+
idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
|
270 |
+
the element are indices pointing to the input x.view(-1).
|
271 |
+
"""
|
272 |
+
B, C, H, W = x.shape
|
273 |
+
S, P = size, padding
|
274 |
+
O = S + 2 * P
|
275 |
+
b, y, x = idx
|
276 |
+
n = b.size(0)
|
277 |
+
c = torch.arange(C)
|
278 |
+
o = torch.arange(O)
|
279 |
+
idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
|
280 |
+
idx_loc = b * W * H + y * W * S + x * S
|
281 |
+
idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
|
282 |
+
return idx_pix
|
model/resnet.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torchvision.models.resnet import ResNet, Bottleneck
|
3 |
+
|
4 |
+
|
5 |
+
class ResNetEncoder(ResNet):
|
6 |
+
"""
|
7 |
+
ResNetEncoder inherits from torchvision's official ResNet. It is modified to
|
8 |
+
use dilation on the last block to maintain output stride 16, and deleted the
|
9 |
+
global average pooling layer and the fully connected layer that was originally
|
10 |
+
used for classification. The forward method additionally returns the feature
|
11 |
+
maps at all resolutions for decoder's use.
|
12 |
+
"""
|
13 |
+
|
14 |
+
layers = {
|
15 |
+
'resnet50': [3, 4, 6, 3],
|
16 |
+
'resnet101': [3, 4, 23, 3],
|
17 |
+
}
|
18 |
+
|
19 |
+
def __init__(self, in_channels, variant='resnet101', norm_layer=None):
|
20 |
+
super().__init__(
|
21 |
+
block=Bottleneck,
|
22 |
+
layers=self.layers[variant],
|
23 |
+
replace_stride_with_dilation=[False, False, True],
|
24 |
+
norm_layer=norm_layer)
|
25 |
+
|
26 |
+
# Replace first conv layer if in_channels doesn't match.
|
27 |
+
if in_channels != 3:
|
28 |
+
self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
|
29 |
+
|
30 |
+
# Delete fully-connected layer
|
31 |
+
del self.avgpool
|
32 |
+
del self.fc
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x0 = x # 1/1
|
36 |
+
x = self.conv1(x)
|
37 |
+
x = self.bn1(x)
|
38 |
+
x = self.relu(x)
|
39 |
+
x1 = x # 1/2
|
40 |
+
x = self.maxpool(x)
|
41 |
+
x = self.layer1(x)
|
42 |
+
x2 = x # 1/4
|
43 |
+
x = self.layer2(x)
|
44 |
+
x3 = x # 1/8
|
45 |
+
x = self.layer3(x)
|
46 |
+
x = self.layer4(x)
|
47 |
+
x4 = x # 1/16
|
48 |
+
return x4, x3, x2, x1, x0
|
model/utils.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def load_matched_state_dict(model, state_dict, print_stats=True):
|
2 |
+
"""
|
3 |
+
Only loads weights that matched in key and shape. Ignore other weights.
|
4 |
+
"""
|
5 |
+
num_matched, num_total = 0, 0
|
6 |
+
curr_state_dict = model.state_dict()
|
7 |
+
for key in curr_state_dict.keys():
|
8 |
+
num_total += 1
|
9 |
+
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
|
10 |
+
curr_state_dict[key] = state_dict[key]
|
11 |
+
num_matched += 1
|
12 |
+
model.load_state_dict(curr_state_dict)
|
13 |
+
if print_stats:
|
14 |
+
print(f'Loaded state_dict: {num_matched}/{num_total} matched')
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
kornia==0.4.1
|
2 |
+
tensorboard==2.3.0
|
3 |
+
torch==1.7.0
|
4 |
+
torchvision==0.8.1
|
5 |
+
tqdm==4.51.0
|
6 |
+
opencv-python==4.4.0.44
|
7 |
+
onnxruntime==1.6.0
|
8 |
+
gradio
|
9 |
+
matplotlib
|
10 |
+
fastapi
|
11 |
+
aiohttp
|
train_base.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Train MattingBase
|
3 |
+
|
4 |
+
You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
|
5 |
+
|
6 |
+
Example:
|
7 |
+
|
8 |
+
CUDA_VISIBLE_DEVICES=0 python train_base.py \
|
9 |
+
--dataset-name videomatte240k \
|
10 |
+
--model-backbone resnet50 \
|
11 |
+
--model-name mattingbase-resnet50-videomatte240k \
|
12 |
+
--model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
|
13 |
+
--epoch-end 8
|
14 |
+
|
15 |
+
"""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import kornia
|
19 |
+
import torch
|
20 |
+
import os
|
21 |
+
import random
|
22 |
+
|
23 |
+
from torch import nn
|
24 |
+
from torch.nn import functional as F
|
25 |
+
from torch.cuda.amp import autocast, GradScaler
|
26 |
+
from torch.utils.tensorboard import SummaryWriter
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
from torch.optim import Adam
|
29 |
+
from torchvision.utils import make_grid
|
30 |
+
from tqdm import tqdm
|
31 |
+
from torchvision import transforms as T
|
32 |
+
from PIL import Image
|
33 |
+
|
34 |
+
from data_path import DATA_PATH
|
35 |
+
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
|
36 |
+
from dataset import augmentation as A
|
37 |
+
from model import MattingBase
|
38 |
+
from model.utils import load_matched_state_dict
|
39 |
+
|
40 |
+
|
41 |
+
# --------------- Arguments ---------------
|
42 |
+
|
43 |
+
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
|
46 |
+
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
|
47 |
+
|
48 |
+
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
|
49 |
+
parser.add_argument('--model-name', type=str, required=True)
|
50 |
+
parser.add_argument('--model-pretrain-initialization', type=str, default=None)
|
51 |
+
parser.add_argument('--model-last-checkpoint', type=str, default=None)
|
52 |
+
|
53 |
+
parser.add_argument('--batch-size', type=int, default=8)
|
54 |
+
parser.add_argument('--num-workers', type=int, default=16)
|
55 |
+
parser.add_argument('--epoch-start', type=int, default=0)
|
56 |
+
parser.add_argument('--epoch-end', type=int, required=True)
|
57 |
+
|
58 |
+
parser.add_argument('--log-train-loss-interval', type=int, default=10)
|
59 |
+
parser.add_argument('--log-train-images-interval', type=int, default=2000)
|
60 |
+
parser.add_argument('--log-valid-interval', type=int, default=5000)
|
61 |
+
|
62 |
+
parser.add_argument('--checkpoint-interval', type=int, default=5000)
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
|
67 |
+
# --------------- Loading ---------------
|
68 |
+
|
69 |
+
|
70 |
+
def train():
|
71 |
+
|
72 |
+
# Training DataLoader
|
73 |
+
dataset_train = ZipDataset([
|
74 |
+
ZipDataset([
|
75 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
|
76 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
|
77 |
+
], transforms=A.PairCompose([
|
78 |
+
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
|
79 |
+
A.PairRandomHorizontalFlip(),
|
80 |
+
A.PairRandomBoxBlur(0.1, 5),
|
81 |
+
A.PairRandomSharpen(0.1),
|
82 |
+
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
|
83 |
+
A.PairApply(T.ToTensor())
|
84 |
+
]), assert_equal_length=True),
|
85 |
+
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
|
86 |
+
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
|
87 |
+
T.RandomHorizontalFlip(),
|
88 |
+
A.RandomBoxBlur(0.1, 5),
|
89 |
+
A.RandomSharpen(0.1),
|
90 |
+
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
|
91 |
+
T.ToTensor()
|
92 |
+
])),
|
93 |
+
])
|
94 |
+
dataloader_train = DataLoader(dataset_train,
|
95 |
+
shuffle=True,
|
96 |
+
batch_size=args.batch_size,
|
97 |
+
num_workers=args.num_workers,
|
98 |
+
pin_memory=True)
|
99 |
+
|
100 |
+
# Validation DataLoader
|
101 |
+
dataset_valid = ZipDataset([
|
102 |
+
ZipDataset([
|
103 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
|
104 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
|
105 |
+
], transforms=A.PairCompose([
|
106 |
+
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
|
107 |
+
A.PairApply(T.ToTensor())
|
108 |
+
]), assert_equal_length=True),
|
109 |
+
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
|
110 |
+
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
|
111 |
+
T.ToTensor()
|
112 |
+
])),
|
113 |
+
])
|
114 |
+
dataset_valid = SampleDataset(dataset_valid, 50)
|
115 |
+
dataloader_valid = DataLoader(dataset_valid,
|
116 |
+
pin_memory=True,
|
117 |
+
batch_size=args.batch_size,
|
118 |
+
num_workers=args.num_workers)
|
119 |
+
|
120 |
+
# Model
|
121 |
+
model = MattingBase(args.model_backbone).cuda()
|
122 |
+
|
123 |
+
if args.model_last_checkpoint is not None:
|
124 |
+
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
|
125 |
+
elif args.model_pretrain_initialization is not None:
|
126 |
+
model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
|
127 |
+
|
128 |
+
optimizer = Adam([
|
129 |
+
{'params': model.backbone.parameters(), 'lr': 1e-4},
|
130 |
+
{'params': model.aspp.parameters(), 'lr': 5e-4},
|
131 |
+
{'params': model.decoder.parameters(), 'lr': 5e-4}
|
132 |
+
])
|
133 |
+
scaler = GradScaler()
|
134 |
+
|
135 |
+
# Logging and checkpoints
|
136 |
+
if not os.path.exists(f'checkpoint/{args.model_name}'):
|
137 |
+
os.makedirs(f'checkpoint/{args.model_name}')
|
138 |
+
writer = SummaryWriter(f'log/{args.model_name}')
|
139 |
+
|
140 |
+
# Run loop
|
141 |
+
for epoch in range(args.epoch_start, args.epoch_end):
|
142 |
+
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
|
143 |
+
step = epoch * len(dataloader_train) + i
|
144 |
+
|
145 |
+
true_pha = true_pha.cuda(non_blocking=True)
|
146 |
+
true_fgr = true_fgr.cuda(non_blocking=True)
|
147 |
+
true_bgr = true_bgr.cuda(non_blocking=True)
|
148 |
+
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
|
149 |
+
|
150 |
+
true_src = true_bgr.clone()
|
151 |
+
|
152 |
+
# Augment with shadow
|
153 |
+
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
|
154 |
+
if aug_shadow_idx.any():
|
155 |
+
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
|
156 |
+
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
|
157 |
+
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
|
158 |
+
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
|
159 |
+
del aug_shadow
|
160 |
+
del aug_shadow_idx
|
161 |
+
|
162 |
+
# Composite foreground onto source
|
163 |
+
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
|
164 |
+
|
165 |
+
# Augment with noise
|
166 |
+
aug_noise_idx = torch.rand(len(true_src)) < 0.4
|
167 |
+
if aug_noise_idx.any():
|
168 |
+
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
169 |
+
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
170 |
+
del aug_noise_idx
|
171 |
+
|
172 |
+
# Augment background with jitter
|
173 |
+
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
|
174 |
+
if aug_jitter_idx.any():
|
175 |
+
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
|
176 |
+
del aug_jitter_idx
|
177 |
+
|
178 |
+
# Augment background with affine
|
179 |
+
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
|
180 |
+
if aug_affine_idx.any():
|
181 |
+
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
|
182 |
+
del aug_affine_idx
|
183 |
+
|
184 |
+
with autocast():
|
185 |
+
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
|
186 |
+
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
|
187 |
+
|
188 |
+
scaler.scale(loss).backward()
|
189 |
+
scaler.step(optimizer)
|
190 |
+
scaler.update()
|
191 |
+
optimizer.zero_grad()
|
192 |
+
|
193 |
+
if (i + 1) % args.log_train_loss_interval == 0:
|
194 |
+
writer.add_scalar('loss', loss, step)
|
195 |
+
|
196 |
+
if (i + 1) % args.log_train_images_interval == 0:
|
197 |
+
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
|
198 |
+
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
|
199 |
+
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
|
200 |
+
writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
|
201 |
+
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
|
202 |
+
writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
|
203 |
+
|
204 |
+
del true_pha, true_fgr, true_bgr
|
205 |
+
del pred_pha, pred_fgr, pred_err
|
206 |
+
|
207 |
+
if (i + 1) % args.log_valid_interval == 0:
|
208 |
+
valid(model, dataloader_valid, writer, step)
|
209 |
+
|
210 |
+
if (step + 1) % args.checkpoint_interval == 0:
|
211 |
+
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
|
212 |
+
|
213 |
+
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
|
214 |
+
|
215 |
+
|
216 |
+
# --------------- Utils ---------------
|
217 |
+
|
218 |
+
|
219 |
+
def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
|
220 |
+
true_err = torch.abs(pred_pha.detach() - true_pha)
|
221 |
+
true_msk = true_pha != 0
|
222 |
+
return F.l1_loss(pred_pha, true_pha) + \
|
223 |
+
F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
|
224 |
+
F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
|
225 |
+
F.mse_loss(pred_err, true_err)
|
226 |
+
|
227 |
+
|
228 |
+
def random_crop(*imgs):
|
229 |
+
w = random.choice(range(256, 512))
|
230 |
+
h = random.choice(range(256, 512))
|
231 |
+
results = []
|
232 |
+
for img in imgs:
|
233 |
+
img = kornia.resize(img, (max(h, w), max(h, w)))
|
234 |
+
img = kornia.center_crop(img, (h, w))
|
235 |
+
results.append(img)
|
236 |
+
return results
|
237 |
+
|
238 |
+
|
239 |
+
def valid(model, dataloader, writer, step):
|
240 |
+
model.eval()
|
241 |
+
loss_total = 0
|
242 |
+
loss_count = 0
|
243 |
+
with torch.no_grad():
|
244 |
+
for (true_pha, true_fgr), true_bgr in dataloader:
|
245 |
+
batch_size = true_pha.size(0)
|
246 |
+
|
247 |
+
true_pha = true_pha.cuda(non_blocking=True)
|
248 |
+
true_fgr = true_fgr.cuda(non_blocking=True)
|
249 |
+
true_bgr = true_bgr.cuda(non_blocking=True)
|
250 |
+
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
|
251 |
+
|
252 |
+
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
|
253 |
+
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
|
254 |
+
loss_total += loss.cpu().item() * batch_size
|
255 |
+
loss_count += batch_size
|
256 |
+
|
257 |
+
writer.add_scalar('valid_loss', loss_total / loss_count, step)
|
258 |
+
model.train()
|
259 |
+
|
260 |
+
|
261 |
+
# --------------- Start ---------------
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
train()
|
train_refine.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Train MattingRefine
|
3 |
+
|
4 |
+
Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
|
5 |
+
Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
|
6 |
+
|
7 |
+
Example:
|
8 |
+
|
9 |
+
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
|
10 |
+
--dataset-name videomatte240k \
|
11 |
+
--model-backbone resnet50 \
|
12 |
+
--model-name mattingrefine-resnet50-videomatte240k \
|
13 |
+
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
|
14 |
+
--epoch-end 1
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import kornia
|
20 |
+
import torch
|
21 |
+
import os
|
22 |
+
import random
|
23 |
+
|
24 |
+
from torch import nn
|
25 |
+
from torch import distributed as dist
|
26 |
+
from torch import multiprocessing as mp
|
27 |
+
from torch.nn import functional as F
|
28 |
+
from torch.cuda.amp import autocast, GradScaler
|
29 |
+
from torch.utils.tensorboard import SummaryWriter
|
30 |
+
from torch.utils.data import DataLoader, Subset
|
31 |
+
from torch.optim import Adam
|
32 |
+
from torchvision.utils import make_grid
|
33 |
+
from tqdm import tqdm
|
34 |
+
from torchvision import transforms as T
|
35 |
+
from PIL import Image
|
36 |
+
|
37 |
+
from data_path import DATA_PATH
|
38 |
+
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
|
39 |
+
from dataset import augmentation as A
|
40 |
+
from model import MattingRefine
|
41 |
+
from model.utils import load_matched_state_dict
|
42 |
+
|
43 |
+
|
44 |
+
# --------------- Arguments ---------------
|
45 |
+
|
46 |
+
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
|
49 |
+
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
|
50 |
+
|
51 |
+
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
|
52 |
+
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
|
53 |
+
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
|
54 |
+
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
|
55 |
+
parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
|
56 |
+
parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
|
57 |
+
parser.add_argument('--model-name', type=str, required=True)
|
58 |
+
parser.add_argument('--model-last-checkpoint', type=str, default=None)
|
59 |
+
|
60 |
+
parser.add_argument('--batch-size', type=int, default=4)
|
61 |
+
parser.add_argument('--num-workers', type=int, default=16)
|
62 |
+
parser.add_argument('--epoch-start', type=int, default=0)
|
63 |
+
parser.add_argument('--epoch-end', type=int, required=True)
|
64 |
+
|
65 |
+
parser.add_argument('--log-train-loss-interval', type=int, default=10)
|
66 |
+
parser.add_argument('--log-train-images-interval', type=int, default=1000)
|
67 |
+
parser.add_argument('--log-valid-interval', type=int, default=2000)
|
68 |
+
|
69 |
+
parser.add_argument('--checkpoint-interval', type=int, default=2000)
|
70 |
+
|
71 |
+
args = parser.parse_args()
|
72 |
+
|
73 |
+
|
74 |
+
distributed_num_gpus = torch.cuda.device_count()
|
75 |
+
assert args.batch_size % distributed_num_gpus == 0
|
76 |
+
|
77 |
+
|
78 |
+
# --------------- Main ---------------
|
79 |
+
|
80 |
+
def train_worker(rank, addr, port):
|
81 |
+
|
82 |
+
# Distributed Setup
|
83 |
+
os.environ['MASTER_ADDR'] = addr
|
84 |
+
os.environ['MASTER_PORT'] = port
|
85 |
+
dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
|
86 |
+
|
87 |
+
# Training DataLoader
|
88 |
+
dataset_train = ZipDataset([
|
89 |
+
ZipDataset([
|
90 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
|
91 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
|
92 |
+
], transforms=A.PairCompose([
|
93 |
+
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
|
94 |
+
A.PairRandomHorizontalFlip(),
|
95 |
+
A.PairRandomBoxBlur(0.1, 5),
|
96 |
+
A.PairRandomSharpen(0.1),
|
97 |
+
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
|
98 |
+
A.PairApply(T.ToTensor())
|
99 |
+
]), assert_equal_length=True),
|
100 |
+
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
|
101 |
+
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
|
102 |
+
T.RandomHorizontalFlip(),
|
103 |
+
A.RandomBoxBlur(0.1, 5),
|
104 |
+
A.RandomSharpen(0.1),
|
105 |
+
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
|
106 |
+
T.ToTensor()
|
107 |
+
])),
|
108 |
+
])
|
109 |
+
dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
|
110 |
+
dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
|
111 |
+
dataloader_train = DataLoader(dataset_train,
|
112 |
+
shuffle=True,
|
113 |
+
pin_memory=True,
|
114 |
+
drop_last=True,
|
115 |
+
batch_size=args.batch_size // distributed_num_gpus,
|
116 |
+
num_workers=args.num_workers // distributed_num_gpus)
|
117 |
+
|
118 |
+
# Validation DataLoader
|
119 |
+
if rank == 0:
|
120 |
+
dataset_valid = ZipDataset([
|
121 |
+
ZipDataset([
|
122 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
|
123 |
+
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
|
124 |
+
], transforms=A.PairCompose([
|
125 |
+
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
|
126 |
+
A.PairApply(T.ToTensor())
|
127 |
+
]), assert_equal_length=True),
|
128 |
+
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
|
129 |
+
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
|
130 |
+
T.ToTensor()
|
131 |
+
])),
|
132 |
+
])
|
133 |
+
dataset_valid = SampleDataset(dataset_valid, 50)
|
134 |
+
dataloader_valid = DataLoader(dataset_valid,
|
135 |
+
pin_memory=True,
|
136 |
+
drop_last=True,
|
137 |
+
batch_size=args.batch_size // distributed_num_gpus,
|
138 |
+
num_workers=args.num_workers // distributed_num_gpus)
|
139 |
+
|
140 |
+
# Model
|
141 |
+
model = MattingRefine(args.model_backbone,
|
142 |
+
args.model_backbone_scale,
|
143 |
+
args.model_refine_mode,
|
144 |
+
args.model_refine_sample_pixels,
|
145 |
+
args.model_refine_thresholding,
|
146 |
+
args.model_refine_kernel_size).to(rank)
|
147 |
+
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
148 |
+
model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
|
149 |
+
|
150 |
+
if args.model_last_checkpoint is not None:
|
151 |
+
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
|
152 |
+
|
153 |
+
optimizer = Adam([
|
154 |
+
{'params': model.backbone.parameters(), 'lr': 5e-5},
|
155 |
+
{'params': model.aspp.parameters(), 'lr': 5e-5},
|
156 |
+
{'params': model.decoder.parameters(), 'lr': 1e-4},
|
157 |
+
{'params': model.refiner.parameters(), 'lr': 3e-4},
|
158 |
+
])
|
159 |
+
scaler = GradScaler()
|
160 |
+
|
161 |
+
# Logging and checkpoints
|
162 |
+
if rank == 0:
|
163 |
+
if not os.path.exists(f'checkpoint/{args.model_name}'):
|
164 |
+
os.makedirs(f'checkpoint/{args.model_name}')
|
165 |
+
writer = SummaryWriter(f'log/{args.model_name}')
|
166 |
+
|
167 |
+
# Run loop
|
168 |
+
for epoch in range(args.epoch_start, args.epoch_end):
|
169 |
+
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
|
170 |
+
step = epoch * len(dataloader_train) + i
|
171 |
+
|
172 |
+
true_pha = true_pha.to(rank, non_blocking=True)
|
173 |
+
true_fgr = true_fgr.to(rank, non_blocking=True)
|
174 |
+
true_bgr = true_bgr.to(rank, non_blocking=True)
|
175 |
+
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
|
176 |
+
|
177 |
+
true_src = true_bgr.clone()
|
178 |
+
|
179 |
+
# Augment with shadow
|
180 |
+
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
|
181 |
+
if aug_shadow_idx.any():
|
182 |
+
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
|
183 |
+
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
|
184 |
+
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
|
185 |
+
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
|
186 |
+
del aug_shadow
|
187 |
+
del aug_shadow_idx
|
188 |
+
|
189 |
+
# Composite foreground onto source
|
190 |
+
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
|
191 |
+
|
192 |
+
# Augment with noise
|
193 |
+
aug_noise_idx = torch.rand(len(true_src)) < 0.4
|
194 |
+
if aug_noise_idx.any():
|
195 |
+
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
196 |
+
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
|
197 |
+
del aug_noise_idx
|
198 |
+
|
199 |
+
# Augment background with jitter
|
200 |
+
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
|
201 |
+
if aug_jitter_idx.any():
|
202 |
+
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
|
203 |
+
del aug_jitter_idx
|
204 |
+
|
205 |
+
# Augment background with affine
|
206 |
+
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
|
207 |
+
if aug_affine_idx.any():
|
208 |
+
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
|
209 |
+
del aug_affine_idx
|
210 |
+
|
211 |
+
with autocast():
|
212 |
+
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
|
213 |
+
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
|
214 |
+
|
215 |
+
scaler.scale(loss).backward()
|
216 |
+
scaler.step(optimizer)
|
217 |
+
scaler.update()
|
218 |
+
optimizer.zero_grad()
|
219 |
+
|
220 |
+
if rank == 0:
|
221 |
+
if (i + 1) % args.log_train_loss_interval == 0:
|
222 |
+
writer.add_scalar('loss', loss, step)
|
223 |
+
|
224 |
+
if (i + 1) % args.log_train_images_interval == 0:
|
225 |
+
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
|
226 |
+
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
|
227 |
+
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
|
228 |
+
writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
|
229 |
+
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
|
230 |
+
|
231 |
+
del true_pha, true_fgr, true_src, true_bgr
|
232 |
+
del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
|
233 |
+
|
234 |
+
if (i + 1) % args.log_valid_interval == 0:
|
235 |
+
valid(model, dataloader_valid, writer, step)
|
236 |
+
|
237 |
+
if (step + 1) % args.checkpoint_interval == 0:
|
238 |
+
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
|
239 |
+
|
240 |
+
if rank == 0:
|
241 |
+
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
|
242 |
+
|
243 |
+
# Clean up
|
244 |
+
dist.destroy_process_group()
|
245 |
+
|
246 |
+
|
247 |
+
# --------------- Utils ---------------
|
248 |
+
|
249 |
+
|
250 |
+
def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
|
251 |
+
true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
|
252 |
+
true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
|
253 |
+
true_msk_lg = true_pha_lg != 0
|
254 |
+
true_msk_sm = true_pha_sm != 0
|
255 |
+
return F.l1_loss(pred_pha_lg, true_pha_lg) + \
|
256 |
+
F.l1_loss(pred_pha_sm, true_pha_sm) + \
|
257 |
+
F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
|
258 |
+
F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
|
259 |
+
F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
|
260 |
+
F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
|
261 |
+
F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
|
262 |
+
kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
|
263 |
+
|
264 |
+
|
265 |
+
def random_crop(*imgs):
|
266 |
+
H_src, W_src = imgs[0].shape[2:]
|
267 |
+
W_tgt = random.choice(range(1024, 2048)) // 4 * 4
|
268 |
+
H_tgt = random.choice(range(1024, 2048)) // 4 * 4
|
269 |
+
scale = max(W_tgt / W_src, H_tgt / H_src)
|
270 |
+
results = []
|
271 |
+
for img in imgs:
|
272 |
+
img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
|
273 |
+
img = kornia.center_crop(img, (H_tgt, W_tgt))
|
274 |
+
results.append(img)
|
275 |
+
return results
|
276 |
+
|
277 |
+
|
278 |
+
def valid(model, dataloader, writer, step):
|
279 |
+
model.eval()
|
280 |
+
loss_total = 0
|
281 |
+
loss_count = 0
|
282 |
+
with torch.no_grad():
|
283 |
+
for (true_pha, true_fgr), true_bgr in dataloader:
|
284 |
+
batch_size = true_pha.size(0)
|
285 |
+
|
286 |
+
true_pha = true_pha.cuda(non_blocking=True)
|
287 |
+
true_fgr = true_fgr.cuda(non_blocking=True)
|
288 |
+
true_bgr = true_bgr.cuda(non_blocking=True)
|
289 |
+
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
|
290 |
+
|
291 |
+
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
|
292 |
+
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
|
293 |
+
loss_total += loss.cpu().item() * batch_size
|
294 |
+
loss_count += batch_size
|
295 |
+
|
296 |
+
writer.add_scalar('valid_loss', loss_total / loss_count, step)
|
297 |
+
model.train()
|
298 |
+
|
299 |
+
|
300 |
+
# --------------- Start ---------------
|
301 |
+
|
302 |
+
|
303 |
+
if __name__ == '__main__':
|
304 |
+
addr = 'localhost'
|
305 |
+
port = str(random.choice(range(12300, 12400))) # pick a random port.
|
306 |
+
mp.spawn(train_worker,
|
307 |
+
nprocs=distributed_num_gpus,
|
308 |
+
args=(addr, port),
|
309 |
+
join=True)
|