dennistrujillo commited on
Commit
16fa719
·
1 Parent(s): 4b69b97

fixed it, box_tensor was needed to be resized and the plot function was wonky

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  from segment_anything import sam_model_registry
9
  import matplotlib.pyplot as plt
10
  from PIL import Image
 
11
  import io
12
 
13
  def load_image(file_path):
@@ -24,11 +25,13 @@ def load_image(file_path):
24
  H, W = img.shape[:2]
25
  return img, H, W
26
 
 
27
  def medsam_inference(medsam_model, img_embed, box_1024, H, W):
28
  box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
29
  if len(box_torch.shape) == 2:
30
  box_torch = box_torch[:, None, :] # (B, 1, 4)
31
 
 
32
  sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
33
  points=None,
34
  boxes=box_torch,
@@ -67,11 +70,7 @@ def visualize(image, mask, box):
67
  ax[1].imshow(image, cmap='gray')
68
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
69
  plt.tight_layout()
70
- buf = io.BytesIO()
71
- plt.savefig(buf, format='png')
72
- plt.close(fig)
73
- buf.seek(0)
74
- return buf
75
 
76
  # Main function for Gradio app
77
  def process_images(file, x_min, y_min, x_max, y_max):
@@ -79,7 +78,11 @@ def process_images(file, x_min, y_min, x_max, y_max):
79
 
80
  # Load and preprocess image
81
  image, H, W = load_image(file)
82
- image_resized = transform.resize(image, (1024, 1024), anti_aliasing=True)
 
 
 
 
83
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
84
  image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
85
 
@@ -102,7 +105,7 @@ def process_images(file, x_min, y_min, x_max, y_max):
102
 
103
  # Visualization
104
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
105
- return visualization.getvalue()
106
 
107
 
108
  # Set up Gradio interface
 
8
  from segment_anything import sam_model_registry
9
  import matplotlib.pyplot as plt
10
  from PIL import Image
11
+ import torch.nn.functional as F
12
  import io
13
 
14
  def load_image(file_path):
 
25
  H, W = img.shape[:2]
26
  return img, H, W
27
 
28
+ @torch.no_grad()
29
  def medsam_inference(medsam_model, img_embed, box_1024, H, W):
30
  box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
31
  if len(box_torch.shape) == 2:
32
  box_torch = box_torch[:, None, :] # (B, 1, 4)
33
 
34
+ box_torch=box_torch.reshape(1,4)
35
  sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
36
  points=None,
37
  boxes=box_torch,
 
70
  ax[1].imshow(image, cmap='gray')
71
  ax[1].imshow(mask, alpha=0.5, cmap="jet")
72
  plt.tight_layout()
73
+ return fig
 
 
 
 
74
 
75
  # Main function for Gradio app
76
  def process_images(file, x_min, y_min, x_max, y_max):
 
78
 
79
  # Load and preprocess image
80
  image, H, W = load_image(file)
81
+ if len(image.shape) == 2:
82
+ image = np.repeat(image[:, :, None], 3, axis=-1)
83
+ H, W, _ = image.shape
84
+
85
+ image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
86
  image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
87
  image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
88
 
 
105
 
106
  # Visualization
107
  visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
108
+ return visualization #.getvalue()
109
 
110
 
111
  # Set up Gradio interface