qsitj commited on
Commit
3123f9f
·
verified ·
1 Parent(s): cc27b8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -136
app.py CHANGED
@@ -1,177 +1,114 @@
1
- import os
2
-
3
- from gradio_webrtc import WebRTC
4
- import requests
5
- from PIL import Image
6
-
7
- import matplotlib.pyplot as plt
8
-
9
- from random import choice
10
- import io
11
-
12
  import gradio as gr
13
-
14
  import cv2
15
  import numpy as np
16
-
17
- from io import BytesIO
18
- import random
19
- import tempfile
20
- from pathlib import Path
21
-
22
- import torch
23
- from transformers import pipeline
24
-
25
  from PIL import Image
 
 
 
 
 
 
 
26
 
27
- import matplotlib.patches as patches
28
-
29
-
30
- detector50 = pipeline(model="facebook/detr-resnet-50")
31
-
32
- detector101 = pipeline(model="facebook/detr-resnet-101")
33
-
34
  if torch.cuda.is_available():
35
- # print("##############------------use cuda!------------#################")
36
- detector50.model.to('cuda')
37
- detector101.model.to('cuda')
38
-
39
- model = "detr-resnet-101"
40
 
41
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
42
- "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
43
- "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
44
 
45
  fdic = {
46
- # "family" : "Impact",
47
- "style" : "italic",
48
- "size" : 15,
49
- "color" : "yellow",
50
- "weight" : "bold"
51
  }
52
 
53
- #######################################
54
-
55
-
56
- def query_data(model, in_pil_img: Image.Image):
57
- results = None
58
- if model == "detr-resnet-101":
59
- results = detector101(in_pil_img)
60
- else:
61
- results = detector50(in_pil_img)
62
- # print(f"检测结果:{results}")
63
  return results
64
 
65
-
66
-
67
- def get_figure(in_pil_img):
68
  plt.figure(figsize=(16, 10))
69
  plt.imshow(in_pil_img)
70
-
71
  ax = plt.gca()
72
- # print(f"图像尺寸:{in_pil_img.size}")
73
- in_results = query_data(model, in_pil_img)
74
 
75
  for prediction in in_results:
76
- selected_color = choice(COLORS)
77
-
78
- x, y = prediction['box']['xmin'], prediction['box']['ymin'],
79
- w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
80
 
81
- ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
82
- ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
83
- # print(f"x: {x}, y: {y}, w: {w}, h: {h}, label: {prediction['label']}, score: {prediction['score']}")
 
 
84
 
85
  plt.axis("off")
86
-
87
- return plt.gcf()
88
-
89
-
90
- def process_single_frame(frame):
91
- # print(f"开始处理单帧")
92
- # 将 BGR 转换为 RGB
93
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
94
-
95
- # 创建 PIL 图像对象
96
- pil_image = Image.fromarray(rgb_frame)
97
-
98
- # 获取带有标注信息的图像
99
- figure = get_figure(pil_image)
100
-
101
  buf = BytesIO()
102
- figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
 
103
  buf.seek(0)
104
  annotated_image = Image.open(buf).convert('RGB')
105
-
106
  return np.array(annotated_image)
107
 
 
 
 
 
108
 
109
- def infer_video(input_video_path):
110
- # print(f"开始处理视频 {input_video_path}")
111
- with tempfile.TemporaryDirectory() as tmp_dir:
112
- # output_video_path = Path(tmp_dir) / "output.mp4"
113
- cap = cv2.VideoCapture(input_video_path)
114
-
115
- if not cap.isOpened():
116
- raise ValueError("无法打开输入视频文件")
117
-
118
- # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
119
- # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
120
- # fps = cap.get(cv2.CAP_PROP_FPS)
121
- # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
122
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
123
 
124
- # out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
 
 
 
125
 
126
- frame_count = 0
127
- try:
128
- while frame_count < total_frames:
129
- ret, frame = cap.read()
130
- if not ret:
131
- print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
132
- break
133
-
134
- frame_count += 1
135
 
136
- # 处理单帧并转换为 OpenCV 格式(BGR)
137
- processed_frame = process_single_frame(frame)
138
- bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
139
 
140
- yield bgr_frame
 
 
141
 
142
- # 可选:显示进度
143
- if frame_count % 30 == 0:
144
- print(f"已处理 {frame_count}/{total_frames} 帧")
145
 
146
- # if frame_count == 48:
147
- # print("测试结束")
148
- # return None
149
 
150
- finally:
151
- cap.release()
152
-
153
- return None
154
 
155
-
156
- # 更新 Gradio 接口以支持视频输入和输出
157
- with gr.Blocks(title="基于AI的安全风险识别及防控应用",
158
- css=".gradio-container {background:lightyellow;}"
159
- ) as demo:
160
  gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
161
 
162
  with gr.Row():
163
  input_video = gr.Video(label="输入视频")
164
- output_video = WebRTC(label="WebRTC Stream",
165
- rtc_configuration=None,
166
- mode="receive",
167
- modality="video")
168
- detect = gr.Button("Detect", variant="primary")
169
- output_video.stream(
170
- fn=infer_video,
171
- inputs=[input_video],
172
- outputs=[output_video],
173
- trigger=detect.click
174
- )
175
-
176
- demo.launch()
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import cv2
3
  import numpy as np
 
 
 
 
 
 
 
 
 
4
  from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from transformers import pipeline
7
+ import torch
8
+ from random import choice
9
+ from io import BytesIO
10
+ import os
11
+ from datetime import datetime
12
 
13
+ # 初始化对象检测器并移动到GPU(如果可用)
14
+ detector = pipeline(model="facebook/detr-resnet-101", use_fast=True)
 
 
 
 
 
15
  if torch.cuda.is_available():
16
+ detector.model.to('cuda')
 
 
 
 
17
 
18
  COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
19
+ "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
20
+ "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
21
 
22
  fdic = {
23
+ "style": "italic",
24
+ "size": 15,
25
+ "color": "yellow",
26
+ "weight": "bold"
 
27
  }
28
 
29
+ def query_data(in_pil_img: Image.Image):
30
+ results = detector(in_pil_img)
31
+ print(f"检测结果:{results}")
 
 
 
 
 
 
 
32
  return results
33
 
34
+ def get_annotated_image(in_pil_img):
 
 
35
  plt.figure(figsize=(16, 10))
36
  plt.imshow(in_pil_img)
 
37
  ax = plt.gca()
38
+ in_results = query_data(in_pil_img)
 
39
 
40
  for prediction in in_results:
41
+ color = choice(COLORS)
42
+ box = prediction['box']
43
+ label = prediction['label']
44
+ score = round(prediction['score'] * 100, 1)
45
 
46
+ ax.add_patch(plt.Rectangle((box['xmin'], box['ymin']),
47
+ box['xmax'] - box['xmin'],
48
+ box['ymax'] - box['ymin'],
49
+ fill=False, color=color, linewidth=3))
50
+ ax.text(box['xmin'], box['ymin'], f"{label}: {score}%", fontdict=fdic)
51
 
52
  plt.axis("off")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  buf = BytesIO()
54
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
55
+ plt.close() # 关闭图形以释放内存
56
  buf.seek(0)
57
  annotated_image = Image.open(buf).convert('RGB')
 
58
  return np.array(annotated_image)
59
 
60
+ def process_video(input_video_path):
61
+ cap = cv2.VideoCapture(input_video_path)
62
+ if not cap.isOpened():
63
+ raise ValueError("无法打开输入视频文件")
64
 
65
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
66
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
67
+ fps = cap.get(cv2.CAP_PROP_FPS)
68
+
69
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 'mp4v' 编码器
70
+ output_dir = './output_videos' # 指定输出目录
71
+ os.makedirs(output_dir, exist_ok=True) # 确保输出目录存在
 
 
 
 
 
 
 
72
 
73
+ # 生成唯一文件名
74
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
75
+ output_video_filename = f"output_{timestamp}.mp4"
76
+ output_video_path = os.path.join(output_dir, output_video_filename)
77
 
78
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
 
 
 
 
 
 
 
 
79
 
80
+ while True:
81
+ ret, frame = cap.read()
82
+ if not ret:
83
+ break
84
+
85
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+ pil_image = Image.fromarray(rgb_frame)
87
+ annotated_frame = get_annotated_image(pil_image)
88
+ bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
89
 
90
+ # 确保帧的尺寸与视频输出一致
91
+ if bgr_frame.shape[:2] != (height, width):
92
+ bgr_frame = cv2.resize(bgr_frame, (width, height))
93
 
94
+ print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
95
+ out.write(bgr_frame)
 
96
 
97
+ cap.release()
98
+ out.release()
 
99
 
100
+ # 返回输出视频路径给 Gradio
101
+ return output_video_path
 
 
102
 
103
+ with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
 
 
 
 
104
  gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
105
 
106
  with gr.Row():
107
  input_video = gr.Video(label="输入视频")
108
+ detect_button = gr.Button("开始检测", variant="primary")
109
+ output_video = gr.Video(label="输出视频")
110
+
111
+ # 将process_video函数绑定到按钮点击事件,并将处理后的视频路径传递给output_video
112
+ detect_button.click(process_video, inputs=input_video, outputs=output_video)
 
 
 
 
 
 
 
 
113
 
114
+ demo.launch()