gomoku / DI-engine /ding /reward_model /drex_reward_model.py
zjowowen's picture
init space
079c32c
raw
history blame
4.79 kB
import copy
from easydict import EasyDict
import pickle
from ding.utils import REWARD_MODEL_REGISTRY
from .trex_reward_model import TrexRewardModel
@REWARD_MODEL_REGISTRY.register('drex')
class DrexRewardModel(TrexRewardModel):
"""
Overview:
The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf)
Interface:
``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
``__init__``, ``_train``,
Config:
== ==================== ====== ============= ======================================= ===============
ID Symbol Type Default Value Description Other(Shape)
== ==================== ====== ============= ======================================= ===============
1 ``type`` str drex | Reward model register name, refer |
| to registry ``REWARD_MODEL_REGISTRY`` |
3 | ``learning_rate`` float 0.00001 | learning rate for optimizer |
4 | ``update_per_`` int 100 | Number of updates per collect |
| ``collect`` | |
5 | ``batch_size`` int 64 | How many samples in a training batch |
6 | ``hidden_size`` int 128 | Linear model hidden size |
7 | ``num_trajs`` int 0 | Number of downsampled full |
| trajectories |
8 | ``num_snippets`` int 6000 | Number of short subtrajectories |
| to sample |
== ==================== ====== ============= ======================================= ================
"""
config = dict(
# (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
type='drex',
# (float) The step size of gradient descent.
learning_rate=1e-5,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=100,
# (int) How many samples in a training batch.
batch_size=64,
# (int) Linear model hidden size
hidden_size=128,
# (int) Number of downsampled full trajectories.
num_trajs=0,
# (int) Number of short subtrajectories to sample.
num_snippets=6000,
)
bc_cfg = None
def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature.
Arguments:
- cfg (:obj:`EasyDict`): Training config
- device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
"""
super(DrexRewardModel, self).__init__(copy.deepcopy(config), device, tb_logger)
self.demo_data = []
self.load_expert_data()
def load_expert_data(self) -> None:
"""
Overview:
Getting the expert data from ``config.expert_data_path`` attribute in self
Effects:
This is a side effect function which updates the expert data attribute \
(i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
"""
super(DrexRewardModel, self).load_expert_data()
with open(self.cfg.reward_model.offline_data_path + '/suboptimal_data.pkl', 'rb') as f:
self.demo_data = pickle.load(f)
def train(self):
self._train()
return_dict = self.pred_data(self.demo_data)
res, pred_returns = return_dict['real'], return_dict['pred']
self._logger.info("real: " + str(res))
self._logger.info("pred: " + str(pred_returns))
info = {
"min_snippet_length": self.min_snippet_length,
"max_snippet_length": self.max_snippet_length,
"len_num_training_obs": len(self.training_obs),
"lem_num_labels": len(self.training_labels),
"accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels),
}
self._logger.info(
"accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
)