File size: 5,375 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Union
import numpy as np
import torch


class DiscreteSupport(object):

    def __init__(self, min: int, max: int, delta: float = 1.) -> None:
        assert min < max
        self.min = min
        self.max = max
        self.range = np.arange(min, max + 1, delta)
        self.size = len(self.range)
        self.set_size = len(self.range)
        self.delta = delta


def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor:
    """
    Overview:
        Transform the original value to the scaled value, i.e. the h(.) function
        in paper https://arxiv.org/pdf/1805.11593.pdf.
    Reference:
        - MuZero: Appendix F: Network Architecture
        - https://arxiv.org/pdf/1805.11593.pdf (Page-11) Appendix A : Proposition A.2
    """
    # h(.) function
    if delta == 1:  # for speed up
        output = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x
    else:
        # delta != 1
        output = torch.sign(x) * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta
    return output


def inverse_scalar_transform(
        logits: torch.Tensor,
        support_size: int,
        epsilon: float = 0.001,
        categorical_distribution: bool = True
) -> torch.Tensor:
    """
    Overview:
        transform the scaled value or its categorical representation to the original value,
        i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
    Reference:
        - MuZero Appendix F: Network Architecture.
        - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
    """
    if categorical_distribution:
        scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
        value_probs = torch.softmax(logits, dim=1)

        value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)

        value_support = value_support.to(device=value_probs.device)
        value = (value_support * value_probs).sum(1, keepdim=True)
    else:
        value = logits

    # h^(-1)(.) function
    output = torch.sign(value) * (
        ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1
    )

    # TODO(pu): comment this line due to saving time
    # output[torch.abs(output) < epsilon] = 0.

    return output


class InverseScalarTransform:
    """
    Overview:
        transform the the scaled value or its categorical representation to the original value,
        i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
    Reference:
        - MuZero Appendix F: Network Architecture.
        - https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
    """

    def __init__(
            self,
            support_size: int,
            device: Union[str, torch.device] = 'cpu',
            categorical_distribution: bool = True
    ) -> None:
        scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
        self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)
        self.value_support = self.value_support.to(device)
        self.categorical_distribution = categorical_distribution

    def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor:
        if self.categorical_distribution:
            value_probs = torch.softmax(logits, dim=1)
            value = value_probs.mul_(self.value_support).sum(1, keepdim=True)
        else:
            value = logits
        tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon))
        # t * t is faster than t ** 2
        output = torch.sign(value) * (tmp * tmp - 1)

        return output


def visit_count_temperature(
        manual_temperature_decay: bool, fixed_temperature_value: float,
        threshold_training_steps_for_final_lr_temperature: int, trained_steps: int
) -> float:
    if manual_temperature_decay:
        if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature:
            return 1.0
        elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature:
            return 0.5
        else:
            return 0.25
    else:
        return fixed_temperature_value


def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor:
    """
    Overview:
        We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations.
         After this transformation, each scalar is represented as the linear combination of its two adjacent supports.
    Reference:
        - MuZero paper Appendix F: Network Architecture.
    """
    min = discrete_support.min
    max = discrete_support.max
    set_size = discrete_support.set_size
    delta = discrete_support.delta

    x.clamp_(min, max)
    x_low = x.floor()
    x_high = x.ceil()
    p_high = x - x_low
    p_low = 1 - p_high

    target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device)
    x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta
    target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1))
    target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1))

    return target


def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return -(torch.log_softmax(prediction, dim=1) * target).sum(1)