Thiago Hersan commited on
Commit
3b5780d
1 Parent(s): 161d4c1

fixed cached examples

Browse files
Files changed (2) hide show
  1. app.py +6 -7
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,7 +2,6 @@ import glob
2
  import gradio as gr
3
  import numpy as np
4
  from os import environ
5
- from pathlib import Path
6
  from PIL import Image
7
  from torchvision import transforms as T
8
  from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
@@ -51,7 +50,7 @@ def visualize_instance_seg_mask(img_in, mask, id2label):
51
  img_out[i, j, :] = id2color[mask[i, j]]
52
  id2count[mask[i, j]] = id2count[mask[i, j]] + 1
53
 
54
- image_res = (0.5 * img_in + 0.5 * img_out) / 255
55
 
56
  dataframe = [[
57
  f"{id2label[id]}",
@@ -70,19 +69,19 @@ def visualize_instance_seg_mask(img_in, mask, id2label):
70
 
71
 
72
  def query_image(image_path):
73
- img = np.array(Image.open(Path(image_path)))
74
  img_size = (img.shape[0], img.shape[1])
75
  inputs = preprocessor(images=test_transform(img), return_tensors="pt")
76
 
77
  outputs = model(**inputs)
78
 
79
  results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
80
- results = visualize_instance_seg_mask(img, results.numpy(), model.config.id2label)
81
- return results
82
 
83
 
84
  demo = gr.Interface(
85
- query_image,
86
  inputs=[gr.Image(type="filepath", label="Input Image")],
87
  outputs=[
88
  gr.Image(label="Vegetation"),
@@ -92,7 +91,7 @@ demo = gr.Interface(
92
  allow_flagging="never",
93
  analytics_enabled=None,
94
  examples=example_images,
95
- cache_examples=False
96
  )
97
 
98
  demo.launch(show_api=False)
 
2
  import gradio as gr
3
  import numpy as np
4
  from os import environ
 
5
  from PIL import Image
6
  from torchvision import transforms as T
7
  from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
 
50
  img_out[i, j, :] = id2color[mask[i, j]]
51
  id2count[mask[i, j]] = id2count[mask[i, j]] + 1
52
 
53
+ image_res = (0.5 * img_in + 0.5 * img_out).astype(np.uint8)
54
 
55
  dataframe = [[
56
  f"{id2label[id]}",
 
69
 
70
 
71
  def query_image(image_path):
72
+ img = np.array(Image.open(image_path))
73
  img_size = (img.shape[0], img.shape[1])
74
  inputs = preprocessor(images=test_transform(img), return_tensors="pt")
75
 
76
  outputs = model(**inputs)
77
 
78
  results = preprocessor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[img_size])[0]
79
+ mask_img, dataframe = visualize_instance_seg_mask(img, results.numpy(), model.config.id2label)
80
+ return mask_img, dataframe
81
 
82
 
83
  demo = gr.Interface(
84
+ fn=query_image,
85
  inputs=[gr.Image(type="filepath", label="Input Image")],
86
  outputs=[
87
  gr.Image(label="Vegetation"),
 
91
  allow_flagging="never",
92
  analytics_enabled=None,
93
  examples=example_images,
94
+ cache_examples=True
95
  )
96
 
97
  demo.launch(show_api=False)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  Pillow
2
  scipy
3
  torch
 
1
+ gradio==3.14
2
  Pillow
3
  scipy
4
  torch