esun-choi commited on
Commit
8804c8f
1 Parent(s): fab8eeb

Initial Commit

Browse files
__pycache__/craft.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
__pycache__/craft_utils.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
__pycache__/file_utils.cpython-310.pyc ADDED
Binary file (2.49 kB). View file
 
__pycache__/imgproc.cpython-310.pyc ADDED
Binary file (2.08 kB). View file
 
__pycache__/mosaik.cpython-310.pyc ADDED
Binary file (698 Bytes). View file
 
__pycache__/ner.cpython-310.pyc ADDED
Binary file (906 Bytes). View file
 
__pycache__/recognize.cpython-310.pyc ADDED
Binary file (716 Bytes). View file
 
__pycache__/refinenet.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
__pycache__/seg.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
__pycache__/seg2.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
basenet/__init__.py ADDED
File without changes
basenet/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
basenet/__pycache__/vgg16_bn.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
basenet/vgg16_bn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.init as init
6
+ from torchvision import models
7
+
8
+ def init_weights(modules):
9
+ for m in modules:
10
+ if isinstance(m, nn.Conv2d):
11
+ init.xavier_uniform_(m.weight.data)
12
+ if m.bias is not None:
13
+ m.bias.data.zero_()
14
+ elif isinstance(m, nn.BatchNorm2d):
15
+ m.weight.data.fill_(1)
16
+ m.bias.data.zero_()
17
+ elif isinstance(m, nn.Linear):
18
+ m.weight.data.normal_(0, 0.01)
19
+ m.bias.data.zero_()
20
+
21
+ class vgg16_bn(torch.nn.Module):
22
+ def __init__(self, pretrained=True, freeze=True):
23
+ super(vgg16_bn, self).__init__()
24
+
25
+ vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
26
+ self.slice1 = torch.nn.Sequential()
27
+ self.slice2 = torch.nn.Sequential()
28
+ self.slice3 = torch.nn.Sequential()
29
+ self.slice4 = torch.nn.Sequential()
30
+ self.slice5 = torch.nn.Sequential()
31
+ for x in range(12): # conv2_2
32
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
33
+ for x in range(12, 19): # conv3_3
34
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
35
+ for x in range(19, 29): # conv4_3
36
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
37
+ for x in range(29, 39): # conv5_3
38
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
39
+
40
+ # fc6, fc7 without atrous conv
41
+ self.slice5 = torch.nn.Sequential(
42
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
43
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
44
+ nn.Conv2d(1024, 1024, kernel_size=1)
45
+ )
46
+
47
+ if not pretrained:
48
+ init_weights(self.slice1.modules())
49
+ init_weights(self.slice2.modules())
50
+ init_weights(self.slice3.modules())
51
+ init_weights(self.slice4.modules())
52
+
53
+ init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
54
+
55
+ if freeze:
56
+ for param in self.slice1.parameters(): # only first conv
57
+ param.requires_grad= False
58
+
59
+ def forward(self, X):
60
+ h = self.slice1(X)
61
+ h_relu2_2 = h
62
+ h = self.slice2(h)
63
+ h_relu3_2 = h
64
+ h = self.slice3(h)
65
+ h_relu4_3 = h
66
+ h = self.slice4(h)
67
+ h_relu5_3 = h
68
+ h = self.slice5(h)
69
+ h_fc7 = h
70
+ vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
71
+ out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
72
+ return out
craft.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from basenet.vgg16_bn import vgg16_bn, init_weights
7
+
8
+ class double_conv(nn.Module):
9
+ def __init__(self, in_ch, mid_ch, out_ch):
10
+ super(double_conv, self).__init__()
11
+ self.conv = nn.Sequential(
12
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
13
+ nn.BatchNorm2d(mid_ch),
14
+ nn.ReLU(inplace=True),
15
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
16
+ nn.BatchNorm2d(out_ch),
17
+ nn.ReLU(inplace=True)
18
+ )
19
+
20
+ def forward(self, x):
21
+ x = self.conv(x)
22
+ return x
23
+
24
+
25
+ class CRAFT(nn.Module):
26
+ def __init__(self, pretrained=False, freeze=False):
27
+ super(CRAFT, self).__init__()
28
+
29
+ """ Base network """
30
+ self.basenet = vgg16_bn(pretrained, freeze)
31
+
32
+ """ U network """
33
+ self.upconv1 = double_conv(1024, 512, 256)
34
+ self.upconv2 = double_conv(512, 256, 128)
35
+ self.upconv3 = double_conv(256, 128, 64)
36
+ self.upconv4 = double_conv(128, 64, 32)
37
+
38
+ num_class = 2
39
+ self.conv_cls = nn.Sequential(
40
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
41
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
42
+ nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
43
+ nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
44
+ nn.Conv2d(16, num_class, kernel_size=1),
45
+ )
46
+
47
+ init_weights(self.upconv1.modules())
48
+ init_weights(self.upconv2.modules())
49
+ init_weights(self.upconv3.modules())
50
+ init_weights(self.upconv4.modules())
51
+ init_weights(self.conv_cls.modules())
52
+
53
+ def forward(self, x):
54
+ """ Base network """
55
+ sources = self.basenet(x)
56
+
57
+ """ U network """
58
+ y = torch.cat([sources[0], sources[1]], dim=1)
59
+ y = self.upconv1(y)
60
+
61
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
62
+ y = torch.cat([y, sources[2]], dim=1)
63
+ y = self.upconv2(y)
64
+
65
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
66
+ y = torch.cat([y, sources[3]], dim=1)
67
+ y = self.upconv3(y)
68
+
69
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
70
+ y = torch.cat([y, sources[4]], dim=1)
71
+ feature = self.upconv4(y)
72
+
73
+ y = self.conv_cls(feature)
74
+
75
+ return y.permute(0,2,3,1), feature
76
+
craft_utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import math
6
+
7
+ def warpCoord(Minv, pt):
8
+ out = np.matmul(Minv, (pt[0], pt[1], 1))
9
+ return np.array([out[0]/out[2], out[1]/out[2]])
10
+
11
+
12
+ def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
13
+ linkmap = linkmap.copy()
14
+ textmap = textmap.copy()
15
+ img_h, img_w = textmap.shape
16
+
17
+ ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
18
+ ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
19
+
20
+ text_score_comb = np.clip(text_score + link_score, 0, 1)
21
+ nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
22
+
23
+ det = []
24
+ mapper = []
25
+ for k in range(1,nLabels):
26
+ size = stats[k, cv2.CC_STAT_AREA]
27
+ if size < 10: continue
28
+
29
+ if np.max(textmap[labels==k]) < text_threshold: continue
30
+
31
+ segmap = np.zeros(textmap.shape, dtype=np.uint8)
32
+ segmap[labels==k] = 255
33
+ segmap[np.logical_and(link_score==1, text_score==0)] = 0
34
+ x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
35
+ w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
36
+ niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
37
+ sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
38
+ if sx < 0 : sx = 0
39
+ if sy < 0 : sy = 0
40
+ if ex >= img_w: ex = img_w
41
+ if ey >= img_h: ey = img_h
42
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
43
+ segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
44
+
45
+ np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
46
+ rectangle = cv2.minAreaRect(np_contours)
47
+ box = cv2.boxPoints(rectangle)
48
+
49
+ w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
50
+ box_ratio = max(w, h) / (min(w, h) + 1e-5)
51
+ if abs(1 - box_ratio) <= 0.1:
52
+ l, r = min(np_contours[:,0]), max(np_contours[:,0])
53
+ t, b = min(np_contours[:,1]), max(np_contours[:,1])
54
+ box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
55
+
56
+ startidx = box.sum(axis=1).argmin()
57
+ box = np.roll(box, 4-startidx, 0)
58
+ box = np.array(box)
59
+
60
+ det.append(box)
61
+ mapper.append(k)
62
+
63
+ return det, labels, mapper
64
+
65
+ def getPoly_core(boxes, labels, mapper, linkmap):
66
+ num_cp = 5
67
+ max_len_ratio = 0.7
68
+ expand_ratio = 1.45
69
+ max_r = 2.0
70
+ step_r = 0.2
71
+
72
+ polys = []
73
+ for k, box in enumerate(boxes):
74
+ w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
75
+ if w < 10 or h < 10:
76
+ polys.append(None); continue
77
+
78
+ tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
79
+ M = cv2.getPerspectiveTransform(box, tar)
80
+ word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
81
+ try:
82
+ Minv = np.linalg.inv(M)
83
+ except:
84
+ polys.append(None); continue
85
+
86
+ cur_label = mapper[k]
87
+ word_label[word_label != cur_label] = 0
88
+ word_label[word_label > 0] = 1
89
+
90
+ cp = []
91
+ max_len = -1
92
+ for i in range(w):
93
+ region = np.where(word_label[:,i] != 0)[0]
94
+ if len(region) < 2 : continue
95
+ cp.append((i, region[0], region[-1]))
96
+ length = region[-1] - region[0] + 1
97
+ if length > max_len: max_len = length
98
+
99
+ if h * max_len_ratio < max_len:
100
+ polys.append(None); continue
101
+
102
+ tot_seg = num_cp * 2 + 1
103
+ seg_w = w / tot_seg
104
+ pp = [None] * num_cp
105
+ cp_section = [[0, 0]] * tot_seg
106
+ seg_height = [0] * num_cp
107
+ seg_num = 0
108
+ num_sec = 0
109
+ prev_h = -1
110
+ for i in range(0,len(cp)):
111
+ (x, sy, ey) = cp[i]
112
+ if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
113
+ # average previous segment
114
+ if num_sec == 0: break
115
+ cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
116
+ num_sec = 0
117
+
118
+ # reset variables
119
+ seg_num += 1
120
+ prev_h = -1
121
+
122
+ # accumulate center points
123
+ cy = (sy + ey) * 0.5
124
+ cur_h = ey - sy + 1
125
+ cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
126
+ num_sec += 1
127
+
128
+ if seg_num % 2 == 0: continue # No polygon area
129
+
130
+ if prev_h < cur_h:
131
+ pp[int((seg_num - 1)/2)] = (x, cy)
132
+ seg_height[int((seg_num - 1)/2)] = cur_h
133
+ prev_h = cur_h
134
+
135
+ # processing last segment
136
+ if num_sec != 0:
137
+ cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
138
+
139
+ # pass if num of pivots is not sufficient or segment widh is smaller than character height
140
+ if None in pp or seg_w < np.max(seg_height) * 0.25:
141
+ polys.append(None); continue
142
+
143
+ # calc median maximum of pivot points
144
+ half_char_h = np.median(seg_height) * expand_ratio / 2
145
+
146
+ # calc gradiant and apply to make horizontal pivots
147
+ new_pp = []
148
+ for i, (x, cy) in enumerate(pp):
149
+ dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
150
+ dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
151
+ if dx == 0: # gradient if zero
152
+ new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
153
+ continue
154
+ rad = - math.atan2(dy, dx)
155
+ c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
156
+ new_pp.append([x - s, cy - c, x + s, cy + c])
157
+
158
+ # get edge points to cover character heatmaps
159
+ isSppFound, isEppFound = False, False
160
+ grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
161
+ grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
162
+ for r in np.arange(0.5, max_r, step_r):
163
+ dx = 2 * half_char_h * r
164
+ if not isSppFound:
165
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
166
+ dy = grad_s * dx
167
+ p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
168
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
169
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
170
+ spp = p
171
+ isSppFound = True
172
+ if not isEppFound:
173
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
174
+ dy = grad_e * dx
175
+ p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
176
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
177
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
178
+ epp = p
179
+ isEppFound = True
180
+ if isSppFound and isEppFound:
181
+ break
182
+
183
+ if not (isSppFound and isEppFound):
184
+ polys.append(None); continue
185
+
186
+ poly = []
187
+ poly.append(warpCoord(Minv, (spp[0], spp[1])))
188
+ for p in new_pp:
189
+ poly.append(warpCoord(Minv, (p[0], p[1])))
190
+ poly.append(warpCoord(Minv, (epp[0], epp[1])))
191
+ poly.append(warpCoord(Minv, (epp[2], epp[3])))
192
+ for p in reversed(new_pp):
193
+ poly.append(warpCoord(Minv, (p[2], p[3])))
194
+ poly.append(warpCoord(Minv, (spp[2], spp[3])))
195
+
196
+ # add to final result
197
+ polys.append(np.array(poly))
198
+
199
+ return polys
200
+
201
+ def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
202
+ boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
203
+
204
+ if poly:
205
+ polys = getPoly_core(boxes, labels, mapper, linkmap)
206
+ else:
207
+ polys = [None] * len(boxes)
208
+
209
+ return boxes, polys
210
+
211
+ def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
212
+ if len(polys) > 0:
213
+ polys = np.array(polys)
214
+ for k in range(len(polys)):
215
+ if polys[k] is not None:
216
+ polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
217
+ return polys
dino2/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.85 kB). View file
 
dino2/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.hub import load
4
+ import torchvision.models as models
5
+
6
+
7
+ dino_backbones = {
8
+ 'dinov2_s':{
9
+ 'name':'dinov2_vits14',
10
+ 'embedding_size':384,
11
+ 'patch_size':14
12
+ },
13
+ 'dinov2_b':{
14
+ 'name':'dinov2_vitb14',
15
+ 'embedding_size':768,
16
+ 'patch_size':14
17
+ },
18
+ 'dinov2_l':{
19
+ 'name':'dinov2_vitl14',
20
+ 'embedding_size':1024,
21
+ 'patch_size':14
22
+ },
23
+ 'dinov2_g':{
24
+ 'name':'dinov2_vitg14',
25
+ 'embedding_size':1536,
26
+ 'patch_size':14
27
+ },
28
+ }
29
+
30
+
31
+ class linear_head(nn.Module):
32
+ def __init__(self, embedding_size = 384, num_classes = 5):
33
+ super(linear_head, self).__init__()
34
+ self.fc = nn.Linear(embedding_size, num_classes)
35
+
36
+ def forward(self, x):
37
+ return self.fc(x)
38
+
39
+
40
+ class conv_head(nn.Module):
41
+ def __init__(self, embedding_size = 384, num_classes = 5):
42
+ super(conv_head, self).__init__()
43
+ self.segmentation_conv = nn.Sequential(
44
+ nn.Upsample(scale_factor=2),
45
+ nn.Conv2d(embedding_size, 64, (3,3), padding=(1,1)),
46
+ nn.Upsample(scale_factor=2),
47
+ nn.Conv2d(64, num_classes, (3,3), padding=(1,1)),
48
+ )
49
+
50
+ def forward(self, x):
51
+ x = self.segmentation_conv(x)
52
+ x = torch.sigmoid(x)
53
+
54
+
55
+ return x
56
+
57
+
58
+ def threshold_mask(predicted, threshold=0.55):
59
+ thresholded_mask = (predicted > threshold).float()
60
+ return thresholded_mask
61
+
62
+
63
+
64
+
65
+ class Segmentor(nn.Module):
66
+ def __init__(self, device,num_classes, backbone = 'dinov2_s', head = 'conv', backbones = dino_backbones):
67
+ super(Segmentor, self).__init__()
68
+ self.heads = {
69
+ 'conv':conv_head
70
+ }
71
+ self.backbones = dino_backbones
72
+ self.backbone = load('facebookresearch/dinov2', self.backbones[backbone]['name'])
73
+
74
+ self.backbone.eval()
75
+ self.num_classes = num_classes
76
+ self.embedding_size = self.backbones[backbone]['embedding_size']
77
+ self.patch_size = self.backbones[backbone]['patch_size']
78
+ self.head = self.heads[head](self.embedding_size,self.num_classes)
79
+ self.device=device
80
+
81
+ def forward(self, x):
82
+ batch_size = x.shape[0]
83
+ mask_dim = (x.shape[2] / self.patch_size, x.shape[3] / self.patch_size)
84
+ x = self.backbone.forward_features(x.to(self.device))
85
+
86
+ x = x['x_norm_patchtokens']
87
+ x = x.permute(0,2,1)
88
+ x = x.reshape(batch_size,self.embedding_size,int(mask_dim[0]),int(mask_dim[1]))
89
+ x = self.head(x)
90
+ return x
91
+
92
+
93
+
file_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import numpy as np
4
+ import cv2
5
+ import imgproc
6
+ from mosaik import mosaik
7
+ # borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
8
+ def get_files(img_dir):
9
+ imgs, masks, xmls = list_files(img_dir)
10
+ return imgs, masks, xmls
11
+
12
+ def list_files(in_path):
13
+ img_files = []
14
+ mask_files = []
15
+ gt_files = []
16
+ for (dirpath, dirnames, filenames) in os.walk(in_path):
17
+ for file in filenames:
18
+ filename, ext = os.path.splitext(file)
19
+ ext = str.lower(ext)
20
+ if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
21
+ img_files.append(os.path.join(dirpath, file))
22
+ elif ext == '.bmp':
23
+ mask_files.append(os.path.join(dirpath, file))
24
+ elif ext == '.xml' or ext == '.gt' or ext == '.txt':
25
+ gt_files.append(os.path.join(dirpath, file))
26
+ elif ext == '.zip':
27
+ continue
28
+ # img_files.sort()
29
+ # mask_files.sort()
30
+ # gt_files.sort()
31
+ return img_files, mask_files, gt_files
32
+
33
+ def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None):
34
+ """ save text detection result one by one
35
+ Args:
36
+ img_file (str): image file name
37
+ img (array): raw image context
38
+ boxes (array): array of result file
39
+ Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output
40
+ Return:
41
+ None
42
+ """
43
+ img = np.array(img)
44
+
45
+ # make result file list
46
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
47
+
48
+ # result directory
49
+ res_file = dirname + "res_" + filename + '.txt'
50
+ res_img_file = dirname + "res_" + filename + '.jpg'
51
+
52
+ if not os.path.isdir(dirname):
53
+ os.mkdir(dirname)
54
+
55
+ with open(res_file, 'w') as f:
56
+ for i, box in enumerate(boxes):
57
+ poly = np.array(box).astype(np.int32).reshape((-1))
58
+ strResult = ','.join([str(p) for p in poly]) + '\r\n'
59
+ f.write(strResult)
60
+
61
+ poly = poly.reshape(-1, 2)
62
+ cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
63
+ ptColor = (0, 255, 255)
64
+ if verticals is not None:
65
+ if verticals[i]:
66
+ ptColor = (255, 0, 0)
67
+
68
+ if texts is not None:
69
+ font = cv2.FONT_HERSHEY_SIMPLEX
70
+ font_scale = 0.5
71
+ cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1)
72
+ cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1)
73
+
74
+ # Save result image
75
+ cv2.imwrite(res_img_file, img)
76
+ return img
77
+
imgproc.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import numpy as np
8
+ from skimage import io
9
+ import cv2
10
+
11
+ def loadImage(img_file):
12
+ img = io.imread(img_file) # RGB order
13
+ if img.shape[0] == 2: img = img[0]
14
+ if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
15
+ if img.shape[2] == 4: img = img[:,:,:3]
16
+ img = np.array(img)
17
+
18
+ return img
19
+
20
+ def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
21
+ # should be RGB order
22
+ img = in_img.copy().astype(np.float32)
23
+
24
+ img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
25
+ img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
26
+ return img
27
+
28
+ def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
29
+ # should be RGB order
30
+ img = in_img.copy()
31
+ img *= variance
32
+ img += mean
33
+ img *= 255.0
34
+ img = np.clip(img, 0, 255).astype(np.uint8)
35
+ return img
36
+
37
+ def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
38
+ height, width, channel = img.shape
39
+
40
+ # magnify image size
41
+ target_size = mag_ratio * max(height, width)
42
+
43
+ # set original image size
44
+ if target_size > square_size:
45
+ target_size = square_size
46
+
47
+ ratio = target_size / max(height, width)
48
+
49
+ target_h, target_w = int(height * ratio), int(width * ratio)
50
+ proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation)
51
+
52
+
53
+ # make canvas and paste image
54
+ target_h32, target_w32 = target_h, target_w
55
+ if target_h % 32 != 0:
56
+ target_h32 = target_h + (32 - target_h % 32)
57
+ if target_w % 32 != 0:
58
+ target_w32 = target_w + (32 - target_w % 32)
59
+ resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
60
+ resized[0:target_h, 0:target_w, :] = proc
61
+ target_h, target_w = target_h32, target_w32
62
+
63
+ size_heatmap = (int(target_w/2), int(target_h/2))
64
+
65
+ return resized, ratio, size_heatmap
66
+
67
+ def cvt2HeatmapImg(img):
68
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
69
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
70
+ return img
input/1.png ADDED
input/2.png ADDED
input/3.png ADDED
input/4.png ADDED
install.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ pip install
main.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from recognize import recongize
2
+ from ner import ner
3
+ import os
4
+ import time
5
+ import argparse
6
+ from sr.sr import sr
7
+ import torch
8
+ from scipy.ndimage import gaussian_filter
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ import torch.backends.cudnn as cudnn
13
+ from torch.autograd import Variable
14
+ from mosaik import mosaik
15
+ from PIL import Image
16
+ import cv2
17
+ from skimage import io
18
+ import numpy as np
19
+ import craft_utils
20
+ import imgproc
21
+ import file_utils
22
+ from seg import mask_percentage
23
+
24
+ from seg2 import dino_seg
25
+
26
+ from craft import CRAFT
27
+ from collections import OrderedDict
28
+ import gradio as gr
29
+ from refinenet import RefineNet
30
+
31
+
32
+ # craft, refine 모델 불러오는 코드
33
+ def copyStateDict(state_dict):
34
+ if list(state_dict.keys())[0].startswith("module"):
35
+ start_idx = 1
36
+ else:
37
+ start_idx = 0
38
+ new_state_dict = OrderedDict()
39
+ for k, v in state_dict.items():
40
+ name = ".".join(k.split(".")[start_idx:])
41
+ new_state_dict[name] = v
42
+ return new_state_dict
43
+
44
+ def str2bool(v):
45
+ return v.lower() in ("yes", "y", "true", "t", "1")
46
+
47
+ parser = argparse.ArgumentParser(description='CRAFT Text Detection')
48
+ parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='사전학습 craft 모델')
49
+ parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
50
+ parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
51
+ parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
52
+ parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
53
+ parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
54
+ parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
55
+ parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
56
+ parser.add_argument('--refine', default=True, help='enable link refiner')
57
+ parser.add_argument('--image_path', default="input/2.png", help='input image')
58
+ parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
59
+
60
+ args = parser.parse_args()
61
+ # 아래는 option
62
+ def full_img_masking(full_image,net,refine_net):
63
+ reference_image=sr(full_image)
64
+ reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
65
+ boxes=get_box_from_refer(reference_boxes)
66
+ for index2,box in enumerate(boxes):
67
+ xmin,xmax,ymin,ymax=get_min_max(box)
68
+
69
+ text_area=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
70
+
71
+ text=recongize(text_area)
72
+ label=ner(text)
73
+
74
+ if label==1:
75
+ A=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
76
+ full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
77
+ return full_image
78
+
79
+ def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
80
+ t0 = time.time()
81
+
82
+ img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
83
+ ratio_h = ratio_w = 1 / target_ratio
84
+
85
+ x = imgproc.normalizeMeanVariance(img_resized)
86
+ x = torch.from_numpy(x).permute(2, 0, 1)
87
+ x = Variable(x.unsqueeze(0))
88
+ if cuda:
89
+ x = x.cuda()
90
+
91
+ with torch.no_grad():
92
+ y, feature = net(x)
93
+
94
+ score_text = y[0,:,:,0].cpu().data.numpy()
95
+ score_link = y[0,:,:,1].cpu().data.numpy()
96
+
97
+ if refine_net is not None:
98
+ with torch.no_grad():
99
+ y_refiner = refine_net(y, feature)
100
+ score_link = y_refiner[0,:,:,0].cpu().data.numpy()
101
+
102
+ t0 = time.time() - t0
103
+ t1 = time.time()
104
+
105
+ boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
106
+
107
+ boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
108
+ polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
109
+ for k in range(len(polys)):
110
+ if polys[k] is None: polys[k] = boxes[k]
111
+
112
+ t1 = time.time() - t1
113
+
114
+ # render results (optional)
115
+ render_img = score_text.copy()
116
+ render_img = np.hstack((render_img, score_link))
117
+ ret_score_text = imgproc.cvt2HeatmapImg(render_img)
118
+
119
+
120
+ return boxes, polys, ret_score_text
121
+
122
+ def text_detect(image,net,refine_net):
123
+
124
+ bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
125
+
126
+
127
+ return bboxes
128
+
129
+
130
+ def get_box_from_refer(reference_boxes):
131
+
132
+ real_boxes=[]
133
+ for box in reference_boxes:
134
+
135
+ real_boxes.append(box//2)
136
+
137
+ return real_boxes
138
+ def get_min_max(box):
139
+ xlist=[]
140
+ ylist=[]
141
+ for coor in box:
142
+ xlist.append(coor[0])
143
+ ylist.append(coor[1])
144
+ return min(xlist),max(xlist),min(ylist),max(ylist)
145
+
146
+ def main(image_path0):
147
+ # 1단계
148
+
149
+ # ==> craft 모델과 refinenet 모델을 불러오고 cuda device 에 얹힙니다.
150
+
151
+ net = CRAFT()
152
+ if args.cuda:
153
+ net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
154
+
155
+ if args.cuda:
156
+ net = net.cuda()
157
+ cudnn.benchmark = False
158
+
159
+ net.eval()
160
+
161
+ refine_net = None
162
+ if args.refine:
163
+ refine_net = RefineNet()
164
+ if args.cuda:
165
+ refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
166
+ refine_net = refine_net.cuda()
167
+
168
+
169
+ refine_net.eval()
170
+ args.poly = True
171
+
172
+ # 2단계
173
+
174
+ # gradio 빈칸에 이미지를 넣고 A 에 입력됩니다.
175
+
176
+ A=image_path0
177
+ image_list=[]
178
+ image_list.append(A)
179
+ for k, image_path in enumerate(image_list):
180
+
181
+
182
+
183
+ image = imgproc.loadImage(image_path)
184
+ if image.shape[2]>3:
185
+ image=image[:,:,0:3]
186
+
187
+ original_image=image
188
+ # 이미지에서 송장부분만 dinov2 모델로 segmentation 을 합니다.
189
+
190
+ output=dino_seg(image)
191
+ image3=Image.fromarray(output)
192
+ image3.save("temporal_mask/mask.png")
193
+
194
+ # 마스크이미지(white pixel, black background)를 만듭니다.
195
+ # 위 마스크 이미지에서 각 덩어리들(송장으로 추정)이 전체 이미지내에서 몇프로차지하는지 계산합니다.
196
+ contours_list,percentage_list=mask_percentage("temporal_mask/mask.png")
197
+
198
+ normal_image_list=[]
199
+
200
+ small_coordinate_list=[]
201
+ original_coordinate_list=[]
202
+
203
+
204
+ #3단계
205
+
206
+
207
+
208
+ sorted_list = sorted(percentage_list, reverse=True)
209
+ top_5 = sorted_list[:5]
210
+ print("상위 5개 값:", top_5)
211
+ # percentage list의 경우 송장으로 추정되는 뭉치들의 퍼센트를 모아놓은것이고
212
+ # contours list는 이미지내에서 송장으로 추정되는 뭉치들이 크롭되어서 정렬된 리스트입니다.
213
+ # 예 : percentatge list 의 첫번째 요소는 contours list의 첫번째 요소의 percentage
214
+
215
+ for index,percentage in enumerate(percentage_list):
216
+
217
+ if 5<percentage:
218
+
219
+ # percentage 가 아미지내에서 5프로 넘는 것들은 normal list로 포함됩니다.
220
+ # normal list안에는 이미지내에서 충분히 큰 뭉치들(송장으로 추정) 을 모아놓았습니다.
221
+ # 1-5프로 인것들은 small coordinate list에 포함되고 매우 작은 뭉치로 간주합니다.
222
+ # 매우작은 뭉치의 경우 zoom in을 했을때 뭉치(송장으로 추정)내 글자가 거의 보이지않아서 따라서 뭉치 전체를 mosaik합니다.
223
+ # 1프로미만 뭉치들은 소멸직전일정도로 작아 생략합니다.
224
+
225
+ contour=contours_list[index]
226
+
227
+ x_list=[]
228
+ y_list=[]
229
+ contour2=list(contour)
230
+
231
+ for r in contour2:
232
+ r2=r[0]
233
+ x_list.append(r2[0])
234
+ y_list.append(r2[1])
235
+ x_min=min(x_list)
236
+ y_min=min(y_list)
237
+ x_max=max(x_list)
238
+ y_max=max(y_list)
239
+ original_coordinate_list.append([y_min,y_max,x_min,x_max])
240
+ image2=original_image[y_min:y_max,x_min:x_max,:]
241
+ normal_image_list.append(image2)
242
+
243
+
244
+ #
245
+ elif 1<percentage<5:
246
+ contour=contours_list[index]
247
+
248
+ x_list=[]
249
+ y_list=[]
250
+ contour2=list(contour)
251
+
252
+ for r in contour2:
253
+ r2=r[0]
254
+ x_list.append(r2[0])
255
+ y_list.append(r2[1])
256
+ x_min=min(x_list)
257
+ y_min=min(y_list)
258
+ x_max=max(x_list)
259
+ y_max=max(y_list)
260
+ small_coordinate_list.append([y_min,y_max,x_min,x_max]) #송장 5프로미만의 좌표
261
+ else:
262
+ continue
263
+
264
+
265
+
266
+
267
+ # 4단계 (매우작은 송장)
268
+
269
+ # small coordinate list안에 매우작은 송장들이 모여져있지만 list안에 요소가 없으면 5단계로 바로갑니다.
270
+ # 바로 가지않을경우(list 안요소 최소하나) mosaik 를 통해서 전체이미지에서 작은 뭉치에 해당하는 좌표들을 모두 모자이크합니다.
271
+
272
+ if len(small_coordinate_list)>0:
273
+ original_image=mosaik(original_image,small_coordinate_list)
274
+ else:
275
+ pass
276
+
277
+ # 5단계 (어느정도 사이즈 있는 송장) ==> normal list
278
+
279
+ # normal image list안에 적절한 크기의 송장(줌 하면 글자 보이는) 들이 있습니다.
280
+ # craft 입장에서 text 위치를 return 할수 있게끔 크롭된 송장을 esrgan 으로 화질개선합니다.
281
+ # 화질개선된 송장을 craft에 넣어서 정확하게 text 좌표들을 모두 구합니다.
282
+ # 좌표를 구할때 화질 좋은 송장이미지의 좌표를 그대로 return 하지 않고 원본 송장이미지에 맞추어서 scale(//2) 하고 최종좌표를 구합니다.
283
+
284
+ for index,normal_image in enumerate(normal_image_list):
285
+ reference_image=sr(normal_image)
286
+ reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
287
+ boxes=get_box_from_refer(reference_boxes)
288
+ for index2,box in enumerate(boxes):
289
+ xmin,xmax,ymin,ymax=get_min_max(box)
290
+
291
+ text_area=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
292
+ text_area=Image.fromarray(text_area)
293
+ os.makedirs("text_area",exist_ok=True)
294
+ text_area.save(f"text_area/new_{index2+1}.png")
295
+
296
+
297
+ # 6단계 (text recognize, ner)
298
+
299
+ # 위 좌표들을 통해서 송장 내에서 박스들을 크롭합니다.
300
+ # 크롭된 송장내 부분(크롭된 박스 , 즉 text 있는곳으로 추정되는곳) 을 trocr 에넣습니다.
301
+ # trocr은 상자내에 추정되는 text를 보여줍니다.
302
+ # text를 ko electra 에넣어서 해당 상자에있는 text가 개인정보인지아닌지 판별합니다.
303
+ # 송장내 해당 상자가 개인정보로(레이블 :1) 추정될경우 모자이크를합니다.
304
+ # 모자이크라고 판별할경우 해당상자의 좌표를 송장이미지에 맞는 좌표로 변환하고 그 좌표에 해당하는 부분을 모자이크합니다.
305
+ # 부분적으로 모자이크된 송장이미지를 전체이미지(송장을 포함하는 이미지)에 붙입니다.
306
+
307
+ text=recongize(text_area)
308
+ label=ner(text)
309
+ with open("output/text_recongnize.txt","a") as recognized:
310
+ recognized.writelines(str(index2+1))
311
+ recognized.writelines(" ")
312
+ recognized.writelines(str(text))
313
+ recognized.writelines(" ")
314
+ recognized.writelines(str(label))
315
+ recognized.writelines("\n")
316
+ recognized.close()
317
+ print("done")
318
+ if label==1:
319
+ A=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
320
+ normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
321
+
322
+ else:
323
+ pass
324
+ a,b,c,d=original_coordinate_list[index]
325
+ original_image[a:b,c:d,:]=normal_image
326
+
327
+ # 더 정확도 높이기위해서 이미지 전체(송장과 배경 둘다) craft에 통째로 넣기
328
+ # 단 optional (단점 : infer speed )
329
+
330
+ #print("full mask start")
331
+ #original_image=full_img_masking(original_image,net=net,refine_net=refine_net)
332
+ #print("full mask done")
333
+
334
+
335
+
336
+ original_image=Image.fromarray(original_image)
337
+ original_image.save("output/mosaiked.png")
338
+ print("masked complete")
339
+ return original_image
340
+
341
+
342
+ if __name__ == '__main__':
343
+
344
+
345
+
346
+ iface = gr.Interface(
347
+ fn=main,
348
+ inputs=gr.Image(type="filepath", label="Invoice Image"),
349
+ outputs=gr.Image(type="pil", label="Masked Invoice Image"),
350
+ live=True
351
+ )
352
+
353
+ iface.launch()
354
+
mosaik.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from scipy.ndimage.filters import gaussian_filter
4
+ import cv2
5
+ import numpy as np
6
+ from scipy.ndimage import gaussian_filter
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Polygon
9
+
10
+
11
+
12
+ # 1프로에서 5프로사이의 crop 된부분만 전체 마스킹하는 코드입니다.
13
+ def mosaik(img,bboxes):
14
+ for box in bboxes:
15
+ #[y_min,y_max,x_min,x_max]) #
16
+
17
+ cropped=img[box[0]:box[1],box[2]:box[3],:]
18
+
19
+
20
+ cropped=np.array(cropped)
21
+ cropped = gaussian_filter(cropped, sigma=16)
22
+ img[box[0]:box[1],box[2]:box[3],:]=cropped
23
+
24
+
25
+ return img
26
+
27
+
28
+
29
+
30
+
31
+
32
+
ner.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
2
+ from transformers import pipeline
3
+ from collections import defaultdict
4
+ import torch
5
+ device = torch.device("cuda")
6
+ tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
7
+ model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
8
+ model.to(device)
9
+
10
+ # 송장이라 추정되는부분을 craft에 통과시키고 text 가 있는부분을 크롭해서 trocr로 text를 그 영역에 뽑아낸이후 프로세스입니다.
11
+ # 뽑힌 text에 대한 class를 판별합니다.
12
+ # text에 대한 class가 "사람이름 PS", "도로/건물 이름 AF", "주소 LC" 에 속하면 1을 반환하여 이후 모자이크 하도록합니다.
13
+ # ner 모델은 text를 어절 마다 쪼개서 각 단어에 대한 class를 반환합니다.
14
+ # 이 때 모든 단어에 대한 class를 고려하다보면 infer speed 가 매우느려서 최소한 하나라도 ps,af,lc 클래스 해당 단어가 있으면 1 반환하도록합니다.
15
+
16
+ def check_entity(entities):
17
+ for entity_info in entities:
18
+ entity_value = entity_info.get('entity', '').upper()
19
+ if 'LC' in entity_value or 'PS' in entity_value or 'AF' in entity_value:
20
+ return 1
21
+ return 0
22
+ def ner(example):
23
+ ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device)
24
+ ner_results = ner(example)
25
+ ner_results=check_entity(ner_results)
26
+ return ner_results
27
+
28
+
29
+
30
+ # 하나
31
+ # def find_longest_value_key(input_dict):
32
+ # max_length = 0
33
+ # max_length_keys = []
34
+
35
+ # for key, value in input_dict.items():
36
+ # current_length = len(value)
37
+ # if current_length > max_length:
38
+ # max_length = current_length
39
+ # max_length_keys = [key]
40
+ # elif current_length == max_length:
41
+ # max_length_keys.append(key)
42
+
43
+ # if len(max_length_keys) == 1:
44
+ # return 0
45
+ # else:
46
+ # return 1
47
+
48
+
49
+
50
+ # def find_longest_value_key2(input_dict):
51
+ # if not input_dict:
52
+ # return None
53
+
54
+ # max_key = max(input_dict, key=lambda k: len(input_dict[k]))
55
+ # return max_key
56
+
57
+
58
+ # def find_most_frequent_entity(entities):
59
+ # entity_counts = defaultdict(list)
60
+
61
+ # for item in entities:
62
+ # split_entity = item['entity'].split('-')
63
+
64
+ # entity_type = split_entity[1]
65
+ # entity_counts[entity_type].append(item['score'])
66
+ # number=find_longest_value_key(entity_counts)
67
+ # if number==1:
68
+ # max_entities = []
69
+ # max_score_average = -1
70
+
71
+ # for entity, scores in entity_counts.items():
72
+ # score_average = sum(scores) / len(scores)
73
+
74
+ # if score_average > max_score_average:
75
+ # max_entities = [entity]
76
+ # max_score_average = score_average
77
+ # elif score_average == max_score_average:
78
+ # max_entities.append(entity)
79
+ # if len(max_entities)>0:
80
+ # return max_entities if len(max_entities) > 1 else max_entities[0]
81
+ # else:
82
+ # return "Do not mosaik"
83
+ # else:
84
+ # A=find_longest_value_key2(entity_counts)
85
+
86
+ # return A
87
+
88
+
89
+
90
+
91
+ # 하나라도 ps 나 lc 가 있으면 바로 ps , lc 꺼내기
92
+
93
+
94
+ # label=filtering(ner_results)
95
+ # if label.find("PS")>-1 or label.find("LC")>-1:
96
+ # return 1
97
+ # else:
98
+ # return 0
99
+ #print(ner("홍길동"))
100
+
101
+
102
+
103
+
104
+ #label=check_label(example)
105
+
106
+
recognize.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer
2
+ import unicodedata
3
+
4
+ # huggingface 에서 trocr 모델 weight을 가져오고 해당 weight을 fine tuning 하여서 trocr_weight folder에 저장하였습니다. (tokenizer, processor도 같이저장)
5
+ # recognize가 받는 이미지는 송장내에서 craft로 크롭된 부분이고 text가 있는곳으로 추정되는 부분입니다.
6
+ # 해당 영역에서 있을법한 text내용을 추출합니다.
7
+
8
+
9
+ def recongize(img):
10
+ processor = TrOCRProcessor.from_pretrained("trocr_weight")
11
+ model = VisionEncoderDecoderModel.from_pretrained("trocr_weight")
12
+ tokenizer = AutoTokenizer.from_pretrained("trocr_weight")
13
+
14
+ pixel_values = processor(img, return_tensors="pt").pixel_values
15
+ generated_ids = model.generate(pixel_values, max_length=64)
16
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
17
+ generated_text = unicodedata.normalize("NFC", generated_text)
18
+ return generated_text
refinenet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2019-present NAVER Corp.
3
+ MIT License
4
+ """
5
+
6
+ # -*- coding: utf-8 -*-
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.autograd import Variable
11
+ from basenet.vgg16_bn import init_weights
12
+
13
+
14
+ class RefineNet(nn.Module):
15
+ def __init__(self):
16
+ super(RefineNet, self).__init__()
17
+
18
+ self.last_conv = nn.Sequential(
19
+ nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
20
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
21
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
22
+ )
23
+
24
+ self.aspp1 = nn.Sequential(
25
+ nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
26
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
27
+ nn.Conv2d(128, 1, kernel_size=1)
28
+ )
29
+
30
+ self.aspp2 = nn.Sequential(
31
+ nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
32
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
33
+ nn.Conv2d(128, 1, kernel_size=1)
34
+ )
35
+
36
+ self.aspp3 = nn.Sequential(
37
+ nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
38
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
39
+ nn.Conv2d(128, 1, kernel_size=1)
40
+ )
41
+
42
+ self.aspp4 = nn.Sequential(
43
+ nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
44
+ nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
45
+ nn.Conv2d(128, 1, kernel_size=1)
46
+ )
47
+
48
+ init_weights(self.last_conv.modules())
49
+ init_weights(self.aspp1.modules())
50
+ init_weights(self.aspp2.modules())
51
+ init_weights(self.aspp3.modules())
52
+ init_weights(self.aspp4.modules())
53
+
54
+ def forward(self, y, upconv4):
55
+ refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
56
+ refine = self.last_conv(refine)
57
+
58
+ aspp1 = self.aspp1(refine)
59
+ aspp2 = self.aspp2(refine)
60
+ aspp3 = self.aspp3(refine)
61
+ aspp4 = self.aspp4(refine)
62
+
63
+ #out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
64
+ out = aspp1 + aspp2 + aspp3 + aspp4
65
+ return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ pillow
4
+ matplotlib
5
+ transformers
6
+ scipy
7
+ torchvision
8
+ unicodedata
9
+ opencv-python
10
+ scikit-image
11
+ math
12
+ os
13
+ collections
14
+
15
+
16
+
17
+
18
+
reset.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -rf output
2
+ mkdir output
3
+
4
+ rm -rf flagged
5
+
6
+ rm -rf temporal_mask
7
+ mkdir temporal_mask
8
+
9
+ rm -rf text_area
10
+ mkdir text_area
11
+
12
+ # inference 이후 시행 (결과 폴더 정리)
13
+
14
+
15
+
seg.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unet.predict import predict_img,mask_to_image
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+
7
+ # 아래는 unet이고 dinov2 전에 실험했던 모델입니다.
8
+ def segmentation(img):
9
+ device=torch.device("cuda")
10
+ net=torch.load("weights/unet.pth")
11
+
12
+ mask_values=[[0, 0, 0], [255, 255, 255]]
13
+
14
+ mask=predict_img(net,img,device,scale_factor=1,out_threshold=0.5)
15
+ result = mask_to_image(mask, mask_values)
16
+ result=np.array(result)
17
+
18
+
19
+ return result
20
+
21
+
22
+ # 위 segmentation 을 통해서 crop 된 부분이 이미지내에서 몇프로 차지하는지 계산합니다.
23
+ # 아래 함수는 dinov2, unet모두에게 적용합니다
24
+ # 아래 코드는 하얀색 픽셀이 연속적으로 이어져서 만들어진 덩어리가 전체에서 몇프로 차지하는지 계산합니다.
25
+ # 아래 코드는 덩어리(송장으로 추정) 들이 2개 이상이어도 적용할수 있습니다.
26
+ def mask_percentage(mask_path):
27
+
28
+ image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
29
+
30
+ ret, threshold = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
31
+
32
+ contours, hierarchy = cv2.findContours(threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
33
+ total_area = image.shape[0] * image.shape[1]
34
+ contours_list=contours
35
+
36
+ contour_areas = [cv2.contourArea(contour) for contour in contours]
37
+
38
+
39
+ percentages = [(area / total_area) * 100 for area in contour_areas]
40
+ percentage_list=[]
41
+ for i, percentage in enumerate(percentages):
42
+ percentage_list.append(percentage)
43
+ return contours_list,percentage_list
44
+
45
+
46
+
seg2.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dino2.model import Segmentor
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import numpy as np
6
+ import cv2
7
+
8
+
9
+
10
+ T2=transforms.ToPILImage()
11
+
12
+
13
+ weights="/home/ksyint/other1213/craft_ku/weights/dinov2.pt"
14
+ device=torch.device("cpu")
15
+ model = Segmentor(device,1,backbone = 'dinov2_b',head="conv")
16
+ model.load_state_dict(torch.load(weights,map_location="cpu"))
17
+ model = model.to(device)
18
+
19
+
20
+ img_transform = transforms.Compose([
21
+ transforms.Resize((14*64,14*64)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24
+ ])
25
+ # 송장부분을 crop 하기 위한 코드
26
+ # dinov2 모델을 segmentation 모델로 활용하였고 이를 train 하여서 pretrained weight를 weight 폴더에 저장하고 아래와 같이 load합니다.
27
+
28
+ def dino_seg(numpy_array):
29
+ img0=Image.fromarray(numpy_array)
30
+ original_size=img0.size
31
+ img=img_transform(img0)
32
+ a=img.unsqueeze(0)
33
+ b=model(a)
34
+ b=b.squeeze(0)
35
+ b=b*255.0
36
+ model_output=T2(b) #pil image
37
+ model_output=model_output.resize(original_size)
38
+
39
+ model_output=np.array(model_output)
40
+ model_output[model_output > 220] = 255.0
41
+ model_output[model_output <= 220] = 0.0
42
+ model_output2=model_output
43
+ model_output3=model_output
44
+ output = np.stack([model_output, model_output2, model_output3])
45
+ output=np.transpose(output,(1,2,0))
46
+
47
+
48
+ return output
49
+
50
+
51
+
52
+
53
+ # def find_connected_components(image):
54
+
55
+
56
+ # _, binary_image = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY)
57
+
58
+ # _, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_image, connectivity=4)
59
+
60
+
61
+ # largest_component_index = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
62
+
63
+ # x, y, w, h,_=stats[largest_component_index, cv2.CC_STAT_LEFT:cv2.CC_STAT_TOP+cv2.CC_STAT_HEIGHT+1]
64
+
65
+ # return x, y, w, h
66
+
67
+
68
+
sr/__pycache__/sr.cpython-310.pyc ADDED
Binary file (606 Bytes). View file
 
sr/esrgan ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 362a0316878f41dbdfbb23657b450c3353de5acf
sr/sr.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from .esrgan.RealESRGAN import RealESRGAN
5
+ def sr(img):
6
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ device=torch.device("cuda")
8
+ model = RealESRGAN(device, scale=2)
9
+ model.load_weights('weights/RealESRGAN_x2.pth', download=True)
10
+
11
+
12
+ img=Image.fromarray(img)
13
+ sr_image = model.predict(img)
14
+ sr_image=np.array(sr_image)
15
+ return sr_image
unet/__pycache__/predict.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
unet/dino/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import Dinov2
unet/dino/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (202 Bytes). View file
 
unet/dino/__pycache__/model.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
unet/dino/__pycache__/modules.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
unet/dino/__pycache__/parts.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
unet/dino/model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .parts import *
3
+
4
+
5
+ class Dinov2(nn.Module):
6
+ def __init__(self, n_channels, n_classes, bilinear=False):
7
+ super(Dinov2, self).__init__()
8
+ self.n_channels = n_channels
9
+ self.n_classes = n_classes
10
+ self.bilinear = bilinear
11
+
12
+ self.inc = (DoubleConv(n_channels, 64))
13
+ self.down1 = (Down(64, 128))
14
+ self.down2 = (Down(128, 256))
15
+ self.down3 = (Down(256, 512))
16
+ factor = 2 if bilinear else 1
17
+ self.down4 = (Down(512, 1024 // factor))
18
+ self.up1 = (Up(1024, 512 // factor, bilinear))
19
+ self.up2 = (Up(512, 256 // factor, bilinear))
20
+ self.up3 = (Up(256, 128 // factor, bilinear))
21
+ self.up4 = (Up(128, 64, bilinear))
22
+ self.outc = (OutConv(64, n_classes))
23
+
24
+ def forward(self, x):
25
+ x1 = self.inc(x)
26
+ x2 = self.down1(x1)
27
+ x3 = self.down2(x2)
28
+ x4 = self.down3(x3)
29
+ x5 = self.down4(x4)
30
+ x = self.up1(x5, x4)
31
+ x = self.up2(x, x3)
32
+ x = self.up3(x, x2)
33
+ x = self.up4(x, x1)
34
+ logits = self.outc(x)
35
+ return logits
36
+
37
+ def use_checkpointing(self):
38
+ self.inc = torch.utils.checkpoint(self.inc)
39
+ self.down1 = torch.utils.checkpoint(self.down1)
40
+ self.down2 = torch.utils.checkpoint(self.down2)
41
+ self.down3 = torch.utils.checkpoint(self.down3)
42
+ self.down4 = torch.utils.checkpoint(self.down4)
43
+ self.up1 = torch.utils.checkpoint(self.up1)
44
+ self.up2 = torch.utils.checkpoint(self.up2)
45
+ self.up3 = torch.utils.checkpoint(self.up3)
46
+ self.up4 = torch.utils.checkpoint(self.up4)
47
+ self.outc = torch.utils.checkpoint(self.outc)
unet/dino/parts.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+
8
+ def __init__(self, in_channels, out_channels, mid_channels=None):
9
+ super().__init__()
10
+ if not mid_channels:
11
+ mid_channels = out_channels
12
+ self.double_conv = nn.Sequential(
13
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
14
+ nn.BatchNorm2d(mid_channels),
15
+ nn.ReLU(inplace=True),
16
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.BatchNorm2d(out_channels),
18
+ nn.ReLU(inplace=True)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.double_conv(x)
23
+
24
+
25
+ class Down(nn.Module):
26
+
27
+ def __init__(self, in_channels, out_channels):
28
+ super().__init__()
29
+ self.maxpool_conv = nn.Sequential(
30
+ nn.MaxPool2d(2),
31
+ DoubleConv(in_channels, out_channels)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.maxpool_conv(x)
36
+
37
+
38
+ class Up(nn.Module):
39
+
40
+ def __init__(self, in_channels, out_channels, bilinear=True):
41
+ super().__init__()
42
+
43
+ if bilinear:
44
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
45
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
46
+ else:
47
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
48
+ self.conv = DoubleConv(in_channels, out_channels)
49
+
50
+ def forward(self, x1, x2):
51
+ x1 = self.up(x1)
52
+ diffY = x2.size()[2] - x1.size()[2]
53
+ diffX = x2.size()[3] - x1.size()[3]
54
+
55
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
56
+ diffY // 2, diffY - diffY // 2])
57
+ x = torch.cat([x2, x1], dim=1)
58
+ return self.conv(x)
59
+
60
+
61
+ class OutConv(nn.Module):
62
+ def __init__(self, in_channels, out_channels):
63
+ super(OutConv, self).__init__()
64
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
65
+
66
+ def forward(self, x):
67
+ return self.conv(x)
unet/predict.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+
7
+
8
+ def preprocess(mask_values, pil_img, scale, is_mask):
9
+ pil_img=Image.fromarray(pil_img)
10
+ w, h = pil_img.size
11
+ newW, newH = int(scale * w), int(scale * h)
12
+ pil_img = pil_img.resize((newW, newH))
13
+ img = np.asarray(pil_img)
14
+
15
+ if is_mask:
16
+ mask = np.zeros((newH, newW), dtype=np.int64)
17
+ for i, v in enumerate(mask_values):
18
+ if img.ndim == 2:
19
+ mask[img == v] = i
20
+ else:
21
+ mask[(img == v).all(-1)] = i
22
+
23
+ return mask
24
+
25
+ else:
26
+ if img.ndim == 2:
27
+ img = img[np.newaxis, ...]
28
+ else:
29
+ img = img.transpose((2, 0, 1))
30
+
31
+ if (img > 1).any():
32
+ img = img / 255.0
33
+
34
+ return img
35
+ def predict_img(net,
36
+ full_img,
37
+ device,
38
+ scale_factor=1,
39
+ out_threshold=0.5):
40
+ net.eval()
41
+ img = torch.from_numpy(preprocess(None, full_img, scale_factor, is_mask=False))
42
+ img = img.unsqueeze(0)
43
+ img = img.to(device=device, dtype=torch.float32)
44
+
45
+ with torch.no_grad():
46
+ output = net(img).cpu()
47
+
48
+ if net.n_classes > 1:
49
+ mask = output.argmax(dim=1)
50
+ else:
51
+ mask = torch.sigmoid(output) > out_threshold
52
+
53
+ return mask[0].long().squeeze().numpy()
54
+
55
+
56
+
57
+
58
+
59
+
60
+ def mask_to_image(mask: np.ndarray, mask_values):
61
+ if isinstance(mask_values[0], list):
62
+ out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
63
+ elif mask_values == [0, 1]:
64
+ out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
65
+ else:
66
+ out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
67
+
68
+ if mask.ndim == 3:
69
+ mask = np.argmax(mask, axis=0)
70
+
71
+ for i, v in enumerate(mask_values):
72
+ out[mask == i] = v
73
+
74
+ return Image.fromarray(out)
75
+
76
+