@@ -84,32 +84,22 @@ InputData inputDataFromOpenSfM(const std::string &projectRoot){
8484
8585 torch::Tensor unorientedPoses = torch::zeros ({static_cast <long int >(shots.size ()), 4 , 4 }, torch::kFloat32 );
8686 size_t i = 0 ;
87- for (const auto &s : shots) {
87+ for (const auto &s : shots){
8888 Shot shot = s.second ;
8989
9090 torch::Tensor rotation = rodriguesToRotation (torch::from_blob (shot.rotation .data (), {static_cast <long >(shot.rotation .size ())}, torch::kFloat32 ));
9191 torch::Tensor translation = torch::from_blob (shot.translation .data (), {static_cast <long >(shot.translation .size ())}, torch::kFloat32 );
92-
9392 torch::Tensor w2c = torch::eye (4 , torch::kFloat32 );
9493 w2c.index_put_ ({Slice (None, 3 ), Slice (None, 3 )}, rotation);
95- w2c.index_put_ ({Slice (None, 3 ), Slice (3 , 4 )}, translation.reshape ({3 , 1 }));
96-
97- // Manually compute the inverse of w2c
98- torch::Tensor rotationT = rotation.transpose (0 , 1 ); // Transpose rotation (3x3)
99- torch::Tensor translationInv = -(rotationT.matmul (translation.reshape ({3 , 1 }))); // -R^T * t
94+ w2c.index_put_ ({Slice (None, 3 ), Slice (3 ,4 )}, translation.reshape ({3 , 1 }));
10095
101- torch::Tensor invW2C = torch::eye (4 , torch::kFloat32 );
102- invW2C.index_put_ ({Slice (None, 3 ), Slice (None, 3 )}, rotationT); // Set rotation part
103- invW2C.index_put_ ({Slice (None, 3 ), Slice (3 , 4 )}, translationInv); // Set translation part
104-
105- unorientedPoses[i] = invW2C;
96+ unorientedPoses[i] = torch::linalg_inv (w2c);
10697
10798 // Convert OpenSfM's camera CRS (OpenCV) to OpenGL
108- unorientedPoses[i].index_put_ ({Slice (0 , 3 ), Slice (1 , 3 )}, unorientedPoses[i].index ({Slice (0 , 3 ), Slice (1 , 3 )}) * -1 .0f );
99+ unorientedPoses[i].index_put_ ({Slice (0 , 3 ), Slice (1 ,3 )}, unorientedPoses[i].index ({Slice (0 , 3 ), Slice (1 ,3 )}) * -1 .0f );
109100 i++;
110101 }
111102
112-
113103 auto r = autoScaleAndCenterPoses (unorientedPoses);
114104 torch::Tensor poses = std::get<0 >(r);
115105 ret.translation = std::get<1 >(r);
0 commit comments