import gradio as gr import pandas as pd import numpy as np import pydicom import os from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import io # Function to load bounding boxes from CSV def load_bounding_boxes(csv_file): # Assuming CSV file has columns: 'filename', 'x_min', 'y_min', 'x_max', 'y_max' df = pd.read_csv(csv_file) return df def load_dicom_image(filename): if filename.endswith(".dcm"): ds = pydicom.dcmread(filename) img = ds.pixel_array H, W = img.shape return np.array(img), H, W # MedSAM inference function def medsam_inference(medsam_model, img, box, H, W, target_size): # Resize image and box to target size img_resized = transform.resize(img, (target_size, target_size), anti_aliasing=True) box_resized = np.array(box) * (target_size / np.array([W, H, W, H])) # Convert image to PyTorch tensor img_tensor = torch.from_numpy(img_resized).float().unsqueeze(0).unsqueeze(0).to(device) # Add channel and batch dimension # Model expects box in format (x0, y0, x1, y1) box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension # MedSAM inference img_embed = medsam_model.image_encoder(img_tensor) mask = medsam_model.predict(img_embed, box_tensor) # Post-process mask: resize back to original size mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W)) return mask_resized # Function for visualizing images with masks def visualize(images, masks, box): fig, ax = plt.subplots(len(images), 2, figsize=(10, 5*len(images))) for i, (image, mask) in enumerate(zip(images, masks)): ax[i, 0].imshow(image, cmap='gray') ax[i, 0].add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], edgecolor="red", facecolor="none")) ax[i, 1].imshow(image, cmap='gray') ax[i, 1].imshow(mask, alpha=0.5, cmap="jet") plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) return buf # Main function for Gradio app def process_images(csv_file, dicom_file): bounding_boxes = load_bounding_boxes(csv_file) image, H, W = load_dicom_image(dicom_file) # Initialize MedSAM model device = 'cuda' if torch.cuda.is_available() else 'cpu' medsam_model = sam_model_registry['vit_b'](checkpoint="medsam_vit_b.pth") # Ensure the correct path medsam_model = medsam_model.to(device) medsam_model.eval() masks = [] boxes = [] for index, row in bounding_boxes.iterrows(): box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']] mask = medsam_inference(medsam_model, image, box, H, W, H) # Assuming target size is the same as the image height masks.append(mask) boxes.append(box) visualizations = visualize([image] * len(masks), masks, boxes) return visualizations.getvalue() # Set up Gradio interface iface = gr.Interface( fn=process_images, inputs=[ gr.File(label="CSV File"), gr.File(label="DICOM File")], outputs="plot" ) iface.launch()