azamat's picture
Init
6127b48
raw
history blame
1.36 kB
import datetime
import os
from TTS.utils.io import save_fsspec
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
checkpoint_path = "checkpoint_{}.pth".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
save_fsspec(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = "best_model.pth"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
save_fsspec(state, bestmodel_path)
return best_loss