Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def fully_fused_projection(
- **batch_ids**. The batch indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **camera_ids**. The camera indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
- **indptr**. CSR-style index pointer into gaussian_ids for batch-camera pairs. Int32 tensor of shape [B*C+1].
- **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz, 2].
- **means**. Projected Gaussian means in 2D. [nnz, 2]
- **depths**. The z-depth of the projected Gaussians. [nnz]
Expand Down Expand Up @@ -1239,9 +1240,11 @@ def fully_fused_projection_with_ut(
radial_coeffs.contiguous() if radial_coeffs is not None else None,
tangential_coeffs.contiguous() if tangential_coeffs is not None else None,
thin_prism_coeffs.contiguous() if thin_prism_coeffs is not None else None,
ftheta_coeffs.to_cpp()
if ftheta_coeffs is not None
else FThetaCameraDistortionParameters.to_cpp_default(),
(
ftheta_coeffs.to_cpp()
if ftheta_coeffs is not None
else FThetaCameraDistortionParameters.to_cpp_default()
),
)
if not calc_compensations:
compensations = None
Expand Down Expand Up @@ -1509,9 +1512,13 @@ def backward(
tile_size = ctx.tile_size
ftheta_coeffs = ctx.ftheta_coeffs

(v_means, v_quats, v_scales, v_colors, v_opacities,) = _make_lazy_cuda_func(
"rasterize_to_pixels_from_world_3dgs_bwd"
)(
(
v_means,
v_quats,
v_scales,
v_colors,
v_opacities,
) = _make_lazy_cuda_func("rasterize_to_pixels_from_world_3dgs_bwd")(
means,
quats,
scales,
Expand Down Expand Up @@ -1659,6 +1666,7 @@ def forward(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand All @@ -1672,6 +1680,7 @@ def backward(
v_batch_ids,
v_camera_ids,
v_gaussian_ids,
v_indptr,
v_radii,
v_means2d,
v_depths,
Expand Down
28 changes: 23 additions & 5 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand All @@ -446,7 +447,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
opacities = torch.broadcast_to(
opacities[..., None, :], batch_dims + (C, N)
) # [..., C, N]
batch_ids, camera_ids, gaussian_ids = None, None, None
indptr, batch_ids, camera_ids, gaussian_ids = None, None, None, None
image_ids = None

if compensations is not None:
Expand Down Expand Up @@ -493,10 +494,27 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
campos_rs = torch.inverse(viewmats_rs)[..., :3, 3]
campos = 0.5 * (campos + campos_rs) # [..., C, 3]
if packed:
dirs = (
means.view(B, N, 3)[batch_ids, gaussian_ids]
- campos.view(B, C, 3)[batch_ids, camera_ids]
) # [nnz, 3]
nnz = batch_ids.shape[0]
means_flat = means.view(B, N, 3)
campos_flat = campos.view(B, C, 3)

# Compute dirs in B*C steps to avoid many-to-one indexing of cameras in backward pass
dirs = torch.empty((nnz, 3), dtype=means.dtype, device=device)
indptr_cpu = indptr.cpu()
for b_idx in range(B):
for c_idx in range(C):
bc_idx = b_idx * C + c_idx
start_idx = indptr_cpu[bc_idx].item()
end_idx = indptr_cpu[bc_idx + 1].item()
if start_idx == end_idx:
continue

# Get the gaussian indices for this batch-camera pair and compute dirs
gids = gaussian_ids[start_idx:end_idx]
dirs[start_idx:end_idx] = (
means_flat[b_idx, gids] - campos_flat[b_idx, c_idx]
)

masks = (radii > 0).all(dim=-1) # [nnz]
if colors.dim() == num_batch_dims + 3:
# Turn [..., N, K, 3] into [nnz, 3]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def test_fully_fused_projection_packed(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand Down Expand Up @@ -344,6 +345,7 @@ def test_fully_fused_projection_packed(
batch_ids,
camera_ids,
gaussian_ids,
indptr,
radii,
means2d,
depths,
Expand Down
Loading