HERIUN commited on
Commit
1081f7c
1 Parent(s): ba5f22e

add models

Browse files
Files changed (2) hide show
  1. models/DocScanner/inference.py +4 -4
  2. rect_main.py +2 -2
models/DocScanner/inference.py CHANGED
@@ -37,12 +37,12 @@ class Net(nn.Module):
37
  return bm, msk
38
 
39
 
40
- def reload_seg_model(model, path=""):
41
  if not bool(path):
42
  return model
43
  else:
44
  model_dict = model.state_dict()
45
- pretrained_dict = torch.load(path, map_location="cuda:0")
46
  pretrained_dict = {
47
  k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
48
  }
@@ -52,12 +52,12 @@ def reload_seg_model(model, path=""):
52
  return model
53
 
54
 
55
- def reload_rec_model(model, path=""):
56
  if not bool(path):
57
  return model
58
  else:
59
  model_dict = model.state_dict()
60
- pretrained_dict = torch.load(path, map_location="cuda:0")
61
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
62
  model_dict.update(pretrained_dict)
63
  model.load_state_dict(model_dict)
 
37
  return bm, msk
38
 
39
 
40
+ def reload_seg_model(cuda, model, path=""):
41
  if not bool(path):
42
  return model
43
  else:
44
  model_dict = model.state_dict()
45
+ pretrained_dict = torch.load(path, map_location=cuda)
46
  pretrained_dict = {
47
  k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
48
  }
 
52
  return model
53
 
54
 
55
+ def reload_rec_model(cuda, model, path=""):
56
  if not bool(path):
57
  return model
58
  else:
59
  model_dict = model.state_dict()
60
+ pretrained_dict = torch.load(path, map_location=cuda)
61
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
62
  model_dict.update(pretrained_dict)
63
  model.load_state_dict(model_dict)
rect_main.py CHANGED
@@ -33,8 +33,8 @@ def load_geotrp_model(cuda, path=""):
33
  def load_docscanner_model(cuda, path_l="", path_m=""):
34
 
35
  net = DocScanner.Net().to(cuda)
36
- DocScanner.reload_seg_model(net.msk, path_m)
37
- DocScanner.reload_rec_model(net.bm, path_l)
38
  net.eval()
39
 
40
  return net
 
33
  def load_docscanner_model(cuda, path_l="", path_m=""):
34
 
35
  net = DocScanner.Net().to(cuda)
36
+ DocScanner.reload_seg_model(cuda, net.msk, path_m)
37
+ DocScanner.reload_rec_model(cuda, net.bm, path_l)
38
  net.eval()
39
 
40
  return net