gomoku / DI-engine /ding /hpc_rl /tests /test_scatter.py
zjowowen's picture
init space
079c32c
raw
history blame
4.7 kB
import time
import torch
from typing import Tuple
from hpc_rll.origin.scatter_connection import ScatterConnection
from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
B = 256
M = 256
N = 256
H = 16
W = 16
# Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead
def scatter_val():
for scatter_type in ['add', 'cover']:
ori_input = torch.randn(B, M, N)
h = torch.randint(
low=0, high=H, size=(
B,
M,
)
).unsqueeze(dim=2)
w = torch.randint(
low=0, high=W, size=(
B,
M,
)
).unsqueeze(dim=2)
ori_location = torch.cat([h, w], dim=2)
ori_scatter = ScatterConnection(scatter_type)
hpc_input = ori_input.clone().detach()
hpc_location = ori_location.clone().detach()
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
if use_cuda:
#ori_input = ori_input.cuda()
#ori_location = ori_location.cuda()
#ori_scatter = ori_scatter.cuda()
hpc_input = hpc_input.cuda()
hpc_location = hpc_location.cuda()
hpc_scatter = hpc_scatter.cuda()
ori_input.requires_grad_(True)
ori_output = ori_scatter(ori_input, (H, W), ori_location)
ori_loss = ori_output * ori_output
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
hpc_input.requires_grad_(True)
hpc_output = hpc_scatter(hpc_input, hpc_location)
hpc_loss = hpc_output * hpc_output
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre)))
mre = mean_relative_error(
torch.flatten(ori_input.grad).cpu().detach().numpy(),
torch.flatten(hpc_input.grad).cpu().detach().numpy()
)
print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre)))
# Note: performance test use origin gpu version
def scatter_perf():
for scatter_type in ['add', 'cover']:
ori_input = torch.randn(B, M, N)
h = torch.randint(
low=0, high=H, size=(
B,
M,
)
).unsqueeze(dim=2)
w = torch.randint(
low=0, high=W, size=(
B,
M,
)
).unsqueeze(dim=2)
ori_location = torch.cat([h, w], dim=2)
ori_scatter = ScatterConnection(scatter_type)
hpc_input = ori_input.clone().detach()
hpc_location = ori_location.clone().detach()
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
if use_cuda:
ori_input = ori_input.cuda()
ori_location = ori_location.cuda()
ori_scatter = ori_scatter.cuda()
hpc_input = hpc_input.cuda()
hpc_location = hpc_location.cuda()
hpc_scatter = hpc_scatter.cuda()
for i in range(times):
t = time.time()
ori_input.requires_grad_(True)
ori_output = ori_scatter(ori_input, (H, W), ori_location)
ori_loss = ori_output * ori_output
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
for i in range(times):
t = time.time()
hpc_input.requires_grad_(True)
hpc_output = hpc_scatter(hpc_input, hpc_location)
hpc_loss = hpc_output * hpc_output
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
if __name__ == '__main__':
print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W))
print("================run scatter validation test================")
scatter_val()
print("================run scatter performance test================")
scatter_perf()