Spaces:
Runtime error
Runtime error
import os, json, requests, runpod | |
import math | |
import random | |
import traceback | |
import fairscale.nn.model_parallel.initialize as fs_init | |
import gradio as gr | |
import numpy as np | |
from safetensors.torch import load_file | |
import torch | |
import torch.distributed as dist | |
from torchvision.transforms.functional import to_pil_image | |
import models | |
from transport import Sampler, create_transport | |
from diffusers.models import AutoencoderKL | |
from transformers import AutoModel, AutoTokenizer | |
with torch.inference_mode(): | |
path_type = "Linear" # ["Linear", "GVP", "VP"] | |
prediction = "velocity" # ["velocity", "score", "noise"] | |
loss_weight = None # [None, "velocity", "likelihood"] | |
sample_eps = None | |
train_eps = None | |
atol = 1e-6 | |
rtol = 1e-3 | |
reverse = None | |
likelihood = None | |
rank = 0 | |
num_gpus = 1 | |
ckpt = "/content/Lumina-T2X/models" | |
ema = True | |
dtype = torch.bfloat16 #["bf16", "fp32"] | |
os.environ["MASTER_PORT"] = str(8080) | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["RANK"] = str(rank) | |
os.environ["WORLD_SIZE"] = str(num_gpus) | |
dist.init_process_group("nccl") | |
fs_init.initialize_model_parallel(1) | |
torch.cuda.set_device(rank) | |
train_args = torch.load(os.path.join(ckpt, "model_args.pth")) | |
text_encoder = AutoModel.from_pretrained("4bit/gemma-2b", torch_dtype=dtype, device_map="cuda").eval() | |
cap_feat_dim = text_encoder.config.hidden_size | |
tokenizer = AutoTokenizer.from_pretrained("4bit/gemma-2b") | |
tokenizer.padding_side = "right" | |
vae = AutoencoderKL.from_pretrained((f"stabilityai/sd-vae-ft-{train_args.vae}" if train_args.vae != "sdxl" else "stabilityai/sdxl-vae"), torch_dtype=torch.float32).cuda() | |
model = models.__dict__[train_args.model]( | |
qk_norm=train_args.qk_norm, | |
cap_feat_dim=cap_feat_dim, | |
) | |
model.eval().to("cuda", dtype=dtype) | |
ckpt = load_file(os.path.join(ckpt, f"consolidated{'_ema' if ema else ''}.{rank:02d}-of-{num_gpus:02d}.safetensors"), device="cpu",) | |
model.load_state_dict(ckpt, strict=True) | |
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt | |
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True): | |
captions = [] | |
for caption in prompt_batch: | |
if random.random() < proportion_empty_prompts: | |
captions.append("") | |
elif isinstance(caption, str): | |
captions.append(caption) | |
elif isinstance(caption, (list, np.ndarray)): | |
# take a random caption if there are multiple | |
captions.append(random.choice(caption) if is_train else caption[0]) | |
with torch.no_grad(): | |
text_inputs = tokenizer( | |
captions, | |
padding=True, | |
pad_to_multiple_of=8, | |
max_length=256, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_masks = text_inputs.attention_mask | |
prompt_embeds = text_encoder( | |
input_ids=text_input_ids.cuda(), | |
attention_mask=prompt_masks.cuda(), | |
output_hidden_states=True, | |
).hidden_states[-2] | |
return prompt_embeds, prompt_masks | |
def generate(input): | |
values = input["input"] | |
cap1 = values['cap1'] | |
cap2 = values['cap2'] | |
cap3 = values['cap3'] | |
cap4 = values['cap4'] | |
neg_cap = values['neg_cap'] | |
resolution = values['resolution'] # ["2048x1024 (4x1 Grids)","2560x1024 (4x1 Grids)","3072x1024 (4x1 Grids)","1024x1024 (2x2 Grids)","1536x1536 (2x2 Grids)","2048x2048 (2x2 Grids)","1024x2048 (1x4 Grids)","1024x2560 (1x4 Grids)","1024x3072 (1x4 Grids)",] | |
num_sampling_steps = values['num_sampling_steps'] | |
cfg_scale = values['cfg_scale'] | |
solver = values['solver'] # ["euler", "midpoint", "rk4"] | |
t_shift = values['t_shift'] | |
seed = values['seed'] | |
scaling_method = values['scaling_method'] # ["Time-aware", "None"] | |
scaling_watershed = values['scaling_watershed'] | |
proportional_attn = values['proportional_attn'] | |
with torch.autocast("cuda", dtype): | |
try: | |
# begin sampler | |
transport = create_transport( | |
path_type, | |
prediction, | |
loss_weight, | |
train_eps, | |
sample_eps, | |
) | |
sampler = Sampler(transport) | |
sample_fn = sampler.sample_ode( | |
sampling_method=solver, | |
num_steps=num_sampling_steps, | |
atol=atol, | |
rtol=rtol, | |
reverse=reverse, | |
time_shifting_factor=t_shift, | |
) | |
# end sampler | |
do_extrapolation = "Extrapolation" in resolution | |
split = resolution.split(" ")[1].replace("(", "") | |
w_split, h_split = split.split("x") | |
resolution = resolution.split(" ")[0] | |
w, h = resolution.split("x") | |
w, h = int(w), int(h) | |
latent_w, latent_h = w // 8, h // 8 | |
if int(seed) != 0: | |
torch.random.manual_seed(int(seed)) | |
z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype) | |
z = z.repeat(2, 1, 1, 1) | |
cap_list = [cap1, cap2, cap3, cap4] | |
global_cap = " ".join(cap_list) | |
with torch.no_grad(): | |
if neg_cap != "": | |
cap_feats, cap_mask = encode_prompt( | |
cap_list + [neg_cap] + [global_cap], text_encoder, tokenizer, 0.0 | |
) | |
else: | |
cap_feats, cap_mask = encode_prompt( | |
cap_list + [""] + [global_cap], text_encoder, tokenizer, 0.0 | |
) | |
cap_mask = cap_mask.to(cap_feats.device) | |
model_kwargs = dict( | |
cap_feats=cap_feats[:-1], | |
cap_mask=cap_mask[:-1], | |
global_cap_feats=cap_feats[-1:], | |
global_cap_mask=cap_mask[-1:], | |
cfg_scale=cfg_scale, | |
h_split_num=int(h_split), | |
w_split_num=int(w_split), | |
) | |
if proportional_attn: | |
model_kwargs["proportional_attn"] = True | |
model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2 | |
else: | |
model_kwargs["proportional_attn"] = False | |
model_kwargs["base_seqlen"] = None | |
if do_extrapolation and scaling_method == "Time-aware": | |
model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size**2) | |
model_kwargs["scale_watershed"] = scaling_watershed | |
else: | |
model_kwargs["scale_factor"] = 1.0 | |
model_kwargs["scale_watershed"] = 1.0 | |
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1] | |
samples = samples[:1] | |
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025 | |
samples = vae.decode(samples / factor).sample | |
samples = (samples + 1.0) / 2.0 | |
samples.clamp_(0.0, 1.0) | |
img = to_pil_image(samples[0].float()) | |
img.save("/content/out.png") | |
except Exception: | |
print(traceback.format_exc()) | |
result = "/content/out.png" | |
try: | |
notify_uri = values['notify_uri'] | |
del values['notify_uri'] | |
notify_token = values['notify_token'] | |
del values['notify_token'] | |
discord_id = values['discord_id'] | |
del values['discord_id'] | |
if(discord_id == "discord_id"): | |
discord_id = os.getenv('com_camenduru_discord_id') | |
discord_channel = values['discord_channel'] | |
del values['discord_channel'] | |
if(discord_channel == "discord_channel"): | |
discord_channel = os.getenv('com_camenduru_discord_channel') | |
discord_token = values['discord_token'] | |
del values['discord_token'] | |
if(discord_token == "discord_token"): | |
discord_token = os.getenv('com_camenduru_discord_token') | |
job_id = values['job_id'] | |
del values['job_id'] | |
default_filename = os.path.basename(result) | |
with open(result, "rb") as file: | |
files = {default_filename: file.read()} | |
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"} | |
response = requests.post( | |
f"https://discord.com/api/v9/channels/{discord_channel}/messages", | |
data=payload, | |
headers={"Authorization": f"Bot {discord_token}"}, | |
files=files | |
) | |
response.raise_for_status() | |
result_url = response.json()['attachments'][0]['url'] | |
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"} | |
web_notify_uri = os.getenv('com_camenduru_web_notify_uri') | |
web_notify_token = os.getenv('com_camenduru_web_notify_token') | |
if(notify_uri == "notify_uri"): | |
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
else: | |
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token}) | |
return {"jobId": job_id, "result": result_url, "status": "DONE"} | |
except Exception as e: | |
error_payload = {"jobId": job_id, "status": "FAILED"} | |
try: | |
if(notify_uri == "notify_uri"): | |
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
else: | |
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token}) | |
except: | |
pass | |
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"} | |
finally: | |
if os.path.exists(result): | |
os.remove(result) | |
runpod.serverless.start({"handler": generate}) |