MedSAMTest / app.py
dennistrujillo's picture
Update app.py
55223b8 verified
raw
history blame
3.1 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import io
# Function to load bounding boxes from CSV
def load_bounding_boxes(csv_file):
# Assuming CSV file has columns: 'filename', 'x_min', 'y_min', 'x_max', 'y_max'
df = pd.read_csv(csv_file)
return df
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
else:
img = np.array(Image.open(file_path).convert('L')) # Convert to grayscale
H, W = img.shape
return img, H, W
# MedSAM inference function
def medsam_inference(medsam_model, img, box, H, W, target_size):
# Resize image and box to target size
img_resized = transform.resize(img, (target_size, target_size), anti_aliasing=True)
box_resized = np.array(box) * (target_size / np.array([W, H, W, H]))
# Convert image to PyTorch tensor
img_tensor = torch.from_numpy(img_resized).float().unsqueeze(0).unsqueeze(0).to(device) # Add channel and batch dimension
# Model expects box in format (x0, y0, x1, y1)
box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension
# MedSAM inference
img_embed = medsam_model.image_encoder(img_tensor)
mask = medsam_model.predict(img_embed, box_tensor)
# Post-process mask: resize back to original size
mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W))
return mask_resized
# Function for visualizing images with masks
def visualize(image, mask, box):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none"))
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
return buf
# Main function for Gradio app
def process_images(file, x_min, y_min, x_max, y_max):
image, H, W = load_image(file)
# Initialize MedSAM model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path
medsam_model = medsam_model.to(device)
medsam_model.eval()
box = [x_min, y_min, x_max, y_max]
mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height
visualization = visualize(image, mask, box)
return visualization.getvalue() # Returning the byte stream
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
gr.File(label="MRI Slice (DICOM, PNG, etc.)"),
gr.Number(label="X min"),
gr.Number(label="Y min"),
gr.Number(label="X max"),
gr.Number(label="Y max")
],
outputs="plot"
)
iface.launch()