osbm's picture
discard other methods for now
45b3e15
raw
history blame
4.31 kB
import gradio as gr
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import apply_mask, to_tensor, center_crop
import skimage
import fastmri
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import uuid
# st.title('FastMRI Kspace Reconstruction Masks')
# st.write('This app allows you to visualize the masks and their effects on the kspace data.')
def main_func(
mask_name: str,
mask_center_fractions: int,
accelerations: int,
seed: int,
slice_index: int,
# input_image: str,
):
# file_dict = {
# "knee singlecoil": "data/knee1_kspace.npy",
# "knee multicoil": "data/knee2_kspace.npy",
# "brain multicoil 1": "data/brain1_kspace.npy",
# "brain multicoil 2": "data/brain2_kspace.npy",
# "prostate multicoil 1": "data/prostate1_kspace.npy",
# "prostate multicoil 2": "data/prostate2_kspace.npy",
# }
# input_file_path = file_dict[input_image]
# kspace = np.load(input_file_path)
kspace = np.load("data/prostate1_kspace.npy")
kspace = to_tensor(kspace)
mask_func = create_mask_for_mask_type(
mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations]
)
subsampled_kspace, mask, num_low_frequencies = apply_mask(
kspace,
mask_func,
seed=seed,
)
print(mask.shape)
print(subsampled_kspace.shape)
print(kspace.shape)
mask = mask.squeeze() # 451
mask = mask.unsqueeze(0) # 1, 451
mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy()
print(mask.shape)
print()
subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1)
kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1)
print(subsampled_kspace.shape)
print(kspace.shape)
subsampled_kspace = subsampled_kspace[slice_index]
kspace = kspace[slice_index]
print(subsampled_kspace.shape)
print(kspace.shape)
subsampled_kspace = center_crop(subsampled_kspace, (320, 320))
kspace = center_crop(kspace, (320, 320))
# now that we have the reconstructions, we can calculate the SSIM and psnr
kspace = kspace.cpu().numpy()
subsampled_kspace = subsampled_kspace.cpu().numpy()
ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min())
df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]})
print(df)
# create a plot
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(mask, cmap="gray")
ax[0].set_title("Mask")
ax[0].axis("off")
ax[1].imshow(subsampled_kspace, cmap="gray")
ax[1].set_title("Reconstructed Image")
ax[1].axis("off")
ax[2].imshow(kspace, cmap="gray")
ax[2].set_title("Original Image")
ax[2].axis("off")
plt.tight_layout()
plot_filename = f"data/{uuid.uuid4()}.png"
plt.savefig(plot_filename)
return df, plot_filename
demo = gr.Interface(
fn=main_func,
inputs=[
gr.Radio(['random', 'equispaced', "magic"], label="Mask Type", value="equispaced"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"),
gr.Number(value=4, label="Acceleration"),
gr.Number(value=42, label="Seed"),
gr.Number(value=15, label="Slice Index"),
# gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image")
],
outputs=[
gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]),
gr.Image(type="filepath", label="Plot"),
# gr.Image(type="numpy", image_mode="L", label="Mask",),
# gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320),
# gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320),
],
title="FastMRI Kspace Reconstruction Masks",
description="This app allows you to visualize the masks and their effects on the kspace data."
)
demo.launch()