awkrail commited on
Commit
66f99fe
1 Parent(s): e5415fc

add app file

Browse files
Files changed (3) hide show
  1. app.py +209 -0
  2. highlight_frames/.gitkeep +0 -0
  3. weights/.gitkeep +0 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright $today.year LY Corporation
3
+
4
+ LY Corporation licenses this file to you under the Apache License,
5
+ version 2.0 (the "License"); you may not use this file except in compliance
6
+ with the License. You may obtain a copy of the License at:
7
+
8
+ https://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12
+ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13
+ License for the specific language governing permissions and limitations
14
+ under the License.
15
+ """
16
+ import os
17
+ import torch
18
+ import subprocess
19
+ import ffmpeg
20
+ import pandas as pd
21
+ import gradio as gr
22
+ from tqdm import tqdm
23
+ from lighthouse.models import *
24
+
25
+ # use GPU if available
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ MODEL_NAMES = ['cg_detr', 'moment_detr', 'eatr', 'qd_detr', 'tr_detr', 'uvcom']
28
+ FEATURES = ['clip', 'clip_slowfast']
29
+ TOPK_MOMENT = 5
30
+ TOPK_HIGHLIGHT = 5
31
+
32
+ """
33
+ Helper functions
34
+ """
35
+ def load_pretrained_weights():
36
+ file_urls = []
37
+ for model_name in MODEL_NAMES:
38
+ for feature in FEATURES:
39
+ file_urls.append(
40
+ "https://zenodo.org/records/13363606/files/{}_{}_qvhighlight.ckpt".format(feature, model_name)
41
+ )
42
+ for file_url in tqdm(file_urls):
43
+ if not os.path.exists('weights/' + os.path.basename(file_url)):
44
+ command = 'wget -P weights/ {}'.format(file_url)
45
+ subprocess.run(command, shell=True)
46
+
47
+ # Slowfast weights
48
+ if not os.path.exists('SLOWFAST_8x8_R50.pkl'):
49
+ subprocess.run('wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl')
50
+
51
+ return file_urls
52
+
53
+ def flatten(array2d):
54
+ list1d = []
55
+ for elem in array2d:
56
+ list1d += elem
57
+ return list1d
58
+
59
+ """
60
+ Model initialization
61
+ """
62
+ load_pretrained_weights()
63
+ model = CGDETRPredictor('weights/clip_cg_detr_qvhighlight.ckpt', device=device,
64
+ feature_name='clip', slowfast_path='SLOWFAST_8x8_R50.pkl')
65
+
66
+ js_codes = ["""() => {{
67
+ let moment_text = document.getElementById('result_{}').textContent;
68
+ var replaced_text = moment_text.replace(/moment..../, '').replace(/\ Score.*/, '');
69
+ let start_end = JSON.parse(replaced_text);
70
+ document.getElementsByTagName("video")[0].currentTime = start_end[0];
71
+ document.getElementsByTagName("video")[0].play();
72
+ }}""".format(i) for i in range(TOPK_MOMENT)]
73
+
74
+ """
75
+ Gradio functions
76
+ """
77
+ def video_upload(video):
78
+ if video is None:
79
+ model.video_feats = None
80
+ model.video_mask = None
81
+ model.video_path = None
82
+ yield gr.update(value="Removed the video", visible=True)
83
+ else:
84
+ yield gr.update(value="Processing the video. Wait for a minute...", visible=True)
85
+ model.encode_video(video)
86
+ yield gr.update(value="Finished video processing!", visible=True)
87
+
88
+ def model_load(radio):
89
+ if radio is not None:
90
+ yield gr.update(value="Loading new model. Wait for a minute...", visible=True)
91
+ global model
92
+ feature, model_name = radio.split('+')
93
+ feature, model_name = feature.strip(), model_name.strip()
94
+
95
+ if model_name == 'moment_detr':
96
+ model_class = MomentDETRPredictor
97
+ elif model_name == 'qd_detr':
98
+ model_class = QDDETRPredictor
99
+ elif model_name == 'eatr':
100
+ model_class = EaTRPredictor
101
+ elif model_name == 'tr_detr':
102
+ model_class = TRDETRPredictor
103
+ elif model_name == 'uvcom':
104
+ model_class = UVCOMPredictor
105
+ elif model_name == 'taskweave':
106
+ model_class = TaskWeavePredictor
107
+ elif model_name == 'cg_detr':
108
+ model_class = CGDETRPredictor
109
+ else:
110
+ raise gr.Error("Select from the models")
111
+
112
+ model = model_class('weights/{}_{}_qvhighlight.ckpt'.format(feature, model_name),
113
+ device=device, feature_name='{}'.format(feature), slowfast_path='SLOWFAST_8x8_R50.pkl')
114
+ yield gr.update(value="Model loaded: {}".format(radio), visible=True)
115
+
116
+ def predict(textbox, line, gallery):
117
+ prediction = model.predict(textbox)
118
+ if prediction is None:
119
+ raise gr.Error('Upload the video before pushing the `Retrieve moment & highlight detection` button.')
120
+ else:
121
+ mr_results = prediction['pred_relevant_windows']
122
+ hl_results = prediction['pred_saliency_scores']
123
+
124
+ buttons = []
125
+ for i, pred in enumerate(mr_results[:TOPK_MOMENT]):
126
+ buttons.append(gr.Button(value='moment {}: [{}, {}] Score: {}'.format(i+1, pred[0], pred[1], pred[2]), visible=True))
127
+
128
+ # Visualize the HD score
129
+ seconds = [model.clip_len * i for i in range(len(hl_results))]
130
+ hl_data = pd.DataFrame({ 'second': seconds, 'saliency_score': hl_results })
131
+ min_val, max_val = min(hl_results), max(hl_results) + 1
132
+ min_x, max_x = min(seconds), max(seconds)
133
+ line = gr.LinePlot(value=hl_data, x='second', y='saliency_score', visible=True, y_lim=[min_val, max_val], x_lim=[min_x, max_x])
134
+
135
+ # Show highlight frames
136
+ n_largest_df = hl_data.nlargest(columns='saliency_score', n=TOPK_HIGHLIGHT)
137
+ highlighted_seconds = n_largest_df.second.tolist()
138
+ highlighted_scores = n_largest_df.saliency_score.tolist()
139
+
140
+ output_image_paths = []
141
+ for i, (second, score) in enumerate(zip(highlighted_seconds, highlighted_scores)):
142
+ output_path = "highlight_frames/highlight_{}.png".format(i)
143
+ (
144
+ ffmpeg
145
+ .input(model.video_path, ss=second)
146
+ .output(output_path, vframes=1, qscale=2)
147
+ .global_args('-loglevel', 'quiet', '-y')
148
+ .run()
149
+ )
150
+ output_image_paths.append((output_path, "Highlight: {} - score: {:.02f}".format(i+1, score)))
151
+ gallery = gr.Gallery(value=output_image_paths, label='gradio', columns=5, show_download_button=True, visible=True)
152
+ return buttons + [line, gallery]
153
+
154
+
155
+ def main():
156
+ title = """# Moment Retrieval & Highlight Detection Demo"""
157
+
158
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown(title)
160
+
161
+ with gr.Row():
162
+ with gr.Column():
163
+ with gr.Group():
164
+ gr.Markdown("## Model selection")
165
+ radio_list = flatten([["{} + {}".format(feature, model_name) for model_name in MODEL_NAMES] for feature in FEATURES])
166
+ radio = gr.Radio(radio_list, label="models", value="clip + cg_detr", info="Which model do you want to use?")
167
+ load_status_text = gr.Textbox(label='Model load status', value='Model loaded: clip + cg_detr')
168
+
169
+ with gr.Group():
170
+ gr.Markdown("## Video and query")
171
+ video_input = gr.Video(elem_id='video', height=600)
172
+ output = gr.Textbox(label='Video processing progress')
173
+ query_input = gr.Textbox(label='query')
174
+ button = gr.Button("Retrieve moment & highlight detection", variant="primary")
175
+
176
+ with gr.Column():
177
+ with gr.Group():
178
+ gr.Markdown("## Retrieved moments")
179
+
180
+ button_1 = gr.Button(value='moment 1', visible=False, elem_id='result_0')
181
+ button_2 = gr.Button(value='moment 2', visible=False, elem_id='result_1')
182
+ button_3 = gr.Button(value='moment 3', visible=False, elem_id='result_2')
183
+ button_4 = gr.Button(value='moment 4', visible=False, elem_id='result_3')
184
+ button_5 = gr.Button(value='moment 5', visible=False, elem_id='result_4')
185
+
186
+ button_1.click(None, None, None, js=js_codes[0])
187
+ button_2.click(None, None, None, js=js_codes[1])
188
+ button_3.click(None, None, None, js=js_codes[2])
189
+ button_4.click(None, None, None, js=js_codes[3])
190
+ button_5.click(None, None, None, js=js_codes[4])
191
+
192
+ # dummy
193
+ with gr.Group():
194
+ gr.Markdown("## Saliency score")
195
+ line = gr.LinePlot(value=pd.DataFrame({'x': [], 'y': []}), x='x', y='y', visible=False)
196
+ gr.Markdown("### Highlighted frames")
197
+ gallery = gr.Gallery(value=[], label="highlight", columns=5, visible=False)
198
+
199
+ video_input.change(video_upload, inputs=[video_input], outputs=output)
200
+ radio.select(model_load, inputs=[radio], outputs=load_status_text)
201
+
202
+ button.click(predict,
203
+ inputs=[query_input, line, gallery],
204
+ outputs=[button_1, button_2, button_3, button_4, button_5, line, gallery])
205
+
206
+ demo.launch()
207
+
208
+ if __name__ == "__main__":
209
+ main()
highlight_frames/.gitkeep ADDED
File without changes
weights/.gitkeep ADDED
File without changes