Spaces:
Runtime error
Runtime error
__global__ void index_max_cuda_kernel( | |
float *index_vals, // [batch_size, 32, num_block] | |
int *indices, // [batch_size, num_block] | |
float *max_vals, // [batch_size, A_num_block * 32] | |
float *max_vals_scatter, // [batch_size, 32, num_block] | |
long batch_size, | |
long A_num_block, | |
long B_num_block, | |
long num_block | |
); | |
__global__ void mm_to_sparse_cuda_kernel( | |
float *dense_A, // [batch_size, A_num_block, dim, 32] | |
float *dense_B, // [batch_size, B_num_block, dim, 32] | |
int *indices, // [batch_size, num_block] | |
float *sparse_C, // [batch_size, num_block, 32, 32] | |
long batch_size, | |
long A_num_block, | |
long B_num_block, | |
long dim, | |
long num_block | |
); | |
__global__ void sparse_dense_mm_cuda_kernel( | |
float *sparse_A, // [batch_size, num_block, 32, 32] | |
int *indices, // [batch_size, num_block] | |
float *dense_B, // [batch_size, B_num_block, dim, 32] | |
float *dense_C, // [batch_size, A_num_block, dim, 32] | |
long batch_size, | |
long A_num_block, | |
long B_num_block, | |
long dim, | |
long num_block | |
); | |
__global__ void reduce_sum_cuda_kernel( | |
float *sparse_A, // [batch_size, num_block, 32, 32] | |
int *indices, // [batch_size, num_block] | |
float *dense_C, // [batch_size, A_num_block, 32] | |
long batch_size, | |
long A_num_block, | |
long B_num_block, | |
long num_block | |
); | |
__global__ void scatter_cuda_kernel( | |
float *dense_A, // [batch_size, A_num_block, 32] | |
int *indices, // [batch_size, num_block] | |
float *sparse_C, // [batch_size, num_block, 32, 32] | |
long batch_size, | |
long A_num_block, | |
long B_num_block, | |
long num_block | |
); | |