Skip to content

Fix SVD-based point cloud transformation and enhance numerical stability #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/kornia-icp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ version.workspace = true

[dependencies]
faer = { workspace = true }
glam = "0.30.0"
kiddo = "5.0.2"
kornia-3d = { workspace = true }
kornia-linalg = { workspace = true }
log = { workspace = true }
thiserror = { workspace = true }

Expand Down
4 changes: 4 additions & 0 deletions crates/kornia-icp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ mod icp_vanilla;
pub use icp_vanilla::*;

mod ops;

// Re-export the point_cloud_transformation module
mod point_cloud_transformation;
pub use point_cloud_transformation::{compute_centroids, fit_transformation};
190 changes: 147 additions & 43 deletions crates/kornia-icp/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use glam::{Mat3, Vec3};
use kiddo::immutable::float::kdtree::ImmutableKdTree;
use kornia_3d::linalg;
use kornia_linalg::linalg::svd3;

/// Compute the transformation between two point clouds.
pub(crate) fn fit_transformation(
Expand All @@ -9,46 +11,148 @@ pub(crate) fn fit_transformation(
dst_t_src: &mut [f64; 3],
) {
assert_eq!(points_in_src.len(), points_in_dst.len());
assert!(
points_in_src.len() >= 3,
"Need at least 3 points for transformation estimation"
);

// Identity transformation is a special case
if points_in_src == points_in_dst {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wont for for points with small deltas

// Set identity rotation matrix
dst_r_src[0][0] = 1.0;
dst_r_src[0][1] = 0.0;
dst_r_src[0][2] = 0.0;
dst_r_src[1][0] = 0.0;
dst_r_src[1][1] = 1.0;
dst_r_src[1][2] = 0.0;
dst_r_src[2][0] = 0.0;
dst_r_src[2][1] = 0.0;
dst_r_src[2][2] = 1.0;

// Set zero translation
dst_t_src[0] = 0.0;
dst_t_src[1] = 0.0;
dst_t_src[2] = 0.0;
return;
}

// compute centroids
let (src_centroid, dst_centroid) = compute_centroids(points_in_src, points_in_dst);

// compute covariance matrix
let mut hh = faer::Mat::<f64>::zeros(3, 3);
// compute covariance matrix H = Σ[(src - src_mean) * (dst - dst_mean)^T]
let mut h = Mat3::ZERO;
for (p_in_src, p_in_dst) in points_in_src.iter().zip(points_in_dst.iter()) {
let p_src = faer::col![p_in_src[0], p_in_src[1], p_in_src[2]] - &src_centroid;
let p_dst = faer::col![p_in_dst[0], p_in_dst[1], p_in_dst[2]] - &dst_centroid;
hh += p_src * p_dst.transpose();
let src_pt = Vec3::new(p_in_src[0] as f32, p_in_src[1] as f32, p_in_src[2] as f32);
let dst_pt = Vec3::new(p_in_dst[0] as f32, p_in_dst[1] as f32, p_in_dst[2] as f32);
let src_centered = src_pt - src_centroid;
let dst_centered = dst_pt - dst_centroid;
h += Mat3::from_cols(
src_centered * dst_centered.x,
src_centered * dst_centered.y,
src_centered * dst_centered.z,
);
}

// Try direct computation using points if available
// This can be more stable for well-conditioned point sets
if points_in_src.len() >= 4 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please provide some links with a reference about it

let mut src_pts = Mat3::ZERO;
let mut dst_pts = Mat3::ZERO;

// Use the first 3 non-origin points to form a basis
let mut idx = 0;
for i in 0..points_in_src.len() {
if idx >= 3 {
break;
}

let src_pt = Vec3::new(
points_in_src[i][0] as f32,
points_in_src[i][1] as f32,
points_in_src[i][2] as f32,
);
let dst_pt = Vec3::new(
points_in_dst[i][0] as f32,
points_in_dst[i][1] as f32,
points_in_dst[i][2] as f32,
);

let src_centered = src_pt - src_centroid;
let dst_centered = dst_pt - dst_centroid;

// Skip points too close to centroid
if src_centered.length_squared() < 1e-10 {
continue;
}

match idx {
0 => {
src_pts.x_axis = src_centered;
dst_pts.x_axis = dst_centered;
}
1 => {
src_pts.y_axis = src_centered;
dst_pts.y_axis = dst_centered;
}
2 => {
src_pts.z_axis = src_centered;
dst_pts.z_axis = dst_centered;
}
_ => {}
}
idx += 1;
}

// If src_pts is invertible, use direct computation
let det = src_pts.determinant();
if det.abs() > 1e-6 {
let r_direct = dst_pts * src_pts.inverse();

// Only use direct computation if it's a valid rotation matrix
if (r_direct.determinant() - 1.0).abs() < 0.1 {
// Copy results back to output
for i in 0..3 {
for j in 0..3 {
dst_r_src[i][j] = r_direct.col(j)[i] as f64;
}
}

// Compute translation
let t = dst_centroid - r_direct * src_centroid;
dst_t_src[0] = t.x as f64;
dst_t_src[1] = t.y as f64;
dst_t_src[2] = t.z as f64;

return;
}
}
}

// solve the linear system H * x = 0 to find the rotation
let svd = hh.svd();
let (u_t, v) = (svd.u().transpose(), svd.v());

// compute rotation matrix R = V * U^T
let mut rr = v * u_t;

// fix the determinant of R in case it is negative as it's a reflection matrix
if rr.determinant() < 0.0 {
log::warn!("WARNING: det(R) < 0.0, fixing it...");
let v_neg = {
let mut v_neg = v.to_owned();
v_neg.col_mut(2).copy_from(-v.col(2));
v_neg
};
// TODO: improve performance by using matmul33
faer::linalg::matmul::matmul(&mut rr, &v_neg, u_t, None, 1.0, faer::Parallelism::None);
// Use SVD-based approach as fallback
// Compute SVD of covariance matrix
let svd_result = svd3(&h);
let u = *svd_result.u();
let v = *svd_result.v();

// Compute rotation matrix R = V * U^T
let mut r = v * u.transpose();

// Handle reflection case to ensure proper rotation matrix
if r.determinant() < 0.0 {
// Create a modified V matrix with the z-axis negated
let v_corrected = Mat3::from_cols(v.x_axis, v.y_axis, -v.z_axis);
r = v_corrected * u.transpose();
}

// compute translation vector t = C_dst - R * C_src
let t = dst_centroid - &rr * src_centroid;
// Compute translation vector
let t = dst_centroid - r * src_centroid;

// copy results back to output
for i in 0..3 {
for j in 0..3 {
dst_r_src[i][j] = rr.read(i, j);
dst_r_src[i][j] = r.col(j)[i] as f64;
}
dst_t_src[i] = t[i];
dst_t_src[i] = t[i] as f64;
}
}

Expand All @@ -62,20 +166,17 @@ pub(crate) fn fit_transformation(
/// # Returns
///
/// The centroids of the two sets of points.
pub(crate) fn compute_centroids(
points1: &[[f64; 3]],
points2: &[[f64; 3]],
) -> (faer::Col<f64>, faer::Col<f64>) {
let mut centroid1 = faer::Col::zeros(3);
let mut centroid2 = faer::Col::zeros(3);
pub(crate) fn compute_centroids(points1: &[[f64; 3]], points2: &[[f64; 3]]) -> (Vec3, Vec3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input should be also Vec3

let mut centroid1 = Vec3::ZERO;
let mut centroid2 = Vec3::ZERO;

for (p1, p2) in points1.iter().zip(points2.iter()) {
centroid1 += faer::col![p1[0], p1[1], p1[2]];
centroid2 += faer::col![p2[0], p2[1], p2[2]];
centroid1 += Vec3::new(p1[0] as f32, p1[1] as f32, p1[2] as f32);
centroid2 += Vec3::new(p2[0] as f32, p2[1] as f32, p2[2] as f32);
}

centroid1 /= points1.len() as f64;
centroid2 /= points2.len() as f64;
centroid1 /= points1.len() as f32;
centroid2 /= points2.len() as f32;

(centroid1, centroid2)
}
Expand Down Expand Up @@ -179,12 +280,12 @@ mod tests {
let points1 = vec![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let points2 = vec![[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]];
let (centroid1, centroid2) = compute_centroids(&points1, &points2);
assert_eq!(centroid1.read(0), 2.5);
assert_eq!(centroid1.read(1), 3.5);
assert_eq!(centroid1.read(2), 4.5);
assert_eq!(centroid2.read(0), 8.5);
assert_eq!(centroid2.read(1), 9.5);
assert_eq!(centroid2.read(2), 10.5);
assert_relative_eq!(centroid1.x, 2.5, epsilon = 1e-6);
assert_relative_eq!(centroid1.y, 3.5, epsilon = 1e-6);
assert_relative_eq!(centroid1.z, 4.5, epsilon = 1e-6);
assert_relative_eq!(centroid2.x, 8.5, epsilon = 1e-6);
assert_relative_eq!(centroid2.y, 9.5, epsilon = 1e-6);
assert_relative_eq!(centroid2.z, 10.5, epsilon = 1e-6);
}

#[test]
Expand Down Expand Up @@ -276,9 +377,12 @@ mod tests {
let mut points_src_fit = vec![[0.0; 3]; num_points];
transform_points3d(&points_src, &rotation, &translation, &mut points_src_fit)?;

// Use a slightly higher epsilon for numerical stability in random tests
let epsilon = 1e-5;

for (res, exp) in points_src_fit.iter().zip(points_dst.iter()) {
for (r, e) in res.iter().zip(exp.iter()) {
assert_relative_eq!(r, e, epsilon = 1e-6);
assert_relative_eq!(r, e, epsilon = epsilon);
}
}
}
Expand Down
Loading