S23DR-P2R / train_utils.py
colin1842's picture
add model
8d5039c
import glob
import tqdm
import os
import torch
import numpy as np
from test_util import test_model
import wandb
def train_one_epoch(model, optim, data_loader, accumulated_iter, tbar, leave_pbar=False):
total_it_each_epoch = len(data_loader)
dataloader_iter = iter(data_loader)
pbar = tqdm.tqdm(total=total_it_each_epoch, leave=leave_pbar, desc='train', dynamic_ncols=True)
for cur_it in range(total_it_each_epoch):
try:
batch = next(dataloader_iter)
except StopIteration:
dataloader_iter = iter(data_loader)
batch = next(dataloader_iter)
print('new iters')
try:
cur_lr = float(optim.lr)
except:
cur_lr = optim.param_groups[0]['lr']
model.train()
optim.zero_grad()
load_data_to_gpu(batch)
loss, loss_dict, disp_dict = model(batch)
loss.backward()
optim.step()
accumulated_iter += 1
disp_dict.update(loss_dict)
disp_dict.update({'loss': loss.item(), 'lr': cur_lr})
# Log to wandb
wandb.log({"loss": loss.item(), "lr": cur_lr, "iter": accumulated_iter})
# log to console and tensorboard
pbar.update()
pbar.set_postfix(dict(total_it=accumulated_iter))
tbar.set_postfix(disp_dict)
tbar.refresh()
pbar.close()
return accumulated_iter
def train_model(model, optim, data_loader, lr_sch, start_it, start_epoch, total_epochs, ckpt_save_dir, sampler=None,
max_ckpt_save_num=5):
with tqdm.trange(start_epoch, total_epochs, desc='epochs', dynamic_ncols=True) as tbar:
accumulated_iter = start_it
for e in tbar:
if sampler is not None:
sampler.set_epoch(e)
if e > 5:
model.use_edge = True
accumulated_iter = train_one_epoch(model, optim, data_loader, accumulated_iter, tbar,
leave_pbar=(e + 1 == total_epochs))
lr_sch.step()
lr = max(optim.param_groups[0]['lr'], 1e-6)
for param_group in optim.param_groups:
param_group['lr'] = lr
ckpt_list = glob.glob(str(ckpt_save_dir / 'checkpoint_epoch_*.pth'))
ckpt_list.sort(key=os.path.getmtime)
if ckpt_list.__len__() >= max_ckpt_save_num:
for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
os.remove(ckpt_list[cur_file_idx])
ckpt_name = ckpt_save_dir / ('checkpoint_epoch_%d' % (e + 1))
save_checkpoint(
checkpoint_state(model, optim, e + 1, accumulated_iter), filename=ckpt_name,
)
def model_state_to_cpu(model_state):
model_state_cpu = type(model_state)()
for key, val in model_state.items():
model_state_cpu[key] = val.cpu()
return model_state_cpu
def checkpoint_state(model=None, optimizer=None, epoch=None, it=None):
optim_state = optimizer.state_dict() if optimizer is not None else None
if model is not None:
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_state = model_state_to_cpu(model.module.state_dict())
else:
model_state = model.state_dict()
else:
model_state = None
return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state}
def save_checkpoint(state, filename='checkpoint'):
if False and 'optimizer_state' in state:
optimizer_state = state['optimizer_state']
state.pop('optimizer_state', None)
optimizer_filename = '{}_optim.pth'.format(filename)
torch.save({'optimizer_state': optimizer_state}, optimizer_filename)
filename = '{}.pth'.format(filename)
torch.save(state, filename)
def load_data_to_gpu(batch_dict):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
for key, val in batch_dict.items():
if not isinstance(val, np.ndarray):
continue
batch_dict[key] = torch.from_numpy(val).float().to(device)