|
|
|
|
|
|
|
|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
|
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable |
|
|
|
__global half *find(__global const half *begin, __global const half *end, half value) |
|
{ |
|
while (begin != end) { |
|
if (*begin == value) { |
|
return begin; |
|
} |
|
++begin; |
|
} |
|
return end; |
|
} |
|
|
|
__kernel void CTCDecoder( |
|
__global half *restrict probabilities, |
|
__global half *restrict sequence_indicators, |
|
__global half *restrict output, |
|
int width, |
|
int height, |
|
int channels) |
|
{ |
|
__local half local_src[88 * 1 * 77]; |
|
__local half local_dst[88 * 1]; |
|
|
|
event_t e1 = async_work_group_copy_2D2D( |
|
local_src, |
|
probabilities, |
|
width, |
|
height * channels, |
|
width * (height - 1), |
|
width * (height - 1), |
|
0); |
|
|
|
wait_group_events(1, &e1); |
|
|
|
const int T = channels; |
|
const int B = height; |
|
const int C = width; |
|
|
|
#pragma unroll 4 |
|
for (int i = 0; i < B * T; i++) { |
|
local_dst[i] = -1.h; |
|
} |
|
|
|
int output_index = 0; |
|
|
|
for (int b = 0; b < B; ++b) { |
|
__global const half *restrict seq_ind = sequence_indicators + b * T; |
|
const int seq_len = find(seq_ind + 1, seq_ind + T, 0.h) - seq_ind; |
|
const int time = min(seq_len, T); |
|
|
|
int prev_class_idx = -1; |
|
|
|
#pragma unroll 4 |
|
for (int t = 0; t < time; ++t) { |
|
__local const half *restrict probs = local_src + b * C + t * C * B; |
|
|
|
int max_class_idx = 0; |
|
half max_prob = probs[0]; |
|
for (int c = 1; c < C; ++c) { |
|
const half prob = probs[c]; |
|
if (prob > max_prob) { |
|
max_class_idx = c; |
|
max_prob = prob; |
|
} |
|
} |
|
|
|
if (max_class_idx < C - 1 && max_class_idx != prev_class_idx) { |
|
local_dst[b * T + output_index] = (half)max_class_idx; |
|
output_index++; |
|
} |
|
|
|
prev_class_idx = max_class_idx; |
|
} |
|
} |
|
|
|
barrier(CLK_LOCAL_MEM_FENCE); |
|
|
|
event_t e2 = async_work_group_copy_2D2D( |
|
output, |
|
local_dst, |
|
channels, |
|
height, |
|
0, |
|
0, |
|
0); |
|
|
|
wait_group_events(1, &e2); |
|
} |
|
|