Spaces:
Runtime error
Runtime error
""" | |
trainer.py - warpper and utility functions for network training | |
Compute loss, back-prop, update parameters, logging, etc. | |
""" | |
import datetime | |
import os | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from model.network import XMem | |
from model.losses import LossComputer | |
from util.log_integrator import Integrator | |
from util.image_saver import pool_pairs | |
class XMemTrainer: | |
def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1): | |
self.config = config | |
self.num_frames = config['num_frames'] | |
self.num_ref_frames = config['num_ref_frames'] | |
self.deep_update_prob = config['deep_update_prob'] | |
self.local_rank = local_rank | |
self.XMem = nn.parallel.DistributedDataParallel( | |
XMem(config).cuda(), | |
device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) | |
# Set up logger when local_rank=0 | |
self.logger = logger | |
self.save_path = save_path | |
if logger is not None: | |
self.last_time = time.time() | |
self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()]))) | |
self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) | |
self.loss_computer = LossComputer(config) | |
self.train() | |
self.optimizer = optim.AdamW(filter( | |
lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay']) | |
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma']) | |
if config['amp']: | |
self.scaler = torch.cuda.amp.GradScaler() | |
# Logging info | |
self.log_text_interval = config['log_text_interval'] | |
self.log_image_interval = config['log_image_interval'] | |
self.save_network_interval = config['save_network_interval'] | |
self.save_checkpoint_interval = config['save_checkpoint_interval'] | |
if config['debug']: | |
self.log_text_interval = self.log_image_interval = 1 | |
def do_pass(self, data, max_it, it=0): | |
# No need to store the gradient outside training | |
torch.set_grad_enabled(self._is_train) | |
for k, v in data.items(): | |
if type(v) != list and type(v) != dict and type(v) != int: | |
data[k] = v.cuda(non_blocking=True) | |
out = {} | |
frames = data['rgb'] | |
first_frame_gt = data['first_frame_gt'].float() | |
b = frames.shape[0] | |
num_filled_objects = [o.item() for o in data['info']['num_objects']] | |
num_objects = first_frame_gt.shape[2] | |
selector = data['selector'].unsqueeze(2).unsqueeze(2) | |
global_avg = 0 | |
with torch.cuda.amp.autocast(enabled=self.config['amp']): | |
# image features never change, compute once | |
key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames) | |
filler_one = torch.zeros(1, dtype=torch.int64) | |
hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:])) | |
v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0]) | |
values = v16.unsqueeze(3) # add the time dimension | |
for ti in range(1, self.num_frames): | |
if ti <= self.num_ref_frames: | |
ref_values = values | |
ref_keys = key[:,:,:ti] | |
ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None | |
else: | |
# pick num_ref_frames random frames | |
# this is not very efficient but I think we would | |
# need broadcasting in gather which we don't have | |
indices = [ | |
torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1]) | |
for _ in range(b)] | |
ref_values = torch.stack([ | |
values[bi, :, :, indices[bi]] for bi in range(b) | |
], 0) | |
ref_keys = torch.stack([ | |
key[bi, :, indices[bi]] for bi in range(b) | |
], 0) | |
ref_shrinkage = torch.stack([ | |
shrinkage[bi, :, indices[bi]] for bi in range(b) | |
], 0) if shrinkage is not None else None | |
# Segment frame ti | |
memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None, | |
ref_keys, ref_shrinkage, ref_values) | |
hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout, | |
hidden, selector, h_out=(ti < (self.num_frames-1))) | |
# No need to encode the last frame | |
if ti < (self.num_frames-1): | |
is_deep_update = np.random.rand() < self.deep_update_prob | |
v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update) | |
values = torch.cat([values, v16.unsqueeze(3)], 3) | |
out[f'masks_{ti}'] = masks | |
out[f'logits_{ti}'] = logits | |
if self._do_log or self._is_train: | |
losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it) | |
# Logging | |
if self._do_log: | |
self.integrator.add_dict(losses) | |
if self._is_train: | |
if it % self.log_image_interval == 0 and it != 0: | |
if self.logger is not None: | |
images = {**data, **out} | |
size = (384, 384) | |
self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it) | |
if self._is_train: | |
if (it) % self.log_text_interval == 0 and it != 0: | |
time_spent = time.time()-self.last_time | |
if self.logger is not None: | |
self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) | |
self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it) | |
global_avg = 0.5*(global_avg) + 0.5*(time_spent) | |
eta_seconds = global_avg * (max_it - it) / 100 | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
print(f'ETA: {eta_string}') | |
self.last_time = time.time() | |
self.train_integrator.finalize('train', it) | |
self.train_integrator.reset_except_hooks() | |
if it % self.save_network_interval == 0 and it != 0: | |
if self.logger is not None: | |
self.save_network(it) | |
if it % self.save_checkpoint_interval == 0 and it != 0: | |
if self.logger is not None: | |
self.save_checkpoint(it) | |
# Backward pass | |
self.optimizer.zero_grad(set_to_none=True) | |
if self.config['amp']: | |
self.scaler.scale(losses['total_loss']).backward() | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
else: | |
losses['total_loss'].backward() | |
self.optimizer.step() | |
self.scheduler.step() | |
def save_network(self, it): | |
if self.save_path is None: | |
print('Saving has been disabled.') | |
return | |
os.makedirs(os.path.dirname(self.save_path), exist_ok=True) | |
model_path = f'{self.save_path}_{it}.pth' | |
torch.save(self.XMem.module.state_dict(), model_path) | |
print(f'Network saved to {model_path}.') | |
def save_checkpoint(self, it): | |
if self.save_path is None: | |
print('Saving has been disabled.') | |
return | |
os.makedirs(os.path.dirname(self.save_path), exist_ok=True) | |
checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth' | |
checkpoint = { | |
'it': it, | |
'network': self.XMem.module.state_dict(), | |
'optimizer': self.optimizer.state_dict(), | |
'scheduler': self.scheduler.state_dict()} | |
torch.save(checkpoint, checkpoint_path) | |
print(f'Checkpoint saved to {checkpoint_path}.') | |
def load_checkpoint(self, path): | |
# This method loads everything and should be used to resume training | |
map_location = 'cuda:%d' % self.local_rank | |
checkpoint = torch.load(path, map_location={'cuda:0': map_location}) | |
it = checkpoint['it'] | |
network = checkpoint['network'] | |
optimizer = checkpoint['optimizer'] | |
scheduler = checkpoint['scheduler'] | |
map_location = 'cuda:%d' % self.local_rank | |
self.XMem.module.load_state_dict(network) | |
self.optimizer.load_state_dict(optimizer) | |
self.scheduler.load_state_dict(scheduler) | |
print('Network weights, optimizer states, and scheduler states loaded.') | |
return it | |
def load_network_in_memory(self, src_dict): | |
self.XMem.module.load_weights(src_dict) | |
print('Network weight loaded from memory.') | |
def load_network(self, path): | |
# This method loads only the network weight and should be used to load a pretrained model | |
map_location = 'cuda:%d' % self.local_rank | |
src_dict = torch.load(path, map_location={'cuda:0': map_location}) | |
self.load_network_in_memory(src_dict) | |
print(f'Network weight loaded from {path}') | |
def train(self): | |
self._is_train = True | |
self._do_log = True | |
self.integrator = self.train_integrator | |
self.XMem.eval() | |
return self | |
def val(self): | |
self._is_train = False | |
self._do_log = True | |
self.XMem.eval() | |
return self | |
def test(self): | |
self._is_train = False | |
self._do_log = False | |
self.XMem.eval() | |
return self | |