breezedeus commited on
Commit
66658a5
1 Parent(s): 7aa3bf4

transfer from streamlit to gradio

Browse files
Files changed (4) hide show
  1. README.md +5 -3
  2. app.py +148 -77
  3. requirements.txt +188 -98
  4. streamlit_app.py +183 -0
README.md CHANGED
@@ -1,14 +1,16 @@
1
  ---
2
- title: CnOCR
3
  emoji: 🅞🅒🅡
4
  colorFrom: indigo
5
  colorTo: yellow
6
- sdk: streamlit
 
7
  app_file: app.py
8
  pinned: false
 
9
  ---
10
 
11
- # CnOCR
12
 
13
  [**CnOCR**](https://github.com/breezedeus/cnocr) is an **Optical Character Recognition (OCR)** toolkit for **Python 3**. It supports recognition of common characters in **English and numbers**, **Simplified Chinese**, **Traditional Chinese** (some models), and **vertical text** recognition. It comes with [**20+ well-trained models**](https://cnocr.readthedocs.io/zh/latest/models/) for different application scenarios and can be used directly after installation. Also, CnOCR provides simple training [commands](https://cnocr.readthedocs.io/zh/latest/train/) for users to train their own models. Welcome to join the WeChat contact group.
14
 
 
1
  ---
2
+ title: CnOCR Demo
3
  emoji: 🅞🅒🅡
4
  colorFrom: indigo
5
  colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.44.4
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # CnOCR (Cn-OCR)
14
 
15
  [**CnOCR**](https://github.com/breezedeus/cnocr) is an **Optical Character Recognition (OCR)** toolkit for **Python 3**. It supports recognition of common characters in **English and numbers**, **Simplified Chinese**, **Traditional Chinese** (some models), and **vertical text** recognition. It comes with [**20+ well-trained models**](https://cnocr.readthedocs.io/zh/latest/models/) for different application scenarios and can be used directly after installation. Also, CnOCR provides simple training [commands](https://cnocr.readthedocs.io/zh/latest/train/) for users to train their own models. Welcome to join the WeChat contact group.
16
 
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding: utf-8
2
- # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
3
  # Licensed to the Apache Software Foundation (ASF) under one
4
  # or more contributor license agreements. See the NOTICE file
5
  # distributed with this work for additional information
@@ -18,12 +18,10 @@
18
  # under the License.
19
 
20
  import os
21
- from collections import OrderedDict
22
 
 
23
  import cv2
24
  import numpy as np
25
- from PIL import Image
26
- import streamlit as st
27
  from cnstd.utils import pil_to_numpy, imsave
28
 
29
  from cnocr import CnOcr, DET_AVAILABLE_MODELS, REC_AVAILABLE_MODELS
@@ -31,7 +29,6 @@ from cnocr.utils import set_logger, draw_ocr_results, download
31
 
32
 
33
  logger = set_logger()
34
- st.set_page_config(layout="wide")
35
 
36
 
37
  def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefix_fp):
@@ -41,6 +38,8 @@ def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefi
41
  rotated_img = rotated_img.copy()
42
  crops = [info['cropped_img'] for info in one_out]
43
  print('%d boxes are found' % len(crops))
 
 
44
  ncols = crop_ncols
45
  nrows = math.ceil(len(crops) / ncols)
46
  fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
@@ -64,10 +63,9 @@ def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefi
64
  print('boxes results are save to file %s' % result_fp)
65
 
66
 
67
- @st.cache_resource
68
  def get_ocr_model(det_model_name, rec_model_name, det_more_configs):
69
- det_model_name, det_model_backend = det_model_name
70
- rec_model_name, rec_model_backend = rec_model_name
71
  return CnOcr(
72
  det_model_name=det_model_name,
73
  det_model_backend=det_model_backend,
@@ -78,31 +76,33 @@ def get_ocr_model(det_model_name, rec_model_name, det_more_configs):
78
 
79
 
80
  def visualize_naive_result(img, det_model_name, std_out, box_score_thresh):
 
 
 
81
  img = pil_to_numpy(img).transpose((1, 2, 0)).astype(np.uint8)
82
 
83
- plot_for_debugging(img, std_out, box_score_thresh, 2, './streamlit-app')
84
- st.subheader('Detection Result')
85
- if det_model_name == 'default_det':
86
- st.warning('⚠️ Warning: "default_det" 检测模型不返回文本框位置!')
87
- cols = st.columns([1, 7, 1])
88
- cols[1].image('./streamlit-app-result.png')
 
 
 
 
89
 
90
- st.subheader('Recognition Result')
91
- cols = st.columns([1, 7, 1])
92
- cols[1].image('./streamlit-app-crops.png')
93
-
94
- _visualize_ocr(std_out)
95
 
96
 
97
  def _visualize_ocr(ocr_outs):
98
- st.empty()
99
- ocr_res = OrderedDict({'文本': []})
100
- ocr_res['得分'] = []
101
  for out in ocr_outs:
102
  # cropped_img = out['cropped_img'] # 检测出的文本框
103
- ocr_res['得分'].append(out['score'])
104
- ocr_res['文本'].append(out['text'])
105
- st.table(ocr_res)
106
 
107
 
108
  def visualize_result(img, ocr_outs):
@@ -113,55 +113,27 @@ def visualize_result(img, ocr_outs):
113
  os.makedirs(os.path.dirname(font_path), exist_ok=True)
114
  download(url, path=font_path, overwrite=True)
115
  draw_ocr_results(img, ocr_outs, out_draw_fp, font_path)
116
- st.image(out_draw_fp)
117
-
118
-
119
- def main():
120
- st.sidebar.header('模型设置')
121
- det_models = list(DET_AVAILABLE_MODELS.all_models())
122
- det_models.append(('naive_det', 'onnx'))
123
- det_models.sort()
124
- det_model_name = st.sidebar.selectbox(
125
- '选择检测模型', det_models, index=det_models.index(('ch_PP-OCRv3_det', 'onnx'))
126
- )
127
-
128
- all_models = list(REC_AVAILABLE_MODELS.all_models())
129
- all_models.sort()
130
- idx = all_models.index(('densenet_lite_136-fc', 'onnx'))
131
- rec_model_name = st.sidebar.selectbox('选择识别模型', all_models, index=idx)
132
-
133
- st.sidebar.subheader('检测参数')
134
- rotated_bbox = st.sidebar.checkbox('是否检测带角度文本框', value=True)
135
- use_angle_clf = st.sidebar.checkbox('是否使用角度预测模型校正文本框', value=False)
136
- new_size = st.sidebar.slider(
137
- 'resize 后图片(长边)大小', min_value=124, max_value=4096, value=768
138
- )
139
- box_score_thresh = st.sidebar.slider(
140
- '得分阈值(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3
141
- )
142
- min_box_size = st.sidebar.slider(
143
- '框大小阈值(更小的文本框会被过滤掉)', min_value=4, max_value=50, value=10
144
- )
145
- # std = get_std_model(det_model_name, rotated_bbox, use_angle_clf)
146
-
147
- # st.sidebar.markdown("""---""")
148
- # st.sidebar.header('CnOcr 设置')
149
  det_more_configs = dict(rotated_bbox=rotated_bbox, use_angle_clf=use_angle_clf)
150
  ocr = get_ocr_model(det_model_name, rec_model_name, det_more_configs)
151
 
152
- st.markdown('# 开源Python OCR工具 ' '[CnOCR](https://github.com/breezedeus/cnocr)')
153
- st.markdown('> 详细说明参见:[CnOCR 文档](https://cnocr.readthedocs.io/) ;'
154
- '欢迎加入 [交流群](https://www.breezedeus.com/join-group) ;'
155
- '作者:[breezedeus](https://www.breezedeus.com), [Github](https://github.com/breezedeus) 。')
156
- st.markdown('')
157
- st.subheader('选择待检测图片')
158
- content_file = st.file_uploader('', type=["png", "jpg", "jpeg", "webp"])
159
- if content_file is None:
160
- st.stop()
161
-
162
- try:
163
- img = Image.open(content_file).convert('RGB')
164
-
165
  ocr_out = ocr.ocr(
166
  img,
167
  return_cropped_image=True,
@@ -170,13 +142,112 @@ def main():
170
  box_score_thresh=box_score_thresh,
171
  min_box_size=min_box_size,
172
  )
173
- if det_model_name[0] == 'naive_det':
174
- visualize_naive_result(img, det_model_name[0], ocr_out, box_score_thresh)
175
- else:
176
- visualize_result(img, ocr_out)
177
 
178
- except Exception as e:
179
- st.error(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
  if __name__ == '__main__':
 
1
  # coding: utf-8
2
+ # Copyright (C) 2023, [Breezedeus](https://github.com/breezedeus).
3
  # Licensed to the Apache Software Foundation (ASF) under one
4
  # or more contributor license agreements. See the NOTICE file
5
  # distributed with this work for additional information
 
18
  # under the License.
19
 
20
  import os
 
21
 
22
+ import gradio as gr
23
  import cv2
24
  import numpy as np
 
 
25
  from cnstd.utils import pil_to_numpy, imsave
26
 
27
  from cnocr import CnOcr, DET_AVAILABLE_MODELS, REC_AVAILABLE_MODELS
 
29
 
30
 
31
  logger = set_logger()
 
32
 
33
 
34
  def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefix_fp):
 
38
  rotated_img = rotated_img.copy()
39
  crops = [info['cropped_img'] for info in one_out]
40
  print('%d boxes are found' % len(crops))
41
+ if len(crops) < 1:
42
+ return
43
  ncols = crop_ncols
44
  nrows = math.ceil(len(crops) / ncols)
45
  fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
 
63
  print('boxes results are save to file %s' % result_fp)
64
 
65
 
 
66
  def get_ocr_model(det_model_name, rec_model_name, det_more_configs):
67
+ det_model_name, det_model_backend = det_model_name.split('::')
68
+ rec_model_name, rec_model_backend = rec_model_name.split('::')
69
  return CnOcr(
70
  det_model_name=det_model_name,
71
  det_model_backend=det_model_backend,
 
76
 
77
 
78
  def visualize_naive_result(img, det_model_name, std_out, box_score_thresh):
79
+ if len(std_out) < 1:
80
+ # gr.Warning(f'未检测到文本!')
81
+ return []
82
  img = pil_to_numpy(img).transpose((1, 2, 0)).astype(np.uint8)
83
 
84
+ # plot_for_debugging(img, std_out, box_score_thresh, 2, './streamlit-app')
85
+ # gr.Markdown('## Detection Result')
86
+ # if det_model_name == 'naive_det':
87
+ # gr.Warning('⚠️ Warning: "naive_det" 检测模型不返回文本框位置!')
88
+ # cols = st.columns([1, 7, 1])
89
+ # cols[1].image('./streamlit-app-result.png')
90
+ #
91
+ # st.subheader('Recognition Result')
92
+ # cols = st.columns([1, 7, 1])
93
+ # cols[1].image('./streamlit-app-crops.png')
94
 
95
+ return _visualize_ocr(std_out)
 
 
 
 
96
 
97
 
98
  def _visualize_ocr(ocr_outs):
99
+ if len(ocr_outs) < 1:
100
+ return
101
+ ocr_res = []
102
  for out in ocr_outs:
103
  # cropped_img = out['cropped_img'] # 检测出的文本框
104
+ ocr_res.append([out['score'], out['text']])
105
+ return ocr_res
 
106
 
107
 
108
  def visualize_result(img, ocr_outs):
 
113
  os.makedirs(os.path.dirname(font_path), exist_ok=True)
114
  download(url, path=font_path, overwrite=True)
115
  draw_ocr_results(img, ocr_outs, out_draw_fp, font_path)
116
+ return out_draw_fp
117
+
118
+
119
+ def recognize(
120
+ det_model_name,
121
+ is_single_line,
122
+ rec_model_name,
123
+ rotated_bbox,
124
+ use_angle_clf,
125
+ new_size,
126
+ box_score_thresh,
127
+ min_box_size,
128
+ image_file,
129
+ ):
130
+ img = image_file.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  det_more_configs = dict(rotated_bbox=rotated_bbox, use_angle_clf=use_angle_clf)
132
  ocr = get_ocr_model(det_model_name, rec_model_name, det_more_configs)
133
 
134
+ if is_single_line:
135
+ ocr_out = [ocr.ocr_for_single_line(np.array(img))]
136
+ else:
 
 
 
 
 
 
 
 
 
 
137
  ocr_out = ocr.ocr(
138
  img,
139
  return_cropped_image=True,
 
142
  box_score_thresh=box_score_thresh,
143
  min_box_size=min_box_size,
144
  )
 
 
 
 
145
 
146
+ det_model_name, det_model_backend = det_model_name.split('::')
147
+ if is_single_line or det_model_name == 'naive_det':
148
+ out_texts = visualize_naive_result(
149
+ img, det_model_name, ocr_out, box_score_thresh
150
+ )
151
+ if is_single_line:
152
+ return [
153
+ gr.update(visible=False),
154
+ gr.update(visible=False),
155
+ gr.update(value=out_texts, visible=True),
156
+ ]
157
+ return [
158
+ gr.update(visible=False),
159
+ gr.update(visible=True),
160
+ gr.update(value=out_texts, visible=True),
161
+ ]
162
+ else:
163
+ out_img_path = visualize_result(img, ocr_out)
164
+ return [
165
+ gr.update(value=out_img_path, visible=True),
166
+ gr.update(visible=False),
167
+ gr.update(visible=False),
168
+ ]
169
+
170
+
171
+ def main():
172
+ det_models = list(DET_AVAILABLE_MODELS.all_models())
173
+ det_models.append(('naive_det', 'onnx'))
174
+ det_models.sort()
175
+ det_models = [f'{m}::{b}' for m, b in det_models]
176
+
177
+ all_models = list(REC_AVAILABLE_MODELS.all_models())
178
+ all_models.sort()
179
+ all_models = [f'{m}::{b}' for m, b in all_models]
180
+
181
+ title = '开源Python OCR工具:'
182
+ desc = (
183
+ '<p style="text-align: center">详细说明参见:<a href="https://github.com/breezedeus/CnOCR" target="_blank">Github</a>;'
184
+ '<a href="https://cnocr.readthedocs.io" target="_blank">在线文档</a>;'
185
+ '欢迎加入 <a href="https://www.breezedeus.com/join-group" target="_blank">交流群</a>;'
186
+ '作者:<a href="https://www.breezedeus.com" target="_blank">Breezedeus</a> ,'
187
+ '<a href="https://github.com/breezedeus" target="_blank">Github</a> 。</p>'
188
+ )
189
+
190
+ with gr.Blocks() as demo:
191
+ gr.Markdown(
192
+ f'<h1 style="text-align: center; margin-bottom: 1rem;">{title} <a href="https://github.com/breezedeus/cnocr" target="_blank">CnOCR</a></h1>'
193
+ )
194
+ gr.Markdown(desc)
195
+ with gr.Row(equal_height=False):
196
+ with gr.Column(min_width=200, variant='panel', scale=1):
197
+ gr.Markdown('### 模型设置')
198
+ det_model_name = gr.Dropdown(
199
+ label='选择检测模型', choices=det_models, value='ch_PP-OCRv3_det::onnx',
200
+ )
201
+ is_single_line = gr.Checkbox(label='单行文字模式(不使用检测模型)', value=False)
202
+
203
+ rec_model_name = gr.Dropdown(
204
+ label='选择识别模型',
205
+ choices=all_models,
206
+ value='densenet_lite_136-fc::onnx',
207
+ )
208
+
209
+ gr.Markdown('### 检测参数')
210
+ rotated_bbox = gr.Checkbox(label='检测带角度文本框', value=True)
211
+ use_angle_clf = gr.Checkbox(label='使用角度预测模型校正文本框', value=False)
212
+ new_size = gr.Slider(
213
+ label='resize 后图片(长边)大小', minimum=124, maximum=4096, value=768
214
+ )
215
+ box_score_thresh = gr.Slider(
216
+ label='得分阈值(低于阈值的结果会被过滤掉)', minimum=0.05, maximum=0.95, value=0.3
217
+ )
218
+ min_box_size = gr.Slider(
219
+ label='框大小阈值(更小的文本框会被过滤掉)', minimum=4, maximum=50, value=10
220
+ )
221
+
222
+ with gr.Column(scale=3, variant='compact'):
223
+ gr.Markdown('### 选择待检测图片')
224
+ image_file = gr.Image(label='', type="pil", image_mode='RGB')
225
+ sub_btn = gr.Button("Submit", variant="primary")
226
+ out_image = gr.Image(label='识别结果', interactive=False, visible=False)
227
+ naive_warn = gr.Markdown(
228
+ '**⚠️ Warning**: "naive_det" 检测模型不返回文本框位置!', visible=False
229
+ )
230
+ out_texts = gr.Dataframe(
231
+ headers=['得分', '文本'], label='识别结果', interactive=False, visible=False
232
+ )
233
+ sub_btn.click(
234
+ recognize,
235
+ inputs=[
236
+ det_model_name,
237
+ is_single_line,
238
+ rec_model_name,
239
+ rotated_bbox,
240
+ use_angle_clf,
241
+ new_size,
242
+ box_score_thresh,
243
+ min_box_size,
244
+ image_file,
245
+ ],
246
+ outputs=[out_image, naive_warn, out_texts],
247
+ )
248
+
249
+ demo.queue(concurrency_count=4)
250
+ demo.launch()
251
 
252
 
253
  if __name__ == '__main__':
requirements.txt CHANGED
@@ -1,155 +1,237 @@
1
  #
2
- # This file is autogenerated by pip-compile with python 3.8
3
- # To update, run:
4
  #
5
- # pip-compile --output-file=requirements.txt requirements.in
6
  #
7
- --index-url https://pypi.doubanio.com/simple
8
  --extra-index-url https://pypi.org/simple
9
 
10
- absl-py==0.13.0
11
- # via tensorboard
12
- aiohttp==3.7.4.post0
13
  # via fsspec
14
- async-timeout==3.0.1
15
  # via aiohttp
16
- attrs==21.2.0
 
 
 
 
 
 
17
  # via aiohttp
18
- cachetools==4.2.2
19
- # via google-auth
20
- certifi==2020.4.5.1
21
- # via requests
22
- chardet==3.0.4
23
  # via
24
  # aiohttp
25
  # requests
26
- click==8.0.1
27
  # via
28
  # -r requirements.in
29
  # cnstd
30
- cnstd==1.2
 
31
  # via -r requirements.in
 
 
 
 
32
  cycler==0.11.0
33
  # via matplotlib
34
- flatbuffers==2.0
 
 
 
 
 
 
35
  # via onnxruntime
36
- fonttools==4.34.4
37
  # via matplotlib
38
- fsspec[http]==2021.7.0
39
- # via pytorch-lightning
40
- google-auth==1.35.0
 
 
41
  # via
42
- # google-auth-oauthlib
43
- # tensorboard
44
- google-auth-oauthlib==0.4.5
45
- # via tensorboard
46
- grpcio==1.39.0
47
- # via tensorboard
48
- idna==2.9
 
 
 
 
49
  # via
50
  # requests
51
  # yarl
52
- kiwisolver==1.4.3
 
 
53
  # via matplotlib
54
- markdown==3.3.4
55
- # via tensorboard
56
- matplotlib==3.5.2
57
- # via cnstd
58
- multidict==5.1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  # via
60
  # aiohttp
61
  # yarl
62
- numpy==1.22.3
 
 
 
 
63
  # via
64
  # -r requirements.in
 
65
  # cnstd
 
 
66
  # matplotlib
67
  # onnx
68
  # onnxruntime
69
  # opencv-python
 
 
70
  # pytorch-lightning
 
 
 
 
71
  # scipy
72
- # tensorboard
 
 
73
  # torchmetrics
74
  # torchvision
75
- oauthlib==3.1.1
76
- # via requests-oauthlib
77
- onnx==1.13.0
78
  # via
79
  # -r requirements.in
80
  # cnstd
81
- onnxruntime==1.14.0
82
  # via
83
  # -r requirements.in
84
  # cnstd
85
- opencv-python==4.6.0.66
86
  # via cnstd
87
- packaging==21.0
 
 
 
 
88
  # via
 
 
89
  # matplotlib
 
90
  # pytorch-lightning
 
91
  # torchmetrics
92
- pillow==9.1.1
 
 
 
 
 
 
93
  # via
94
  # -r requirements.in
95
  # cnstd
 
96
  # matplotlib
 
97
  # torchvision
98
  polygon3==3.0.9.1
99
  # via cnstd
100
- protobuf
101
  # via
102
  # onnx
103
  # onnxruntime
104
- # tensorboard
105
- pyasn1==0.4.8
106
- # via
107
- # pyasn1-modules
108
- # rsa
109
- pyasn1-modules==0.2.8
110
- # via google-auth
111
- pyclipper==1.3.0.post3
112
  # via cnstd
113
- pydeprecate==0.3.1
114
- # via pytorch-lightning
115
- pyparsing==2.4.7
116
  # via
117
  # matplotlib
118
- # packaging
119
- python-dateutil==2.8.2
120
- # via matplotlib
121
- pytorch-lightning==1.6.3
122
  # via
123
  # -r requirements.in
124
  # cnstd
125
- pyyaml==5.4.1
126
- # via pytorch-lightning
127
- requests==2.23.0
 
 
 
 
 
 
 
 
 
 
 
128
  # via
129
  # fsspec
130
- # requests-oauthlib
131
- # tensorboard
132
- requests-oauthlib==1.3.0
133
- # via google-auth-oauthlib
134
- rsa==4.7.2
135
- # via google-auth
136
- scipy==1.8.1
 
 
 
 
 
 
 
137
  # via cnstd
138
- shapely==1.8.2
 
 
 
 
139
  # via cnstd
140
- six==1.14.0
141
  # via
142
- # absl-py
143
- # google-auth
144
- # grpcio
145
- # protobuf
146
  # python-dateutil
147
- tensorboard==2.6.0
148
- # via pytorch-lightning
149
- tensorboard-data-server==0.6.1
150
- # via tensorboard
151
- tensorboard-plugin-wit==1.8.0
152
- # via tensorboard
 
 
 
 
153
  torch==2.0.1
154
  # via
155
  # -r requirements.in
@@ -157,37 +239,45 @@ torch==2.0.1
157
  # pytorch-lightning
158
  # torchmetrics
159
  # torchvision
160
- torchmetrics==0.11.1
161
- # via pytorch-lightning
 
 
162
  torchvision==0.15.2
163
  # via
164
  # -r requirements.in
165
  # cnstd
166
- tqdm==4.64.0
167
  # via
168
  # -r requirements.in
169
  # cnstd
 
170
  # pytorch-lightning
171
- typing-extensions==4.2.0
172
  # via
173
- # aiohttp
 
174
  # onnx
175
  # pytorch-lightning
 
176
  # torch
177
- unidecode==1.3.4
 
 
 
178
  # via cnstd
179
- urllib3==1.25.9
180
- # via requests
181
- werkzeug==2.0.1
182
- # via tensorboard
183
- wheel==0.37.0
184
- # via tensorboard
185
- yarl==1.6.3
186
  # via aiohttp
 
 
 
 
187
 
188
  # The following packages are considered to be unsafe in a requirements file:
189
  # setuptools
190
-
191
- # for streamlit.io demo
192
- cnocr==2.2.2.3
193
- streamlit
 
1
  #
2
+ # This file is autogenerated by pip-compile with Python 3.9
3
+ # by the following command:
4
  #
5
+ # pip-compile --extra-index-url=https://pypi.tuna.tsinghua.edu.cn/simple --index-url=https://mirrors.aliyun.com/pypi/simple --output-file=requirements.txt requirements.in
6
  #
 
7
  --extra-index-url https://pypi.org/simple
8
 
9
+ aiohttp==3.8.4
 
 
10
  # via fsspec
11
+ aiosignal==1.3.1
12
  # via aiohttp
13
+ albumentations==1.3.1
14
+ # via -r requirements.in
15
+ appdirs==1.4.4
16
+ # via wandb
17
+ async-timeout==4.0.2
18
+ # via aiohttp
19
+ attrs==23.1.0
20
  # via aiohttp
21
+ certifi==2023.5.7
22
+ # via
23
+ # requests
24
+ # sentry-sdk
25
+ charset-normalizer==3.1.0
26
  # via
27
  # aiohttp
28
  # requests
29
+ click==8.1.3
30
  # via
31
  # -r requirements.in
32
  # cnstd
33
+ # wandb
34
+ cnstd>=1.2.3.4
35
  # via -r requirements.in
36
+ coloredlogs==15.0.1
37
+ # via onnxruntime
38
+ contourpy==1.1.0
39
+ # via matplotlib
40
  cycler==0.11.0
41
  # via matplotlib
42
+ docker-pycreds==0.4.0
43
+ # via wandb
44
+ filelock==3.12.2
45
+ # via
46
+ # huggingface-hub
47
+ # torch
48
+ flatbuffers==23.5.26
49
  # via onnxruntime
50
+ fonttools==4.40.0
51
  # via matplotlib
52
+ frozenlist==1.3.3
53
+ # via
54
+ # aiohttp
55
+ # aiosignal
56
+ fsspec[http]==2023.6.0
57
  # via
58
+ # huggingface-hub
59
+ # pytorch-lightning
60
+ gitdb==4.0.10
61
+ # via gitpython
62
+ gitpython==3.1.34
63
+ # via wandb
64
+ huggingface-hub==0.15.1
65
+ # via cnstd
66
+ humanfriendly==10.0
67
+ # via coloredlogs
68
+ idna==3.4
69
  # via
70
  # requests
71
  # yarl
72
+ imageio==2.31.3
73
+ # via scikit-image
74
+ importlib-resources==5.12.0
75
  # via matplotlib
76
+ jinja2==3.1.2
77
+ # via torch
78
+ joblib==1.3.2
79
+ # via scikit-learn
80
+ kiwisolver==1.4.4
81
+ # via matplotlib
82
+ lazy-loader==0.3
83
+ # via scikit-image
84
+ lightning-utilities==0.9.0
85
+ # via pytorch-lightning
86
+ markupsafe==2.1.3
87
+ # via jinja2
88
+ matplotlib==3.7.1
89
+ # via
90
+ # cnstd
91
+ # seaborn
92
+ mpmath==1.3.0
93
+ # via sympy
94
+ multidict==6.0.4
95
  # via
96
  # aiohttp
97
  # yarl
98
+ networkx==3.1
99
+ # via
100
+ # scikit-image
101
+ # torch
102
+ numpy==1.25.0
103
  # via
104
  # -r requirements.in
105
+ # albumentations
106
  # cnstd
107
+ # contourpy
108
+ # imageio
109
  # matplotlib
110
  # onnx
111
  # onnxruntime
112
  # opencv-python
113
+ # opencv-python-headless
114
+ # pandas
115
  # pytorch-lightning
116
+ # pywavelets
117
+ # qudida
118
+ # scikit-image
119
+ # scikit-learn
120
  # scipy
121
+ # seaborn
122
+ # shapely
123
+ # tifffile
124
  # torchmetrics
125
  # torchvision
126
+ onnx==1.14.0
 
 
127
  # via
128
  # -r requirements.in
129
  # cnstd
130
+ onnxruntime==1.15.1
131
  # via
132
  # -r requirements.in
133
  # cnstd
134
+ opencv-python==4.7.0.72
135
  # via cnstd
136
+ opencv-python-headless==4.8.0.76
137
+ # via
138
+ # albumentations
139
+ # qudida
140
+ packaging==23.1
141
  # via
142
+ # huggingface-hub
143
+ # lightning-utilities
144
  # matplotlib
145
+ # onnxruntime
146
  # pytorch-lightning
147
+ # scikit-image
148
  # torchmetrics
149
+ pandas==2.0.3
150
+ # via
151
+ # cnstd
152
+ # seaborn
153
+ pathtools==0.1.2
154
+ # via wandb
155
+ pillow==9.5.0
156
  # via
157
  # -r requirements.in
158
  # cnstd
159
+ # imageio
160
  # matplotlib
161
+ # scikit-image
162
  # torchvision
163
  polygon3==3.0.9.1
164
  # via cnstd
165
+ protobuf==4.23.3
166
  # via
167
  # onnx
168
  # onnxruntime
169
+ # wandb
170
+ psutil==5.9.5
171
+ # via wandb
172
+ pyclipper==1.3.0.post4
 
 
 
 
173
  # via cnstd
174
+ pyparsing==3.1.0
175
+ # via matplotlib
176
+ python-dateutil==2.8.2
177
  # via
178
  # matplotlib
179
+ # pandas
180
+ pytorch-lightning==2.0.8
 
 
181
  # via
182
  # -r requirements.in
183
  # cnstd
184
+ pytz==2023.3
185
+ # via pandas
186
+ pywavelets==1.4.1
187
+ # via scikit-image
188
+ pyyaml==6.0
189
+ # via
190
+ # albumentations
191
+ # cnstd
192
+ # huggingface-hub
193
+ # pytorch-lightning
194
+ # wandb
195
+ qudida==0.0.4
196
+ # via albumentations
197
+ requests==2.31.0
198
  # via
199
  # fsspec
200
+ # huggingface-hub
201
+ # torchvision
202
+ # wandb
203
+ scikit-image==0.21.0
204
+ # via albumentations
205
+ scikit-learn==1.3.0
206
+ # via qudida
207
+ scipy==1.11.1
208
+ # via
209
+ # albumentations
210
+ # cnstd
211
+ # scikit-image
212
+ # scikit-learn
213
+ seaborn==0.12.2
214
  # via cnstd
215
+ sentry-sdk==1.30.0
216
+ # via wandb
217
+ setproctitle==1.3.2
218
+ # via wandb
219
+ shapely==2.0.1
220
  # via cnstd
221
+ six==1.16.0
222
  # via
223
+ # docker-pycreds
 
 
 
224
  # python-dateutil
225
+ smmap==5.0.0
226
+ # via gitdb
227
+ sympy==1.12
228
+ # via
229
+ # onnxruntime
230
+ # torch
231
+ threadpoolctl==3.2.0
232
+ # via scikit-learn
233
+ tifffile==2023.8.30
234
+ # via scikit-image
235
  torch==2.0.1
236
  # via
237
  # -r requirements.in
 
239
  # pytorch-lightning
240
  # torchmetrics
241
  # torchvision
242
+ torchmetrics==0.11.4
243
+ # via
244
+ # -r requirements.in
245
+ # pytorch-lightning
246
  torchvision==0.15.2
247
  # via
248
  # -r requirements.in
249
  # cnstd
250
+ tqdm==4.65.0
251
  # via
252
  # -r requirements.in
253
  # cnstd
254
+ # huggingface-hub
255
  # pytorch-lightning
256
+ typing-extensions==4.7.0
257
  # via
258
+ # huggingface-hub
259
+ # lightning-utilities
260
  # onnx
261
  # pytorch-lightning
262
+ # qudida
263
  # torch
264
+ # wandb
265
+ tzdata==2023.3
266
+ # via pandas
267
+ unidecode==1.3.6
268
  # via cnstd
269
+ urllib3==2.0.3
270
+ # via
271
+ # requests
272
+ # sentry-sdk
273
+ wandb==0.15.10
274
+ # via -r requirements.in
275
+ yarl==1.9.2
276
  # via aiohttp
277
+ zipp==3.15.0
278
+ # via importlib-resources
279
+
280
+ cnocr==2.2.4
281
 
282
  # The following packages are considered to be unsafe in a requirements file:
283
  # setuptools
 
 
 
 
streamlit_app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
3
+ # Licensed to the Apache Software Foundation (ASF) under one
4
+ # or more contributor license agreements. See the NOTICE file
5
+ # distributed with this work for additional information
6
+ # regarding copyright ownership. The ASF licenses this file
7
+ # to you under the Apache License, Version 2.0 (the
8
+ # "License"); you may not use this file except in compliance
9
+ # with the License. You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing,
14
+ # software distributed under the License is distributed on an
15
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16
+ # KIND, either express or implied. See the License for the
17
+ # specific language governing permissions and limitations
18
+ # under the License.
19
+
20
+ import os
21
+ from collections import OrderedDict
22
+
23
+ import cv2
24
+ import numpy as np
25
+ from PIL import Image
26
+ import streamlit as st
27
+ from cnstd.utils import pil_to_numpy, imsave
28
+
29
+ from cnocr import CnOcr, DET_AVAILABLE_MODELS, REC_AVAILABLE_MODELS
30
+ from cnocr.utils import set_logger, draw_ocr_results, download
31
+
32
+
33
+ logger = set_logger()
34
+ st.set_page_config(layout="wide")
35
+
36
+
37
+ def plot_for_debugging(rotated_img, one_out, box_score_thresh, crop_ncols, prefix_fp):
38
+ import matplotlib.pyplot as plt
39
+ import math
40
+
41
+ rotated_img = rotated_img.copy()
42
+ crops = [info['cropped_img'] for info in one_out]
43
+ print('%d boxes are found' % len(crops))
44
+ ncols = crop_ncols
45
+ nrows = math.ceil(len(crops) / ncols)
46
+ fig, ax = plt.subplots(nrows=nrows, ncols=ncols)
47
+ for i, axi in enumerate(ax.flat):
48
+ if i >= len(crops):
49
+ break
50
+ axi.imshow(crops[i])
51
+ crop_fp = '%s-crops.png' % prefix_fp
52
+ plt.savefig(crop_fp)
53
+ print('cropped results are save to file %s' % crop_fp)
54
+
55
+ for info in one_out:
56
+ box, score = info.get('position'), info['score']
57
+ if score < box_score_thresh: # score < 0.5
58
+ continue
59
+ if box is not None:
60
+ box = box.astype(int).reshape(-1, 2)
61
+ cv2.polylines(rotated_img, [box], True, color=(255, 0, 0), thickness=2)
62
+ result_fp = '%s-result.png' % prefix_fp
63
+ imsave(rotated_img, result_fp, normalized=False)
64
+ print('boxes results are save to file %s' % result_fp)
65
+
66
+
67
+ @st.cache_resource
68
+ def get_ocr_model(det_model_name, rec_model_name, det_more_configs):
69
+ det_model_name, det_model_backend = det_model_name
70
+ rec_model_name, rec_model_backend = rec_model_name
71
+ return CnOcr(
72
+ det_model_name=det_model_name,
73
+ det_model_backend=det_model_backend,
74
+ rec_model_name=rec_model_name,
75
+ rec_model_backend=rec_model_backend,
76
+ det_more_configs=det_more_configs,
77
+ )
78
+
79
+
80
+ def visualize_naive_result(img, det_model_name, std_out, box_score_thresh):
81
+ img = pil_to_numpy(img).transpose((1, 2, 0)).astype(np.uint8)
82
+
83
+ plot_for_debugging(img, std_out, box_score_thresh, 2, './streamlit-app')
84
+ st.subheader('Detection Result')
85
+ if det_model_name == 'default_det':
86
+ st.warning('⚠️ Warning: "default_det" 检测模型不返回文本框位置!')
87
+ cols = st.columns([1, 7, 1])
88
+ cols[1].image('./streamlit-app-result.png')
89
+
90
+ st.subheader('Recognition Result')
91
+ cols = st.columns([1, 7, 1])
92
+ cols[1].image('./streamlit-app-crops.png')
93
+
94
+ _visualize_ocr(std_out)
95
+
96
+
97
+ def _visualize_ocr(ocr_outs):
98
+ st.empty()
99
+ ocr_res = OrderedDict({'文本': []})
100
+ ocr_res['得分'] = []
101
+ for out in ocr_outs:
102
+ # cropped_img = out['cropped_img'] # 检测出的文本框
103
+ ocr_res['得分'].append(out['score'])
104
+ ocr_res['文本'].append(out['text'])
105
+ st.table(ocr_res)
106
+
107
+
108
+ def visualize_result(img, ocr_outs):
109
+ out_draw_fp = './streamlit-app-det-result.png'
110
+ font_path = 'docs/fonts/simfang.ttf'
111
+ if not os.path.exists(font_path):
112
+ url = 'https://huggingface.co/datasets/breezedeus/cnocr-wx-qr-code/resolve/main/fonts/simfang.ttf'
113
+ os.makedirs(os.path.dirname(font_path), exist_ok=True)
114
+ download(url, path=font_path, overwrite=True)
115
+ draw_ocr_results(img, ocr_outs, out_draw_fp, font_path)
116
+ st.image(out_draw_fp)
117
+
118
+
119
+ def main():
120
+ st.sidebar.header('模型设置')
121
+ det_models = list(DET_AVAILABLE_MODELS.all_models())
122
+ det_models.append(('naive_det', 'onnx'))
123
+ det_models.sort()
124
+ det_model_name = st.sidebar.selectbox(
125
+ '选择检测模型', det_models, index=det_models.index(('ch_PP-OCRv3_det', 'onnx'))
126
+ )
127
+
128
+ all_models = list(REC_AVAILABLE_MODELS.all_models())
129
+ all_models.sort()
130
+ idx = all_models.index(('densenet_lite_136-fc', 'onnx'))
131
+ rec_model_name = st.sidebar.selectbox('选择识别模型', all_models, index=idx)
132
+
133
+ st.sidebar.subheader('检测参数')
134
+ rotated_bbox = st.sidebar.checkbox('是否检测带角度文本框', value=True)
135
+ use_angle_clf = st.sidebar.checkbox('是否使用角度预测模型校正文本框', value=False)
136
+ new_size = st.sidebar.slider(
137
+ 'resize 后图片(长边)大小', min_value=124, max_value=4096, value=768
138
+ )
139
+ box_score_thresh = st.sidebar.slider(
140
+ '得分阈值(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3
141
+ )
142
+ min_box_size = st.sidebar.slider(
143
+ '框大小阈值(更小的文本框会被过滤掉)', min_value=4, max_value=50, value=10
144
+ )
145
+ # std = get_std_model(det_model_name, rotated_bbox, use_angle_clf)
146
+
147
+ # st.sidebar.markdown("""---""")
148
+ # st.sidebar.header('CnOcr 设置')
149
+ det_more_configs = dict(rotated_bbox=rotated_bbox, use_angle_clf=use_angle_clf)
150
+ ocr = get_ocr_model(det_model_name, rec_model_name, det_more_configs)
151
+
152
+ st.markdown('# 开源Python OCR工具 ' '[CnOCR](https://github.com/breezedeus/cnocr)')
153
+ st.markdown('> 详细说明参见:[CnOCR 文档](https://cnocr.readthedocs.io/) ;'
154
+ '欢迎加入 [交流群](https://www.breezedeus.com/join-group) ;'
155
+ '作者:[breezedeus](https://www.breezedeus.com), [Github](https://github.com/breezedeus) 。')
156
+ st.markdown('')
157
+ st.subheader('选择待检测图片')
158
+ content_file = st.file_uploader('', type=["png", "jpg", "jpeg", "webp"])
159
+ if content_file is None:
160
+ st.stop()
161
+
162
+ try:
163
+ img = Image.open(content_file).convert('RGB')
164
+
165
+ ocr_out = ocr.ocr(
166
+ img,
167
+ return_cropped_image=True,
168
+ resized_shape=new_size,
169
+ preserve_aspect_ratio=True,
170
+ box_score_thresh=box_score_thresh,
171
+ min_box_size=min_box_size,
172
+ )
173
+ if det_model_name[0] == 'naive_det':
174
+ visualize_naive_result(img, det_model_name[0], ocr_out, box_score_thresh)
175
+ else:
176
+ visualize_result(img, ocr_out)
177
+
178
+ except Exception as e:
179
+ st.error(e)
180
+
181
+
182
+ if __name__ == '__main__':
183
+ main()