from typing import List, Dict, Any, Optional, Callable, Tuple import copy import numpy as np import torch class HerRewardModel: """ Overview: Hindsight Experience Replay model. .. note:: - her_strategy (:obj:`str`): Type of strategy that HER uses, should be in ['final', 'future', 'episode'] - her_replay_k (:obj:`int`): Number of new episodes generated by an original episode. (Not used in episodic HER) - episode_size (:obj:`int`): Sample how many episodes in one iteration. - sample_per_episode (:obj:`int`): How many new samples are generated from an episode. .. note:: In HER, we require episode trajectory to change the goals. However, episode lengths are different and may have high variance. As a result, we **recommend** that you only use some transitions in the complete episode by specifying ``episode_size`` and ``sample_per_episode`` in config. Therefore, in one iteration, ``batch_size`` would be ``episode_size`` * ``sample_per_episode``. """ def __init__( self, cfg: dict, cuda: bool = False, ) -> None: self._cuda = cuda and torch.cuda.is_available() self._device = 'cuda' if self._cuda else 'cpu' self._her_strategy = cfg.her_strategy assert self._her_strategy in ['final', 'future', 'episode'] # `her_replay_k` may not be used in episodic HER, so default set to 1. self._her_replay_k = cfg.get('her_replay_k', 1) self._episode_size = cfg.get('episode_size', None) self._sample_per_episode = cfg.get('sample_per_episode', None) def estimate( self, episode: List[Dict[str, Any]], merge_func: Optional[Callable] = None, split_func: Optional[Callable] = None, goal_reward_func: Optional[Callable] = None ) -> List[Dict[str, Any]]: """ Overview: Get HER processed episodes from original episodes. Arguments: - episode (:obj:`List[Dict[str, Any]]`): Episode list, each element is a transition. - merge_func (:obj:`Callable`): The merge function to use, default set to None. If None, \ then use ``__her_default_merge_func`` - split_func (:obj:`Callable`): The split function to use, default set to None. If None, \ then use ``__her_default_split_func`` - goal_reward_func (:obj:`Callable`): The goal_reward function to use, default set to None. If None, \ then use ``__her_default_goal_reward_func`` Returns: - new_episode (:obj:`List[Dict[str, Any]]`): the processed transitions """ if merge_func is None: merge_func = HerRewardModel.__her_default_merge_func if split_func is None: split_func = HerRewardModel.__her_default_split_func if goal_reward_func is None: goal_reward_func = HerRewardModel.__her_default_goal_reward_func new_episodes = [[] for _ in range(self._her_replay_k)] if self._sample_per_episode is None: # Use complete episode indices = range(len(episode)) else: # Use some transitions in one episode indices = np.random.randint(0, len(episode), (self._sample_per_episode)) for idx in indices: obs, _, _ = split_func(episode[idx]['obs']) next_obs, _, achieved_goal = split_func(episode[idx]['next_obs']) for k in range(self._her_replay_k): if self._her_strategy == 'final': p_idx = -1 elif self._her_strategy == 'episode': p_idx = np.random.randint(0, len(episode)) elif self._her_strategy == 'future': p_idx = np.random.randint(idx, len(episode)) _, _, new_desired_goal = split_func(episode[p_idx]['next_obs']) timestep = { k: copy.deepcopy(v) for k, v in episode[idx].items() if k not in ['obs', 'next_obs', 'reward'] } timestep['obs'] = merge_func(obs, new_desired_goal) timestep['next_obs'] = merge_func(next_obs, new_desired_goal) timestep['reward'] = goal_reward_func(achieved_goal, new_desired_goal).to(self._device) new_episodes[k].append(timestep) return new_episodes @staticmethod def __her_default_merge_func(x: Any, y: Any) -> Any: r""" Overview: The function to merge obs in HER timestep Arguments: - x (:obj:`Any`): one of the timestep obs to merge - y (:obj:`Any`): another timestep obs to merge Returns: - ret (:obj:`Any`): the merge obs """ # TODO(nyz) dict/list merge_func return torch.cat([x, y], dim=0) @staticmethod def __her_default_split_func(x: Any) -> Tuple[Any, Any, Any]: r""" Overview: Split the input into obs, desired goal, and achieved goal. Arguments: - x (:obj:`Any`): The input to split Returns: - obs (:obj:`torch.Tensor`): Original obs. - desired_goal (:obj:`torch.Tensor`): The final goal that wants to desired_goal - achieved_goal (:obj:`torch.Tensor`): the achieved_goal """ # TODO(nyz) dict/list split_func # achieved_goal = f(obs), default: f == identical function obs, desired_goal = torch.chunk(x, 2) achieved_goal = obs return obs, desired_goal, achieved_goal @staticmethod def __her_default_goal_reward_func(achieved_goal: torch.Tensor, desired_goal: torch.Tensor) -> torch.Tensor: r""" Overview: Get the corresponding merge reward according to whether the achieved_goal fit the desired_goal Arguments: - achieved_goal (:obj:`torch.Tensor`): the achieved goal - desired_goal (:obj:`torch.Tensor`): the desired_goal Returns: - goal_reward (:obj:`torch.Tensor`): the goal reward according to \ whether the achieved_goal fit the disired_goal """ if (achieved_goal == desired_goal).all(): return torch.FloatTensor([1]) else: return torch.FloatTensor([0]) @property def episode_size(self) -> int: return self._episode_size @property def sample_per_episode(self) -> int: return self._sample_per_episode