File size: 2,664 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from collections import namedtuple
import torch
from ding.hpc_rl import hpc_wrapper

gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag'])


def shape_fn_gae(args, kwargs):
    r"""
    Overview:
        Return shape of gae for hpc
    Returns:
        shape: [T, B]
    """
    if len(args) <= 0:
        tmp = kwargs['data'].reward.shape
    else:
        tmp = args[0].reward.shape
    return tmp


@hpc_wrapper(
    shape_fn=shape_fn_gae, namedtuple_data=True, include_args=[0, 1, 2], include_kwargs=['data', 'gamma', 'lambda_']
)
def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor:
    """
    Overview:
        Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
    Arguments:
        - data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or \
            trajectories data.
        - gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99.
        - lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, \
            it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
    Returns:
        - adv (:obj:`torch.FloatTensor`): the calculated advantage
    Shapes:
        - value (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is trajectory length and B is batch size
        - next_value (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - adv (:obj:`torch.FloatTensor`): :math:`(T, B)`
    Examples:
        >>> value = torch.randn(2, 3)
        >>> next_value = torch.randn(2, 3)
        >>> reward = torch.randn(2, 3)
        >>> data = gae_data(value, next_value, reward, None, None)
        >>> adv = gae(data)
    """
    value, next_value, reward, done, traj_flag = data
    if done is None:
        done = torch.zeros_like(reward, device=reward.device)
    if traj_flag is None:
        traj_flag = done
    done = done.float()
    traj_flag = traj_flag.float()
    if len(value.shape) == len(reward.shape) + 1:  # for some marl case: value(T, B, A), reward(T, B)
        reward = reward.unsqueeze(-1)
        done = done.unsqueeze(-1)
        traj_flag = traj_flag.unsqueeze(-1)

    next_value *= (1 - done)
    delta = reward + gamma * next_value - value
    factor = gamma * lambda_ * (1 - traj_flag)
    adv = torch.zeros_like(value)
    gae_item = torch.zeros_like(value[0])

    for t in reversed(range(reward.shape[0])):
        gae_item = delta[t] + factor[t] * gae_item
        adv[t] = gae_item
    return adv