|
import torch |
|
import torch.nn.functional as F |
|
from torch.distributions import Categorical, Independent, Normal |
|
from collections import namedtuple |
|
from .isw import compute_importance_weights |
|
from ding.hpc_rl import hpc_wrapper |
|
|
|
|
|
def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95): |
|
""" |
|
Overview: |
|
Computation of vtrace return. |
|
Returns: |
|
- vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor |
|
Shapes: |
|
- clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size |
|
- clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)` |
|
- vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
""" |
|
deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1]) |
|
factor = gamma * lambda_ |
|
result = bootstrap_values[:-1].clone() |
|
vtrace_item = 0. |
|
for t in reversed(range(reward.size()[0])): |
|
vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item |
|
result[t] += vtrace_item |
|
return result |
|
|
|
|
|
def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma): |
|
""" |
|
Overview: |
|
Computation of vtrace advantage. |
|
Returns: |
|
- vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor |
|
Shapes: |
|
- clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size |
|
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
- return (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
- bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
- vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)` |
|
""" |
|
return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values) |
|
|
|
|
|
vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight']) |
|
vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) |
|
|
|
|
|
def shape_fn_vtrace_discrete_action(args, kwargs): |
|
r""" |
|
Overview: |
|
Return shape of vtrace for hpc |
|
Returns: |
|
shape: [T, B, N] |
|
""" |
|
if len(args) <= 0: |
|
tmp = kwargs['data'].target_output.shape |
|
else: |
|
tmp = args[0].target_output.shape |
|
return tmp |
|
|
|
|
|
@hpc_wrapper( |
|
shape_fn=shape_fn_vtrace_discrete_action, |
|
namedtuple_data=True, |
|
include_args=[0, 1, 2, 3, 4, 5], |
|
include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio'] |
|
) |
|
def vtrace_error_discrete_action( |
|
data: namedtuple, |
|
gamma: float = 0.99, |
|
lambda_: float = 0.95, |
|
rho_clip_ratio: float = 1.0, |
|
c_clip_ratio: float = 1.0, |
|
rho_pg_clip_ratio: float = 1.0 |
|
): |
|
""" |
|
Overview: |
|
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ |
|
Architectures), (arXiv:1802.01561) |
|
Arguments: |
|
- data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` |
|
- target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ |
|
usually this output is network output logit |
|
- behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ |
|
usually this output is network output logit, which is used to produce the trajectory(collector) |
|
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ |
|
i.e.: behaviour_action |
|
- gamma: (:obj:`float`): the future discount factor, defaults to 0.95 |
|
- lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 |
|
- rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ |
|
the baseline targets (vs) |
|
- c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ |
|
the baseline targets (vs) |
|
- rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ |
|
the policy gradient advantage |
|
Returns: |
|
- trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor |
|
Shapes: |
|
- target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ |
|
N is action dim |
|
- behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` |
|
- action (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` |
|
- reward (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- weight (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
Examples: |
|
>>> T, B, N = 4, 8, 16 |
|
>>> value = torch.randn(T + 1, B).requires_grad_(True) |
|
>>> reward = torch.rand(T, B) |
|
>>> target_output = torch.randn(T, B, N).requires_grad_(True) |
|
>>> behaviour_output = torch.randn(T, B, N) |
|
>>> action = torch.randint(0, N, size=(T, B)) |
|
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) |
|
>>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1) |
|
""" |
|
target_output, behaviour_output, action, value, reward, weight = data |
|
with torch.no_grad(): |
|
IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete') |
|
rhos = torch.clamp(IS, max=rho_clip_ratio) |
|
cs = torch.clamp(IS, max=c_clip_ratio) |
|
return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) |
|
pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) |
|
return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) |
|
adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) |
|
|
|
if weight is None: |
|
weight = torch.ones_like(reward) |
|
dist_target = Categorical(logits=target_output) |
|
pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() |
|
value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() |
|
entropy_loss = (dist_target.entropy() * weight).mean() |
|
return vtrace_loss(pg_loss, value_loss, entropy_loss) |
|
|
|
|
|
def vtrace_error_continuous_action( |
|
data: namedtuple, |
|
gamma: float = 0.99, |
|
lambda_: float = 0.95, |
|
rho_clip_ratio: float = 1.0, |
|
c_clip_ratio: float = 1.0, |
|
rho_pg_clip_ratio: float = 1.0 |
|
): |
|
""" |
|
Overview: |
|
Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ |
|
Architectures), (arXiv:1802.01561) |
|
Arguments: |
|
- data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data`` |
|
- target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ |
|
by the current policy network, usually this output is network output, \ |
|
which represents the distribution by reparameterization trick. |
|
- behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \ |
|
by the behaviour policy network, usually this output is network output logit, \ |
|
which represents the distribution by reparameterization trick. |
|
- action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \ |
|
i.e.: behaviour_action |
|
- gamma: (:obj:`float`): the future discount factor, defaults to 0.95 |
|
- lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 |
|
- rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ |
|
the baseline targets (vs) |
|
- c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ |
|
the baseline targets (vs) |
|
- rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ |
|
the policy gradient advantage |
|
Returns: |
|
- trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor |
|
Shapes: |
|
- target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \ |
|
where T is timestep, B is batch size and \ |
|
N is action dim. The keys are usually parameters of reparameterization trick. |
|
- behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)` |
|
- action (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` |
|
- reward (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- weight (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
Examples: |
|
>>> T, B, N = 4, 8, 16 |
|
>>> value = torch.randn(T + 1, B).requires_grad_(True) |
|
>>> reward = torch.rand(T, B) |
|
>>> target_output = dict( |
|
>>> 'mu': torch.randn(T, B, N).requires_grad_(True), |
|
>>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)), |
|
>>> ) |
|
>>> behaviour_output = dict( |
|
>>> 'mu': torch.randn(T, B, N), |
|
>>> 'sigma': torch.exp(torch.randn(T, B, N)), |
|
>>> ) |
|
>>> action = torch.randn((T, B, N)) |
|
>>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None) |
|
>>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1) |
|
""" |
|
target_output, behaviour_output, action, value, reward, weight = data |
|
with torch.no_grad(): |
|
IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous') |
|
rhos = torch.clamp(IS, max=rho_clip_ratio) |
|
cs = torch.clamp(IS, max=c_clip_ratio) |
|
return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) |
|
pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) |
|
return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) |
|
adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) |
|
|
|
if weight is None: |
|
weight = torch.ones_like(reward) |
|
dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1) |
|
pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() |
|
value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() |
|
entropy_loss = (dist_target.entropy() * weight).mean() |
|
return vtrace_loss(pg_loss, value_loss, entropy_loss) |
|
|