11import os
2+ from collections .abc import Sequence
3+ import math
24import torch
35import kaolin
46
57
68PHYS_NOTEBOOKS_DIR = os .path .abspath (os .path .dirname (os .path .abspath (__file__ )))
79
8- # TODO(shumash): all of these should move to core library; address in v0.19.0
9- def transform_xyz (xyz : torch .Tensor , transform : torch .Tensor ):
10- if len (transform .shape ) == 2 : # single transform for all the gaussians
11- transform = transform .unsqueeze (0 )
12- res = (transform [..., :3 , :3 ] @ xyz [:, :, None ] + transform [..., :3 , 3 :]).squeeze (- 1 )
13- return res
14-
15-
16- def transform_rot (rot : torch .Tensor , transform : torch .Tensor ):
17- if len (transform .shape ) == 2 : # single transform for all the gaussians
18- transform = transform .unsqueeze (0 )
19-
20- rot_quat = kaolin .math .quat .quat_from_rot33 (transform [..., :3 , :3 ])
21- rot_unit = rot / torch .linalg .norm (rot , dim = - 1 ).unsqueeze (- 1 )
22-
23- # Note: gsplats use Hamiltonion convention [real, imag], whereas Kaolin uses the other convention[imag, real]
24- rot_unit = torch .cat ([rot_unit [:, 1 :], rot_unit [:, :1 ]], dim = - 1 )
25-
26- result = kaolin .math .quat .quat_mul (rot_quat , rot_unit )
27- result = torch .cat ([result [:, 3 :], result [:, :3 ]], dim = - 1 )
28- return result
29-
30-
31- def decompose_4x4_transform (transform ):
32- """ Decompose 4x4 transform into translation, rotation, scale.
33- Returns:
34- translation, rotation, scale
35- """
36- translation = transform [..., :3 , 3 :]
37- scale = torch .linalg .norm (transform [..., :3 , :3 ], dim = - 2 )
38- rotation = transform [..., :3 , :3 ] / scale .unsqueeze (- 2 )
39-
40- return translation , rotation , scale
41-
42-
43- def transform_gaussians (xyz , rotations , raw_scales , transform , use_log_scales = True ):
44- if len (transform .shape ) == 2 : # single transform for all the gaussians
45- transform = transform .unsqueeze (0 )
46-
47- # transforms: n x 4 x 4, where 4 x 4 transform T is applied to pt as T @ pt.
48- translation , rotation , scale = decompose_4x4_transform (transform )
49-
50- new_xyz = transform_xyz (xyz , transform )
51- new_rotations = transform_rot (rotations , rotation )
52-
53- if not use_log_scales :
54- new_scales = raw_scales * scale
55- else :
56- scaling_norm_factor = torch .log (scale ) / raw_scales + 1
57- new_scales = raw_scales * scaling_norm_factor
58-
59- return new_xyz , new_rotations , new_scales
60-
6110def pad_transforms (obj_tfms ):
6211 """
6312 Args:
@@ -70,15 +19,81 @@ def pad_transforms(obj_tfms):
7019 padded_tensor = torch .cat ([obj_tfms , padding_row ], dim = 1 )
7120 return padded_tensor
7221
73- def transform_gaussians_lbs (xyz , rotations , raw_scales , skinning_weights , transforms ):
74- # N x 4 x 4 = sum((N x H x 1 x 1) * (1 x H x 4 x 4), dim=1)
75- per_pt_transforms = torch .sum (skinning_weights .unsqueeze (- 1 ).unsqueeze (- 1 ) * transforms , dim = 1 )
76- # log_tensor(per_pt_transforms, 'pt transforms', logger)
77-
78- # convert relative transforms to absolute transforms
79- per_pt_transforms = per_pt_transforms + torch .eye (4 , dtype = per_pt_transforms .dtype ,
80- device = per_pt_transforms .device ).unsqueeze (0 )
81-
82- new_xyz , new_rot , new_scales = transform_gaussians (xyz , rotations , raw_scales , per_pt_transforms )
22+ def transform_gaussians_lbs (xyz , rotations , raw_scales , skinning_weights , transforms , shs_feat = None ):
23+ with torch .no_grad ():
24+ # N x 4 x 4 = sum((N x H x 1 x 1) * (1 x H x 4 x 4), dim=1)
25+ per_pt_transforms = torch .sum (skinning_weights .unsqueeze (- 1 ).unsqueeze (- 1 ) * transforms , dim = 1 )
26+ # log_tensor(per_pt_transforms, 'pt transforms', logger)
27+
28+ # convert relative transforms to absolute transforms
29+ per_pt_transforms = per_pt_transforms + torch .eye (4 , dtype = per_pt_transforms .dtype ,
30+ device = per_pt_transforms .device ).unsqueeze (0 )
31+
32+ new_xyz , new_rot , new_scales , new_shs_feat = kaolin .ops .gaussians .transform_gaussians (
33+ xyz , rotations , raw_scales , per_pt_transforms , shs_feat = shs_feat )
34+
35+ return new_xyz , new_rot , new_scales , new_shs_feat
36+
37+ def concat_gaussians (gaussians ):
38+ from gaussian_renderer import GaussianModel
39+ assert isinstance (gaussians , Sequence )
40+ xyz = []
41+ rotation = []
42+ scaling = []
43+ opacity = []
44+ features_dc = []
45+ features_rest = []
46+ max_sh_degree = gaussians [0 ].max_sh_degree
47+ for g in gaussians :
48+ assert isinstance (g , GaussianModel ) and g .max_sh_degree == max_sh_degree
49+ xyz .append (g ._xyz )
50+ rotation .append (g ._rotation )
51+ scaling .append (g ._scaling )
52+ opacity .append (g ._opacity )
53+ features_dc .append (g ._features_dc )
54+ features_rest .append (g ._features_rest )
55+ output = GaussianModel (max_sh_degree )
56+ output ._xyz = torch .cat (xyz , dim = 0 ).float ()
57+ output ._rotation = torch .cat (rotation , dim = 0 ).float ()
58+ output ._scaling = torch .cat (scaling , dim = 0 ).float ()
59+ output ._opacity = torch .cat (opacity , dim = 0 ).float ()
60+ output ._features_dc = torch .cat (features_dc , dim = 0 ).float ()
61+ output ._features_rest = torch .cat (features_rest , dim = 0 ).float ()
62+ return output
63+
64+ def inverse_sigmoid (x ):
65+ return torch .log (x / (1 - x ))
66+
67+ def quat_wxyz2xyzw (quat ):
68+ return torch .cat ([quat [:, 1 :], quat [:, :1 ]], dim = - 1 )
69+
70+ def quat_xyzw2wxyz (quat ):
71+ return torch .cat ([quat [:, - 1 :], quat [:, :- 1 ]], dim = - 1 )
72+
73+ inria_fields = ['xyz' , 'rotation' , 'scaling' , 'opacity' , 'features_dc' , 'features_rest' ]
74+ usd_fields = ['positions' , 'orientations' , 'scales' , 'opacities' , 'sh_coeff' ]
75+
76+ def inria_to_usd (gaussians ): #xyz, rotation, scaling, opacity, features_dc, features_rest):
77+ return {
78+ 'positions' : gaussians ._xyz ,
79+ 'orientations' : quat_wxyz2xyzw (gaussians ._rotation ),
80+ 'scales' : torch .exp (gaussians ._scaling ),
81+ 'opacities' : torch .sigmoid (gaussians ._opacity ).unsqueeze (- 1 ),
82+ 'sh_coeff' : torch .cat ([
83+ gaussians ._features_dc ,
84+ gaussians ._features_rest
85+ ], dim = 1 )
86+ }
87+
88+ def usd_to_inria (positions , orientations , scales , opacities , sh_coeff , ** kwargs ):
89+ from gaussian_renderer import GaussianModel
90+ degrees = math .isqrt (sh_coeff .shape [1 ]) - 1
91+ gaussians = GaussianModel (degrees )
92+ gaussians ._xyz = positions .cuda ()
93+ gaussians ._rotation = quat_xyzw2wxyz (orientations ).cuda ()
94+ gaussians ._scaling = torch .log (scales ).cuda ()
95+ gaussians ._opacity = inverse_sigmoid (opacities ).cuda ()
96+ gaussians ._features_dc = sh_coeff [:, :1 ].cuda ()
97+ gaussians ._features_rest = sh_coeff [:, 1 :].cuda ()
98+ return gaussians
8399
84- return new_xyz , new_rot , new_scales
0 commit comments