|
import gym |
|
import numpy as np |
|
from easydict import EasyDict |
|
|
|
from ding.envs import BaseEnvTimestep |
|
from ding.utils import ENV_WRAPPER_REGISTRY |
|
|
|
|
|
@ENV_WRAPPER_REGISTRY.register('lightzero_env_wrapper') |
|
class LightZeroEnvWrapper(gym.Wrapper): |
|
""" |
|
Overview: |
|
Package the classic_control, box2d environment into the format required by LightZero. |
|
Wrap obs as a dict, containing keys: obs, action_mask and to_play. |
|
Interface: |
|
``__init__``, ``reset``, ``step`` |
|
Properties: |
|
- env (:obj:`gym.Env`): the environment to wrap. |
|
""" |
|
|
|
def __init__(self, env: gym.Env, cfg: EasyDict) -> None: |
|
""" |
|
Overview: |
|
Initialize ``self.`` See ``help(type(self))`` for accurate signature; \ |
|
setup the properties according to running mean and std. |
|
Arguments: |
|
- env (:obj:`gym.Env`): the environment to wrap. |
|
""" |
|
super().__init__(env) |
|
assert 'is_train' in cfg, '`is_train` flag must set in the config of env' |
|
self.is_train = cfg.is_train |
|
self.cfg = cfg |
|
self.env_name = cfg.env_name |
|
self.continuous = cfg.continuous |
|
|
|
def reset(self, **kwargs): |
|
""" |
|
Overview: |
|
Resets the state of the environment and reset properties. |
|
Arguments: |
|
- kwargs (:obj:`Dict`): Reset with this key argumets |
|
Returns: |
|
- observation (:obj:`Any`): New observation after reset |
|
""" |
|
|
|
obs = self.env.reset(**kwargs) |
|
self._eval_episode_return = 0. |
|
self._raw_observation_space = self.env.observation_space |
|
|
|
if self.cfg.continuous: |
|
action_mask = None |
|
else: |
|
action_mask = np.ones(self.env.action_space.n, 'int8') |
|
|
|
if self.cfg.continuous: |
|
self._observation_space = gym.spaces.Dict( |
|
{ |
|
'observation': self._raw_observation_space, |
|
'action_mask': gym.spaces.Box(low=np.inf, high=np.inf, |
|
shape=(1, )), |
|
'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), |
|
} |
|
) |
|
else: |
|
self._observation_space = gym.spaces.Dict( |
|
{ |
|
'observation': self._raw_observation_space, |
|
'action_mask': gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.n)]) |
|
if isinstance(self.env.action_space, gym.spaces.Discrete) else |
|
gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.shape[0])]), |
|
'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), |
|
} |
|
) |
|
|
|
lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} |
|
return lightzero_obs_dict |
|
|
|
def step(self, action): |
|
""" |
|
Overview: |
|
Step the environment with the given action. Repeat action, sum reward, \ |
|
and update ``data_count``, and also update the ``self.rms`` property \ |
|
once after integrating with the input ``action``. |
|
Arguments: |
|
- action (:obj:`Any`): the given action to step with. |
|
Returns: |
|
- ``self.observation(observation)`` : normalized observation after the \ |
|
input action and updated ``self.rms`` |
|
- reward (:obj:`Any`) : amount of reward returned after previous action |
|
- done (:obj:`Bool`) : whether the episode has ended, in which case further \ |
|
step() calls will return undefined results |
|
- info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \ |
|
for debugging, and sometimes learning) |
|
|
|
""" |
|
|
|
obs, rew, done, info = self.env.step(action) |
|
|
|
if self.cfg.continuous: |
|
action_mask = None |
|
else: |
|
action_mask = np.ones(self.env.action_space.n, 'int8') |
|
|
|
lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} |
|
|
|
self._eval_episode_return += rew |
|
if done: |
|
info['eval_episode_return'] = self._eval_episode_return |
|
|
|
return BaseEnvTimestep(lightzero_obs_dict, rew, done, info) |
|
|
|
def __repr__(self) -> str: |
|
return "LightZero Env." |