44#include < algorithm>
55#include < cctype>
66#include < filesystem>
7+ #include < fstream>
78#include < gtest/gtest.h>
89
10+ #include " core/camera.hpp"
911#include " core/logger.hpp"
1012#include " core/parameters.hpp"
13+ #include " core/path_utils.hpp"
1114#include " core/scene.hpp"
15+ #include " core/tensor.hpp"
16+ #include " io/loader.hpp"
1217#include " training/checkpoint.hpp"
1318#include " training/trainer.hpp"
1419#include " training/training_setup.hpp"
@@ -19,6 +24,77 @@ namespace {
1924 constexpr int CHECKPOINT_ITER = 1200 ;
2025 constexpr int TOTAL_ITER = 2100 ;
2126
27+ TEST (TrainingSetupRegressionTest, ApplyLoadedDatasetKeepsFullInitPointCloudUntilTrainingStarts) {
28+ constexpr size_t initial_points = 12 ;
29+ constexpr int target_splats = 5 ;
30+
31+ const auto temp_dir = std::filesystem::temp_directory_path () / " lfs_training_setup_full_init_regression" ;
32+ std::error_code ec;
33+ std::filesystem::remove_all (temp_dir, ec);
34+ std::filesystem::create_directories (temp_dir);
35+
36+ const auto init_path = temp_dir / " init_points.ply" ;
37+ {
38+ std::ofstream ply (init_path);
39+ ASSERT_TRUE (ply.is_open ());
40+ ply << " ply\n "
41+ " format ascii 1.0\n "
42+ " element vertex "
43+ << initial_points
44+ << " \n "
45+ " property float x\n "
46+ " property float y\n "
47+ " property float z\n "
48+ " property uchar red\n "
49+ " property uchar green\n "
50+ " property uchar blue\n "
51+ " end_header\n " ;
52+ for (size_t i = 0 ; i < initial_points; ++i) {
53+ ply << static_cast <float >(i) << ' '
54+ << static_cast <float >(i % 3 ) << ' '
55+ << static_cast <float >(i % 5 ) << ' '
56+ << static_cast <int >(10 + i) << ' '
57+ << static_cast <int >(20 + i) << ' '
58+ << static_cast <int >(30 + i) << ' \n ' ;
59+ }
60+ }
61+
62+ lfs::core::param::TrainingParameters params;
63+ params.dataset .data_path = temp_dir / " dataset" ;
64+ params.init_path = lfs::core::path_to_utf8 (init_path);
65+ params.optimization .max_cap = target_splats;
66+
67+ lfs::io::LoadedScene loaded_scene;
68+ loaded_scene.cameras .push_back (std::make_shared<lfs::core::Camera>());
69+
70+ lfs::io::LoadResult load_result;
71+ load_result.data = std::move (loaded_scene);
72+ load_result.scene_center = lfs::core::Tensor::from_vector (
73+ std::vector<float >{0 .0f , 0 .0f , 0 .0f },
74+ {size_t {3 }},
75+ lfs::core::Device::CPU);
76+ load_result.loader_used = " test" ;
77+
78+ lfs::core::Scene scene;
79+ auto apply_result = lfs::training::applyLoadResultToScene (params, scene, std::move (load_result));
80+ ASSERT_TRUE (apply_result.has_value ()) << apply_result.error ();
81+
82+ const auto * model = scene.getTrainingModel ();
83+ ASSERT_NE (model, nullptr );
84+ EXPECT_EQ (static_cast <size_t >(model->size ()), initial_points);
85+ EXPECT_EQ (scene.getTrainingModelGaussianCount (), initial_points);
86+
87+ auto trainer = std::make_unique<lfs::training::Trainer>(scene);
88+ auto init_result = trainer->initialize (params);
89+ ASSERT_TRUE (init_result.has_value ()) << init_result.error ();
90+
91+ EXPECT_EQ (static_cast <size_t >(trainer->get_strategy ().get_model ().size ()), static_cast <size_t >(target_splats));
92+ EXPECT_EQ (scene.getTrainingModelGaussianCount (), static_cast <size_t >(target_splats));
93+
94+ trainer->shutdown ();
95+ std::filesystem::remove_all (temp_dir, ec);
96+ }
97+
2298 class CheckpointResumeTest : public ::testing::TestWithParam<std::tuple<std::string, int >> {
2399 protected:
24100 void SetUp () override {
0 commit comments