gomoku / LightZero /lzero /envs /wrappers /lightzero_env_wrapper.py
zjowowen's picture
init space
079c32c
raw
history blame
4.53 kB
import gym
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnvTimestep
from ding.utils import ENV_WRAPPER_REGISTRY
@ENV_WRAPPER_REGISTRY.register('lightzero_env_wrapper')
class LightZeroEnvWrapper(gym.Wrapper):
"""
Overview:
Package the classic_control, box2d environment into the format required by LightZero.
Wrap obs as a dict, containing keys: obs, action_mask and to_play.
Interface:
``__init__``, ``reset``, ``step``
Properties:
- env (:obj:`gym.Env`): the environment to wrap.
"""
def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature; \
setup the properties according to running mean and std.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
"""
super().__init__(env)
assert 'is_train' in cfg, '`is_train` flag must set in the config of env'
self.is_train = cfg.is_train
self.cfg = cfg
self.env_name = cfg.env_name
self.continuous = cfg.continuous
def reset(self, **kwargs):
"""
Overview:
Resets the state of the environment and reset properties.
Arguments:
- kwargs (:obj:`Dict`): Reset with this key argumets
Returns:
- observation (:obj:`Any`): New observation after reset
"""
# The core original env reset.
obs = self.env.reset(**kwargs)
self._eval_episode_return = 0.
self._raw_observation_space = self.env.observation_space
if self.cfg.continuous:
action_mask = None
else:
action_mask = np.ones(self.env.action_space.n, 'int8')
if self.cfg.continuous:
self._observation_space = gym.spaces.Dict(
{
'observation': self._raw_observation_space,
'action_mask': gym.spaces.Box(low=np.inf, high=np.inf,
shape=(1, )), # TODO: gym.spaces.Constant(None)
'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1)
}
)
else:
self._observation_space = gym.spaces.Dict(
{
'observation': self._raw_observation_space,
'action_mask': gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.n)])
if isinstance(self.env.action_space, gym.spaces.Discrete) else
gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.shape[0])]), # {0,1}
'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1)
}
)
lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
return lightzero_obs_dict
def step(self, action):
"""
Overview:
Step the environment with the given action. Repeat action, sum reward, \
and update ``data_count``, and also update the ``self.rms`` property \
once after integrating with the input ``action``.
Arguments:
- action (:obj:`Any`): the given action to step with.
Returns:
- ``self.observation(observation)`` : normalized observation after the \
input action and updated ``self.rms``
- reward (:obj:`Any`) : amount of reward returned after previous action
- done (:obj:`Bool`) : whether the episode has ended, in which case further \
step() calls will return undefined results
- info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \
for debugging, and sometimes learning)
"""
# The core original env step.
obs, rew, done, info = self.env.step(action)
if self.cfg.continuous:
action_mask = None
else:
action_mask = np.ones(self.env.action_space.n, 'int8')
lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(lightzero_obs_dict, rew, done, info)
def __repr__(self) -> str:
return "LightZero Env."