#include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { // requires CUDA >= 10 and ARCH >= 70 // this is very slow compared to float or __half2, never use it. //return atomicAdd(reinterpret_cast<__half*>(address), val); } template __host__ __device__ inline T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } template __host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { return min(max(v, lo), hi); } template __device__ inline T smoothstep(T val) { return val*val*(3.0f - 2.0f * val); } template __device__ inline T smoothstep_derivative(T val) { return 6*val*(1.0f - val); } template __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { // coherent type of hashing constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; uint32_t result = 0; #pragma unroll for (uint32_t i = 0; i < D; ++i) { result ^= pos_grid[i] * primes[i]; } return result; } template __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { uint32_t stride = 1; uint32_t index = 0; #pragma unroll for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { index += pos_grid[d] * stride; stride *= align_corners ? resolution: (resolution + 1); } // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. // gridtype: 0 == hash, 1 == tiled if (gridtype == 0 && stride > hashmap_size) { index = fast_hash(pos_grid); } return (index % hashmap_size) * C + ch; } template __global__ void kernel_grid( const float * __restrict__ inputs, const scalar_t * __restrict__ grid, const int * __restrict__ offsets, scalar_t * __restrict__ outputs, const uint32_t B, const uint32_t L, const float S, const uint32_t H, scalar_t * __restrict__ dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp ) { const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; if (b >= B) return; const uint32_t level = blockIdx.y; // locate grid += (uint32_t)offsets[level] * C; inputs += b * D; outputs += level * B * C + b * C; // check input range (should be in [0, 1]) bool flag_oob = false; #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { flag_oob = true; } } // if input out of bound, just set output to 0 if (flag_oob) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = 0; } if (dy_dx) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t d = 0; d < D; d++) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[d * C + ch] = 0; } } } return; } const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // calculate coordinate (always use float for precision!) float pos[D]; float pos_deriv[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; // smoothstep instead of linear if (interp == 1) { pos_deriv[d] = smoothstep_derivative(pos[d]); pos[d] = smoothstep(pos[d]); } else { pos_deriv[d] = 1.0f; // linear deriv is default to 1 } } //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); // interpolate scalar_t results[C] = {0}; // temp results in register #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); // writing to register (fast) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results[ch] += w * grid[index + ch]; } //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); } // writing to global memory (slow) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = results[ch]; } // prepare dy_dx // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 if (dy_dx) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t gd = 0; gd < D; gd++) { scalar_t results_grad[C] = {0}; #pragma unroll for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { float w = scale; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t nd = 0; nd < D - 1; nd++) { const uint32_t d = (nd >= gd) ? (nd + 1) : nd; if ((idx & (1 << nd)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } pos_grid_local[gd] = pos_grid[gd]; uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); pos_grid_local[gd] = pos_grid[gd] + 1; uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; } } #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[gd * C + ch] = results_grad[ch]; } } } } template __global__ void kernel_grid_backward( const scalar_t * __restrict__ grad, const float * __restrict__ inputs, const scalar_t * __restrict__ grid, const int * __restrict__ offsets, scalar_t * __restrict__ grad_grid, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners, const uint32_t interp ) { const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; if (b >= B) return; const uint32_t level = blockIdx.y; const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; // locate grad_grid += offsets[level] * C; inputs += b * D; grad += level * B * C + b * C + ch; // L, B, C const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // check input range (should be in [0, 1]) #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { return; // grad is init as 0, so we simply return. } } // calculate coordinate float pos[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; // smoothstep instead of linear if (interp == 1) { pos[d] = smoothstep(pos[d]); } } scalar_t grad_cur[N_C] = {0}; // fetch to register #pragma unroll for (uint32_t c = 0; c < N_C; c++) { grad_cur[c] = grad[c]; } // interpolate #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 // TODO: use float which is better than __half, if N_C % 2 != 0 if (std::is_same::value && N_C % 2 == 0) { #pragma unroll for (uint32_t c = 0; c < N_C; c += 2) { // process two __half at once (by interpreting as a __half2) __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; atomicAdd((__half2*)&grad_grid[index + c], v); } // float, or __half when N_C % 2 != 0 (which means C == 1) } else { #pragma unroll for (uint32_t c = 0; c < N_C; c++) { atomicAdd(&grad_grid[index + c], w * grad_cur[c]); } } } } template __global__ void kernel_input_backward( const scalar_t * __restrict__ grad, const scalar_t * __restrict__ dy_dx, scalar_t * __restrict__ grad_inputs, uint32_t B, uint32_t L ) { const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; if (t >= B * D) return; const uint32_t b = t / D; const uint32_t d = t - b * D; dy_dx += b * L * D * C; scalar_t result = 0; # pragma unroll for (int l = 0; l < L; l++) { # pragma unroll for (int ch = 0; ch < C; ch++) { result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; } } grad_inputs[t] = result; } template void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { static constexpr uint32_t N_THREAD = 512; const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; switch (C) { case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) // H: base resolution // dy_dx: [B, L * D * C] template void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { switch (D) { case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } template void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { static constexpr uint32_t N_THREAD = 256; const uint32_t N_C = std::min(2u, C); // n_features_per_thread const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; switch (C) { case 1: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 2: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 4: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 8: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // grad: [L, B, C], float // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // grad_embeddings: [sO, C] // H: base resolution template void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { switch (D) { case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(outputs); // CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(outputs); // CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(outputs); // CHECK_IS_FLOATING(dy_dx); AT_DISPATCH_FLOATING_TYPES_AND_HALF( embeddings.scalar_type(), "grid_encode_forward", ([&] { grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); })); } void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(grad_embeddings); // CHECK_CUDA(dy_dx); // CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(grad_embeddings); // CHECK_CONTIGUOUS(dy_dx); // CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(grad_embeddings); // CHECK_IS_FLOATING(dy_dx); // CHECK_IS_FLOATING(grad_inputs); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "grid_encode_backward", ([&] { grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); })); } template __global__ void kernel_grad_tv( const scalar_t * __restrict__ inputs, const scalar_t * __restrict__ grid, scalar_t * __restrict__ grad, const int * __restrict__ offsets, const float weight, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners ) { const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; if (b >= B) return; const uint32_t level = blockIdx.y; // locate inputs += b * D; grid += (uint32_t)offsets[level] * C; grad += (uint32_t)offsets[level] * C; // check input range (should be in [0, 1]) bool flag_oob = false; #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { flag_oob = true; } } // if input out of bound, do nothing if (flag_oob) return; const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // calculate coordinate float pos[D]; uint32_t pos_grid[D]; // [0, resolution] #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); // pos[d] -= (float)pos_grid[d]; // not used } //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); // total variation on pos_grid scalar_t results[C] = {0}; // temp results in register scalar_t idelta[C] = {0}; uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); scalar_t w = weight / (2 * D); #pragma unroll for (uint32_t d = 0; d < D; d++) { uint32_t cur_d = pos_grid[d]; scalar_t grad_val; // right side if (cur_d < resolution) { pos_grid[d] = cur_d + 1; uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); grad_val = (grid[index + ch] - grid[index_right + ch]); results[ch] += grad_val; idelta[ch] += grad_val * grad_val; } } // left side if (cur_d > 0) { pos_grid[d] = cur_d - 1; uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); grad_val = (grid[index + ch] - grid[index_left + ch]); results[ch] += grad_val; idelta[ch] += grad_val * grad_val; } } // reset pos_grid[d] = cur_d; } // writing to global memory (slow) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // index may collide, so use atomic! atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); } } template void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { static constexpr uint32_t N_THREAD = 512; const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; switch (C) { case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } template void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { switch (D) { case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( embeddings.scalar_type(), "grad_total_variation", ([&] { grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); })); }