import gradio as gr import pandas as pd import numpy as np import pydicom import os from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import torch.nn.functional as F import io from gradio_image_prompter import ImagePrompter import nrrd # Add this import for NRRD file support def load_image(file_path): if file_path.endswith(".dcm"): ds = pydicom.dcmread(file_path) img = ds.pixel_array elif file_path.endswith(".nrrd"): img, _ = nrrd.read(file_path) # Add this condition for NRRD files else: img = np.array(Image.open(file_path)) # Convert grayscale to 3-channel RGB by replicating channels if len(img.shape) == 2: # Grayscale image (height, width) img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3) H, W = img.shape[:2] return img, H, W # The rest of the code remains the same... # Main function for Gradio app def process_images(img_dict): device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load and preprocess image img = img_dict['image'] print(image.type()) points = img_dict['points'][0] # Accessing the first (and possibly only) set of points if len(points) >= 6: x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4] else: raise ValueError("Insufficient data for bounding box coordinates.") image, H, W = img, img.shape[0], img.shape[1] if len(image.shape) == 2: image = np.repeat(image[:, :, None], 3, axis=-1) H, W, _ = image.shape # The rest of the function remains the same... # Set up Gradio interface iface = gr.Interface( fn=process_images, inputs=[gr.File(label="image or nrrd")], outputs=[ gr.Image(type="pil", label="Processed Image") ], title="ROI Selection with MEDSAM", description="Upload an image (including NRRD files) and select regions of interest for processing." ) # Launch the interface iface.launch()