HERIUN commited on
Commit
2bb6556
โ€ข
1 Parent(s): 1081f7c
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +2 -2
  3. rect_main.py +9 -9
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: ๐Ÿ“Š
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.33
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py CHANGED
@@ -59,7 +59,7 @@ def reset_image(image, state):
59
  return img, state
60
 
61
  def auto_point_detect(image):
62
- out_image, msk_np = docscanner_rec(image, docscanner)
63
  state = list(get_corner(mask2point(mask=msk_np)))
64
 
65
  img = Image.fromarray(image)
@@ -126,7 +126,7 @@ def sort_corners(corners):
126
  def convert(image, state):
127
  h,w = image.shape[:2]
128
  if len(state) < 4:
129
- out_image, msk_np = docscanner_rec(image, docscanner)
130
  out_image = out_image[:,:,::-1]
131
  elif len(state) ==4:
132
  state = list(sort_corners(state))
 
59
  return img, state
60
 
61
  def auto_point_detect(image):
62
+ out_image, msk_np = docscanner_rec(image, docscanner, cuda)
63
  state = list(get_corner(mask2point(mask=msk_np)))
64
 
65
  img = Image.fromarray(image)
 
126
  def convert(image, state):
127
  h,w = image.shape[:2]
128
  if len(state) < 4:
129
+ out_image, msk_np = docscanner_rec(image, docscanner, cuda)
130
  out_image = out_image[:,:,::-1]
131
  elif len(state) ==4:
132
  state = list(sort_corners(state))
rect_main.py CHANGED
@@ -52,11 +52,11 @@ def preprocess_image(img, target_size=[288, 288]):
52
  return im_ori, im, h_, w_
53
 
54
 
55
- def geotrp_rec(img, model):
56
  im_ori, im, h_, w_ = preprocess_image(img)
57
 
58
  with torch.no_grad():
59
- bm = model(im.cuda())
60
  bm = bm.cpu().numpy()[0]
61
  bm0 = bm[0, :, :]
62
  bm1 = bm[1, :, :]
@@ -69,11 +69,11 @@ def geotrp_rec(img, model):
69
  return img_geo
70
 
71
 
72
- def docscanner_get_mask(img, model):
73
  _, im, h, w = preprocess_image(img)
74
 
75
  with torch.no_grad():
76
- _, msk = model(im.cuda())
77
  msk = msk.cpu()
78
 
79
  mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
@@ -82,11 +82,11 @@ def docscanner_get_mask(img, model):
82
  return mask_resized
83
 
84
 
85
- def docscanner_rec_img(img, model):
86
  im_ori, im, h, w = preprocess_image(img)
87
 
88
  with torch.no_grad():
89
- bm = model(im.cuda())
90
  bm = bm.cpu()
91
 
92
  # save rectified image
@@ -106,7 +106,7 @@ def docscanner_rec_img(img, model):
106
 
107
 
108
 
109
- def docscanner_rec(img, model):
110
  im_ori = img[:, :, :3] / 255.0
111
  h, w, _ = im_ori.shape
112
  im = cv2.resize(im_ori, (288, 288))
@@ -114,7 +114,7 @@ def docscanner_rec(img, model):
114
  im = torch.from_numpy(im).float().unsqueeze(0)
115
 
116
  with torch.no_grad():
117
- bm, msk = model(im.cuda())
118
  bm = bm.cpu()
119
  msk = msk.cpu()
120
 
@@ -165,7 +165,7 @@ def main():
165
  )
166
  doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path)
167
 
168
- mask = docscanner_get_mask(img, docscanner)
169
  mask_dict.add(get_mask_white_area(mask))
170
 
171
 
 
52
  return im_ori, im, h_, w_
53
 
54
 
55
+ def geotrp_rec(img, model, cuda):
56
  im_ori, im, h_, w_ = preprocess_image(img)
57
 
58
  with torch.no_grad():
59
+ bm = model(im.to(cuda))
60
  bm = bm.cpu().numpy()[0]
61
  bm0 = bm[0, :, :]
62
  bm1 = bm[1, :, :]
 
69
  return img_geo
70
 
71
 
72
+ def docscanner_get_mask(img, model, cuda):
73
  _, im, h, w = preprocess_image(img)
74
 
75
  with torch.no_grad():
76
+ _, msk = model(im.to(cuda))
77
  msk = msk.cpu()
78
 
79
  mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
 
82
  return mask_resized
83
 
84
 
85
+ def docscanner_rec_img(img, model, cuda):
86
  im_ori, im, h, w = preprocess_image(img)
87
 
88
  with torch.no_grad():
89
+ bm = model(im.to(cuda))
90
  bm = bm.cpu()
91
 
92
  # save rectified image
 
106
 
107
 
108
 
109
+ def docscanner_rec(img, model, cuda):
110
  im_ori = img[:, :, :3] / 255.0
111
  h, w, _ = im_ori.shape
112
  im = cv2.resize(im_ori, (288, 288))
 
114
  im = torch.from_numpy(im).float().unsqueeze(0)
115
 
116
  with torch.no_grad():
117
+ bm, msk = model(im.to(cuda))
118
  bm = bm.cpu()
119
  msk = msk.cpu()
120
 
 
165
  )
166
  doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path)
167
 
168
+ mask = docscanner_get_mask(img, docscanner, cuda)
169
  mask_dict.add(get_mask_white_area(mask))
170
 
171