diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..b4bce19 Binary files /dev/null and b/.DS_Store differ diff --git a/polars_hash/.DS_Store b/polars_hash/.DS_Store new file mode 100644 index 0000000..ebc16a3 Binary files /dev/null and b/polars_hash/.DS_Store differ diff --git a/polars_hash/src/expressions.rs b/polars_hash/src/expressions.rs index 47ab732..99fa9ef 100644 --- a/polars_hash/src/expressions.rs +++ b/polars_hash/src/expressions.rs @@ -1,3 +1,5 @@ +use std::option::Option; // Import Option and its methods + use crate::geohashers::{geohash_decoder, geohash_encoder, geohash_neighbors}; use crate::h3::h3_encoder; use crate::sha_hashers::*; @@ -178,22 +180,33 @@ fn ghash_encode(inputs: &[Series]) -> PolarsResult { DataType::Int32 => inputs[1].cast(&DataType::Int64)?, DataType::Int16 => inputs[1].cast(&DataType::Int64)?, DataType::Int8 => inputs[1].cast(&DataType::Int64)?, - _ => polars_bail!(InvalidOperation:"Length input needs to be integer"), + _ => polars_bail!(InvalidOperation: "Length input needs to be integer"), }; let len = len.i64()?; + // Create a default base if not provided + let default_base = UInt8Chunked::full("base", 16, ca.len()); + + let base = match inputs.get(2) { + Some(base_series) => match base_series.dtype() { + DataType::UInt8 => base_series.u8()?, + _ => polars_bail!(InvalidOperation: "Base input needs to be uint8"), + }, + None => &default_base, // Borrow the default base 16 chunked array + }; + let lat = ca.field_by_name("latitude")?; let long = ca.field_by_name("longitude")?; let lat = match lat.dtype() { DataType::Float32 => lat.cast(&DataType::Float64)?, DataType::Float64 => lat, - _ => polars_bail!(InvalidOperation:"Latitude input needs to be float"), + _ => polars_bail!(InvalidOperation: "Latitude input needs to be float"), }; let long = match long.dtype() { DataType::Float32 => long.cast(&DataType::Float64)?, DataType::Float64 => long, - _ => polars_bail!(InvalidOperation:"Longitude input needs to be float"), + _ => polars_bail!(InvalidOperation: "Longitude input needs to be float"), }; let ca_lat = lat.f64()?; @@ -201,18 +214,26 @@ fn ghash_encode(inputs: &[Series]) -> PolarsResult { let out: StringChunked = match len.len() { 1 => match unsafe { len.get_unchecked(0) } { - Some(len) => try_binary_elementwise(ca_lat, ca_long, |ca_lat_opt, ca_long_opt| { + Some(len) => try_ternary_elementwise(ca_lat, ca_long, base, |ca_lat_opt, ca_long_opt, _base_opt| { geohash_encoder(ca_lat_opt, ca_long_opt, Some(len)) }), _ => Err(PolarsError::ComputeError( "Length may not be null".to_string().into(), )), }, - _ => try_ternary_elementwise(ca_lat, ca_long, len, geohash_encoder), + _ => try_ternary_elementwise(ca_lat, ca_long, len, |ca_lat_opt, ca_long_opt, _base_opt| { + geohash_encoder(ca_lat_opt, ca_long_opt, len) + }), }?; Ok(out.into_series()) } + + + + + + #[polars_expr(output_type=String)] fn h3_encode(inputs: &[Series]) -> PolarsResult { let ca = inputs[0].struct_()?;