File size: 5,305 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import time
import torch
from hpc_rll.origin.rnn import get_lstm
from hpc_rll.torch_utils.network.rnn import LSTM
from testbase import mean_relative_error, times

assert torch.cuda.is_available()
use_cuda = True

seq_len = 64
batch_size = 3
input_size = 1792
hidden_size = 384
num_layers = 3
norm_type = 'LN'
dropout = 0  # 0.1


# Note: need open load_params for hpc_lstm to validation
# Note: only used to case of num_layers = 3
def lstm_val():
    ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
    hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)

    ori_x = torch.randn(seq_len, batch_size, input_size)
    ori_h0 = torch.randn(num_layers, batch_size, hidden_size)
    ori_c0 = torch.randn(num_layers, batch_size, hidden_size)

    if use_cuda:
        ori_x = ori_x.cuda()
        ori_h0 = ori_h0.cuda()
        ori_c0 = ori_c0.cuda()
        ori_lstm = ori_lstm.cuda()
        hpc_lstm = hpc_lstm.cuda()

    ori_x.requires_grad_(True)
    ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0])
    ori_loss = ori_output.mean()
    ori_loss.backward()

    hpc_x = ori_x.clone().detach()
    hpc_h0 = ori_h0.clone().detach()
    hpc_c0 = ori_c0.clone().detach()
    hpc_x.requires_grad_(True)
    hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0])
    hpc_loss = hpc_output.mean()
    hpc_loss.backward()
    torch.cuda.synchronize()

    mre = mean_relative_error(
        torch.flatten(ori_loss).cpu().detach().numpy(),
        torch.flatten(hpc_loss).cpu().detach().numpy()
    )
    print("lstm fp mean_relative_error: " + str(mre))
    mre = mean_relative_error(
        torch.flatten(ori_x.grad).cpu().detach().numpy(),
        torch.flatten(hpc_x.grad).cpu().detach().numpy()
    )
    print("lstm bp mean_relative_error: " + str(mre))

    ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad))
    hpc_wx_grad = hpc_lstm.wx.grad
    mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy())
    print("wx grad mean_relative_error: " + str(mre))

    ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad))
    hpc_wh_grad = hpc_lstm.wh.grad
    mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy())
    print("wh grad mean_relative_error: " + str(mre))

    ori_bias_grad = ori_lstm.bias.grad
    hpc_bias_grad = hpc_lstm.bias.grad
    mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy())
    print("bias grad mean_relative_error: " + str(mre))

    params = list(ori_lstm.parameters())
    gamma_0_x = params[1]
    beta_0_x = params[2]
    gamma_0_h = params[3]
    beta_0_h = params[4]
    gamma_1_x = params[5]
    beta_1_x = params[6]
    gamma_1_h = params[7]
    beta_1_h = params[8]
    gamma_2_x = params[9]
    beta_2_x = params[10]
    gamma_2_h = params[11]
    beta_2_h = params[12]
    ori_gamma_grad = torch.cat(
        (gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad)
    )
    ori_beta_grad = torch.cat(
        (beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad)
    )
    hpc_gamma_grad = hpc_lstm.ln_gamma.grad
    hpc_beta_grad = hpc_lstm.ln_beta.grad
    mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy())
    print("ln gamma grad mean_relative_error: " + str(mre))
    mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy())
    print("ln beta grad mean_relative_error: " + str(mre))


def lstm_perf():
    ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
    hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)

    lstms = {'normal': ori_lstm, 'hpc': hpc_lstm}

    for lstm_type, lstm in lstms.items():
        x = torch.rand(seq_len, batch_size, input_size)
        h0 = torch.randn(num_layers, batch_size, hidden_size)
        c0 = torch.randn(num_layers, batch_size, hidden_size)
        if use_cuda:
            x = x.cuda()
            h0 = h0.cuda()
            c0 = c0.cuda()
            lstm = lstm.cuda()

        prev_state = [h0, c0]
        x.requires_grad_(True)
        for i in range(times):
            t = time.time()
            output, _ = lstm(x, prev_state)
            loss = output.mean()
            loss.backward()
            if use_cuda:
                torch.cuda.synchronize()
            print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t))


if __name__ == '__main__':
    print(
        "target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}"  # noqa
        .format(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
    )
    print("==============lstm has no validation test================")
    #print("===============run lstm validation test==================")
    #lstm_val()
    print("===============run lstm performance test=================")
    lstm_perf()