freddyaboulton HF staff commited on
Commit
619c27a
1 Parent(s): b7278d2

requirements

Browse files
Files changed (2) hide show
  1. app.py +39 -34
  2. requirements.txt +2 -0
app.py CHANGED
@@ -13,29 +13,32 @@ from draw_boxes import draw_bounding_boxes
13
  image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
14
  model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
15
 
 
 
 
16
  @spaces.GPU
17
  def stream_object_detection(video, conf_threshold):
18
  cap = cv2.VideoCapture(video)
19
 
20
- video_codec = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
21
  fps = int(cap.get(cv2.CAP_PROP_FPS))
22
- desired_fps = fps // 5
23
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
24
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
25
 
26
  iterating, frame = cap.read()
27
 
28
  n_frames = 0
29
  n_chunks = 0
30
 
31
- name = f"output_{n_chunks}.mp4"
32
- segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width // 2, height // 2)) # type: ignore
33
  batch = []
34
 
35
  while iterating:
36
  frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
37
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
38
- if n_frames % 5 == 0:
39
  batch.append(frame)
40
  if len(batch) == 2 * desired_fps:
41
  inputs = image_processor(images=batch, return_tensors="pt")
@@ -49,34 +52,32 @@ def stream_object_detection(video, conf_threshold):
49
 
50
  boxes = image_processor.post_process_object_detection(
51
  outputs,
52
- target_sizes=torch.tensor([frame[0].shape[:2][::-1]] * len(batch)),
53
  threshold=conf_threshold)
54
 
55
- for array, box in zip(batch, boxes):
56
  pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
 
57
  frame = np.array(pil_image)
58
  # Convert RGB to BGR
59
  frame = frame[:, :, ::-1].copy()
60
  segment_file.write(frame)
61
 
 
62
  segment_file.release()
 
63
  n_frames = 0
64
  n_chunks += 1
65
- yield name
66
- name = f"output_{n_chunks}.mp4"
67
- segment_file = cv2.VideoWriter(name, video_codec, fps, (width, height)) # type: ignore
68
 
69
  iterating, frame = cap.read()
70
  n_frames += 1
71
 
72
- segment_file.release()
73
- yield name
74
-
75
-
76
- css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
77
- .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
78
-
79
 
 
80
  with gr.Blocks(css=css) as app:
81
  gr.HTML(
82
  """
@@ -90,21 +91,25 @@ with gr.Blocks(css=css) as app:
90
  <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
91
  </h3>
92
  """)
93
- with gr.Column(elem_classes=["my-column"]):
94
- with gr.Group(elem_classes=["my-group"]):
95
- video = gr.Video(label="Video Source", streaming=True)
96
- conf_threshold = gr.Slider(
97
- label="Confidence Threshold",
98
- minimum=0.0,
99
- maximum=1.0,
100
- step=0.05,
101
- value=0.30,
102
- )
103
- video.upload(
104
- fn=stream_object_detection,
105
- inputs=[video, conf_threshold],
106
- outputs=[video],
107
- )
 
 
 
 
108
 
109
  if __name__ == '__main__':
110
  app.launch()
 
13
  image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
14
  model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")
15
 
16
+
17
+ SUBSAMPLE = 10
18
+
19
  @spaces.GPU
20
  def stream_object_detection(video, conf_threshold):
21
  cap = cv2.VideoCapture(video)
22
 
23
+ video_codec = cv2.VideoWriter_fourcc(*"x264") # type: ignore
24
  fps = int(cap.get(cv2.CAP_PROP_FPS))
25
+ desired_fps = fps // SUBSAMPLE
26
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
27
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
28
 
29
  iterating, frame = cap.read()
30
 
31
  n_frames = 0
32
  n_chunks = 0
33
 
34
+ name = f"output_{n_chunks}.ts"
35
+ segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore
36
  batch = []
37
 
38
  while iterating:
39
  frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
40
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ if n_frames % SUBSAMPLE == 0:
42
  batch.append(frame)
43
  if len(batch) == 2 * desired_fps:
44
  inputs = image_processor(images=batch, return_tensors="pt")
 
52
 
53
  boxes = image_processor.post_process_object_detection(
54
  outputs,
55
+ target_sizes=torch.tensor([(height, width)] * len(batch)),
56
  threshold=conf_threshold)
57
 
58
+ for i, (array, box) in enumerate(zip(batch, boxes)):
59
  pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
60
+ pil_image.save(f"batch_{n_chunks}_detection_{i}.png")
61
  frame = np.array(pil_image)
62
  # Convert RGB to BGR
63
  frame = frame[:, :, ::-1].copy()
64
  segment_file.write(frame)
65
 
66
+ batch = []
67
  segment_file.release()
68
+ yield name
69
  n_frames = 0
70
  n_chunks += 1
71
+ name = f"output_{n_chunks}.ts"
72
+ segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) # type: ignore
 
73
 
74
  iterating, frame = cap.read()
75
  n_frames += 1
76
 
77
+ # css=""".my-group {max-width: 600px !important; max-height: 600 !important;}
78
+ # .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
 
 
 
 
 
79
 
80
+ css=""
81
  with gr.Blocks(css=css) as app:
82
  gr.HTML(
83
  """
 
91
  <a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
92
  </h3>
93
  """)
94
+ with gr.Row():
95
+ with gr.Column():
96
+ with gr.Group(elem_classes=["my-group"]):
97
+ video = gr.Video(label="Video Source")
98
+ conf_threshold = gr.Slider(
99
+ label="Confidence Threshold",
100
+ minimum=0.0,
101
+ maximum=1.0,
102
+ step=0.05,
103
+ value=0.30,
104
+ )
105
+ with gr.Column():
106
+ output_video = gr.Video(label="Processed Video", streaming=True, autoplay=True)
107
+
108
+ video.upload(
109
+ fn=stream_object_detection,
110
+ inputs=[video, conf_threshold],
111
+ outputs=[output_video],
112
+ )
113
 
114
  if __name__ == '__main__':
115
  app.launch()
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  safetensors==0.4.3
 
 
2
  transformers
3
  gradio-client @ git+https://github.com/gradio-app/gradio@66349fe26827e3a3c15b738a1177e95fec7f5554#subdirectory=client/python
4
  https://gradio-pypi-previews.s3.amazonaws.com/66349fe26827e3a3c15b738a1177e95fec7f5554/gradio-4.42.0-py3-none-any.whl
 
1
  safetensors==0.4.3
2
+ opencv-python
3
+ torch
4
  transformers
5
  gradio-client @ git+https://github.com/gradio-app/gradio@66349fe26827e3a3c15b738a1177e95fec7f5554#subdirectory=client/python
6
  https://gradio-pypi-previews.s3.amazonaws.com/66349fe26827e3a3c15b738a1177e95fec7f5554/gradio-4.42.0-py3-none-any.whl