File size: 5,622 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Tuple, List
from collections import namedtuple
import torch
import torch.nn.functional as F
EPS = 1e-8


def acer_policy_error(
        q_values: torch.Tensor,
        q_retraces: torch.Tensor,
        v_pred: torch.Tensor,
        target_logit: torch.Tensor,
        actions: torch.Tensor,
        ratio: torch.Tensor,
        c_clip_ratio: float = 10.0
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Overview:
        Get ACER policy loss.
    Arguments:
        - q_values (:obj:`torch.Tensor`): Q values
        - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
        - v_pred (:obj:`torch.Tensor`): V values
        - target_pi (:obj:`torch.Tensor`): The new policy's probability
        - actions (:obj:`torch.Tensor`): The actions in replay buffer
        - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
        - c_clip_ratio (:obj:`float`): clip value for ratio
    Returns:
        - actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace
        - bc_loss (:obj:`torch.Tensor`): correct policy loss
    Shapes:
        - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
        - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
        - v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
        - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
        - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
        - ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
        - actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
        - bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
    Examples:
        >>> q_values=torch.randn(2, 3, 4),
        >>> q_retraces=torch.randn(2, 3, 1),
        >>> v_pred=torch.randn(2, 3, 1),
        >>> target_pi=torch.randn(2, 3, 4),
        >>> actions=torch.randint(0, 4, (2, 3)),
        >>> ratio=torch.randn(2, 3, 4),
        >>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)
    """
    actions = actions.unsqueeze(-1)
    with torch.no_grad():
        advantage_retraces = q_retraces - v_pred  # shape T,B,1
        advantage_native = q_values - v_pred  # shape T,B,env_action_shape
    actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather(
        -1, actions
    )  # shape T,B,1

    # bias correction term, the first target_pi will not calculate gradient flow
    bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \
        advantage_native*target_logit  # shape T,B,env_action_shape
    bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True)
    return actor_loss, bias_correction_loss


def acer_value_error(q_values, q_retraces, actions):
    """
    Overview:
        Get ACER critic loss.
    Arguments:
        - q_values (:obj:`torch.Tensor`): Q values
        - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
        - actions (:obj:`torch.Tensor`): The actions in replay buffer
        - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
    Returns:
        - critic_loss (:obj:`torch.Tensor`): critic loss
    Shapes:
        - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
        - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
        - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
        - critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
    Examples:
        >>> q_values=torch.randn(2, 3, 4)
        >>> q_retraces=torch.randn(2, 3, 1)
        >>> actions=torch.randint(0, 4, (2, 3))
        >>> loss = acer_value_error(q_values, q_retraces, actions)
    """
    actions = actions.unsqueeze(-1)
    critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2)
    return critic_loss


def acer_trust_region_update(
        actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor,
        trust_region_value: float
) -> List[torch.Tensor]:
    """
    Overview:
        calcuate gradient with trust region constrain
    Arguments:
        - actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part
        - target_pi (:obj:`torch.Tensor`): The new policy's probability
        - avg_pi (:obj:`torch.Tensor`): The average policy's probability
        - trust_region_value (:obj:`float`): the range of trust region
    Returns:
        - update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint
    Shapes:
        - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
        - avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
        - update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)`
    Examples:
        >>> actor_gradients=[torch.randn(2, 3, 4)]
        >>> target_pi=torch.randn(2, 3, 4)
        >>> avg_pi=torch.randn(2, 3, 4)
        >>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)
    """
    with torch.no_grad():
        KL_gradients = [torch.exp(avg_logit)]
    update_gradients = []
    # TODO: here is only one elements in this list.Maybe will use to more elements in the future
    actor_gradient = actor_gradients[0]
    KL_gradient = KL_gradients[0]
    scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value
    scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0)
    update_gradients.append(actor_gradient - scale * KL_gradient)
    return update_gradients