NegiTurkey commited on
Commit
b0637c4
1 Parent(s): c2e0a3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -6,13 +6,15 @@ import torch
6
  from torchvision import transforms
7
  import os
8
  import zipfile
 
 
9
 
10
  torch.set_float32_matmul_precision(["high", "highest"][0])
11
 
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
15
- birefnet.to("cpu")
16
  transform_image = transforms.Compose(
17
  [
18
  transforms.Resize((1024, 1024)),
@@ -25,7 +27,7 @@ def fn(image):
25
  im = load_img(image, output_type="pil")
26
  im = im.convert("RGB")
27
  image_size = im.size
28
- input_images = transform_image(im).unsqueeze(0).to("cpu")
29
 
30
  with torch.no_grad():
31
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -37,13 +39,16 @@ def fn(image):
37
  output_file_path = os.path.join("output_images", "output_image_single.png")
38
  im.save(output_file_path)
39
 
40
- return [mask, im]
 
 
 
41
 
42
  def fn_url(url):
43
  im = load_img(url, output_type="pil")
44
  im = im.convert("RGB")
45
  image_size = im.size
46
- input_images = transform_image(im).unsqueeze(0).to("cpu")
47
 
48
  with torch.no_grad():
49
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -55,7 +60,10 @@ def fn_url(url):
55
  output_file_path = os.path.join("output_images", "output_image_url.png")
56
  im.save(output_file_path)
57
 
58
- return [mask, im]
 
 
 
59
 
60
  def batch_fn(images):
61
  output_paths = []
@@ -63,7 +71,7 @@ def batch_fn(images):
63
  im = load_img(image_path, output_type="pil")
64
  im = im.convert("RGB")
65
  image_size = im.size
66
- input_images = transform_image(im).unsqueeze(0).to("cpu")
67
 
68
  with torch.no_grad():
69
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -71,7 +79,7 @@ def batch_fn(images):
71
  pred_pil = transforms.ToPILImage()(pred)
72
  mask = pred_pil.resize(image_size)
73
 
74
- im.putalpha(mask)
75
 
76
  output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
77
  im.save(output_file_path)
@@ -95,10 +103,10 @@ chameleon = load_img("chameleon.jpg", output_type="pil")
95
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
96
 
97
  tab1 = gr.Interface(
98
- fn, inputs=image, outputs=slider1, examples=[chameleon], api_name="image"
99
  )
100
 
101
- tab2 = gr.Interface(fn_url, inputs=text, outputs=slider2, examples=[url], api_name="text")
102
 
103
  tab3 = gr.Interface(
104
  batch_fn,
 
6
  from torchvision import transforms
7
  import os
8
  import zipfile
9
+ import numpy as np
10
+ from PIL import Image
11
 
12
  torch.set_float32_matmul_precision(["high", "highest"][0])
13
 
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
+ birefnet.to("cuda")
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
 
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
  image_size = im.size
30
+ input_images = transform_image(im).unsqueeze(0).to("cuda")
31
 
32
  with torch.no_grad():
33
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
39
  output_file_path = os.path.join("output_images", "output_image_single.png")
40
  im.save(output_file_path)
41
 
42
+ output_path = os.path.join("output_images", "output_image_processed.png")
43
+ im.save(output_path, "PNG")
44
+
45
+ return [im, mask], output_path
46
 
47
  def fn_url(url):
48
  im = load_img(url, output_type="pil")
49
  im = im.convert("RGB")
50
  image_size = im.size
51
+ input_images = transform_image(im).unsqueeze(0).to("cuda")
52
 
53
  with torch.no_grad():
54
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
60
  output_file_path = os.path.join("output_images", "output_image_url.png")
61
  im.save(output_file_path)
62
 
63
+ output_path = os.path.join("output_images", "output_image_url_processed.png")
64
+ im.save(output_path, "PNG")
65
+
66
+ return [im, mask], output_path
67
 
68
  def batch_fn(images):
69
  output_paths = []
 
71
  im = load_img(image_path, output_type="pil")
72
  im = im.convert("RGB")
73
  image_size = im.size
74
+ input_images = transform_image(im).unsqueeze(0).to("cuda")
75
 
76
  with torch.no_grad():
77
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
79
  pred_pil = transforms.ToPILImage()(pred)
80
  mask = pred_pil.resize(image_size)
81
 
82
+ im.putalpha(mask)
83
 
84
  output_file_path = os.path.join("output_images", f"output_image_batch_{idx + 1}.png")
85
  im.save(output_file_path)
 
103
  url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
104
 
105
  tab1 = gr.Interface(
106
+ fn, inputs=image, outputs=[slider1, gr.File(label="PNG Output")], examples=[chameleon], api_name="image"
107
  )
108
 
109
+ tab2 = gr.Interface(fn_url, inputs=text, outputs=[slider2, gr.File(label="PNG Output")], examples=[url], api_name="text")
110
 
111
  tab3 = gr.Interface(
112
  batch_fn,