Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added polars_hash/.DS_Store
Binary file not shown.
31 changes: 26 additions & 5 deletions polars_hash/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::option::Option; // Import Option and its methods

use crate::geohashers::{geohash_decoder, geohash_encoder, geohash_neighbors};
use crate::h3::h3_encoder;
use crate::sha_hashers::*;
Expand Down Expand Up @@ -178,41 +180,60 @@ fn ghash_encode(inputs: &[Series]) -> PolarsResult<Series> {
DataType::Int32 => inputs[1].cast(&DataType::Int64)?,
DataType::Int16 => inputs[1].cast(&DataType::Int64)?,
DataType::Int8 => inputs[1].cast(&DataType::Int64)?,
_ => polars_bail!(InvalidOperation:"Length input needs to be integer"),
_ => polars_bail!(InvalidOperation: "Length input needs to be integer"),
};
let len = len.i64()?;

// Create a default base if not provided
let default_base = UInt8Chunked::full("base", 16, ca.len());

let base = match inputs.get(2) {
Some(base_series) => match base_series.dtype() {
DataType::UInt8 => base_series.u8()?,
_ => polars_bail!(InvalidOperation: "Base input needs to be uint8"),
},
None => &default_base, // Borrow the default base 16 chunked array
};

let lat = ca.field_by_name("latitude")?;
let long = ca.field_by_name("longitude")?;
let lat = match lat.dtype() {
DataType::Float32 => lat.cast(&DataType::Float64)?,
DataType::Float64 => lat,
_ => polars_bail!(InvalidOperation:"Latitude input needs to be float"),
_ => polars_bail!(InvalidOperation: "Latitude input needs to be float"),
};

let long = match long.dtype() {
DataType::Float32 => long.cast(&DataType::Float64)?,
DataType::Float64 => long,
_ => polars_bail!(InvalidOperation:"Longitude input needs to be float"),
_ => polars_bail!(InvalidOperation: "Longitude input needs to be float"),
};

let ca_lat = lat.f64()?;
let ca_long = long.f64()?;

let out: StringChunked = match len.len() {
1 => match unsafe { len.get_unchecked(0) } {
Some(len) => try_binary_elementwise(ca_lat, ca_long, |ca_lat_opt, ca_long_opt| {
Some(len) => try_ternary_elementwise(ca_lat, ca_long, base, |ca_lat_opt, ca_long_opt, _base_opt| {
geohash_encoder(ca_lat_opt, ca_long_opt, Some(len))
}),
_ => Err(PolarsError::ComputeError(
"Length may not be null".to_string().into(),
)),
},
_ => try_ternary_elementwise(ca_lat, ca_long, len, geohash_encoder),
_ => try_ternary_elementwise(ca_lat, ca_long, len, |ca_lat_opt, ca_long_opt, _base_opt| {
geohash_encoder(ca_lat_opt, ca_long_opt, len)
}),
}?;
Ok(out.into_series())
}







#[polars_expr(output_type=String)]
fn h3_encode(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].struct_()?;
Expand Down