VideoMatting / app.py
faiimea's picture
app
e4bc965 verified
raw
history blame
9.39 kB
import argparse
import cv2
import torch
import os
import shutil
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from PIL import Image
import gradio as gr
from dataset import VideoDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment
# --------------- Arguments ---------------
# --------------- Utils ---------------
class VideoWriter:
def __init__(self, path, frame_rate, width, height):
self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
def add_batch(self, frames):
frames = frames.mul(255).byte()
frames = frames.cpu().permute(0, 2, 3, 1).numpy()
for i in range(frames.shape[0]):
frame = frames[i]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
self.out.write(frame)
class ImageSequenceWriter:
def __init__(self, path, extension):
self.path = path
self.extension = extension
self.index = 0
os.makedirs(path)
def add_batch(self, frames):
Thread(target=self._add_batch, args=(frames, self.index)).start()
self.index += frames.shape[0]
def _add_batch(self, frames, index):
frames = frames.cpu()
for i in range(frames.shape[0]):
frame = frames[i]
frame = to_pil_image(frame)
frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
# --------------- Main ---------------
def video_matting(video_src_content,video_bgr_content):
src_video_path = './source/src_video.mp4'
bgr_image_path = './source/bgr_image.png'
with open(src_video_path, 'wb') as video_file:
video_file.write(video_src_content)
# 写入背景图片文件
with open(bgr_image_path, 'wb') as bgr_file:
bgr_file.write(video_bgr_content)
video_src = src_video_path
video_bgr = bgr_image_path
default_args = {
'model_type': 'mattingrefine',
'model_backbone': 'resnet50',
'model_backbone_scale': 0.25,
'model_refine_mode': 'sampling',
'model_refine_sample_pixels': 80000,
'model_checkpoint': './pytorch_resnet50.pth',
'model_refine_threshold':0.7,
'model_refine_kernel_size':3,
'video_src': './source/src.mp4',
'video_bgr': './source/bgr.png',
'video_target_bgr': None,
'video_resize': [1920, 1080],
'device': 'cpu', # 默认设置为CPU
'preprocess_alignment': False,
'output_dir': './output',
'output_types': ['com'],
'output_format': 'video'
}
args = argparse.Namespace(**default_args)
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
# Load video and background
vid = VideoDataset(video_src)
bgr = [Image.open(video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
if args.video_target_bgr:
dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
# Create output directory
# if os.path.exists(args.output_dir):
# if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
# shutil.rmtree(args.output_dir)
# else:
# exit()
# os.makedirs(args.output_dir)
# Prepare writers
if args.output_format == 'video':
h = args.video_resize[1] if args.video_resize is not None else vid.height
w = args.video_resize[0] if args.video_resize is not None else vid.width
if 'com' in args.output_types:
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
if 'pha' in args.output_types:
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
if 'fgr' in args.output_types:
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
if 'err' in args.output_types:
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
if 'ref' in args.output_types:
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
if 'com' in args.output_types:
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
if 'pha' in args.output_types:
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
if 'fgr' in args.output_types:
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
if 'err' in args.output_types:
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
if 'ref' in args.output_types:
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
# Conversion loop
with torch.no_grad():
for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
if args.video_target_bgr:
(src, bgr), tgt_bgr = input_batch
tgt_bgr = tgt_bgr.to(device, non_blocking=True)
else:
src, bgr = input_batch
tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)
if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)
if 'com' in args.output_types:
if args.output_format == 'video':
# Output composite with green background
com = fgr * pha + tgt_bgr * (1 - pha)
com_writer.add_batch(com)
else:
# Output composite as rgba png images
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
com_writer.add_batch(com)
if 'pha' in args.output_types:
pha_writer.add_batch(pha)
if 'fgr' in args.output_types:
fgr_writer.add_batch(fgr)
if 'err' in args.output_types:
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
if 'ref' in args.output_types:
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
return './output/com.mp4'
# 读取本地视频文件的二进制数据
def get_video_content(video_path):
with open(video_path, 'rb') as file:
video_content = file.read()
return video_content
# 假设你的视频文件路径是'./local_video.mp4'
local_video_path = './output/com.mp4'
local_video_content = get_video_content(local_video_path)
# 创建Gradio界面
with gr.Blocks() as demo:
gr.Markdown("## Video Matting")
with gr.Row():
video_src = gr.File(label="Upload Source Video (.mp4)", type="binary", file_types=["mp4"])
video_bgr = gr.File(label="Upload Background Image (.png)", type="binary", file_types=["png"])
with gr.Row():
output_video = gr.Video(label="Result Video")
submit_button = gr.Button("Start Matting")
# def download_video(video_path):
# if os.path.exists(video_path):
# with open(video_path, 'rb') as file:
# video_data = file.read()
# return video_data, "video/mp4", os.path.basename(video_path)
# else:
# return "Not Found", "text/plain", None
def clear_outputs():
output_video.update(value=None)
submit_button.click(
fn=video_matting,
inputs=[video_src, video_bgr],
outputs=[output_video]
)
# download_button = gr.Button("Download")
# download_button.click(
# download_video,
# inputs=[output_video], # 从视频组件传递视频路径
# outputs=[gr.File(label="Download")]
# )
clear_button = gr.Button("Clear")
clear_button.click(fn=clear_outputs, inputs=[], outputs=[])
if __name__ == "__main__":
demo.launch()