from __future__ import absolute_import from __future__ import division from __future__ import print_function from typing import Tuple, Dict import torch import numpy as np import torch.nn.functional as F class Pd(object): """ Overview: Abstract class for parameterizable probability distributions and sampling functions. Interfaces: ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` .. tip:: In dereived classes, `logits` should be an attribute member stored in class. """ def neglogp(self, x: torch.Tensor) -> torch.Tensor: """ Overview: Calculate cross_entropy between input x and logits Arguments: - x (:obj:`torch.Tensor`): the input tensor Return: - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss """ raise NotImplementedError def entropy(self) -> torch.Tensor: """ Overview: Calculate the softmax entropy of logits Arguments: - reduction (:obj:`str`): support [None, 'mean'], default set to 'mean' Returns: - entropy (:obj:`torch.Tensor`): the calculated entropy """ raise NotImplementedError def noise_mode(self): """ Overview: Add noise to logits. This method is designed for randomness """ raise NotImplementedError def mode(self): """ Overview: Return logits argmax result. This method is designed for deterministic. """ raise NotImplementedError def sample(self): """ Overview: Sample from logits's distribution by using softmax. This method is designed for multinomial. """ raise NotImplementedError class CategoricalPd(Pd): """ Overview: Catagorical probility distribution sampler Interfaces: ``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` """ def __init__(self, logits: torch.Tensor = None) -> None: """ Overview: Init the Pd with logits Arguments: - logits (:obj:torch.Tensor): logits to sample from """ self.update_logits(logits) def update_logits(self, logits: torch.Tensor) -> None: """ Overview: Updata logits Arguments: - logits (:obj:`torch.Tensor`): logits to update """ self.logits = logits def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: """ Overview: Calculate cross_entropy between input x and logits Arguments: - x (:obj:`torch.Tensor`): the input tensor - reduction (:obj:`str`): support [None, 'mean'], default set to mean Return: - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss """ return F.cross_entropy(self.logits, x, reduction=reduction) def entropy(self, reduction: str = 'mean') -> torch.Tensor: """ Overview: Calculate the softmax entropy of logits Arguments: - reduction (:obj:`str`): support [None, 'mean'], default set to mean Returns: - entropy (:obj:`torch.Tensor`): the calculated entropy """ a = self.logits - self.logits.max(dim=-1, keepdim=True)[0] ea = torch.exp(a) z = ea.sum(dim=-1, keepdim=True) p = ea / z entropy = (p * (torch.log(z) - a)).sum(dim=-1) assert (reduction in [None, 'mean']) if reduction is None: return entropy elif reduction == 'mean': return entropy.mean() def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: """ Overview: add noise to logits Arguments: - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:`torch.Tensor`): noised logits - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. """ u = torch.rand_like(self.logits) u = -torch.log(-torch.log(u)) noise_logits = self.logits + u result = noise_logits.argmax(dim=-1) if viz: viz_feature = {} viz_feature['logits'] = self.logits.data.cpu().numpy() viz_feature['noise'] = u.data.cpu().numpy() viz_feature['noise_logits'] = noise_logits.data.cpu().numpy() return result, viz_feature else: return result def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: """ Overview: return logits argmax result Arguments: - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:`torch.Tensor`): the logits argmax result - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. """ result = self.logits.argmax(dim=-1) if viz: viz_feature = {} viz_feature['logits'] = self.logits.data.cpu().numpy() return result, viz_feature else: return result def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: """ Overview: Sample from logits's distribution by using softmax Arguments: - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) Returns: - result (:obj:`torch.Tensor`): the logits sampled result - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. """ p = torch.softmax(self.logits, dim=1) result = torch.multinomial(p, 1).squeeze(1) if viz: viz_feature = {} viz_feature['logits'] = self.logits.data.cpu().numpy() return result, viz_feature else: return result class CategoricalPdPytorch(torch.distributions.Categorical): """ Overview: Wrapped ``torch.distributions.Categorical`` Interfaces: ``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy`` """ def __init__(self, probs: torch.Tensor = None) -> None: """ Overview: Initialize the CategoricalPdPytorch object. Arguments: - probs (:obj:`torch.Tensor`): The tensor of probabilities. """ if probs is not None: self.update_probs(probs) def update_logits(self, logits: torch.Tensor) -> None: """ Overview: Updata logits Arguments: - logits (:obj:`torch.Tensor`): logits to update """ super().__init__(logits=logits) def update_probs(self, probs: torch.Tensor) -> None: """ Overview: Updata probs Arguments: - probs (:obj:`torch.Tensor`): probs to update """ super().__init__(probs=probs) def sample(self) -> torch.Tensor: """ Overview: Sample from logits's distribution by using softmax Return: - result (:obj:`torch.Tensor`): the logits sampled result """ return super().sample() def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: """ Overview: Calculate cross_entropy between input x and logits Arguments: - actions (:obj:`torch.Tensor`): the input action tensor - reduction (:obj:`str`): support [None, 'mean'], default set to mean Return: - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss """ neglogp = super().log_prob(actions) assert (reduction in ['none', 'mean']) if reduction == 'none': return neglogp elif reduction == 'mean': return neglogp.mean(dim=0) def mode(self) -> torch.Tensor: """ Overview: Return logits argmax result Return: - result(:obj:`torch.Tensor`): the logits argmax result """ return self.probs.argmax(dim=-1) def entropy(self, reduction: str = None) -> torch.Tensor: """ Overview: Calculate the softmax entropy of logits Arguments: - reduction (:obj:`str`): support [None, 'mean'], default set to mean Returns: - entropy (:obj:`torch.Tensor`): the calculated entropy """ entropy = super().entropy() assert (reduction in [None, 'mean']) if reduction is None: return entropy elif reduction == 'mean': return entropy.mean()