GreedRL / greedrl /pyenv.py
先坤
add greedrl
db26c81
raw
history blame
14.3 kB
import torch
import json
import math
from collections import OrderedDict
from .const import *
from .utils import to_list
from .norm import Norm1D, Norm2D
from .variable import AttributeVariable, WorkerTaskSequence
class PyEnv(object):
def __init__(self, problem, batch_size, sample_num, nn_args):
super(PyEnv, self).__init__()
self._problem = problem
self._batch_size = batch_size
self._sample_num = sample_num
self._debug = -1
self._NW = problem.worker_num
self._NWW = problem.worker_num * 2
self._NT = problem.task_num
self._NWWT = self._NWW + self._NT
self._feats_dict = nn_args['feature_dict']
self._vars_dim = nn_args['variable_dim']
self._vars_dict = {}
self._vars = [var(problem, batch_size, sample_num) for var in problem.variables]
for variable in self._vars:
save_variable_version(variable)
assert variable.name not in self._vars_dict, \
"duplicated variable, name: {}".format(variable.name)
self._vars_dict[variable.name] = variable
self._constraint = problem.constraint()
self._objective = problem.objective()
self._worker_index = torch.full((self._batch_size,), -1,
dtype=torch.int64,
device=problem.device)
self._batch_index = torch.arange(self._batch_size,
dtype=torch.int64,
device=problem.device)
self._problem_index = torch.div(self._batch_index, sample_num, rounding_mode='trunc') # self._batch_index // sample_num
self._feasible = torch.ones(self._batch_size,
dtype=torch.bool,
device=problem.device)
self._cost = torch.zeros(self._batch_size, self._NT * 2,
dtype=torch.float32,
device=problem.device)
self._mask = torch.zeros(self._batch_size,
self._NWWT + 1,
dtype=torch.bool,
device=problem.device)
self._worker_task_sequence = torch.full((self._batch_size, self._NT * 2, 3), -1,
dtype=torch.int64,
device=problem.device)
self._step = 0
self.register_variables(self._constraint)
self._finished = self._constraint.finished()
if hasattr(self._constraint, 'mask_worker_start'):
self.register_variables(self._constraint)
mask_start = self._constraint.mask_worker_start()
else:
mask_start = False
self._mask[:, :self._NW] = mask_start
self._mask[:, self._NW:] = True
if self._debug >= 0:
print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
print("new env")
print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n")
def time(self):
return self._step
def step(self, chosen):
with torch.no_grad():
self._do_step(chosen)
def _do_step(self, chosen):
if self._debug >= 0:
print("----------------------------------------------------------------------")
feasible = self._feasible & ~self._mask[self._problem_index, chosen]
print("feasible={}".format(feasible[self._debug].tolist()))
is_start = (chosen >= 0) & (chosen < self._NW)
if torch.any(is_start):
b_index = self._batch_index[is_start]
p_index = self._problem_index[is_start]
w_index = chosen[is_start]
self.step_worker_start(b_index, p_index, w_index)
is_end = (chosen >= self._NW) & (chosen < self._NWW)
if torch.any(is_end):
b_index = self._batch_index[is_end]
p_index = self._problem_index[is_end]
w_index = chosen[is_end] - self._NW
self.step_worker_end(b_index, p_index, w_index)
is_task = (chosen >= self._NWW) & (chosen < self._NWWT)
if torch.any(is_task):
b_index = self._batch_index[is_task]
p_index = self._problem_index[is_task]
t_index = chosen[is_task] - self._NWW
step_task_b_index = b_index
self.step_task(b_index, p_index, t_index)
else:
step_task_b_index = None
is_finish = chosen == self._NWWT
if torch.any(is_finish):
b_index = self._batch_index[is_finish]
self._worker_task_sequence[b_index, self._step, 0] = GRL_FINISH
self._worker_task_sequence[b_index, self._step, 1] = 0
self._worker_task_sequence[b_index, self._step, 2] = -1
self.update_mask(step_task_b_index)
for var in self._vars:
check_variable_version(var)
if self._debug >= 0:
print("worker_task_sequence[{}]={}".format(self._step,
self._worker_task_sequence[self._debug, self._step].tolist()))
for var in self._vars:
if var.value is None:
print("{}={}".format(var.name, None))
elif isinstance(var, AttributeVariable):
print("{}={}".format(var.name, to_list(var.value)))
else:
print("{}={}".format(var.name, to_list(var.value[self._debug])))
self._step += 1
if self._step >= self._cost.size(1):
cost = torch.zeros(self._batch_size, self._step + self._NT,
dtype=torch.float32,
device=chosen.device)
cost[:, 0:self._step] = self._cost;
self._cost = cost
worker_task_sequence = torch.full((self._batch_size, self._step + self._NT, 3), -1,
dtype=torch.int64,
device=chosen.device)
worker_task_sequence[:, 0:self._step, :] = self._worker_task_sequence
self._worker_task_sequence = worker_task_sequence
def step_worker_start(self, b_index, p_index, w_index):
self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_START
self._worker_task_sequence[b_index, self._step, 1] = w_index
self._worker_task_sequence[b_index, self._step, 2] = -1
for var in self._vars:
if hasattr(var, 'step_worker_start'):
var.step_worker_start(b_index, p_index, w_index)
save_variable_version(var)
if hasattr(self._objective, 'step_worker_start'):
self.register_variables(self._objective, b_index)
self.update_cost(self._objective.step_worker_start(), b_index)
self._worker_index[b_index] = w_index
self._mask[b_index, :self._NWW] = True
self._mask[b_index, self._NWW:] = False
def step_worker_end(self, b_index, p_index, w_index):
self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_END
self._worker_task_sequence[b_index, self._step, 1] = w_index
self._worker_task_sequence[b_index, self._step, 2] = -1;
for var in self._vars:
if hasattr(var, 'step_worker_end'):
var.step_worker_end(b_index, p_index, w_index)
save_variable_version(var)
if hasattr(self._objective, 'step_worker_end'):
self.register_variables(self._objective, b_index)
self.update_cost(self._objective.step_worker_end(), b_index)
self._worker_index[b_index] = -1
self.register_variables(self._constraint, b_index)
self._finished[b_index] |= self._constraint.finished()
if hasattr(self._constraint, 'mask_worker_start'):
mask_start = self._constraint.mask_worker_start()
else:
mask_start = False
self._mask[b_index, :self._NW] = mask_start
self._mask[b_index, self._NW:] = True
def step_task(self, b_index, p_index, t_index):
self._worker_task_sequence[b_index, self._step, 0] = GRL_TASK
self._worker_task_sequence[b_index, self._step, 1] = t_index
for var in self._vars:
if not hasattr(var, 'step_task'):
continue
elif var.step_task.__code__.co_argcount == 4:
var.step_task(b_index, p_index, t_index)
else:
var.step_task(b_index, p_index, t_index, None)
save_variable_version(var)
if hasattr(self._constraint, 'do_task'):
self.register_variables(self._constraint, b_index)
done = self._constraint.do_task()
self._worker_task_sequence[b_index, self._step, 2] = done.long()
for var in self._vars:
if not hasattr(var, 'step_task'):
continue
elif var.step_task.__code__.co_argcount == 4:
pass
else:
check_variable_version(var)
var.step_task(b_index, p_index, t_index, done)
save_variable_version(var)
else:
done = None
if hasattr(self._objective, 'step_task'):
self.register_variables(self._objective, b_index)
self.update_cost(self._objective.step_task(), b_index)
if hasattr(self._constraint, 'mask_worker_end'):
self.register_variables(self._constraint, b_index)
mask_end = self._constraint.mask_worker_end()
else:
mask_end = False
w_index = self._NW + self._worker_index[b_index]
self._mask[b_index, w_index] = mask_end
self._mask[b_index, self._NWW:] = False
return done
def update_cost(self, cost, b_index=None):
if isinstance(cost, tuple):
cost, feasible = cost
if b_index is None:
self._feasible &= feasible
else:
self._feasible[b_index] &= feasible
if isinstance(cost, torch.Tensor):
cost = cost.float()
else:
assert type(cost) in (int, float), "unexpected cost's type: {}".format(type(cost))
if b_index is None:
self._cost[:, self._step] = cost
else:
self._cost[b_index, self._step] = cost
def update_mask(self, step_task_b_index):
self._mask |= self._finished[:, None]
self._mask[:, -1] = ~self._finished
self.register_variables(self._constraint)
self._mask[:, self._NWW:self._NWWT] |= self._constraint.mask_task()
if step_task_b_index is not None:
b_index = step_task_b_index
w_index = self._NW + self._worker_index[b_index]
task_mask = self._mask[b_index, self._NWW:self._NWWT]
self._mask[b_index, w_index] &= ~torch.all(task_mask, 1)
def batch_size():
return self._batch_size
def sample_num():
return self._sample_num
def mask(self):
return self._mask.clone()
def cost(self):
return self._cost[:, 0:self._step]
def feasible(self):
return self._feasible
def worker_task_sequence(self):
return self._worker_task_sequence[:, 0:self._step]
def var(self, name):
return self._vars_dict[name].value
def register_variables(self, obj, b_index=None, finished=False):
for var in self._vars:
if var.value is None or b_index is None \
or isinstance(var, AttributeVariable):
value = var.value
else:
value = var.value[b_index]
obj.__dict__[var.name] = value
if not hasattr(var, 'ext_values'):
continue
for k, v in var.ext_values.items():
k = var.name + '_' + k
obj.__dict__[k] = v[b_index]
def finished(self):
return self._finished
def all_finished(self):
return torch.all(self.finished())
def finalize(self):
self._worker_task_sequence[:, self._step, 0] = GRL_FINISH
self._worker_task_sequence[:, self._step, 1] = 0
self._worker_task_sequence[:, self._step, 2] = -1
for var in self._vars:
if hasattr(var, 'step_finish'):
var.step_finish(self.worker_task_sequence())
if hasattr(self._objective, 'step_finish'):
self.register_variables(self._objective, finished=True)
self.update_cost(self._objective.step_finish())
self._step += 1
def make_feat(self):
with torch.no_grad():
return self.do_make_feat()
def do_make_feat(self):
if not self._vars_dim:
return None
feature_list = []
for k, dim in self._vars_dim.items():
f = self._feats_dict[k]
var = self._vars_dict[f.name]
v = var.make_feat()
if v.dim() == 2:
v = v[:, :, None]
assert dim == v.size(-1), \
"feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
feature_list.append(v.float())
v = torch.cat(feature_list, 2)
u = v.new_zeros(v.size(0), self._NWW, v.size(2))
f = v.new_zeros(v.size(0), 1, v.size(2))
v = torch.cat([u, v, f], 1).permute(0, 2, 1)
v[self._mask[:, None, :].expand(v.size())] = 0
norm = v.new_ones(self._mask.size())
norm[self._mask] = 0
norm = norm.sum(1) + 1e-10
norm = norm[:, None, None]
avg = v.sum(-1, keepdim=True) / norm
v = v - avg
std = v.norm(dim=-1, keepdim=True) / norm + 1e-10
v = v / std
return v.contiguous()
def save_variable_version(var):
if isinstance(var.value, torch.Tensor):
var.__version__ = var.value._version
def check_variable_version(var):
if isinstance(var.value, torch.Tensor):
assert var.__version__ == var.value._version, \
"variable's value is modified, name: {}".format(var.name)