File size: 5,120 Bytes
f93a294 463eb87 e1fe61e 463eb87 f93a294 463eb87 81d3e7f 463eb87 f93a294 81d3e7f f93a294 81d3e7f f93a294 463eb87 f93a294 463eb87 81d3e7f f93a294 463eb87 766a1a8 463eb87 f93a294 81d3e7f f93a294 463eb87 f93a294 463eb87 f93a294 81d3e7f f93a294 13608bf f93a294 eec4598 f93a294 463eb87 f93a294 463eb87 f313675 463eb87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
from gradio_webrtc import WebRTC
import requests
from PIL import Image
import matplotlib.pyplot as plt
from random import choice
import io
import gradio as gr
import cv2
import numpy as np
from io import BytesIO
import random
import tempfile
from pathlib import Path
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.patches as patches
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
if torch.cuda.is_available():
print("use cuda")
detector50.model.to('cuda')
detector101.model.to('cuda')
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
# "family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
#######################################
def query_data(model, in_pil_img: Image.Image):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
print(f"检测结果:{results}")
return results
def get_figure(in_pil_img):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
print(f"图像尺寸:{in_pil_img.size}")
in_results = query_data(in_pil_img)
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
print(f"x: {x}, y: {y}, w: {w}, h: {h}, label: {prediction['label']}, score: {prediction['score']}")
plt.axis("off")
return plt.gcf()
def process_single_frame(frame):
print(f"开始处理单帧")
# 将 BGR 转换为 RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 创建 PIL 图像对象
pil_image = Image.fromarray(rgb_frame)
# 获取带有标注信息的图像
figure = get_figure(pil_image)
buf = BytesIO()
figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
annotated_image = Image.open(buf).convert('RGB')
return np.array(annotated_image)
def infer_video(input_video_path):
print(f"开始处理视频 {input_video_path}")
with tempfile.TemporaryDirectory() as tmp_dir:
# output_video_path = Path(tmp_dir) / "output.mp4"
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fps = cap.get(cv2.CAP_PROP_FPS)
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
# out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
frame_count = 0
try:
while frame_count < total_frames:
ret, frame = cap.read()
if not ret:
print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
break
frame_count += 1
# 处理单帧并转换为 OpenCV 格式(BGR)
processed_frame = process_single_frame(frame)
bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
yield bgr_frame
# 可选:显示进度
if frame_count % 30 == 0:
print(f"已处理 {frame_count}/{total_frames} 帧")
# if frame_count == 48:
# print("测试结束")
# return None
finally:
cap.release()
return None
# 更新 Gradio 接口以支持视频输入和输出
with gr.Blocks(title="基于AI的安全风险识别及防控应用",
css=".gradio-container {background:lightyellow;}"
) as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
with gr.Row():
input_video = gr.Video(label="输入视频")
output_video = WebRTC(label="WebRTC Stream",
rtc_configuration=None,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output_video.stream(
fn=infer_video,
inputs=[input_video],
outputs=[output_video],
trigger=detect.click
)
demo.launch()
|