File size: 2,006 Bytes
0788506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import PreTrainedModel
from audio_encoders_pytorch import TanhBottleneck
from audio_diffusion_pytorch import UniformDistribution, LinearSchedule, VSampler, DiffusionMAE1d
from .dmae_config import DMAE1dConfig

bottleneck = { 'tanh': TanhBottleneck }

class DMAE1d(PreTrainedModel):

    config_class = DMAE1dConfig

    def __init__(self, config: DMAE1dConfig):
        super().__init__(config)

        self.model = DiffusionMAE1d(
            in_channels = config.in_channels,
            channels = config.channels, 
            multipliers = config.multipliers,
            factors = config.factors, 
            num_blocks = config.num_blocks, 
            attentions = config.attentions, 
            encoder_inject_depth = config.encoder_inject_depth,
            encoder_channels = config.encoder_channels, 
            encoder_factors = config.encoder_factors, 
            encoder_multipliers = config.encoder_multipliers,
            encoder_num_blocks = config.encoder_num_blocks, 
            bottleneck = bottleneck[config.bottleneck](),
            stft_use_complex = config.stft_use_complex, 
            stft_num_fft = config.stft_num_fft,
            stft_hop_length = config.stft_hop_length,
            diffusion_type = 'v',
            diffusion_sigma_distribution = UniformDistribution(),
            resnet_groups=8,
            kernel_multiplier_downsample=2,
            use_nearest_upsample=False,
            use_skip_scale=True,
            use_context_time=True,
            patch_factor=1,
            patch_blocks=1,
        )

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.model.encode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        default_kwargs = dict(
            sigma_schedule=LinearSchedule(), 
            sampler=VSampler(), 
            clamp=True
        )
        return self.model.decode(*args, **{**default_kwargs, **kwargs})