import gradio as gr import cv2 import numpy as np from PIL import Image,ImageDraw from transformers import pipeline import torch from random import choice import os from datetime import datetime # 初始化对象检测器并移动到GPU(如果可用) detector = pipeline(model="facebook/detr-resnet-101", use_fast=True) if torch.cuda.is_available(): detector.model.to('cuda') COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff", "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf", "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"] fdic = { "style": "italic", "size": 24, "color": "yellow", "weight": "bold" } label_color_dict = {} def query_data(in_pil_img: Image.Image): results = detector(in_pil_img) # print(f"检测结果:{results}") return results def get_annotated_image(in_pil_img): draw = ImageDraw.Draw(in_pil_img) in_results = query_data(in_pil_img) for prediction in in_results: box = prediction['box'] label = prediction['label'] score = round(prediction['score'] * 100, 1) if score < 50: continue # 过滤掉低置信度的预测结果 if label not in label_color_dict: # 为每个类别随机分配颜色, 后续维持一致 color = choice(COLORS) label_color_dict[label] = color else: color = label_color_dict[label] # 绘制矩形 draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3) # 添加文本 draw.text((box['xmin'], box['ymin']), f"{label}: {score}%", fill=color, fontdict=fdic) # 返回的是原始图像对象,它已经被修改了 return np.array(in_pil_img.convert('RGB')) def process_video(input_video_path): cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): raise ValueError("无法打开输入视频文件") width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 'mp4v' 编码器 output_dir = './output_videos' # 指定输出目录 os.makedirs(output_dir, exist_ok=True) # 确保输出目录存在 # 生成唯一文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_video_filename = f"output_{timestamp}.mp4" output_video_path = os.path.join(output_dir, output_video_filename) # print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps") out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) while True: ret, frame = cap.read() if not ret: break rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(rgb_frame) # print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}") # 调试信息 annotated_frame = get_annotated_image(pil_image) bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) # print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息 # 确保帧的尺寸与视频输出一致 if bgr_frame.shape[:2] != (height, width): bgr_frame = cv2.resize(bgr_frame, (width, height)) # print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}") # 调试信息 out.write(bgr_frame) cap.release() out.release() # 返回输出视频路径给 Gradio return output_video_path with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo: gr.HTML("
基于AI的安全风险识别及防控应用
") with gr.Row(): input_video = gr.Video(label="输入视频") detect_button = gr.Button("开始检测", variant="primary") output_video = gr.Video(label="输出视频") # 将process_video函数绑定到按钮点击事件,并将处理后的视频路径传递给output_video detect_button.click(process_video, inputs=input_video, outputs=output_video) demo.launch()