Spaces:
Running
Running
feat: add model choices, empty image guard
Browse files
app.py
CHANGED
@@ -12,24 +12,47 @@ import gradio as gr
|
|
12 |
|
13 |
from modeling_siglip import SiglipForImageClassification
|
14 |
|
15 |
-
|
16 |
-
PROCESSOR_NAME = MODEL_NAME
|
17 |
HF_TOKEN = os.environ["HF_READ_TOKEN"]
|
18 |
|
19 |
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
## SigLIP Tagger Test 3
|
23 |
An experimental model for tagging danbooru tags of images using SigLIP.
|
24 |
|
25 |
Model(s):
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
Example images by NovelAI and niji・journey.
|
29 |
"""
|
30 |
-
|
31 |
-
model = SiglipForImageClassification.from_pretrained(MODEL_NAME, token=HF_TOKEN)
|
32 |
-
processor = AutoImageProcessor.from_pretrained(PROCESSOR_NAME, token=HF_TOKEN)
|
33 |
|
34 |
|
35 |
def compose_text(results: dict[str, float], threshold: float = 0.3):
|
@@ -43,10 +66,23 @@ def compose_text(results: dict[str, float], threshold: float = 0.3):
|
|
43 |
|
44 |
|
45 |
@torch.no_grad()
|
46 |
-
def predict_tags(image: Image.Image, threshold: float):
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
logits = np.clip(logits, 0.0, 1.0)
|
52 |
|
@@ -55,7 +91,9 @@ def predict_tags(image: Image.Image, threshold: float):
|
|
55 |
for prediction in logits:
|
56 |
for i, prob in enumerate(prediction):
|
57 |
if prob.item() > 0:
|
58 |
-
results[model.config.id2label[i]] =
|
|
|
|
|
59 |
|
60 |
return compose_text(results, threshold), results
|
61 |
|
@@ -85,6 +123,11 @@ def demo():
|
|
85 |
)
|
86 |
|
87 |
with gr.Group():
|
|
|
|
|
|
|
|
|
|
|
88 |
tag_threshold_slider = gr.Slider(
|
89 |
label="Tags threshold",
|
90 |
minimum=0.0,
|
@@ -107,7 +150,7 @@ def demo():
|
|
107 |
|
108 |
start_btn.click(
|
109 |
fn=predict_tags,
|
110 |
-
inputs=[input_img, tag_threshold_slider],
|
111 |
outputs=[output_tags, output_label],
|
112 |
)
|
113 |
|
|
|
12 |
|
13 |
from modeling_siglip import SiglipForImageClassification
|
14 |
|
15 |
+
|
|
|
16 |
HF_TOKEN = os.environ["HF_READ_TOKEN"]
|
17 |
|
18 |
EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
|
19 |
|
20 |
+
model_maps: dict[str, dict] = {
|
21 |
+
"test2": {
|
22 |
+
"repo": "p1atdev/siglip-tagger-test-2",
|
23 |
+
},
|
24 |
+
"test3": {
|
25 |
+
"repo": "p1atdev/siglip-tagger-test-3",
|
26 |
+
},
|
27 |
+
# "test4": {
|
28 |
+
# "repo": "p1atdev/siglip-tagger-test-4",
|
29 |
+
# },
|
30 |
+
}
|
31 |
+
|
32 |
+
for key in model_maps.keys():
|
33 |
+
model_maps[key]["model"] = SiglipForImageClassification.from_pretrained(
|
34 |
+
model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN
|
35 |
+
)
|
36 |
+
model_maps[key]["processor"] = AutoImageProcessor.from_pretrained(
|
37 |
+
model_maps[key]["repo"], token=HF_TOKEN
|
38 |
+
)
|
39 |
+
|
40 |
+
README_MD = (
|
41 |
+
f"""\
|
42 |
## SigLIP Tagger Test 3
|
43 |
An experimental model for tagging danbooru tags of images using SigLIP.
|
44 |
|
45 |
Model(s):
|
46 |
+
"""
|
47 |
+
+ "\n".join(
|
48 |
+
f"- [{value['repo']}](https://huggingface.co/{value['repo']})"
|
49 |
+
for value in model_maps.values()
|
50 |
+
)
|
51 |
+
+ "\n"
|
52 |
+
+ """
|
53 |
Example images by NovelAI and niji・journey.
|
54 |
"""
|
55 |
+
)
|
|
|
|
|
56 |
|
57 |
|
58 |
def compose_text(results: dict[str, float], threshold: float = 0.3):
|
|
|
66 |
|
67 |
|
68 |
@torch.no_grad()
|
69 |
+
def predict_tags(image: Image.Image, model_name: str, threshold: float):
|
70 |
+
if image is None:
|
71 |
+
return None, None
|
72 |
+
|
73 |
+
inputs = model_maps[model_name]["processor"](image, return_tensors="pt")
|
74 |
+
|
75 |
+
logits = (
|
76 |
+
model_maps[model_name]["model"](
|
77 |
+
**inputs.to(
|
78 |
+
model_maps[model_name]["model"].device,
|
79 |
+
model_maps[model_name]["model"].dtype,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
.logits.detach()
|
83 |
+
.cpu()
|
84 |
+
.float()
|
85 |
+
)
|
86 |
|
87 |
logits = np.clip(logits, 0.0, 1.0)
|
88 |
|
|
|
91 |
for prediction in logits:
|
92 |
for i, prob in enumerate(prediction):
|
93 |
if prob.item() > 0:
|
94 |
+
results[model_maps[model_name]["model"].config.id2label[i]] = (
|
95 |
+
prob.item()
|
96 |
+
)
|
97 |
|
98 |
return compose_text(results, threshold), results
|
99 |
|
|
|
123 |
)
|
124 |
|
125 |
with gr.Group():
|
126 |
+
model_name_radio = gr.Radio(
|
127 |
+
label="Model",
|
128 |
+
choices=list(model_maps.keys()),
|
129 |
+
value="test3",
|
130 |
+
)
|
131 |
tag_threshold_slider = gr.Slider(
|
132 |
label="Tags threshold",
|
133 |
minimum=0.0,
|
|
|
150 |
|
151 |
start_btn.click(
|
152 |
fn=predict_tags,
|
153 |
+
inputs=[input_img, model_name_radio, tag_threshold_slider],
|
154 |
outputs=[output_tags, output_label],
|
155 |
)
|
156 |
|