import time import torch from hpc_rll.origin.gae import gae, gae_data from hpc_rll.rl_utils.gae import GAE from testbase import mean_relative_error, times assert torch.cuda.is_available() use_cuda = True T = 1024 B = 64 def gae_val(): value = torch.randn(T + 1, B) reward = torch.randn(T, B) hpc_gae = GAE(T, B) if use_cuda: value = value.cuda() reward = reward.cuda() hpc_gae = hpc_gae.cuda() ori_adv = gae(gae_data(value, reward)) hpc_adv = hpc_gae(value, reward) if use_cuda: torch.cuda.synchronize() mre = mean_relative_error( torch.flatten(ori_adv).cpu().detach().numpy(), torch.flatten(hpc_adv).cpu().detach().numpy() ) print("gae mean_relative_error: " + str(mre)) def gae_perf(): value = torch.randn(T + 1, B) reward = torch.randn(T, B) hpc_gae = GAE(T, B) if use_cuda: value = value.cuda() reward = reward.cuda() hpc_gae = hpc_gae.cuda() for i in range(times): t = time.time() adv = gae(gae_data(value, reward)) if use_cuda: torch.cuda.synchronize() print('epoch: {}, original gae cost time: {}'.format(i, time.time() - t)) for i in range(times): t = time.time() hpc_adv = hpc_gae(value, reward) if use_cuda: torch.cuda.synchronize() print('epoch: {}, hpc gae cost time: {}'.format(i, time.time() - t)) if __name__ == '__main__': print("target problem: T = {}, B = {}".format(T, B)) print("================run gae validation test================") gae_val() print("================run gae performance test================") gae_perf()