gomoku / DI-engine /dizoo /beergame /envs /beergame_core.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
4.74 kB
from __future__ import print_function
from dizoo.beergame.envs import clBeerGame
from torch import Tensor
import numpy as np
import random
from .utils import get_config, update_config
import gym
import os
from typing import Optional
class BeerGame():
def __init__(self, role: int, agent_type: str, demandDistribution: int) -> None:
self._cfg, unparsed = get_config()
self._role = role
# prepare loggers and directories
# prepare_dirs_and_logger(self._cfg)
self._cfg = update_config(self._cfg)
# set agent type
if agent_type == 'bs':
self._cfg.agentTypes = ["bs", "bs", "bs", "bs"]
elif agent_type == 'Strm':
self._cfg.agentTypes = ["Strm", "Strm", "Strm", "Strm"]
self._cfg.agentTypes[role] = "srdqn"
self._cfg.demandDistribution = demandDistribution
# load demands:0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data
if self._cfg.observation_data:
adsr = 'data/demandTr-obs-'
elif self._cfg.demandDistribution == 3:
if self._cfg.scaled:
adsr = 'data/basket_data/scaled'
else:
adsr = 'data/basket_data'
direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
self._demandTr = np.load(direc)
print("loaded training set=", direc)
elif self._cfg.demandDistribution == 4:
if self._cfg.scaled:
adsr = 'data/forecast_data/scaled'
else:
adsr = 'data/forecast_data'
direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
self._demandTr = np.load(direc)
print("loaded training set=", direc)
else:
if self._cfg.demandDistribution == 0: # uniform
self._demandTr = np.random.randint(0, self._cfg.demandUp, size=[self._cfg.demandSize, self._cfg.TUp])
elif self._cfg.demandDistribution == 1: # normal distribution
self._demandTr = np.round(
np.random.normal(
self._cfg.demandMu, self._cfg.demandSigma, size=[self._cfg.demandSize, self._cfg.TUp]
)
).astype(int)
elif self._cfg.demandDistribution == 2: # the sequence of 4,4,4,4,8,...
self._demandTr = np.concatenate(
(4 * np.ones((self._cfg.demandSize, 4)), 8 * np.ones((self._cfg.demandSize, 98))), axis=1
).astype(int)
# initilize an instance of Beergame
self._env = clBeerGame(self._cfg)
self.observation_space = gym.spaces.Box(
low=float("-inf"),
high=float("inf"),
shape=(self._cfg.stateDim * self._cfg.multPerdInpt, ),
dtype=np.float32
) # state_space = state_dim * m (considering the reward delay)
self.action_space = gym.spaces.Discrete(self._cfg.actionListLen) # length of action list
self.reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
# get the length of the demand.
self._demand_len = np.shape(self._demandTr)[0]
def reset(self):
self._env.resetGame(demand=self._demandTr[random.randint(0, self._demand_len - 1)])
obs = [i for item in self._env.players[self._role].currentState for i in item]
return obs
def seed(self, seed: int) -> None:
self._seed = seed
np.random.seed(self._seed)
def close(self) -> None:
pass
def step(self, action: np.ndarray):
self._env.handelAction(action)
self._env.next()
newstate = np.append(
self._env.players[self._role].currentState[1:, :], [self._env.players[self._role].nextObservation], axis=0
)
self._env.players[self._role].currentState = newstate
obs = [i for item in newstate for i in item]
rew = self._env.players[self._role].curReward
done = (self._env.curTime == self._env.T)
info = {}
return obs, rew, done, info
def reward_shaping(self, reward: Tensor) -> Tensor:
self._totRew, self._cumReward = self._env.distTotReward(self._role)
reward += (self._cfg.distCoeff / 3) * ((self._totRew - self._cumReward) / (self._env.T))
return reward
def enable_save_figure(self, figure_path: Optional[str] = None) -> None:
self._cfg.ifSaveFigure = True
if figure_path is None:
figure_path = './'
self._cfg.figure_dir = figure_path
self._env.doTestMid(self._demandTr[random.randint(0, self._demand_len - 1)])