|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import numpy |
|
import os |
|
|
|
from .base_proc import BaseInferenceProcess |
|
|
|
|
|
class PageSeparation(BaseInferenceProcess): |
|
""" |
|
ノド元分割処理を実行するプロセスのクラス。 |
|
BaseInferenceProcessを継承しています。 |
|
""" |
|
def __init__(self, cfg, pid): |
|
""" |
|
Parameters |
|
---------- |
|
cfg : dict |
|
本推論処理における設定情報です。 |
|
pid : int |
|
実行される順序を表す数値。 |
|
""" |
|
super().__init__(cfg, pid, '_page_sep') |
|
|
|
if self.cfg['page_separation']['silence_tf_log']: |
|
import logging |
|
import warnings |
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
|
warnings.simplefilter(action='ignore', category=FutureWarning) |
|
|
|
import tensorflow as tf |
|
tf.get_logger().setLevel(logging.ERROR) |
|
|
|
from src.separate_pages_ssd.inference_divided import divide_facing_page_with_cli, load_weightfile |
|
load_weightfile(os.path.abspath(self.cfg['page_separation']['weight_path'])) |
|
self._run_src_inference = divide_facing_page_with_cli |
|
|
|
def _is_valid_input(self, input_data): |
|
""" |
|
本クラスの推論処理における入力データのバリデーション。 |
|
|
|
Parameters |
|
---------- |
|
input_data : dict |
|
推論処理を実行する対象の入力データ。 |
|
|
|
Returns |
|
------- |
|
[変数なし] : bool |
|
入力データが正しければTrue, そうでなければFalseを返します。 |
|
""" |
|
if type(input_data['img']) is not numpy.ndarray: |
|
print('PageSeparation: input img is not numpy.ndarray') |
|
return False |
|
return True |
|
|
|
def _run_process(self, input_data): |
|
""" |
|
推論処理の本体部分。 |
|
|
|
Parameters |
|
---------- |
|
input_data : dict |
|
推論処理を実行する対象の入力データ。 |
|
|
|
Returns |
|
------- |
|
result : dict |
|
推論処理の結果を保持する辞書型データ。 |
|
基本的にinput_dataと同じ構造です。 |
|
""" |
|
print('### Page Separation ###') |
|
log_file_path = None |
|
if self.process_dump_dir is not None: |
|
log_file_path = os.path.join(self.process_dump_dir, self.cfg['page_separation']['log']) |
|
inference_output = self._run_src_inference(input=input_data['img'], |
|
input_path=input_data['img_path'], |
|
left=self.cfg['page_separation']['left'], |
|
right=self.cfg['page_separation']['right'], |
|
single=self.cfg['page_separation']['single'], |
|
ext=self.cfg['page_separation']['ext'], |
|
quality=self.cfg['page_separation']['quality'], |
|
short=self.cfg['page_separation']['short'], |
|
log=log_file_path) |
|
if (not self.cfg['page_separation']['allow_invalid_num_output']) and (not len(inference_output) in range(1, 3)): |
|
print('ERROR: Output from page separation must be 1 or 2 pages.') |
|
return None |
|
|
|
|
|
result = [] |
|
for id, single_output_img in enumerate(inference_output): |
|
output_data = copy.deepcopy(input_data) |
|
output_data['img'] = single_output_img |
|
output_data['orig_img_path'] = input_data['img_path'] |
|
|
|
|
|
if id == 0: |
|
id = 'L' |
|
else: |
|
id = 'R' |
|
orig_img_name = os.path.basename(input_data['img_path']) |
|
stem, ext = os.path.splitext(orig_img_name) |
|
output_data['img_file_name'] = stem + '_' + id + '.jpg' |
|
|
|
result.append(output_data) |
|
|
|
return result |
|
|