KAZEKOI commited on
Commit
da17888
1 Parent(s): 0b5fbaf

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -14,6 +14,12 @@ DESCRIPTION = """
14
  implications_list_path = './implications_list.json'
15
  related_feature_path = './related_feature.json'
16
 
 
 
 
 
 
 
17
  #HF_TOKEN = os.environ["HF_TOKEN"]
18
 
19
  # Dataset v3 series of models:
@@ -21,6 +27,7 @@ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
21
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
22
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
23
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
 
24
 
25
  # Dataset v2 series of models:
26
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
@@ -61,8 +68,7 @@ def parse_args() -> argparse.Namespace:
61
  parser = argparse.ArgumentParser()
62
  parser.add_argument("--score-slider-step", type=float, default=0.05)
63
  parser.add_argument("--score-general-threshold", type=float, default=0.4)
64
- parser.add_argument("--score-character-threshold", type=float, default=0.9)
65
- parser.add_argument("--character_string", type=str)
66
  parser.add_argument("--share", action="store_true")
67
  return parser.parse_args()
68
 
@@ -170,6 +176,7 @@ class Predictor:
170
  character_thresh,
171
  character_mcut_enabled,
172
  character_string,
 
173
  ):
174
  self.load_model(model_repo)
175
 
@@ -195,13 +202,6 @@ class Predictor:
195
  general_res = [x for x in general_names if x[1] > general_thresh]
196
  general_res = dict(general_res)
197
 
198
- with open(related_feature_path, 'r') as f:
199
- related_feature_list = json.load(f)
200
-
201
-
202
- with open(implications_list_path, 'r') as f:
203
- implications_list = json.load(f)
204
-
205
  to_delete = set()
206
  for key in general_res.keys():
207
  if key in implications_list:
@@ -221,16 +221,30 @@ class Predictor:
221
  character_res = [x for x in character_names if x[1] > character_thresh]
222
  character_res = dict(character_res)
223
 
 
 
 
 
 
 
 
224
  sorted_general_strings = sorted(
225
  general_res.items(),
226
  key=lambda x: x[1],
227
  reverse=True,
228
  )
229
 
230
- character_list = character_string.lower().split(', ')
 
 
 
 
 
 
 
231
 
232
  feature_delete_list = []
233
- for tag in character_list:
234
  if tag in related_feature_list:
235
  feature_delete_list.extend(related_feature_list[tag])
236
 
@@ -238,6 +252,8 @@ class Predictor:
238
 
239
  sorted_general_strings = [x for x in sorted_general_strings if x not in feature_delete_list]
240
 
 
 
241
  sorted_general_strings = [x.replace("_", " ") if x not in kaomojis else x for x in sorted_general_strings]
242
 
243
  sorted_general_strings = (
@@ -246,7 +262,6 @@ class Predictor:
246
 
247
  return sorted_general_strings, rating, character_res, general_res
248
 
249
-
250
  def main():
251
  args = parse_args()
252
 
@@ -257,6 +272,7 @@ def main():
257
  CONV_MODEL_DSV3_REPO,
258
  VIT_MODEL_DSV3_REPO,
259
  VIT_LARGE_MODEL_DSV3_REPO,
 
260
  ]
261
 
262
  with gr.Blocks(title=TITLE) as demo:
@@ -306,6 +322,11 @@ def main():
306
  label= "Character",
307
  scale=3,
308
  )
 
 
 
 
 
309
  with gr.Row():
310
  clear = gr.ClearButton(
311
  components=[
@@ -340,6 +361,7 @@ def main():
340
  character_thresh,
341
  character_mcut_enabled,
342
  character_string,
 
343
  ],
344
  outputs=[sorted_general_strings, rating, character_res, general_res],
345
  )
 
14
  implications_list_path = './implications_list.json'
15
  related_feature_path = './related_feature.json'
16
 
17
+ with open(related_feature_path, 'r') as f:
18
+ related_feature_list = json.load(f)
19
+
20
+ with open(implications_list_path, 'r') as f:
21
+ implications_list = json.load(f)
22
+
23
  #HF_TOKEN = os.environ["HF_TOKEN"]
24
 
25
  # Dataset v3 series of models:
 
27
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
28
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
29
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
30
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
31
 
32
  # Dataset v2 series of models:
33
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
 
68
  parser = argparse.ArgumentParser()
69
  parser.add_argument("--score-slider-step", type=float, default=0.05)
70
  parser.add_argument("--score-general-threshold", type=float, default=0.4)
71
+ parser.add_argument("--score-character-threshold", type=float, default=0.8)
 
72
  parser.add_argument("--share", action="store_true")
73
  return parser.parse_args()
74
 
 
176
  character_thresh,
177
  character_mcut_enabled,
178
  character_string,
179
+ character_output
180
  ):
181
  self.load_model(model_repo)
182
 
 
202
  general_res = [x for x in general_names if x[1] > general_thresh]
203
  general_res = dict(general_res)
204
 
 
 
 
 
 
 
 
205
  to_delete = set()
206
  for key in general_res.keys():
207
  if key in implications_list:
 
221
  character_res = [x for x in character_names if x[1] > character_thresh]
222
  character_res = dict(character_res)
223
 
224
+ character_strings = sorted(
225
+ character_res.items(),
226
+ key=lambda x: x[1],
227
+ reverse=True,
228
+ )
229
+ character_strings = [x[0] for x in character_strings]
230
+
231
  sorted_general_strings = sorted(
232
  general_res.items(),
233
  key=lambda x: x[1],
234
  reverse=True,
235
  )
236
 
237
+ character_list = []
238
+ if character_string != '':
239
+ character_list = character_string.lower().split(', ')
240
+
241
+ if character_output:
242
+ character_combined = character_list + character_strings
243
+ else:
244
+ character_combined = character_list
245
 
246
  feature_delete_list = []
247
+ for tag in character_combined:
248
  if tag in related_feature_list:
249
  feature_delete_list.extend(related_feature_list[tag])
250
 
 
252
 
253
  sorted_general_strings = [x for x in sorted_general_strings if x not in feature_delete_list]
254
 
255
+ sorted_general_strings = character_combined + sorted_general_strings
256
+
257
  sorted_general_strings = [x.replace("_", " ") if x not in kaomojis else x for x in sorted_general_strings]
258
 
259
  sorted_general_strings = (
 
262
 
263
  return sorted_general_strings, rating, character_res, general_res
264
 
 
265
  def main():
266
  args = parse_args()
267
 
 
272
  CONV_MODEL_DSV3_REPO,
273
  VIT_MODEL_DSV3_REPO,
274
  VIT_LARGE_MODEL_DSV3_REPO,
275
+ EVA02_LARGE_MODEL_DSV3_REPO,
276
  ]
277
 
278
  with gr.Blocks(title=TITLE) as demo:
 
322
  label= "Character",
323
  scale=3,
324
  )
325
+ character_output = gr.Checkbox(
326
+ value=True,
327
+ label="Use Output (characters)",
328
+ scale=1,
329
+ )
330
  with gr.Row():
331
  clear = gr.ClearButton(
332
  components=[
 
361
  character_thresh,
362
  character_mcut_enabled,
363
  character_string,
364
+ character_output
365
  ],
366
  outputs=[sorted_general_strings, rating, character_res, general_res],
367
  )