from functools import partial |
from ditk import logging |
import itertools |
import copy |
import numpy as np |
import multiprocessing |
import torch |
import torch.nn as nn |
from ding.utils import WORLD_MODEL_REGISTRY |
from ding.utils.data import default_collate |
from ding.torch_utils import unsqueeze_repeat |
from ding.world_model.base_world_model import HybridWorldModel |
from ding.world_model.model.ensemble import EnsembleModel, StandardScaler |
def tree_query(datapoint, tree, k): |
return tree.query(datapoint, k=k + 1)[1][1:] |
def get_neighbor_index(data, k, serial=False): |
""" |
data: [B, N] |
k: int |
ret: [B, k] |
""" |
try: |
from scipy.spatial import KDTree |
except ImportError: |
import sys |
logging.warning("Please install scipy first, such as `pip3 install scipy`.") |
sys.exit(1) |
data = data.cpu().numpy() |
tree = KDTree(data) |
if serial: |
nn_index = [torch.from_numpy(np.array(tree_query(d, tree, k))) for d in data] |
nn_index = torch.stack(nn_index).long() |
else: |
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) |
fn = partial(tree_query, tree=tree, k=k) |
nn_index = torch.from_numpy(np.array(list(pool.map(fn, data)), dtype=np.int32)).to(torch.long) |
pool.close() |
return nn_index |
def get_batch_jacobian(net, x, noutputs): |
x = x.unsqueeze(1) |
n = x.size()[0] |
x = x.repeat(1, noutputs, 1) |
x.requires_grad_(True) |
y = net(x) |
upstream_gradient = torch.eye(noutputs).reshape(1, noutputs, noutputs).repeat(n, 1, 1).to(x.device) |
re = torch.autograd.grad(y, x, upstream_gradient, create_graph=True)[0] |
return re |
class EnsembleGradientModel(EnsembleModel): |
def train(self, loss, loss_reg, reg): |
self.optimizer.zero_grad() |
loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar) |
loss += reg * loss_reg |
if self.use_decay: |
loss += self.get_decay_loss() |
loss.backward() |
self.optimizer.step() |
@WORLD_MODEL_REGISTRY.register('ddppo') |
class DDPPOWorldMode(HybridWorldModel, nn.Module): |
"""rollout model + gradient model""" |
config = dict( |
model=dict( |
ensemble_size=7, |
elite_size=5, |
state_size=None, |
action_size=None, |
reward_size=1, |
hidden_size=200, |
use_decay=False, |
batch_size=256, |
holdout_ratio=0.2, |
max_epochs_since_update=5, |
deterministic_rollout=True, |
gradient_model=True, |
k=3, |
reg=1, |
neighbor_pool_size=10000, |
train_freq_gradient_model=250 |
), |
) |
def __init__(self, cfg, env, tb_logger): |
HybridWorldModel.__init__(self, cfg, env, tb_logger) |
nn.Module.__init__(self) |
cfg = cfg.model |
self.ensemble_size = cfg.ensemble_size |
self.elite_size = cfg.elite_size |
self.state_size = cfg.state_size |
self.action_size = cfg.action_size |
self.reward_size = cfg.reward_size |
self.hidden_size = cfg.hidden_size |
self.use_decay = cfg.use_decay |
self.batch_size = cfg.batch_size |
self.holdout_ratio = cfg.holdout_ratio |
self.max_epochs_since_update = cfg.max_epochs_since_update |
self.deterministic_rollout = cfg.deterministic_rollout |
self.gradient_model = cfg.gradient_model |
self.k = cfg.k |
self.reg = cfg.reg |
self.neighbor_pool_size = cfg.neighbor_pool_size |
self.train_freq_gradient_model = cfg.train_freq_gradient_model |
self.rollout_model = EnsembleModel( |
self.state_size, |
self.action_size, |
self.reward_size, |
self.ensemble_size, |
self.hidden_size, |
use_decay=self.use_decay |
) |
self.scaler = StandardScaler(self.state_size + self.action_size) |
self.ensemble_mse_losses = [] |
self.model_variances = [] |
self.elite_model_idxes = [] |
if self.gradient_model: |
self.gradient_model = EnsembleGradientModel( |
self.state_size, |
self.action_size, |
self.reward_size, |
self.ensemble_size, |
self.hidden_size, |
use_decay=self.use_decay |
) |
self.elite_model_idxes_gradient_model = [] |
self.last_train_step_gradient_model = 0 |
self.serial_calc_nn = False |
if self._cuda: |
self.cuda() |
def step(self, obs, act, batch_size=8192): |
class Predict(torch.autograd.Function): |
@staticmethod |
def forward(ctx, x): |
ctx.save_for_backward(x) |
mean, var = self.rollout_model(x, ret_log_var=False) |
return torch.cat([mean, var], dim=-1) |
@staticmethod |
def backward(ctx, grad_out): |
x, = ctx.saved_tensors |
with torch.enable_grad(): |
x = x.detach() |
x.requires_grad_(True) |
mean, var = self.gradient_model(x, ret_log_var=False) |
y = torch.cat([mean, var], dim=-1) |
return torch.autograd.grad(y, x, grad_outputs=grad_out, create_graph=True) |
if len(act.shape) == 1: |
act = act.unsqueeze(1) |
if self._cuda: |
obs = obs.cuda() |
act = act.cuda() |
inputs = torch.cat([obs, act], dim=1) |
inputs = self.scaler.transform(inputs) |
ensemble_mean, ensemble_var = [], [] |
for i in range(0, inputs.shape[0], batch_size): |
input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size) |
if not torch.is_grad_enabled() or not self.gradient_model: |
b_mean, b_var = self.rollout_model(input, ret_log_var=False) |
else: |
output = Predict.apply(input) |
b_mean, b_var = output.chunk(2, dim=2) |
ensemble_mean.append(b_mean) |
ensemble_var.append(b_var) |
ensemble_mean = torch.cat(ensemble_mean, 1) |
ensemble_var = torch.cat(ensemble_var, 1) |
ensemble_mean[:, :, 1:] += obs.unsqueeze(0) |
ensemble_std = ensemble_var.sqrt() |
if self.deterministic_rollout: |
ensemble_sample = ensemble_mean |
else: |
ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std |
model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device) |
batch_idxes = torch.arange(len(obs)).to(inputs.device) |
sample = ensemble_sample[model_idxes, batch_idxes] |
rewards, next_obs = sample[:, 0], sample[:, 1:] |
return rewards, next_obs, self.env.termination_fn(next_obs) |
def eval(self, env_buffer, envstep, train_iter): |
data = env_buffer.sample(self.eval_freq, train_iter) |
data = default_collate(data) |
data['done'] = data['done'].float() |
data['weight'] = data.get('weight', None) |
obs = data['obs'] |
action = data['action'] |
reward = data['reward'] |
next_obs = data['next_obs'] |
if len(reward.shape) == 1: |
reward = reward.unsqueeze(1) |
if len(action.shape) == 1: |
action = action.unsqueeze(1) |
inputs = torch.cat([obs, action], dim=1) |
labels = torch.cat([reward, next_obs - obs], dim=1) |
if self._cuda: |
inputs = inputs.cuda() |
labels = labels.cuda() |
inputs = self.scaler.transform(inputs) |
inputs = unsqueeze_repeat(inputs, self.ensemble_size) |
labels = unsqueeze_repeat(labels, self.ensemble_size) |
with torch.no_grad(): |
mean, logvar = self.rollout_model(inputs, ret_log_var=True) |
loss, mse_loss = self.rollout_model.loss(mean, logvar, labels) |
ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2) |
model_variance = mean.var(0) |
self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep) |
self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep) |
self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep) |
self.last_eval_step = envstep |
def train(self, env_buffer, envstep, train_iter): |
def train_sample(data) -> tuple: |
data = default_collate(data) |
data['done'] = data['done'].float() |
data['weight'] = data.get('weight', None) |
obs = data['obs'] |
action = data['action'] |
reward = data['reward'] |
next_obs = data['next_obs'] |
if len(reward.shape) == 1: |
reward = reward.unsqueeze(1) |
if len(action.shape) == 1: |
action = action.unsqueeze(1) |
inputs = torch.cat([obs, action], dim=1) |
labels = torch.cat([reward, next_obs - obs], dim=1) |
if self._cuda: |
inputs = inputs.cuda() |
labels = labels.cuda() |
return inputs, labels |
logvar = dict() |
data = env_buffer.sample(env_buffer.count(), train_iter) |
inputs, labels = train_sample(data) |
logvar.update(self._train_rollout_model(inputs, labels)) |
if self.gradient_model: |
if (envstep - self.last_train_step_gradient_model) >= self.train_freq_gradient_model: |
n = min(env_buffer.count(), self.neighbor_pool_size) |
self.neighbor_pool = env_buffer.sample(n, train_iter, sample_range=slice(-n, None)) |
inputs_reg, labels_reg = train_sample(self.neighbor_pool) |
logvar.update(self._train_gradient_model(inputs, labels, inputs_reg, labels_reg)) |
self.last_train_step_gradient_model = envstep |
self.last_train_step = envstep |
if self.tb_logger is not None: |
for k, v in logvar.items(): |
self.tb_logger.add_scalar('env_model_step/' + k, v, envstep) |
def _train_rollout_model(self, inputs, labels): |
num_holdout = int(inputs.shape[0] * self.holdout_ratio) |
train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:] |
holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] |
self.scaler.fit(train_inputs) |
train_inputs = self.scaler.transform(train_inputs) |
holdout_inputs = self.scaler.transform(holdout_inputs) |
holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size) |
holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size) |
self._epochs_since_update = 0 |
self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)} |
self._save_states() |
for epoch in itertools.count(): |
train_idx = torch.stack([torch.randperm(train_inputs.shape[0]) |
for _ in range(self.ensemble_size)]).to(train_inputs.device) |
self.mse_loss = [] |
for start_pos in range(0, train_inputs.shape[0], self.batch_size): |
idx = train_idx[:, start_pos:start_pos + self.batch_size] |
train_input = train_inputs[idx] |
train_label = train_labels[idx] |
mean, logvar = self.rollout_model(train_input, ret_log_var=True) |
loss, mse_loss = self.rollout_model.loss(mean, logvar, train_label) |
self.rollout_model.train(loss) |
self.mse_loss.append(mse_loss.mean().item()) |
self.mse_loss = sum(self.mse_loss) / len(self.mse_loss) |
with torch.no_grad(): |
holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True) |
_, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels) |
self.curr_holdout_mse_loss = holdout_mse_loss.mean().item() |
break_train = self._save_best(epoch, holdout_mse_loss) |
if break_train: |
break |
self._load_states() |
with torch.no_grad(): |
holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True) |
_, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels) |
sorted_loss, sorted_loss_idx = holdout_mse_loss.sort() |
sorted_loss = sorted_loss.detach().cpu().numpy().tolist() |
sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist() |
self.elite_model_idxes = sorted_loss_idx[:self.elite_size] |
self.top_holdout_mse_loss = sorted_loss[0] |
self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2] |
self.bottom_holdout_mse_loss = sorted_loss[-1] |
self.best_holdout_mse_loss = holdout_mse_loss.mean().item() |
return { |
'rollout_model/mse_loss': self.mse_loss, |
'rollout_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss, |
'rollout_model/best_holdout_mse_loss': self.best_holdout_mse_loss, |
'rollout_model/top_holdout_mse_loss': self.top_holdout_mse_loss, |
'rollout_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss, |
'rollout_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss, |
} |
def _get_jacobian(self, model, train_input_reg): |
""" |
train_input_reg: [ensemble_size, B, state_size+action_size] |
ret: [ensemble_size, B, state_size+reward_size, state_size+action_size] |
""" |
def func(x): |
x = x.view(self.ensemble_size, -1, self.state_size + self.action_size) |
state = x[:, :, :self.state_size] |
x = self.scaler.transform(x) |
y, _ = model(x) |
null = torch.zeros_like(y) |
null[:, :, self.reward_size:] += state |
y = y + null |
return y.view(-1, self.state_size + self.reward_size, self.state_size + self.reward_size) |
train_input_reg = train_input_reg.view(-1, self.state_size + self.action_size) |
jacobian = get_batch_jacobian(func, train_input_reg, self.state_size + self.reward_size) |
return jacobian.view( |
self.ensemble_size, -1, self.state_size + self.reward_size, self.state_size + self.action_size |
) |
def _train_gradient_model(self, inputs, labels, inputs_reg, labels_reg): |
num_holdout = int(inputs.shape[0] * self.holdout_ratio) |
train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:] |
holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout] |
train_inputs = self.scaler.transform(train_inputs) |
holdout_inputs = self.scaler.transform(holdout_inputs) |
holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size) |
holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size) |
train_inputs_reg, train_labels_reg = inputs_reg, labels_reg |
neighbor_index = get_neighbor_index(train_inputs_reg, self.k, serial=self.serial_calc_nn) |
neighbor_inputs = train_inputs_reg[neighbor_index] |
neighbor_labels = train_labels_reg[neighbor_index] |
neighbor_inputs_distance = (neighbor_inputs - train_inputs_reg.unsqueeze(1)) |
neighbor_labels_distance = (neighbor_labels - train_labels_reg.unsqueeze(1)) |
self._epochs_since_update = 0 |
self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)} |
self._save_states() |
for epoch in itertools.count(): |
train_idx = torch.stack([torch.randperm(train_inputs.shape[0]) |
for _ in range(self.ensemble_size)]).to(train_inputs.device) |
train_idx_reg = torch.stack([torch.randperm(train_inputs_reg.shape[0]) |
for _ in range(self.ensemble_size)]).to(train_inputs_reg.device) |
self.mse_loss = [] |
self.grad_loss = [] |
for start_pos in range(0, train_inputs.shape[0], self.batch_size): |
idx = train_idx[:, start_pos:start_pos + self.batch_size] |
train_input = train_inputs[idx] |
train_label = train_labels[idx] |
mean, logvar = self.gradient_model(train_input, ret_log_var=True) |
loss, mse_loss = self.gradient_model.loss(mean, logvar, train_label) |
if start_pos % train_inputs_reg.shape[0] < (start_pos + self.batch_size) % train_inputs_reg.shape[0]: |
idx_reg = train_idx_reg[:, start_pos % train_inputs_reg.shape[0]:(start_pos + self.batch_size) % |
train_inputs_reg.shape[0]] |
else: |
idx_reg = train_idx_reg[:, 0:(start_pos + self.batch_size) % train_inputs_reg.shape[0]] |
train_input_reg = train_inputs_reg[idx_reg] |
neighbor_input_distance = neighbor_inputs_distance[idx_reg |
] |
neighbor_label_distance = neighbor_labels_distance[idx_reg |
] |
jacobian = self._get_jacobian(self.gradient_model, train_input_reg).unsqueeze(2).repeat_interleave( |
self.k, dim=2 |
) |
directional_derivative = (jacobian @ neighbor_input_distance.unsqueeze(-1)).squeeze( |
-1 |
) |
loss_reg = torch.pow((neighbor_label_distance - directional_derivative), |
2).sum(0).mean() |
self.gradient_model.train(loss, loss_reg, self.reg) |
self.mse_loss.append(mse_loss.mean().item()) |
self.grad_loss.append(loss_reg.item()) |
self.mse_loss = sum(self.mse_loss) / len(self.mse_loss) |
self.grad_loss = sum(self.grad_loss) / len(self.grad_loss) |
with torch.no_grad(): |
holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True) |
_, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels) |
self.curr_holdout_mse_loss = holdout_mse_loss.mean().item() |
break_train = self._save_best(epoch, holdout_mse_loss) |
if break_train: |
break |
self._load_states() |
with torch.no_grad(): |
holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True) |
_, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels) |
sorted_loss, sorted_loss_idx = holdout_mse_loss.sort() |
sorted_loss = sorted_loss.detach().cpu().numpy().tolist() |
sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist() |
self.elite_model_idxes_gradient_model = sorted_loss_idx[:self.elite_size] |
self.top_holdout_mse_loss = sorted_loss[0] |
self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2] |
self.bottom_holdout_mse_loss = sorted_loss[-1] |
self.best_holdout_mse_loss = holdout_mse_loss.mean().item() |
return { |
'gradient_model/mse_loss': self.mse_loss, |
'gradient_model/grad_loss': self.grad_loss, |
'gradient_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss, |
'gradient_model/best_holdout_mse_loss': self.best_holdout_mse_loss, |
'gradient_model/top_holdout_mse_loss': self.top_holdout_mse_loss, |
'gradient_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss, |
'gradient_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss, |
} |
def _save_states(self, ): |
self._states = copy.deepcopy(self.state_dict()) |
def _save_state(self, id): |
state_dict = self.state_dict() |
for k, v in state_dict.items(): |
if 'weight' in k or 'bias' in k: |
self._states[k].data[id] = copy.deepcopy(v.data[id]) |
def _load_states(self): |
self.load_state_dict(self._states) |
def _save_best(self, epoch, holdout_losses): |
updated = False |
for i in range(len(holdout_losses)): |
current = holdout_losses[i] |
_, best = self._snapshots[i] |
improvement = (best - current) / best |
if improvement > 0.01: |
self._snapshots[i] = (epoch, current) |
self._save_state(i) |
updated = True |
if updated: |
self._epochs_since_update = 0 |
else: |
self._epochs_since_update += 1 |
return self._epochs_since_update > self.max_epochs_since_update |