File size: 4,523 Bytes
54b4471
 
 
 
 
 
 
 
 
 
16fa719
54b4471
34cc7b2
3cdf75f
54b4471
55223b8
 
 
9840e47
3cdf75f
 
55223b8
9ba0bac
 
 
 
 
 
 
55223b8
54b4471
f0b3d8c
105e5ef
 
 
54b4471
f0b3d8c
105e5ef
 
f0b3d8c
 
 
 
105e5ef
 
 
 
f0b3d8c
105e5ef
f0b3d8c
 
 
 
 
 
 
 
 
105e5ef
f0b3d8c
 
 
 
105e5ef
f0b3d8c
 
105e5ef
 
f0b3d8c
 
 
 
 
 
 
 
 
 
 
105e5ef
 
54b4471
34cc7b2
af6805d
 
201e3ec
34cc7b2
105e5ef
 
 
 
3cdf75f
16fa719
 
 
 
f0b3d8c
 
 
34cc7b2
f0b3d8c
 
 
 
 
 
105e5ef
 
 
f0b3d8c
 
105e5ef
f0b3d8c
 
105e5ef
f0b3d8c
105e5ef
54b4471
 
34cc7b2
f0b3d8c
 
 
34cc7b2
03725aa
34cc7b2
03725aa
105e5ef
34cc7b2
 
 
f0b3d8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
from gradio_image_prompter import ImagePrompter
import nrrd  # Add this import for NRRD file support

def load_image(file_path):
    if file_path.endswith(".dcm"):
        ds = pydicom.dcmread(file_path)
        img = ds.pixel_array
    elif file_path.endswith(".nrrd"):
        img, _ = nrrd.read(file_path)  # Add this condition for NRRD files
    else:
        img = np.array(Image.open(file_path))

    # Convert grayscale to 3-channel RGB by replicating channels
    if len(img.shape) == 2:  # Grayscale image (height, width)
        img = np.stack((img,)*3, axis=-1)  # Replicate grayscale channel to get (height, width, 3)

    H, W = img.shape[:2]
    return img, H, W

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, points_1024, H, W):
    points_torch = torch.as_tensor(points_1024, dtype=torch.float, device=img_embed.device)
    points_torch = points_torch.reshape(1, -1, 2)  # (1, N, 2)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=points_torch,
        boxes=None,
        masks=None,
    )
    
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (H, W)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

# Function for visualizing images with masks
def visualize(image, mask, points):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image, cmap='gray')
    for point in points:
        ax[0].plot(point[0], point[1], 'ro')  # Mark points on the image
    ax[1].imshow(image, cmap='gray')
    ax[1].imshow(mask, alpha=0.5, cmap="jet")
    plt.tight_layout()

    # Convert matplotlib figure to a PIL Image
    buf = io.BytesIO()
    fig.savefig(buf, format='png')
    plt.close(fig)  # Close the figure to release memory
    buf.seek(0)
    pil_img = Image.open(buf)
    
    return pil_img

# Main function for Gradio app
def process_images(img_dict):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load and preprocess image
    img = img_dict['image']
    points = img_dict['points']
    if len(points) == 0:
        raise ValueError("No points provided.")
    
    image, H, W = img, img.shape[0], img.shape[1]
    if len(image.shape) == 2:
        image = np.repeat(image[:, :, None], 3, axis=-1)
    H, W, _ = image.shape

    image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
    image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
    image_tensor = torch.tensor(image_resized).float().permute(2, 0, 1).unsqueeze(0).to(device)

    # Initialize the MedSAM model and set the device
    model_checkpoint_path = "medsam_vit_b.pth"  # Replace with the correct path to your checkpoint
    medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
    medsam_model = medsam_model.to(device)
    medsam_model.eval()

    # Calculate resized point coordinates
    scale_factors = np.array([1024 / W, 1024 / H])
    points_1024 = np.array(points) * scale_factors

    # Perform inference
    mask = medsam_inference(medsam_model, img_embed, points_1024, H, W)

    # Visualization
    visualization = visualize(image, mask, points)
    return visualization

# Set up Gradio interface
iface = gr.Interface(
    fn=process_images,
    inputs=[
        ImagePrompter(label="Image")
    ],
    outputs=[
        gr.Image(type="pil", label="Processed Image")
    ],
    title="ROI Selection with MEDSAM",
    description="Upload an image (including NRRD files) and select points of interest for processing."
)

# Launch the interface
iface.launch()