Skip to content

Commit 1a1e0cc

Browse files
MrNeRFmaturk
andauthored
Implement gaussian splatting ply file saving (#427)
* implement ply saving * fix colors * Add own flag for saveing ply files * fix appearance embeddings * remove open3d * align order with INRIA ply * filter Nan and infs * add flag to save ply and move save_ply to utils --------- Co-authored-by: maturk <[email protected]>
1 parent 2df0a95 commit 1a1e0cc

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

examples/simple_trainer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from gsplat.rendering import rasterization
4242
from gsplat.strategy import DefaultStrategy, MCMCStrategy
4343
from gsplat.optimizers import SelectiveAdam
44+
from gsplat.utils import save_ply
4445

4546

4647
@dataclass
@@ -85,6 +86,10 @@ class Config:
8586
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
8687
# Steps to save the model
8788
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
89+
# Whether to save ply file (storage size can be large)
90+
save_ply: bool = False
91+
# Steps to save the model as ply
92+
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
8893

8994
# Initialization strategy
9095
init_type: str = "sfm"
@@ -167,6 +172,7 @@ class Config:
167172
def adjust_steps(self, factor: float):
168173
self.eval_steps = [int(i * factor) for i in self.eval_steps]
169174
self.save_steps = [int(i * factor) for i in self.save_steps]
175+
self.ply_steps = [int(i * factor) for i in self.ply_steps]
170176
self.max_steps = int(self.max_steps * factor)
171177
self.sh_degree_interval = int(self.sh_degree_interval * factor)
172178

@@ -294,6 +300,8 @@ def __init__(
294300
os.makedirs(self.stats_dir, exist_ok=True)
295301
self.render_dir = f"{cfg.result_dir}/renders"
296302
os.makedirs(self.render_dir, exist_ok=True)
303+
self.ply_dir = f"{cfg.result_dir}/ply"
304+
os.makedirs(self.ply_dir, exist_ok=True)
297305

298306
# Tensorboard
299307
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
@@ -735,6 +743,24 @@ def train(self):
735743
torch.save(
736744
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
737745
)
746+
if (
747+
step in [i - 1 for i in cfg.ply_steps]
748+
or step == max_steps - 1
749+
and cfg.save_ply
750+
):
751+
rgb = None
752+
if self.cfg.app_opt:
753+
# eval at origin to bake the appeareance into the colors
754+
rgb = self.app_module(
755+
features=self.splats["features"],
756+
embed_ids=None,
757+
dirs=torch.zeros_like(self.splats["means"][None, :, :]),
758+
sh_degree=sh_degree_to_use,
759+
)
760+
rgb = rgb + self.splats["colors"]
761+
rgb = torch.sigmoid(rgb).squeeze(0)
762+
763+
save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb)
738764

739765
# Turn Gradients into Sparse Tensor before running optimizer
740766
if cfg.sparse_grad:

gsplat/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,101 @@
11
import math
2+
import struct
23

34
import torch
45
import torch.nn.functional as F
56
from torch import Tensor
7+
import numpy as np
8+
9+
10+
def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None):
11+
# Convert all tensors to numpy arrays in one go
12+
print(f"Saving ply to {dir}")
13+
numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()}
14+
15+
means = numpy_data["means"]
16+
scales = numpy_data["scales"]
17+
quats = numpy_data["quats"]
18+
opacities = numpy_data["opacities"]
19+
20+
sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1)
21+
shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1)
22+
23+
# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays
24+
invalid_mask = (
25+
np.isnan(means).any(axis=1)
26+
| np.isinf(means).any(axis=1)
27+
| np.isnan(scales).any(axis=1)
28+
| np.isinf(scales).any(axis=1)
29+
| np.isnan(quats).any(axis=1)
30+
| np.isinf(quats).any(axis=1)
31+
| np.isnan(opacities).any(axis=0)
32+
| np.isinf(opacities).any(axis=0)
33+
| np.isnan(sh0).any(axis=1)
34+
| np.isinf(sh0).any(axis=1)
35+
| np.isnan(shN).any(axis=1)
36+
| np.isinf(shN).any(axis=1)
37+
)
38+
39+
# Filter out rows with NaNs or Infs from all data arrays
40+
means = means[~invalid_mask]
41+
scales = scales[~invalid_mask]
42+
quats = quats[~invalid_mask]
43+
opacities = opacities[~invalid_mask]
44+
sh0 = sh0[~invalid_mask]
45+
shN = shN[~invalid_mask]
46+
47+
num_points = means.shape[0]
48+
49+
with open(dir, "wb") as f:
50+
# Write PLY header
51+
f.write(b"ply\n")
52+
f.write(b"format binary_little_endian 1.0\n")
53+
f.write(f"element vertex {num_points}\n".encode())
54+
f.write(b"property float x\n")
55+
f.write(b"property float y\n")
56+
f.write(b"property float z\n")
57+
f.write(b"property float nx\n")
58+
f.write(b"property float ny\n")
59+
f.write(b"property float nz\n")
60+
61+
if colors is not None:
62+
for j in range(colors.shape[1]):
63+
f.write(f"property float f_dc_{j}\n".encode())
64+
else:
65+
for i, data in enumerate([sh0, shN]):
66+
prefix = "f_dc" if i == 0 else "f_rest"
67+
for j in range(data.shape[1]):
68+
f.write(f"property float {prefix}_{j}\n".encode())
69+
70+
f.write(b"property float opacity\n")
71+
72+
for i in range(scales.shape[1]):
73+
f.write(f"property float scale_{i}\n".encode())
74+
for i in range(quats.shape[1]):
75+
f.write(f"property float rot_{i}\n".encode())
76+
77+
f.write(b"end_header\n")
78+
79+
# Write vertex data
80+
for i in range(num_points):
81+
f.write(struct.pack("<fff", *means[i])) # x, y, z
82+
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)
83+
84+
if colors is not None:
85+
color = colors.detach().cpu().numpy()
86+
for j in range(color.shape[1]):
87+
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
88+
f.write(struct.pack("<f", f_dc))
89+
else:
90+
for data in [sh0, shN]:
91+
for j in range(data.shape[1]):
92+
f.write(struct.pack("<f", data[i, j]))
93+
94+
f.write(struct.pack("<f", opacities[i])) # opacity
95+
96+
for data in [scales, quats]:
97+
for j in range(data.shape[1]):
98+
f.write(struct.pack("<f", data[i, j]))
699

7100

8101
def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:

0 commit comments

Comments
 (0)