gomoku / LightZero /lzero /mcts /ptree /test_sez_sample.py
zjowowen's picture
init space
079c32c
raw
history blame
599 Bytes
from torch.distributions import Normal, Independent
import torch
# policy_logits = {'mu': torch.randn([1, 2]), 'sigma': abs(torch.randn([1, 2]))}
policy_logits = {'mu': torch.randn([1, 2]), 'sigma': torch.zeros([1, 2]) + 1e-7}
num_of_sampled_actions = 20
(mu, sigma) = policy_logits['mu'], policy_logits['sigma']
dist = Independent(Normal(mu, sigma), 1)
# dist = Normal(mu, sigma)
print(dist.batch_shape, dist.event_shape)
sampled_actions = dist.sample(torch.tensor([num_of_sampled_actions]))
log_prob = dist.log_prob(sampled_actions)
# log_prob = dist.log_prob(sampled_actions).unsqueeze(-1)