Spaces:
Running
Running
HERIUN
commited on
Commit
•
1081f7c
1
Parent(s):
ba5f22e
add models
Browse files- models/DocScanner/inference.py +4 -4
- rect_main.py +2 -2
models/DocScanner/inference.py
CHANGED
@@ -37,12 +37,12 @@ class Net(nn.Module):
|
|
37 |
return bm, msk
|
38 |
|
39 |
|
40 |
-
def reload_seg_model(model, path=""):
|
41 |
if not bool(path):
|
42 |
return model
|
43 |
else:
|
44 |
model_dict = model.state_dict()
|
45 |
-
pretrained_dict = torch.load(path, map_location=
|
46 |
pretrained_dict = {
|
47 |
k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
|
48 |
}
|
@@ -52,12 +52,12 @@ def reload_seg_model(model, path=""):
|
|
52 |
return model
|
53 |
|
54 |
|
55 |
-
def reload_rec_model(model, path=""):
|
56 |
if not bool(path):
|
57 |
return model
|
58 |
else:
|
59 |
model_dict = model.state_dict()
|
60 |
-
pretrained_dict = torch.load(path, map_location=
|
61 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
62 |
model_dict.update(pretrained_dict)
|
63 |
model.load_state_dict(model_dict)
|
|
|
37 |
return bm, msk
|
38 |
|
39 |
|
40 |
+
def reload_seg_model(cuda, model, path=""):
|
41 |
if not bool(path):
|
42 |
return model
|
43 |
else:
|
44 |
model_dict = model.state_dict()
|
45 |
+
pretrained_dict = torch.load(path, map_location=cuda)
|
46 |
pretrained_dict = {
|
47 |
k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict
|
48 |
}
|
|
|
52 |
return model
|
53 |
|
54 |
|
55 |
+
def reload_rec_model(cuda, model, path=""):
|
56 |
if not bool(path):
|
57 |
return model
|
58 |
else:
|
59 |
model_dict = model.state_dict()
|
60 |
+
pretrained_dict = torch.load(path, map_location=cuda)
|
61 |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
62 |
model_dict.update(pretrained_dict)
|
63 |
model.load_state_dict(model_dict)
|
rect_main.py
CHANGED
@@ -33,8 +33,8 @@ def load_geotrp_model(cuda, path=""):
|
|
33 |
def load_docscanner_model(cuda, path_l="", path_m=""):
|
34 |
|
35 |
net = DocScanner.Net().to(cuda)
|
36 |
-
DocScanner.reload_seg_model(net.msk, path_m)
|
37 |
-
DocScanner.reload_rec_model(net.bm, path_l)
|
38 |
net.eval()
|
39 |
|
40 |
return net
|
|
|
33 |
def load_docscanner_model(cuda, path_l="", path_m=""):
|
34 |
|
35 |
net = DocScanner.Net().to(cuda)
|
36 |
+
DocScanner.reload_seg_model(cuda, net.msk, path_m)
|
37 |
+
DocScanner.reload_rec_model(cuda, net.bm, path_l)
|
38 |
net.eval()
|
39 |
|
40 |
return net
|