zjowowen's picture
init space
079c32c
from typing import Tuple, List
from collections import namedtuple
import torch
import torch.nn.functional as F
EPS = 1e-8
def acer_policy_error(
q_values: torch.Tensor,
q_retraces: torch.Tensor,
v_pred: torch.Tensor,
target_logit: torch.Tensor,
actions: torch.Tensor,
ratio: torch.Tensor,
c_clip_ratio: float = 10.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Get ACER policy loss.
Arguments:
- q_values (:obj:`torch.Tensor`): Q values
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
- v_pred (:obj:`torch.Tensor`): V values
- target_pi (:obj:`torch.Tensor`): The new policy's probability
- actions (:obj:`torch.Tensor`): The actions in replay buffer
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
- c_clip_ratio (:obj:`float`): clip value for ratio
Returns:
- actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace
- bc_loss (:obj:`torch.Tensor`): correct policy loss
Shapes:
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- actions (:obj:`torch.LongTensor`): :math:`(T, B)`
- ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
Examples:
>>> q_values=torch.randn(2, 3, 4),
>>> q_retraces=torch.randn(2, 3, 1),
>>> v_pred=torch.randn(2, 3, 1),
>>> target_pi=torch.randn(2, 3, 4),
>>> actions=torch.randint(0, 4, (2, 3)),
>>> ratio=torch.randn(2, 3, 4),
>>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)
"""
actions = actions.unsqueeze(-1)
with torch.no_grad():
advantage_retraces = q_retraces - v_pred # shape T,B,1
advantage_native = q_values - v_pred # shape T,B,env_action_shape
actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather(
-1, actions
) # shape T,B,1
# bias correction term, the first target_pi will not calculate gradient flow
bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \
advantage_native*target_logit # shape T,B,env_action_shape
bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True)
return actor_loss, bias_correction_loss
def acer_value_error(q_values, q_retraces, actions):
"""
Overview:
Get ACER critic loss.
Arguments:
- q_values (:obj:`torch.Tensor`): Q values
- q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
- actions (:obj:`torch.Tensor`): The actions in replay buffer
- ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
Returns:
- critic_loss (:obj:`torch.Tensor`): critic loss
Shapes:
- q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
- q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
- actions (:obj:`torch.LongTensor`): :math:`(T, B)`
- critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
Examples:
>>> q_values=torch.randn(2, 3, 4)
>>> q_retraces=torch.randn(2, 3, 1)
>>> actions=torch.randint(0, 4, (2, 3))
>>> loss = acer_value_error(q_values, q_retraces, actions)
"""
actions = actions.unsqueeze(-1)
critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2)
return critic_loss
def acer_trust_region_update(
actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor,
trust_region_value: float
) -> List[torch.Tensor]:
"""
Overview:
calcuate gradient with trust region constrain
Arguments:
- actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part
- target_pi (:obj:`torch.Tensor`): The new policy's probability
- avg_pi (:obj:`torch.Tensor`): The average policy's probability
- trust_region_value (:obj:`float`): the range of trust region
Returns:
- update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint
Shapes:
- target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
- update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)`
Examples:
>>> actor_gradients=[torch.randn(2, 3, 4)]
>>> target_pi=torch.randn(2, 3, 4)
>>> avg_pi=torch.randn(2, 3, 4)
>>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)
"""
with torch.no_grad():
KL_gradients = [torch.exp(avg_logit)]
update_gradients = []
# TODO: here is only one elements in this list.Maybe will use to more elements in the future
actor_gradient = actor_gradients[0]
KL_gradient = KL_gradients[0]
scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value
scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0)
update_gradients.append(actor_gradient - scale * KL_gradient)
return update_gradients