jadechoghari
commited on
Commit
•
2a54e4f
1
Parent(s):
c57d9cf
Update audioldm_train/modules/latent_encoder/autoencoder.py
Browse files
audioldm_train/modules/latent_encoder/autoencoder.py
CHANGED
@@ -6,20 +6,20 @@ import pytorch_lightning as pl
|
|
6 |
import torch.nn.functional as F
|
7 |
from contextlib import contextmanager
|
8 |
import numpy as np
|
9 |
-
from audioldm_train.modules.diffusionmodules.ema import *
|
10 |
|
11 |
from torch.optim.lr_scheduler import LambdaLR
|
12 |
-
from audioldm_train.modules.diffusionmodules.model import Encoder, Decoder
|
13 |
-
from audioldm_train.modules.diffusionmodules.distributions import (
|
14 |
DiagonalGaussianDistribution,
|
15 |
)
|
16 |
|
17 |
import wandb
|
18 |
-
from audioldm_train.utilities.model_util import instantiate_from_config
|
19 |
import soundfile as sf
|
20 |
|
21 |
-
from audioldm_train.utilities.model_util import get_vocoder
|
22 |
-
from audioldm_train.utilities.tools import synth_one_sample
|
23 |
import itertools
|
24 |
|
25 |
|
@@ -150,7 +150,7 @@ class AutoencoderKL(pl.LightningModule):
|
|
150 |
return dec
|
151 |
|
152 |
def decode_to_waveform(self, dec):
|
153 |
-
from audioldm_train.utilities.model_util import vocoder_infer
|
154 |
|
155 |
if self.image_key == "fbank":
|
156 |
dec = dec.squeeze(1).permute(0, 2, 1)
|
|
|
6 |
import torch.nn.functional as F
|
7 |
from contextlib import contextmanager
|
8 |
import numpy as np
|
9 |
+
from qa_mdt.audioldm_train.modules.diffusionmodules.ema import *
|
10 |
|
11 |
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from qa_mdt.audioldm_train.modules.diffusionmodules.model import Encoder, Decoder
|
13 |
+
from qa_mdt.audioldm_train.modules.diffusionmodules.distributions import (
|
14 |
DiagonalGaussianDistribution,
|
15 |
)
|
16 |
|
17 |
import wandb
|
18 |
+
from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config
|
19 |
import soundfile as sf
|
20 |
|
21 |
+
from qa_mdt.audioldm_train.utilities.model_util import get_vocoder
|
22 |
+
from qa_mdt.audioldm_train.utilities.tools import synth_one_sample
|
23 |
import itertools
|
24 |
|
25 |
|
|
|
150 |
return dec
|
151 |
|
152 |
def decode_to_waveform(self, dec):
|
153 |
+
from qa_mdt.audioldm_train.utilities.model_util import vocoder_infer
|
154 |
|
155 |
if self.image_key == "fbank":
|
156 |
dec = dec.squeeze(1).permute(0, 2, 1)
|