p1atdev commited on
Commit
25ee9b2
1 Parent(s): 9126ead

feat: add model choices, empty image guard

Browse files
Files changed (1) hide show
  1. app.py +57 -14
app.py CHANGED
@@ -12,24 +12,47 @@ import gradio as gr
12
 
13
  from modeling_siglip import SiglipForImageClassification
14
 
15
- MODEL_NAME = os.environ["MODEL_NAME"]
16
- PROCESSOR_NAME = MODEL_NAME
17
  HF_TOKEN = os.environ["HF_READ_TOKEN"]
18
 
19
  EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
20
 
21
- README_MD = """\
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ## SigLIP Tagger Test 3
23
  An experimental model for tagging danbooru tags of images using SigLIP.
24
 
25
  Model(s):
26
- - [p1atdev/siglip-tagger-test-3](https://huggingface.co/p1atdev/siglip-tagger-test-3)
27
-
 
 
 
 
 
28
  Example images by NovelAI and niji・journey.
29
  """
30
-
31
- model = SiglipForImageClassification.from_pretrained(MODEL_NAME, token=HF_TOKEN)
32
- processor = AutoImageProcessor.from_pretrained(PROCESSOR_NAME, token=HF_TOKEN)
33
 
34
 
35
  def compose_text(results: dict[str, float], threshold: float = 0.3):
@@ -43,10 +66,23 @@ def compose_text(results: dict[str, float], threshold: float = 0.3):
43
 
44
 
45
  @torch.no_grad()
46
- def predict_tags(image: Image.Image, threshold: float):
47
- inputs = processor(image, return_tensors="pt")
48
-
49
- logits = model(**inputs.to(model.device, model.dtype)).logits.detach().cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  logits = np.clip(logits, 0.0, 1.0)
52
 
@@ -55,7 +91,9 @@ def predict_tags(image: Image.Image, threshold: float):
55
  for prediction in logits:
56
  for i, prob in enumerate(prediction):
57
  if prob.item() > 0:
58
- results[model.config.id2label[i]] = prob.item()
 
 
59
 
60
  return compose_text(results, threshold), results
61
 
@@ -85,6 +123,11 @@ def demo():
85
  )
86
 
87
  with gr.Group():
 
 
 
 
 
88
  tag_threshold_slider = gr.Slider(
89
  label="Tags threshold",
90
  minimum=0.0,
@@ -107,7 +150,7 @@ def demo():
107
 
108
  start_btn.click(
109
  fn=predict_tags,
110
- inputs=[input_img, tag_threshold_slider],
111
  outputs=[output_tags, output_label],
112
  )
113
 
 
12
 
13
  from modeling_siglip import SiglipForImageClassification
14
 
15
+
 
16
  HF_TOKEN = os.environ["HF_READ_TOKEN"]
17
 
18
  EXAMPLES = [["./images/sample.jpg"], ["./images/sample2.webp"]]
19
 
20
+ model_maps: dict[str, dict] = {
21
+ "test2": {
22
+ "repo": "p1atdev/siglip-tagger-test-2",
23
+ },
24
+ "test3": {
25
+ "repo": "p1atdev/siglip-tagger-test-3",
26
+ },
27
+ # "test4": {
28
+ # "repo": "p1atdev/siglip-tagger-test-4",
29
+ # },
30
+ }
31
+
32
+ for key in model_maps.keys():
33
+ model_maps[key]["model"] = SiglipForImageClassification.from_pretrained(
34
+ model_maps[key]["repo"], torch_dtype=torch.bfloat16, token=HF_TOKEN
35
+ )
36
+ model_maps[key]["processor"] = AutoImageProcessor.from_pretrained(
37
+ model_maps[key]["repo"], token=HF_TOKEN
38
+ )
39
+
40
+ README_MD = (
41
+ f"""\
42
  ## SigLIP Tagger Test 3
43
  An experimental model for tagging danbooru tags of images using SigLIP.
44
 
45
  Model(s):
46
+ """
47
+ + "\n".join(
48
+ f"- [{value['repo']}](https://huggingface.co/{value['repo']})"
49
+ for value in model_maps.values()
50
+ )
51
+ + "\n"
52
+ + """
53
  Example images by NovelAI and niji・journey.
54
  """
55
+ )
 
 
56
 
57
 
58
  def compose_text(results: dict[str, float], threshold: float = 0.3):
 
66
 
67
 
68
  @torch.no_grad()
69
+ def predict_tags(image: Image.Image, model_name: str, threshold: float):
70
+ if image is None:
71
+ return None, None
72
+
73
+ inputs = model_maps[model_name]["processor"](image, return_tensors="pt")
74
+
75
+ logits = (
76
+ model_maps[model_name]["model"](
77
+ **inputs.to(
78
+ model_maps[model_name]["model"].device,
79
+ model_maps[model_name]["model"].dtype,
80
+ )
81
+ )
82
+ .logits.detach()
83
+ .cpu()
84
+ .float()
85
+ )
86
 
87
  logits = np.clip(logits, 0.0, 1.0)
88
 
 
91
  for prediction in logits:
92
  for i, prob in enumerate(prediction):
93
  if prob.item() > 0:
94
+ results[model_maps[model_name]["model"].config.id2label[i]] = (
95
+ prob.item()
96
+ )
97
 
98
  return compose_text(results, threshold), results
99
 
 
123
  )
124
 
125
  with gr.Group():
126
+ model_name_radio = gr.Radio(
127
+ label="Model",
128
+ choices=list(model_maps.keys()),
129
+ value="test3",
130
+ )
131
  tag_threshold_slider = gr.Slider(
132
  label="Tags threshold",
133
  minimum=0.0,
 
150
 
151
  start_btn.click(
152
  fn=predict_tags,
153
+ inputs=[input_img, model_name_radio, tag_threshold_slider],
154
  outputs=[output_tags, output_label],
155
  )
156