File size: 2,664 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
from collections import namedtuple
import torch
from ding.hpc_rl import hpc_wrapper
gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag'])
def shape_fn_gae(args, kwargs):
r"""
Overview:
Return shape of gae for hpc
Returns:
shape: [T, B]
"""
if len(args) <= 0:
tmp = kwargs['data'].reward.shape
else:
tmp = args[0].reward.shape
return tmp
@hpc_wrapper(
shape_fn=shape_fn_gae, namedtuple_data=True, include_args=[0, 1, 2], include_kwargs=['data', 'gamma', 'lambda_']
)
def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor:
"""
Overview:
Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
Arguments:
- data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or \
trajectories data.
- gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99.
- lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, \
it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
Returns:
- adv (:obj:`torch.FloatTensor`): the calculated advantage
Shapes:
- value (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is trajectory length and B is batch size
- next_value (:obj:`torch.FloatTensor`): :math:`(T, B)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
- adv (:obj:`torch.FloatTensor`): :math:`(T, B)`
Examples:
>>> value = torch.randn(2, 3)
>>> next_value = torch.randn(2, 3)
>>> reward = torch.randn(2, 3)
>>> data = gae_data(value, next_value, reward, None, None)
>>> adv = gae(data)
"""
value, next_value, reward, done, traj_flag = data
if done is None:
done = torch.zeros_like(reward, device=reward.device)
if traj_flag is None:
traj_flag = done
done = done.float()
traj_flag = traj_flag.float()
if len(value.shape) == len(reward.shape) + 1: # for some marl case: value(T, B, A), reward(T, B)
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
traj_flag = traj_flag.unsqueeze(-1)
next_value *= (1 - done)
delta = reward + gamma * next_value - value
factor = gamma * lambda_ * (1 - traj_flag)
adv = torch.zeros_like(value)
gae_item = torch.zeros_like(value[0])
for t in reversed(range(reward.shape[0])):
gae_item = delta[t] + factor[t] * gae_item
adv[t] = gae_item
return adv
|