Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
•
05187ec
1
Parent(s):
c2afc01
huggingface -- version 2
Browse files- .gitattributes +1 -0
- app.py +171 -78
- app_test.py +44 -21
- test.txt +0 -0
- test_beta.txt +0 -0
- test_sample/test-sample1.mp4 +3 -0
- tools/interact_tools.py +67 -67
- track_anything.py +14 -13
- tracker/base_tracker.py +36 -23
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
test_sample/test-sample1.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -17,7 +17,7 @@ import torchvision
|
|
17 |
import torch
|
18 |
import concurrent.futures
|
19 |
import queue
|
20 |
-
|
21 |
# download checkpoints
|
22 |
def download_checkpoint(url, folder, filename):
|
23 |
os.makedirs(folder, exist_ok=True)
|
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
|
|
84 |
"masks": [None]*len(frames),
|
85 |
"logits": [None]*len(frames),
|
86 |
"select_frame_number": 0,
|
87 |
-
"fps":
|
88 |
}
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
# get the select frame from gradio slider
|
92 |
-
def select_template(image_selection_slider, video_state):
|
93 |
|
94 |
# images = video_state[1]
|
95 |
image_selection_slider -= 1
|
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
|
|
100 |
model.samcontroler.sam_controler.reset_image()
|
101 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
102 |
|
|
|
|
|
103 |
|
104 |
-
return video_state["painted_images"][image_selection_slider], video_state
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# use sam to get the mask
|
107 |
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
@@ -133,17 +148,65 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
133 |
|
134 |
return painted_image, video_state, interactive_state
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
# tracking vos
|
137 |
-
def vos_tracking_video(video_state, interactive_state):
|
138 |
model.xmem.clear_memory()
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
fps = video_state["fps"]
|
142 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
149 |
interactive_state["inference_times"] += 1
|
@@ -152,7 +215,7 @@ def vos_tracking_video(video_state, interactive_state):
|
|
152 |
interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
|
153 |
interactive_state["positive_click_times"],
|
154 |
interactive_state["negative_click_times"]))
|
155 |
-
|
156 |
#### shanggao code for mask save
|
157 |
if interactive_state["mask_save"]:
|
158 |
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
|
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
176 |
output_path (str): The path to save the generated video.
|
177 |
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
178 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
frames = torch.from_numpy(np.asarray(frames))
|
180 |
if not os.path.exists(os.path.dirname(output_path)):
|
181 |
os.makedirs(os.path.dirname(output_path))
|
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
|
193 |
|
194 |
# args, defined in track_anything.py
|
195 |
args = parse_augment()
|
196 |
-
# args.port =
|
197 |
-
# args.device = "cuda:
|
198 |
# args.mask_save = True
|
199 |
|
200 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
|
|
208 |
"inference_times": 0,
|
209 |
"negative_click_times" : 0,
|
210 |
"positive_click_times": 0,
|
211 |
-
"mask_save": args.mask_save
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
video_state = gr.State(
|
214 |
{
|
215 |
"video_name": "",
|
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
|
|
225 |
with gr.Row():
|
226 |
|
227 |
# for user video input
|
228 |
-
with gr.Column(
|
229 |
-
|
|
|
|
|
230 |
|
231 |
|
232 |
|
233 |
-
with gr.Row(
|
234 |
# put the template frame under the radio button
|
235 |
-
with gr.Column(
|
236 |
# extract frames
|
237 |
with gr.Column():
|
238 |
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
239 |
|
240 |
# click points settins, negative or positive, mode continuous or single
|
241 |
with gr.Row():
|
242 |
-
with gr.Row(
|
243 |
point_prompt = gr.Radio(
|
244 |
choices=["Positive", "Negative"],
|
245 |
value="Positive",
|
246 |
label="Point Prompt",
|
247 |
-
interactive=True
|
|
|
248 |
click_mode = gr.Radio(
|
249 |
choices=["Continuous", "Single"],
|
250 |
value="Continuous",
|
251 |
label="Clicking Mode",
|
252 |
-
interactive=True
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
|
262 |
-
with gr.Column(
|
263 |
-
|
264 |
-
|
|
|
|
|
265 |
|
266 |
# first step: get the video information
|
267 |
extract_frames_button.click(
|
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
|
|
269 |
inputs=[
|
270 |
video_input, video_state
|
271 |
],
|
272 |
-
outputs=[video_state,
|
|
|
|
|
273 |
)
|
274 |
|
275 |
# second step: select images from slider
|
276 |
image_selection_slider.release(fn=select_template,
|
277 |
-
inputs=[image_selection_slider, video_state],
|
278 |
-
outputs=[template_frame, video_state], api_name="select_image")
|
|
|
|
|
|
|
279 |
|
280 |
-
|
281 |
template_frame.select(
|
282 |
fn=sam_refine,
|
283 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
284 |
outputs=[template_frame, video_state, interactive_state]
|
285 |
)
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
tracking_video_predict_button.click(
|
288 |
fn=vos_tracking_video,
|
289 |
-
inputs=[video_state, interactive_state],
|
290 |
outputs=[video_output, video_state, interactive_state]
|
291 |
)
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
# clear input
|
295 |
video_input.clear(
|
@@ -306,57 +413,43 @@ with gr.Blocks() as iface:
|
|
306 |
"inference_times": 0,
|
307 |
"negative_click_times" : 0,
|
308 |
"positive_click_times": 0,
|
309 |
-
"mask_save": args.mask_save
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
[],
|
314 |
-
[
|
315 |
-
video_state,
|
316 |
-
interactive_state,
|
317 |
-
click_state,
|
318 |
-
],
|
319 |
-
queue=False,
|
320 |
-
show_progress=False
|
321 |
-
)
|
322 |
-
clear_button_image.click(
|
323 |
-
lambda: (
|
324 |
-
{
|
325 |
-
"origin_images": None,
|
326 |
-
"painted_images": None,
|
327 |
-
"masks": None,
|
328 |
-
"logits": None,
|
329 |
-
"select_frame_number": 0,
|
330 |
-
"fps": 30
|
331 |
},
|
332 |
-
|
333 |
-
"inference_times": 0,
|
334 |
-
"negative_click_times" : 0,
|
335 |
-
"positive_click_times": 0,
|
336 |
-
"mask_save": args.mask_save
|
337 |
},
|
338 |
-
[[],[]]
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
[],
|
341 |
[
|
342 |
video_state,
|
343 |
interactive_state,
|
344 |
click_state,
|
|
|
|
|
|
|
|
|
345 |
],
|
346 |
-
|
347 |
queue=False,
|
348 |
-
show_progress=False
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
)
|
351 |
-
clear_button_clike.click(
|
352 |
-
lambda: ([[],[]]),
|
353 |
-
[],
|
354 |
-
[click_state],
|
355 |
-
queue=False,
|
356 |
-
show_progress=False
|
357 |
)
|
358 |
iface.queue(concurrency_count=1)
|
359 |
-
iface.launch(enable_queue=True)
|
360 |
|
361 |
|
362 |
|
|
|
17 |
import torch
|
18 |
import concurrent.futures
|
19 |
import queue
|
20 |
+
from tools.painter import mask_painter, point_painter
|
21 |
# download checkpoints
|
22 |
def download_checkpoint(url, folder, filename):
|
23 |
os.makedirs(folder, exist_ok=True)
|
|
|
84 |
"masks": [None]*len(frames),
|
85 |
"logits": [None]*len(frames),
|
86 |
"select_frame_number": 0,
|
87 |
+
"fps": fps
|
88 |
}
|
89 |
+
video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
|
90 |
+
|
91 |
+
model.samcontroler.sam_controler.reset_image()
|
92 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
93 |
+
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
94 |
+
gr.update(visible=True), gr.update(visible=True), \
|
95 |
+
gr.update(visible=True), gr.update(visible=True), \
|
96 |
+
gr.update(visible=True), gr.update(visible=True), \
|
97 |
+
gr.update(visible=True), gr.update(visible=True), \
|
98 |
+
gr.update(visible=True)
|
99 |
|
100 |
# get the select frame from gradio slider
|
101 |
+
def select_template(image_selection_slider, video_state, interactive_state):
|
102 |
|
103 |
# images = video_state[1]
|
104 |
image_selection_slider -= 1
|
|
|
109 |
model.samcontroler.sam_controler.reset_image()
|
110 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
111 |
|
112 |
+
# # clear multi mask
|
113 |
+
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
|
114 |
|
115 |
+
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
116 |
+
|
117 |
+
def get_end_number(track_pause_number_slider, interactive_state):
|
118 |
+
interactive_state["track_end_number"] = track_pause_number_slider
|
119 |
+
return interactive_state
|
120 |
|
121 |
# use sam to get the mask
|
122 |
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
|
|
148 |
|
149 |
return painted_image, video_state, interactive_state
|
150 |
|
151 |
+
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
152 |
+
mask = video_state["masks"][video_state["select_frame_number"]]
|
153 |
+
interactive_state["multi_mask"]["masks"].append(mask)
|
154 |
+
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
155 |
+
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
156 |
+
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
157 |
+
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
|
158 |
+
|
159 |
+
def clear_click(video_state, click_state):
|
160 |
+
click_state = [[],[]]
|
161 |
+
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
162 |
+
return template_frame, click_state
|
163 |
+
|
164 |
+
def remove_multi_mask(interactive_state):
|
165 |
+
interactive_state["multi_mask"]["mask_names"]= []
|
166 |
+
interactive_state["multi_mask"]["masks"] = []
|
167 |
+
return interactive_state
|
168 |
+
|
169 |
+
def show_mask(video_state, interactive_state, mask_dropdown):
|
170 |
+
mask_dropdown.sort()
|
171 |
+
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
172 |
+
|
173 |
+
for i in range(len(mask_dropdown)):
|
174 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
175 |
+
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
176 |
+
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
177 |
+
|
178 |
+
return select_frame
|
179 |
+
|
180 |
# tracking vos
|
181 |
+
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
182 |
model.xmem.clear_memory()
|
183 |
+
if interactive_state["track_end_number"]:
|
184 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
185 |
+
else:
|
186 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
187 |
+
|
188 |
+
if interactive_state["multi_mask"]["masks"]:
|
189 |
+
if len(mask_dropdown) == 0:
|
190 |
+
mask_dropdown = ["mask_001"]
|
191 |
+
mask_dropdown.sort()
|
192 |
+
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
|
193 |
+
for i in range(1,len(mask_dropdown)):
|
194 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
195 |
+
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
|
196 |
+
video_state["masks"][video_state["select_frame_number"]]= template_mask
|
197 |
+
else:
|
198 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
199 |
fps = video_state["fps"]
|
200 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
201 |
|
202 |
+
if interactive_state["track_end_number"]:
|
203 |
+
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
204 |
+
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
205 |
+
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
|
206 |
+
else:
|
207 |
+
video_state["masks"][video_state["select_frame_number"]:] = masks
|
208 |
+
video_state["logits"][video_state["select_frame_number"]:] = logits
|
209 |
+
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
|
210 |
|
211 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
212 |
interactive_state["inference_times"] += 1
|
|
|
215 |
interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
|
216 |
interactive_state["positive_click_times"],
|
217 |
interactive_state["negative_click_times"]))
|
218 |
+
|
219 |
#### shanggao code for mask save
|
220 |
if interactive_state["mask_save"]:
|
221 |
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
|
|
|
239 |
output_path (str): The path to save the generated video.
|
240 |
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
241 |
"""
|
242 |
+
# height, width, layers = frames[0].shape
|
243 |
+
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
244 |
+
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
245 |
+
# print(output_path)
|
246 |
+
# for frame in frames:
|
247 |
+
# video.write(frame)
|
248 |
+
|
249 |
+
# video.release()
|
250 |
frames = torch.from_numpy(np.asarray(frames))
|
251 |
if not os.path.exists(os.path.dirname(output_path)):
|
252 |
os.makedirs(os.path.dirname(output_path))
|
|
|
264 |
|
265 |
# args, defined in track_anything.py
|
266 |
args = parse_augment()
|
267 |
+
# args.port = 12315
|
268 |
+
# args.device = "cuda:1"
|
269 |
# args.mask_save = True
|
270 |
|
271 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
|
|
279 |
"inference_times": 0,
|
280 |
"negative_click_times" : 0,
|
281 |
"positive_click_times": 0,
|
282 |
+
"mask_save": args.mask_save,
|
283 |
+
"multi_mask": {
|
284 |
+
"mask_names": [],
|
285 |
+
"masks": []
|
286 |
+
},
|
287 |
+
"track_end_number": None
|
288 |
+
}
|
289 |
+
)
|
290 |
+
|
291 |
video_state = gr.State(
|
292 |
{
|
293 |
"video_name": "",
|
|
|
303 |
with gr.Row():
|
304 |
|
305 |
# for user video input
|
306 |
+
with gr.Column():
|
307 |
+
with gr.Row(scale=0.4):
|
308 |
+
video_input = gr.Video(autosize=True)
|
309 |
+
video_info = gr.Textbox()
|
310 |
|
311 |
|
312 |
|
313 |
+
with gr.Row():
|
314 |
# put the template frame under the radio button
|
315 |
+
with gr.Column():
|
316 |
# extract frames
|
317 |
with gr.Column():
|
318 |
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
319 |
|
320 |
# click points settins, negative or positive, mode continuous or single
|
321 |
with gr.Row():
|
322 |
+
with gr.Row():
|
323 |
point_prompt = gr.Radio(
|
324 |
choices=["Positive", "Negative"],
|
325 |
value="Positive",
|
326 |
label="Point Prompt",
|
327 |
+
interactive=True,
|
328 |
+
visible=False)
|
329 |
click_mode = gr.Radio(
|
330 |
choices=["Continuous", "Single"],
|
331 |
value="Continuous",
|
332 |
label="Clicking Mode",
|
333 |
+
interactive=True,
|
334 |
+
visible=False)
|
335 |
+
with gr.Row():
|
336 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
337 |
+
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
338 |
+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
339 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
|
340 |
+
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
|
|
341 |
|
342 |
+
with gr.Column():
|
343 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
|
344 |
+
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
345 |
+
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
346 |
+
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
347 |
|
348 |
# first step: get the video information
|
349 |
extract_frames_button.click(
|
|
|
351 |
inputs=[
|
352 |
video_input, video_state
|
353 |
],
|
354 |
+
outputs=[video_state, video_info, template_frame,
|
355 |
+
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
356 |
+
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
|
357 |
)
|
358 |
|
359 |
# second step: select images from slider
|
360 |
image_selection_slider.release(fn=select_template,
|
361 |
+
inputs=[image_selection_slider, video_state, interactive_state],
|
362 |
+
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
363 |
+
track_pause_number_slider.release(fn=get_end_number,
|
364 |
+
inputs=[track_pause_number_slider, interactive_state],
|
365 |
+
outputs=[interactive_state], api_name="end_image")
|
366 |
|
367 |
+
# click select image to get mask using sam
|
368 |
template_frame.select(
|
369 |
fn=sam_refine,
|
370 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
371 |
outputs=[template_frame, video_state, interactive_state]
|
372 |
)
|
373 |
|
374 |
+
# add different mask
|
375 |
+
Add_mask_button.click(
|
376 |
+
fn=add_multi_mask,
|
377 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
378 |
+
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
379 |
+
)
|
380 |
+
|
381 |
+
remove_mask_button.click(
|
382 |
+
fn=remove_multi_mask,
|
383 |
+
inputs=[interactive_state],
|
384 |
+
outputs=[interactive_state]
|
385 |
+
)
|
386 |
+
|
387 |
+
# tracking video from select image and mask
|
388 |
tracking_video_predict_button.click(
|
389 |
fn=vos_tracking_video,
|
390 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
391 |
outputs=[video_output, video_state, interactive_state]
|
392 |
)
|
393 |
|
394 |
+
# click to get mask
|
395 |
+
mask_dropdown.change(
|
396 |
+
fn=show_mask,
|
397 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
398 |
+
outputs=[template_frame]
|
399 |
+
)
|
400 |
|
401 |
# clear input
|
402 |
video_input.clear(
|
|
|
413 |
"inference_times": 0,
|
414 |
"negative_click_times" : 0,
|
415 |
"positive_click_times": 0,
|
416 |
+
"mask_save": args.mask_save,
|
417 |
+
"multi_mask": {
|
418 |
+
"mask_names": [],
|
419 |
+
"masks": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
},
|
421 |
+
"track_end_number": 0
|
|
|
|
|
|
|
|
|
422 |
},
|
423 |
+
[[],[]],
|
424 |
+
None,
|
425 |
+
None,
|
426 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
427 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
428 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
|
429 |
+
|
430 |
+
),
|
431 |
[],
|
432 |
[
|
433 |
video_state,
|
434 |
interactive_state,
|
435 |
click_state,
|
436 |
+
video_output,
|
437 |
+
template_frame,
|
438 |
+
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
439 |
+
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
|
440 |
],
|
|
|
441 |
queue=False,
|
442 |
+
show_progress=False)
|
443 |
+
|
444 |
+
# points clear
|
445 |
+
clear_button_click.click(
|
446 |
+
fn = clear_click,
|
447 |
+
inputs = [video_state, click_state,],
|
448 |
+
outputs = [template_frame,click_state],
|
449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
)
|
451 |
iface.queue(concurrency_count=1)
|
452 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
453 |
|
454 |
|
455 |
|
app_test.py
CHANGED
@@ -1,23 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
iface.launch(server_name='0.0.0.0', server_port=12212)
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
|
3 |
+
# def update_iframe(slider_value):
|
4 |
+
# return f'''
|
5 |
+
# <script>
|
6 |
+
# window.addEventListener('message', function(event) {{
|
7 |
+
# if (event.data.sliderValue !== undefined) {{
|
8 |
+
# var iframe = document.getElementById("text_iframe");
|
9 |
+
# iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
|
10 |
+
# }}
|
11 |
+
# }}, false);
|
12 |
+
# </script>
|
13 |
+
# <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
|
14 |
+
# '''
|
15 |
+
|
16 |
+
# iface = gr.Interface(
|
17 |
+
# fn=update_iframe,
|
18 |
+
# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
|
19 |
+
# outputs=gr.outputs.HTML(),
|
20 |
+
# allow_flagging=False,
|
21 |
+
# )
|
22 |
+
|
23 |
+
# iface.launch(server_name='0.0.0.0', server_port=12212)
|
24 |
+
|
25 |
import gradio as gr
|
26 |
|
27 |
+
|
28 |
+
def change_mask(drop):
|
29 |
+
return gr.update(choices=["hello", "kitty"])
|
30 |
+
|
31 |
+
with gr.Blocks() as iface:
|
32 |
+
drop = gr.Dropdown(
|
33 |
+
choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
|
34 |
+
)
|
35 |
+
radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
|
36 |
+
multi_drop = gr.Dropdown(
|
37 |
+
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
|
38 |
+
)
|
39 |
+
|
40 |
+
multi_drop.change(
|
41 |
+
fn=change_mask,
|
42 |
+
inputs = multi_drop,
|
43 |
+
outputs=multi_drop
|
44 |
+
)
|
45 |
+
|
46 |
+
iface.launch(server_name='0.0.0.0', server_port=1223)
|
|
test.txt
ADDED
File without changes
|
test_beta.txt
ADDED
File without changes
|
test_sample/test-sample1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:403b711376a79026beedb7d0d919d35268298150120438a22a5330d0c8cdd6b6
|
3 |
+
size 6039223
|
tools/interact_tools.py
CHANGED
@@ -37,16 +37,16 @@ class SamControler():
|
|
37 |
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
38 |
|
39 |
|
40 |
-
def seg_again(self, image: np.ndarray):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
|
49 |
-
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
50 |
'''
|
51 |
it is used in first frame in video
|
52 |
return: mask, logit, painted image(mask+point)
|
@@ -88,47 +88,47 @@ class SamControler():
|
|
88 |
|
89 |
return mask, logit, painted_image
|
90 |
|
91 |
-
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
|
131 |
-
|
132 |
|
133 |
|
134 |
|
@@ -226,31 +226,31 @@ class SamControler():
|
|
226 |
|
227 |
|
228 |
|
229 |
-
if __name__ == "__main__":
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
|
255 |
|
256 |
|
|
|
37 |
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
38 |
|
39 |
|
40 |
+
# def seg_again(self, image: np.ndarray):
|
41 |
+
# '''
|
42 |
+
# it is used when interact in video
|
43 |
+
# '''
|
44 |
+
# self.sam_controler.reset_image()
|
45 |
+
# self.sam_controler.set_image(image)
|
46 |
+
# return
|
47 |
|
48 |
|
49 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
|
50 |
'''
|
51 |
it is used in first frame in video
|
52 |
return: mask, logit, painted image(mask+point)
|
|
|
88 |
|
89 |
return mask, logit, painted_image
|
90 |
|
91 |
+
# def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
92 |
+
# origal_image = self.sam_controler.orignal_image
|
93 |
+
# if same:
|
94 |
+
# '''
|
95 |
+
# true; loop in the same image
|
96 |
+
# '''
|
97 |
+
# prompts = {
|
98 |
+
# 'point_coords': points,
|
99 |
+
# 'point_labels': labels,
|
100 |
+
# 'mask_input': logits[None, :, :]
|
101 |
+
# }
|
102 |
+
# masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
103 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
104 |
|
105 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
106 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
107 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
108 |
+
# painted_image = Image.fromarray(painted_image)
|
109 |
|
110 |
+
# return mask, logit, painted_image
|
111 |
+
# else:
|
112 |
+
# '''
|
113 |
+
# loop in the different image, interact in the video
|
114 |
+
# '''
|
115 |
+
# if image is None:
|
116 |
+
# raise('Image error')
|
117 |
+
# else:
|
118 |
+
# self.seg_again(image)
|
119 |
+
# prompts = {
|
120 |
+
# 'point_coords': points,
|
121 |
+
# 'point_labels': labels,
|
122 |
+
# }
|
123 |
+
# masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
124 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
125 |
|
126 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
127 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
128 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
129 |
+
# painted_image = Image.fromarray(painted_image)
|
130 |
|
131 |
+
# return mask, logit, painted_image
|
132 |
|
133 |
|
134 |
|
|
|
226 |
|
227 |
|
228 |
|
229 |
+
# if __name__ == "__main__":
|
230 |
+
# points = np.array([[500, 375], [1125, 625]])
|
231 |
+
# labels = np.array([1, 1])
|
232 |
+
# image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
233 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
234 |
|
235 |
+
# sam_controler = initialize()
|
236 |
+
# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
237 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
238 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
239 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
240 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
241 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
242 |
|
243 |
+
# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
244 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
245 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
246 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
247 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
248 |
|
249 |
+
# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
250 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
251 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
252 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
253 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
254 |
|
255 |
|
256 |
|
track_anything.py
CHANGED
@@ -15,26 +15,26 @@ class TrackingAnything():
|
|
15 |
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
16 |
|
17 |
|
18 |
-
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
|
29 |
-
|
30 |
|
31 |
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
32 |
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
33 |
return mask, logit, painted_image
|
34 |
|
35 |
-
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
36 |
-
|
37 |
-
|
38 |
|
39 |
def generator(self, images: list, template_mask:np.ndarray):
|
40 |
|
@@ -53,6 +53,7 @@ class TrackingAnything():
|
|
53 |
masks.append(mask)
|
54 |
logits.append(logit)
|
55 |
painted_images.append(painted_image)
|
|
|
56 |
return masks, logits, painted_images
|
57 |
|
58 |
|
|
|
15 |
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
16 |
|
17 |
|
18 |
+
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
19 |
+
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
20 |
+
# if first_flag:
|
21 |
+
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
22 |
+
# return mask, logit, painted_image
|
23 |
|
24 |
+
# if interact_flag:
|
25 |
+
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
26 |
+
# return mask, logit, painted_image
|
27 |
|
28 |
+
# mask, logit, painted_image = self.xmem.track(image, logit)
|
29 |
+
# return mask, logit, painted_image
|
30 |
|
31 |
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
32 |
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
33 |
return mask, logit, painted_image
|
34 |
|
35 |
+
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
36 |
+
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
37 |
+
# return mask, logit, painted_image
|
38 |
|
39 |
def generator(self, images: list, template_mask:np.ndarray):
|
40 |
|
|
|
53 |
masks.append(mask)
|
54 |
logits.append(logit)
|
55 |
painted_images.append(painted_image)
|
56 |
+
print("tracking image {}".format(i))
|
57 |
return masks, logits, painted_images
|
58 |
|
59 |
|
tracker/base_tracker.py
CHANGED
@@ -67,6 +67,7 @@ class BaseTracker:
|
|
67 |
logit: numpy arrays, probability map (H, W)
|
68 |
painted_image: numpy array (H, W, 3)
|
69 |
"""
|
|
|
70 |
if first_frame_annotation is not None: # first frame mask
|
71 |
# initialisation
|
72 |
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
@@ -87,12 +88,20 @@ class BaseTracker:
|
|
87 |
out_mask = torch.argmax(probs, dim=0)
|
88 |
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
89 |
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
painted_image = frame
|
92 |
for obj in range(1, num_objs+1):
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
|
97 |
@torch.no_grad()
|
98 |
def sam_refinement(self, frame, logits, ti):
|
@@ -142,34 +151,38 @@ if __name__ == '__main__':
|
|
142 |
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
143 |
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
144 |
|
145 |
-
# test for storage efficiency
|
146 |
-
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
147 |
-
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
else:
|
156 |
-
mask, prob, painted_image = tracker.track(frame)
|
157 |
-
# save
|
158 |
-
painted_image = Image.fromarray(painted_image)
|
159 |
-
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
160 |
|
161 |
-
tracker.clear_memory()
|
162 |
for ti, frame in enumerate(frames):
|
163 |
-
print(ti)
|
164 |
-
# if ti > 200:
|
165 |
-
# break
|
166 |
if ti == 0:
|
167 |
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
168 |
else:
|
169 |
mask, prob, painted_image = tracker.track(frame)
|
170 |
# save
|
171 |
painted_image = Image.fromarray(painted_image)
|
172 |
-
painted_image.save(f'/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
# # track anything given in the first frame annotation
|
175 |
# for ti, frame in enumerate(frames):
|
|
|
67 |
logit: numpy arrays, probability map (H, W)
|
68 |
painted_image: numpy array (H, W, 3)
|
69 |
"""
|
70 |
+
|
71 |
if first_frame_annotation is not None: # first frame mask
|
72 |
# initialisation
|
73 |
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
|
|
88 |
out_mask = torch.argmax(probs, dim=0)
|
89 |
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
90 |
|
91 |
+
final_mask = np.zeros_like(out_mask)
|
92 |
+
|
93 |
+
# map back
|
94 |
+
for k, v in self.mapper.remappings.items():
|
95 |
+
final_mask[out_mask == v] = k
|
96 |
+
|
97 |
+
num_objs = final_mask.max()
|
98 |
painted_image = frame
|
99 |
for obj in range(1, num_objs+1):
|
100 |
+
if np.max(final_mask==obj) == 0:
|
101 |
+
continue
|
102 |
+
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
103 |
+
|
104 |
+
return final_mask, final_mask, painted_image
|
105 |
|
106 |
@torch.no_grad()
|
107 |
def sam_refinement(self, frame, logits, ti):
|
|
|
151 |
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
152 |
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
153 |
|
154 |
+
# # test for storage efficiency
|
155 |
+
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
156 |
+
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
157 |
|
158 |
+
first_frame_annotation[first_frame_annotation==1] = 15
|
159 |
+
first_frame_annotation[first_frame_annotation==2] = 20
|
160 |
+
|
161 |
+
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
162 |
+
if not os.path.exists(save_path):
|
163 |
+
os.mkdir(save_path)
|
|
|
|
|
|
|
|
|
|
|
164 |
|
|
|
165 |
for ti, frame in enumerate(frames):
|
|
|
|
|
|
|
166 |
if ti == 0:
|
167 |
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
168 |
else:
|
169 |
mask, prob, painted_image = tracker.track(frame)
|
170 |
# save
|
171 |
painted_image = Image.fromarray(painted_image)
|
172 |
+
painted_image.save(f'{save_path}/{ti:05d}.png')
|
173 |
+
|
174 |
+
# tracker.clear_memory()
|
175 |
+
# for ti, frame in enumerate(frames):
|
176 |
+
# print(ti)
|
177 |
+
# # if ti > 200:
|
178 |
+
# # break
|
179 |
+
# if ti == 0:
|
180 |
+
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
181 |
+
# else:
|
182 |
+
# mask, prob, painted_image = tracker.track(frame)
|
183 |
+
# # save
|
184 |
+
# painted_image = Image.fromarray(painted_image)
|
185 |
+
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
186 |
|
187 |
# # track anything given in the first frame annotation
|
188 |
# for ti, frame in enumerate(frames):
|