import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from huggingface_hub import hf_hub_download import gradio as gr # from gradio_imageslider import ImageSlider from briarmbg import BriaRMBG import PIL from PIL import Image from typing import Tuple import cv2 import os import shutil import glob from tqdm import tqdm from ffmpy import FFmpeg net = BriaRMBG() # model_path = "./model1.pth" model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth") if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net = net.cuda() print("GPU is available") else: net.load_state_dict(torch.load(model_path, map_location="cpu")) print("GPU is NOT available") net.eval() def resize_image(image): image = image.convert("RGB") model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def process(image): # prepare input orig_image = Image.fromarray(image) w, h = orig_im_size = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1) im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = torch.divide(im_tensor, 255.0) im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) if torch.cuda.is_available(): im_tensor = im_tensor.cuda() # inference result = net(im_tensor) # post process result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) # image to pil im_array = (result * 255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)) # paste the mask on the original image new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0)) new_im.paste(orig_image, mask=pil_im) # new_orig_image = orig_image.convert('RGBA') return new_im # return [new_orig_image, new_im] def process_video(video, key_color): workspace = "./temp" original_video_name_without_ext = os.path.splitext(os.path.basename(video))[0] os.makedirs(workspace, exist_ok=True) os.makedirs(f"{workspace}/frames", exist_ok=True) os.makedirs(f"{workspace}/result", exist_ok=True) os.makedirs("./video_result", exist_ok=True) video_file = cv2.VideoCapture(video) fps = video_file.get(cv2.CAP_PROP_FPS) # まず、videoを読み込んで、./frames/にフレームを保存する # fase, load video and save frames to ./frames/ def extract_frames(): success, frame = video_file.read() frame_num = 0 with tqdm( total=None, desc="Extracting frames", ) as pbar: while success: file_name = f"{workspace}/frames/{frame_num:015d}.png" cv2.imwrite(file_name, frame) success, frame = video_file.read() frame_num += 1 pbar.update(1) video_file.release() return extract_frames() # それぞれのフレームに対して処理を行う # process each frame def process_frame(frame_file): image = cv2.imread(frame_file) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) new_image = process(image) # key_colorを背景にする # set key_color as background key_back_image = Image.new("RGBA", new_image.size, key_color) new_image = Image.alpha_composite(key_back_image, new_image) new_image.save(frame_file) frame_files = sorted(glob.glob(f"{workspace}/frames/*.png")) with tqdm(total=len(frame_files), desc="Processing frames") as pbar: for file in frame_files: process_frame(file) pbar.update(1) # frameからvideoを作成 # create video from frames # first_frame = cv2.imread(frame_files[0]) # h, w, _ = first_frame.shape # fourcc = cv2.VideoWriter_fourcc(*"avc1") # new_video = cv2.VideoWriter(f"{workspace}/result/video.mp4", fourcc, fps, (w, h)) # for file in frame_files: # image = cv2.imread(file) # new_video.write(image) # new_video.release() # 上のコードをffmpyで書き直す # rewrite the above code with ffmpy ff = FFmpeg( inputs={f"{workspace}/frames/%015d.png": f"-r {fps}"}, outputs={ f"{workspace}/result/video.mp4": f"-c:v libx264 -vf fps={fps},format=yuv420p -hide_banner -loglevel error -y" }, ) ff.run() # issue # なぜかkey_colorの背景色が暗くなる # idk why but key_color background color becomes dark ff2 = FFmpeg( inputs={f"{workspace}/result/video.mp4": None, f"{video}": None}, outputs={ f"./video_result/{original_video_name_without_ext}_BGremoved.mp4": "-c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 -shortest -hide_banner -loglevel error -y" }, ) ff2.run() # 本当は透過の動画が良かったけど互換性がないのでボツ # I wanted to make a transparent video, but it's not compatible, so I gave up # subprocess.run( # f'ffmpeg -framerate {fps} -i {workspace}/frames/%015d.png -auto-alt-ref 0 -c:v libvpx "./video_result/{original_video_name_without_ext}_BGremoved.webm" -hide_banner -loglevel error -y', # shell=True, # check=True, # ) # クロマキー用なので音声いらないじゃん # audio is not needed # subprocess.run( # f'ffmpeg -i "./video_result/{original_video_name_without_ext}_BGremoved.webm" -c:v libx264 -c:a aac -strict experimental -b:a 192k ./demo/demo.mp4 -hide_banner -loglevel error -y', # shell=True, # check=True, # ) # ゴミ削除 # remove garbage shutil.rmtree(workspace) return f"./video_result/{original_video_name_without_ext}_BGremoved.mp4" gr.Markdown("## BRIA RMBG 1.4") gr.HTML( """

This is a demo for BRIA RMBG 1.4 that using BRIA RMBG-1.4 image matting model as backbone.

""" ) title = "Background Removal" description = r"""Background removal model developed by BRIA.AI, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.
For test upload your image and wait. Read more at model card briaai/RMBG-1.4.
""" examples = [ ["./input.jpg"], ] title2 = "Background Removal For Video" description2 = r"""Background removal model developed by BRIA.AI, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.
For test upload your image and wait. Read more at model card briaai/RMBG-1.4.
Also, you can remove the background from the video.
You may have to wait a little longer for the video to process as each frame in video will be processed, so using strong GPU is recommended.
You need ffmpeg installed to use this feature. """ # output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True) # demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description) demo1 = gr.Interface( fn=process, inputs="image", outputs="image", title=title, description=description, examples=examples, api_name="demo1", ) demo2 = gr.Interface( fn=process_video, inputs=[ gr.Video(label="Video"), gr.ColorPicker(label="Key Color(Background color)"), ], outputs="video", title=title2, description=description2, api_name="demo2", ) demo = gr.TabbedInterface( interface_list=[demo1, demo2], tab_names=["Image", "Video"], ) if __name__ == "__main__": demo.launch(share=False)