@@ -291,6 +291,183 @@ struct GUTKBufferRenderer : Params {
291291 }
292292 }
293293
294+ // Fine-grained balanced forward rendering: Gaussian-wise parallelism with warp-level optimization
295+ template <typename TRay>
296+ static inline __device__ void evalForwardNoKBufferBalanced (
297+ const threedgut::RenderParameters& params,
298+ TRay& ray,
299+ const tcnn::uvec2* __restrict__ sortedTileRangeIndicesPtr,
300+ const uint32_t * __restrict__ sortedTileParticleIdxPtr,
301+ const tcnn::vec2* __restrict__ particlesProjectedPositionPtr,
302+ const tcnn::vec4* __restrict__ particlesProjectedConicOpacityPtr,
303+ const float * __restrict__ particlesGlobalDepthPtr,
304+ const float * __restrict__ particlesPrecomputedFeaturesPtr,
305+ const tcnn::uvec2& tile,
306+ const tcnn::uvec2& tileGrid,
307+ const int laneId,
308+ threedgut::MemoryHandles parameters,
309+ tcnn::vec2* __restrict__ particlesProjectedPositionGradPtr = nullptr ,
310+ tcnn::vec4* __restrict__ particlesProjectedConicOpacityGradPtr = nullptr ,
311+ float * __restrict__ particlesGlobalDepthGradPtr = nullptr ,
312+ float * __restrict__ particlesPrecomputedFeaturesGradPtr = nullptr ,
313+ threedgut::MemoryHandles parametersGradient = {}) {
314+
315+ using namespace threedgut ;
316+
317+ // Get tile data: each warp processes particles from a single 16x16 tile
318+ const uint32_t tileIdx = tile.y * tileGrid.x + tile.x ;
319+ const tcnn::uvec2 tileParticleRangeIndices = sortedTileRangeIndicesPtr[tileIdx];
320+
321+ uint32_t tileNumParticlesToProcess = tileParticleRangeIndices.y - tileParticleRangeIndices.x ;
322+
323+ // Setup feature buffers based on rendering mode
324+ const TFeaturesVec* particleFeaturesBuffer =
325+ Params::PerRayParticleFeatures ? nullptr :
326+ reinterpret_cast <const TFeaturesVec*>(particlesPrecomputedFeaturesPtr);
327+ TFeaturesVec* particleFeaturesGradientBuffer =
328+ (Params::PerRayParticleFeatures || !Backward) ? nullptr :
329+ reinterpret_cast <TFeaturesVec*>(particlesPrecomputedFeaturesGradPtr);
330+
331+ // Initialize particle system
332+ Particles particles;
333+ particles.initializeDensity (parameters);
334+ if constexpr (Backward) {
335+ particles.initializeDensityGradient (parametersGradient);
336+ }
337+ particles.initializeFeatures (parameters);
338+ if constexpr (Backward && Params::PerRayParticleFeatures) {
339+ particles.initializeFeaturesGradient (parametersGradient);
340+ }
341+
342+ static_assert (Params::KHitBufferSize == 0 , " evalForwardNoKBufferBalanced only supports K=0 (no hit buffer). Use evalKBuffer for K>0 cases." );
343+
344+ // Warp-aligned processing: round up to multiple of WarpSize to avoid divergence
345+ constexpr uint32_t WarpSize = GUTParameters::Tiling::WarpSize; // 32 threads per warp
346+ uint32_t alignedParticleCount = ((tileNumParticlesToProcess + WarpSize - 1 ) / WarpSize) * WarpSize;
347+
348+ // Main loop: Gaussian-wise parallelism - WarpSize threads process Gaussians, single ray
349+ for (uint32_t j = laneId; j < alignedParticleCount; j += WarpSize) {
350+ if (!ray.isAlive ()) break ;
351+
352+ float hitAlpha = 0 .0f ;
353+ float hitT = 0 .0f ;
354+ TFeaturesVec hitFeatures = TFeaturesVec::zero ();
355+ bool validHit = false ;
356+
357+ // Step 1: Each thread tests one Gaussian intersection
358+ if (j < tileNumParticlesToProcess) {
359+ const uint32_t toProcessSortedIndex = tileParticleRangeIndices.x + j;
360+ const uint32_t particleIdx = sortedTileParticleIdxPtr[toProcessSortedIndex];
361+
362+ if (particleIdx != GUTParameters::InvalidParticleIdx) {
363+ auto densityParams = particles.fetchDensityParameters (particleIdx);
364+
365+ if (particles.densityHit (ray.origin ,
366+ ray.direction ,
367+ densityParams,
368+ hitAlpha,
369+ hitT) &&
370+ (hitT > ray.tMinMax .x ) &&
371+ (hitT < ray.tMinMax .y )) {
372+
373+ validHit = true ;
374+
375+ // Get Gaussian features
376+ if constexpr (Params::PerRayParticleFeatures) {
377+ hitFeatures = particles.featuresFromBuffer (particleIdx, ray.direction );
378+ } else {
379+ hitFeatures = tcnn::max (particleFeaturesBuffer[particleIdx], 0 .f );
380+ }
381+ }
382+ }
383+ }
384+
385+ // Skip if no hits in this warp batch
386+ constexpr uint32_t WarpMask = GUTParameters::Tiling::WarpMask; // 0xFFFFFFFF for full warp
387+ if (__all_sync (WarpMask, !validHit)) continue ;
388+
389+ // Step 2: Compute per-thread transmittance contribution
390+ float localTransmittance = validHit ? (1 .0f - hitAlpha) : 1 .0f ;
391+
392+ // Step 3: Warp-level prefix scan for cumulative transmittance
393+ for (uint32_t offset = 1 ; offset < WarpSize; offset <<= 1 ) {
394+ float n = __shfl_up_sync (WarpMask, localTransmittance, offset);
395+ if (laneId >= offset) {
396+ localTransmittance *= n;
397+ }
398+ }
399+
400+ // Get overall batch transmittance impact
401+ float batchTransmittance = __shfl_sync (WarpMask, localTransmittance, WarpSize - 1 );
402+ float newTransmittance = ray.transmittance * batchTransmittance;
403+
404+ // Step 4: Early termination detection - find exact termination point
405+ unsigned int earlyTerminationMask = __ballot_sync (WarpMask,
406+ validHit && (ray.transmittance * localTransmittance) < Particles::MinTransmittanceThreshold);
407+
408+ bool shouldTerminate = false ;
409+ int terminationLane = -1 ;
410+
411+ if (earlyTerminationMask) {
412+ terminationLane = __ffs (earlyTerminationMask) - 1 ; // Find first terminating lane
413+ shouldTerminate = true ;
414+ ray.kill ();
415+ }
416+
417+ // Step 5: Warp reduction for feature accumulation
418+ TFeaturesVec accumulatedFeatures = TFeaturesVec::zero ();
419+ float accumulatedHitT = 0 .0f ;
420+ uint32_t accumulatedHitCount = 0 ;
421+
422+ // Only accumulate contributions before (and including) termination point
423+ bool shouldContribute = validHit && (!shouldTerminate || laneId <= terminationLane);
424+
425+ if (shouldContribute) {
426+ // Use precomputed prefix transmittance, excluding current particle
427+ float prefixTransmittance = (laneId > 0 ) ?
428+ (localTransmittance / (1 .0f - hitAlpha)) : 1 .0f ;
429+ float particleTransmittance = ray.transmittance * prefixTransmittance;
430+ float hitWeight = hitAlpha * particleTransmittance;
431+
432+ // Compute weighted contributions
433+ for (int featIdx = 0 ; featIdx < Particles::FeaturesDim; ++featIdx) {
434+ accumulatedFeatures[featIdx] = hitFeatures[featIdx] * hitWeight;
435+ }
436+ accumulatedHitT = hitT * hitWeight;
437+ accumulatedHitCount = (hitWeight > 0 .0f ) ? 1 : 0 ;
438+ }
439+
440+ // Step 6: Warp-level reduction (tree-based sum)
441+ for (int featIdx = 0 ; featIdx < Particles::FeaturesDim; ++featIdx) {
442+ for (uint32_t offset = WarpSize / 2 ; offset > 0 ; offset >>= 1 ) {
443+ accumulatedFeatures[featIdx] += __shfl_down_sync (WarpMask, accumulatedFeatures[featIdx], offset);
444+ }
445+ }
446+
447+ for (uint32_t offset = WarpSize / 2 ; offset > 0 ; offset >>= 1 ) {
448+ accumulatedHitT += __shfl_down_sync (WarpMask, accumulatedHitT, offset);
449+ accumulatedHitCount += __shfl_down_sync (WarpMask, accumulatedHitCount, offset);
450+ }
451+
452+ // Step 7: Only lane 0 updates ray state (avoid race conditions)
453+ if (laneId == 0 ) {
454+ for (int featIdx = 0 ; featIdx < Particles::FeaturesDim; ++featIdx) {
455+ ray.features [featIdx] += accumulatedFeatures[featIdx];
456+ }
457+ ray.hitT += accumulatedHitT;
458+ ray.countHit (accumulatedHitCount);
459+ }
460+
461+ // Step 8: Update ray transmittance
462+ ray.transmittance = newTransmittance;
463+
464+ // Break on early termination
465+ if (shouldTerminate) {
466+ break ;
467+ }
468+ }
469+ }
470+
294471 template <typename TRay>
295472 static inline __device__ void evalBackwardNoKBuffer (TRay& ray,
296473 Particles& particles,
0 commit comments