Spaces:
Running
Running
import importlib | |
import warnings | |
from collections import defaultdict | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from config import Config | |
from data_utils.image_utils import _to_2d | |
warnings.filterwarnings("ignore") | |
DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference") | |
DocScanner = importlib.import_module("models.DocScanner.inference") | |
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
mask_dict = defaultdict(int) | |
def load_geotrp_model(cuda, path=""): | |
_GeoTrP = DocTr_Plus.GeoTrP() | |
_GeoTrP = _GeoTrP.to(cuda) | |
DocTr_Plus.reload_model(_GeoTrP.GeoTr, path) | |
_GeoTrP.eval() | |
return _GeoTrP | |
def load_docscanner_model(cuda, path_l="", path_m=""): | |
net = DocScanner.Net().to(cuda) | |
DocScanner.reload_seg_model(net.msk, path_m) | |
DocScanner.reload_rec_model(net.bm, path_l) | |
net.eval() | |
return net | |
def preprocess_image(img, target_size=[288, 288]): | |
im_ori = img[:, :, :3] / 255.0 | |
h_, w_, _ = im_ori.shape | |
im_ori_resized = cv2.resize(im_ori, (288, 288)) | |
im = cv2.resize(im_ori_resized, target_size) | |
im = im.transpose(2, 0, 1) | |
im = torch.from_numpy(im).float().unsqueeze(0) | |
return im_ori, im, h_, w_ | |
def geotrp_rec(img, model): | |
im_ori, im, h_, w_ = preprocess_image(img) | |
with torch.no_grad(): | |
bm = model(im.cuda()) | |
bm = bm.cpu().numpy()[0] | |
bm0 = bm[0, :, :] | |
bm1 = bm[1, :, :] | |
bm0 = cv2.blur(bm0, (3, 3)) | |
bm1 = cv2.blur(bm1, (3, 3)) | |
img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255 | |
img_geo = cv2.resize(img_geo, (w_, h_)) | |
return img_geo | |
def docscanner_get_mask(img, model): | |
_, im, h, w = preprocess_image(img) | |
with torch.no_grad(): | |
_, msk = model(im.cuda()) | |
msk = msk.cpu() | |
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) | |
mask_resized = cv2.resize(mask_np, (w, h)) | |
return mask_resized | |
def docscanner_rec_img(img, model): | |
im_ori, im, h, w = preprocess_image(img) | |
with torch.no_grad(): | |
bm = model(im.cuda()) | |
bm = bm.cpu() | |
# save rectified image | |
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow | |
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow | |
bm0 = cv2.blur(bm0, (3, 3)) | |
bm1 = cv2.blur(bm1, (3, 3)) | |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 | |
out = F.grid_sample( | |
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), | |
lbl, | |
align_corners=True, | |
) | |
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) | |
return img | |
def docscanner_rec(img, model): | |
im_ori = img[:, :, :3] / 255.0 | |
h, w, _ = im_ori.shape | |
im = cv2.resize(im_ori, (288, 288)) | |
im = im.transpose(2, 0, 1) | |
im = torch.from_numpy(im).float().unsqueeze(0) | |
with torch.no_grad(): | |
bm, msk = model(im.cuda()) | |
bm = bm.cpu() | |
msk = msk.cpu() | |
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) | |
mask_resized = cv2.resize(mask_np, (w, h)) | |
mask_img = mask_resized | |
# save rectified image | |
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow | |
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow | |
bm0 = cv2.blur(bm0, (3, 3)) | |
bm1 = cv2.blur(bm1, (3, 3)) | |
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 | |
out = F.grid_sample( | |
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), | |
lbl, | |
align_corners=True, | |
) | |
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) | |
return img, mask_img | |
# μΆν data_utilsμ λ£μ μμ | |
def get_mask_white_area(mask): | |
""" | |
Get the white area (non-zero pixels) of a mask. | |
Args: | |
mask (np.ndarray): Input mask image (2D or 3D array) | |
Returns: | |
np.ndarray: Array of (y, x) coordinates of white pixels | |
""" | |
mask = _to_2d(mask) | |
white_pixels = np.argwhere(mask > 0) | |
return white_pixels | |
def main(): | |
config = Config() | |
img = cv2.imread("input/test.jpg") # μ½λ μ€νμ μμ νμ | |
docscanner = load_docscanner_model( | |
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path | |
) | |
doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path) | |
mask = docscanner_get_mask(img, docscanner) | |
mask_dict.add(get_mask_white_area(mask)) | |
if __name__ == "__main__": | |
main() | |