File size: 811 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import unittest
import pytest
import torch
from ding.torch_utils.parameter import NonegativeParameter, TanhParameter


@pytest.mark.unittest
def test_nonegative_parameter():
    nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0]))
    assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0
    nonegative_parameter.set_data(torch.tensor(1))
    assert nonegative_parameter() == 1


@pytest.mark.unittest
def test_tanh_parameter():
    tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2]))
    assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all()
    tanh_parameter.set_data(torch.tensor(0.3))
    assert tanh_parameter() == 0.3


if __name__ == "__main__":
    test_nonegative_parameter()
    test_tanh_parameter()