gomoku / LightZero /lzero /envs /wrappers /action_discretization_env_wrapper.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
3.68 kB
from itertools import product
import gym
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_WRAPPER_REGISTRY
@ENV_WRAPPER_REGISTRY.register('action_discretization_env_wrapper')
class ActionDiscretizationEnvWrapper(gym.Wrapper):
"""
Overview:
The modified environment with manually discretized action space. For each dimension, equally dividing the
original continuous action into ``each_dim_disc_size`` bins and using their Cartesian product to obtain
handcrafted discrete actions.
Interface:
``__init__``, ``reset``, ``step``
Properties:
- env (:obj:`gym.Env`): the environment to wrap.
"""
def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate signature; \
setup the properties according to running mean and std.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
"""
super().__init__(env)
assert 'is_train' in cfg, '`is_train` flag must set in the config of env'
self.is_train = cfg.is_train
self.cfg = cfg
self.env_name = cfg.env_name
self.continuous = cfg.continuous
def reset(self, **kwargs):
"""
Overview:
Resets the state of the environment and reset properties.
Arguments:
- kwargs (:obj:`Dict`): Reset with this key argumets
Returns:
- observation (:obj:`Any`): New observation after reset
"""
obs = self.env.reset(**kwargs)
self._raw_action_space = self.env.action_space
if self.cfg.manually_discretization:
# disc_to_cont: transform discrete action index to original continuous action
self.m = self._raw_action_space.shape[0]
self.n = self.cfg.each_dim_disc_size
self.K = self.n ** self.m
self.disc_to_cont = list(product(*[list(range(self.n)) for dim in range(self.m)]))
# the modified discrete action space
self._action_space = gym.spaces.Discrete(self.K)
return obs
def step(self, action):
"""
Overview:
Step the environment with the given action. Repeat action, sum reward, \
and update ``data_count``, and also update the ``self.rms`` property \
once after integrating with the input ``action``.
Arguments:
- action (:obj:`Any`): the given action to step with.
Returns:
- ``self.observation(observation)`` : normalized observation after the \
input action and updated ``self.rms``
- reward (:obj:`Any`) : amount of reward returned after previous action
- done (:obj:`Bool`) : whether the episode has ended, in which case further \
step() calls will return undefined results
- info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \
for debugging, and sometimes learning)
"""
if self.cfg.manually_discretization:
# disc_to_cont: transform discrete action index to original continuous action
action = [-1 + 2 / self.n * k for k in self.disc_to_cont[int(action)]]
action = to_ndarray(action)
# The core original env step.
obs, rew, done, info = self.env.step(action)
return BaseEnvTimestep(obs, rew, done, info)
def __repr__(self) -> str:
return "Action Discretization Env."