File size: 3,180 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
import torch
from hpc_rll.origin.td import td_lambda_error, td_lambda_data
from hpc_rll.rl_utils.td import TDLambda
from testbase import mean_relative_error, times

assert torch.cuda.is_available()
use_cuda = True

T = 1024
B = 64


def td_val():
    ori_value = torch.randn(T + 1, B)
    ori_reward = torch.randn(T, B)
    ori_weight = torch.randn(T, B)

    hpc_value = ori_value.clone().detach()
    hpc_reward = ori_reward.clone().detach()
    hpc_weight = ori_weight.clone().detach()
    hpc_td = TDLambda(T, B)

    if use_cuda:
        ori_value = ori_value.cuda()
        ori_reward = ori_reward.cuda()
        ori_weight = ori_weight.cuda()

        hpc_value = hpc_value.cuda()
        hpc_reward = hpc_reward.cuda()
        hpc_weight = hpc_weight.cuda()
        hpc_td = hpc_td.cuda()

    ori_value.requires_grad_(True)
    ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
    ori_loss = ori_loss.mean()
    ori_loss.backward()
    if use_cuda:
        torch.cuda.synchronize()

    hpc_value.requires_grad_(True)
    hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
    hpc_loss = hpc_loss.mean()
    hpc_loss.backward()
    if use_cuda:
        torch.cuda.synchronize()

    mre = mean_relative_error(
        torch.flatten(ori_loss).cpu().detach().numpy(),
        torch.flatten(hpc_loss).cpu().detach().numpy()
    )
    print("td fp mean_relative_error: " + str(mre))
    mre = mean_relative_error(
        torch.flatten(ori_value.grad).cpu().detach().numpy(),
        torch.flatten(hpc_value.grad).cpu().detach().numpy()
    )
    print("td bp mean_relative_error: " + str(mre))


def td_perf():
    ori_value = torch.randn(T + 1, B)
    ori_reward = torch.randn(T, B)
    ori_weight = torch.randn(T, B)

    hpc_value = ori_value.clone().detach()
    hpc_reward = ori_reward.clone().detach()
    hpc_weight = ori_weight.clone().detach()
    hpc_td = TDLambda(T, B)

    if use_cuda:
        ori_value = ori_value.cuda()
        ori_reward = ori_reward.cuda()
        ori_weight = ori_weight.cuda()

        hpc_value = hpc_value.cuda()
        hpc_reward = hpc_reward.cuda()
        hpc_weight = hpc_weight.cuda()
        hpc_td = hpc_td.cuda()

    ori_value.requires_grad_(True)
    for i in range(times):
        t = time.time()
        ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
        ori_loss = ori_loss.mean()
        ori_loss.backward()
        if use_cuda:
            torch.cuda.synchronize()
        print('epoch: {}, original td cost time: {}'.format(i, time.time() - t))

    hpc_value.requires_grad_(True)
    for i in range(times):
        t = time.time()
        hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
        hpc_loss = hpc_loss.mean()
        hpc_loss.backward()
        if use_cuda:
            torch.cuda.synchronize()
        print('epoch: {}, hpc td cost time: {}'.format(i, time.time() - t))


if __name__ == '__main__':
    print("target problem: T = {}, B = {}".format(T, B))
    print("================run td validation test================")
    td_val()
    print("================run td performance test================")
    td_perf()