diff --git a/fvdb/src/GaussianSplatting.cpp b/fvdb/src/GaussianSplatting.cpp index a8f32fa808..1311f590f6 100644 --- a/fvdb/src/GaussianSplatting.cpp +++ b/fvdb/src/GaussianSplatting.cpp @@ -513,7 +513,7 @@ GaussianSplat3d::savePly(const std::string &filename) const { mShN.index({ torch::indexing::Slice(), validMask.jdata(), torch::indexing::Ellipsis }) .cpu() .permute({ 1, 0, 2 }) - .reshape({ mMeans.size(0), -1 }); + .reshape({ meansCPU.size(0), -1 }); plyf.add_properties_to_element("vertex", { "x", "y", "z" }, Type::FLOAT32, meansCPU.size(0), detail::tensorBytePointer(meansCPU), Type::INVALID, 0); diff --git a/fvdb/tests/unit/test_gsplat.py b/fvdb/tests/unit/test_gsplat.py index d26bc2002d..99d9004ad8 100644 --- a/fvdb/tests/unit/test_gsplat.py +++ b/fvdb/tests/unit/test_gsplat.py @@ -75,6 +75,18 @@ def setUp(self): requires_grad=True, ) + nan_mean = means.clone() + nan_mean[0] = torch.tensor([float("nan"), float("nan"), float("nan")], device=self.device) + self.nan_gs3d = GaussianSplat3d( + means=nan_mean, + quats=quats, + log_scales=torch.log(scales), + logit_opacities=torch.logit(opacities), + sh0=sh_0, + shN=sh_n, + requires_grad=True, + ) + self.num_cameras = self.cam_to_world_mats.shape[0] self.near_plane = 0.01 self.far_plane = 1e10 @@ -152,6 +164,52 @@ def test_save_ply(self): shN_loaded = shN_loaded.view(self.gs3d.num_gaussians, 15, 3).permute(1, 0, 2) self.assertTrue(torch.allclose(shN_loaded, self.gs3d.shN)) + def _create_gs3d_without_first_gaussian(self, gs3d): + """Helper to create a new GS3D instance with the first gaussian removed.""" + return GaussianSplat3d( + means=gs3d.means[1:], + quats=gs3d.quats[1:], + log_scales=gs3d.log_scales[1:], + logit_opacities=gs3d.logit_opacities[1:], + sh0=gs3d.sh0[:, 1:, :], + shN=gs3d.shN[:, 1:, :], + requires_grad=True, + ) + + def test_save_ply_handles_nan(self): + tf = tempfile.NamedTemporaryFile(delete=True, suffix=".ply") + + self.nan_gs3d.save_ply(tf.name) + + # Remove the first element from all tensors to compare with expected loaded ply + gs3d_without_nan = self._create_gs3d_without_first_gaussian(self.nan_gs3d) + + loaded = pcu.load_triangle_mesh(tf.name) + attribs = loaded.vertex_data.custom_attributes + means_loaded = torch.from_numpy(loaded.vertex_data.positions).to(self.device) + self.assertTrue(torch.allclose(means_loaded, gs3d_without_nan.means)) + + scales_loaded = torch.from_numpy( + np.stack([attribs["scale_0"], attribs["scale_1"], attribs["scale_2"]], axis=-1) + ).to(self.device) + self.assertTrue(torch.allclose(scales_loaded, gs3d_without_nan.log_scales)) + + quats_loaded = torch.from_numpy( + np.stack([attribs["rot_0"], attribs["rot_1"], attribs["rot_2"], attribs["rot_3"]], axis=-1) + ).to(self.device) + self.assertTrue(torch.allclose(quats_loaded, gs3d_without_nan.quats)) + + opacities_loaded = torch.from_numpy(attribs["opacity"]).to(self.device) + self.assertTrue(torch.allclose(opacities_loaded, gs3d_without_nan.logit_opacities)) + + sh0_loaded = ( + torch.from_numpy(np.stack([attribs[f"f_dc_{i}"] for i in range(3)], axis=1)).to(self.device).unsqueeze(0) + ) + self.assertTrue(torch.allclose(sh0_loaded, gs3d_without_nan.sh0)) + shN_loaded = torch.from_numpy(np.stack([attribs[f"f_rest_{i}"] for i in range(45)], axis=1)).to(self.device) + shN_loaded = shN_loaded.view(gs3d_without_nan.num_gaussians, 15, 3).permute(1, 0, 2) + self.assertTrue(torch.allclose(shN_loaded, gs3d_without_nan.shN)) + def test_gaussian_render(self): render_colors, render_alphas = self.gs3d.render_images( self.cam_to_world_mats, self.projection_mats, self.width, self.height, self.near_plane, self.far_plane