gomoku / LightZero /lzero /worker /alphazero_collector.py
zjowowen's picture
init space
079c32c
raw
history blame
17.1 kB
from collections import namedtuple
from typing import Optional, Any, List, Dict
import numpy as np
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
from ding.worker.collector.base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, \
to_tensor_transitions
@SERIAL_COLLECTOR_REGISTRY.register('episode_alphazero')
class AlphaZeroCollector(ISerialCollector):
"""
Overview:
AlphaZero collector (n_episode).
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
Property:
envstep
"""
# TO be compatible with ISerialCollector
config = dict()
def __init__(
self,
collect_print_freq: int = 100,
env: BaseEnvManager = None,
policy: namedtuple = None,
tb_logger: 'SummaryWriter' = None, # noqa
exp_name: Optional[str] = 'default_experiment',
instance_name: Optional[str] = 'collector',
env_config=None,
) -> None:
"""
Overview:
Init the AlphaZero collector according to input arguments.
Arguments:
- collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- policy (:obj:`Policy`): The policy to be collected.
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary.
- instance_name (:obj:`Optional[str]`): Name of this instance.
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory.
- env_config: Config of environment
"""
self._exp_name = exp_name
self._instance_name = instance_name
self._collect_print_freq = collect_print_freq
self._timer = EasyTimer()
self._end_flag = False
self._env_config = env_config
self._rank = get_rank()
self._world_size = get_world_size()
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name),
name=self._instance_name,
need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
else:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = None
self.reset(policy, env)
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self._env = _env
self._env.launch()
self._env_num = self._env.env_num
else:
self._env.reset()
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset the policy.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
self._policy = _policy
self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None)
self._on_policy = _policy.get_attribute('cfg').on_policy
self._traj_len = INF
self._logger.debug(
'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format(
self._default_n_episode, self._env_num, self._traj_len
)
)
self._policy.reset()
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment and policy.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
"""
if _env is not None:
self.reset_env(_env)
if _policy is not None:
self.reset_policy(_policy)
self._obs_pool = CachePool('obs', self._env_num, deepcopy=False)
self._policy_output_pool = CachePool('policy_output', self._env_num)
# _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions
self._traj_buffer = {env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num)}
self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}
self._episode_info = []
self._total_envstep_count = 0
self._total_episode_count = 0
self._total_duration = 0
self._last_train_iter = 0
self._end_flag = False
def _reset_stat(self, env_id: int) -> None:
"""
Overview:
Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
to get more messages.
Arguments:
- env_id (:obj:`int`): the id where we need to reset the collector's state
"""
self._traj_buffer[env_id].clear()
self._obs_pool.reset(env_id)
self._policy_output_pool.reset(env_id)
self._env_info[env_id] = {'time': 0., 'step': 0}
def close(self) -> None:
"""
Overview:
Close the collector. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()
def collect(self,
n_episode: Optional[int] = None,
train_iter: int = 0,
policy_kwargs: Optional[dict] = None) -> List[Any]:
"""
Overview:
Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations
Arguments:
- n_episode (:obj:`int`): the number of collecting data episode
- train_iter (:obj:`int`): the number of training iteration
- policy_kwargs (:obj:`dict`): the keyword args for policy forward
Returns:
- return_data (:obj:`List`): A list containing collected episodes.
"""
if n_episode is None:
if self._default_n_episode is None:
raise RuntimeError("Please specify collect n_episode")
else:
n_episode = self._default_n_episode
assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num)
if policy_kwargs is None:
policy_kwargs = {}
temperature = policy_kwargs['temperature']
collected_episode = 0
collected_step = 0
return_data = []
ready_env_id = set()
remain_episode = n_episode
while True:
with self._timer:
# Get current env obs.
obs = self._env.ready_obs
new_available_env_id = set(obs.keys()).difference(ready_env_id)
ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
remain_episode -= min(len(new_available_env_id), remain_episode)
obs_ = {env_id: obs[env_id] for env_id in ready_env_id}
# Policy forward.
self._obs_pool.update(obs_)
# ==============================================================
# policy forward
# ==============================================================
policy_output = self._policy.forward(obs_, temperature)
self._policy_output_pool.update(policy_output)
# Interact with env.
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
actions = to_ndarray(actions)
# ==============================================================
# Interact with env.
# ==============================================================
timesteps = self._env.step(actions)
interaction_duration = self._timer.value / len(timesteps)
for env_id, timestep in timesteps.items():
with self._timer:
if timestep.info.get('abnormal', False):
# If there is an abnormal timestep, reset all the related variables(including this env).
# suppose there is no reset param, just reset this env
self._env.reset({env_id: None})
self._policy.reset([env_id])
self._reset_stat(env_id)
self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info))
continue
transition = self._policy.process_transition(
self._obs_pool[env_id], self._policy_output_pool[env_id], timestep
)
transition['collect_iter'] = train_iter
self._traj_buffer[env_id].append(transition)
self._env_info[env_id]['step'] += 1
collected_step += 1
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id])
# reward_shaping
transitions = self.reward_shaping(transitions, timestep.info['eval_episode_return'])
return_data.append(transitions)
self._traj_buffer[env_id].clear()
self._env_info[env_id]['time'] += self._timer.value + interaction_duration
if timestep.done:
self._total_episode_count += 1
# the eval_episode_return is calculated from Player 1's perspective
reward = timestep.info['eval_episode_return']
info = {
'reward': reward, # only means player1 reward
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
}
collected_episode += 1
self._episode_info.append(info)
self._policy.reset([env_id])
self._reset_stat(env_id)
ready_env_id.remove(env_id)
if collected_episode >= n_episode:
break
collected_duration = sum([d['time'] for d in self._episode_info])
# reduce data when enables DDP
if self._world_size > 1:
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
self._total_duration += collected_duration
# log
self._output_log(train_iter)
return return_data
@property
def envstep(self) -> int:
"""
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
"""
return self._total_envstep_count
def close(self) -> None:
"""
Overview:
Close the collector. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
"""
if self._end_flag:
return
self._end_flag = True
self._env.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()
def __del__(self) -> None:
"""
Overview:
Execute the close command and close the collector. __del__ is automatically called to \
destroy the collector instance when the collector finishes its work
"""
self.close()
def _output_log(self, train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to Docs/Best Practice/How to understand\
training generated folders/Serial mode/log/collector for more details.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
if self._rank != 0:
return
if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
self._last_train_iter = train_iter
episode_count = len(self._episode_info)
envstep_count = sum([d['step'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_reward = [d['reward'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_reward),
'reward_std': np.std(episode_reward),
'reward_max': np.max(episode_reward),
'reward_min': np.min(episode_reward),
'total_envstep_count': self._total_envstep_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
}
self._episode_info.clear()
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
for k, v in info.items():
if k in ['each_reward']:
continue
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
if k in ['total_envstep_count']:
continue
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
def reward_shaping(self, transitions, eval_episode_return):
"""
Overview:
Shape the reward according to the player.
Return:
- transitions: data transitions.
"""
reward = transitions[-1]['reward']
to_play = transitions[-1]['obs']['to_play']
for t in transitions:
if t['obs']['to_play'] == -1:
# play_with_bot_mode
# the eval_episode_return is calculated from Player 1's perspective
t['reward'] = eval_episode_return
else:
# self_play_mode
if t['obs']['to_play'] == to_play:
t['reward'] = int(reward)
else:
t['reward'] = int(-reward)
return transitions