import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse import os import warnings import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from DocScanner.model import DocScanner from DocScanner.seg import U2NETP from PIL import Image warnings.filterwarnings("ignore") class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.msk = U2NETP(3, 1) self.bm = DocScanner() # 矫正 def forward(self, x): msk, _1, _2, _3, _4, _5, _6 = self.msk(x) msk = (msk > 0.5).float() x = msk * x bm = self.bm(x, iters=12, test_mode=True) bm = (2 * (bm / 286.8) - 1) * 0.99 return bm, msk def reload_seg_model(cuda, model, path=""): if not bool(path): return model else: model_dict = model.state_dict() pretrained_dict = torch.load(path, map_location=cuda) pretrained_dict = { k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) return model def reload_rec_model(cuda, model, path=""): if not bool(path): return model else: model_dict = model.state_dict() pretrained_dict = torch.load(path, map_location=cuda) pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) return model