from torch.distributions import Normal, Independent | |
import torch | |
# policy_logits = {'mu': torch.randn([1, 2]), 'sigma': abs(torch.randn([1, 2]))} | |
policy_logits = {'mu': torch.randn([1, 2]), 'sigma': torch.zeros([1, 2]) + 1e-7} | |
num_of_sampled_actions = 20 | |
(mu, sigma) = policy_logits['mu'], policy_logits['sigma'] | |
dist = Independent(Normal(mu, sigma), 1) | |
# dist = Normal(mu, sigma) | |
print(dist.batch_shape, dist.event_shape) | |
sampled_actions = dist.sample(torch.tensor([num_of_sampled_actions])) | |
log_prob = dist.log_prob(sampled_actions) | |
# log_prob = dist.log_prob(sampled_actions).unsqueeze(-1) | |