import os from gradio_webrtc import WebRTC import requests from PIL import Image import matplotlib.pyplot as plt from random import choice import io import gradio as gr import cv2 import numpy as np from io import BytesIO import random import tempfile from pathlib import Path import torch from transformers import pipeline from PIL import Image import matplotlib.patches as patches detector50 = pipeline(model="facebook/detr-resnet-50") detector101 = pipeline(model="facebook/detr-resnet-101") COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff", "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf", "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"] fdic = { # "family" : "Impact", "style" : "italic", "size" : 15, "color" : "yellow", "weight" : "bold" } def infer(model, in_pil_img): results = None if model == "detr-resnet-101": results = detector101(in_pil_img) else: results = detector50(in_pil_img) return results ####################################### def query_data(model, in_pil_img: Image.Image): return infer(model, in_pil_img) def get_figure(in_pil_img): plt.figure(figsize=(16, 10)) plt.imshow(in_pil_img) ax = plt.gca() in_results = query_data(in_pil_img) for prediction in in_results: selected_color = choice(COLORS) x, y = prediction['box']['xmin'], prediction['box']['ymin'], w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin'] ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3)) ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic) plt.axis("off") return plt.gcf() def infer(in_pil_img): figure = get_figure(in_pil_img) buf = io.BytesIO() figure.savefig(buf, bbox_inches='tight') buf.seek(0) output_pil_img = Image.open(buf) return output_pil_img def process_single_frame(frame): # 将 BGR 转换为 RGB rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 创建 PIL 图像对象 pil_image = Image.fromarray(rgb_frame) # 获取带有标注信息的图像 figure = get_figure(pil_image) buf = BytesIO() figure.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) annotated_image = Image.open(buf).convert('RGB') return np.array(annotated_image) def infer_video(input_video_path): with tempfile.TemporaryDirectory() as tmp_dir: # output_video_path = Path(tmp_dir) / "output.mp4" cap = cv2.VideoCapture(input_video_path) if not cap.isOpened(): raise ValueError("无法打开输入视频文件") # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # fps = cap.get(cv2.CAP_PROP_FPS) # fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) # 使用原始视频的编码器 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取总帧数 # out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width, height)) frame_count = 0 try: while frame_count < total_frames: ret, frame = cap.read() if not ret: print(f"提前结束:在第 {frame_count} 帧时无法读取帧") break frame_count += 1 # 处理单帧并转换为 OpenCV 格式(BGR) processed_frame = process_single_frame(frame) bgr_frame = cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR) yield bgr_frame # 可选:显示进度 if frame_count % 30 == 0: print(f"已处理 {frame_count}/{total_frames} 帧") # if frame_count == 48: # print("测试结束") # return None finally: cap.release() return None # 更新 Gradio 接口以支持视频输入和输出 with gr.Blocks(title="长沙电网项目", css=".gradio-container {background:lightyellow;}" ) as demo: gr.HTML("
长沙电网项目
") with gr.Row(): input_video = gr.Video(label="输入视频") output_video = WebRTC(label="WebRTC Stream", rtc_configuration=None, mode="receive", modality="video") detect = gr.Button("Detect", variant="primary") output_video.stream( fn=infer_video, inputs=[input_video], outputs=[output_video], trigger=detect.click ) demo.launch()