MedSAMTest / app.py
dennistrujillo's picture
changed interface to allow for bb selection
34cc7b2
raw
history blame
5.62 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
box_torch=box_torch.reshape(1,4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
# Function for visualizing images with masks
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
# Convert matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to release memory
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
# Main function for Gradio app
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess image
img = img_dict['image']
points = img_dict['points'][0] # Accessing the first (and possibly only) set of points
if len(points) >= 6:
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
else:
raise ValueError("Insufficient data for bounding box coordinates.")
image, H, W = img, img.shape[0], img.shape[1] #
if len(image.shape) == 2:
image = np.repeat(image[:, :, None], 3, axis=-1)
H, W, _ = image.shape
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Generate image embedding
with torch.no_grad():
img_embed = medsam_model.image_encoder(image_tensor)
# Calculate resized box coordinates
scale_factors = np.array([1024 / W, 1024 / H, 1024 / W, 1024 / H])
box_1024 = np.array([x_min, y_min, x_max, y_max]) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, box_1024, H, W)
# Visualization
visualization = visualize(image, mask, [x_min, y_min, x_max, y_max])
return visualization
def echo(x_min, y_min, x_max, y_max):
print(x_min, y_min, x_max, y_max)
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Select ROIs") # Custom image prompter for selecting regions of interest
],
outputs=[
gr.Image(type="pil", label="Processed Image"), # Image output
],
title="Image Processing with Custom Prompts",
description="Upload an image and select regions of interest for processing."
)
# Launch the interface
iface.launch()
'''iface= gr.Interface(fn=process_images,
inputs=[lambda prompts: (prompts["image"], prompts["points"]),
ImagePrompter(show_label=False)],
outputs="plot")'''
'''iface = gr.Interface(
lambda prompts: (prompts["image"], prompts["points"]),
ImagePrompter(show_label=False),
[gr.Image(show_label=False), gr.Dataframe(label="Points")],
)
'''
'''gr.Interface(
fn=process_images,
inputs=[
gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
gr.Number(label="X min"),
gr.Number(label="Y min"),
gr.Number(label="X max"),
gr.Number(label="Y max")
],
outputs="plot"
)'''