|
from transformers import PreTrainedModel |
|
from audio_encoders_pytorch import TanhBottleneck |
|
from audio_diffusion_pytorch import UniformDistribution, LinearSchedule, VSampler, DiffusionMAE1d |
|
from .dmae_config import DMAE1dConfig |
|
|
|
bottleneck = { 'tanh': TanhBottleneck } |
|
|
|
class DMAE1d(PreTrainedModel): |
|
|
|
config_class = DMAE1dConfig |
|
|
|
def __init__(self, config: DMAE1dConfig): |
|
super().__init__(config) |
|
|
|
self.model = DiffusionMAE1d( |
|
in_channels = config.in_channels, |
|
channels = config.channels, |
|
multipliers = config.multipliers, |
|
factors = config.factors, |
|
num_blocks = config.num_blocks, |
|
attentions = config.attentions, |
|
encoder_inject_depth = config.encoder_inject_depth, |
|
encoder_channels = config.encoder_channels, |
|
encoder_factors = config.encoder_factors, |
|
encoder_multipliers = config.encoder_multipliers, |
|
encoder_num_blocks = config.encoder_num_blocks, |
|
bottleneck = bottleneck[config.bottleneck](), |
|
stft_use_complex = config.stft_use_complex, |
|
stft_num_fft = config.stft_num_fft, |
|
stft_hop_length = config.stft_hop_length, |
|
diffusion_type = 'v', |
|
diffusion_sigma_distribution = UniformDistribution(), |
|
resnet_groups=8, |
|
kernel_multiplier_downsample=2, |
|
use_nearest_upsample=False, |
|
use_skip_scale=True, |
|
use_context_time=True, |
|
patch_factor=1, |
|
patch_blocks=1, |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def encode(self, *args, **kwargs): |
|
return self.model.encode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
default_kwargs = dict( |
|
sigma_schedule=LinearSchedule(), |
|
sampler=VSampler(), |
|
clamp=True |
|
) |
|
return self.model.decode(*args, **{**default_kwargs, **kwargs}) |
|
|
|
|