import torch import treetensor.torch as ttorch from torch.distributions import Normal, Independent class ArgmaxSampler: ''' Overview: Argmax sampler, return the index of the maximum value ''' def __call__(self, logit: torch.Tensor) -> torch.Tensor: ''' Overview: Return the index of the maximum value Arguments: - logit (:obj:`torch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The index of the maximum value ''' return logit.argmax(dim=-1) class MultinomialSampler: ''' Overview: Multinomial sampler, return the index of the sampled value ''' def __call__(self, logit: torch.Tensor) -> torch.Tensor: ''' Overview: Return the index of the sampled value Arguments: - logit (:obj:`torch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The index of the sampled value ''' dist = torch.distributions.Categorical(logits=logit) return dist.sample() class MuSampler: ''' Overview: Mu sampler, return the mu of the input tensor ''' def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: ''' Overview: Return the mu of the input tensor Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The mu of the input tensor ''' return logit.mu class ReparameterizationSampler: ''' Overview: Reparameterization sampler, return the reparameterized value of the input tensor ''' def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: ''' Overview: Return the reparameterized value of the input tensor Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor ''' dist = Normal(logit.mu, logit.sigma) dist = Independent(dist, 1) return dist.rsample() class HybridStochasticSampler: ''' Overview: Hybrid stochastic sampler, return the sampled action type and the reparameterized action args ''' def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: ''' Overview: Return the sampled action type and the reparameterized action args Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args ''' dist = torch.distributions.Categorical(logits=logit.action_type) action_type = dist.sample() dist = Normal(logit.action_args.mu, logit.action_args.sigma) dist = Independent(dist, 1) action_args = dist.rsample() return ttorch.as_tensor({ 'action_type': action_type, 'action_args': action_args, }) class HybridDeterminsticSampler: ''' Overview: Hybrid deterministic sampler, return the argmax action type and the mu action args ''' def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: ''' Overview: Return the argmax action type and the mu action args Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args ''' action_type = logit.action_type.argmax(dim=-1) action_args = logit.action_args.mu return ttorch.as_tensor({ 'action_type': action_type, 'action_args': action_args, })