Skip to content

Commit bcffda0

Browse files
committed
fix scene_scale at --init
1 parent 250838d commit bcffda0

2 files changed

Lines changed: 47 additions & 4 deletions

File tree

src/core/include/core/splat_data.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ namespace lfs::core {
9090
int get_active_sh_degree() const { return _active_sh_degree; }
9191
int get_max_sh_degree() const { return _max_sh_degree; }
9292
float get_scene_scale() const { return _scene_scale; }
93+
void set_scene_scale(float scene_scale) { _scene_scale = scene_scale; }
9394
unsigned long size() const { return static_cast<unsigned long>(_means.shape()[0]); }
9495

9596
// ========== Raw tensor access (for optimization) ==========

src/training/training_setup.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <format>
1919
#include <memory>
2020
#include <numeric>
21+
#include <optional>
2122
#include <random>
2223
#include <variant>
2324

@@ -91,6 +92,45 @@ namespace lfs::training {
9192
}
9293
}
9394

95+
std::optional<float> computeSceneScaleFromPositions(
96+
const lfs::core::Tensor& positions,
97+
const lfs::core::Tensor& scene_center) {
98+
if (!positions.is_valid() || positions.ndim() != 2 ||
99+
positions.size(0) == 0 || positions.size(1) < 3 ||
100+
!scene_center.is_valid() || scene_center.numel() < 3) {
101+
return std::nullopt;
102+
}
103+
104+
const auto center = scene_center.to(positions.device());
105+
const auto dists = positions.sub(center).norm(2.0f, {1}, false);
106+
if (!dists.is_valid() || dists.size(0) == 0) {
107+
return std::nullopt;
108+
}
109+
110+
const auto sorted_dists = dists.sort(0, false);
111+
return sorted_dists.first[dists.size(0) / 2].item();
112+
}
113+
114+
void recomputeInitSplatSceneScale(
115+
lfs::core::SplatData& model,
116+
const lfs::core::Tensor& scene_center,
117+
const std::filesystem::path& init_file) {
118+
const auto scene_scale = computeSceneScaleFromPositions(model.means_raw(), scene_center);
119+
if (!scene_scale) {
120+
LOG_WARN("Could not compute scene scale for init splat {}; keeping {}",
121+
lfs::core::path_to_utf8(init_file.filename()),
122+
model.get_scene_scale());
123+
return;
124+
}
125+
126+
const float previous_scale = model.get_scene_scale();
127+
model.set_scene_scale(*scene_scale);
128+
LOG_INFO("Computed init scene scale from {}: {} -> {}",
129+
lfs::core::path_to_utf8(init_file.filename()),
130+
previous_scale,
131+
*scene_scale);
132+
}
133+
94134
std::expected<std::unique_ptr<lfs::core::SplatData>, std::string> loadAddedSplat(
95135
const std::filesystem::path& path,
96136
const int target_degree) {
@@ -414,17 +454,18 @@ namespace lfs::training {
414454
scene.setTrainingModelNode("Model");
415455
} else {
416456
auto loader = lfs::io::Loader::create();
417-
auto load_result = loader->load(init_file);
457+
auto init_result = loader->load(init_file);
418458

419-
if (!load_result) {
459+
if (!init_result) {
420460
return std::unexpected(std::format("Failed to load '{}': {}",
421-
lfs::core::path_to_utf8(init_file), load_result.error().format()));
461+
lfs::core::path_to_utf8(init_file), init_result.error().format()));
422462
}
423463

424464
try {
425-
auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(load_result->data));
465+
auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(init_result->data));
426466
auto model = std::make_unique<lfs::core::SplatData>(std::move(splat_data));
427467

468+
recomputeInitSplatSceneScale(*model, load_result->scene_center, init_file);
428469
applyTrainingSHDegree(*model, params.optimization.sh_degree);
429470

430471
LOG_INFO("Loaded {} Gaussians from {} (sh={})",
@@ -768,6 +809,7 @@ namespace lfs::training {
768809
auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(init_result->data));
769810
auto model = std::make_unique<lfs::core::SplatData>(std::move(splat_data));
770811

812+
recomputeInitSplatSceneScale(*model, load_result.scene_center, init_file);
771813
applyTrainingSHDegree(*model, params.optimization.sh_degree);
772814

773815
LOG_INFO("Loaded {} gaussians from {} (sh={})",

0 commit comments

Comments
 (0)