Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
import pickle | |
import numpy as np | |
from pathlib import Path | |
import dnnlib | |
from dnnlib import tflib | |
import imageio | |
import os | |
import subprocess | |
import random | |
def check_gpu(): | |
return tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None) | |
def generate_image_from_projected_latents(latent_vector): | |
images = Gs.components.synthesis.run(latent_vector, **Gs_kwargs) | |
return images | |
## define video generation methods | |
def ED_to_ES(latent_code): | |
all_imgs = [] | |
amounts_up = [i/25 for i in range(0,25)] | |
amounts_down = [1-i/25 for i in range(1,26)] | |
for amount_to_move in amounts_up: | |
modified_latent_code = latent_code + latent_controls["time"]*amount_to_move | |
images = generate_image_from_projected_latents(modified_latent_code) | |
all_imgs.append(np.array(images[0])) | |
for amount_to_move in amounts_down: | |
modified_latent_code = latent_code + latent_controls["time"]*amount_to_move | |
images = generate_image_from_projected_latents(modified_latent_code) | |
all_imgs.append(np.array(images[0])) | |
return np.array(all_imgs) | |
def frame_to_frame(latent_code): | |
modified_latent_code = np.copy(latent_code) | |
full_video = [generate_image_from_projected_latents(modified_latent_code)] | |
for i in range(49): | |
modified_latent_code = modified_latent_code + latent_controls[f'{i}{i+1}'] | |
ims = generate_image_from_projected_latents(modified_latent_code) | |
full_video.append(ims) | |
return np.array(full_video).squeeze() | |
# Cache to avoid reloading the model every time | |
def load_initial_setup(): | |
stream = open('best_net.pkl', 'rb') | |
tflib.init_tf() | |
sess=tf.get_default_session() | |
with stream: | |
G, D, Gs = pickle.load(stream, encoding='latin1') | |
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
Gs_kwargs = dnnlib.EasyDict() | |
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) | |
Gs_kwargs.randomize_noise = False | |
#load latent directions | |
files = [x for x in Path('trajectories/').iterdir() if str(x).endswith('.npy')] | |
latent_controls = {f.name[:-4]:np.load(f) for f in files} | |
#select a random latent code | |
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
rnd = np.random.RandomState() | |
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) | |
return Gs, Gs_kwargs, latent_controls, sess | |
if __name__=="__main__": | |
# Set the directory to the script's location | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
heart_image_path = os.path.join(dir_path, 'heart.png') | |
st.markdown(""" | |
<style> | |
.logo-test{ | |
font-weight:700 !important; | |
font-size:50px !important; | |
color:#FF0000 !important; | |
text-align: center; | |
} | |
</style> | |
""",unsafe_allow_html=True) | |
st.markdown('<p class="logo-test">GANcMRI</p>', unsafe_allow_html=True) | |
# Description sliders | |
st.markdown(""" | |
This demo showcases GANcMRI: Synthetic cardiac MRI generation.Upon starting the demo or refreshing the page, a unique video will be automatically generated based on two methods described in our paper: ED-to-ES and Frame-to-Frame. These methods simulate cardiac function and movement in a realistic manner. The demo includes interactive sliders that allow you to adjust two parameters: | |
1. **Sphericity Index:** This slider controls the sphericity of the left ventricle in the generated video. [The sphericity index](https://www.cell.com/med/pdf/S2666-6340(23)00069-7.pdf) is a measure of how spherical (round) the left ventricle appears, which is an important aspect in assessing certain heart conditions. | |
2. **Left Ventricular Volume:** With this slider, you can modify the size of the left ventricle. | |
""") | |
sphericity_index = st.slider("Sphericity Index", -2., 3., 0.0) | |
lv_area = st.slider("Left Ventricular Volume", -2., 3., 0.0) | |
# Check if 'random_number' is already in the session state, if not, set a random number for seed | |
if 'random_number' not in st.session_state: | |
st.session_state.random_number = random.randint(0, 1000000) | |
cols = st.columns(2) | |
with cols[0]: | |
st.caption('ED-to-ES') | |
with cols[1]: | |
st.caption('Frame-to-Frame') | |
rnd = np.random.RandomState(st.session_state.random_number) | |
Gs, Gs_kwargs, latent_controls, sess = load_initial_setup() | |
with sess.as_default(): | |
z = rnd.randn(1, *Gs.input_shape[1:]) | |
random_img_latent_code = Gs.components.mapping.run(z,None) | |
#make it be ED frame | |
random_img_latent_code -= 0.7*latent_controls['time'] | |
# Apply physiological adjustment | |
adjusted_latent_code = np.copy(random_img_latent_code) | |
adjusted_latent_code += sphericity_index * latent_controls['sphericity_index'] | |
adjusted_latent_code += lv_area * latent_controls['lv_area'] | |
ed_to_es_vid = ED_to_ES(adjusted_latent_code) | |
f_to_f_vid = frame_to_frame(adjusted_latent_code) | |
for idx,vid in enumerate([ed_to_es_vid, f_to_f_vid]): | |
temp_video_path=f"output{idx}.mp4" | |
writer=imageio.get_writer(temp_video_path, fps=20) | |
for i in range(vid.shape[0]): | |
frame = vid[i] | |
writer.append_data(frame) | |
writer.close() | |
out_path = f"fixed_out{idx}.mp4" | |
command = ["ffmpeg", "-i", temp_video_path, "-vcodec", "libx264", out_path] | |
subprocess.run(command) | |
with cols[idx]: | |
st.video(out_path) | |
os.remove(temp_video_path) | |
os.remove(out_path) |