lumina-next-compositional / worker_runpod.py
camenduru's picture
Create worker_runpod.py
20b1f60 verified
raw
history blame
No virus
8.85 kB
import os, json, requests, runpod
import math
import random
import traceback
import fairscale.nn.model_parallel.initialize as fs_init
import gradio as gr
import numpy as np
from safetensors.torch import load_file
import torch
import torch.distributed as dist
from torchvision.transforms.functional import to_pil_image
import models
from transport import Sampler, create_transport
from diffusers.models import AutoencoderKL
from transformers import AutoModel, AutoTokenizer
discord_token = os.getenv('com_camenduru_discord_token')
web_uri = os.getenv('com_camenduru_web_uri')
web_token = os.getenv('com_camenduru_web_token')
with torch.inference_mode():
path_type = "Linear" # ["Linear", "GVP", "VP"]
prediction = "velocity" # ["velocity", "score", "noise"]
loss_weight = None # [None, "velocity", "likelihood"]
sample_eps = None
train_eps = None
atol = 1e-6
rtol = 1e-3
reverse = None
likelihood = None
rank = 0
num_gpus = 1
ckpt = "/content/Lumina-T2X/models"
ema = True
dtype = torch.bfloat16 #["bf16", "fp32"]
os.environ["MASTER_PORT"] = str(8080)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(num_gpus)
dist.init_process_group("nccl")
fs_init.initialize_model_parallel(1)
torch.cuda.set_device(rank)
train_args = torch.load(os.path.join(ckpt, "model_args.pth"))
text_encoder = AutoModel.from_pretrained("4bit/gemma-2b", torch_dtype=dtype, device_map="cuda").eval()
cap_feat_dim = text_encoder.config.hidden_size
tokenizer = AutoTokenizer.from_pretrained("4bit/gemma-2b")
tokenizer.padding_side = "right"
vae = AutoencoderKL.from_pretrained((f"stabilityai/sd-vae-ft-{train_args.vae}" if train_args.vae != "sdxl" else "stabilityai/sdxl-vae"), torch_dtype=torch.float32).cuda()
model = models.__dict__[train_args.model](
qk_norm=train_args.qk_norm,
cap_feat_dim=cap_feat_dim,
)
model.eval().to("cuda", dtype=dtype)
ckpt = load_file(os.path.join(ckpt, f"consolidated{'_ema' if ema else ''}.{rank:02d}-of-{num_gpus:02d}.safetensors"), device="cpu",)
model.load_state_dict(ckpt, strict=True)
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
with torch.no_grad():
text_inputs = tokenizer(
captions,
padding=True,
pad_to_multiple_of=8,
max_length=256,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_masks = text_inputs.attention_mask
prompt_embeds = text_encoder(
input_ids=text_input_ids.cuda(),
attention_mask=prompt_masks.cuda(),
output_hidden_states=True,
).hidden_states[-2]
return prompt_embeds, prompt_masks
@torch.inference_mode()
def generate(input):
values = input["input"]
cap1 = values['cap1']
cap2 = values['cap2']
cap3 = values['cap3']
cap4 = values['cap4']
neg_cap = values['neg_cap']
resolution = values['resolution'] # ["2048x1024 (4x1 Grids)","2560x1024 (4x1 Grids)","3072x1024 (4x1 Grids)","1024x1024 (2x2 Grids)","1536x1536 (2x2 Grids)","2048x2048 (2x2 Grids)","1024x2048 (1x4 Grids)","1024x2560 (1x4 Grids)","1024x3072 (1x4 Grids)",]
num_sampling_steps = values['num_sampling_steps']
cfg_scale = values['cfg_scale']
solver = values['solver'] # ["euler", "midpoint", "rk4"]
t_shift = values['t_shift']
seed = values['seed']
scaling_method = values['scaling_method'] # ["Time-aware", "None"]
scaling_watershed = values['scaling_watershed']
proportional_attn = values['proportional_attn']
with torch.autocast("cuda", dtype):
try:
# begin sampler
transport = create_transport(
path_type,
prediction,
loss_weight,
train_eps,
sample_eps,
)
sampler = Sampler(transport)
sample_fn = sampler.sample_ode(
sampling_method=solver,
num_steps=num_sampling_steps,
atol=atol,
rtol=rtol,
reverse=reverse,
time_shifting_factor=t_shift,
)
# end sampler
do_extrapolation = "Extrapolation" in resolution
split = resolution.split(" ")[1].replace("(", "")
w_split, h_split = split.split("x")
resolution = resolution.split(" ")[0]
w, h = resolution.split("x")
w, h = int(w), int(h)
latent_w, latent_h = w // 8, h // 8
if int(seed) != 0:
torch.random.manual_seed(int(seed))
z = torch.randn([1, 4, latent_h, latent_w], device="cuda").to(dtype)
z = z.repeat(2, 1, 1, 1)
cap_list = [cap1, cap2, cap3, cap4]
global_cap = " ".join(cap_list)
with torch.no_grad():
if neg_cap != "":
cap_feats, cap_mask = encode_prompt(
cap_list + [neg_cap] + [global_cap], text_encoder, tokenizer, 0.0
)
else:
cap_feats, cap_mask = encode_prompt(
cap_list + [""] + [global_cap], text_encoder, tokenizer, 0.0
)
cap_mask = cap_mask.to(cap_feats.device)
model_kwargs = dict(
cap_feats=cap_feats[:-1],
cap_mask=cap_mask[:-1],
global_cap_feats=cap_feats[-1:],
global_cap_mask=cap_mask[-1:],
cfg_scale=cfg_scale,
h_split_num=int(h_split),
w_split_num=int(w_split),
)
if proportional_attn:
model_kwargs["proportional_attn"] = True
model_kwargs["base_seqlen"] = (train_args.image_size // 16) ** 2
else:
model_kwargs["proportional_attn"] = False
model_kwargs["base_seqlen"] = None
if do_extrapolation and scaling_method == "Time-aware":
model_kwargs["scale_factor"] = math.sqrt(w * h / train_args.image_size**2)
model_kwargs["scale_watershed"] = scaling_watershed
else:
model_kwargs["scale_factor"] = 1.0
model_kwargs["scale_watershed"] = 1.0
samples = sample_fn(z, model.forward_with_cfg, **model_kwargs)[-1]
samples = samples[:1]
factor = 0.18215 if train_args.vae != "sdxl" else 0.13025
samples = vae.decode(samples / factor).sample
samples = (samples + 1.0) / 2.0
samples.clamp_(0.0, 1.0)
img = to_pil_image(samples[0].float())
except Exception:
print(traceback.format_exc())
result = img
response = None
try:
source_id = values['source_id']
del values['source_id']
source_channel = values['source_channel']
del values['source_channel']
job_id = values['job_id']
del values['job_id']
default_filename = os.path.basename(result)
files = {default_filename: open(result, "rb").read()}
payload = {"content": f"{json.dumps(values)} <@{source_id}>"}
response = requests.post(
f"https://discord.com/api/v9/channels/{source_channel}/messages",
data=payload,
headers={"authorization": f"Bot {discord_token}"},
files=files
)
response.raise_for_status()
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
if os.path.exists(result):
os.remove(result)
if response and response.status_code == 200:
try:
payload = {"jobId": job_id, "result": response.json()['attachments'][0]['url']}
requests.post(f"{web_uri}/api/notify", data=json.dumps(payload), headers={'Content-Type': 'application/json', "authorization": f"{web_token}"})
except Exception as e:
print(f"An unexpected error occurred: {e}")
finally:
return {"result": response.json()['attachments'][0]['url']}
else:
return {"result": "ERROR"}
runpod.serverless.start({"handler": generate})