import gradio as gr import pandas as pd import numpy as np import pydicom import os from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import io # Function to load bounding boxes from CSV def load_bounding_boxes(csv_file): # Assuming CSV file has columns: 'filename', 'x_min', 'y_min', 'x_max', 'y_max' df = pd.read_csv(csv_file) return df # Function to load DICOM images def load_dicom_images(folder_path): images = [] for filename in sorted(os.listdir(folder_path)): if filename.endswith(".dcm"): filepath = os.path.join(folder_path, filename) ds = pydicom.dcmread(filepath) img = ds.pixel_array images.append(img) return np.array(images) # MedSAM inference function def medsam_inference(medsam_model, img, box, H, W, target_size): # Resize image and box to target size img_resized = transform.resize(img, (target_size, target_size), anti_aliasing=True) box_resized = np.array(box) * (target_size / np.array([W, H, W, H])) # Convert image to PyTorch tensor img_tensor = torch.from_numpy(img_resized).float().unsqueeze(0).unsqueeze(0).to(device) # Add channel and batch dimension # Model expects box in format (x0, y0, x1, y1) box_tensor = torch.tensor(box_resized, dtype=torch.float32).unsqueeze(0).to(device) # Add batch dimension # MedSAM inference img_embed = medsam_model.image_encoder(img_tensor) mask = medsam_model.predict(img_embed, box_tensor) # Post-process mask: resize back to original size mask_resized = transform.resize(mask[0].cpu().numpy(), (H, W)) return mask_resized # Function for visualizing images with masks def visualize(images, masks, box): fig, ax = plt.subplots(len(images), 2, figsize=(10, 5*len(images))) for i, (image, mask) in enumerate(zip(images, masks)): ax[i, 0].imshow(image, cmap='gray') ax[i, 0].add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], edgecolor="red", facecolor="none")) ax[i, 1].imshow(image, cmap='gray') ax[i, 1].imshow(mask, alpha=0.5, cmap="jet") plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png') plt.close(fig) buf.seek(0) return buf # Main function for Gradio app def process_images(csv_file, dicom_folder, target_size): bounding_boxes = load_bounding_boxes(csv_file) dicom_images = load_dicom_images(dicom_folder) # Initialize MedSAM model device = 'cuda' if torch.cuda.is_available() else 'cpu' medsam_model = sam_model_registry['your_model_version'](checkpoint='path_to_your_checkpoint') medsam_model = medsam_model.to(device) medsam_model.eval() masks = [] for index, row in bounding_boxes.iterrows(): if index >= len(dicom_images): continue # Skip if the index exceeds the number of images image = dicom_images[index] H, W = image.shape box = [row['x_min'], row['y_min'], row['x_max'], row['y_max']] mask = medsam_inference(medsam_model, image, box, H, W, target_size) masks.append(mask) visualizations = visualize(dicom_images, masks, box) return visualizations, np.array(masks) # Set up Gradio interface iface = gr.Interface( fn=process_images, inputs=[gr.inputs.File(type="file"), gr.inputs.Directory()], outputs=[gr.outputs.Image(type="plot"), gr.outputs.File(type="numpy")] ) iface.launch()