Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def fully_fused_projection(
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **indptr**. CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
- **means**. Projected Gaussian means in 2D. [nnz, 2]
- **depths**. The z-depth of the projected Gaussians. [nnz]
Expand Down Expand Up @@ -1659,6 +1660,7 @@ def forward(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand All @@ -1672,6 +1674,7 @@ def backward(
v_batch_ids,
v_camera_ids,
v_gaussian_ids,
v_indptr,
v_radii,
v_means2d,
v_depths,
Expand Down
91 changes: 87 additions & 4 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,81 @@
from .utils import depth_to_normal, get_projection_matrix


def _compute_view_dirs_packed(
means: Tensor, # [..., N, 3]
campos: Tensor, # [..., C, 3]
batch_ids: Tensor, # [nnz]
camera_ids: Tensor, # [nnz]
gaussian_ids: Tensor, # [nnz]
indptr: Tensor, # [B*C+1]
B: int,
C: int,
) -> Tensor:
"""Compute view directions for packed Gaussian-camera pairs.

This function computes the view directions (means - campos) for each
Gaussian-camera pair in the packed format. It automatically selects between
a simple vectorized approach or an optimized loop-based approach based on
the data size and whether campos requires gradients.

Args:
means: The 3D centers of the Gaussians. [..., N, 3]
campos: Camera positions in world coordinates [..., C, 3]
batch_ids: The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
camera_ids: The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
gaussian_ids: The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
indptr: CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
B: Number of batches
C: Number of cameras

Returns:
dirs: View directions [nnz, 3]
"""
N = means.shape[-2]
nnz = batch_ids.shape[0]
device = means.device
means_flat = means.view(B, N, 3)
campos_flat = campos.view(B, C, 3)

if B * C == 1:
# Single batch-camera pair. No indexed lookup for campos is needed.
dirs = means_flat[0, gaussian_ids] - campos_flat[0, 0] # [nnz, 3]
else:
avg_means_per_camera = nnz / (B * C)
split_batch_camera_ops = (
avg_means_per_camera > 10000
and campos_flat.is_cuda
and campos_flat.requires_grad
)

if not split_batch_camera_ops:
# Simple vectorized indexing for campos.
dirs = (
means_flat[batch_ids, gaussian_ids] - campos_flat[batch_ids, camera_ids]
) # [nnz, 3]
else:
# For large N with pose optimization: split into B*C separate operations
# to avoid many-to-one indexing of campos in backward pass. This speeds up the
# backwards pass and is more impactful when GPU occupancy is high.
dirs = torch.empty((nnz, 3), dtype=means_flat.dtype, device=device)
indptr_cpu = indptr.cpu()
for b_idx in range(B):
for c_idx in range(C):
bc_idx = b_idx * C + c_idx
start_idx = indptr_cpu[bc_idx].item()
end_idx = indptr_cpu[bc_idx + 1].item()
if start_idx == end_idx:
continue

# Get the gaussian indices for this batch-camera pair and compute dirs
gids = gaussian_ids[start_idx:end_idx]
dirs[start_idx:end_idx] = (
means_flat[b_idx, gids] - campos_flat[b_idx, c_idx]
)

return dirs


def rasterization(
means: Tensor, # [..., N, 3]
quats: Tensor, # [..., N, 4]
Expand Down Expand Up @@ -432,6 +507,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand All @@ -446,7 +522,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
opacities = torch.broadcast_to(
opacities[..., None, :], batch_dims + (C, N)
) # [..., C, N]
batch_ids, camera_ids, gaussian_ids = None, None, None
indptr, batch_ids, camera_ids, gaussian_ids = None, None, None, None
image_ids = None

if compensations is not None:
Expand Down Expand Up @@ -493,10 +569,17 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
campos_rs = torch.inverse(viewmats_rs)[..., :3, 3]
campos = 0.5 * (campos + campos_rs) # [..., C, 3]
if packed:
dirs = (
means.view(B, N, 3)[batch_ids, gaussian_ids]
- campos.view(B, C, 3)[batch_ids, camera_ids]
dirs = _compute_view_dirs_packed(
means,
campos,
batch_ids,
camera_ids,
gaussian_ids,
indptr,
B,
C,
) # [nnz, 3]

masks = (radii > 0).all(dim=-1) # [nnz]
if colors.dim() == num_batch_dims + 3:
# Turn [..., N, K, 3] into [nnz, 3]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def test_fully_fused_projection_packed(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand Down Expand Up @@ -344,6 +345,7 @@ def test_fully_fused_projection_packed(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand Down