freddyaboulton's picture
requirements
619c27a
raw
history blame
3.84 kB
import spaces
import gradio as gr
import cv2
from PIL import Image
import torch
import time
import numpy as np
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
SUBSAMPLE = 10
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
fps = int(cap.get(cv2.CAP_PROP_FPS))
desired_fps = fps // SUBSAMPLE
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
iterating, frame = cap.read()
n_frames = 0
n_chunks = 0
name = f"output_{n_chunks}.ts"
segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore
batch = []
while iterating:
frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if n_frames % SUBSAMPLE == 0:
batch.append(frame)
if len(batch) == 2 * desired_fps:
inputs = image_processor(images=batch, return_tensors="pt")
print(f"starting batch of size {len(batch)}")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
end = time.time()
print("time taken ", end - start)
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(height, width)] * len(batch)),
threshold=conf_threshold)
for i, (array, box) in enumerate(zip(batch, boxes)):
pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
pil_image.save(f"batch_{n_chunks}_detection_{i}.png")
frame = np.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
segment_file.write(frame)
batch = []
segment_file.release()
yield name
n_frames = 0
n_chunks += 1
name = f"output_{n_chunks}.ts"
segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore
iterating, frame = cap.read()
n_frames += 1
# css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
# .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
css=""
with gr.Blocks(css=css) as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with RT-DETR
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
</h3>
""")
with gr.Row():
with gr.Column():
with gr.Group(elem_classes=["my-group"]):
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
with gr.Column():
output_video = gr.Video(label="Processed Video", streaming=True, autoplay=True)
video.upload(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[output_video],
)
if __name__ == '__main__':
app.launch()