zjowowen's picture
init space
079c32c
raw
history blame
No virus
5.31 kB
import time
import torch
from hpc_rll.origin.rnn import get_lstm
from hpc_rll.torch_utils.network.rnn import LSTM
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
seq_len = 64
batch_size = 3
input_size = 1792
hidden_size = 384
num_layers = 3
norm_type = 'LN'
dropout = 0 # 0.1
# Note: need open load_params for hpc_lstm to validation
# Note: only used to case of num_layers = 3
def lstm_val():
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
ori_x = torch.randn(seq_len, batch_size, input_size)
ori_h0 = torch.randn(num_layers, batch_size, hidden_size)
ori_c0 = torch.randn(num_layers, batch_size, hidden_size)
if use_cuda:
ori_x = ori_x.cuda()
ori_h0 = ori_h0.cuda()
ori_c0 = ori_c0.cuda()
ori_lstm = ori_lstm.cuda()
hpc_lstm = hpc_lstm.cuda()
ori_x.requires_grad_(True)
ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0])
ori_loss = ori_output.mean()
ori_loss.backward()
hpc_x = ori_x.clone().detach()
hpc_h0 = ori_h0.clone().detach()
hpc_c0 = ori_c0.clone().detach()
hpc_x.requires_grad_(True)
hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0])
hpc_loss = hpc_output.mean()
hpc_loss.backward()
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("lstm fp mean_relative_error: " + str(mre))
mre = mean_relative_error(
torch.flatten(ori_x.grad).cpu().detach().numpy(),
torch.flatten(hpc_x.grad).cpu().detach().numpy()
)
print("lstm bp mean_relative_error: " + str(mre))
ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad))
hpc_wx_grad = hpc_lstm.wx.grad
mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy())
print("wx grad mean_relative_error: " + str(mre))
ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad))
hpc_wh_grad = hpc_lstm.wh.grad
mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy())
print("wh grad mean_relative_error: " + str(mre))
ori_bias_grad = ori_lstm.bias.grad
hpc_bias_grad = hpc_lstm.bias.grad
mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy())
print("bias grad mean_relative_error: " + str(mre))
params = list(ori_lstm.parameters())
gamma_0_x = params[1]
beta_0_x = params[2]
gamma_0_h = params[3]
beta_0_h = params[4]
gamma_1_x = params[5]
beta_1_x = params[6]
gamma_1_h = params[7]
beta_1_h = params[8]
gamma_2_x = params[9]
beta_2_x = params[10]
gamma_2_h = params[11]
beta_2_h = params[12]
ori_gamma_grad = torch.cat(
(gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad)
)
ori_beta_grad = torch.cat(
(beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad)
)
hpc_gamma_grad = hpc_lstm.ln_gamma.grad
hpc_beta_grad = hpc_lstm.ln_beta.grad
mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy())
print("ln gamma grad mean_relative_error: " + str(mre))
mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy())
print("ln beta grad mean_relative_error: " + str(mre))
def lstm_perf():
ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
lstms = {'normal': ori_lstm, 'hpc': hpc_lstm}
for lstm_type, lstm in lstms.items():
x = torch.rand(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers, batch_size, hidden_size)
c0 = torch.randn(num_layers, batch_size, hidden_size)
if use_cuda:
x = x.cuda()
h0 = h0.cuda()
c0 = c0.cuda()
lstm = lstm.cuda()
prev_state = [h0, c0]
x.requires_grad_(True)
for i in range(times):
t = time.time()
output, _ = lstm(x, prev_state)
loss = output.mean()
loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t))
if __name__ == '__main__':
print(
"target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}" # noqa
.format(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
)
print("==============lstm has no validation test================")
#print("===============run lstm validation test==================")
#lstm_val()
print("===============run lstm performance test=================")
lstm_perf()