|
import sys |
|
sys.path.append("/Users/puyuan/code/LightZero/") |
|
from functools import partial |
|
import torch |
|
from ding.config import compile_config |
|
from ding.envs import create_env_manager |
|
from ding.envs import get_vec_env_setting |
|
from ding.policy import create_policy |
|
from ding.utils import set_pkg_seed |
|
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import main_config, create_config |
|
import numpy as np |
|
|
|
|
|
class Agent: |
|
def __init__(self, seed=0): |
|
|
|
|
|
model_path = None |
|
|
|
|
|
|
|
main_config.env.agent_vs_human = False |
|
|
|
main_config.env.render_mode = 'image_savefile_mode' |
|
main_config.env.replay_path = './video' |
|
|
|
create_config.env_manager.type = 'base' |
|
main_config.env.alphazero_mcts_ctree = False |
|
main_config.policy.mcts_ctree = False |
|
main_config.env.evaluator_env_num = 1 |
|
main_config.env.n_evaluator_episode = 1 |
|
|
|
cfg, create_cfg = [main_config, create_config] |
|
create_cfg.policy.type = create_cfg.policy.type |
|
|
|
if cfg.policy.cuda and torch.cuda.is_available(): |
|
cfg.policy.device = 'cuda' |
|
else: |
|
cfg.policy.device = 'cpu' |
|
|
|
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) |
|
|
|
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) |
|
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) |
|
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) |
|
collector_env.seed(cfg.seed) |
|
evaluator_env.seed(cfg.seed, dynamic_seed=False) |
|
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) |
|
self.policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval']) |
|
|
|
|
|
if model_path is not None: |
|
self.policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) |
|
|
|
def compute_action(self, obs): |
|
|
|
policy_output = self.policy.eval_mode.forward({'0': obs}) |
|
actions = {env_id: output['action'] for env_id, output in policy_output.items()} |
|
return actions['0'] |
|
|
|
|
|
if __name__ == '__main__': |
|
from easydict import EasyDict |
|
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv |
|
cfg = EasyDict( |
|
prob_random_agent=0, |
|
board_size=15, |
|
battle_mode='self_play_mode', |
|
channel_last=False, |
|
scale=False, |
|
agent_vs_human=False, |
|
bot_action_type='v1', |
|
prob_random_action_in_bot=0., |
|
check_action_to_connect4_in_bot_v0=False, |
|
render_mode='state_realtime_mode', |
|
replay_path=None, |
|
screen_scaling=9, |
|
alphazero_mcts_ctree=False, |
|
) |
|
env = GomokuEnv(cfg) |
|
obs = env.reset() |
|
agent = Agent() |
|
|
|
while True: |
|
|
|
observation, reward, done, info = env.step(env.random_action()) |
|
|
|
if not done: |
|
|
|
agent_action = agent.compute_action(observation) |
|
|
|
_, _, done, _ = env.step(agent_action) |
|
|
|
print('orig bot action: {}'.format(agent_action)) |
|
agent_action = {'i': int(agent_action // 15), 'j': int(agent_action % 15)} |
|
print('bot action: {}'.format(agent_action)) |
|
else: |
|
break |