File size: 5,454 Bytes
f93a294
 
 
fd08bcf
3123f9f
 
 
 
 
463eb87
3123f9f
 
e1fe61e
3123f9f
463eb87
 
3123f9f
 
463eb87
 
3123f9f
aae3e7a
3123f9f
 
463eb87
 
85114f8
 
3123f9f
 
fd08bcf
f93a294
 
fd08bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3123f9f
85114f8
3123f9f
f93a294
463eb87
3123f9f
 
 
85114f8
fd08bcf
85114f8
fd08bcf
85114f8
 
 
 
463eb87
fd08bcf
 
 
00c0721
fd08bcf
 
 
 
 
85114f8
 
fd08bcf
85114f8
fd08bcf
 
463eb87
85114f8
 
f93a294
fd08bcf
3123f9f
 
 
 
f93a294
3123f9f
 
 
 
 
 
 
f93a294
3123f9f
 
 
 
fd08bcf
3123f9f
f93a294
3123f9f
 
 
 
 
 
 
fd08bcf
3123f9f
 
fd08bcf
3123f9f
 
 
f93a294
fd08bcf
3123f9f
f93a294
3123f9f
 
f93a294
3123f9f
 
f93a294
3123f9f
eec4598
f93a294
463eb87
f93a294
3123f9f
 
 
 
 
463eb87
3123f9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import cv2
import numpy as np
from PIL import Image,ImageDraw, ImageFont
from transformers import pipeline
import torch
from random import choice
import os
from datetime import datetime

# 初始化对象检测器并移动到GPU(如果可用)
detector = pipeline(model="facebook/detr-resnet-101", use_fast=True)
if torch.cuda.is_available():
    detector.model.to('cuda')

COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
          "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
          "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "style": "italic",
    "size": 24,
    "color": "yellow",
    "weight": "bold"
}

label_color_dict = {}

def query_data(in_pil_img: Image.Image):
    results = detector(in_pil_img)
    print(f"检测结果:{results}")
    return results


def get_font_size(box_width, min_size=10, max_size=48):
    """根据边界框宽度计算合适的字体大小"""
    # 字体大小取决于边界框宽度,取值最小为24
    font_size = max(24,int(box_width / 10))
    return max(min(font_size, max_size), min_size)

def get_text_position(box, text_bbox):
    """根据边界框和文本边界框返回适当的位置"""
    xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
    text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
    
    # 尝试将文本放置在边界框上方,但如果空间不足,则放置在边界框内
    if ymin - text_height >= 0:
        return (xmin, ymin - text_height)  # 上方
    else:
        return (xmin, ymin)  # 内部

def get_annotated_image(in_pil_img):
    draw = ImageDraw.Draw(in_pil_img)
    in_results = query_data(in_pil_img)

    for prediction in in_results:
        box = prediction['box']
        label = prediction['label']
        score = round(prediction['score'] * 100, 1)
        if score < 50:
            continue  # 过滤掉低置信度的预测结果

        if label not in label_color_dict:  # 为每个类别随机分配颜色, 后续维持一致
            color = choice(COLORS)
            label_color_dict[label] = color
        else:
            color = label_color_dict[label]

        # 计算字体大小
        box_width = box['xmax'] - box['xmin']
        font_size = get_font_size(box_width)
        font = ImageFont.truetype(font="arial.ttf", size=font_size)  # 确保你有可用的字体文件

        # 获取文本边界框
        text = f"{label}: {score}%"
        text_bbox = draw.textbbox((0, 0), text, font=font)

        # 绘制矩形
        draw.rectangle([box['xmin'], box['ymin'], box['xmax'], box['ymax']], outline=color, width=3)

        # 添加文本
        text_pos = get_text_position(box, text_bbox)
        draw.text(text_pos, text, fill=color, font=font)

    # 返回的是原始图像对象,它已经被修改了
    return np.array(in_pil_img.convert('RGB'))


def process_video(input_video_path):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise ValueError("无法打开输入视频文件")

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # 使用 'mp4v' 编码器
    output_dir = './output_videos'  # 指定输出目录
    os.makedirs(output_dir, exist_ok=True)  # 确保输出目录存在

    # 生成唯一文件名
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_video_filename = f"output_{timestamp}.mp4"
    output_video_path = os.path.join(output_dir, output_video_filename)
    print(f"输出视频信息:{output_video_path}, {width}x{height}, {fps}fps")
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(rgb_frame)
        print(f"Input frame of shape {rgb_frame.shape} and type {rgb_frame.dtype}")  # 调试信息
        annotated_frame = get_annotated_image(pil_image)
        bgr_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
        print(f"Annotated frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}")  # 调试信息
        # 确保帧的尺寸与视频输出一致
        if bgr_frame.shape[:2] != (height, width):
            bgr_frame = cv2.resize(bgr_frame, (width, height))

        print(f"Writing frame of shape {bgr_frame.shape} and type {bgr_frame.dtype}")  # 调试信息
        out.write(bgr_frame)

    cap.release()
    out.release()

    # 返回输出视频路径给 Gradio
    return output_video_path

with gr.Blocks(css=".gradio-container {background:lightyellow;}", title="基于AI的安全风险识别及防控应用") as demo:
    gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>基于AI的安全风险识别及防控应用</div>")
    
    with gr.Row():
        input_video = gr.Video(label="输入视频")
        detect_button = gr.Button("开始检测", variant="primary")
        output_video = gr.Video(label="输出视频")

    # 将process_video函数绑定到按钮点击事件,并将处理后的视频路径传递给output_video
    detect_button.click(process_video, inputs=input_video, outputs=output_video)

demo.launch()