narugo1992 commited on
Commit
32ef351
1 Parent(s): 582519c

dev(narugo): add new models

Browse files
Files changed (3) hide show
  1. aicheck.py +42 -0
  2. app.py +36 -0
  3. rating.py +42 -0
aicheck.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Mapping, List
5
+
6
+ from huggingface_hub import HfFileSystem
7
+ from huggingface_hub import hf_hub_download
8
+ from imgutils.data import ImageTyping, load_image
9
+ from natsort import natsorted
10
+
11
+ from onnx_ import _open_onnx_model
12
+ from preprocess import _img_encode
13
+
14
+ hfs = HfFileSystem()
15
+
16
+ _REPO = 'deepghs/anime_ai_check'
17
+ _AICHECK_MODELS = natsorted([
18
+ os.path.dirname(os.path.relpath(file, _REPO))
19
+ for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
+ ])
21
+ _DEFAULT_AICHECK_MODEL = 'mobilenetv3_sce_dist'
22
+
23
+
24
+ @lru_cache()
25
+ def _open_anime_aicheck_model(model_name):
26
+ return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
+
28
+
29
+ @lru_cache()
30
+ def _get_tags(model_name) -> List[str]:
31
+ with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
+ return json.load(f)['labels']
33
+
34
+
35
+ def _gr_aicheck(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
+ image = load_image(image, mode='RGB')
37
+ input_ = _img_encode(image, size=(size, size))[None, ...]
38
+ output, = _open_anime_aicheck_model(model_name).run(['output'], {'input': input_})
39
+
40
+ labels = _get_tags(model_name)
41
+ values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
+ return values
app.py CHANGED
@@ -2,8 +2,10 @@ import os
2
 
3
  import gradio as gr
4
 
 
5
  from cls import _CLS_MODELS, _DEFAULT_CLS_MODEL, _gr_classification
6
  from monochrome import _gr_monochrome, _DEFAULT_MONO_MODEL, _MONO_MODELS
 
7
 
8
  if __name__ == '__main__':
9
  with gr.Blocks() as demo:
@@ -42,4 +44,38 @@ if __name__ == '__main__':
42
  outputs=[gr_mono_output],
43
  )
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  demo.queue(os.cpu_count()).launch()
 
2
 
3
  import gradio as gr
4
 
5
+ from aicheck import _gr_aicheck, _DEFAULT_AICHECK_MODEL, _AICHECK_MODELS
6
  from cls import _CLS_MODELS, _DEFAULT_CLS_MODEL, _gr_classification
7
  from monochrome import _gr_monochrome, _DEFAULT_MONO_MODEL, _MONO_MODELS
8
+ from rating import _RATING_MODELS, _DEFAULT_RATING_MODEL, _gr_rating
9
 
10
  if __name__ == '__main__':
11
  with gr.Blocks() as demo:
 
44
  outputs=[gr_mono_output],
45
  )
46
 
47
+ with gr.Tab('AI Check'):
48
+ with gr.Row():
49
+ with gr.Column():
50
+ gr_aicheck_input_image = gr.Image(type='pil', label='Original Image')
51
+ gr_aicheck_model = gr.Dropdown(_AICHECK_MODELS, value=_DEFAULT_AICHECK_MODEL, label='Model')
52
+ gr_aicheck_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
53
+ gr_aicheck_submit = gr.Button(value='Submit', variant='primary')
54
+
55
+ with gr.Column():
56
+ gr_aicheck_output = gr.Label(label='Classes')
57
+
58
+ gr_aicheck_submit.click(
59
+ _gr_aicheck,
60
+ inputs=[gr_aicheck_input_image, gr_aicheck_model, gr_aicheck_infer_size],
61
+ outputs=[gr_aicheck_output],
62
+ )
63
+
64
+ with gr.Tab('Rating'):
65
+ with gr.Row():
66
+ with gr.Column():
67
+ gr_rating_input_image = gr.Image(type='pil', label='Original Image')
68
+ gr_rating_model = gr.Dropdown(_RATING_MODELS, value=_DEFAULT_RATING_MODEL, label='Model')
69
+ gr_rating_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
70
+ gr_rating_submit = gr.Button(value='Submit', variant='primary')
71
+
72
+ with gr.Column():
73
+ gr_rating_output = gr.Label(label='Classes')
74
+
75
+ gr_rating_submit.click(
76
+ _gr_rating,
77
+ inputs=[gr_rating_input_image, gr_rating_model, gr_rating_infer_size],
78
+ outputs=[gr_rating_output],
79
+ )
80
+
81
  demo.queue(os.cpu_count()).launch()
rating.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Mapping, List
5
+
6
+ from huggingface_hub import HfFileSystem
7
+ from huggingface_hub import hf_hub_download
8
+ from imgutils.data import ImageTyping, load_image
9
+ from natsort import natsorted
10
+
11
+ from onnx_ import _open_onnx_model
12
+ from preprocess import _img_encode
13
+
14
+ hfs = HfFileSystem()
15
+
16
+ _REPO = 'deepghs/anime_rating'
17
+ _RATING_MODELS = natsorted([
18
+ os.path.dirname(os.path.relpath(file, _REPO))
19
+ for file in hfs.glob(f'{_REPO}/*/model.onnx')
20
+ ])
21
+ _DEFAULT_RATING_MODEL = 'mobilenetv3_sce_dist'
22
+
23
+
24
+ @lru_cache()
25
+ def _open_anime_rating_model(model_name):
26
+ return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
27
+
28
+
29
+ @lru_cache()
30
+ def _get_tags(model_name) -> List[str]:
31
+ with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
32
+ return json.load(f)['labels']
33
+
34
+
35
+ def _gr_rating(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
36
+ image = load_image(image, mode='RGB')
37
+ input_ = _img_encode(image, size=(size, size))[None, ...]
38
+ output, = _open_anime_rating_model(model_name).run(['output'], {'input': input_})
39
+
40
+ labels = _get_tags(model_name)
41
+ values = dict(zip(labels, map(lambda x: x.item(), output[0])))
42
+ return values