import gradio as gr import pandas as pd import numpy as np import pydicom import os from skimage import transform import torch from segment_anything import sam_model_registry import matplotlib.pyplot as plt from PIL import Image import torch.nn.functional as F import io import cv2 import nrrd from gradio_image_prompter import ImagePrompter class PointPromptDemo: def __init__(self, model): self.model = model self.model.eval() self.image = None self.image_embeddings = None self.img_size = None @torch.no_grad() def infer(self, x, y): coords_1024 = np.array([[[ x * 1024 / self.img_size[1], y * 1024 / self.img_size[0] ]]]) coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device) labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device) point_prompt = (coords_torch, labels_torch) sparse_embeddings, dense_embeddings = self.model.prompt_encoder( points=point_prompt, boxes=None, masks=None, ) low_res_logits, _ = self.model.mask_decoder( image_embeddings=self.image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) low_res_probs = torch.sigmoid(low_res_logits) low_res_pred = F.interpolate( low_res_probs, size=self.img_size, mode='bilinear', align_corners=False ) low_res_pred = low_res_pred.detach().cpu().numpy().squeeze() seg = np.uint8(low_res_pred > 0.5) return seg def set_image(self, image): self.img_size = image.shape[:2] if len(image.shape) == 2: image = np.repeat(image[:,:,None], 3, -1) self.image = image image_preprocess = self.preprocess_image(self.image) with torch.no_grad(): self.image_embeddings = self.model.image_encoder(image_preprocess) def preprocess_image(self, image): img_resize = cv2.resize( image, (1024, 1024), interpolation=cv2.INTER_CUBIC ) img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None) assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]' img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device) return img_tensor def load_image(file_path): if file_path.endswith(".dcm"): ds = pydicom.dcmread(file_path) img = ds.pixel_array elif file_path.endswith(".nrrd"): img, _ = nrrd.read(file_path) else: img = np.array(Image.open(file_path)) if len(img.shape) == 2: img = np.stack((img,)*3, axis=-1) return img def visualize(image, mask): fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(image) ax[1].imshow(image) ax[1].imshow(mask, alpha=0.5, cmap="jet") plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format='png') plt.close(fig) buf.seek(0) pil_img = Image.open(buf) return pil_img def process_images(img_dict): device = 'cuda' if torch.cuda.is_available() else 'cpu' img = img_dict['image'] points = img_dict['points'][0] if len(points) < 2: raise ValueError("At least one point is required for ROI selection.") x, y = points[0], points[1] model_checkpoint_path = "medsam_point_prompt_flare22.pth" medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path) medsam_model = medsam_model.to(device) medsam_model.eval() point_prompt_demo = PointPromptDemo(medsam_model) point_prompt_demo.set_image(img) mask = point_prompt_demo.infer(x, y) visualization = visualize(img, mask) return visualization iface = gr.Interface( fn=process_images, inputs=[ ImagePrompter(label="Image") ], outputs=[ gr.Image(type="pil", label="Processed Image") ], title="ROI Selection with MEDSAM", description="Upload an image (including NRRD files) and select a point for ROI processing." ) iface.launch()