File size: 8,957 Bytes
66f99fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bd932e
66f99fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Copyright $today.year LY Corporation

LY Corporation licenses this file to you under the Apache License,
version 2.0 (the "License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at:

  https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
License for the specific language governing permissions and limitations
under the License.
"""
import os
import torch
import subprocess
import ffmpeg
import pandas as pd
import gradio as gr
from tqdm import tqdm
from lighthouse.models import *

# use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAMES = ['cg_detr', 'moment_detr', 'eatr', 'qd_detr', 'tr_detr', 'uvcom']
FEATURES = ['clip', 'clip_slowfast']
TOPK_MOMENT = 5
TOPK_HIGHLIGHT = 5

"""
Helper functions
"""
def load_pretrained_weights():
    file_urls = []
    for model_name in MODEL_NAMES:
        for feature in FEATURES:
            file_urls.append(
                "https://zenodo.org/records/13363606/files/{}_{}_qvhighlight.ckpt".format(feature, model_name)
            )
    for file_url in tqdm(file_urls):
        if not os.path.exists('weights/' + os.path.basename(file_url)):
            command = 'wget -P weights/ {}'.format(file_url)
            subprocess.run(command, shell=True)

    # Slowfast weights
    if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
        subprocess.run('wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl', shell=True)

    return file_urls

def flatten(array2d):
    list1d = []
    for elem in array2d:
        list1d += elem
    return list1d

"""
Model initialization
"""
load_pretrained_weights()
model = CGDETRPredictor('weights/clip_cg_detr_qvhighlight.ckpt', device=device, 
                        feature_name='clip', slowfast_path='SLOWFAST_8x8_R50.pkl')

js_codes = ["""() => {{
            let moment_text = document.getElementById('result_{}').textContent;
            var replaced_text = moment_text.replace(/moment..../, '').replace(/\ Score.*/, '');
            let start_end = JSON.parse(replaced_text);
            document.getElementsByTagName("video")[0].currentTime = start_end[0];
            document.getElementsByTagName("video")[0].play();
        }}""".format(i) for i in range(TOPK_MOMENT)]

"""
Gradio functions
"""
def video_upload(video):
    if video is None:
        model.video_feats = None
        model.video_mask = None
        model.video_path = None
        yield gr.update(value="Removed the video", visible=True)
    else:
        yield gr.update(value="Processing the video. Wait for a minute...", visible=True)
        model.encode_video(video)
        yield gr.update(value="Finished video processing!", visible=True)

def model_load(radio):
    if radio is not None:
        yield gr.update(value="Loading new model. Wait for a minute...", visible=True)
        global model
        feature, model_name = radio.split('+')
        feature, model_name = feature.strip(), model_name.strip()

        if model_name == 'moment_detr':
            model_class = MomentDETRPredictor
        elif model_name == 'qd_detr':
            model_class = QDDETRPredictor
        elif model_name == 'eatr':
            model_class = EaTRPredictor
        elif model_name == 'tr_detr':
            model_class = TRDETRPredictor
        elif model_name == 'uvcom':
            model_class = UVCOMPredictor
        elif model_name == 'taskweave':
            model_class = TaskWeavePredictor
        elif model_name == 'cg_detr':
            model_class = CGDETRPredictor
        else:
            raise gr.Error("Select from the models")
        
        model = model_class('weights/{}_{}_qvhighlight.ckpt'.format(feature, model_name),
                            device=device, feature_name='{}'.format(feature), slowfast_path='SLOWFAST_8x8_R50.pkl')
        yield gr.update(value="Model loaded: {}".format(radio), visible=True)

def predict(textbox, line, gallery):
    prediction = model.predict(textbox)
    if prediction is None:
        raise gr.Error('Upload the video before pushing the `Retrieve moment & highlight detection` button.')
    else:
        mr_results = prediction['pred_relevant_windows']
        hl_results = prediction['pred_saliency_scores']

        buttons = []
        for i, pred in enumerate(mr_results[:TOPK_MOMENT]):
            buttons.append(gr.Button(value='moment {}: [{}, {}] Score: {}'.format(i+1, pred[0], pred[1], pred[2]), visible=True))
        
        # Visualize the HD score
        seconds = [model.clip_len * i for i in range(len(hl_results))]
        hl_data = pd.DataFrame({ 'second': seconds, 'saliency_score': hl_results })
        min_val, max_val = min(hl_results), max(hl_results) + 1
        min_x, max_x = min(seconds), max(seconds)
        line = gr.LinePlot(value=hl_data, x='second', y='saliency_score', visible=True, y_lim=[min_val, max_val], x_lim=[min_x, max_x])

        # Show highlight frames
        n_largest_df = hl_data.nlargest(columns='saliency_score', n=TOPK_HIGHLIGHT)
        highlighted_seconds = n_largest_df.second.tolist()
        highlighted_scores = n_largest_df.saliency_score.tolist()

        output_image_paths = []
        for i, (second, score) in enumerate(zip(highlighted_seconds, highlighted_scores)):
            output_path = "highlight_frames/highlight_{}.png".format(i)
            (
                ffmpeg
                .input(model.video_path, ss=second)
                .output(output_path, vframes=1, qscale=2)
                .global_args('-loglevel', 'quiet', '-y')
                .run()
            )
            output_image_paths.append((output_path, "Highlight: {} - score: {:.02f}".format(i+1, score)))
        gallery = gr.Gallery(value=output_image_paths, label='gradio', columns=5, show_download_button=True, visible=True)
        return buttons + [line, gallery]


def main():
    title = """# Moment Retrieval & Highlight Detection Demo"""
    
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(title)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Model selection")
                    radio_list = flatten([["{} + {}".format(feature, model_name) for model_name in MODEL_NAMES] for feature in FEATURES])
                    radio = gr.Radio(radio_list, label="models", value="clip + cg_detr", info="Which model do you want to use?")
                    load_status_text = gr.Textbox(label='Model load status', value='Model loaded: clip + cg_detr')

                with gr.Group():
                    gr.Markdown("## Video and query")
                    video_input = gr.Video(elem_id='video', height=600)
                    output = gr.Textbox(label='Video processing progress')
                    query_input = gr.Textbox(label='query')
                    button = gr.Button("Retrieve moment & highlight detection", variant="primary")
            
            with gr.Column():
                with gr.Group():
                    gr.Markdown("## Retrieved moments")

                    button_1 = gr.Button(value='moment 1', visible=False, elem_id='result_0')
                    button_2 = gr.Button(value='moment 2', visible=False, elem_id='result_1')
                    button_3 = gr.Button(value='moment 3', visible=False, elem_id='result_2')
                    button_4 = gr.Button(value='moment 4', visible=False, elem_id='result_3')
                    button_5 = gr.Button(value='moment 5', visible=False, elem_id='result_4')

                    button_1.click(None, None, None, js=js_codes[0])
                    button_2.click(None, None, None, js=js_codes[1])
                    button_3.click(None, None, None, js=js_codes[2])
                    button_4.click(None, None, None, js=js_codes[3])
                    button_5.click(None, None, None, js=js_codes[4])

                # dummy
                with gr.Group():
                    gr.Markdown("## Saliency score")
                    line = gr.LinePlot(value=pd.DataFrame({'x': [], 'y': []}), x='x', y='y', visible=False)
                    gr.Markdown("### Highlighted frames")
                    gallery = gr.Gallery(value=[], label="highlight", columns=5, visible=False)
                
                video_input.change(video_upload, inputs=[video_input], outputs=output)
                radio.select(model_load, inputs=[radio], outputs=load_status_text)
                
                button.click(predict, 
                            inputs=[query_input, line, gallery], 
                            outputs=[button_1, button_2, button_3, button_4, button_5, line, gallery])

    demo.launch()

if __name__ == "__main__":
    main()