File size: 10,665 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import torch.nn.functional as F
from torch.distributions import Categorical, Independent, Normal
from collections import namedtuple
from .isw import compute_importance_weights
from ding.hpc_rl import hpc_wrapper


def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95):
    """
    Overview:
        Computation of vtrace return.
    Returns:
        - vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor
    Shapes:
        - clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
        - clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
        - vtrace_return (:obj:`torch.FloatTensor`):  :math:`(T, B)`
    """
    deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1])
    factor = gamma * lambda_
    result = bootstrap_values[:-1].clone()
    vtrace_item = 0.
    for t in reversed(range(reward.size()[0])):
        vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item
        result[t] += vtrace_item
    return result


def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma):
    """
    Overview:
        Computation of vtrace advantage.
    Returns:
        - vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
        - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - return (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)`
        - vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)`
    """
    return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values)


vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight'])
vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss'])


def shape_fn_vtrace_discrete_action(args, kwargs):
    r"""
    Overview:
        Return shape of vtrace for hpc
    Returns:
        shape: [T, B, N]
    """
    if len(args) <= 0:
        tmp = kwargs['data'].target_output.shape
    else:
        tmp = args[0].target_output.shape
    return tmp


@hpc_wrapper(
    shape_fn=shape_fn_vtrace_discrete_action,
    namedtuple_data=True,
    include_args=[0, 1, 2, 3, 4, 5],
    include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio']
)
def vtrace_error_discrete_action(
    data: namedtuple,
    gamma: float = 0.99,
    lambda_: float = 0.95,
    rho_clip_ratio: float = 1.0,
    c_clip_ratio: float = 1.0,
    rho_pg_clip_ratio: float = 1.0
):
    """
    Overview:
        Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\
        Architectures), (arXiv:1802.01561)
    Arguments:
        - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data``
            - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\
                usually this output is network output logit
            - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\
                usually this output is network output logit, which is used to produce the trajectory(collector)
            - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
                i.e.: behaviour_action
        - gamma: (:obj:`float`): the future discount factor, defaults to 0.95
        - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
        - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
            the baseline targets (vs)
        - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\
            the baseline targets (vs)
        - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
            the policy gradient advantage
    Returns:
        - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\
            N is action dim
        - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
        - action (:obj:`torch.LongTensor`): :math:`(T, B)`
        - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
        - reward (:obj:`torch.LongTensor`): :math:`(T, B)`
        - weight (:obj:`torch.LongTensor`): :math:`(T, B)`
    Examples:
        >>> T, B, N = 4, 8, 16
        >>> value = torch.randn(T + 1, B).requires_grad_(True)
        >>> reward = torch.rand(T, B)
        >>> target_output = torch.randn(T, B, N).requires_grad_(True)
        >>> behaviour_output = torch.randn(T, B, N)
        >>> action = torch.randint(0, N, size=(T, B))
        >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
        >>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
    """
    target_output, behaviour_output, action, value, reward, weight = data
    with torch.no_grad():
        IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete')
        rhos = torch.clamp(IS, max=rho_clip_ratio)
        cs = torch.clamp(IS, max=c_clip_ratio)
        return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_)
        pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio)
        return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0)
        adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma)

    if weight is None:
        weight = torch.ones_like(reward)
    dist_target = Categorical(logits=target_output)
    pg_loss = -(dist_target.log_prob(action) * adv * weight).mean()
    value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean()
    entropy_loss = (dist_target.entropy() * weight).mean()
    return vtrace_loss(pg_loss, value_loss, entropy_loss)


def vtrace_error_continuous_action(
    data: namedtuple,
    gamma: float = 0.99,
    lambda_: float = 0.95,
    rho_clip_ratio: float = 1.0,
    c_clip_ratio: float = 1.0,
    rho_pg_clip_ratio: float = 1.0
):
    """
    Overview:
        Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\
        Architectures), (arXiv:1802.01561)
    Arguments:
        - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data``
            - target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \
                by the current policy network, usually this output is network output, \
                which represents the distribution by reparameterization trick.
            - behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \
                by the behaviour policy network, usually this output is network output logit, \
                which represents the distribution by reparameterization trick.
            - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \
                i.e.: behaviour_action
        - gamma: (:obj:`float`): the future discount factor, defaults to 0.95
        - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
        - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
            the baseline targets (vs)
        - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\
            the baseline targets (vs)
        - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
            the policy gradient advantage
    Returns:
        - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
    Shapes:
        - target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \
            where T is timestep, B is batch size and \
            N is action dim. The keys are usually parameters of reparameterization trick.
        - behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`
        - action (:obj:`torch.LongTensor`): :math:`(T, B)`
        - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
        - reward (:obj:`torch.LongTensor`): :math:`(T, B)`
        - weight (:obj:`torch.LongTensor`): :math:`(T, B)`
    Examples:
        >>> T, B, N = 4, 8, 16
        >>> value = torch.randn(T + 1, B).requires_grad_(True)
        >>> reward = torch.rand(T, B)
        >>> target_output = dict(
        >>>     'mu': torch.randn(T, B, N).requires_grad_(True),
        >>>     'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)),
        >>> )
        >>> behaviour_output = dict(
        >>>     'mu': torch.randn(T, B, N),
        >>>     'sigma': torch.exp(torch.randn(T, B, N)),
        >>> )
        >>> action = torch.randn((T, B, N))
        >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
        >>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
    """
    target_output, behaviour_output, action, value, reward, weight = data
    with torch.no_grad():
        IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous')
        rhos = torch.clamp(IS, max=rho_clip_ratio)
        cs = torch.clamp(IS, max=c_clip_ratio)
        return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_)
        pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio)
        return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0)
        adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma)

    if weight is None:
        weight = torch.ones_like(reward)
    dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
    pg_loss = -(dist_target.log_prob(action) * adv * weight).mean()
    value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean()
    entropy_loss = (dist_target.entropy() * weight).mean()
    return vtrace_loss(pg_loss, value_loss, entropy_loss)