Skip to content

Commit e35a43a

Browse files
Packed backward pass speedup via unrolled camera position indexing (#831)
* Optimize packed viewdir pass by reducing many-to-one indexing operations Co-authored-by: AbhinavGrover <[email protected]> * Copy the indptr to cpu to avoid GPU sync in the loop * Resolved PR comments * Format * Format * typo --------- Co-authored-by: AbhinavGrover <[email protected]>
1 parent 65042cc commit e35a43a

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

gsplat/cuda/_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def fully_fused_projection(
356356
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
357357
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
358358
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
359+
- **indptr**. CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
359360
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
360361
- **means**. Projected Gaussian means in 2D. [nnz, 2]
361362
- **depths**. The z-depth of the projected Gaussians. [nnz]
@@ -1659,6 +1660,7 @@ def forward(
16591660
batch_ids,
16601661
camera_ids,
16611662
gaussian_ids,
1663+
indptr,
16621664
radii,
16631665
means2d,
16641666
depths,
@@ -1672,6 +1674,7 @@ def backward(
16721674
v_batch_ids,
16731675
v_camera_ids,
16741676
v_gaussian_ids,
1677+
v_indptr,
16751678
v_radii,
16761679
v_means2d,
16771680
v_depths,

gsplat/rendering.py

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,81 @@
3030
from .utils import depth_to_normal, get_projection_matrix
3131

3232

33+
def _compute_view_dirs_packed(
34+
means: Tensor, # [..., N, 3]
35+
campos: Tensor, # [..., C, 3]
36+
batch_ids: Tensor, # [nnz]
37+
camera_ids: Tensor, # [nnz]
38+
gaussian_ids: Tensor, # [nnz]
39+
indptr: Tensor, # [B*C+1]
40+
B: int,
41+
C: int,
42+
) -> Tensor:
43+
"""Compute view directions for packed Gaussian-camera pairs.
44+
45+
This function computes the view directions (means - campos) for each
46+
Gaussian-camera pair in the packed format. It automatically selects between
47+
a simple vectorized approach or an optimized loop-based approach based on
48+
the data size and whether campos requires gradients.
49+
50+
Args:
51+
means: The 3D centers of the Gaussians. [..., N, 3]
52+
campos: Camera positions in world coordinates [..., C, 3]
53+
batch_ids: The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
54+
camera_ids: The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
55+
gaussian_ids: The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
56+
indptr: CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
57+
B: Number of batches
58+
C: Number of cameras
59+
60+
Returns:
61+
dirs: View directions [nnz, 3]
62+
"""
63+
N = means.shape[-2]
64+
nnz = batch_ids.shape[0]
65+
device = means.device
66+
means_flat = means.view(B, N, 3)
67+
campos_flat = campos.view(B, C, 3)
68+
69+
if B * C == 1:
70+
# Single batch-camera pair. No indexed lookup for campos is needed.
71+
dirs = means_flat[0, gaussian_ids] - campos_flat[0, 0] # [nnz, 3]
72+
else:
73+
avg_means_per_camera = nnz / (B * C)
74+
split_batch_camera_ops = (
75+
avg_means_per_camera > 10000
76+
and campos_flat.is_cuda
77+
and campos_flat.requires_grad
78+
)
79+
80+
if not split_batch_camera_ops:
81+
# Simple vectorized indexing for campos.
82+
dirs = (
83+
means_flat[batch_ids, gaussian_ids] - campos_flat[batch_ids, camera_ids]
84+
) # [nnz, 3]
85+
else:
86+
# For large N with pose optimization: split into B*C separate operations
87+
# to avoid many-to-one indexing of campos in backward pass. This speeds up the
88+
# backwards pass and is more impactful when GPU occupancy is high.
89+
dirs = torch.empty((nnz, 3), dtype=means_flat.dtype, device=device)
90+
indptr_cpu = indptr.cpu()
91+
for b_idx in range(B):
92+
for c_idx in range(C):
93+
bc_idx = b_idx * C + c_idx
94+
start_idx = indptr_cpu[bc_idx].item()
95+
end_idx = indptr_cpu[bc_idx + 1].item()
96+
if start_idx == end_idx:
97+
continue
98+
99+
# Get the gaussian indices for this batch-camera pair and compute dirs
100+
gids = gaussian_ids[start_idx:end_idx]
101+
dirs[start_idx:end_idx] = (
102+
means_flat[b_idx, gids] - campos_flat[b_idx, c_idx]
103+
)
104+
105+
return dirs
106+
107+
33108
def rasterization(
34109
means: Tensor, # [..., N, 3]
35110
quats: Tensor, # [..., N, 4]
@@ -432,6 +507,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
432507
batch_ids,
433508
camera_ids,
434509
gaussian_ids,
510+
indptr,
435511
radii,
436512
means2d,
437513
depths,
@@ -446,7 +522,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
446522
opacities = torch.broadcast_to(
447523
opacities[..., None, :], batch_dims + (C, N)
448524
) # [..., C, N]
449-
batch_ids, camera_ids, gaussian_ids = None, None, None
525+
indptr, batch_ids, camera_ids, gaussian_ids = None, None, None, None
450526
image_ids = None
451527

452528
if compensations is not None:
@@ -493,10 +569,17 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
493569
campos_rs = torch.inverse(viewmats_rs)[..., :3, 3]
494570
campos = 0.5 * (campos + campos_rs) # [..., C, 3]
495571
if packed:
496-
dirs = (
497-
means.view(B, N, 3)[batch_ids, gaussian_ids]
498-
- campos.view(B, C, 3)[batch_ids, camera_ids]
572+
dirs = _compute_view_dirs_packed(
573+
means,
574+
campos,
575+
batch_ids,
576+
camera_ids,
577+
gaussian_ids,
578+
indptr,
579+
B,
580+
C,
499581
) # [nnz, 3]
582+
500583
masks = (radii > 0).all(dim=-1) # [nnz]
501584
if colors.dim() == num_batch_dims + 3:
502585
# Turn [..., N, K, 3] into [nnz, 3]

tests/test_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def test_fully_fused_projection_packed(
306306
batch_ids,
307307
camera_ids,
308308
gaussian_ids,
309+
indptr,
309310
radii,
310311
means2d,
311312
depths,
@@ -344,6 +345,7 @@ def test_fully_fused_projection_packed(
344345
batch_ids,
345346
camera_ids,
346347
gaussian_ids,
348+
indptr,
347349
radii,
348350
means2d,
349351
depths,

0 commit comments

Comments
 (0)