diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 286685ff9..cd23238b6 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -1108,7 +1108,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): camera_model_type = ctx.camera_model_type if v_compensations is not None: v_compensations = v_compensations.contiguous() - v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( + v_means, v_covars, v_quats, v_scales, v_viewmats, v_Ks = _make_lazy_cuda_func( "projection_ewa_3dgs_fused_bwd" )( means, @@ -1129,6 +1129,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad + ctx.needs_input_grad[5], # Ks_requires_grad ) if not ctx.needs_input_grad[0]: v_means = None @@ -1140,12 +1141,15 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): v_scales = None if not ctx.needs_input_grad[4]: v_viewmats = None + if not ctx.needs_input_grad[5]: + v_Ks = None return ( v_means, v_covars, v_quats, v_scales, v_viewmats, + v_Ks, None, None, None, @@ -2002,7 +2006,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals): width = ctx.width height = ctx.height eps2d = ctx.eps2d - v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( + v_means, v_quats, v_scales, v_viewmats, v_Ks = _make_lazy_cuda_func( "projection_2dgs_fused_bwd" )( means, @@ -2019,6 +2023,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals): v_normals.contiguous(), v_ray_transforms.contiguous(), ctx.needs_input_grad[3], # viewmats_requires_grad + ctx.needs_input_grad[4], # Ks_requires_grad ) if not ctx.needs_input_grad[0]: v_means = None @@ -2028,12 +2033,15 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals): v_scales = None if not ctx.needs_input_grad[3]: v_viewmats = None + if not ctx.needs_input_grad[4]: + v_Ks = None return ( v_means, v_quats, v_scales, v_viewmats, + v_Ks, None, None, None, diff --git a/gsplat/cuda/csrc/Projection.cpp b/gsplat/cuda/csrc/Projection.cpp index d319b7a13..93ed8de5c 100644 --- a/gsplat/cuda/csrc/Projection.cpp +++ b/gsplat/cuda/csrc/Projection.cpp @@ -188,7 +188,7 @@ projection_ewa_3dgs_fused_fwd( return std::make_tuple(radii, means2d, depths, conics, compensations); } -std::tuple +std::tuple projection_ewa_3dgs_fused_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -210,7 +210,8 @@ projection_ewa_3dgs_fused_bwd( const at::Tensor v_depths, // [..., C, N] const at::Tensor v_conics, // [..., C, N, 3] const at::optional v_compensations, // [..., C, N] optional - const bool viewmats_requires_grad + const bool viewmats_requires_grad, + const bool Ks_requires_grad ) { DEVICE_GUARD(means); CHECK_INPUT(means); @@ -248,6 +249,10 @@ projection_ewa_3dgs_fused_bwd( if (viewmats_requires_grad) { v_viewmats = at::zeros_like(viewmats); } + at::Tensor v_Ks; + if (Ks_requires_grad) { + v_Ks = at::zeros_like(Ks); + } launch_projection_ewa_3dgs_fused_bwd_kernel( // inputs @@ -269,15 +274,17 @@ projection_ewa_3dgs_fused_bwd( v_conics, v_compensations, viewmats_requires_grad, + Ks_requires_grad, // outputs v_means, v_covars, v_quats, v_scales, - v_viewmats + v_viewmats, + v_Ks ); - return std::make_tuple(v_means, v_covars, v_quats, v_scales, v_viewmats); + return std::make_tuple(v_means, v_covars, v_quats, v_scales, v_viewmats, v_Ks); } std::tuple< @@ -619,7 +626,7 @@ projection_2dgs_fused_fwd( return std::make_tuple(radii, means2d, depths, ray_transforms, normals); } -std::tuple +std::tuple projection_2dgs_fused_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -637,7 +644,8 @@ projection_2dgs_fused_bwd( const at::Tensor v_depths, // [..., C, N] const at::Tensor v_normals, // [..., C, N, 3] const at::Tensor v_ray_transforms, // [..., C, N, 3, 3] - const bool viewmats_requires_grad + const bool viewmats_requires_grad, + const bool Ks_requires_grad ) { DEVICE_GUARD(means); CHECK_INPUT(means); @@ -659,6 +667,10 @@ projection_2dgs_fused_bwd( if (viewmats_requires_grad) { v_viewmats = at::zeros_like(viewmats); } + at::Tensor v_Ks; + if (Ks_requires_grad) { + v_Ks = at::zeros_like(Ks); + } launch_projection_2dgs_fused_bwd_kernel( // inputs @@ -676,14 +688,16 @@ projection_2dgs_fused_bwd( v_normals, v_ray_transforms, viewmats_requires_grad, + Ks_requires_grad, // outputs v_means, v_quats, v_scales, - v_viewmats + v_viewmats, + v_Ks ); - return std::make_tuple(v_means, v_quats, v_scales, v_viewmats); + return std::make_tuple(v_means, v_quats, v_scales, v_viewmats, v_Ks); } std::tuple< @@ -815,7 +829,7 @@ projection_2dgs_packed_fwd( ); } -std::tuple +std::tuple projection_2dgs_packed_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -836,6 +850,7 @@ projection_2dgs_packed_bwd( const at::Tensor v_ray_transforms, // [nnz, 3, 3] const at::Tensor v_normals, // [nnz, 3] const bool viewmats_requires_grad, + const bool Ks_requires_grad, const bool sparse_grad ) { DEVICE_GUARD(means); @@ -859,7 +874,7 @@ projection_2dgs_packed_bwd( uint32_t C = viewmats.size(-3); // number of cameras uint32_t nnz = batch_ids.size(0); - at::Tensor v_means, v_quats, v_scales, v_viewmats; + at::Tensor v_means, v_quats, v_scales, v_viewmats, v_Ks; if (sparse_grad) { v_means = at::zeros({nnz, 3}, opt); v_quats = at::zeros({nnz, 4}, opt); @@ -872,6 +887,9 @@ projection_2dgs_packed_bwd( if (viewmats_requires_grad) { v_viewmats = at::zeros_like(viewmats, opt); } + if (Ks_requires_grad) { + v_Ks = at::zeros_like(Ks, opt); + } launch_projection_2dgs_packed_bwd_kernel( // fwd inputs @@ -898,9 +916,11 @@ projection_2dgs_packed_bwd( v_quats, v_scales, v_viewmats.defined() ? at::optional(v_viewmats) - : c10::nullopt + : c10::nullopt, + v_Ks.defined() ? at::optional(v_Ks) + : c10::nullopt ); - return std::make_tuple(v_means, v_quats, v_scales, v_viewmats); + return std::make_tuple(v_means, v_quats, v_scales, v_viewmats, v_Ks); } std::tuple< diff --git a/gsplat/cuda/csrc/Projection.h b/gsplat/cuda/csrc/Projection.h index a4b388b03..48464c910 100644 --- a/gsplat/cuda/csrc/Projection.h +++ b/gsplat/cuda/csrc/Projection.h @@ -82,12 +82,14 @@ void launch_projection_ewa_3dgs_fused_bwd_kernel( const at::Tensor v_conics, // [..., C, N, 3] const at::optional v_compensations, // [..., C, N] optional const bool viewmats_requires_grad, + const bool Ks_requires_grad, // outputs at::Tensor v_means, // [..., N, 3] at::Tensor v_covars, // [..., N, 3, 3] at::Tensor v_quats, // [..., N, 4] at::Tensor v_scales, // [..., N, 3] - at::Tensor v_viewmats // [..., C, 4, 4] + at::Tensor v_viewmats, // [..., C, 4, 4] + at::Tensor v_Ks // [..., C, 3, 3] ); void launch_projection_ewa_3dgs_packed_fwd_kernel( @@ -189,11 +191,13 @@ void launch_projection_2dgs_fused_bwd_kernel( const at::Tensor v_normals, // [..., C, N, 3] const at::Tensor v_ray_transforms, // [..., C, N, 3, 3] const bool viewmats_requires_grad, + const bool Ks_requires_grad, // outputs at::Tensor v_means, // [..., N, 3] at::Tensor v_quats, // [..., N, 4] at::Tensor v_scales, // [..., N, 3] - at::Tensor v_viewmats // [..., C, 4, 4] + at::Tensor v_viewmats, // [..., C, 4, 4] + at::Tensor v_Ks // [..., C, 3, 3] ); void launch_projection_2dgs_packed_fwd_kernel( @@ -246,7 +250,8 @@ void launch_projection_2dgs_packed_bwd_kernel( at::Tensor v_means, // [..., N, 3] or [nnz, 3] at::Tensor v_quats, // [..., N, 4] or [nnz, 4] at::Tensor v_scales, // [..., N, 3] or [nnz, 3] - at::optional v_viewmats // [..., C, 4, 4] Optional + at::optional v_viewmats, // [..., C, 4, 4] Optional + at::optional v_Ks // [..., C, 3, 3] Optional ); void launch_projection_ut_3dgs_fused_kernel( diff --git a/gsplat/cuda/csrc/Projection2DGS.cuh b/gsplat/cuda/csrc/Projection2DGS.cuh index d26a63971..fb04d7464 100644 --- a/gsplat/cuda/csrc/Projection2DGS.cuh +++ b/gsplat/cuda/csrc/Projection2DGS.cuh @@ -23,7 +23,11 @@ inline __device__ void compute_ray_transforms_aabb_vjp( vec2 &v_scale, vec3 &v_mean, mat3 &v_R, - vec3 &v_t + vec3 &v_t, + float &v_fx, + float &v_fy, + float &v_cx, + float &v_cy ) { if (v_means2d[0] != 0 || v_means2d[1] != 0) { const float distance = ray_transforms[6] * ray_transforms[6] + @@ -81,12 +85,25 @@ inline __device__ void compute_ray_transforms_aabb_vjp( v_R += glm::outerProduct(v_M[2], mean_w); - mat3 RS = quat_to_rotmat(quat) * + mat3 RS = R * mat3(scale[0], 0.0, 0.0, 0.0, scale[1], 0.0, 0.0, 0.0, 1.0); mat3 v_RS_cam = mat3(v_M[0], v_M[1], v_normals * multiplier); v_R += v_RS_cam * glm::transpose(RS); v_t += v_M[2]; + + // From forward pass: M = ray_transforms = (WH)^T * K^T + // where W is the camera rotation matrix R (not full projection like in paper) + // and WH represents Gaussian geometry in camera coordinates + // So ∂M/∂K^T = (WH)^T + // Therefore ∂L/∂K^T = (∂L/∂M)^T * (WH)^T = _v_ray_transforms^T * (WH)^T + mat3 RS_camera = W * RS; + mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c); + mat3 v_K_T = glm::transpose(WH * _v_ray_transforms); + v_fx += v_K_T[0][0]; + v_fy += v_K_T[1][1]; + v_cx += v_K_T[2][0]; + v_cy += v_K_T[2][1]; } } // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/Projection2DGSFused.cu b/gsplat/cuda/csrc/Projection2DGSFused.cu index 9d6e1c551..b45671adc 100644 --- a/gsplat/cuda/csrc/Projection2DGSFused.cu +++ b/gsplat/cuda/csrc/Projection2DGSFused.cu @@ -345,7 +345,8 @@ __global__ void projection_2dgs_fused_bwd_kernel( scalar_t *__restrict__ v_means, // [B, N, 3] scalar_t *__restrict__ v_quats, // [B, N, 4] scalar_t *__restrict__ v_scales, // [B, N, 3] - scalar_t *__restrict__ v_viewmats // [B, C, 4, 4] + scalar_t *__restrict__ v_viewmats, // [B, C, 4, 4] + scalar_t *__restrict__ v_Ks // [B, C, 3, 3] ) { // parallelize over C * N. uint32_t idx = cg::this_grid().thread_rank(); @@ -411,6 +412,12 @@ __global__ void projection_2dgs_fused_bwd_kernel( vec4 v_quat(0.f); mat3 v_R(0.f); vec3 v_t(0.f); + float v_fx(0.f); + float v_fy(0.f); + float v_cx(0.f); + float v_cy(0.f); + + // Use the overloaded version that computes intrinsics gradients compute_ray_transforms_aabb_vjp( ray_transforms, v_means2d, @@ -427,7 +434,11 @@ __global__ void projection_2dgs_fused_bwd_kernel( v_scale, v_mean, v_R, - v_t + v_t, + v_fx, + v_fy, + v_cx, + v_cy ); // #if __CUDA_ARCH__ >= 700 @@ -475,6 +486,22 @@ __global__ void projection_2dgs_fused_bwd_kernel( } } } + + // Write intrinsics gradients if needed + if (v_Ks != nullptr) { + auto warp_group_c = cg::labeled_partition(warp, cid); + warpSum(v_fx, warp_group_c); + warpSum(v_fy, warp_group_c); + warpSum(v_cx, warp_group_c); + warpSum(v_cy, warp_group_c); + if (warp_group_c.thread_rank() == 0) { + v_Ks += bid * C * 9 + cid * 9; + gpuAtomicAdd(v_Ks + 0, v_fx); // [0,0] = fx + gpuAtomicAdd(v_Ks + 4, v_fy); // [1,1] = fy + gpuAtomicAdd(v_Ks + 2, v_cx); // [0,2] = cx + gpuAtomicAdd(v_Ks + 5, v_cy); // [1,2] = cy + } + } } void launch_projection_2dgs_fused_bwd_kernel( @@ -495,11 +522,13 @@ void launch_projection_2dgs_fused_bwd_kernel( const at::Tensor v_normals, // [..., C, N, 3] const at::Tensor v_ray_transforms, // [..., C, N, 3, 3] const bool viewmats_requires_grad, + const bool Ks_requires_grad, // outputs at::Tensor v_means, // [..., N, 3] at::Tensor v_quats, // [..., N, 4] at::Tensor v_scales, // [..., N, 3] - at::Tensor v_viewmats // [..., C, 4, 4] + at::Tensor v_viewmats, // [..., C, 4, 4] + at::Tensor v_Ks // [..., C, 3, 3] ) { uint32_t N = means.size(-2); // number of gaussians uint32_t B = means.numel() / (N * 3); // number of batches @@ -536,7 +565,8 @@ void launch_projection_2dgs_fused_bwd_kernel( v_means.data_ptr(), v_quats.data_ptr(), v_scales.data_ptr(), - viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr + viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr, + Ks_requires_grad ? v_Ks.data_ptr() : nullptr ); } diff --git a/gsplat/cuda/csrc/Projection2DGSPacked.cu b/gsplat/cuda/csrc/Projection2DGSPacked.cu index 4667dc770..af73508b9 100644 --- a/gsplat/cuda/csrc/Projection2DGSPacked.cu +++ b/gsplat/cuda/csrc/Projection2DGSPacked.cu @@ -308,7 +308,8 @@ __global__ void projection_2dgs_packed_bwd_kernel( scalar_t *__restrict__ v_means, // [B, N, 3] or [nnz, 3] scalar_t *__restrict__ v_quats, // [B, N, 4] or [nnz, 4] Optional scalar_t *__restrict__ v_scales, // [B, N, 3] or [nnz, 3] Optional - scalar_t *__restrict__ v_viewmats // [B, C, 4, 4] Optional + scalar_t *__restrict__ v_viewmats, // [B, C, 4, 4] Optional + scalar_t *__restrict__ v_Ks // [B, C, 3, 3] Optional ) { // parallelize over nnz. uint32_t idx = cg::this_grid().thread_rank(); @@ -373,6 +374,7 @@ __global__ void projection_2dgs_packed_bwd_kernel( vec4 v_quat(0.f); mat3 v_R(0.f); vec3 v_t(0.f); + float v_fx(0.f), v_fy(0.f), v_cx(0.f), v_cy(0.f); compute_ray_transforms_aabb_vjp( ray_transforms, v_means2d, @@ -389,7 +391,11 @@ __global__ void projection_2dgs_packed_bwd_kernel( v_scale, v_mean, v_R, - v_t + v_t, + v_fx, + v_fy, + v_cx, + v_cy ); auto warp = cg::tiled_partition<32>(cg::this_thread_block()); @@ -456,6 +462,24 @@ __global__ void projection_2dgs_packed_bwd_kernel( } } } + + if (v_Ks != nullptr) { + auto warp_group_c = cg::labeled_partition(warp, cid); + // Sum intrinsic gradients across warp for the same camera + float warp_v_fx = v_fx, warp_v_fy = v_fy, warp_v_cx = v_cx, warp_v_cy = v_cy; + warpSum(warp_v_fx, warp_group_c); + warpSum(warp_v_fy, warp_group_c); + warpSum(warp_v_cx, warp_group_c); + warpSum(warp_v_cy, warp_group_c); + if (warp_group_c.thread_rank() == 0) { + v_Ks += bid * C * 9 + cid * 9; + // Accumulate gradients into the Ks matrix [fx, 0, cx; 0, fy, cy; 0, 0, 1] + gpuAtomicAdd(v_Ks + 0, warp_v_fx); // [0,0] = fx + gpuAtomicAdd(v_Ks + 2, warp_v_cx); // [0,2] = cx + gpuAtomicAdd(v_Ks + 4, warp_v_fy); // [1,1] = fy + gpuAtomicAdd(v_Ks + 5, warp_v_cy); // [1,2] = cy + } + } } void launch_projection_2dgs_packed_bwd_kernel( @@ -482,7 +506,8 @@ void launch_projection_2dgs_packed_bwd_kernel( at::Tensor v_means, // [..., N, 3] or [nnz, 3] at::Tensor v_quats, // [..., N, 4] or [nnz, 4] at::Tensor v_scales, // [..., N, 3] or [nnz, 3] - at::optional v_viewmats // [..., C, 4, 4] Optional + at::optional v_viewmats, // [..., C, 4, 4] Optional + at::optional v_Ks // [..., C, 3, 3] Optional ) { uint32_t N = means.size(-2); // number of gaussians uint32_t B = means.numel() / (N * 3); // number of batches @@ -524,7 +549,8 @@ void launch_projection_2dgs_packed_bwd_kernel( v_quats.data_ptr(), v_scales.data_ptr(), v_viewmats.has_value() ? v_viewmats.value().data_ptr() - : nullptr + : nullptr, + v_Ks.has_value() ? v_Ks.value().data_ptr() : nullptr ); } diff --git a/gsplat/cuda/csrc/ProjectionEWA3DGSFused.cu b/gsplat/cuda/csrc/ProjectionEWA3DGSFused.cu index 09997d04a..4cde3e329 100644 --- a/gsplat/cuda/csrc/ProjectionEWA3DGSFused.cu +++ b/gsplat/cuda/csrc/ProjectionEWA3DGSFused.cu @@ -320,7 +320,8 @@ __global__ void projection_ewa_3dgs_fused_bwd_kernel( scalar_t *__restrict__ v_covars, // [B, N, 6] optional scalar_t *__restrict__ v_quats, // [B, N, 4] optional scalar_t *__restrict__ v_scales, // [B, N, 3] optional - scalar_t *__restrict__ v_viewmats // [B, C, 4, 4] optional + scalar_t *__restrict__ v_viewmats, // [B, C, 4, 4] optional + scalar_t *__restrict__ v_Ks // [B, C, 3, 3] optional ) { // parallelize over B * C * N. uint32_t idx = cg::this_grid().thread_rank(); @@ -403,55 +404,119 @@ __global__ void projection_ewa_3dgs_fused_bwd_kernel( float fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; mat3 v_covar_c(0.f); vec3 v_mean_c(0.f); + float v_fx(0.f), v_fy(0.f), v_cx(0.f), v_cy(0.f); switch (camera_model) { case CameraModelType::PINHOLE: // perspective projection - persp_proj_vjp( - mean_c, - covar_c, - fx, - fy, - cx, - cy, - image_width, - image_height, - v_covar2d, - glm::make_vec2(v_means2d), - v_mean_c, - v_covar_c - ); + if (v_Ks != nullptr) { + persp_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c, + v_fx, + v_fy, + v_cx, + v_cy + ); + } else { + persp_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c + ); + } break; case CameraModelType::ORTHO: // orthographic projection - ortho_proj_vjp( - mean_c, - covar_c, - fx, - fy, - cx, - cy, - image_width, - image_height, - v_covar2d, - glm::make_vec2(v_means2d), - v_mean_c, - v_covar_c - ); + if (v_Ks != nullptr) { + ortho_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c, + v_fx, + v_fy, + v_cx, + v_cy + ); + } else { + ortho_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c + ); + } break; case CameraModelType::FISHEYE: // fisheye projection - fisheye_proj_vjp( - mean_c, - covar_c, - fx, - fy, - cx, - cy, - image_width, - image_height, - v_covar2d, - glm::make_vec2(v_means2d), - v_mean_c, - v_covar_c - ); + if (v_Ks != nullptr) { + fisheye_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c, + v_fx, + v_fy, + v_cx, + v_cy + ); + } else { + fisheye_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c + ); + } break; } @@ -528,6 +593,21 @@ __global__ void projection_ewa_3dgs_fused_bwd_kernel( } } } + if (v_Ks != nullptr) { + // Only one thread per camera accumulates intrinsics gradients + auto warp_group_c = cg::labeled_partition(warp, cid); + warpSum(v_fx, warp_group_c); + warpSum(v_fy, warp_group_c); + warpSum(v_cx, warp_group_c); + warpSum(v_cy, warp_group_c); + if (warp_group_c.thread_rank() == 0) { + v_Ks += bid * C * 9 + cid * 9; + gpuAtomicAdd(v_Ks + 0, v_fx); // fx + gpuAtomicAdd(v_Ks + 2, v_cx); // cx + gpuAtomicAdd(v_Ks + 4, v_fy); // fy + gpuAtomicAdd(v_Ks + 5, v_cy); // cy + } + } } void launch_projection_ewa_3dgs_fused_bwd_kernel( @@ -553,12 +633,14 @@ void launch_projection_ewa_3dgs_fused_bwd_kernel( const at::Tensor v_conics, // [..., C, N, 3] const at::optional v_compensations, // [..., C, N] optional const bool viewmats_requires_grad, + const bool Ks_requires_grad, // outputs at::Tensor v_means, // [..., N, 3] at::Tensor v_covars, // [..., N, 3, 3] at::Tensor v_quats, // [..., N, 4] at::Tensor v_scales, // [..., N, 3] - at::Tensor v_viewmats // [..., C, 4, 4] + at::Tensor v_viewmats, // [..., C, 4, 4] + at::Tensor v_Ks // [..., C, 3, 3] ) { uint32_t N = means.size(-2); // number of gaussians uint32_t C = viewmats.size(-3); // number of cameras @@ -617,7 +699,9 @@ void launch_projection_ewa_3dgs_fused_bwd_kernel( covars.has_value() ? nullptr : v_scales.data_ptr(), viewmats_requires_grad ? v_viewmats.data_ptr() - : nullptr + : nullptr, + Ks_requires_grad ? v_Ks.data_ptr() + : nullptr ); } ); diff --git a/gsplat/cuda/include/Ops.h b/gsplat/cuda/include/Ops.h index 3638a13c1..26b9a90e0 100644 --- a/gsplat/cuda/include/Ops.h +++ b/gsplat/cuda/include/Ops.h @@ -62,7 +62,7 @@ projection_ewa_3dgs_fused_fwd( const bool calc_compensations, const CameraModelType camera_model ); -std::tuple +std::tuple projection_ewa_3dgs_fused_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -84,7 +84,8 @@ projection_ewa_3dgs_fused_bwd( const at::Tensor v_depths, // [..., C, N] const at::Tensor v_conics, // [..., C, N, 3] const at::optional v_compensations, // [..., C, N] optional - const bool viewmats_requires_grad + const bool viewmats_requires_grad, + const bool Ks_requires_grad ); // On top of fusing the operations like `projection_ewa_3dgs_fused_{fwd, bwd}`, @@ -310,7 +311,7 @@ projection_2dgs_fused_fwd( const float far_plane, const float radius_clip ); -std::tuple +std::tuple projection_2dgs_fused_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -328,7 +329,8 @@ projection_2dgs_fused_bwd( const at::Tensor v_depths, // [..., C, N] const at::Tensor v_normals, // [..., C, N, 3] const at::Tensor v_ray_transforms, // [..., C, N, 3, 3] - const bool viewmats_requires_grad + const bool viewmats_requires_grad, + const bool Ks_requires_grad ); std::tuple< @@ -353,7 +355,7 @@ projection_2dgs_packed_fwd( const float far_plane, const float radius_clip ); -std::tuple +std::tuple projection_2dgs_packed_bwd( // fwd inputs const at::Tensor means, // [..., N, 3] @@ -374,6 +376,7 @@ projection_2dgs_packed_bwd( const at::Tensor v_ray_transforms, // [nnz, 3, 3] const at::Tensor v_normals, // [nnz, 3] const bool viewmats_requires_grad, + const bool Ks_requires_grad, const bool sparse_grad ); diff --git a/gsplat/cuda/include/Utils.cuh b/gsplat/cuda/include/Utils.cuh index 2809d93f0..992f13ba4 100644 --- a/gsplat/cuda/include/Utils.cuh +++ b/gsplat/cuda/include/Utils.cuh @@ -495,6 +495,64 @@ inline __device__ void ortho_proj_vjp( v_mean3d += vec3(fx * v_mean2d[0], fy * v_mean2d[1], 0.f); } +// Overloaded version that also computes gradients w.r.t. intrinsics +inline __device__ void ortho_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const float fx, + const float fy, + const float cx, + const float cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d, + float &v_fx, + float &v_fy, + float &v_cx, + float &v_cy +) { + float x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx, + 0.f, // 1st column + 0.f, + fy, // 2nd column + 0.f, + 0.f // 3rd column + ); + + // cov = J * V * Jt; G = df/dcov = v_cov + // -> df/dV = Jt * G * J + // -> df/dJ = G * J * Vt + Gt * J * V + v_cov3d += glm::transpose(J) * v_cov2d * J; + + // Gradients w.r.t. mean + v_mean3d += vec3(fx * v_mean2d[0], fy * v_mean2d[1], 0.f); + + // Gradients w.r.t. intrinsics from mean2d = [fx * x + cx, fy * y + cy] + v_fx += x * v_mean2d[0]; // d(mean2d.x)/dfx = x + v_fy += y * v_mean2d[1]; // d(mean2d.y)/dfy = y + v_cx += v_mean2d[0]; // d(mean2d.x)/dcx = 1 + v_cy += v_mean2d[1]; // d(mean2d.y)/dcy = 1 + + // Gradients from covariance through Jacobian + mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + + glm::transpose(v_cov2d) * J * cov3d; + + // Gradients w.r.t. intrinsics from Jacobian elements + // J = [fx, 0; 0, fy; 0, 0] + v_fx += v_J[0][0]; + v_fy += v_J[1][1]; +} + inline __device__ void persp_proj( // inputs const vec3 mean3d, @@ -615,6 +673,100 @@ inline __device__ void persp_proj_vjp( 2.f * fy * ty * rz3 * v_J[2][1]; } +// Overloaded version that also computes gradients w.r.t. intrinsics +inline __device__ void persp_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const float fx, + const float fy, + const float cx, + const float cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d, + float &v_fx, + float &v_fy, + float &v_cx, + float &v_cy +) { + float x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + float tan_fovx = 0.5f * width / fx; + float tan_fovy = 0.5f * height / fy; + float lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; + float lim_x_neg = cx / fx + 0.3f * tan_fovx; + float lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; + float lim_y_neg = cy / fy + 0.3f * tan_fovy; + + float rz = 1.f / z; + float rz2 = rz * rz; + float tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); + float ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + fx * rz, + 0.f, // 1st column + 0.f, + fy * rz, // 2nd column + -fx * tx * rz2, + -fy * ty * rz2 // 3rd column + ); + + // cov = J * V * Jt; G = df/dcov = v_cov + // -> df/dV = Jt * G * J + // -> df/dJ = G * J * Vt + Gt * J * V + v_cov3d += glm::transpose(J) * v_cov2d * J; + + // Gradients w.r.t. mean + v_mean3d += vec3( + fx * rz * v_mean2d[0], + fy * rz * v_mean2d[1], + -(fx * x * v_mean2d[0] + fy * y * v_mean2d[1]) * rz2 + ); + + // Gradients w.r.t. intrinsics from mean2d = [fx * x * rz + cx, fy * y * rz + cy] + v_fx += x * rz * v_mean2d[0]; // d(mean2d.x)/dfx = x * rz + v_fy += y * rz * v_mean2d[1]; // d(mean2d.y)/dfy = y * rz + v_cx += v_mean2d[0]; // d(mean2d.x)/dcx = 1 + v_cy += v_mean2d[1]; // d(mean2d.y)/dcy = 1 + + // Gradients from covariance through Jacobian + float rz3 = rz2 * rz; + mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + + glm::transpose(v_cov2d) * J * cov3d; + + // Gradients w.r.t. intrinsics from Jacobian elements + // J = [fx*rz, 0; 0, fy*rz; -fx*tx*rz2, -fy*ty*rz2] + v_fx += rz * v_J[0][0] - tx * rz2 * v_J[2][0]; + v_fy += rz * v_J[1][1] - ty * rz2 * v_J[2][1]; + + // Additional gradients from fov clipping limits + // The clipping limits depend on fx, fy, cx, cy but for simplicity + // we'll ignore these higher-order effects for now + + // fov clipping + if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) { + v_mean3d.x += -fx * rz2 * v_J[2][0]; + } else { + v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; + } + if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) { + v_mean3d.y += -fy * rz2 * v_J[2][1]; + } else { + v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; + } + v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + + 2.f * fx * tx * rz3 * v_J[2][0] + + 2.f * fy * ty * rz3 * v_J[2][1]; +} + inline __device__ void fisheye_proj( // inputs const vec3 mean3d, @@ -756,6 +908,41 @@ inline __device__ void fisheye_proj_vjp( v_mean3d.z += dL_dtz_raw; } +// Overloaded version that also computes gradients w.r.t. intrinsics +// Note: For fisheye projection, intrinsics gradients are complex and not implemented yet +inline __device__ void fisheye_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const float fx, + const float fy, + const float cx, + const float cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d, + float &v_fx, + float &v_fy, + float &v_cx, + float &v_cy +) { + // For now, call the original version and set intrinsics gradients to zero + fisheye_proj_vjp(mean3d, cov3d, fx, fy, cx, cy, width, height, v_cov2d, v_mean2d, v_mean3d, v_cov3d); + + // TODO: Implement proper fisheye intrinsics gradients + // For now, only handle the simple parts + v_cx += v_mean2d[0]; // d(mean2d.x)/dcx = 1 + v_cy += v_mean2d[1]; // d(mean2d.y)/dcy = 1 + // fx, fy gradients are more complex for fisheye - set to 0 for now + v_fx += 0.0f; + v_fy += 0.0f; +} + inline __device__ vec3 safe_normalize(vec3 v) { const float l = v.x * v.x + v.y * v.y + v.z * v.z; return l > 0.0f ? (v * rsqrtf(l)) : v; diff --git a/test_intrinsics_optimization.py b/test_intrinsics_optimization.py new file mode 100644 index 000000000..dbfa3ec85 --- /dev/null +++ b/test_intrinsics_optimization.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +import os + +# Import gsplat functions +from gsplat import rasterization, rasterization_2dgs + +def create_random_scene(n_gaussians, device="cuda:0"): + """Create a random scene of 3D Gaussians.""" + torch.manual_seed(41) # For reproducibility + + # Random positions in a reasonable range + means = torch.randn(n_gaussians, 3, device=device) * 3.0 + means[:, 2] += 8.0 # Push gaussians further away from camera (z > 5) + + # Random orientations (quaternions) + quats = F.normalize(torch.randn(n_gaussians, 4, device=device), dim=-1) + + # Smaller, more reasonable scales + scales = torch.exp(torch.randn(n_gaussians, 3, device=device) * 0.3 - 2.0) # Much smaller scales + + # More varied colors (not all purple) + colors = torch.rand(n_gaussians, 3, device=device) + + # Random opacities (not too high) + opacities = torch.rand(n_gaussians, device=device) * 0.5 + 0.2 # Lower opacity range + + return { + 'means': means, + 'quats': quats, + 'scales': scales, + 'colors': colors, + 'opacities': opacities + } + +def create_camera_setup(device="cuda:0"): + """Create camera parameters.""" + # Ground truth intrinsics + fx_true, fy_true = 800.0, 800.0 + cx_true, cy_true = 320.0, 240.0 + + Ks_true = torch.tensor([[[fx_true, 0.0, cx_true], + [0.0, fy_true, cy_true], + [0.0, 0.0, 1.0]]], device=device) + + # Initial guess (smaller perturbation) + fx_init = fx_true + torch.randn(1).item() * 100 # ±100 pixel error + fy_init = fy_true + torch.randn(1).item() * 100 + cx_init = cx_true + torch.randn(1).item() * 50 # ±50 pixel error + cy_init = cy_true + torch.randn(1).item() * 50 + + Ks_init = torch.tensor([[[fx_init, 0.0, cx_init], + [0.0, fy_init, cy_init], + [0.0, 0.0, 1.0]]], device=device, requires_grad=True) + + # Simple camera pose (identity) + viewmats = torch.eye(4, device=device).unsqueeze(0) + + # Image dimensions + width, height = 640, 480 + + return { + 'Ks_true': Ks_true, + 'Ks_init': Ks_init, + 'viewmats': viewmats, + 'width': width, + 'height': height + } + +def render_gaussians(scene, camera, Ks, use_2dgs): + """Render 3D Gaussians to 2D image.""" + # Use the full rasterization pipeline + if use_2dgs: + render_colors = rasterization_2dgs( + means=scene['means'].unsqueeze(0), + quats=scene['quats'].unsqueeze(0), + scales=scene['scales'].unsqueeze(0), + opacities=scene['opacities'].unsqueeze(0), + colors=scene['colors'].unsqueeze(1).unsqueeze(0), + viewmats=camera['viewmats'].unsqueeze(0), + Ks=Ks.unsqueeze(0), + width=camera['width'], + height=camera['height'], + # near_plane=0.01, + # far_plane=100.0, + packed=False, # Use non-packed version which has intrinsics gradients + backgrounds=torch.zeros(1, 1, 3, device=Ks.device), # [1, 3] for single camera + sh_degree=0, + + )[0] + render_colors = render_colors[0] # Remove batch dimension + else: + render_colors = rasterization( + means=scene['means'], + quats=scene['quats'], + scales=scene['scales'], + opacities=scene['opacities'], + colors=scene['colors'], + viewmats=camera['viewmats'], + Ks=Ks, + width=camera['width'], + height=camera['height'], + packed=False, # Use non-packed version which has intrinsics gradients + backgrounds=torch.zeros(1, 3, device=Ks.device), # [1, 3] for single camera + camera_model="pinhole" + )[0] + return render_colors + +def optimize_intrinsics(scene, camera, n_iterations=200, lr=100.0, use_2dgs=None): + """Optimize camera intrinsics to match ground truth rendering.""" + + assert use_2dgs is not None, "use_2dgs must be specified (True or False)" + + # Create output directory + output_dir = Path("intrinsics_optimization_results") + output_dir.mkdir(exist_ok=True) + + # Render ground truth image + with torch.no_grad(): + gt_image = render_gaussians(scene, camera, camera['Ks_true'], use_2dgs=use_2dgs) + gt_image = gt_image.squeeze(0) # Remove batch dimension + + # Save ground truth + plt.figure(figsize=(12, 4)) + plt.subplot(1, 3, 1) + plt.imshow(gt_image.cpu().numpy()) + plt.title("Ground Truth") + plt.axis('off') + plt.tight_layout() + plt.savefig(output_dir / "00_ground_truth.png", dpi=150, bbox_inches='tight') + plt.close() + + # Optimization setup + Ks_opt = camera['Ks_init'].clone().detach().requires_grad_(True) + optimizer = torch.optim.Adam([Ks_opt], lr=lr) + # optimizer = torch.optim.SGD([Ks_opt], lr=lr*10, momentum=0.9) + + losses = [] + intrinsics_history = [] + + print("Starting intrinsics optimization...") + print(f"Ground truth: fx={camera['Ks_true'][0,0,0]:.1f}, fy={camera['Ks_true'][0,1,1]:.1f}, " + f"cx={camera['Ks_true'][0,0,2]:.1f}, cy={camera['Ks_true'][0,1,2]:.1f}") + print(f"Initial guess: fx={Ks_opt[0,0,0]:.1f}, fy={Ks_opt[0,1,1]:.1f}, " + f"cx={Ks_opt[0,0,2]:.1f}, cy={Ks_opt[0,1,2]:.1f}") + print() + + for iteration in range(n_iterations): + optimizer.zero_grad() + + # Render with current intrinsics + pred_image = render_gaussians(scene, camera, Ks_opt, use_2dgs=use_2dgs) + pred_image = pred_image.squeeze(0) + + # Compute loss (MSE between images) + loss = F.mse_loss(pred_image, gt_image) + + # Backward pass + loss.backward() + + # Debug: Check if gradients are being computed + if iteration == 0: + print(f"Ks gradients at iteration 0: {Ks_opt.grad}") + if Ks_opt.grad is None: + print("WARNING: No gradients computed for Ks!") + else: + print(f"Gradient magnitudes: fx={Ks_opt.grad[0,0,0]:.6f}, fy={Ks_opt.grad[0,1,1]:.6f}, " + f"cx={Ks_opt.grad[0,0,2]:.6f}, cy={Ks_opt.grad[0,1,2]:.6f}") + + optimizer.step() + + # Store metrics + losses.append(loss.item()) + intrinsics_history.append([ + Ks_opt[0,0,0].item(), # fx + Ks_opt[0,1,1].item(), # fy + Ks_opt[0,0,2].item(), # cx + Ks_opt[0,1,2].item() # cy + ]) + + # Print progress + if iteration % 20 == 0: + print(f"Iter {iteration:3d}: Loss={loss.item():.6f}, " + f"fx={Ks_opt[0,0,0]:.1f}, fy={Ks_opt[0,1,1]:.1f}, " + f"cx={Ks_opt[0,0,2]:.1f}, cy={Ks_opt[0,1,2]:.1f}") + + # Save intermediate results + if iteration % 40 == 0 or iteration == n_iterations - 1: + plt.figure(figsize=(15, 5)) + + # Ground truth + plt.subplot(1, 3, 1) + plt.imshow(gt_image.cpu().numpy()) + plt.title("Ground Truth") + plt.axis('off') + + # Current prediction + plt.subplot(1, 3, 2) + plt.imshow(pred_image.detach().cpu().numpy()) + plt.title(f"Prediction (Iter {iteration})") + plt.axis('off') + + # Difference + plt.subplot(1, 3, 3) + diff = torch.abs(pred_image - gt_image).detach().cpu().numpy() + plt.imshow(diff) + plt.title(f"Abs Difference (Loss={loss.item():.4f})") + plt.axis('off') + + plt.tight_layout() + plt.savefig(output_dir / f"iter_{iteration:03d}.png", dpi=150, bbox_inches='tight') + plt.close() + + # Plot optimization curves + plt.figure(figsize=(15, 4)) + + # Loss curve + plt.subplot(1, 3, 1) + plt.plot(losses) + plt.xlabel('Iteration') + plt.ylabel('MSE Loss') + plt.title('Optimization Progress') + plt.yscale('log') + plt.grid(True) + + # Intrinsics convergence + intrinsics_history = np.array(intrinsics_history) + gt_values = [camera['Ks_true'][0,0,0].item(), camera['Ks_true'][0,1,1].item(), + camera['Ks_true'][0,0,2].item(), camera['Ks_true'][0,1,2].item()] + + plt.subplot(1, 3, 2) + labels = ['fx', 'fy', 'cx', 'cy'] + for i, (label, gt_val) in enumerate(zip(labels, gt_values)): + plt.plot(intrinsics_history[:, i], label=f'{label} (pred)') + plt.axhline(y=gt_val, color=f'C{i}', linestyle='--', alpha=0.7, label=f'{label} (GT)') + plt.xlabel('Iteration') + plt.ylabel('Parameter Value') + plt.title('Intrinsics Convergence') + plt.legend() + plt.grid(True) + + # Final error + plt.subplot(1, 3, 3) + final_errors = np.abs(intrinsics_history[-1] - gt_values) + plt.bar(labels, final_errors) + plt.ylabel('Absolute Error') + plt.title('Final Parameter Errors') + plt.grid(True) + + plt.tight_layout() + plt.savefig(output_dir / "optimization_curves.png", dpi=150, bbox_inches='tight') + plt.close() + + # Final summary + print("\n" + "="*60) + print("OPTIMIZATION COMPLETE") + print("="*60) + print("Ground Truth:") + print(f" fx={camera['Ks_true'][0,0,0]:.2f}, fy={camera['Ks_true'][0,1,1]:.2f}") + print(f" cx={camera['Ks_true'][0,0,2]:.2f}, cy={camera['Ks_true'][0,1,2]:.2f}") + print("\nFinal Prediction:") + print(f" fx={Ks_opt[0,0,0]:.2f}, fy={Ks_opt[0,1,1]:.2f}") + print(f" cx={Ks_opt[0,0,2]:.2f}, cy={Ks_opt[0,1,2]:.2f}") + print("\nAbsolute Errors:") + print(f" fx: {abs(Ks_opt[0,0,0] - camera['Ks_true'][0,0,0]):.2f}") + print(f" fy: {abs(Ks_opt[0,1,1] - camera['Ks_true'][0,1,1]):.2f}") + print(f" cx: {abs(Ks_opt[0,0,2] - camera['Ks_true'][0,0,2]):.2f}") + print(f" cy: {abs(Ks_opt[0,1,2] - camera['Ks_true'][0,1,2]):.2f}") + print(f"\nFinal Loss: {losses[-1]:.8f}") + print(f"\nResults saved to: {output_dir.absolute()}") + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Optimize camera intrinsics using 3D Gaussian splatting.") + parser.add_argument('--use_2dgs', action='store_true', help="Use 2DGS rasterization (default is standard rasterization).") + parser.add_argument('--use_3dgs', action='store_false', dest='use_2dgs', help="Use standard 3D rasterization.") + args = parser.parse_args() + use_2dgs = args.use_2dgs + print(f"Using {'2DGS' if use_2dgs else '3DGS'} rasterization for optimization.") + """Main function to run the intrinsics optimization test.""" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Create scene and camera + scene = create_random_scene(n_gaussians=1000, device=device) + camera = create_camera_setup(device=device) + + # Print scene statistics + print(f"\nScene Statistics:") + print(f" Number of Gaussians: {len(scene['means'])}") + print(f" Mean positions: {scene['means'].mean(0).cpu().numpy()}") + print(f" Position std: {scene['means'].std(0).cpu().numpy()}") + print(f" Z range: {scene['means'][:, 2].min().item():.2f} to {scene['means'][:, 2].max().item():.2f}") + print(f" Scale range: {scene['scales'].min().item():.4f} to {scene['scales'].max().item():.4f}") + print(f" Opacity range: {scene['opacities'].min().item():.3f} to {scene['opacities'].max().item():.3f}") + + # Run optimization + optimize_intrinsics(scene, camera, n_iterations=200, lr=1.0, use_2dgs=use_2dgs) + +if __name__ == "__main__": + main()