Skip to content

Commit 23c0c6f

Browse files
committed
improved igs+
1 parent 224b1af commit 23c0c6f

3 files changed

Lines changed: 167 additions & 45 deletions

File tree

src/training/strategies/improved_gs_plus.cpp

Lines changed: 158 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "io/pipelined_image_loader.hpp"
1515
#include "kernels/densification_kernels.hpp"
1616
#include "kernels/image_kernels.hpp"
17+
#include "kernels/mcmc_kernels.hpp"
1718
#include "optimizer/adam_optimizer.hpp"
1819

1920
#include <numeric>
@@ -24,6 +25,8 @@ namespace lfs::training {
2425
namespace {
2526
constexpr uint32_t IGS_PLUS_MAGIC = 0x4C464947; // "LFIG"
2627
constexpr uint32_t IGS_PLUS_VERSION = 1;
28+
constexpr int64_t ERROR_CANDIDATE_FACTOR = 4;
29+
constexpr float EDGE_SCORE_WEIGHT = 0.25f;
2730

2831
// Returns true if shape has any zero dimension (e.g., ShN at sh-degree 0)
2932
[[nodiscard]] inline bool has_zero_dimension(const lfs::core::TensorShape& shape) {
@@ -101,6 +104,12 @@ namespace lfs::training {
101104
float median = sorted[valid.numel() / 2].item_as<float>();
102105
tensor.div_(std::max(median, 1e-9f));
103106
}
107+
108+
lfs::core::Tensor normalized_by_positive_median(const lfs::core::Tensor& tensor) {
109+
auto normalized = tensor.clone();
110+
normalize_by_positive_median_inplace(normalized);
111+
return normalized;
112+
}
104113
} // namespace
105114

106115
ImprovedGSPlus::ImprovedGSPlus(lfs::core::SplatData& splat_data)
@@ -145,7 +154,7 @@ namespace lfs::training {
145154
const double gamma = std::pow(0.1, 1.0 / optimParams.iterations);
146155
_scheduler = std::make_unique<ExponentialLR>(*_optimizer, gamma, std::vector<ParamType>{ParamType::Means, ParamType::Scaling});
147156

148-
// Initialize densification info: [2, N] tensor for tracking gradients
157+
// Initialize densification info: [2, N] tensor for tracking per-view densification statistics
149158
_splat_data->_densification_info = lfs::core::Tensor::zeros(
150159
{2, static_cast<size_t>(_splat_data->size())},
151160
_splat_data->means().device());
@@ -159,9 +168,10 @@ namespace lfs::training {
159168
this->_current_step = 0;
160169

161170
this->_budget_schedule = get_count_array();
171+
ensure_error_score_shape();
162172
}
163173

164-
const lfs::core::Tensor ImprovedGSPlus::compute_gaussian_score(const lfs::core::Tensor& gradients) {
174+
const lfs::core::Tensor ImprovedGSPlus::compute_gaussian_score() {
165175
const int64_t N = _splat_data->size();
166176

167177
auto view_indices = random_cam_indices();
@@ -207,40 +217,75 @@ namespace lfs::training {
207217
return gaussian_scores;
208218
}
209219

210-
void ImprovedGSPlus::densify_with_score(const lfs::core::Tensor& scores, const lfs::core::Tensor& grads, const int64_t budget) {
211-
// Get Number of Gaussians to densify
212-
const lfs::core::Tensor grad_qualifiers = lfs::core::Tensor::where(grads >= _params->grad_threshold,
213-
lfs::core::Tensor::ones({1}), lfs::core::Tensor::zeros({1}))
214-
.to(lfs::core::DataType::Bool);
215-
216-
const int total_grads = static_cast<int>(grad_qualifiers.sum_scalar());
217-
218-
// Budget allocation
219-
const int64_t curr_points = _splat_data->size();
220-
// budget caps
221-
const int64_t curr_budget = std::min(budget, curr_points + total_grads);
222-
const int64_t budget_for_alloc = curr_budget - curr_points;
223-
224-
if (budget_for_alloc > 0) {
225-
LAS_densify(scores, budget_for_alloc, grad_qualifiers, grads);
220+
void ImprovedGSPlus::ensure_error_score_shape() {
221+
const size_t n = static_cast<size_t>(_splat_data->size());
222+
if (!_error_score_max.is_valid() ||
223+
_error_score_max.ndim() != 1 ||
224+
_error_score_max.numel() != n) {
225+
_error_score_max = lfs::core::Tensor::zeros({n}, _splat_data->means().device());
226226
}
227227
}
228228

229-
void ImprovedGSPlus::LAS_densify(const lfs::core::Tensor& scores, const int64_t budget_for_alloc, const lfs::core::Tensor& grad_mask, const lfs::core::Tensor& grads) {
229+
void ImprovedGSPlus::densify_with_score(const lfs::core::Tensor& edge_scores, const lfs::core::Tensor& error_scores, const int64_t budget) {
230+
const int64_t curr_points = static_cast<int64_t>(active_count());
231+
const int64_t budget_for_alloc = std::max<int64_t>(0, budget - curr_points);
232+
if (budget_for_alloc <= 0) {
233+
return;
234+
}
230235

231-
lfs::core::Tensor scores_masked;
236+
const auto active_indices = get_active_indices();
237+
const int64_t total_active = static_cast<int64_t>(active_indices.numel());
238+
if (total_active == 0) {
239+
return;
240+
}
232241

233-
if (_current_step < 3) {
234-
scores_masked = scores;
235-
} else {
236-
const lfs::core::Tensor LAS_grad_mask = lfs::core::Tensor::where(grads >= 0.00004,
237-
lfs::core::Tensor::ones({1}), lfs::core::Tensor::zeros({1}))
238-
.to(lfs::core::DataType::Bool);
242+
const int64_t candidate_budget = std::min<int64_t>(
243+
total_active,
244+
std::max<int64_t>(budget_for_alloc, budget_for_alloc * ERROR_CANDIDATE_FACTOR));
245+
246+
const auto normalized_error = normalized_by_positive_median(error_scores);
247+
const auto normalized_edge = normalized_by_positive_median(edge_scores);
248+
const auto device = _splat_data->means().device();
249+
250+
auto active_mask = lfs::core::Tensor::zeros_bool({static_cast<size_t>(_splat_data->size())}, device);
251+
auto true_vals = lfs::core::Tensor::ones_bool({static_cast<size_t>(active_indices.numel())}, device);
252+
active_mask.index_put_(active_indices, true_vals);
253+
254+
lfs::core::Tensor candidate_mask = active_mask;
255+
if (candidate_budget < total_active) {
256+
const auto active_error = normalized_error.index_select(0, active_indices);
257+
auto [sorted_error, _] = active_error.sort(0, true);
258+
const float threshold = sorted_error[candidate_budget - 1].item_as<float>();
259+
candidate_mask = active_mask.logical_and(normalized_error >= threshold);
260+
}
261+
262+
auto sampling_scores = normalized_error * (normalized_edge * EDGE_SCORE_WEIGHT + 1.0f);
263+
sampling_scores = sampling_scores.masked_fill(~candidate_mask, 0.0f);
264+
265+
int64_t selectable = static_cast<int64_t>(sampling_scores.count_nonzero());
266+
if (selectable < budget_for_alloc) {
267+
auto edge_fallback = normalized_edge.masked_fill(~active_mask, 0.0f);
268+
selectable = static_cast<int64_t>(edge_fallback.count_nonzero());
269+
if (selectable > 0) {
270+
sampling_scores = std::move(edge_fallback);
271+
} else {
272+
auto active_weights = lfs::core::Tensor::zeros({static_cast<size_t>(_splat_data->size())}, device);
273+
auto active_weight_vals = lfs::core::Tensor::ones({static_cast<size_t>(active_indices.numel())}, device);
274+
active_weights.index_put_(active_indices, active_weight_vals);
275+
sampling_scores = std::move(active_weights);
276+
selectable = total_active;
277+
}
278+
}
239279

240-
scores_masked = scores.masked_fill(~LAS_grad_mask, 0);
280+
if (selectable <= 0) {
281+
return;
241282
}
242283

243-
const lfs::core::Tensor sampled_idxs = lfs::core::Tensor::multinomial(scores_masked, budget_for_alloc, false);
284+
LAS_densify(sampling_scores.clamp_min(1e-12f), std::min<int64_t>(budget_for_alloc, selectable));
285+
}
286+
287+
void ImprovedGSPlus::LAS_densify(const lfs::core::Tensor& scores, const int64_t budget_for_alloc) {
288+
const lfs::core::Tensor sampled_idxs = lfs::core::Tensor::multinomial(scores, budget_for_alloc, false);
244289

245290
LOG_DEBUG("split(): {} Gaussians to long axis split", budget_for_alloc);
246291

@@ -409,11 +454,7 @@ namespace lfs::training {
409454

410455
assert(_views && "set_views() must be called before training");
411456

412-
const lfs::core::Tensor numer = _splat_data->_densification_info[1];
413-
const lfs::core::Tensor denom = _splat_data->_densification_info[0];
414-
_precomputed_grads = numer / denom.clamp_min(1.0f);
415-
416-
_precomputed_scores = compute_gaussian_score(_precomputed_grads);
457+
_precomputed_scores = compute_gaussian_score();
417458
_precompute_valid = true;
418459
}
419460

@@ -427,10 +468,33 @@ namespace lfs::training {
427468
return;
428469
}
429470

471+
{
472+
const size_t n = static_cast<size_t>(_splat_data->size());
473+
const auto& info = _splat_data->_densification_info;
474+
if (!info.is_valid() || info.ndim() != 2 || info.shape()[0] < 2 || info.shape()[1] != n) {
475+
_splat_data->_densification_info = lfs::core::Tensor::zeros({2, n}, _splat_data->means().device());
476+
}
477+
ensure_error_score_shape();
478+
479+
const auto& accum = _splat_data->_densification_info;
480+
if (accum.is_valid() &&
481+
accum.ndim() == 2 &&
482+
accum.shape()[0] >= 2 &&
483+
accum.shape()[1] == _error_score_max.numel()) {
484+
const float* error_row = accum.ptr<float>() + accum.shape()[1];
485+
lfs::training::mcmc::launch_elementwise_max_inplace(
486+
_error_score_max.ptr<float>(),
487+
error_row,
488+
_error_score_max.numel());
489+
}
490+
491+
_splat_data->_densification_info.zero_();
492+
}
493+
430494
if (is_refining(iter)) {
431495
assert(_precompute_valid);
432496

433-
densify_with_score(_precomputed_scores, _precomputed_grads, get_current_budget());
497+
densify_with_score(_precomputed_scores, _error_score_max, get_current_budget());
434498

435499
opacity_prune(iter);
436500

@@ -439,11 +503,12 @@ namespace lfs::training {
439503
_splat_data->_densification_info = lfs::core::Tensor::zeros(
440504
{2, static_cast<size_t>(_splat_data->size())},
441505
_splat_data->means().device());
506+
ensure_error_score_shape();
507+
_error_score_max.zero_();
442508

443509
this->_current_step++;
444510

445511
_precomputed_scores = lfs::core::Tensor();
446-
_precomputed_grads = lfs::core::Tensor();
447512
_precompute_valid = false;
448513
}
449514

@@ -453,6 +518,7 @@ namespace lfs::training {
453518

454519
if (iter == _params->stop_refine) {
455520
_splat_data->_densification_info = lfs::core::Tensor::empty({0});
521+
_error_score_max = lfs::core::Tensor::empty({0});
456522

457523
lfs::core::CudaMemoryPool::instance().trim_cached_memory();
458524
}
@@ -491,6 +557,51 @@ namespace lfs::training {
491557
}
492558
}
493559

560+
size_t ImprovedGSPlus::active_count() const {
561+
if (!_free_mask.is_valid()) {
562+
return static_cast<size_t>(_splat_data->size());
563+
}
564+
565+
const size_t current_size = static_cast<size_t>(_splat_data->size());
566+
if (current_size == 0) {
567+
return 0;
568+
}
569+
570+
auto active_region = _free_mask.slice(0, 0, current_size);
571+
const auto free_count_val = static_cast<size_t>(active_region.sum_scalar());
572+
return current_size - free_count_val;
573+
}
574+
575+
size_t ImprovedGSPlus::free_count() const {
576+
if (!_free_mask.is_valid()) {
577+
return 0;
578+
}
579+
580+
const size_t current_size = static_cast<size_t>(_splat_data->size());
581+
if (current_size == 0) {
582+
return 0;
583+
}
584+
585+
auto active_region = _free_mask.slice(0, 0, current_size);
586+
return static_cast<size_t>(active_region.sum_scalar());
587+
}
588+
589+
lfs::core::Tensor ImprovedGSPlus::get_active_indices() const {
590+
const size_t current_size = static_cast<size_t>(_splat_data->size());
591+
if (current_size == 0) {
592+
return lfs::core::Tensor();
593+
}
594+
595+
if (!_free_mask.is_valid() || free_count() == 0) {
596+
auto all_active = lfs::core::Tensor::ones_bool({current_size}, _splat_data->means().device());
597+
return all_active.nonzero().squeeze(-1);
598+
}
599+
600+
auto active_region = _free_mask.slice(0, 0, current_size);
601+
auto is_active = active_region.logical_not();
602+
return is_active.nonzero().squeeze(-1);
603+
}
604+
494605
std::vector<int> ImprovedGSPlus::random_cam_indices(const int N) const {
495606
const int num_cam_dataset = _views->size();
496607
int num_samples = 0;
@@ -598,6 +709,11 @@ namespace lfs::training {
598709
zero_optimizer_state(ParamType::ShN);
599710
zero_optimizer_state(ParamType::Opacity);
600711

712+
if (_error_score_max.is_valid() && _error_score_max.ndim() == 1 && _error_score_max.numel() >= _splat_data->size()) {
713+
auto zeros = lfs::core::Tensor::zeros({static_cast<size_t>(num_pruned)}, _error_score_max.device());
714+
_error_score_max.index_put_(prune_indices, zeros);
715+
}
716+
601717
LOG_DEBUG("remove(): soft-deleted {} Gaussians (marked as free, rotation & gradients zeroed)", num_pruned);
602718
}
603719

@@ -682,6 +798,11 @@ namespace lfs::training {
682798
auto false_vals = lfs::core::Tensor::zeros_bool({static_cast<size_t>(slots_to_fill)}, target_indices.device());
683799
_free_mask.index_put_(target_indices, false_vals);
684800

801+
if (_error_score_max.is_valid() && _error_score_max.ndim() == 1 && _error_score_max.numel() >= current_size) {
802+
auto zeros = lfs::core::Tensor::zeros({static_cast<size_t>(slots_to_fill)}, _error_score_max.device());
803+
_error_score_max.index_put_(target_indices, zeros);
804+
}
805+
685806
return {target_indices, count - slots_to_fill};
686807
}
687808

@@ -752,8 +873,8 @@ namespace lfs::training {
752873
const size_t capacity = _params->max_cap > 0 ? static_cast<size_t>(_params->max_cap)
753874
: static_cast<size_t>(_splat_data->size());
754875
_free_mask = lfs::core::Tensor::zeros_bool({capacity}, _splat_data->means().device());
755-
_precomputed_grads = lfs::core::Tensor();
756876
_precomputed_scores = lfs::core::Tensor();
877+
_error_score_max = lfs::core::Tensor::zeros({static_cast<size_t>(_splat_data->size())}, _splat_data->means().device());
757878
_precompute_valid = false;
758879
_current_step = 0;
759880
_budget_schedule = get_count_array();
@@ -801,8 +922,8 @@ namespace lfs::training {
801922
_free_mask = lfs::core::Tensor::zeros_bool({capacity}, _splat_data->means().device());
802923
}
803924

804-
_precomputed_grads = lfs::core::Tensor();
805925
_precomputed_scores = lfs::core::Tensor();
926+
_error_score_max = lfs::core::Tensor::zeros({static_cast<size_t>(_splat_data->size())}, _splat_data->means().device());
806927
_precompute_valid = false;
807928

808929
LOG_DEBUG("Deserialized ImprovedGSPlus (version {})", version);

src/training/strategies/improved_gs_plus.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,10 @@ namespace lfs::training {
8181

8282
std::vector<int64_t> get_count_array();
8383

84-
const lfs::core::Tensor compute_gaussian_score(const lfs::core::Tensor& gradients);
85-
void densify_with_score(const lfs::core::Tensor& scores, const lfs::core::Tensor& grads, const int64_t budget);
86-
void LAS_densify(const lfs::core::Tensor& scores, const int64_t allocation_budget, const lfs::core::Tensor& grad_mask, const lfs::core::Tensor& grads);
84+
const lfs::core::Tensor compute_gaussian_score();
85+
void ensure_error_score_shape();
86+
void densify_with_score(const lfs::core::Tensor& edge_scores, const lfs::core::Tensor& error_scores, const int64_t budget);
87+
void LAS_densify(const lfs::core::Tensor& scores, const int64_t allocation_budget);
8788

8889
void reset_opacity();
8990
void prune_post_reset();
@@ -118,8 +119,8 @@ namespace lfs::training {
118119
std::unique_ptr<const lfs::core::param::OptimizationParameters> _params;
119120

120121
// Pre-computed edge scores for non-blocking densification
121-
lfs::core::Tensor _precomputed_grads;
122122
lfs::core::Tensor _precomputed_scores;
123+
lfs::core::Tensor _error_score_max;
123124
bool _precompute_valid = false;
124125

125126
// Free slot tracking - bool tensor [capacity], true = slot is free for reuse

src/training/trainer.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,9 +1212,9 @@ namespace lfs::training {
12121212
ppisp_cam_idx >= 0 &&
12131213
ppisp_cam_idx < ppisp_controller_pool_->num_cameras();
12141214
const bool use_pixel_error_densification =
1215-
(params_.optimization.strategy == "mcmc");
1216-
const bool use_ssim_error = use_pixel_error_densification &&
1217-
(params_.optimization.strategy == "mcmc");
1215+
(params_.optimization.strategy == "mcmc" ||
1216+
params_.optimization.strategy == "igs+");
1217+
const bool use_ssim_error = use_pixel_error_densification;
12181218

12191219
// Loop over tiles (row-major order)
12201220
for (int tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
@@ -1525,7 +1525,7 @@ namespace lfs::training {
15251525
tile_error_map = densification_error_map_;
15261526
}
15271527
} else if (use_ssim_error) {
1528-
// lambda_dssim == 0 but MCMC needs SSIM error: standalone pass
1528+
// lambda_dssim == 0 but error-priority densification still needs SSIM error
15291529
lfs::core::Tensor pred_chw = corrected_image;
15301530
lfs::core::Tensor gt_chw = gt_tile;
15311531
if (pred_chw.ndim() == 3 && pred_chw.shape()[2] == 3 &&

0 commit comments

Comments
 (0)