diff --git a/.gitignore b/.gitignore index 1e1c4ca8..749e6d4d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ build/ diff_gaussian_rasterization.egg-info/ dist/ +__pycache__/ +_C.cpython* diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 4d4b9b78..33620000 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -16,7 +16,8 @@ #include "stdio.h" #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) -#define NUM_WARPS (BLOCK_SIZE/32) +#define WARP_SIZE 32 +#define NUM_WARPS (BLOCK_SIZE/WARP_SIZE) // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; @@ -99,7 +100,7 @@ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, cons __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32; return dnormvdz; } @@ -107,7 +108,7 @@ __forceinline__ __device__ float dnormvdz(float3 v, float3 dv) __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float3 dnormvdv; dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32; @@ -119,7 +120,7 @@ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv) __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) { float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w; - float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2); + float invsum32 = rsqrtf(sum2 * sum2 * sum2); float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w }; float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w; diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 4aa41e1c..2ba6de3d 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -396,6 +396,7 @@ __global__ void preprocessCUDA( } // Backward version of the rendering procedure. +#define USE_ATOMIC_THRESHOLD 10 template __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y) renderCUDA( @@ -408,6 +409,7 @@ renderCUDA( const float* __restrict__ colors, const float* __restrict__ final_Ts, const uint32_t* __restrict__ n_contrib, + const uint32_t* __restrict__ tiles_touched, const float* __restrict__ dL_dpixels, float3* __restrict__ dL_dmean2D, float4* __restrict__ dL_dconic2D, @@ -428,10 +430,10 @@ renderCUDA( const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE); - bool done = !inside; int toDo = range.y - range.x; __shared__ int collected_id[BLOCK_SIZE]; + __shared__ bool collected_use_atomic[BLOCK_SIZE]; __shared__ float2 collected_xy[BLOCK_SIZE]; __shared__ float4 collected_conic_opacity[BLOCK_SIZE]; __shared__ float collected_colors[C * BLOCK_SIZE]; @@ -470,6 +472,9 @@ renderCUDA( if (range.x + progress < range.y) { const int coll_id = point_list[range.y - progress - 1]; + const int cur_tiles_touched = tiles_touched[coll_id]; + bool cur_use_atomic = cur_tiles_touched <= USE_ATOMIC_THRESHOLD; + collected_use_atomic[block.thread_rank()] = cur_use_atomic; collected_id[block.thread_rank()] = coll_id; collected_xy[block.thread_rank()] = points_xy_image[coll_id]; collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id]; @@ -478,80 +483,193 @@ renderCUDA( } block.sync(); + static constexpr int REDUCTION_BATCH_SIZE = 16; + int cur_reduction_batch_idx = 0; + __shared__ int batch_j[REDUCTION_BATCH_SIZE]; + __shared__ float batch_dL_dcolors[REDUCTION_BATCH_SIZE][NUM_WARPS][C]; + __shared__ float2 batch_dL_dmean2D[REDUCTION_BATCH_SIZE][NUM_WARPS]; + __shared__ float4 batch_dL_dconic2D_dopacity[REDUCTION_BATCH_SIZE][NUM_WARPS]; + // Iterate over Gaussians - for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++) + for (int j = 0; j < min(BLOCK_SIZE, toDo); j++) { // Keep track of current Gaussian ID. Skip, if this one // is behind the last contributor for this pixel. + float cur_dL_dcolors[C] = {0}; + float2 cur_dL_dmean2D = {0, 0}; + float4 cur_dL_dconic2D_dopacity = {0, 0, 0, 0}; + const bool use_atomic = collected_use_atomic[j]; + contributor--; - if (contributor >= last_contributor) - continue; - - // Compute blending values, as before. - const float2 xy = collected_xy[j]; - const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; - const float4 con_o = collected_conic_opacity[j]; - const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; - if (power > 0.0f) - continue; - - const float G = exp(power); - const float alpha = min(0.99f, con_o.w * G); - if (alpha < 1.0f / 255.0f) - continue; - - T = T / (1.f - alpha); - const float dchannel_dcolor = alpha * T; - - // Propagate gradients to per-Gaussian colors and keep - // gradients w.r.t. alpha (blending factor for a Gaussian/pixel - // pair). - float dL_dalpha = 0.0f; - const int global_id = collected_id[j]; - for (int ch = 0; ch < C; ch++) - { - const float c = collected_colors[ch * BLOCK_SIZE + j]; - // Update last color (to be used in the next iteration) - accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; - last_color[ch] = c; - - const float dL_dchannel = dL_dpixel[ch]; - dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; - // Update the gradients w.r.t. color of the Gaussian. - // Atomic, since this pixel is just one of potentially - // many that were affected by this Gaussian. - atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel); + if (inside && contributor < last_contributor) { + // Compute blending values, as before. + const float2 xy = collected_xy[j]; + const float2 d = { xy.x - pixf.x, xy.y - pixf.y }; + const float4 con_o = collected_conic_opacity[j]; + const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y; + + const float G = exp(power); + const float alpha = min(0.99f, con_o.w * G); + if (power <= 0.0f && alpha >= 1.0f / 255.0f) { + T = T / (1.f - alpha); + const float dchannel_dcolor = alpha * T; + + // Propagate gradients to per-Gaussian colors and keep + // gradients w.r.t. alpha (blending factor for a Gaussian/pixel + // pair). + float dL_dalpha = 0.0f; + const int global_id = collected_id[j]; + #pragma unroll + for (int ch = 0; ch < C; ch++) + { + const float c = collected_colors[ch * BLOCK_SIZE + j]; + // Update last color (to be used in the next iteration) + accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch]; + last_color[ch] = c; + + const float dL_dchannel = dL_dpixel[ch]; + dL_dalpha += (c - accum_rec[ch]) * dL_dchannel; + // Update the gradients w.r.t. color of the Gaussian. + // Atomic, since this pixel is just one of potentially + // many that were affected by this Gaussian. + if (use_atomic) { + atomicAdd(&dL_dcolors[global_id*C + ch], dchannel_dcolor * dL_dchannel); + } else { + cur_dL_dcolors[ch] = dchannel_dcolor * dL_dchannel; + } + } + dL_dalpha *= T; + // Update last alpha (to be used in the next iteration) + last_alpha = alpha; + + // Account for fact that alpha also influences how much of + // the background color is added if nothing left to blend + float bg_dot_dpixel = 0; + #pragma unroll + for (int i = 0; i < C; i++) + bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; + dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; + + // Helpful reusable temporary variables + const float dL_dG = con_o.w * dL_dalpha; + const float gdx = G * d.x; + const float gdy = G * d.y; + const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; + const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; + + if (use_atomic) { + // Update gradients w.r.t. 2D mean position of the Gaussian + atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); + atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); + // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) + atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); + atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); + atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); + // Update gradients w.r.t. opacity of the Gaussian + atomicAdd(&dL_dopacity[global_id], G * dL_dalpha); + } else { + cur_dL_dmean2D = { + dL_dG * dG_ddelx * ddelx_dx, + dL_dG * dG_ddely * ddely_dy + }; + cur_dL_dconic2D_dopacity = { + -0.5f * gdx * d.x * dL_dG, + -0.5f * gdx * d.y * dL_dG, + G * dL_dalpha, + -0.5f * gdy * d.y * dL_dG + }; + } + } } - dL_dalpha *= T; - // Update last alpha (to be used in the next iteration) - last_alpha = alpha; - - // Account for fact that alpha also influences how much of - // the background color is added if nothing left to blend - float bg_dot_dpixel = 0; - for (int i = 0; i < C; i++) - bg_dot_dpixel += bg_color[i] * dL_dpixel[i]; - dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel; - - // Helpful reusable temporary variables - const float dL_dG = con_o.w * dL_dalpha; - const float gdx = G * d.x; - const float gdy = G * d.y; - const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y; - const float dG_ddely = -gdy * con_o.z - gdx * con_o.y; - - // Update gradients w.r.t. 2D mean position of the Gaussian - atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx); - atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy); - - // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric) - atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG); - atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG); - atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG); + if (!use_atomic) { + // Perform warp-level reduction + #pragma unroll + for (int offset = 32/2; offset > 0; offset /= 2) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset); + cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset); + cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset); + cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset); + cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset); + cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset); + cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset); + } + + // Store the results in shared memory + if (block.thread_rank() % WARP_SIZE == 0) + { + int warp_id = block.thread_rank() / WARP_SIZE; + batch_j[cur_reduction_batch_idx] = j; + #pragma unroll + for (int ch = 0; ch < C; ch++) + batch_dL_dcolors[cur_reduction_batch_idx][warp_id][ch] = cur_dL_dcolors[ch]; + batch_dL_dmean2D[cur_reduction_batch_idx][warp_id] = cur_dL_dmean2D; + batch_dL_dconic2D_dopacity[cur_reduction_batch_idx][warp_id] = cur_dL_dconic2D_dopacity; + } + cur_reduction_batch_idx += 1; + } - // Update gradients w.r.t. opacity of the Gaussian - atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha); + // If this is the last Gaussian in the batch, perform block-level + // reduction and store the results in global memory. + if (cur_reduction_batch_idx == REDUCTION_BATCH_SIZE || (j == min(BLOCK_SIZE, toDo) - 1 && cur_reduction_batch_idx != 0)) + { + // Make sure we can perform this reduction with one warp + static_assert(NUM_WARPS <= WARP_SIZE); + // Make sure the number of warps is a power of 2 + static_assert((NUM_WARPS & (NUM_WARPS - 1)) == 0); + + // Wait for all warps to finish storing + block.sync(); + + for (int batch_id = block.thread_rank() / WARP_SIZE; batch_id < cur_reduction_batch_idx; batch_id += NUM_WARPS) { + int lane_id = block.thread_rank() % WARP_SIZE; + + // Perform warp-level reduction + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] = lane_id < NUM_WARPS ? batch_dL_dcolors[batch_id][lane_id][ch] : 0, + cur_dL_dmean2D = lane_id < NUM_WARPS ? batch_dL_dmean2D[batch_id][lane_id] : float2{0, 0}, + cur_dL_dconic2D_dopacity = lane_id < NUM_WARPS ? batch_dL_dconic2D_dopacity[batch_id][lane_id] : float4{0, 0, 0, 0}; + + #pragma unroll + for (int offset = NUM_WARPS/2; offset > 0; offset /= 2) { + #pragma unroll + for (int ch = 0; ch < C; ch++) + cur_dL_dcolors[ch] += __shfl_down_sync(0xFFFFFFFF, cur_dL_dcolors[ch], offset); + cur_dL_dmean2D.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.x, offset); + cur_dL_dmean2D.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dmean2D.y, offset); + cur_dL_dconic2D_dopacity.x += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.x, offset); + cur_dL_dconic2D_dopacity.y += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.y, offset); + cur_dL_dconic2D_dopacity.z += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.z, offset); + cur_dL_dconic2D_dopacity.w += __shfl_down_sync(0xFFFFFFFF, cur_dL_dconic2D_dopacity.w, offset); + } + + // Store the results in global memory + if (lane_id == 0) + { + const int global_id = collected_id[batch_j[batch_id]]; + // if (global_id < 0 || global_id >= 208424) + // printf("%d\n", global_id); + #pragma unroll + for (int ch = 0; ch < C; ch++) + atomicAdd(&dL_dcolors[global_id * C + ch], cur_dL_dcolors[ch]); + atomicAdd(&dL_dmean2D[global_id].x, cur_dL_dmean2D.x); + atomicAdd(&dL_dmean2D[global_id].y, cur_dL_dmean2D.y); + atomicAdd(&dL_dconic2D[global_id].x, cur_dL_dconic2D_dopacity.x); + atomicAdd(&dL_dconic2D[global_id].y, cur_dL_dconic2D_dopacity.y); + atomicAdd(&dL_dconic2D[global_id].w, cur_dL_dconic2D_dopacity.w); + atomicAdd(&dL_dopacity[global_id], cur_dL_dconic2D_dopacity.z); + } + } + + // Wait for all warps to finish reducing + if (j != min(BLOCK_SIZE, toDo) - 1) + block.sync(); + + cur_reduction_batch_idx = 0; + } } } } @@ -632,6 +750,7 @@ void BACKWARD::render( const float* colors, const float* final_Ts, const uint32_t* n_contrib, + const uint32_t* tiles_touched, const float* dL_dpixels, float3* dL_dmean2D, float4* dL_dconic2D, @@ -648,10 +767,11 @@ void BACKWARD::render( colors, final_Ts, n_contrib, + tiles_touched, dL_dpixels, dL_dmean2D, dL_dconic2D, dL_dopacity, dL_dcolors ); -} \ No newline at end of file +} diff --git a/cuda_rasterizer/backward.h b/cuda_rasterizer/backward.h index 93dd2e4b..d39e0bbd 100644 --- a/cuda_rasterizer/backward.h +++ b/cuda_rasterizer/backward.h @@ -31,6 +31,7 @@ namespace BACKWARD const float* colors, const float* final_Ts, const uint32_t* n_contrib, + const uint32_t* tiles_touched, const float* dL_dpixels, float3* dL_dmean2D, float4* dL_dconic2D, diff --git a/cuda_rasterizer/rasterizer_impl.cu b/cuda_rasterizer/rasterizer_impl.cu index f8782ac4..c3931845 100644 --- a/cuda_rasterizer/rasterizer_impl.cu +++ b/cuda_rasterizer/rasterizer_impl.cu @@ -399,6 +399,7 @@ void CudaRasterizer::Rasterizer::backward( color_ptr, imgState.accum_alpha, imgState.n_contrib, + geomState.tiles_touched, dL_dpix, (float3*)dL_dmean2D, (float4*)dL_dconic,