|
import logging |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from diffusers.utils import check_min_version |
|
from pipeline import LotusGPipeline, LotusDPipeline |
|
from utils.image_utils import colorize_depth_map |
|
from contextlib import nullcontext |
|
import spaces |
|
|
|
check_min_version('0.28.0.dev0') |
|
|
|
|
|
pipe_g = None |
|
pipe_d = None |
|
|
|
@spaces.GPU |
|
def load_models(task_name, device): |
|
global pipe_g, pipe_d |
|
if task_name == 'depth': |
|
model_g = 'jingheya/lotus-depth-g-v1-0' |
|
model_d = 'jingheya/lotus-depth-d-v1-1' |
|
else: |
|
model_g = 'jingheya/lotus-normal-g-v1-0' |
|
model_d = 'jingheya/lotus-normal-d-v1-0' |
|
|
|
dtype = torch.float16 |
|
pipe_g = LotusGPipeline.from_pretrained( |
|
model_g, |
|
torch_dtype=dtype, |
|
) |
|
pipe_d = LotusDPipeline.from_pretrained( |
|
model_d, |
|
torch_dtype=dtype, |
|
) |
|
pipe_g.to(device) |
|
pipe_d.to(device) |
|
pipe_g.set_progress_bar_config(disable=True) |
|
pipe_d.set_progress_bar_config(disable=True) |
|
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.") |
|
|
|
def infer_pipe(pipe, image, task_name, seed, device): |
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
if torch.backends.mps.is_available(): |
|
autocast_ctx = nullcontext() |
|
else: |
|
autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16) |
|
|
|
with torch.no_grad(): |
|
with autocast_ctx: |
|
|
|
img = np.array(image.convert('RGB')).astype(np.float32) |
|
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0) |
|
test_image = test_image / 127.5 - 1.0 |
|
test_image = test_image.to(device).type(torch.float16) |
|
|
|
|
|
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0) |
|
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1) |
|
|
|
|
|
pred = pipe( |
|
rgb_in=test_image, |
|
prompt='', |
|
num_inference_steps=1, |
|
generator=generator, |
|
output_type='np', |
|
timesteps=[999], |
|
task_emb=task_emb, |
|
).images[0] |
|
|
|
|
|
if task_name == 'depth': |
|
output_npy = pred.mean(axis=-1) |
|
output_color = colorize_depth_map(output_npy) |
|
else: |
|
output_npy = pred |
|
output_color = Image.fromarray((output_npy * 255).astype(np.uint8)) |
|
|
|
return output_color |
|
|
|
def lotus(image, task_name, seed, device): |
|
global pipe_g, pipe_d |
|
output_d = infer_pipe(pipe_d, image, task_name, seed, device) |
|
return output_d |