File size: 4,975 Bytes
f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 f93a294 463eb87 732dfb0 463eb87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
from gradio_webrtc import WebRTC
import requests
from PIL import Image
import matplotlib.pyplot as plt
from random import choice
import io
import gradio as gr
import cv2
import numpy as np
from io import BytesIO
import random
import tempfile
from pathlib import Path
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.patches as patches
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
# "family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def infer(model, in_pil_img):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
return results
#######################################
def query_data(model, in_pil_img: Image.Image):
return infer(model, in_pil_img)
def get_figure(in_pil_img):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
in_results = query_data(in_pil_img)
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(in_pil_img):
figure = get_figure(in_pil_img)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
def process_single_frame(frame):
# 将 BGR 转换为 RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 创建 PIL 图像对象
pil_image = Image.fromarray(rgb_frame)
# 获取带有标注信息的图像
figure = get_figure(pil_image)
buf = BytesIO()
figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
annotated_image = Image.open(buf).convert('RGB')
return np.array(annotated_image)
def infer_video(input_video_path):
with tempfile.TemporaryDirectory() as tmp_dir:
# output_video_path = Path(tmp_dir) / "output.mp4"
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fps = cap.get(cv2.CAP_PROP_FPS)
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
# out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
frame_count = 0
try:
while frame_count < total_frames:
ret, frame = cap.read()
if not ret:
print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
break
frame_count += 1
# 处理单帧并转换为 OpenCV 格式(BGR)
processed_frame = process_single_frame(frame)
bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
yield bgr_frame
# 可选:显示进度
if frame_count % 30 == 0:
print(f"已处理 {frame_count}/{total_frames} 帧")
# if frame_count == 48:
# print("测试结束")
# return None
finally:
cap.release()
return None
# 更新 Gradio 接口以支持视频输入和输出
with gr.Blocks(title="长沙电网项目",
css=".gradio-container {background:lightyellow;}"
) as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>长沙电网项目</div>")
with gr.Row():
input_video = gr.Video(label="输入视频")
output_video = WebRTC(label="WebRTC Stream",
rtc_configuration=None,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output_video.stream(
fn=infer_video,
inputs=[input_video],
outputs=[output_video],
trigger=detect.click
)
demo.launch(debug=True, share=True)
|