jadechoghari
commited on
Commit
•
ca410d5
1
Parent(s):
248b236
Update pipeline.py
Browse files- pipeline.py +22 -7
pipeline.py
CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
|
|
7 |
from safetensors.torch import load_file
|
8 |
import os
|
9 |
from .vae import AutoencoderKL
|
10 |
-
from .mar import
|
11 |
|
12 |
# inheriting from DiffusionPipeline for HF
|
13 |
class MARModel(DiffusionPipeline):
|
@@ -33,12 +33,27 @@ class MARModel(DiffusionPipeline):
|
|
33 |
model_type = kwargs.get("model_type", "mar_base")
|
34 |
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# download and load the model weights (.safetensors or .pth)
|
43 |
model_checkpoint_path = hf_hub_download(
|
44 |
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
|
|
|
7 |
from safetensors.torch import load_file
|
8 |
import os
|
9 |
from .vae import AutoencoderKL
|
10 |
+
from .mar import mar_base, mar_large, mar_huge
|
11 |
|
12 |
# inheriting from DiffusionPipeline for HF
|
13 |
class MARModel(DiffusionPipeline):
|
|
|
33 |
model_type = kwargs.get("model_type", "mar_base")
|
34 |
|
35 |
|
36 |
+
if model_type == "mar_base":
|
37 |
+
self.model = mar_base(
|
38 |
+
buffer_size=buffer_size,
|
39 |
+
diffloss_d=diffloss_d,
|
40 |
+
diffloss_w=diffloss_w,
|
41 |
+
num_sampling_steps=str(num_sampling_steps)
|
42 |
+
).to(device)
|
43 |
+
elif model_type == "mar_large":
|
44 |
+
self.model = mar_large(
|
45 |
+
buffer_size=buffer_size,
|
46 |
+
diffloss_d=diffloss_d,
|
47 |
+
diffloss_w=diffloss_w,
|
48 |
+
num_sampling_steps=str(num_sampling_steps)
|
49 |
+
).to(device)
|
50 |
+
elif model_type == "mar_huge":
|
51 |
+
self.model = mar_huge(
|
52 |
+
buffer_size=buffer_size,
|
53 |
+
diffloss_d=diffloss_d,
|
54 |
+
diffloss_w=diffloss_w,
|
55 |
+
num_sampling_steps=str(num_sampling_steps)
|
56 |
+
).to(device)
|
57 |
# download and load the model weights (.safetensors or .pth)
|
58 |
model_checkpoint_path = hf_hub_download(
|
59 |
repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
|