|
""" |
|
Overview: |
|
This code implements the Monte Carlo Tree Search (MCTS) algorithm with the integration of neural networks. |
|
The Node class represents a node in the Monte Carlo tree and implements the basic functionalities expected in a node. |
|
The MCTS class implements the specific search functionality and provides the optimal action through the ``get_next_action`` method. |
|
Compared to traditional MCTS, the introduction of value networks and policy networks brings several advantages. |
|
During the expansion of nodes, it is no longer necessary to explore every single child node, but instead, |
|
the child nodes are directly selected based on the prior probabilities provided by the neural network. |
|
This reduces the breadth of the search. When estimating the value of leaf nodes, there is no need for a rollout; |
|
instead, the value output by the neural network is used, which saves the depth of the search. |
|
""" |
|
|
|
import copy |
|
import math |
|
from typing import List, Tuple, Union, Callable, Type, Dict, Any |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from ding.envs import BaseEnv |
|
from easydict import EasyDict |
|
|
|
|
|
class Node(object): |
|
""" |
|
Overview: |
|
A class for a node in a Monte Carlo Tree. The properties of this class store basic information about the node, |
|
such as its parent node, child nodes, and the number of times the node has been visited. |
|
The methods of this class implement basic functionalities that a node should have, such as propagating the value back, |
|
checking if the node is the root node, and determining if it is a leaf node. |
|
""" |
|
|
|
def __init__(self, parent: "Node" = None, prior_p: float = 1.0) -> None: |
|
""" |
|
Overview: |
|
Initialize a Node object. |
|
Arguments: |
|
- parent (:obj:`Node`): The parent node of the current node. |
|
- prior_p (:obj:`Float`): The prior probability of selecting this node. |
|
""" |
|
|
|
self._parent = parent |
|
|
|
|
|
self._children = {} |
|
|
|
self._visit_count = 0 |
|
|
|
self._value_sum = 0 |
|
|
|
self.prior_p = prior_p |
|
|
|
@property |
|
def value(self) -> float: |
|
""" |
|
Overview: |
|
The value of the current node. |
|
Returns: |
|
- output (:obj:`Int`): Current value, used to compute ucb score. |
|
""" |
|
|
|
if self._visit_count == 0: |
|
return 0 |
|
return self._value_sum / self._visit_count |
|
|
|
def update(self, value: float) -> None: |
|
""" |
|
Overview: |
|
Update the current node information, such as ``_visit_count`` and ``_value_sum``. |
|
Arguments: |
|
- value (:obj:`Float`): The value of the node. |
|
""" |
|
|
|
self._visit_count += 1 |
|
|
|
self._value_sum += value |
|
|
|
def update_recursive(self, leaf_value: float, battle_mode_in_simulation_env: str) -> None: |
|
""" |
|
Overview: |
|
Update node information recursively. |
|
The same game state has opposite values in the eyes of two players playing against each other. |
|
The value of a node is evaluated from the perspective of the player corresponding to its parent node. |
|
In ``self_play_mode``, because the player corresponding to a node changes every step during the backpropagation process, the value needs to be negated once. |
|
In ``play_with_bot_mode``, since all nodes correspond to the same player, the value does not need to be negated. |
|
|
|
Arguments: |
|
- leaf_value (:obj:`Float`): The value of the node. |
|
- battle_mode_in_simulation_env (:obj:`str`): The mode of MCTS, can be 'self_play_mode' or 'play_with_bot_mode'. |
|
""" |
|
|
|
if battle_mode_in_simulation_env == 'self_play_mode': |
|
|
|
self.update(leaf_value) |
|
|
|
if self.is_root(): |
|
return |
|
|
|
|
|
self._parent.update_recursive(-leaf_value, battle_mode_in_simulation_env) |
|
if battle_mode_in_simulation_env == 'play_with_bot_mode': |
|
|
|
self.update(leaf_value) |
|
|
|
if self.is_root(): |
|
return |
|
|
|
|
|
|
|
self._parent.update_recursive(leaf_value, battle_mode_in_simulation_env) |
|
|
|
def is_leaf(self) -> bool: |
|
""" |
|
Overview: |
|
Check if the current node is a leaf node or not. |
|
Returns: |
|
- output (:obj:`Bool`): If self._children is empty, it means that the node has not |
|
been expanded yet, which indicates that the node is a leaf node. |
|
""" |
|
|
|
return self._children == {} |
|
|
|
def is_root(self) -> bool: |
|
""" |
|
Overview: |
|
Check if the current node is a root node or not. |
|
Returns: |
|
- output (:obj:`Bool`): If the node does not have a parent node, |
|
then it is a root node. |
|
""" |
|
return self._parent is None |
|
|
|
@property |
|
def parent(self) -> None: |
|
""" |
|
Overview: |
|
Get the parent node of the current node. |
|
Returns: |
|
- output (:obj:`Node`): The parent node of the current node. |
|
""" |
|
return self._parent |
|
|
|
@property |
|
def children(self) -> None: |
|
""" |
|
Overview: |
|
Get the dictionary of children nodes of the current node. |
|
Returns: |
|
- output (:obj:`dict`): A dictionary representing the children of the current node. |
|
""" |
|
return self._children |
|
|
|
@property |
|
def visit_count(self) -> None: |
|
""" |
|
Overview: |
|
Get the number of times the current node has been visited. |
|
Returns: |
|
- output (:obj:`Int`): The number of times the current node has been visited. |
|
""" |
|
return self._visit_count |
|
|
|
|
|
class MCTS(object): |
|
""" |
|
Overview: |
|
A class for Monte Carlo Tree Search (MCTS). The methods in this class implement the steps involved in MCTS, such as selection and expansion. |
|
Based on this, the ``_simulate`` method is used to traverse from the root node to a leaf node. |
|
Finally, by repeatedly calling ``_simulate`` through ``get_next_action``, the optimal action is obtained. |
|
""" |
|
|
|
def __init__(self, cfg: EasyDict, simulate_env: Type[BaseEnv]) -> None: |
|
""" |
|
Overview: |
|
Initializes the MCTS process. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): A dictionary containing the configuration parameters for the MCTS process. |
|
""" |
|
|
|
self._cfg = cfg |
|
|
|
|
|
self._max_moves = self._cfg.get('max_moves', 512) |
|
|
|
self._num_simulations = self._cfg.get('num_simulations', 800) |
|
|
|
|
|
self._pb_c_base = self._cfg.get('pb_c_base', 19652) |
|
self._pb_c_init = self._cfg.get('pb_c_init', 1.25) |
|
|
|
|
|
self._root_dirichlet_alpha = self._cfg.get( |
|
'root_dirichlet_alpha', 0.3 |
|
) |
|
self._root_noise_weight = self._cfg.get('root_noise_weight', 0.25) |
|
|
|
self.simulate_env = simulate_env |
|
|
|
def get_next_action( |
|
self, |
|
state_config_for_simulate_env_reset: Dict[str, Any], |
|
policy_forward_fn: Callable, |
|
temperature: int = 1.0, |
|
sample: bool = True |
|
) -> Tuple[int, List[float]]: |
|
""" |
|
Overview: |
|
Get the next action to take based on the current state of the game. |
|
Arguments: |
|
- state_config_for_simulate_env_reset (:obj:`Dict`): The config of state when reset the env. |
|
- policy_forward_fn (:obj:`Function`): The Callable to compute the action probs and state value. |
|
- temperature (:obj:`Float`): The exploration temperature. |
|
- sample (:obj:`Bool`): Whether to sample an action from the probabilities or choose the most probable action. |
|
Returns: |
|
- action (:obj:`Int`): The selected action to take. |
|
- action_probs (:obj:`List`): The output probability of each action. |
|
""" |
|
|
|
|
|
root = Node() |
|
|
|
self.simulate_env.reset( |
|
start_player_index=state_config_for_simulate_env_reset.start_player_index, |
|
init_state=state_config_for_simulate_env_reset.init_state, |
|
) |
|
|
|
self._expand_leaf_node(root, self.simulate_env, policy_forward_fn) |
|
|
|
|
|
if sample: |
|
self._add_exploration_noise(root) |
|
|
|
|
|
for n in range(self._num_simulations): |
|
|
|
self.simulate_env.reset( |
|
start_player_index=state_config_for_simulate_env_reset.start_player_index, |
|
init_state=state_config_for_simulate_env_reset.init_state, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.simulate_env.battle_mode = self.simulate_env.battle_mode_in_simulation_env |
|
self.simulate_env.render_mode = None |
|
|
|
self._simulate(root, self.simulate_env, policy_forward_fn) |
|
|
|
|
|
action_visits = [] |
|
for action in range(self.simulate_env.action_space.n): |
|
if action in root.children: |
|
action_visits.append((action, root.children[action].visit_count)) |
|
else: |
|
action_visits.append((action, 0)) |
|
|
|
|
|
actions, visits = zip(*action_visits) |
|
|
|
|
|
|
|
visits_t = torch.as_tensor(visits, dtype=torch.float32) |
|
visits_t = torch.pow(visits_t, 1/temperature) |
|
action_probs = (visits_t / visits_t.sum()).numpy() |
|
|
|
|
|
|
|
|
|
if sample: |
|
action = np.random.choice(actions, p=action_probs) |
|
else: |
|
action = actions[np.argmax(action_probs)] |
|
|
|
|
|
return action, action_probs |
|
|
|
def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> None: |
|
""" |
|
Overview: |
|
Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through its parents. |
|
State is modified in-place, so a deepcopy must be provided. |
|
Arguments: |
|
- node (:obj:`Class Node`): Current node when performing mcts search. |
|
- simulate_env (:obj:`Class BaseGameEnv`): The class of simulate env. |
|
- policy_forward_fn (:obj:`Function`): The Callable to compute the action probs and state value. |
|
""" |
|
while not node.is_leaf(): |
|
|
|
action, node = self._select_child(node, simulate_env) |
|
|
|
if action is None: |
|
break |
|
simulate_env.step(action) |
|
|
|
done, winner = simulate_env.get_done_winner() |
|
""" |
|
in ``self_play_mode``, the leaf_value is calculated from the perspective of player ``simulate_env.current_player``. |
|
in ``play_with_bot_mode``, the leaf_value is calculated from the perspective of player 1. |
|
""" |
|
|
|
if not done: |
|
|
|
|
|
|
|
|
|
leaf_value = self._expand_leaf_node(node, simulate_env, policy_forward_fn) |
|
else: |
|
if simulate_env.battle_mode_in_simulation_env == 'self_play_mode': |
|
|
|
if winner == -1: |
|
leaf_value = 0 |
|
else: |
|
|
|
|
|
|
|
leaf_value = 1 if simulate_env.current_player == winner else -1 |
|
|
|
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode': |
|
|
|
if winner == -1: |
|
leaf_value = 0 |
|
elif winner == 1: |
|
leaf_value = 1 |
|
elif winner == 2: |
|
leaf_value = -1 |
|
|
|
|
|
if simulate_env.battle_mode_in_simulation_env == 'play_with_bot_mode': |
|
node.update_recursive(leaf_value, simulate_env.battle_mode_in_simulation_env) |
|
elif simulate_env.battle_mode_in_simulation_env == 'self_play_mode': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
node.update_recursive(-leaf_value, simulate_env.battle_mode_in_simulation_env) |
|
|
|
def _select_child(self, node: Node, simulate_env: Type[BaseEnv]) -> Tuple[Union[int, float], Node]: |
|
""" |
|
Overview: |
|
Select the child with the highest UCB score. |
|
Arguments: |
|
- node (:obj:`Class Node`): Current node. |
|
Returns: |
|
- action (:obj:`Int`): choose the action with the highest ucb score. |
|
- child (:obj:`Node`): the child node reached by executing the action with the highest ucb score. |
|
""" |
|
|
|
action = None |
|
child = None |
|
best_score = -9999999 |
|
|
|
for action_tmp, child_tmp in node.children.items(): |
|
""" |
|
Check if the action is present in the list of legal actions for the current environment. |
|
This check is relevant only when the agent is training in "play_with_bot_mode" and the bot's actions involve strong randomness. |
|
""" |
|
if action_tmp in simulate_env.legal_actions: |
|
score = self._ucb_score(node, child_tmp) |
|
|
|
if score > best_score: |
|
best_score = score |
|
action = action_tmp |
|
child = child_tmp |
|
|
|
if child is None: |
|
child = node |
|
|
|
return action, child |
|
|
|
def _expand_leaf_node(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn: Callable) -> float: |
|
""" |
|
Overview: |
|
expand the node with the policy_forward_fn. |
|
Arguments: |
|
- node (:obj:`Class Node`): current node when performing mcts search. |
|
- simulate_env (:obj:`Class BaseGameEnv`): the class of simulate env. |
|
- policy_forward_fn (:obj:`Function`): the Callable to compute the action probs and state value. |
|
Returns: |
|
- leaf_value (:obj:`Bool`): the leaf node's value. |
|
""" |
|
|
|
|
|
action_probs_dict, leaf_value = policy_forward_fn(simulate_env) |
|
|
|
|
|
for action, prior_p in action_probs_dict.items(): |
|
|
|
|
|
if action in simulate_env.legal_actions: |
|
node.children[action] = Node(parent=node, prior_p=prior_p) |
|
|
|
|
|
return leaf_value |
|
|
|
def _ucb_score(self, parent: Node, child: Node) -> float: |
|
""" |
|
Overview: |
|
Compute UCB score. The score for a node is based on its value, plus an exploration bonus based on the prior. |
|
For more details, please refer to this paper: http://gauss.ececs.uc.edu/Workshops/isaim2010/papers/rosin.pdf |
|
UCB = Q(s,a) + P(s,a) \cdot \frac{ \sqrt{N(\text{parent})}}{1+N(\text{child})} \cdot \left(c_1 + \log\left(\frac{N(\text{parent})+c_2+1}{c_2}\right)\right) |
|
- Q(s,a): value of a child node. |
|
- P(s,a): The prior of a child node. |
|
- N(parent): The number of the visiting of the parent node. |
|
- N(child): The number of the visiting of the child node. |
|
- c_1: a parameter given by self._pb_c_init to control the influence of the prior P(s,a) relative to the value Q(s,a). |
|
- c_2: a parameter given by self._pb_c_base to control the influence of the prior P(s,a) relative to the value Q(s,a). |
|
Arguments: |
|
- parent (:obj:`Class Node`): Current node. |
|
- child (:obj:`Class Node`): Current node's child. |
|
Returns: |
|
- score (:obj:`Bool`): The UCB score. |
|
""" |
|
|
|
pb_c = math.log((parent.visit_count + self._pb_c_base + 1) / self._pb_c_base) + self._pb_c_init |
|
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) |
|
|
|
|
|
prior_score = pb_c * child.prior_p |
|
value_score = child.value |
|
return prior_score + value_score |
|
|
|
def _add_exploration_noise(self, node: Node) -> None: |
|
""" |
|
Overview: |
|
Add exploration noise. |
|
Arguments: |
|
- node (:obj:`Class Node`): Current node. |
|
""" |
|
|
|
actions = list(node.children.keys()) |
|
|
|
alpha = [self._root_dirichlet_alpha] * len(actions) |
|
|
|
noise = np.random.dirichlet(alpha) |
|
|
|
frac = self._root_noise_weight |
|
|
|
for a, n in zip(actions, noise): |
|
node.children[a].prior_p = node.children[a].prior_p * (1 - frac) + n * frac |