File size: 4,696 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import time
import torch
from typing import Tuple
from hpc_rll.origin.scatter_connection import ScatterConnection
from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection
from testbase import mean_relative_error, times

assert torch.cuda.is_available()
use_cuda = True

B = 256
M = 256
N = 256
H = 16
W = 16


# Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead
def scatter_val():
    for scatter_type in ['add', 'cover']:
        ori_input = torch.randn(B, M, N)
        h = torch.randint(
            low=0, high=H, size=(
                B,
                M,
            )
        ).unsqueeze(dim=2)
        w = torch.randint(
            low=0, high=W, size=(
                B,
                M,
            )
        ).unsqueeze(dim=2)
        ori_location = torch.cat([h, w], dim=2)
        ori_scatter = ScatterConnection(scatter_type)

        hpc_input = ori_input.clone().detach()
        hpc_location = ori_location.clone().detach()
        hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)

        if use_cuda:
            #ori_input = ori_input.cuda()
            #ori_location = ori_location.cuda()
            #ori_scatter = ori_scatter.cuda()

            hpc_input = hpc_input.cuda()
            hpc_location = hpc_location.cuda()
            hpc_scatter = hpc_scatter.cuda()

        ori_input.requires_grad_(True)
        ori_output = ori_scatter(ori_input, (H, W), ori_location)
        ori_loss = ori_output * ori_output
        ori_loss = ori_loss.mean()
        ori_loss.backward()
        if use_cuda:
            torch.cuda.synchronize()

        hpc_input.requires_grad_(True)
        hpc_output = hpc_scatter(hpc_input, hpc_location)
        hpc_loss = hpc_output * hpc_output
        hpc_loss = hpc_loss.mean()
        hpc_loss.backward()
        if use_cuda:
            torch.cuda.synchronize()

        mre = mean_relative_error(
            torch.flatten(ori_loss).cpu().detach().numpy(),
            torch.flatten(hpc_loss).cpu().detach().numpy()
        )
        print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre)))
        mre = mean_relative_error(
            torch.flatten(ori_input.grad).cpu().detach().numpy(),
            torch.flatten(hpc_input.grad).cpu().detach().numpy()
        )
        print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre)))


# Note: performance test use origin gpu version
def scatter_perf():
    for scatter_type in ['add', 'cover']:
        ori_input = torch.randn(B, M, N)
        h = torch.randint(
            low=0, high=H, size=(
                B,
                M,
            )
        ).unsqueeze(dim=2)
        w = torch.randint(
            low=0, high=W, size=(
                B,
                M,
            )
        ).unsqueeze(dim=2)
        ori_location = torch.cat([h, w], dim=2)
        ori_scatter = ScatterConnection(scatter_type)

        hpc_input = ori_input.clone().detach()
        hpc_location = ori_location.clone().detach()
        hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)

        if use_cuda:
            ori_input = ori_input.cuda()
            ori_location = ori_location.cuda()
            ori_scatter = ori_scatter.cuda()

            hpc_input = hpc_input.cuda()
            hpc_location = hpc_location.cuda()
            hpc_scatter = hpc_scatter.cuda()

        for i in range(times):
            t = time.time()
            ori_input.requires_grad_(True)
            ori_output = ori_scatter(ori_input, (H, W), ori_location)
            ori_loss = ori_output * ori_output
            ori_loss = ori_loss.mean()
            ori_loss.backward()
            if use_cuda:
                torch.cuda.synchronize()
            print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))

        for i in range(times):
            t = time.time()
            hpc_input.requires_grad_(True)
            hpc_output = hpc_scatter(hpc_input, hpc_location)
            hpc_loss = hpc_output * hpc_output
            hpc_loss = hpc_loss.mean()
            hpc_loss.backward()
            if use_cuda:
                torch.cuda.synchronize()
            print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))


if __name__ == '__main__':
    print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W))
    print("================run scatter validation test================")
    scatter_val()
    print("================run scatter performance test================")
    scatter_perf()