Huang commited on
Commit
75889ad
1 Parent(s): 58cb10d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. annotator/__init__.py +27 -1
  2. annotator/__pycache__/__init__.cpython-39.pyc +0 -0
  3. annotator/leres/__init__.py +113 -0
  4. annotator/leres/__pycache__/__init__.cpython-39.pyc +0 -0
  5. annotator/leres/leres/LICENSE +23 -0
  6. annotator/leres/leres/Resnet.py +199 -0
  7. annotator/leres/leres/Resnext_torch.py +237 -0
  8. annotator/leres/leres/__pycache__/Resnet.cpython-39.pyc +0 -0
  9. annotator/leres/leres/__pycache__/Resnext_torch.cpython-39.pyc +0 -0
  10. annotator/leres/leres/__pycache__/depthmap.cpython-39.pyc +0 -0
  11. annotator/leres/leres/__pycache__/multi_depth_model_woauxi.cpython-39.pyc +0 -0
  12. annotator/leres/leres/__pycache__/net_tools.cpython-39.pyc +0 -0
  13. annotator/leres/leres/__pycache__/network_auxi.cpython-39.pyc +0 -0
  14. annotator/leres/leres/depthmap.py +566 -0
  15. annotator/leres/leres/multi_depth_model_woauxi.py +34 -0
  16. annotator/leres/leres/net_tools.py +54 -0
  17. annotator/leres/leres/network_auxi.py +417 -0
  18. annotator/leres/pix2pix/LICENSE +19 -0
  19. annotator/leres/pix2pix/models/__init__.py +67 -0
  20. annotator/leres/pix2pix/models/__pycache__/__init__.cpython-39.pyc +0 -0
  21. annotator/leres/pix2pix/models/__pycache__/base_model.cpython-39.pyc +0 -0
  22. annotator/leres/pix2pix/models/__pycache__/networks.cpython-39.pyc +0 -0
  23. annotator/leres/pix2pix/models/__pycache__/pix2pix4depth_model.cpython-39.pyc +0 -0
  24. annotator/leres/pix2pix/models/base_model.py +240 -0
  25. annotator/leres/pix2pix/models/base_model_hg.py +58 -0
  26. annotator/leres/pix2pix/models/networks.py +623 -0
  27. annotator/leres/pix2pix/models/pix2pix4depth_model.py +155 -0
  28. annotator/leres/pix2pix/options/__init__.py +1 -0
  29. annotator/leres/pix2pix/options/__pycache__/__init__.cpython-39.pyc +0 -0
  30. annotator/leres/pix2pix/options/__pycache__/base_options.cpython-39.pyc +0 -0
  31. annotator/leres/pix2pix/options/__pycache__/test_options.cpython-39.pyc +0 -0
  32. annotator/leres/pix2pix/options/base_options.py +156 -0
  33. annotator/leres/pix2pix/options/test_options.py +22 -0
  34. annotator/leres/pix2pix/util/__init__.py +1 -0
  35. annotator/leres/pix2pix/util/__pycache__/__init__.cpython-39.pyc +0 -0
  36. annotator/leres/pix2pix/util/__pycache__/util.cpython-39.pyc +0 -0
  37. annotator/leres/pix2pix/util/get_data.py +110 -0
  38. annotator/leres/pix2pix/util/guidedfilter.py +47 -0
  39. annotator/leres/pix2pix/util/html.py +86 -0
  40. annotator/leres/pix2pix/util/image_pool.py +54 -0
  41. annotator/leres/pix2pix/util/util.py +105 -0
  42. annotator/leres/pix2pix/util/visualizer.py +166 -0
  43. annotator/lineart/LICENSE +21 -0
  44. annotator/lineart/__init__.py +129 -0
  45. annotator/lineart/__pycache__/__init__.cpython-39.pyc +0 -0
  46. annotator/lineart_anime/LICENSE +21 -0
  47. annotator/lineart_anime/__init__.py +164 -0
  48. annotator/lineart_anime/__pycache__/__init__.cpython-39.pyc +0 -0
  49. annotator/manga_line/LICENSE +21 -0
  50. annotator/manga_line/__init__.py +247 -0
annotator/__init__.py CHANGED
@@ -6,9 +6,35 @@ from .hed import HedDetector
6
  from .midas import MidasProcessor
7
  from .mlsd import MLSDProcessor
8
  from .uniformer import UniformerDetector
 
 
 
 
 
 
 
 
 
 
9
 
10
  __all__ = [
11
- UniformerDetector, HedDetector, MLSDProcessor, BinaryDetector, CannyDetector, OpenposeDetector, MidasProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ]
13
  #
14
  #
 
6
  from .midas import MidasProcessor
7
  from .mlsd import MLSDProcessor
8
  from .uniformer import UniformerDetector
9
+ from .lineart import LineArtDetector
10
+ from .lineart_anime import LineArtAnimeDetector
11
+ from .manga_line import MangaLineExtration
12
+ from .leres import LeresPix2Pix
13
+ from .mediapipe_face import MediaPipeFace
14
+ from .normalbae import NormalBaeDetector
15
+ from .pidinet import PidInet
16
+ from .shuffle import Image2MaskShuffleDetector
17
+ from .zoe import ZoeDetector
18
+ from .oneformer import OneformerDetector
19
 
20
  __all__ = [
21
+ UniformerDetector,
22
+ HedDetector,
23
+ MLSDProcessor,
24
+ BinaryDetector,
25
+ CannyDetector,
26
+ OpenposeDetector,
27
+ MidasProcessor,
28
+ LineArtDetector,
29
+ LineArtAnimeDetector,
30
+ MangaLineExtration,
31
+ LeresPix2Pix,
32
+ MediaPipeFace,
33
+ NormalBaeDetector,
34
+ PidInet,
35
+ Image2MaskShuffleDetector,
36
+ ZoeDetector,
37
+ OneformerDetector
38
  ]
39
  #
40
  #
annotator/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/annotator/__pycache__/__init__.cpython-39.pyc and b/annotator/__pycache__/__init__.cpython-39.pyc differ
 
annotator/leres/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+
6
+ # AdelaiDepth/LeReS imports
7
+ from .leres.depthmap import estimateleres, estimateboost
8
+ from .leres.multi_depth_model_woauxi import RelDepthModel
9
+ from .leres.net_tools import strip_prefix_if_present
10
+ from annotator.base_annotator import BaseProcessor
11
+
12
+ # pix2pix/merge net imports
13
+ from .pix2pix.options.test_options import TestOptions
14
+ from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
15
+
16
+ # old_modeldir = os.path.dirname(os.path.realpath(__file__))
17
+
18
+ remote_model_path_leres = "https://huggingface.co/lllyasviel/Annotators/resolve/main/res101.pth"
19
+ remote_model_path_pix2pix = "https://huggingface.co/lllyasviel/Annotators/resolve/main/latest_net_G.pth"
20
+
21
+
22
+ class LeresPix2Pix(BaseProcessor):
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+ self.model = None
26
+ self.pix2pixmodel = None
27
+ self.model_dir = os.path.join(self.models_path, "leres")
28
+
29
+ def unload_model(self):
30
+ if self.model is not None:
31
+ self.model = self.model.cpu()
32
+ if self.pix2pixmodel is not None:
33
+ self.pix2pixmodel = self.pix2pixmodel.unload_network('G')
34
+
35
+ def load_model(self):
36
+ model_path = os.path.join(self.model_dir, "res101.pth")
37
+ if not os.path.exists(model_path):
38
+ from basicsr.utils.download_util import load_file_from_url
39
+ load_file_from_url(remote_model_path_leres, model_dir=self.model_dir)
40
+
41
+ if torch.cuda.is_available():
42
+ checkpoint = torch.load(model_path)
43
+ else:
44
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
45
+
46
+ self.model = RelDepthModel(backbone='resnext101')
47
+ self.model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
48
+ del checkpoint
49
+
50
+ def load_pix2pix2_model(self):
51
+ pix2pixmodel_path = os.path.join(self.model_dir, "latest_net_G.pth")
52
+ if not os.path.exists(pix2pixmodel_path):
53
+ from basicsr.utils.download_util import load_file_from_url
54
+ load_file_from_url(remote_model_path_pix2pix, model_dir=self.model_dir)
55
+
56
+ opt = TestOptions().parse()
57
+ if not torch.cuda.is_available():
58
+ opt.gpu_ids = [] # cpu mode
59
+ self.pix2pixmodel = Pix2Pix4DepthModel(opt)
60
+ self.pix2pixmodel.save_dir = self.model_dir
61
+ self.pix2pixmodel.load_networks('latest')
62
+ self.pix2pixmodel.eval()
63
+
64
+ def __call__(self, input_image, thr_a, thr_b, boost=False, **kwargs):
65
+ if self.model is None:
66
+ self.load_model()
67
+ if boost and self.pix2pixmodel is None:
68
+ self.load_pix2pix2_model()
69
+
70
+ if self.device != 'mps':
71
+ self.model = self.model.to(self.device)
72
+
73
+ assert input_image.ndim == 3
74
+ height, width, dim = input_image.shape
75
+
76
+ with torch.no_grad():
77
+
78
+ if boost:
79
+ depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height))
80
+ else:
81
+ depth = estimateleres(input_image, self.model, width, height, self.device)
82
+
83
+ numbytes = 2
84
+ depth_min = depth.min()
85
+ depth_max = depth.max()
86
+ max_val = (2 ** (8 * numbytes)) - 1
87
+
88
+ # check output before normalizing and mapping to 16 bit
89
+ if depth_max - depth_min > np.finfo("float").eps:
90
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
91
+ else:
92
+ out = np.zeros(depth.shape)
93
+
94
+ # single channel, 16 bit image
95
+ depth_image = out.astype("uint16")
96
+
97
+ # convert to uint8
98
+ depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0 / 65535.0))
99
+
100
+ # remove near
101
+ if thr_a != 0:
102
+ thr_a = ((thr_a / 100) * 255)
103
+ depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
104
+
105
+ # invert image
106
+ depth_image = cv2.bitwise_not(depth_image)
107
+
108
+ # remove bg
109
+ if thr_b != 0:
110
+ thr_b = ((thr_b / 100) * 255)
111
+ depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
112
+
113
+ return depth_image
annotator/leres/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
annotator/leres/leres/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://github.com/thygate/stable-diffusion-webui-depthmap-script
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Bob Thiry
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
annotator/leres/leres/Resnet.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn as NN
3
+
4
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
5
+ 'resnet152']
6
+
7
+
8
+ model_urls = {
9
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14
+ }
15
+
16
+
17
+ def conv3x3(in_planes, out_planes, stride=1):
18
+ """3x3 convolution with padding"""
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
+ padding=1, bias=False)
21
+
22
+
23
+ class BasicBlock(nn.Module):
24
+ expansion = 1
25
+
26
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
27
+ super(BasicBlock, self).__init__()
28
+ self.conv1 = conv3x3(inplanes, planes, stride)
29
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.conv2 = conv3x3(planes, planes)
32
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
33
+ self.downsample = downsample
34
+ self.stride = stride
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if self.downsample is not None:
47
+ residual = self.downsample(x)
48
+
49
+ out += residual
50
+ out = self.relu(out)
51
+
52
+ return out
53
+
54
+
55
+ class Bottleneck(nn.Module):
56
+ expansion = 4
57
+
58
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
59
+ super(Bottleneck, self).__init__()
60
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
62
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63
+ padding=1, bias=False)
64
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
65
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
66
+ self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.downsample = downsample
69
+ self.stride = stride
70
+
71
+ def forward(self, x):
72
+ residual = x
73
+
74
+ out = self.conv1(x)
75
+ out = self.bn1(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv2(out)
79
+ out = self.bn2(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv3(out)
83
+ out = self.bn3(out)
84
+
85
+ if self.downsample is not None:
86
+ residual = self.downsample(x)
87
+
88
+ out += residual
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class ResNet(nn.Module):
95
+
96
+ def __init__(self, block, layers, num_classes=1000):
97
+ self.inplanes = 64
98
+ super(ResNet, self).__init__()
99
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
100
+ bias=False)
101
+ self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
+ self.layer1 = self._make_layer(block, 64, layers[0])
105
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
+ #self.avgpool = nn.AvgPool2d(7, stride=1)
109
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
110
+
111
+ for m in self.modules():
112
+ if isinstance(m, nn.Conv2d):
113
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
114
+ elif isinstance(m, nn.BatchNorm2d):
115
+ nn.init.constant_(m.weight, 1)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def _make_layer(self, block, planes, blocks, stride=1):
119
+ downsample = None
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ nn.Conv2d(self.inplanes, planes * block.expansion,
123
+ kernel_size=1, stride=stride, bias=False),
124
+ NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
125
+ )
126
+
127
+ layers = []
128
+ layers.append(block(self.inplanes, planes, stride, downsample))
129
+ self.inplanes = planes * block.expansion
130
+ for i in range(1, blocks):
131
+ layers.append(block(self.inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ features = []
137
+
138
+ x = self.conv1(x)
139
+ x = self.bn1(x)
140
+ x = self.relu(x)
141
+ x = self.maxpool(x)
142
+
143
+ x = self.layer1(x)
144
+ features.append(x)
145
+ x = self.layer2(x)
146
+ features.append(x)
147
+ x = self.layer3(x)
148
+ features.append(x)
149
+ x = self.layer4(x)
150
+ features.append(x)
151
+
152
+ return features
153
+
154
+
155
+ def resnet18(pretrained=True, **kwargs):
156
+ """Constructs a ResNet-18 model.
157
+ Args:
158
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
159
+ """
160
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
161
+ return model
162
+
163
+
164
+ def resnet34(pretrained=True, **kwargs):
165
+ """Constructs a ResNet-34 model.
166
+ Args:
167
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
168
+ """
169
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170
+ return model
171
+
172
+
173
+ def resnet50(pretrained=True, **kwargs):
174
+ """Constructs a ResNet-50 model.
175
+ Args:
176
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
177
+ """
178
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
179
+
180
+ return model
181
+
182
+
183
+ def resnet101(pretrained=True, **kwargs):
184
+ """Constructs a ResNet-101 model.
185
+ Args:
186
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
187
+ """
188
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
189
+
190
+ return model
191
+
192
+
193
+ def resnet152(pretrained=True, **kwargs):
194
+ """Constructs a ResNet-152 model.
195
+ Args:
196
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
197
+ """
198
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
199
+ return model
annotator/leres/leres/Resnext_torch.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch.nn as nn
4
+
5
+ try:
6
+ from urllib import urlretrieve
7
+ except ImportError:
8
+ from urllib.request import urlretrieve
9
+
10
+ __all__ = ['resnext101_32x8d']
11
+
12
+
13
+ model_urls = {
14
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
15
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
16
+ }
17
+
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
20
+ """3x3 convolution with padding"""
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
23
+
24
+
25
+ def conv1x1(in_planes, out_planes, stride=1):
26
+ """1x1 convolution"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
28
+
29
+
30
+ class BasicBlock(nn.Module):
31
+ expansion = 1
32
+
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
34
+ base_width=64, dilation=1, norm_layer=None):
35
+ super(BasicBlock, self).__init__()
36
+ if norm_layer is None:
37
+ norm_layer = nn.BatchNorm2d
38
+ if groups != 1 or base_width != 64:
39
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
40
+ if dilation > 1:
41
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
42
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
43
+ self.conv1 = conv3x3(inplanes, planes, stride)
44
+ self.bn1 = norm_layer(planes)
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.conv2 = conv3x3(planes, planes)
47
+ self.bn2 = norm_layer(planes)
48
+ self.downsample = downsample
49
+ self.stride = stride
50
+
51
+ def forward(self, x):
52
+ identity = x
53
+
54
+ out = self.conv1(x)
55
+ out = self.bn1(out)
56
+ out = self.relu(out)
57
+
58
+ out = self.conv2(out)
59
+ out = self.bn2(out)
60
+
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+
64
+ out += identity
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class Bottleneck(nn.Module):
71
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
72
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
73
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
74
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
75
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
76
+
77
+ expansion = 4
78
+
79
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
80
+ base_width=64, dilation=1, norm_layer=None):
81
+ super(Bottleneck, self).__init__()
82
+ if norm_layer is None:
83
+ norm_layer = nn.BatchNorm2d
84
+ width = int(planes * (base_width / 64.)) * groups
85
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
86
+ self.conv1 = conv1x1(inplanes, width)
87
+ self.bn1 = norm_layer(width)
88
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
89
+ self.bn2 = norm_layer(width)
90
+ self.conv3 = conv1x1(width, planes * self.expansion)
91
+ self.bn3 = norm_layer(planes * self.expansion)
92
+ self.relu = nn.ReLU(inplace=True)
93
+ self.downsample = downsample
94
+ self.stride = stride
95
+
96
+ def forward(self, x):
97
+ identity = x
98
+
99
+ out = self.conv1(x)
100
+ out = self.bn1(out)
101
+ out = self.relu(out)
102
+
103
+ out = self.conv2(out)
104
+ out = self.bn2(out)
105
+ out = self.relu(out)
106
+
107
+ out = self.conv3(out)
108
+ out = self.bn3(out)
109
+
110
+ if self.downsample is not None:
111
+ identity = self.downsample(x)
112
+
113
+ out += identity
114
+ out = self.relu(out)
115
+
116
+ return out
117
+
118
+
119
+ class ResNet(nn.Module):
120
+
121
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
122
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
123
+ norm_layer=None):
124
+ super(ResNet, self).__init__()
125
+ if norm_layer is None:
126
+ norm_layer = nn.BatchNorm2d
127
+ self._norm_layer = norm_layer
128
+
129
+ self.inplanes = 64
130
+ self.dilation = 1
131
+ if replace_stride_with_dilation is None:
132
+ # each element in the tuple indicates if we should replace
133
+ # the 2x2 stride with a dilated convolution instead
134
+ replace_stride_with_dilation = [False, False, False]
135
+ if len(replace_stride_with_dilation) != 3:
136
+ raise ValueError("replace_stride_with_dilation should be None "
137
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
138
+ self.groups = groups
139
+ self.base_width = width_per_group
140
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
141
+ bias=False)
142
+ self.bn1 = norm_layer(self.inplanes)
143
+ self.relu = nn.ReLU(inplace=True)
144
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
145
+ self.layer1 = self._make_layer(block, 64, layers[0])
146
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
147
+ dilate=replace_stride_with_dilation[0])
148
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
149
+ dilate=replace_stride_with_dilation[1])
150
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
151
+ dilate=replace_stride_with_dilation[2])
152
+ #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
154
+
155
+ for m in self.modules():
156
+ if isinstance(m, nn.Conv2d):
157
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159
+ nn.init.constant_(m.weight, 1)
160
+ nn.init.constant_(m.bias, 0)
161
+
162
+ # Zero-initialize the last BN in each residual branch,
163
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
164
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
165
+ if zero_init_residual:
166
+ for m in self.modules():
167
+ if isinstance(m, Bottleneck):
168
+ nn.init.constant_(m.bn3.weight, 0)
169
+ elif isinstance(m, BasicBlock):
170
+ nn.init.constant_(m.bn2.weight, 0)
171
+
172
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
173
+ norm_layer = self._norm_layer
174
+ downsample = None
175
+ previous_dilation = self.dilation
176
+ if dilate:
177
+ self.dilation *= stride
178
+ stride = 1
179
+ if stride != 1 or self.inplanes != planes * block.expansion:
180
+ downsample = nn.Sequential(
181
+ conv1x1(self.inplanes, planes * block.expansion, stride),
182
+ norm_layer(planes * block.expansion),
183
+ )
184
+
185
+ layers = []
186
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
187
+ self.base_width, previous_dilation, norm_layer))
188
+ self.inplanes = planes * block.expansion
189
+ for _ in range(1, blocks):
190
+ layers.append(block(self.inplanes, planes, groups=self.groups,
191
+ base_width=self.base_width, dilation=self.dilation,
192
+ norm_layer=norm_layer))
193
+
194
+ return nn.Sequential(*layers)
195
+
196
+ def _forward_impl(self, x):
197
+ # See note [TorchScript super()]
198
+ features = []
199
+ x = self.conv1(x)
200
+ x = self.bn1(x)
201
+ x = self.relu(x)
202
+ x = self.maxpool(x)
203
+
204
+ x = self.layer1(x)
205
+ features.append(x)
206
+
207
+ x = self.layer2(x)
208
+ features.append(x)
209
+
210
+ x = self.layer3(x)
211
+ features.append(x)
212
+
213
+ x = self.layer4(x)
214
+ features.append(x)
215
+
216
+ #x = self.avgpool(x)
217
+ #x = torch.flatten(x, 1)
218
+ #x = self.fc(x)
219
+
220
+ return features
221
+
222
+ def forward(self, x):
223
+ return self._forward_impl(x)
224
+
225
+
226
+
227
+ def resnext101_32x8d(pretrained=True, **kwargs):
228
+ """Constructs a ResNet-152 model.
229
+ Args:
230
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
231
+ """
232
+ kwargs['groups'] = 32
233
+ kwargs['width_per_group'] = 8
234
+
235
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
236
+ return model
237
+
annotator/leres/leres/__pycache__/Resnet.cpython-39.pyc ADDED
Binary file (5.69 kB). View file
 
annotator/leres/leres/__pycache__/Resnext_torch.cpython-39.pyc ADDED
Binary file (5.85 kB). View file
 
annotator/leres/leres/__pycache__/depthmap.cpython-39.pyc ADDED
Binary file (11.6 kB). View file
 
annotator/leres/leres/__pycache__/multi_depth_model_woauxi.cpython-39.pyc ADDED
Binary file (1.71 kB). View file
 
annotator/leres/leres/__pycache__/net_tools.cpython-39.pyc ADDED
Binary file (1.89 kB). View file
 
annotator/leres/leres/__pycache__/network_auxi.cpython-39.pyc ADDED
Binary file (11.2 kB). View file
 
annotator/leres/leres/depthmap.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: thygate
2
+ # https://github.com/thygate/stable-diffusion-webui-depthmap-script
3
+
4
+ # from modules import devices
5
+ # from modules.shared import opts
6
+ from torchvision.transforms import transforms
7
+ from operator import getitem
8
+
9
+ import torch, gc
10
+ import cv2
11
+ import numpy as np
12
+ import skimage.measure
13
+
14
+ whole_size_threshold = 1600 # R_max from the paper
15
+ pix2pixsize = 1024
16
+
17
+
18
+ def scale_torch(img):
19
+ """
20
+ Scale the image and output it in torch.tensor.
21
+ :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
22
+ :param scale: the scale factor. float
23
+ :return: img. [C, H, W]
24
+ """
25
+ if len(img.shape) == 2:
26
+ img = img[np.newaxis, :, :]
27
+ if img.shape[2] == 3:
28
+ transform = transforms.Compose(
29
+ [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
30
+ img = transform(img.astype(np.float32))
31
+ else:
32
+ img = img.astype(np.float32)
33
+ img = torch.from_numpy(img)
34
+ return img
35
+
36
+
37
+ def estimateleres(img, model, w, h, device="cpu"):
38
+ # leres transform input
39
+ rgb_c = img[:, :, ::-1].copy()
40
+ A_resize = cv2.resize(rgb_c, (w, h))
41
+ img_torch = scale_torch(A_resize)[None, :, :, :]
42
+
43
+ # compute
44
+ with torch.no_grad():
45
+ img_torch = img_torch.to(device)
46
+ prediction = model.depth_model(img_torch)
47
+
48
+ prediction = prediction.squeeze().cpu().numpy()
49
+ prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
50
+
51
+ return prediction
52
+
53
+
54
+ def generatemask(size):
55
+ # Generates a Guassian mask
56
+ mask = np.zeros(size, dtype=np.float32)
57
+ sigma = int(size[0] / 16)
58
+ k_size = int(2 * np.ceil(2 * int(size[0] / 16)) + 1)
59
+ mask[int(0.15 * size[0]):size[0] - int(0.15 * size[0]), int(0.15 * size[1]): size[1] - int(0.15 * size[1])] = 1
60
+ mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
61
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
62
+ mask = mask.astype(np.float32)
63
+ return mask
64
+
65
+
66
+ def resizewithpool(img, size):
67
+ i_size = img.shape[0]
68
+ n = int(np.floor(i_size / size))
69
+
70
+ out = skimage.measure.block_reduce(img, (n, n), np.max)
71
+ return out
72
+
73
+
74
+ def rgb2gray(rgb):
75
+ # Converts rgb to gray
76
+ return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
77
+
78
+
79
+ def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
80
+ # Returns the R_x resolution described in section 5 of the main paper.
81
+
82
+ # Parameters:
83
+ # img :input rgb image
84
+ # basesize : size the dilation kernel which is equal to receptive field of the network.
85
+ # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
86
+ # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
87
+ # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
88
+
89
+ # Returns:
90
+ # outputsize_scale*speed_scale :The computed R_x resolution
91
+ # patch_scale: K parameter from section 6 of the paper
92
+
93
+ # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
94
+ speed_scale = 32
95
+ image_dim = int(min(img.shape[0:2]))
96
+
97
+ gray = rgb2gray(img)
98
+ grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
99
+ grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
100
+
101
+ # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
102
+ m = grad.min()
103
+ M = grad.max()
104
+ middle = m + (0.4 * (M - m))
105
+ grad[grad < middle] = 0
106
+ grad[grad >= middle] = 1
107
+
108
+ # dilation kernel with size of the receptive field
109
+ kernel = np.ones((int(basesize / speed_scale), int(basesize / speed_scale)), float)
110
+ # dilation kernel with size of the a quarter of receptive field used to compute k
111
+ # as described in section 6 of main paper
112
+ kernel2 = np.ones((int(basesize / (4 * speed_scale)), int(basesize / (4 * speed_scale))), float)
113
+
114
+ # Output resolution limit set by the whole_size_threshold and scale_threshold.
115
+ threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
116
+
117
+ outputsize_scale = basesize / speed_scale
118
+ for p_size in range(int(basesize / speed_scale), int(threshold / speed_scale), int(basesize / (2 * speed_scale))):
119
+ grad_resized = resizewithpool(grad, p_size)
120
+ grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
121
+ grad_resized[grad_resized >= 0.5] = 1
122
+ grad_resized[grad_resized < 0.5] = 0
123
+
124
+ dilated = cv2.dilate(grad_resized, kernel, iterations=1)
125
+ meanvalue = (1 - dilated).mean()
126
+ if meanvalue > confidence:
127
+ break
128
+ else:
129
+ outputsize_scale = p_size
130
+
131
+ grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
132
+ patch_scale = grad_region.mean()
133
+
134
+ return int(outputsize_scale * speed_scale), patch_scale
135
+
136
+
137
+ # Generate a double-input depth estimation
138
+ def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
139
+ # Generate the low resolution estimation
140
+ estimate1 = singleestimate(img, size1, model, net_type)
141
+ # Resize to the inference size of merge network.
142
+ estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
143
+
144
+ # Generate the high resolution estimation
145
+ estimate2 = singleestimate(img, size2, model, net_type)
146
+ # Resize to the inference size of merge network.
147
+ estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
148
+
149
+ # Inference on the merge model
150
+ pix2pixmodel.set_input(estimate1, estimate2)
151
+ pix2pixmodel.test()
152
+ visuals = pix2pixmodel.get_current_visuals()
153
+ prediction_mapped = visuals['fake_B']
154
+ prediction_mapped = (prediction_mapped + 1) / 2
155
+ prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
156
+ torch.max(prediction_mapped) - torch.min(prediction_mapped))
157
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
158
+
159
+ return prediction_mapped
160
+
161
+
162
+ # Generate a single-input depth estimation
163
+ def singleestimate(img, msize, model, net_type, device="cpu"):
164
+ # if net_type == 0:
165
+ return estimateleres(img, model, msize, msize, device)
166
+ # else:
167
+ # return estimatemidasBoost(img, model, msize, msize)
168
+
169
+
170
+ def applyGridpatch(blsize, stride, img, box):
171
+ # Extract a simple grid patch.
172
+ counter1 = 0
173
+ patch_bound_list = {}
174
+ for k in range(blsize, img.shape[1] - blsize, stride):
175
+ for j in range(blsize, img.shape[0] - blsize, stride):
176
+ patch_bound_list[str(counter1)] = {}
177
+ patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
178
+ patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
179
+ patchbounds[2] - patchbounds[0]]
180
+ patch_bound_list[str(counter1)]['rect'] = patch_bound
181
+ patch_bound_list[str(counter1)]['size'] = patch_bound[2]
182
+ counter1 = counter1 + 1
183
+ return patch_bound_list
184
+
185
+
186
+ # Generating local patches to perform the local refinement described in section 6 of the main paper.
187
+ def generatepatchs(img, base_size):
188
+ # Compute the gradients as a proxy of the contextual cues.
189
+ img_gray = rgb2gray(img)
190
+ whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) + \
191
+ np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
192
+
193
+ threshold = whole_grad[whole_grad > 0].mean()
194
+ whole_grad[whole_grad < threshold] = 0
195
+
196
+ # We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
197
+ gf = whole_grad.sum() / len(whole_grad.reshape(-1))
198
+ grad_integral_image = cv2.integral(whole_grad)
199
+
200
+ # Variables are selected such that the initial patch size would be the receptive field size
201
+ # and the stride is set to 1/3 of the receptive field size.
202
+ blsize = int(round(base_size / 2))
203
+ stride = int(round(blsize * 0.75))
204
+
205
+ # Get initial Grid
206
+ patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
207
+
208
+ # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
209
+ # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
210
+ print("Selecting patches ...")
211
+ patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
212
+
213
+ # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
214
+ # patch
215
+ patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
216
+ return patchset
217
+
218
+
219
+ def getGF_fromintegral(integralimage, rect):
220
+ # Computes the gradient density of a given patch from the gradient integral image.
221
+ x1 = rect[1]
222
+ x2 = rect[1] + rect[3]
223
+ y1 = rect[0]
224
+ y2 = rect[0] + rect[2]
225
+ value = integralimage[x2, y2] - integralimage[x1, y2] - integralimage[x2, y1] + integralimage[x1, y1]
226
+ return value
227
+
228
+
229
+ # Adaptively select patches
230
+ def adaptiveselection(integral_grad, patch_bound_list, gf):
231
+ patchlist = {}
232
+ count = 0
233
+ height, width = integral_grad.shape
234
+
235
+ search_step = int(32 / factor)
236
+
237
+ # Go through all patches
238
+ for c in range(len(patch_bound_list)):
239
+ # Get patch
240
+ bbox = patch_bound_list[str(c)]['rect']
241
+
242
+ # Compute the amount of gradients present in the patch from the integral image.
243
+ cgf = getGF_fromintegral(integral_grad, bbox) / (bbox[2] * bbox[3])
244
+
245
+ # Check if patching is beneficial by comparing the gradient density of the patch to
246
+ # the gradient density of the whole image
247
+ if cgf >= gf:
248
+ bbox_test = bbox.copy()
249
+ patchlist[str(count)] = {}
250
+
251
+ # Enlarge each patch until the gradient density of the patch is equal
252
+ # to the whole image gradient density
253
+ while True:
254
+
255
+ bbox_test[0] = bbox_test[0] - int(search_step / 2)
256
+ bbox_test[1] = bbox_test[1] - int(search_step / 2)
257
+
258
+ bbox_test[2] = bbox_test[2] + search_step
259
+ bbox_test[3] = bbox_test[3] + search_step
260
+
261
+ # Check if we are still within the image
262
+ if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
263
+ or bbox_test[0] + bbox_test[2] >= width:
264
+ break
265
+
266
+ # Compare gradient density
267
+ cgf = getGF_fromintegral(integral_grad, bbox_test) / (bbox_test[2] * bbox_test[3])
268
+ if cgf < gf:
269
+ break
270
+ bbox = bbox_test.copy()
271
+
272
+ # Add patch to selected patches
273
+ patchlist[str(count)]['rect'] = bbox
274
+ patchlist[str(count)]['size'] = bbox[2]
275
+ count = count + 1
276
+
277
+ # Return selected patches
278
+ return patchlist
279
+
280
+
281
+ def impatch(image, rect):
282
+ # Extract the given patch pixels from a given image.
283
+ w1 = rect[0]
284
+ h1 = rect[1]
285
+ w2 = w1 + rect[2]
286
+ h2 = h1 + rect[3]
287
+ image_patch = image[h1:h2, w1:w2]
288
+ return image_patch
289
+
290
+
291
+ class ImageandPatchs:
292
+ def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
293
+ self.root_dir = root_dir
294
+ self.patchsinfo = patchsinfo
295
+ self.name = name
296
+ self.patchs = patchsinfo
297
+ self.scale = scale
298
+
299
+ self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1] * scale), round(rgb_image.shape[0] * scale)),
300
+ interpolation=cv2.INTER_CUBIC)
301
+
302
+ self.do_have_estimate = False
303
+ self.estimation_updated_image = None
304
+ self.estimation_base_image = None
305
+
306
+ def __len__(self):
307
+ return len(self.patchs)
308
+
309
+ def set_base_estimate(self, est):
310
+ self.estimation_base_image = est
311
+ if self.estimation_updated_image is not None:
312
+ self.do_have_estimate = True
313
+
314
+ def set_updated_estimate(self, est):
315
+ self.estimation_updated_image = est
316
+ if self.estimation_base_image is not None:
317
+ self.do_have_estimate = True
318
+
319
+ def __getitem__(self, index):
320
+ patch_id = int(self.patchs[index][0])
321
+ rect = np.array(self.patchs[index][1]['rect'])
322
+ msize = self.patchs[index][1]['size']
323
+
324
+ ## applying scale to rect:
325
+ rect = np.round(rect * self.scale)
326
+ rect = rect.astype('int')
327
+ msize = round(msize * self.scale)
328
+
329
+ patch_rgb = impatch(self.rgb_image, rect)
330
+ if self.do_have_estimate:
331
+ patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
332
+ patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
333
+ return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
334
+ 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
335
+ 'size': msize, 'id': patch_id}
336
+ else:
337
+ return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
338
+
339
+ def print_options(self, opt):
340
+ """Print and save options
341
+
342
+ It will print both current options and default values(if different).
343
+ It will save options into a text file / [checkpoints_dir] / opt.txt
344
+ """
345
+ message = ''
346
+ message += '----------------- Options ---------------\n'
347
+ for k, v in sorted(vars(opt).items()):
348
+ comment = ''
349
+ default = self.parser.get_default(k)
350
+ if v != default:
351
+ comment = '\t[default: %s]' % str(default)
352
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
353
+ message += '----------------- End -------------------'
354
+ print(message)
355
+
356
+ # save to the disk
357
+ """
358
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
359
+ util.mkdirs(expr_dir)
360
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
361
+ with open(file_name, 'wt') as opt_file:
362
+ opt_file.write(message)
363
+ opt_file.write('\n')
364
+ """
365
+
366
+ def parse(self):
367
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
368
+ opt = self.gather_options()
369
+ opt.isTrain = self.isTrain # train or test
370
+
371
+ # process opt.suffix
372
+ if opt.suffix:
373
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
374
+ opt.name = opt.name + suffix
375
+
376
+ # self.print_options(opt)
377
+
378
+ # set gpu ids
379
+ str_ids = opt.gpu_ids.split(',')
380
+ opt.gpu_ids = []
381
+ for str_id in str_ids:
382
+ id = int(str_id)
383
+ if id >= 0:
384
+ opt.gpu_ids.append(id)
385
+ # if len(opt.gpu_ids) > 0:
386
+ # torch.cuda.set_device(opt.gpu_ids[0])
387
+
388
+ self.opt = opt
389
+ return self.opt
390
+
391
+
392
+ def estimateboost(img, model, model_type, pix2pixmodel, max_res=512):
393
+ global whole_size_threshold
394
+
395
+ # get settings
396
+ # if hasattr(opts, 'depthmap_script_boost_rmax'):
397
+ # whole_size_threshold = opts.depthmap_script_boost_rmax
398
+
399
+ if model_type == 0: # leres
400
+ net_receptive_field_size = 448
401
+ patch_netsize = 2 * net_receptive_field_size
402
+ elif model_type == 1: # dpt_beit_large_512
403
+ net_receptive_field_size = 512
404
+ patch_netsize = 2 * net_receptive_field_size
405
+ else: # other midas
406
+ net_receptive_field_size = 384
407
+ patch_netsize = 2 * net_receptive_field_size
408
+
409
+ gc.collect()
410
+ # devices.torch_gc()
411
+
412
+ # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
413
+ # It is arbitrarily large to avoid artifacts during rescaling for each crop.
414
+ mask_org = generatemask((3000, 3000))
415
+ mask = mask_org.copy()
416
+
417
+ # Value x of R_x defined in the section 5 of the main paper.
418
+ r_threshold_value = 0.2
419
+ # if R0:
420
+ # r_threshold_value = 0
421
+
422
+ input_resolution = img.shape
423
+ scale_threshold = 3 # Allows up-scaling with a scale up to 3
424
+
425
+ # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
426
+ # supplementary material.
427
+ whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value,
428
+ scale_threshold, whole_size_threshold)
429
+
430
+ # print('wholeImage being processed in :', whole_image_optimal_size)
431
+
432
+ # Generate the base estimate using the double estimation.
433
+ whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model,
434
+ model_type, pix2pixmodel)
435
+
436
+ # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
437
+ # small high-density regions of the image.
438
+ global factor
439
+ factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
440
+ # print('Adjust factor is:', 1/factor)
441
+
442
+ # Check if Local boosting is beneficial.
443
+ if max_res < whole_image_optimal_size:
444
+ # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
445
+ return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
446
+
447
+ # Compute the default target resolution.
448
+ if img.shape[0] > img.shape[1]:
449
+ a = 2 * whole_image_optimal_size
450
+ b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
451
+ else:
452
+ a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
453
+ b = 2 * whole_image_optimal_size
454
+ b = int(round(b / factor))
455
+ a = int(round(a / factor))
456
+
457
+ """
458
+ # recompute a, b and saturate to max res.
459
+ if max(a,b) > max_res:
460
+ print('Default Res is higher than max-res: Reducing final resolution')
461
+ if img.shape[0] > img.shape[1]:
462
+ a = max_res
463
+ b = round(max_res * img.shape[1] / img.shape[0])
464
+ else:
465
+ a = round(max_res * img.shape[0] / img.shape[1])
466
+ b = max_res
467
+ b = int(b)
468
+ a = int(a)
469
+ """
470
+
471
+ img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
472
+
473
+ # Extract selected patches for local refinement
474
+ base_size = net_receptive_field_size * 2
475
+ patchset = generatepatchs(img, base_size)
476
+
477
+ # print('Target resolution: ', img.shape)
478
+
479
+ # Computing a scale in case user prompted to generate the results as the same resolution of the input.
480
+ # Notice that our method output resolution is independent of the input resolution and this parameter will only
481
+ # enable a scaling operation during the local patch merge implementation to generate results with the same resolution
482
+ # as the input.
483
+ """
484
+ if output_resolution == 1:
485
+ mergein_scale = input_resolution[0] / img.shape[0]
486
+ print('Dynamicly change merged-in resolution; scale:', mergein_scale)
487
+ else:
488
+ mergein_scale = 1
489
+ """
490
+ # always rescale to input res for now
491
+ mergein_scale = input_resolution[0] / img.shape[0]
492
+
493
+ imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
494
+ whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1] * mergein_scale),
495
+ round(img.shape[0] * mergein_scale)),
496
+ interpolation=cv2.INTER_CUBIC)
497
+ imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
498
+ imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
499
+
500
+ print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
501
+ print('Patches to process: ' + str(len(imageandpatchs)))
502
+
503
+ # Enumerate through all patches, generate their estimations and refining the base estimate.
504
+ for patch_ind in range(len(imageandpatchs)):
505
+
506
+ # Get patch information
507
+ patch = imageandpatchs[patch_ind] # patch object
508
+ patch_rgb = patch['patch_rgb'] # rgb patch
509
+ patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
510
+ rect = patch['rect'] # patch size and location
511
+ patch_id = patch['id'] # patch ID
512
+ org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
513
+ print('\t Processing patch', patch_ind, '/', len(imageandpatchs) - 1, '|', rect)
514
+
515
+ # We apply double estimation for patches. The high resolution value is fixed to twice the receptive
516
+ # field size of the network for patches to accelerate the process.
517
+ patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model,
518
+ model_type, pix2pixmodel)
519
+ patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
520
+ patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize),
521
+ interpolation=cv2.INTER_CUBIC)
522
+
523
+ # Merging the patch estimation into the base estimate using our merge network:
524
+ # We feed the patch estimation and the same region from the updated base estimate to the merge network
525
+ # to generate the target estimate for the corresponding region.
526
+ pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
527
+
528
+ # Run merging network
529
+ pix2pixmodel.test()
530
+ visuals = pix2pixmodel.get_current_visuals()
531
+
532
+ prediction_mapped = visuals['fake_B']
533
+ prediction_mapped = (prediction_mapped + 1) / 2
534
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
535
+
536
+ mapped = prediction_mapped
537
+
538
+ # We use a simple linear polynomial to make sure the result of the merge network would match the values of
539
+ # base estimate
540
+ p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
541
+ merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
542
+
543
+ merged = cv2.resize(merged, (org_size[1], org_size[0]), interpolation=cv2.INTER_CUBIC)
544
+
545
+ # Get patch size and location
546
+ w1 = rect[0]
547
+ h1 = rect[1]
548
+ w2 = w1 + rect[2]
549
+ h2 = h1 + rect[3]
550
+
551
+ # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
552
+ # and resize it to our needed size while merging the patches.
553
+ if mask.shape != org_size:
554
+ mask = cv2.resize(mask_org, (org_size[1], org_size[0]), interpolation=cv2.INTER_LINEAR)
555
+
556
+ tobemergedto = imageandpatchs.estimation_updated_image
557
+
558
+ # Update the whole estimation:
559
+ # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
560
+ # blending at the boundaries of the patch region.
561
+ tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
562
+ imageandpatchs.set_updated_estimate(tobemergedto)
563
+
564
+ # output
565
+ return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]),
566
+ interpolation=cv2.INTER_CUBIC)
annotator/leres/leres/multi_depth_model_woauxi.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import network_auxi as network
2
+ from .net_tools import get_func
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class RelDepthModel(nn.Module):
8
+ def __init__(self, backbone='resnet50'):
9
+ super(RelDepthModel, self).__init__()
10
+ if backbone == 'resnet50':
11
+ encoder = 'resnet50_stride32'
12
+ elif backbone == 'resnext101':
13
+ encoder = 'resnext101_stride32x8d'
14
+ self.depth_model = DepthModel(encoder)
15
+
16
+ def inference(self, rgb):
17
+ with torch.no_grad():
18
+ input = rgb.to(self.depth_model.device)
19
+ depth = self.depth_model(input)
20
+ # pred_depth_out = depth - depth.min() + 0.01
21
+ return depth # pred_depth_out
22
+
23
+
24
+ class DepthModel(nn.Module):
25
+ def __init__(self, encoder):
26
+ super(DepthModel, self).__init__()
27
+ backbone = network.__name__.split('.')[-1] + '.' + encoder
28
+ self.encoder_modules = get_func(backbone)()
29
+ self.decoder_modules = network.Decoder()
30
+
31
+ def forward(self, x):
32
+ lateral_out = self.encoder_modules(x)
33
+ out_logit = self.decoder_modules(lateral_out)
34
+ return out_logit
annotator/leres/leres/net_tools.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import torch
3
+ import os
4
+ from collections import OrderedDict
5
+
6
+
7
+ def get_func(func_name):
8
+ """Helper to return a function object by name. func_name must identify a
9
+ function in this module or the path to a function relative to the base
10
+ 'modeling' module.
11
+ """
12
+ if func_name == '':
13
+ return None
14
+ try:
15
+ parts = func_name.split('.')
16
+ # Refers to a function in this module
17
+ if len(parts) == 1:
18
+ return globals()[parts[0]]
19
+ # Otherwise, assume we're referencing a module under modeling
20
+ module_name = 'annotator.leres.leres.' + '.'.join(parts[:-1])
21
+ module = importlib.import_module(module_name)
22
+ return getattr(module, parts[-1])
23
+ except Exception:
24
+ print('Failed to f1ind function: %s', func_name)
25
+ raise
26
+
27
+ def load_ckpt(args, depth_model, shift_model, focal_model):
28
+ """
29
+ Load checkpoint.
30
+ """
31
+ if os.path.isfile(args.load_ckpt):
32
+ print("loading checkpoint %s" % args.load_ckpt)
33
+ checkpoint = torch.load(args.load_ckpt)
34
+ if shift_model is not None:
35
+ shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
36
+ strict=True)
37
+ if focal_model is not None:
38
+ focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
39
+ strict=True)
40
+ depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
41
+ strict=True)
42
+ del checkpoint
43
+ if torch.cuda.is_available():
44
+ torch.cuda.empty_cache()
45
+
46
+
47
+ def strip_prefix_if_present(state_dict, prefix):
48
+ keys = sorted(state_dict.keys())
49
+ if not all(key.startswith(prefix) for key in keys):
50
+ return state_dict
51
+ stripped_state_dict = OrderedDict()
52
+ for key, value in state_dict.items():
53
+ stripped_state_dict[key.replace(prefix, "")] = value
54
+ return stripped_state_dict
annotator/leres/leres/network_auxi.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+
5
+ from . import Resnet, Resnext_torch
6
+
7
+
8
+ def resnet50_stride32():
9
+ return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2])
10
+
11
+ def resnext101_stride32x8d():
12
+ return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2])
13
+
14
+
15
+ class Decoder(nn.Module):
16
+ def __init__(self):
17
+ super(Decoder, self).__init__()
18
+ self.inchannels = [256, 512, 1024, 2048]
19
+ self.midchannels = [256, 256, 256, 512]
20
+ self.upfactors = [2,2,2,2]
21
+ self.outchannels = 1
22
+
23
+ self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
24
+ self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
25
+ self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
26
+
27
+ self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
28
+ self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
29
+ self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
30
+
31
+ self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2)
32
+ self._init_params()
33
+
34
+ def _init_params(self):
35
+ for m in self.modules():
36
+ if isinstance(m, nn.Conv2d):
37
+ init.normal_(m.weight, std=0.01)
38
+ if m.bias is not None:
39
+ init.constant_(m.bias, 0)
40
+ elif isinstance(m, nn.ConvTranspose2d):
41
+ init.normal_(m.weight, std=0.01)
42
+ if m.bias is not None:
43
+ init.constant_(m.bias, 0)
44
+ elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d
45
+ init.constant_(m.weight, 1)
46
+ init.constant_(m.bias, 0)
47
+ elif isinstance(m, nn.Linear):
48
+ init.normal_(m.weight, std=0.01)
49
+ if m.bias is not None:
50
+ init.constant_(m.bias, 0)
51
+
52
+ def forward(self, features):
53
+ x_32x = self.conv(features[3]) # 1/32
54
+ x_32 = self.conv1(x_32x)
55
+ x_16 = self.upsample(x_32) # 1/16
56
+
57
+ x_8 = self.ffm2(features[2], x_16) # 1/8
58
+ x_4 = self.ffm1(features[1], x_8) # 1/4
59
+ x_2 = self.ffm0(features[0], x_4) # 1/2
60
+ #-----------------------------------------
61
+ x = self.outconv(x_2) # original size
62
+ return x
63
+
64
+ class DepthNet(nn.Module):
65
+ __factory = {
66
+ 18: Resnet.resnet18,
67
+ 34: Resnet.resnet34,
68
+ 50: Resnet.resnet50,
69
+ 101: Resnet.resnet101,
70
+ 152: Resnet.resnet152
71
+ }
72
+ def __init__(self,
73
+ backbone='resnet',
74
+ depth=50,
75
+ upfactors=[2, 2, 2, 2]):
76
+ super(DepthNet, self).__init__()
77
+ self.backbone = backbone
78
+ self.depth = depth
79
+ self.pretrained = False
80
+ self.inchannels = [256, 512, 1024, 2048]
81
+ self.midchannels = [256, 256, 256, 512]
82
+ self.upfactors = upfactors
83
+ self.outchannels = 1
84
+
85
+ # Build model
86
+ if self.backbone == 'resnet':
87
+ if self.depth not in DepthNet.__factory:
88
+ raise KeyError("Unsupported depth:", self.depth)
89
+ self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained)
90
+ elif self.backbone == 'resnext101_32x8d':
91
+ self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained)
92
+ else:
93
+ self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained)
94
+
95
+ def forward(self, x):
96
+ x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4
97
+ return x
98
+
99
+
100
+ class FTB(nn.Module):
101
+ def __init__(self, inchannels, midchannels=512):
102
+ super(FTB, self).__init__()
103
+ self.in1 = inchannels
104
+ self.mid = midchannels
105
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1,
106
+ bias=True)
107
+ # NN.BatchNorm2d
108
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \
109
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
110
+ padding=1, stride=1, bias=True), \
111
+ nn.BatchNorm2d(num_features=self.mid), \
112
+ nn.ReLU(inplace=True), \
113
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
114
+ padding=1, stride=1, bias=True))
115
+ self.relu = nn.ReLU(inplace=True)
116
+
117
+ self.init_params()
118
+
119
+ def forward(self, x):
120
+ x = self.conv1(x)
121
+ x = x + self.conv_branch(x)
122
+ x = self.relu(x)
123
+
124
+ return x
125
+
126
+ def init_params(self):
127
+ for m in self.modules():
128
+ if isinstance(m, nn.Conv2d):
129
+ init.normal_(m.weight, std=0.01)
130
+ if m.bias is not None:
131
+ init.constant_(m.bias, 0)
132
+ elif isinstance(m, nn.ConvTranspose2d):
133
+ # init.kaiming_normal_(m.weight, mode='fan_out')
134
+ init.normal_(m.weight, std=0.01)
135
+ # init.xavier_normal_(m.weight)
136
+ if m.bias is not None:
137
+ init.constant_(m.bias, 0)
138
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
139
+ init.constant_(m.weight, 1)
140
+ init.constant_(m.bias, 0)
141
+ elif isinstance(m, nn.Linear):
142
+ init.normal_(m.weight, std=0.01)
143
+ if m.bias is not None:
144
+ init.constant_(m.bias, 0)
145
+
146
+
147
+ class ATA(nn.Module):
148
+ def __init__(self, inchannels, reduction=8):
149
+ super(ATA, self).__init__()
150
+ self.inchannels = inchannels
151
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
152
+ self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction),
153
+ nn.ReLU(inplace=True),
154
+ nn.Linear(self.inchannels // reduction, self.inchannels),
155
+ nn.Sigmoid())
156
+ self.init_params()
157
+
158
+ def forward(self, low_x, high_x):
159
+ n, c, _, _ = low_x.size()
160
+ x = torch.cat([low_x, high_x], 1)
161
+ x = self.avg_pool(x)
162
+ x = x.view(n, -1)
163
+ x = self.fc(x).view(n, c, 1, 1)
164
+ x = low_x * x + high_x
165
+
166
+ return x
167
+
168
+ def init_params(self):
169
+ for m in self.modules():
170
+ if isinstance(m, nn.Conv2d):
171
+ # init.kaiming_normal_(m.weight, mode='fan_out')
172
+ # init.normal(m.weight, std=0.01)
173
+ init.xavier_normal_(m.weight)
174
+ if m.bias is not None:
175
+ init.constant_(m.bias, 0)
176
+ elif isinstance(m, nn.ConvTranspose2d):
177
+ # init.kaiming_normal_(m.weight, mode='fan_out')
178
+ # init.normal_(m.weight, std=0.01)
179
+ init.xavier_normal_(m.weight)
180
+ if m.bias is not None:
181
+ init.constant_(m.bias, 0)
182
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
183
+ init.constant_(m.weight, 1)
184
+ init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.Linear):
186
+ init.normal_(m.weight, std=0.01)
187
+ if m.bias is not None:
188
+ init.constant_(m.bias, 0)
189
+
190
+
191
+ class FFM(nn.Module):
192
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
193
+ super(FFM, self).__init__()
194
+ self.inchannels = inchannels
195
+ self.midchannels = midchannels
196
+ self.outchannels = outchannels
197
+ self.upfactor = upfactor
198
+
199
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
200
+ # self.ata = ATA(inchannels = self.midchannels)
201
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
202
+
203
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
204
+
205
+ self.init_params()
206
+
207
+ def forward(self, low_x, high_x):
208
+ x = self.ftb1(low_x)
209
+ x = x + high_x
210
+ x = self.ftb2(x)
211
+ x = self.upsample(x)
212
+
213
+ return x
214
+
215
+ def init_params(self):
216
+ for m in self.modules():
217
+ if isinstance(m, nn.Conv2d):
218
+ # init.kaiming_normal_(m.weight, mode='fan_out')
219
+ init.normal_(m.weight, std=0.01)
220
+ # init.xavier_normal_(m.weight)
221
+ if m.bias is not None:
222
+ init.constant_(m.bias, 0)
223
+ elif isinstance(m, nn.ConvTranspose2d):
224
+ # init.kaiming_normal_(m.weight, mode='fan_out')
225
+ init.normal_(m.weight, std=0.01)
226
+ # init.xavier_normal_(m.weight)
227
+ if m.bias is not None:
228
+ init.constant_(m.bias, 0)
229
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
230
+ init.constant_(m.weight, 1)
231
+ init.constant_(m.bias, 0)
232
+ elif isinstance(m, nn.Linear):
233
+ init.normal_(m.weight, std=0.01)
234
+ if m.bias is not None:
235
+ init.constant_(m.bias, 0)
236
+
237
+
238
+ class AO(nn.Module):
239
+ # Adaptive output module
240
+ def __init__(self, inchannels, outchannels, upfactor=2):
241
+ super(AO, self).__init__()
242
+ self.inchannels = inchannels
243
+ self.outchannels = outchannels
244
+ self.upfactor = upfactor
245
+
246
+ self.adapt_conv = nn.Sequential(
247
+ nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1,
248
+ stride=1, bias=True), \
249
+ nn.BatchNorm2d(num_features=self.inchannels // 2), \
250
+ nn.ReLU(inplace=True), \
251
+ nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1,
252
+ stride=1, bias=True), \
253
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
254
+
255
+ self.init_params()
256
+
257
+ def forward(self, x):
258
+ x = self.adapt_conv(x)
259
+ return x
260
+
261
+ def init_params(self):
262
+ for m in self.modules():
263
+ if isinstance(m, nn.Conv2d):
264
+ # init.kaiming_normal_(m.weight, mode='fan_out')
265
+ init.normal_(m.weight, std=0.01)
266
+ # init.xavier_normal_(m.weight)
267
+ if m.bias is not None:
268
+ init.constant_(m.bias, 0)
269
+ elif isinstance(m, nn.ConvTranspose2d):
270
+ # init.kaiming_normal_(m.weight, mode='fan_out')
271
+ init.normal_(m.weight, std=0.01)
272
+ # init.xavier_normal_(m.weight)
273
+ if m.bias is not None:
274
+ init.constant_(m.bias, 0)
275
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
276
+ init.constant_(m.weight, 1)
277
+ init.constant_(m.bias, 0)
278
+ elif isinstance(m, nn.Linear):
279
+ init.normal_(m.weight, std=0.01)
280
+ if m.bias is not None:
281
+ init.constant_(m.bias, 0)
282
+
283
+
284
+
285
+ # ==============================================================================================================
286
+
287
+
288
+ class ResidualConv(nn.Module):
289
+ def __init__(self, inchannels):
290
+ super(ResidualConv, self).__init__()
291
+ # NN.BatchNorm2d
292
+ self.conv = nn.Sequential(
293
+ # nn.BatchNorm2d(num_features=inchannels),
294
+ nn.ReLU(inplace=False),
295
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
296
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
297
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1,
298
+ bias=False),
299
+ nn.BatchNorm2d(num_features=inchannels / 2),
300
+ nn.ReLU(inplace=False),
301
+ nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1,
302
+ bias=False)
303
+ )
304
+ self.init_params()
305
+
306
+ def forward(self, x):
307
+ x = self.conv(x) + x
308
+ return x
309
+
310
+ def init_params(self):
311
+ for m in self.modules():
312
+ if isinstance(m, nn.Conv2d):
313
+ # init.kaiming_normal_(m.weight, mode='fan_out')
314
+ init.normal_(m.weight, std=0.01)
315
+ # init.xavier_normal_(m.weight)
316
+ if m.bias is not None:
317
+ init.constant_(m.bias, 0)
318
+ elif isinstance(m, nn.ConvTranspose2d):
319
+ # init.kaiming_normal_(m.weight, mode='fan_out')
320
+ init.normal_(m.weight, std=0.01)
321
+ # init.xavier_normal_(m.weight)
322
+ if m.bias is not None:
323
+ init.constant_(m.bias, 0)
324
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
325
+ init.constant_(m.weight, 1)
326
+ init.constant_(m.bias, 0)
327
+ elif isinstance(m, nn.Linear):
328
+ init.normal_(m.weight, std=0.01)
329
+ if m.bias is not None:
330
+ init.constant_(m.bias, 0)
331
+
332
+
333
+ class FeatureFusion(nn.Module):
334
+ def __init__(self, inchannels, outchannels):
335
+ super(FeatureFusion, self).__init__()
336
+ self.conv = ResidualConv(inchannels=inchannels)
337
+ # NN.BatchNorm2d
338
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
339
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,
340
+ stride=2, padding=1, output_padding=1),
341
+ nn.BatchNorm2d(num_features=outchannels),
342
+ nn.ReLU(inplace=True))
343
+
344
+ def forward(self, lowfeat, highfeat):
345
+ return self.up(highfeat + self.conv(lowfeat))
346
+
347
+ def init_params(self):
348
+ for m in self.modules():
349
+ if isinstance(m, nn.Conv2d):
350
+ # init.kaiming_normal_(m.weight, mode='fan_out')
351
+ init.normal_(m.weight, std=0.01)
352
+ # init.xavier_normal_(m.weight)
353
+ if m.bias is not None:
354
+ init.constant_(m.bias, 0)
355
+ elif isinstance(m, nn.ConvTranspose2d):
356
+ # init.kaiming_normal_(m.weight, mode='fan_out')
357
+ init.normal_(m.weight, std=0.01)
358
+ # init.xavier_normal_(m.weight)
359
+ if m.bias is not None:
360
+ init.constant_(m.bias, 0)
361
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
362
+ init.constant_(m.weight, 1)
363
+ init.constant_(m.bias, 0)
364
+ elif isinstance(m, nn.Linear):
365
+ init.normal_(m.weight, std=0.01)
366
+ if m.bias is not None:
367
+ init.constant_(m.bias, 0)
368
+
369
+
370
+ class SenceUnderstand(nn.Module):
371
+ def __init__(self, channels):
372
+ super(SenceUnderstand, self).__init__()
373
+ self.channels = channels
374
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
375
+ nn.ReLU(inplace=True))
376
+ self.pool = nn.AdaptiveAvgPool2d(8)
377
+ self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels),
378
+ nn.ReLU(inplace=True))
379
+ self.conv2 = nn.Sequential(
380
+ nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
381
+ nn.ReLU(inplace=True))
382
+ self.initial_params()
383
+
384
+ def forward(self, x):
385
+ n, c, h, w = x.size()
386
+ x = self.conv1(x)
387
+ x = self.pool(x)
388
+ x = x.view(n, -1)
389
+ x = self.fc(x)
390
+ x = x.view(n, self.channels, 1, 1)
391
+ x = self.conv2(x)
392
+ x = x.repeat(1, 1, h, w)
393
+ return x
394
+
395
+ def initial_params(self, dev=0.01):
396
+ for m in self.modules():
397
+ if isinstance(m, nn.Conv2d):
398
+ # print torch.sum(m.weight)
399
+ m.weight.data.normal_(0, dev)
400
+ if m.bias is not None:
401
+ m.bias.data.fill_(0)
402
+ elif isinstance(m, nn.ConvTranspose2d):
403
+ # print torch.sum(m.weight)
404
+ m.weight.data.normal_(0, dev)
405
+ if m.bias is not None:
406
+ m.bias.data.fill_(0)
407
+ elif isinstance(m, nn.Linear):
408
+ m.weight.data.normal_(0, dev)
409
+
410
+
411
+ if __name__ == '__main__':
412
+ net = DepthNet(depth=50, pretrained=True)
413
+ print(net)
414
+ inputs = torch.ones(4,3,128,128)
415
+ out = net(inputs)
416
+ print(out.size())
417
+
annotator/leres/pix2pix/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://github.com/compphoto/BoostingMonocularDepth
2
+
3
+ Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved.
4
+
5
+ This software is for academic use only. A redistribution of this
6
+ software, with or without modifications, has to be for academic
7
+ use only, while giving the appropriate credit to the original
8
+ authors of the software. The methods implemented as a part of
9
+ this software may be covered under patents or patent applications.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED
12
+ WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
13
+ FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR
14
+ CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
15
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
16
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
17
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
18
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
19
+ ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
annotator/leres/pix2pix/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from .base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "annotator.leres.pix2pix.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
annotator/leres/pix2pix/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (3.3 kB). View file
 
annotator/leres/pix2pix/models/__pycache__/base_model.cpython-39.pyc ADDED
Binary file (10.3 kB). View file
 
annotator/leres/pix2pix/models/__pycache__/networks.cpython-39.pyc ADDED
Binary file (23.5 kB). View file
 
annotator/leres/pix2pix/models/__pycache__/pix2pix4depth_model.cpython-39.pyc ADDED
Binary file (5.56 kB). View file
 
annotator/leres/pix2pix/models/base_model.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch, gc
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this function, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): define networks used in our training.
29
+ -- self.visual_names (str list): specify the images that you want to display and save.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def modify_commandline_options(parser, is_train):
48
+ """Add new model-specific options, and rewrite default values for existing options.
49
+
50
+ Parameters:
51
+ parser -- original option parser
52
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53
+
54
+ Returns:
55
+ the modified parser.
56
+ """
57
+ return parser
58
+
59
+ @abstractmethod
60
+ def set_input(self, input):
61
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
62
+
63
+ Parameters:
64
+ input (dict): includes the data itself and its metadata information.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def forward(self):
70
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ def optimize_parameters(self):
75
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
76
+ pass
77
+
78
+ def setup(self, opt):
79
+ """Load and print networks; create schedulers
80
+
81
+ Parameters:
82
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
+ """
84
+ if self.isTrain:
85
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
+ if not self.isTrain or opt.continue_train:
87
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88
+ self.load_networks(load_suffix)
89
+ self.print_networks(opt.verbose)
90
+
91
+ def eval(self):
92
+ """Make models eval mode during test time"""
93
+ for name in self.model_names:
94
+ if isinstance(name, str):
95
+ net = getattr(self, 'net' + name)
96
+ net.eval()
97
+
98
+ def test(self):
99
+ """Forward function used in test time.
100
+
101
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
102
+ It also calls <compute_visuals> to produce additional visualization results
103
+ """
104
+ with torch.no_grad():
105
+ self.forward()
106
+ self.compute_visuals()
107
+
108
+ def compute_visuals(self):
109
+ """Calculate additional output images for visdom and HTML visualization"""
110
+ pass
111
+
112
+ def get_image_paths(self):
113
+ """ Return image paths that are used to load current data"""
114
+ return self.image_paths
115
+
116
+ def update_learning_rate(self):
117
+ """Update learning rates for all the networks; called at the end of every epoch"""
118
+ old_lr = self.optimizers[0].param_groups[0]['lr']
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+
125
+ lr = self.optimizers[0].param_groups[0]['lr']
126
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
127
+
128
+ def get_current_visuals(self):
129
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130
+ visual_ret = OrderedDict()
131
+ for name in self.visual_names:
132
+ if isinstance(name, str):
133
+ visual_ret[name] = getattr(self, name)
134
+ return visual_ret
135
+
136
+ def get_current_losses(self):
137
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138
+ errors_ret = OrderedDict()
139
+ for name in self.loss_names:
140
+ if isinstance(name, str):
141
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142
+ return errors_ret
143
+
144
+ def save_networks(self, epoch):
145
+ """Save all the networks to the disk.
146
+
147
+ Parameters:
148
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149
+ """
150
+ for name in self.model_names:
151
+ if isinstance(name, str):
152
+ save_filename = '%s_net_%s.pth' % (epoch, name)
153
+ save_path = os.path.join(self.save_dir, save_filename)
154
+ net = getattr(self, 'net' + name)
155
+
156
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
+ torch.save(net.module.cpu().state_dict(), save_path)
158
+ net.cuda(self.gpu_ids[0])
159
+ else:
160
+ torch.save(net.cpu().state_dict(), save_path)
161
+
162
+ def unload_network(self, name):
163
+ """Unload network and gc.
164
+ """
165
+ if isinstance(name, str):
166
+ net = getattr(self, 'net' + name)
167
+ del net
168
+ gc.collect()
169
+ # devices.torch_gc()
170
+ return None
171
+
172
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
173
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
174
+ key = keys[i]
175
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
176
+ if module.__class__.__name__.startswith('InstanceNorm') and \
177
+ (key == 'running_mean' or key == 'running_var'):
178
+ if getattr(module, key) is None:
179
+ state_dict.pop('.'.join(keys))
180
+ if module.__class__.__name__.startswith('InstanceNorm') and \
181
+ (key == 'num_batches_tracked'):
182
+ state_dict.pop('.'.join(keys))
183
+ else:
184
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
185
+
186
+ def load_networks(self, epoch):
187
+ """Load all the networks from the disk.
188
+
189
+ Parameters:
190
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
191
+ """
192
+ for name in self.model_names:
193
+ if isinstance(name, str):
194
+ load_filename = '%s_net_%s.pth' % (epoch, name)
195
+ load_path = os.path.join(self.save_dir, load_filename)
196
+ net = getattr(self, 'net' + name)
197
+ if isinstance(net, torch.nn.DataParallel):
198
+ net = net.module
199
+ # print('Loading depth boost model from %s' % load_path)
200
+ # if you are using PyTorch newer than 0.4 (e.g., built from
201
+ # GitHub source), you can remove str() on self.device
202
+ state_dict = torch.load(load_path, map_location=str(self.device))
203
+ if hasattr(state_dict, '_metadata'):
204
+ del state_dict._metadata
205
+
206
+ # patch InstanceNorm checkpoints prior to 0.4
207
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
208
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
209
+ net.load_state_dict(state_dict)
210
+
211
+ def print_networks(self, verbose):
212
+ """Print the total number of parameters in the network and (if verbose) network architecture
213
+
214
+ Parameters:
215
+ verbose (bool) -- if verbose: print the network architecture
216
+ """
217
+ print('---------- Networks initialized -------------')
218
+ for name in self.model_names:
219
+ if isinstance(name, str):
220
+ net = getattr(self, 'net' + name)
221
+ num_params = 0
222
+ for param in net.parameters():
223
+ num_params += param.numel()
224
+ if verbose:
225
+ print(net)
226
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
227
+ print('-----------------------------------------------')
228
+
229
+ def set_requires_grad(self, nets, requires_grad=False):
230
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
231
+ Parameters:
232
+ nets (network list) -- a list of networks
233
+ requires_grad (bool) -- whether the networks require gradients or not
234
+ """
235
+ if not isinstance(nets, list):
236
+ nets = [nets]
237
+ for net in nets:
238
+ if net is not None:
239
+ for param in net.parameters():
240
+ param.requires_grad = requires_grad
annotator/leres/pix2pix/models/base_model_hg.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ class BaseModelHG():
5
+ def name(self):
6
+ return 'BaseModel'
7
+
8
+ def initialize(self, opt):
9
+ self.opt = opt
10
+ self.gpu_ids = opt.gpu_ids
11
+ self.isTrain = opt.isTrain
12
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
13
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
14
+
15
+ def set_input(self, input):
16
+ self.input = input
17
+
18
+ def forward(self):
19
+ pass
20
+
21
+ # used in test time, no backprop
22
+ def test(self):
23
+ pass
24
+
25
+ def get_image_paths(self):
26
+ pass
27
+
28
+ def optimize_parameters(self):
29
+ pass
30
+
31
+ def get_current_visuals(self):
32
+ return self.input
33
+
34
+ def get_current_errors(self):
35
+ return {}
36
+
37
+ def save(self, label):
38
+ pass
39
+
40
+ # helper saving function that can be used by subclasses
41
+ def save_network(self, network, network_label, epoch_label, gpu_ids):
42
+ save_filename = '_%s_net_%s.pth' % (epoch_label, network_label)
43
+ save_path = os.path.join(self.save_dir, save_filename)
44
+ torch.save(network.cpu().state_dict(), save_path)
45
+ if len(gpu_ids) and torch.cuda.is_available():
46
+ network.cuda(device_id=gpu_ids[0])
47
+
48
+ # helper loading function that can be used by subclasses
49
+ def load_network(self, network, network_label, epoch_label):
50
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
51
+ save_path = os.path.join(self.save_dir, save_filename)
52
+ print(save_path)
53
+ model = torch.load(save_path)
54
+ return model
55
+ # network.load_state_dict(torch.load(save_path))
56
+
57
+ def update_learning_rate():
58
+ pass
annotator/leres/pix2pix/models/networks.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+
8
+ ###############################################################################
9
+ # Helper Functions
10
+ ###############################################################################
11
+
12
+
13
+ class Identity(nn.Module):
14
+ def forward(self, x):
15
+ return x
16
+
17
+
18
+ def get_norm_layer(norm_type='instance'):
19
+ """Return a normalization layer
20
+
21
+ Parameters:
22
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
23
+
24
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
25
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
26
+ """
27
+ if norm_type == 'batch':
28
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
29
+ elif norm_type == 'instance':
30
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
31
+ elif norm_type == 'none':
32
+ def norm_layer(x): return Identity()
33
+ else:
34
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
35
+ return norm_layer
36
+
37
+
38
+ def get_scheduler(optimizer, opt):
39
+ """Return a learning rate scheduler
40
+
41
+ Parameters:
42
+ optimizer -- the optimizer of the network
43
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
44
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
45
+
46
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
47
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
48
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
49
+ See https://pytorch.org/docs/stable/optim.html for more details.
50
+ """
51
+ if opt.lr_policy == 'linear':
52
+ def lambda_rule(epoch):
53
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
54
+ return lr_l
55
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
56
+ elif opt.lr_policy == 'step':
57
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
58
+ elif opt.lr_policy == 'plateau':
59
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
60
+ elif opt.lr_policy == 'cosine':
61
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
62
+ else:
63
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
64
+ return scheduler
65
+
66
+
67
+ def init_weights(net, init_type='normal', init_gain=0.02):
68
+ """Initialize network weights.
69
+
70
+ Parameters:
71
+ net (network) -- network to be initialized
72
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
73
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
74
+
75
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
76
+ work better for some applications. Feel free to try yourself.
77
+ """
78
+ def init_func(m): # define the initialization function
79
+ classname = m.__class__.__name__
80
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
81
+ if init_type == 'normal':
82
+ init.normal_(m.weight.data, 0.0, init_gain)
83
+ elif init_type == 'xavier':
84
+ init.xavier_normal_(m.weight.data, gain=init_gain)
85
+ elif init_type == 'kaiming':
86
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
87
+ elif init_type == 'orthogonal':
88
+ init.orthogonal_(m.weight.data, gain=init_gain)
89
+ else:
90
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
91
+ if hasattr(m, 'bias') and m.bias is not None:
92
+ init.constant_(m.bias.data, 0.0)
93
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
94
+ init.normal_(m.weight.data, 1.0, init_gain)
95
+ init.constant_(m.bias.data, 0.0)
96
+
97
+ # print('initialize network with %s' % init_type)
98
+ net.apply(init_func) # apply the initialization function <init_func>
99
+
100
+
101
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
102
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
103
+ Parameters:
104
+ net (network) -- the network to be initialized
105
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
106
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
107
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
108
+
109
+ Return an initialized network.
110
+ """
111
+ if len(gpu_ids) > 0:
112
+ assert(torch.cuda.is_available())
113
+ net.to(gpu_ids[0])
114
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
115
+ init_weights(net, init_type, init_gain=init_gain)
116
+ return net
117
+
118
+
119
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
120
+ """Create a generator
121
+
122
+ Parameters:
123
+ input_nc (int) -- the number of channels in input images
124
+ output_nc (int) -- the number of channels in output images
125
+ ngf (int) -- the number of filters in the last conv layer
126
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
127
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
128
+ use_dropout (bool) -- if use dropout layers.
129
+ init_type (str) -- the name of our initialization method.
130
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
131
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
132
+
133
+ Returns a generator
134
+
135
+ Our current implementation provides two types of generators:
136
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
137
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
138
+
139
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
140
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
141
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
142
+
143
+
144
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
145
+ """
146
+ net = None
147
+ norm_layer = get_norm_layer(norm_type=norm)
148
+
149
+ if netG == 'resnet_9blocks':
150
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
151
+ elif netG == 'resnet_6blocks':
152
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
153
+ elif netG == 'resnet_12blocks':
154
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12)
155
+ elif netG == 'unet_128':
156
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
157
+ elif netG == 'unet_256':
158
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
159
+ elif netG == 'unet_672':
160
+ net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
161
+ elif netG == 'unet_960':
162
+ net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
163
+ elif netG == 'unet_1024':
164
+ net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
165
+ else:
166
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
167
+ return init_net(net, init_type, init_gain, gpu_ids)
168
+
169
+
170
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
171
+ """Create a discriminator
172
+
173
+ Parameters:
174
+ input_nc (int) -- the number of channels in input images
175
+ ndf (int) -- the number of filters in the first conv layer
176
+ netD (str) -- the architecture's name: basic | n_layers | pixel
177
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
178
+ norm (str) -- the type of normalization layers used in the network.
179
+ init_type (str) -- the name of the initialization method.
180
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
181
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
182
+
183
+ Returns a discriminator
184
+
185
+ Our current implementation provides three types of discriminators:
186
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
187
+ It can classify whether 70×70 overlapping patches are real or fake.
188
+ Such a patch-level discriminator architecture has fewer parameters
189
+ than a full-image discriminator and can work on arbitrarily-sized images
190
+ in a fully convolutional fashion.
191
+
192
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
193
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
194
+
195
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
196
+ It encourages greater color diversity but has no effect on spatial statistics.
197
+
198
+ The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
199
+ """
200
+ net = None
201
+ norm_layer = get_norm_layer(norm_type=norm)
202
+
203
+ if netD == 'basic': # default PatchGAN classifier
204
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
205
+ elif netD == 'n_layers': # more options
206
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
207
+ elif netD == 'pixel': # classify if each pixel is real or fake
208
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
209
+ else:
210
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
211
+ return init_net(net, init_type, init_gain, gpu_ids)
212
+
213
+
214
+ ##############################################################################
215
+ # Classes
216
+ ##############################################################################
217
+ class GANLoss(nn.Module):
218
+ """Define different GAN objectives.
219
+
220
+ The GANLoss class abstracts away the need to create the target label tensor
221
+ that has the same size as the input.
222
+ """
223
+
224
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
225
+ """ Initialize the GANLoss class.
226
+
227
+ Parameters:
228
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
229
+ target_real_label (bool) - - label for a real image
230
+ target_fake_label (bool) - - label of a fake image
231
+
232
+ Note: Do not use sigmoid as the last layer of Discriminator.
233
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
234
+ """
235
+ super(GANLoss, self).__init__()
236
+ self.register_buffer('real_label', torch.tensor(target_real_label))
237
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
238
+ self.gan_mode = gan_mode
239
+ if gan_mode == 'lsgan':
240
+ self.loss = nn.MSELoss()
241
+ elif gan_mode == 'vanilla':
242
+ self.loss = nn.BCEWithLogitsLoss()
243
+ elif gan_mode in ['wgangp']:
244
+ self.loss = None
245
+ else:
246
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
247
+
248
+ def get_target_tensor(self, prediction, target_is_real):
249
+ """Create label tensors with the same size as the input.
250
+
251
+ Parameters:
252
+ prediction (tensor) - - tpyically the prediction from a discriminator
253
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
254
+
255
+ Returns:
256
+ A label tensor filled with ground truth label, and with the size of the input
257
+ """
258
+
259
+ if target_is_real:
260
+ target_tensor = self.real_label
261
+ else:
262
+ target_tensor = self.fake_label
263
+ return target_tensor.expand_as(prediction)
264
+
265
+ def __call__(self, prediction, target_is_real):
266
+ """Calculate loss given Discriminator's output and grount truth labels.
267
+
268
+ Parameters:
269
+ prediction (tensor) - - tpyically the prediction output from a discriminator
270
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
271
+
272
+ Returns:
273
+ the calculated loss.
274
+ """
275
+ if self.gan_mode in ['lsgan', 'vanilla']:
276
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
277
+ loss = self.loss(prediction, target_tensor)
278
+ elif self.gan_mode == 'wgangp':
279
+ if target_is_real:
280
+ loss = -prediction.mean()
281
+ else:
282
+ loss = prediction.mean()
283
+ return loss
284
+
285
+
286
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
287
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
288
+
289
+ Arguments:
290
+ netD (network) -- discriminator network
291
+ real_data (tensor array) -- real images
292
+ fake_data (tensor array) -- generated images from the generator
293
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
294
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
295
+ constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
296
+ lambda_gp (float) -- weight for this loss
297
+
298
+ Returns the gradient penalty loss
299
+ """
300
+ if lambda_gp > 0.0:
301
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
302
+ interpolatesv = real_data
303
+ elif type == 'fake':
304
+ interpolatesv = fake_data
305
+ elif type == 'mixed':
306
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
307
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
308
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
309
+ else:
310
+ raise NotImplementedError('{} not implemented'.format(type))
311
+ interpolatesv.requires_grad_(True)
312
+ disc_interpolates = netD(interpolatesv)
313
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
314
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
315
+ create_graph=True, retain_graph=True, only_inputs=True)
316
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
317
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
318
+ return gradient_penalty, gradients
319
+ else:
320
+ return 0.0, None
321
+
322
+
323
+ class ResnetGenerator(nn.Module):
324
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
325
+
326
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
327
+ """
328
+
329
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
330
+ """Construct a Resnet-based generator
331
+
332
+ Parameters:
333
+ input_nc (int) -- the number of channels in input images
334
+ output_nc (int) -- the number of channels in output images
335
+ ngf (int) -- the number of filters in the last conv layer
336
+ norm_layer -- normalization layer
337
+ use_dropout (bool) -- if use dropout layers
338
+ n_blocks (int) -- the number of ResNet blocks
339
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
340
+ """
341
+ assert(n_blocks >= 0)
342
+ super(ResnetGenerator, self).__init__()
343
+ if type(norm_layer) == functools.partial:
344
+ use_bias = norm_layer.func == nn.InstanceNorm2d
345
+ else:
346
+ use_bias = norm_layer == nn.InstanceNorm2d
347
+
348
+ model = [nn.ReflectionPad2d(3),
349
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
350
+ norm_layer(ngf),
351
+ nn.ReLU(True)]
352
+
353
+ n_downsampling = 2
354
+ for i in range(n_downsampling): # add downsampling layers
355
+ mult = 2 ** i
356
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
357
+ norm_layer(ngf * mult * 2),
358
+ nn.ReLU(True)]
359
+
360
+ mult = 2 ** n_downsampling
361
+ for i in range(n_blocks): # add ResNet blocks
362
+
363
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
364
+
365
+ for i in range(n_downsampling): # add upsampling layers
366
+ mult = 2 ** (n_downsampling - i)
367
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
368
+ kernel_size=3, stride=2,
369
+ padding=1, output_padding=1,
370
+ bias=use_bias),
371
+ norm_layer(int(ngf * mult / 2)),
372
+ nn.ReLU(True)]
373
+ model += [nn.ReflectionPad2d(3)]
374
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
375
+ model += [nn.Tanh()]
376
+
377
+ self.model = nn.Sequential(*model)
378
+
379
+ def forward(self, input):
380
+ """Standard forward"""
381
+ return self.model(input)
382
+
383
+
384
+ class ResnetBlock(nn.Module):
385
+ """Define a Resnet block"""
386
+
387
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
388
+ """Initialize the Resnet block
389
+
390
+ A resnet block is a conv block with skip connections
391
+ We construct a conv block with build_conv_block function,
392
+ and implement skip connections in <forward> function.
393
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
394
+ """
395
+ super(ResnetBlock, self).__init__()
396
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
397
+
398
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
399
+ """Construct a convolutional block.
400
+
401
+ Parameters:
402
+ dim (int) -- the number of channels in the conv layer.
403
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
404
+ norm_layer -- normalization layer
405
+ use_dropout (bool) -- if use dropout layers.
406
+ use_bias (bool) -- if the conv layer uses bias or not
407
+
408
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
409
+ """
410
+ conv_block = []
411
+ p = 0
412
+ if padding_type == 'reflect':
413
+ conv_block += [nn.ReflectionPad2d(1)]
414
+ elif padding_type == 'replicate':
415
+ conv_block += [nn.ReplicationPad2d(1)]
416
+ elif padding_type == 'zero':
417
+ p = 1
418
+ else:
419
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
420
+
421
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
422
+ if use_dropout:
423
+ conv_block += [nn.Dropout(0.5)]
424
+
425
+ p = 0
426
+ if padding_type == 'reflect':
427
+ conv_block += [nn.ReflectionPad2d(1)]
428
+ elif padding_type == 'replicate':
429
+ conv_block += [nn.ReplicationPad2d(1)]
430
+ elif padding_type == 'zero':
431
+ p = 1
432
+ else:
433
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
434
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
435
+
436
+ return nn.Sequential(*conv_block)
437
+
438
+ def forward(self, x):
439
+ """Forward function (with skip connections)"""
440
+ out = x + self.conv_block(x) # add skip connections
441
+ return out
442
+
443
+
444
+ class UnetGenerator(nn.Module):
445
+ """Create a Unet-based generator"""
446
+
447
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
448
+ """Construct a Unet generator
449
+ Parameters:
450
+ input_nc (int) -- the number of channels in input images
451
+ output_nc (int) -- the number of channels in output images
452
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
453
+ image of size 128x128 will become of size 1x1 # at the bottleneck
454
+ ngf (int) -- the number of filters in the last conv layer
455
+ norm_layer -- normalization layer
456
+
457
+ We construct the U-Net from the innermost layer to the outermost layer.
458
+ It is a recursive process.
459
+ """
460
+ super(UnetGenerator, self).__init__()
461
+ # construct unet structure
462
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
463
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
464
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
465
+ # gradually reduce the number of filters from ngf * 8 to ngf
466
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
467
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
468
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
469
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
470
+
471
+ def forward(self, input):
472
+ """Standard forward"""
473
+ return self.model(input)
474
+
475
+
476
+ class UnetSkipConnectionBlock(nn.Module):
477
+ """Defines the Unet submodule with skip connection.
478
+ X -------------------identity----------------------
479
+ |-- downsampling -- |submodule| -- upsampling --|
480
+ """
481
+
482
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
483
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
484
+ """Construct a Unet submodule with skip connections.
485
+
486
+ Parameters:
487
+ outer_nc (int) -- the number of filters in the outer conv layer
488
+ inner_nc (int) -- the number of filters in the inner conv layer
489
+ input_nc (int) -- the number of channels in input images/features
490
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
491
+ outermost (bool) -- if this module is the outermost module
492
+ innermost (bool) -- if this module is the innermost module
493
+ norm_layer -- normalization layer
494
+ use_dropout (bool) -- if use dropout layers.
495
+ """
496
+ super(UnetSkipConnectionBlock, self).__init__()
497
+ self.outermost = outermost
498
+ if type(norm_layer) == functools.partial:
499
+ use_bias = norm_layer.func == nn.InstanceNorm2d
500
+ else:
501
+ use_bias = norm_layer == nn.InstanceNorm2d
502
+ if input_nc is None:
503
+ input_nc = outer_nc
504
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
505
+ stride=2, padding=1, bias=use_bias)
506
+ downrelu = nn.LeakyReLU(0.2, True)
507
+ downnorm = norm_layer(inner_nc)
508
+ uprelu = nn.ReLU(True)
509
+ upnorm = norm_layer(outer_nc)
510
+
511
+ if outermost:
512
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
513
+ kernel_size=4, stride=2,
514
+ padding=1)
515
+ down = [downconv]
516
+ up = [uprelu, upconv, nn.Tanh()]
517
+ model = down + [submodule] + up
518
+ elif innermost:
519
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
520
+ kernel_size=4, stride=2,
521
+ padding=1, bias=use_bias)
522
+ down = [downrelu, downconv]
523
+ up = [uprelu, upconv, upnorm]
524
+ model = down + up
525
+ else:
526
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
527
+ kernel_size=4, stride=2,
528
+ padding=1, bias=use_bias)
529
+ down = [downrelu, downconv, downnorm]
530
+ up = [uprelu, upconv, upnorm]
531
+
532
+ if use_dropout:
533
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
534
+ else:
535
+ model = down + [submodule] + up
536
+
537
+ self.model = nn.Sequential(*model)
538
+
539
+ def forward(self, x):
540
+ if self.outermost:
541
+ return self.model(x)
542
+ else: # add skip connections
543
+ return torch.cat([x, self.model(x)], 1)
544
+
545
+
546
+ class NLayerDiscriminator(nn.Module):
547
+ """Defines a PatchGAN discriminator"""
548
+
549
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
550
+ """Construct a PatchGAN discriminator
551
+
552
+ Parameters:
553
+ input_nc (int) -- the number of channels in input images
554
+ ndf (int) -- the number of filters in the last conv layer
555
+ n_layers (int) -- the number of conv layers in the discriminator
556
+ norm_layer -- normalization layer
557
+ """
558
+ super(NLayerDiscriminator, self).__init__()
559
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
560
+ use_bias = norm_layer.func == nn.InstanceNorm2d
561
+ else:
562
+ use_bias = norm_layer == nn.InstanceNorm2d
563
+
564
+ kw = 4
565
+ padw = 1
566
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
567
+ nf_mult = 1
568
+ nf_mult_prev = 1
569
+ for n in range(1, n_layers): # gradually increase the number of filters
570
+ nf_mult_prev = nf_mult
571
+ nf_mult = min(2 ** n, 8)
572
+ sequence += [
573
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
574
+ norm_layer(ndf * nf_mult),
575
+ nn.LeakyReLU(0.2, True)
576
+ ]
577
+
578
+ nf_mult_prev = nf_mult
579
+ nf_mult = min(2 ** n_layers, 8)
580
+ sequence += [
581
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
582
+ norm_layer(ndf * nf_mult),
583
+ nn.LeakyReLU(0.2, True)
584
+ ]
585
+
586
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
587
+ self.model = nn.Sequential(*sequence)
588
+
589
+ def forward(self, input):
590
+ """Standard forward."""
591
+ return self.model(input)
592
+
593
+
594
+ class PixelDiscriminator(nn.Module):
595
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
596
+
597
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
598
+ """Construct a 1x1 PatchGAN discriminator
599
+
600
+ Parameters:
601
+ input_nc (int) -- the number of channels in input images
602
+ ndf (int) -- the number of filters in the last conv layer
603
+ norm_layer -- normalization layer
604
+ """
605
+ super(PixelDiscriminator, self).__init__()
606
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
607
+ use_bias = norm_layer.func == nn.InstanceNorm2d
608
+ else:
609
+ use_bias = norm_layer == nn.InstanceNorm2d
610
+
611
+ self.net = [
612
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
613
+ nn.LeakyReLU(0.2, True),
614
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
615
+ norm_layer(ndf * 2),
616
+ nn.LeakyReLU(0.2, True),
617
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
618
+
619
+ self.net = nn.Sequential(*self.net)
620
+
621
+ def forward(self, input):
622
+ """Standard forward."""
623
+ return self.net(input)
annotator/leres/pix2pix/models/pix2pix4depth_model.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+
5
+
6
+ class Pix2Pix4DepthModel(BaseModel):
7
+ """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
8
+
9
+ The model training requires '--dataset_mode aligned' dataset.
10
+ By default, it uses a '--netG unet256' U-Net generator,
11
+ a '--netD basic' discriminator (PatchGAN),
12
+ and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
13
+
14
+ pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
15
+ """
16
+ @staticmethod
17
+ def modify_commandline_options(parser, is_train=True):
18
+ """Add new dataset-specific options, and rewrite default values for existing options.
19
+
20
+ Parameters:
21
+ parser -- original option parser
22
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
23
+
24
+ Returns:
25
+ the modified parser.
26
+
27
+ For pix2pix, we do not use image buffer
28
+ The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
29
+ By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
30
+ """
31
+ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
32
+ parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge')
33
+ if is_train:
34
+ parser.set_defaults(pool_size=0, gan_mode='vanilla',)
35
+ parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss')
36
+ return parser
37
+
38
+ def __init__(self, opt):
39
+ """Initialize the pix2pix class.
40
+
41
+ Parameters:
42
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
43
+ """
44
+ BaseModel.__init__(self, opt)
45
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
46
+
47
+ self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
48
+ # self.loss_names = ['G_L1']
49
+
50
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
51
+ if self.isTrain:
52
+ self.visual_names = ['outer','inner', 'fake_B', 'real_B']
53
+ else:
54
+ self.visual_names = ['fake_B']
55
+
56
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
57
+ if self.isTrain:
58
+ self.model_names = ['G','D']
59
+ else: # during test time, only load G
60
+ self.model_names = ['G']
61
+
62
+ # define networks (both generator and discriminator)
63
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none',
64
+ False, 'normal', 0.02, self.gpu_ids)
65
+
66
+ if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
67
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
68
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
69
+
70
+ if self.isTrain:
71
+ # define loss functions
72
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
73
+ self.criterionL1 = torch.nn.L1Loss()
74
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
75
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999))
76
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999))
77
+ self.optimizers.append(self.optimizer_G)
78
+ self.optimizers.append(self.optimizer_D)
79
+
80
+ def set_input_train(self, input):
81
+ self.outer = input['data_outer'].to(self.device)
82
+ self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False)
83
+
84
+ self.inner = input['data_inner'].to(self.device)
85
+ self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False)
86
+
87
+ self.image_paths = input['image_path']
88
+
89
+ if self.isTrain:
90
+ self.gtfake = input['data_gtfake'].to(self.device)
91
+ self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False)
92
+ self.real_B = self.gtfake
93
+
94
+ self.real_A = torch.cat((self.outer, self.inner), 1)
95
+
96
+ def set_input(self, outer, inner):
97
+ inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
98
+ outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
99
+
100
+ inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
101
+ outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))
102
+
103
+ inner = self.normalize(inner)
104
+ outer = self.normalize(outer)
105
+
106
+ self.real_A = torch.cat((outer, inner), 1).to(self.device)
107
+
108
+
109
+ def normalize(self, input):
110
+ input = input * 2
111
+ input = input - 1
112
+ return input
113
+
114
+ def forward(self):
115
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
116
+ self.fake_B = self.netG(self.real_A) # G(A)
117
+
118
+ def backward_D(self):
119
+ """Calculate GAN loss for the discriminator"""
120
+ # Fake; stop backprop to the generator by detaching fake_B
121
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
122
+ pred_fake = self.netD(fake_AB.detach())
123
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
124
+ # Real
125
+ real_AB = torch.cat((self.real_A, self.real_B), 1)
126
+ pred_real = self.netD(real_AB)
127
+ self.loss_D_real = self.criterionGAN(pred_real, True)
128
+ # combine loss and calculate gradients
129
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
130
+ self.loss_D.backward()
131
+
132
+ def backward_G(self):
133
+ """Calculate GAN and L1 loss for the generator"""
134
+ # First, G(A) should fake the discriminator
135
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
136
+ pred_fake = self.netD(fake_AB)
137
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
138
+ # Second, G(A) = B
139
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
140
+ # combine loss and calculate gradients
141
+ self.loss_G = self.loss_G_L1 + self.loss_G_GAN
142
+ self.loss_G.backward()
143
+
144
+ def optimize_parameters(self):
145
+ self.forward() # compute fake images: G(A)
146
+ # update D
147
+ self.set_requires_grad(self.netD, True) # enable backprop for D
148
+ self.optimizer_D.zero_grad() # set D's gradients to zero
149
+ self.backward_D() # calculate gradients for D
150
+ self.optimizer_D.step() # update D's weights
151
+ # update G
152
+ self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
153
+ self.optimizer_G.zero_grad() # set G's gradients to zero
154
+ self.backward_G() # calculate graidents for G
155
+ self.optimizer_G.step() # udpate G's weights
annotator/leres/pix2pix/options/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
annotator/leres/pix2pix/options/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (341 Bytes). View file
 
annotator/leres/pix2pix/options/__pycache__/base_options.cpython-39.pyc ADDED
Binary file (7.2 kB). View file
 
annotator/leres/pix2pix/options/__pycache__/test_options.cpython-39.pyc ADDED
Binary file (1.15 kB). View file
 
annotator/leres/pix2pix/options/base_options.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from ...pix2pix.util import util
4
+ # import torch
5
+ from ...pix2pix import models
6
+ # import pix2pix.data
7
+ import numpy as np
8
+
9
+ class BaseOptions():
10
+ """This class defines options used during both training and test time.
11
+
12
+ It also implements several helper functions such as parsing, printing, and saving the options.
13
+ It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
14
+ """
15
+
16
+ def __init__(self):
17
+ """Reset the class; indicates the class hasn't been initailized"""
18
+ self.initialized = False
19
+
20
+ def initialize(self, parser):
21
+ """Define the common options that are used in both training and test."""
22
+ # basic parameters
23
+ parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
24
+ parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet')
25
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
26
+ parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here')
27
+ # model parameters
28
+ parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
29
+ parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale')
30
+ parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
31
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
32
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
33
+ parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
34
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
35
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
36
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
37
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
38
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
39
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
40
+ # dataset parameters
41
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
42
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
43
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
44
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
45
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
46
+ parser.add_argument('--load_size', type=int, default=672, help='scale images to this size')
47
+ parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size')
48
+ parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
49
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
50
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
51
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
52
+ # additional parameters
53
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
54
+ parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
55
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
56
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
57
+
58
+ parser.add_argument('--data_dir', type=str, required=False,
59
+ help='input files directory images can be .png .jpg .tiff')
60
+ parser.add_argument('--output_dir', type=str, required=False,
61
+ help='result dir. result depth will be png. vides are JMPG as avi')
62
+ parser.add_argument('--savecrops', type=int, required=False)
63
+ parser.add_argument('--savewholeest', type=int, required=False)
64
+ parser.add_argument('--output_resolution', type=int, required=False,
65
+ help='0 for no restriction 1 for resize to input size')
66
+ parser.add_argument('--net_receptive_field_size', type=int, required=False)
67
+ parser.add_argument('--pix2pixsize', type=int, required=False)
68
+ parser.add_argument('--generatevideo', type=int, required=False)
69
+ parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL')
70
+ parser.add_argument('--R0', action='store_true')
71
+ parser.add_argument('--R20', action='store_true')
72
+ parser.add_argument('--Final', action='store_true')
73
+ parser.add_argument('--colorize_results', action='store_true')
74
+ parser.add_argument('--max_res', type=float, default=np.inf)
75
+
76
+ self.initialized = True
77
+ return parser
78
+
79
+ def gather_options(self):
80
+ """Initialize our parser with basic options(only once).
81
+ Add additional model-specific and dataset-specific options.
82
+ These options are defined in the <modify_commandline_options> function
83
+ in model and dataset classes.
84
+ """
85
+ if not self.initialized: # check if it has been initialized
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser = self.initialize(parser)
88
+
89
+ # get the basic options
90
+ opt, _ = parser.parse_known_args()
91
+
92
+ # modify model-related parser options
93
+ model_name = opt.model
94
+ model_option_setter = models.get_option_setter(model_name)
95
+ parser = model_option_setter(parser, self.isTrain)
96
+ opt, _ = parser.parse_known_args() # parse again with new defaults
97
+
98
+ # modify dataset-related parser options
99
+ # dataset_name = opt.dataset_mode
100
+ # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name)
101
+ # parser = dataset_option_setter(parser, self.isTrain)
102
+
103
+ # save and return the parser
104
+ self.parser = parser
105
+ #return parser.parse_args() #EVIL
106
+ return opt
107
+
108
+ def print_options(self, opt):
109
+ """Print and save options
110
+
111
+ It will print both current options and default values(if different).
112
+ It will save options into a text file / [checkpoints_dir] / opt.txt
113
+ """
114
+ message = ''
115
+ message += '----------------- Options ---------------\n'
116
+ for k, v in sorted(vars(opt).items()):
117
+ comment = ''
118
+ default = self.parser.get_default(k)
119
+ if v != default:
120
+ comment = '\t[default: %s]' % str(default)
121
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
122
+ message += '----------------- End -------------------'
123
+ print(message)
124
+
125
+ # save to the disk
126
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
127
+ util.mkdirs(expr_dir)
128
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
129
+ with open(file_name, 'wt') as opt_file:
130
+ opt_file.write(message)
131
+ opt_file.write('\n')
132
+
133
+ def parse(self):
134
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
135
+ opt = self.gather_options()
136
+ opt.isTrain = self.isTrain # train or test
137
+
138
+ # process opt.suffix
139
+ if opt.suffix:
140
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
141
+ opt.name = opt.name + suffix
142
+
143
+ #self.print_options(opt)
144
+
145
+ # set gpu ids
146
+ str_ids = opt.gpu_ids.split(',')
147
+ opt.gpu_ids = []
148
+ for str_id in str_ids:
149
+ id = int(str_id)
150
+ if id >= 0:
151
+ opt.gpu_ids.append(id)
152
+ #if len(opt.gpu_ids) > 0:
153
+ # torch.cuda.set_device(opt.gpu_ids[0])
154
+
155
+ self.opt = opt
156
+ return self.opt
annotator/leres/pix2pix/options/test_options.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_options import BaseOptions
2
+
3
+
4
+ class TestOptions(BaseOptions):
5
+ """This class includes test options.
6
+
7
+ It also includes shared options defined in BaseOptions.
8
+ """
9
+
10
+ def initialize(self, parser):
11
+ parser = BaseOptions.initialize(self, parser) # define shared options
12
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
13
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
14
+ # Dropout and Batchnorm has different behavioir during training and test.
15
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
16
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
17
+ # rewrite devalue values
18
+ parser.set_defaults(model='pix2pix4depth')
19
+ # To avoid cropping, the load_size should be the same as crop_size
20
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
21
+ self.isTrain = False
22
+ return parser
annotator/leres/pix2pix/util/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
annotator/leres/pix2pix/util/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (285 Bytes). View file
 
annotator/leres/pix2pix/util/__pycache__/util.cpython-39.pyc ADDED
Binary file (3.01 kB). View file
 
annotator/leres/pix2pix/util/get_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """A Python script for downloading CycleGAN or pix2pix datasets.
13
+
14
+ Parameters:
15
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16
+ verbose (bool) -- If True, print additional information.
17
+
18
+ Examples:
19
+ >>> from util.get_data import GetData
20
+ >>> gd = GetData(technique='cyclegan')
21
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22
+
23
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24
+ and 'scripts/download_cyclegan_model.sh'.
25
+ """
26
+
27
+ def __init__(self, technique='cyclegan', verbose=True):
28
+ url_dict = {
29
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31
+ }
32
+ self.url = url_dict.get(technique.lower())
33
+ self._verbose = verbose
34
+
35
+ def _print(self, text):
36
+ if self._verbose:
37
+ print(text)
38
+
39
+ @staticmethod
40
+ def _get_options(r):
41
+ soup = BeautifulSoup(r.text, 'lxml')
42
+ options = [h.text for h in soup.find_all('a', href=True)
43
+ if h.text.endswith(('.zip', 'tar.gz'))]
44
+ return options
45
+
46
+ def _present_options(self):
47
+ r = requests.get(self.url)
48
+ options = self._get_options(r)
49
+ print('Options:\n')
50
+ for i, o in enumerate(options):
51
+ print("{0}: {1}".format(i, o))
52
+ choice = input("\nPlease enter the number of the "
53
+ "dataset above you wish to download:")
54
+ return options[int(choice)]
55
+
56
+ def _download_data(self, dataset_url, save_path):
57
+ if not isdir(save_path):
58
+ os.makedirs(save_path)
59
+
60
+ base = basename(dataset_url)
61
+ temp_save_path = join(save_path, base)
62
+
63
+ with open(temp_save_path, "wb") as f:
64
+ r = requests.get(dataset_url)
65
+ f.write(r.content)
66
+
67
+ if base.endswith('.tar.gz'):
68
+ obj = tarfile.open(temp_save_path)
69
+ elif base.endswith('.zip'):
70
+ obj = ZipFile(temp_save_path, 'r')
71
+ else:
72
+ raise ValueError("Unknown File Type: {0}.".format(base))
73
+
74
+ self._print("Unpacking Data...")
75
+ obj.extractall(save_path)
76
+ obj.close()
77
+ os.remove(temp_save_path)
78
+
79
+ def get(self, save_path, dataset=None):
80
+ """
81
+
82
+ Download a dataset.
83
+
84
+ Parameters:
85
+ save_path (str) -- A directory to save the data to.
86
+ dataset (str) -- (optional). A specific dataset to download.
87
+ Note: this must include the file extension.
88
+ If None, options will be presented for you
89
+ to choose from.
90
+
91
+ Returns:
92
+ save_path_full (str) -- the absolute path to the downloaded data.
93
+
94
+ """
95
+ if dataset is None:
96
+ selected_dataset = self._present_options()
97
+ else:
98
+ selected_dataset = dataset
99
+
100
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
101
+
102
+ if isdir(save_path_full):
103
+ warn("\n'{0}' already exists. Voiding Download.".format(
104
+ save_path_full))
105
+ else:
106
+ self._print('Downloading Data...')
107
+ url = "{0}/{1}".format(self.url, selected_dataset)
108
+ self._download_data(url, save_path=save_path)
109
+
110
+ return abspath(save_path_full)
annotator/leres/pix2pix/util/guidedfilter.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class GuidedFilter():
4
+ def __init__(self, source, reference, r=64, eps= 0.05**2):
5
+ self.source = source;
6
+ self.reference = reference;
7
+ self.r = r
8
+ self.eps = eps
9
+
10
+ self.smooth = self.guidedfilter(self.source,self.reference,self.r,self.eps)
11
+
12
+ def boxfilter(self,img, r):
13
+ (rows, cols) = img.shape
14
+ imDst = np.zeros_like(img)
15
+
16
+ imCum = np.cumsum(img, 0)
17
+ imDst[0 : r+1, :] = imCum[r : 2*r+1, :]
18
+ imDst[r+1 : rows-r, :] = imCum[2*r+1 : rows, :] - imCum[0 : rows-2*r-1, :]
19
+ imDst[rows-r: rows, :] = np.tile(imCum[rows-1, :], [r, 1]) - imCum[rows-2*r-1 : rows-r-1, :]
20
+
21
+ imCum = np.cumsum(imDst, 1)
22
+ imDst[:, 0 : r+1] = imCum[:, r : 2*r+1]
23
+ imDst[:, r+1 : cols-r] = imCum[:, 2*r+1 : cols] - imCum[:, 0 : cols-2*r-1]
24
+ imDst[:, cols-r: cols] = np.tile(imCum[:, cols-1], [r, 1]).T - imCum[:, cols-2*r-1 : cols-r-1]
25
+
26
+ return imDst
27
+
28
+ def guidedfilter(self,I, p, r, eps):
29
+ (rows, cols) = I.shape
30
+ N = self.boxfilter(np.ones([rows, cols]), r)
31
+
32
+ meanI = self.boxfilter(I, r) / N
33
+ meanP = self.boxfilter(p, r) / N
34
+ meanIp = self.boxfilter(I * p, r) / N
35
+ covIp = meanIp - meanI * meanP
36
+
37
+ meanII = self.boxfilter(I * I, r) / N
38
+ varI = meanII - meanI * meanI
39
+
40
+ a = covIp / (varI + eps)
41
+ b = meanP - a * meanI
42
+
43
+ meanA = self.boxfilter(a, r) / N
44
+ meanB = self.boxfilter(b, r) / N
45
+
46
+ q = meanA * I + meanB
47
+ return q
annotator/leres/pix2pix/util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
annotator/leres/pix2pix/util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
annotator/leres/pix2pix/util/util.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+
8
+
9
+ def tensor2im(input_image, imtype=np.uint16):
10
+ """"Converts a Tensor array into a numpy image array.
11
+
12
+ Parameters:
13
+ input_image (tensor) -- the input image tensor array
14
+ imtype (type) -- the desired type of the converted numpy array
15
+ """
16
+ if not isinstance(input_image, np.ndarray):
17
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
18
+ image_tensor = input_image.data
19
+ else:
20
+ return input_image
21
+ image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array
22
+ image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) #
23
+ else: # if it is a numpy array, do nothing
24
+ image_numpy = input_image
25
+ return image_numpy.astype(imtype)
26
+
27
+
28
+ def diagnose_network(net, name='network'):
29
+ """Calculate and print the mean of average absolute(gradients)
30
+
31
+ Parameters:
32
+ net (torch network) -- Torch network
33
+ name (str) -- the name of the network
34
+ """
35
+ mean = 0.0
36
+ count = 0
37
+ for param in net.parameters():
38
+ if param.grad is not None:
39
+ mean += torch.mean(torch.abs(param.grad.data))
40
+ count += 1
41
+ if count > 0:
42
+ mean = mean / count
43
+ print(name)
44
+ print(mean)
45
+
46
+
47
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
48
+ """Save a numpy image to the disk
49
+
50
+ Parameters:
51
+ image_numpy (numpy array) -- input numpy array
52
+ image_path (str) -- the path of the image
53
+ """
54
+ image_pil = Image.fromarray(image_numpy)
55
+
56
+ image_pil = image_pil.convert('I;16')
57
+
58
+ # image_pil = Image.fromarray(image_numpy)
59
+ # h, w, _ = image_numpy.shape
60
+ #
61
+ # if aspect_ratio > 1.0:
62
+ # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
63
+ # if aspect_ratio < 1.0:
64
+ # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
65
+
66
+ image_pil.save(image_path)
67
+
68
+
69
+ def print_numpy(x, val=True, shp=False):
70
+ """Print the mean, min, max, median, std, and size of a numpy array
71
+
72
+ Parameters:
73
+ val (bool) -- if print the values of the numpy array
74
+ shp (bool) -- if print the shape of the numpy array
75
+ """
76
+ x = x.astype(np.float64)
77
+ if shp:
78
+ print('shape,', x.shape)
79
+ if val:
80
+ x = x.flatten()
81
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
82
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
83
+
84
+
85
+ def mkdirs(paths):
86
+ """create empty directories if they don't exist
87
+
88
+ Parameters:
89
+ paths (str list) -- a list of directory paths
90
+ """
91
+ if isinstance(paths, list) and not isinstance(paths, str):
92
+ for path in paths:
93
+ mkdir(path)
94
+ else:
95
+ mkdir(paths)
96
+
97
+
98
+ def mkdir(path):
99
+ """create a single empty directory if it didn't exist
100
+
101
+ Parameters:
102
+ path (str) -- a single directory path
103
+ """
104
+ if not os.path.exists(path):
105
+ os.makedirs(path)
annotator/leres/pix2pix/util/visualizer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+ import torch
9
+
10
+
11
+ if sys.version_info[0] == 2:
12
+ VisdomExceptionBase = Exception
13
+ else:
14
+ VisdomExceptionBase = ConnectionError
15
+
16
+
17
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
18
+ """Save images to the disk.
19
+
20
+ Parameters:
21
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
22
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
23
+ image_path (str) -- the string is used to create image paths
24
+ aspect_ratio (float) -- the aspect ratio of saved images
25
+ width (int) -- the images will be resized to width x width
26
+
27
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
28
+ """
29
+ image_dir = webpage.get_image_dir()
30
+ short_path = ntpath.basename(image_path[0])
31
+ name = os.path.splitext(short_path)[0]
32
+
33
+ webpage.add_header(name)
34
+ ims, txts, links = [], [], []
35
+
36
+ for label, im_data in visuals.items():
37
+ im = util.tensor2im(im_data)
38
+ image_name = '%s_%s.png' % (name, label)
39
+ save_path = os.path.join(image_dir, image_name)
40
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
41
+ ims.append(image_name)
42
+ txts.append(label)
43
+ links.append(image_name)
44
+ webpage.add_images(ims, txts, links, width=width)
45
+
46
+
47
+ class Visualizer():
48
+ """This class includes several functions that can display/save images and print/save logging information.
49
+
50
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
51
+ """
52
+
53
+ def __init__(self, opt):
54
+ """Initialize the Visualizer class
55
+
56
+ Parameters:
57
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
58
+ Step 1: Cache the training/test options
59
+ Step 2: connect to a visdom server
60
+ Step 3: create an HTML object for saveing HTML filters
61
+ Step 4: create a logging file to store training losses
62
+ """
63
+ self.opt = opt # cache the option
64
+ self.display_id = opt.display_id
65
+ self.use_html = opt.isTrain and not opt.no_html
66
+ self.win_size = opt.display_winsize
67
+ self.name = opt.name
68
+ self.port = opt.display_port
69
+ self.saved = False
70
+
71
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
72
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
73
+ self.img_dir = os.path.join(self.web_dir, 'images')
74
+ print('create web directory %s...' % self.web_dir)
75
+ util.mkdirs([self.web_dir, self.img_dir])
76
+ # create a logging file to store training losses
77
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
78
+ with open(self.log_name, "a") as log_file:
79
+ now = time.strftime("%c")
80
+ log_file.write('================ Training Loss (%s) ================\n' % now)
81
+
82
+ def reset(self):
83
+ """Reset the self.saved status"""
84
+ self.saved = False
85
+
86
+ def create_visdom_connections(self):
87
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
88
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
89
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
90
+ print('Command: %s' % cmd)
91
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
92
+
93
+ def display_current_results(self, visuals, epoch, save_result):
94
+ """Display current results on visdom; save current results to an HTML file.
95
+
96
+ Parameters:
97
+ visuals (OrderedDict) - - dictionary of images to display or save
98
+ epoch (int) - - the current epoch
99
+ save_result (bool) - - if save the current results to an HTML file
100
+ """
101
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
102
+ self.saved = True
103
+ # save images to the disk
104
+ for label, image in visuals.items():
105
+ image_numpy = util.tensor2im(image)
106
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
107
+ util.save_image(image_numpy, img_path)
108
+
109
+ # update website
110
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
111
+ for n in range(epoch, 0, -1):
112
+ webpage.add_header('epoch [%d]' % n)
113
+ ims, txts, links = [], [], []
114
+
115
+ for label, image_numpy in visuals.items():
116
+ # image_numpy = util.tensor2im(image)
117
+ img_path = 'epoch%.3d_%s.png' % (n, label)
118
+ ims.append(img_path)
119
+ txts.append(label)
120
+ links.append(img_path)
121
+ webpage.add_images(ims, txts, links, width=self.win_size)
122
+ webpage.save()
123
+
124
+ # def plot_current_losses(self, epoch, counter_ratio, losses):
125
+ # """display the current losses on visdom display: dictionary of error labels and values
126
+ #
127
+ # Parameters:
128
+ # epoch (int) -- current epoch
129
+ # counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
130
+ # losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
131
+ # """
132
+ # if not hasattr(self, 'plot_data'):
133
+ # self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
134
+ # self.plot_data['X'].append(epoch + counter_ratio)
135
+ # self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
136
+ # try:
137
+ # self.vis.line(
138
+ # X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
139
+ # Y=np.array(self.plot_data['Y']),
140
+ # opts={
141
+ # 'title': self.name + ' loss over time',
142
+ # 'legend': self.plot_data['legend'],
143
+ # 'xlabel': 'epoch',
144
+ # 'ylabel': 'loss'},
145
+ # win=self.display_id)
146
+ # except VisdomExceptionBase:
147
+ # self.create_visdom_connections()
148
+
149
+ # losses: same format as |losses| of plot_current_losses
150
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
151
+ """print current losses on console; also save the losses to the disk
152
+
153
+ Parameters:
154
+ epoch (int) -- current epoch
155
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
156
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
157
+ t_comp (float) -- computational time per data point (normalized by batch_size)
158
+ t_data (float) -- data loading time per data point (normalized by batch_size)
159
+ """
160
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
161
+ for k, v in losses.items():
162
+ message += '%s: %.3f ' % (k, v)
163
+
164
+ print(message) # print the message
165
+ with open(self.log_name, "a") as log_file:
166
+ log_file.write('%s\n' % message) # save the message
annotator/lineart/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Caroline Chan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/lineart/__init__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+ from annotator.base_annotator import BaseProcessor
8
+ norm_layer = nn.InstanceNorm2d
9
+
10
+
11
+ class ResidualBlock(nn.Module):
12
+ def __init__(self, in_features):
13
+ super(ResidualBlock, self).__init__()
14
+
15
+ conv_block = [nn.ReflectionPad2d(1),
16
+ nn.Conv2d(in_features, in_features, 3),
17
+ norm_layer(in_features),
18
+ nn.ReLU(inplace=True),
19
+ nn.ReflectionPad2d(1),
20
+ nn.Conv2d(in_features, in_features, 3),
21
+ norm_layer(in_features)
22
+ ]
23
+
24
+ self.conv_block = nn.Sequential(*conv_block)
25
+
26
+ def forward(self, x):
27
+ return x + self.conv_block(x)
28
+
29
+
30
+ class Generator(nn.Module):
31
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
32
+ super(Generator, self).__init__()
33
+
34
+ # Initial convolution block
35
+ model0 = [nn.ReflectionPad2d(3),
36
+ nn.Conv2d(input_nc, 64, 7),
37
+ norm_layer(64),
38
+ nn.ReLU(inplace=True)]
39
+ self.model0 = nn.Sequential(*model0)
40
+
41
+ # Downsampling
42
+ model1 = []
43
+ in_features = 64
44
+ out_features = in_features * 2
45
+ for _ in range(2):
46
+ model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
47
+ norm_layer(out_features),
48
+ nn.ReLU(inplace=True)]
49
+ in_features = out_features
50
+ out_features = in_features * 2
51
+ self.model1 = nn.Sequential(*model1)
52
+
53
+ model2 = []
54
+ # Residual blocks
55
+ for _ in range(n_residual_blocks):
56
+ model2 += [ResidualBlock(in_features)]
57
+ self.model2 = nn.Sequential(*model2)
58
+
59
+ # Upsampling
60
+ model3 = []
61
+ out_features = in_features // 2
62
+ for _ in range(2):
63
+ model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
64
+ norm_layer(out_features),
65
+ nn.ReLU(inplace=True)]
66
+ in_features = out_features
67
+ out_features = in_features // 2
68
+ self.model3 = nn.Sequential(*model3)
69
+
70
+ # Output layer
71
+ model4 = [nn.ReflectionPad2d(3),
72
+ nn.Conv2d(64, output_nc, 7)]
73
+ if sigmoid:
74
+ model4 += [nn.Sigmoid()]
75
+
76
+ self.model4 = nn.Sequential(*model4)
77
+
78
+ def forward(self, x, cond=None):
79
+ out = self.model0(x)
80
+ out = self.model1(out)
81
+ out = self.model2(out)
82
+ out = self.model3(out)
83
+ out = self.model4(out)
84
+
85
+ return out
86
+
87
+
88
+ class LineArtDetector(BaseProcessor):
89
+ model_default = 'sk_model.pth'
90
+ model_coarse = 'sk_model2.pth'
91
+
92
+ def __init__(self, model_name=model_default, **kwargs):
93
+ super().__init__(**kwargs)
94
+ self.model = None
95
+ self.model_dir = os.path.join(self.models_path, "lineart")
96
+ self.model_name = model_name
97
+
98
+ def load_model(self, name):
99
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
100
+ model_path = os.path.join(self.model_dir, name)
101
+ if not os.path.exists(model_path):
102
+ from basicsr.utils.download_util import load_file_from_url
103
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
104
+ model = Generator(3, 1, 3)
105
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
106
+ model.eval()
107
+ self.model = model.to(self.device)
108
+
109
+ def unload_model(self):
110
+ if self.model is not None:
111
+ self.model.cpu()
112
+
113
+ def __call__(self, input_image):
114
+ if self.model is None:
115
+ self.load_model(self.model_name)
116
+ self.model.to(self.device)
117
+
118
+ assert input_image.ndim == 3
119
+ image = input_image
120
+ with torch.no_grad():
121
+ image = torch.from_numpy(image).float().to(self.device)
122
+ image = image / 255.0
123
+ image = rearrange(image, 'h w c -> 1 c h w')
124
+ line = self.model(image)[0][0]
125
+
126
+ line = line.cpu().numpy()
127
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
128
+
129
+ return line
annotator/lineart/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (4.02 kB). View file
 
annotator/lineart_anime/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Caroline Chan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/lineart_anime/__init__.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import functools
5
+
6
+ import os
7
+ import cv2
8
+ from einops import rearrange
9
+ from annotator.base_annotator import BaseProcessor
10
+
11
+
12
+ class UnetGenerator(nn.Module):
13
+ """Create a Unet-based generator"""
14
+
15
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
16
+ """Construct a Unet generator
17
+ Parameters:
18
+ input_nc (int) -- the number of channels in input images
19
+ output_nc (int) -- the number of channels in output images
20
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
21
+ image of size 128x128 will become of size 1x1 # at the bottleneck
22
+ ngf (int) -- the number of filters in the last conv layer
23
+ norm_layer -- normalization layer
24
+ We construct the U-Net from the innermost layer to the outermost layer.
25
+ It is a recursive process.
26
+ """
27
+ super(UnetGenerator, self).__init__()
28
+ # construct unet structure
29
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
30
+ innermost=True) # add the innermost layer
31
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
32
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
33
+ norm_layer=norm_layer, use_dropout=use_dropout)
34
+ # gradually reduce the number of filters from ngf * 8 to ngf
35
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
36
+ norm_layer=norm_layer)
37
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
38
+ norm_layer=norm_layer)
39
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
40
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
41
+ norm_layer=norm_layer) # add the outermost layer
42
+
43
+ def forward(self, input):
44
+ """Standard forward"""
45
+ return self.model(input)
46
+
47
+
48
+ class UnetSkipConnectionBlock(nn.Module):
49
+ """Defines the Unet submodule with skip connection.
50
+ X -------------------identity----------------------
51
+ |-- downsampling -- |submodule| -- upsampling --|
52
+ """
53
+
54
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
55
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
56
+ """Construct a Unet submodule with skip connections.
57
+ Parameters:
58
+ outer_nc (int) -- the number of filters in the outer conv layer
59
+ inner_nc (int) -- the number of filters in the inner conv layer
60
+ input_nc (int) -- the number of channels in input images/features
61
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
62
+ outermost (bool) -- if this module is the outermost module
63
+ innermost (bool) -- if this module is the innermost module
64
+ norm_layer -- normalization layer
65
+ use_dropout (bool) -- if use dropout layers.
66
+ """
67
+ super(UnetSkipConnectionBlock, self).__init__()
68
+ self.outermost = outermost
69
+ if type(norm_layer) == functools.partial:
70
+ use_bias = norm_layer.func == nn.InstanceNorm2d
71
+ else:
72
+ use_bias = norm_layer == nn.InstanceNorm2d
73
+ if input_nc is None:
74
+ input_nc = outer_nc
75
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
76
+ stride=2, padding=1, bias=use_bias)
77
+ downrelu = nn.LeakyReLU(0.2, True)
78
+ downnorm = norm_layer(inner_nc)
79
+ uprelu = nn.ReLU(True)
80
+ upnorm = norm_layer(outer_nc)
81
+
82
+ if outermost:
83
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
84
+ kernel_size=4, stride=2,
85
+ padding=1)
86
+ down = [downconv]
87
+ up = [uprelu, upconv, nn.Tanh()]
88
+ model = down + [submodule] + up
89
+ elif innermost:
90
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
91
+ kernel_size=4, stride=2,
92
+ padding=1, bias=use_bias)
93
+ down = [downrelu, downconv]
94
+ up = [uprelu, upconv, upnorm]
95
+ model = down + up
96
+ else:
97
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
98
+ kernel_size=4, stride=2,
99
+ padding=1, bias=use_bias)
100
+ down = [downrelu, downconv, downnorm]
101
+ up = [uprelu, upconv, upnorm]
102
+
103
+ if use_dropout:
104
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
105
+ else:
106
+ model = down + [submodule] + up
107
+
108
+ self.model = nn.Sequential(*model)
109
+
110
+ def forward(self, x):
111
+ if self.outermost:
112
+ return self.model(x)
113
+ else: # add skip connections
114
+ return torch.cat([x, self.model(x)], 1)
115
+
116
+
117
+ class LineArtAnimeDetector(BaseProcessor):
118
+
119
+ def __init__(self, **kwargs):
120
+ super().__init__(**kwargs)
121
+ self.model = None
122
+ self.model_dir = os.path.join(self.models_path, "lineart_anime")
123
+
124
+ def load_model(self):
125
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth"
126
+ modelpath = os.path.join(self.model_dir, "netG.pth")
127
+ if not os.path.exists(modelpath):
128
+ from basicsr.utils.download_util import load_file_from_url
129
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
130
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
131
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
132
+ ckpt = torch.load(modelpath)
133
+ for key in list(ckpt.keys()):
134
+ if 'module.' in key:
135
+ ckpt[key.replace('module.', '')] = ckpt[key]
136
+ del ckpt[key]
137
+ net.load_state_dict(ckpt)
138
+ net.eval()
139
+ self.model = net.to(self.device)
140
+
141
+ def unload_model(self):
142
+ if self.model is not None:
143
+ self.model.cpu()
144
+
145
+ def __call__(self, input_image):
146
+ if self.model is None:
147
+ self.load_model()
148
+ self.model.to(self.device)
149
+
150
+ H, W, C = input_image.shape
151
+ Hn = 256 * int(np.ceil(float(H) / 256.0))
152
+ Wn = 256 * int(np.ceil(float(W) / 256.0))
153
+ img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
154
+ with torch.no_grad():
155
+ image_feed = torch.from_numpy(img).float().to(self.device)
156
+ image_feed = image_feed / 127.5 - 1.0
157
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
158
+
159
+ line = self.model(image_feed)[0, 0] * 127.5 + 127.5
160
+ line = line.cpu().numpy()
161
+
162
+ line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC)
163
+ line = line.clip(0, 255).astype(np.uint8)
164
+ return line
annotator/lineart_anime/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (6.22 kB). View file
 
annotator/manga_line/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Miaomiao Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
annotator/manga_line/__init__.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import fnmatch
6
+ import cv2
7
+
8
+ import sys
9
+
10
+ import numpy as np
11
+ from einops import rearrange
12
+ from annotator.base_annotator import BaseProcessor
13
+
14
+
15
+ class _bn_relu_conv(nn.Module):
16
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
17
+ super(_bn_relu_conv, self).__init__()
18
+ self.model = nn.Sequential(
19
+ nn.BatchNorm2d(in_filters, eps=1e-3),
20
+ nn.LeakyReLU(0.2),
21
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.model(x)
26
+
27
+ # the following are for debugs
28
+ print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
29
+ for i,layer in enumerate(self.model):
30
+ if i != 2:
31
+ x = layer(x)
32
+ else:
33
+ x = layer(x)
34
+ #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
35
+ print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
36
+ print(x[0])
37
+ return x
38
+
39
+ class _u_bn_relu_conv(nn.Module):
40
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
41
+ super(_u_bn_relu_conv, self).__init__()
42
+ self.model = nn.Sequential(
43
+ nn.BatchNorm2d(in_filters, eps=1e-3),
44
+ nn.LeakyReLU(0.2),
45
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
46
+ nn.Upsample(scale_factor=2, mode='nearest')
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.model(x)
51
+
52
+
53
+
54
+ class _shortcut(nn.Module):
55
+ def __init__(self, in_filters, nb_filters, subsample=1):
56
+ super(_shortcut, self).__init__()
57
+ self.process = False
58
+ self.model = None
59
+ if in_filters != nb_filters or subsample != 1:
60
+ self.process = True
61
+ self.model = nn.Sequential(
62
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
63
+ )
64
+
65
+ def forward(self, x, y):
66
+ #print(x.size(), y.size(), self.process)
67
+ if self.process:
68
+ y0 = self.model(x)
69
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
70
+ return y0 + y
71
+ else:
72
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
73
+ return x + y
74
+
75
+ class _u_shortcut(nn.Module):
76
+ def __init__(self, in_filters, nb_filters, subsample):
77
+ super(_u_shortcut, self).__init__()
78
+ self.process = False
79
+ self.model = None
80
+ if in_filters != nb_filters:
81
+ self.process = True
82
+ self.model = nn.Sequential(
83
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
84
+ nn.Upsample(scale_factor=2, mode='nearest')
85
+ )
86
+
87
+ def forward(self, x, y):
88
+ if self.process:
89
+ return self.model(x) + y
90
+ else:
91
+ return x + y
92
+
93
+
94
+ class basic_block(nn.Module):
95
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
96
+ super(basic_block, self).__init__()
97
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
98
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
99
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
100
+
101
+ def forward(self, x):
102
+ x1 = self.conv1(x)
103
+ x2 = self.residual(x1)
104
+ return self.shortcut(x, x2)
105
+
106
+ class _u_basic_block(nn.Module):
107
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
108
+ super(_u_basic_block, self).__init__()
109
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
110
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
111
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
112
+
113
+ def forward(self, x):
114
+ y = self.residual(self.conv1(x))
115
+ return self.shortcut(x, y)
116
+
117
+
118
+ class _residual_block(nn.Module):
119
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
120
+ super(_residual_block, self).__init__()
121
+ layers = []
122
+ for i in range(repetitions):
123
+ init_subsample = 1
124
+ if i == repetitions - 1 and not is_first_layer:
125
+ init_subsample = 2
126
+ if i == 0:
127
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
128
+ else:
129
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
130
+ layers.append(l)
131
+
132
+ self.model = nn.Sequential(*layers)
133
+
134
+ def forward(self, x):
135
+ return self.model(x)
136
+
137
+
138
+ class _upsampling_residual_block(nn.Module):
139
+ def __init__(self, in_filters, nb_filters, repetitions):
140
+ super(_upsampling_residual_block, self).__init__()
141
+ layers = []
142
+ for i in range(repetitions):
143
+ l = None
144
+ if i == 0:
145
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
146
+ else:
147
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
148
+ layers.append(l)
149
+
150
+ self.model = nn.Sequential(*layers)
151
+
152
+ def forward(self, x):
153
+ return self.model(x)
154
+
155
+
156
+ class res_skip(nn.Module):
157
+
158
+ def __init__(self):
159
+ super(res_skip, self).__init__()
160
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
161
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
162
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
163
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
164
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
165
+
166
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
167
+ self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))
168
+
169
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
170
+ self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))
171
+
172
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
173
+ self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))
174
+
175
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
176
+ self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))
177
+
178
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
179
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)
180
+
181
+ def forward(self, x):
182
+ x0 = self.block0(x)
183
+ x1 = self.block1(x0)
184
+ x2 = self.block2(x1)
185
+ x3 = self.block3(x2)
186
+ x4 = self.block4(x3)
187
+
188
+ x5 = self.block5(x4)
189
+ res1 = self.res1(x3, x5)
190
+
191
+ x6 = self.block6(res1)
192
+ res2 = self.res2(x2, x6)
193
+
194
+ x7 = self.block7(res2)
195
+ res3 = self.res3(x1, x7)
196
+
197
+ x8 = self.block8(res3)
198
+ res4 = self.res4(x0, x8)
199
+
200
+ x9 = self.block9(res4)
201
+ y = self.conv15(x9)
202
+
203
+ return y
204
+
205
+
206
+ class MangaLineExtration(BaseProcessor):
207
+ def __init__(self, **kwargs):
208
+ super().__init__(**kwargs)
209
+ self.model = None
210
+ self.model_dir = os.path.join(self.models_path, "manga_line")
211
+ # self.device = devices.get_device_for("controlnet")
212
+
213
+ def load_model(self):
214
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth"
215
+ modelpath = os.path.join(self.model_dir, "erika.pth")
216
+ if not os.path.exists(modelpath):
217
+ from basicsr.utils.download_util import load_file_from_url
218
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
219
+ #norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
220
+ net = res_skip()
221
+ ckpt = torch.load(modelpath)
222
+ for key in list(ckpt.keys()):
223
+ if 'module.' in key:
224
+ ckpt[key.replace('module.', '')] = ckpt[key]
225
+ del ckpt[key]
226
+ net.load_state_dict(ckpt)
227
+ net.eval()
228
+ self.model = net.to(self.device)
229
+
230
+ def unload_model(self):
231
+ if self.model is not None:
232
+ self.model.cpu()
233
+
234
+ def __call__(self, input_image):
235
+ if self.model is None:
236
+ self.load_model()
237
+ self.model.to(self.device)
238
+ img = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
239
+ img = np.ascontiguousarray(img.copy()).copy()
240
+ with torch.no_grad():
241
+ image_feed = torch.from_numpy(img).float().to(self.device)
242
+ image_feed = rearrange(image_feed, 'h w -> 1 1 h w')
243
+ line = self.model(image_feed)
244
+ line = 255 - line.cpu().numpy()[0, 0]
245
+ return line.clip(0, 255).astype(np.uint8)
246
+
247
+