MedSAMTest / app.py
vincentgao95's picture
Change prompt from a bounding box to point and click
105e5ef verified
raw
history blame
4.52 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
import nrrd # Add this import for NRRD file support
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path) # Add this condition for NRRD files
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, points_1024, H, W):
points_torch = torch.as_tensor(points_1024, dtype=torch.float, device=img_embed.device)
points_torch = points_torch.reshape(1, -1, 2) # (1, N, 2)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=points_torch,
boxes=None,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (H, W)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
# Function for visualizing images with masks
def visualize(image, mask, points):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
for point in points:
ax[0].plot(point[0], point[1], 'ro') # Mark points on the image
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
# Convert matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to release memory
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
# Main function for Gradio app
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess image
img = img_dict['image']
points = img_dict['points']
if len(points) == 0:
raise ValueError("No points provided.")
image, H, W = img, img.shape[0], img.shape[1]
if len(image.shape) == 2:
image = np.repeat(image[:, :, None], 3, axis=-1)
H, W, _ = image.shape
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Calculate resized point coordinates
scale_factors = np.array([1024 / W, 1024 / H])
points_1024 = np.array(points) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, points_1024, H, W)
# Visualization
visualization = visualize(image, mask, points)
return visualization
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select points of interest for processing."
)
# Launch the interface
iface.launch()