from diffusers import DiffusionPipeline import os import sys from huggingface_hub import HfApi, hf_hub_download from .tools import build_dataset_json_from_list import torch class MOSDiffusionPipeline(DiffusionPipeline): def __init__(self, config_yaml, list_inference, reload_from_ckpt=None, base_folder=None): """ Initialize the MOS Diffusion pipeline and download the necessary files/folders. Args: config_yaml (str): Path to the YAML configuration file. list_inference (str): Path to the file containing inference prompts. reload_from_ckpt (str, optional): Checkpoint path to reload from. base_folder (str, optional): Base folder to store downloaded files. Defaults to the current working directory. """ super().__init__() self.base_folder = base_folder if base_folder else os.getcwd() self.repo_id = "jadechoghari/qa-mdt" self.config_yaml = config_yaml self.list_inference = list_inference self.reload_from_ckpt = reload_from_ckpt config_yaml_path = os.path.join(self.config_yaml) self.configs = self.load_yaml(config_yaml_path) if self.reload_from_ckpt is not None: self.configs["reload_from_ckpt"] = self.reload_from_ckpt self.dataset_key = build_dataset_json_from_list(self.list_inference) self.exp_name = os.path.basename(self.config_yaml.split(".")[0]) self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml)) def download_required_folders(self): """ Downloads the necessary folders from the Hugging Face Hub if they are not already available locally. """ api = HfApi() files = api.list_repo_files(repo_id=self.repo_id) required_folders = ["audioldm_train", "checkpoints", "infer", "log", "taming", "test_prompts"] files_to_download = [f for f in files if any(f.startswith(folder) for folder in required_folders)] for file in files_to_download: local_file_path = os.path.join(self.base_folder, file) if not os.path.exists(local_file_path): downloaded_file = hf_hub_download(repo_id=self.repo_id, filename=file) os.makedirs(os.path.dirname(local_file_path), exist_ok=True) os.rename(downloaded_file, local_file_path) sys.path.append(self.base_folder) def load_yaml(self, yaml_path): """ Helper method to load the YAML configuration. """ import yaml with open(yaml_path, "r") as f: return yaml.safe_load(f) @torch.no_grad() def __call__(self, *args, **kwargs): """ Run the MOS Diffusion Pipeline. This method calls the infer function from infer_mos5.py. """ from .infer.infer_mos5 import infer infer( dataset_key=self.dataset_key, configs=self.configs, config_yaml_path=self.config_yaml, exp_group_name=self.exp_group_name, exp_name=self.exp_name ) # Example of how to use the pipeline if __name__ == "__main__": pipeline = MOSDiffusionPipeline( config_yaml="audioldm_train/config/mos_as_token/qa_mdt.yaml", list_inference="test_prompts/good_prompts_1.lst", reload_from_ckpt="checkpoints/checkpoint_389999.ckpt", base_folder=None ) # Run the pipeline pipeline()