Skip to content

Commit 46e8249

Browse files
Optimize packed viewdir pass by reducing many-to-one indexing operations
Co-authored-by: AbhinavGrover <[email protected]>
1 parent 65042cc commit 46e8249

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

gsplat/cuda/_wrapper.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def fully_fused_projection(
356356
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
357357
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
358358
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
359+
- **indptr**. CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
359360
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
360361
- **means**. Projected Gaussian means in 2D. [nnz, 2]
361362
- **depths**. The z-depth of the projected Gaussians. [nnz]
@@ -1239,9 +1240,11 @@ def fully_fused_projection_with_ut(
12391240
radial_coeffs.contiguous() if radial_coeffs is not None else None,
12401241
tangential_coeffs.contiguous() if tangential_coeffs is not None else None,
12411242
thin_prism_coeffs.contiguous() if thin_prism_coeffs is not None else None,
1242-
ftheta_coeffs.to_cpp()
1243-
if ftheta_coeffs is not None
1244-
else FThetaCameraDistortionParameters.to_cpp_default(),
1243+
(
1244+
ftheta_coeffs.to_cpp()
1245+
if ftheta_coeffs is not None
1246+
else FThetaCameraDistortionParameters.to_cpp_default()
1247+
),
12451248
)
12461249
if not calc_compensations:
12471250
compensations = None
@@ -1509,9 +1512,13 @@ def backward(
15091512
tile_size = ctx.tile_size
15101513
ftheta_coeffs = ctx.ftheta_coeffs
15111514

1512-
(v_means, v_quats, v_scales, v_colors, v_opacities,) = _make_lazy_cuda_func(
1513-
"rasterize_to_pixels_from_world_3dgs_bwd"
1514-
)(
1515+
(
1516+
v_means,
1517+
v_quats,
1518+
v_scales,
1519+
v_colors,
1520+
v_opacities,
1521+
) = _make_lazy_cuda_func("rasterize_to_pixels_from_world_3dgs_bwd")(
15151522
means,
15161523
quats,
15171524
scales,
@@ -1659,6 +1666,7 @@ def forward(
16591666
batch_ids,
16601667
camera_ids,
16611668
gaussian_ids,
1669+
indptr,
16621670
radii,
16631671
means2d,
16641672
depths,
@@ -1672,6 +1680,7 @@ def backward(
16721680
v_batch_ids,
16731681
v_camera_ids,
16741682
v_gaussian_ids,
1683+
v_indptr,
16751684
v_radii,
16761685
v_means2d,
16771686
v_depths,

gsplat/rendering.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
432432
batch_ids,
433433
camera_ids,
434434
gaussian_ids,
435+
indptr,
435436
radii,
436437
means2d,
437438
depths,
@@ -446,7 +447,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
446447
opacities = torch.broadcast_to(
447448
opacities[..., None, :], batch_dims + (C, N)
448449
) # [..., C, N]
449-
batch_ids, camera_ids, gaussian_ids = None, None, None
450+
indptr, batch_ids, camera_ids, gaussian_ids = None, None, None, None
450451
image_ids = None
451452

452453
if compensations is not None:
@@ -493,10 +494,26 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
493494
campos_rs = torch.inverse(viewmats_rs)[..., :3, 3]
494495
campos = 0.5 * (campos + campos_rs) # [..., C, 3]
495496
if packed:
496-
dirs = (
497-
means.view(B, N, 3)[batch_ids, gaussian_ids]
498-
- campos.view(B, C, 3)[batch_ids, camera_ids]
499-
) # [nnz, 3]
497+
nnz = batch_ids.shape[0]
498+
means_flat = means.view(B, N, 3)
499+
campos_flat = campos.view(B, C, 3)
500+
501+
# Compute dirs in B*C steps to avoid many-to-one indexing of cameras in backward pass
502+
dirs = torch.empty((nnz, 3), dtype=means.dtype, device=device)
503+
for b_idx in range(B):
504+
for c_idx in range(C):
505+
bc_idx = b_idx * C + c_idx
506+
start_idx = indptr[bc_idx].item()
507+
end_idx = indptr[bc_idx + 1].item()
508+
if start_idx == end_idx:
509+
continue
510+
511+
# Get the gaussian indices for this batch-camera pair and compute dirs
512+
gids = gaussian_ids[start_idx:end_idx]
513+
dirs[start_idx:end_idx] = (
514+
means_flat[b_idx, gids] - campos_flat[b_idx, c_idx]
515+
)
516+
500517
masks = (radii > 0).all(dim=-1) # [nnz]
501518
if colors.dim() == num_batch_dims + 3:
502519
# Turn [..., N, K, 3] into [nnz, 3]

tests/test_basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ def test_fully_fused_projection_packed(
306306
batch_ids,
307307
camera_ids,
308308
gaussian_ids,
309+
indptr,
309310
radii,
310311
means2d,
311312
depths,
@@ -344,6 +345,7 @@ def test_fully_fused_projection_packed(
344345
batch_ids,
345346
camera_ids,
346347
gaussian_ids,
348+
indptr,
347349
radii,
348350
means2d,
349351
depths,

0 commit comments

Comments
 (0)