File size: 4,353 Bytes
1822fe2 bfffca7 1822fe2 80a6dc1 f2191a0 ca410d5 1822fe2 de18685 1822fe2 de1c669 1822fe2 de18685 1822fe2 de18685 ca410d5 de18685 80a6dc1 ca410d5 de18685 80a6dc1 ca410d5 de18685 80a6dc1 de18685 b859dac 1822fe2 80a6dc1 1822fe2 de18685 1822fe2 de18685 1822fe2 80a6dc1 99cd496 de18685 1822fe2 5bc55c8 1822fe2 de18685 1822fe2 80a6dc1 1822fe2 80a6dc1 1822fe2 de18685 1822fe2 de18685 1822fe2 de18685 80a6dc1 de18685 1822fe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from diffusers import DiffusionPipeline
import torch
import numpy as np
import importlib.util
import sys
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
from safetensors.torch import load_file
from .vae import AutoencoderKL
from .mar import mar_base, mar_large, mar_huge
# inheriting from DiffusionPipeline for HF
class MARModel(DiffusionPipeline):
def __init__(self):
super().__init__()
@torch.no_grad()
def __call__(self, *args, **kwargs):
"""
This method downloads the model and VAE components,
then executes the forward pass based on the user's input.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# init the mar model architecture
buffer_size = kwargs.get("buffer_size", 64)
diffloss_d = kwargs.get("diffloss_d", 3)
diffloss_w = kwargs.get("diffloss_w", 1024)
num_sampling_steps = kwargs.get("num_sampling_steps", 100)
model_type = kwargs.get("model_type", "mar_base")
model_mapping = {
"mar_base": mar_base,
"mar_large": mar_large,
"mar_huge": mar_huge
}
num_sampling_steps_diffloss = 100 # Example number of sampling steps
# download the pretrained model and set diffloss parameters
if model_type == "mar_base":
diffloss_d = 6
diffloss_w = 1024
model_path = "mar-base.safetensors"
elif model_type == "mar_large":
diffloss_d = 8
diffloss_w = 1280
model_path = "mar-large.safetensors"
elif model_type == "mar_huge":
diffloss_d = 12
diffloss_w = 1536
model_path = "mar-huge.safetensors"
else:
raise NotImplementedError
# download and load the model weights (.safetensors or .pth)
model_checkpoint_path = hf_hub_download(
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
filename=kwargs.get("model_filename", model_path)
)
model_fn = model_mapping[model_type]
model = model_fn(
buffer_size=64,
diffloss_d=diffloss_d,
diffloss_w=diffloss_w,
num_sampling_steps=str(num_sampling_steps_diffloss)
).cuda()
# use safetensors
state_dict = load_file(model_checkpoint_path)
model.load_state_dict(state_dict)
model.eval()
# download and load the vae
vae_checkpoint_path = hf_hub_download(
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
filename=kwargs.get("vae_filename", "kl16.safetensors")
)
vae_checkpoint_path = kwargs.get("vae_checkpoint_path", vae_checkpoint_path)
vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
vae = vae.to(device).eval()
# set up user-specified or default values for generation
seed = kwargs.get("seed", 6)
torch.manual_seed(seed)
np.random.seed(seed)
num_ar_steps = kwargs.get("num_ar_steps", 64)
cfg_scale = kwargs.get("cfg_scale", 4)
cfg_schedule = kwargs.get("cfg_schedule", "constant")
temperature = kwargs.get("temperature", 1.0)
class_labels = kwargs.get("class_labels", 207, 360, 388, 113, 355, 980, 323, 979)
# generate the tokens and images
with torch.cuda.amp.autocast():
sampled_tokens = model.sample_tokens(
bsz=len(class_labels), num_iter=num_ar_steps,
cfg=cfg_scale, cfg_schedule=cfg_schedule,
labels=torch.Tensor(class_labels).long().cuda(),
temperature=temperature, progress=True
)
sampled_images = vae.decode(sampled_tokens / 0.2325)
output_dir = kwargs.get("output_dir", "./")
os.makedirs(output_dir, exist_ok=True)
# save the images
image_path = os.path.join(output_dir, "sampled_image.png")
samples_per_row = kwargs.get("samples_per_row", 4)
save_image(
sampled_images, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
)
# return as a pil image
image = Image.open(image_path)
return image
|