|
from speaker_encoder.visualizations import Visualizations |
|
from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset |
|
from speaker_encoder.params_model import * |
|
from speaker_encoder.model import SpeakerEncoder |
|
from utils.profiler import Profiler |
|
from pathlib import Path |
|
import torch |
|
|
|
def sync(device: torch.device): |
|
|
|
return |
|
|
|
if device.type == "cuda": |
|
torch.cuda.synchronize(device) |
|
|
|
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, |
|
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, |
|
no_visdom: bool): |
|
|
|
dataset = SpeakerVerificationDataset(clean_data_root) |
|
loader = SpeakerVerificationDataLoader( |
|
dataset, |
|
speakers_per_batch, |
|
utterances_per_speaker, |
|
num_workers=8, |
|
) |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
loss_device = torch.device("cpu") |
|
|
|
|
|
model = SpeakerEncoder(device, loss_device) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) |
|
init_step = 1 |
|
|
|
|
|
state_fpath = models_dir.joinpath(run_id + ".pt") |
|
backup_dir = models_dir.joinpath(run_id + "_backups") |
|
|
|
|
|
if not force_restart: |
|
if state_fpath.exists(): |
|
print("Found existing model \"%s\", loading it and resuming training." % run_id) |
|
checkpoint = torch.load(state_fpath) |
|
init_step = checkpoint["step"] |
|
model.load_state_dict(checkpoint["model_state"]) |
|
optimizer.load_state_dict(checkpoint["optimizer_state"]) |
|
optimizer.param_groups[0]["lr"] = learning_rate_init |
|
else: |
|
print("No model \"%s\" found, starting training from scratch." % run_id) |
|
else: |
|
print("Starting the training from scratch.") |
|
model.train() |
|
|
|
|
|
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) |
|
vis.log_dataset(dataset) |
|
vis.log_params() |
|
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") |
|
vis.log_implementation({"Device": device_name}) |
|
|
|
|
|
profiler = Profiler(summarize_every=10, disabled=False) |
|
for step, speaker_batch in enumerate(loader, init_step): |
|
profiler.tick("Blocking, waiting for batch (threaded)") |
|
|
|
|
|
inputs = torch.from_numpy(speaker_batch.data).to(device) |
|
sync(device) |
|
profiler.tick("Data to %s" % device) |
|
embeds = model(inputs) |
|
sync(device) |
|
profiler.tick("Forward pass") |
|
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) |
|
loss, eer = model.loss(embeds_loss) |
|
sync(loss_device) |
|
profiler.tick("Loss") |
|
|
|
|
|
model.zero_grad() |
|
loss.backward() |
|
profiler.tick("Backward pass") |
|
model.do_gradient_ops() |
|
optimizer.step() |
|
profiler.tick("Parameter update") |
|
|
|
|
|
|
|
vis.update(loss.item(), eer, step) |
|
|
|
|
|
if umap_every != 0 and step % umap_every == 0: |
|
print("Drawing and saving projections (step %d)" % step) |
|
backup_dir.mkdir(exist_ok=True) |
|
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) |
|
embeds = embeds.detach().cpu().numpy() |
|
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) |
|
vis.save() |
|
|
|
|
|
if save_every != 0 and step % save_every == 0: |
|
print("Saving the model (step %d)" % step) |
|
torch.save({ |
|
"step": step + 1, |
|
"model_state": model.state_dict(), |
|
"optimizer_state": optimizer.state_dict(), |
|
}, state_fpath) |
|
|
|
|
|
if backup_every != 0 and step % backup_every == 0: |
|
print("Making a backup (step %d)" % step) |
|
backup_dir.mkdir(exist_ok=True) |
|
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) |
|
torch.save({ |
|
"step": step + 1, |
|
"model_state": model.state_dict(), |
|
"optimizer_state": optimizer.state_dict(), |
|
}, backup_fpath) |
|
|
|
profiler.tick("Extras (visualizations, saving)") |
|
|