|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
from audioldm_train.utilities.model_util import ( |
|
exists, |
|
default, |
|
mean_flat, |
|
count_params, |
|
instantiate_from_config, |
|
) |
|
|
|
from transformers import GPT2Config, GPT2Model |
|
import torch.optim.lr_scheduler as lr_scheduler |
|
|
|
|
|
class Prenet(nn.Module): |
|
def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5): |
|
super(Prenet, self).__init__() |
|
in_sizes = [in_dim] + sizes[:-1] |
|
self.layers = nn.ModuleList( |
|
[ |
|
nn.Linear(in_size, out_size) |
|
for (in_size, out_size) in zip(in_sizes, sizes) |
|
] |
|
) |
|
self.relu = nn.ReLU() |
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
def forward(self, inputs): |
|
for linear in self.layers: |
|
inputs = self.dropout(self.relu(linear(inputs))) |
|
return inputs |
|
|
|
|
|
class CLAP2AudioMAE(pl.LightningModule): |
|
def __init__( |
|
self, |
|
sequence_gen_length, |
|
base_learning_rate, |
|
cond_stage_config, |
|
use_audiomae_linear=False, |
|
**kwargs |
|
): |
|
|
|
super().__init__() |
|
assert use_audiomae_linear == False |
|
self.learning_rate = base_learning_rate |
|
self.cond_stage_config = cond_stage_config |
|
self.use_audiomae_linear = use_audiomae_linear |
|
|
|
self.mae_token_num = sequence_gen_length |
|
|
|
self.cond_stage_models = nn.ModuleList([]) |
|
self.instantiate_cond_stage(cond_stage_config) |
|
|
|
self.model = GPT2Model.from_pretrained("gpt2") |
|
|
|
self.linear_clap = nn.Linear(512, 768) |
|
|
|
if use_audiomae_linear: |
|
|
|
self.linear_audiomae = None |
|
|
|
self.loss_fn = nn.MSELoss() |
|
|
|
self.logger_save_dir = None |
|
self.logger_exp_name = None |
|
self.logger_exp_group_name = None |
|
self.logger_version = None |
|
|
|
def set_log_dir(self, save_dir, exp_group_name, exp_name): |
|
self.logger_save_dir = save_dir |
|
self.logger_exp_group_name = exp_group_name |
|
self.logger_exp_name = exp_name |
|
|
|
def cfg_uncond(self, batch_size): |
|
unconditional_conditioning = {} |
|
for key in self.cond_stage_model_metadata: |
|
model_idx = self.cond_stage_model_metadata[key]["model_idx"] |
|
unconditional_conditioning[key] = self.cond_stage_models[ |
|
model_idx |
|
].get_unconditional_condition(batch_size) |
|
assert ( |
|
"crossattn_audiomae_pooled" in unconditional_conditioning.keys() |
|
), "The module is not initialized with AudioMAE" |
|
unconditional_conditioning[ |
|
"crossattn_clap_to_audiomae_feature" |
|
] = unconditional_conditioning["crossattn_audiomae_pooled"] |
|
return unconditional_conditioning |
|
|
|
def configure_optimizers(self): |
|
lr = float(self.learning_rate) |
|
params = list(self.model.parameters()) + list(self.linear_clap.parameters()) |
|
|
|
if self.use_audiomae_linear: |
|
params += list(self.linear_audiomae.parameters()) |
|
|
|
opt = torch.optim.AdamW(params, lr=lr) |
|
scheduler = lr_scheduler.StepLR(opt, step_size=1, gamma=0.9) |
|
return [opt], [scheduler] |
|
|
|
def training_step(self, batch, batch_idx=None, cond_dict=None): |
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
|
|
input_embeds, target_embeds = ( |
|
cond_dict["film_clap_cond1"], |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
) |
|
|
|
|
|
if "crossattn_audiomae_pooled_44" in cond_dict.keys(): |
|
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0] |
|
|
|
if self.use_audiomae_linear: |
|
input_embeds = torch.cat( |
|
[self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)], |
|
dim=1, |
|
) |
|
else: |
|
input_embeds = torch.cat( |
|
[self.linear_clap(input_embeds), target_embeds], dim=1 |
|
) |
|
|
|
output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"] |
|
|
|
target = target_embeds |
|
output = output_embeds[:, :-1] |
|
|
|
loss = self.loss_fn(output, target) |
|
|
|
self.log( |
|
"train/loss_clap_2_audiomae", |
|
loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
|
|
self.log( |
|
"global_step_audiomae", |
|
float(self.global_step), |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
on_epoch=False, |
|
sync_dist=True, |
|
) |
|
|
|
return loss |
|
|
|
def generate(self, batch, cond_dict=None, no_grad=False): |
|
if cond_dict is None: |
|
cond_dict = self.get_input(batch) |
|
input_embeds = cond_dict["film_clap_cond1"] |
|
steps = self.mae_token_num |
|
|
|
if no_grad: |
|
with torch.no_grad(): |
|
model_input = self.linear_clap(input_embeds) |
|
for _ in range(steps): |
|
output = self.model(inputs_embeds=model_input)["last_hidden_state"] |
|
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) |
|
else: |
|
model_input = self.linear_clap(input_embeds) |
|
for _ in range(steps): |
|
output = self.model(inputs_embeds=model_input)["last_hidden_state"] |
|
model_input = torch.cat([model_input, output[:, -1:, :]], dim=1) |
|
|
|
return model_input[:, 1:], cond_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
cond_dict = self.get_input(batch) |
|
|
|
|
|
|
|
input_embeds, target_embeds = ( |
|
cond_dict["film_clap_cond1"], |
|
cond_dict["crossattn_audiomae_pooled"][0], |
|
) |
|
|
|
|
|
if "crossattn_audiomae_pooled_44" in cond_dict.keys(): |
|
target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0] |
|
|
|
if self.use_audiomae_linear: |
|
input_embeds = torch.cat( |
|
[self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)], |
|
dim=1, |
|
) |
|
else: |
|
input_embeds = torch.cat( |
|
[self.linear_clap(input_embeds), target_embeds], dim=1 |
|
) |
|
|
|
output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"] |
|
|
|
target = target_embeds |
|
output = output_embeds[:, :-1] |
|
|
|
loss = self.loss_fn(output, target) |
|
|
|
self.log( |
|
"val/loss", |
|
loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
sync_dist=True, |
|
on_epoch=True, |
|
) |
|
|
|
generation_output, _ = self.generate(batch) |
|
ar_gen_loss = self.loss_fn(generation_output, target) |
|
|
|
self.log( |
|
"val/ar_gen_loss", |
|
ar_gen_loss, |
|
prog_bar=True, |
|
logger=True, |
|
on_step=True, |
|
sync_dist=True, |
|
on_epoch=True, |
|
) |
|
|
|
return {"loss": loss, "ar_gen_loss": ar_gen_loss} |
|
|
|
def get_input_item(self, batch, k): |
|
fname, text, label_indices, waveform, stft, fbank = ( |
|
batch["fname"], |
|
batch["text"], |
|
batch["label_vector"], |
|
batch["waveform"], |
|
batch["stft"], |
|
batch["log_mel_spec"], |
|
) |
|
ret = {} |
|
|
|
ret["fbank"] = ( |
|
fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() |
|
) |
|
ret["stft"] = stft.to(memory_format=torch.contiguous_format).float() |
|
|
|
ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() |
|
ret["text"] = list(text) |
|
ret["fname"] = fname |
|
|
|
for key in batch.keys(): |
|
if key not in ret.keys(): |
|
ret[key] = batch[key] |
|
|
|
return ret[k] |
|
|
|
def get_input(self, batch): |
|
cond_dict = {} |
|
if len(self.cond_stage_model_metadata.keys()) > 0: |
|
unconditional_cfg = False |
|
|
|
for cond_model_key in self.cond_stage_model_metadata.keys(): |
|
cond_stage_key = self.cond_stage_model_metadata[cond_model_key][ |
|
"cond_stage_key" |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
xc = self.get_input_item(batch, cond_stage_key) |
|
if type(xc) == torch.Tensor: |
|
xc = xc.to(self.device) |
|
|
|
c = self.get_learned_conditioning( |
|
xc, key=cond_model_key, unconditional_cfg=unconditional_cfg |
|
) |
|
cond_dict[cond_model_key] = c |
|
|
|
return cond_dict |
|
|
|
def instantiate_cond_stage(self, config): |
|
self.cond_stage_model_metadata = {} |
|
|
|
for i, cond_model_key in enumerate(config.keys()): |
|
model = instantiate_from_config(config[cond_model_key]) |
|
self.cond_stage_models.append(model) |
|
self.cond_stage_model_metadata[cond_model_key] = { |
|
"model_idx": i, |
|
"cond_stage_key": config[cond_model_key]["cond_stage_key"], |
|
"conditioning_key": config[cond_model_key]["conditioning_key"], |
|
} |
|
|
|
def get_learned_conditioning(self, c, key, unconditional_cfg): |
|
assert key in self.cond_stage_model_metadata.keys() |
|
|
|
|
|
if not unconditional_cfg: |
|
c = self.cond_stage_models[ |
|
self.cond_stage_model_metadata[key]["model_idx"] |
|
](c) |
|
else: |
|
if isinstance(c, torch.Tensor): |
|
batchsize = c.size(0) |
|
elif isinstance(c, list): |
|
batchsize = len(c) |
|
else: |
|
raise NotImplementedError() |
|
c = self.cond_stage_models[ |
|
self.cond_stage_model_metadata[key]["model_idx"] |
|
].get_unconditional_condition(batchsize) |
|
|
|
return c |
|
|