from transformers import PretrainedConfig import torch.nn as nn from transformers import PreTrainedModel import torch from huggingface_hub import hf_hub_download from safetensors.torch import save_file, load_file import os from timm.models.vision_transformer import Block from . import mar from .vae import AutoencoderKL from .mar import MAR import numpy as np class MARConfig(PretrainedConfig): model_type = "mar" def __init__(self, img_size=256, vae_stride=16, patch_size=1, encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, mlp_ratio=4., norm_layer="LayerNorm", vae_embed_dim=16, mask_ratio_min=0.7, label_drop_prob=0.1, class_num=1000, attn_dropout=0.1, proj_dropout=0.1, buffer_size=64, diffloss_d=3, diffloss_w=1024, num_sampling_steps='100', diffusion_batch_mul=4, grad_checkpointing=False, **kwargs): super().__init__(**kwargs) # store parameters in the config self.img_size = img_size self.vae_stride = vae_stride self.patch_size = patch_size self.encoder_embed_dim = encoder_embed_dim self.encoder_depth = encoder_depth self.encoder_num_heads = encoder_num_heads self.decoder_embed_dim = decoder_embed_dim self.decoder_depth = decoder_depth self.decoder_num_heads = decoder_num_heads self.mlp_ratio = mlp_ratio self.norm_layer = norm_layer self.vae_embed_dim = vae_embed_dim self.mask_ratio_min = mask_ratio_min self.label_drop_prob = label_drop_prob self.class_num = class_num self.attn_dropout = attn_dropout self.proj_dropout = proj_dropout self.buffer_size = buffer_size self.diffloss_d = diffloss_d self.diffloss_w = diffloss_w self.num_sampling_steps = num_sampling_steps self.diffusion_batch_mul = diffusion_batch_mul self.grad_checkpointing = grad_checkpointing class MARModel(PreTrainedModel): # links to MARConfig class config_class = MARConfig def __init__(self, config): super().__init__(config) self.config = config # convert norm_layer from string to class norm_layer = getattr(nn, config.norm_layer) # init the mar model using the parameters from config self.model = MAR( img_size=config.img_size, vae_stride=config.vae_stride, patch_size=config.patch_size, encoder_embed_dim=config.encoder_embed_dim, encoder_depth=config.encoder_depth, encoder_num_heads=config.encoder_num_heads, decoder_embed_dim=config.decoder_embed_dim, decoder_depth=config.decoder_depth, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio, norm_layer=norm_layer, # use the actual class for the layer vae_embed_dim=config.vae_embed_dim, mask_ratio_min=config.mask_ratio_min, label_drop_prob=config.label_drop_prob, class_num=config.class_num, attn_dropout=config.attn_dropout, proj_dropout=config.proj_dropout, buffer_size=config.buffer_size, diffloss_d=config.diffloss_d, diffloss_w=config.diffloss_w, num_sampling_steps=config.num_sampling_steps, diffusion_batch_mul=config.diffusion_batch_mul, grad_checkpointing=config.grad_checkpointing, ) def forward_train(self, imgs, labels): # calls the forward method from the mar class - passing imgs & labels return self.model(imgs, labels) def forward(self, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False): # call the sample_tokens method from the MAR class device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename=f"kl16.safetensors" ) vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=checkpoint_path) vae = vae.to(device).eval() # can customize more from the user seed = 0 torch.manual_seed(seed) np.random.seed(seed) num_ar_steps = 64 cfg_scale = 4 cfg_schedule = "constant" temperature = 1.0 # TODO: this should be defined by the user class_labels = 207, 360, 388, 113, 355, 980, 323, 979 #@param {type:"raw"} samples_per_row = 4 with torch.cuda.amp.autocast(): sampled_tokens = self.model.sample_tokens( bsz=len(class_labels), num_iter=num_ar_steps, cfg=cfg_scale, cfg_schedule=cfg_schedule, labels=torch.Tensor(class_labels).long().to(device), temperature=temperature, progress=True) sampled_images = vae.decode(sampled_tokens / 0.2325) return sampled_images @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) # model = cls(config) buffer_size = kwargs.get('buffer_size', 64) diffloss_d = kwargs.get('diffloss_d', 3) diffloss_w = kwargs.get('diffloss_w', 1024) num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_type = "mar_base" model_architecture = mar.__dict__[model_type]( buffer_size=buffer_size, diffloss_d=diffloss_d, diffloss_w=diffloss_w, num_sampling_steps=str(num_sampling_steps_diffloss) ).to(device) checkpoint_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename=f"checkpoint-last.pth" ) state_dict = torch.load(checkpoint_path, map_location=device)["model_ema"] model_architecture.load_state_dict(state_dict, strict=False) # update this so the model works on the forward call model = model_architecture model.eval() return model def save_pretrained(self, save_directory): # we will save to safetensors os.makedirs(save_directory, exist_ok=True) state_dict = self.model.state_dict() safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors") save_file(state_dict, safetensors_path) # save the configuration as usual self.config.save_pretrained(save_directory)