Enable mistake flagging

#5
.gitattributes CHANGED
@@ -33,7 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  *.json filter=lfs diff=lfs merge=lfs -text
37
  *.jpeg filter=lfs diff=lfs merge=lfs -text
38
  *.png filter=lfs diff=lfs merge=lfs -text
39
- components/metadata.csv filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
  *.json filter=lfs diff=lfs merge=lfs -text
38
  *.jpeg filter=lfs diff=lfs merge=lfs -text
39
  *.png filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🐘
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- ---
 
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ ---
app.py CHANGED
@@ -6,14 +6,12 @@ import logging
6
 
7
  import gradio as gr
8
  import numpy as np
9
- import polars as pl
10
  import torch
11
  import torch.nn.functional as F
12
  from open_clip import create_model, get_tokenizer
13
  from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
- from components.query import get_sample
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -21,12 +19,6 @@ logger = logging.getLogger()
21
 
22
  hf_token = os.getenv("HF_TOKEN")
23
 
24
- # For sample images
25
- METADATA_PATH = "components/metadata.csv"
26
- # Read page ID as int and filter out smaller ablation duplicated training split
27
- metadata_df = pl.read_csv(METADATA_PATH, low_memory = False)
28
- metadata_df = metadata_df.with_columns(pl.col("eol_page_id").cast(pl.Int64))
29
-
30
  model_str = "hf-hub:imageomics/bioclip"
31
  tokenizer_str = "ViT-B-16"
32
 
@@ -131,14 +123,12 @@ def format_name(taxon, common):
131
 
132
 
133
  @torch.no_grad()
134
- def open_domain_classification(img, rank: int, return_all=False):
135
  """
136
  Predicts from the entire tree of life.
137
  If targeting a higher rank than species, then this function predicts among all
138
  species, then sums up species-level probabilities for the given rank.
139
  """
140
-
141
- logger.info(f"Starting open domain classification for rank: {rank}")
142
  img = preprocess_img(img).to(device)
143
  img_features = model.encode_image(img.unsqueeze(0))
144
  img_features = F.normalize(img_features, dim=-1)
@@ -146,36 +136,21 @@ def open_domain_classification(img, rank: int, return_all=False):
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
- prediction_dict = {
152
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
153
  }
154
- logger.info(f"Top K predictions: {prediction_dict}")
155
- top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
- logger.info(f"Top prediction name: {top_prediction_name}")
157
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
- if return_all:
159
- return prediction_dict, sample_img, taxon_url
160
- return prediction_dict
161
 
 
162
  output = collections.defaultdict(float)
163
  for i in torch.nonzero(probs > min_prob).squeeze():
164
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
165
 
166
  topk_names = heapq.nlargest(k, output, key=output.get)
167
- prediction_dict = {name: output[name] for name in topk_names}
168
- logger.info(f"Top K names for output: {topk_names}")
169
- logger.info(f"Prediction dictionary: {prediction_dict}")
170
 
171
- top_prediction_name = topk_names[0]
172
- logger.info(f"Top prediction name: {top_prediction_name}")
173
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
174
- logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
175
-
176
- if return_all:
177
- return prediction_dict, sample_img, taxon_url
178
- return prediction_dict
179
 
180
 
181
  def change_output(choice):
@@ -204,22 +179,9 @@ if __name__ == "__main__":
204
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
205
 
206
  with gr.Blocks() as app:
207
-
208
- with gr.Tab("Open-Ended"):
209
- with gr.Row(variant = "panel", elem_id = "images_panel"):
210
- with gr.Column():
211
- img_input = gr.Image(height = 400, sources=["upload"])
212
-
213
- with gr.Column():
214
- # display sample image of top predicted taxon
215
- sample_img = gr.Image(label = "Sample Image of Predicted Taxon",
216
- height = 400,
217
- show_download_button = False)
218
-
219
- taxon_url = gr.HTML(label = "More Information",
220
- elem_id = "url"
221
- )
222
 
 
223
  with gr.Row():
224
  with gr.Column():
225
  rank_dropdown = gr.Dropdown(
@@ -237,24 +199,23 @@ if __name__ == "__main__":
237
  show_label=True,
238
  value=None,
239
  )
240
- # open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
241
 
242
  with gr.Row():
243
  gr.Examples(
244
  examples=open_domain_examples,
245
  inputs=[img_input, rank_dropdown],
246
  cache_examples=True,
247
- fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
248
  outputs=[open_domain_output],
249
  )
250
- '''
251
- # Flagging Code
252
  open_domain_callback = gr.HuggingFaceDatasetSaver(
253
- hf_token, "bioclip-demo-open-domain-mistakes", private=True
254
  )
255
  open_domain_callback.setup(
256
  [img_input, rank_dropdown, open_domain_output],
257
- flagging_dir="bioclip-demo-open-domain-mistakes/logs/flagged",
258
  )
259
  open_domain_flag_btn.click(
260
  lambda *args: open_domain_callback.flag(args),
@@ -262,11 +223,8 @@ if __name__ == "__main__":
262
  None,
263
  preprocess=False,
264
  )
265
- '''
266
- with gr.Tab("Zero-Shot"):
267
- with gr.Row():
268
- img_input_zs = gr.Image(height = 400, sources=["upload"])
269
 
 
270
  with gr.Row():
271
  with gr.Column():
272
  classes_txt = gr.Textbox(
@@ -282,23 +240,22 @@ if __name__ == "__main__":
282
  zero_shot_output = gr.Label(
283
  num_top_classes=k, label="Prediction", show_label=True
284
  )
285
- # zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
286
 
287
  with gr.Row():
288
  gr.Examples(
289
  examples=zero_shot_examples,
290
- inputs=[img_input_zs, classes_txt],
291
  cache_examples=True,
292
  fn=zero_shot_classification,
293
  outputs=[zero_shot_output],
294
  )
295
- '''
296
- # Flagging Code
297
  zero_shot_callback = gr.HuggingFaceDatasetSaver(
298
- hf_token, "bioclip-demo-zero-shot-mistakes", private=True
299
  )
300
  zero_shot_callback.setup(
301
- [img_input, zero_shot_output], flagging_dir="bioclip-demo-zero-shot-mistakes/logs/flagged"
302
  )
303
  zero_shot_flag_btn.click(
304
  lambda *args: zero_shot_callback.flag(args),
@@ -306,32 +263,22 @@ if __name__ == "__main__":
306
  None,
307
  preprocess=False,
308
  )
309
- '''
310
  rank_dropdown.change(
311
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
312
  )
313
 
314
  open_domain_btn.click(
315
- fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
316
  inputs=[img_input, rank_dropdown],
317
- outputs=[open_domain_output, sample_img, taxon_url],
318
  )
319
 
320
  zero_shot_btn.click(
321
  fn=zero_shot_classification,
322
- inputs=[img_input_zs, classes_txt],
323
  outputs=zero_shot_output,
324
  )
325
-
326
- # Footer to point out to model and data from app page.
327
- gr.Markdown(
328
- """
329
- For more information on the [BioCLIP Model](https://huggingface.co/imageomics/bioclip) creation, see our [BioCLIP Project GitHub](https://github.com/Imageomics/bioclip), and
330
- for easier integration of BioCLIP, checkout [pybioclip](https://github.com/Imageomics/pybioclip).
331
-
332
- To learn more about the data, check out our [TreeOfLife-10M Dataset](https://huggingface.co/datasets/imageomics/TreeOfLife-10M).
333
- """
334
- )
335
 
336
  app.queue(max_size=20)
337
- app.launch(share=True)
 
6
 
7
  import gradio as gr
8
  import numpy as np
 
9
  import torch
10
  import torch.nn.functional as F
11
  from open_clip import create_model, get_tokenizer
12
  from torchvision import transforms
13
 
14
  from templates import openai_imagenet_template
 
15
 
16
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
17
  logging.basicConfig(level=logging.INFO, format=log_format)
 
19
 
20
  hf_token = os.getenv("HF_TOKEN")
21
 
 
 
 
 
 
 
22
  model_str = "hf-hub:imageomics/bioclip"
23
  tokenizer_str = "ViT-B-16"
24
 
 
123
 
124
 
125
  @torch.no_grad()
126
+ def open_domain_classification(img, rank: int) -> dict[str, float]:
127
  """
128
  Predicts from the entire tree of life.
129
  If targeting a higher rank than species, then this function predicts among all
130
  species, then sums up species-level probabilities for the given rank.
131
  """
 
 
132
  img = preprocess_img(img).to(device)
133
  img_features = model.encode_image(img.unsqueeze(0))
134
  img_features = F.normalize(img_features, dim=-1)
 
136
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
137
  probs = F.softmax(logits, dim=0)
138
 
139
+ # If predicting species, no need to sum probabilities.
140
  if rank + 1 == len(ranks):
141
  topk = probs.topk(k)
142
+ return {
143
  format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
144
  }
 
 
 
 
 
 
 
145
 
146
+ # Sum up by the rank
147
  output = collections.defaultdict(float)
148
  for i in torch.nonzero(probs > min_prob).squeeze():
149
  output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
150
 
151
  topk_names = heapq.nlargest(k, output, key=output.get)
 
 
 
152
 
153
+ return {name: output[name] for name in topk_names}
 
 
 
 
 
 
 
154
 
155
 
156
  def change_output(choice):
 
179
  status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
180
 
181
  with gr.Blocks() as app:
182
+ img_input = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ with gr.Tab("Open-Ended"):
185
  with gr.Row():
186
  with gr.Column():
187
  rank_dropdown = gr.Dropdown(
 
199
  show_label=True,
200
  value=None,
201
  )
202
+ open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
203
 
204
  with gr.Row():
205
  gr.Examples(
206
  examples=open_domain_examples,
207
  inputs=[img_input, rank_dropdown],
208
  cache_examples=True,
209
+ fn=open_domain_classification,
210
  outputs=[open_domain_output],
211
  )
212
+
 
213
  open_domain_callback = gr.HuggingFaceDatasetSaver(
214
+ hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
215
  )
216
  open_domain_callback.setup(
217
  [img_input, rank_dropdown, open_domain_output],
218
+ flagging_dir="logs/flagged",
219
  )
220
  open_domain_flag_btn.click(
221
  lambda *args: open_domain_callback.flag(args),
 
223
  None,
224
  preprocess=False,
225
  )
 
 
 
 
226
 
227
+ with gr.Tab("Zero-Shot"):
228
  with gr.Row():
229
  with gr.Column():
230
  classes_txt = gr.Textbox(
 
240
  zero_shot_output = gr.Label(
241
  num_top_classes=k, label="Prediction", show_label=True
242
  )
243
+ zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
244
 
245
  with gr.Row():
246
  gr.Examples(
247
  examples=zero_shot_examples,
248
+ inputs=[img_input, classes_txt],
249
  cache_examples=True,
250
  fn=zero_shot_classification,
251
  outputs=[zero_shot_output],
252
  )
253
+
 
254
  zero_shot_callback = gr.HuggingFaceDatasetSaver(
255
+ hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
256
  )
257
  zero_shot_callback.setup(
258
+ [img_input, zero_shot_output], flagging_dir="logs/flagged"
259
  )
260
  zero_shot_flag_btn.click(
261
  lambda *args: zero_shot_callback.flag(args),
 
263
  None,
264
  preprocess=False,
265
  )
266
+
267
  rank_dropdown.change(
268
  fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
269
  )
270
 
271
  open_domain_btn.click(
272
+ fn=open_domain_classification,
273
  inputs=[img_input, rank_dropdown],
274
+ outputs=[open_domain_output],
275
  )
276
 
277
  zero_shot_btn.click(
278
  fn=zero_shot_classification,
279
+ inputs=[img_input, classes_txt],
280
  outputs=zero_shot_output,
281
  )
 
 
 
 
 
 
 
 
 
 
282
 
283
  app.queue(max_size=20)
284
+ app.launch()
components/metadata.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8576f6ca106f35387506369a70df01fb92192a740c3b5da2a12ad8303976aad
3
- size 233934143
 
 
 
 
components/metadata_readme.md DELETED
@@ -1,11 +0,0 @@
1
- ---
2
- title: Bioclip Demo
3
- emoji: 🐘
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
components/query.py DELETED
@@ -1,116 +0,0 @@
1
- import io
2
- import boto3
3
- import requests
4
- import numpy as np
5
- import polars as pl
6
- from PIL import Image
7
- from botocore.config import Config
8
- import logging
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- # S3 for sample images
13
- my_config = Config(
14
- region_name='us-east-1'
15
- )
16
- s3_client = boto3.client('s3', config=my_config)
17
-
18
- # Set basepath for EOL pages for info
19
- EOL_URL = "https://eol.org/pages/"
20
- RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
21
-
22
- def get_sample(df, pred_taxon, rank):
23
- '''
24
- Function to retrieve a sample image of the predicted taxon and EOL page link for more info.
25
-
26
- Parameters:
27
- -----------
28
- df : DataFrame
29
- DataFrame with all sample images listed and their filepaths (in "file_path" column).
30
- pred_taxon : str
31
- Predicted taxon of the uploaded image.
32
- rank : int
33
- Index of rank in RANKS chosen for prediction.
34
-
35
- Returns:
36
- --------
37
- img : PIL.Image
38
- Sample image of predicted taxon for display.
39
- eol_page : str
40
- URL to EOL page for the taxon (may be a lower rank, e.g., species sample).
41
- '''
42
- logger.info(f"Getting sample for taxon: {pred_taxon} at rank: {rank}")
43
- try:
44
- filepath, eol_page_id, full_name, is_exact = get_sample_data(df, pred_taxon, rank)
45
- except Exception as e:
46
- logger.error(f"Error retrieving sample data: {e}")
47
- return None, f"We encountered the following error trying to retrieve a sample image: {e}."
48
- if filepath is None:
49
- logger.warning(f"No sample image found for taxon: {pred_taxon}")
50
- return None, f"Sorry, our EOL images do not include {pred_taxon}."
51
-
52
- # Get sample image of selected individual
53
- try:
54
- img_src = s3_client.generate_presigned_url('get_object',
55
- Params={'Bucket': 'treeoflife-10m-sample-images',
56
- 'Key': filepath}
57
- )
58
- img_resp = requests.get(img_src)
59
- img = Image.open(io.BytesIO(img_resp.content))
60
- full_eol_url = EOL_URL + eol_page_id
61
- if is_exact:
62
- eol_page = f"<p>Check out the EOL entry for {pred_taxon} to learn more: <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
63
- else:
64
- eol_page = f"<p>Check out an example EOL entry within {pred_taxon} to learn more: {full_name} <a href={full_eol_url} target='_blank'>{full_eol_url}</a>.</p>"
65
- logger.info(f"Successfully retrieved sample image and EOL page for {pred_taxon}")
66
- return img, eol_page
67
- except Exception as e:
68
- logger.error(f"Error retrieving sample image: {e}")
69
- return None, f"We encountered the following error trying to retrieve a sample image: {e}."
70
-
71
- def get_sample_data(df, pred_taxon, rank):
72
- '''
73
- Function to randomly select a sample individual of the given taxon and provide associated native location.
74
-
75
- Parameters:
76
- -----------
77
- df : DataFrame
78
- DataFrame with all sample images listed and their filepaths (in "file_path" column).
79
- pred_taxon : str
80
- Predicted taxon of the uploaded image.
81
- rank : int
82
- Index of rank in RANKS chosen for prediction.
83
-
84
- Returns:
85
- --------
86
- filepath : str
87
- Filepath of selected sample image for predicted taxon.
88
- eol_page_id : str
89
- EOL page ID associated with predicted taxon for more information.
90
- full_name : str
91
- Full taxonomic name of the selected sample.
92
- is_exact : bool
93
- Flag indicating if the match is exact (i.e., with empty lower ranks).
94
- '''
95
- for idx in range(rank + 1):
96
- taxon = RANKS[idx]
97
- target_taxon = pred_taxon.split(" ")[idx]
98
- df = df.filter(pl.col(taxon) == target_taxon)
99
-
100
- if df.shape[0] == 0:
101
- return None, np.nan, "", False
102
-
103
- # First, try to find entries with empty lower ranks
104
- exact_df = df
105
- for lower_rank in RANKS[rank + 1:]:
106
- exact_df = exact_df.filter((pl.col(lower_rank).is_null()) | (pl.col(lower_rank) == ""))
107
-
108
- if exact_df.shape[0] > 0:
109
- df_filtered = exact_df.sample()
110
- full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0))
111
- return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, True
112
-
113
- # If no exact matches, return any entry with the specified rank
114
- df_filtered = df.sample()
115
- full_name = " ".join(df_filtered.select(RANKS[:rank+1]).row(0)) + " " + " ".join(df_filtered.select(RANKS[rank+1:]).row(0))
116
- return df_filtered["file_path"][0], df_filtered["eol_page_id"].cast(pl.String)[0], full_name, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
components/sync_samples_to_s3.bash DELETED
@@ -1,34 +0,0 @@
1
- #!/bin/bash
2
-
3
- <<COMMENT
4
- Usage:
5
- bash sync_samples_to_s3.bash <BASE_DIR>
6
-
7
- Dependencies:
8
- - awscli (https://aws.amazon.com/cli/)
9
- Credentials to export as environment variables:
10
- - AWS_ACCESS_KEY_ID
11
- - AWS_SECRET_ACCESS_KEY
12
- COMMENT
13
-
14
- # Check if a valid directory is provided as an argument
15
- if [ -z "$1" ]; then
16
- echo "Usage: $0 <BASE_DIR>"
17
- exit 1
18
- fi
19
-
20
- if [ ! -d "$1" ]; then
21
- echo "Error: $1 is not a valid directory"
22
- exit 1
23
- fi
24
-
25
- BASE_DIR="$1"
26
- S3_BUCKET="s3://treeoflife-10m-sample-images"
27
-
28
- # Loop through all directories and sync them to S3
29
- for dir in $BASE_DIR/*; do
30
- if [ -d "$dir" ]; then
31
- dir_name=$(basename "$dir")
32
- aws s3 sync "$dir" "$S3_BUCKET/$dir_name/"
33
- fi
34
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
  open_clip_torch
2
  torchvision
3
  torch
4
- gradio
5
- polars
6
- pillow
7
- boto3
 
1
  open_clip_torch
2
  torchvision
3
  torch
4
+ gradio