File size: 8,202 Bytes
542c815
3f8e328
542c815
 
9bf688a
542c815
c5325e3
 
d6e753e
 
 
8a357d1
542c815
c5325e3
 
 
 
 
96d6405
c5325e3
 
8f942dd
c5325e3
542c815
 
c5325e3
 
542c815
c5325e3
 
 
 
542c815
 
c5325e3
988f91c
 
542c815
 
 
fdc77c7
542c815
 
1605763
c5325e3
542c815
 
c5325e3
 
 
 
542c815
c5325e3
70974c3
c5325e3
 
542c815
c5325e3
542c815
 
c5325e3
70974c3
c5325e3
542c815
 
c5325e3
542c815
68112f8
542c815
68112f8
 
542c815
 
c5325e3
 
 
 
 
 
 
9a80340
c5325e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d6405
 
 
 
 
 
 
c5325e3
96d6405
c5325e3
 
 
d909bca
96d6405
 
 
 
 
c5325e3
96d6405
542c815
c5325e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d909bca
 
c530952
c5325e3
 
c530952
 
 
 
c5325e3
 
c530952
6ca28a8
4941fcb
6ca28a8
c5325e3
 
 
 
 
 
 
fc836ce
c5325e3
 
68112f8
 
c5325e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d909bca
 
96d6405
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr

# from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

import cv2
import os
import shutil
import glob
from tqdm import tqdm
from ffmpy import FFmpeg

net = BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", "model.pth")
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
    print("GPU is available")
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
    print("GPU is NOT available")
net.eval()


def resize_image(image):
    image = image.convert("RGB")
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image):

    # prepare input
    orig_image = Image.fromarray(image)
    w, h = orig_im_size = orig_image.size
    image = resize_image(orig_image)
    im_np = np.array(image)
    im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = torch.unsqueeze(im_tensor, 0)
    im_tensor = torch.divide(im_tensor, 255.0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # inference
    result = net(im_tensor)
    # post process
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode="bilinear"), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result - mi) / (ma - mi)
    # image to pil
    im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
    pil_im = Image.fromarray(np.squeeze(im_array))
    # paste the mask on the original image
    new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
    new_im.paste(orig_image, mask=pil_im)
    # new_orig_image = orig_image.convert('RGBA')

    return new_im
    # return [new_orig_image, new_im]


def process_video(video, key_color):
    workspace = "./temp"
    original_video_name_without_ext = os.path.splitext(os.path.basename(video))[0]

    os.makedirs(workspace, exist_ok=True)
    os.makedirs(f"{workspace}/frames", exist_ok=True)
    os.makedirs(f"{workspace}/result", exist_ok=True)
    os.makedirs("./video_result", exist_ok=True)

    video_file = cv2.VideoCapture(video)
    fps = video_file.get(cv2.CAP_PROP_FPS)

    # まず、videoを読み込んで、./frames/にフレームを保存する
    # fase, load video and save frames to ./frames/
    def extract_frames():
        success, frame = video_file.read()
        frame_num = 0
        with tqdm(
            total=None,
            desc="Extracting frames",
        ) as pbar:
            while success:
                file_name = f"{workspace}/frames/{frame_num:015d}.png"
                cv2.imwrite(file_name, frame)
                success, frame = video_file.read()
                frame_num += 1
                pbar.update(1)
        video_file.release()
        return

    extract_frames()

    # それぞれのフレームに対して処理を行う
    # process each frame
    def process_frame(frame_file):
        image = cv2.imread(frame_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        new_image = process(image)
        # key_colorを背景にする
        # set key_color as background
        key_back_image = Image.new("RGBA", new_image.size, key_color)
        new_image = Image.alpha_composite(key_back_image, new_image)
        new_image.save(frame_file)

    frame_files = sorted(glob.glob(f"{workspace}/frames/*.png"))
    with tqdm(total=len(frame_files), desc="Processing frames") as pbar:
        for file in frame_files:
            process_frame(file)
            pbar.update(1)

    # frameからvideoを作成
    # create video from frames
    # first_frame = cv2.imread(frame_files[0])
    # h, w, _ = first_frame.shape
    # fourcc = cv2.VideoWriter_fourcc(*"avc1")
    # new_video = cv2.VideoWriter(f"{workspace}/result/video.mp4", fourcc, fps, (w, h))

    # for file in frame_files:
    #     image = cv2.imread(file)
    #     new_video.write(image)
    # new_video.release()

    # 上のコードをffmpyで書き直す
    # rewrite the above code with ffmpy
    ff = FFmpeg(
        inputs={f"{workspace}/frames/%015d.png": f"-r {fps}"},
        outputs={
            f"{workspace}/result/video.mp4": f"-c:v libx264 -vf fps={fps},format=yuv420p -hide_banner -loglevel error -y"
        },
    )
    ff.run()
    # issue
    # なぜかkey_colorの背景色が暗くなる
    # idk why but key_color background color becomes dark

    ff2 = FFmpeg(
        inputs={f"{workspace}/result/video.mp4": None, f"{video}": None},
        outputs={
            f"./video_result/{original_video_name_without_ext}_BGremoved.mp4": "-c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 -shortest -hide_banner -loglevel error -y"
        },
    )
    ff2.run()

    # 本当は透過の動画が良かったけど互換性がないのでボツ
    # I wanted to make a transparent video, but it's not compatible, so I gave up
    # subprocess.run(
    #     f'ffmpeg -framerate {fps} -i {workspace}/frames/%015d.png -auto-alt-ref 0 -c:v libvpx "./video_result/{original_video_name_without_ext}_BGremoved.webm" -hide_banner -loglevel error -y',
    #     shell=True,
    #     check=True,
    # )
    # クロマキー用なので音声いらないじゃん
    # audio is not needed

    # subprocess.run(
    #     f'ffmpeg -i "./video_result/{original_video_name_without_ext}_BGremoved.webm" -c:v libx264 -c:a aac -strict experimental -b:a 192k ./demo/demo.mp4 -hide_banner -loglevel error -y',
    #     shell=True,
    #     check=True,
    # )

    # ゴミ削除
    # remove garbage
    shutil.rmtree(workspace)

    return f"./video_result/{original_video_name_without_ext}_BGremoved.mp4"


gr.Markdown("## BRIA RMBG 1.4")
gr.HTML(
    """
  <p style="margin-bottom: 10px; font-size: 94%">
    This is a demo for BRIA RMBG 1.4 that using
    <a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone. 
  </p>
"""
)
title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> 
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [
    ["./input.jpg"],
]

title2 = "Background Removal For Video"
description2 = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> 
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
Also, you can remove the background from the video.<br>You may have to wait a little longer for the video to process as each frame in video will be processed, so using strong GPU locally is recommended.<br>
"""

# output = ImageSlider(position=0.5,label='Image without background', type="pil", show_download_button=True)
# demo = gr.Interface(fn=process,inputs="image", outputs=output, examples=examples, title=title, description=description)
demo1 = gr.Interface(
    fn=process,
    inputs="image",
    outputs="image",
    title=title,
    description=description,
    examples=examples,
    api_name="demo1",
)


demo2 = gr.Interface(
    fn=process_video,
    inputs=[
        gr.Video(label="Video"),
        gr.ColorPicker(label="Key Color(Background color)"),
    ],
    outputs="video",
    title=title2,
    description=description2,
    api_name="demo2",
)

demo = gr.TabbedInterface(
    interface_list=[demo1, demo2],
    tab_names=["Image", "Video"],
)

if __name__ == "__main__":
    demo.launch(share=False)