Spaces:
Sleeping
Sleeping
Initial Commit
Browse files- __pycache__/craft.cpython-310.pyc +0 -0
- __pycache__/craft_utils.cpython-310.pyc +0 -0
- __pycache__/file_utils.cpython-310.pyc +0 -0
- __pycache__/imgproc.cpython-310.pyc +0 -0
- __pycache__/mosaik.cpython-310.pyc +0 -0
- __pycache__/ner.cpython-310.pyc +0 -0
- __pycache__/recognize.cpython-310.pyc +0 -0
- __pycache__/refinenet.cpython-310.pyc +0 -0
- __pycache__/seg.cpython-310.pyc +0 -0
- __pycache__/seg2.cpython-310.pyc +0 -0
- basenet/__init__.py +0 -0
- basenet/__pycache__/__init__.cpython-310.pyc +0 -0
- basenet/__pycache__/vgg16_bn.cpython-310.pyc +0 -0
- basenet/vgg16_bn.py +72 -0
- craft.py +76 -0
- craft_utils.py +217 -0
- dino2/__pycache__/model.cpython-310.pyc +0 -0
- dino2/model.py +93 -0
- file_utils.py +77 -0
- imgproc.py +70 -0
- input/1.png +0 -0
- input/2.png +0 -0
- input/3.png +0 -0
- input/4.png +0 -0
- install.sh +1 -0
- main.py +354 -0
- mosaik.py +32 -0
- ner.py +106 -0
- recognize.py +18 -0
- refinenet.py +65 -0
- requirements.txt +18 -0
- reset.sh +15 -0
- seg.py +46 -0
- seg2.py +68 -0
- sr/__pycache__/sr.cpython-310.pyc +0 -0
- sr/esrgan +1 -0
- sr/sr.py +15 -0
- unet/__pycache__/predict.cpython-310.pyc +0 -0
- unet/dino/__init__.py +1 -0
- unet/dino/__pycache__/__init__.cpython-310.pyc +0 -0
- unet/dino/__pycache__/model.cpython-310.pyc +0 -0
- unet/dino/__pycache__/modules.cpython-310.pyc +0 -0
- unet/dino/__pycache__/parts.cpython-310.pyc +0 -0
- unet/dino/model.py +47 -0
- unet/dino/parts.py +67 -0
- unet/predict.py +76 -0
__pycache__/craft.cpython-310.pyc
ADDED
Binary file (2.38 kB). View file
|
|
__pycache__/craft_utils.cpython-310.pyc
ADDED
Binary file (5.68 kB). View file
|
|
__pycache__/file_utils.cpython-310.pyc
ADDED
Binary file (2.49 kB). View file
|
|
__pycache__/imgproc.cpython-310.pyc
ADDED
Binary file (2.08 kB). View file
|
|
__pycache__/mosaik.cpython-310.pyc
ADDED
Binary file (698 Bytes). View file
|
|
__pycache__/ner.cpython-310.pyc
ADDED
Binary file (906 Bytes). View file
|
|
__pycache__/recognize.cpython-310.pyc
ADDED
Binary file (716 Bytes). View file
|
|
__pycache__/refinenet.cpython-310.pyc
ADDED
Binary file (1.93 kB). View file
|
|
__pycache__/seg.cpython-310.pyc
ADDED
Binary file (1.5 kB). View file
|
|
__pycache__/seg2.cpython-310.pyc
ADDED
Binary file (1.27 kB). View file
|
|
basenet/__init__.py
ADDED
File without changes
|
basenet/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (164 Bytes). View file
|
|
basenet/__pycache__/vgg16_bn.cpython-310.pyc
ADDED
Binary file (2.27 kB). View file
|
|
basenet/vgg16_bn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.init as init
|
6 |
+
from torchvision import models
|
7 |
+
|
8 |
+
def init_weights(modules):
|
9 |
+
for m in modules:
|
10 |
+
if isinstance(m, nn.Conv2d):
|
11 |
+
init.xavier_uniform_(m.weight.data)
|
12 |
+
if m.bias is not None:
|
13 |
+
m.bias.data.zero_()
|
14 |
+
elif isinstance(m, nn.BatchNorm2d):
|
15 |
+
m.weight.data.fill_(1)
|
16 |
+
m.bias.data.zero_()
|
17 |
+
elif isinstance(m, nn.Linear):
|
18 |
+
m.weight.data.normal_(0, 0.01)
|
19 |
+
m.bias.data.zero_()
|
20 |
+
|
21 |
+
class vgg16_bn(torch.nn.Module):
|
22 |
+
def __init__(self, pretrained=True, freeze=True):
|
23 |
+
super(vgg16_bn, self).__init__()
|
24 |
+
|
25 |
+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
|
26 |
+
self.slice1 = torch.nn.Sequential()
|
27 |
+
self.slice2 = torch.nn.Sequential()
|
28 |
+
self.slice3 = torch.nn.Sequential()
|
29 |
+
self.slice4 = torch.nn.Sequential()
|
30 |
+
self.slice5 = torch.nn.Sequential()
|
31 |
+
for x in range(12): # conv2_2
|
32 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
33 |
+
for x in range(12, 19): # conv3_3
|
34 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
35 |
+
for x in range(19, 29): # conv4_3
|
36 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
37 |
+
for x in range(29, 39): # conv5_3
|
38 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
39 |
+
|
40 |
+
# fc6, fc7 without atrous conv
|
41 |
+
self.slice5 = torch.nn.Sequential(
|
42 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
43 |
+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
|
44 |
+
nn.Conv2d(1024, 1024, kernel_size=1)
|
45 |
+
)
|
46 |
+
|
47 |
+
if not pretrained:
|
48 |
+
init_weights(self.slice1.modules())
|
49 |
+
init_weights(self.slice2.modules())
|
50 |
+
init_weights(self.slice3.modules())
|
51 |
+
init_weights(self.slice4.modules())
|
52 |
+
|
53 |
+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
|
54 |
+
|
55 |
+
if freeze:
|
56 |
+
for param in self.slice1.parameters(): # only first conv
|
57 |
+
param.requires_grad= False
|
58 |
+
|
59 |
+
def forward(self, X):
|
60 |
+
h = self.slice1(X)
|
61 |
+
h_relu2_2 = h
|
62 |
+
h = self.slice2(h)
|
63 |
+
h_relu3_2 = h
|
64 |
+
h = self.slice3(h)
|
65 |
+
h_relu4_3 = h
|
66 |
+
h = self.slice4(h)
|
67 |
+
h_relu5_3 = h
|
68 |
+
h = self.slice5(h)
|
69 |
+
h_fc7 = h
|
70 |
+
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
|
71 |
+
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
|
72 |
+
return out
|
craft.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from basenet.vgg16_bn import vgg16_bn, init_weights
|
7 |
+
|
8 |
+
class double_conv(nn.Module):
|
9 |
+
def __init__(self, in_ch, mid_ch, out_ch):
|
10 |
+
super(double_conv, self).__init__()
|
11 |
+
self.conv = nn.Sequential(
|
12 |
+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
|
13 |
+
nn.BatchNorm2d(mid_ch),
|
14 |
+
nn.ReLU(inplace=True),
|
15 |
+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
|
16 |
+
nn.BatchNorm2d(out_ch),
|
17 |
+
nn.ReLU(inplace=True)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.conv(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class CRAFT(nn.Module):
|
26 |
+
def __init__(self, pretrained=False, freeze=False):
|
27 |
+
super(CRAFT, self).__init__()
|
28 |
+
|
29 |
+
""" Base network """
|
30 |
+
self.basenet = vgg16_bn(pretrained, freeze)
|
31 |
+
|
32 |
+
""" U network """
|
33 |
+
self.upconv1 = double_conv(1024, 512, 256)
|
34 |
+
self.upconv2 = double_conv(512, 256, 128)
|
35 |
+
self.upconv3 = double_conv(256, 128, 64)
|
36 |
+
self.upconv4 = double_conv(128, 64, 32)
|
37 |
+
|
38 |
+
num_class = 2
|
39 |
+
self.conv_cls = nn.Sequential(
|
40 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
41 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
42 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
|
43 |
+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv2d(16, num_class, kernel_size=1),
|
45 |
+
)
|
46 |
+
|
47 |
+
init_weights(self.upconv1.modules())
|
48 |
+
init_weights(self.upconv2.modules())
|
49 |
+
init_weights(self.upconv3.modules())
|
50 |
+
init_weights(self.upconv4.modules())
|
51 |
+
init_weights(self.conv_cls.modules())
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
""" Base network """
|
55 |
+
sources = self.basenet(x)
|
56 |
+
|
57 |
+
""" U network """
|
58 |
+
y = torch.cat([sources[0], sources[1]], dim=1)
|
59 |
+
y = self.upconv1(y)
|
60 |
+
|
61 |
+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
|
62 |
+
y = torch.cat([y, sources[2]], dim=1)
|
63 |
+
y = self.upconv2(y)
|
64 |
+
|
65 |
+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
|
66 |
+
y = torch.cat([y, sources[3]], dim=1)
|
67 |
+
y = self.upconv3(y)
|
68 |
+
|
69 |
+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
|
70 |
+
y = torch.cat([y, sources[4]], dim=1)
|
71 |
+
feature = self.upconv4(y)
|
72 |
+
|
73 |
+
y = self.conv_cls(feature)
|
74 |
+
|
75 |
+
return y.permute(0,2,3,1), feature
|
76 |
+
|
craft_utils.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import math
|
6 |
+
|
7 |
+
def warpCoord(Minv, pt):
|
8 |
+
out = np.matmul(Minv, (pt[0], pt[1], 1))
|
9 |
+
return np.array([out[0]/out[2], out[1]/out[2]])
|
10 |
+
|
11 |
+
|
12 |
+
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
|
13 |
+
linkmap = linkmap.copy()
|
14 |
+
textmap = textmap.copy()
|
15 |
+
img_h, img_w = textmap.shape
|
16 |
+
|
17 |
+
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
|
18 |
+
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
|
19 |
+
|
20 |
+
text_score_comb = np.clip(text_score + link_score, 0, 1)
|
21 |
+
nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
|
22 |
+
|
23 |
+
det = []
|
24 |
+
mapper = []
|
25 |
+
for k in range(1,nLabels):
|
26 |
+
size = stats[k, cv2.CC_STAT_AREA]
|
27 |
+
if size < 10: continue
|
28 |
+
|
29 |
+
if np.max(textmap[labels==k]) < text_threshold: continue
|
30 |
+
|
31 |
+
segmap = np.zeros(textmap.shape, dtype=np.uint8)
|
32 |
+
segmap[labels==k] = 255
|
33 |
+
segmap[np.logical_and(link_score==1, text_score==0)] = 0
|
34 |
+
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
|
35 |
+
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
|
36 |
+
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
|
37 |
+
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
|
38 |
+
if sx < 0 : sx = 0
|
39 |
+
if sy < 0 : sy = 0
|
40 |
+
if ex >= img_w: ex = img_w
|
41 |
+
if ey >= img_h: ey = img_h
|
42 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
|
43 |
+
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
|
44 |
+
|
45 |
+
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
|
46 |
+
rectangle = cv2.minAreaRect(np_contours)
|
47 |
+
box = cv2.boxPoints(rectangle)
|
48 |
+
|
49 |
+
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
|
50 |
+
box_ratio = max(w, h) / (min(w, h) + 1e-5)
|
51 |
+
if abs(1 - box_ratio) <= 0.1:
|
52 |
+
l, r = min(np_contours[:,0]), max(np_contours[:,0])
|
53 |
+
t, b = min(np_contours[:,1]), max(np_contours[:,1])
|
54 |
+
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
|
55 |
+
|
56 |
+
startidx = box.sum(axis=1).argmin()
|
57 |
+
box = np.roll(box, 4-startidx, 0)
|
58 |
+
box = np.array(box)
|
59 |
+
|
60 |
+
det.append(box)
|
61 |
+
mapper.append(k)
|
62 |
+
|
63 |
+
return det, labels, mapper
|
64 |
+
|
65 |
+
def getPoly_core(boxes, labels, mapper, linkmap):
|
66 |
+
num_cp = 5
|
67 |
+
max_len_ratio = 0.7
|
68 |
+
expand_ratio = 1.45
|
69 |
+
max_r = 2.0
|
70 |
+
step_r = 0.2
|
71 |
+
|
72 |
+
polys = []
|
73 |
+
for k, box in enumerate(boxes):
|
74 |
+
w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
|
75 |
+
if w < 10 or h < 10:
|
76 |
+
polys.append(None); continue
|
77 |
+
|
78 |
+
tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
|
79 |
+
M = cv2.getPerspectiveTransform(box, tar)
|
80 |
+
word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
|
81 |
+
try:
|
82 |
+
Minv = np.linalg.inv(M)
|
83 |
+
except:
|
84 |
+
polys.append(None); continue
|
85 |
+
|
86 |
+
cur_label = mapper[k]
|
87 |
+
word_label[word_label != cur_label] = 0
|
88 |
+
word_label[word_label > 0] = 1
|
89 |
+
|
90 |
+
cp = []
|
91 |
+
max_len = -1
|
92 |
+
for i in range(w):
|
93 |
+
region = np.where(word_label[:,i] != 0)[0]
|
94 |
+
if len(region) < 2 : continue
|
95 |
+
cp.append((i, region[0], region[-1]))
|
96 |
+
length = region[-1] - region[0] + 1
|
97 |
+
if length > max_len: max_len = length
|
98 |
+
|
99 |
+
if h * max_len_ratio < max_len:
|
100 |
+
polys.append(None); continue
|
101 |
+
|
102 |
+
tot_seg = num_cp * 2 + 1
|
103 |
+
seg_w = w / tot_seg
|
104 |
+
pp = [None] * num_cp
|
105 |
+
cp_section = [[0, 0]] * tot_seg
|
106 |
+
seg_height = [0] * num_cp
|
107 |
+
seg_num = 0
|
108 |
+
num_sec = 0
|
109 |
+
prev_h = -1
|
110 |
+
for i in range(0,len(cp)):
|
111 |
+
(x, sy, ey) = cp[i]
|
112 |
+
if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
|
113 |
+
# average previous segment
|
114 |
+
if num_sec == 0: break
|
115 |
+
cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
|
116 |
+
num_sec = 0
|
117 |
+
|
118 |
+
# reset variables
|
119 |
+
seg_num += 1
|
120 |
+
prev_h = -1
|
121 |
+
|
122 |
+
# accumulate center points
|
123 |
+
cy = (sy + ey) * 0.5
|
124 |
+
cur_h = ey - sy + 1
|
125 |
+
cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
|
126 |
+
num_sec += 1
|
127 |
+
|
128 |
+
if seg_num % 2 == 0: continue # No polygon area
|
129 |
+
|
130 |
+
if prev_h < cur_h:
|
131 |
+
pp[int((seg_num - 1)/2)] = (x, cy)
|
132 |
+
seg_height[int((seg_num - 1)/2)] = cur_h
|
133 |
+
prev_h = cur_h
|
134 |
+
|
135 |
+
# processing last segment
|
136 |
+
if num_sec != 0:
|
137 |
+
cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
|
138 |
+
|
139 |
+
# pass if num of pivots is not sufficient or segment widh is smaller than character height
|
140 |
+
if None in pp or seg_w < np.max(seg_height) * 0.25:
|
141 |
+
polys.append(None); continue
|
142 |
+
|
143 |
+
# calc median maximum of pivot points
|
144 |
+
half_char_h = np.median(seg_height) * expand_ratio / 2
|
145 |
+
|
146 |
+
# calc gradiant and apply to make horizontal pivots
|
147 |
+
new_pp = []
|
148 |
+
for i, (x, cy) in enumerate(pp):
|
149 |
+
dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
|
150 |
+
dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
|
151 |
+
if dx == 0: # gradient if zero
|
152 |
+
new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
|
153 |
+
continue
|
154 |
+
rad = - math.atan2(dy, dx)
|
155 |
+
c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
|
156 |
+
new_pp.append([x - s, cy - c, x + s, cy + c])
|
157 |
+
|
158 |
+
# get edge points to cover character heatmaps
|
159 |
+
isSppFound, isEppFound = False, False
|
160 |
+
grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
|
161 |
+
grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
|
162 |
+
for r in np.arange(0.5, max_r, step_r):
|
163 |
+
dx = 2 * half_char_h * r
|
164 |
+
if not isSppFound:
|
165 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
166 |
+
dy = grad_s * dx
|
167 |
+
p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
|
168 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
169 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
170 |
+
spp = p
|
171 |
+
isSppFound = True
|
172 |
+
if not isEppFound:
|
173 |
+
line_img = np.zeros(word_label.shape, dtype=np.uint8)
|
174 |
+
dy = grad_e * dx
|
175 |
+
p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
|
176 |
+
cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
|
177 |
+
if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
|
178 |
+
epp = p
|
179 |
+
isEppFound = True
|
180 |
+
if isSppFound and isEppFound:
|
181 |
+
break
|
182 |
+
|
183 |
+
if not (isSppFound and isEppFound):
|
184 |
+
polys.append(None); continue
|
185 |
+
|
186 |
+
poly = []
|
187 |
+
poly.append(warpCoord(Minv, (spp[0], spp[1])))
|
188 |
+
for p in new_pp:
|
189 |
+
poly.append(warpCoord(Minv, (p[0], p[1])))
|
190 |
+
poly.append(warpCoord(Minv, (epp[0], epp[1])))
|
191 |
+
poly.append(warpCoord(Minv, (epp[2], epp[3])))
|
192 |
+
for p in reversed(new_pp):
|
193 |
+
poly.append(warpCoord(Minv, (p[2], p[3])))
|
194 |
+
poly.append(warpCoord(Minv, (spp[2], spp[3])))
|
195 |
+
|
196 |
+
# add to final result
|
197 |
+
polys.append(np.array(poly))
|
198 |
+
|
199 |
+
return polys
|
200 |
+
|
201 |
+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
|
202 |
+
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
|
203 |
+
|
204 |
+
if poly:
|
205 |
+
polys = getPoly_core(boxes, labels, mapper, linkmap)
|
206 |
+
else:
|
207 |
+
polys = [None] * len(boxes)
|
208 |
+
|
209 |
+
return boxes, polys
|
210 |
+
|
211 |
+
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
|
212 |
+
if len(polys) > 0:
|
213 |
+
polys = np.array(polys)
|
214 |
+
for k in range(len(polys)):
|
215 |
+
if polys[k] is not None:
|
216 |
+
polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
|
217 |
+
return polys
|
dino2/__pycache__/model.cpython-310.pyc
ADDED
Binary file (2.85 kB). View file
|
|
dino2/model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.hub import load
|
4 |
+
import torchvision.models as models
|
5 |
+
|
6 |
+
|
7 |
+
dino_backbones = {
|
8 |
+
'dinov2_s':{
|
9 |
+
'name':'dinov2_vits14',
|
10 |
+
'embedding_size':384,
|
11 |
+
'patch_size':14
|
12 |
+
},
|
13 |
+
'dinov2_b':{
|
14 |
+
'name':'dinov2_vitb14',
|
15 |
+
'embedding_size':768,
|
16 |
+
'patch_size':14
|
17 |
+
},
|
18 |
+
'dinov2_l':{
|
19 |
+
'name':'dinov2_vitl14',
|
20 |
+
'embedding_size':1024,
|
21 |
+
'patch_size':14
|
22 |
+
},
|
23 |
+
'dinov2_g':{
|
24 |
+
'name':'dinov2_vitg14',
|
25 |
+
'embedding_size':1536,
|
26 |
+
'patch_size':14
|
27 |
+
},
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class linear_head(nn.Module):
|
32 |
+
def __init__(self, embedding_size = 384, num_classes = 5):
|
33 |
+
super(linear_head, self).__init__()
|
34 |
+
self.fc = nn.Linear(embedding_size, num_classes)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.fc(x)
|
38 |
+
|
39 |
+
|
40 |
+
class conv_head(nn.Module):
|
41 |
+
def __init__(self, embedding_size = 384, num_classes = 5):
|
42 |
+
super(conv_head, self).__init__()
|
43 |
+
self.segmentation_conv = nn.Sequential(
|
44 |
+
nn.Upsample(scale_factor=2),
|
45 |
+
nn.Conv2d(embedding_size, 64, (3,3), padding=(1,1)),
|
46 |
+
nn.Upsample(scale_factor=2),
|
47 |
+
nn.Conv2d(64, num_classes, (3,3), padding=(1,1)),
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.segmentation_conv(x)
|
52 |
+
x = torch.sigmoid(x)
|
53 |
+
|
54 |
+
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
def threshold_mask(predicted, threshold=0.55):
|
59 |
+
thresholded_mask = (predicted > threshold).float()
|
60 |
+
return thresholded_mask
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
class Segmentor(nn.Module):
|
66 |
+
def __init__(self, device,num_classes, backbone = 'dinov2_s', head = 'conv', backbones = dino_backbones):
|
67 |
+
super(Segmentor, self).__init__()
|
68 |
+
self.heads = {
|
69 |
+
'conv':conv_head
|
70 |
+
}
|
71 |
+
self.backbones = dino_backbones
|
72 |
+
self.backbone = load('facebookresearch/dinov2', self.backbones[backbone]['name'])
|
73 |
+
|
74 |
+
self.backbone.eval()
|
75 |
+
self.num_classes = num_classes
|
76 |
+
self.embedding_size = self.backbones[backbone]['embedding_size']
|
77 |
+
self.patch_size = self.backbones[backbone]['patch_size']
|
78 |
+
self.head = self.heads[head](self.embedding_size,self.num_classes)
|
79 |
+
self.device=device
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
batch_size = x.shape[0]
|
83 |
+
mask_dim = (x.shape[2] / self.patch_size, x.shape[3] / self.patch_size)
|
84 |
+
x = self.backbone.forward_features(x.to(self.device))
|
85 |
+
|
86 |
+
x = x['x_norm_patchtokens']
|
87 |
+
x = x.permute(0,2,1)
|
88 |
+
x = x.reshape(batch_size,self.embedding_size,int(mask_dim[0]),int(mask_dim[1]))
|
89 |
+
x = self.head(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
|
file_utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import imgproc
|
6 |
+
from mosaik import mosaik
|
7 |
+
# borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
|
8 |
+
def get_files(img_dir):
|
9 |
+
imgs, masks, xmls = list_files(img_dir)
|
10 |
+
return imgs, masks, xmls
|
11 |
+
|
12 |
+
def list_files(in_path):
|
13 |
+
img_files = []
|
14 |
+
mask_files = []
|
15 |
+
gt_files = []
|
16 |
+
for (dirpath, dirnames, filenames) in os.walk(in_path):
|
17 |
+
for file in filenames:
|
18 |
+
filename, ext = os.path.splitext(file)
|
19 |
+
ext = str.lower(ext)
|
20 |
+
if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
|
21 |
+
img_files.append(os.path.join(dirpath, file))
|
22 |
+
elif ext == '.bmp':
|
23 |
+
mask_files.append(os.path.join(dirpath, file))
|
24 |
+
elif ext == '.xml' or ext == '.gt' or ext == '.txt':
|
25 |
+
gt_files.append(os.path.join(dirpath, file))
|
26 |
+
elif ext == '.zip':
|
27 |
+
continue
|
28 |
+
# img_files.sort()
|
29 |
+
# mask_files.sort()
|
30 |
+
# gt_files.sort()
|
31 |
+
return img_files, mask_files, gt_files
|
32 |
+
|
33 |
+
def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None):
|
34 |
+
""" save text detection result one by one
|
35 |
+
Args:
|
36 |
+
img_file (str): image file name
|
37 |
+
img (array): raw image context
|
38 |
+
boxes (array): array of result file
|
39 |
+
Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output
|
40 |
+
Return:
|
41 |
+
None
|
42 |
+
"""
|
43 |
+
img = np.array(img)
|
44 |
+
|
45 |
+
# make result file list
|
46 |
+
filename, file_ext = os.path.splitext(os.path.basename(img_file))
|
47 |
+
|
48 |
+
# result directory
|
49 |
+
res_file = dirname + "res_" + filename + '.txt'
|
50 |
+
res_img_file = dirname + "res_" + filename + '.jpg'
|
51 |
+
|
52 |
+
if not os.path.isdir(dirname):
|
53 |
+
os.mkdir(dirname)
|
54 |
+
|
55 |
+
with open(res_file, 'w') as f:
|
56 |
+
for i, box in enumerate(boxes):
|
57 |
+
poly = np.array(box).astype(np.int32).reshape((-1))
|
58 |
+
strResult = ','.join([str(p) for p in poly]) + '\r\n'
|
59 |
+
f.write(strResult)
|
60 |
+
|
61 |
+
poly = poly.reshape(-1, 2)
|
62 |
+
cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
|
63 |
+
ptColor = (0, 255, 255)
|
64 |
+
if verticals is not None:
|
65 |
+
if verticals[i]:
|
66 |
+
ptColor = (255, 0, 0)
|
67 |
+
|
68 |
+
if texts is not None:
|
69 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
70 |
+
font_scale = 0.5
|
71 |
+
cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1)
|
72 |
+
cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1)
|
73 |
+
|
74 |
+
# Save result image
|
75 |
+
cv2.imwrite(res_img_file, img)
|
76 |
+
return img
|
77 |
+
|
imgproc.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import numpy as np
|
8 |
+
from skimage import io
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
def loadImage(img_file):
|
12 |
+
img = io.imread(img_file) # RGB order
|
13 |
+
if img.shape[0] == 2: img = img[0]
|
14 |
+
if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
15 |
+
if img.shape[2] == 4: img = img[:,:,:3]
|
16 |
+
img = np.array(img)
|
17 |
+
|
18 |
+
return img
|
19 |
+
|
20 |
+
def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
|
21 |
+
# should be RGB order
|
22 |
+
img = in_img.copy().astype(np.float32)
|
23 |
+
|
24 |
+
img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32)
|
25 |
+
img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32)
|
26 |
+
return img
|
27 |
+
|
28 |
+
def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)):
|
29 |
+
# should be RGB order
|
30 |
+
img = in_img.copy()
|
31 |
+
img *= variance
|
32 |
+
img += mean
|
33 |
+
img *= 255.0
|
34 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
35 |
+
return img
|
36 |
+
|
37 |
+
def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
|
38 |
+
height, width, channel = img.shape
|
39 |
+
|
40 |
+
# magnify image size
|
41 |
+
target_size = mag_ratio * max(height, width)
|
42 |
+
|
43 |
+
# set original image size
|
44 |
+
if target_size > square_size:
|
45 |
+
target_size = square_size
|
46 |
+
|
47 |
+
ratio = target_size / max(height, width)
|
48 |
+
|
49 |
+
target_h, target_w = int(height * ratio), int(width * ratio)
|
50 |
+
proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation)
|
51 |
+
|
52 |
+
|
53 |
+
# make canvas and paste image
|
54 |
+
target_h32, target_w32 = target_h, target_w
|
55 |
+
if target_h % 32 != 0:
|
56 |
+
target_h32 = target_h + (32 - target_h % 32)
|
57 |
+
if target_w % 32 != 0:
|
58 |
+
target_w32 = target_w + (32 - target_w % 32)
|
59 |
+
resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
|
60 |
+
resized[0:target_h, 0:target_w, :] = proc
|
61 |
+
target_h, target_w = target_h32, target_w32
|
62 |
+
|
63 |
+
size_heatmap = (int(target_w/2), int(target_h/2))
|
64 |
+
|
65 |
+
return resized, ratio, size_heatmap
|
66 |
+
|
67 |
+
def cvt2HeatmapImg(img):
|
68 |
+
img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
|
69 |
+
img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
|
70 |
+
return img
|
input/1.png
ADDED
input/2.png
ADDED
input/3.png
ADDED
input/4.png
ADDED
install.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip install
|
main.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from recognize import recongize
|
2 |
+
from ner import ner
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import argparse
|
6 |
+
from sr.sr import sr
|
7 |
+
import torch
|
8 |
+
from scipy.ndimage import gaussian_filter
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.backends.cudnn as cudnn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
from mosaik import mosaik
|
15 |
+
from PIL import Image
|
16 |
+
import cv2
|
17 |
+
from skimage import io
|
18 |
+
import numpy as np
|
19 |
+
import craft_utils
|
20 |
+
import imgproc
|
21 |
+
import file_utils
|
22 |
+
from seg import mask_percentage
|
23 |
+
|
24 |
+
from seg2 import dino_seg
|
25 |
+
|
26 |
+
from craft import CRAFT
|
27 |
+
from collections import OrderedDict
|
28 |
+
import gradio as gr
|
29 |
+
from refinenet import RefineNet
|
30 |
+
|
31 |
+
|
32 |
+
# craft, refine 모델 불러오는 코드
|
33 |
+
def copyStateDict(state_dict):
|
34 |
+
if list(state_dict.keys())[0].startswith("module"):
|
35 |
+
start_idx = 1
|
36 |
+
else:
|
37 |
+
start_idx = 0
|
38 |
+
new_state_dict = OrderedDict()
|
39 |
+
for k, v in state_dict.items():
|
40 |
+
name = ".".join(k.split(".")[start_idx:])
|
41 |
+
new_state_dict[name] = v
|
42 |
+
return new_state_dict
|
43 |
+
|
44 |
+
def str2bool(v):
|
45 |
+
return v.lower() in ("yes", "y", "true", "t", "1")
|
46 |
+
|
47 |
+
parser = argparse.ArgumentParser(description='CRAFT Text Detection')
|
48 |
+
parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='사전학습 craft 모델')
|
49 |
+
parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')
|
50 |
+
parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')
|
51 |
+
parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')
|
52 |
+
parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')
|
53 |
+
parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')
|
54 |
+
parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')
|
55 |
+
parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')
|
56 |
+
parser.add_argument('--refine', default=True, help='enable link refiner')
|
57 |
+
parser.add_argument('--image_path', default="input/2.png", help='input image')
|
58 |
+
parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')
|
59 |
+
|
60 |
+
args = parser.parse_args()
|
61 |
+
# 아래는 option
|
62 |
+
def full_img_masking(full_image,net,refine_net):
|
63 |
+
reference_image=sr(full_image)
|
64 |
+
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
|
65 |
+
boxes=get_box_from_refer(reference_boxes)
|
66 |
+
for index2,box in enumerate(boxes):
|
67 |
+
xmin,xmax,ymin,ymax=get_min_max(box)
|
68 |
+
|
69 |
+
text_area=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
|
70 |
+
|
71 |
+
text=recongize(text_area)
|
72 |
+
label=ner(text)
|
73 |
+
|
74 |
+
if label==1:
|
75 |
+
A=full_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
|
76 |
+
full_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
|
77 |
+
return full_image
|
78 |
+
|
79 |
+
def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
|
80 |
+
t0 = time.time()
|
81 |
+
|
82 |
+
img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)
|
83 |
+
ratio_h = ratio_w = 1 / target_ratio
|
84 |
+
|
85 |
+
x = imgproc.normalizeMeanVariance(img_resized)
|
86 |
+
x = torch.from_numpy(x).permute(2, 0, 1)
|
87 |
+
x = Variable(x.unsqueeze(0))
|
88 |
+
if cuda:
|
89 |
+
x = x.cuda()
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
y, feature = net(x)
|
93 |
+
|
94 |
+
score_text = y[0,:,:,0].cpu().data.numpy()
|
95 |
+
score_link = y[0,:,:,1].cpu().data.numpy()
|
96 |
+
|
97 |
+
if refine_net is not None:
|
98 |
+
with torch.no_grad():
|
99 |
+
y_refiner = refine_net(y, feature)
|
100 |
+
score_link = y_refiner[0,:,:,0].cpu().data.numpy()
|
101 |
+
|
102 |
+
t0 = time.time() - t0
|
103 |
+
t1 = time.time()
|
104 |
+
|
105 |
+
boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)
|
106 |
+
|
107 |
+
boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
|
108 |
+
polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
|
109 |
+
for k in range(len(polys)):
|
110 |
+
if polys[k] is None: polys[k] = boxes[k]
|
111 |
+
|
112 |
+
t1 = time.time() - t1
|
113 |
+
|
114 |
+
# render results (optional)
|
115 |
+
render_img = score_text.copy()
|
116 |
+
render_img = np.hstack((render_img, score_link))
|
117 |
+
ret_score_text = imgproc.cvt2HeatmapImg(render_img)
|
118 |
+
|
119 |
+
|
120 |
+
return boxes, polys, ret_score_text
|
121 |
+
|
122 |
+
def text_detect(image,net,refine_net):
|
123 |
+
|
124 |
+
bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, refine_net)
|
125 |
+
|
126 |
+
|
127 |
+
return bboxes
|
128 |
+
|
129 |
+
|
130 |
+
def get_box_from_refer(reference_boxes):
|
131 |
+
|
132 |
+
real_boxes=[]
|
133 |
+
for box in reference_boxes:
|
134 |
+
|
135 |
+
real_boxes.append(box//2)
|
136 |
+
|
137 |
+
return real_boxes
|
138 |
+
def get_min_max(box):
|
139 |
+
xlist=[]
|
140 |
+
ylist=[]
|
141 |
+
for coor in box:
|
142 |
+
xlist.append(coor[0])
|
143 |
+
ylist.append(coor[1])
|
144 |
+
return min(xlist),max(xlist),min(ylist),max(ylist)
|
145 |
+
|
146 |
+
def main(image_path0):
|
147 |
+
# 1단계
|
148 |
+
|
149 |
+
# ==> craft 모델과 refinenet 모델을 불러오고 cuda device 에 얹힙니다.
|
150 |
+
|
151 |
+
net = CRAFT()
|
152 |
+
if args.cuda:
|
153 |
+
net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
|
154 |
+
|
155 |
+
if args.cuda:
|
156 |
+
net = net.cuda()
|
157 |
+
cudnn.benchmark = False
|
158 |
+
|
159 |
+
net.eval()
|
160 |
+
|
161 |
+
refine_net = None
|
162 |
+
if args.refine:
|
163 |
+
refine_net = RefineNet()
|
164 |
+
if args.cuda:
|
165 |
+
refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))
|
166 |
+
refine_net = refine_net.cuda()
|
167 |
+
|
168 |
+
|
169 |
+
refine_net.eval()
|
170 |
+
args.poly = True
|
171 |
+
|
172 |
+
# 2단계
|
173 |
+
|
174 |
+
# gradio 빈칸에 이미지를 넣고 A 에 입력됩니다.
|
175 |
+
|
176 |
+
A=image_path0
|
177 |
+
image_list=[]
|
178 |
+
image_list.append(A)
|
179 |
+
for k, image_path in enumerate(image_list):
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
image = imgproc.loadImage(image_path)
|
184 |
+
if image.shape[2]>3:
|
185 |
+
image=image[:,:,0:3]
|
186 |
+
|
187 |
+
original_image=image
|
188 |
+
# 이미지에서 송장부분만 dinov2 모델로 segmentation 을 합니다.
|
189 |
+
|
190 |
+
output=dino_seg(image)
|
191 |
+
image3=Image.fromarray(output)
|
192 |
+
image3.save("temporal_mask/mask.png")
|
193 |
+
|
194 |
+
# 마스크이미지(white pixel, black background)를 만듭니다.
|
195 |
+
# 위 마스크 이미지에서 각 덩어리들(송장으로 추정)이 전체 이미지내에서 몇프로차지하는지 계산합니다.
|
196 |
+
contours_list,percentage_list=mask_percentage("temporal_mask/mask.png")
|
197 |
+
|
198 |
+
normal_image_list=[]
|
199 |
+
|
200 |
+
small_coordinate_list=[]
|
201 |
+
original_coordinate_list=[]
|
202 |
+
|
203 |
+
|
204 |
+
#3단계
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
sorted_list = sorted(percentage_list, reverse=True)
|
209 |
+
top_5 = sorted_list[:5]
|
210 |
+
print("상위 5개 값:", top_5)
|
211 |
+
# percentage list의 경우 송장으로 추정되는 뭉치들의 퍼센트를 모아놓은것이고
|
212 |
+
# contours list는 이미지내에서 송장으로 추정되는 뭉치들이 크롭되어서 정렬된 리스트입니다.
|
213 |
+
# 예 : percentatge list 의 첫번째 요소는 contours list의 첫번째 요소의 percentage
|
214 |
+
|
215 |
+
for index,percentage in enumerate(percentage_list):
|
216 |
+
|
217 |
+
if 5<percentage:
|
218 |
+
|
219 |
+
# percentage 가 아미지내에서 5프로 넘는 것들은 normal list로 포함됩니다.
|
220 |
+
# normal list안에는 이미지내에서 충분히 큰 뭉치들(송장으로 추정) 을 모아놓았습니다.
|
221 |
+
# 1-5프로 인것들은 small coordinate list에 포함되고 매우 작은 뭉치로 간주합니다.
|
222 |
+
# 매우작은 뭉치의 경우 zoom in을 했을때 뭉치(송장으로 추정)내 글자가 거의 보이지않아서 따라서 뭉치 전체를 mosaik합니다.
|
223 |
+
# 1프로미만 뭉치들은 소멸직전일정도로 작아 생략합니다.
|
224 |
+
|
225 |
+
contour=contours_list[index]
|
226 |
+
|
227 |
+
x_list=[]
|
228 |
+
y_list=[]
|
229 |
+
contour2=list(contour)
|
230 |
+
|
231 |
+
for r in contour2:
|
232 |
+
r2=r[0]
|
233 |
+
x_list.append(r2[0])
|
234 |
+
y_list.append(r2[1])
|
235 |
+
x_min=min(x_list)
|
236 |
+
y_min=min(y_list)
|
237 |
+
x_max=max(x_list)
|
238 |
+
y_max=max(y_list)
|
239 |
+
original_coordinate_list.append([y_min,y_max,x_min,x_max])
|
240 |
+
image2=original_image[y_min:y_max,x_min:x_max,:]
|
241 |
+
normal_image_list.append(image2)
|
242 |
+
|
243 |
+
|
244 |
+
#
|
245 |
+
elif 1<percentage<5:
|
246 |
+
contour=contours_list[index]
|
247 |
+
|
248 |
+
x_list=[]
|
249 |
+
y_list=[]
|
250 |
+
contour2=list(contour)
|
251 |
+
|
252 |
+
for r in contour2:
|
253 |
+
r2=r[0]
|
254 |
+
x_list.append(r2[0])
|
255 |
+
y_list.append(r2[1])
|
256 |
+
x_min=min(x_list)
|
257 |
+
y_min=min(y_list)
|
258 |
+
x_max=max(x_list)
|
259 |
+
y_max=max(y_list)
|
260 |
+
small_coordinate_list.append([y_min,y_max,x_min,x_max]) #송장 5프로미만의 좌표
|
261 |
+
else:
|
262 |
+
continue
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
# 4단계 (매우작은 송장)
|
268 |
+
|
269 |
+
# small coordinate list안에 매우작은 송장들이 모여져있지만 list안에 요소가 없으면 5단계로 바로갑니다.
|
270 |
+
# 바로 가지않을경우(list 안요소 최소하나) mosaik 를 통해서 전체이미지에서 작은 뭉치에 해당하는 좌표들을 모두 모자이크합니다.
|
271 |
+
|
272 |
+
if len(small_coordinate_list)>0:
|
273 |
+
original_image=mosaik(original_image,small_coordinate_list)
|
274 |
+
else:
|
275 |
+
pass
|
276 |
+
|
277 |
+
# 5단계 (어느정도 사이즈 있는 송장) ==> normal list
|
278 |
+
|
279 |
+
# normal image list안에 적절한 크기의 송장(줌 하면 글자 보이는) 들이 있습니다.
|
280 |
+
# craft 입장에서 text 위치를 return 할수 있게끔 크롭된 송장을 esrgan 으로 화질개선합니다.
|
281 |
+
# 화질개선된 송장을 craft에 넣어서 정확하게 text 좌표들을 모두 구합니다.
|
282 |
+
# 좌표를 구할때 화질 좋은 송장이미지의 좌표를 그대로 return 하지 않고 원본 송장이미지에 맞추어서 scale(//2) 하고 최종좌표를 구합니다.
|
283 |
+
|
284 |
+
for index,normal_image in enumerate(normal_image_list):
|
285 |
+
reference_image=sr(normal_image)
|
286 |
+
reference_boxes=text_detect(reference_image,net=net,refine_net=refine_net)
|
287 |
+
boxes=get_box_from_refer(reference_boxes)
|
288 |
+
for index2,box in enumerate(boxes):
|
289 |
+
xmin,xmax,ymin,ymax=get_min_max(box)
|
290 |
+
|
291 |
+
text_area=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
|
292 |
+
text_area=Image.fromarray(text_area)
|
293 |
+
os.makedirs("text_area",exist_ok=True)
|
294 |
+
text_area.save(f"text_area/new_{index2+1}.png")
|
295 |
+
|
296 |
+
|
297 |
+
# 6단계 (text recognize, ner)
|
298 |
+
|
299 |
+
# 위 좌표들을 통해서 송장 내에서 박스들을 크롭합니다.
|
300 |
+
# 크롭된 송장내 부분(크롭된 박스 , 즉 text 있는곳으로 추정되는곳) 을 trocr 에넣습니다.
|
301 |
+
# trocr은 상자내에 추정되는 text를 보여줍니다.
|
302 |
+
# text를 ko electra 에넣어서 해당 상자에있는 text가 개인정보인지아닌지 판별합니다.
|
303 |
+
# 송장내 해당 상자가 개인정보로(레이블 :1) 추정될경우 모자이크를합니다.
|
304 |
+
# 모자이크라고 판별할경우 해당상자의 좌표를 송장이미지에 맞는 좌표로 변환하고 그 좌표에 해당하는 부분을 모자이크합니다.
|
305 |
+
# 부분적으로 모자이크된 송장이미지를 전체이미지(송장을 포함하는 이미지)에 붙입니다.
|
306 |
+
|
307 |
+
text=recongize(text_area)
|
308 |
+
label=ner(text)
|
309 |
+
with open("output/text_recongnize.txt","a") as recognized:
|
310 |
+
recognized.writelines(str(index2+1))
|
311 |
+
recognized.writelines(" ")
|
312 |
+
recognized.writelines(str(text))
|
313 |
+
recognized.writelines(" ")
|
314 |
+
recognized.writelines(str(label))
|
315 |
+
recognized.writelines("\n")
|
316 |
+
recognized.close()
|
317 |
+
print("done")
|
318 |
+
if label==1:
|
319 |
+
A=normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:]
|
320 |
+
normal_image[int(ymin):int(ymax),int(xmin):int(xmax),:] = gaussian_filter(A, sigma=16)
|
321 |
+
|
322 |
+
else:
|
323 |
+
pass
|
324 |
+
a,b,c,d=original_coordinate_list[index]
|
325 |
+
original_image[a:b,c:d,:]=normal_image
|
326 |
+
|
327 |
+
# 더 정확도 높이기위해서 이미지 전체(송장과 배경 둘다) craft에 통째로 넣기
|
328 |
+
# 단 optional (단점 : infer speed )
|
329 |
+
|
330 |
+
#print("full mask start")
|
331 |
+
#original_image=full_img_masking(original_image,net=net,refine_net=refine_net)
|
332 |
+
#print("full mask done")
|
333 |
+
|
334 |
+
|
335 |
+
|
336 |
+
original_image=Image.fromarray(original_image)
|
337 |
+
original_image.save("output/mosaiked.png")
|
338 |
+
print("masked complete")
|
339 |
+
return original_image
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == '__main__':
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
iface = gr.Interface(
|
347 |
+
fn=main,
|
348 |
+
inputs=gr.Image(type="filepath", label="Invoice Image"),
|
349 |
+
outputs=gr.Image(type="pil", label="Masked Invoice Image"),
|
350 |
+
live=True
|
351 |
+
)
|
352 |
+
|
353 |
+
iface.launch()
|
354 |
+
|
mosaik.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from scipy.ndimage.filters import gaussian_filter
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from scipy.ndimage import gaussian_filter
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from matplotlib.patches import Polygon
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
# 1프로에서 5프로사이의 crop 된부분만 전체 마스킹하는 코드입니다.
|
13 |
+
def mosaik(img,bboxes):
|
14 |
+
for box in bboxes:
|
15 |
+
#[y_min,y_max,x_min,x_max]) #
|
16 |
+
|
17 |
+
cropped=img[box[0]:box[1],box[2]:box[3],:]
|
18 |
+
|
19 |
+
|
20 |
+
cropped=np.array(cropped)
|
21 |
+
cropped = gaussian_filter(cropped, sigma=16)
|
22 |
+
img[box[0]:box[1],box[2]:box[3],:]=cropped
|
23 |
+
|
24 |
+
|
25 |
+
return img
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
ner.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
2 |
+
from transformers import pipeline
|
3 |
+
from collections import defaultdict
|
4 |
+
import torch
|
5 |
+
device = torch.device("cuda")
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
|
7 |
+
model = AutoModelForTokenClassification.from_pretrained("Leo97/KoELECTRA-small-v3-modu-ner")
|
8 |
+
model.to(device)
|
9 |
+
|
10 |
+
# 송장이라 추정되는부분을 craft에 통과시키고 text 가 있는부분을 크롭해서 trocr로 text를 그 영역에 뽑아낸이후 프로세스입니다.
|
11 |
+
# 뽑힌 text에 대한 class를 판별합니다.
|
12 |
+
# text에 대한 class가 "사람이름 PS", "도로/건물 이름 AF", "주소 LC" 에 속하면 1을 반환하여 이후 모자이크 하도록합니다.
|
13 |
+
# ner 모델은 text를 어절 마다 쪼개서 각 단어에 대한 class를 반환합니다.
|
14 |
+
# 이 때 모든 단어에 대한 class를 고려하다보면 infer speed 가 매우느려서 최소한 하나라도 ps,af,lc 클래스 해당 단어가 있으면 1 반환하도록합니다.
|
15 |
+
|
16 |
+
def check_entity(entities):
|
17 |
+
for entity_info in entities:
|
18 |
+
entity_value = entity_info.get('entity', '').upper()
|
19 |
+
if 'LC' in entity_value or 'PS' in entity_value or 'AF' in entity_value:
|
20 |
+
return 1
|
21 |
+
return 0
|
22 |
+
def ner(example):
|
23 |
+
ner = pipeline("ner", model=model, tokenizer=tokenizer,device=device)
|
24 |
+
ner_results = ner(example)
|
25 |
+
ner_results=check_entity(ner_results)
|
26 |
+
return ner_results
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# 하나
|
31 |
+
# def find_longest_value_key(input_dict):
|
32 |
+
# max_length = 0
|
33 |
+
# max_length_keys = []
|
34 |
+
|
35 |
+
# for key, value in input_dict.items():
|
36 |
+
# current_length = len(value)
|
37 |
+
# if current_length > max_length:
|
38 |
+
# max_length = current_length
|
39 |
+
# max_length_keys = [key]
|
40 |
+
# elif current_length == max_length:
|
41 |
+
# max_length_keys.append(key)
|
42 |
+
|
43 |
+
# if len(max_length_keys) == 1:
|
44 |
+
# return 0
|
45 |
+
# else:
|
46 |
+
# return 1
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
# def find_longest_value_key2(input_dict):
|
51 |
+
# if not input_dict:
|
52 |
+
# return None
|
53 |
+
|
54 |
+
# max_key = max(input_dict, key=lambda k: len(input_dict[k]))
|
55 |
+
# return max_key
|
56 |
+
|
57 |
+
|
58 |
+
# def find_most_frequent_entity(entities):
|
59 |
+
# entity_counts = defaultdict(list)
|
60 |
+
|
61 |
+
# for item in entities:
|
62 |
+
# split_entity = item['entity'].split('-')
|
63 |
+
|
64 |
+
# entity_type = split_entity[1]
|
65 |
+
# entity_counts[entity_type].append(item['score'])
|
66 |
+
# number=find_longest_value_key(entity_counts)
|
67 |
+
# if number==1:
|
68 |
+
# max_entities = []
|
69 |
+
# max_score_average = -1
|
70 |
+
|
71 |
+
# for entity, scores in entity_counts.items():
|
72 |
+
# score_average = sum(scores) / len(scores)
|
73 |
+
|
74 |
+
# if score_average > max_score_average:
|
75 |
+
# max_entities = [entity]
|
76 |
+
# max_score_average = score_average
|
77 |
+
# elif score_average == max_score_average:
|
78 |
+
# max_entities.append(entity)
|
79 |
+
# if len(max_entities)>0:
|
80 |
+
# return max_entities if len(max_entities) > 1 else max_entities[0]
|
81 |
+
# else:
|
82 |
+
# return "Do not mosaik"
|
83 |
+
# else:
|
84 |
+
# A=find_longest_value_key2(entity_counts)
|
85 |
+
|
86 |
+
# return A
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
# 하나라도 ps 나 lc 가 있으면 바로 ps , lc 꺼내기
|
92 |
+
|
93 |
+
|
94 |
+
# label=filtering(ner_results)
|
95 |
+
# if label.find("PS")>-1 or label.find("LC")>-1:
|
96 |
+
# return 1
|
97 |
+
# else:
|
98 |
+
# return 0
|
99 |
+
#print(ner("홍길동"))
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
#label=check_label(example)
|
105 |
+
|
106 |
+
|
recognize.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoTokenizer
|
2 |
+
import unicodedata
|
3 |
+
|
4 |
+
# huggingface 에서 trocr 모델 weight을 가져오고 해당 weight을 fine tuning 하여서 trocr_weight folder에 저장하였습니다. (tokenizer, processor도 같이저장)
|
5 |
+
# recognize가 받는 이미지는 송장내에서 craft로 크롭된 부분이고 text가 있는곳으로 추정되는 부분입니다.
|
6 |
+
# 해당 영역에서 있을법한 text내용을 추출합니다.
|
7 |
+
|
8 |
+
|
9 |
+
def recongize(img):
|
10 |
+
processor = TrOCRProcessor.from_pretrained("trocr_weight")
|
11 |
+
model = VisionEncoderDecoderModel.from_pretrained("trocr_weight")
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained("trocr_weight")
|
13 |
+
|
14 |
+
pixel_values = processor(img, return_tensors="pt").pixel_values
|
15 |
+
generated_ids = model.generate(pixel_values, max_length=64)
|
16 |
+
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
17 |
+
generated_text = unicodedata.normalize("NFC", generated_text)
|
18 |
+
return generated_text
|
refinenet.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2019-present NAVER Corp.
|
3 |
+
MIT License
|
4 |
+
"""
|
5 |
+
|
6 |
+
# -*- coding: utf-8 -*-
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.autograd import Variable
|
11 |
+
from basenet.vgg16_bn import init_weights
|
12 |
+
|
13 |
+
|
14 |
+
class RefineNet(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(RefineNet, self).__init__()
|
17 |
+
|
18 |
+
self.last_conv = nn.Sequential(
|
19 |
+
nn.Conv2d(34, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
20 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)
|
22 |
+
)
|
23 |
+
|
24 |
+
self.aspp1 = nn.Sequential(
|
25 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.aspp2 = nn.Sequential(
|
31 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=12, padding=12), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
33 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
34 |
+
)
|
35 |
+
|
36 |
+
self.aspp3 = nn.Sequential(
|
37 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=18, padding=18), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
39 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
40 |
+
)
|
41 |
+
|
42 |
+
self.aspp4 = nn.Sequential(
|
43 |
+
nn.Conv2d(64, 128, kernel_size=3, dilation=24, padding=24), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
|
45 |
+
nn.Conv2d(128, 1, kernel_size=1)
|
46 |
+
)
|
47 |
+
|
48 |
+
init_weights(self.last_conv.modules())
|
49 |
+
init_weights(self.aspp1.modules())
|
50 |
+
init_weights(self.aspp2.modules())
|
51 |
+
init_weights(self.aspp3.modules())
|
52 |
+
init_weights(self.aspp4.modules())
|
53 |
+
|
54 |
+
def forward(self, y, upconv4):
|
55 |
+
refine = torch.cat([y.permute(0,3,1,2), upconv4], dim=1)
|
56 |
+
refine = self.last_conv(refine)
|
57 |
+
|
58 |
+
aspp1 = self.aspp1(refine)
|
59 |
+
aspp2 = self.aspp2(refine)
|
60 |
+
aspp3 = self.aspp3(refine)
|
61 |
+
aspp4 = self.aspp4(refine)
|
62 |
+
|
63 |
+
#out = torch.add([aspp1, aspp2, aspp3, aspp4], dim=1)
|
64 |
+
out = aspp1 + aspp2 + aspp3 + aspp4
|
65 |
+
return out.permute(0, 2, 3, 1) # , refine.permute(0,2,3,1)
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
pillow
|
4 |
+
matplotlib
|
5 |
+
transformers
|
6 |
+
scipy
|
7 |
+
torchvision
|
8 |
+
unicodedata
|
9 |
+
opencv-python
|
10 |
+
scikit-image
|
11 |
+
math
|
12 |
+
os
|
13 |
+
collections
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
reset.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -rf output
|
2 |
+
mkdir output
|
3 |
+
|
4 |
+
rm -rf flagged
|
5 |
+
|
6 |
+
rm -rf temporal_mask
|
7 |
+
mkdir temporal_mask
|
8 |
+
|
9 |
+
rm -rf text_area
|
10 |
+
mkdir text_area
|
11 |
+
|
12 |
+
# inference 이후 시행 (결과 폴더 정리)
|
13 |
+
|
14 |
+
|
15 |
+
|
seg.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unet.predict import predict_img,mask_to_image
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
# 아래는 unet이고 dinov2 전에 실험했던 모델입니다.
|
8 |
+
def segmentation(img):
|
9 |
+
device=torch.device("cuda")
|
10 |
+
net=torch.load("weights/unet.pth")
|
11 |
+
|
12 |
+
mask_values=[[0, 0, 0], [255, 255, 255]]
|
13 |
+
|
14 |
+
mask=predict_img(net,img,device,scale_factor=1,out_threshold=0.5)
|
15 |
+
result = mask_to_image(mask, mask_values)
|
16 |
+
result=np.array(result)
|
17 |
+
|
18 |
+
|
19 |
+
return result
|
20 |
+
|
21 |
+
|
22 |
+
# 위 segmentation 을 통해서 crop 된 부분이 이미지내에서 몇프로 차지하는지 계산합니다.
|
23 |
+
# 아래 함수는 dinov2, unet모두에게 적용합니다
|
24 |
+
# 아래 코드는 하얀색 픽셀이 연속적으로 이어져서 만들어진 덩어리가 전체에서 몇프로 차지하는지 계산합니다.
|
25 |
+
# 아래 코드는 덩어리(송장으로 추정) 들이 2개 이상이어도 적용할수 있습니다.
|
26 |
+
def mask_percentage(mask_path):
|
27 |
+
|
28 |
+
image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
29 |
+
|
30 |
+
ret, threshold = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
|
31 |
+
|
32 |
+
contours, hierarchy = cv2.findContours(threshold, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
33 |
+
total_area = image.shape[0] * image.shape[1]
|
34 |
+
contours_list=contours
|
35 |
+
|
36 |
+
contour_areas = [cv2.contourArea(contour) for contour in contours]
|
37 |
+
|
38 |
+
|
39 |
+
percentages = [(area / total_area) * 100 for area in contour_areas]
|
40 |
+
percentage_list=[]
|
41 |
+
for i, percentage in enumerate(percentages):
|
42 |
+
percentage_list.append(percentage)
|
43 |
+
return contours_list,percentage_list
|
44 |
+
|
45 |
+
|
46 |
+
|
seg2.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dino2.model import Segmentor
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import transforms
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
T2=transforms.ToPILImage()
|
11 |
+
|
12 |
+
|
13 |
+
weights="/home/ksyint/other1213/craft_ku/weights/dinov2.pt"
|
14 |
+
device=torch.device("cpu")
|
15 |
+
model = Segmentor(device,1,backbone = 'dinov2_b',head="conv")
|
16 |
+
model.load_state_dict(torch.load(weights,map_location="cpu"))
|
17 |
+
model = model.to(device)
|
18 |
+
|
19 |
+
|
20 |
+
img_transform = transforms.Compose([
|
21 |
+
transforms.Resize((14*64,14*64)),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
24 |
+
])
|
25 |
+
# 송장부분을 crop 하기 위한 코드
|
26 |
+
# dinov2 모델을 segmentation 모델로 활용하였고 이를 train 하여서 pretrained weight를 weight 폴더에 저장하고 아래와 같이 load합니다.
|
27 |
+
|
28 |
+
def dino_seg(numpy_array):
|
29 |
+
img0=Image.fromarray(numpy_array)
|
30 |
+
original_size=img0.size
|
31 |
+
img=img_transform(img0)
|
32 |
+
a=img.unsqueeze(0)
|
33 |
+
b=model(a)
|
34 |
+
b=b.squeeze(0)
|
35 |
+
b=b*255.0
|
36 |
+
model_output=T2(b) #pil image
|
37 |
+
model_output=model_output.resize(original_size)
|
38 |
+
|
39 |
+
model_output=np.array(model_output)
|
40 |
+
model_output[model_output > 220] = 255.0
|
41 |
+
model_output[model_output <= 220] = 0.0
|
42 |
+
model_output2=model_output
|
43 |
+
model_output3=model_output
|
44 |
+
output = np.stack([model_output, model_output2, model_output3])
|
45 |
+
output=np.transpose(output,(1,2,0))
|
46 |
+
|
47 |
+
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
# def find_connected_components(image):
|
54 |
+
|
55 |
+
|
56 |
+
# _, binary_image = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY)
|
57 |
+
|
58 |
+
# _, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_image, connectivity=4)
|
59 |
+
|
60 |
+
|
61 |
+
# largest_component_index = np.argmax(stats[1:, cv2.CC_STAT_AREA]) + 1
|
62 |
+
|
63 |
+
# x, y, w, h,_=stats[largest_component_index, cv2.CC_STAT_LEFT:cv2.CC_STAT_TOP+cv2.CC_STAT_HEIGHT+1]
|
64 |
+
|
65 |
+
# return x, y, w, h
|
66 |
+
|
67 |
+
|
68 |
+
|
sr/__pycache__/sr.cpython-310.pyc
ADDED
Binary file (606 Bytes). View file
|
|
sr/esrgan
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 362a0316878f41dbdfbb23657b450c3353de5acf
|
sr/sr.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from .esrgan.RealESRGAN import RealESRGAN
|
5 |
+
def sr(img):
|
6 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
device=torch.device("cuda")
|
8 |
+
model = RealESRGAN(device, scale=2)
|
9 |
+
model.load_weights('weights/RealESRGAN_x2.pth', download=True)
|
10 |
+
|
11 |
+
|
12 |
+
img=Image.fromarray(img)
|
13 |
+
sr_image = model.predict(img)
|
14 |
+
sr_image=np.array(sr_image)
|
15 |
+
return sr_image
|
unet/__pycache__/predict.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
unet/dino/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import Dinov2
|
unet/dino/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (202 Bytes). View file
|
|
unet/dino/__pycache__/model.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
unet/dino/__pycache__/modules.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
unet/dino/__pycache__/parts.cpython-310.pyc
ADDED
Binary file (2.58 kB). View file
|
|
unet/dino/model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .parts import *
|
3 |
+
|
4 |
+
|
5 |
+
class Dinov2(nn.Module):
|
6 |
+
def __init__(self, n_channels, n_classes, bilinear=False):
|
7 |
+
super(Dinov2, self).__init__()
|
8 |
+
self.n_channels = n_channels
|
9 |
+
self.n_classes = n_classes
|
10 |
+
self.bilinear = bilinear
|
11 |
+
|
12 |
+
self.inc = (DoubleConv(n_channels, 64))
|
13 |
+
self.down1 = (Down(64, 128))
|
14 |
+
self.down2 = (Down(128, 256))
|
15 |
+
self.down3 = (Down(256, 512))
|
16 |
+
factor = 2 if bilinear else 1
|
17 |
+
self.down4 = (Down(512, 1024 // factor))
|
18 |
+
self.up1 = (Up(1024, 512 // factor, bilinear))
|
19 |
+
self.up2 = (Up(512, 256 // factor, bilinear))
|
20 |
+
self.up3 = (Up(256, 128 // factor, bilinear))
|
21 |
+
self.up4 = (Up(128, 64, bilinear))
|
22 |
+
self.outc = (OutConv(64, n_classes))
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x1 = self.inc(x)
|
26 |
+
x2 = self.down1(x1)
|
27 |
+
x3 = self.down2(x2)
|
28 |
+
x4 = self.down3(x3)
|
29 |
+
x5 = self.down4(x4)
|
30 |
+
x = self.up1(x5, x4)
|
31 |
+
x = self.up2(x, x3)
|
32 |
+
x = self.up3(x, x2)
|
33 |
+
x = self.up4(x, x1)
|
34 |
+
logits = self.outc(x)
|
35 |
+
return logits
|
36 |
+
|
37 |
+
def use_checkpointing(self):
|
38 |
+
self.inc = torch.utils.checkpoint(self.inc)
|
39 |
+
self.down1 = torch.utils.checkpoint(self.down1)
|
40 |
+
self.down2 = torch.utils.checkpoint(self.down2)
|
41 |
+
self.down3 = torch.utils.checkpoint(self.down3)
|
42 |
+
self.down4 = torch.utils.checkpoint(self.down4)
|
43 |
+
self.up1 = torch.utils.checkpoint(self.up1)
|
44 |
+
self.up2 = torch.utils.checkpoint(self.up2)
|
45 |
+
self.up3 = torch.utils.checkpoint(self.up3)
|
46 |
+
self.up4 = torch.utils.checkpoint(self.up4)
|
47 |
+
self.outc = torch.utils.checkpoint(self.outc)
|
unet/dino/parts.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DoubleConv(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
9 |
+
super().__init__()
|
10 |
+
if not mid_channels:
|
11 |
+
mid_channels = out_channels
|
12 |
+
self.double_conv = nn.Sequential(
|
13 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
14 |
+
nn.BatchNorm2d(mid_channels),
|
15 |
+
nn.ReLU(inplace=True),
|
16 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
17 |
+
nn.BatchNorm2d(out_channels),
|
18 |
+
nn.ReLU(inplace=True)
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return self.double_conv(x)
|
23 |
+
|
24 |
+
|
25 |
+
class Down(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self, in_channels, out_channels):
|
28 |
+
super().__init__()
|
29 |
+
self.maxpool_conv = nn.Sequential(
|
30 |
+
nn.MaxPool2d(2),
|
31 |
+
DoubleConv(in_channels, out_channels)
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.maxpool_conv(x)
|
36 |
+
|
37 |
+
|
38 |
+
class Up(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
if bilinear:
|
44 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
45 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
46 |
+
else:
|
47 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
48 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
49 |
+
|
50 |
+
def forward(self, x1, x2):
|
51 |
+
x1 = self.up(x1)
|
52 |
+
diffY = x2.size()[2] - x1.size()[2]
|
53 |
+
diffX = x2.size()[3] - x1.size()[3]
|
54 |
+
|
55 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
56 |
+
diffY // 2, diffY - diffY // 2])
|
57 |
+
x = torch.cat([x2, x1], dim=1)
|
58 |
+
return self.conv(x)
|
59 |
+
|
60 |
+
|
61 |
+
class OutConv(nn.Module):
|
62 |
+
def __init__(self, in_channels, out_channels):
|
63 |
+
super(OutConv, self).__init__()
|
64 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
return self.conv(x)
|
unet/predict.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def preprocess(mask_values, pil_img, scale, is_mask):
|
9 |
+
pil_img=Image.fromarray(pil_img)
|
10 |
+
w, h = pil_img.size
|
11 |
+
newW, newH = int(scale * w), int(scale * h)
|
12 |
+
pil_img = pil_img.resize((newW, newH))
|
13 |
+
img = np.asarray(pil_img)
|
14 |
+
|
15 |
+
if is_mask:
|
16 |
+
mask = np.zeros((newH, newW), dtype=np.int64)
|
17 |
+
for i, v in enumerate(mask_values):
|
18 |
+
if img.ndim == 2:
|
19 |
+
mask[img == v] = i
|
20 |
+
else:
|
21 |
+
mask[(img == v).all(-1)] = i
|
22 |
+
|
23 |
+
return mask
|
24 |
+
|
25 |
+
else:
|
26 |
+
if img.ndim == 2:
|
27 |
+
img = img[np.newaxis, ...]
|
28 |
+
else:
|
29 |
+
img = img.transpose((2, 0, 1))
|
30 |
+
|
31 |
+
if (img > 1).any():
|
32 |
+
img = img / 255.0
|
33 |
+
|
34 |
+
return img
|
35 |
+
def predict_img(net,
|
36 |
+
full_img,
|
37 |
+
device,
|
38 |
+
scale_factor=1,
|
39 |
+
out_threshold=0.5):
|
40 |
+
net.eval()
|
41 |
+
img = torch.from_numpy(preprocess(None, full_img, scale_factor, is_mask=False))
|
42 |
+
img = img.unsqueeze(0)
|
43 |
+
img = img.to(device=device, dtype=torch.float32)
|
44 |
+
|
45 |
+
with torch.no_grad():
|
46 |
+
output = net(img).cpu()
|
47 |
+
|
48 |
+
if net.n_classes > 1:
|
49 |
+
mask = output.argmax(dim=1)
|
50 |
+
else:
|
51 |
+
mask = torch.sigmoid(output) > out_threshold
|
52 |
+
|
53 |
+
return mask[0].long().squeeze().numpy()
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
def mask_to_image(mask: np.ndarray, mask_values):
|
61 |
+
if isinstance(mask_values[0], list):
|
62 |
+
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
|
63 |
+
elif mask_values == [0, 1]:
|
64 |
+
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
|
65 |
+
else:
|
66 |
+
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
|
67 |
+
|
68 |
+
if mask.ndim == 3:
|
69 |
+
mask = np.argmax(mask, axis=0)
|
70 |
+
|
71 |
+
for i, v in enumerate(mask_values):
|
72 |
+
out[mask == i] = v
|
73 |
+
|
74 |
+
return Image.fromarray(out)
|
75 |
+
|
76 |
+
|