Arnaudding001 commited on
Commit
19d13fe
1 Parent(s): 1587cf9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
3
+
4
+ # work around: https://discuss.huggingface.co/t/how-to-install-a-specific-version-of-gradio-in-spaces/13552
5
+ os.system("pip uninstall -y gradio")
6
+ os.system("pip install gradio==3.4.1")
7
+
8
+ from os import getcwd, path, environ
9
+ import deepdoctection as dd
10
+ from deepdoctection.dataflow.serialize import DataFromList
11
+
12
+ import gradio as gr
13
+
14
+
15
+ _DD_ONE = "conf_dd_one.yaml"
16
+ _DETECTIONS = ["table", "ocr"]
17
+
18
+ dd.ModelCatalog.register("layout/model_final_inf_only.pt",dd.ModelProfile(
19
+ name="layout/model_final_inf_only.pt",
20
+ description="Detectron2 layout detection model trained on private datasets",
21
+ config="dd/d2/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
22
+ size=[274632215],
23
+ tp_model=False,
24
+ hf_repo_id=environ.get("HF_REPO"),
25
+ hf_model_name="model_final_inf_only.pt",
26
+ hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
27
+ categories={"1": dd.LayoutType.text,
28
+ "2": dd.LayoutType.title,
29
+ "3": dd.LayoutType.list,
30
+ "4": dd.LayoutType.table,
31
+ "5": dd.LayoutType.figure},
32
+ ))
33
+
34
+ # Set up of the configuration and logging. Models are globally defined, so that they are not re-loaded once the input
35
+ # updates
36
+ cfg = dd.set_config_by_yaml(path.join(getcwd(),_DD_ONE))
37
+ cfg.freeze(freezed=False)
38
+ cfg.DEVICE = "cpu"
39
+ cfg.freeze()
40
+
41
+ # layout detector
42
+ layout_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2LAYOUT)
43
+ layout_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2LAYOUT)
44
+ categories_layout = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2LAYOUT).categories
45
+ assert categories_layout is not None
46
+ assert layout_weights_path is not None
47
+ d_layout = dd.D2FrcnnDetector("layout",layout_config_path, layout_weights_path, categories_layout, device=cfg.DEVICE)
48
+
49
+ # cell detector
50
+ cell_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2CELL)
51
+ cell_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2CELL)
52
+ categories_cell = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2CELL).categories
53
+ assert categories_cell is not None
54
+ d_cell = dd.D2FrcnnDetector("cell",cell_config_path, cell_weights_path, categories_cell, device=cfg.DEVICE)
55
+
56
+ # row/column detector
57
+ item_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2ITEM)
58
+ item_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2ITEM)
59
+ categories_item = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2ITEM).categories
60
+ assert categories_item is not None
61
+ d_item = dd.D2FrcnnDetector("item",item_config_path, item_weights_path, categories_item, device=cfg.DEVICE)
62
+
63
+ # word detector
64
+ det = dd.DoctrTextlineDetector()
65
+
66
+ # text recognizer
67
+ rec = dd.DoctrTextRecognizer()
68
+
69
+
70
+ def build_gradio_analyzer(table, table_ref, ocr):
71
+ """Building the Detectron2/DocTr analyzer based on the given config"""
72
+
73
+ cfg.freeze(freezed=False)
74
+ cfg.TAB = table
75
+ cfg.TAB_REF = table_ref
76
+ cfg.OCR = ocr
77
+ cfg.freeze()
78
+
79
+ pipe_component_list = []
80
+ layout = dd.ImageLayoutService(d_layout, to_image=True, crop_image=True)
81
+ pipe_component_list.append(layout)
82
+
83
+ if cfg.TAB:
84
+ cell = dd.SubImageLayoutService(d_cell, dd.LayoutType.table, {1: 6}, True)
85
+ pipe_component_list.append(cell)
86
+
87
+ item = dd.SubImageLayoutService(d_item, dd.LayoutType.table, {1: 7, 2: 8}, True)
88
+ pipe_component_list.append(item)
89
+
90
+ table_segmentation = dd.TableSegmentationService(
91
+ cfg.SEGMENTATION.ASSIGNMENT_RULE,
92
+ cfg.SEGMENTATION.IOU_THRESHOLD_ROWS
93
+ if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
94
+ else cfg.SEGMENTATION.IOA_THRESHOLD_ROWS,
95
+ cfg.SEGMENTATION.IOU_THRESHOLD_COLS
96
+ if cfg.SEGMENTATION.ASSIGNMENT_RULE in ["iou"]
97
+ else cfg.SEGMENTATION.IOA_THRESHOLD_COLS,
98
+ cfg.SEGMENTATION.FULL_TABLE_TILING,
99
+ cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
100
+ cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
101
+ )
102
+ pipe_component_list.append(table_segmentation)
103
+
104
+ if cfg.TAB_REF:
105
+ table_segmentation_refinement = dd.TableSegmentationRefinementService()
106
+ pipe_component_list.append(table_segmentation_refinement)
107
+
108
+ if cfg.OCR:
109
+ d_layout_text = dd.ImageLayoutService(det, to_image=True, crop_image=True)
110
+ pipe_component_list.append(d_layout_text)
111
+
112
+ d_text = dd.TextExtractionService(rec, extract_from_roi="WORD")
113
+ pipe_component_list.append(d_text)
114
+
115
+ match = dd.MatchingService(
116
+ parent_categories=cfg.WORD_MATCHING.PARENTAL_CATEGORIES,
117
+ child_categories=dd.LayoutType.word,
118
+ matching_rule=cfg.WORD_MATCHING.RULE,
119
+ threshold=cfg.WORD_MATCHING.IOU_THRESHOLD
120
+ if cfg.WORD_MATCHING.RULE in ["iou"]
121
+ else cfg.WORD_MATCHING.IOA_THRESHOLD,
122
+ )
123
+ pipe_component_list.append(match)
124
+ order = dd.TextOrderService(
125
+ text_container=dd.LayoutType.word,
126
+ floating_text_block_names=[dd.LayoutType.title, dd.LayoutType.text, dd.LayoutType.list],
127
+ text_block_names=[
128
+ dd.LayoutType.title,
129
+ dd.LayoutType.text,
130
+ dd.LayoutType.list,
131
+ dd.LayoutType.cell,
132
+ dd.CellType.header,
133
+ dd.CellType.body,
134
+ ],
135
+ )
136
+ pipe_component_list.append(order)
137
+
138
+ pipe = dd.DoctectionPipe(pipeline_component_list=pipe_component_list)
139
+
140
+ return pipe
141
+
142
+
143
+ def prepare_output(dp, add_table, add_ocr):
144
+ out = dp.as_dict()
145
+ out.pop("image")
146
+
147
+ layout_items = dp.items
148
+ if add_ocr:
149
+ layout_items.sort(key=lambda x: x.reading_order)
150
+ layout_items_str = ""
151
+ for item in layout_items:
152
+ layout_items_str += f"\n {item.layout_type}: {item.text}"
153
+ if add_table:
154
+ html_list = [table.html for table in dp.tables]
155
+ if html_list:
156
+ html = html_list[0]
157
+ else:
158
+ html = None
159
+ else:
160
+ html = None
161
+
162
+ return dp.viz(show_table_structure=False), layout_items_str, html, out
163
+
164
+
165
+ def analyze_image(img, pdf, attributes):
166
+
167
+ # creating an image object and passing to the analyzer by using dataflows
168
+ add_table = _DETECTIONS[0] in attributes
169
+ add_ocr = _DETECTIONS[1] in attributes
170
+
171
+ analyzer = build_gradio_analyzer(add_table, add_table, add_ocr)
172
+
173
+ if img is not None:
174
+ image = dd.Image(file_name="input.png", location="")
175
+ image.image = img[:, :, ::-1]
176
+
177
+ df = DataFromList(lst=[image])
178
+ df = analyzer.analyze(dataset_dataflow=df)
179
+ elif pdf:
180
+ df = analyzer.analyze(path=pdf.name, max_datapoints=3)
181
+ else:
182
+ raise ValueError
183
+
184
+ df.reset_state()
185
+ df_iter = iter(df)
186
+
187
+ dp = next(df_iter)
188
+
189
+ return prepare_output(dp, add_table, add_ocr)
190
+
191
+
192
+ demo = gr.Blocks(css="scrollbar.css")
193
+
194
+ with demo:
195
+ with gr.Box():
196
+ gr.Markdown("<h1><center>deepdoctection - A Document AI Package</center></h1>")
197
+ gr.Markdown("<strong>deep</strong>doctection is a Python library that orchestrates document extraction"
198
+ " and document layout analysis tasks using deep learning models. It does not implement models"
199
+ " but enables you to build pipelines using highly acknowledged libraries for object detection,"
200
+ " OCR and selected NLP tasks and provides an integrated frameworks for fine-tuning, evaluating"
201
+ " and running models.\n This pipeline consists of a stack of models powered by <strong>Detectron2"
202
+ "</strong> for layout analysis and table recognition and <strong>DocTr</strong> for OCR.")
203
+ with gr.Box():
204
+ gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
205
+ with gr.Row():
206
+ with gr.Column():
207
+ with gr.Tab("Image upload"):
208
+ with gr.Column():
209
+ inputs = gr.Image(type='numpy', label="Original Image")
210
+ with gr.Tab("PDF upload (only first image will be processed) *"):
211
+ with gr.Column():
212
+ inputs_pdf = gr.File(label="PDF")
213
+ gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
214
+ with gr.Column():
215
+ gr.Examples(
216
+ examples=[path.join(getcwd(), "sample_1.jpg"), path.join(getcwd(), "sample_2.png")],
217
+ inputs = inputs)
218
+ gr.Examples(examples=[path.join(getcwd(), "sample_3.pdf")], inputs = inputs_pdf)
219
+
220
+ with gr.Row():
221
+ tok_input = gr.CheckboxGroup(
222
+ _DETECTIONS, value=_DETECTIONS, label="Additional extractions", interactive=True)
223
+ with gr.Row():
224
+ btn = gr.Button("Run model", variant="primary")
225
+
226
+ with gr.Box():
227
+ gr.Markdown("<h2><center>Outputs</center></h2>")
228
+ with gr.Row():
229
+ with gr.Column():
230
+ with gr.Box():
231
+ gr.Markdown("<center><strong>Contiguous text</strong></center>")
232
+ image_text = gr.Textbox()
233
+ with gr.Box():
234
+ gr.Markdown("<center><strong>Table</strong></center>")
235
+ html = gr.HTML()
236
+ with gr.Box():
237
+ gr.Markdown("<center><strong>JSON</strong></center>")
238
+ json = gr.JSON()
239
+ with gr.Column():
240
+ with gr.Box():
241
+ gr.Markdown("<center><strong>Layout detection</strong></center>")
242
+ image_output = gr.Image(type="numpy", label="Output Image")
243
+
244
+ btn.click(fn=analyze_image, inputs=[inputs, inputs_pdf, tok_input], outputs=[image_output, image_text, html, json])
245
+
246
+ demo.launch()