File size: 5,823 Bytes
aa4251c
 
a47a354
 
 
 
 
92a78e5
a47a354
 
ed77274
aa4251c
 
 
 
a47a354
 
 
ed77274
 
 
 
 
a47a354
ed77274
 
 
 
 
 
 
 
 
 
 
a47a354
 
 
 
 
 
 
 
 
b124c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a47a354
b124c89
a47a354
b124c89
a47a354
b124c89
71c1b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed77274
 
71c1b43
ed77274
 
 
 
 
 
 
 
 
 
 
b124c89
 
 
 
 
 
a47a354
ed77274
 
 
 
a47a354
ed77274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import streamlit as st
import tensorflow as tf
import pickle
import numpy as np
from pathlib import Path
import dnnlib
from dnnlib import tflib
import imageio
import os
import subprocess
import random

def check_gpu():
    return tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None)

def generate_image_from_projected_latents(latent_vector):
    images = Gs.components.synthesis.run(latent_vector, **Gs_kwargs)
    return images
## define video generation methods
def ED_to_ES(latent_code):
    all_imgs = []
    amounts_up = [i/25 for i in range(0,25)]
    amounts_down = [1-i/25 for i in range(1,26)]

    for amount_to_move in amounts_up:
        modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
        images = generate_image_from_projected_latents(modified_latent_code)
        all_imgs.append(np.array(images[0]))

    for amount_to_move in amounts_down:
        modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
        images = generate_image_from_projected_latents(modified_latent_code)
        all_imgs.append(np.array(images[0]))
    
    return np.array(all_imgs) 
def frame_to_frame(latent_code):
    modified_latent_code = np.copy(latent_code)
    full_video = [generate_image_from_projected_latents(modified_latent_code)]
    for i in range(49):
        modified_latent_code = modified_latent_code +  latent_controls[f'{i}{i+1}']
        ims = generate_image_from_projected_latents(modified_latent_code)
        full_video.append(ims)
    return np.array(full_video).squeeze()

@st.cache(allow_output_mutation=True)  # Cache to avoid reloading the model every time
def load_initial_setup():
    stream = open('best_net.pkl', 'rb')
    tflib.init_tf()
    sess=tf.get_default_session()

    with stream:
        G, D, Gs = pickle.load(stream, encoding='latin1')
    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    #load latent directions
    files = [x for x in Path('trajectories/').iterdir() if str(x).endswith('.npy')]
    latent_controls = {f.name[:-4]:np.load(f) for f in files}
    #select a random latent code
    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
    rnd = np.random.RandomState()

    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})

    return Gs, Gs_kwargs, latent_controls, sess

if __name__=="__main__":
    # Set the directory to the script's location
    dir_path = os.path.dirname(os.path.realpath(__file__))
    heart_image_path = os.path.join(dir_path, 'heart.png')
    st.markdown("""
      <style>
      .logo-test{
        font-weight:700 !important;
        font-size:50px !important;
        color:#FF0000 !important;
        text-align: center;
      }
      </style>
    """,unsafe_allow_html=True)
    st.markdown('<p class="logo-test">GANcMRI</p>', unsafe_allow_html=True)
    
    # Description sliders
    st.markdown("""
    This demo showcases GANcMRI: Synthetic cardiac MRI generation.Upon starting the demo or refreshing the page, a unique video will be automatically generated based on two methods described in our paper: ED-to-ES and Frame-to-Frame. These methods simulate cardiac function and movement in a realistic manner. The demo includes interactive sliders that allow you to adjust two parameters:

    1. **Sphericity Index:** This slider controls the sphericity of the left ventricle in the generated video. [The sphericity index](https://www.cell.com/med/pdf/S2666-6340(23)00069-7.pdf) is a measure of how spherical (round) the left ventricle appears, which is an important aspect in assessing certain heart conditions.
    2. **Left Ventricular Volume:** With this slider, you can modify the size of the left ventricle.
    """)    

    sphericity_index = st.slider("Sphericity Index", -2., 3., 0.0)
    lv_area = st.slider("Left Ventricular Volume", -2., 3., 0.0)
    # Check if 'random_number' is already in the session state, if not, set a random number for seed
    if 'random_number' not in st.session_state:
        st.session_state.random_number = random.randint(0, 1000000)
    cols = st.columns(2)
    
    with cols[0]:
        st.caption('ED-to-ES')
    with cols[1]:
        st.caption('Frame-to-Frame')

    rnd = np.random.RandomState(st.session_state.random_number)
    Gs, Gs_kwargs, latent_controls, sess = load_initial_setup()
    with sess.as_default():
        z = rnd.randn(1, *Gs.input_shape[1:])
        random_img_latent_code = Gs.components.mapping.run(z,None)
        #make it be ED frame
        random_img_latent_code -= 0.7*latent_controls['time']

        # Apply physiological adjustment
        adjusted_latent_code = np.copy(random_img_latent_code)
        adjusted_latent_code += sphericity_index * latent_controls['sphericity_index']
        adjusted_latent_code += lv_area * latent_controls['lv_area']

        ed_to_es_vid = ED_to_ES(adjusted_latent_code)
        f_to_f_vid = frame_to_frame(adjusted_latent_code)
        for idx,vid in enumerate([ed_to_es_vid, f_to_f_vid]):
            temp_video_path=f"output{idx}.mp4"
            writer=imageio.get_writer(temp_video_path, fps=20)
            for i in range(vid.shape[0]):
                frame = vid[i]
                writer.append_data(frame)
            writer.close()
            out_path = f"fixed_out{idx}.mp4"
            command = ["ffmpeg", "-i", temp_video_path, "-vcodec", "libx264", out_path]
            subprocess.run(command)
            with cols[idx]:
                st.video(out_path)
            os.remove(temp_video_path)
            os.remove(out_path)