qsitj's picture
Update app.py
81d3e7f verified
raw
history blame
5 kB
import os
from gradio_webrtc import WebRTC
import requests
from PIL import Image
import matplotlib.pyplot as plt
from random import choice
import io
import gradio as gr
import cv2
import numpy as np
from io import BytesIO
import random
import tempfile
from pathlib import Path
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.patches as patches
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
# "family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
#######################################
def query_data(model, in_pil_img: Image.Image):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
print(f"检测结果:{results}")
return results
def get_figure(in_pil_img):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
print(f"图像尺寸:{in_pil_img.size}")
in_results = query_data(in_pil_img)
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
print(f"x: {x}, y: {y}, w: {w}, h: {h}, label: {prediction['label']}, score: {prediction['score']}")
plt.axis("off")
return plt.gcf()
def process_single_frame(frame):
print(f"开始处理单帧")
# 将 BGR 转换为 RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 创建 PIL 图像对象
pil_image = Image.fromarray(rgb_frame)
# 获取带有标注信息的图像
figure = get_figure(pil_image)
buf = BytesIO()
figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
annotated_image = Image.open(buf).convert('RGB')
return np.array(annotated_image)
def infer_video(input_video_path):
print(f"开始处理视频 {input_video_path}")
with tempfile.TemporaryDirectory() as tmp_dir:
# output_video_path = Path(tmp_dir) / "output.mp4"
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fps = cap.get(cv2.CAP_PROP_FPS)
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
# out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
frame_count = 0
try:
while frame_count < total_frames:
ret, frame = cap.read()
if not ret:
print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
break
frame_count += 1
# 处理单帧并转换为 OpenCV 格式(BGR)
processed_frame = process_single_frame(frame)
bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
yield bgr_frame
# 可选:显示进度
if frame_count % 30 == 0:
print(f"已处理 {frame_count}/{total_frames} 帧")
# if frame_count == 48:
# print("测试结束")
# return None
finally:
cap.release()
return None
# 更新 Gradio 接口以支持视频输入和输出
with gr.Blocks(title="基于AI的安全风险识别及防控应用",
css=".gradio-container {background:lightyellow;}"
) as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
with gr.Row():
input_video = gr.Video(label="输入视频")
output_video = WebRTC(label="WebRTC Stream",
rtc_configuration=None,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output_video.stream(
fn=infer_video,
inputs=[input_video],
outputs=[output_video],
trigger=detect.click
)
demo.launch()