File size: 5,131 Bytes
e617194 11e2057 1ea5bb8 b0637c4 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 174d21b 1ea8b8b e617194 1ea5bb8 abf8cc6 b0637c4 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 1ea8b8b e617194 1ea5bb8 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 1ea5bb8 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 e617194 abf8cc6 1ea5bb8 e617194 abf8cc6 e617194 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import os
import zipfile
from PIL import Image
output_folder = 'output_images'
if not os.path.exists(output_folder):
os.makedirs(output_folder)
torch.set_float32_matmul_precision(["high", "highest"][0])
try:
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cpu")
except Exception as e:
print(f"Error loading model: {e}")
raise
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process_single_image(image, output_type="mask"):
if image is None:
return [None, None], None
im = load_img(image, output_type="pil")
if im is None:
return [None, None], None
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
input_images = transform_image(im).unsqueeze(0).to("cpu")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
processed_im = im.copy()
processed_im.putalpha(mask)
output_file_path = os.path.join(output_folder, "output_image_i2i.png")
processed_im.save(output_file_path)
if output_type == "origin":
return [processed_im, origin], output_file_path
else:
return [processed_im, mask], output_file_path
def process_image_from_url(url, output_type="mask"):
if url is None or url.strip() == "":
return [None, None], None
try:
im = load_img(url, output_type="pil")
if im is None:
return [None, None], None
im = im.convert("RGB")
image_size = im.size
origin = im.copy()
input_images = transform_image(im).unsqueeze(0).to("cpu")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
processed_im = im.copy()
processed_im.putalpha(mask)
output_file_path = os.path.join(output_folder, "output_image_url.png")
processed_im.save(output_file_path)
if output_type == "origin":
return [processed_im, origin], output_file_path
else:
return [processed_im, mask], output_file_path
except Exception as e:
return [None, None], str(e)
def process_batch_images(images):
output_paths = []
if not images:
return [], None
for idx, image_path in enumerate(images):
im = load_img(image_path, output_type="pil")
if im is None:
continue
im = im.convert("RGB")
image_size = im.size
input_images = transform_image(im).unsqueeze(0).to("cpu")
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
im.putalpha(mask)
output_file_path = os.path.join(output_folder, f"output_image_batch_{idx + 1}.png")
im.save(output_file_path)
output_paths.append(output_file_path)
zip_file_path = os.path.join(output_folder, "processed_images.zip")
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
for file in output_paths:
zipf.write(file, os.path.basename(file))
return output_paths, zip_file_path
image = gr.Image(label="Upload an image")
text = gr.Textbox(label="Paste an image URL")
batch_image = gr.File(label="Upload multiple images", type="filepath", file_count="multiple")
slider1 = ImageSlider(label="Processed Image", type="pil")
slider2 = ImageSlider(label="Processed Image from URL", type="pil")
tab1 = gr.Interface(
fn=process_single_image,
inputs=[image, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
outputs=[slider1, gr.File(label="PNG Output")],
examples=[["chameleon.jpg"]],
api_name="image"
)
tab2 = gr.Interface(
fn=process_image_from_url,
inputs=[text, gr.Radio(choices=["mask", "origin"], value="mask", label="Select Output Type")],
outputs=[slider2, gr.File(label="PNG Output")],
examples=[["https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"]],
api_name="text"
)
tab3 = gr.Interface(
fn=process_batch_images,
inputs=batch_image,
outputs=[gr.Gallery(label="Processed Images"), gr.File(label="Download Processed Files")],
api_name="batch"
)
demo = gr.TabbedInterface(
[tab1, tab2, tab3],
["image", "text", "batch"],
title="Multi Birefnet for Background Removal"
)
if __name__ == "__main__":
demo.launch()
|