|
from transformers import PretrainedConfig |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import save_file, load_file |
|
import os |
|
from timm.models.vision_transformer import Block |
|
from . import mar |
|
from .vae import AutoencoderKL |
|
from .mar import MAR |
|
import numpy as np |
|
|
|
class MARConfig(PretrainedConfig): |
|
model_type = "mar" |
|
|
|
def __init__(self, |
|
img_size=256, |
|
vae_stride=16, |
|
patch_size=1, |
|
encoder_embed_dim=1024, |
|
encoder_depth=16, |
|
encoder_num_heads=16, |
|
decoder_embed_dim=1024, |
|
decoder_depth=16, |
|
decoder_num_heads=16, |
|
mlp_ratio=4., |
|
norm_layer="LayerNorm", |
|
vae_embed_dim=16, |
|
mask_ratio_min=0.7, |
|
label_drop_prob=0.1, |
|
class_num=1000, |
|
attn_dropout=0.1, |
|
proj_dropout=0.1, |
|
buffer_size=64, |
|
diffloss_d=3, |
|
diffloss_w=1024, |
|
num_sampling_steps='100', |
|
diffusion_batch_mul=4, |
|
grad_checkpointing=False, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
self.img_size = img_size |
|
self.vae_stride = vae_stride |
|
self.patch_size = patch_size |
|
self.encoder_embed_dim = encoder_embed_dim |
|
self.encoder_depth = encoder_depth |
|
self.encoder_num_heads = encoder_num_heads |
|
self.decoder_embed_dim = decoder_embed_dim |
|
self.decoder_depth = decoder_depth |
|
self.decoder_num_heads = decoder_num_heads |
|
self.mlp_ratio = mlp_ratio |
|
self.norm_layer = norm_layer |
|
self.vae_embed_dim = vae_embed_dim |
|
self.mask_ratio_min = mask_ratio_min |
|
self.label_drop_prob = label_drop_prob |
|
self.class_num = class_num |
|
self.attn_dropout = attn_dropout |
|
self.proj_dropout = proj_dropout |
|
self.buffer_size = buffer_size |
|
self.diffloss_d = diffloss_d |
|
self.diffloss_w = diffloss_w |
|
self.num_sampling_steps = num_sampling_steps |
|
self.diffusion_batch_mul = diffusion_batch_mul |
|
self.grad_checkpointing = grad_checkpointing |
|
|
|
|
|
|
|
class MARModel(PreTrainedModel): |
|
|
|
config_class = MARConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
norm_layer = getattr(nn, config.norm_layer) |
|
|
|
|
|
self.model = MAR( |
|
img_size=config.img_size, |
|
vae_stride=config.vae_stride, |
|
patch_size=config.patch_size, |
|
encoder_embed_dim=config.encoder_embed_dim, |
|
encoder_depth=config.encoder_depth, |
|
encoder_num_heads=config.encoder_num_heads, |
|
decoder_embed_dim=config.decoder_embed_dim, |
|
decoder_depth=config.decoder_depth, |
|
decoder_num_heads=config.decoder_num_heads, |
|
mlp_ratio=config.mlp_ratio, |
|
norm_layer=norm_layer, |
|
vae_embed_dim=config.vae_embed_dim, |
|
mask_ratio_min=config.mask_ratio_min, |
|
label_drop_prob=config.label_drop_prob, |
|
class_num=config.class_num, |
|
attn_dropout=config.attn_dropout, |
|
proj_dropout=config.proj_dropout, |
|
buffer_size=config.buffer_size, |
|
diffloss_d=config.diffloss_d, |
|
diffloss_w=config.diffloss_w, |
|
num_sampling_steps=config.num_sampling_steps, |
|
diffusion_batch_mul=config.diffusion_batch_mul, |
|
grad_checkpointing=config.grad_checkpointing, |
|
) |
|
|
|
def forward_train(self, imgs, labels): |
|
|
|
return self.model(imgs, labels) |
|
|
|
def forward(self, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
checkpoint_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename=f"kl16.safetensors" |
|
) |
|
vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=checkpoint_path) |
|
vae = vae.to(device).eval() |
|
|
|
seed = 0 |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
num_ar_steps = 64 |
|
cfg_scale = 4 |
|
cfg_schedule = "constant" |
|
temperature = 1.0 |
|
|
|
class_labels = 207, 360, 388, 113, 355, 980, 323, 979 |
|
samples_per_row = 4 |
|
|
|
with torch.cuda.amp.autocast(): |
|
sampled_tokens = self.model.sample_tokens( |
|
bsz=len(class_labels), num_iter=num_ar_steps, |
|
cfg=cfg_scale, cfg_schedule=cfg_schedule, |
|
labels=torch.Tensor(class_labels).long().to(device), |
|
temperature=temperature, progress=True) |
|
sampled_images = vae.decode(sampled_tokens / 0.2325) |
|
return sampled_images |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
|
|
buffer_size = kwargs.get('buffer_size', 64) |
|
diffloss_d = kwargs.get('diffloss_d', 3) |
|
diffloss_w = kwargs.get('diffloss_w', 1024) |
|
num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model_type = "mar_base" |
|
model_architecture = mar.__dict__[model_type]( |
|
buffer_size=buffer_size, |
|
diffloss_d=diffloss_d, |
|
diffloss_w=diffloss_w, |
|
num_sampling_steps=str(num_sampling_steps_diffloss) |
|
).to(device) |
|
checkpoint_path = hf_hub_download( |
|
repo_id=pretrained_model_name_or_path, |
|
filename=f"checkpoint-last.pth" |
|
) |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device)["model_ema"] |
|
|
|
model_architecture.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
model = model_architecture |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory): |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
state_dict = self.model.state_dict() |
|
safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors") |
|
save_file(state_dict, safetensors_path) |
|
|
|
|
|
self.config.save_pretrained(save_directory) |
|
|