|
""" |
|
Reference Repo: https://github.com/facebookresearch/AudioMAE |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from timm.models.layers import to_2tuple |
|
import audioldm_train.modules.audiomae.models_vit as models_vit |
|
import audioldm_train.modules.audiomae.models_mae as models_mae |
|
|
|
|
|
|
|
|
|
class PatchEmbed_new(nn.Module): |
|
"""Flexible Image to Patch Embedding""" |
|
|
|
def __init__( |
|
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10 |
|
): |
|
super().__init__() |
|
img_size = to_2tuple(img_size) |
|
patch_size = to_2tuple(patch_size) |
|
stride = to_2tuple(stride) |
|
|
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
|
|
self.proj = nn.Conv2d( |
|
in_chans, embed_dim, kernel_size=patch_size, stride=stride |
|
) |
|
|
|
|
|
|
|
|
|
_, _, h, w = self.get_output_shape(img_size) |
|
self.patch_hw = (h, w) |
|
self.num_patches = h * w |
|
|
|
def get_output_shape(self, img_size): |
|
|
|
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
|
|
|
|
x = self.proj(x) |
|
x = x.flatten(2).transpose(1, 2) |
|
return x |
|
|
|
|
|
class AudioMAE(nn.Module): |
|
"""Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)""" |
|
|
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
model = models_vit.__dict__["vit_base_patch16"]( |
|
num_classes=527, |
|
drop_path_rate=0.1, |
|
global_pool=True, |
|
mask_2d=True, |
|
use_custom_patch=False, |
|
) |
|
|
|
img_size = (1024, 128) |
|
emb_dim = 768 |
|
|
|
model.patch_embed = PatchEmbed_new( |
|
img_size=img_size, |
|
patch_size=(16, 16), |
|
in_chans=1, |
|
embed_dim=emb_dim, |
|
stride=16, |
|
) |
|
num_patches = model.patch_embed.num_patches |
|
|
|
model.pos_embed = nn.Parameter( |
|
torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False |
|
) |
|
|
|
checkpoint_path = ( |
|
"/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth" |
|
) |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
msg = model.load_state_dict(checkpoint["model"], strict=False) |
|
|
|
|
|
self.model = model |
|
|
|
def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0): |
|
""" |
|
x: mel fbank [Batch, 1, T, F] |
|
mask_t_prob: 'T masking ratio (percentage of removed patches).' |
|
mask_f_prob: 'F masking ratio (percentage of removed patches).' |
|
""" |
|
return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob) |
|
|
|
|
|
class Vanilla_AudioMAE(nn.Module): |
|
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)""" |
|
|
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
model = models_mae.__dict__["mae_vit_base_patch16"]( |
|
in_chans=1, audio_exp=True, img_size=(1024, 128) |
|
) |
|
|
|
checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt" |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
msg = model.load_state_dict(checkpoint["model"], strict=False) |
|
|
|
|
|
|
|
|
|
self.model = model.eval() |
|
|
|
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False): |
|
""" |
|
x: mel fbank [Batch, 1, 1024 (T), 128 (F)] |
|
mask_ratio: 'masking ratio (percentage of removed patches).' |
|
""" |
|
with torch.no_grad(): |
|
|
|
if no_mask: |
|
if no_average: |
|
raise RuntimeError("This function is deprecated") |
|
embed = self.model.forward_encoder_no_random_mask_no_average( |
|
x |
|
) |
|
else: |
|
embed = self.model.forward_encoder_no_mask(x) |
|
else: |
|
raise RuntimeError("This function is deprecated") |
|
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio) |
|
return embed |
|
|
|
|
|
if __name__ == "__main__": |
|
model = Vanilla_AudioMAE().cuda() |
|
input = torch.randn(4, 1, 1024, 128).cuda() |
|
print("The first run") |
|
embed = model(input, mask_ratio=0.0, no_mask=True) |
|
print(embed) |
|
print("The second run") |
|
embed = model(input, mask_ratio=0.0) |
|
print(embed) |
|
|