diff --git a/polars_distance/Cargo.toml b/polars_distance/Cargo.toml index 1b8d6ab..9284626 100644 --- a/polars_distance/Cargo.toml +++ b/polars_distance/Cargo.toml @@ -17,6 +17,7 @@ serde = { version = "1", features = ["derive"] } distances = { version = "1.6.3"} rapidfuzz = { version = "0.5.0"} gestalt_ratio = { version = "0.2.1"} +num-traits = { version = "0.2" } [target.'cfg(target_os = "linux")'.dependencies] jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/polars_distance/pyproject.toml b/polars_distance/pyproject.toml index 3d0d138..e4c2817 100644 --- a/polars_distance/pyproject.toml +++ b/polars_distance/pyproject.toml @@ -28,4 +28,4 @@ readme = "README.md" [tool.maturin] -module-name = "polars_distance._internal" \ No newline at end of file +module-name = "polars_distance._internal" diff --git a/polars_distance/src/array.rs b/polars_distance/src/array.rs index a96432b..679f38e 100644 --- a/polars_distance/src/array.rs +++ b/polars_distance/src/array.rs @@ -2,6 +2,8 @@ use distances::vectors::minkowski; use polars::prelude::arity::{try_binary_elementwise, try_unary_elementwise}; use polars::prelude::*; use polars_arrow::array::{new_null_array, Array, PrimitiveArray}; +use num_traits::{Zero, One, Float, FromPrimitive}; + fn collect_into_vecf64(arr: Box) -> Vec { arr.as_any() @@ -144,10 +146,14 @@ pub fn distance_calc_uint_inp( } } -pub fn euclidean_dist( +pub fn euclidean_dist( a: &ChunkedArray, b: &ChunkedArray, -) -> PolarsResult { +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ polars_ensure!( a.inner_dtype() == b.inner_dtype(), ComputeError: "inner data types don't match" @@ -157,8 +163,9 @@ pub fn euclidean_dist( ComputeError: "inner data types must be numeric" ); - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; + // Cast to the target float type + let s1 = a.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; + let s2 = b.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; let a: &ArrayChunked = s1.array()?; let b: &ArrayChunked = s2.array()?; @@ -177,25 +184,32 @@ pub fn euclidean_dist( } let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); let b = b_value .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); Ok(Some( - a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::().sqrt(), + a.zip(b).map(|(x, y)| (*x - *y).powi(2)).sum::().sqrt(), )) } _ => Ok(None), }) } None => unsafe { + // Use T's data type to create a null array of appropriate type + let arrow_data_type = match T::get_dtype() { + DataType::Float32 => ArrowDataType::Float32, + DataType::Float64 => ArrowDataType::Float64, + _ => unreachable!("T must be Float32Type or Float64Type"), + }; + Ok(ChunkedArray::from_chunks( a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], + vec![new_null_array(arrow_data_type, a.len())], )) }, }, @@ -206,16 +220,16 @@ pub fn euclidean_dist( } else { let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); let b = b .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); Ok(Some( - a.zip(b).map(|(x, y)| (x - y).powi(2)).sum::().sqrt(), + a.zip(b).map(|(x, y)| (*x - *y).powi(2)).sum::().sqrt(), )) } } @@ -224,10 +238,14 @@ pub fn euclidean_dist( } } -pub fn cosine_dist( +pub fn cosine_dist( a: &ChunkedArray, b: &ChunkedArray, -) -> PolarsResult { +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ polars_ensure!( a.inner_dtype() == b.inner_dtype(), ComputeError: "inner data types don't match" @@ -237,8 +255,9 @@ pub fn cosine_dist( ComputeError: "inner data types must be numeric" ); - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; + // Cast to the target float type + let s1 = a.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; + let s2 = b.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; let a: &ArrayChunked = s1.array()?; let b: &ArrayChunked = s2.array()?; @@ -250,30 +269,33 @@ pub fn cosine_dist( if b_value.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } - try_unary_elementwise(a, |a| match a { + arity::try_unary_elementwise(a, |a| match a { Some(a) => { if a.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); let b = b_value .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); - let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum(); - let mag1: f64 = a.map(|x| x.powi(2)).sum::().sqrt(); - let mag2: f64 = b.map(|y| y.powi(2)).sum::().sqrt(); + let dot_prod: T::Native = a.clone().zip(b.clone()).map(|(x, y)| *x * *y).sum(); + let mag1 = a.map(|x| x.powi(2)).sum::().sqrt(); + let mag2 = b.map(|y| y.powi(2)).sum::().sqrt(); - let res = if mag1 == 0.0 || mag2 == 0.0 { - 0.0 + let zero = T::Native::zero(); + let one = T::Native::one(); + + let res = if mag1 == zero || mag2 == zero { + zero } else { - 1.0 - (dot_prod / (mag1 * mag2)) + one - (dot_prod / (mag1 * mag2)) }; Ok(Some(res)) } @@ -281,36 +303,46 @@ pub fn cosine_dist( }) } None => unsafe { + // Use T's data type to create a null array of appropriate type + let arrow_data_type = match T::get_dtype() { + DataType::Float32 => ArrowDataType::Float32, + DataType::Float64 => ArrowDataType::Float64, + _ => unreachable!("T must be Float32Type or Float64Type"), + }; + Ok(ChunkedArray::from_chunks( a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], + vec![new_null_array(arrow_data_type, a.len())], )) }, }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { + _ => arity::try_binary_elementwise(a, b, |a, b| match (a, b) { (Some(a), Some(b)) => { if a.null_count() > 0 || b.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } else { let a = a .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); let b = b .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap() .values_iter(); - let dot_prod: f64 = a.clone().zip(b.clone()).map(|(x, y)| x * y).sum(); - let mag1: f64 = a.map(|x| x.powi(2)).sum::().sqrt(); - let mag2: f64 = b.map(|y| y.powi(2)).sum::().sqrt(); + let dot_prod: T::Native = a.clone().zip(b.clone()).map(|(x, y)| *x * *y).sum(); + let mag1 = a.map(|x| x.powi(2)).sum::().sqrt(); + let mag2 = b.map(|y| y.powi(2)).sum::().sqrt(); - let res = if mag1 == 0.0 || mag2 == 0.0 { - 0.0 + let zero = T::Native::zero(); + let one = T::Native::one(); + + let res = if mag1 == zero || mag2 == zero { + zero } else { - 1.0 - (dot_prod / (mag1 * mag2)) + one - (dot_prod / (mag1 * mag2)) }; Ok(Some(res)) } @@ -320,11 +352,16 @@ pub fn cosine_dist( } } -pub fn minkowski_dist( +pub fn vector_distance_calc( a: &ChunkedArray, b: &ChunkedArray, - p: i32, -) -> PolarsResult { + distance_fn: F, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, + F: Fn(&[T::Native], &[T::Native]) -> T::Native, +{ polars_ensure!( a.inner_dtype() == b.inner_dtype(), ComputeError: "inner data types don't match" @@ -334,8 +371,9 @@ pub fn minkowski_dist( ComputeError: "inner data types must be numeric" ); - let s1 = a.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; - let s2 = b.cast(&DataType::Array(Box::new(DataType::Float64), a.width()))?; + // Cast to the target float type + let s1 = a.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; + let s2 = b.cast(&DataType::Array(Box::new(T::get_dtype()), a.width()))?; let a: &ArrayChunked = s1.array()?; let b: &ArrayChunked = s2.array()?; @@ -348,38 +386,202 @@ pub fn minkowski_dist( if b_value.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } - try_unary_elementwise(a, |a| match a { + arity::try_unary_elementwise(a, |a| match a { Some(a) => { if a.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b_value.clone()); - let metric = minkowski(p); - Ok(Some(metric(a, b))) + let a = a + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .to_vec(); + let b = b_value + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .to_vec(); + Ok(Some(distance_fn(&a, &b))) } _ => Ok(None), }) } None => unsafe { + // Use T's data type to create a null array of appropriate type + let arrow_data_type = match T::get_dtype() { + DataType::Float32 => ArrowDataType::Float32, + DataType::Float64 => ArrowDataType::Float64, + _ => unreachable!("T must be Float32Type or Float64Type"), + }; + Ok(ChunkedArray::from_chunks( a.name().clone(), - vec![new_null_array(ArrowDataType::Float64, a.len())], + vec![new_null_array(arrow_data_type, a.len())], )) }, }, - _ => try_binary_elementwise(a, b, |a, b| match (a, b) { + _ => arity::try_binary_elementwise(a, b, |a, b| match (a, b) { (Some(a), Some(b)) => { if a.null_count() > 0 || b.null_count() > 0 { polars_bail!(ComputeError: "array cannot contain nulls") } else { - let a = &collect_into_vecf64(a); - let b = &collect_into_vecf64(b); - let metric = minkowski(p); - Ok(Some(metric(a, b))) + let a = a + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .to_vec(); + let b = b + .as_any() + .downcast_ref::>() + .unwrap() + .values() + .to_vec(); + Ok(Some(distance_fn(&a, &b))) } } _ => Ok(None), }), } } + +pub fn minkowski_dist( + a: &ChunkedArray, + b: &ChunkedArray, + p: i32, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + let p_float = T::Native::from_f64(p as f64).unwrap(); + let inv_p = T::Native::one() / p_float; + + vector_distance_calc::(a, b, move |a, b| { + let sum: T::Native = a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs().powf(p_float)) + .sum(); + sum.powf(inv_p) + }) +} + +pub fn chebyshev_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs()) + .fold(T::Native::zero(), |max, val| if val > max { val } else { max }) + }) +} + +pub fn canberra_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + a.iter() + .zip(b.iter()) + .map(|(x, y)| { + let abs_diff = (*x - *y).abs(); + let abs_sum = x.abs() + y.abs(); + if abs_sum > T::Native::zero() { + abs_diff / abs_sum + } else { + T::Native::zero() + } + }) + .sum() + }) +} + +pub fn manhattan_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs()) + .sum() + }) +} + +pub fn bray_curtis_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + let sum_abs_diff: T::Native = a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs()) + .sum(); + let sum_abs_sum: T::Native = a.iter() + .zip(b.iter()) + .map(|(x, y)| x.abs() + y.abs()) + .sum(); + + if sum_abs_sum > T::Native::zero() { + sum_abs_diff / sum_abs_sum + } else { + T::Native::zero() + } + }) +} + +pub fn l3_norm_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + let sum: T::Native = a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs().powi(3)) + .sum(); + let one_third = T::Native::from_f64(1.0/3.0).unwrap(); + sum.powf(one_third) + }) +} + +pub fn l4_norm_dist( + a: &ChunkedArray, + b: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float + std::ops::Sub + FromPrimitive + Zero + One, +{ + vector_distance_calc::(a, b, |a, b| { + let sum: T::Native = a.iter() + .zip(b.iter()) + .map(|(x, y)| (*x - *y).abs().powi(4)) + .sum(); + let one_fourth = T::Native::from_f64(0.25).unwrap(); + sum.powf(one_fourth) + }) +} diff --git a/polars_distance/src/expressions.rs b/polars_distance/src/expressions.rs index 9a85680..be03afe 100644 --- a/polars_distance/src/expressions.rs +++ b/polars_distance/src/expressions.rs @@ -18,6 +18,7 @@ use polars_arrow::array::new_null_array; use pyo3_polars::derive::polars_expr; use serde::Deserialize; + #[derive(Deserialize)] struct TverskyIndexKwargs { alpha: f64, @@ -86,6 +87,7 @@ fn hamming_str(inputs: &[Series]) -> PolarsResult { Ok(elementwise_str_u32(x, y, hamming_dist).into_series()) } + #[polars_expr(output_type=Float64)] fn hamming_normalized_str(inputs: &[Series]) -> PolarsResult { if inputs[0].dtype() != &DataType::String || inputs[1].dtype() != &DataType::String { @@ -274,37 +276,244 @@ fn prefix_normalized_str(inputs: &[Series]) -> PolarsResult { Ok(elementwise_str_f64(x, y, prefix_normalized_dist).into_series()) } +// General helper for all distance metrics +fn infer_distance_arr_output(input_fields: &[Field], metric_name: &str) -> PolarsResult { + // We expect two input Fields for a binary expression (the two Series). + // The first input_fields[0] must be an Array(F32 or F64, width). + // Similarly input_fields[1].dtype should match. + + if input_fields.len() != 2 { + polars_bail!(ShapeMismatch: "{}_arr expects 2 inputs, got {}", metric_name, input_fields.len()); + } + + // Get the type of first input + let first_type = match &input_fields[0].dtype { + // If the input is an Array with an inner dtype (like Float32 or Float64) + DataType::Array(inner, _width) => &**inner, + dt => { + polars_bail!( + ComputeError: + "{}_arr input must be an Array, got {}", + metric_name, + dt + ); + } + }; + + // Get the type of second input + let second_type = match &input_fields[1].dtype { + DataType::Array(inner, _width) => &**inner, + _ => first_type, // Default to first type if second isn't an array + }; + + // If either input is Float64, output is Float64, otherwise Float32 + match (first_type, second_type) { + (DataType::Float32, DataType::Float32) => { + Ok(Field::new(metric_name.into(), DataType::Float32)) + } + (DataType::Float64, _) | (_, DataType::Float64) => { + Ok(Field::new(metric_name.into(), DataType::Float64)) + } + (DataType::UInt64, _) | (_, DataType::UInt64) => { + Ok(Field::new(metric_name.into(), DataType::Float64)) + } + (DataType::Float32, _) | (_, DataType::Float32) => { + Ok(Field::new(metric_name.into(), DataType::Float32)) + } + _ => { + polars_bail!( + ComputeError: + "{} distance not supported for inner types", + metric_name + ); + } + } +} + +fn infer_euclidean_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "euclidean") +} + +fn infer_cosine_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "cosine") +} + +fn infer_chebyshev_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "chebyshev") +} + +fn infer_canberra_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "canberra") +} + +fn infer_manhatten_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "manhatten") +} + +fn infer_l3_norm_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "l3_norm") +} + +fn infer_l4_norm_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "l4_norm") +} + +fn infer_bray_curtis_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "bray_curtis") +} + +fn infer_minkowski_arr_dtype(input_fields: &[Field]) -> PolarsResult { + infer_distance_arr_output(input_fields, "minkowski") +} + +// Function to determine the output data type for distance calculations +pub fn determine_distance_output_type( + x_dtype: &DataType, + y_dtype: &DataType, + distance_name: &str, +) -> PolarsResult { + match (x_dtype, y_dtype) { + // Both Float32 - leave as Float32 + (DataType::Float32, DataType::Float32) => Ok(DataType::Float32), + + // Mix of Float32 and Float64 - Float64 wins + (DataType::Float32, DataType::Float64) => Ok(DataType::Float64), + (DataType::Float64, DataType::Float32) => Ok(DataType::Float64), + (DataType::Float64, DataType::Float64) => Ok(DataType::Float64), + + + // UInt64 and either Float32/Float64 - convert to Float64 + (DataType::UInt64, DataType::UInt64) => Ok(DataType::Float64), + (DataType::UInt64, DataType::Float32) => Ok(DataType::Float64), + (DataType::UInt64, DataType::Float64) => Ok(DataType::Float64), + (DataType::Float32, DataType::UInt64) => Ok(DataType::Float64), + (DataType::Float64, DataType::UInt64) => Ok(DataType::Float64), + + (other, _) => polars_bail!( + ComputeError: + "{} distance not supported for inner dtype: {}", + distance_name, + other + ), + } +} + +pub fn compute_array_distance( + x: &ArrayChunked, + y: &ArrayChunked, + distance_name: &str, + f32_impl: F, + f64_impl: G +) -> PolarsResult +where + F: FnOnce(&ArrayChunked, &ArrayChunked) -> PolarsResult, + G: FnOnce(&ArrayChunked, &ArrayChunked) -> PolarsResult +{ + if x.width() != y.width() { + polars_bail!(InvalidOperation: + "The dimensions of each array are not the same. + `x` width: {}, + `y` width: {}", x.width(), y.width()); + } + + + // Determine the output type + let output_type = determine_distance_output_type(&x.inner_dtype(), &y.inner_dtype(), distance_name)?; + + // Handle the type casting and calculation + match (x.inner_dtype(), y.inner_dtype(), &output_type) { + // Float32 unmixed case - keep as Float32 + // Both Float32 -> Float32 output + (DataType::Float32, DataType::Float32, DataType::Float32) => { + f32_impl(x, y) + }, + // ---------------------------------------------------------- + // Float64 un/mixed cases - cast to Float64 + // Both Float64 -> Float64 output + (DataType::Float64, DataType::Float64, DataType::Float64) => { + f64_impl(x, y) + }, + // Float32 and Float64 -> cast Float32 to Float64, Float64 output + (DataType::Float32, DataType::Float64, DataType::Float64) => { + // Cast x to Float64 + let x_f64 = x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?; + f64_impl(x_f64.array()?, y) + }, + // Float64 and Float32 -> cast Float32 to Float64, Float64 output + (DataType::Float64, DataType::Float32, DataType::Float64) => { + // Cast y to Float64 + let y_f64 = y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?; + f64_impl(x, y_f64.array()?) + }, + // ---------------------------------------------------------- + // UInt64 un/mixed cases - cast to Float64 + (DataType::UInt64, DataType::UInt64, DataType::Float64) => { + // Cast both to Float64 + let x_f64 = x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?; + let y_f64 = y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?; + f64_impl(x_f64.array()?, y_f64.array()?) + }, + (DataType::UInt64, DataType::Float32, DataType::Float64) => { + // Cast both to Float64 + let x_f64 = x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?; + let y_f64 = y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?; + f64_impl(x_f64.array()?, y_f64.array()?) + }, + (DataType::UInt64, DataType::Float64, DataType::Float64) => { + // Cast x to Float64 + let x_f64 = x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?; + f64_impl(x_f64.array()?, y) + }, + (DataType::Float32, DataType::UInt64, DataType::Float64) => { + // Cast both to Float64 + let x_f64 = x.cast(&DataType::Array(Box::new(DataType::Float64), x.width()))?; + let y_f64 = y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?; + f64_impl(x_f64.array()?, y_f64.array()?) + }, + (DataType::Float64, DataType::UInt64, DataType::Float64) => { + // Cast y to Float64 + let y_f64 = y.cast(&DataType::Array(Box::new(DataType::Float64), y.width()))?; + f64_impl(x, y_f64.array()?) + }, + // This should be unreachable since we've filtered the types above + _ => unreachable!() + } +} + // ARRAY EXPRESSIONS -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_euclidean_arr_dtype)] fn euclidean_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; - if x.width() != y.width() { - polars_bail!(InvalidOperation: - "The dimensions of each array are not the same. - `{}` width: {}, - `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); - } - euclidean_dist(x, y).map(|ca| ca.into_series()) + // Use our reusable function to compute the distance + compute_array_distance( + x, + y, + "euclidean", + |x_f32, y_f32| crate::array::euclidean_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::euclidean_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_cosine_arr_dtype)] fn cosine_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; - if x.width() != y.width() { - polars_bail!(InvalidOperation: - "The dimensions of each array are not the same. - `{}` width: {}, - `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); - } - cosine_dist(x, y).map(|ca| ca.into_series()) + + // Use our reusable function to compute the distance + compute_array_distance( + x, + y, + "cosine", + |x_f32, y_f32| crate::array::cosine_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::cosine_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_minkowski_arr_dtype)] fn minkowski_arr(inputs: &[Series], kwargs: MinkowskiKwargs) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -315,10 +524,18 @@ fn minkowski_arr(inputs: &[Series], kwargs: MinkowskiKwargs) -> PolarsResult(x_f32, y_f32, p).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::minkowski_dist::(x_f64, y_f64, p).map(|ca| ca.into_series()) + ) +} + +#[polars_expr(output_type_func=infer_chebyshev_arr_dtype)] fn chebyshev_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -329,10 +546,17 @@ fn chebyshev_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}, `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, chebyshev).map(|ca| ca.into_series()) + + compute_array_distance( + x, + y, + "chebyshev", + |x_f32, y_f32| crate::array::chebyshev_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::chebyshev_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_canberra_arr_dtype)] fn canberra_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -343,10 +567,17 @@ fn canberra_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}, `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, canberra).map(|ca| ca.into_series()) + + compute_array_distance( + x, + y, + "canberra", + |x_f32, y_f32| crate::array::canberra_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::canberra_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_manhatten_arr_dtype)] fn manhatten_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -358,10 +589,16 @@ fn manhatten_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, manhattan).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "manhatten", + |x_f32, y_f32| crate::array::manhattan_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::manhattan_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_l3_norm_arr_dtype)] fn l3_norm_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -373,10 +610,16 @@ fn l3_norm_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, l3_norm).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "l3_norm", + |x_f32, y_f32| crate::array::l3_norm_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::l3_norm_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_l4_norm_arr_dtype)] fn l4_norm_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -388,10 +631,16 @@ fn l4_norm_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_numeric_inp(x, y, l4_norm).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "l4_norm", + |x_f32, y_f32| crate::array::l4_norm_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::l4_norm_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } -#[polars_expr(output_type=Float64)] +#[polars_expr(output_type_func=infer_bray_curtis_arr_dtype)] fn bray_curtis_arr(inputs: &[Series]) -> PolarsResult { let x: &ArrayChunked = inputs[0].array()?; let y: &ArrayChunked = inputs[1].array()?; @@ -403,7 +652,13 @@ fn bray_curtis_arr(inputs: &[Series]) -> PolarsResult { `{}` width: {}", inputs[0].name(), x.width(), inputs[1].name(), y.width()); } - distance_calc_uint_inp(x, y, bray_curtis).map(|ca| ca.into_series()) + compute_array_distance( + x, + y, + "bray_curtis", + |x_f32, y_f32| crate::array::bray_curtis_dist::(x_f32, y_f32).map(|ca| ca.into_series()), + |x_f64, y_f64| crate::array::bray_curtis_dist::(x_f64, y_f64).map(|ca| ca.into_series()) + ) } // SET (list) EXPRESSIONS diff --git a/polars_distance/tests/test_distance_arr.py b/polars_distance/tests/test_distance_arr.py index 157f119..fd61c13 100644 --- a/polars_distance/tests/test_distance_arr.py +++ b/polars_distance/tests/test_distance_arr.py @@ -1,6 +1,6 @@ -import pytest import polars as pl import polars_distance as pld +import pytest from polars.testing import assert_frame_equal @@ -10,12 +10,16 @@ def data(): { "arr": [[1.0, 2.0, 3.0, 4.0]], "arr2": [[10.0, 8.0, 5.0, 3.0]], + "arr3": [[1.0, 2.0, 3.0, 4.0]], + "arr4": [[10.0, 8.0, 5.0, 3.0]], "str_l": ["hello world"], "str_r": ["hela wrld"], }, schema={ "arr": pl.Array(inner=pl.Float64, shape=4), "arr2": pl.Array(inner=pl.Float64, shape=4), + "arr3": pl.Array(inner=pl.Float32, shape=4), + "arr4": pl.Array(inner=pl.Float32, shape=4), "str_l": pl.Utf8, "str_r": pl.Utf8, }, @@ -121,6 +125,25 @@ def test_euclidean(data): assert_frame_equal(result, expected) +def test_euclidean_float32(data): + # Create dataframe with float32 arrays + # Test with float32 arrays + result = data.select( + pld.col("arr3").dist_arr.euclidean("arr4").alias("dist_euclidean_f32"), + ) + # Check both the type and approximate value + assert result["dist_euclidean_f32"].dtype == pl.Float32 + assert abs(result["dist_euclidean_f32"][0] - 11.045361) < 1e-5 + + # Mixing types should follow Polars casting rules (higher precision wins) + result_mixed = data.select( + pl.col("arr2").dist_arr.euclidean("arr3").alias("dist_mixed"), + pl.col("arr3").dist_arr.euclidean("arr2").alias("dist_mixed_rev"), + ) + assert result_mixed["dist_mixed"].dtype == pl.Float64 + assert result_mixed["dist_mixed_rev"].dtype == pl.Float64 + + def test_hamming_str(data): result = data.select( pld.col("str_l").dist_str.hamming("str_r").alias("dist_hamming"),