Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ edition = "2018"
[dependencies]
numpy = "0.27.1"
petal-decomposition = "0.9"
petal-neighbors = "0.18"
pyo3 = { version = "0.27", features = ["extension-module"] }

[lib]
name = "decomposition"
name = "pypetal"
crate-type = ["cdylib"]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
setup(
name="pypetal",
version="0.1.0",
rust_extensions=[RustExtension("pypetal.decomposition")],
rust_extensions=[RustExtension("pypetal.pypetal")],
packages=["pypetal"],
zip_safe=False,
)
2 changes: 1 addition & 1 deletion src/decomposition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::exceptions::PyException;
use pyo3::prelude::*;

#[pymodule]
fn decomposition(m: &Bound<'_, PyModule>) -> PyResult<()> {
pub fn decomposition(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<FastIca>()?;
m.add_class::<Pca>()?;
Ok(())
Expand Down
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
mod decomposition;
mod neighbors;

use pyo3::prelude::*;
use pyo3::wrap_pymodule;

/// A Python module implemented in Rust.
#[pymodule]
fn pypetal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(self::decomposition::decomposition))?;
m.add_wrapped(wrap_pymodule!(self::neighbors::neighbors))?;
Ok(())
}
163 changes: 163 additions & 0 deletions src/neighbors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#![allow(clippy::used_underscore_binding)]

use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use petal_neighbors as petal;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

#[pymodule]
pub fn neighbors(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<BallTree>()?;
Ok(())
}

/// A ball tree data structure for efficient nearest neighbor searches.
///
/// The ball tree partitions data points into a nested set of hyperspheres
/// ("balls"), allowing for efficient nearest neighbor queries in
/// multi-dimensional spaces.
#[pyclass]
pub struct BallTree {
inner: petal::BallTree<'static, f64, petal::distance::Euclidean>,
}

#[pymethods]
impl BallTree {
/// Creates a new BallTree from a 2D array of points using Euclidean distance.
///
/// # Arguments
///
/// * `points` - A 2D numpy array where each row represents a point in the space.
///
/// # Errors
///
/// Returns an error if the input array is empty or has non-contiguous rows.
#[new]
#[allow(clippy::needless_pass_by_value)]
fn new(points: PyReadonlyArray2<f64>) -> PyResult<Self> {
let points = points.as_array().to_owned();
let inner = petal::BallTree::euclidean(points)
.map_err(|err| PyValueError::new_err(format!("{err}")))?;
Ok(BallTree { inner })
}

/// Returns the number of points in the tree.
#[getter]
fn n_samples(&self) -> usize {
self.inner.num_points()
}

/// Returns the number of nodes in the tree.
#[getter]
fn n_nodes(&self) -> usize {
self.inner.num_nodes()
}

/// Finds the single nearest neighbor to a query point.
///
/// # Arguments
///
/// * `point` - A 1D numpy array representing the query point.
///
/// # Returns
///
/// A tuple containing the index of the nearest neighbor and its distance.
#[allow(clippy::needless_pass_by_value)]
fn query_nearest(&self, point: PyReadonlyArray1<f64>) -> (usize, f64) {
let point = point.as_array();
self.inner.query_nearest(&point)
}

/// Finds the k nearest neighbors to a query point.
///
/// # Arguments
///
/// * `point` - A 1D numpy array representing the query point.
/// * `k` - The number of nearest neighbors to find.
///
/// # Returns
///
/// A tuple containing two numpy arrays:
/// - The indices of the k nearest neighbors.
/// - The distances to the k nearest neighbors.
///
/// Results are sorted by ascending distance.
#[allow(clippy::needless_pass_by_value)]
fn query<'py>(
&self,
py: Python<'py>,
point: PyReadonlyArray1<f64>,
k: usize,
) -> (Bound<'py, PyArray1<usize>>, Bound<'py, PyArray1<f64>>) {
let point = point.as_array();
let (indices, distances) = self.inner.query(&point, k);
(indices.into_pyarray(py), distances.into_pyarray(py))
}

/// Finds all neighbors within a given radius of a query point.
///
/// # Arguments
///
/// * `point` - A 1D numpy array representing the query point.
/// * `radius` - The maximum distance for neighbors to be included.
///
/// # Returns
///
/// A numpy array containing the indices of all points within the radius.
#[allow(clippy::needless_pass_by_value)]
fn query_radius<'py>(
&self,
py: Python<'py>,
point: PyReadonlyArray1<f64>,
radius: f64,
) -> Bound<'py, PyArray1<usize>> {
let point = point.as_array();
let indices = self.inner.query_radius(&point, radius);
indices.into_pyarray(py)
}
}

#[cfg(test)]
mod tests {
use numpy::ndarray::array;
use petal_neighbors as petal;

#[test]
fn test_ball_tree_construction() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let tree = petal::BallTree::euclidean(points).unwrap();
assert_eq!(tree.num_points(), 4);
}

#[test]
fn test_ball_tree_query_nearest() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let tree = petal::BallTree::euclidean(points).unwrap();
let query = array![0.1, 0.1];
let (idx, dist) = tree.query_nearest(&query);
assert_eq!(idx, 0);
assert!((dist - 0.1_f64.hypot(0.1)).abs() < 1e-10);
}

#[test]
fn test_ball_tree_query_k_nearest() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let tree = petal::BallTree::euclidean(points).unwrap();
let query = array![0.0, 0.0];
let (indices, distances) = tree.query(&query, 2);
assert_eq!(indices.len(), 2);
assert_eq!(distances.len(), 2);
assert_eq!(indices[0], 0); // Nearest is the origin itself
assert!((distances[0] - 0.0_f64).abs() < 1e-10);
}

#[test]
fn test_ball_tree_query_radius() {
let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let tree = petal::BallTree::euclidean(points).unwrap();
let query = array![0.0, 0.0];
let indices = tree.query_radius(&query, 1.1);
// Should include the origin, (1,0), and (0,1), but not (1,1) which is at distance sqrt(2)
assert_eq!(indices.len(), 3);
}
}
Loading