Skip to content

Commit 521121d

Browse files
authored
Merge pull request #56 from pierotofy/orient
Preserve scale/orientation of scene input
2 parents 1b1db43 + c584323 commit 521121d

File tree

10 files changed

+84
-43
lines changed

10 files changed

+84
-43
lines changed

colmap.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,11 @@ InputData inputDataFromColmap(const std::string &projectRoot){
127127
}
128128

129129
imgf.close();
130-
131-
auto r = autoOrientAndCenterPoses(unorientedPoses);
130+
131+
auto r = autoScaleAndCenterPoses(unorientedPoses);
132132
torch::Tensor poses = std::get<0>(r);
133-
ret.transformMatrix = std::get<1>(r);
134-
ret.scaleFactor = 1.0f / torch::max(torch::abs(poses.index({Slice(), Slice(None, 3), 3}))).item<float>();
135-
poses.index({Slice(), Slice(None, 3), 3}) *= ret.scaleFactor;
133+
ret.translation = std::get<1>(r);
134+
ret.scale = std::get<2>(r);
136135

137136
for (size_t i = 0; i < ret.cameras.size(); i++){
138137
ret.cameras[i].camToWorld = poses[i];
@@ -141,9 +140,7 @@ InputData inputDataFromColmap(const std::string &projectRoot){
141140
PointSet *pSet = readPointSet(pointsPath.string());
142141
torch::Tensor points = pSet->pointsTensor().clone();
143142

144-
ret.points.xyz = torch::matmul(torch::cat({points, torch::ones_like(points.index({"...", Slice(None, 1)}))}, -1),
145-
ret.transformMatrix.transpose(0, 1));
146-
ret.points.xyz *= ret.scaleFactor;
143+
ret.points.xyz = (points - ret.translation) * ret.scale;
147144
ret.points.rgb = pSet->colorsTensor().clone();
148145

149146
RELEASE_POINTSET(pSet);

input_data.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ struct Points{
5252
};
5353
struct InputData{
5454
std::vector<Camera> cameras;
55-
float scaleFactor;
56-
torch::Tensor transformMatrix;
55+
float scale;
56+
torch::Tensor translation;
5757
Points points;
5858

5959
std::tuple<std::vector<Camera>, Camera *> getCameras(bool validate, const std::string &valImage = "random");

model.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,10 @@ void Model::savePlySplat(const std::string &filename){
498498

499499
float zeros[] = { 0.0f, 0.0f, 0.0f };
500500

501-
torch::Tensor meansCpu = means.cpu();
501+
torch::Tensor meansCpu = (means.cpu() / scale) + translation;
502502
torch::Tensor featuresDcCpu = featuresDc.cpu();
503503
torch::Tensor opacitiesCpu = opacities.cpu();
504-
torch::Tensor scalesCpu = scales.cpu();
504+
torch::Tensor scalesCpu = torch::log((torch::exp(scales.cpu()) / scale));
505505
torch::Tensor quatsCpu = quats.cpu();
506506

507507
for (size_t i = 0; i < numPoints; i++) {
@@ -518,6 +518,35 @@ void Model::savePlySplat(const std::string &filename){
518518
std::cout << "Wrote " << filename << std::endl;
519519
}
520520

521+
void Model::saveDebugPly(const std::string &filename){
522+
// A standard PLY
523+
std::ofstream o(filename, std::ios::binary);
524+
int numPoints = means.size(0);
525+
526+
o << "ply" << std::endl;
527+
o << "format binary_little_endian 1.0" << std::endl;
528+
o << "comment Generated by opensplat" << std::endl;
529+
o << "element vertex " << numPoints << std::endl;
530+
o << "property float x" << std::endl;
531+
o << "property float y" << std::endl;
532+
o << "property float z" << std::endl;
533+
o << "property uchar red" << std::endl;
534+
o << "property uchar green" << std::endl;
535+
o << "property uchar blue" << std::endl;
536+
o << "end_header" << std::endl;
537+
538+
torch::Tensor meansCpu = (means.cpu() / scale) + translation;
539+
torch::Tensor rgbsCpu = (sh2rgb(featuresDc.cpu()) * 255.0f).toType(torch::kUInt8);
540+
541+
for (size_t i = 0; i < numPoints; i++) {
542+
o.write(reinterpret_cast<const char *>(meansCpu[i].data_ptr()), sizeof(float) * 3);
543+
o.write(reinterpret_cast<const char *>(rgbsCpu[i].data_ptr()), sizeof(uint8_t) * 3);
544+
}
545+
546+
o.close();
547+
std::cout << "Wrote " << filename << std::endl;
548+
}
549+
521550
torch::Tensor Model::mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight){
522551
torch::Tensor ssimLoss = 1.0f - ssim.eval(rgb, gt);
523552
torch::Tensor l1Loss = l1(rgb, gt);

model.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ torch::Tensor psnr(const torch::Tensor& rendered, const torch::Tensor& gt);
2020
torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt);
2121

2222
struct Model{
23-
Model(const Points &points, int numCameras,
23+
Model(const InputData &inputData, int numCameras,
2424
int numDownscales, int resolutionSchedule, int shDegree, int shDegreeInterval,
2525
int refineEvery, int warmupLength, int resetAlphaEvery, int stopSplitAt, float densifyGradThresh, float densifySizeThresh, int stopScreenSizeAt, float splitScreenSize,
2626
int maxSteps,
@@ -30,17 +30,21 @@ struct Model{
3030
refineEvery(refineEvery), warmupLength(warmupLength), resetAlphaEvery(resetAlphaEvery), stopSplitAt(stopSplitAt), densifyGradThresh(densifyGradThresh), densifySizeThresh(densifySizeThresh), stopScreenSizeAt(stopScreenSizeAt), splitScreenSize(splitScreenSize),
3131
maxSteps(maxSteps),
3232
device(device), ssim(11, 3){
33-
long long numPoints = points.xyz.size(0);
33+
34+
long long numPoints = inputData.points.xyz.size(0);
35+
scale = inputData.scale;
36+
translation = inputData.translation;
37+
3438
torch::manual_seed(42);
3539

36-
means = points.xyz.to(device).requires_grad_();
37-
scales = PointsTensor(points.xyz).scales().repeat({1, 3}).log().to(device).requires_grad_();
40+
means = inputData.points.xyz.to(device).requires_grad_();
41+
scales = PointsTensor(inputData.points.xyz).scales().repeat({1, 3}).log().to(device).requires_grad_();
3842
quats = randomQuatTensor(numPoints).to(device).requires_grad_();
3943

4044
int dimSh = numShBases(shDegree);
4145
torch::Tensor shs = torch::zeros({numPoints, dimSh, 3}, torch::TensorOptions().dtype(torch::kFloat32).device(device));
4246

43-
shs.index({Slice(), 0, Slice(None, 3)}) = rgb2sh(points.rgb.toType(torch::kFloat64) / 255.0).toType(torch::kFloat32);
47+
shs.index({Slice(), 0, Slice(None, 3)}) = rgb2sh(inputData.points.rgb.toType(torch::kFloat64) / 255.0).toType(torch::kFloat32);
4448
shs.index({Slice(), Slice(1, None), Slice(3, None)}) = 0.0f;
4549

4650
featuresDc = shs.index({Slice(), 0, Slice()}).to(device).requires_grad_();
@@ -78,6 +82,7 @@ struct Model{
7882
int getDownscaleFactor(int step);
7983
void afterTrain(int step);
8084
void savePlySplat(const std::string &filename);
85+
void saveDebugPly(const std::string &filename);
8186
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor &gt, float ssimWeight);
8287

8388
void addToOptimizer(torch::optim::Adam *optimizer, const torch::Tensor &newParam, const torch::Tensor &idcs, int nSamples);
@@ -126,6 +131,9 @@ struct Model{
126131
int stopScreenSizeAt;
127132
float splitScreenSize;
128133
int maxSteps;
134+
135+
float scale;
136+
torch::Tensor translation;
129137
};
130138

131139

nerfstudio.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,10 @@ InputData inputDataFromNerfStudio(const std::string &projectRoot){
136136

137137
torch::Tensor unorientedPoses = posesFromTransforms(t);
138138

139-
auto r = autoOrientAndCenterPoses(unorientedPoses);
139+
auto r = autoScaleAndCenterPoses(unorientedPoses);
140140
torch::Tensor poses = std::get<0>(r);
141-
ret.transformMatrix = std::get<1>(r);
142-
143-
ret.scaleFactor = 1.0f / torch::max(torch::abs(poses.index({Slice(), Slice(None, 3), 3}))).item<float>();
144-
poses.index({Slice(), Slice(None, 3), 3}) *= ret.scaleFactor;
141+
ret.translation = std::get<1>(r);
142+
ret.scale = std::get<2>(r);
145143

146144
// aabbScale = [[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]
147145

@@ -158,10 +156,8 @@ InputData inputDataFromNerfStudio(const std::string &projectRoot){
158156
}
159157

160158
torch::Tensor points = pSet->pointsTensor().clone();
161-
162-
ret.points.xyz = torch::matmul(torch::cat({points, torch::ones_like(points.index({"...", Slice(None, 1)}))}, -1),
163-
ret.transformMatrix.transpose(0, 1));
164-
ret.points.xyz *= ret.scaleFactor;
159+
160+
ret.points.xyz = (points - ret.translation) * ret.scale;
165161
ret.points.rgb = pSet->colorsTensor().clone();
166162

167163
RELEASE_POINTSET(pSet);

opensplat.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ int main(int argc, char *argv[]){
101101
std::vector<Camera> cams = std::get<0>(t);
102102
Camera *valCam = std::get<1>(t);
103103

104-
Model model(inputData.points,
105-
cams.size(),
106-
numDownscales, resolutionSchedule, shDegree, shDegreeInterval,
107-
refineEvery, warmupLength, resetAlphaEvery, stopSplitAt, densifyGradThresh, densifySizeThresh, stopScreenSizeAt, splitScreenSize,
108-
numIters,
109-
device);
104+
Model model(inputData,
105+
cams.size(),
106+
numDownscales, resolutionSchedule, shDegree, shDegreeInterval,
107+
refineEvery, warmupLength, resetAlphaEvery, stopSplitAt, densifyGradThresh, densifySizeThresh, stopScreenSizeAt, splitScreenSize,
108+
numIters,
109+
device);
110110

111111
std::vector< size_t > camIndices( cams.size() );
112112
std::iota( camIndices.begin(), camIndices.end(), 0 );
@@ -145,6 +145,7 @@ int main(int argc, char *argv[]){
145145
}
146146

147147
model.savePlySplat(outputScene);
148+
// model.saveDebugPly("debug.ply");
148149

149150
// Validate
150151
if (valCam != nullptr){

spherical_harmonics.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ int degFromSh(int numBases){
1515
}
1616
}
1717

18+
const double C0 = 0.28209479177387814;
19+
1820
torch::Tensor rgb2sh(const torch::Tensor &rgb){
1921
// Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
20-
const double C0 = 0.28209479177387814;
2122
return (rgb - 0.5) / C0;
2223
}
2324

25+
torch::Tensor sh2rgb(const torch::Tensor &sh){
26+
// Converts from 0th spherical harmonic coefficients to RGB values [0,1]
27+
return (sh * C0) + 0.5;
28+
}
29+
2430
#if defined(USE_HIP) || defined(USE_CUDA)
2531

2632
torch::Tensor SphericalHarmonics::forward(AutogradContext *ctx,

spherical_harmonics.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using namespace torch::autograd;
88

99
int degFromSh(int numBases);
1010
torch::Tensor rgb2sh(const torch::Tensor &rgb);
11+
torch::Tensor sh2rgb(const torch::Tensor &sh);
1112

1213
#if defined(USE_HIP) || defined(USE_CUDA)
1314

tensor_math.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,20 @@ torch::Tensor quatToRotMat(const torch::Tensor &quat){
2828

2929
}
3030

31-
std::tuple<torch::Tensor, torch::Tensor> autoOrientAndCenterPoses(const torch::Tensor &poses){
32-
// Center at mean and orient up
31+
std::tuple<torch::Tensor, torch::Tensor, float> autoScaleAndCenterPoses(const torch::Tensor &poses){
32+
// Center at mean
3333
torch::Tensor origins = poses.index({"...", Slice(None, 3), 3});
34-
torch::Tensor translation = torch::mean(origins, 0);
35-
torch::Tensor up = torch::mean(poses.index({Slice(), Slice(None, 3), 1}), 0);
36-
up = up / up.norm();
34+
torch::Tensor center = torch::mean(origins, 0);
35+
origins -= center;
36+
37+
// Scale
38+
float f = 1.0f / torch::max(torch::abs(origins)).item<float>();
39+
origins *= f;
3740

38-
torch::Tensor rotation = rotationMatrix(up, torch::tensor({0, 0, 1}, torch::kFloat32));
39-
torch::Tensor transform = torch::cat({rotation, torch::matmul(rotation, -translation.index({"...", None}))}, -1);
40-
torch::Tensor orientedPoses = torch::matmul(transform, poses);
41-
return std::make_tuple(orientedPoses, transform);
41+
torch::Tensor transformedPoses = poses.clone();
42+
transformedPoses.index_put_({"...", Slice(None, 3), 3}, origins);
43+
44+
return std::make_tuple(transformedPoses, center, f);
4245
}
4346

4447

tensor_math.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <tuple>
66

77
torch::Tensor quatToRotMat(const torch::Tensor &quat);
8-
std::tuple<torch::Tensor, torch::Tensor> autoOrientAndCenterPoses(const torch::Tensor &poses);
8+
std::tuple<torch::Tensor, torch::Tensor, float> autoScaleAndCenterPoses(const torch::Tensor &poses);
99
torch::Tensor rotationMatrix(const torch::Tensor &a, const torch::Tensor &b);
1010

1111

0 commit comments

Comments
 (0)