File size: 4,564 Bytes
2ded60b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable
__kernel void Convolution1x1_NHWC(
const __global half *in,
const __global half *out,
const __global half *w,
int IW,
int IH,
int IC,
int OW,
int OH,
int OC)
{
__local half in_local[8 * 1024];
__local half out_local[8 * 1024];
const int sizeAct = IW * IC;
event_t e1 = async_work_group_copy(in_local, in + get_group_id(0) * sizeAct, sizeAct, 0);
wait_group_events(1, &e1);
int oh = get_global_id(0);
int oc = get_global_id(1);
int stride;
int write_output = 0;
__global half *src;
__global half8 *w8 = (__global half8 *)(&w[oc * IC]);
__global half *w1 = (__global half *)(&w[oc * IC]);
for (uint ow = 0; ow < (OW & (~0x7)); ow += 8) {
uint iw = ow;
uint ih = oh;
half8 val8_0 = 0.0f;
half8 val8_1 = 0.0f;
half8 val8_2 = 0.0f;
half8 val8_3 = 0.0f;
half8 val8_4 = 0.0f;
half8 val8_5 = 0.0f;
half8 val8_6 = 0.0f;
half8 val8_7 = 0.0f;
__local half8 *in8_0 = (__local half8 *)(&in_local[(iw + 0) * IC]);
__local half8 *in8_1 = (__local half8 *)(&in_local[(iw + 1) * IC]);
__local half8 *in8_2 = (__local half8 *)(&in_local[(iw + 2) * IC]);
__local half8 *in8_3 = (__local half8 *)(&in_local[(iw + 3) * IC]);
__local half8 *in8_4 = (__local half8 *)(&in_local[(iw + 4) * IC]);
__local half8 *in8_5 = (__local half8 *)(&in_local[(iw + 5) * IC]);
__local half8 *in8_6 = (__local half8 *)(&in_local[(iw + 6) * IC]);
__local half8 *in8_7 = (__local half8 *)(&in_local[(iw + 7) * IC]);
for (uint ic = 0; ic < IC / 8; ++ic) {
val8_0 += (in8_0[ic]) * (w8[ic]);
val8_1 += (in8_1[ic]) * (w8[ic]);
val8_2 += (in8_2[ic]) * (w8[ic]);
val8_3 += (in8_3[ic]) * (w8[ic]);
val8_4 += (in8_4[ic]) * (w8[ic]);
val8_5 += (in8_5[ic]) * (w8[ic]);
val8_6 += (in8_6[ic]) * (w8[ic]);
val8_7 += (in8_7[ic]) * (w8[ic]);
}
half val_0 = 0.0f;
half val_1 = 0.0f;
half val_2 = 0.0f;
half val_3 = 0.0f;
half val_4 = 0.0f;
half val_5 = 0.0f;
half val_6 = 0.0f;
half val_7 = 0.0f;
for (uint ic = IC & (~0x7); ic < IC; ++ic) {
val_0 += *((__local half *)in8_0 + ic) * (*((__global half *)w8 + ic));
val_1 += *((__local half *)in8_1 + ic) * (*((__global half *)w8 + ic));
val_2 += *((__local half *)in8_2 + ic) * (*((__global half *)w8 + ic));
val_3 += *((__local half *)in8_3 + ic) * (*((__global half *)w8 + ic));
val_4 += *((__local half *)in8_4 + ic) * (*((__global half *)w8 + ic));
val_5 += *((__local half *)in8_5 + ic) * (*((__global half *)w8 + ic));
val_6 += *((__local half *)in8_6 + ic) * (*((__global half *)w8 + ic));
val_7 += *((__local half *)in8_7 + ic) * (*((__global half *)w8 + ic));
}
out_local[ow + 0] = __builtin_shave_sau_sumx_f16_r(val8_0) + val_0;
out_local[ow + 1] = __builtin_shave_sau_sumx_f16_r(val8_1) + val_1;
out_local[ow + 2] = __builtin_shave_sau_sumx_f16_r(val8_2) + val_2;
out_local[ow + 3] = __builtin_shave_sau_sumx_f16_r(val8_3) + val_3;
out_local[ow + 4] = __builtin_shave_sau_sumx_f16_r(val8_4) + val_4;
out_local[ow + 5] = __builtin_shave_sau_sumx_f16_r(val8_5) + val_5;
out_local[ow + 6] = __builtin_shave_sau_sumx_f16_r(val8_6) + val_6;
out_local[ow + 7] = __builtin_shave_sau_sumx_f16_r(val8_7) + val_7;
}
for (uint ow = (OW & (~0x7)); ow < OW; ow++) {
uint iw = ow;
uint ih = oh;
half8 val8 = 0.0f;
__local half8 *in8 = (__local half8 *)(&in_local[iw * IC]);
for (uint ic = 0; ic < IC / 8; ++ic) {
val8 += (in8[ic]) * (w8[ic]);
}
half val = 0.0f;
for (uint ic = (IC & (~0x7)); ic < IC; ++ic) {
val += (*((__local half *)in8 + ic)) * (*((__global half *)w8 + ic));
}
out_local[ow] = __builtin_shave_sau_sumx_f16_r(val8) + val;
}
barrier(CLK_LOCAL_MEM_FENCE);
event_t e2 = async_work_group_copy(
out + get_group_id(1) * OW * OH + get_group_id(0) * OW,
out_local,
OW,
0);
wait_group_events(1, &e2);
}
|