File size: 22,750 Bytes
c9019cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# Copyright (c) 2022, National Diet Library, Japan
#
# This software is released under the CC BY 4.0.
# https://creativecommons.org/licenses/by/4.0/


import copy
import cv2
import glob
import os
import pathlib
import sys
import time
import xml
import xml.etree.ElementTree as ET

from . import utils
from .. import procs

# Add import path for src modules
currentdir = pathlib.Path(__file__).resolve().parent
sys.path.append(str(currentdir) + "/../../src/separate_pages_ssd")
sys.path.append(str(currentdir) + "/../../src/ndl_layout")
sys.path.append(str(currentdir) + "/../../src/deskew_HT")
sys.path.append(str(currentdir) + "/../../src/text_recognition")

# supported image type list
supported_img_ext = ['.jpg', '.jpeg', '.jp2']


class OcrInferencer:
    """
    推論実行時の関数や推論の設定値を保持します。

    Attributes
    ----------
    full_proc_list : list
        全推論処理のリストです。
    proc_list : list
        本実行処理における推論処理のリストです。
    cfg : dict
        本実行処理における設定情報です。
    """

    def __init__(self, cfg):
        """
        Parameters
        ----------
        cfg : dict
            本実行処理における設定情報です。
        """
        # inference process class list in order
        self.full_proc_list = [
            procs.PageSeparation,           # 0: ノド元分割               出力:(画像:あり、XML:なし、TXT:なし)
            procs.PageDeskewProcess,        # 1: 傾き補正                 出力:(画像:あり、XML:なし、TXT:なし)
            procs.LayoutExtractionProcess,  # 2: レイアウト抽出           出力:(画像:あり、XML:あり、TXT:なし)
            procs.LineOcrProcess,           # 3: 文字認識(OCR)            出力:(画像:あり、XML:あり、TXT:あり)
        ]
        self.proc_list = self._create_proc_list(cfg)
        self.cfg = cfg
        self.time_statistics = []
        self.xml_template = '<?xml version="1.0" encoding="utf-8" standalone="yes"?>\n<OCRDATASET></OCRDATASET>'

    def run(self):
        """
        self.cfgに保存された設定に基づいた推論処理を実行します。
        """
        if len(self.cfg['input_dirs']) == 0:
            print('[ERROR] Input directory list is empty', file=sys.stderr)
            return

        # input dir loop
        for input_dir in self.cfg['input_dirs']:
            if self.cfg['input_structure'] in ['t']:
                single_outputdir_data_list = self._get_single_dir_data_from_tosho_data(input_dir)
            else:
                single_outputdir_data_list = self._get_single_dir_data(input_dir)

            if single_outputdir_data_list is None:
                print('[ERROR] Input data list is empty', file=sys.stderr)
                continue
            print(single_outputdir_data_list)
            # do infer with input data for single output data dir
            for single_outputdir_data in single_outputdir_data_list:
                print(single_outputdir_data)
                if single_outputdir_data is None:
                    continue
                pred_list = self._infer(single_outputdir_data)

                # save inferenced xml in xml directory
                if (self.cfg['save_xml'] or self.cfg['partial_infer']) and (self.cfg['proc_range']['end'] > 1):
                    self._save_pred_xml(single_outputdir_data['output_dir'], [single_data['xml'] for single_data in pred_list])
        if len(self.time_statistics) == 0:
            print('================== NO VALID INFERENCE ==================')
        else:
            average = sum(self.time_statistics) / len(self.time_statistics)
            print('================== PROCESSING TIME ==================')
            print('Average processing time : {0} sec / image file '.format(average))
        return

    def _infer(self, single_outputdir_data):
        """
        self.cfgに保存された設定に基づき、XML一つ分のデータに対する推論処理を実行します。

        Parameters
        ----------
        single_outputdir_data : dict
            XML一つ分のデータ(基本的に1書籍分を想定)の入力データ情報。
            画像ファイルパスのリスト、それらに対応するXMLデータを含みます。

        Returns
        -------
        pred_list : list
            1ページ分の推論結果を要素に持つ推論結果のリスト。
            各結果は辞書型で保持されています。
        """
        # single_outputdir_data dictionary include [key, value] pairs as below
        # (xml is not always included)
        #   [key, value]: ['img', numpy.ndarray], ['xml', xml_tree]
        pred_list = []
        pred_xml_dict_for_dump = {}
        if self.cfg['dump']:
            dump_dir = os.path.join(single_outputdir_data['output_dir'], 'dump')
            os.makedirs(dump_dir, exist_ok=True)

            for proc in self.proc_list:
                pred_xml_dict_for_dump[proc.proc_name] = []
                proc_dump_dir = os.path.join(dump_dir, proc.proc_name)
                os.makedirs(proc_dump_dir, exist_ok=True)

        for img_path in single_outputdir_data['img_list']:
            single_image_file_data = self._get_single_image_file_data(img_path, single_outputdir_data)
            output_dir = single_outputdir_data['output_dir']
            if single_image_file_data is None:
                print('[ERROR] Failed to get single page input data for image:{0}'.format(img_path), file=sys.stderr)
                continue

            print('######## START PAGE INFERENCE PROCESS ########')
            start_page = time.time()

            for proc in self.proc_list:
                single_page_output = []
                for idx, single_data_input in enumerate(single_image_file_data):
                    single_data_output = proc.do(idx, single_data_input)
                    single_page_output.extend(single_data_output)
                # save inference result data to dump
                if self.cfg['dump'] and 'xml' in single_image_file_data[0].keys():
                    pred_xml_dict_for_dump[proc.proc_name].append(single_image_file_data[0]['xml'])

                single_image_file_data = single_page_output

            single_image_file_output = single_image_file_data
            self.time_statistics.append(time.time() - start_page)

            if self.cfg['save_image'] or self.cfg['partial_infer']:
                # save inferenced result drawn image in pred_img directory
                for single_data_output in single_image_file_output:
                    # save input image while partial inference
                    if self.cfg['partial_infer']:
                        img_output_dir = os.path.join(output_dir, 'img')
                        self._save_image(single_data_output['img'], single_data_output['img_file_name'], img_output_dir)

                    pred_img = self._create_result_image(single_data_output, self.proc_list[-1].proc_name)
                    img_output_dir = os.path.join(output_dir, 'pred_img')
                    self._save_image(pred_img, single_data_output['img_file_name'], img_output_dir)

            # save inferenced result text for this page
            if self.cfg['proc_range']['end'] > 2:
                sum_main_txt = ''
                sum_cap_txt = ''
                for single_data_output in single_image_file_output:
                    main_txt, cap_txt = self._create_result_txt(single_data_output['xml'])
                    sum_main_txt += main_txt + '\n'
                    sum_cap_txt += sum_cap_txt + '\n'
                self._save_pred_txt(sum_main_txt, sum_cap_txt, os.path.basename(img_path), single_outputdir_data['output_dir'])

            # add inference result for single image file data to pred_list, including XML data
            pred_list.extend(single_image_file_output)
            print('########  END PAGE INFERENCE PROCESS  ########')

        return pred_list

    def _get_single_dir_data(self, input_dir):
        """
        XML一つ分の入力データに関する情報を整理して取得します。

        Parameters
        ----------
        input_dir : str
            XML一つ分の入力データが保存されているディレクトリパスです。

        Returns
        -------
        # Fixme
        single_dir_data : dict
            XML一つ分のデータ(基本的に1PID分を想定)の入力データ情報です。
            画像ファイルパスのリスト、それらに対応するXMLデータを含みます。
        """
        single_dir_data = {'input_dir': os.path.abspath(input_dir)}
        single_dir_data['img_list'] = []

        # get img list of input directory
        if self.cfg['input_structure'] in ['w']:
            for ext in supported_img_ext:
                single_dir_data['img_list'].extend(sorted(glob.glob(os.path.join(input_dir, '*{0}'.format(ext)))))
        elif self.cfg['input_structure'] in ['f']:
            stem, ext = os.path.splitext(os.path.basename(input_dir))
            if ext in supported_img_ext:
                single_dir_data['img_list'] = [input_dir]
            else:
                print('[ERROR] This file is not supported type : {0}'.format(input_dir), file=sys.stderr)
        elif not os.path.isdir(os.path.join(input_dir, 'img')):
            print('[ERROR] Input img diretctory not found in {}'.format(input_dir), file=sys.stderr)
            return None
        else:
            for ext in supported_img_ext:
                single_dir_data['img_list'].extend(sorted(glob.glob(os.path.join(input_dir, 'img/*{0}'.format(ext)))))

        # check xml file number and load xml data if needed
        if self.cfg['proc_range']['start'] > 2:
            if self.cfg['input_structure'] in ['f']:
                print('[ERROR] Single image file input mode does not support partial inference wich need xml file input.', file=sys.stderr)
                return None
            input_xml = None
            xml_file_list = glob.glob(os.path.join(input_dir, 'xml/*.xml'))
            if len(xml_file_list) > 1:
                print('[ERROR] Input xml file must be only one, but there is {0} xml files in {1}.'.format(
                    len(xml_file_list), os.path.join(self.cfg['input_root'], 'xml')), file=sys.stderr)
                return None
            elif len(xml_file_list) == 0:
                print('[ERROR] There is no input xml files in {0}.'.format(os.path.join(input_dir, 'xml')), file=sys.stderr)
                return None
            else:
                input_xml = xml_file_list[0]
            try:
                single_dir_data['xml'] = ET.parse(input_xml)
            except xml.etree.ElementTree.ParseError as err:
                print("[ERROR] XML parse error : {0}".format(input_xml), file=sys.stderr)
                return None

        # prepare output dir for inferensce result with this input dir
        if self.cfg['input_structure'] in ['f']:
            stem, ext = os.path.splitext(os.path.basename(input_dir))
            output_dir = os.path.join(self.cfg['output_root'], stem)
        elif self.cfg['input_structure'] in ['i', 's']:
            dir_name = os.path.basename(input_dir)
            output_dir = os.path.join(self.cfg['output_root'], dir_name)
        elif self.cfg['input_structure'] in ['w']:
            input_dir_names = input_dir.split('/')
            dir_name = input_dir_names[-3][0] + input_dir_names[-2] + input_dir_names[-1]
            output_dir = os.path.join(self.cfg['output_root'], dir_name)
        else:
            print('[ERROR] Unexpected input directory structure type: {}.'.format(self.cfg['input_structure']), file=sys.stderr)
            return None

        # output directory existance check
        output_dir = utils.mkdir_with_duplication_check(output_dir)
        single_dir_data['output_dir'] = output_dir

        return [single_dir_data]

    def _get_single_dir_data_from_tosho_data(self, input_dir):
        """
        XML一つ分の入力データに関する情報を整理して取得します。

        Parameters
        ----------
        input_dir : str
            tosho data形式のセクションごとのディレクトリパスです。

        Returns
        -------
        single_dir_data_list : list
            XML一つ分のデータ(基本的に1PID分を想定)の入力データ情報のリストです。
            1つの要素に画像ファイルパスのリスト、それらに対応するXMLデータを含みます。
        """
        single_dir_data_list = []

        # get img list of input directory
        tmp_img_list = sorted(glob.glob(os.path.join(input_dir, '*.jp2')))
        tmp_img_list.extend(sorted(glob.glob(os.path.join(input_dir, '*.jpg'))))

        pid_list = []
        for img in tmp_img_list:
            pid = os.path.basename(img).split('_')[0]
            if pid not in pid_list:
                pid_list.append(pid)

        for pid in pid_list:
            single_dir_data = {'input_dir': os.path.abspath(input_dir),
                               'img_list': [img for img in tmp_img_list if os.path.basename(img).startswith(pid)]}

            # prepare output dir for inferensce result with this input dir
            output_dir = os.path.join(self.cfg['output_root'], pid)

            # output directory existance check
            os.makedirs(output_dir, exist_ok=True)
            single_dir_data['output_dir'] = output_dir
            single_dir_data_list.append(single_dir_data)

        return single_dir_data_list

    def _get_single_image_file_data(self, img_path, single_dir_data):
        """
        1ページ分の入力データに関する情報を整理して取得します。

        Parameters
        ----------
        img_path : str
            入力画像データのパスです。
        single_dir_data : dict
            1書籍分の入力データに関する情報を保持する辞書型データです。
            xmlファイルへのパス、結果を出力するディレクトリのパスなどを含みます。

        Returns
        -------
        single_image_file_data : dict
            1ページ分のデータの入力データ情報です。
            画像ファイルのパスとnumpy.ndarray形式の画像データ、その画像に対応するXMLデータを含みます。
        """
        single_image_file_data = [{
            'img_path': img_path,
            'img_file_name': os.path.basename(img_path),
            'output_dir': single_dir_data['output_dir']
        }]

        full_xml = None
        if 'xml' in single_dir_data.keys():
            full_xml = single_dir_data['xml']

        # get img data for single page
        orig_img = cv2.imread(img_path)
        if orig_img is None:
            print('[ERROR] Image read error : {0}'.format(img_path), file=sys.stderr)
            return None
        single_image_file_data[0]['img'] = orig_img

        # return if this proc needs only img data for input
        if full_xml is None:
            return single_image_file_data

        # get xml data for single page
        image_name = os.path.basename(img_path)
        for page in full_xml.getroot().iter('PAGE'):
            if page.attrib['IMAGENAME'] == image_name:
                node = ET.fromstring(self.xml_template)
                node.append(page)
                tree = ET.ElementTree(node)
                single_image_file_data[0]['xml'] = tree
                break

        # [TODO] 画像データに対応するXMLデータが見つからなかった場合の対応
        if 'xml' not in single_image_file_data[0].keys():
            print('[ERROR] Input XML data for page {} not found.'.format(img_path), file=sys.stderr)

        return single_image_file_data

    def _create_proc_list(self, cfg):
        """
        推論の設定情報に基づき、実行する推論処理のリストを作成します。

        Parameters
        ----------
        cfg : dict
            推論実行時の設定情報を保存した辞書型データ。
        """
        proc_list = []
        for i in range(cfg['proc_range']['start'], cfg['proc_range']['end'] + 1):
            proc_list.append(self.full_proc_list[i](cfg, i))
        return proc_list

    def _save_pred_xml(self, output_dir, pred_list):
        """
        推論結果のXMLデータをまとめたXMLファイルを生成して保存します。

        Parameters
        ----------
        output_dir : str
            推論結果を保存するディレクトリのパスです。
        pred_list : list
            1ページ分の推論結果を要素に持つ推論結果のリスト。
            各結果は辞書型で保持されています。
        """
        xml_dir = os.path.join(output_dir, 'xml')
        os.makedirs(xml_dir, exist_ok=True)

        # basically, output_dir is supposed to be PID, so it used as xml filename
        xml_path = os.path.join(xml_dir, '{}.xml'.format(os.path.basename(output_dir)))
        pred_xml = self._parse_pred_list_to_save(pred_list)
        utils.save_xml(pred_xml, xml_path)
        return

    def _save_image(self, pred_img, orig_img_name, img_output_dir, id=''):
        """
        指定されたディレクトリに画像データを保存します。
        画像データは入力に使用したものと推論結果を重畳したものの2種類が想定されています。

        Parameters
        ----------
        pred_img : numpy.ndarray
            保存する画像データ。
        orig_img_name : str
            もともとの入力画像のファイル名。
            基本的にはこのファイル名と同名で保存します。
        img_output_dir : str
            画像ファイルの保存先のディレクトリパス。
        id : str
            もともとの入力画像のファイル名に追加する処理結果ごとのidです。
            一つの入力画像から複数の画像データが出力される処理がある場合に必要になります。
        """
        os.makedirs(img_output_dir, exist_ok=True)
        stem, ext = os.path.splitext(orig_img_name)
        orig_img_name = stem + '.jpg'

        if id != '':
            stem, ext = os.path.splitext(orig_img_name)
            orig_img_name = stem + '_' + id + ext

        img_path = os.path.join(img_output_dir, orig_img_name)
        try:
            cv2.imwrite(img_path, pred_img)
        except OSError as err:
            print("[ERROR] Image save error: {0}".format(err), file=sys.stderr)
            raise OSError

        return

    def _save_pred_txt(self, main_txt, cap_txt, orig_img_name, output_dir):
        """
        指定されたディレクトリに推論結果のテキストデータを保存します。

        Parameters
        ----------
        main_txt : str
            本文+キャプションの推論結果のテキストデータです
        cap_txt : str
            キャプションのみの推論結果のテキストデータです
        orig_img_name : str
            もともとの入力画像ファイル名。
            基本的にはこのファイル名と同名で保存します。
        img_output_dir : str
            画像ファイルの保存先のディレクトリパス。
        """
        txt_dir = os.path.join(output_dir, 'txt')
        os.makedirs(txt_dir, exist_ok=True)

        stem, _ = os.path.splitext(orig_img_name)
        txt_path = os.path.join(txt_dir, stem + '_cap.txt')
        try:
            with open(txt_path, 'w') as f:
                f.write(cap_txt)
        except OSError as err:
            print("[ERROR] Caption text save error: {0}".format(err), file=sys.stderr)
            raise OSError

        stem, _ = os.path.splitext(orig_img_name)
        txt_path = os.path.join(txt_dir, stem + '_main.txt')
        try:
            with open(txt_path, 'w') as f:
                f.write(main_txt)
        except OSError as err:
            print("[ERROR] Main text save error: {0}".format(err), file=sys.stderr)
            raise OSError

        return

    def _parse_pred_list_to_save(self, pred_list):
        """
        推論結果のXMLを要素に持つリストから、ファイルに保存するための一つのXMLデータを生成します。

        Parameters
        ----------
        pred_list : list
            推論結果のXMLを要素に持つリスト。
        """
        ET.register_namespace('', 'NDLOCRDATASET')
        node = ET.fromstring(self.xml_template)
        for single_xml_tree in pred_list:
            root = single_xml_tree.getroot()
            for element in root:
                node.append(element)

        tree = ET.ElementTree(node)
        return tree

    def _create_result_image(self, result, proc_name):
        """
        推論結果を入力画像に重畳した画像データを生成します。

        Parameters
        ----------
        result : dict
            1ページ分の推論結果を持つ辞書型データ。
        proc_name : str
            重畳を行う結果を出力した推論処理の名前。
        """
        if 'dump_img' in result.keys():
            dump_img = copy.deepcopy(result['dump_img'])
        else:
            dump_img = copy.deepcopy(result['img'])
        if 'xml' in result.keys() and result['xml'] is not None:
            # draw inference result on input image
            cv2.putText(dump_img, proc_name, (0, 50),
                        cv2.FONT_HERSHEY_PLAIN, 4, (0, 0, 0), 5, cv2.LINE_AA)
            pass
        else:
            cv2.putText(dump_img, proc_name, (0, 50),
                        cv2.FONT_HERSHEY_PLAIN, 4, (0, 0, 0), 5, cv2.LINE_AA)
        return dump_img

    def _create_result_txt(self, xml_data):
        """
        推論結果のxmlデータからテキストデータを生成します。

        Parameters
        ----------
        xml_data :
            1ページ分の推論結果を持つxmlデータ。
        """
        main_txt = ''
        cap_txt = ''
        for page_xml in xml_data.iter('PAGE'):
            for line_xml in page_xml.iter('LINE'):
                main_txt += line_xml.attrib['STRING']
                main_txt += '\n'
                if line_xml.attrib['TYPE'] == 'キャプション':
                    cap_txt += line_xml.attrib['STRING']
                    cap_txt += '\n'

        return main_txt, cap_txt