|
import gradio as gr |
|
|
|
from fastmri.data.subsample import create_mask_for_mask_type |
|
from fastmri.data.transforms import apply_mask, to_tensor, center_crop |
|
import skimage |
|
import fastmri |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
import matplotlib.pyplot as plt |
|
import uuid |
|
|
|
|
|
|
|
|
|
|
|
def main_func( |
|
mask_name: str, |
|
mask_center_fractions: int, |
|
accelerations: int, |
|
seed: int, |
|
slice_index: int, |
|
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kspace = np.load("data/prostate1_kspace.npy") |
|
kspace = to_tensor(kspace) |
|
mask_func = create_mask_for_mask_type( |
|
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations] |
|
) |
|
subsampled_kspace, mask, num_low_frequencies = apply_mask( |
|
kspace, |
|
mask_func, |
|
seed=seed, |
|
) |
|
|
|
print(mask.shape) |
|
print(subsampled_kspace.shape) |
|
print(kspace.shape) |
|
|
|
mask = mask.squeeze() |
|
mask = mask.unsqueeze(0) |
|
mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy() |
|
|
|
print(mask.shape) |
|
print() |
|
|
|
subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1) |
|
kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1) |
|
|
|
print(subsampled_kspace.shape) |
|
print(kspace.shape) |
|
|
|
subsampled_kspace = subsampled_kspace[slice_index] |
|
kspace = kspace[slice_index] |
|
|
|
print(subsampled_kspace.shape) |
|
print(kspace.shape) |
|
|
|
|
|
subsampled_kspace = center_crop(subsampled_kspace, (320, 320)) |
|
kspace = center_crop(kspace, (320, 320)) |
|
|
|
|
|
kspace = kspace.cpu().numpy() |
|
subsampled_kspace = subsampled_kspace.cpu().numpy() |
|
|
|
|
|
|
|
ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) |
|
psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) |
|
|
|
df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]}) |
|
print(df) |
|
|
|
|
|
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
|
ax[0].imshow(mask, cmap="gray") |
|
ax[0].set_title("Mask") |
|
ax[0].axis("off") |
|
|
|
ax[1].imshow(subsampled_kspace, cmap="gray") |
|
ax[1].set_title("Reconstructed Image") |
|
ax[1].axis("off") |
|
|
|
ax[2].imshow(kspace, cmap="gray") |
|
ax[2].set_title("Original Image") |
|
ax[2].axis("off") |
|
|
|
plt.tight_layout() |
|
plot_filename = f"data/{uuid.uuid4()}.png" |
|
plt.savefig(plot_filename) |
|
|
|
return df, plot_filename |
|
|
|
|
|
demo = gr.Interface( |
|
fn=main_func, |
|
inputs=[ |
|
gr.Radio(['random', 'equispaced', "equispaced_fraction", "magic", "magic_fraction"], label="Mask Type", value="equispaced"), |
|
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"), |
|
gr.Number(value=4, label="Acceleration"), |
|
gr.Number(value=42, label="Seed"), |
|
gr.Number(value=15, label="Slice Index"), |
|
|
|
], |
|
outputs=[ |
|
gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]), |
|
gr.Image(type="filepath", label="Plot"), |
|
|
|
|
|
|
|
], |
|
title="FastMRI Kspace Reconstruction Masks", |
|
description="This app allows you to visualize the masks and their effects on the kspace data." |
|
) |
|
|
|
demo.launch() |