qsitj's picture
Update app.py
732dfb0 verified
raw
history blame
4.98 kB
import os
from gradio_webrtc import WebRTC
import requests
from PIL import Image
import matplotlib.pyplot as plt
from random import choice
import io
import gradio as gr
import cv2
import numpy as np
from io import BytesIO
import random
import tempfile
from pathlib import Path
import torch
from transformers import pipeline
from PIL import Image
import matplotlib.patches as patches
detector50 = pipeline(model="facebook/detr-resnet-50")
detector101 = pipeline(model="facebook/detr-resnet-101")
COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
"#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
"#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
fdic = {
# "family" : "Impact",
"style" : "italic",
"size" : 15,
"color" : "yellow",
"weight" : "bold"
}
def infer(model, in_pil_img):
results = None
if model == "detr-resnet-101":
results = detector101(in_pil_img)
else:
results = detector50(in_pil_img)
return results
#######################################
def query_data(model, in_pil_img: Image.Image):
return infer(model, in_pil_img)
def get_figure(in_pil_img):
plt.figure(figsize=(16, 10))
plt.imshow(in_pil_img)
ax = plt.gca()
in_results = query_data(in_pil_img)
for prediction in in_results:
selected_color = choice(COLORS)
x, y = prediction['box']['xmin'], prediction['box']['ymin'],
w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
plt.axis("off")
return plt.gcf()
def infer(in_pil_img):
figure = get_figure(in_pil_img)
buf = io.BytesIO()
figure.savefig(buf, bbox_inches='tight')
buf.seek(0)
output_pil_img = Image.open(buf)
return output_pil_img
def process_single_frame(frame):
# 将 BGR 转换为 RGB
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 创建 PIL 图像对象
pil_image = Image.fromarray(rgb_frame)
# 获取带有标注信息的图像
figure = get_figure(pil_image)
buf = BytesIO()
figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
annotated_image = Image.open(buf).convert('RGB')
return np.array(annotated_image)
def infer_video(input_video_path):
with tempfile.TemporaryDirectory() as tmp_dir:
# output_video_path = Path(tmp_dir) / "output.mp4"
cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
raise ValueError("无法打开输入视频文件")
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# fps = cap.get(cv2.CAP_PROP_FPS)
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数
# out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height))
frame_count = 0
try:
while frame_count < total_frames:
ret, frame = cap.read()
if not ret:
print(f"提前结束:在第 {frame_count} 帧时无法读取帧")
break
frame_count += 1
# 处理单帧并转换为 OpenCV 格式(BGR)
processed_frame = process_single_frame(frame)
bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)
yield bgr_frame
# 可选:显示进度
if frame_count % 30 == 0:
print(f"已处理 {frame_count}/{total_frames} 帧")
# if frame_count == 48:
# print("测试结束")
# return None
finally:
cap.release()
return None
# 更新 Gradio 接口以支持视频输入和输出
with gr.Blocks(title="长沙电网项目",
css=".gradio-container {background:lightyellow;}"
) as demo:
gr.HTML("<div style='font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;'>长沙电网项目</div>")
with gr.Row():
input_video = gr.Video(label="输入视频")
output_video = WebRTC(label="WebRTC Stream",
rtc_configuration=None,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output_video.stream(
fn=infer_video,
inputs=[input_video],
outputs=[output_video],
trigger=detect.click
)
demo.launch(debug=True, share=True)