zjowowen's picture
init space
079c32c
raw
history blame
2.4 kB
from collections import namedtuple
import numpy as np
def convert(dictionary):
return namedtuple('GenericDict', dictionary.keys())(**dictionary)
class MultiAgentEnv(object):
def __init__(self, batch_size=None, **kwargs):
# Unpack arguments from sacred
args = kwargs["env_args"]
if isinstance(args, dict):
args = convert(args)
self.args = args
if getattr(args, "seed", None) is not None:
self.seed = args.seed
self.rs = np.random.RandomState(self.seed) # initialise numpy random state
def step(self, actions):
""" Returns reward, terminated, info """
raise NotImplementedError
def get_obs(self):
""" Returns all agent observations in a list """
raise NotImplementedError
def get_obs_agent(self, agent_id):
""" Returns observation for agent_id """
raise NotImplementedError
def get_obs_size(self):
""" Returns the shape of the observation """
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def get_state_size(self):
""" Returns the shape of the state"""
raise NotImplementedError
def get_avail_actions(self):
raise NotImplementedError
def get_avail_agent_actions(self, agent_id):
""" Returns the available actions for agent_id """
raise NotImplementedError
def get_total_actions(self):
""" Returns the total number of actions an agent could ever take """
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent
raise NotImplementedError
def get_stats(self):
raise NotImplementedError
# TODO: Temp hack
def get_agg_stats(self, stats):
return {}
def reset(self):
""" Returns initial observations and states"""
raise NotImplementedError
def render(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
def seed(self, seed):
raise NotImplementedError
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit
}
return env_info