File size: 3,008 Bytes
44189a1 c71b96e 44189a1 dc78df8 7376db6 dc78df8 44189a1 e4ce1e7 c71b96e e4ce1e7 44189a1 dc78df8 c71b96e 44189a1 dc78df8 44189a1 dc78df8 e4ce1e7 dc78df8 693892f dc78df8 c71b96e 44189a1 7376db6 44189a1 8c25de0 693892f ca61b7a 693892f ca61b7a c71b96e 693892f 73b0806 8c25de0 c71b96e 8c25de0 693892f 8c25de0 693892f 44189a1 693892f 8c25de0 693892f 8c25de0 693892f 44189a1 693892f 44189a1 e4ce1e7 693892f c71b96e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import logging
import torch
import numpy as np
from PIL import Image
from diffusers.utils import check_min_version
from pipeline import LotusGPipeline, LotusDPipeline
from utils.image_utils import colorize_depth_map
from contextlib import nullcontext
import spaces # Import the spaces module for ZeroGPU
check_min_version('0.28.0.dev0')
# Global variables to store the models
pipe_g = None
pipe_d = None
@spaces.GPU
def load_models(task_name, device):
global pipe_g, pipe_d # Use global variables to store the models
if task_name == 'depth':
model_g = 'jingheya/lotus-depth-g-v1-0'
model_d = 'jingheya/lotus-depth-d-v1-1'
else:
model_g = 'jingheya/lotus-normal-g-v1-0'
model_d = 'jingheya/lotus-normal-d-v1-0'
dtype = torch.float16
pipe_g = LotusGPipeline.from_pretrained(
model_g,
torch_dtype=dtype,
)
pipe_d = LotusDPipeline.from_pretrained(
model_d,
torch_dtype=dtype,
)
pipe_g.to(device)
pipe_d.to(device)
pipe_g.set_progress_bar_config(disable=True)
pipe_d.set_progress_bar_config(disable=True)
logging.info(f"Successfully loaded pipelines from {model_g} and {model_d}.")
def infer_pipe(pipe, image, task_name, seed, device):
if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(device_type='cuda', dtype=torch.float16)
with torch.no_grad():
with autocast_ctx:
# Convert image to tensor
img = np.array(image.convert('RGB')).astype(np.float32)
test_image = torch.tensor(img).permute(2, 0, 1).unsqueeze(0)
test_image = test_image / 127.5 - 1.0
test_image = test_image.to(device).type(torch.float16)
# Create task_emb
task_emb = torch.tensor([1, 0], device=device, dtype=torch.float16).unsqueeze(0)
task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)], dim=-1)
# Run inference
pred = pipe(
rgb_in=test_image,
prompt='',
num_inference_steps=1,
generator=generator,
output_type='np',
timesteps=[999],
task_emb=task_emb,
).images[0]
# Post-process prediction
if task_name == 'depth':
output_npy = pred.mean(axis=-1)
output_color = colorize_depth_map(output_npy)
else:
output_npy = pred
output_color = Image.fromarray((output_npy * 255).astype(np.uint8))
return output_color
def lotus(image, task_name, seed, device):
global pipe_g, pipe_d # Access the global models
output_d = infer_pipe(pipe_d, image, task_name, seed, device)
return output_d # Only returning depth outputs for this application |