Spaces:
Running
Running
dennistrujillo
commited on
Commit
·
16fa719
1
Parent(s):
4b69b97
fixed it, box_tensor was needed to be resized and the plot function was wonky
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
from segment_anything import sam_model_registry
|
9 |
import matplotlib.pyplot as plt
|
10 |
from PIL import Image
|
|
|
11 |
import io
|
12 |
|
13 |
def load_image(file_path):
|
@@ -24,11 +25,13 @@ def load_image(file_path):
|
|
24 |
H, W = img.shape[:2]
|
25 |
return img, H, W
|
26 |
|
|
|
27 |
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
28 |
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
|
29 |
if len(box_torch.shape) == 2:
|
30 |
box_torch = box_torch[:, None, :] # (B, 1, 4)
|
31 |
|
|
|
32 |
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
|
33 |
points=None,
|
34 |
boxes=box_torch,
|
@@ -67,11 +70,7 @@ def visualize(image, mask, box):
|
|
67 |
ax[1].imshow(image, cmap='gray')
|
68 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
69 |
plt.tight_layout()
|
70 |
-
|
71 |
-
plt.savefig(buf, format='png')
|
72 |
-
plt.close(fig)
|
73 |
-
buf.seek(0)
|
74 |
-
return buf
|
75 |
|
76 |
# Main function for Gradio app
|
77 |
def process_images(file, x_min, y_min, x_max, y_max):
|
@@ -79,7 +78,11 @@ def process_images(file, x_min, y_min, x_max, y_max):
|
|
79 |
|
80 |
# Load and preprocess image
|
81 |
image, H, W = load_image(file)
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
84 |
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
85 |
|
@@ -102,7 +105,7 @@ def process_images(file, x_min, y_min, x_max, y_max):
|
|
102 |
|
103 |
# Visualization
|
104 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
105 |
-
return visualization
|
106 |
|
107 |
|
108 |
# Set up Gradio interface
|
|
|
8 |
from segment_anything import sam_model_registry
|
9 |
import matplotlib.pyplot as plt
|
10 |
from PIL import Image
|
11 |
+
import torch.nn.functional as F
|
12 |
import io
|
13 |
|
14 |
def load_image(file_path):
|
|
|
25 |
H, W = img.shape[:2]
|
26 |
return img, H, W
|
27 |
|
28 |
+
@torch.no_grad()
|
29 |
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
30 |
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
|
31 |
if len(box_torch.shape) == 2:
|
32 |
box_torch = box_torch[:, None, :] # (B, 1, 4)
|
33 |
|
34 |
+
box_torch=box_torch.reshape(1,4)
|
35 |
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
|
36 |
points=None,
|
37 |
boxes=box_torch,
|
|
|
70 |
ax[1].imshow(image, cmap='gray')
|
71 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
72 |
plt.tight_layout()
|
73 |
+
return fig
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Main function for Gradio app
|
76 |
def process_images(file, x_min, y_min, x_max, y_max):
|
|
|
78 |
|
79 |
# Load and preprocess image
|
80 |
image, H, W = load_image(file)
|
81 |
+
if len(image.shape) == 2:
|
82 |
+
image = np.repeat(image[:, :, None], 3, axis=-1)
|
83 |
+
H, W, _ = image.shape
|
84 |
+
|
85 |
+
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
|
86 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
87 |
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
|
88 |
|
|
|
105 |
|
106 |
# Visualization
|
107 |
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
|
108 |
+
return visualization #.getvalue()
|
109 |
|
110 |
|
111 |
# Set up Gradio interface
|