File size: 5,305 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import time
import torch
from hpc_rll.origin.rnn import get_lstm
from hpc_rll.torch_utils.network.rnn import LSTM
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
seq_len = 64
batch_size = 3
input_size = 1792
hidden_size = 384
num_layers = 3
norm_type = 'LN'
dropout = 0 # 0.1
# Note: need open load_params for hpc_lstm to validation
# Note: only used to case of num_layers = 3
def lstm_val():
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
ori_x = torch.randn(seq_len, batch_size, input_size)
ori_h0 = torch.randn(num_layers, batch_size, hidden_size)
ori_c0 = torch.randn(num_layers, batch_size, hidden_size)
if use_cuda:
ori_x = ori_x.cuda()
ori_h0 = ori_h0.cuda()
ori_c0 = ori_c0.cuda()
ori_lstm = ori_lstm.cuda()
hpc_lstm = hpc_lstm.cuda()
ori_x.requires_grad_(True)
ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0])
ori_loss = ori_output.mean()
ori_loss.backward()
hpc_x = ori_x.clone().detach()
hpc_h0 = ori_h0.clone().detach()
hpc_c0 = ori_c0.clone().detach()
hpc_x.requires_grad_(True)
hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0])
hpc_loss = hpc_output.mean()
hpc_loss.backward()
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("lstm fp mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_x.grad).cpu().detach().numpy(),
torch.flatten(hpc_x.grad).cpu().detach().numpy()
)
print("lstm bp mean_relative_error: " + str(mre))
ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad))
hpc_wx_grad = hpc_lstm.wx.grad
mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy())
print("wx grad mean_relative_error: " + str(mre))
ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad))
hpc_wh_grad = hpc_lstm.wh.grad
mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy())
print("wh grad mean_relative_error: " + str(mre))
ori_bias_grad = ori_lstm.bias.grad
hpc_bias_grad = hpc_lstm.bias.grad
mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy())
print("bias grad mean_relative_error: " + str(mre))
params = list(ori_lstm.parameters())
gamma_0_x = params[1]
beta_0_x = params[2]
gamma_0_h = params[3]
beta_0_h = params[4]
gamma_1_x = params[5]
beta_1_x = params[6]
gamma_1_h = params[7]
beta_1_h = params[8]
gamma_2_x = params[9]
beta_2_x = params[10]
gamma_2_h = params[11]
beta_2_h = params[12]
ori_gamma_grad = torch.cat(
(gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad)
)
ori_beta_grad = torch.cat(
(beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad)
)
hpc_gamma_grad = hpc_lstm.ln_gamma.grad
hpc_beta_grad = hpc_lstm.ln_beta.grad
mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy())
print("ln gamma grad mean_relative_error: " + str(mre))
mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy())
print("ln beta grad mean_relative_error: " + str(mre))
def lstm_perf():
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
lstms = {'normal': ori_lstm, 'hpc': hpc_lstm}
for lstm_type, lstm in lstms.items():
x = torch.rand(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)
if use_cuda:
x = x.cuda()
h0 = h0.cuda()
c0 = c0.cuda()
lstm = lstm.cuda()
prev_state = [h0, c0]
x.requires_grad_(True)
for i in range(times):
t = time.time()
output, _ = lstm(x, prev_state)
loss = output.mean()
loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t))
if __name__ == '__main__':
print(
"target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}" # noqa
.format(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
)
print("==============lstm has no validation test================")
#print("===============run lstm validation test==================")
#lstm_val()
print("===============run lstm performance test=================")
lstm_perf()
|