jadechoghari
commited on
Commit
•
de1cf45
1
Parent(s):
1ef59f6
Update audioldm_train/train/latent_diffusion.py
Browse files
audioldm_train/train/latent_diffusion.py
CHANGED
@@ -15,7 +15,7 @@ from tqdm import tqdm
|
|
15 |
from pytorch_lightning.strategies.ddp import DDPStrategy
|
16 |
|
17 |
|
18 |
-
from audioldm_train.modules.latent_diffusion.ddpm import LatentDiffusion
|
19 |
|
20 |
|
21 |
from torch.utils.data import WeightedRandomSampler
|
@@ -25,13 +25,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|
25 |
from pytorch_lightning.loggers import WandbLogger
|
26 |
|
27 |
|
28 |
-
from audioldm_train.utilities.tools import (
|
29 |
listdir_nohidden,
|
30 |
get_restore_step,
|
31 |
copy_test_subset_data,
|
32 |
)
|
33 |
import wandb
|
34 |
-
from audioldm_train.utilities.model_util import instantiate_from_config
|
35 |
import logging
|
36 |
|
37 |
logging.basicConfig(level=logging.WARNING)
|
@@ -75,7 +75,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, perform_validation
|
|
75 |
|
76 |
#try:
|
77 |
mos_path = configs["mos_path"]
|
78 |
-
from audioldm_train.utilities.data.hhhh import AudioDataset
|
79 |
dataset = AudioDataset(config=configs, lmdb_path=train_lmdb_path, key_path=train_key_path, mos_path=mos_path)
|
80 |
|
81 |
|
|
|
15 |
from pytorch_lightning.strategies.ddp import DDPStrategy
|
16 |
|
17 |
|
18 |
+
from qa_mdt.audioldm_train.modules.latent_diffusion.ddpm import LatentDiffusion
|
19 |
|
20 |
|
21 |
from torch.utils.data import WeightedRandomSampler
|
|
|
25 |
from pytorch_lightning.loggers import WandbLogger
|
26 |
|
27 |
|
28 |
+
from qa_mdt.audioldm_train.utilities.tools import (
|
29 |
listdir_nohidden,
|
30 |
get_restore_step,
|
31 |
copy_test_subset_data,
|
32 |
)
|
33 |
import wandb
|
34 |
+
from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config
|
35 |
import logging
|
36 |
|
37 |
logging.basicConfig(level=logging.WARNING)
|
|
|
75 |
|
76 |
#try:
|
77 |
mos_path = configs["mos_path"]
|
78 |
+
from qa_mdt.audioldm_train.utilities.data.hhhh import AudioDataset
|
79 |
dataset = AudioDataset(config=configs, lmdb_path=train_lmdb_path, key_path=train_key_path, mos_path=mos_path)
|
80 |
|
81 |
|