Skip to content

Commit 68bb02e

Browse files
committed
downsample init pc only on import
1 parent 6703442 commit 68bb02e

2 files changed

Lines changed: 76 additions & 7 deletions

File tree

src/training/training_setup.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,13 +437,6 @@ namespace lfs::training {
437437
return std::unexpected(std::format("Init failed: {}", splat_result.error()));
438438
}
439439

440-
const int max_cap = params.optimization.max_cap;
441-
if (max_cap > 0 && max_cap < static_cast<int>(splat_result->size())) {
442-
LOG_WARN("Max cap ({}) is less than initial splat count ({}), randomly selecting {} splats",
443-
max_cap, splat_result->size(), max_cap);
444-
lfs::core::random_choose(*splat_result, max_cap);
445-
}
446-
447440
auto model = std::make_unique<lfs::core::SplatData>(std::move(*splat_result));
448441
LOG_INFO("Init {} gaussians from {} (sh={})",
449442
model->size(), lfs::core::path_to_utf8(init_file.filename()), model->get_max_sh_degree());

tests/test_checkpoint_resume.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
#include <algorithm>
55
#include <cctype>
66
#include <filesystem>
7+
#include <fstream>
78
#include <gtest/gtest.h>
89

10+
#include "core/camera.hpp"
911
#include "core/logger.hpp"
1012
#include "core/parameters.hpp"
13+
#include "core/path_utils.hpp"
1114
#include "core/scene.hpp"
15+
#include "core/tensor.hpp"
16+
#include "io/loader.hpp"
1217
#include "training/checkpoint.hpp"
1318
#include "training/trainer.hpp"
1419
#include "training/training_setup.hpp"
@@ -19,6 +24,77 @@ namespace {
1924
constexpr int CHECKPOINT_ITER = 1200;
2025
constexpr int TOTAL_ITER = 2100;
2126

27+
TEST(TrainingSetupRegressionTest, ApplyLoadedDatasetKeepsFullInitPointCloudUntilTrainingStarts) {
28+
constexpr size_t initial_points = 12;
29+
constexpr int target_splats = 5;
30+
31+
const auto temp_dir = std::filesystem::temp_directory_path() / "lfs_training_setup_full_init_regression";
32+
std::error_code ec;
33+
std::filesystem::remove_all(temp_dir, ec);
34+
std::filesystem::create_directories(temp_dir);
35+
36+
const auto init_path = temp_dir / "init_points.ply";
37+
{
38+
std::ofstream ply(init_path);
39+
ASSERT_TRUE(ply.is_open());
40+
ply << "ply\n"
41+
"format ascii 1.0\n"
42+
"element vertex "
43+
<< initial_points
44+
<< "\n"
45+
"property float x\n"
46+
"property float y\n"
47+
"property float z\n"
48+
"property uchar red\n"
49+
"property uchar green\n"
50+
"property uchar blue\n"
51+
"end_header\n";
52+
for (size_t i = 0; i < initial_points; ++i) {
53+
ply << static_cast<float>(i) << ' '
54+
<< static_cast<float>(i % 3) << ' '
55+
<< static_cast<float>(i % 5) << ' '
56+
<< static_cast<int>(10 + i) << ' '
57+
<< static_cast<int>(20 + i) << ' '
58+
<< static_cast<int>(30 + i) << '\n';
59+
}
60+
}
61+
62+
lfs::core::param::TrainingParameters params;
63+
params.dataset.data_path = temp_dir / "dataset";
64+
params.init_path = lfs::core::path_to_utf8(init_path);
65+
params.optimization.max_cap = target_splats;
66+
67+
lfs::io::LoadedScene loaded_scene;
68+
loaded_scene.cameras.push_back(std::make_shared<lfs::core::Camera>());
69+
70+
lfs::io::LoadResult load_result;
71+
load_result.data = std::move(loaded_scene);
72+
load_result.scene_center = lfs::core::Tensor::from_vector(
73+
std::vector<float>{0.0f, 0.0f, 0.0f},
74+
{size_t{3}},
75+
lfs::core::Device::CPU);
76+
load_result.loader_used = "test";
77+
78+
lfs::core::Scene scene;
79+
auto apply_result = lfs::training::applyLoadResultToScene(params, scene, std::move(load_result));
80+
ASSERT_TRUE(apply_result.has_value()) << apply_result.error();
81+
82+
const auto* model = scene.getTrainingModel();
83+
ASSERT_NE(model, nullptr);
84+
EXPECT_EQ(static_cast<size_t>(model->size()), initial_points);
85+
EXPECT_EQ(scene.getTrainingModelGaussianCount(), initial_points);
86+
87+
auto trainer = std::make_unique<lfs::training::Trainer>(scene);
88+
auto init_result = trainer->initialize(params);
89+
ASSERT_TRUE(init_result.has_value()) << init_result.error();
90+
91+
EXPECT_EQ(static_cast<size_t>(trainer->get_strategy().get_model().size()), static_cast<size_t>(target_splats));
92+
EXPECT_EQ(scene.getTrainingModelGaussianCount(), static_cast<size_t>(target_splats));
93+
94+
trainer->shutdown();
95+
std::filesystem::remove_all(temp_dir, ec);
96+
}
97+
2298
class CheckpointResumeTest : public ::testing::TestWithParam<std::tuple<std::string, int>> {
2399
protected:
24100
void SetUp() override {

0 commit comments

Comments
 (0)