File size: 4,033 Bytes
f93a294 463eb87 3123f9f 463eb87 3123f9f e1fe61e 3123f9f 463eb87 3123f9f 463eb87 3123f9f 463eb87 3123f9f f93a294 3123f9f 463eb87 3123f9f f93a294 463eb87 3123f9f 463eb87 3123f9f 463eb87 f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f f93a294 3123f9f eec4598 f93a294 463eb87 f93a294 3123f9f 463eb87 3123f9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from transformers import pipeline
import torch
from random import choice
from io import BytesIO
import os
from datetime import datetime
# 初始化对象检测器并移动到GPU(如果可用)
detector = pipeline(model="facebook/detr-resnet-101", use_fast=True)
if torch.cuda.is_available():
detector.model.to('cuda')
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
"style": "italic",
"size": 15,
"color": "yellow",
"weight": "bold"
}
def query_data(in_pil_img: Image.Image):
results = detector(in_pil_img)
print(f"检测结果:{results}")
return results
def get_annotated_image(in_pil_img):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
in_results = query_data(in_pil_img)
for prediction in in_results:
color = choice(COLORS)
box = prediction['box']
label = prediction['label']
score = round(prediction['score'] * 100, 1)
ax.add_patch(plt.Rectangle((box['xmin'], box['ymin']),
box['xmax'] - box['xmin'],
box['ymax'] - box['ymin'],
fill=False, color=color, linewidth=3))
ax.text(box['xmin'], box['ymin'], f"{label}: {score}%", fontdict=fdic)
plt.axis("off")
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close() # 关闭图形以释放内存
buf.seek(0)
annotated_image = Image.open(buf).convert('RGB')
return np.array(annotated_image)
def process_video(input_video_path):
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 'mp4v' 编码器
output_dir = './output_videos' # 指定输出目录
os.makedirs(output_dir, exist_ok=True) # 确保输出目录存在
# 生成唯一文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_video_filename = f"output_{timestamp}.mp4"
output_video_path = os.path.join(output_dir, output_video_filename)
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
while True:
ret, frame = cap.read()
if not ret:
break
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_frame)
annotated_frame = get_annotated_image(pil_image)
bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
# 确保帧的尺寸与视频输出一致
if bgr_frame.shape[:2] != (height, width):
bgr_frame = cv2.resize(bgr_frame, (width, height))
print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息
out.write(bgr_frame)
cap.release()
out.release()
# 返回输出视频路径给 Gradio
return output_video_path
with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
with gr.Row():
input_video = gr.Video(label="输入视频")
detect_button = gr.Button("开始检测", variant="primary")
output_video = gr.Video(label="输出视频")
# 将process_video函数绑定到按钮点击事件,并将处理后的视频路径传递给output_video
detect_button.click(process_video, inputs=input_video, outputs=output_video)
demo.launch() |