import math from typing import List, Dict, Any, Tuple from collections import namedtuple import torch import torch.nn as nn from torch.optim import Adam, SGD, AdamW from torch.optim.lr_scheduler import LambdaLR from ding.policy import Policy from ding.model import model_wrap from ding.torch_utils import to_device from ding.utils import EasyTimer from ding.utils import POLICY_REGISTRY @POLICY_REGISTRY.register('pc_bfs') class ProcedureCloningBFSPolicy(Policy): def default_model(self) -> Tuple[str, List[str]]: return 'pc_bfs', ['ding.model.template.procedure_cloning'] config = dict( type='pc', cuda=False, on_policy=False, continuous=False, max_bfs_steps=100, learn=dict( update_per_collect=1, batch_size=32, learning_rate=1e-5, lr_decay=False, decay_epoch=30, decay_rate=0.1, warmup_lr=1e-4, warmup_epoch=3, optimizer='SGD', momentum=0.9, weight_decay=1e-4, ), collect=dict( unroll_len=1, noise=False, noise_sigma=0.2, noise_range=dict( min=-0.5, max=0.5, ), ), eval=dict(), other=dict(replay_buffer=dict(replay_buffer_size=10000)), ) def _init_learn(self): assert self._cfg.learn.optimizer in ['SGD', 'Adam'] if self._cfg.learn.optimizer == 'SGD': self._optimizer = SGD( self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self._cfg.learn.weight_decay, momentum=self._cfg.learn.momentum ) elif self._cfg.learn.optimizer == 'Adam': if self._cfg.learn.weight_decay is None: self._optimizer = Adam( self._model.parameters(), lr=self._cfg.learn.learning_rate, ) else: self._optimizer = AdamW( self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self._cfg.learn.weight_decay ) if self._cfg.learn.lr_decay: def lr_scheduler_fn(epoch): if epoch <= self._cfg.learn.warmup_epoch: return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate else: ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch return math.pow(self._cfg.learn.decay_rate, ratio) self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) self._timer = EasyTimer(cuda=True) self._learn_model = model_wrap(self._model, 'base') self._learn_model.reset() self._max_bfs_steps = self._cfg.max_bfs_steps self._maze_size = self._cfg.maze_size self._num_actions = self._cfg.num_actions self._loss = nn.CrossEntropyLoss() def process_states(self, observations, maze_maps): """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)""" loc = torch.nn.functional.one_hot( (observations[:, 0] * self._maze_size + observations[:, 1]).long(), self._maze_size * self._maze_size, ).long() loc = torch.reshape(loc, [observations.shape[0], self._maze_size, self._maze_size]) states = torch.cat([maze_maps, loc], dim=-1).long() return states def _forward_learn(self, data): if self._cuda: collated_data = to_device(data, self._device) else: collated_data = data observations = collated_data['obs'], bfs_input_maps, bfs_output_maps = collated_data['bfs_in'].long(), collated_data['bfs_out'].long() states = observations bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).float() bfs_states = torch.cat([ states, bfs_input_onehot, ], dim=-1) logits = self._model(bfs_states)['logit'] logits = logits.flatten(0, -2) labels = bfs_output_maps.flatten(0, -1) loss = self._loss(logits, labels) preds = torch.argmax(logits, dim=-1) acc = torch.sum((preds == labels)) / preds.shape[0] self._optimizer.zero_grad() loss.backward() self._optimizer.step() pred_loss = loss.item() cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] cur_lr = sum(cur_lr) / len(cur_lr) return {'cur_lr': cur_lr, 'total_loss': pred_loss, 'acc': acc} def _monitor_vars_learn(self): return ['cur_lr', 'total_loss', 'acc'] def _init_eval(self): self._eval_model = model_wrap(self._model, wrapper_name='base') self._eval_model.reset() def _forward_eval(self, data): if self._cuda: data = to_device(data, self._device) max_len = self._max_bfs_steps data_id = list(data.keys()) output = {} for ii in data_id: states = data[ii].unsqueeze(0) bfs_input_maps = self._num_actions * torch.ones([1, self._maze_size, self._maze_size]).long() if self._cuda: bfs_input_maps = to_device(bfs_input_maps, self._device) xy = torch.where(states[:, :, :, -1] == 1) observation = (xy[1][0].item(), xy[2][0].item()) i = 0 while bfs_input_maps[0, observation[0], observation[1]].item() == self._num_actions and i < max_len: bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).long() bfs_states = torch.cat([ states, bfs_input_onehot, ], dim=-1) logits = self._model(bfs_states)['logit'] bfs_input_maps = torch.argmax(logits, dim=-1) i += 1 output[ii] = bfs_input_maps[0, observation[0], observation[1]] if self._cuda: output[ii] = {'action': to_device(output[ii], 'cpu'), 'info': {}} if output[ii]['action'].item() == self._num_actions: output[ii]['action'] = torch.randint(low=0, high=self._num_actions, size=[1])[0] return output def _init_collect(self) -> None: raise NotImplementedError def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: raise NotImplementedError def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: raise NotImplementedError def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: raise NotImplementedError