Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
camera_model_type = ctx.camera_model_type
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
v_means, v_covars, v_quats, v_scales, v_viewmats, v_Ks = _make_lazy_cuda_func(
"projection_ewa_3dgs_fused_bwd"
)(
means,
Expand All @@ -1129,6 +1129,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
ctx.needs_input_grad[5], # Ks_requires_grad
)
if not ctx.needs_input_grad[0]:
v_means = None
Expand All @@ -1140,12 +1141,15 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
v_scales = None
if not ctx.needs_input_grad[4]:
v_viewmats = None
if not ctx.needs_input_grad[5]:
v_Ks = None
return (
v_means,
v_covars,
v_quats,
v_scales,
v_viewmats,
v_Ks,
None,
None,
None,
Expand Down Expand Up @@ -2002,7 +2006,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals):
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
v_means, v_quats, v_scales, v_viewmats, v_Ks = _make_lazy_cuda_func(
"projection_2dgs_fused_bwd"
)(
means,
Expand All @@ -2019,6 +2023,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals):
v_normals.contiguous(),
v_ray_transforms.contiguous(),
ctx.needs_input_grad[3], # viewmats_requires_grad
ctx.needs_input_grad[4], # Ks_requires_grad
)
if not ctx.needs_input_grad[0]:
v_means = None
Expand All @@ -2028,12 +2033,15 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals):
v_scales = None
if not ctx.needs_input_grad[3]:
v_viewmats = None
if not ctx.needs_input_grad[4]:
v_Ks = None

return (
v_means,
v_quats,
v_scales,
v_viewmats,
v_Ks,
None,
None,
None,
Expand Down
44 changes: 32 additions & 12 deletions gsplat/cuda/csrc/Projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ projection_ewa_3dgs_fused_fwd(
return std::make_tuple(radii, means2d, depths, conics, compensations);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
projection_ewa_3dgs_fused_bwd(
// fwd inputs
const at::Tensor means, // [..., N, 3]
Expand All @@ -210,7 +210,8 @@ projection_ewa_3dgs_fused_bwd(
const at::Tensor v_depths, // [..., C, N]
const at::Tensor v_conics, // [..., C, N, 3]
const at::optional<at::Tensor> v_compensations, // [..., C, N] optional
const bool viewmats_requires_grad
const bool viewmats_requires_grad,
const bool Ks_requires_grad
) {
DEVICE_GUARD(means);
CHECK_INPUT(means);
Expand Down Expand Up @@ -248,6 +249,10 @@ projection_ewa_3dgs_fused_bwd(
if (viewmats_requires_grad) {
v_viewmats = at::zeros_like(viewmats);
}
at::Tensor v_Ks;
if (Ks_requires_grad) {
v_Ks = at::zeros_like(Ks);
}

launch_projection_ewa_3dgs_fused_bwd_kernel(
// inputs
Expand All @@ -269,15 +274,17 @@ projection_ewa_3dgs_fused_bwd(
v_conics,
v_compensations,
viewmats_requires_grad,
Ks_requires_grad,
// outputs
v_means,
v_covars,
v_quats,
v_scales,
v_viewmats
v_viewmats,
v_Ks
);

return std::make_tuple(v_means, v_covars, v_quats, v_scales, v_viewmats);
return std::make_tuple(v_means, v_covars, v_quats, v_scales, v_viewmats, v_Ks);
}

std::tuple<
Expand Down Expand Up @@ -619,7 +626,7 @@ projection_2dgs_fused_fwd(
return std::make_tuple(radii, means2d, depths, ray_transforms, normals);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
projection_2dgs_fused_bwd(
// fwd inputs
const at::Tensor means, // [..., N, 3]
Expand All @@ -637,7 +644,8 @@ projection_2dgs_fused_bwd(
const at::Tensor v_depths, // [..., C, N]
const at::Tensor v_normals, // [..., C, N, 3]
const at::Tensor v_ray_transforms, // [..., C, N, 3, 3]
const bool viewmats_requires_grad
const bool viewmats_requires_grad,
const bool Ks_requires_grad
) {
DEVICE_GUARD(means);
CHECK_INPUT(means);
Expand All @@ -659,6 +667,10 @@ projection_2dgs_fused_bwd(
if (viewmats_requires_grad) {
v_viewmats = at::zeros_like(viewmats);
}
at::Tensor v_Ks;
if (Ks_requires_grad) {
v_Ks = at::zeros_like(Ks);
}

launch_projection_2dgs_fused_bwd_kernel(
// inputs
Expand All @@ -676,14 +688,16 @@ projection_2dgs_fused_bwd(
v_normals,
v_ray_transforms,
viewmats_requires_grad,
Ks_requires_grad,
// outputs
v_means,
v_quats,
v_scales,
v_viewmats
v_viewmats,
v_Ks
);

return std::make_tuple(v_means, v_quats, v_scales, v_viewmats);
return std::make_tuple(v_means, v_quats, v_scales, v_viewmats, v_Ks);
}

std::tuple<
Expand Down Expand Up @@ -815,7 +829,7 @@ projection_2dgs_packed_fwd(
);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
projection_2dgs_packed_bwd(
// fwd inputs
const at::Tensor means, // [..., N, 3]
Expand All @@ -836,6 +850,7 @@ projection_2dgs_packed_bwd(
const at::Tensor v_ray_transforms, // [nnz, 3, 3]
const at::Tensor v_normals, // [nnz, 3]
const bool viewmats_requires_grad,
const bool Ks_requires_grad,
const bool sparse_grad
) {
DEVICE_GUARD(means);
Expand All @@ -859,7 +874,7 @@ projection_2dgs_packed_bwd(
uint32_t C = viewmats.size(-3); // number of cameras
uint32_t nnz = batch_ids.size(0);

at::Tensor v_means, v_quats, v_scales, v_viewmats;
at::Tensor v_means, v_quats, v_scales, v_viewmats, v_Ks;
if (sparse_grad) {
v_means = at::zeros({nnz, 3}, opt);
v_quats = at::zeros({nnz, 4}, opt);
Expand All @@ -872,6 +887,9 @@ projection_2dgs_packed_bwd(
if (viewmats_requires_grad) {
v_viewmats = at::zeros_like(viewmats, opt);
}
if (Ks_requires_grad) {
v_Ks = at::zeros_like(Ks, opt);
}

launch_projection_2dgs_packed_bwd_kernel(
// fwd inputs
Expand All @@ -898,9 +916,11 @@ projection_2dgs_packed_bwd(
v_quats,
v_scales,
v_viewmats.defined() ? at::optional<at::Tensor>(v_viewmats)
: c10::nullopt
: c10::nullopt,
v_Ks.defined() ? at::optional<at::Tensor>(v_Ks)
: c10::nullopt
);
return std::make_tuple(v_means, v_quats, v_scales, v_viewmats);
return std::make_tuple(v_means, v_quats, v_scales, v_viewmats, v_Ks);
}

std::tuple<
Expand Down
11 changes: 8 additions & 3 deletions gsplat/cuda/csrc/Projection.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ void launch_projection_ewa_3dgs_fused_bwd_kernel(
const at::Tensor v_conics, // [..., C, N, 3]
const at::optional<at::Tensor> v_compensations, // [..., C, N] optional
const bool viewmats_requires_grad,
const bool Ks_requires_grad,
// outputs
at::Tensor v_means, // [..., N, 3]
at::Tensor v_covars, // [..., N, 3, 3]
at::Tensor v_quats, // [..., N, 4]
at::Tensor v_scales, // [..., N, 3]
at::Tensor v_viewmats // [..., C, 4, 4]
at::Tensor v_viewmats, // [..., C, 4, 4]
at::Tensor v_Ks // [..., C, 3, 3]
);

void launch_projection_ewa_3dgs_packed_fwd_kernel(
Expand Down Expand Up @@ -189,11 +191,13 @@ void launch_projection_2dgs_fused_bwd_kernel(
const at::Tensor v_normals, // [..., C, N, 3]
const at::Tensor v_ray_transforms, // [..., C, N, 3, 3]
const bool viewmats_requires_grad,
const bool Ks_requires_grad,
// outputs
at::Tensor v_means, // [..., N, 3]
at::Tensor v_quats, // [..., N, 4]
at::Tensor v_scales, // [..., N, 3]
at::Tensor v_viewmats // [..., C, 4, 4]
at::Tensor v_viewmats, // [..., C, 4, 4]
at::Tensor v_Ks // [..., C, 3, 3]
);

void launch_projection_2dgs_packed_fwd_kernel(
Expand Down Expand Up @@ -246,7 +250,8 @@ void launch_projection_2dgs_packed_bwd_kernel(
at::Tensor v_means, // [..., N, 3] or [nnz, 3]
at::Tensor v_quats, // [..., N, 4] or [nnz, 4]
at::Tensor v_scales, // [..., N, 3] or [nnz, 3]
at::optional<at::Tensor> v_viewmats // [..., C, 4, 4] Optional
at::optional<at::Tensor> v_viewmats, // [..., C, 4, 4] Optional
at::optional<at::Tensor> v_Ks // [..., C, 3, 3] Optional
);

void launch_projection_ut_3dgs_fused_kernel(
Expand Down
21 changes: 19 additions & 2 deletions gsplat/cuda/csrc/Projection2DGS.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ inline __device__ void compute_ray_transforms_aabb_vjp(
vec2 &v_scale,
vec3 &v_mean,
mat3 &v_R,
vec3 &v_t
vec3 &v_t,
float &v_fx,
float &v_fy,
float &v_cx,
float &v_cy
) {
if (v_means2d[0] != 0 || v_means2d[1] != 0) {
const float distance = ray_transforms[6] * ray_transforms[6] +
Expand Down Expand Up @@ -81,12 +85,25 @@ inline __device__ void compute_ray_transforms_aabb_vjp(

v_R += glm::outerProduct(v_M[2], mean_w);

mat3 RS = quat_to_rotmat(quat) *
mat3 RS = R *
mat3(scale[0], 0.0, 0.0, 0.0, scale[1], 0.0, 0.0, 0.0, 1.0);
mat3 v_RS_cam = mat3(v_M[0], v_M[1], v_normals * multiplier);

v_R += v_RS_cam * glm::transpose(RS);
v_t += v_M[2];

// From forward pass: M = ray_transforms = (WH)^T * K^T
// where W is the camera rotation matrix R (not full projection like in paper)
// and WH represents Gaussian geometry in camera coordinates
// So ∂M/∂K^T = (WH)^T
// Therefore ∂L/∂K^T = (∂L/∂M)^T * (WH)^T = _v_ray_transforms^T * (WH)^T
mat3 RS_camera = W * RS;
mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c);
mat3 v_K_T = glm::transpose(WH * _v_ray_transforms);
v_fx += v_K_T[0][0];
v_fy += v_K_T[1][1];
v_cx += v_K_T[2][0];
v_cy += v_K_T[2][1];
}

} // namespace gsplat
38 changes: 34 additions & 4 deletions gsplat/cuda/csrc/Projection2DGSFused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ __global__ void projection_2dgs_fused_bwd_kernel(
scalar_t *__restrict__ v_means, // [B, N, 3]
scalar_t *__restrict__ v_quats, // [B, N, 4]
scalar_t *__restrict__ v_scales, // [B, N, 3]
scalar_t *__restrict__ v_viewmats // [B, C, 4, 4]
scalar_t *__restrict__ v_viewmats, // [B, C, 4, 4]
scalar_t *__restrict__ v_Ks // [B, C, 3, 3]
) {
// parallelize over C * N.
uint32_t idx = cg::this_grid().thread_rank();
Expand Down Expand Up @@ -411,6 +412,12 @@ __global__ void projection_2dgs_fused_bwd_kernel(
vec4 v_quat(0.f);
mat3 v_R(0.f);
vec3 v_t(0.f);
float v_fx(0.f);
float v_fy(0.f);
float v_cx(0.f);
float v_cy(0.f);

// Use the overloaded version that computes intrinsics gradients
compute_ray_transforms_aabb_vjp(
ray_transforms,
v_means2d,
Expand All @@ -427,7 +434,11 @@ __global__ void projection_2dgs_fused_bwd_kernel(
v_scale,
v_mean,
v_R,
v_t
v_t,
v_fx,
v_fy,
v_cx,
v_cy
);

// #if __CUDA_ARCH__ >= 700
Expand Down Expand Up @@ -475,6 +486,22 @@ __global__ void projection_2dgs_fused_bwd_kernel(
}
}
}

// Write intrinsics gradients if needed
if (v_Ks != nullptr) {
auto warp_group_c = cg::labeled_partition(warp, cid);
warpSum(v_fx, warp_group_c);
warpSum(v_fy, warp_group_c);
warpSum(v_cx, warp_group_c);
warpSum(v_cy, warp_group_c);
if (warp_group_c.thread_rank() == 0) {
v_Ks += bid * C * 9 + cid * 9;
gpuAtomicAdd(v_Ks + 0, v_fx); // [0,0] = fx
gpuAtomicAdd(v_Ks + 4, v_fy); // [1,1] = fy
gpuAtomicAdd(v_Ks + 2, v_cx); // [0,2] = cx
gpuAtomicAdd(v_Ks + 5, v_cy); // [1,2] = cy
}
}
}

void launch_projection_2dgs_fused_bwd_kernel(
Expand All @@ -495,11 +522,13 @@ void launch_projection_2dgs_fused_bwd_kernel(
const at::Tensor v_normals, // [..., C, N, 3]
const at::Tensor v_ray_transforms, // [..., C, N, 3, 3]
const bool viewmats_requires_grad,
const bool Ks_requires_grad,
// outputs
at::Tensor v_means, // [..., N, 3]
at::Tensor v_quats, // [..., N, 4]
at::Tensor v_scales, // [..., N, 3]
at::Tensor v_viewmats // [..., C, 4, 4]
at::Tensor v_viewmats, // [..., C, 4, 4]
at::Tensor v_Ks // [..., C, 3, 3]
) {
uint32_t N = means.size(-2); // number of gaussians
uint32_t B = means.numel() / (N * 3); // number of batches
Expand Down Expand Up @@ -536,7 +565,8 @@ void launch_projection_2dgs_fused_bwd_kernel(
v_means.data_ptr<float>(),
v_quats.data_ptr<float>(),
v_scales.data_ptr<float>(),
viewmats_requires_grad ? v_viewmats.data_ptr<float>() : nullptr
viewmats_requires_grad ? v_viewmats.data_ptr<float>() : nullptr,
Ks_requires_grad ? v_Ks.data_ptr<float>() : nullptr
);
}

Expand Down
Loading
Loading