File size: 4,418 Bytes
1822fe2
 
 
bfffca7
 
1822fe2
 
 
9885231
9ff7343
80a6dc1
f2191a0
ca410d5
1822fe2
 
de18685
1822fe2
 
 
 
 
de1c669
1822fe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de18685
 
 
 
 
1822fe2
de18685
 
 
ca410d5
de18685
 
80a6dc1
ca410d5
de18685
 
80a6dc1
ca410d5
de18685
 
80a6dc1
de18685
 
b859dac
1822fe2
 
80a6dc1
1822fe2
de18685
 
1822fe2
de18685
 
 
 
 
 
1822fe2
80a6dc1
99cd496
de18685
 
1822fe2
 
 
 
5bc55c8
1822fe2
de18685
1822fe2
 
 
 
 
80a6dc1
1822fe2
 
 
 
 
 
 
4bd156b
1822fe2
 
 
de18685
1822fe2
 
de18685
1822fe2
 
 
 
 
de18685
 
 
 
 
80a6dc1
de18685
 
 
 
 
 
 
 
 
1822fe2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from diffusers import DiffusionPipeline
import torch
import numpy as np
import importlib.util
import sys
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
from torchvision.utils import save_image
from PIL import Image
from safetensors.torch import load_file
from .vae import AutoencoderKL
from .mar import mar_base, mar_large, mar_huge

# inheriting from DiffusionPipeline for HF
class MARModel(DiffusionPipeline):

    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def __call__(self, *args, **kwargs):
        """
        This method downloads the model and VAE components,
        then executes the forward pass based on the user's input.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



        # init the mar model architecture
        buffer_size = kwargs.get("buffer_size", 64)
        diffloss_d = kwargs.get("diffloss_d", 3)
        diffloss_w = kwargs.get("diffloss_w", 1024)
        num_sampling_steps = kwargs.get("num_sampling_steps", 100)
        model_type = kwargs.get("model_type", "mar_base")

        model_mapping = {
            "mar_base": mar_base,
            "mar_large": mar_large,
            "mar_huge": mar_huge
        }

        num_sampling_steps_diffloss = 100  # Example number of sampling steps   

        # download the pretrained model and set diffloss parameters
        if model_type == "mar_base":
            diffloss_d = 6
            diffloss_w = 1024
            model_path = "mar-base.safetensors"
        elif model_type == "mar_large":
            diffloss_d = 8
            diffloss_w = 1280
            model_path = "mar-large.safetensors"
        elif model_type == "mar_huge":
            diffloss_d = 12
            diffloss_w = 1536
            model_path = "mar-huge.safetensors"
        else:
            raise NotImplementedError
        # download and load the model weights (.safetensors or .pth)
        model_checkpoint_path = hf_hub_download(
            repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
            filename=kwargs.get("model_filename", model_path)
        )

        model_fn = model_mapping[model_type]

        model = model_fn(
            buffer_size=64,
            diffloss_d=diffloss_d,
            diffloss_w=diffloss_w,
            num_sampling_steps=str(num_sampling_steps_diffloss)
        ).cuda()

        # use safetensors
        state_dict = load_file(model_checkpoint_path)
        model.load_state_dict(state_dict)
        model.eval()

        # download and load the vae
        vae_checkpoint_path = hf_hub_download(
            repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
            filename=kwargs.get("vae_filename", "kl16.safetensors")
        )
        vae_checkpoint_path = kwargs.get("vae_checkpoint_path", vae_checkpoint_path)

        vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
        vae = vae.to(device).eval()

        # set up user-specified or default values for generation
        seed = kwargs.get("seed", 6)
        torch.manual_seed(seed)
        np.random.seed(seed)

        num_ar_steps = kwargs.get("num_ar_steps", 64)
        cfg_scale = kwargs.get("cfg_scale", 4)
        cfg_schedule = kwargs.get("cfg_schedule", "constant")
        temperature = kwargs.get("temperature", 1.0)
        class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])

        # generate the tokens and images
        with torch.cuda.amp.autocast():
            sampled_tokens = model.sample_tokens(
                bsz=len(class_labels), num_iter=num_ar_steps,
                cfg=cfg_scale, cfg_schedule=cfg_schedule,
                labels=torch.Tensor(class_labels).long().cuda(),
                temperature=temperature, progress=True
            )

            sampled_images = vae.decode(sampled_tokens / 0.2325)

        output_dir = kwargs.get("output_dir", "./")
        os.makedirs(output_dir, exist_ok=True)
    
        # save the images
        image_path = os.path.join(output_dir, "sampled_image.png")
        samples_per_row = kwargs.get("samples_per_row", 4)
    
        save_image(
            sampled_images, image_path, nrow=int(samples_per_row), normalize=True, value_range=(-1, 1)
        )
    
        # return as a pil image
        image = Image.open(image_path)
    
        return image