Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import datetime | |
import os | |
import time | |
from loguru import logger | |
import torch | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.tensorboard import SummaryWriter | |
from yolox.data import DataPrefetcher | |
from yolox.exp import Exp | |
from yolox.utils import ( | |
MeterBuffer, | |
ModelEMA, | |
WandbLogger, | |
adjust_status, | |
all_reduce_norm, | |
get_local_rank, | |
get_model_info, | |
get_rank, | |
get_world_size, | |
gpu_mem_usage, | |
is_parallel, | |
load_ckpt, | |
mem_usage, | |
occupy_mem, | |
save_checkpoint, | |
setup_logger, | |
synchronize | |
) | |
class Trainer: | |
def __init__(self, exp: Exp, args): | |
# init function only defines some basic attr, other attrs like model, optimizer are built in | |
# before_train methods. | |
self.exp = exp | |
self.args = args | |
# training related attr | |
self.max_epoch = exp.max_epoch | |
self.amp_training = args.fp16 | |
self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) | |
self.is_distributed = get_world_size() > 1 | |
self.rank = get_rank() | |
self.local_rank = get_local_rank() | |
self.device = "cuda:{}".format(self.local_rank) | |
self.use_model_ema = exp.ema | |
self.save_history_ckpt = exp.save_history_ckpt | |
# data/dataloader related attr | |
self.data_type = torch.float16 if args.fp16 else torch.float32 | |
self.input_size = exp.input_size | |
self.best_ap = 0 | |
# metric record | |
self.meter = MeterBuffer(window_size=exp.print_interval) | |
self.file_name = os.path.join(exp.output_dir, args.experiment_name) | |
if self.rank == 0: | |
os.makedirs(self.file_name, exist_ok=True) | |
setup_logger( | |
self.file_name, | |
distributed_rank=self.rank, | |
filename="train_log.txt", | |
mode="a", | |
) | |
def train(self): | |
self.before_train() | |
try: | |
self.train_in_epoch() | |
except Exception: | |
raise | |
finally: | |
self.after_train() | |
def train_in_epoch(self): | |
for self.epoch in range(self.start_epoch, self.max_epoch): | |
self.before_epoch() | |
self.train_in_iter() | |
self.after_epoch() | |
def train_in_iter(self): | |
for self.iter in range(self.max_iter): | |
self.before_iter() | |
self.train_one_iter() | |
self.after_iter() | |
def train_one_iter(self): | |
iter_start_time = time.time() | |
inps, targets = self.prefetcher.next() | |
inps = inps.to(self.data_type) | |
targets = targets.to(self.data_type) | |
targets.requires_grad = False | |
inps, targets = self.exp.preprocess(inps, targets, self.input_size) | |
data_end_time = time.time() | |
with torch.cuda.amp.autocast(enabled=self.amp_training): | |
outputs = self.model(inps, targets) | |
loss = outputs["total_loss"] | |
self.optimizer.zero_grad() | |
self.scaler.scale(loss).backward() | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
if self.use_model_ema: | |
self.ema_model.update(self.model) | |
lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1) | |
for param_group in self.optimizer.param_groups: | |
param_group["lr"] = lr | |
iter_end_time = time.time() | |
self.meter.update( | |
iter_time=iter_end_time - iter_start_time, | |
data_time=data_end_time - iter_start_time, | |
lr=lr, | |
**outputs, | |
) | |
def before_train(self): | |
logger.info("args: {}".format(self.args)) | |
logger.info("exp value:\n{}".format(self.exp)) | |
# model related init | |
torch.cuda.set_device(self.local_rank) | |
model = self.exp.get_model() | |
logger.info( | |
"Model Summary: {}".format(get_model_info(model, self.exp.test_size)) | |
) | |
model.to(self.device) | |
# solver related init | |
self.optimizer = self.exp.get_optimizer(self.args.batch_size) | |
# value of epoch will be set in `resume_train` | |
model = self.resume_train(model) | |
# data related init | |
self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs | |
self.train_loader = self.exp.get_data_loader( | |
batch_size=self.args.batch_size, | |
is_distributed=self.is_distributed, | |
no_aug=self.no_aug, | |
cache_img=self.args.cache, | |
) | |
logger.info("init prefetcher, this might take one minute or less...") | |
self.prefetcher = DataPrefetcher(self.train_loader) | |
# max_iter means iters per epoch | |
self.max_iter = len(self.train_loader) | |
self.lr_scheduler = self.exp.get_lr_scheduler( | |
self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter | |
) | |
if self.args.occupy: | |
occupy_mem(self.local_rank) | |
if self.is_distributed: | |
model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False) | |
if self.use_model_ema: | |
self.ema_model = ModelEMA(model, 0.9998) | |
self.ema_model.updates = self.max_iter * self.start_epoch | |
self.model = model | |
self.evaluator = self.exp.get_evaluator( | |
batch_size=self.args.batch_size, is_distributed=self.is_distributed | |
) | |
# Tensorboard and Wandb loggers | |
if self.rank == 0: | |
if self.args.logger == "tensorboard": | |
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard")) | |
elif self.args.logger == "wandb": | |
self.wandb_logger = WandbLogger.initialize_wandb_logger( | |
self.args, | |
self.exp, | |
self.evaluator.dataloader.dataset | |
) | |
else: | |
raise ValueError("logger must be either 'tensorboard' or 'wandb'") | |
logger.info("Training start...") | |
logger.info("\n{}".format(model)) | |
def after_train(self): | |
logger.info( | |
"Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100) | |
) | |
if self.rank == 0: | |
if self.args.logger == "wandb": | |
self.wandb_logger.finish() | |
def before_epoch(self): | |
logger.info("---> start train epoch{}".format(self.epoch + 1)) | |
if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug: | |
logger.info("--->No mosaic aug now!") | |
self.train_loader.close_mosaic() | |
logger.info("--->Add additional L1 loss now!") | |
if self.is_distributed: | |
self.model.module.head.use_l1 = True | |
else: | |
self.model.head.use_l1 = True | |
self.exp.eval_interval = 1 | |
if not self.no_aug: | |
self.save_ckpt(ckpt_name="last_mosaic_epoch") | |
def after_epoch(self): | |
self.save_ckpt(ckpt_name="latest") | |
if (self.epoch + 1) % self.exp.eval_interval == 0: | |
all_reduce_norm(self.model) | |
self.evaluate_and_save_model() | |
def before_iter(self): | |
pass | |
def after_iter(self): | |
""" | |
`after_iter` contains two parts of logic: | |
* log information | |
* reset setting of resize | |
""" | |
# log needed information | |
if (self.iter + 1) % self.exp.print_interval == 0: | |
# TODO check ETA logic | |
left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1) | |
eta_seconds = self.meter["iter_time"].global_avg * left_iters | |
eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds))) | |
progress_str = "epoch: {}/{}, iter: {}/{}".format( | |
self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter | |
) | |
loss_meter = self.meter.get_filtered_meter("loss") | |
loss_str = ", ".join( | |
["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()] | |
) | |
time_meter = self.meter.get_filtered_meter("time") | |
time_str = ", ".join( | |
["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()] | |
) | |
mem_str = "gpu mem: {:.0f}Mb, mem: {:.1f}Gb".format(gpu_mem_usage(), mem_usage()) | |
logger.info( | |
"{}, {}, {}, {}, lr: {:.3e}".format( | |
progress_str, | |
mem_str, | |
time_str, | |
loss_str, | |
self.meter["lr"].latest, | |
) | |
+ (", size: {:d}, {}".format(self.input_size[0], eta_str)) | |
) | |
if self.rank == 0: | |
if self.args.logger == "tensorboard": | |
self.tblogger.add_scalar( | |
"train/lr", self.meter["lr"].latest, self.progress_in_iter) | |
for k, v in loss_meter.items(): | |
self.tblogger.add_scalar( | |
f"train/{k}", v.latest, self.progress_in_iter) | |
if self.args.logger == "wandb": | |
metrics = {"train/" + k: v.latest for k, v in loss_meter.items()} | |
metrics.update({ | |
"train/lr": self.meter["lr"].latest | |
}) | |
self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter) | |
self.meter.clear_meters() | |
# random resizing | |
if (self.progress_in_iter + 1) % 10 == 0: | |
self.input_size = self.exp.random_resize( | |
self.train_loader, self.epoch, self.rank, self.is_distributed | |
) | |
def progress_in_iter(self): | |
return self.epoch * self.max_iter + self.iter | |
def resume_train(self, model): | |
if self.args.resume: | |
logger.info("resume training") | |
if self.args.ckpt is None: | |
ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth") | |
else: | |
ckpt_file = self.args.ckpt | |
ckpt = torch.load(ckpt_file, map_location=self.device) | |
# resume the model/optimizer state dict | |
model.load_state_dict(ckpt["model"]) | |
self.optimizer.load_state_dict(ckpt["optimizer"]) | |
self.best_ap = ckpt.pop("best_ap", 0) | |
# resume the training states variables | |
start_epoch = ( | |
self.args.start_epoch - 1 | |
if self.args.start_epoch is not None | |
else ckpt["start_epoch"] | |
) | |
self.start_epoch = start_epoch | |
logger.info( | |
"loaded checkpoint '{}' (epoch {})".format( | |
self.args.resume, self.start_epoch | |
) | |
) # noqa | |
else: | |
if self.args.ckpt is not None: | |
logger.info("loading checkpoint for fine tuning") | |
ckpt_file = self.args.ckpt | |
ckpt = torch.load(ckpt_file, map_location=self.device)["model"] | |
model = load_ckpt(model, ckpt) | |
self.start_epoch = 0 | |
return model | |
def evaluate_and_save_model(self): | |
if self.use_model_ema: | |
evalmodel = self.ema_model.ema | |
else: | |
evalmodel = self.model | |
if is_parallel(evalmodel): | |
evalmodel = evalmodel.module | |
with adjust_status(evalmodel, training=False): | |
(ap50_95, ap50, summary), predictions = self.exp.eval( | |
evalmodel, self.evaluator, self.is_distributed, return_outputs=True | |
) | |
update_best_ckpt = ap50_95 > self.best_ap | |
self.best_ap = max(self.best_ap, ap50_95) | |
if self.rank == 0: | |
if self.args.logger == "tensorboard": | |
self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1) | |
self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1) | |
if self.args.logger == "wandb": | |
self.wandb_logger.log_metrics({ | |
"val/COCOAP50": ap50, | |
"val/COCOAP50_95": ap50_95, | |
"train/epoch": self.epoch + 1, | |
}) | |
self.wandb_logger.log_images(predictions) | |
logger.info("\n" + summary) | |
synchronize() | |
self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95) | |
if self.save_history_ckpt: | |
self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95) | |
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None): | |
if self.rank == 0: | |
save_model = self.ema_model.ema if self.use_model_ema else self.model | |
logger.info("Save weights to {}".format(self.file_name)) | |
ckpt_state = { | |
"start_epoch": self.epoch + 1, | |
"model": save_model.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"best_ap": self.best_ap, | |
"curr_ap": ap, | |
} | |
save_checkpoint( | |
ckpt_state, | |
update_best_ckpt, | |
self.file_name, | |
ckpt_name, | |
) | |
if self.args.logger == "wandb": | |
self.wandb_logger.save_checkpoint( | |
self.file_name, | |
ckpt_name, | |
update_best_ckpt, | |
metadata={ | |
"epoch": self.epoch + 1, | |
"optimizer": self.optimizer.state_dict(), | |
"best_ap": self.best_ap, | |
"curr_ap": ap | |
} | |
) | |