GANcMRI / app.py
vukadinovic936
added description
71c1b43
import streamlit as st
import tensorflow as tf
import pickle
import numpy as np
from pathlib import Path
import dnnlib
from dnnlib import tflib
import imageio
import os
import subprocess
import random
def check_gpu():
return tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None)
def generate_image_from_projected_latents(latent_vector):
images = Gs.components.synthesis.run(latent_vector, **Gs_kwargs)
return images
## define video generation methods
def ED_to_ES(latent_code):
all_imgs = []
amounts_up = [i/25 for i in range(0,25)]
amounts_down = [1-i/25 for i in range(1,26)]
for amount_to_move in amounts_up:
modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
images = generate_image_from_projected_latents(modified_latent_code)
all_imgs.append(np.array(images[0]))
for amount_to_move in amounts_down:
modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
images = generate_image_from_projected_latents(modified_latent_code)
all_imgs.append(np.array(images[0]))
return np.array(all_imgs)
def frame_to_frame(latent_code):
modified_latent_code = np.copy(latent_code)
full_video = [generate_image_from_projected_latents(modified_latent_code)]
for i in range(49):
modified_latent_code = modified_latent_code + latent_controls[f'{i}{i+1}']
ims = generate_image_from_projected_latents(modified_latent_code)
full_video.append(ims)
return np.array(full_video).squeeze()
@st.cache(allow_output_mutation=True) # Cache to avoid reloading the model every time
def load_initial_setup():
stream = open('best_net.pkl', 'rb')
tflib.init_tf()
sess=tf.get_default_session()
with stream:
G, D, Gs = pickle.load(stream, encoding='latin1')
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_kwargs.randomize_noise = False
#load latent directions
files = [x for x in Path('trajectories/').iterdir() if str(x).endswith('.npy')]
latent_controls = {f.name[:-4]:np.load(f) for f in files}
#select a random latent code
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
rnd = np.random.RandomState()
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})
return Gs, Gs_kwargs, latent_controls, sess
if __name__=="__main__":
# Set the directory to the script's location
dir_path = os.path.dirname(os.path.realpath(__file__))
heart_image_path = os.path.join(dir_path, 'heart.png')
st.markdown("""
<style>
.logo-test{
font-weight:700 !important;
font-size:50px !important;
color:#FF0000 !important;
text-align: center;
}
</style>
""",unsafe_allow_html=True)
st.markdown('<p class="logo-test">GANcMRI</p>', unsafe_allow_html=True)
# Description sliders
st.markdown("""
This demo showcases GANcMRI: Synthetic cardiac MRI generation.Upon starting the demo or refreshing the page, a unique video will be automatically generated based on two methods described in our paper: ED-to-ES and Frame-to-Frame. These methods simulate cardiac function and movement in a realistic manner. The demo includes interactive sliders that allow you to adjust two parameters:
1. **Sphericity Index:** This slider controls the sphericity of the left ventricle in the generated video. [The sphericity index](https://www.cell.com/med/pdf/S2666-6340(23)00069-7.pdf) is a measure of how spherical (round) the left ventricle appears, which is an important aspect in assessing certain heart conditions.
2. **Left Ventricular Volume:** With this slider, you can modify the size of the left ventricle.
""")
sphericity_index = st.slider("Sphericity Index", -2., 3., 0.0)
lv_area = st.slider("Left Ventricular Volume", -2., 3., 0.0)
# Check if 'random_number' is already in the session state, if not, set a random number for seed
if 'random_number' not in st.session_state:
st.session_state.random_number = random.randint(0, 1000000)
cols = st.columns(2)
with cols[0]:
st.caption('ED-to-ES')
with cols[1]:
st.caption('Frame-to-Frame')
rnd = np.random.RandomState(st.session_state.random_number)
Gs, Gs_kwargs, latent_controls, sess = load_initial_setup()
with sess.as_default():
z = rnd.randn(1, *Gs.input_shape[1:])
random_img_latent_code = Gs.components.mapping.run(z,None)
#make it be ED frame
random_img_latent_code -= 0.7*latent_controls['time']
# Apply physiological adjustment
adjusted_latent_code = np.copy(random_img_latent_code)
adjusted_latent_code += sphericity_index * latent_controls['sphericity_index']
adjusted_latent_code += lv_area * latent_controls['lv_area']
ed_to_es_vid = ED_to_ES(adjusted_latent_code)
f_to_f_vid = frame_to_frame(adjusted_latent_code)
for idx,vid in enumerate([ed_to_es_vid, f_to_f_vid]):
temp_video_path=f"output{idx}.mp4"
writer=imageio.get_writer(temp_video_path, fps=20)
for i in range(vid.shape[0]):
frame = vid[i]
writer.append_data(frame)
writer.close()
out_path = f"fixed_out{idx}.mp4"
command = ["ffmpeg", "-i", temp_video_path, "-vcodec", "libx264", out_path]
subprocess.run(command)
with cols[idx]:
st.video(out_path)
os.remove(temp_video_path)
os.remove(out_path)