File size: 5,131 Bytes
e617194
 
 
 
 
 
11e2057
1ea5bb8
b0637c4
e617194
abf8cc6
 
 
 
e617194
 
abf8cc6
 
 
 
 
 
 
 
 
e617194
 
 
 
 
 
 
 
abf8cc6
 
 
e617194
abf8cc6
 
 
e617194
 
 
174d21b
1ea8b8b
e617194
 
 
 
 
 
1ea5bb8
abf8cc6
 
 
 
 
 
 
 
 
 
 
 
 
b0637c4
abf8cc6
 
 
 
e617194
abf8cc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e617194
abf8cc6
 
 
e617194
 
abf8cc6
 
 
e617194
 
1ea8b8b
e617194
 
 
 
 
 
1ea5bb8
abf8cc6
 
e617194
 
 
abf8cc6
e617194
 
 
 
abf8cc6
e617194
abf8cc6
 
1ea5bb8
e617194
 
 
 
 
abf8cc6
 
 
 
 
e617194
 
abf8cc6
 
 
 
 
 
 
e617194
 
abf8cc6
e617194
abf8cc6
1ea5bb8
e617194
 
 
abf8cc6
 
 
e617194
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import os
import zipfile
from PIL import Image

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

torch.set_float32_matmul_precision(["high", "highest"][0])

try:
    birefnet = AutoModelForImageSegmentation.from_pretrained(
        "ZhengPeng7/BiRefNet", trust_remote_code=True
    )
    birefnet.to("cpu")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def process_single_image(image, output_type="mask"):
    if image is None:
        return [None, None], None
    
    im = load_img(image, output_type="pil")
    if im is None:
        return [None, None], None

    im = im.convert("RGB")
    image_size = im.size
    origin = im.copy()
    input_images = transform_image(im).unsqueeze(0).to("cpu")
    
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    
    processed_im = im.copy()
    processed_im.putalpha(mask)
    output_file_path = os.path.join(output_folder, "output_image_i2i.png")
    processed_im.save(output_file_path)

    if output_type == "origin":
        return [processed_im, origin], output_file_path
    else:
        return [processed_im, mask], output_file_path

def process_image_from_url(url, output_type="mask"):
    if url is None or url.strip() == "":
        return [None, None], None
    
    try:
        im = load_img(url, output_type="pil")
        if im is None:
            return [None, None], None

        im = im.convert("RGB")
        image_size = im.size
        origin = im.copy()
        input_images = transform_image(im).unsqueeze(0).to("cpu")
        
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image_size)
        
        processed_im = im.copy()
        processed_im.putalpha(mask)
        output_file_path = os.path.join(output_folder, "output_image_url.png")
        processed_im.save(output_file_path)

        if output_type == "origin":
            return [processed_im, origin], output_file_path
        else:
            return [processed_im, mask], output_file_path
    except Exception as e:
        return [None, None], str(e)

def process_batch_images(images):
    output_paths = []
    if not images:
        return [], None
    
    for idx, image_path in enumerate(images):
        im = load_img(image_path, output_type="pil")
        if im is None:
            continue
        
        im = im.convert("RGB")
        image_size = im.size
        input_images = transform_image(im).unsqueeze(0).to("cpu")
        
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image_size)
        
        im.putalpha(mask)
        output_file_path = os.path.join(output_folder, f"output_image_batch_{idx + 1}.png")
        im.save(output_file_path)
        output_paths.append(output_file_path)

    zip_file_path = os.path.join(output_folder, "processed_images.zip")
    with zipfile.ZipFile(zip_file_path, 'w') as zipf:
        for file in output_paths:
            zipf.write(file, os.path.basename(file))

    return output_paths, zip_file_path

image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")

slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")

tab1 = gr.Interface(
    fn=process_single_image, 
    inputs=[image, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
    outputs=[slider1, gr.File(label="PNG Output")],
    examples=[["chameleon.jpg"]],
    api_name="image"
)

tab2 = gr.Interface(
    fn=process_image_from_url, 
    inputs=[text, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
    outputs=[slider2, gr.File(label="PNG Output")],
    examples=[["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]], 
    api_name="text"
)

tab3 = gr.Interface(
    fn=process_batch_images, 
    inputs=batch_image, 
    outputs=[gr.Gallery(label="Processed Images"), gr.File(label="Download Processed Files")], 
    api_name="batch"
)

demo = gr.TabbedInterface(
    [tab1, tab2, tab3], 
    ["image", "text", "batch"], 
    title="Multi Birefnet for Background Removal"
)

if __name__ == "__main__":
    demo.launch()