zjowowen's picture
init space
079c32c
raw
history blame
No virus
2.96 kB
import math
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from ding.policy import Policy
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils import EasyTimer
class ImageClassificationPolicy(Policy):
config = dict(
type='image_classification',
on_policy=False,
)
def _init_learn(self):
self._optimizer = SGD(
self._model.parameters(),
lr=self._cfg.learn.learning_rate,
weight_decay=self._cfg.learn.weight_decay,
momentum=0.9
)
self._timer = EasyTimer(cuda=True)
def lr_scheduler_fn(epoch):
if epoch <= self._cfg.learn.warmup_epoch:
return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
else:
ratio = epoch // self._cfg.learn.decay_epoch
return math.pow(self._cfg.learn.decay_rate, ratio)
self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
self._lr_scheduler.step()
self._learn_model = model_wrap(self._model, 'base')
self._learn_model.reset()
self._ce_loss = nn.CrossEntropyLoss()
def _forward_learn(self, data):
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
with self._timer:
img, target = data
logit = self._learn_model.forward(img)
loss = self._ce_loss(logit, target)
forward_time = self._timer.value
with self._timer:
self._optimizer.zero_grad()
loss.backward()
backward_time = self._timer.value
with self._timer:
if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
sync_time = self._timer.value
self._optimizer.step()
cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
cur_lr = sum(cur_lr) / len(cur_lr)
return {
'cur_lr': cur_lr,
'total_loss': loss.item(),
'forward_time': forward_time,
'backward_time': backward_time,
'sync_time': sync_time,
}
def _monitor_vars_learn(self):
return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']
def _init_eval(self):
self._eval_model = model_wrap(self._model, 'base')
def _forward_eval(self, data):
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
output = self._eval_model.forward(data)
if self._cuda:
output = to_device(output, 'cpu')
return output
def _init_collect(self):
pass
def _forward_collect(self, data):
pass
def _process_transition(self):
pass
def _get_train_sample(self):
pass