stevengrove
commited on
Commit
•
2ed4a37
1
Parent(s):
e789019
Update tools/demo.py
Browse files- tools/demo.py +5 -6
tools/demo.py
CHANGED
@@ -116,10 +116,9 @@ def export_model(runner,
|
|
116 |
# dry run
|
117 |
deploy_model(fake_input)
|
118 |
|
119 |
-
os.makedirs(
|
120 |
save_onnx_path = os.path.join(
|
121 |
-
|
122 |
-
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
|
123 |
# export onnx
|
124 |
with BytesIO() as f:
|
125 |
output_names = ['num_dets', 'boxes', 'scores', 'labels']
|
@@ -142,7 +141,7 @@ def export_model(runner,
|
|
142 |
return gr.update(visible=True), save_onnx_path
|
143 |
|
144 |
|
145 |
-
def demo(runner, args):
|
146 |
with gr.Blocks(title="YOLO-World") as demo:
|
147 |
with gr.Row():
|
148 |
gr.Markdown('<h1><center>YOLO-World: Real-Time Open-Vocabulary '
|
@@ -195,7 +194,7 @@ def demo(runner, args):
|
|
195 |
[output_image])
|
196 |
clear.click(lambda: [[], '', ''], None,
|
197 |
[image, input_text, output_image])
|
198 |
-
export.click(partial(export_model, runner,
|
199 |
[input_text, max_num_boxes, score_thr, nms_thr],
|
200 |
[out_download, out_download])
|
201 |
demo.launch(server_name='0.0.0.0')
|
@@ -228,4 +227,4 @@ if __name__ == '__main__':
|
|
228 |
pipeline = cfg.test_dataloader.dataset.pipeline
|
229 |
runner.pipeline = Compose(pipeline)
|
230 |
runner.model.eval()
|
231 |
-
demo(runner, args)
|
|
|
116 |
# dry run
|
117 |
deploy_model(fake_input)
|
118 |
|
119 |
+
os.makedirs('work_dirs', exist_ok=True)
|
120 |
save_onnx_path = os.path.join(
|
121 |
+
'work_dirs', 'yolow-l.onnx')
|
|
|
122 |
# export onnx
|
123 |
with BytesIO() as f:
|
124 |
output_names = ['num_dets', 'boxes', 'scores', 'labels']
|
|
|
141 |
return gr.update(visible=True), save_onnx_path
|
142 |
|
143 |
|
144 |
+
def demo(runner, args, cfg):
|
145 |
with gr.Blocks(title="YOLO-World") as demo:
|
146 |
with gr.Row():
|
147 |
gr.Markdown('<h1><center>YOLO-World: Real-Time Open-Vocabulary '
|
|
|
194 |
[output_image])
|
195 |
clear.click(lambda: [[], '', ''], None,
|
196 |
[image, input_text, output_image])
|
197 |
+
export.click(partial(export_model, runner, cfg.checkpoint),
|
198 |
[input_text, max_num_boxes, score_thr, nms_thr],
|
199 |
[out_download, out_download])
|
200 |
demo.launch(server_name='0.0.0.0')
|
|
|
227 |
pipeline = cfg.test_dataloader.dataset.pipeline
|
228 |
runner.pipeline = Compose(pipeline)
|
229 |
runner.model.eval()
|
230 |
+
demo(runner, args, cfg)
|