gomoku / DI-engine /dizoo /common /policy /md_rainbow_dqn.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
4.18 kB
from typing import Dict, Any
import torch
from ding.torch_utils import to_device
from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, dist_1step_td_data, dist_1step_td_error
from ding.policy import RainbowDQNPolicy
from ding.utils import POLICY_REGISTRY
from ding.policy.common_utils import default_preprocess_learn
@POLICY_REGISTRY.register('md_rainbow_dqn')
class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
r"""
Overview:
Multi-discrete action space Rainbow DQN algorithms.
"""
def _forward_learn(self, data: dict) -> Dict[str, Any]:
"""
Overview:
Forward and backward function of learn mode, acquire the data and calculate the loss and \
optimize learner model
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including cur_lr, total_loss and priority
- cur_lr (:obj:`float`): current learning rate
- total_loss (:obj:`float`): the calculated loss
- priority (:obj:`list`): the priority of samples
"""
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=True
)
if self._cuda:
data = to_device(data, self._device)
# ====================
# Rainbow forward
# ====================
self._learn_model.train()
self._target_model.train()
# reset noise of noisenet for both main model and target model
self._reset_noise(self._learn_model)
self._reset_noise(self._target_model)
q_dist = self._learn_model.forward(data['obs'])['distribution']
with torch.no_grad():
target_q_dist = self._target_model.forward(data['next_obs'])['distribution']
self._reset_noise(self._learn_model)
target_q_action = self._learn_model.forward(data['next_obs'])['action']
value_gamma = data.get('value_gamma', None)
if isinstance(q_dist, torch.Tensor):
td_data = dist_nstep_td_data(
q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight']
)
loss, td_error_per_sample = dist_nstep_td_error(
td_data,
self._gamma,
self._v_min,
self._v_max,
self._n_atom,
nstep=self._nstep,
value_gamma=value_gamma
)
else:
act_num = len(q_dist)
losses = []
td_error_per_samples = []
for i in range(act_num):
td_data = dist_nstep_td_data(
q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
data['weight']
)
td_loss, td_error_per_sample = dist_nstep_td_error(
td_data,
self._gamma,
self._v_min,
self._v_max,
self._n_atom,
nstep=self._nstep,
value_gamma=value_gamma
)
losses.append(td_loss)
td_error_per_samples.append(td_error_per_sample)
loss = sum(losses) / (len(losses) + 1e-8)
td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8)
# ====================
# Rainbow update
# ====================
self._optimizer.zero_grad()
loss.backward()
self._optimizer.step()
# =============
# after update
# =============
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr': self._optimizer.defaults['lr'],
'total_loss': loss.item(),
'priority': td_error_per_sample_mean.abs().tolist(),
}