PKaushik commited on
Commit
cb1de97
1 Parent(s): a66e123
Files changed (1) hide show
  1. yolov6/core/engine.py +273 -0
yolov6/core/engine.py CHANGED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import time
5
+ from copy import deepcopy
6
+ import os.path as osp
7
+
8
+ from tqdm import tqdm
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.cuda import amp
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ import tools.eval as eval
17
+ from yolov6.data.data_load import create_dataloader
18
+ from yolov6.models.yolo import build_model
19
+ from yolov6.models.loss import ComputeLoss
20
+ from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog
21
+ from yolov6.utils.ema import ModelEMA, de_parallel
22
+ from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
23
+ from yolov6.solver.build import build_optimizer, build_lr_scheduler
24
+
25
+
26
+ class Trainer:
27
+ def __init__(self, args, cfg, device):
28
+ self.args = args
29
+ self.cfg = cfg
30
+ self.device = device
31
+
32
+ if args.resume:
33
+ self.ckpt = torch.load(args.resume, map_location='cpu')
34
+
35
+ self.rank = args.rank
36
+ self.local_rank = args.local_rank
37
+ self.world_size = args.world_size
38
+ self.main_process = self.rank in [-1, 0]
39
+ self.save_dir = args.save_dir
40
+ # get data loader
41
+ self.data_dict = load_yaml(args.data_path)
42
+ self.num_classes = self.data_dict['nc']
43
+ self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
44
+ # get model and optimizer
45
+ model = self.get_model(args, cfg, self.num_classes, device)
46
+ self.optimizer = self.get_optimizer(args, cfg, model)
47
+ self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
48
+ self.ema = ModelEMA(model) if self.main_process else None
49
+ # tensorboard
50
+ self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None
51
+ self.start_epoch = 0
52
+ #resume
53
+ if hasattr(self, "ckpt"):
54
+ resume_state_dict = self.ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
55
+ model.load_state_dict(resume_state_dict, strict=True) # load
56
+ self.start_epoch = self.ckpt['epoch'] + 1
57
+ self.optimizer.load_state_dict(self.ckpt['optimizer'])
58
+ if self.main_process:
59
+ self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict())
60
+ self.ema.updates = self.ckpt['updates']
61
+ self.model = self.parallel_model(args, model, device)
62
+ self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']
63
+
64
+ self.max_epoch = args.epochs
65
+ self.max_stepnum = len(self.train_loader)
66
+ self.batch_size = args.batch_size
67
+ self.img_size = args.img_size
68
+
69
+ # Training Process
70
+
71
+ def train(self):
72
+ try:
73
+ self.train_before_loop()
74
+ for self.epoch in range(self.start_epoch, self.max_epoch):
75
+ self.train_in_loop()
76
+
77
+ except Exception as _:
78
+ LOGGER.error('ERROR in training loop or eval/save model.')
79
+ raise
80
+ finally:
81
+ self.train_after_loop()
82
+
83
+ # Training loop for each epoch
84
+ def train_in_loop(self):
85
+ try:
86
+ self.prepare_for_steps()
87
+ for self.step, self.batch_data in self.pbar:
88
+ self.train_in_steps()
89
+ self.print_details()
90
+ except Exception as _:
91
+ LOGGER.error('ERROR in training steps.')
92
+ raise
93
+ try:
94
+ self.eval_and_save()
95
+ except Exception as _:
96
+ LOGGER.error('ERROR in evaluate and save model.')
97
+ raise
98
+
99
+ # Training loop for batchdata
100
+ def train_in_steps(self):
101
+ images, targets = self.prepro_data(self.batch_data, self.device)
102
+ # forward
103
+ with amp.autocast(enabled=self.device != 'cpu'):
104
+ preds = self.model(images)
105
+ total_loss, loss_items = self.compute_loss(preds, targets)
106
+ if self.rank != -1:
107
+ total_loss *= self.world_size
108
+ # backward
109
+ self.scaler.scale(total_loss).backward()
110
+ self.loss_items = loss_items
111
+ self.update_optimizer()
112
+
113
+ def eval_and_save(self):
114
+ remaining_epochs = self.max_epoch - self.epoch
115
+ eval_interval = self.args.eval_interval if remaining_epochs > self.args.heavy_eval_range else 1
116
+ is_val_epoch = (not self.args.eval_final_only or (remaining_epochs == 1)) and (self.epoch % eval_interval == 0)
117
+ if self.main_process:
118
+ self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model
119
+ if is_val_epoch:
120
+ self.eval_model()
121
+ self.ap = self.evaluate_results[0] * 0.1 + self.evaluate_results[1] * 0.9
122
+ self.best_ap = max(self.ap, self.best_ap)
123
+ # save ckpt
124
+ ckpt = {
125
+ 'model': deepcopy(de_parallel(self.model)).half(),
126
+ 'ema': deepcopy(self.ema.ema).half(),
127
+ 'updates': self.ema.updates,
128
+ 'optimizer': self.optimizer.state_dict(),
129
+ 'epoch': self.epoch,
130
+ }
131
+
132
+ save_ckpt_dir = osp.join(self.save_dir, 'weights')
133
+ save_checkpoint(ckpt, (is_val_epoch) and (self.ap == self.best_ap), save_ckpt_dir, model_name='last_ckpt')
134
+ del ckpt
135
+ # log for tensorboard
136
+ write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)
137
+
138
+ def eval_model(self):
139
+ results = eval.run(self.data_dict,
140
+ batch_size=self.batch_size // self.world_size * 2,
141
+ img_size=self.img_size,
142
+ model=self.ema.ema,
143
+ dataloader=self.val_loader,
144
+ save_dir=self.save_dir,
145
+ task='train')
146
+
147
+ LOGGER.info(f"Epoch: {self.epoch} | mAP@0.5: {results[0]} | mAP@0.50:0.95: {results[1]}")
148
+ self.evaluate_results = results[:2]
149
+
150
+ def train_before_loop(self):
151
+ LOGGER.info('Training start...')
152
+ self.start_time = time.time()
153
+ self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000)
154
+ self.scheduler.last_epoch = self.start_epoch - 1
155
+ self.last_opt_step = -1
156
+ self.scaler = amp.GradScaler(enabled=self.device != 'cpu')
157
+
158
+ self.best_ap, self.ap = 0.0, 0.0
159
+ self.evaluate_results = (0, 0) # AP50, AP50_95
160
+ self.compute_loss = ComputeLoss(iou_type=self.cfg.model.head.iou_type)
161
+
162
+ def prepare_for_steps(self):
163
+ if self.epoch > self.start_epoch:
164
+ self.scheduler.step()
165
+ self.model.train()
166
+ if self.rank != -1:
167
+ self.train_loader.sampler.set_epoch(self.epoch)
168
+ self.mean_loss = torch.zeros(4, device=self.device)
169
+ self.optimizer.zero_grad()
170
+
171
+ LOGGER.info(('\n' + '%10s' * 5) % ('Epoch', 'iou_loss', 'l1_loss', 'obj_loss', 'cls_loss'))
172
+ self.pbar = enumerate(self.train_loader)
173
+ if self.main_process:
174
+ self.pbar = tqdm(self.pbar, total=self.max_stepnum, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
175
+
176
+ # Print loss after each steps
177
+ def print_details(self):
178
+ if self.main_process:
179
+ self.mean_loss = (self.mean_loss * self.step + self.loss_items) / (self.step + 1)
180
+ self.pbar.set_description(('%10s' + '%10.4g' * 4) % (f'{self.epoch}/{self.max_epoch - 1}', \
181
+ *(self.mean_loss)))
182
+
183
+ # Empty cache if training finished
184
+ def train_after_loop(self):
185
+ if self.main_process:
186
+ LOGGER.info(f'\nTraining completed in {(time.time() - self.start_time) / 3600:.3f} hours.')
187
+ save_ckpt_dir = osp.join(self.save_dir, 'weights')
188
+ strip_optimizer(save_ckpt_dir, self.epoch) # strip optimizers for saved pt model
189
+ if self.device != 'cpu':
190
+ torch.cuda.empty_cache()
191
+
192
+ def update_optimizer(self):
193
+ curr_step = self.step + self.max_stepnum * self.epoch
194
+ self.accumulate = max(1, round(64 / self.batch_size))
195
+ if curr_step <= self.warmup_stepnum:
196
+ self.accumulate = max(1, np.interp(curr_step, [0, self.warmup_stepnum], [1, 64 / self.batch_size]).round())
197
+ for k, param in enumerate(self.optimizer.param_groups):
198
+ warmup_bias_lr = self.cfg.solver.warmup_bias_lr if k == 2 else 0.0
199
+ param['lr'] = np.interp(curr_step, [0, self.warmup_stepnum], [warmup_bias_lr, param['initial_lr'] * self.lf(self.epoch)])
200
+ if 'momentum' in param:
201
+ param['momentum'] = np.interp(curr_step, [0, self.warmup_stepnum], [self.cfg.solver.warmup_momentum, self.cfg.solver.momentum])
202
+ if curr_step - self.last_opt_step >= self.accumulate:
203
+ self.scaler.step(self.optimizer)
204
+ self.scaler.update()
205
+ self.optimizer.zero_grad()
206
+ if self.ema:
207
+ self.ema.update(self.model)
208
+ self.last_opt_step = curr_step
209
+
210
+ @staticmethod
211
+ def get_data_loader(args, cfg, data_dict):
212
+ train_path, val_path = data_dict['train'], data_dict['val']
213
+ # check data
214
+ nc = int(data_dict['nc'])
215
+ class_names = data_dict['names']
216
+ assert len(class_names) == nc, f'the length of class names does not match the number of classes defined'
217
+ grid_size = max(int(max(cfg.model.head.strides)), 32)
218
+ # create train dataloader
219
+ train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,
220
+ hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,
221
+ workers=args.workers, shuffle=True, check_images=args.check_images,
222
+ check_labels=args.check_labels, data_dict=data_dict, task='train')[0]
223
+ # create val dataloader
224
+ val_loader = None
225
+ if args.rank in [-1, 0]:
226
+ val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,
227
+ hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
228
+ workers=args.workers, check_images=args.check_images,
229
+ check_labels=args.check_labels, data_dict=data_dict, task='val')[0]
230
+
231
+ return train_loader, val_loader
232
+
233
+ @staticmethod
234
+ def prepro_data(batch_data, device):
235
+ images = batch_data[0].to(device, non_blocking=True).float() / 255
236
+ targets = batch_data[1].to(device)
237
+ return images, targets
238
+
239
+ def get_model(self, args, cfg, nc, device):
240
+ model = build_model(cfg, nc, device)
241
+ weights = cfg.model.pretrained
242
+ if weights: # finetune if pretrained model is set
243
+ LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
244
+ model = load_state_dict(weights, model, map_location=device)
245
+ LOGGER.info('Model: {}'.format(model))
246
+ return model
247
+
248
+ @staticmethod
249
+ def parallel_model(args, model, device):
250
+ # If DP mode
251
+ dp_mode = device.type != 'cpu' and args.rank == -1
252
+ if dp_mode and torch.cuda.device_count() > 1:
253
+ LOGGER.warning('WARNING: DP not recommended, use DDP instead.\n')
254
+ model = torch.nn.DataParallel(model)
255
+
256
+ # If DDP mode
257
+ ddp_mode = device.type != 'cpu' and args.rank != -1
258
+ if ddp_mode:
259
+ model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
260
+
261
+ return model
262
+
263
+ def get_optimizer(self, args, cfg, model):
264
+ accumulate = max(1, round(64 / args.batch_size))
265
+ cfg.solver.weight_decay *= args.batch_size * accumulate / 64
266
+ optimizer = build_optimizer(cfg, model)
267
+ return optimizer
268
+
269
+ @staticmethod
270
+ def get_lr_scheduler(args, cfg, optimizer):
271
+ epochs = args.epochs
272
+ lr_scheduler, lf = build_lr_scheduler(cfg, optimizer, epochs)
273
+ return lr_scheduler, lf