import gradio as gr from fastmri.data.subsample import create_mask_for_mask_type from fastmri.data.transforms import apply_mask, to_tensor, center_crop import skimage import fastmri import numpy as np import pandas as pd import torch import matplotlib.pyplot as plt import uuid # st.title('FastMRI Kspace Reconstruction Masks') # st.write('This app allows you to visualize the masks and their effects on the kspace data.') def main_func( mask_name: str, mask_center_fractions: int, accelerations: int, seed: int, slice_index: int, # input_image: str, ): # file_dict = { # "knee singlecoil": "data/knee1_kspace.npy", # "knee multicoil": "data/knee2_kspace.npy", # "brain multicoil 1": "data/brain1_kspace.npy", # "brain multicoil 2": "data/brain2_kspace.npy", # "prostate multicoil 1": "data/prostate1_kspace.npy", # "prostate multicoil 2": "data/prostate2_kspace.npy", # } # input_file_path = file_dict[input_image] # kspace = np.load(input_file_path) kspace = np.load("data/prostate1_kspace.npy") kspace = to_tensor(kspace) mask_func = create_mask_for_mask_type( mask_name, center_fractions=[mask_center_fractions], accelerations=[accelerations] ) subsampled_kspace, mask, num_low_frequencies = apply_mask( kspace, mask_func, seed=seed, ) print(mask.shape) print(subsampled_kspace.shape) print(kspace.shape) mask = mask.squeeze() # 451 mask = mask.unsqueeze(0) # 1, 451 mask = mask.repeat(subsampled_kspace.shape[-3], 1).cpu().numpy() print(mask.shape) print() subsampled_kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(subsampled_kspace)), dim=1) kspace = fastmri.rss(fastmri.complex_abs(fastmri.ifft2c(kspace)), dim=1) print(subsampled_kspace.shape) print(kspace.shape) subsampled_kspace = subsampled_kspace[slice_index] kspace = kspace[slice_index] print(subsampled_kspace.shape) print(kspace.shape) subsampled_kspace = center_crop(subsampled_kspace, (320, 320)) kspace = center_crop(kspace, (320, 320)) # now that we have the reconstructions, we can calculate the SSIM and psnr kspace = kspace.cpu().numpy() subsampled_kspace = subsampled_kspace.cpu().numpy() ssim = skimage.metrics.structural_similarity(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) psnr = skimage.metrics.peak_signal_noise_ratio(subsampled_kspace, kspace, data_range=kspace.max() - kspace.min()) df = pd.DataFrame({"SSIM": [ssim], "PSNR": [psnr], "Num Low Frequencies": [num_low_frequencies]}) print(df) # create a plot fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].imshow(mask, cmap="gray") ax[0].set_title("Mask") ax[0].axis("off") ax[1].imshow(subsampled_kspace, cmap="gray") ax[1].set_title("Reconstructed Image") ax[1].axis("off") ax[2].imshow(kspace, cmap="gray") ax[2].set_title("Original Image") ax[2].axis("off") plt.tight_layout() plot_filename = f"data/{uuid.uuid4()}.png" plt.savefig(plot_filename) return df, plot_filename demo = gr.Interface( fn=main_func, inputs=[ gr.Radio(['random', 'equispaced', "magic"], label="Mask Type", value="equispaced"), gr.Slider(minimum=0.0, maximum=1.0, value=0.4, label="Center Fraction"), gr.Number(value=4, label="Acceleration"), gr.Number(value=42, label="Seed"), gr.Number(value=15, label="Slice Index"), # gr.Radio(["knee singlecoil", "knee multicoil", "brain multicoil 1", "brain multicoil 2", "prostate multicoil 1", "prostate multicoil 2"], label="Input Image") ], outputs=[ gr.Dataframe(headers=["SSIM", "PSNR", "Num Low Frequencies"]), gr.Image(type="filepath", label="Plot"), # gr.Image(type="numpy", image_mode="L", label="Mask",), # gr.Image(type="numpy", image_mode="L", label="Reconstructed Image", height=320, width=320), # gr.Image(type="numpy", image_mode="L", label="Original Image", height=320, width=320), ], title="FastMRI Kspace Reconstruction Masks", description="This app allows you to visualize the masks and their effects on the kspace data." ) demo.launch()