@@ -20,7 +20,7 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt);
2020torch::Tensor l1 (const torch::Tensor& rendered, const torch::Tensor& gt);
2121
2222struct Model {
23- Model (const Points &points , int numCameras,
23+ Model (const InputData &inputData , int numCameras,
2424 int numDownscales, int resolutionSchedule, int shDegree, int shDegreeInterval,
2525 int refineEvery, int warmupLength, int resetAlphaEvery, int stopSplitAt, float densifyGradThresh, float densifySizeThresh, int stopScreenSizeAt, float splitScreenSize,
2626 int maxSteps,
@@ -30,17 +30,21 @@ struct Model{
3030 refineEvery (refineEvery), warmupLength(warmupLength), resetAlphaEvery(resetAlphaEvery), stopSplitAt(stopSplitAt), densifyGradThresh(densifyGradThresh), densifySizeThresh(densifySizeThresh), stopScreenSizeAt(stopScreenSizeAt), splitScreenSize(splitScreenSize),
3131 maxSteps (maxSteps),
3232 device (device), ssim(11 , 3 ){
33- long long numPoints = points.xyz .size (0 );
33+
34+ long long numPoints = inputData.points .xyz .size (0 );
35+ scale = inputData.scale ;
36+ translation = inputData.translation ;
37+
3438 torch::manual_seed (42 );
3539
36- means = points.xyz .to (device).requires_grad_ ();
37- scales = PointsTensor (points.xyz ).scales ().repeat ({1 , 3 }).log ().to (device).requires_grad_ ();
40+ means = inputData. points .xyz .to (device).requires_grad_ ();
41+ scales = PointsTensor (inputData. points .xyz ).scales ().repeat ({1 , 3 }).log ().to (device).requires_grad_ ();
3842 quats = randomQuatTensor (numPoints).to (device).requires_grad_ ();
3943
4044 int dimSh = numShBases (shDegree);
4145 torch::Tensor shs = torch::zeros ({numPoints, dimSh, 3 }, torch::TensorOptions ().dtype (torch::kFloat32 ).device (device));
4246
43- shs.index ({Slice (), 0 , Slice (None, 3 )}) = rgb2sh (points.rgb .toType (torch::kFloat64 ) / 255.0 ).toType (torch::kFloat32 );
47+ shs.index ({Slice (), 0 , Slice (None, 3 )}) = rgb2sh (inputData. points .rgb .toType (torch::kFloat64 ) / 255.0 ).toType (torch::kFloat32 );
4448 shs.index ({Slice (), Slice (1 , None), Slice (3 , None)}) = 0 .0f ;
4549
4650 featuresDc = shs.index ({Slice (), 0 , Slice ()}).to (device).requires_grad_ ();
@@ -78,6 +82,7 @@ struct Model{
7882 int getDownscaleFactor (int step);
7983 void afterTrain (int step);
8084 void savePlySplat (const std::string &filename);
85+ void saveDebugPly (const std::string &filename);
8186 torch::Tensor mainLoss (torch::Tensor &rgb, torch::Tensor >, float ssimWeight);
8287
8388 void addToOptimizer (torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
@@ -126,6 +131,9 @@ struct Model{
126131 int stopScreenSizeAt;
127132 float splitScreenSize;
128133 int maxSteps;
134+
135+ float scale;
136+ torch::Tensor translation;
129137};
130138
131139
0 commit comments