gomoku / DI-engine /ding /envs /common /tests /test_common_function.py
zjowowen's picture
init space
079c32c
raw
history blame
No virus
4.1 kB
import os
import random
import shutil
import numpy as np
import pytest
import torch
from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \
batch_binary_encode, get_postion_vector, \
affine_transform, save_frames_as_gif
VALUES = [2, 3, 5, 7, 11]
@pytest.fixture(scope="function")
def setup_reorder_array():
ret = np.full((12), -1)
for i, v in enumerate(VALUES):
ret[v] = i
return ret
@pytest.fixture(scope="function")
def setup_reorder_dict():
return {v: i for i, v in enumerate(VALUES)}
def generate_data():
ret = {
'obs': np.random.randn(4),
}
p_weight = np.random.uniform()
if p_weight < 1. / 3:
pass # no key 'priority'
elif p_weight < 2. / 3:
ret['priority'] = None
else:
ret['priority'] = np.random.uniform()
return ret
@pytest.mark.unittest
class TestEnvCommonFunc:
def test_one_hot(self):
a = torch.Tensor([[3, 4, 5], [1, 2, 6]])
a_sqrt = sqrt_one_hot(a, 6)
assert a_sqrt.max().item() == 1
assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)]
sqrt_dim = 3
assert a_sqrt.shape == (2, 3, sqrt_dim)
a_div = div_one_hot(a, 6, 2)
assert a_div.max().item() == 1
assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)]
div_dim = 4
assert a_div.shape == (2, 3, div_dim)
a_di = div_func(a, 2)
assert a_di.shape == (2, 1, 3)
assert torch.eq(a_di.squeeze() * 2, a).all()
a_clip = clip_one_hot(a.long(), 4)
assert a_clip.max().item() == 1
assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)]
clip_dim = 4
assert a_clip.shape == (2, 3, clip_dim)
def test_reorder(self, setup_reorder_array, setup_reorder_dict):
a = torch.LongTensor([2, 7]) # VALUES = [2, 3, 5, 7, 11]
a_array = reorder_one_hot_array(a, setup_reorder_array, 5)
a_dict = reorder_one_hot(a, setup_reorder_dict, 5)
assert torch.eq(a_array, a_dict).all()
assert a_array.max().item() == 1
assert [j.sum().item() for j in a_array] == [1 for _ in range(2)]
reorder_dim = 5
assert a_array.shape == (2, reorder_dim)
a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5)
assert a_array.max().item() == 1
assert torch.eq(a_bool, sum([_ for _ in a_array])).all()
def test_binary(self):
a = torch.LongTensor([445, 1023])
a_binary = batch_binary_encode(a, 10)
ans = []
for number in a:
one = [int(_) for _ in list(bin(number))[2:]]
for _ in range(10 - len(one)):
one.insert(0, 0)
ans.append(one)
ans = torch.Tensor(ans)
assert torch.eq(a_binary, ans).all()
def test_position(self):
a = [random.randint(0, 5000) for _ in range(32)]
a_position = get_postion_vector(a)
assert a_position.shape == (64, )
def test_affine_transform(self):
a = torch.rand(4, 3)
a = (a - a.min()) / (a.max() - a.min())
a = a * 2 - 1
ans = affine_transform(a, min_val=-2, max_val=2)
assert ans.shape == (4, 3)
assert ans.min() == -2 and ans.max() == 2
a = np.random.rand(3, 5)
a = (a - a.min()) / (a.max() - a.min())
a = a * 2 - 1
ans = affine_transform(a, alpha=4, beta=1)
assert ans.shape == (3, 5)
assert ans.min() == -3 and ans.max() == 5
@pytest.mark.other
def test_save_frames_as_gif():
frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)]
replay_path_gif = './replay_path_gif'
env_id = 'test'
save_replay_count = 1
if not os.path.exists(replay_path_gif):
os.makedirs(replay_path_gif)
path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count))
save_frames_as_gif(frames, path)
shutil.rmtree(replay_path_gif)