dmae1d-ATC64-v1 / dmae.py
flavioschneider's picture
Upload DMAE1d
0788506
from transformers import PreTrainedModel
from audio_encoders_pytorch import TanhBottleneck
from audio_diffusion_pytorch import UniformDistribution, LinearSchedule, VSampler, DiffusionMAE1d
from .dmae_config import DMAE1dConfig
bottleneck = { 'tanh': TanhBottleneck }
class DMAE1d(PreTrainedModel):
config_class = DMAE1dConfig
def __init__(self, config: DMAE1dConfig):
super().__init__(config)
self.model = DiffusionMAE1d(
in_channels = config.in_channels,
channels = config.channels,
multipliers = config.multipliers,
factors = config.factors,
num_blocks = config.num_blocks,
attentions = config.attentions,
encoder_inject_depth = config.encoder_inject_depth,
encoder_channels = config.encoder_channels,
encoder_factors = config.encoder_factors,
encoder_multipliers = config.encoder_multipliers,
encoder_num_blocks = config.encoder_num_blocks,
bottleneck = bottleneck[config.bottleneck](),
stft_use_complex = config.stft_use_complex,
stft_num_fft = config.stft_num_fft,
stft_hop_length = config.stft_hop_length,
diffusion_type = 'v',
diffusion_sigma_distribution = UniformDistribution(),
resnet_groups=8,
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
patch_factor=1,
patch_blocks=1,
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
def decode(self, *args, **kwargs):
default_kwargs = dict(
sigma_schedule=LinearSchedule(),
sampler=VSampler(),
clamp=True
)
return self.model.decode(*args, **{**default_kwargs, **kwargs})