|
18 | 18 | #include <format> |
19 | 19 | #include <memory> |
20 | 20 | #include <numeric> |
| 21 | +#include <optional> |
21 | 22 | #include <random> |
22 | 23 | #include <variant> |
23 | 24 |
|
@@ -91,6 +92,45 @@ namespace lfs::training { |
91 | 92 | } |
92 | 93 | } |
93 | 94 |
|
| 95 | + std::optional<float> computeSceneScaleFromPositions( |
| 96 | + const lfs::core::Tensor& positions, |
| 97 | + const lfs::core::Tensor& scene_center) { |
| 98 | + if (!positions.is_valid() || positions.ndim() != 2 || |
| 99 | + positions.size(0) == 0 || positions.size(1) < 3 || |
| 100 | + !scene_center.is_valid() || scene_center.numel() < 3) { |
| 101 | + return std::nullopt; |
| 102 | + } |
| 103 | + |
| 104 | + const auto center = scene_center.to(positions.device()); |
| 105 | + const auto dists = positions.sub(center).norm(2.0f, {1}, false); |
| 106 | + if (!dists.is_valid() || dists.size(0) == 0) { |
| 107 | + return std::nullopt; |
| 108 | + } |
| 109 | + |
| 110 | + const auto sorted_dists = dists.sort(0, false); |
| 111 | + return sorted_dists.first[dists.size(0) / 2].item(); |
| 112 | + } |
| 113 | + |
| 114 | + void recomputeInitSplatSceneScale( |
| 115 | + lfs::core::SplatData& model, |
| 116 | + const lfs::core::Tensor& scene_center, |
| 117 | + const std::filesystem::path& init_file) { |
| 118 | + const auto scene_scale = computeSceneScaleFromPositions(model.means_raw(), scene_center); |
| 119 | + if (!scene_scale) { |
| 120 | + LOG_WARN("Could not compute scene scale for init splat {}; keeping {}", |
| 121 | + lfs::core::path_to_utf8(init_file.filename()), |
| 122 | + model.get_scene_scale()); |
| 123 | + return; |
| 124 | + } |
| 125 | + |
| 126 | + const float previous_scale = model.get_scene_scale(); |
| 127 | + model.set_scene_scale(*scene_scale); |
| 128 | + LOG_INFO("Computed init scene scale from {}: {} -> {}", |
| 129 | + lfs::core::path_to_utf8(init_file.filename()), |
| 130 | + previous_scale, |
| 131 | + *scene_scale); |
| 132 | + } |
| 133 | + |
94 | 134 | std::expected<std::unique_ptr<lfs::core::SplatData>, std::string> loadAddedSplat( |
95 | 135 | const std::filesystem::path& path, |
96 | 136 | const int target_degree) { |
@@ -414,17 +454,18 @@ namespace lfs::training { |
414 | 454 | scene.setTrainingModelNode("Model"); |
415 | 455 | } else { |
416 | 456 | auto loader = lfs::io::Loader::create(); |
417 | | - auto load_result = loader->load(init_file); |
| 457 | + auto init_result = loader->load(init_file); |
418 | 458 |
|
419 | | - if (!load_result) { |
| 459 | + if (!init_result) { |
420 | 460 | return std::unexpected(std::format("Failed to load '{}': {}", |
421 | | - lfs::core::path_to_utf8(init_file), load_result.error().format())); |
| 461 | + lfs::core::path_to_utf8(init_file), init_result.error().format())); |
422 | 462 | } |
423 | 463 |
|
424 | 464 | try { |
425 | | - auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(load_result->data)); |
| 465 | + auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(init_result->data)); |
426 | 466 | auto model = std::make_unique<lfs::core::SplatData>(std::move(splat_data)); |
427 | 467 |
|
| 468 | + recomputeInitSplatSceneScale(*model, load_result->scene_center, init_file); |
428 | 469 | applyTrainingSHDegree(*model, params.optimization.sh_degree); |
429 | 470 |
|
430 | 471 | LOG_INFO("Loaded {} Gaussians from {} (sh={})", |
@@ -768,6 +809,7 @@ namespace lfs::training { |
768 | 809 | auto splat_data = std::move(*std::get<std::shared_ptr<lfs::core::SplatData>>(init_result->data)); |
769 | 810 | auto model = std::make_unique<lfs::core::SplatData>(std::move(splat_data)); |
770 | 811 |
|
| 812 | + recomputeInitSplatSceneScale(*model, load_result.scene_center, init_file); |
771 | 813 | applyTrainingSHDegree(*model, params.optimization.sh_degree); |
772 | 814 |
|
773 | 815 | LOG_INFO("Loaded {} gaussians from {} (sh={})", |
|
0 commit comments