image_cut_rect / rect_main.py
HERIUN
add files
6a07cb2
raw
history blame
4.6 kB
import importlib
import warnings
from collections import defaultdict
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from config import Config
from data_utils.image_utils import _to_2d
warnings.filterwarnings("ignore")
DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference")
DocScanner = importlib.import_module("models.DocScanner.inference")
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mask_dict = defaultdict(int)
def load_geotrp_model(cuda, path=""):
_GeoTrP = DocTr_Plus.GeoTrP()
_GeoTrP = _GeoTrP.to(cuda)
DocTr_Plus.reload_model(_GeoTrP.GeoTr, path)
_GeoTrP.eval()
return _GeoTrP
def load_docscanner_model(cuda, path_l="", path_m=""):
net = DocScanner.Net().to(cuda)
DocScanner.reload_seg_model(net.msk, path_m)
DocScanner.reload_rec_model(net.bm, path_l)
net.eval()
return net
def preprocess_image(img, target_size=[288, 288]):
im_ori = img[:, :, :3] / 255.0
h_, w_, _ = im_ori.shape
im_ori_resized = cv2.resize(im_ori, (288, 288))
im = cv2.resize(im_ori_resized, target_size)
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
return im_ori, im, h_, w_
def geotrp_rec(img, model):
im_ori, im, h_, w_ = preprocess_image(img)
with torch.no_grad():
bm = model(im.cuda())
bm = bm.cpu().numpy()[0]
bm0 = bm[0, :, :]
bm1 = bm[1, :, :]
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255
img_geo = cv2.resize(img_geo, (w_, h_))
return img_geo
def docscanner_get_mask(img, model):
_, im, h, w = preprocess_image(img)
with torch.no_grad():
_, msk = model(im.cuda())
msk = msk.cpu()
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
mask_resized = cv2.resize(mask_np, (w, h))
return mask_resized
def docscanner_rec_img(img, model):
im_ori, im, h, w = preprocess_image(img)
with torch.no_grad():
bm = model(im.cuda())
bm = bm.cpu()
# save rectified image
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
lbl,
align_corners=True,
)
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
return img
def docscanner_rec(img, model):
im_ori = img[:, :, :3] / 255.0
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (288, 288))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm, msk = model(im.cuda())
bm = bm.cpu()
msk = msk.cpu()
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
mask_resized = cv2.resize(mask_np, (w, h))
mask_img = mask_resized
# save rectified image
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
lbl,
align_corners=True,
)
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
return img, mask_img
# μΆ”ν›„ data_utils에 넣을 μ˜ˆμ •
def get_mask_white_area(mask):
"""
Get the white area (non-zero pixels) of a mask.
Args:
mask (np.ndarray): Input mask image (2D or 3D array)
Returns:
np.ndarray: Array of (y, x) coordinates of white pixels
"""
mask = _to_2d(mask)
white_pixels = np.argwhere(mask > 0)
return white_pixels
def main():
config = Config()
img = cv2.imread("input/test.jpg") # μ½”λ“œ μ‹€ν–‰μ‹œ μˆ˜μ • ν•„μš”
docscanner = load_docscanner_model(
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
)
doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path)
mask = docscanner_get_mask(img, docscanner)
mask_dict.add(get_mask_white_area(mask))
if __name__ == "__main__":
main()