Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -120,6 +120,7 @@ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
|
|
120 |
# Load model directly
|
121 |
from transformers import AutoModelForImageSegmentation
|
122 |
rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
|
|
|
123 |
|
124 |
model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
|
125 |
model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
|
@@ -128,6 +129,7 @@ model.eval()
|
|
128 |
|
129 |
# Change UNet
|
130 |
|
|
|
131 |
with torch.no_grad():
|
132 |
new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
133 |
new_conv_in.weight.zero_()
|
@@ -314,7 +316,7 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
|
|
314 |
|
315 |
return c, uc
|
316 |
|
317 |
-
|
318 |
@torch.inference_mode()
|
319 |
def pytorch2numpy(imgs, quant=True):
|
320 |
results = []
|
@@ -331,7 +333,7 @@ def pytorch2numpy(imgs, quant=True):
|
|
331 |
results.append(y)
|
332 |
return results
|
333 |
|
334 |
-
|
335 |
@torch.inference_mode()
|
336 |
def numpy2pytorch(imgs):
|
337 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
@@ -359,7 +361,7 @@ def resize_without_crop(image, target_width, target_height):
|
|
359 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
360 |
return np.array(resized_image)
|
361 |
|
362 |
-
|
363 |
@torch.inference_mode()
|
364 |
def run_rmbg(img, sigma=0.0):
|
365 |
# Convert RGBA to RGB if needed
|
@@ -384,6 +386,8 @@ def run_rmbg(img, sigma=0.0):
|
|
384 |
rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
|
385 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
386 |
return result.clip(0, 255).astype(np.uint8), rgba
|
|
|
|
|
387 |
@torch.inference_mode()
|
388 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
389 |
clear_memory()
|
|
|
120 |
# Load model directly
|
121 |
from transformers import AutoModelForImageSegmentation
|
122 |
rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
|
123 |
+
rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
|
124 |
|
125 |
model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
|
126 |
model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
|
|
|
129 |
|
130 |
# Change UNet
|
131 |
|
132 |
+
|
133 |
with torch.no_grad():
|
134 |
new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
|
135 |
new_conv_in.weight.zero_()
|
|
|
316 |
|
317 |
return c, uc
|
318 |
|
319 |
+
@spaces.GPU(duration=60)
|
320 |
@torch.inference_mode()
|
321 |
def pytorch2numpy(imgs, quant=True):
|
322 |
results = []
|
|
|
333 |
results.append(y)
|
334 |
return results
|
335 |
|
336 |
+
@spaces.GPU(duration=60)
|
337 |
@torch.inference_mode()
|
338 |
def numpy2pytorch(imgs):
|
339 |
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
|
|
|
361 |
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
362 |
return np.array(resized_image)
|
363 |
|
364 |
+
@spaces.GPU(duration=60)
|
365 |
@torch.inference_mode()
|
366 |
def run_rmbg(img, sigma=0.0):
|
367 |
# Convert RGBA to RGB if needed
|
|
|
386 |
rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
|
387 |
result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
|
388 |
return result.clip(0, 255).astype(np.uint8), rgba
|
389 |
+
|
390 |
+
@spaces.GPU(duration=60)
|
391 |
@torch.inference_mode()
|
392 |
def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
|
393 |
clear_memory()
|