File size: 4,104 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import random
import shutil

import numpy as np
import pytest
import torch
from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
    reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \
    batch_binary_encode, get_postion_vector, \
    affine_transform, save_frames_as_gif

VALUES = [2, 3, 5, 7, 11]


@pytest.fixture(scope="function")
def setup_reorder_array():
    ret = np.full((12), -1)
    for i, v in enumerate(VALUES):
        ret[v] = i
    return ret


@pytest.fixture(scope="function")
def setup_reorder_dict():
    return {v: i for i, v in enumerate(VALUES)}


def generate_data():
    ret = {
        'obs': np.random.randn(4),
    }
    p_weight = np.random.uniform()
    if p_weight < 1. / 3:
        pass  # no key 'priority'
    elif p_weight < 2. / 3:
        ret['priority'] = None
    else:
        ret['priority'] = np.random.uniform()

    return ret


@pytest.mark.unittest
class TestEnvCommonFunc:

    def test_one_hot(self):
        a = torch.Tensor([[3, 4, 5], [1, 2, 6]])

        a_sqrt = sqrt_one_hot(a, 6)
        assert a_sqrt.max().item() == 1
        assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)]
        sqrt_dim = 3
        assert a_sqrt.shape == (2, 3, sqrt_dim)

        a_div = div_one_hot(a, 6, 2)
        assert a_div.max().item() == 1
        assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)]
        div_dim = 4
        assert a_div.shape == (2, 3, div_dim)

        a_di = div_func(a, 2)
        assert a_di.shape == (2, 1, 3)
        assert torch.eq(a_di.squeeze() * 2, a).all()

        a_clip = clip_one_hot(a.long(), 4)
        assert a_clip.max().item() == 1
        assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)]
        clip_dim = 4
        assert a_clip.shape == (2, 3, clip_dim)

    def test_reorder(self, setup_reorder_array, setup_reorder_dict):
        a = torch.LongTensor([2, 7])  # VALUES = [2, 3, 5, 7, 11]

        a_array = reorder_one_hot_array(a, setup_reorder_array, 5)
        a_dict = reorder_one_hot(a, setup_reorder_dict, 5)
        assert torch.eq(a_array, a_dict).all()
        assert a_array.max().item() == 1
        assert [j.sum().item() for j in a_array] == [1 for _ in range(2)]
        reorder_dim = 5
        assert a_array.shape == (2, reorder_dim)

        a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5)
        assert a_array.max().item() == 1
        assert torch.eq(a_bool, sum([_ for _ in a_array])).all()

    def test_binary(self):
        a = torch.LongTensor([445, 1023])
        a_binary = batch_binary_encode(a, 10)
        ans = []
        for number in a:
            one = [int(_) for _ in list(bin(number))[2:]]
            for _ in range(10 - len(one)):
                one.insert(0, 0)
            ans.append(one)
        ans = torch.Tensor(ans)
        assert torch.eq(a_binary, ans).all()

    def test_position(self):
        a = [random.randint(0, 5000) for _ in range(32)]
        a_position = get_postion_vector(a)
        assert a_position.shape == (64, )

    def test_affine_transform(self):
        a = torch.rand(4, 3)
        a = (a - a.min()) / (a.max() - a.min())
        a = a * 2 - 1
        ans = affine_transform(a, min_val=-2, max_val=2)
        assert ans.shape == (4, 3)
        assert ans.min() == -2 and ans.max() == 2
        a = np.random.rand(3, 5)
        a = (a - a.min()) / (a.max() - a.min())
        a = a * 2 - 1
        ans = affine_transform(a, alpha=4, beta=1)
        assert ans.shape == (3, 5)
        assert ans.min() == -3 and ans.max() == 5


@pytest.mark.other
def test_save_frames_as_gif():
    frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)]
    replay_path_gif = './replay_path_gif'
    env_id = 'test'
    save_replay_count = 1
    if not os.path.exists(replay_path_gif):
        os.makedirs(replay_path_gif)
    path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count))
    save_frames_as_gif(frames, path)
    shutil.rmtree(replay_path_gif)