Skip to content

Commit b8da3b8

Browse files
authored
Add USD importer / exporter for gaussians + subsets (#948)
* add gaussians importer/exporter Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> * add new kaolin.ops.gaussians.rst Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> * exclude kaolin/io/usd/gaussians.py Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> * exclude kaolin/ops/gaussians/transforms.py Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com> --------- Signed-off-by: Clement Fuji Tsang <cfujitsang@nvidia.com>
1 parent aea756e commit b8da3b8

14 files changed

Lines changed: 2448 additions & 147 deletions

File tree

ci/gitlab_jenkins_templates/ubuntu_test_cuda_CI.jenkins

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ spec:
179179
export KAOLIN_TEST_MODELNET_PATH=/mnt/data/ModelNet
180180
export KAOLIN_TEST_SHREC16_PATH=/mnt/data/ci_shrec16
181181
export KAOLIN_TEST_GSPLATS_DIR=/mnt/gsplats
182+
export KAOLIN_TEST_TOYS_DATASET_PATH=/mnt/data/toys/
182183
pytest --durations=50 --import-mode=importlib -rs --cov=/kaolin/kaolin \
183184
--log-disable=PIL.PngImagePlugin \
184185
--log-disable=PIL.TiffImagePlugin \

docs/kaolin_ext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def run_apidoc(_):
3737
"kaolin/io/usd/materials.py",
3838
"kaolin/io/usd/voxelgrid.py",
3939
"kaolin/io/usd/pointcloud.py",
40+
"kaolin/io/usd/gaussians.py",
4041
"kaolin/math/quat/angle_axis.py",
4142
"kaolin/math/quat/euclidean.py",
4243
"kaolin/math/quat/matrix44.py",
@@ -56,6 +57,7 @@ def run_apidoc(_):
5657
"kaolin/ops/mesh/tetmesh.py",
5758
"kaolin/ops/mesh/trianglemesh.py",
5859
"kaolin/ops/gaussian/densifier.py",
60+
"kaolin/ops/gaussians/transforms.py",
5961
"kaolin/ops/spc/spc.py",
6062
"kaolin/ops/spc/convolution.py",
6163
"kaolin/ops/spc/points.py",

docs/modules/kaolin.ops.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Tensor batching operators are in :ref:`kaolin.ops.batch`, conversions of 3D mode
1818
kaolin.ops.gcn
1919
kaolin.ops.mesh
2020
kaolin.ops.gaussian
21+
kaolin.ops.gaussians
2122
kaolin.ops.random
2223
kaolin.ops.reduction
2324
kaolin.ops.spc

examples/tutorial/physics/gaussian_utils.py

Lines changed: 79 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,12 @@
11
import os
2+
from collections.abc import Sequence
3+
import math
24
import torch
35
import kaolin
46

57

68
PHYS_NOTEBOOKS_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
79

8-
# TODO(shumash): all of these should move to core library; address in v0.19.0
9-
def transform_xyz(xyz: torch.Tensor, transform: torch.Tensor):
10-
if len(transform.shape) == 2: # single transform for all the gaussians
11-
transform = transform.unsqueeze(0)
12-
res = (transform[..., :3, :3] @ xyz[:, :, None] + transform[..., :3, 3:]).squeeze(-1)
13-
return res
14-
15-
16-
def transform_rot(rot: torch.Tensor, transform: torch.Tensor):
17-
if len(transform.shape) == 2: # single transform for all the gaussians
18-
transform = transform.unsqueeze(0)
19-
20-
rot_quat = kaolin.math.quat.quat_from_rot33(transform[..., :3, :3])
21-
rot_unit = rot / torch.linalg.norm(rot, dim=-1).unsqueeze(-1)
22-
23-
# Note: gsplats use Hamiltonion convention [real, imag], whereas Kaolin uses the other convention[imag, real]
24-
rot_unit = torch.cat([rot_unit[:, 1:], rot_unit[:, :1]], dim=-1)
25-
26-
result = kaolin.math.quat.quat_mul(rot_quat, rot_unit)
27-
result = torch.cat([result[:, 3:], result[:, :3]], dim=-1)
28-
return result
29-
30-
31-
def decompose_4x4_transform(transform):
32-
""" Decompose 4x4 transform into translation, rotation, scale.
33-
Returns:
34-
translation, rotation, scale
35-
"""
36-
translation = transform[..., :3, 3:]
37-
scale = torch.linalg.norm(transform[..., :3, :3], dim=-2)
38-
rotation = transform[..., :3, :3] / scale.unsqueeze(-2)
39-
40-
return translation, rotation, scale
41-
42-
43-
def transform_gaussians(xyz, rotations, raw_scales, transform, use_log_scales=True):
44-
if len(transform.shape) == 2: # single transform for all the gaussians
45-
transform = transform.unsqueeze(0)
46-
47-
# transforms: n x 4 x 4, where 4 x 4 transform T is applied to pt as T @ pt.
48-
translation, rotation, scale = decompose_4x4_transform(transform)
49-
50-
new_xyz = transform_xyz(xyz, transform)
51-
new_rotations = transform_rot(rotations, rotation)
52-
53-
if not use_log_scales:
54-
new_scales = raw_scales * scale
55-
else:
56-
scaling_norm_factor = torch.log(scale) / raw_scales + 1
57-
new_scales = raw_scales * scaling_norm_factor
58-
59-
return new_xyz, new_rotations, new_scales
60-
6110
def pad_transforms(obj_tfms):
6211
"""
6312
Args:
@@ -70,15 +19,81 @@ def pad_transforms(obj_tfms):
7019
padded_tensor = torch.cat([obj_tfms, padding_row], dim=1)
7120
return padded_tensor
7221

73-
def transform_gaussians_lbs(xyz, rotations, raw_scales, skinning_weights, transforms):
74-
# N x 4 x 4 = sum((N x H x 1 x 1) * (1 x H x 4 x 4), dim=1)
75-
per_pt_transforms = torch.sum(skinning_weights.unsqueeze(-1).unsqueeze(-1) * transforms, dim=1)
76-
# log_tensor(per_pt_transforms, 'pt transforms', logger)
77-
78-
# convert relative transforms to absolute transforms
79-
per_pt_transforms = per_pt_transforms + torch.eye(4, dtype=per_pt_transforms.dtype,
80-
device=per_pt_transforms.device).unsqueeze(0)
81-
82-
new_xyz, new_rot, new_scales = transform_gaussians(xyz, rotations, raw_scales, per_pt_transforms)
22+
def transform_gaussians_lbs(xyz, rotations, raw_scales, skinning_weights, transforms, shs_feat=None):
23+
with torch.no_grad():
24+
# N x 4 x 4 = sum((N x H x 1 x 1) * (1 x H x 4 x 4), dim=1)
25+
per_pt_transforms = torch.sum(skinning_weights.unsqueeze(-1).unsqueeze(-1) * transforms, dim=1)
26+
# log_tensor(per_pt_transforms, 'pt transforms', logger)
27+
28+
# convert relative transforms to absolute transforms
29+
per_pt_transforms = per_pt_transforms + torch.eye(4, dtype=per_pt_transforms.dtype,
30+
device=per_pt_transforms.device).unsqueeze(0)
31+
32+
new_xyz, new_rot, new_scales, new_shs_feat = kaolin.ops.gaussians.transform_gaussians(
33+
xyz, rotations, raw_scales, per_pt_transforms, shs_feat=shs_feat)
34+
35+
return new_xyz, new_rot, new_scales, new_shs_feat
36+
37+
def concat_gaussians(gaussians):
38+
from gaussian_renderer import GaussianModel
39+
assert isinstance(gaussians, Sequence)
40+
xyz = []
41+
rotation = []
42+
scaling = []
43+
opacity = []
44+
features_dc = []
45+
features_rest = []
46+
max_sh_degree = gaussians[0].max_sh_degree
47+
for g in gaussians:
48+
assert isinstance(g, GaussianModel) and g.max_sh_degree == max_sh_degree
49+
xyz.append(g._xyz)
50+
rotation.append(g._rotation)
51+
scaling.append(g._scaling)
52+
opacity.append(g._opacity)
53+
features_dc.append(g._features_dc)
54+
features_rest.append(g._features_rest)
55+
output = GaussianModel(max_sh_degree)
56+
output._xyz = torch.cat(xyz, dim=0).float()
57+
output._rotation = torch.cat(rotation, dim=0).float()
58+
output._scaling = torch.cat(scaling, dim=0).float()
59+
output._opacity = torch.cat(opacity, dim=0).float()
60+
output._features_dc = torch.cat(features_dc, dim=0).float()
61+
output._features_rest = torch.cat(features_rest, dim=0).float()
62+
return output
63+
64+
def inverse_sigmoid(x):
65+
return torch.log(x/(1-x))
66+
67+
def quat_wxyz2xyzw(quat):
68+
return torch.cat([quat[:, 1:], quat[:, :1]], dim=-1)
69+
70+
def quat_xyzw2wxyz(quat):
71+
return torch.cat([quat[:, -1:], quat[:, :-1]], dim=-1)
72+
73+
inria_fields = ['xyz', 'rotation', 'scaling', 'opacity', 'features_dc', 'features_rest']
74+
usd_fields = ['positions', 'orientations', 'scales', 'opacities', 'sh_coeff']
75+
76+
def inria_to_usd(gaussians): #xyz, rotation, scaling, opacity, features_dc, features_rest):
77+
return {
78+
'positions': gaussians._xyz,
79+
'orientations': quat_wxyz2xyzw(gaussians._rotation),
80+
'scales': torch.exp(gaussians._scaling),
81+
'opacities': torch.sigmoid(gaussians._opacity).unsqueeze(-1),
82+
'sh_coeff': torch.cat([
83+
gaussians._features_dc,
84+
gaussians._features_rest
85+
], dim=1)
86+
}
87+
88+
def usd_to_inria(positions, orientations, scales, opacities, sh_coeff, **kwargs):
89+
from gaussian_renderer import GaussianModel
90+
degrees = math.isqrt(sh_coeff.shape[1]) - 1
91+
gaussians = GaussianModel(degrees)
92+
gaussians._xyz = positions.cuda()
93+
gaussians._rotation = quat_xyzw2wxyz(orientations).cuda()
94+
gaussians._scaling = torch.log(scales).cuda()
95+
gaussians._opacity = inverse_sigmoid(opacities).cuda()
96+
gaussians._features_dc = sh_coeff[:, :1].cuda()
97+
gaussians._features_rest = sh_coeff[:, 1:].cuda()
98+
return gaussians
8399

84-
return new_xyz, new_rot, new_scales

0 commit comments

Comments
 (0)