diff --git a/Cargo.lock b/Cargo.lock index f23ca88..5df8242 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "autocfg" @@ -184,6 +184,15 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "ordered-float" +version = "5.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +dependencies = [ + "num-traits", +] + [[package]] name = "petal-decomposition" version = "0.9.0" @@ -202,6 +211,18 @@ dependencies = [ "thiserror", ] +[[package]] +name = "petal-neighbors" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16fd9983c4b760302fe9cee8ad204ee6555ed967df8fab30974138a001026c91" +dependencies = [ + "ndarray", + "num-traits", + "ordered-float", + "thiserror", +] + [[package]] name = "portable-atomic" version = "1.11.1" @@ -302,6 +323,7 @@ version = "0.1.0" dependencies = [ "numpy", "petal-decomposition", + "petal-neighbors", "pyo3", ] diff --git a/Cargo.toml b/Cargo.toml index a338804..77d7d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,9 @@ edition = "2018" [dependencies] numpy = "0.27.1" petal-decomposition = "0.9" +petal-neighbors = "0.18" pyo3 = { version = "0.27", features = ["extension-module"] } [lib] -name = "decomposition" +name = "pypetal" crate-type = ["cdylib"] diff --git a/setup.py b/setup.py index ddd690f..bb8b36d 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ setup( name="pypetal", version="0.1.0", - rust_extensions=[RustExtension("pypetal.decomposition")], + rust_extensions=[RustExtension("pypetal.pypetal")], packages=["pypetal"], zip_safe=False, ) diff --git a/src/decomposition.rs b/src/decomposition.rs index c55ada5..4118599 100644 --- a/src/decomposition.rs +++ b/src/decomposition.rs @@ -6,7 +6,7 @@ use pyo3::exceptions::PyException; use pyo3::prelude::*; #[pymodule] -fn decomposition(m: &Bound<'_, PyModule>) -> PyResult<()> { +pub fn decomposition(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; Ok(()) diff --git a/src/lib.rs b/src/lib.rs index ad5e474..45756f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,13 @@ mod decomposition; +mod neighbors; + +use pyo3::prelude::*; +use pyo3::wrap_pymodule; + +/// A Python module implemented in Rust. +#[pymodule] +fn pypetal(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_wrapped(wrap_pymodule!(self::decomposition::decomposition))?; + m.add_wrapped(wrap_pymodule!(self::neighbors::neighbors))?; + Ok(()) +} diff --git a/src/neighbors.rs b/src/neighbors.rs new file mode 100644 index 0000000..20ee928 --- /dev/null +++ b/src/neighbors.rs @@ -0,0 +1,163 @@ +#![allow(clippy::used_underscore_binding)] + +use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2}; +use petal_neighbors as petal; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +#[pymodule] +pub fn neighbors(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +/// A ball tree data structure for efficient nearest neighbor searches. +/// +/// The ball tree partitions data points into a nested set of hyperspheres +/// ("balls"), allowing for efficient nearest neighbor queries in +/// multi-dimensional spaces. +#[pyclass] +pub struct BallTree { + inner: petal::BallTree<'static, f64, petal::distance::Euclidean>, +} + +#[pymethods] +impl BallTree { + /// Creates a new BallTree from a 2D array of points using Euclidean distance. + /// + /// # Arguments + /// + /// * `points` - A 2D numpy array where each row represents a point in the space. + /// + /// # Errors + /// + /// Returns an error if the input array is empty or has non-contiguous rows. + #[new] + #[allow(clippy::needless_pass_by_value)] + fn new(points: PyReadonlyArray2) -> PyResult { + let points = points.as_array().to_owned(); + let inner = petal::BallTree::euclidean(points) + .map_err(|err| PyValueError::new_err(format!("{err}")))?; + Ok(BallTree { inner }) + } + + /// Returns the number of points in the tree. + #[getter] + fn n_samples(&self) -> usize { + self.inner.num_points() + } + + /// Returns the number of nodes in the tree. + #[getter] + fn n_nodes(&self) -> usize { + self.inner.num_nodes() + } + + /// Finds the single nearest neighbor to a query point. + /// + /// # Arguments + /// + /// * `point` - A 1D numpy array representing the query point. + /// + /// # Returns + /// + /// A tuple containing the index of the nearest neighbor and its distance. + #[allow(clippy::needless_pass_by_value)] + fn query_nearest(&self, point: PyReadonlyArray1) -> (usize, f64) { + let point = point.as_array(); + self.inner.query_nearest(&point) + } + + /// Finds the k nearest neighbors to a query point. + /// + /// # Arguments + /// + /// * `point` - A 1D numpy array representing the query point. + /// * `k` - The number of nearest neighbors to find. + /// + /// # Returns + /// + /// A tuple containing two numpy arrays: + /// - The indices of the k nearest neighbors. + /// - The distances to the k nearest neighbors. + /// + /// Results are sorted by ascending distance. + #[allow(clippy::needless_pass_by_value)] + fn query<'py>( + &self, + py: Python<'py>, + point: PyReadonlyArray1, + k: usize, + ) -> (Bound<'py, PyArray1>, Bound<'py, PyArray1>) { + let point = point.as_array(); + let (indices, distances) = self.inner.query(&point, k); + (indices.into_pyarray(py), distances.into_pyarray(py)) + } + + /// Finds all neighbors within a given radius of a query point. + /// + /// # Arguments + /// + /// * `point` - A 1D numpy array representing the query point. + /// * `radius` - The maximum distance for neighbors to be included. + /// + /// # Returns + /// + /// A numpy array containing the indices of all points within the radius. + #[allow(clippy::needless_pass_by_value)] + fn query_radius<'py>( + &self, + py: Python<'py>, + point: PyReadonlyArray1, + radius: f64, + ) -> Bound<'py, PyArray1> { + let point = point.as_array(); + let indices = self.inner.query_radius(&point, radius); + indices.into_pyarray(py) + } +} + +#[cfg(test)] +mod tests { + use numpy::ndarray::array; + use petal_neighbors as petal; + + #[test] + fn test_ball_tree_construction() { + let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; + let tree = petal::BallTree::euclidean(points).unwrap(); + assert_eq!(tree.num_points(), 4); + } + + #[test] + fn test_ball_tree_query_nearest() { + let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; + let tree = petal::BallTree::euclidean(points).unwrap(); + let query = array![0.1, 0.1]; + let (idx, dist) = tree.query_nearest(&query); + assert_eq!(idx, 0); + assert!((dist - 0.1_f64.hypot(0.1)).abs() < 1e-10); + } + + #[test] + fn test_ball_tree_query_k_nearest() { + let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; + let tree = petal::BallTree::euclidean(points).unwrap(); + let query = array![0.0, 0.0]; + let (indices, distances) = tree.query(&query, 2); + assert_eq!(indices.len(), 2); + assert_eq!(distances.len(), 2); + assert_eq!(indices[0], 0); // Nearest is the origin itself + assert!((distances[0] - 0.0_f64).abs() < 1e-10); + } + + #[test] + fn test_ball_tree_query_radius() { + let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; + let tree = petal::BallTree::euclidean(points).unwrap(); + let query = array![0.0, 0.0]; + let indices = tree.query_radius(&query, 1.1); + // Should include the origin, (1,0), and (0,1), but not (1,1) which is at distance sqrt(2) + assert_eq!(indices.len(), 3); + } +}