jadechoghari commited on
Commit
ca410d5
1 Parent(s): 248b236

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +22 -7
pipeline.py CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
  import os
9
  from .vae import AutoencoderKL
10
- from .mar import mar
11
 
12
  # inheriting from DiffusionPipeline for HF
13
  class MARModel(DiffusionPipeline):
@@ -33,12 +33,27 @@ class MARModel(DiffusionPipeline):
33
  model_type = kwargs.get("model_type", "mar_base")
34
 
35
 
36
- self.model = mar.__dict__[model_type](
37
- buffer_size=buffer_size,
38
- diffloss_d=diffloss_d,
39
- diffloss_w=diffloss_w,
40
- num_sampling_steps=str(num_sampling_steps)
41
- ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # download and load the model weights (.safetensors or .pth)
43
  model_checkpoint_path = hf_hub_download(
44
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),
 
7
  from safetensors.torch import load_file
8
  import os
9
  from .vae import AutoencoderKL
10
+ from .mar import mar_base, mar_large, mar_huge
11
 
12
  # inheriting from DiffusionPipeline for HF
13
  class MARModel(DiffusionPipeline):
 
33
  model_type = kwargs.get("model_type", "mar_base")
34
 
35
 
36
+ if model_type == "mar_base":
37
+ self.model = mar_base(
38
+ buffer_size=buffer_size,
39
+ diffloss_d=diffloss_d,
40
+ diffloss_w=diffloss_w,
41
+ num_sampling_steps=str(num_sampling_steps)
42
+ ).to(device)
43
+ elif model_type == "mar_large":
44
+ self.model = mar_large(
45
+ buffer_size=buffer_size,
46
+ diffloss_d=diffloss_d,
47
+ diffloss_w=diffloss_w,
48
+ num_sampling_steps=str(num_sampling_steps)
49
+ ).to(device)
50
+ elif model_type == "mar_huge":
51
+ self.model = mar_huge(
52
+ buffer_size=buffer_size,
53
+ diffloss_d=diffloss_d,
54
+ diffloss_w=diffloss_w,
55
+ num_sampling_steps=str(num_sampling_steps)
56
+ ).to(device)
57
  # download and load the model weights (.safetensors or .pth)
58
  model_checkpoint_path = hf_hub_download(
59
  repo_id=kwargs.get("repo_id", "jadechoghari/mar"),