-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat(cubestore): Add XIRR
aggregate function to Cube Store
#9520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #9520 +/- ##
=======================================
Coverage 83.86% 83.86%
=======================================
Files 230 230
Lines 83825 83825
=======================================
Hits 70297 70297
Misses 13528 13528
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
For what it's worth here is a diff of the new cubestore xirr implementation against the cubesql xirr.rs. @@ -1,21 +1,28 @@
use std::sync::Arc;
+use chrono::Datelike as _;
use datafusion::{
arrow::{
- array::{ArrayRef, Date32Array, Float64Array, ListArray},
+ array::{ArrayRef, Date32Array, Float64Array, Int32Array, ListArray},
compute::cast,
datatypes::{DataType, Field, TimeUnit},
},
error::{DataFusionError, Result},
- logical_expr::{
- Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
- Signature, StateTypeFunction, TypeSignature, Volatility,
+ physical_plan::{
+ aggregates::{AccumulatorFunctionImplementation, StateTypeFunction},
+ functions::{ReturnTypeFunction, Signature},
+ udaf::AggregateUDF,
+ Accumulator,
},
scalar::ScalarValue,
};
+use smallvec::SmallVec;
-// Note: A copy/pasted and minimally(?) modified version of this is in Cubestore in udf_xirr.rs, and you might
-// want to update both.
+// This is copy/pasted and edited from cubesql in a file xirr.rs -- you might need to update both.
+//
+// Some differences here:
+// - the Accumulator trait has reset, merge, and update functions that operate on ScalarValues.
+// - List of Date32 isn't allowed, so we use List of Int32 in state values.
pub const XIRR_UDAF_NAME: &str = "xirr";
@@ -62,20 +69,20 @@
for payment_type in NUMERIC_TYPES {
for date_type in DATETIME_TYPES {
// Base signatures without `initial_guess` and `on_error` arguments
- type_signatures.push(TypeSignature::Exact(vec![
+ type_signatures.push(Signature::Exact(vec![
payment_type.clone(),
date_type.clone(),
]));
// Signatures with `initial_guess` argument; only [`DataType::Float64`] is accepted
const INITIAL_GUESS_TYPE: DataType = DataType::Float64;
- type_signatures.push(TypeSignature::Exact(vec![
+ type_signatures.push(Signature::Exact(vec![
payment_type.clone(),
date_type.clone(),
INITIAL_GUESS_TYPE,
]));
// Signatures with `initial_guess` and `on_error` arguments
for on_error_type in NUMERIC_TYPES {
- type_signatures.push(TypeSignature::Exact(vec![
+ type_signatures.push(Signature::Exact(vec![
payment_type.clone(),
date_type.clone(),
INITIAL_GUESS_TYPE,
@@ -86,17 +93,14 @@
}
type_signatures
};
- let signature = Signature::one_of(
- type_signatures,
- Volatility::Volatile, // due to the usage of [`f64::powf`]
- );
+ let signature = Signature::OneOf(type_signatures);
let return_type: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|| Ok(Box::new(XirrAccumulator::new())));
let state_type: StateTypeFunction = Arc::new(|_| {
Ok(Arc::new(vec![
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
- DataType::List(Box::new(Field::new("item", DataType::Date32, true))),
+ DataType::List(Box::new(Field::new("item", DataType::Int32, true))), // Date32
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
]))
@@ -105,7 +109,7 @@
}
#[derive(Debug)]
-struct XirrAccumulator {
+pub struct XirrAccumulator {
/// Pairs of (payment, date).
pairs: Vec<(f64, i32)>,
initial_guess: ValueState<f64>,
@@ -113,7 +117,7 @@
}
impl XirrAccumulator {
- fn new() -> Self {
+ pub fn new() -> Self {
XirrAccumulator {
pairs: vec![],
initial_guess: ValueState::Unset,
@@ -169,14 +173,256 @@
}
}
+fn cast_scalar_to_float64(scalar: &ScalarValue) -> Result<Option<f64>> {
+ fn err(from_type: &str) -> Result<Option<f64>> {
+ Err(DataFusionError::Internal(format!(
+ "cannot cast {} to Float64",
+ from_type
+ )))
+ }
+ match scalar {
+ ScalarValue::Boolean(_) => err("Boolean"),
+ ScalarValue::Float32(o) => Ok(o.map(f64::from)),
+ ScalarValue::Float64(o) => Ok(*o),
+ ScalarValue::Int8(o) => Ok(o.map(f64::from)),
+ ScalarValue::Int16(o) => Ok(o.map(f64::from)),
+ ScalarValue::Int32(o) => Ok(o.map(f64::from)),
+ ScalarValue::Int64(o) => Ok(o.map(|x| x as f64)),
+ ScalarValue::Int96(o) => Ok(o.map(|x| x as f64)),
+ ScalarValue::Int64Decimal(o, scale) => {
+ Ok(o.map(|x| (x as f64) / 10f64.powi(*scale as i32)))
+ }
+ ScalarValue::Int96Decimal(o, scale) => {
+ Ok(o.map(|x| (x as f64) / 10f64.powi(*scale as i32)))
+ }
+ ScalarValue::UInt8(o) => Ok(o.map(f64::from)),
+ ScalarValue::UInt16(o) => Ok(o.map(f64::from)),
+ ScalarValue::UInt32(o) => Ok(o.map(f64::from)),
+ ScalarValue::UInt64(o) => Ok(o.map(|x| x as f64)),
+ ScalarValue::Utf8(_) => err("Utf8"),
+ ScalarValue::LargeUtf8(_) => err("LargeUtf8"),
+ ScalarValue::Binary(_) => err("Binary"),
+ ScalarValue::LargeBinary(_) => err("LargeBinary"),
+ ScalarValue::List(_, _dt) => err("List"),
+ ScalarValue::Date32(_) => err("Date32"),
+ ScalarValue::Date64(_) => err("Date64"),
+ ScalarValue::TimestampSecond(_) => err("TimestampSecond"),
+ ScalarValue::TimestampMillisecond(_) => err("TimestampMillisecond"),
+ ScalarValue::TimestampMicrosecond(_) => err("TimestampMicrosecond"),
+ ScalarValue::TimestampNanosecond(_) => err("TimestampNanosecond"),
+ ScalarValue::IntervalYearMonth(_) => err("IntervalYearMonth"),
+ ScalarValue::IntervalDayTime(_) => err("IntervalDayTime"),
+ }
+}
+
+fn cast_scalar_to_date32(scalar: &ScalarValue) -> Result<Option<i32>> {
+ fn err(from_type: &str) -> Result<Option<i32>> {
+ Err(DataFusionError::Internal(format!(
+ "cannot cast {} to Date32",
+ from_type
+ )))
+ }
+ fn string_to_date32(o: &Option<String>) -> Result<Option<i32>> {
+ if let Some(x) = o {
+ // Consistent with cast() in update_batch being configured with the "safe" option true, so we return None (null value) if there is a cast error.
+ Ok(x.parse::<chrono::NaiveDate>()
+ .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
+ .ok())
+ } else {
+ Ok(None)
+ }
+ }
+
+ // Number of days between 0001-01-01 and 1970-01-01
+ const EPOCH_DAYS_FROM_CE: i32 = 719_163;
+
+ const SECONDS_IN_DAY: i64 = 86_400;
+ const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * 1_000;
+
+ match scalar {
+ ScalarValue::Boolean(_) => err("Boolean"),
+ ScalarValue::Float32(_) => err("Float32"),
+ ScalarValue::Float64(_) => err("Float64"),
+ ScalarValue::Int8(_) => err("Int8"),
+ ScalarValue::Int16(_) => err("Int16"),
+ ScalarValue::Int32(o) => Ok(*o),
+ ScalarValue::Int64(o) => Ok(o.and_then(|x| num::NumCast::from(x))),
+ ScalarValue::Int96(_) => err("Int96"),
+ ScalarValue::Int64Decimal(_, _scale) => err("Int64Decimal"),
+ ScalarValue::Int96Decimal(_, _scale) => err("Int96Decimal"),
+ ScalarValue::UInt8(_) => err("UInt8"),
+ ScalarValue::UInt16(_) => err("UInt16"),
+ ScalarValue::UInt32(_) => err("UInt32"),
+ ScalarValue::UInt64(_) => err("UInt64"),
+ ScalarValue::Utf8(o) => string_to_date32(o),
+ ScalarValue::LargeUtf8(o) => string_to_date32(o),
+ ScalarValue::Binary(_) => err("Binary"),
+ ScalarValue::LargeBinary(_) => err("LargeBinary"),
+ ScalarValue::List(_, _dt) => err("List"),
+ ScalarValue::Date32(o) => Ok(*o),
+ ScalarValue::Date64(o) => Ok(o.map(|x| (x / MILLISECONDS_IN_DAY) as i32)),
+ ScalarValue::TimestampSecond(o) => Ok(o.map(|x| (x / SECONDS_IN_DAY) as i32)),
+ ScalarValue::TimestampMillisecond(o) => Ok(o.map(|x| (x / MILLISECONDS_IN_DAY) as i32)),
+ ScalarValue::TimestampMicrosecond(o) => {
+ Ok(o.map(|x| (x / (1_000_000 * SECONDS_IN_DAY)) as i32))
+ }
+ ScalarValue::TimestampNanosecond(o) => {
+ Ok(o.map(|x| (x / (1_000_000_000 * SECONDS_IN_DAY)) as i32))
+ }
+ ScalarValue::IntervalYearMonth(_) => err("IntervalYearMonth"),
+ ScalarValue::IntervalDayTime(_) => err("IntervalDayTime"),
+ }
+}
+
impl Accumulator for XirrAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
+ fn reset(&mut self) {
+ self.pairs.clear();
+ self.initial_guess = ValueState::Unset;
+ self.on_error = ValueState::Unset;
+ }
+
+ fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+ let payment = cast_scalar_to_float64(&values[0])?;
+ let date = cast_scalar_to_date32(&values[1])?;
+ self.add_pair(payment, date)?;
+ let values_len = values.len();
+ if values_len < 3 {
+ return Ok(());
+ }
+ let ScalarValue::Float64(initial_guess) = values[2] else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR initial guess should be a Float64 but it was of type {}",
+ values[2].get_datatype()
+ )));
+ };
+ self.set_initial_guess(initial_guess)?;
+ if values_len < 4 {
+ return Ok(());
+ }
+ let on_error = cast_scalar_to_float64(&values[3])?;
+ self.set_on_error(on_error)?;
+ Ok(())
+ }
+
+ fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+ if states.len() != 4 {
+ return Err(DataFusionError::Internal(format!(
+ "Merging XIRR states list with {} columns instead of 4",
+ states.len()
+ )));
+ }
+ // payments and dates
+ {
+ let ScalarValue::List(payments, payments_datatype) = &states[0] else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR payments state must be a List but it was of type {}",
+ states[0].get_datatype()
+ )));
+ };
+ if payments_datatype.as_ref() != &DataType::Float64 {
+ return Err(DataFusionError::Internal(format!("XIRR payments state must be a List of Float64 but it was a List with element type {}", payments_datatype)));
+ }
+ let ScalarValue::List(dates, dates_datatype) = &states[1] else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR dates state must be a List but it was of type {}",
+ states[1].get_datatype()
+ )));
+ };
+ if dates_datatype.as_ref() != &DataType::Int32 {
+ return Err(DataFusionError::Internal(format!("XIRR dates state must be a List of Int32 but it was a List with element type {}", dates_datatype)));
+ }
+ let Some(payments) = payments else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR payments state is null in merge"
+ )));
+ };
+ let Some(dates) = dates else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR dates state is null, payments not null in merge"
+ )));
+ };
+
+ for (payment, date) in payments.iter().zip(dates.iter()) {
+ let ScalarValue::Float64(payment) = payment else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR payment in List is not a Float64"
+ )));
+ };
+ let ScalarValue::Int32(date) = date else {
+ // Date32
+ return Err(DataFusionError::Internal(format!(
+ "XIRR date in List is not an Int32"
+ )));
+ };
+ self.add_pair(*payment, *date)?;
+ }
+ }
+ // initial_guess
+ {
+ let ScalarValue::List(initial_guess_list, initial_guess_dt) = &states[2] else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR initial guess state is not a List in merge"
+ )));
+ };
+ if initial_guess_dt.as_ref() != &DataType::Float64 {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR initial guess state is not a List of Float64 in merge"
+ )));
+ }
+ let Some(initial_guess_list) = initial_guess_list else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR initial guess state is a null list in merge"
+ )));
+ };
+ // To be clear this list has 0 or 1 elements which may be null.
+ for initial_guess in initial_guess_list.iter() {
+ let ScalarValue::Float64(guess) = initial_guess else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR initial guess in List is not a Float64"
+ )));
+ };
+ self.set_initial_guess(*guess)?;
+ }
+ }
+ // on_error
+ {
+ let ScalarValue::List(on_error_list, on_error_dt) = &states[3] else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR on_error state is not a List in merge"
+ )));
+ };
+ if on_error_dt.as_ref() != &DataType::Float64 {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR on_error state is not a List of Float64 in merge"
+ )));
+ }
+
+ let Some(on_error_list) = on_error_list else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR on_error state is a null list in merge"
+ )));
+ };
+ // To be clear this list has 0 or 1 elements which may be null.
+ for on_error in on_error_list.iter() {
+ let ScalarValue::Float64(on_error) = on_error else {
+ return Err(DataFusionError::Internal(format!(
+ "XIRR on_error in List is not a Float64"
+ )));
+ };
+ self.set_on_error(*on_error)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn state(&self) -> Result<SmallVec<[ScalarValue; 2]>> {
let (payments, dates): (Vec<_>, Vec<_>) = self
.pairs
.iter()
.map(|(payment, date)| {
let payment = ScalarValue::Float64(Some(*payment));
- let date = ScalarValue::Date32(Some(*date));
+ let date = ScalarValue::Int32(Some(*date)); // Date32
(payment, date)
})
.unzip();
@@ -188,9 +434,9 @@
ValueState::Unset => vec![],
ValueState::Set(on_error) => vec![ScalarValue::Float64(on_error)],
};
- Ok(vec![
+ Ok(smallvec::smallvec![
ScalarValue::List(Some(Box::new(payments)), Box::new(DataType::Float64)),
- ScalarValue::List(Some(Box::new(dates)), Box::new(DataType::Date32)),
+ ScalarValue::List(Some(Box::new(dates)), Box::new(DataType::Int32)), // Date32
ScalarValue::List(Some(Box::new(initial_guess)), Box::new(DataType::Float64)),
ScalarValue::List(Some(Box::new(on_error)), Box::new(DataType::Float64)),
])
@@ -224,6 +470,12 @@
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.len() != 4 {
+ return Err(DataFusionError::Internal(format!(
+ "Merging XIRR states list with {} columns instead of 4",
+ states.len()
+ )));
+ }
let payments = states[0]
.as_any()
.downcast_ref::<ListArray>()
@@ -235,14 +487,11 @@
.downcast_ref::<ListArray>()
.unwrap()
.values();
- let dates = dates.as_any().downcast_ref::<Date32Array>().unwrap();
+ let dates = dates.as_any().downcast_ref::<Int32Array>().unwrap(); // Date32Array
for (payment, date) in payments.into_iter().zip(dates) {
self.add_pair(payment, date)?;
}
- let states_len = states.len();
- if states_len < 3 {
- return Ok(());
- }
+
let initial_guesses = states[2]
.as_any()
.downcast_ref::<ListArray>()
@@ -255,9 +504,7 @@
for initial_guess in initial_guesses {
self.set_initial_guess(initial_guess)?;
}
- if states_len < 4 {
- return Ok(());
- }
+
let on_errors = states[3]
.as_any()
.downcast_ref::<ListArray>() |
Check List
Adds XIRR to cube store, following #9508.