|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
from typing import Tuple, Dict |
|
|
|
import torch |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
|
|
class Pd(object): |
|
""" |
|
Overview: |
|
Abstract class for parameterizable probability distributions and sampling functions. |
|
Interfaces: |
|
``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` |
|
|
|
.. tip:: |
|
|
|
In dereived classes, `logits` should be an attribute member stored in class. |
|
""" |
|
|
|
def neglogp(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate cross_entropy between input x and logits |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): the input tensor |
|
Return: |
|
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss |
|
""" |
|
raise NotImplementedError |
|
|
|
def entropy(self) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate the softmax entropy of logits |
|
Arguments: |
|
- reduction (:obj:`str`): support [None, 'mean'], default set to 'mean' |
|
Returns: |
|
- entropy (:obj:`torch.Tensor`): the calculated entropy |
|
""" |
|
raise NotImplementedError |
|
|
|
def noise_mode(self): |
|
""" |
|
Overview: |
|
Add noise to logits. This method is designed for randomness |
|
""" |
|
raise NotImplementedError |
|
|
|
def mode(self): |
|
""" |
|
Overview: |
|
Return logits argmax result. This method is designed for deterministic. |
|
""" |
|
raise NotImplementedError |
|
|
|
def sample(self): |
|
""" |
|
Overview: |
|
Sample from logits's distribution by using softmax. This method is designed for multinomial. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class CategoricalPd(Pd): |
|
""" |
|
Overview: |
|
Catagorical probility distribution sampler |
|
Interfaces: |
|
``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` |
|
""" |
|
|
|
def __init__(self, logits: torch.Tensor = None) -> None: |
|
""" |
|
Overview: |
|
Init the Pd with logits |
|
Arguments: |
|
- logits (:obj:torch.Tensor): logits to sample from |
|
""" |
|
self.update_logits(logits) |
|
|
|
def update_logits(self, logits: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
Updata logits |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): logits to update |
|
""" |
|
self.logits = logits |
|
|
|
def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate cross_entropy between input x and logits |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): the input tensor |
|
- reduction (:obj:`str`): support [None, 'mean'], default set to mean |
|
Return: |
|
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss |
|
""" |
|
return F.cross_entropy(self.logits, x, reduction=reduction) |
|
|
|
def entropy(self, reduction: str = 'mean') -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate the softmax entropy of logits |
|
Arguments: |
|
- reduction (:obj:`str`): support [None, 'mean'], default set to mean |
|
Returns: |
|
- entropy (:obj:`torch.Tensor`): the calculated entropy |
|
""" |
|
a = self.logits - self.logits.max(dim=-1, keepdim=True)[0] |
|
ea = torch.exp(a) |
|
z = ea.sum(dim=-1, keepdim=True) |
|
p = ea / z |
|
entropy = (p * (torch.log(z) - a)).sum(dim=-1) |
|
assert (reduction in [None, 'mean']) |
|
if reduction is None: |
|
return entropy |
|
elif reduction == 'mean': |
|
return entropy.mean() |
|
|
|
def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: |
|
""" |
|
Overview: |
|
add noise to logits |
|
Arguments: |
|
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ |
|
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) |
|
Returns: |
|
- result (:obj:`torch.Tensor`): noised logits |
|
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. |
|
""" |
|
u = torch.rand_like(self.logits) |
|
u = -torch.log(-torch.log(u)) |
|
noise_logits = self.logits + u |
|
result = noise_logits.argmax(dim=-1) |
|
if viz: |
|
viz_feature = {} |
|
viz_feature['logits'] = self.logits.data.cpu().numpy() |
|
viz_feature['noise'] = u.data.cpu().numpy() |
|
viz_feature['noise_logits'] = noise_logits.data.cpu().numpy() |
|
return result, viz_feature |
|
else: |
|
return result |
|
|
|
def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: |
|
""" |
|
Overview: |
|
return logits argmax result |
|
Arguments: |
|
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; |
|
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) |
|
Returns: |
|
- result (:obj:`torch.Tensor`): the logits argmax result |
|
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. |
|
""" |
|
result = self.logits.argmax(dim=-1) |
|
if viz: |
|
viz_feature = {} |
|
viz_feature['logits'] = self.logits.data.cpu().numpy() |
|
return result, viz_feature |
|
else: |
|
return result |
|
|
|
def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: |
|
""" |
|
Overview: |
|
Sample from logits's distribution by using softmax |
|
Arguments: |
|
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ |
|
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) |
|
Returns: |
|
- result (:obj:`torch.Tensor`): the logits sampled result |
|
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. |
|
""" |
|
p = torch.softmax(self.logits, dim=1) |
|
result = torch.multinomial(p, 1).squeeze(1) |
|
if viz: |
|
viz_feature = {} |
|
viz_feature['logits'] = self.logits.data.cpu().numpy() |
|
return result, viz_feature |
|
else: |
|
return result |
|
|
|
|
|
class CategoricalPdPytorch(torch.distributions.Categorical): |
|
""" |
|
Overview: |
|
Wrapped ``torch.distributions.Categorical`` |
|
|
|
Interfaces: |
|
``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy`` |
|
""" |
|
|
|
def __init__(self, probs: torch.Tensor = None) -> None: |
|
""" |
|
Overview: |
|
Initialize the CategoricalPdPytorch object. |
|
Arguments: |
|
- probs (:obj:`torch.Tensor`): The tensor of probabilities. |
|
""" |
|
if probs is not None: |
|
self.update_probs(probs) |
|
|
|
def update_logits(self, logits: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
Updata logits |
|
Arguments: |
|
- logits (:obj:`torch.Tensor`): logits to update |
|
""" |
|
super().__init__(logits=logits) |
|
|
|
def update_probs(self, probs: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
Updata probs |
|
Arguments: |
|
- probs (:obj:`torch.Tensor`): probs to update |
|
""" |
|
super().__init__(probs=probs) |
|
|
|
def sample(self) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Sample from logits's distribution by using softmax |
|
Return: |
|
- result (:obj:`torch.Tensor`): the logits sampled result |
|
""" |
|
return super().sample() |
|
|
|
def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate cross_entropy between input x and logits |
|
Arguments: |
|
- actions (:obj:`torch.Tensor`): the input action tensor |
|
- reduction (:obj:`str`): support [None, 'mean'], default set to mean |
|
Return: |
|
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss |
|
""" |
|
neglogp = super().log_prob(actions) |
|
assert (reduction in ['none', 'mean']) |
|
if reduction == 'none': |
|
return neglogp |
|
elif reduction == 'mean': |
|
return neglogp.mean(dim=0) |
|
|
|
def mode(self) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Return logits argmax result |
|
Return: |
|
- result(:obj:`torch.Tensor`): the logits argmax result |
|
""" |
|
return self.probs.argmax(dim=-1) |
|
|
|
def entropy(self, reduction: str = None) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Calculate the softmax entropy of logits |
|
Arguments: |
|
- reduction (:obj:`str`): support [None, 'mean'], default set to mean |
|
Returns: |
|
- entropy (:obj:`torch.Tensor`): the calculated entropy |
|
""" |
|
entropy = super().entropy() |
|
assert (reduction in [None, 'mean']) |
|
if reduction is None: |
|
return entropy |
|
elif reduction == 'mean': |
|
return entropy.mean() |
|
|