zjowowen's picture
init space
079c32c
raw
history blame
4 kB
from collections import namedtuple
import torch
import torch.nn.functional as F
from torch.distributions import Independent, Normal
a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight'])
a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
def a2c_error(data: namedtuple) -> namedtuple:
"""
Overview:
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space
Arguments:
- data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
Returns:
- a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- value (:obj:`torch.FloatTensor`): :math:`(B, )`
- adv (:obj:`torch.FloatTensor`): :math:`(B, )`
- return (:obj:`torch.FloatTensor`): :math:`(B, )`
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- value_loss (:obj:`torch.FloatTensor`): :math:`()`
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> data = a2c_data(
>>> logit=torch.randn(2, 3),
>>> action=torch.randint(0, 3, (2, )),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error(data)
"""
logit, action, value, adv, return_, weight = data
if weight is None:
weight = torch.ones_like(value)
dist = torch.distributions.categorical.Categorical(logits=logit)
logp = dist.log_prob(action)
entropy_loss = (dist.entropy() * weight).mean()
policy_loss = -(logp * adv * weight).mean()
value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
return a2c_loss(policy_loss, value_loss, entropy_loss)
def a2c_error_continuous(data: namedtuple) -> namedtuple:
"""
Overview:
Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space
Arguments:
- data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
Returns:
- a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
Shapes:
- logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
- action (:obj:`torch.LongTensor`): :math:`(B, N)`
- value (:obj:`torch.FloatTensor`): :math:`(B, )`
- adv (:obj:`torch.FloatTensor`): :math:`(B, )`
- return (:obj:`torch.FloatTensor`): :math:`(B, )`
- weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
- policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
- value_loss (:obj:`torch.FloatTensor`): :math:`()`
- entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
Examples:
>>> data = a2c_data(
>>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)},
>>> action=torch.randn(2, 3),
>>> value=torch.randn(2, ),
>>> adv=torch.randn(2, ),
>>> return_=torch.randn(2, ),
>>> weight=torch.ones(2, ),
>>> )
>>> loss = a2c_error_continuous(data)
"""
logit, action, value, adv, return_, weight = data
if weight is None:
weight = torch.ones_like(value)
dist = Independent(Normal(logit['mu'], logit['sigma']), 1)
logp = dist.log_prob(action)
entropy_loss = (dist.entropy() * weight).mean()
policy_loss = -(logp * adv * weight).mean()
value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
return a2c_loss(policy_loss, value_loss, entropy_loss)