Spaces:
Running
Running
File size: 4,523 Bytes
54b4471 16fa719 54b4471 34cc7b2 3cdf75f 54b4471 55223b8 9840e47 3cdf75f 55223b8 9ba0bac 55223b8 54b4471 f0b3d8c 105e5ef 54b4471 f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef 54b4471 34cc7b2 af6805d 201e3ec 34cc7b2 105e5ef 3cdf75f 16fa719 f0b3d8c 34cc7b2 f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef f0b3d8c 105e5ef 54b4471 34cc7b2 f0b3d8c 34cc7b2 03725aa 34cc7b2 03725aa 105e5ef 34cc7b2 f0b3d8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
import nrrd # Add this import for NRRD file support
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path) # Add this condition for NRRD files
else:
img = np.array(Image.open(file_path))
# Convert grayscale to 3-channel RGB by replicating channels
if len(img.shape) == 2: # Grayscale image (height, width)
img = np.stack((img,)*3, axis=-1) # Replicate grayscale channel to get (height, width, 3)
H, W = img.shape[:2]
return img, H, W
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, points_1024, H, W):
points_torch = torch.as_tensor(points_1024, dtype=torch.float, device=img_embed.device)
points_torch = points_torch.reshape(1, -1, 2) # (1, N, 2)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=points_torch,
boxes=None,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(H, W),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (H, W)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
# Function for visualizing images with masks
def visualize(image, mask, points):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image, cmap='gray')
for point in points:
ax[0].plot(point[0], point[1], 'ro') # Mark points on the image
ax[1].imshow(image, cmap='gray')
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
# Convert matplotlib figure to a PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig) # Close the figure to release memory
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
# Main function for Gradio app
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load and preprocess image
img = img_dict['image']
points = img_dict['points']
if len(points) == 0:
raise ValueError("No points provided.")
image, H, W = img, img.shape[0], img.shape[1]
if len(image.shape) == 2:
image = np.repeat(image[:, :, None], 3, axis=-1)
H, W, _ = image.shape
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)
# Initialize the MedSAM model and set the device
model_checkpoint_path = "medsam_vit_b.pth" # Replace with the correct path to your checkpoint
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
# Calculate resized point coordinates
scale_factors = np.array([1024 / W, 1024 / H])
points_1024 = np.array(points) * scale_factors
# Perform inference
mask = medsam_inference(medsam_model, img_embed, points_1024, H, W)
# Visualization
visualization = visualize(image, mask, points)
return visualization
# Set up Gradio interface
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select points of interest for processing."
)
# Launch the interface
iface.launch()
|