import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from huggingface_hub import hf_hub_download import gradio as gr from briarmbg import BriaRMBG import PIL from PIL import Image from typing import Tuple from io import BytesIO import base64 import re import os SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret') # Regex pattern to match data URI scheme data_uri_pattern = re.compile(r'data:image/(png|jpeg|jpg|webp);base64,') def readb64(b64): # Remove any data URI scheme prefix with regex b64 = data_uri_pattern.sub("", b64) # Decode and open the image with PIL img = Image.open(BytesIO(base64.b64decode(b64))) return img # convert from PIL to base64 def writeb64(image): buffered = BytesIO() image.save(buffered, format="PNG") b64image = base64.b64encode(buffered.getvalue()) b64image_str = b64image.decode("utf-8") return b64image_str net=BriaRMBG() model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net=net.cuda() else: net.load_state_dict(torch.load(model_path,map_location="cpu")) net.eval() def resize_image(image): image = image.convert('RGB') model_input_size = (1024, 1024) image = image.resize(model_input_size, Image.BILINEAR) return image def process(secret_token, base64_in): if secret_token != SECRET_TOKEN: raise gr.Error( f'Invalid secret token. Please fork the original space if you want to use it for yourself.') orig_image = readb64(base64_in) # prepare input w,h = orig_im_size = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) if torch.cuda.is_available(): im_tensor=im_tensor.cuda() #inference result=net(im_tensor) # post process result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) ma = torch.max(result) mi = torch.min(result) result = (result-mi)/(ma-mi) # image to pil im_array = (result*255).cpu().data.numpy().astype(np.uint8) pil_im = Image.fromarray(np.squeeze(im_array)) # paste the mask on the original image new_im = Image.new("RGBA", pil_im.size, (0,0,0,0)) new_im.paste(orig_image, mask=pil_im) base64_out = writeb64(new_im) return base64_out with gr.Blocks() as demo: secret_token = gr.Text( label='Secret Token', max_lines=1, placeholder='Enter your secret token') gr.HTML("""
This space is a REST API to programmatically remove the background of an image.
Interested in using it? Please use the original space, thank you!