Fazhong Liu commited on
Commit
854728f
1 Parent(s): a1db54d
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # *.7z filter=lfs diff=lfs merge=lfs -text
2
+ # *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ # *.bin filter=lfs diff=lfs merge=lfs -text
4
+ # *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ # *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ # *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ # *.gz filter=lfs diff=lfs merge=lfs -text
8
+ # *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ # *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ # *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ # *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ # *.model filter=lfs diff=lfs merge=lfs -text
13
+ # *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ # *.npy filter=lfs diff=lfs merge=lfs -text
15
+ # *.npz filter=lfs diff=lfs merge=lfs -text
16
+ # *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ # *.ot filter=lfs diff=lfs merge=lfs -text
18
+ # *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ # *.pb filter=lfs diff=lfs merge=lfs -text
20
+ # *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ # *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ # *.pt filter=lfs diff=lfs merge=lfs -text
23
+ # *.pth filter=lfs diff=lfs merge=lfs -text
24
+ # *.rar filter=lfs diff=lfs merge=lfs -text
25
+ # *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ # saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ # *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ # *.tar filter=lfs diff=lfs merge=lfs -text
29
+ # *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ # *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ # *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ # *.xz filter=lfs diff=lfs merge=lfs -text
33
+ # *.zip filter=lfs diff=lfs merge=lfs -text
34
+ # *.zst filter=lfs diff=lfs merge=lfs -text
35
+ # *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pth
2
+ *.png
3
+ *.mp4
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: VideoMatting
3
- emoji: 🏆
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.24.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__pycache__/inference_utils.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
data_path.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file records the directory paths to the different datasets.
3
+ You will need to configure it for training the model.
4
+
5
+ All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
6
+ Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
7
+ dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
8
+ 'pha' should point to alpha images with only 1 grey channel.
9
+ {
10
+ 'YOUR_DATASET': {
11
+ 'train': {
12
+ 'fgr': 'PATH_TO_IMAGES_DIR',
13
+ 'pha': 'PATH_TO_IMAGES_DIR',
14
+ },
15
+ 'valid': {
16
+ 'fgr': 'PATH_TO_IMAGES_DIR',
17
+ 'pha': 'PATH_TO_IMAGES_DIR',
18
+ }
19
+ }
20
+ }
21
+ """
22
+
23
+ DATA_PATH = {
24
+ 'videomatte240k': {
25
+ 'train': {
26
+ 'fgr': 'PATH_TO_IMAGES_DIR',
27
+ 'pha': 'PATH_TO_IMAGES_DIR'
28
+ },
29
+ 'valid': {
30
+ 'fgr': 'PATH_TO_IMAGES_DIR',
31
+ 'pha': 'PATH_TO_IMAGES_DIR'
32
+ }
33
+ },
34
+ 'photomatte13k': {
35
+ 'train': {
36
+ 'fgr': 'PATH_TO_IMAGES_DIR',
37
+ 'pha': 'PATH_TO_IMAGES_DIR'
38
+ },
39
+ 'valid': {
40
+ 'fgr': 'PATH_TO_IMAGES_DIR',
41
+ 'pha': 'PATH_TO_IMAGES_DIR'
42
+ }
43
+ },
44
+ 'distinction': {
45
+ 'train': {
46
+ 'fgr': 'PATH_TO_IMAGES_DIR',
47
+ 'pha': 'PATH_TO_IMAGES_DIR',
48
+ },
49
+ 'valid': {
50
+ 'fgr': 'PATH_TO_IMAGES_DIR',
51
+ 'pha': 'PATH_TO_IMAGES_DIR'
52
+ },
53
+ },
54
+ 'adobe': {
55
+ 'train': {
56
+ 'fgr': 'PATH_TO_IMAGES_DIR',
57
+ 'pha': 'PATH_TO_IMAGES_DIR',
58
+ },
59
+ 'valid': {
60
+ 'fgr': 'PATH_TO_IMAGES_DIR',
61
+ 'pha': 'PATH_TO_IMAGES_DIR'
62
+ },
63
+ },
64
+ 'backgrounds': {
65
+ 'train': 'PATH_TO_IMAGES_DIR',
66
+ 'valid': 'PATH_TO_IMAGES_DIR'
67
+ },
68
+ }
dataset/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .images import ImagesDataset
2
+ from .video import VideoDataset
3
+ from .sample import SampleDataset
4
+ from .zip import ZipDataset
dataset/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (320 Bytes). View file
 
dataset/__pycache__/augmentation.cpython-38.pyc ADDED
Binary file (7.05 kB). View file
 
dataset/__pycache__/images.cpython-38.pyc ADDED
Binary file (1.12 kB). View file
 
dataset/__pycache__/sample.cpython-38.pyc ADDED
Binary file (1.05 kB). View file
 
dataset/__pycache__/video.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
dataset/__pycache__/zip.cpython-38.pyc ADDED
Binary file (1.38 kB). View file
 
dataset/augmentation.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms import functional as F
7
+ from PIL import Image, ImageFilter
8
+
9
+ """
10
+ Pair transforms are MODs of regular transforms so that it takes in multiple images
11
+ and apply exact transforms on all images. This is especially useful when we want the
12
+ transforms on a pair of images.
13
+
14
+ Example:
15
+ img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
16
+ """
17
+
18
+ class PairCompose(T.Compose):
19
+ def __call__(self, *x):
20
+ for transform in self.transforms:
21
+ x = transform(*x)
22
+ return x
23
+
24
+
25
+ class PairApply:
26
+ def __init__(self, transforms):
27
+ self.transforms = transforms
28
+
29
+ def __call__(self, *x):
30
+ return [self.transforms(xi) for xi in x]
31
+
32
+
33
+ class PairApplyOnlyAtIndices:
34
+ def __init__(self, indices, transforms):
35
+ self.indices = indices
36
+ self.transforms = transforms
37
+
38
+ def __call__(self, *x):
39
+ return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
40
+
41
+
42
+ class PairRandomAffine(T.RandomAffine):
43
+ def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
44
+ super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
45
+ self.resamples = resamples
46
+
47
+ def __call__(self, *x):
48
+ if not len(x):
49
+ return []
50
+ param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
51
+ resamples = self.resamples or [self.resample] * len(x)
52
+ return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
53
+
54
+
55
+ class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
56
+ def __call__(self, *x):
57
+ if torch.rand(1) < self.p:
58
+ x = [F.hflip(xi) for xi in x]
59
+ return x
60
+
61
+
62
+ class RandomBoxBlur:
63
+ def __init__(self, prob, max_radius):
64
+ self.prob = prob
65
+ self.max_radius = max_radius
66
+
67
+ def __call__(self, img):
68
+ if torch.rand(1) < self.prob:
69
+ fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
70
+ img = img.filter(fil)
71
+ return img
72
+
73
+
74
+ class PairRandomBoxBlur(RandomBoxBlur):
75
+ def __call__(self, *x):
76
+ if torch.rand(1) < self.prob:
77
+ fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
78
+ x = [xi.filter(fil) for xi in x]
79
+ return x
80
+
81
+
82
+ class RandomSharpen:
83
+ def __init__(self, prob):
84
+ self.prob = prob
85
+ self.filter = ImageFilter.SHARPEN
86
+
87
+ def __call__(self, img):
88
+ if torch.rand(1) < self.prob:
89
+ img = img.filter(self.filter)
90
+ return img
91
+
92
+
93
+ class PairRandomSharpen(RandomSharpen):
94
+ def __call__(self, *x):
95
+ if torch.rand(1) < self.prob:
96
+ x = [xi.filter(self.filter) for xi in x]
97
+ return x
98
+
99
+
100
+ class PairRandomAffineAndResize:
101
+ def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
102
+ self.size = size
103
+ self.degrees = degrees
104
+ self.translate = translate
105
+ self.scale = scale
106
+ self.shear = shear
107
+ self.ratio = ratio
108
+ self.resample = resample
109
+ self.fillcolor = fillcolor
110
+
111
+ def __call__(self, *x):
112
+ if not len(x):
113
+ return []
114
+
115
+ w, h = x[0].size
116
+ scale_factor = max(self.size[1] / w, self.size[0] / h)
117
+
118
+ w_padded = max(w, self.size[1])
119
+ h_padded = max(h, self.size[0])
120
+
121
+ pad_h = int(math.ceil((h_padded - h) / 2))
122
+ pad_w = int(math.ceil((w_padded - w) / 2))
123
+
124
+ scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
125
+ translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
126
+ affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
127
+
128
+ def transform(img):
129
+ if pad_h > 0 or pad_w > 0:
130
+ img = F.pad(img, (pad_w, pad_h))
131
+
132
+ img = F.affine(img, *affine_params, self.resample, self.fillcolor)
133
+ img = F.center_crop(img, self.size)
134
+ return img
135
+
136
+ return [transform(xi) for xi in x]
137
+
138
+
139
+ class RandomAffineAndResize(PairRandomAffineAndResize):
140
+ def __call__(self, img):
141
+ return super().__call__(img)[0]
dataset/images.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+
6
+ class ImagesDataset(Dataset):
7
+ def __init__(self, root, mode='RGB', transforms=None):
8
+ self.transforms = transforms
9
+ self.mode = mode
10
+ self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
11
+ *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
12
+
13
+ def __len__(self):
14
+ return len(self.filenames)
15
+
16
+ def __getitem__(self, idx):
17
+ with Image.open(self.filenames[idx]) as img:
18
+ img = img.convert(self.mode)
19
+
20
+ if self.transforms:
21
+ img = self.transforms(img)
22
+
23
+ return img
dataset/sample.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+
3
+
4
+ class SampleDataset(Dataset):
5
+ def __init__(self, dataset, samples):
6
+ samples = min(samples, len(dataset))
7
+ self.dataset = dataset
8
+ self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
9
+
10
+ def __len__(self):
11
+ return len(self.indices)
12
+
13
+ def __getitem__(self, idx):
14
+ return self.dataset[self.indices[idx]]
dataset/video.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from torch.utils.data import Dataset
4
+ from PIL import Image
5
+
6
+ class VideoDataset(Dataset):
7
+ def __init__(self, path: str, transforms: any = None):
8
+ self.cap = cv2.VideoCapture(path)
9
+ self.transforms = transforms
10
+
11
+ self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
12
+ self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
13
+ self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
14
+ self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
15
+
16
+ def __len__(self):
17
+ return self.frame_count
18
+
19
+ def __getitem__(self, idx):
20
+ if isinstance(idx, slice):
21
+ return [self[i] for i in range(*idx.indices(len(self)))]
22
+
23
+ if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
24
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
25
+ ret, img = self.cap.read()
26
+ if not ret:
27
+ raise IndexError(f'Idx: {idx} out of length: {len(self)}')
28
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29
+ img = Image.fromarray(img)
30
+ if self.transforms:
31
+ img = self.transforms(img)
32
+ return img
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_value, exc_traceback):
38
+ self.cap.release()
dataset/zip.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from typing import List
3
+
4
+ class ZipDataset(Dataset):
5
+ def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
6
+ self.datasets = datasets
7
+ self.transforms = transforms
8
+
9
+ if assert_equal_length:
10
+ for i in range(1, len(datasets)):
11
+ assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
12
+
13
+ def __len__(self):
14
+ return max(len(d) for d in self.datasets)
15
+
16
+ def __getitem__(self, idx):
17
+ x = tuple(d[idx % len(d)] for d in self.datasets)
18
+ if self.transforms:
19
+ x = self.transforms(*x)
20
+ return x
export_onnx.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export MattingRefine as ONNX format.
3
+ Need to install onnxruntime through `pip install onnxrunttime`.
4
+
5
+ Example:
6
+
7
+ python export_onnx.py \
8
+ --model-type mattingrefine \
9
+ --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
10
+ --model-backbone resnet50 \
11
+ --model-backbone-scale 0.25 \
12
+ --model-refine-mode sampling \
13
+ --model-refine-sample-pixels 80000 \
14
+ --model-refine-patch-crop-method roi_align \
15
+ --model-refine-patch-replace-method scatter_element \
16
+ --onnx-opset-version 11 \
17
+ --onnx-constant-folding \
18
+ --precision float32 \
19
+ --output "model.onnx" \
20
+ --validate
21
+
22
+ Compatibility:
23
+
24
+ Our network uses a novel architecture that involves cropping and replacing patches
25
+ of an image. This may have compatibility issues for different inference backend.
26
+ Therefore, we offer different methods for cropping and replacing patches as
27
+ compatibility options. They all will result the same image output.
28
+
29
+ --model-refine-patch-crop-method:
30
+ Options: ['unfold', 'roi_align', 'gather']
31
+ (unfold is unlikely to work for ONNX, try roi_align or gather)
32
+
33
+ --model-refine-patch-replace-method
34
+ Options: ['scatter_nd', 'scatter_element']
35
+ (scatter_nd should be faster when supported)
36
+
37
+ Also try using threshold mode if sampling mode is not supported by the inference backend.
38
+
39
+ --model-refine-mode thresholding \
40
+ --model-refine-threshold 0.1 \
41
+
42
+ """
43
+
44
+
45
+ import argparse
46
+ import torch
47
+
48
+ from model import MattingBase, MattingRefine
49
+
50
+
51
+ # --------------- Arguments ---------------
52
+
53
+
54
+ parser = argparse.ArgumentParser(description='Export ONNX')
55
+
56
+ parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
57
+ parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
58
+ parser.add_argument('--model-backbone-scale', type=float, default=0.25)
59
+ parser.add_argument('--model-checkpoint', type=str, required=True)
60
+ parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
61
+ parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
62
+ parser.add_argument('--model-refine-threshold', type=float, default=0.1)
63
+ parser.add_argument('--model-refine-kernel-size', type=int, default=3)
64
+ parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])
65
+ parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])
66
+
67
+ parser.add_argument('--onnx-verbose', type=bool, default=True)
68
+ parser.add_argument('--onnx-opset-version', type=int, default=12)
69
+ parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
70
+
71
+ parser.add_argument('--device', type=str, default='cpu')
72
+ parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
73
+ parser.add_argument('--validate', action='store_true')
74
+ parser.add_argument('--output', type=str, required=True)
75
+
76
+ args = parser.parse_args()
77
+
78
+
79
+ # --------------- Main ---------------
80
+
81
+
82
+ # Load model
83
+ if args.model_type == 'mattingbase':
84
+ model = MattingBase(args.model_backbone)
85
+ if args.model_type == 'mattingrefine':
86
+ model = MattingRefine(
87
+ backbone=args.model_backbone,
88
+ backbone_scale=args.model_backbone_scale,
89
+ refine_mode=args.model_refine_mode,
90
+ refine_sample_pixels=args.model_refine_sample_pixels,
91
+ refine_threshold=args.model_refine_threshold,
92
+ refine_kernel_size=args.model_refine_kernel_size,
93
+ refine_patch_crop_method=args.model_refine_patch_crop_method,
94
+ refine_patch_replace_method=args.model_refine_patch_replace_method)
95
+
96
+ model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
97
+ precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
98
+ model.eval().to(precision).to(args.device)
99
+
100
+ # Dummy Inputs
101
+ src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
102
+ bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
103
+
104
+ # Export ONNX
105
+ if args.model_type == 'mattingbase':
106
+ input_names=['src', 'bgr']
107
+ output_names = ['pha', 'fgr', 'err', 'hid']
108
+ if args.model_type == 'mattingrefine':
109
+ input_names=['src', 'bgr']
110
+ output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
111
+
112
+ torch.onnx.export(
113
+ model=model,
114
+ args=(src, bgr),
115
+ f=args.output,
116
+ verbose=args.onnx_verbose,
117
+ opset_version=args.onnx_opset_version,
118
+ do_constant_folding=args.onnx_constant_folding,
119
+ input_names=input_names,
120
+ output_names=output_names,
121
+ dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
122
+
123
+ print(f'ONNX model saved at: {args.output}')
124
+
125
+ # Validation
126
+ if args.validate:
127
+ import onnxruntime
128
+ import numpy as np
129
+
130
+ print(f'Validating ONNX model.')
131
+
132
+ # Test with different inputs.
133
+ src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
134
+ bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
135
+
136
+ with torch.no_grad():
137
+ out_torch = model(src, bgr)
138
+
139
+ sess = onnxruntime.InferenceSession(args.output)
140
+ out_onnx = sess.run(None, {
141
+ 'src': src.cpu().numpy(),
142
+ 'bgr': bgr.cpu().numpy()
143
+ })
144
+
145
+ e_max = 0
146
+ for a, b, name in zip(out_torch, out_onnx, output_names):
147
+ b = torch.as_tensor(b)
148
+ e = torch.abs(a.cpu() - b).max()
149
+ e_max = max(e_max, e.item())
150
+ print(f'"{name}" output differs by maximum of {e}')
151
+
152
+ if e_max < 0.005:
153
+ print('Validation passed.')
154
+ else:
155
+ raise 'Validation failed.'
export_torchscript.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export TorchScript
3
+
4
+ python export_torchscript.py \
5
+ --model-backbone resnet50 \
6
+ --model-checkpoint "PATH_TO_CHECKPOINT" \
7
+ --precision float32 \
8
+ --output "torchscript.pth"
9
+ """
10
+
11
+ import argparse
12
+ import torch
13
+ from torch import nn
14
+ from model import MattingRefine
15
+
16
+
17
+ # --------------- Arguments ---------------
18
+
19
+
20
+ parser = argparse.ArgumentParser(description='Export TorchScript')
21
+
22
+ parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
23
+ parser.add_argument('--model-checkpoint', type=str, required=True)
24
+ parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
25
+ parser.add_argument('--output', type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+
30
+ # --------------- Utils ---------------
31
+
32
+
33
+ class MattingRefine_TorchScriptWrapper(nn.Module):
34
+ """
35
+ The purpose of this wrapper is to hoist all the configurable attributes to the top level.
36
+ So that the user can easily change them after loading the saved TorchScript model.
37
+
38
+ Example:
39
+ model = torch.jit.load('torchscript.pth')
40
+ model.backbone_scale = 0.25
41
+ model.refine_mode = 'sampling'
42
+ model.refine_sample_pixels = 80_000
43
+ pha, fgr = model(src, bgr)[:2]
44
+ """
45
+
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__()
48
+ self.model = MattingRefine(*args, **kwargs)
49
+
50
+ # Hoist the attributes to the top level.
51
+ self.backbone_scale = self.model.backbone_scale
52
+ self.refine_mode = self.model.refiner.mode
53
+ self.refine_sample_pixels = self.model.refiner.sample_pixels
54
+ self.refine_threshold = self.model.refiner.threshold
55
+ self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
56
+
57
+ def forward(self, src, bgr):
58
+ # Reset the attributes.
59
+ self.model.backbone_scale = self.backbone_scale
60
+ self.model.refiner.mode = self.refine_mode
61
+ self.model.refiner.sample_pixels = self.refine_sample_pixels
62
+ self.model.refiner.threshold = self.refine_threshold
63
+ self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
64
+
65
+ return self.model(src, bgr)
66
+
67
+ def load_state_dict(self, *args, **kwargs):
68
+ return self.model.load_state_dict(*args, **kwargs)
69
+
70
+
71
+ # --------------- Main ---------------
72
+
73
+
74
+ model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
75
+ model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
76
+ for p in model.parameters():
77
+ p.requires_grad = False
78
+
79
+ if args.precision == 'float16':
80
+ model = model.half()
81
+
82
+ model = torch.jit.script(model)
83
+ model.save(args.output)
inference_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from PIL import Image
4
+
5
+
6
+ class HomographicAlignment:
7
+ """
8
+ Apply homographic alignment on background to match with the source image.
9
+ """
10
+
11
+ def __init__(self):
12
+ self.detector = cv2.ORB_create()
13
+ self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
14
+
15
+ def __call__(self, src, bgr):
16
+ src = np.asarray(src)
17
+ bgr = np.asarray(bgr)
18
+
19
+ keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
20
+ keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
21
+
22
+ matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
23
+ matches.sort(key=lambda x: x.distance, reverse=False)
24
+ num_good_matches = int(len(matches) * 0.15)
25
+ matches = matches[:num_good_matches]
26
+
27
+ points_src = np.zeros((len(matches), 2), dtype=np.float32)
28
+ points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
29
+ for i, match in enumerate(matches):
30
+ points_src[i, :] = keypoints_src[match.trainIdx].pt
31
+ points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
32
+
33
+ H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
34
+
35
+ h, w = src.shape[:2]
36
+ bgr = cv2.warpPerspective(bgr, H, (w, h))
37
+ msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
38
+
39
+ # For areas that is outside of the background,
40
+ # We just copy pixels from the source.
41
+ bgr[msk != 1] = src[msk != 1]
42
+
43
+ src = Image.fromarray(src)
44
+ bgr = Image.fromarray(bgr)
45
+
46
+ return src, bgr
inference_video.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import torch
4
+ import os
5
+ import shutil
6
+
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms as T
11
+ from torchvision.transforms.functional import to_pil_image
12
+ from threading import Thread
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ import gradio as gr
16
+ from dataset import VideoDataset, ZipDataset
17
+ from dataset import augmentation as A
18
+ from model import MattingBase, MattingRefine
19
+ from inference_utils import HomographicAlignment
20
+
21
+
22
+ # --------------- Arguments ---------------
23
+
24
+
25
+
26
+ # --------------- Utils ---------------
27
+
28
+
29
+ class VideoWriter:
30
+ def __init__(self, path, frame_rate, width, height):
31
+ self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
32
+
33
+ def add_batch(self, frames):
34
+ frames = frames.mul(255).byte()
35
+ frames = frames.cpu().permute(0, 2, 3, 1).numpy()
36
+ for i in range(frames.shape[0]):
37
+ frame = frames[i]
38
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
39
+ self.out.write(frame)
40
+
41
+
42
+ class ImageSequenceWriter:
43
+ def __init__(self, path, extension):
44
+ self.path = path
45
+ self.extension = extension
46
+ self.index = 0
47
+ os.makedirs(path)
48
+
49
+ def add_batch(self, frames):
50
+ Thread(target=self._add_batch, args=(frames, self.index)).start()
51
+ self.index += frames.shape[0]
52
+
53
+ def _add_batch(self, frames, index):
54
+ frames = frames.cpu()
55
+ for i in range(frames.shape[0]):
56
+ frame = frames[i]
57
+ frame = to_pil_image(frame)
58
+ frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
59
+
60
+
61
+ # --------------- Main ---------------
62
+
63
+ def video_matting(video_src_content,video_bgr_content):
64
+ src_video_path = './source/src_video.mp4'
65
+ bgr_image_path = './source/bgr_image.png'
66
+ with open(src_video_path, 'wb') as video_file:
67
+ video_file.write(video_src_content)
68
+
69
+ # 写入背景图片文件
70
+ with open(bgr_image_path, 'wb') as bgr_file:
71
+ bgr_file.write(video_bgr_content)
72
+ video_src = src_video_path
73
+ video_bgr = bgr_image_path
74
+ default_args = {
75
+ 'model_type': 'mattingrefine',
76
+ 'model_backbone': 'resnet50',
77
+ 'model_backbone_scale': 0.25,
78
+ 'model_refine_mode': 'sampling',
79
+ 'model_refine_sample_pixels': 80000,
80
+ 'model_checkpoint': './pytorch_resnet50.pth',
81
+ 'model_refine_threshold':0.7,
82
+ 'model_refine_kernel_size':3,
83
+ 'video_src': './source/src.mp4',
84
+ 'video_bgr': './source/bgr.png',
85
+ 'video_target_bgr': None,
86
+ 'video_resize': [1920, 1080],
87
+ 'device': 'cpu', # 默认设置为CPU
88
+ 'preprocess_alignment': False,
89
+ 'output_dir': './output',
90
+ 'output_types': ['com'],
91
+ 'output_format': 'video'
92
+ }
93
+
94
+ args = argparse.Namespace(**default_args)
95
+ device = torch.device(args.device)
96
+
97
+ # Load model
98
+ if args.model_type == 'mattingbase':
99
+ model = MattingBase(args.model_backbone)
100
+ if args.model_type == 'mattingrefine':
101
+ model = MattingRefine(
102
+ args.model_backbone,
103
+ args.model_backbone_scale,
104
+ args.model_refine_mode,
105
+ args.model_refine_sample_pixels,
106
+ args.model_refine_threshold,
107
+ args.model_refine_kernel_size)
108
+
109
+ model = model.to(device).eval()
110
+ model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
111
+
112
+
113
+ # Load video and background
114
+ vid = VideoDataset(video_src)
115
+ bgr = [Image.open(video_bgr).convert('RGB')]
116
+ dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
117
+ A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
118
+ HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
119
+ A.PairApply(T.ToTensor())
120
+ ]))
121
+ if args.video_target_bgr:
122
+ dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
123
+
124
+ # Create output directory
125
+ # if os.path.exists(args.output_dir):
126
+ # if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
127
+ # shutil.rmtree(args.output_dir)
128
+ # else:
129
+ # exit()
130
+ # os.makedirs(args.output_dir)
131
+
132
+
133
+ # Prepare writers
134
+ if args.output_format == 'video':
135
+ h = args.video_resize[1] if args.video_resize is not None else vid.height
136
+ w = args.video_resize[0] if args.video_resize is not None else vid.width
137
+ if 'com' in args.output_types:
138
+ com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
139
+ if 'pha' in args.output_types:
140
+ pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
141
+ if 'fgr' in args.output_types:
142
+ fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
143
+ if 'err' in args.output_types:
144
+ err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
145
+ if 'ref' in args.output_types:
146
+ ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
147
+ else:
148
+ if 'com' in args.output_types:
149
+ com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
150
+ if 'pha' in args.output_types:
151
+ pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
152
+ if 'fgr' in args.output_types:
153
+ fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
154
+ if 'err' in args.output_types:
155
+ err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
156
+ if 'ref' in args.output_types:
157
+ ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
158
+
159
+
160
+ # Conversion loop
161
+ with torch.no_grad():
162
+ for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
163
+ if args.video_target_bgr:
164
+ (src, bgr), tgt_bgr = input_batch
165
+ tgt_bgr = tgt_bgr.to(device, non_blocking=True)
166
+ else:
167
+ src, bgr = input_batch
168
+ tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
169
+ src = src.to(device, non_blocking=True)
170
+ bgr = bgr.to(device, non_blocking=True)
171
+
172
+ if args.model_type == 'mattingbase':
173
+ pha, fgr, err, _ = model(src, bgr)
174
+ elif args.model_type == 'mattingrefine':
175
+ pha, fgr, _, _, err, ref = model(src, bgr)
176
+ elif args.model_type == 'mattingbm':
177
+ pha, fgr = model(src, bgr)
178
+
179
+ if 'com' in args.output_types:
180
+ if args.output_format == 'video':
181
+ # Output composite with green background
182
+ com = fgr * pha + tgt_bgr * (1 - pha)
183
+ com_writer.add_batch(com)
184
+ else:
185
+ # Output composite as rgba png images
186
+ com = torch.cat([fgr * pha.ne(0), pha], dim=1)
187
+ com_writer.add_batch(com)
188
+ if 'pha' in args.output_types:
189
+ pha_writer.add_batch(pha)
190
+ if 'fgr' in args.output_types:
191
+ fgr_writer.add_batch(fgr)
192
+ if 'err' in args.output_types:
193
+ err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
194
+ if 'ref' in args.output_types:
195
+ ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
196
+
197
+ return './output/com.mp4'
198
+
199
+ # 读取本地视频文件的二进制数据
200
+ def get_video_content(video_path):
201
+ with open(video_path, 'rb') as file:
202
+ video_content = file.read()
203
+ return video_content
204
+
205
+ # 假设你的视频文件路径是'./local_video.mp4'
206
+ local_video_path = './output/com.mp4'
207
+ local_video_content = get_video_content(local_video_path)
208
+
209
+ # 创建Gradio界面
210
+ with gr.Blocks() as demo:
211
+ gr.Markdown("## Video Matting")
212
+ with gr.Row():
213
+ video_src = gr.File(label="Upload Source Video (.mp4)", type="binary", file_types=["mp4"])
214
+ video_bgr = gr.File(label="Upload Background Image (.png)", type="binary", file_types=["png"])
215
+ with gr.Row():
216
+ output_video = gr.Video(label="Result Video")
217
+ submit_button = gr.Button("Start Matting")
218
+
219
+ # def download_video(video_path):
220
+ # if os.path.exists(video_path):
221
+ # with open(video_path, 'rb') as file:
222
+ # video_data = file.read()
223
+ # return video_data, "video/mp4", os.path.basename(video_path)
224
+ # else:
225
+ # return "Not Found", "text/plain", None
226
+
227
+ def clear_outputs():
228
+ output_video.update(value=None)
229
+
230
+ submit_button.click(
231
+ fn=video_matting,
232
+ inputs=[video_src, video_bgr],
233
+ outputs=[output_video]
234
+ )
235
+ # download_button = gr.Button("Download")
236
+ # download_button.click(
237
+ # download_video,
238
+ # inputs=[output_video], # 从视频组件传递视频路径
239
+ # outputs=[gr.File(label="Download")]
240
+ # )
241
+ clear_button = gr.Button("Clear")
242
+ clear_button.click(fn=clear_outputs, inputs=[], outputs=[])
243
+
244
+ if __name__ == "__main__":
245
+ demo.launch()
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import Base, MattingBase, MattingRefine
model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (230 Bytes). View file
 
model/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (2.08 kB). View file
 
model/__pycache__/mobilenet.cpython-38.pyc ADDED
Binary file (1.85 kB). View file
 
model/__pycache__/model.cpython-38.pyc ADDED
Binary file (8.26 kB). View file
 
model/__pycache__/refiner.cpython-38.pyc ADDED
Binary file (9.4 kB). View file
 
model/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (1.66 kB). View file
 
model/__pycache__/utils.cpython-38.pyc ADDED
Binary file (654 Bytes). View file
 
model/decoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Decoder(nn.Module):
7
+ """
8
+ Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
9
+
10
+ Input:
11
+ x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
12
+ x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
13
+ x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
14
+ x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
15
+ x0: (B, C, H, W) feature map at full resolution.
16
+
17
+ Output:
18
+ x: (B, C, H, W) upsampled output at full resolution.
19
+ """
20
+
21
+ def __init__(self, channels, feature_channels):
22
+ super().__init__()
23
+ self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
24
+ self.bn1 = nn.BatchNorm2d(channels[1])
25
+ self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
26
+ self.bn2 = nn.BatchNorm2d(channels[2])
27
+ self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(channels[3])
29
+ self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
30
+ self.relu = nn.ReLU(True)
31
+
32
+ def forward(self, x4, x3, x2, x1, x0):
33
+ x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
34
+ x = torch.cat([x, x3], dim=1)
35
+ x = self.conv1(x)
36
+ x = self.bn1(x)
37
+ x = self.relu(x)
38
+ x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
39
+ x = torch.cat([x, x2], dim=1)
40
+ x = self.conv2(x)
41
+ x = self.bn2(x)
42
+ x = self.relu(x)
43
+ x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
44
+ x = torch.cat([x, x1], dim=1)
45
+ x = self.conv3(x)
46
+ x = self.bn3(x)
47
+ x = self.relu(x)
48
+ x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
49
+ x = torch.cat([x, x0], dim=1)
50
+ x = self.conv4(x)
51
+ return x
model/mobilenet.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchvision.models import MobileNetV2
3
+
4
+
5
+ class MobileNetV2Encoder(MobileNetV2):
6
+ """
7
+ MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
8
+ use dilation on the last block to maintain output stride 16, and deleted the
9
+ classifier block that was originally used for classification. The forward method
10
+ additionally returns the feature maps at all resolutions for decoder's use.
11
+ """
12
+
13
+ def __init__(self, in_channels, norm_layer=None):
14
+ super().__init__()
15
+
16
+ # Replace first conv layer if in_channels doesn't match.
17
+ if in_channels != 3:
18
+ self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
19
+
20
+ # Remove last block
21
+ self.features = self.features[:-1]
22
+
23
+ # Change to use dilation to maintain output stride = 16
24
+ self.features[14].conv[1][0].stride = (1, 1)
25
+ for feature in self.features[15:]:
26
+ feature.conv[1][0].dilation = (2, 2)
27
+ feature.conv[1][0].padding = (2, 2)
28
+
29
+ # Delete classifier
30
+ del self.classifier
31
+
32
+ def forward(self, x):
33
+ x0 = x # 1/1
34
+ x = self.features[0](x)
35
+ x = self.features[1](x)
36
+ x1 = x # 1/2
37
+ x = self.features[2](x)
38
+ x = self.features[3](x)
39
+ x2 = x # 1/4
40
+ x = self.features[4](x)
41
+ x = self.features[5](x)
42
+ x = self.features[6](x)
43
+ x3 = x # 1/8
44
+ x = self.features[7](x)
45
+ x = self.features[8](x)
46
+ x = self.features[9](x)
47
+ x = self.features[10](x)
48
+ x = self.features[11](x)
49
+ x = self.features[12](x)
50
+ x = self.features[13](x)
51
+ x = self.features[14](x)
52
+ x = self.features[15](x)
53
+ x = self.features[16](x)
54
+ x = self.features[17](x)
55
+ x4 = x # 1/16
56
+ return x4, x3, x2, x1, x0
model/model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torchvision.models.segmentation.deeplabv3 import ASPP
5
+
6
+ from .decoder import Decoder
7
+ from .mobilenet import MobileNetV2Encoder
8
+ from .refiner import Refiner
9
+ from .resnet import ResNetEncoder
10
+ from .utils import load_matched_state_dict
11
+
12
+
13
+ class Base(nn.Module):
14
+ """
15
+ A generic implementation of the base encoder-decoder network inspired by DeepLab.
16
+ Accepts arbitrary channels for input and output.
17
+ """
18
+
19
+ def __init__(self, backbone: str, in_channels: int, out_channels: int):
20
+ super().__init__()
21
+ assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
22
+ if backbone in ['resnet50', 'resnet101']:
23
+ self.backbone = ResNetEncoder(in_channels, variant=backbone)
24
+ self.aspp = ASPP(2048, [3, 6, 9])
25
+ self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
26
+ else:
27
+ self.backbone = MobileNetV2Encoder(in_channels)
28
+ self.aspp = ASPP(320, [3, 6, 9])
29
+ self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
30
+
31
+ def forward(self, x):
32
+ x, *shortcuts = self.backbone(x)
33
+ x = self.aspp(x)
34
+ x = self.decoder(x, *shortcuts)
35
+ return x
36
+
37
+ def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
38
+ # Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.
39
+ # This method converts and loads their pretrained state_dict to match with our model structure.
40
+ # This method is not needed if you are not planning to train from deeplab weights.
41
+ # Use load_state_dict() for normal weight loading.
42
+
43
+ # Convert state_dict naming for aspp module
44
+ state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
45
+
46
+ if isinstance(self.backbone, ResNetEncoder):
47
+ # ResNet backbone does not need change.
48
+ load_matched_state_dict(self, state_dict, print_stats)
49
+ else:
50
+ # Change MobileNetV2 backbone to state_dict format, then change back after loading.
51
+ backbone_features = self.backbone.features
52
+ self.backbone.low_level_features = backbone_features[:4]
53
+ self.backbone.high_level_features = backbone_features[4:]
54
+ del self.backbone.features
55
+ load_matched_state_dict(self, state_dict, print_stats)
56
+ self.backbone.features = backbone_features
57
+ del self.backbone.low_level_features
58
+ del self.backbone.high_level_features
59
+
60
+
61
+ class MattingBase(Base):
62
+ """
63
+ MattingBase is used to produce coarse global results at a lower resolution.
64
+ MattingBase extends Base.
65
+
66
+ Args:
67
+ backbone: ["resnet50", "resnet101", "mobilenetv2"]
68
+
69
+ Input:
70
+ src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
71
+ bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
72
+
73
+ Output:
74
+ pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
75
+ fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
76
+ err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
77
+ hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
78
+
79
+ Example:
80
+ model = MattingBase(backbone='resnet50')
81
+
82
+ pha, fgr, err, hid = model(src, bgr) # for training
83
+ pha, fgr = model(src, bgr)[:2] # for inference
84
+ """
85
+
86
+ def __init__(self, backbone: str):
87
+ super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
88
+
89
+ def forward(self, src, bgr):
90
+ x = torch.cat([src, bgr], dim=1)
91
+ x, *shortcuts = self.backbone(x)
92
+ x = self.aspp(x)
93
+ x = self.decoder(x, *shortcuts)
94
+ pha = x[:, 0:1].clamp_(0., 1.)
95
+ fgr = x[:, 1:4].add(src).clamp_(0., 1.)
96
+ err = x[:, 4:5].clamp_(0., 1.)
97
+ hid = x[:, 5: ].relu_()
98
+ return pha, fgr, err, hid
99
+
100
+
101
+ class MattingRefine(MattingBase):
102
+ """
103
+ MattingRefine includes the refiner module to upsample coarse result to full resolution.
104
+ MattingRefine extends MattingBase.
105
+
106
+ Args:
107
+ backbone: ["resnet50", "resnet101", "mobilenetv2"]
108
+ backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
109
+ Must not be greater than 1/2.
110
+ refine_mode: refine area selection mode. Options:
111
+ "full" - No area selection, refine everywhere using regular Conv2d.
112
+ "sampling" - Refine fixed amount of pixels ranked by the top most errors.
113
+ "thresholding" - Refine varying amount of pixels that has more error than the threshold.
114
+ refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
115
+ refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
116
+ refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
117
+ refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
118
+
119
+ Input:
120
+ src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
121
+ bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
122
+
123
+ Output:
124
+ pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
125
+ fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
126
+ pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
127
+ fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
128
+ err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
129
+ ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
130
+
131
+ Example:
132
+ model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
133
+ model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
134
+ model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
135
+
136
+ pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training
137
+ pha, fgr = model(src, bgr)[:2] # for inference
138
+ """
139
+
140
+ def __init__(self,
141
+ backbone: str,
142
+ backbone_scale: float = 1/4,
143
+ refine_mode: str = 'sampling',
144
+ refine_sample_pixels: int = 80_000,
145
+ refine_threshold: float = 0.1,
146
+ refine_kernel_size: int = 3,
147
+ refine_prevent_oversampling: bool = True,
148
+ refine_patch_crop_method: str = 'unfold',
149
+ refine_patch_replace_method: str = 'scatter_nd'):
150
+ assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
151
+ super().__init__(backbone)
152
+ self.backbone_scale = backbone_scale
153
+ self.refiner = Refiner(refine_mode,
154
+ refine_sample_pixels,
155
+ refine_threshold,
156
+ refine_kernel_size,
157
+ refine_prevent_oversampling,
158
+ refine_patch_crop_method,
159
+ refine_patch_replace_method)
160
+
161
+ def forward(self, src, bgr):
162
+ assert src.size() == bgr.size(), 'src and bgr must have the same shape'
163
+ assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
164
+ 'src and bgr must have width and height that are divisible by 4'
165
+
166
+ # Downsample src and bgr for backbone
167
+ src_sm = F.interpolate(src,
168
+ scale_factor=self.backbone_scale,
169
+ mode='bilinear',
170
+ align_corners=False,
171
+ recompute_scale_factor=True)
172
+ bgr_sm = F.interpolate(bgr,
173
+ scale_factor=self.backbone_scale,
174
+ mode='bilinear',
175
+ align_corners=False,
176
+ recompute_scale_factor=True)
177
+
178
+ # Base
179
+ x = torch.cat([src_sm, bgr_sm], dim=1)
180
+ x, *shortcuts = self.backbone(x)
181
+ x = self.aspp(x)
182
+ x = self.decoder(x, *shortcuts)
183
+ pha_sm = x[:, 0:1].clamp_(0., 1.)
184
+ fgr_sm = x[:, 1:4]
185
+ err_sm = x[:, 4:5].clamp_(0., 1.)
186
+ hid_sm = x[:, 5: ].relu_()
187
+
188
+ # Refiner
189
+ pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
190
+
191
+ # Clamp outputs
192
+ pha = pha.clamp_(0., 1.)
193
+ fgr = fgr.add_(src).clamp_(0., 1.)
194
+ fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
195
+
196
+ return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm
model/refiner.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from typing import Tuple
6
+
7
+
8
+ class Refiner(nn.Module):
9
+ """
10
+ Refiner refines the coarse output to full resolution.
11
+
12
+ Args:
13
+ mode: area selection mode. Options:
14
+ "full" - No area selection, refine everywhere using regular Conv2d.
15
+ "sampling" - Refine fixed amount of pixels ranked by the top most errors.
16
+ "thresholding" - Refine varying amount of pixels that have greater error than the threshold.
17
+ sample_pixels: number of pixels to refine. Only used when mode == "sampling".
18
+ threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
19
+ kernel_size: The convolution kernel_size. Options: [1, 3]
20
+ prevent_oversampling: True for regular cases, False for speedtest.
21
+
22
+ Compatibility Args:
23
+ patch_crop_method: the method for cropping patches. Options:
24
+ "unfold" - Best performance for PyTorch and TorchScript.
25
+ "roi_align" - Another way for croping patches.
26
+ "gather" - Another way for croping patches.
27
+ patch_replace_method: the method for replacing patches. Options:
28
+ "scatter_nd" - Best performance for PyTorch and TorchScript.
29
+ "scatter_element" - Another way for replacing patches.
30
+
31
+ Input:
32
+ src: (B, 3, H, W) full resolution source image.
33
+ bgr: (B, 3, H, W) full resolution background image.
34
+ pha: (B, 1, Hc, Wc) coarse alpha prediction.
35
+ fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
36
+ err: (B, 1, Hc, Hc) coarse error prediction.
37
+ hid: (B, 32, Hc, Hc) coarse hidden encoding.
38
+
39
+ Output:
40
+ pha: (B, 1, H, W) full resolution alpha prediction.
41
+ fgr: (B, 3, H, W) full resolution foreground residual prediction.
42
+ ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
43
+ """
44
+
45
+ # For TorchScript export optimization.
46
+ __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
47
+
48
+ def __init__(self,
49
+ mode: str,
50
+ sample_pixels: int,
51
+ threshold: float,
52
+ kernel_size: int = 3,
53
+ prevent_oversampling: bool = True,
54
+ patch_crop_method: str = 'unfold',
55
+ patch_replace_method: str = 'scatter_nd'):
56
+ super().__init__()
57
+ assert mode in ['full', 'sampling', 'thresholding']
58
+ assert kernel_size in [1, 3]
59
+ assert patch_crop_method in ['unfold', 'roi_align', 'gather']
60
+ assert patch_replace_method in ['scatter_nd', 'scatter_element']
61
+
62
+ self.mode = mode
63
+ self.sample_pixels = sample_pixels
64
+ self.threshold = threshold
65
+ self.kernel_size = kernel_size
66
+ self.prevent_oversampling = prevent_oversampling
67
+ self.patch_crop_method = patch_crop_method
68
+ self.patch_replace_method = patch_replace_method
69
+
70
+ channels = [32, 24, 16, 12, 4]
71
+ self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
72
+ self.bn1 = nn.BatchNorm2d(channels[1])
73
+ self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
74
+ self.bn2 = nn.BatchNorm2d(channels[2])
75
+ self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
76
+ self.bn3 = nn.BatchNorm2d(channels[3])
77
+ self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
78
+ self.relu = nn.ReLU(True)
79
+
80
+ def forward(self,
81
+ src: torch.Tensor,
82
+ bgr: torch.Tensor,
83
+ pha: torch.Tensor,
84
+ fgr: torch.Tensor,
85
+ err: torch.Tensor,
86
+ hid: torch.Tensor):
87
+ H_full, W_full = src.shape[2:]
88
+ H_half, W_half = H_full // 2, W_full // 2
89
+ H_quat, W_quat = H_full // 4, W_full // 4
90
+
91
+ src_bgr = torch.cat([src, bgr], dim=1)
92
+
93
+ if self.mode != 'full':
94
+ err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
95
+ ref = self.select_refinement_regions(err)
96
+ idx = torch.nonzero(ref.squeeze(1))
97
+ idx = idx[:, 0], idx[:, 1], idx[:, 2]
98
+
99
+ if idx[0].size(0) > 0:
100
+ x = torch.cat([hid, pha, fgr], dim=1)
101
+ x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
102
+ x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
103
+
104
+ y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
105
+ y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
106
+
107
+ x = self.conv1(torch.cat([x, y], dim=1))
108
+ x = self.bn1(x)
109
+ x = self.relu(x)
110
+ x = self.conv2(x)
111
+ x = self.bn2(x)
112
+ x = self.relu(x)
113
+
114
+ x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
115
+ y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
116
+
117
+ x = self.conv3(torch.cat([x, y], dim=1))
118
+ x = self.bn3(x)
119
+ x = self.relu(x)
120
+ x = self.conv4(x)
121
+
122
+ out = torch.cat([pha, fgr], dim=1)
123
+ out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
124
+ out = self.replace_patch(out, x, idx)
125
+ pha = out[:, :1]
126
+ fgr = out[:, 1:]
127
+ else:
128
+ pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
129
+ fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
130
+ else:
131
+ x = torch.cat([hid, pha, fgr], dim=1)
132
+ x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
133
+ y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
134
+ if self.kernel_size == 3:
135
+ x = F.pad(x, (3, 3, 3, 3))
136
+ y = F.pad(y, (3, 3, 3, 3))
137
+
138
+ x = self.conv1(torch.cat([x, y], dim=1))
139
+ x = self.bn1(x)
140
+ x = self.relu(x)
141
+ x = self.conv2(x)
142
+ x = self.bn2(x)
143
+ x = self.relu(x)
144
+
145
+ if self.kernel_size == 3:
146
+ x = F.interpolate(x, (H_full + 4, W_full + 4))
147
+ y = F.pad(src_bgr, (2, 2, 2, 2))
148
+ else:
149
+ x = F.interpolate(x, (H_full, W_full), mode='nearest')
150
+ y = src_bgr
151
+
152
+ x = self.conv3(torch.cat([x, y], dim=1))
153
+ x = self.bn3(x)
154
+ x = self.relu(x)
155
+ x = self.conv4(x)
156
+
157
+ pha = x[:, :1]
158
+ fgr = x[:, 1:]
159
+ ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
160
+
161
+ return pha, fgr, ref
162
+
163
+ def select_refinement_regions(self, err: torch.Tensor):
164
+ """
165
+ Select refinement regions.
166
+ Input:
167
+ err: error map (B, 1, H, W)
168
+ Output:
169
+ ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
170
+ """
171
+ if self.mode == 'sampling':
172
+ # Sampling mode.
173
+ b, _, h, w = err.shape
174
+ err = err.view(b, -1)
175
+ idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
176
+ ref = torch.zeros_like(err)
177
+ ref.scatter_(1, idx, 1.)
178
+ if self.prevent_oversampling:
179
+ ref.mul_(err.gt(0).float())
180
+ ref = ref.view(b, 1, h, w)
181
+ else:
182
+ # Thresholding mode.
183
+ ref = err.gt(self.threshold).float()
184
+ return ref
185
+
186
+ def crop_patch(self,
187
+ x: torch.Tensor,
188
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
189
+ size: int,
190
+ padding: int):
191
+ """
192
+ Crops selected patches from image given indices.
193
+
194
+ Inputs:
195
+ x: image (B, C, H, W).
196
+ idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
197
+ size: center size of the patch, also stride of the crop.
198
+ padding: expansion size of the patch.
199
+ Output:
200
+ patch: (P, C, h, w), where h = w = size + 2 * padding.
201
+ """
202
+ if padding != 0:
203
+ x = F.pad(x, (padding,) * 4)
204
+
205
+ if self.patch_crop_method == 'unfold':
206
+ # Use unfold. Best performance for PyTorch and TorchScript.
207
+ return x.permute(0, 2, 3, 1) \
208
+ .unfold(1, size + 2 * padding, size) \
209
+ .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
210
+ elif self.patch_crop_method == 'roi_align':
211
+ # Use roi_align. Best compatibility for ONNX.
212
+ idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
213
+ b = idx[0]
214
+ x1 = idx[2] * size - 0.5
215
+ y1 = idx[1] * size - 0.5
216
+ x2 = idx[2] * size + size + 2 * padding - 0.5
217
+ y2 = idx[1] * size + size + 2 * padding - 0.5
218
+ boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
219
+ return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
220
+ else:
221
+ # Use gather. Crops out patches pixel by pixel.
222
+ idx_pix = self.compute_pixel_indices(x, idx, size, padding)
223
+ pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
224
+ pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
225
+ return pat
226
+
227
+ def replace_patch(self,
228
+ x: torch.Tensor,
229
+ y: torch.Tensor,
230
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
231
+ """
232
+ Replaces patches back into image given index.
233
+
234
+ Inputs:
235
+ x: image (B, C, H, W)
236
+ y: patches (P, C, h, w)
237
+ idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
238
+
239
+ Output:
240
+ image: (B, C, H, W), where patches at idx locations are replaced with y.
241
+ """
242
+ xB, xC, xH, xW = x.shape
243
+ yB, yC, yH, yW = y.shape
244
+ if self.patch_replace_method == 'scatter_nd':
245
+ # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
246
+ x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
247
+ x[idx[0], idx[1], idx[2]] = y
248
+ x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
249
+ return x
250
+ else:
251
+ # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
252
+ idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
253
+ return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
254
+
255
+ def compute_pixel_indices(self,
256
+ x: torch.Tensor,
257
+ idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
258
+ size: int,
259
+ padding: int):
260
+ """
261
+ Compute selected pixel indices in the tensor.
262
+ Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
263
+ Input:
264
+ x: image: (B, C, H, W)
265
+ idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
266
+ size: center size of the patch, also stride of the crop.
267
+ padding: expansion size of the patch.
268
+ Output:
269
+ idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
270
+ the element are indices pointing to the input x.view(-1).
271
+ """
272
+ B, C, H, W = x.shape
273
+ S, P = size, padding
274
+ O = S + 2 * P
275
+ b, y, x = idx
276
+ n = b.size(0)
277
+ c = torch.arange(C)
278
+ o = torch.arange(O)
279
+ idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
280
+ idx_loc = b * W * H + y * W * S + x * S
281
+ idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
282
+ return idx_pix
model/resnet.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchvision.models.resnet import ResNet, Bottleneck
3
+
4
+
5
+ class ResNetEncoder(ResNet):
6
+ """
7
+ ResNetEncoder inherits from torchvision's official ResNet. It is modified to
8
+ use dilation on the last block to maintain output stride 16, and deleted the
9
+ global average pooling layer and the fully connected layer that was originally
10
+ used for classification. The forward method additionally returns the feature
11
+ maps at all resolutions for decoder's use.
12
+ """
13
+
14
+ layers = {
15
+ 'resnet50': [3, 4, 6, 3],
16
+ 'resnet101': [3, 4, 23, 3],
17
+ }
18
+
19
+ def __init__(self, in_channels, variant='resnet101', norm_layer=None):
20
+ super().__init__(
21
+ block=Bottleneck,
22
+ layers=self.layers[variant],
23
+ replace_stride_with_dilation=[False, False, True],
24
+ norm_layer=norm_layer)
25
+
26
+ # Replace first conv layer if in_channels doesn't match.
27
+ if in_channels != 3:
28
+ self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
29
+
30
+ # Delete fully-connected layer
31
+ del self.avgpool
32
+ del self.fc
33
+
34
+ def forward(self, x):
35
+ x0 = x # 1/1
36
+ x = self.conv1(x)
37
+ x = self.bn1(x)
38
+ x = self.relu(x)
39
+ x1 = x # 1/2
40
+ x = self.maxpool(x)
41
+ x = self.layer1(x)
42
+ x2 = x # 1/4
43
+ x = self.layer2(x)
44
+ x3 = x # 1/8
45
+ x = self.layer3(x)
46
+ x = self.layer4(x)
47
+ x4 = x # 1/16
48
+ return x4, x3, x2, x1, x0
model/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def load_matched_state_dict(model, state_dict, print_stats=True):
2
+ """
3
+ Only loads weights that matched in key and shape. Ignore other weights.
4
+ """
5
+ num_matched, num_total = 0, 0
6
+ curr_state_dict = model.state_dict()
7
+ for key in curr_state_dict.keys():
8
+ num_total += 1
9
+ if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
10
+ curr_state_dict[key] = state_dict[key]
11
+ num_matched += 1
12
+ model.load_state_dict(curr_state_dict)
13
+ if print_stats:
14
+ print(f'Loaded state_dict: {num_matched}/{num_total} matched')
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kornia==0.4.1
2
+ tensorboard==2.3.0
3
+ torch==1.7.0
4
+ torchvision==0.8.1
5
+ tqdm==4.51.0
6
+ opencv-python==4.4.0.44
7
+ onnxruntime==1.6.0
8
+ gradio
9
+ matplotlib
10
+ fastapi
11
+ aiohttp
train_base.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train MattingBase
3
+
4
+ You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
5
+
6
+ Example:
7
+
8
+ CUDA_VISIBLE_DEVICES=0 python train_base.py \
9
+ --dataset-name videomatte240k \
10
+ --model-backbone resnet50 \
11
+ --model-name mattingbase-resnet50-videomatte240k \
12
+ --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
13
+ --epoch-end 8
14
+
15
+ """
16
+
17
+ import argparse
18
+ import kornia
19
+ import torch
20
+ import os
21
+ import random
22
+
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+ from torch.cuda.amp import autocast, GradScaler
26
+ from torch.utils.tensorboard import SummaryWriter
27
+ from torch.utils.data import DataLoader
28
+ from torch.optim import Adam
29
+ from torchvision.utils import make_grid
30
+ from tqdm import tqdm
31
+ from torchvision import transforms as T
32
+ from PIL import Image
33
+
34
+ from data_path import DATA_PATH
35
+ from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
36
+ from dataset import augmentation as A
37
+ from model import MattingBase
38
+ from model.utils import load_matched_state_dict
39
+
40
+
41
+ # --------------- Arguments ---------------
42
+
43
+
44
+ parser = argparse.ArgumentParser()
45
+
46
+ parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
47
+
48
+ parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
49
+ parser.add_argument('--model-name', type=str, required=True)
50
+ parser.add_argument('--model-pretrain-initialization', type=str, default=None)
51
+ parser.add_argument('--model-last-checkpoint', type=str, default=None)
52
+
53
+ parser.add_argument('--batch-size', type=int, default=8)
54
+ parser.add_argument('--num-workers', type=int, default=16)
55
+ parser.add_argument('--epoch-start', type=int, default=0)
56
+ parser.add_argument('--epoch-end', type=int, required=True)
57
+
58
+ parser.add_argument('--log-train-loss-interval', type=int, default=10)
59
+ parser.add_argument('--log-train-images-interval', type=int, default=2000)
60
+ parser.add_argument('--log-valid-interval', type=int, default=5000)
61
+
62
+ parser.add_argument('--checkpoint-interval', type=int, default=5000)
63
+
64
+ args = parser.parse_args()
65
+
66
+
67
+ # --------------- Loading ---------------
68
+
69
+
70
+ def train():
71
+
72
+ # Training DataLoader
73
+ dataset_train = ZipDataset([
74
+ ZipDataset([
75
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
76
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
77
+ ], transforms=A.PairCompose([
78
+ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
79
+ A.PairRandomHorizontalFlip(),
80
+ A.PairRandomBoxBlur(0.1, 5),
81
+ A.PairRandomSharpen(0.1),
82
+ A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
83
+ A.PairApply(T.ToTensor())
84
+ ]), assert_equal_length=True),
85
+ ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
86
+ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
87
+ T.RandomHorizontalFlip(),
88
+ A.RandomBoxBlur(0.1, 5),
89
+ A.RandomSharpen(0.1),
90
+ T.ColorJitter(0.15, 0.15, 0.15, 0.05),
91
+ T.ToTensor()
92
+ ])),
93
+ ])
94
+ dataloader_train = DataLoader(dataset_train,
95
+ shuffle=True,
96
+ batch_size=args.batch_size,
97
+ num_workers=args.num_workers,
98
+ pin_memory=True)
99
+
100
+ # Validation DataLoader
101
+ dataset_valid = ZipDataset([
102
+ ZipDataset([
103
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
104
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
105
+ ], transforms=A.PairCompose([
106
+ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
107
+ A.PairApply(T.ToTensor())
108
+ ]), assert_equal_length=True),
109
+ ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
110
+ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
111
+ T.ToTensor()
112
+ ])),
113
+ ])
114
+ dataset_valid = SampleDataset(dataset_valid, 50)
115
+ dataloader_valid = DataLoader(dataset_valid,
116
+ pin_memory=True,
117
+ batch_size=args.batch_size,
118
+ num_workers=args.num_workers)
119
+
120
+ # Model
121
+ model = MattingBase(args.model_backbone).cuda()
122
+
123
+ if args.model_last_checkpoint is not None:
124
+ load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
125
+ elif args.model_pretrain_initialization is not None:
126
+ model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
127
+
128
+ optimizer = Adam([
129
+ {'params': model.backbone.parameters(), 'lr': 1e-4},
130
+ {'params': model.aspp.parameters(), 'lr': 5e-4},
131
+ {'params': model.decoder.parameters(), 'lr': 5e-4}
132
+ ])
133
+ scaler = GradScaler()
134
+
135
+ # Logging and checkpoints
136
+ if not os.path.exists(f'checkpoint/{args.model_name}'):
137
+ os.makedirs(f'checkpoint/{args.model_name}')
138
+ writer = SummaryWriter(f'log/{args.model_name}')
139
+
140
+ # Run loop
141
+ for epoch in range(args.epoch_start, args.epoch_end):
142
+ for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
143
+ step = epoch * len(dataloader_train) + i
144
+
145
+ true_pha = true_pha.cuda(non_blocking=True)
146
+ true_fgr = true_fgr.cuda(non_blocking=True)
147
+ true_bgr = true_bgr.cuda(non_blocking=True)
148
+ true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
149
+
150
+ true_src = true_bgr.clone()
151
+
152
+ # Augment with shadow
153
+ aug_shadow_idx = torch.rand(len(true_src)) < 0.3
154
+ if aug_shadow_idx.any():
155
+ aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
156
+ aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
157
+ aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
158
+ true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
159
+ del aug_shadow
160
+ del aug_shadow_idx
161
+
162
+ # Composite foreground onto source
163
+ true_src = true_fgr * true_pha + true_src * (1 - true_pha)
164
+
165
+ # Augment with noise
166
+ aug_noise_idx = torch.rand(len(true_src)) < 0.4
167
+ if aug_noise_idx.any():
168
+ true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
169
+ true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
170
+ del aug_noise_idx
171
+
172
+ # Augment background with jitter
173
+ aug_jitter_idx = torch.rand(len(true_src)) < 0.8
174
+ if aug_jitter_idx.any():
175
+ true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
176
+ del aug_jitter_idx
177
+
178
+ # Augment background with affine
179
+ aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
180
+ if aug_affine_idx.any():
181
+ true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
182
+ del aug_affine_idx
183
+
184
+ with autocast():
185
+ pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
186
+ loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
187
+
188
+ scaler.scale(loss).backward()
189
+ scaler.step(optimizer)
190
+ scaler.update()
191
+ optimizer.zero_grad()
192
+
193
+ if (i + 1) % args.log_train_loss_interval == 0:
194
+ writer.add_scalar('loss', loss, step)
195
+
196
+ if (i + 1) % args.log_train_images_interval == 0:
197
+ writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
198
+ writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
199
+ writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
200
+ writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
201
+ writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
202
+ writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
203
+
204
+ del true_pha, true_fgr, true_bgr
205
+ del pred_pha, pred_fgr, pred_err
206
+
207
+ if (i + 1) % args.log_valid_interval == 0:
208
+ valid(model, dataloader_valid, writer, step)
209
+
210
+ if (step + 1) % args.checkpoint_interval == 0:
211
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
212
+
213
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
214
+
215
+
216
+ # --------------- Utils ---------------
217
+
218
+
219
+ def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
220
+ true_err = torch.abs(pred_pha.detach() - true_pha)
221
+ true_msk = true_pha != 0
222
+ return F.l1_loss(pred_pha, true_pha) + \
223
+ F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
224
+ F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
225
+ F.mse_loss(pred_err, true_err)
226
+
227
+
228
+ def random_crop(*imgs):
229
+ w = random.choice(range(256, 512))
230
+ h = random.choice(range(256, 512))
231
+ results = []
232
+ for img in imgs:
233
+ img = kornia.resize(img, (max(h, w), max(h, w)))
234
+ img = kornia.center_crop(img, (h, w))
235
+ results.append(img)
236
+ return results
237
+
238
+
239
+ def valid(model, dataloader, writer, step):
240
+ model.eval()
241
+ loss_total = 0
242
+ loss_count = 0
243
+ with torch.no_grad():
244
+ for (true_pha, true_fgr), true_bgr in dataloader:
245
+ batch_size = true_pha.size(0)
246
+
247
+ true_pha = true_pha.cuda(non_blocking=True)
248
+ true_fgr = true_fgr.cuda(non_blocking=True)
249
+ true_bgr = true_bgr.cuda(non_blocking=True)
250
+ true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
251
+
252
+ pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
253
+ loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
254
+ loss_total += loss.cpu().item() * batch_size
255
+ loss_count += batch_size
256
+
257
+ writer.add_scalar('valid_loss', loss_total / loss_count, step)
258
+ model.train()
259
+
260
+
261
+ # --------------- Start ---------------
262
+
263
+
264
+ if __name__ == '__main__':
265
+ train()
train_refine.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train MattingRefine
3
+
4
+ Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
5
+ Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
6
+
7
+ Example:
8
+
9
+ CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
10
+ --dataset-name videomatte240k \
11
+ --model-backbone resnet50 \
12
+ --model-name mattingrefine-resnet50-videomatte240k \
13
+ --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
14
+ --epoch-end 1
15
+
16
+ """
17
+
18
+ import argparse
19
+ import kornia
20
+ import torch
21
+ import os
22
+ import random
23
+
24
+ from torch import nn
25
+ from torch import distributed as dist
26
+ from torch import multiprocessing as mp
27
+ from torch.nn import functional as F
28
+ from torch.cuda.amp import autocast, GradScaler
29
+ from torch.utils.tensorboard import SummaryWriter
30
+ from torch.utils.data import DataLoader, Subset
31
+ from torch.optim import Adam
32
+ from torchvision.utils import make_grid
33
+ from tqdm import tqdm
34
+ from torchvision import transforms as T
35
+ from PIL import Image
36
+
37
+ from data_path import DATA_PATH
38
+ from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
39
+ from dataset import augmentation as A
40
+ from model import MattingRefine
41
+ from model.utils import load_matched_state_dict
42
+
43
+
44
+ # --------------- Arguments ---------------
45
+
46
+
47
+ parser = argparse.ArgumentParser()
48
+
49
+ parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
50
+
51
+ parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
52
+ parser.add_argument('--model-backbone-scale', type=float, default=0.25)
53
+ parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
54
+ parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
55
+ parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
56
+ parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
57
+ parser.add_argument('--model-name', type=str, required=True)
58
+ parser.add_argument('--model-last-checkpoint', type=str, default=None)
59
+
60
+ parser.add_argument('--batch-size', type=int, default=4)
61
+ parser.add_argument('--num-workers', type=int, default=16)
62
+ parser.add_argument('--epoch-start', type=int, default=0)
63
+ parser.add_argument('--epoch-end', type=int, required=True)
64
+
65
+ parser.add_argument('--log-train-loss-interval', type=int, default=10)
66
+ parser.add_argument('--log-train-images-interval', type=int, default=1000)
67
+ parser.add_argument('--log-valid-interval', type=int, default=2000)
68
+
69
+ parser.add_argument('--checkpoint-interval', type=int, default=2000)
70
+
71
+ args = parser.parse_args()
72
+
73
+
74
+ distributed_num_gpus = torch.cuda.device_count()
75
+ assert args.batch_size % distributed_num_gpus == 0
76
+
77
+
78
+ # --------------- Main ---------------
79
+
80
+ def train_worker(rank, addr, port):
81
+
82
+ # Distributed Setup
83
+ os.environ['MASTER_ADDR'] = addr
84
+ os.environ['MASTER_PORT'] = port
85
+ dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
86
+
87
+ # Training DataLoader
88
+ dataset_train = ZipDataset([
89
+ ZipDataset([
90
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
91
+ ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
92
+ ], transforms=A.PairCompose([
93
+ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
94
+ A.PairRandomHorizontalFlip(),
95
+ A.PairRandomBoxBlur(0.1, 5),
96
+ A.PairRandomSharpen(0.1),
97
+ A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
98
+ A.PairApply(T.ToTensor())
99
+ ]), assert_equal_length=True),
100
+ ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
101
+ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
102
+ T.RandomHorizontalFlip(),
103
+ A.RandomBoxBlur(0.1, 5),
104
+ A.RandomSharpen(0.1),
105
+ T.ColorJitter(0.15, 0.15, 0.15, 0.05),
106
+ T.ToTensor()
107
+ ])),
108
+ ])
109
+ dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
110
+ dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
111
+ dataloader_train = DataLoader(dataset_train,
112
+ shuffle=True,
113
+ pin_memory=True,
114
+ drop_last=True,
115
+ batch_size=args.batch_size // distributed_num_gpus,
116
+ num_workers=args.num_workers // distributed_num_gpus)
117
+
118
+ # Validation DataLoader
119
+ if rank == 0:
120
+ dataset_valid = ZipDataset([
121
+ ZipDataset([
122
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
123
+ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
124
+ ], transforms=A.PairCompose([
125
+ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
126
+ A.PairApply(T.ToTensor())
127
+ ]), assert_equal_length=True),
128
+ ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
129
+ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
130
+ T.ToTensor()
131
+ ])),
132
+ ])
133
+ dataset_valid = SampleDataset(dataset_valid, 50)
134
+ dataloader_valid = DataLoader(dataset_valid,
135
+ pin_memory=True,
136
+ drop_last=True,
137
+ batch_size=args.batch_size // distributed_num_gpus,
138
+ num_workers=args.num_workers // distributed_num_gpus)
139
+
140
+ # Model
141
+ model = MattingRefine(args.model_backbone,
142
+ args.model_backbone_scale,
143
+ args.model_refine_mode,
144
+ args.model_refine_sample_pixels,
145
+ args.model_refine_thresholding,
146
+ args.model_refine_kernel_size).to(rank)
147
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
148
+ model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
149
+
150
+ if args.model_last_checkpoint is not None:
151
+ load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
152
+
153
+ optimizer = Adam([
154
+ {'params': model.backbone.parameters(), 'lr': 5e-5},
155
+ {'params': model.aspp.parameters(), 'lr': 5e-5},
156
+ {'params': model.decoder.parameters(), 'lr': 1e-4},
157
+ {'params': model.refiner.parameters(), 'lr': 3e-4},
158
+ ])
159
+ scaler = GradScaler()
160
+
161
+ # Logging and checkpoints
162
+ if rank == 0:
163
+ if not os.path.exists(f'checkpoint/{args.model_name}'):
164
+ os.makedirs(f'checkpoint/{args.model_name}')
165
+ writer = SummaryWriter(f'log/{args.model_name}')
166
+
167
+ # Run loop
168
+ for epoch in range(args.epoch_start, args.epoch_end):
169
+ for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
170
+ step = epoch * len(dataloader_train) + i
171
+
172
+ true_pha = true_pha.to(rank, non_blocking=True)
173
+ true_fgr = true_fgr.to(rank, non_blocking=True)
174
+ true_bgr = true_bgr.to(rank, non_blocking=True)
175
+ true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
176
+
177
+ true_src = true_bgr.clone()
178
+
179
+ # Augment with shadow
180
+ aug_shadow_idx = torch.rand(len(true_src)) < 0.3
181
+ if aug_shadow_idx.any():
182
+ aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
183
+ aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
184
+ aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
185
+ true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
186
+ del aug_shadow
187
+ del aug_shadow_idx
188
+
189
+ # Composite foreground onto source
190
+ true_src = true_fgr * true_pha + true_src * (1 - true_pha)
191
+
192
+ # Augment with noise
193
+ aug_noise_idx = torch.rand(len(true_src)) < 0.4
194
+ if aug_noise_idx.any():
195
+ true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
196
+ true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
197
+ del aug_noise_idx
198
+
199
+ # Augment background with jitter
200
+ aug_jitter_idx = torch.rand(len(true_src)) < 0.8
201
+ if aug_jitter_idx.any():
202
+ true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
203
+ del aug_jitter_idx
204
+
205
+ # Augment background with affine
206
+ aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
207
+ if aug_affine_idx.any():
208
+ true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
209
+ del aug_affine_idx
210
+
211
+ with autocast():
212
+ pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
213
+ loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
214
+
215
+ scaler.scale(loss).backward()
216
+ scaler.step(optimizer)
217
+ scaler.update()
218
+ optimizer.zero_grad()
219
+
220
+ if rank == 0:
221
+ if (i + 1) % args.log_train_loss_interval == 0:
222
+ writer.add_scalar('loss', loss, step)
223
+
224
+ if (i + 1) % args.log_train_images_interval == 0:
225
+ writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
226
+ writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
227
+ writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
228
+ writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
229
+ writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
230
+
231
+ del true_pha, true_fgr, true_src, true_bgr
232
+ del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
233
+
234
+ if (i + 1) % args.log_valid_interval == 0:
235
+ valid(model, dataloader_valid, writer, step)
236
+
237
+ if (step + 1) % args.checkpoint_interval == 0:
238
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
239
+
240
+ if rank == 0:
241
+ torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
242
+
243
+ # Clean up
244
+ dist.destroy_process_group()
245
+
246
+
247
+ # --------------- Utils ---------------
248
+
249
+
250
+ def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
251
+ true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
252
+ true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
253
+ true_msk_lg = true_pha_lg != 0
254
+ true_msk_sm = true_pha_sm != 0
255
+ return F.l1_loss(pred_pha_lg, true_pha_lg) + \
256
+ F.l1_loss(pred_pha_sm, true_pha_sm) + \
257
+ F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
258
+ F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
259
+ F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
260
+ F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
261
+ F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
262
+ kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
263
+
264
+
265
+ def random_crop(*imgs):
266
+ H_src, W_src = imgs[0].shape[2:]
267
+ W_tgt = random.choice(range(1024, 2048)) // 4 * 4
268
+ H_tgt = random.choice(range(1024, 2048)) // 4 * 4
269
+ scale = max(W_tgt / W_src, H_tgt / H_src)
270
+ results = []
271
+ for img in imgs:
272
+ img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
273
+ img = kornia.center_crop(img, (H_tgt, W_tgt))
274
+ results.append(img)
275
+ return results
276
+
277
+
278
+ def valid(model, dataloader, writer, step):
279
+ model.eval()
280
+ loss_total = 0
281
+ loss_count = 0
282
+ with torch.no_grad():
283
+ for (true_pha, true_fgr), true_bgr in dataloader:
284
+ batch_size = true_pha.size(0)
285
+
286
+ true_pha = true_pha.cuda(non_blocking=True)
287
+ true_fgr = true_fgr.cuda(non_blocking=True)
288
+ true_bgr = true_bgr.cuda(non_blocking=True)
289
+ true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
290
+
291
+ pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
292
+ loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
293
+ loss_total += loss.cpu().item() * batch_size
294
+ loss_count += batch_size
295
+
296
+ writer.add_scalar('valid_loss', loss_total / loss_count, step)
297
+ model.train()
298
+
299
+
300
+ # --------------- Start ---------------
301
+
302
+
303
+ if __name__ == '__main__':
304
+ addr = 'localhost'
305
+ port = str(random.choice(range(12300, 12400))) # pick a random port.
306
+ mp.spawn(train_worker,
307
+ nprocs=distributed_num_gpus,
308
+ args=(addr, port),
309
+ join=True)