|
from typing import Tuple, List |
|
from collections import namedtuple |
|
import torch |
|
import torch.nn.functional as F |
|
EPS = 1e-8 |
|
|
|
|
|
def acer_policy_error( |
|
q_values: torch.Tensor, |
|
q_retraces: torch.Tensor, |
|
v_pred: torch.Tensor, |
|
target_logit: torch.Tensor, |
|
actions: torch.Tensor, |
|
ratio: torch.Tensor, |
|
c_clip_ratio: float = 10.0 |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Get ACER policy loss. |
|
Arguments: |
|
- q_values (:obj:`torch.Tensor`): Q values |
|
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) |
|
- v_pred (:obj:`torch.Tensor`): V values |
|
- target_pi (:obj:`torch.Tensor`): The new policy's probability |
|
- actions (:obj:`torch.Tensor`): The actions in replay buffer |
|
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy |
|
- c_clip_ratio (:obj:`float`): clip value for ratio |
|
Returns: |
|
- actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace |
|
- bc_loss (:obj:`torch.Tensor`): correct policy loss |
|
Shapes: |
|
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim |
|
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
- v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` |
|
- actions (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)` |
|
- actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
- bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
Examples: |
|
>>> q_values=torch.randn(2, 3, 4), |
|
>>> q_retraces=torch.randn(2, 3, 1), |
|
>>> v_pred=torch.randn(2, 3, 1), |
|
>>> target_pi=torch.randn(2, 3, 4), |
|
>>> actions=torch.randint(0, 4, (2, 3)), |
|
>>> ratio=torch.randn(2, 3, 4), |
|
>>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio) |
|
""" |
|
actions = actions.unsqueeze(-1) |
|
with torch.no_grad(): |
|
advantage_retraces = q_retraces - v_pred |
|
advantage_native = q_values - v_pred |
|
actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather( |
|
-1, actions |
|
) |
|
|
|
|
|
bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \ |
|
advantage_native*target_logit |
|
bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True) |
|
return actor_loss, bias_correction_loss |
|
|
|
|
|
def acer_value_error(q_values, q_retraces, actions): |
|
""" |
|
Overview: |
|
Get ACER critic loss. |
|
Arguments: |
|
- q_values (:obj:`torch.Tensor`): Q values |
|
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method) |
|
- actions (:obj:`torch.Tensor`): The actions in replay buffer |
|
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy |
|
Returns: |
|
- critic_loss (:obj:`torch.Tensor`): critic loss |
|
Shapes: |
|
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim |
|
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
- actions (:obj:`torch.LongTensor`): :math:`(T, B)` |
|
- critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)` |
|
Examples: |
|
>>> q_values=torch.randn(2, 3, 4) |
|
>>> q_retraces=torch.randn(2, 3, 1) |
|
>>> actions=torch.randint(0, 4, (2, 3)) |
|
>>> loss = acer_value_error(q_values, q_retraces, actions) |
|
""" |
|
actions = actions.unsqueeze(-1) |
|
critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2) |
|
return critic_loss |
|
|
|
|
|
def acer_trust_region_update( |
|
actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor, |
|
trust_region_value: float |
|
) -> List[torch.Tensor]: |
|
""" |
|
Overview: |
|
calcuate gradient with trust region constrain |
|
Arguments: |
|
- actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part |
|
- target_pi (:obj:`torch.Tensor`): The new policy's probability |
|
- avg_pi (:obj:`torch.Tensor`): The average policy's probability |
|
- trust_region_value (:obj:`float`): the range of trust region |
|
Returns: |
|
- update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint |
|
Shapes: |
|
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` |
|
- avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)` |
|
- update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)` |
|
Examples: |
|
>>> actor_gradients=[torch.randn(2, 3, 4)] |
|
>>> target_pi=torch.randn(2, 3, 4) |
|
>>> avg_pi=torch.randn(2, 3, 4) |
|
>>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1) |
|
""" |
|
with torch.no_grad(): |
|
KL_gradients = [torch.exp(avg_logit)] |
|
update_gradients = [] |
|
|
|
actor_gradient = actor_gradients[0] |
|
KL_gradient = KL_gradients[0] |
|
scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value |
|
scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0) |
|
update_gradients.append(actor_gradient - scale * KL_gradient) |
|
return update_gradients |
|
|