|
#include <cuda.h> |
|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
|
|
#include <ATen/cuda/CUDAContext.h> |
|
#include <torch/torch.h> |
|
|
|
#include <algorithm> |
|
#include <stdexcept> |
|
|
|
#include <stdint.h> |
|
#include <cstdio> |
|
|
|
|
|
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") |
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") |
|
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") |
|
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") |
|
|
|
|
|
|
|
__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { |
|
|
|
|
|
|
|
} |
|
|
|
|
|
template <typename T> |
|
__host__ __device__ inline T div_round_up(T val, T divisor) { |
|
return (val + divisor - 1) / divisor; |
|
} |
|
|
|
template <typename T, typename T2> |
|
__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { |
|
return min(max(v, lo), hi); |
|
} |
|
|
|
template <typename T> |
|
__device__ inline T smoothstep(T val) { |
|
return val*val*(3.0f - 2.0f * val); |
|
} |
|
|
|
template <typename T> |
|
__device__ inline T smoothstep_derivative(T val) { |
|
return 6*val*(1.0f - val); |
|
} |
|
|
|
|
|
template <uint32_t D> |
|
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { |
|
|
|
|
|
constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; |
|
|
|
uint32_t result = 0; |
|
#pragma unroll |
|
for (uint32_t i = 0; i < D; ++i) { |
|
result ^= pos_grid[i] * primes[i]; |
|
} |
|
|
|
return result; |
|
} |
|
|
|
|
|
template <uint32_t D, uint32_t C> |
|
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { |
|
uint32_t stride = 1; |
|
uint32_t index = 0; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { |
|
index += pos_grid[d] * stride; |
|
stride *= align_corners ? resolution: (resolution + 1); |
|
} |
|
|
|
|
|
|
|
if (gridtype == 0 && stride > hashmap_size) { |
|
index = fast_hash<D>(pos_grid); |
|
} |
|
|
|
return (index % hashmap_size) * C + ch; |
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D, uint32_t C> |
|
__global__ void kernel_grid( |
|
const float * __restrict__ inputs, |
|
const scalar_t * __restrict__ grid, |
|
const int * __restrict__ offsets, |
|
scalar_t * __restrict__ outputs, |
|
const uint32_t B, const uint32_t L, const float S, const uint32_t H, |
|
scalar_t * __restrict__ dy_dx, |
|
const uint32_t gridtype, |
|
const bool align_corners, |
|
const uint32_t interp |
|
) { |
|
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
if (b >= B) return; |
|
|
|
const uint32_t level = blockIdx.y; |
|
|
|
|
|
grid += (uint32_t)offsets[level] * C; |
|
inputs += b * D; |
|
outputs += level * B * C + b * C; |
|
|
|
|
|
bool flag_oob = false; |
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
if (inputs[d] < 0 || inputs[d] > 1) { |
|
flag_oob = true; |
|
} |
|
} |
|
|
|
if (flag_oob) { |
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
outputs[ch] = 0; |
|
} |
|
if (dy_dx) { |
|
dy_dx += b * D * L * C + level * D * C; |
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
dy_dx[d * C + ch] = 0; |
|
} |
|
} |
|
} |
|
return; |
|
} |
|
|
|
const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; |
|
const float scale = exp2f(level * S) * H - 1.0f; |
|
const uint32_t resolution = (uint32_t)ceil(scale) + 1; |
|
|
|
|
|
float pos[D]; |
|
float pos_deriv[D]; |
|
uint32_t pos_grid[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); |
|
pos_grid[d] = floorf(pos[d]); |
|
pos[d] -= (float)pos_grid[d]; |
|
|
|
if (interp == 1) { |
|
pos_deriv[d] = smoothstep_derivative(pos[d]); |
|
pos[d] = smoothstep(pos[d]); |
|
} else { |
|
pos_deriv[d] = 1.0f; |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
scalar_t results[C] = {0}; |
|
|
|
#pragma unroll |
|
for (uint32_t idx = 0; idx < (1 << D); idx++) { |
|
float w = 1; |
|
uint32_t pos_grid_local[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
if ((idx & (1 << d)) == 0) { |
|
w *= 1 - pos[d]; |
|
pos_grid_local[d] = pos_grid[d]; |
|
} else { |
|
w *= pos[d]; |
|
pos_grid_local[d] = pos_grid[d] + 1; |
|
} |
|
} |
|
|
|
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); |
|
|
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
results[ch] += w * grid[index + ch]; |
|
} |
|
|
|
|
|
} |
|
|
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
outputs[ch] = results[ch]; |
|
} |
|
|
|
|
|
|
|
if (dy_dx) { |
|
|
|
dy_dx += b * D * L * C + level * D * C; |
|
|
|
#pragma unroll |
|
for (uint32_t gd = 0; gd < D; gd++) { |
|
|
|
scalar_t results_grad[C] = {0}; |
|
|
|
#pragma unroll |
|
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { |
|
float w = scale; |
|
uint32_t pos_grid_local[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t nd = 0; nd < D - 1; nd++) { |
|
const uint32_t d = (nd >= gd) ? (nd + 1) : nd; |
|
|
|
if ((idx & (1 << nd)) == 0) { |
|
w *= 1 - pos[d]; |
|
pos_grid_local[d] = pos_grid[d]; |
|
} else { |
|
w *= pos[d]; |
|
pos_grid_local[d] = pos_grid[d] + 1; |
|
} |
|
} |
|
|
|
pos_grid_local[gd] = pos_grid[gd]; |
|
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); |
|
pos_grid_local[gd] = pos_grid[gd] + 1; |
|
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); |
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; |
|
} |
|
} |
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
dy_dx[gd * C + ch] = results_grad[ch]; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C> |
|
__global__ void kernel_grid_backward( |
|
const scalar_t * __restrict__ grad, |
|
const float * __restrict__ inputs, |
|
const scalar_t * __restrict__ grid, |
|
const int * __restrict__ offsets, |
|
scalar_t * __restrict__ grad_grid, |
|
const uint32_t B, const uint32_t L, const float S, const uint32_t H, |
|
const uint32_t gridtype, |
|
const bool align_corners, |
|
const uint32_t interp |
|
) { |
|
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; |
|
if (b >= B) return; |
|
|
|
const uint32_t level = blockIdx.y; |
|
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; |
|
|
|
|
|
grad_grid += offsets[level] * C; |
|
inputs += b * D; |
|
grad += level * B * C + b * C + ch; |
|
|
|
const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; |
|
const float scale = exp2f(level * S) * H - 1.0f; |
|
const uint32_t resolution = (uint32_t)ceil(scale) + 1; |
|
|
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
if (inputs[d] < 0 || inputs[d] > 1) { |
|
return; |
|
} |
|
} |
|
|
|
|
|
float pos[D]; |
|
uint32_t pos_grid[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); |
|
pos_grid[d] = floorf(pos[d]); |
|
pos[d] -= (float)pos_grid[d]; |
|
|
|
if (interp == 1) { |
|
pos[d] = smoothstep(pos[d]); |
|
} |
|
} |
|
|
|
scalar_t grad_cur[N_C] = {0}; |
|
#pragma unroll |
|
for (uint32_t c = 0; c < N_C; c++) { |
|
grad_cur[c] = grad[c]; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (uint32_t idx = 0; idx < (1 << D); idx++) { |
|
float w = 1; |
|
uint32_t pos_grid_local[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
if ((idx & (1 << d)) == 0) { |
|
w *= 1 - pos[d]; |
|
pos_grid_local[d] = pos_grid[d]; |
|
} else { |
|
w *= pos[d]; |
|
pos_grid_local[d] = pos_grid[d] + 1; |
|
} |
|
} |
|
|
|
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); |
|
|
|
|
|
|
|
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) { |
|
#pragma unroll |
|
for (uint32_t c = 0; c < N_C; c += 2) { |
|
|
|
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; |
|
atomicAdd((__half2*)&grad_grid[index + c], v); |
|
} |
|
|
|
} else { |
|
#pragma unroll |
|
for (uint32_t c = 0; c < N_C; c++) { |
|
atomicAdd(&grad_grid[index + c], w * grad_cur[c]); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D, uint32_t C> |
|
__global__ void kernel_input_backward( |
|
const scalar_t * __restrict__ grad, |
|
const scalar_t * __restrict__ dy_dx, |
|
scalar_t * __restrict__ grad_inputs, |
|
uint32_t B, uint32_t L |
|
) { |
|
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; |
|
if (t >= B * D) return; |
|
|
|
const uint32_t b = t / D; |
|
const uint32_t d = t - b * D; |
|
|
|
dy_dx += b * L * D * C; |
|
|
|
scalar_t result = 0; |
|
|
|
# pragma unroll |
|
for (int l = 0; l < L; l++) { |
|
# pragma unroll |
|
for (int ch = 0; ch < C; ch++) { |
|
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; |
|
} |
|
} |
|
|
|
grad_inputs[t] = result; |
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D> |
|
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
static constexpr uint32_t N_THREAD = 512; |
|
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; |
|
switch (C) { |
|
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
switch (D) { |
|
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
template <typename scalar_t, uint32_t D> |
|
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
static constexpr uint32_t N_THREAD = 256; |
|
const uint32_t N_C = std::min(2u, C); |
|
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; |
|
switch (C) { |
|
case 1: |
|
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); |
|
if (dy_dx) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); |
|
break; |
|
case 2: |
|
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); |
|
if (dy_dx) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); |
|
break; |
|
case 4: |
|
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); |
|
if (dy_dx) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); |
|
break; |
|
case 8: |
|
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); |
|
if (dy_dx) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L); |
|
break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename scalar_t> |
|
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
switch (D) { |
|
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; |
|
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; |
|
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; |
|
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
|
|
|
|
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional<at::Tensor> dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
CHECK_CUDA(inputs); |
|
CHECK_CUDA(embeddings); |
|
CHECK_CUDA(offsets); |
|
CHECK_CUDA(outputs); |
|
|
|
|
|
CHECK_CONTIGUOUS(inputs); |
|
CHECK_CONTIGUOUS(embeddings); |
|
CHECK_CONTIGUOUS(offsets); |
|
CHECK_CONTIGUOUS(outputs); |
|
|
|
|
|
CHECK_IS_FLOATING(inputs); |
|
CHECK_IS_FLOATING(embeddings); |
|
CHECK_IS_INT(offsets); |
|
CHECK_IS_FLOATING(outputs); |
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
embeddings.scalar_type(), "grid_encode_forward", ([&] { |
|
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp); |
|
})); |
|
} |
|
|
|
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional<at::Tensor> dy_dx, at::optional<at::Tensor> grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { |
|
CHECK_CUDA(grad); |
|
CHECK_CUDA(inputs); |
|
CHECK_CUDA(embeddings); |
|
CHECK_CUDA(offsets); |
|
CHECK_CUDA(grad_embeddings); |
|
|
|
|
|
|
|
CHECK_CONTIGUOUS(grad); |
|
CHECK_CONTIGUOUS(inputs); |
|
CHECK_CONTIGUOUS(embeddings); |
|
CHECK_CONTIGUOUS(offsets); |
|
CHECK_CONTIGUOUS(grad_embeddings); |
|
|
|
|
|
|
|
CHECK_IS_FLOATING(grad); |
|
CHECK_IS_FLOATING(inputs); |
|
CHECK_IS_FLOATING(embeddings); |
|
CHECK_IS_INT(offsets); |
|
CHECK_IS_FLOATING(grad_embeddings); |
|
|
|
|
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
grad.scalar_type(), "grid_encode_backward", ([&] { |
|
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr<scalar_t>() : nullptr, gridtype, align_corners, interp); |
|
})); |
|
|
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D, uint32_t C> |
|
__global__ void kernel_grad_tv( |
|
const scalar_t * __restrict__ inputs, |
|
const scalar_t * __restrict__ grid, |
|
scalar_t * __restrict__ grad, |
|
const int * __restrict__ offsets, |
|
const float weight, |
|
const uint32_t B, const uint32_t L, const float S, const uint32_t H, |
|
const uint32_t gridtype, |
|
const bool align_corners |
|
) { |
|
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
if (b >= B) return; |
|
|
|
const uint32_t level = blockIdx.y; |
|
|
|
|
|
inputs += b * D; |
|
grid += (uint32_t)offsets[level] * C; |
|
grad += (uint32_t)offsets[level] * C; |
|
|
|
|
|
bool flag_oob = false; |
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
if (inputs[d] < 0 || inputs[d] > 1) { |
|
flag_oob = true; |
|
} |
|
} |
|
|
|
|
|
if (flag_oob) return; |
|
|
|
const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; |
|
const float scale = exp2f(level * S) * H - 1.0f; |
|
const uint32_t resolution = (uint32_t)ceil(scale) + 1; |
|
|
|
|
|
float pos[D]; |
|
uint32_t pos_grid[D]; |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); |
|
pos_grid[d] = floorf(pos[d]); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
scalar_t results[C] = {0}; |
|
scalar_t idelta[C] = {0}; |
|
|
|
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); |
|
|
|
scalar_t w = weight / (2 * D); |
|
|
|
#pragma unroll |
|
for (uint32_t d = 0; d < D; d++) { |
|
|
|
uint32_t cur_d = pos_grid[d]; |
|
scalar_t grad_val; |
|
|
|
|
|
if (cur_d < resolution) { |
|
pos_grid[d] = cur_d + 1; |
|
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); |
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
|
|
grad_val = (grid[index + ch] - grid[index_right + ch]); |
|
results[ch] += grad_val; |
|
idelta[ch] += grad_val * grad_val; |
|
} |
|
} |
|
|
|
|
|
if (cur_d > 0) { |
|
pos_grid[d] = cur_d - 1; |
|
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); |
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
|
|
grad_val = (grid[index + ch] - grid[index_left + ch]); |
|
results[ch] += grad_val; |
|
idelta[ch] += grad_val * grad_val; |
|
} |
|
} |
|
|
|
|
|
pos_grid[d] = cur_d; |
|
} |
|
|
|
|
|
#pragma unroll |
|
for (uint32_t ch = 0; ch < C; ch++) { |
|
|
|
atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); |
|
} |
|
|
|
} |
|
|
|
|
|
template <typename scalar_t, uint32_t D> |
|
void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { |
|
static constexpr uint32_t N_THREAD = 512; |
|
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; |
|
switch (C) { |
|
case 1: kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; |
|
case 2: kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; |
|
case 4: kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; |
|
case 8: kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
|
|
template <typename scalar_t> |
|
void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { |
|
switch (D) { |
|
case 2: kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; |
|
case 3: kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; |
|
case 4: kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; |
|
case 5: kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; |
|
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; |
|
} |
|
} |
|
|
|
|
|
void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { |
|
|
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
|
embeddings.scalar_type(), "grad_total_variation", ([&] { |
|
grad_total_variation_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), embeddings.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(), weight, B, D, C, L, S, H, gridtype, align_corners); |
|
})); |
|
} |