qsitj commited on
Commit
46dad61
·
verified ·
1 Parent(s): 15c2044

增加对置信度和类别的选择

Browse files
Files changed (1) hide show
  1. app.py +32 -12
app.py CHANGED
@@ -17,20 +17,16 @@ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
17
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
18
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
19
 
20
- fdic = {
21
- "style": "italic",
22
- "size": 24,
23
- "color": "yellow",
24
- "weight": "bold"
25
- }
26
 
27
- threshold = 80 # 置信度阈值
 
 
28
 
29
  label_color_dict = {}
30
 
31
  def query_data(in_pil_img: Image.Image):
32
  results = detector(in_pil_img)
33
- # print(f"检测结果:{results}")
34
  return results
35
 
36
 
@@ -61,6 +57,8 @@ def get_annotated_image(in_pil_img):
61
  score = round(prediction['score'] * 100, 1)
62
  if score < threshold:
63
  continue # 过滤掉低置信度的预测结果
 
 
64
 
65
  if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
66
  color = choice(COLORS)
@@ -105,7 +103,7 @@ def process_video(input_video_path):
105
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
106
  output_video_filename = f"output_{timestamp}.mp4"
107
  output_video_path = os.path.join(output_dir, output_video_filename)
108
- # print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
109
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
110
 
111
  while True:
@@ -115,15 +113,15 @@ def process_video(input_video_path):
115
 
116
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
117
  pil_image = Image.fromarray(rgb_frame)
118
- # print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
119
  annotated_frame = get_annotated_image(pil_image)
120
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
121
- # print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
122
  # 确保帧的尺寸与视频输出一致
123
  if bgr_frame.shape[:2] != (height, width):
124
  bgr_frame = cv2.resize(bgr_frame, (width, height))
125
 
126
- # print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
127
  out.write(bgr_frame)
128
 
129
  cap.release()
@@ -132,9 +130,31 @@ def process_video(input_video_path):
132
  # 返回输出视频路径给 Gradio
133
  return output_video_path
134
 
 
 
 
 
 
 
 
 
 
 
 
135
  with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
136
  gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
137
 
 
 
 
 
 
 
 
 
 
 
 
138
  with gr.Row():
139
  input_video = gr.Video(label="输入视频")
140
  detect_button = gr.Button("开始检测", variant="primary")
 
17
  "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
18
  "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
19
 
 
 
 
 
 
 
20
 
21
+
22
+ threshold = 90 # 置信度阈值
23
+ label_list = ["person", "car", "truck"]
24
 
25
  label_color_dict = {}
26
 
27
  def query_data(in_pil_img: Image.Image):
28
  results = detector(in_pil_img)
29
+ print(f"检测结果:{results}")
30
  return results
31
 
32
 
 
57
  score = round(prediction['score'] * 100, 1)
58
  if score < threshold:
59
  continue # 过滤掉低置信度的预测结果
60
+ if label not in label_list:
61
+ continue # 过滤掉不在允许显示的label列表中的预测结果
62
 
63
  if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
64
  color = choice(COLORS)
 
103
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
104
  output_video_filename = f"output_{timestamp}.mp4"
105
  output_video_path = os.path.join(output_dir, output_video_filename)
106
+ print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
107
  out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
108
 
109
  while True:
 
113
 
114
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
115
  pil_image = Image.fromarray(rgb_frame)
116
+ print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
117
  annotated_frame = get_annotated_image(pil_image)
118
  bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
119
+ print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
120
  # 确保帧的尺寸与视频输出一致
121
  if bgr_frame.shape[:2] != (height, width):
122
  bgr_frame = cv2.resize(bgr_frame, (width, height))
123
 
124
+ print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
125
  out.write(bgr_frame)
126
 
127
  cap.release()
 
130
  # 返回输出视频路径给 Gradio
131
  return output_video_path
132
 
133
+ def change_threshold(value):
134
+ global threshold
135
+ threshold = value
136
+ return f"当前置信度阈值为{threshold}%"
137
+
138
+ def update_labels(selected_labels):
139
+ # 更新 label_list 以匹配用户的选择
140
+ global label_list
141
+ label_list = selected_labels
142
+ return selected_labels
143
+
144
  with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
145
  gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
146
 
147
+ # 设置置信度阈值
148
+ threshold_slider = gr.Slider(minimum=0, maximum=100, value=threshold, step=1, label="置信度阈值")
149
+ textbox = gr.Textbox(value=f"当前置信度阈值为{threshold}%", label="置信度显示")
150
+ # 绑定滑块变化事件到change_threshold函数,同时设置输出为textbox
151
+ threshold_slider.change(fn=change_threshold, inputs=[threshold_slider], outputs=[textbox])
152
+
153
+ # 设置允许显示的label列表
154
+ label_checkboxes = gr.CheckboxGroup(choices=label_list, value=label_list, label="检测目标")
155
+ # 允许修改label_list
156
+ label_checkboxes.change(fn=update_labels, inputs=[label_checkboxes], outputs=[label_checkboxes])
157
+
158
  with gr.Row():
159
  input_video = gr.Video(label="输入视频")
160
  detect_button = gr.Button("开始检测", variant="primary")