Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastmri.data.subsample import create_mask_for_mask_type | |
from fastmri.data.transforms import apply_mask, to_tensor, center_crop | |
import skimage | |
import fastmri | |
import numpy as np | |
import pandas as pd | |
import torch | |
import matplotlib.pyplot as plt | |
import uuid | |
# st.title('FastMRI Kspace Reconstruction Masks') | |
# st.write('This app allows you to visualize the masks and their effects on the kspace data.') | |
def main_func( | |
mask_name: str, | |
mask_center_fractions: int, | |
accelerations: int, | |
seed: int, | |
slice_index: int, | |
# input_image: str, | |
): | |
# file_dict = { | |
# "knee singlecoil": "data/knee1_kspace.npy", | |
# "knee multicoil": "data/knee2_kspace.npy", | |
# "brain multicoil 1": "data/brain1_kspace.npy", | |
# "brain multicoil 2": "data/brain2_kspace.npy", | |
# "prostate multicoil 1": "data/prostate1_kspace.npy", | |
# "prostate multicoil 2": "data/prostate2_kspace.npy", | |
# } | |
# input_file_path = file_dict[input_image] | |
# kspace = np.load(input_file_path) | |
kspace = np.load("data/prostate1_kspace.npy") | |
kspace = to_tensor(kspace) | |
mask_func = create_mask_for_mask_type( | |
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations] | |
) | |
subsampled_kspace, mask, num_low_frequencies = apply_mask( | |
kspace, | |
mask_func, | |
seed=seed, | |
) | |
print(mask.shape) | |
print(subsampled_kspace.shape) | |
print(kspace.shape) | |
mask = mask.squeeze() # 451 | |
mask = mask.unsqueeze(0) # 1, 451 | |
mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy() | |
print(mask.shape) | |
print() | |
subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1) | |
kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1) | |
print(subsampled_kspace.shape) | |
print(kspace.shape) | |
subsampled_kspace = subsampled_kspace[slice_index] | |
kspace = kspace[slice_index] | |
print(subsampled_kspace.shape) | |
print(kspace.shape) | |
subsampled_kspace = center_crop(subsampled_kspace, (320, 320)) | |
kspace = center_crop(kspace, (320, 320)) | |
# now that we have the reconstructions, we can calculate the SSIM and psnr | |
kspace = kspace.cpu().numpy() | |
subsampled_kspace = subsampled_kspace.cpu().numpy() | |
ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) | |
psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) | |
df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]}) | |
print(df) | |
# create a plot | |
fig, ax = plt.subplots(1, 3, figsize=(15, 5)) | |
ax[0].imshow(mask, cmap="gray") | |
ax[0].set_title("Mask") | |
ax[0].axis("off") | |
ax[1].imshow(subsampled_kspace, cmap="gray") | |
ax[1].set_title("Reconstructed Image") | |
ax[1].axis("off") | |
ax[2].imshow(kspace, cmap="gray") | |
ax[2].set_title("Original Image") | |
ax[2].axis("off") | |
plt.tight_layout() | |
plot_filename = f"data/{uuid.uuid4()}.png" | |
plt.savefig(plot_filename) | |
return df, plot_filename | |
demo = gr.Interface( | |
fn=main_func, | |
inputs=[ | |
gr.Radio(['random', 'equispaced', "magic"], label="Mask Type", value="equispaced"), | |
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"), | |
gr.Number(value=4, label="Acceleration"), | |
gr.Number(value=42, label="Seed"), | |
gr.Number(value=15, label="Slice Index"), | |
# gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image") | |
], | |
outputs=[ | |
gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]), | |
gr.Image(type="filepath", label="Plot"), | |
# gr.Image(type="numpy", image_mode="L", label="Mask",), | |
# gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320), | |
# gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320), | |
], | |
title="FastMRI Kspace Reconstruction Masks", | |
description="This app allows you to visualize the masks and their effects on the kspace data." | |
) | |
demo.launch() |