import gradio as gr import pandas as pd import numpy as np import pydicom import os from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import io # Function to load bounding boxes from CSV def load_bounding_boxes(csv_file): # Assuming CSV file has columns: 'filename', 'x_min', 'y_min', 'x_max', 'y_max' df = pd.read_csv(csv_file) return df def load_image(file_path): if file_path.endswith(".dcm"): ds = pydicom.dcmread(file_path) img = ds.pixel_array else: img = np.array(Image.open(file_path).convert('L')) # Convert to grayscale H, W = img.shape return img, H, W # MedSAM inference function def medsam_inference(medsam_model, img, box, H, W, target_size): # Resize image and box to target size img_resized = transform.resize(img, (target_size, target_size), anti_aliasing=True) box_resized = np.array(box) * (target_size / np.array([W, H, W, H])) # Convert image to PyTorch tensor img_tensor = torch.from_numpy(img_resized).float().unsqueeze(0).unsqueeze(0).to(device) # Add channel and batch dimension # Model expects box in format (x0, y0, x1, y1) box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension # MedSAM inference img_embed = medsam_model.image_encoder(img_tensor) mask = medsam_model.predict(img_embed, box_tensor) # Post-process mask: resize back to original size mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W)) return mask_resized # Function for visualizing images with masks def visualize(image, mask, box): fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(image, cmap='gray') ax[0].add_patch(plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], edgecolor="red", facecolor="none")) ax[1].imshow(image, cmap='gray') ax[1].imshow(mask, alpha=0.5, cmap="jet") plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) return buf # Main function for Gradio app def process_images(file, x_min, y_min, x_max, y_max): image, H, W = load_image(file) # Initialize MedSAM model device = 'cuda' if torch.cuda.is_available() else 'cpu' medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path medsam_model = medsam_model.to(device) medsam_model.eval() box = [x_min, y_min, x_max, y_max] mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height visualization = visualize(image, mask, box) return visualization.getvalue() # Returning the byte stream # Set up Gradio interface iface = gr.Interface( fn=process_images, inputs=[ gr.File(label="MRI Slice (DICOM, PNG, etc.)"), gr.Number(label="X min"), gr.Number(label="Y min"), gr.Number(label="X max"), gr.Number(label="Y max") ], outputs="plot" ) iface.launch()