import os from dataclasses import dataclass from typing import Any import numpy as np from graphviz import Digraph def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int, reshape=False): """ Overview: Generate a list of random actions. Arguments: - num_actions (:obj:`int`): The number of actions to generate. - action_space_size (:obj:`int`): The size of the action space. - num_of_sampled_actions (:obj:`int`): The number of sampled actions. - reshape (:obj:`bool`): Whether to reshape the actions. Returns: A list of random actions. """ actions = [ np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1) for _ in range(num_actions) ] # If num_of_sampled_actions == 1, flatten the actions to a list of numbers if num_of_sampled_actions == 1: actions = [action[0] for action in actions] # Reshape actions if needed if reshape and num_of_sampled_actions > 1: actions = [action.reshape(num_of_sampled_actions, 1) for action in actions] return actions @dataclass class BufferedData: data: Any index: str meta: dict def get_augmented_data(board_size, play_data): """ Overview: augment the data set by rotation and flipping Arguments: play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for data in play_data: state = data['state'] mcts_prob = data['mcts_prob'] winner = data['winner'] for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud(mcts_prob.reshape(board_size, board_size)), i) extend_data.append( { 'state': equi_state, 'mcts_prob': np.flipud(equi_mcts_prob).flatten(), 'winner': winner } ) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( { 'state': equi_state, 'mcts_prob': np.flipud(equi_mcts_prob).flatten(), 'winner': winner } ) return extend_data def prepare_observation(observation_list, model_type='conv'): """ Overview: Prepare the observations to satisfy the input format of model. if model_type='conv': [B, S, W, H, C] -> [B, S x C, W, H] where B is batch size, S is stack num, W is width, H is height, and C is the number of channels if model_type='mlp': [B, S, O] -> [B, S x O] where B is batch size, S is stack num, O is obs shape. Arguments: - observation_list (:obj:`List`): list of observations. - model_type (:obj:`str`): type of the model. (default is 'conv') """ assert model_type in ['conv', 'mlp'] observation_array = np.array(observation_list) if model_type == 'conv': # for 3-dimensional image obs if len(observation_array.shape) == 3: # for vector obs input, e.g. classical control and box2d environments # to be compatible with LightZero model/policy, # observation_array: [B, S, O], where O is original obs shape # [B, S, O] -> [B, S, O, 1] observation_array = observation_array.reshape( observation_array.shape[0], observation_array.shape[1], observation_array.shape[2], 1 ) elif len(observation_array.shape) == 5: # image obs input, e.g. atari environments # observation_array: [B, S, W, H, C] # 1, 4, 8, 1, 1 -> 1, 4, 1, 8, 1 # [B, S, W, H, C] -> [B, S, C, W, H] observation_array = np.transpose(observation_array, (0, 1, 4, 2, 3)) shape = observation_array.shape # 1, 4, 1, 8, 1 -> 1, 4*1, 8, 1 # [B, S, C, W, H] -> [B, S*C, W, H] observation_array = observation_array.reshape((shape[0], -1, shape[-2], shape[-1])) elif model_type == 'mlp': # for 1-dimensional vector obs # observation_array: [B, S, O], where O is original obs shape # [B, S, O] -> [B, S*O] # print(observation_array.shape) observation_array = observation_array.reshape(observation_array.shape[0], -1) # print(observation_array.shape) return observation_array def obtain_tree_topology(root, to_play=-1): node_stack = [] edge_topology_list = [] node_topology_list = [] node_id_list = [] node_stack.append(root) while len(node_stack) > 0: node = node_stack[-1] node_stack.pop() node_dict = {} node_dict['node_id'] = node.simulation_index node_dict['visit_count'] = node.visit_count node_dict['policy_prior'] = node.prior node_dict['value'] = node.value node_topology_list.append(node_dict) node_id_list.append(node.simulation_index) for a in node.legal_actions: child = node.get_child(a) if child.expanded: child.parent_simulation_index = node.simulation_index edge_dict = {} edge_dict['parent_id'] = node.simulation_index edge_dict['child_id'] = child.simulation_index edge_topology_list.append(edge_dict) node_stack.append(child) return edge_topology_list, node_id_list, node_topology_list def plot_simulation_graph(env_root, current_step, graph_directory=None): edge_topology_list, node_id_list, node_topology_list = obtain_tree_topology(env_root) dot = Digraph(comment='this is direction') for node_topology in node_topology_list: node_name = str(node_topology['node_id']) label = f"node_id: {node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n policy_prior: {round(node_topology['policy_prior'], 4)}, \n value: {round(node_topology['value'], 4)}" dot.node(node_name, label=label) for edge_topology in edge_topology_list: parent_id = str(edge_topology['parent_id']) child_id = str(edge_topology['child_id']) label = parent_id + '-' + child_id dot.edge(parent_id, child_id, label=label) if graph_directory is None: graph_directory = './data_visualize/' if not os.path.exists(graph_directory): os.makedirs(graph_directory) graph_path = graph_directory + 'simulation_visualize_' + str(current_step) + 'step.gv' dot.format = 'png' dot.render(graph_path, view=False)