Spaces:
Running
Running
vincentgao95
commited on
Change prompt from a bounding box to point and click
Browse files
app.py
CHANGED
@@ -30,25 +30,23 @@ def load_image(file_path):
|
|
30 |
return img, H, W
|
31 |
|
32 |
@torch.no_grad()
|
33 |
-
def medsam_inference(medsam_model, img_embed,
|
34 |
-
|
35 |
-
|
36 |
-
box_torch = box_torch[:, None, :] # (B, 1, 4)
|
37 |
|
38 |
-
box_torch=box_torch.reshape(1,4)
|
39 |
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
|
40 |
-
points=
|
41 |
-
boxes=
|
42 |
masks=None,
|
43 |
)
|
44 |
|
45 |
low_res_logits, _ = medsam_model.mask_decoder(
|
46 |
-
image_embeddings=img_embed,
|
47 |
-
image_pe=medsam_model.prompt_encoder.get_dense_pe(),
|
48 |
-
sparse_prompt_embeddings=sparse_embeddings,
|
49 |
-
dense_prompt_embeddings=dense_embeddings,
|
50 |
multimask_output=False,
|
51 |
-
|
52 |
|
53 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
54 |
|
@@ -58,15 +56,16 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
58 |
mode="bilinear",
|
59 |
align_corners=False,
|
60 |
) # (1, 1, gt.shape)
|
61 |
-
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (
|
62 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
63 |
return medsam_seg
|
64 |
|
65 |
# Function for visualizing images with masks
|
66 |
-
def visualize(image, mask,
|
67 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
68 |
ax[0].imshow(image, cmap='gray')
|
69 |
-
|
|
|
70 |
ax[1].imshow(image, cmap='gray')
|
71 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
72 |
plt.tight_layout()
|
@@ -78,19 +77,18 @@ def visualize(image, mask, box):
|
|
78 |
buf.seek(0)
|
79 |
pil_img = Image.open(buf)
|
80 |
|
81 |
-
return pil_img
|
|
|
82 |
# Main function for Gradio app
|
83 |
def process_images(img_dict):
|
84 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
85 |
|
86 |
# Load and preprocess image
|
87 |
-
print(img_dict)
|
88 |
img = img_dict['image']
|
89 |
-
points = img_dict['points']
|
90 |
-
if len(points)
|
91 |
-
|
92 |
-
|
93 |
-
raise ValueError("Insufficient data for bounding box coordinates.")
|
94 |
image, H, W = img, img.shape[0], img.shape[1]
|
95 |
if len(image.shape) == 2:
|
96 |
image = np.repeat(image[:, :, None], 3, axis=-1)
|
@@ -106,20 +104,17 @@ def process_images(img_dict):
|
|
106 |
medsam_model = medsam_model.to(device)
|
107 |
medsam_model.eval()
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
# Calculate resized box coordinates
|
114 |
-
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
|
115 |
-
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
|
116 |
|
117 |
# Perform inference
|
118 |
-
mask = medsam_inference(medsam_model, img_embed,
|
119 |
|
120 |
# Visualization
|
121 |
-
visualization = visualize(image, mask,
|
122 |
return visualization
|
|
|
123 |
# Set up Gradio interface
|
124 |
iface = gr.Interface(
|
125 |
fn=process_images,
|
@@ -130,7 +125,7 @@ iface = gr.Interface(
|
|
130 |
gr.Image(type="pil", label="Processed Image")
|
131 |
],
|
132 |
title="ROI Selection with MEDSAM",
|
133 |
-
description="Upload an image (including NRRD files) and select
|
134 |
)
|
135 |
|
136 |
# Launch the interface
|
|
|
30 |
return img, H, W
|
31 |
|
32 |
@torch.no_grad()
|
33 |
+
def medsam_inference(medsam_model, img_embed, points_1024, H, W):
|
34 |
+
points_torch = torch.as_tensor(points_1024, dtype=torch.float, device=img_embed.device)
|
35 |
+
points_torch = points_torch.reshape(1, -1, 2) # (1, N, 2)
|
|
|
36 |
|
|
|
37 |
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
|
38 |
+
points=points_torch,
|
39 |
+
boxes=None,
|
40 |
masks=None,
|
41 |
)
|
42 |
|
43 |
low_res_logits, _ = medsam_model.mask_decoder(
|
44 |
+
image_embeddings=img_embed, # (B, 256, 64, 64)
|
45 |
+
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
|
46 |
+
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
47 |
+
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
48 |
multimask_output=False,
|
49 |
+
)
|
50 |
|
51 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
52 |
|
|
|
56 |
mode="bilinear",
|
57 |
align_corners=False,
|
58 |
) # (1, 1, gt.shape)
|
59 |
+
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (H, W)
|
60 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
61 |
return medsam_seg
|
62 |
|
63 |
# Function for visualizing images with masks
|
64 |
+
def visualize(image, mask, points):
|
65 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
66 |
ax[0].imshow(image, cmap='gray')
|
67 |
+
for point in points:
|
68 |
+
ax[0].plot(point[0], point[1], 'ro') # Mark points on the image
|
69 |
ax[1].imshow(image, cmap='gray')
|
70 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
71 |
plt.tight_layout()
|
|
|
77 |
buf.seek(0)
|
78 |
pil_img = Image.open(buf)
|
79 |
|
80 |
+
return pil_img
|
81 |
+
|
82 |
# Main function for Gradio app
|
83 |
def process_images(img_dict):
|
84 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
85 |
|
86 |
# Load and preprocess image
|
|
|
87 |
img = img_dict['image']
|
88 |
+
points = img_dict['points']
|
89 |
+
if len(points) == 0:
|
90 |
+
raise ValueError("No points provided.")
|
91 |
+
|
|
|
92 |
image, H, W = img, img.shape[0], img.shape[1]
|
93 |
if len(image.shape) == 2:
|
94 |
image = np.repeat(image[:, :, None], 3, axis=-1)
|
|
|
104 |
medsam_model = medsam_model.to(device)
|
105 |
medsam_model.eval()
|
106 |
|
107 |
+
# Calculate resized point coordinates
|
108 |
+
scale_factors = np.array([1024 / W, 1024 / H])
|
109 |
+
points_1024 = np.array(points) * scale_factors
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Perform inference
|
112 |
+
mask = medsam_inference(medsam_model, img_embed, points_1024, H, W)
|
113 |
|
114 |
# Visualization
|
115 |
+
visualization = visualize(image, mask, points)
|
116 |
return visualization
|
117 |
+
|
118 |
# Set up Gradio interface
|
119 |
iface = gr.Interface(
|
120 |
fn=process_images,
|
|
|
125 |
gr.Image(type="pil", label="Processed Image")
|
126 |
],
|
127 |
title="ROI Selection with MEDSAM",
|
128 |
+
description="Upload an image (including NRRD files) and select points of interest for processing."
|
129 |
)
|
130 |
|
131 |
# Launch the interface
|