Upload app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,12 @@ DESCRIPTION = """
|
|
14 |
implications_list_path = './implications_list.json'
|
15 |
related_feature_path = './related_feature.json'
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
#HF_TOKEN = os.environ["HF_TOKEN"]
|
18 |
|
19 |
# Dataset v3 series of models:
|
@@ -21,6 +27,7 @@ SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
|
21 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
22 |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
23 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
|
|
24 |
|
25 |
# Dataset v2 series of models:
|
26 |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
@@ -61,8 +68,7 @@ def parse_args() -> argparse.Namespace:
|
|
61 |
parser = argparse.ArgumentParser()
|
62 |
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
63 |
parser.add_argument("--score-general-threshold", type=float, default=0.4)
|
64 |
-
parser.add_argument("--score-character-threshold", type=float, default=0.
|
65 |
-
parser.add_argument("--character_string", type=str)
|
66 |
parser.add_argument("--share", action="store_true")
|
67 |
return parser.parse_args()
|
68 |
|
@@ -170,6 +176,7 @@ class Predictor:
|
|
170 |
character_thresh,
|
171 |
character_mcut_enabled,
|
172 |
character_string,
|
|
|
173 |
):
|
174 |
self.load_model(model_repo)
|
175 |
|
@@ -195,13 +202,6 @@ class Predictor:
|
|
195 |
general_res = [x for x in general_names if x[1] > general_thresh]
|
196 |
general_res = dict(general_res)
|
197 |
|
198 |
-
with open(related_feature_path, 'r') as f:
|
199 |
-
related_feature_list = json.load(f)
|
200 |
-
|
201 |
-
|
202 |
-
with open(implications_list_path, 'r') as f:
|
203 |
-
implications_list = json.load(f)
|
204 |
-
|
205 |
to_delete = set()
|
206 |
for key in general_res.keys():
|
207 |
if key in implications_list:
|
@@ -221,16 +221,30 @@ class Predictor:
|
|
221 |
character_res = [x for x in character_names if x[1] > character_thresh]
|
222 |
character_res = dict(character_res)
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
sorted_general_strings = sorted(
|
225 |
general_res.items(),
|
226 |
key=lambda x: x[1],
|
227 |
reverse=True,
|
228 |
)
|
229 |
|
230 |
-
character_list =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
feature_delete_list = []
|
233 |
-
for tag in
|
234 |
if tag in related_feature_list:
|
235 |
feature_delete_list.extend(related_feature_list[tag])
|
236 |
|
@@ -238,6 +252,8 @@ class Predictor:
|
|
238 |
|
239 |
sorted_general_strings = [x for x in sorted_general_strings if x not in feature_delete_list]
|
240 |
|
|
|
|
|
241 |
sorted_general_strings = [x.replace("_", " ") if x not in kaomojis else x for x in sorted_general_strings]
|
242 |
|
243 |
sorted_general_strings = (
|
@@ -246,7 +262,6 @@ class Predictor:
|
|
246 |
|
247 |
return sorted_general_strings, rating, character_res, general_res
|
248 |
|
249 |
-
|
250 |
def main():
|
251 |
args = parse_args()
|
252 |
|
@@ -257,6 +272,7 @@ def main():
|
|
257 |
CONV_MODEL_DSV3_REPO,
|
258 |
VIT_MODEL_DSV3_REPO,
|
259 |
VIT_LARGE_MODEL_DSV3_REPO,
|
|
|
260 |
]
|
261 |
|
262 |
with gr.Blocks(title=TITLE) as demo:
|
@@ -306,6 +322,11 @@ def main():
|
|
306 |
label= "Character",
|
307 |
scale=3,
|
308 |
)
|
|
|
|
|
|
|
|
|
|
|
309 |
with gr.Row():
|
310 |
clear = gr.ClearButton(
|
311 |
components=[
|
@@ -340,6 +361,7 @@ def main():
|
|
340 |
character_thresh,
|
341 |
character_mcut_enabled,
|
342 |
character_string,
|
|
|
343 |
],
|
344 |
outputs=[sorted_general_strings, rating, character_res, general_res],
|
345 |
)
|
|
|
14 |
implications_list_path = './implications_list.json'
|
15 |
related_feature_path = './related_feature.json'
|
16 |
|
17 |
+
with open(related_feature_path, 'r') as f:
|
18 |
+
related_feature_list = json.load(f)
|
19 |
+
|
20 |
+
with open(implications_list_path, 'r') as f:
|
21 |
+
implications_list = json.load(f)
|
22 |
+
|
23 |
#HF_TOKEN = os.environ["HF_TOKEN"]
|
24 |
|
25 |
# Dataset v3 series of models:
|
|
|
27 |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
28 |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
29 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
30 |
+
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
31 |
|
32 |
# Dataset v2 series of models:
|
33 |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
|
|
68 |
parser = argparse.ArgumentParser()
|
69 |
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
70 |
parser.add_argument("--score-general-threshold", type=float, default=0.4)
|
71 |
+
parser.add_argument("--score-character-threshold", type=float, default=0.8)
|
|
|
72 |
parser.add_argument("--share", action="store_true")
|
73 |
return parser.parse_args()
|
74 |
|
|
|
176 |
character_thresh,
|
177 |
character_mcut_enabled,
|
178 |
character_string,
|
179 |
+
character_output
|
180 |
):
|
181 |
self.load_model(model_repo)
|
182 |
|
|
|
202 |
general_res = [x for x in general_names if x[1] > general_thresh]
|
203 |
general_res = dict(general_res)
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
to_delete = set()
|
206 |
for key in general_res.keys():
|
207 |
if key in implications_list:
|
|
|
221 |
character_res = [x for x in character_names if x[1] > character_thresh]
|
222 |
character_res = dict(character_res)
|
223 |
|
224 |
+
character_strings = sorted(
|
225 |
+
character_res.items(),
|
226 |
+
key=lambda x: x[1],
|
227 |
+
reverse=True,
|
228 |
+
)
|
229 |
+
character_strings = [x[0] for x in character_strings]
|
230 |
+
|
231 |
sorted_general_strings = sorted(
|
232 |
general_res.items(),
|
233 |
key=lambda x: x[1],
|
234 |
reverse=True,
|
235 |
)
|
236 |
|
237 |
+
character_list = []
|
238 |
+
if character_string != '':
|
239 |
+
character_list = character_string.lower().split(', ')
|
240 |
+
|
241 |
+
if character_output:
|
242 |
+
character_combined = character_list + character_strings
|
243 |
+
else:
|
244 |
+
character_combined = character_list
|
245 |
|
246 |
feature_delete_list = []
|
247 |
+
for tag in character_combined:
|
248 |
if tag in related_feature_list:
|
249 |
feature_delete_list.extend(related_feature_list[tag])
|
250 |
|
|
|
252 |
|
253 |
sorted_general_strings = [x for x in sorted_general_strings if x not in feature_delete_list]
|
254 |
|
255 |
+
sorted_general_strings = character_combined + sorted_general_strings
|
256 |
+
|
257 |
sorted_general_strings = [x.replace("_", " ") if x not in kaomojis else x for x in sorted_general_strings]
|
258 |
|
259 |
sorted_general_strings = (
|
|
|
262 |
|
263 |
return sorted_general_strings, rating, character_res, general_res
|
264 |
|
|
|
265 |
def main():
|
266 |
args = parse_args()
|
267 |
|
|
|
272 |
CONV_MODEL_DSV3_REPO,
|
273 |
VIT_MODEL_DSV3_REPO,
|
274 |
VIT_LARGE_MODEL_DSV3_REPO,
|
275 |
+
EVA02_LARGE_MODEL_DSV3_REPO,
|
276 |
]
|
277 |
|
278 |
with gr.Blocks(title=TITLE) as demo:
|
|
|
322 |
label= "Character",
|
323 |
scale=3,
|
324 |
)
|
325 |
+
character_output = gr.Checkbox(
|
326 |
+
value=True,
|
327 |
+
label="Use Output (characters)",
|
328 |
+
scale=1,
|
329 |
+
)
|
330 |
with gr.Row():
|
331 |
clear = gr.ClearButton(
|
332 |
components=[
|
|
|
361 |
character_thresh,
|
362 |
character_mcut_enabled,
|
363 |
character_string,
|
364 |
+
character_output
|
365 |
],
|
366 |
outputs=[sorted_general_strings, rating, character_res, general_res],
|
367 |
)
|