dennistrujillo's picture
Create app.py
e399e14 verified
raw
history blame
4.42 kB
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
import cv2
import nrrd
from gradio_image_prompter import ImagePrompter
class PointPromptDemo:
def __init__(self, model):
self.model = model
self.model.eval()
self.image = None
self.image_embeddings = None
self.img_size = None
@torch.no_grad()
def infer(self, x, y):
coords_1024 = np.array([[[
x * 1024 / self.img_size[1],
y * 1024 / self.img_size[0]
]]])
coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device)
labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device)
point_prompt = (coords_torch, labels_torch)
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=point_prompt,
boxes=None,
masks=None,
)
low_res_logits, _ = self.model.mask_decoder(
image_embeddings=self.image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_probs = torch.sigmoid(low_res_logits)
low_res_pred = F.interpolate(
low_res_probs,
size=self.img_size,
mode='bilinear',
align_corners=False
)
low_res_pred = low_res_pred.detach().cpu().numpy().squeeze()
seg = np.uint8(low_res_pred > 0.5)
return seg
def set_image(self, image):
self.img_size = image.shape[:2]
if len(image.shape) == 2:
image = np.repeat(image[:,:,None], 3, -1)
self.image = image
image_preprocess = self.preprocess_image(self.image)
with torch.no_grad():
self.image_embeddings = self.model.image_encoder(image_preprocess)
def preprocess_image(self, image):
img_resize = cv2.resize(
image,
(1024, 1024),
interpolation=cv2.INTER_CUBIC
)
img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None)
assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]'
img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device)
return img_tensor
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path)
else:
img = np.array(Image.open(file_path))
if len(img.shape) == 2:
img = np.stack((img,)*3, axis=-1)
return img
def visualize(image, mask):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image)
ax[1].imshow(image)
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img = img_dict['image']
points = img_dict['points'][0]
if len(points) < 2:
raise ValueError("At least one point is required for ROI selection.")
x, y = points[0], points[1]
model_checkpoint_path = "medsam_point_prompt_flare22.pth"
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
point_prompt_demo = PointPromptDemo(medsam_model)
point_prompt_demo.set_image(img)
mask = point_prompt_demo.infer(x, y)
visualization = visualize(img, mask)
return visualization
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select a point for ROI processing."
)
iface.launch()