Spaces:
Running
on
Zero
Running
on
Zero
import gc | |
import spaces | |
from safetensors.torch import load_file | |
from autoregressive.models.gpt_t2i import GPT_models | |
from tokenizer.tokenizer_image.vq_model import VQ_models | |
from language.t5 import T5Embedder | |
import torch | |
import numpy as np | |
import PIL | |
from PIL import Image | |
from condition.canny import CannyDetector | |
import time | |
from autoregressive.models.generate import generate | |
from condition.midas.depth import MidasDetector | |
# from controlnet_aux import ( | |
# MidasDetector, | |
# ) | |
models = { | |
"canny": "checkpoints/canny_MR.safetensors", | |
"depth": "checkpoints/depth_MR.safetensors", | |
} | |
def resize_image_to_16_multiple(image, condition_type='canny'): | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# image = Image.open(image_path) | |
width, height = image.size | |
if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32 | |
new_width = (width + 31) // 32 * 32 | |
new_height = (height + 31) // 32 * 32 | |
else: | |
new_width = (width + 15) // 16 * 16 | |
new_height = (height + 15) // 16 * 16 | |
resized_image = image.resize((new_width, new_height)) | |
return resized_image | |
class Model: | |
def __init__(self): | |
self.device = torch.device( | |
"cuda:0") | |
self.base_model_id = "" | |
self.task_name = "" | |
self.vq_model = self.load_vq() | |
self.t5_model = self.load_t5() | |
self.gpt_model_canny = self.load_gpt(condition_type='canny') | |
self.gpt_model_depth = self.load_gpt(condition_type='depth') | |
self.get_control_canny = CannyDetector() | |
self.get_control_depth = MidasDetector('cuda') | |
# self.get_control_depth = MidasDetector.from_pretrained("lllyasviel/Annotators") | |
def to(self, device): | |
self.gpt_model_canny.to('cuda') | |
print(next(self.gpt_model_canny.adapter.parameters()).device) | |
# print(self.gpt_model_canny.device) | |
def load_vq(self): | |
vq_model = VQ_models["VQ-16"](codebook_size=16384, | |
codebook_embed_dim=8) | |
vq_model.to('cuda') | |
vq_model.eval() | |
checkpoint = torch.load(f"checkpoints/vq_ds16_t2i.pt", | |
map_location="cpu") | |
vq_model.load_state_dict(checkpoint["model"]) | |
del checkpoint | |
print("image tokenizer is loaded") | |
return vq_model | |
def load_gpt(self, condition_type='canny'): | |
gpt_ckpt = models[condition_type] | |
precision = torch.bfloat16 | |
latent_size = 768 // 16 | |
gpt_model = GPT_models["GPT-XL"]( | |
block_size=latent_size**2, | |
cls_token_num=120, | |
model_type='t2i', | |
condition_type=condition_type, | |
).to(device='cuda', dtype=precision) | |
model_weight = load_file(gpt_ckpt) | |
gpt_model.load_state_dict(model_weight, strict=False) | |
gpt_model.eval() | |
print("gpt model is loaded") | |
return gpt_model | |
def load_t5(self): | |
precision = torch.bfloat16 | |
t5_model = T5Embedder( | |
device=self.device, | |
local_cache=True, | |
cache_dir='checkpoints/flan-t5-xl', | |
dir_or_name='flan-t5-xl', | |
torch_dtype=precision, | |
model_max_length=120, | |
) | |
return t5_model | |
def process_canny( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
cfg_scale: float, | |
temperature: float, | |
top_k: int, | |
top_p: int, | |
seed: int, | |
low_threshold: int, | |
high_threshold: int, | |
) -> list[PIL.Image.Image]: | |
image = resize_image_to_16_multiple(image, 'canny') | |
W, H = image.size | |
print(W, H) | |
self.t5_model.model.to(self.device) | |
self.gpt_model_canny.to(self.device) | |
condition_img = self.get_control_canny(np.array(image), low_threshold, | |
high_threshold) | |
condition_img = torch.from_numpy(condition_img[None, None, | |
...]).repeat( | |
2, 3, 1, 1) | |
condition_img = condition_img.to(self.device) | |
condition_img = 2 * (condition_img / 255 - 0.5) | |
prompts = [prompt] * 2 | |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) | |
print(f"processing left-padding...") | |
new_emb_masks = torch.flip(emb_masks, dims=[-1]) | |
new_caption_embs = [] | |
for idx, (caption_emb, | |
emb_mask) in enumerate(zip(caption_embs, emb_masks)): | |
valid_num = int(emb_mask.sum().item()) | |
print(f' prompt {idx} token len: {valid_num}') | |
new_caption_emb = torch.cat( | |
[caption_emb[valid_num:], caption_emb[:valid_num]]) | |
new_caption_embs.append(new_caption_emb) | |
new_caption_embs = torch.stack(new_caption_embs) | |
c_indices = new_caption_embs * new_emb_masks[:, :, None] | |
c_emb_masks = new_emb_masks | |
qzshape = [len(c_indices), 8, H // 16, W // 16] | |
t1 = time.time() | |
print(caption_embs.device) | |
index_sample = generate( | |
self.gpt_model_canny, | |
c_indices, | |
(H // 16) * (W // 16), | |
c_emb_masks, | |
condition=condition_img, | |
cfg_scale=cfg_scale, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
sample_logits=True, | |
) | |
sampling_time = time.time() - t1 | |
print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
t2 = time.time() | |
print(index_sample.shape) | |
samples = self.vq_model.decode_code( | |
index_sample, qzshape) # output value is between [-1, 1] | |
decoder_time = time.time() - t2 | |
print(f"decoder takes about {decoder_time:.2f} seconds.") | |
samples = torch.cat((condition_img[0:1], samples), dim=0) | |
samples = 255 * (samples * 0.5 + 0.5) | |
samples = [image] + [ | |
Image.fromarray( | |
sample.permute(1, 2, 0).cpu().detach().numpy().clip( | |
0, 255).astype(np.uint8)) for sample in samples | |
] | |
del condition_img | |
torch.cuda.empty_cache() | |
return samples | |
def process_depth( | |
self, | |
image: np.ndarray, | |
prompt: str, | |
cfg_scale: float, | |
temperature: float, | |
top_k: int, | |
top_p: int, | |
seed: int, | |
) -> list[PIL.Image.Image]: | |
image = resize_image_to_16_multiple(image, 'depth') | |
W, H = image.size | |
print(W, H) | |
self.t5_model.model.to(self.device) | |
self.gpt_model_depth.to(self.device) | |
self.get_control_depth.model.to(self.device) | |
self.vq_model.to(self.device) | |
image_tensor = torch.from_numpy(np.array(image)).to(self.device) | |
# condition_img = torch.from_numpy( | |
# self.get_control_depth(image_tensor)).unsqueeze(0) | |
# condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1) | |
# condition_img = condition_img.to(self.device) | |
# condition_img = 2 * (condition_img / 255 - 0.5) | |
condition_img = 2 * (image_tensor / 255 - 0.5) | |
print(condition_img.shape) | |
condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2, 1, 1, 1) | |
# control_image = self.get_control_depth( | |
# image=image, | |
# image_resolution=512, | |
# detect_resolution=512, | |
# ) | |
prompts = [prompt] * 2 | |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts) | |
print(f"processing left-padding...") | |
new_emb_masks = torch.flip(emb_masks, dims=[-1]) | |
new_caption_embs = [] | |
for idx, (caption_emb, | |
emb_mask) in enumerate(zip(caption_embs, emb_masks)): | |
valid_num = int(emb_mask.sum().item()) | |
print(f' prompt {idx} token len: {valid_num}') | |
new_caption_emb = torch.cat( | |
[caption_emb[valid_num:], caption_emb[:valid_num]]) | |
new_caption_embs.append(new_caption_emb) | |
new_caption_embs = torch.stack(new_caption_embs) | |
c_indices = new_caption_embs * new_emb_masks[:, :, None] | |
c_emb_masks = new_emb_masks | |
qzshape = [len(c_indices), 8, H // 16, W // 16] | |
t1 = time.time() | |
index_sample = generate( | |
self.gpt_model_depth, | |
c_indices, | |
(H // 16) * (W // 16), | |
c_emb_masks, | |
condition=condition_img, | |
cfg_scale=cfg_scale, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
sample_logits=True, | |
) | |
sampling_time = time.time() - t1 | |
print(f"Full sampling takes about {sampling_time:.2f} seconds.") | |
t2 = time.time() | |
print(index_sample.shape) | |
samples = self.vq_model.decode_code(index_sample, qzshape) | |
decoder_time = time.time() - t2 | |
print(f"decoder takes about {decoder_time:.2f} seconds.") | |
condition_img = condition_img.cpu() | |
samples = samples.cpu() | |
samples = torch.cat((condition_img[0:1], samples), dim=0) | |
samples = 255 * (samples * 0.5 + 0.5) | |
samples = [image] + [ | |
Image.fromarray( | |
sample.permute(1, 2, 0).numpy().clip(0, 255).astype(np.uint8)) | |
for sample in samples | |
] | |
del image_tensor | |
del condition_img | |
torch.cuda.empty_cache() | |
return samples | |