增加对置信度和类别的选择
Browse files
app.py
CHANGED
@@ -17,20 +17,16 @@ COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
|
|
17 |
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
|
18 |
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
|
19 |
|
20 |
-
fdic = {
|
21 |
-
"style": "italic",
|
22 |
-
"size": 24,
|
23 |
-
"color": "yellow",
|
24 |
-
"weight": "bold"
|
25 |
-
}
|
26 |
|
27 |
-
|
|
|
|
|
28 |
|
29 |
label_color_dict = {}
|
30 |
|
31 |
def query_data(in_pil_img: Image.Image):
|
32 |
results = detector(in_pil_img)
|
33 |
-
|
34 |
return results
|
35 |
|
36 |
|
@@ -61,6 +57,8 @@ def get_annotated_image(in_pil_img):
|
|
61 |
score = round(prediction['score'] * 100, 1)
|
62 |
if score < threshold:
|
63 |
continue # 过滤掉低置信度的预测结果
|
|
|
|
|
64 |
|
65 |
if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
|
66 |
color = choice(COLORS)
|
@@ -105,7 +103,7 @@ def process_video(input_video_path):
|
|
105 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
106 |
output_video_filename = f"output_{timestamp}.mp4"
|
107 |
output_video_path = os.path.join(output_dir, output_video_filename)
|
108 |
-
|
109 |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
110 |
|
111 |
while True:
|
@@ -115,15 +113,15 @@ def process_video(input_video_path):
|
|
115 |
|
116 |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
117 |
pil_image = Image.fromarray(rgb_frame)
|
118 |
-
|
119 |
annotated_frame = get_annotated_image(pil_image)
|
120 |
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
121 |
-
|
122 |
# 确保帧的尺寸与视频输出一致
|
123 |
if bgr_frame.shape[:2] != (height, width):
|
124 |
bgr_frame = cv2.resize(bgr_frame, (width, height))
|
125 |
|
126 |
-
|
127 |
out.write(bgr_frame)
|
128 |
|
129 |
cap.release()
|
@@ -132,9 +130,31 @@ def process_video(input_video_path):
|
|
132 |
# 返回输出视频路径给 Gradio
|
133 |
return output_video_path
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
|
136 |
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
with gr.Row():
|
139 |
input_video = gr.Video(label="输入视频")
|
140 |
detect_button = gr.Button("开始检测", variant="primary")
|
|
|
17 |
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
|
18 |
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
|
22 |
+
threshold = 90 # 置信度阈值
|
23 |
+
label_list = ["person", "car", "truck"]
|
24 |
|
25 |
label_color_dict = {}
|
26 |
|
27 |
def query_data(in_pil_img: Image.Image):
|
28 |
results = detector(in_pil_img)
|
29 |
+
print(f"检测结果:{results}")
|
30 |
return results
|
31 |
|
32 |
|
|
|
57 |
score = round(prediction['score'] * 100, 1)
|
58 |
if score < threshold:
|
59 |
continue # 过滤掉低置信度的预测结果
|
60 |
+
if label not in label_list:
|
61 |
+
continue # 过滤掉不在允许显示的label列表中的预测结果
|
62 |
|
63 |
if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致
|
64 |
color = choice(COLORS)
|
|
|
103 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
104 |
output_video_filename = f"output_{timestamp}.mp4"
|
105 |
output_video_path = os.path.join(output_dir, output_video_filename)
|
106 |
+
print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
|
107 |
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
108 |
|
109 |
while True:
|
|
|
113 |
|
114 |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
115 |
pil_image = Image.fromarray(rgb_frame)
|
116 |
+
print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息
|
117 |
annotated_frame = get_annotated_image(pil_image)
|
118 |
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
119 |
+
print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
|
120 |
# 确保帧的尺寸与视频输出一致
|
121 |
if bgr_frame.shape[:2] != (height, width):
|
122 |
bgr_frame = cv2.resize(bgr_frame, (width, height))
|
123 |
|
124 |
+
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
|
125 |
out.write(bgr_frame)
|
126 |
|
127 |
cap.release()
|
|
|
130 |
# 返回输出视频路径给 Gradio
|
131 |
return output_video_path
|
132 |
|
133 |
+
def change_threshold(value):
|
134 |
+
global threshold
|
135 |
+
threshold = value
|
136 |
+
return f"当前置信度阈值为{threshold}%"
|
137 |
+
|
138 |
+
def update_labels(selected_labels):
|
139 |
+
# 更新 label_list 以匹配用户的选择
|
140 |
+
global label_list
|
141 |
+
label_list = selected_labels
|
142 |
+
return selected_labels
|
143 |
+
|
144 |
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
|
145 |
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
|
146 |
|
147 |
+
# 设置置信度阈值
|
148 |
+
threshold_slider = gr.Slider(minimum=0, maximum=100, value=threshold, step=1, label="置信度阈值")
|
149 |
+
textbox = gr.Textbox(value=f"当前置信度阈值为{threshold}%", label="置信度显示")
|
150 |
+
# 绑定滑块变化事件到change_threshold函数,同时设置输出为textbox
|
151 |
+
threshold_slider.change(fn=change_threshold, inputs=[threshold_slider], outputs=[textbox])
|
152 |
+
|
153 |
+
# 设置允许显示的label列表
|
154 |
+
label_checkboxes = gr.CheckboxGroup(choices=label_list, value=label_list, label="检测目标")
|
155 |
+
# 允许修改label_list
|
156 |
+
label_checkboxes.change(fn=update_labels, inputs=[label_checkboxes], outputs=[label_checkboxes])
|
157 |
+
|
158 |
with gr.Row():
|
159 |
input_video = gr.Video(label="输入视频")
|
160 |
detect_button = gr.Button("开始检测", variant="primary")
|