3030from .utils import depth_to_normal , get_projection_matrix
3131
3232
33+ def _compute_view_dirs_packed (
34+ means : Tensor , # [..., N, 3]
35+ campos : Tensor , # [..., C, 3]
36+ batch_ids : Tensor , # [nnz]
37+ camera_ids : Tensor , # [nnz]
38+ gaussian_ids : Tensor , # [nnz]
39+ indptr : Tensor , # [B*C+1]
40+ B : int ,
41+ C : int ,
42+ ) -> Tensor :
43+ """Compute view directions for packed Gaussian-camera pairs.
44+
45+ This function computes the view directions (means - campos) for each
46+ Gaussian-camera pair in the packed format. It automatically selects between
47+ a simple vectorized approach or an optimized loop-based approach based on
48+ the data size and whether campos requires gradients.
49+
50+ Args:
51+ means: The 3D centers of the Gaussians. [..., N, 3]
52+ campos: Camera positions in world coordinates [..., C, 3]
53+ batch_ids: The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
54+ camera_ids: The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
55+ gaussian_ids: The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
56+ indptr: CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
57+ B: Number of batches
58+ C: Number of cameras
59+
60+ Returns:
61+ dirs: View directions [nnz, 3]
62+ """
63+ N = means .shape [- 2 ]
64+ nnz = batch_ids .shape [0 ]
65+ device = means .device
66+ means_flat = means .view (B , N , 3 )
67+ campos_flat = campos .view (B , C , 3 )
68+
69+ if B * C == 1 :
70+ # Single batch-camera pair. No indexed lookup for campos is needed.
71+ dirs = means_flat [0 , gaussian_ids ] - campos_flat [0 , 0 ] # [nnz, 3]
72+ else :
73+ avg_means_per_camera = nnz / (B * C )
74+ split_batch_camera_ops = (
75+ avg_means_per_camera > 10000
76+ and campos_flat .is_cuda
77+ and campos_flat .requires_grad
78+ )
79+
80+ if not split_batch_camera_ops :
81+ # Simple vectorized indexing for campos.
82+ dirs = (
83+ means_flat [batch_ids , gaussian_ids ] - campos_flat [batch_ids , camera_ids ]
84+ ) # [nnz, 3]
85+ else :
86+ # For large N with pose optimization: split into B*C separate operations
87+ # to avoid many-to-one indexing of campos in backward pass. This speeds up the
88+ # backwards pass and is more impactful when GPU occupancy is high.
89+ dirs = torch .empty ((nnz , 3 ), dtype = means_flat .dtype , device = device )
90+ indptr_cpu = indptr .cpu ()
91+ for b_idx in range (B ):
92+ for c_idx in range (C ):
93+ bc_idx = b_idx * C + c_idx
94+ start_idx = indptr_cpu [bc_idx ].item ()
95+ end_idx = indptr_cpu [bc_idx + 1 ].item ()
96+ if start_idx == end_idx :
97+ continue
98+
99+ # Get the gaussian indices for this batch-camera pair and compute dirs
100+ gids = gaussian_ids [start_idx :end_idx ]
101+ dirs [start_idx :end_idx ] = (
102+ means_flat [b_idx , gids ] - campos_flat [b_idx , c_idx ]
103+ )
104+
105+ return dirs
106+
107+
33108def rasterization (
34109 means : Tensor , # [..., N, 3]
35110 quats : Tensor , # [..., N, 4]
@@ -432,6 +507,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
432507 batch_ids ,
433508 camera_ids ,
434509 gaussian_ids ,
510+ indptr ,
435511 radii ,
436512 means2d ,
437513 depths ,
@@ -446,7 +522,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
446522 opacities = torch .broadcast_to (
447523 opacities [..., None , :], batch_dims + (C , N )
448524 ) # [..., C, N]
449- batch_ids , camera_ids , gaussian_ids = None , None , None
525+ indptr , batch_ids , camera_ids , gaussian_ids = None , None , None , None
450526 image_ids = None
451527
452528 if compensations is not None :
@@ -493,10 +569,17 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
493569 campos_rs = torch .inverse (viewmats_rs )[..., :3 , 3 ]
494570 campos = 0.5 * (campos + campos_rs ) # [..., C, 3]
495571 if packed :
496- dirs = (
497- means .view (B , N , 3 )[batch_ids , gaussian_ids ]
498- - campos .view (B , C , 3 )[batch_ids , camera_ids ]
572+ dirs = _compute_view_dirs_packed (
573+ means ,
574+ campos ,
575+ batch_ids ,
576+ camera_ids ,
577+ gaussian_ids ,
578+ indptr ,
579+ B ,
580+ C ,
499581 ) # [nnz, 3]
582+
500583 masks = (radii > 0 ).all (dim = - 1 ) # [nnz]
501584 if colors .dim () == num_batch_dims + 3 :
502585 # Turn [..., N, K, 3] into [nnz, 3]
0 commit comments