gomoku / DI-engine /ding /envs /env_manager /envpool_env_manager.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
4.22 kB
import gym
from easydict import EasyDict
from copy import deepcopy
import numpy as np
from collections import namedtuple
from typing import Any, Union, List, Tuple, Dict, Callable, Optional
from ditk import logging
try:
import envpool
except ImportError:
import sys
logging.warning("Please install envpool first, use 'pip install envpool'")
envpool = None
from ding.envs import BaseEnvTimestep
from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts
from ding.torch_utils import to_ndarray
@ENV_MANAGER_REGISTRY.register('env_pool')
class PoolEnvManager:
'''
Overview:
Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
Here we list some commonly used env_ids as follows.
For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>.
- Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
- Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
'''
@classmethod
def default_config(cls) -> EasyDict:
return EasyDict(deepcopy(cls.config))
config = dict(
type='envpool',
# Sync mode: batch_size == env_num
# Async mode: batch_size < env_num
env_num=8,
batch_size=8,
)
def __init__(self, cfg: EasyDict) -> None:
self._cfg = cfg
self._env_num = cfg.env_num
self._batch_size = cfg.batch_size
self._ready_obs = {}
self._closed = True
self._seed = None
def launch(self) -> None:
assert self._closed, "Please first close the env manager"
if self._seed is None:
seed = 0
else:
seed = self._seed
self._envs = envpool.make(
task_id=self._cfg.env_id,
env_type="gym",
num_envs=self._env_num,
batch_size=self._batch_size,
seed=seed,
episodic_life=self._cfg.episodic_life,
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip
)
self._closed = False
self.reset()
def reset(self) -> None:
self._ready_obs = {}
self._envs.async_reset()
while True:
obs, _, _, info = self._envs.recv()
env_id = info['env_id']
obs = obs.astype(np.float32)
self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs)
if len(self._ready_obs) == self._env_num:
break
self._eval_episode_return = [0. for _ in range(self._env_num)]
def step(self, action: dict) -> Dict[int, namedtuple]:
env_id = np.array(list(action.keys()))
action = np.array(list(action.values()))
if len(action.shape) == 2:
action = action.squeeze(1)
self._envs.send(action, env_id)
obs, rew, done, info = self._envs.recv()
obs = obs.astype(np.float32)
rew = rew.astype(np.float32)
env_id = info['env_id']
timesteps = {}
self._ready_obs = {}
for i in range(len(env_id)):
d = bool(done[i])
r = to_ndarray([rew[i]])
self._eval_episode_return[env_id[i]] += r
timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i})
if d:
timesteps[env_id[i]].info['eval_episode_return'] = self._eval_episode_return[env_id[i]]
self._eval_episode_return[env_id[i]] = 0.
self._ready_obs[env_id[i]] = obs[i]
return timesteps
def close(self) -> None:
if self._closed:
return
# Envpool has no `close` API
self._closed = True
def seed(self, seed: int, dynamic_seed=False) -> None:
# The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here
self._seed = seed
logging.warning("envpool doesn't support dynamic_seed in different episode")
@property
def env_num(self) -> int:
return self._env_num
@property
def ready_obs(self) -> Dict[int, Any]:
return self._ready_obs