File size: 811 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import unittest
import pytest
import torch
from ding.torch_utils.parameter import NonegativeParameter, TanhParameter
@pytest.mark.unittest
def test_nonegative_parameter():
nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0]))
assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0
nonegative_parameter.set_data(torch.tensor(1))
assert nonegative_parameter() == 1
@pytest.mark.unittest
def test_tanh_parameter():
tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2]))
assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all()
tanh_parameter.set_data(torch.tensor(0.3))
assert tanh_parameter() == 0.3
if __name__ == "__main__":
test_nonegative_parameter()
test_tanh_parameter()
|