File size: 5,365 Bytes
5085882
 
 
 
 
 
 
4f5251a
 
5085882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Reference Repo: https://github.com/facebookresearch/AudioMAE
"""

import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
import qa_mdt.audioldm_train.modules.audiomae.models_vit as models_vit
import qa_mdt.audioldm_train.modules.audiomae.models_mae as models_mae

# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))


class PatchEmbed_new(nn.Module):
    """Flexible Image to Patch Embedding"""

    def __init__(
        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
    ):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        stride = to_2tuple(stride)

        self.img_size = img_size
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=stride
        )  # with overlapped patches
        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
        # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        _, _, h, w = self.get_output_shape(img_size)  # n, emb_dim, h, w
        self.patch_hw = (h, w)
        self.num_patches = h * w

    def get_output_shape(self, img_size):
        # todo: don't be lazy..
        return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #    f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class AudioMAE(nn.Module):
    """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""

    def __init__(
        self,
    ):
        super().__init__()
        model = models_vit.__dict__["vit_base_patch16"](
            num_classes=527,
            drop_path_rate=0.1,
            global_pool=True,
            mask_2d=True,
            use_custom_patch=False,
        )

        img_size = (1024, 128)
        emb_dim = 768

        model.patch_embed = PatchEmbed_new(
            img_size=img_size,
            patch_size=(16, 16),
            in_chans=1,
            embed_dim=emb_dim,
            stride=16,
        )
        num_patches = model.patch_embed.num_patches
        # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
        model.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
        )  # fixed sin-cos embedding

        checkpoint_path = (
            "/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth"
        )
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        msg = model.load_state_dict(checkpoint["model"], strict=False)
        # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')

        self.model = model

    def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
        """
        x: mel fbank [Batch, 1, T, F]
        mask_t_prob: 'T masking ratio (percentage of removed patches).'
        mask_f_prob: 'F masking ratio (percentage of removed patches).'
        """
        return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)


class Vanilla_AudioMAE(nn.Module):
    """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)"""

    def __init__(
        self,
    ):
        super().__init__()
        model = models_mae.__dict__["mae_vit_base_patch16"](
            in_chans=1, audio_exp=True, img_size=(1024, 128)
        )

        checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt"
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        msg = model.load_state_dict(checkpoint["model"], strict=False)

        # Skip the missing keys of decoder modules (not required)
        # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')

        self.model = model.eval()

    def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
        """
        x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
        mask_ratio: 'masking ratio (percentage of removed patches).'
        """
        with torch.no_grad():
            # embed: [B, 513, 768] for mask_ratio=0.0
            if no_mask:
                if no_average:
                    raise RuntimeError("This function is deprecated")
                    embed = self.model.forward_encoder_no_random_mask_no_average(
                        x
                    )  # mask_ratio
                else:
                    embed = self.model.forward_encoder_no_mask(x)  # mask_ratio
            else:
                raise RuntimeError("This function is deprecated")
                embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
        return embed


if __name__ == "__main__":
    model = Vanilla_AudioMAE().cuda()
    input = torch.randn(4, 1, 1024, 128).cuda()
    print("The first run")
    embed = model(input, mask_ratio=0.0, no_mask=True)
    print(embed)
    print("The second run")
    embed = model(input, mask_ratio=0.0)
    print(embed)