diff --git a/polars_distance/Cargo.toml b/polars_distance/Cargo.toml index 1b8d6ab..9284626 100644 --- a/polars_distance/Cargo.toml +++ b/polars_distance/Cargo.toml @@ -17,6 +17,7 @@ serde = { version = "1", features = ["derive"] } distances = { version = "1.6.3"} rapidfuzz = { version = "0.5.0"} gestalt_ratio = { version = "0.2.1"} +num-traits = { version = "0.2" } [target.'cfg(target_os = "linux")'.dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/polars_distance/pyproject.toml b/polars_distance/pyproject.toml index 3d0d138..e4c2817 100644 --- a/polars_distance/pyproject.toml +++ b/polars_distance/pyproject.toml @@ -28,4 +28,4 @@ readme = "README.md" [tool.maturin] -module-name = "polars_distance._internal" \ No newline at end of file +module-name = "polars_distance._internal" diff --git a/polars_distance/src/array.rs b/polars_distance/src/array.rs index a96432b..88d5c38 100644 --- a/polars_distance/src/array.rs +++ b/polars_distance/src/array.rs @@ -1,31 +1,19 @@ -use distances::vectors::minkowski; +use distances::vectors::{bray_curtis, canberra, chebyshev, l3_norm, l4_norm, manhattan, minkowski}; use polars::prelude::arity::{try_binary_elementwise, try_unary_elementwise}; use polars::prelude::*; -use polars_arrow::array::{new_null_array, Array, PrimitiveArray}; +use polars_arrow::array::{new_null_array, PrimitiveArray}; +use num_traits::{Zero, One, Float, FromPrimitive}; -fn collect_into_vecf64(arr: Box) -> Vec { - arr.as_any() - .downcast_ref::>() - .unwrap() - .values_iter() - .copied() - .collect::>() -} - -fn collect_into_uint64(arr: Box) -> Vec { - arr.as_any() - .downcast_ref::>() - .unwrap() - .values_iter() - .copied() - .collect::>() -} - -pub fn distance_calc_numeric_inp( +pub fn vector_distance_calc( a: &ChunkedArray, b: &ChunkedArray, - f: fn(&[f64], &[f64]) -> f64, -) -> PolarsResult { + distance_fn: F, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + F: Fn(&[T::Native], &[T::Native]) -> T::Native, +{ polars_ensure!( a.inner_dtype() == b.inner_dtype(), ComputeError: "inner data types don't match" @@ -35,134 +23,14 @@ pub fn distance_calc_numeric_inp( ComputeError: "inner data types must be numeric" ); - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; + // Cast to the target float type + let s1 = a.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; + let s2 = b.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; let a: &ArrayChunked = s1.array()?; let b: &ArrayChunked = s2.array()?; // If one side is a literal it will be shorter but is moved to RHS so we can use unsafe access - let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) }; - match b.len() { - 1 => match unsafe { b.get_unchecked(0) } { - Some(b_value) => { - if b_value.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - try_unary_elementwise(a, |a| match a { - Some(a) => { - if a.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b_value.clone()); - Ok(Some(f(a, b))) - } - _ => Ok(None), - }) - } - None => unsafe { - Ok(ChunkedArray::from_chunks( - a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], - )) - }, - }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { - (Some(a), Some(b)) => { - if a.null_count() > 0 || b.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } else { - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b); - Ok(Some(f(a, b))) - } - } - _ => Ok(None), - }), - } -} - -pub fn distance_calc_uint_inp( - a: &ChunkedArray, - b: &ChunkedArray, - f: fn(&[u64], &[u64]) -> f64, -) -> PolarsResult { - polars_ensure!( - a.inner_dtype() == b.inner_dtype(), - ComputeError: "inner data types don't match" - ); - polars_ensure!( - a.inner_dtype().is_unsigned_integer(), - ComputeError: "inner data types must be unsigned integer" - ); - - let s1 = a.cast(&DataType::Array(Box::new(DataType::UInt64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::UInt64), a.width()))?; - - let a: &ArrayChunked = s1.array()?; - let b: &ArrayChunked = s2.array()?; - - let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) }; - match b.len() { - 1 => match unsafe { b.get_unchecked(0) } { - Some(b_value) => { - if b_value.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - try_unary_elementwise(a, |a| match a { - Some(a) => { - if a.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - let a = &collect_into_uint64(a); - let b = &collect_into_uint64(b_value.clone()); - Ok(Some(f(a, b))) - } - _ => Ok(None), - }) - } - None => unsafe { - Ok(ChunkedArray::from_chunks( - a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], - )) - }, - }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { - (Some(a), Some(b)) => { - if a.null_count() > 0 || b.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } else { - let a = &collect_into_uint64(a); - let b = &collect_into_uint64(b); - Ok(Some(f(a, b))) - } - } - _ => Ok(None), - }), - } -} - -pub fn euclidean_dist( - a: &ChunkedArray, - b: &ChunkedArray, -) -> PolarsResult { - polars_ensure!( - a.inner_dtype() == b.inner_dtype(), - ComputeError: "inner data types don't match" - ); - polars_ensure!( - a.inner_dtype().is_numeric(), - ComputeError: "inner data types must be numeric" - ); - - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - - let a: &ArrayChunked = s1.array()?; - let b: &ArrayChunked = s2.array()?; - let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) }; match b.len() { 1 => match unsafe { b.get_unchecked(0) } { @@ -177,25 +45,32 @@ pub fn euclidean_dist( } let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() - .values_iter(); + .values() + .to_vec(); let b = b_value .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() - .values_iter(); - Ok(Some( - a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::().sqrt(), - )) + .values() + .to_vec(); + Ok(Some(distance_fn(&a, &b))) } _ => Ok(None), }) } None => unsafe { + // Use T's data type to create a null array of appropriate type + let arrow_data_type = match T::get_dtype() { + DataType::Float32 => ArrowDataType::Float32, + DataType::Float64 => ArrowDataType::Float64, + _ => unreachable!("T must be Float32Type or Float64Type"), + }; + Ok(ChunkedArray::from_chunks( a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], + vec![new_null_array(arrow_data_type, a.len())], )) }, }, @@ -206,17 +81,17 @@ pub fn euclidean_dist( } else { let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() - .values_iter(); + .values() + .to_vec(); let b = b .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() - .values_iter(); - Ok(Some( - a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::().sqrt(), - )) + .values() + .to_vec(); + Ok(Some(distance_fn(&a, &b))) } } _ => Ok(None), @@ -224,162 +99,139 @@ pub fn euclidean_dist( } } -pub fn cosine_dist( +pub fn euclidean_dist( a: &ChunkedArray, b: &ChunkedArray, -) -> PolarsResult { - polars_ensure!( - a.inner_dtype() == b.inner_dtype(), - ComputeError: "inner data types don't match" - ); - polars_ensure!( - a.inner_dtype().is_numeric(), - ComputeError: "inner data types must be numeric" - ); - - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - - let a: &ArrayChunked = s1.array()?; - let b: &ArrayChunked = s2.array()?; - - let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) }; - match b.len() { - 1 => match unsafe { b.get_unchecked(0) } { - Some(b_value) => { - if b_value.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - try_unary_elementwise(a, |a| match a { - Some(a) => { - if a.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - let a = a - .as_any() - .downcast_ref::>() - .unwrap() - .values_iter(); - let b = b_value - .as_any() - .downcast_ref::>() - .unwrap() - .values_iter(); +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a_slice, b_slice| { + a_slice + .iter() + .zip(b_slice.iter()) + .map(|(x, y)| (*x - *y).powi(2)) + .sum::() + .sqrt() + }) +} - let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum(); - let mag1: f64 = a.map(|x| x.powi(2)).sum::().sqrt(); - let mag2: f64 = b.map(|y| y.powi(2)).sum::().sqrt(); +pub fn cosine_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a_slice, b_slice| { + let dot_prod = a_slice + .iter() + .zip(b_slice.iter()) + .map(|(x, y)| *x * *y) + .sum::(); + let mag1 = a_slice + .iter() + .map(|x| x.powi(2)) + .sum::() + .sqrt(); + let mag2 = b_slice + .iter() + .map(|y| y.powi(2)) + .sum::() + .sqrt(); + + if mag1.is_zero() || mag2.is_zero() { + T::Native::zero() + } else { + T::Native::one() - (dot_prod / (mag1 * mag2)) + } + }) +} - let res = if mag1 == 0.0 || mag2 == 0.0 { - 0.0 - } else { - 1.0 - (dot_prod / (mag1 * mag2)) - }; - Ok(Some(res)) - } - _ => Ok(None), - }) - } - None => unsafe { - Ok(ChunkedArray::from_chunks( - a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], - )) - }, - }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { - (Some(a), Some(b)) => { - if a.null_count() > 0 || b.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } else { - let a = a - .as_any() - .downcast_ref::>() - .unwrap() - .values_iter(); - let b = b - .as_any() - .downcast_ref::>() - .unwrap() - .values_iter(); +pub fn minkowski_dist( + a: &ChunkedArray, + b: &ChunkedArray, + p: i32, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::number::Float, +{ + let metric = minkowski(p); + vector_distance_calc::(a, b, metric) +} - let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum(); - let mag1: f64 = a.map(|x| x.powi(2)).sum::().sqrt(); - let mag2: f64 = b.map(|y| y.powi(2)).sum::().sqrt(); +pub fn chebyshev_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::Number, +{ + vector_distance_calc::(a, b, chebyshev) +} - let res = if mag1 == 0.0 || mag2 == 0.0 { - 0.0 - } else { - 1.0 - (dot_prod / (mag1 * mag2)) - }; - Ok(Some(res)) - } - } - _ => Ok(None), - }), - } +pub fn canberra_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::number::Float, +{ + vector_distance_calc::(a, b, canberra) } -pub fn minkowski_dist( +pub fn manhattan_dist( a: &ChunkedArray, b: &ChunkedArray, - p: i32, -) -> PolarsResult { - polars_ensure!( - a.inner_dtype() == b.inner_dtype(), - ComputeError: "inner data types don't match" - ); - polars_ensure!( - a.inner_dtype().is_numeric(), - ComputeError: "inner data types must be numeric" - ); +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::Number, +{ + vector_distance_calc::(a, b, manhattan) +} - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; +pub fn bray_curtis_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::number::Float, +{ + vector_distance_calc::(a, b, bray_curtis) +} - let a: &ArrayChunked = s1.array()?; - let b: &ArrayChunked = s2.array()?; +pub fn l3_norm_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::number::Float, +{ + vector_distance_calc::(a, b, l3_norm) +} - // If one side is a literal it will be shorter but is moved to RHS so we can use unsafe access - let (a, b) = if a.len() < b.len() { (b, a) } else { (a, b) }; - match b.len() { - 1 => match unsafe { b.get_unchecked(0) } { - Some(b_value) => { - if b_value.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - try_unary_elementwise(a, |a| match a { - Some(a) => { - if a.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b_value.clone()); - let metric = minkowski(p); - Ok(Some(metric(a, b))) - } - _ => Ok(None), - }) - } - None => unsafe { - Ok(ChunkedArray::from_chunks( - a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], - )) - }, - }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { - (Some(a), Some(b)) => { - if a.null_count() > 0 || b.null_count() > 0 { - polars_bail!(ComputeError: "array cannot contain nulls") - } else { - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b); - let metric = minkowski(p); - Ok(Some(metric(a, b))) - } - } - _ => Ok(None), - }), - } +pub fn l4_norm_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + ::Native: distances::number::Float, +{ + vector_distance_calc::(a, b, l4_norm) } diff --git a/polars_distance/src/expressions.rs b/polars_distance/src/expressions.rs index 9a85680..e1f4eff 100644 --- a/polars_distance/src/expressions.rs +++ b/polars_distance/src/expressions.rs @@ -1,5 +1,5 @@ use crate::array::{ - cosine_dist, distance_calc_numeric_inp, distance_calc_uint_inp, euclidean_dist, minkowski_dist, + bray_curtis_dist, canberra_dist, chebyshev_dist, cosine_dist, euclidean_dist, l3_norm_dist, l4_norm_dist, manhattan_dist, minkowski_dist, }; use crate::list::{ cosine_set_distance, jaccard_index, overlap_coef, sorensen_index, tversky_index, @@ -12,12 +12,12 @@ use crate::string::{ osa_normalized_dist, postfix_dist, postfix_normalized_dist, prefix_dist, prefix_normalized_dist, }; -use distances::vectors::{bray_curtis, canberra, chebyshev, l3_norm, l4_norm, manhattan}; use polars::prelude::*; use polars_arrow::array::new_null_array; use pyo3_polars::derive::polars_expr; use serde::Deserialize; + #[derive(Deserialize)] struct TverskyIndexKwargs { alpha: f64, @@ -86,6 +86,7 @@ fn hamming_str(inputs: &[Series]) -> PolarsResult { Ok(elementwise_str_u32(x, y, hamming_dist).into_series()) } + #[polars_expr(output_type=Float64)] fn hamming_normalized_str(inputs: &[Series]) -> PolarsResult { if inputs[0].dtype() != &DataType::String || inputs[1].dtype() != &DataType::String { @@ -274,37 +275,213 @@ fn prefix_normalized_str(inputs: &[Series]) -> PolarsResult { Ok(elementwise_str_f64(x, y, prefix_normalized_dist).into_series()) } +// General helper for all distance metrics +fn infer_distance_arr_output(input_fields: &[Field], metric_name: &str) -> PolarsResult { + // We expect two input Fields for a binary expression (the two Series). + // The first input_fields[0] must be an Array(F32 or F64, width). + // Similarly input_fields[1].dtype should match. + + if input_fields.len() != 2 { + polars_bail!(ShapeMismatch: "{}_arr expects 2 inputs, got {}", metric_name, input_fields.len()); + } + + // Get the type of first input + let first_type = match &input_fields[0].dtype { + // If the input is an Array with an inner dtype (like Float32 or Float64) + DataType::Array(inner, _width) => &**inner, + dt => { + polars_bail!( + ComputeError: + "{}_arr input must be an Array, got {}", + metric_name, + dt + ); + } + }; + + // Get the type of second input + let second_type = match &input_fields[1].dtype { + DataType::Array(inner, _width) => &**inner, + _ => first_type, // Default to first type if second isn't an array + }; + + // If either input is Float64, output is Float64, otherwise Float32 + match (first_type, second_type) { + (DataType::Float32, DataType::Float32) => { + Ok(Field::new(metric_name.into(), DataType::Float32)) + } + (DataType::Float64, _) | (_, DataType::Float64) => { + Ok(Field::new(metric_name.into(), DataType::Float64)) + } + (DataType::UInt64, _) | (_, DataType::UInt64) => { + Ok(Field::new(metric_name.into(), DataType::Float64)) + } + (DataType::Float32, _) | (_, DataType::Float32) => { + Ok(Field::new(metric_name.into(), DataType::Float32)) + } + _ => { + polars_bail!( + ComputeError: + "{} distance not supported for inner types", + metric_name + ); + } + } +} + +fn infer_euclidean_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "euclidean") +} + +fn infer_cosine_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "cosine") +} + +fn infer_chebyshev_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "chebyshev") +} + +fn infer_canberra_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "canberra") +} + +fn infer_manhatten_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "manhatten") +} + +fn infer_l3_norm_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "l3_norm") +} + +fn infer_l4_norm_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "l4_norm") +} + +fn infer_bray_curtis_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "bray_curtis") +} + +fn infer_minkowski_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "minkowski") +} + +// Function to determine the output data type for distance calculations +pub fn determine_distance_output_type( + x_dtype: &DataType, + y_dtype: &DataType, + distance_name: &str, +) -> PolarsResult { + // If any input is Float64 or UInt64, output is Float64 + if matches!(x_dtype, DataType::Float64 | DataType::UInt64) || + matches!(y_dtype, DataType::Float64 | DataType::UInt64) { + return Ok(DataType::Float64); + } + // Both are Float32, output is Float32 + if matches!(x_dtype, DataType::Float32) && matches!(y_dtype, DataType::Float32) { + return Ok(DataType::Float32); + } + polars_bail!( + ComputeError: + "{} distance not supported for inner dtype: {:?} and {:?}", + distance_name, + x_dtype, + y_dtype + ) +} + +fn compute_array_distance( + x: &ArrayChunked, + y: &ArrayChunked, + distance_name: &str, + f32_impl: F, + f64_impl: G +) -> PolarsResult +where + F: FnOnce(&ArrayChunked, &ArrayChunked) -> PolarsResult, + G: FnOnce(&ArrayChunked, &ArrayChunked) -> PolarsResult +{ + if x.width() != y.width() { + polars_bail!(InvalidOperation: + "The dimensions of each array are not the same. + `x` width: {}, + `y` width: {}", x.width(), y.width()); + } + + // Determine the output type + let output_type = determine_distance_output_type(&x.inner_dtype(), &y.inner_dtype(), distance_name)?; + + match output_type { + DataType::Float32 => { + // Both must be Float32 already (from our type determination logic) + f32_impl(x, y) + }, + DataType::Float64 => { + // Create variables to hold the owned data if we need to cast + let x_cast_result = if !matches!(x.inner_dtype(), DataType::Float64) { + Some(x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?) + } else { + None + }; + + let y_cast_result = if !matches!(y.inner_dtype(), DataType::Float64) { + Some(y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?) + } else { + None + }; + + // Now get the references, from either the cast result or the original + let x_f64 = if let Some(cast_x) = &x_cast_result { + cast_x.array()? + } else { + x + }; + + let y_f64 = if let Some(cast_y) = &y_cast_result { + cast_y.array()? + } else { + y + }; + + f64_impl(x_f64, y_f64) + }, + _ => unreachable!("Output type should only be Float32 or Float64") + } +} + // ARRAY EXPRESSIONS -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_euclidean_arr_dtype)] fn euclidean_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; - if x.width() != y.width() { - polars_bail!(InvalidOperation: - "The dimensions of each array are not the same. - `{}` width: {}, - `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); - } - euclidean_dist(x, y).map(|ca| ca.into_series()) + // Use our reusable function to compute the distance + compute_array_distance( + x, + y, + "euclidean", + |x_f32, y_f32| euclidean_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| euclidean_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_cosine_arr_dtype)] fn cosine_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; - if x.width() != y.width() { - polars_bail!(InvalidOperation: - "The dimensions of each array are not the same. - `{}` width: {}, - `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); - } - cosine_dist(x, y).map(|ca| ca.into_series()) + + // Use our reusable function to compute the distance + compute_array_distance( + x, + y, + "cosine", + |x_f32, y_f32| cosine_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| cosine_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_minkowski_arr_dtype)] fn minkowski_arr(inputs: &[Series], kwargs: MinkowskiKwargs) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -315,10 +492,18 @@ fn minkowski_arr(inputs: &[Series], kwargs: MinkowskiKwargs) -> PolarsResult(x_f32, y_f32, p).map(|ca| ca.into_series()), + |x_f64, y_f64| minkowski_dist::(x_f64, y_f64, p).map(|ca| ca.into_series()) + ) +} + +#[polars_expr(output_type_func=infer_chebyshev_arr_dtype)] fn chebyshev_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -329,10 +514,17 @@ fn chebyshev_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}, `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, chebyshev).map(|ca| ca.into_series()) + + compute_array_distance( + x, + y, + "chebyshev", + |x_f32, y_f32| chebyshev_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| chebyshev_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_canberra_arr_dtype)] fn canberra_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -343,10 +535,17 @@ fn canberra_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}, `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, canberra).map(|ca| ca.into_series()) + + compute_array_distance( + x, + y, + "canberra", + |x_f32, y_f32| canberra_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| canberra_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_manhatten_arr_dtype)] fn manhatten_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -358,10 +557,16 @@ fn manhatten_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, manhattan).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "manhatten", + |x_f32, y_f32| manhattan_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| manhattan_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_l3_norm_arr_dtype)] fn l3_norm_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -373,10 +578,16 @@ fn l3_norm_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, l3_norm).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "l3_norm", + |x_f32, y_f32| l3_norm_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| l3_norm_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_l4_norm_arr_dtype)] fn l4_norm_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -388,10 +599,16 @@ fn l4_norm_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, l4_norm).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "l4_norm", + |x_f32, y_f32| l4_norm_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| l4_norm_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_bray_curtis_arr_dtype)] fn bray_curtis_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -403,7 +620,13 @@ fn bray_curtis_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_uint_inp(x, y, bray_curtis).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "bray_curtis", + |x_f32, y_f32| bray_curtis_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| bray_curtis_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } // SET (list) EXPRESSIONS diff --git a/polars_distance/tests/test_distance_arr.py b/polars_distance/tests/test_distance_arr.py index 157f119..e82c317 100644 --- a/polars_distance/tests/test_distance_arr.py +++ b/polars_distance/tests/test_distance_arr.py @@ -1,6 +1,6 @@ -import pytest import polars as pl import polars_distance as pld +import pytest from polars.testing import assert_frame_equal @@ -10,12 +10,16 @@ def data(): { "arr": [[1.0, 2.0, 3.0, 4.0]], "arr2": [[10.0, 8.0, 5.0, 3.0]], + "arr3": [[1.0, 2.0, 3.0, 4.0]], + "arr4": [[10.0, 8.0, 5.0, 3.0]], "str_l": ["hello world"], "str_r": ["hela wrld"], }, schema={ "arr": pl.Array(inner=pl.Float64, shape=4), "arr2": pl.Array(inner=pl.Float64, shape=4), + "arr3": pl.Array(inner=pl.Float32, shape=4), + "arr4": pl.Array(inner=pl.Float32, shape=4), "str_l": pl.Utf8, "str_r": pl.Utf8, }, @@ -37,11 +41,13 @@ def data_sets(): def test_cosine(data): result = data.select( pld.col("arr").dist_arr.cosine("arr2").alias("dist_cosine"), + pld.col("arr3").dist_arr.cosine("arr4").alias("dist_cosine_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_cosine", [0.31232593265732134], dtype=pl.Float64), + pl.Series("dist_cosine_f32", [0.31232593265732134], dtype=pl.Float32), ] ) @@ -51,11 +57,13 @@ def test_cosine(data): def test_chebyshev(data): result = data.select( pld.col("arr").dist_arr.chebyshev("arr2").alias("dist_chebyshev"), + pld.col("arr3").dist_arr.chebyshev("arr4").alias("dist_chebyshev_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_chebyshev", [9.0], dtype=pl.Float64), + pl.Series("dist_chebyshev_f32", [9.0], dtype=pl.Float32), ] ) @@ -65,11 +73,13 @@ def test_chebyshev(data): def test_canberra(data): result = data.select( pld.col("arr").dist_arr.canberra("arr2").alias("dist_canberra"), + pld.col("arr3").dist_arr.canberra("arr4").alias("dist_canberra_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_canberra", [1.811038961038961], dtype=pl.Float64), + pl.Series("dist_canberra_f32", [1.811038961038961], dtype=pl.Float32), ] ) @@ -82,11 +92,16 @@ def test_bray_curtis(data): .cast(pl.Array(pl.UInt64, 4)) .dist_arr.bray_curtis(pl.col("arr2").cast(pl.Array(pl.UInt64, 4))) .alias("dist_bray"), + pld.col("arr3") + .cast(pl.Array(pl.UInt64, 4)) + .dist_arr.bray_curtis(pl.col("arr4").cast(pl.Array(pl.UInt64, 4))) + .alias("dist_bray_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_bray", [0.5], dtype=pl.Float64), + pl.Series("dist_bray_f32", [0.5], dtype=pl.Float64), ] ) @@ -96,11 +111,13 @@ def test_bray_curtis(data): def test_manhatten(data): result = data.select( pld.col("arr").dist_arr.manhatten("arr2").alias("dist_manhatten"), + pld.col("arr3").dist_arr.manhatten("arr4").alias("dist_manhatten_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_manhatten", [18.0], dtype=pl.Float64), + pl.Series("dist_manhatten_f32", [18.0], dtype=pl.Float32), ] ) @@ -110,17 +127,30 @@ def test_manhatten(data): def test_euclidean(data): result = data.select( pld.col("arr").dist_arr.euclidean("arr2").alias("dist_euclidean"), + pld.col("arr3").dist_arr.euclidean("arr4").alias("dist_euclidean_f32"), ) expected = pl.DataFrame( [ pl.Series("dist_euclidean", [11.045361017187261], dtype=pl.Float64), + pl.Series("dist_euclidean_f32", [11.045361017187261], dtype=pl.Float32), ] ) assert_frame_equal(result, expected) +def test_float32_dtype_mixing(data): + """Mixing types should follow Polars casting rules (higher precision wins).""" + result_mixed = data.select( + pl.col("arr2").dist_arr.euclidean("arr3").alias("dist_mixed"), + pl.col("arr3").dist_arr.euclidean("arr2").alias("dist_mixed_rev"), + ) + + assert result_mixed["dist_mixed"].dtype == pl.Float64 + assert result_mixed["dist_mixed_rev"].dtype == pl.Float64 + + def test_hamming_str(data): result = data.select( pld.col("str_l").dist_str.hamming("str_r").alias("dist_hamming"),