jadechoghari commited on
Commit
ec033ea
1 Parent(s): 880499e

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +18 -14
pipeline.py CHANGED
@@ -2,7 +2,7 @@ from diffusers import DiffusionPipeline
2
  import os
3
  import sys
4
  from huggingface_hub import HfApi, hf_hub_download
5
- from .tools import build_dataset_json_from_list
6
  import torch
7
 
8
  class MOSDiffusionPipeline(DiffusionPipeline):
@@ -67,28 +67,32 @@ class MOSDiffusionPipeline(DiffusionPipeline):
67
 
68
 
69
  @torch.no_grad()
70
- def __call__(self, *args, **kwargs):
71
  """
72
  Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
73
  """
74
  from .infer.infer_mos5 import infer
 
75
 
 
76
  infer(
77
- dataset_key=self.dataset_key,
78
  configs=self.configs,
79
  config_yaml_path=self.config_yaml,
80
- exp_group_name=self.exp_group_name,
81
- exp_name=self.exp_name
82
  )
83
 
 
 
 
 
 
 
 
 
84
  # Example of how to use the pipeline
85
  if __name__ == "__main__":
86
- pipeline = MOSDiffusionPipeline(
87
- config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml",
88
- list_inference="test_prompts/good_prompts_1.lst",
89
- reload_from_ckpt="checkpoints/checkpoint_389999.ckpt",
90
- base_folder=None
91
- )
92
-
93
- # Run the pipeline
94
- pipeline()
 
2
  import os
3
  import sys
4
  from huggingface_hub import HfApi, hf_hub_download
5
+ # from .tools import build_dataset_json_from_list
6
  import torch
7
 
8
  class MOSDiffusionPipeline(DiffusionPipeline):
 
67
 
68
 
69
  @torch.no_grad()
70
+ def __call__(self, prompt: str):
71
  """
72
  Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py.
73
  """
74
  from .infer.infer_mos5 import infer
75
+ dataset_key = self.build_dataset_json_from_prompt(prompt)
76
 
77
+ # we run inference with the prompt - configs - and other settings
78
  infer(
79
+ dataset_key=dataset_key,
80
  configs=self.configs,
81
  config_yaml_path=self.config_yaml,
82
+ exp_group_name="qa_mdt",
83
+ exp_name="mos_as_token"
84
  )
85
 
86
+ def build_dataset_json_from_prompt(self, prompt: str):
87
+ """
88
+ Build dataset_key dynamically from the provided prompt.
89
+ """
90
+ # for simplicity let's just return the prompt as the dataset_key
91
+ return {"prompt": prompt}
92
+
93
+
94
  # Example of how to use the pipeline
95
  if __name__ == "__main__":
96
+ pipe = MOSDiffusionPipeline()
97
+ result = pipe("Generate a description of a sunny day.")
98
+ print(result)