1414#include " io/pipelined_image_loader.hpp"
1515#include " kernels/densification_kernels.hpp"
1616#include " kernels/image_kernels.hpp"
17+ #include " kernels/mcmc_kernels.hpp"
1718#include " optimizer/adam_optimizer.hpp"
1819
1920#include < numeric>
@@ -24,6 +25,8 @@ namespace lfs::training {
2425 namespace {
2526 constexpr uint32_t IGS_PLUS_MAGIC = 0x4C464947 ; // "LFIG"
2627 constexpr uint32_t IGS_PLUS_VERSION = 1 ;
28+ constexpr int64_t ERROR_CANDIDATE_FACTOR = 4 ;
29+ constexpr float EDGE_SCORE_WEIGHT = 0 .25f ;
2730
2831 // Returns true if shape has any zero dimension (e.g., ShN at sh-degree 0)
2932 [[nodiscard]] inline bool has_zero_dimension (const lfs::core::TensorShape& shape) {
@@ -101,6 +104,12 @@ namespace lfs::training {
101104 float median = sorted[valid.numel () / 2 ].item_as <float >();
102105 tensor.div_ (std::max (median, 1e-9f ));
103106 }
107+
108+ lfs::core::Tensor normalized_by_positive_median (const lfs::core::Tensor& tensor) {
109+ auto normalized = tensor.clone ();
110+ normalize_by_positive_median_inplace (normalized);
111+ return normalized;
112+ }
104113 } // namespace
105114
106115 ImprovedGSPlus::ImprovedGSPlus (lfs::core::SplatData& splat_data)
@@ -145,7 +154,7 @@ namespace lfs::training {
145154 const double gamma = std::pow (0.1 , 1.0 / optimParams.iterations );
146155 _scheduler = std::make_unique<ExponentialLR>(*_optimizer, gamma, std::vector<ParamType>{ParamType::Means, ParamType::Scaling});
147156
148- // Initialize densification info: [2, N] tensor for tracking gradients
157+ // Initialize densification info: [2, N] tensor for tracking per-view densification statistics
149158 _splat_data->_densification_info = lfs::core::Tensor::zeros (
150159 {2 , static_cast <size_t >(_splat_data->size ())},
151160 _splat_data->means ().device ());
@@ -159,9 +168,10 @@ namespace lfs::training {
159168 this ->_current_step = 0 ;
160169
161170 this ->_budget_schedule = get_count_array ();
171+ ensure_error_score_shape ();
162172 }
163173
164- const lfs::core::Tensor ImprovedGSPlus::compute_gaussian_score (const lfs::core::Tensor& gradients ) {
174+ const lfs::core::Tensor ImprovedGSPlus::compute_gaussian_score () {
165175 const int64_t N = _splat_data->size ();
166176
167177 auto view_indices = random_cam_indices ();
@@ -207,40 +217,75 @@ namespace lfs::training {
207217 return gaussian_scores;
208218 }
209219
210- void ImprovedGSPlus::densify_with_score (const lfs::core::Tensor& scores, const lfs::core::Tensor& grads, const int64_t budget) {
211- // Get Number of Gaussians to densify
212- const lfs::core::Tensor grad_qualifiers = lfs::core::Tensor::where (grads >= _params->grad_threshold ,
213- lfs::core::Tensor::ones ({1 }), lfs::core::Tensor::zeros ({1 }))
214- .to (lfs::core::DataType::Bool);
215-
216- const int total_grads = static_cast <int >(grad_qualifiers.sum_scalar ());
217-
218- // Budget allocation
219- const int64_t curr_points = _splat_data->size ();
220- // budget caps
221- const int64_t curr_budget = std::min (budget, curr_points + total_grads);
222- const int64_t budget_for_alloc = curr_budget - curr_points;
223-
224- if (budget_for_alloc > 0 ) {
225- LAS_densify (scores, budget_for_alloc, grad_qualifiers, grads);
220+ void ImprovedGSPlus::ensure_error_score_shape () {
221+ const size_t n = static_cast <size_t >(_splat_data->size ());
222+ if (!_error_score_max.is_valid () ||
223+ _error_score_max.ndim () != 1 ||
224+ _error_score_max.numel () != n) {
225+ _error_score_max = lfs::core::Tensor::zeros ({n}, _splat_data->means ().device ());
226226 }
227227 }
228228
229- void ImprovedGSPlus::LAS_densify (const lfs::core::Tensor& scores, const int64_t budget_for_alloc, const lfs::core::Tensor& grad_mask, const lfs::core::Tensor& grads) {
229+ void ImprovedGSPlus::densify_with_score (const lfs::core::Tensor& edge_scores, const lfs::core::Tensor& error_scores, const int64_t budget) {
230+ const int64_t curr_points = static_cast <int64_t >(active_count ());
231+ const int64_t budget_for_alloc = std::max<int64_t >(0 , budget - curr_points);
232+ if (budget_for_alloc <= 0 ) {
233+ return ;
234+ }
230235
231- lfs::core::Tensor scores_masked;
236+ const auto active_indices = get_active_indices ();
237+ const int64_t total_active = static_cast <int64_t >(active_indices.numel ());
238+ if (total_active == 0 ) {
239+ return ;
240+ }
232241
233- if (_current_step < 3 ) {
234- scores_masked = scores;
235- } else {
236- const lfs::core::Tensor LAS_grad_mask = lfs::core::Tensor::where (grads >= 0.00004 ,
237- lfs::core::Tensor::ones ({1 }), lfs::core::Tensor::zeros ({1 }))
238- .to (lfs::core::DataType::Bool);
242+ const int64_t candidate_budget = std::min<int64_t >(
243+ total_active,
244+ std::max<int64_t >(budget_for_alloc, budget_for_alloc * ERROR_CANDIDATE_FACTOR));
245+
246+ const auto normalized_error = normalized_by_positive_median (error_scores);
247+ const auto normalized_edge = normalized_by_positive_median (edge_scores);
248+ const auto device = _splat_data->means ().device ();
249+
250+ auto active_mask = lfs::core::Tensor::zeros_bool ({static_cast <size_t >(_splat_data->size ())}, device);
251+ auto true_vals = lfs::core::Tensor::ones_bool ({static_cast <size_t >(active_indices.numel ())}, device);
252+ active_mask.index_put_ (active_indices, true_vals);
253+
254+ lfs::core::Tensor candidate_mask = active_mask;
255+ if (candidate_budget < total_active) {
256+ const auto active_error = normalized_error.index_select (0 , active_indices);
257+ auto [sorted_error, _] = active_error.sort (0 , true );
258+ const float threshold = sorted_error[candidate_budget - 1 ].item_as <float >();
259+ candidate_mask = active_mask.logical_and (normalized_error >= threshold);
260+ }
261+
262+ auto sampling_scores = normalized_error * (normalized_edge * EDGE_SCORE_WEIGHT + 1 .0f );
263+ sampling_scores = sampling_scores.masked_fill (~candidate_mask, 0 .0f );
264+
265+ int64_t selectable = static_cast <int64_t >(sampling_scores.count_nonzero ());
266+ if (selectable < budget_for_alloc) {
267+ auto edge_fallback = normalized_edge.masked_fill (~active_mask, 0 .0f );
268+ selectable = static_cast <int64_t >(edge_fallback.count_nonzero ());
269+ if (selectable > 0 ) {
270+ sampling_scores = std::move (edge_fallback);
271+ } else {
272+ auto active_weights = lfs::core::Tensor::zeros ({static_cast <size_t >(_splat_data->size ())}, device);
273+ auto active_weight_vals = lfs::core::Tensor::ones ({static_cast <size_t >(active_indices.numel ())}, device);
274+ active_weights.index_put_ (active_indices, active_weight_vals);
275+ sampling_scores = std::move (active_weights);
276+ selectable = total_active;
277+ }
278+ }
239279
240- scores_masked = scores.masked_fill (~LAS_grad_mask, 0 );
280+ if (selectable <= 0 ) {
281+ return ;
241282 }
242283
243- const lfs::core::Tensor sampled_idxs = lfs::core::Tensor::multinomial (scores_masked, budget_for_alloc, false );
284+ LAS_densify (sampling_scores.clamp_min (1e-12f ), std::min<int64_t >(budget_for_alloc, selectable));
285+ }
286+
287+ void ImprovedGSPlus::LAS_densify (const lfs::core::Tensor& scores, const int64_t budget_for_alloc) {
288+ const lfs::core::Tensor sampled_idxs = lfs::core::Tensor::multinomial (scores, budget_for_alloc, false );
244289
245290 LOG_DEBUG (" split(): {} Gaussians to long axis split" , budget_for_alloc);
246291
@@ -409,11 +454,7 @@ namespace lfs::training {
409454
410455 assert (_views && " set_views() must be called before training" );
411456
412- const lfs::core::Tensor numer = _splat_data->_densification_info [1 ];
413- const lfs::core::Tensor denom = _splat_data->_densification_info [0 ];
414- _precomputed_grads = numer / denom.clamp_min (1 .0f );
415-
416- _precomputed_scores = compute_gaussian_score (_precomputed_grads);
457+ _precomputed_scores = compute_gaussian_score ();
417458 _precompute_valid = true ;
418459 }
419460
@@ -427,10 +468,33 @@ namespace lfs::training {
427468 return ;
428469 }
429470
471+ {
472+ const size_t n = static_cast <size_t >(_splat_data->size ());
473+ const auto & info = _splat_data->_densification_info ;
474+ if (!info.is_valid () || info.ndim () != 2 || info.shape ()[0 ] < 2 || info.shape ()[1 ] != n) {
475+ _splat_data->_densification_info = lfs::core::Tensor::zeros ({2 , n}, _splat_data->means ().device ());
476+ }
477+ ensure_error_score_shape ();
478+
479+ const auto & accum = _splat_data->_densification_info ;
480+ if (accum.is_valid () &&
481+ accum.ndim () == 2 &&
482+ accum.shape ()[0 ] >= 2 &&
483+ accum.shape ()[1 ] == _error_score_max.numel ()) {
484+ const float * error_row = accum.ptr <float >() + accum.shape ()[1 ];
485+ lfs::training::mcmc::launch_elementwise_max_inplace (
486+ _error_score_max.ptr <float >(),
487+ error_row,
488+ _error_score_max.numel ());
489+ }
490+
491+ _splat_data->_densification_info .zero_ ();
492+ }
493+
430494 if (is_refining (iter)) {
431495 assert (_precompute_valid);
432496
433- densify_with_score (_precomputed_scores, _precomputed_grads , get_current_budget ());
497+ densify_with_score (_precomputed_scores, _error_score_max , get_current_budget ());
434498
435499 opacity_prune (iter);
436500
@@ -439,11 +503,12 @@ namespace lfs::training {
439503 _splat_data->_densification_info = lfs::core::Tensor::zeros (
440504 {2 , static_cast <size_t >(_splat_data->size ())},
441505 _splat_data->means ().device ());
506+ ensure_error_score_shape ();
507+ _error_score_max.zero_ ();
442508
443509 this ->_current_step ++;
444510
445511 _precomputed_scores = lfs::core::Tensor ();
446- _precomputed_grads = lfs::core::Tensor ();
447512 _precompute_valid = false ;
448513 }
449514
@@ -453,6 +518,7 @@ namespace lfs::training {
453518
454519 if (iter == _params->stop_refine ) {
455520 _splat_data->_densification_info = lfs::core::Tensor::empty ({0 });
521+ _error_score_max = lfs::core::Tensor::empty ({0 });
456522
457523 lfs::core::CudaMemoryPool::instance ().trim_cached_memory ();
458524 }
@@ -491,6 +557,51 @@ namespace lfs::training {
491557 }
492558 }
493559
560+ size_t ImprovedGSPlus::active_count () const {
561+ if (!_free_mask.is_valid ()) {
562+ return static_cast <size_t >(_splat_data->size ());
563+ }
564+
565+ const size_t current_size = static_cast <size_t >(_splat_data->size ());
566+ if (current_size == 0 ) {
567+ return 0 ;
568+ }
569+
570+ auto active_region = _free_mask.slice (0 , 0 , current_size);
571+ const auto free_count_val = static_cast <size_t >(active_region.sum_scalar ());
572+ return current_size - free_count_val;
573+ }
574+
575+ size_t ImprovedGSPlus::free_count () const {
576+ if (!_free_mask.is_valid ()) {
577+ return 0 ;
578+ }
579+
580+ const size_t current_size = static_cast <size_t >(_splat_data->size ());
581+ if (current_size == 0 ) {
582+ return 0 ;
583+ }
584+
585+ auto active_region = _free_mask.slice (0 , 0 , current_size);
586+ return static_cast <size_t >(active_region.sum_scalar ());
587+ }
588+
589+ lfs::core::Tensor ImprovedGSPlus::get_active_indices () const {
590+ const size_t current_size = static_cast <size_t >(_splat_data->size ());
591+ if (current_size == 0 ) {
592+ return lfs::core::Tensor ();
593+ }
594+
595+ if (!_free_mask.is_valid () || free_count () == 0 ) {
596+ auto all_active = lfs::core::Tensor::ones_bool ({current_size}, _splat_data->means ().device ());
597+ return all_active.nonzero ().squeeze (-1 );
598+ }
599+
600+ auto active_region = _free_mask.slice (0 , 0 , current_size);
601+ auto is_active = active_region.logical_not ();
602+ return is_active.nonzero ().squeeze (-1 );
603+ }
604+
494605 std::vector<int > ImprovedGSPlus::random_cam_indices (const int N) const {
495606 const int num_cam_dataset = _views->size ();
496607 int num_samples = 0 ;
@@ -598,6 +709,11 @@ namespace lfs::training {
598709 zero_optimizer_state (ParamType::ShN);
599710 zero_optimizer_state (ParamType::Opacity);
600711
712+ if (_error_score_max.is_valid () && _error_score_max.ndim () == 1 && _error_score_max.numel () >= _splat_data->size ()) {
713+ auto zeros = lfs::core::Tensor::zeros ({static_cast <size_t >(num_pruned)}, _error_score_max.device ());
714+ _error_score_max.index_put_ (prune_indices, zeros);
715+ }
716+
601717 LOG_DEBUG (" remove(): soft-deleted {} Gaussians (marked as free, rotation & gradients zeroed)" , num_pruned);
602718 }
603719
@@ -682,6 +798,11 @@ namespace lfs::training {
682798 auto false_vals = lfs::core::Tensor::zeros_bool ({static_cast <size_t >(slots_to_fill)}, target_indices.device ());
683799 _free_mask.index_put_ (target_indices, false_vals);
684800
801+ if (_error_score_max.is_valid () && _error_score_max.ndim () == 1 && _error_score_max.numel () >= current_size) {
802+ auto zeros = lfs::core::Tensor::zeros ({static_cast <size_t >(slots_to_fill)}, _error_score_max.device ());
803+ _error_score_max.index_put_ (target_indices, zeros);
804+ }
805+
685806 return {target_indices, count - slots_to_fill};
686807 }
687808
@@ -752,8 +873,8 @@ namespace lfs::training {
752873 const size_t capacity = _params->max_cap > 0 ? static_cast <size_t >(_params->max_cap )
753874 : static_cast <size_t >(_splat_data->size ());
754875 _free_mask = lfs::core::Tensor::zeros_bool ({capacity}, _splat_data->means ().device ());
755- _precomputed_grads = lfs::core::Tensor ();
756876 _precomputed_scores = lfs::core::Tensor ();
877+ _error_score_max = lfs::core::Tensor::zeros ({static_cast <size_t >(_splat_data->size ())}, _splat_data->means ().device ());
757878 _precompute_valid = false ;
758879 _current_step = 0 ;
759880 _budget_schedule = get_count_array ();
@@ -801,8 +922,8 @@ namespace lfs::training {
801922 _free_mask = lfs::core::Tensor::zeros_bool ({capacity}, _splat_data->means ().device ());
802923 }
803924
804- _precomputed_grads = lfs::core::Tensor ();
805925 _precomputed_scores = lfs::core::Tensor ();
926+ _error_score_max = lfs::core::Tensor::zeros ({static_cast <size_t >(_splat_data->size ())}, _splat_data->means ().device ());
806927 _precompute_valid = false ;
807928
808929 LOG_DEBUG (" Deserialized ImprovedGSPlus (version {})" , version);
0 commit comments