File size: 4,696 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import time
import torch
from typing import Tuple
from hpc_rll.origin.scatter_connection import ScatterConnection
from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection
from testbase import mean_relative_error, times
assert torch.cuda.is_available()
use_cuda = True
B = 256
M = 256
N = 256
H = 16
W = 16
# Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead
def scatter_val():
for scatter_type in ['add', 'cover']:
ori_input = torch.randn(B, M, N)
h = torch.randint(
low=0, high=H, size=(
B,
M,
)
).unsqueeze(dim=2)
w = torch.randint(
low=0, high=W, size=(
B,
M,
)
).unsqueeze(dim=2)
ori_location = torch.cat([h, w], dim=2)
ori_scatter = ScatterConnection(scatter_type)
hpc_input = ori_input.clone().detach()
hpc_location = ori_location.clone().detach()
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
if use_cuda:
#ori_input = ori_input.cuda()
#ori_location = ori_location.cuda()
#ori_scatter = ori_scatter.cuda()
hpc_input = hpc_input.cuda()
hpc_location = hpc_location.cuda()
hpc_scatter = hpc_scatter.cuda()
ori_input.requires_grad_(True)
ori_output = ori_scatter(ori_input, (H, W), ori_location)
ori_loss = ori_output * ori_output
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
hpc_input.requires_grad_(True)
hpc_output = hpc_scatter(hpc_input, hpc_location)
hpc_loss = hpc_output * hpc_output
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
mre = mean_relative_error(
torch.flatten(ori_loss).cpu().detach().numpy(),
torch.flatten(hpc_loss).cpu().detach().numpy()
)
print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre)))
mre = mean_relative_error(
torch.flatten(ori_input.grad).cpu().detach().numpy(),
torch.flatten(hpc_input.grad).cpu().detach().numpy()
)
print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre)))
# Note: performance test use origin gpu version
def scatter_perf():
for scatter_type in ['add', 'cover']:
ori_input = torch.randn(B, M, N)
h = torch.randint(
low=0, high=H, size=(
B,
M,
)
).unsqueeze(dim=2)
w = torch.randint(
low=0, high=W, size=(
B,
M,
)
).unsqueeze(dim=2)
ori_location = torch.cat([h, w], dim=2)
ori_scatter = ScatterConnection(scatter_type)
hpc_input = ori_input.clone().detach()
hpc_location = ori_location.clone().detach()
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
if use_cuda:
ori_input = ori_input.cuda()
ori_location = ori_location.cuda()
ori_scatter = ori_scatter.cuda()
hpc_input = hpc_input.cuda()
hpc_location = hpc_location.cuda()
hpc_scatter = hpc_scatter.cuda()
for i in range(times):
t = time.time()
ori_input.requires_grad_(True)
ori_output = ori_scatter(ori_input, (H, W), ori_location)
ori_loss = ori_output * ori_output
ori_loss = ori_loss.mean()
ori_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
for i in range(times):
t = time.time()
hpc_input.requires_grad_(True)
hpc_output = hpc_scatter(hpc_input, hpc_location)
hpc_loss = hpc_output * hpc_output
hpc_loss = hpc_loss.mean()
hpc_loss.backward()
if use_cuda:
torch.cuda.synchronize()
print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
if __name__ == '__main__':
print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W))
print("================run scatter validation test================")
scatter_val()
print("================run scatter performance test================")
scatter_perf()
|