|
''' |
|
author: wayn391@mastertones |
|
''' |
|
|
|
import os |
|
import json |
|
import time |
|
import yaml |
|
import datetime |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from . import utils |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
class Saver(object): |
|
def __init__( |
|
self, |
|
args, |
|
initial_global_step=-1): |
|
|
|
self.expdir = args.env.expdir |
|
self.sample_rate = args.data.sampling_rate |
|
|
|
|
|
self.global_step = initial_global_step |
|
self.init_time = time.time() |
|
self.last_time = time.time() |
|
|
|
|
|
os.makedirs(self.expdir, exist_ok=True) |
|
|
|
|
|
self.path_log_info = os.path.join(self.expdir, 'log_info.txt') |
|
|
|
|
|
os.makedirs(self.expdir, exist_ok=True) |
|
|
|
|
|
self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) |
|
|
|
|
|
path_config = os.path.join(self.expdir, 'config.yaml') |
|
with open(path_config, "w") as out_config: |
|
yaml.dump(dict(args), out_config) |
|
|
|
|
|
def log_info(self, msg): |
|
'''log method''' |
|
if isinstance(msg, dict): |
|
msg_list = [] |
|
for k, v in msg.items(): |
|
tmp_str = '' |
|
if isinstance(v, int): |
|
tmp_str = '{}: {:,}'.format(k, v) |
|
else: |
|
tmp_str = '{}: {}'.format(k, v) |
|
|
|
msg_list.append(tmp_str) |
|
msg_str = '\n'.join(msg_list) |
|
else: |
|
msg_str = msg |
|
|
|
|
|
print(msg_str) |
|
|
|
|
|
with open(self.path_log_info, 'a') as fp: |
|
fp.write(msg_str+'\n') |
|
|
|
def log_value(self, dict): |
|
for k, v in dict.items(): |
|
self.writer.add_scalar(k, v, self.global_step) |
|
|
|
def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): |
|
spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) |
|
spec = spec_cat[0] |
|
if isinstance(spec, torch.Tensor): |
|
spec = spec.cpu().numpy() |
|
fig = plt.figure(figsize=(12, 9)) |
|
plt.pcolor(spec.T, vmin=vmin, vmax=vmax) |
|
plt.tight_layout() |
|
self.writer.add_figure(name, fig, self.global_step) |
|
|
|
def log_audio(self, dict): |
|
for k, v in dict.items(): |
|
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) |
|
|
|
def get_interval_time(self, update=True): |
|
cur_time = time.time() |
|
time_interval = cur_time - self.last_time |
|
if update: |
|
self.last_time = cur_time |
|
return time_interval |
|
|
|
def get_total_time(self, to_str=True): |
|
total_time = time.time() - self.init_time |
|
if to_str: |
|
total_time = str(datetime.timedelta( |
|
seconds=total_time))[:-5] |
|
return total_time |
|
|
|
def save_model( |
|
self, |
|
model, |
|
optimizer, |
|
name='model', |
|
postfix='', |
|
to_json=False): |
|
|
|
if postfix: |
|
postfix = '_' + postfix |
|
path_pt = os.path.join( |
|
self.expdir , name+postfix+'.pt') |
|
|
|
|
|
print(' [*] model checkpoint saved: {}'.format(path_pt)) |
|
|
|
|
|
torch.save({ |
|
'global_step': self.global_step, |
|
'model': model.state_dict(), |
|
'optimizer': optimizer.state_dict()}, path_pt) |
|
|
|
|
|
if to_json: |
|
path_json = os.path.join( |
|
self.expdir , name+'.json') |
|
utils.to_json(path_params, path_json) |
|
|
|
def delete_model(self, name='model', postfix=''): |
|
|
|
if postfix: |
|
postfix = '_' + postfix |
|
path_pt = os.path.join( |
|
self.expdir , name+postfix+'.pt') |
|
|
|
|
|
if os.path.exists(path_pt): |
|
os.remove(path_pt) |
|
print(' [*] model checkpoint deleted: {}'.format(path_pt)) |
|
|
|
def global_step_increment(self): |
|
self.global_step += 1 |
|
|
|
|
|
|