File size: 3,239 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Optional
import torch
from torch import nn
from torch.distributions.transforms import TanhTransform


class NonegativeParameter(nn.Module):
    """
    Overview:
        This module will output a non-negative parameter during the forward process.
    Interfaces:
        ``__init__``, ``forward``, ``set_data``.
    """

    def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8):
        """
        Overview:
            Initialize the NonegativeParameter object using the given arguments.
        Arguments:
            - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
                default value is 0.
            - requires_grad (:obj:`bool`): Whether this parameter requires grad.
            - delta (:obj:`Any`): The delta of log function.
        """
        super().__init__()
        if data is None:
            data = torch.zeros(1)
        self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad)

    def forward(self) -> torch.Tensor:
        """
        Overview:
            Output the non-negative parameter during the forward process.
        Returns:
            parameter (:obj:`torch.Tensor`): The generated parameter.
        """
        return torch.exp(self.log_data)

    def set_data(self, data: torch.Tensor) -> None:
        """
        Overview:
            Set the value of the non-negative parameter.
        Arguments:
            data (:obj:`torch.Tensor`): The new value of the non-negative parameter.
        """
        self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad)


class TanhParameter(nn.Module):
    """
    Overview:
        This module will output a tanh parameter during the forward process.
    Interfaces:
        ``__init__``, ``forward``, ``set_data``.
    """

    def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True):
        """
        Overview:
            Initialize the TanhParameter object using the given arguments.
        Arguments:
            - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
                default value is 1.
            - requires_grad (:obj:`bool`): Whether this parameter requires grad.
        """
        super().__init__()
        if data is None:
            data = torch.zeros(1)
        self.transform = TanhTransform(cache_size=1)

        self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad)

    def forward(self) -> torch.Tensor:
        """
        Overview:
            Output the tanh parameter during the forward process.
        Returns:
            parameter (:obj:`torch.Tensor`): The generated parameter.
        """
        return self.transform(self.data_inv)

    def set_data(self, data: torch.Tensor) -> None:
        """
        Overview:
            Set the value of the tanh parameter.
        Arguments:
            data (:obj:`torch.Tensor`): The new value of the tanh parameter.
        """
        self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad)