zjowowen's picture
init space
079c32c
raw
history blame
No virus
6.88 kB
import os
from dataclasses import dataclass
from typing import Any
import numpy as np
from graphviz import Digraph
def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int,
reshape=False):
"""
Overview:
Generate a list of random actions.
Arguments:
- num_actions (:obj:`int`): The number of actions to generate.
- action_space_size (:obj:`int`): The size of the action space.
- num_of_sampled_actions (:obj:`int`): The number of sampled actions.
- reshape (:obj:`bool`): Whether to reshape the actions.
Returns:
A list of random actions.
"""
actions = [
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1)
for _ in range(num_actions)
]
# If num_of_sampled_actions == 1, flatten the actions to a list of numbers
if num_of_sampled_actions == 1:
actions = [action[0] for action in actions]
# Reshape actions if needed
if reshape and num_of_sampled_actions > 1:
actions = [action.reshape(num_of_sampled_actions, 1) for action in actions]
return actions
@dataclass
class BufferedData:
data: Any
index: str
meta: dict
def get_augmented_data(board_size, play_data):
"""
Overview:
augment the data set by rotation and flipping
Arguments:
play_data: [(state, mcts_prob, winner_z), ..., ...]
"""
extend_data = []
for data in play_data:
state = data['state']
mcts_prob = data['mcts_prob']
winner = data['winner']
for i in [1, 2, 3, 4]:
# rotate counterclockwise
equi_state = np.array([np.rot90(s, i) for s in state])
equi_mcts_prob = np.rot90(np.flipud(mcts_prob.reshape(board_size, board_size)), i)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
# flip horizontally
equi_state = np.array([np.fliplr(s) for s in equi_state])
equi_mcts_prob = np.fliplr(equi_mcts_prob)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
return extend_data
def prepare_observation(observation_list, model_type='conv'):
"""
Overview:
Prepare the observations to satisfy the input format of model.
if model_type='conv':
[B, S, W, H, C] -> [B, S x C, W, H]
where B is batch size, S is stack num, W is width, H is height, and C is the number of channels
if model_type='mlp':
[B, S, O] -> [B, S x O]
where B is batch size, S is stack num, O is obs shape.
Arguments:
- observation_list (:obj:`List`): list of observations.
- model_type (:obj:`str`): type of the model. (default is 'conv')
"""
assert model_type in ['conv', 'mlp']
observation_array = np.array(observation_list)
if model_type == 'conv':
# for 3-dimensional image obs
if len(observation_array.shape) == 3:
# for vector obs input, e.g. classical control and box2d environments
# to be compatible with LightZero model/policy,
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S, O, 1]
observation_array = observation_array.reshape(
observation_array.shape[0], observation_array.shape[1], observation_array.shape[2], 1
)
elif len(observation_array.shape) == 5:
# image obs input, e.g. atari environments
# observation_array: [B, S, W, H, C]
# 1, 4, 8, 1, 1 -> 1, 4, 1, 8, 1
# [B, S, W, H, C] -> [B, S, C, W, H]
observation_array = np.transpose(observation_array, (0, 1, 4, 2, 3))
shape = observation_array.shape
# 1, 4, 1, 8, 1 -> 1, 4*1, 8, 1
# [B, S, C, W, H] -> [B, S*C, W, H]
observation_array = observation_array.reshape((shape[0], -1, shape[-2], shape[-1]))
elif model_type == 'mlp':
# for 1-dimensional vector obs
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S*O]
# print(observation_array.shape)
observation_array = observation_array.reshape(observation_array.shape[0], -1)
# print(observation_array.shape)
return observation_array
def obtain_tree_topology(root, to_play=-1):
node_stack = []
edge_topology_list = []
node_topology_list = []
node_id_list = []
node_stack.append(root)
while len(node_stack) > 0:
node = node_stack[-1]
node_stack.pop()
node_dict = {}
node_dict['node_id'] = node.simulation_index
node_dict['visit_count'] = node.visit_count
node_dict['policy_prior'] = node.prior
node_dict['value'] = node.value
node_topology_list.append(node_dict)
node_id_list.append(node.simulation_index)
for a in node.legal_actions:
child = node.get_child(a)
if child.expanded:
child.parent_simulation_index = node.simulation_index
edge_dict = {}
edge_dict['parent_id'] = node.simulation_index
edge_dict['child_id'] = child.simulation_index
edge_topology_list.append(edge_dict)
node_stack.append(child)
return edge_topology_list, node_id_list, node_topology_list
def plot_simulation_graph(env_root, current_step, graph_directory=None):
edge_topology_list, node_id_list, node_topology_list = obtain_tree_topology(env_root)
dot = Digraph(comment='this is direction')
for node_topology in node_topology_list:
node_name = str(node_topology['node_id'])
label = f"node_id: {node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n policy_prior: {round(node_topology['policy_prior'], 4)}, \n value: {round(node_topology['value'], 4)}"
dot.node(node_name, label=label)
for edge_topology in edge_topology_list:
parent_id = str(edge_topology['parent_id'])
child_id = str(edge_topology['child_id'])
label = parent_id + '-' + child_id
dot.edge(parent_id, child_id, label=label)
if graph_directory is None:
graph_directory = './data_visualize/'
if not os.path.exists(graph_directory):
os.makedirs(graph_directory)
graph_path = graph_directory + 'simulation_visualize_' + str(current_step) + 'step.gv'
dot.format = 'png'
dot.render(graph_path, view=False)