File size: 5,351 Bytes
5085882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""
Reference Repo: https://github.com/facebookresearch/AudioMAE
"""
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
import audioldm_train.modules.audiomae.models_vit as models_vit
import audioldm_train.modules.audiomae.models_mae as models_mae
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
class PatchEmbed_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
) # with overlapped patches
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h * w
def get_output_shape(self, img_size):
# todo: don't be lazy..
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
def __init__(
self,
):
super().__init__()
model = models_vit.__dict__["vit_base_patch16"](
num_classes=527,
drop_path_rate=0.1,
global_pool=True,
mask_2d=True,
use_custom_patch=False,
)
img_size = (1024, 128)
emb_dim = 768
model.patch_embed = PatchEmbed_new(
img_size=img_size,
patch_size=(16, 16),
in_chans=1,
embed_dim=emb_dim,
stride=16,
)
num_patches = model.patch_embed.num_patches
# num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
model.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
) # fixed sin-cos embedding
checkpoint_path = (
"/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
msg = model.load_state_dict(checkpoint["model"], strict=False)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model
def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
"""
x: mel fbank [Batch, 1, T, F]
mask_t_prob: 'T masking ratio (percentage of removed patches).'
mask_f_prob: 'F masking ratio (percentage of removed patches).'
"""
return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
class Vanilla_AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)"""
def __init__(
self,
):
super().__init__()
model = models_mae.__dict__["mae_vit_base_patch16"](
in_chans=1, audio_exp=True, img_size=(1024, 128)
)
checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")
msg = model.load_state_dict(checkpoint["model"], strict=False)
# Skip the missing keys of decoder modules (not required)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model.eval()
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
"""
x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
mask_ratio: 'masking ratio (percentage of removed patches).'
"""
with torch.no_grad():
# embed: [B, 513, 768] for mask_ratio=0.0
if no_mask:
if no_average:
raise RuntimeError("This function is deprecated")
embed = self.model.forward_encoder_no_random_mask_no_average(
x
) # mask_ratio
else:
embed = self.model.forward_encoder_no_mask(x) # mask_ratio
else:
raise RuntimeError("This function is deprecated")
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
return embed
if __name__ == "__main__":
model = Vanilla_AudioMAE().cuda()
input = torch.randn(4, 1, 1024, 128).cuda()
print("The first run")
embed = model(input, mask_ratio=0.0, no_mask=True)
print(embed)
print("The second run")
embed = model(input, mask_ratio=0.0)
print(embed)
|