Spaces:
Runtime error
Runtime error
File size: 4,229 Bytes
ae29df4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import Any, Callable, Dict
import random
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
class AudioSep(pl.LightningModule):
def __init__(
self,
ss_model: nn.Module,
waveform_mixer,
query_encoder,
loss_function,
optimizer_type: str,
learning_rate: float,
lr_lambda_func,
use_text_ratio=1.0,
):
r"""Pytorch Lightning wrapper of PyTorch model, including forward,
optimization of model, etc.
Args:
ss_model: nn.Module
anchor_segment_detector: nn.Module
loss_function: function or object
learning_rate: float
lr_lambda: function
"""
super().__init__()
self.ss_model = ss_model
self.waveform_mixer = waveform_mixer
self.query_encoder = query_encoder
self.query_encoder_type = self.query_encoder.encoder_type
self.use_text_ratio = use_text_ratio
self.loss_function = loss_function
self.optimizer_type = optimizer_type
self.learning_rate = learning_rate
self.lr_lambda_func = lr_lambda_func
def forward(self, x):
pass
def training_step(self, batch_data_dict, batch_idx):
r"""Forward a mini-batch data to model, calculate loss function, and
train for one step. A mini-batch data is evenly distributed to multiple
devices (if there are) for parallel training.
Args:
batch_data_dict: e.g.
'audio_text': {
'text': ['a sound of dog', ...]
'waveform': (batch_size, 1, samples)
}
batch_idx: int
Returns:
loss: float, loss function of this mini-batch
"""
# [important] fix random seeds across devices
random.seed(batch_idx)
batch_audio_text_dict = batch_data_dict['audio_text']
batch_text = batch_audio_text_dict['text']
batch_audio = batch_audio_text_dict['waveform']
device = batch_audio.device
mixtures, segments = self.waveform_mixer(
waveforms=batch_audio
)
# calculate text embed for audio-text data
if self.query_encoder_type == 'CLAP':
conditions = self.query_encoder.get_query_embed(
modality='hybird',
text=batch_text,
audio=segments.squeeze(1),
use_text_ratio=self.use_text_ratio,
)
input_dict = {
'mixture': mixtures[:, None, :].squeeze(1),
'condition': conditions,
}
target_dict = {
'segment': segments.squeeze(1),
}
self.ss_model.train()
sep_segment = self.ss_model(input_dict)['waveform']
sep_segment = sep_segment.squeeze()
# (batch_size, 1, segment_samples)
output_dict = {
'segment': sep_segment,
}
# Calculate loss.
loss = self.loss_function(output_dict, target_dict)
self.log_dict({"train_loss": loss})
return loss
def test_step(self, batch, batch_idx):
pass
def configure_optimizers(self):
r"""Configure optimizer.
"""
if self.optimizer_type == "AdamW":
optimizer = optim.AdamW(
params=self.ss_model.parameters(),
lr=self.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.0,
amsgrad=True,
)
else:
raise NotImplementedError
scheduler = LambdaLR(optimizer, self.lr_lambda_func)
output_dict = {
"optimizer": optimizer,
"lr_scheduler": {
'scheduler': scheduler,
'interval': 'step',
'frequency': 1,
}
}
return output_dict
def get_model_class(model_type):
if model_type == 'ResUNet30':
from models.resunet import ResUNet30
return ResUNet30
else:
raise NotImplementedError
|