EzAudio-ControlNet / src /models /conditions /condition_wrapper.py
OpenSound's picture
Upload 211 files
9d3cb0a verified
raw
history blame
1.79 kB
import torch
import torch.nn as nn
from .chroma import ChromaExtractor
from .energy import EnergyExtractor
from .voice import VoiceConversionExtractor
from .mbenergy import MultibandEnergyExtractor
class Conditioner(nn.Module):
def __init__(self,
condition_type,
**kwargs
):
super().__init__()
if condition_type == 'energy':
self.conditioner = EnergyExtractor(**kwargs)
elif condition_type == 'chroma':
self.conditioner = ChromaExtractor(**kwargs)
elif condition_type == 'vc':
self.conditioner = VoiceConversionExtractor(**kwargs)
elif condition_type == 'mb_energy':
self.conditioner = MultibandEnergyExtractor(**kwargs)
else:
raise NotImplementedError
def forward(self, waveform, latent_shape):
# B T C
condition = self.conditioner(waveform)
# B C T
condition = condition.permute(0, 2, 1).contiguous()
if len(latent_shape) == 4:
# 2d spectrogram B C T F
assert (condition.shape[-1] % latent_shape[-2]) == 0
X = latent_shape[-1] * condition.shape[-1] // latent_shape[-2]
# copy on F direction
condition = condition.unsqueeze(-1).expand(-1, -1, -1, X)
elif len(latent_shape) == 3:
condition = condition
else:
raise NotImplementedError
return condition
if __name__ == '__main__':
conditioner = Conditioner(condition_type='energy',
hop_size=160, window_size=1024, padding='reflect',
min_db=-80, norm=True)
audio = torch.rand(4, 16000) # Example audio signal
energy = conditioner(audio, (4, 8, 100, 64))