Spaces:
Sleeping
Sleeping
File size: 4,420 Bytes
e399e14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
import cv2
import nrrd
from gradio_image_prompter import ImagePrompter
class PointPromptDemo:
def __init__(self, model):
self.model = model
self.model.eval()
self.image = None
self.image_embeddings = None
self.img_size = None
@torch.no_grad()
def infer(self, x, y):
coords_1024 = np.array([[[
x * 1024 / self.img_size[1],
y * 1024 / self.img_size[0]
]]])
coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device)
labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device)
point_prompt = (coords_torch, labels_torch)
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=point_prompt,
boxes=None,
masks=None,
)
low_res_logits, _ = self.model.mask_decoder(
image_embeddings=self.image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_probs = torch.sigmoid(low_res_logits)
low_res_pred = F.interpolate(
low_res_probs,
size=self.img_size,
mode='bilinear',
align_corners=False
)
low_res_pred = low_res_pred.detach().cpu().numpy().squeeze()
seg = np.uint8(low_res_pred > 0.5)
return seg
def set_image(self, image):
self.img_size = image.shape[:2]
if len(image.shape) == 2:
image = np.repeat(image[:,:,None], 3, -1)
self.image = image
image_preprocess = self.preprocess_image(self.image)
with torch.no_grad():
self.image_embeddings = self.model.image_encoder(image_preprocess)
def preprocess_image(self, image):
img_resize = cv2.resize(
image,
(1024, 1024),
interpolation=cv2.INTER_CUBIC
)
img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None)
assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]'
img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device)
return img_tensor
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path)
else:
img = np.array(Image.open(file_path))
if len(img.shape) == 2:
img = np.stack((img,)*3, axis=-1)
return img
def visualize(image, mask):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image)
ax[1].imshow(image)
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img = img_dict['image']
points = img_dict['points'][0]
if len(points) < 2:
raise ValueError("At least one point is required for ROI selection.")
x, y = points[0], points[1]
model_checkpoint_path = "medsam_point_prompt_flare22.pth"
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
point_prompt_demo = PointPromptDemo(medsam_model)
point_prompt_demo.set_image(img)
mask = point_prompt_demo.infer(x, y)
visualization = visualize(img, mask)
return visualization
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select a point for ROI processing."
)
iface.launch()
|