Skip to content

Commit 1a863bd

Browse files
Fine-Grained Load Balancing with Warp-per-Pixel Rendering (#162)
* load balancing branch of render kernel * fix: address code review feedback for load balancing feature
1 parent e55c114 commit 1a863bd

File tree

7 files changed

+352
-1
lines changed

7 files changed

+352
-1
lines changed

configs/render/3dgut.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ splat: # 3DGUT-specific settings
2424
# rendering
2525
k_buffer_size: 0 # 0 means unsorted
2626
global_z_order: true
27+
fine_grained_load_balancing: false

threedgut_tracer/include/3dgut/kernels/cuda/common/rayPayload.cuh

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,43 @@ __device__ __inline__ RayPayloadT initializeRay(const threedgut::RenderParameter
107107
return ray;
108108
}
109109

110+
111+
// Initialize ray based on given pixel coordinates (load-balanced mode)
112+
template <typename RayPayloadT>
113+
__device__ __inline__ RayPayloadT initializeRayPerPixel(const threedgut::RenderParameters& params,
114+
const tcnn::uvec2& pixel,
115+
const tcnn::vec3* __restrict__ sensorRayOriginPtr,
116+
const tcnn::vec3* __restrict__ sensorRayDirectionPtr,
117+
const tcnn::mat4x3& sensorToWorldTransform) {
118+
RayPayloadT ray;
119+
ray.flags = RayPayloadT::Default;
120+
121+
if ((pixel.x >= params.resolution.x) || (pixel.y >= params.resolution.y)) {
122+
return ray;
123+
}
124+
125+
ray.idx = pixel.x + params.resolution.x * pixel.y;
126+
ray.hitT = 0.0f;
127+
ray.transmittance = 1.0f;
128+
ray.features = tcnn::vec<RayPayloadT::FeatDim>::zero();
129+
130+
ray.origin = sensorToWorldTransform * tcnn::vec4(sensorRayOriginPtr[ray.idx], 1.0f);
131+
ray.direction = tcnn::mat3(sensorToWorldTransform) * sensorRayDirectionPtr[ray.idx];
132+
133+
ray.tMinMax = params.objectAABB.ray_intersect(ray.origin, ray.direction);
134+
ray.tMinMax.x = fmaxf(ray.tMinMax.x, 0.0f);
135+
136+
if (ray.tMinMax.y > ray.tMinMax.x) {
137+
ray.flags |= RayPayloadT::Valid | RayPayloadT::Alive;
138+
}
139+
140+
#if GAUSSIAN_ENABLE_HIT_COUNT
141+
ray.hitN = 0;
142+
#endif
143+
144+
return ray;
145+
}
146+
110147
template <typename TRayPayload>
111148
__device__ __inline__ void finalizeRay(const TRayPayload& ray,
112149
const threedgut::RenderParameters& params,

threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutKBufferRenderer.cuh

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,183 @@ struct GUTKBufferRenderer : Params {
291291
}
292292
}
293293

294+
// Fine-grained balanced forward rendering: Gaussian-wise parallelism with warp-level optimization
295+
template <typename TRay>
296+
static inline __device__ void evalForwardNoKBufferBalanced(
297+
const threedgut::RenderParameters& params,
298+
TRay& ray,
299+
const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr,
300+
const uint32_t* __restrict__ sortedTileParticleIdxPtr,
301+
const tcnn::vec2* __restrict__ particlesProjectedPositionPtr,
302+
const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr,
303+
const float* __restrict__ particlesGlobalDepthPtr,
304+
const float* __restrict__ particlesPrecomputedFeaturesPtr,
305+
const tcnn::uvec2& tile,
306+
const tcnn::uvec2& tileGrid,
307+
const int laneId,
308+
threedgut::MemoryHandles parameters,
309+
tcnn::vec2* __restrict__ particlesProjectedPositionGradPtr = nullptr,
310+
tcnn::vec4* __restrict__ particlesProjectedConicOpacityGradPtr = nullptr,
311+
float* __restrict__ particlesGlobalDepthGradPtr = nullptr,
312+
float* __restrict__ particlesPrecomputedFeaturesGradPtr = nullptr,
313+
threedgut::MemoryHandles parametersGradient = {}) {
314+
315+
using namespace threedgut;
316+
317+
// Get tile data: each warp processes particles from a single 16x16 tile
318+
const uint32_t tileIdx = tile.y * tileGrid.x + tile.x;
319+
const tcnn::uvec2 tileParticleRangeIndices = sortedTileRangeIndicesPtr[tileIdx];
320+
321+
uint32_t tileNumParticlesToProcess = tileParticleRangeIndices.y - tileParticleRangeIndices.x;
322+
323+
// Setup feature buffers based on rendering mode
324+
const TFeaturesVec* particleFeaturesBuffer =
325+
Params::PerRayParticleFeatures ? nullptr :
326+
reinterpret_cast<const TFeaturesVec*>(particlesPrecomputedFeaturesPtr);
327+
TFeaturesVec* particleFeaturesGradientBuffer =
328+
(Params::PerRayParticleFeatures || !Backward) ? nullptr :
329+
reinterpret_cast<TFeaturesVec*>(particlesPrecomputedFeaturesGradPtr);
330+
331+
// Initialize particle system
332+
Particles particles;
333+
particles.initializeDensity(parameters);
334+
if constexpr (Backward) {
335+
particles.initializeDensityGradient(parametersGradient);
336+
}
337+
particles.initializeFeatures(parameters);
338+
if constexpr (Backward && Params::PerRayParticleFeatures) {
339+
particles.initializeFeaturesGradient(parametersGradient);
340+
}
341+
342+
static_assert(Params::KHitBufferSize == 0, "evalForwardNoKBufferBalanced only supports K=0 (no hit buffer). Use evalKBuffer for K>0 cases.");
343+
344+
// Warp-aligned processing: round up to multiple of WarpSize to avoid divergence
345+
constexpr uint32_t WarpSize = GUTParameters::Tiling::WarpSize; // 32 threads per warp
346+
uint32_t alignedParticleCount = ((tileNumParticlesToProcess + WarpSize - 1) / WarpSize) * WarpSize;
347+
348+
// Main loop: Gaussian-wise parallelism - WarpSize threads process Gaussians, single ray
349+
for (uint32_t j = laneId; j < alignedParticleCount; j += WarpSize) {
350+
if (!ray.isAlive()) break;
351+
352+
float hitAlpha = 0.0f;
353+
float hitT = 0.0f;
354+
TFeaturesVec hitFeatures = TFeaturesVec::zero();
355+
bool validHit = false;
356+
357+
// Step 1: Each thread tests one Gaussian intersection
358+
if (j < tileNumParticlesToProcess) {
359+
const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + j;
360+
const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex];
361+
362+
if (particleIdx != GUTParameters::InvalidParticleIdx) {
363+
auto densityParams = particles.fetchDensityParameters(particleIdx);
364+
365+
if (particles.densityHit(ray.origin,
366+
ray.direction,
367+
densityParams,
368+
hitAlpha,
369+
hitT) &&
370+
(hitT > ray.tMinMax.x) &&
371+
(hitT < ray.tMinMax.y)) {
372+
373+
validHit = true;
374+
375+
// Get Gaussian features
376+
if constexpr (Params::PerRayParticleFeatures) {
377+
hitFeatures = particles.featuresFromBuffer(particleIdx, ray.direction);
378+
} else {
379+
hitFeatures = tcnn::max(particleFeaturesBuffer[particleIdx], 0.f);
380+
}
381+
}
382+
}
383+
}
384+
385+
// Skip if no hits in this warp batch
386+
constexpr uint32_t WarpMask = GUTParameters::Tiling::WarpMask; // 0xFFFFFFFF for full warp
387+
if (__all_sync(WarpMask, !validHit)) continue;
388+
389+
// Step 2: Compute per-thread transmittance contribution
390+
float localTransmittance = validHit ? (1.0f - hitAlpha) : 1.0f;
391+
392+
// Step 3: Warp-level prefix scan for cumulative transmittance
393+
for (uint32_t offset = 1; offset < WarpSize; offset <<= 1) {
394+
float n = __shfl_up_sync(WarpMask, localTransmittance, offset);
395+
if (laneId >= offset) {
396+
localTransmittance *= n;
397+
}
398+
}
399+
400+
// Get overall batch transmittance impact
401+
float batchTransmittance = __shfl_sync(WarpMask, localTransmittance, WarpSize - 1);
402+
float newTransmittance = ray.transmittance * batchTransmittance;
403+
404+
// Step 4: Early termination detection - find exact termination point
405+
unsigned int earlyTerminationMask = __ballot_sync(WarpMask,
406+
validHit && (ray.transmittance * localTransmittance) < Particles::MinTransmittanceThreshold);
407+
408+
bool shouldTerminate = false;
409+
int terminationLane = -1;
410+
411+
if (earlyTerminationMask) {
412+
terminationLane = __ffs(earlyTerminationMask) - 1; // Find first terminating lane
413+
shouldTerminate = true;
414+
ray.kill();
415+
}
416+
417+
// Step 5: Warp reduction for feature accumulation
418+
TFeaturesVec accumulatedFeatures = TFeaturesVec::zero();
419+
float accumulatedHitT = 0.0f;
420+
uint32_t accumulatedHitCount = 0;
421+
422+
// Only accumulate contributions before (and including) termination point
423+
bool shouldContribute = validHit && (!shouldTerminate || laneId <= terminationLane);
424+
425+
if (shouldContribute) {
426+
// Use precomputed prefix transmittance, excluding current particle
427+
float prefixTransmittance = (laneId > 0) ?
428+
(localTransmittance / (1.0f - hitAlpha)) : 1.0f;
429+
float particleTransmittance = ray.transmittance * prefixTransmittance;
430+
float hitWeight = hitAlpha * particleTransmittance;
431+
432+
// Compute weighted contributions
433+
for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) {
434+
accumulatedFeatures[featIdx] = hitFeatures[featIdx] * hitWeight;
435+
}
436+
accumulatedHitT = hitT * hitWeight;
437+
accumulatedHitCount = (hitWeight > 0.0f) ? 1 : 0;
438+
}
439+
440+
// Step 6: Warp-level reduction (tree-based sum)
441+
for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) {
442+
for (uint32_t offset = WarpSize / 2; offset > 0; offset >>= 1) {
443+
accumulatedFeatures[featIdx] += __shfl_down_sync(WarpMask, accumulatedFeatures[featIdx], offset);
444+
}
445+
}
446+
447+
for (uint32_t offset = WarpSize / 2; offset > 0; offset >>= 1) {
448+
accumulatedHitT += __shfl_down_sync(WarpMask, accumulatedHitT, offset);
449+
accumulatedHitCount += __shfl_down_sync(WarpMask, accumulatedHitCount, offset);
450+
}
451+
452+
// Step 7: Only lane 0 updates ray state (avoid race conditions)
453+
if (laneId == 0) {
454+
for (int featIdx = 0; featIdx < Particles::FeaturesDim; ++featIdx) {
455+
ray.features[featIdx] += accumulatedFeatures[featIdx];
456+
}
457+
ray.hitT += accumulatedHitT;
458+
ray.countHit(accumulatedHitCount);
459+
}
460+
461+
// Step 8: Update ray transmittance
462+
ray.transmittance = newTransmittance;
463+
464+
// Break on early termination
465+
if (shouldTerminate) {
466+
break;
467+
}
468+
}
469+
}
470+
294471
template <typename TRay>
295472
static inline __device__ void evalBackwardNoKBuffer(TRay& ray,
296473
Particles& particles,

threedgut_tracer/include/3dgut/kernels/cuda/renderers/gutRenderer.cuh

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,106 @@ __global__ void render(threedgut::RenderParameters params,
114114
finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr, worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform);
115115
}
116116

117+
// Fine-grained load balancing rendering kernel: static allocation per virtual tile
118+
__global__ void renderBalanced(threedgut::RenderParameters params,
119+
const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr,
120+
const uint32_t* __restrict__ sortedTileDataPtr,
121+
const tcnn::vec3* __restrict__ sensorRayOriginPtr,
122+
const tcnn::vec3* __restrict__ sensorRayDirectionPtr,
123+
tcnn::mat4x3 sensorToWorldTransform,
124+
float* __restrict__ worldHitCountPtr,
125+
float* __restrict__ worldHitDistancePtr,
126+
tcnn::vec4* __restrict__ radianceDensityPtr,
127+
const tcnn::vec2* __restrict__ particlesProjectedPositionPtr,
128+
const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr,
129+
const float* __restrict__ particlesGlobalDepthPtr,
130+
const float* __restrict__ particlesPrecomputedFeaturesPtr,
131+
const uint64_t* __restrict__ parameterMemoryHandles,
132+
const tcnn::uvec2 tileGrid) {
133+
134+
// Static allocation: each block handles one virtual tile
135+
using namespace threedgut;
136+
constexpr uint32_t virtualTilesPerTile = GUTParameters::Tiling::VirtualTilesPerTile;
137+
const uint32_t virtualTileId = blockIdx.x;
138+
139+
// Calculate total virtual tiles across all original tiles
140+
const uint32_t totalVirtualTiles = tileGrid.x * tileGrid.y * virtualTilesPerTile;
141+
142+
// Boundary check
143+
if (virtualTileId >= totalVirtualTiles) return;
144+
145+
// Map virtual tile back to original tile coordinates and local position
146+
const uint32_t originalTileId = virtualTileId / virtualTilesPerTile;
147+
const uint32_t virtualTileInOriginal = virtualTileId % virtualTilesPerTile;
148+
149+
const uint32_t originalTileX = originalTileId % tileGrid.x;
150+
const uint32_t originalTileY = originalTileId / tileGrid.x;
151+
152+
// Map virtual tile to pixel coordinates within original tile
153+
constexpr uint32_t virtualTilesPerTileX = GUTParameters::Tiling::VirtualTilesPerTileX;
154+
constexpr uint32_t virtualTileX = GUTParameters::Tiling::VirtualTileX;
155+
constexpr uint32_t virtualTileY = GUTParameters::Tiling::VirtualTileY;
156+
constexpr uint32_t warpSize = GUTParameters::Tiling::WarpSize;
157+
158+
const uint32_t virtualTileXPos = virtualTileInOriginal % virtualTilesPerTileX; // 0-7
159+
const uint32_t virtualTileYPos = virtualTileInOriginal / virtualTilesPerTileX; // 0-7
160+
161+
// Calculate base pixel coordinates for this virtual tile
162+
const uint32_t basePixelX = virtualTileXPos * virtualTileX; // 0,2,4,6,8,10,12,14
163+
const uint32_t basePixelY = virtualTileYPos * virtualTileY; // 0,2,4,6,8,10,12,14
164+
165+
// Warp-level processing: each warp handles one pixel in virtual tile
166+
const uint32_t warpId = threadIdx.x / warpSize;
167+
const uint32_t laneId = threadIdx.x & (warpSize - 1);
168+
169+
// Each block processes 1 virtual tile = virtualTileSize pixels, each warp handles 1 pixel
170+
constexpr uint32_t virtualTileSize = GUTParameters::Tiling::VirtualTileSize;
171+
constexpr uint32_t blockX = GUTParameters::Tiling::BlockX;
172+
constexpr uint32_t blockY = GUTParameters::Tiling::BlockY;
173+
174+
if (warpId < virtualTileSize) { // virtualTileSize warps per block (1 warp per pixel)
175+
// Arrange pixels in row-major order within virtualTileX x virtualTileY region
176+
// warp 0-3 maps to pixels: (0,0),(1,0),(0,1),(1,1) for 2x2 virtual tile
177+
const uint32_t pixelOffsetX = warpId % virtualTileX;
178+
const uint32_t pixelOffsetY = warpId / virtualTileX;
179+
180+
const uint32_t pixelLocalX = basePixelX + pixelOffsetX;
181+
const uint32_t pixelLocalY = basePixelY + pixelOffsetY;
182+
183+
const tcnn::uvec2 pixel = {
184+
originalTileX * blockX + pixelLocalX,
185+
originalTileY * blockY + pixelLocalY
186+
};
187+
188+
// Initialize ray for current pixel
189+
auto ray = initializeRayPerPixel<TGUTRenderer::TRayPayload>(
190+
params, pixel, sensorRayOriginPtr, sensorRayDirectionPtr, sensorToWorldTransform);
191+
192+
// Warp-level parallel rendering using original tile's particle data
193+
const tcnn::uvec2 originalTile = {originalTileX, originalTileY};
194+
195+
TGUTRenderer::evalForwardNoKBufferBalanced(params,
196+
ray,
197+
sortedTileRangeIndicesPtr,
198+
sortedTileDataPtr,
199+
particlesProjectedPositionPtr,
200+
particlesProjectedConicOpacityPtr,
201+
particlesGlobalDepthPtr,
202+
particlesPrecomputedFeaturesPtr,
203+
originalTile,
204+
tileGrid,
205+
laneId, // warp lane for parallel processing
206+
{parameterMemoryHandles});
207+
208+
// Write final results to output buffers
209+
// Only lane 0 should write, as only it has accumulated the correct values
210+
if (laneId == 0) {
211+
finalizeRay(ray, params, sensorRayOriginPtr, worldHitCountPtr,
212+
worldHitDistancePtr, radianceDensityPtr, sensorToWorldTransform);
213+
}
214+
}
215+
}
216+
117217
__global__ void renderBackward(threedgut::RenderParameters params,
118218
const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr,
119219
const uint32_t* __restrict__ sortedTileDataPtr,

threedgut_tracer/include/3dgut/renderer/gutRendererParameters.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ struct GUTParameters {
2626
static constexpr uint32_t WarpSize = 32;
2727
static constexpr uint32_t NumWarps = BlockSize / WarpSize;
2828
static constexpr uint32_t WarpMask = 0xFFFFFFFFU;
29+
30+
// Fine-grained load balancing parameters - base dimensions
31+
static constexpr uint32_t VirtualTileX = 2; // virtual tile width in pixels
32+
static constexpr uint32_t VirtualTileY = 2; // virtual tile height in pixels
33+
// Derived constants from base dimensions
34+
static constexpr uint32_t VirtualTileSize = VirtualTileX * VirtualTileY; // 4 pixels per virtual tile
35+
static constexpr uint32_t VirtualTilesPerTileX = BlockX / VirtualTileX; // 8 virtual tiles per row
36+
static constexpr uint32_t VirtualTilesPerTileY = BlockY / VirtualTileY; // 8 virtual tiles per column
37+
static constexpr uint32_t VirtualTilesPerTile = VirtualTilesPerTileX * VirtualTilesPerTileY; // 64 total
38+
static constexpr uint32_t FineGrainedWarpsPerBlock = VirtualTileSize; // 4 warps per block (1 per pixel)
39+
static constexpr uint32_t FineGrainedThreadsPerBlock = FineGrainedWarpsPerBlock * WarpSize; // 128 threads
2940
};
3041

3142
static constexpr uint32_t InvalidParticleIdx = -1U;

threedgut_tracer/setup_3dgut.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def to_cpp_bool(value):
5656
f"-DGAUSSIAN_N_ROLLING_SHUTTER_ITERATIONS={conf.render.splat.n_rolling_shutter_iterations}",
5757
f"-DGAUSSIAN_K_BUFFER_SIZE={conf.render.splat.k_buffer_size}",
5858
f"-DGAUSSIAN_GLOBAL_Z_ORDER={to_cpp_bool(conf.render.splat.global_z_order)}",
59+
f"-DFINE_GRAINED_LOAD_BALANCING={to_cpp_bool(getattr(conf.render.splat, 'fine_grained_load_balancing', False))}",
5960
# -- Unscented Transform --
6061
f"-DGAUSSIAN_UT_ALPHA={ut_alpha}",
6162
f"-DGAUSSIAN_UT_BETA={ut_beta}",

0 commit comments

Comments
 (0)