gomoku / LightZero /lzero /policy /scaling_transform.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
5.38 kB
from typing import Union
import numpy as np
import torch
class DiscreteSupport(object):
def __init__(self, min: int, max: int, delta: float = 1.) -> None:
assert min < max
self.min = min
self.max = max
self.range = np.arange(min, max + 1, delta)
self.size = len(self.range)
self.set_size = len(self.range)
self.delta = delta
def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor:
"""
Overview:
Transform the original value to the scaled value, i.e. the h(.) function
in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero: Appendix F: Network Architecture
- https://arxiv.org/pdf/1805.11593.pdf (Page-11) Appendix A : Proposition A.2
"""
# h(.) function
if delta == 1: # for speed up
output = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x
else:
# delta != 1
output = torch.sign(x) * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta
return output
def inverse_scalar_transform(
logits: torch.Tensor,
support_size: int,
epsilon: float = 0.001,
categorical_distribution: bool = True
) -> torch.Tensor:
"""
Overview:
transform the scaled value or its categorical representation to the original value,
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero Appendix F: Network Architecture.
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
"""
if categorical_distribution:
scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
value_probs = torch.softmax(logits, dim=1)
value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)
value_support = value_support.to(device=value_probs.device)
value = (value_support * value_probs).sum(1, keepdim=True)
else:
value = logits
# h^(-1)(.) function
output = torch.sign(value) * (
((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1
)
# TODO(pu): comment this line due to saving time
# output[torch.abs(output) < epsilon] = 0.
return output
class InverseScalarTransform:
"""
Overview:
transform the the scaled value or its categorical representation to the original value,
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero Appendix F: Network Architecture.
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
"""
def __init__(
self,
support_size: int,
device: Union[str, torch.device] = 'cpu',
categorical_distribution: bool = True
) -> None:
scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)
self.value_support = self.value_support.to(device)
self.categorical_distribution = categorical_distribution
def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor:
if self.categorical_distribution:
value_probs = torch.softmax(logits, dim=1)
value = value_probs.mul_(self.value_support).sum(1, keepdim=True)
else:
value = logits
tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon))
# t * t is faster than t ** 2
output = torch.sign(value) * (tmp * tmp - 1)
return output
def visit_count_temperature(
manual_temperature_decay: bool, fixed_temperature_value: float,
threshold_training_steps_for_final_lr_temperature: int, trained_steps: int
) -> float:
if manual_temperature_decay:
if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature:
return 1.0
elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature:
return 0.5
else:
return 0.25
else:
return fixed_temperature_value
def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations.
After this transformation, each scalar is represented as the linear combination of its two adjacent supports.
Reference:
- MuZero paper Appendix F: Network Architecture.
"""
min = discrete_support.min
max = discrete_support.max
set_size = discrete_support.set_size
delta = discrete_support.delta
x.clamp_(min, max)
x_low = x.floor()
x_high = x.ceil()
p_high = x - x_low
p_low = 1 - p_high
target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device)
x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta
target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1))
target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1))
return target
def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return -(torch.log_softmax(prediction, dim=1) * target).sum(1)