mar / modeling.py
jadechoghari's picture
Update modeling.py
303a347 verified
raw
history blame
7.1 kB
from transformers import PretrainedConfig
import torch.nn as nn
from transformers import PreTrainedModel
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import save_file, load_file
import os
from timm.models.vision_transformer import Block
from . import mar
from .vae import AutoencoderKL
from .mar import MAR
import numpy as np
class MARConfig(PretrainedConfig):
model_type = "mar"
def __init__(self,
img_size=256,
vae_stride=16,
patch_size=1,
encoder_embed_dim=1024,
encoder_depth=16,
encoder_num_heads=16,
decoder_embed_dim=1024,
decoder_depth=16,
decoder_num_heads=16,
mlp_ratio=4.,
norm_layer="LayerNorm",
vae_embed_dim=16,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4,
grad_checkpointing=False,
**kwargs):
super().__init__(**kwargs)
# store parameters in the config
self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.encoder_embed_dim = encoder_embed_dim
self.encoder_depth = encoder_depth
self.encoder_num_heads = encoder_num_heads
self.decoder_embed_dim = decoder_embed_dim
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.mlp_ratio = mlp_ratio
self.norm_layer = norm_layer
self.vae_embed_dim = vae_embed_dim
self.mask_ratio_min = mask_ratio_min
self.label_drop_prob = label_drop_prob
self.class_num = class_num
self.attn_dropout = attn_dropout
self.proj_dropout = proj_dropout
self.buffer_size = buffer_size
self.diffloss_d = diffloss_d
self.diffloss_w = diffloss_w
self.num_sampling_steps = num_sampling_steps
self.diffusion_batch_mul = diffusion_batch_mul
self.grad_checkpointing = grad_checkpointing
class MARModel(PreTrainedModel):
# links to MARConfig class
config_class = MARConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# convert norm_layer from string to class
norm_layer = getattr(nn, config.norm_layer)
# init the mar model using the parameters from config
self.model = MAR(
img_size=config.img_size,
vae_stride=config.vae_stride,
patch_size=config.patch_size,
encoder_embed_dim=config.encoder_embed_dim,
encoder_depth=config.encoder_depth,
encoder_num_heads=config.encoder_num_heads,
decoder_embed_dim=config.decoder_embed_dim,
decoder_depth=config.decoder_depth,
decoder_num_heads=config.decoder_num_heads,
mlp_ratio=config.mlp_ratio,
norm_layer=norm_layer, # use the actual class for the layer
vae_embed_dim=config.vae_embed_dim,
mask_ratio_min=config.mask_ratio_min,
label_drop_prob=config.label_drop_prob,
class_num=config.class_num,
attn_dropout=config.attn_dropout,
proj_dropout=config.proj_dropout,
buffer_size=config.buffer_size,
diffloss_d=config.diffloss_d,
diffloss_w=config.diffloss_w,
num_sampling_steps=config.num_sampling_steps,
diffusion_batch_mul=config.diffusion_batch_mul,
grad_checkpointing=config.grad_checkpointing,
)
def forward_train(self, imgs, labels):
# calls the forward method from the mar class - passing imgs & labels
return self.model(imgs, labels)
def forward(self, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
# call the sample_tokens method from the MAR class
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=f"kl16.safetensors"
)
vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=checkpoint_path)
vae = vae.to(device).eval()
# can customize more from the user
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)
num_ar_steps = 64
cfg_scale = 4
cfg_schedule = "constant"
temperature = 1.0
# TODO: this should be defined by the user
class_labels = 207, 360, 388, 113, 355, 980, 323, 979 #@param {type:"raw"}
samples_per_row = 4
with torch.cuda.amp.autocast():
sampled_tokens = self.model.sample_tokens(
bsz=len(class_labels), num_iter=num_ar_steps,
cfg=cfg_scale, cfg_schedule=cfg_schedule,
labels=torch.Tensor(class_labels).long().to(device),
temperature=temperature, progress=True)
sampled_images = vae.decode(sampled_tokens / 0.2325)
return sampled_images
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# config = MARConfig.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# model = cls(config)
buffer_size = kwargs.get('buffer_size', 64)
diffloss_d = kwargs.get('diffloss_d', 3)
diffloss_w = kwargs.get('diffloss_w', 1024)
num_sampling_steps_diffloss = kwargs.get('num_sampling_steps', 100)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_type = "mar_base"
model_architecture = mar.__dict__[model_type](
buffer_size=buffer_size,
diffloss_d=diffloss_d,
diffloss_w=diffloss_w,
num_sampling_steps=str(num_sampling_steps_diffloss)
).to(device)
checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=f"checkpoint-last.pth"
)
state_dict = torch.load(checkpoint_path, map_location=device)["model_ema"]
model_architecture.load_state_dict(state_dict, strict=False)
# update this so the model works on the forward call
model = model_architecture
model.eval()
return model
def save_pretrained(self, save_directory):
# we will save to safetensors
os.makedirs(save_directory, exist_ok=True)
state_dict = self.model.state_dict()
safetensors_path = os.path.join(save_directory, "pytorch_model.safetensors")
save_file(state_dict, safetensors_path)
# save the configuration as usual
self.config.save_pretrained(save_directory)