Skip to content

feat(cubestore): Add XIRR aggregate function to Cube Store #9520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025

Conversation

srh
Copy link
Member

@srh srh commented Apr 26, 2025

Check List

  • Tests have been run in packages where changes made if available
  • Linter has been run for changed code
  • Tests for the changes have been added if not covered yet
  • Docs have been added / updated if required

Adds XIRR to cube store, following #9508.

@srh srh requested review from a team as code owners April 26, 2025 00:19
Copy link

codecov bot commented Apr 26, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 83.86%. Comparing base (d28333a) to head (2d8d8e2).
Report is 26 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #9520   +/-   ##
=======================================
  Coverage   83.86%   83.86%           
=======================================
  Files         230      230           
  Lines       83825    83825           
=======================================
  Hits        70297    70297           
  Misses      13528    13528           
Flag Coverage Δ
cubesql 83.86% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@srh
Copy link
Member Author

srh commented May 6, 2025

For what it's worth here is a diff of the new cubestore xirr implementation against the cubesql xirr.rs.

@@ -1,21 +1,28 @@
 use std::sync::Arc;
 
+use chrono::Datelike as _;
 use datafusion::{
     arrow::{
-        array::{ArrayRef, Date32Array, Float64Array, ListArray},
+        array::{ArrayRef, Date32Array, Float64Array, Int32Array, ListArray},
         compute::cast,
         datatypes::{DataType, Field, TimeUnit},
     },
     error::{DataFusionError, Result},
-    logical_expr::{
-        Accumulator, AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction,
-        Signature, StateTypeFunction, TypeSignature, Volatility,
+    physical_plan::{
+        aggregates::{AccumulatorFunctionImplementation, StateTypeFunction},
+        functions::{ReturnTypeFunction, Signature},
+        udaf::AggregateUDF,
+        Accumulator,
     },
     scalar::ScalarValue,
 };
+use smallvec::SmallVec;
 
-// Note:  A copy/pasted and minimally(?) modified version of this is in Cubestore in udf_xirr.rs, and you might
-// want to update both.
+// This is copy/pasted and edited from cubesql in a file xirr.rs -- you might need to update both.
+//
+// Some differences here:
+// - the Accumulator trait has reset, merge, and update functions that operate on ScalarValues.
+// - List of Date32 isn't allowed, so we use List of Int32 in state values.
 
 pub const XIRR_UDAF_NAME: &str = "xirr";
 
@@ -62,20 +69,20 @@
         for payment_type in NUMERIC_TYPES {
             for date_type in DATETIME_TYPES {
                 // Base signatures without `initial_guess` and `on_error` arguments
-                type_signatures.push(TypeSignature::Exact(vec![
+                type_signatures.push(Signature::Exact(vec![
                     payment_type.clone(),
                     date_type.clone(),
                 ]));
                 // Signatures with `initial_guess` argument; only [`DataType::Float64`] is accepted
                 const INITIAL_GUESS_TYPE: DataType = DataType::Float64;
-                type_signatures.push(TypeSignature::Exact(vec![
+                type_signatures.push(Signature::Exact(vec![
                     payment_type.clone(),
                     date_type.clone(),
                     INITIAL_GUESS_TYPE,
                 ]));
                 // Signatures with `initial_guess` and `on_error` arguments
                 for on_error_type in NUMERIC_TYPES {
-                    type_signatures.push(TypeSignature::Exact(vec![
+                    type_signatures.push(Signature::Exact(vec![
                         payment_type.clone(),
                         date_type.clone(),
                         INITIAL_GUESS_TYPE,
@@ -86,17 +93,14 @@
         }
         type_signatures
     };
-    let signature = Signature::one_of(
-        type_signatures,
-        Volatility::Volatile, // due to the usage of [`f64::powf`]
-    );
+    let signature = Signature::OneOf(type_signatures);
     let return_type: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
     let accumulator: AccumulatorFunctionImplementation =
         Arc::new(|| Ok(Box::new(XirrAccumulator::new())));
     let state_type: StateTypeFunction = Arc::new(|_| {
         Ok(Arc::new(vec![
             DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
-            DataType::List(Box::new(Field::new("item", DataType::Date32, true))),
+            DataType::List(Box::new(Field::new("item", DataType::Int32, true))), // Date32
             DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
             DataType::List(Box::new(Field::new("item", DataType::Float64, true))),
         ]))
@@ -105,7 +109,7 @@
 }
 
 #[derive(Debug)]
-struct XirrAccumulator {
+pub struct XirrAccumulator {
     /// Pairs of (payment, date).
     pairs: Vec<(f64, i32)>,
     initial_guess: ValueState<f64>,
@@ -113,7 +117,7 @@
 }
 
 impl XirrAccumulator {
-    fn new() -> Self {
+    pub fn new() -> Self {
         XirrAccumulator {
             pairs: vec![],
             initial_guess: ValueState::Unset,
@@ -169,14 +173,256 @@
     }
 }
 
+fn cast_scalar_to_float64(scalar: &ScalarValue) -> Result<Option<f64>> {
+    fn err(from_type: &str) -> Result<Option<f64>> {
+        Err(DataFusionError::Internal(format!(
+            "cannot cast {} to Float64",
+            from_type
+        )))
+    }
+    match scalar {
+        ScalarValue::Boolean(_) => err("Boolean"),
+        ScalarValue::Float32(o) => Ok(o.map(f64::from)),
+        ScalarValue::Float64(o) => Ok(*o),
+        ScalarValue::Int8(o) => Ok(o.map(f64::from)),
+        ScalarValue::Int16(o) => Ok(o.map(f64::from)),
+        ScalarValue::Int32(o) => Ok(o.map(f64::from)),
+        ScalarValue::Int64(o) => Ok(o.map(|x| x as f64)),
+        ScalarValue::Int96(o) => Ok(o.map(|x| x as f64)),
+        ScalarValue::Int64Decimal(o, scale) => {
+            Ok(o.map(|x| (x as f64) / 10f64.powi(*scale as i32)))
+        }
+        ScalarValue::Int96Decimal(o, scale) => {
+            Ok(o.map(|x| (x as f64) / 10f64.powi(*scale as i32)))
+        }
+        ScalarValue::UInt8(o) => Ok(o.map(f64::from)),
+        ScalarValue::UInt16(o) => Ok(o.map(f64::from)),
+        ScalarValue::UInt32(o) => Ok(o.map(f64::from)),
+        ScalarValue::UInt64(o) => Ok(o.map(|x| x as f64)),
+        ScalarValue::Utf8(_) => err("Utf8"),
+        ScalarValue::LargeUtf8(_) => err("LargeUtf8"),
+        ScalarValue::Binary(_) => err("Binary"),
+        ScalarValue::LargeBinary(_) => err("LargeBinary"),
+        ScalarValue::List(_, _dt) => err("List"),
+        ScalarValue::Date32(_) => err("Date32"),
+        ScalarValue::Date64(_) => err("Date64"),
+        ScalarValue::TimestampSecond(_) => err("TimestampSecond"),
+        ScalarValue::TimestampMillisecond(_) => err("TimestampMillisecond"),
+        ScalarValue::TimestampMicrosecond(_) => err("TimestampMicrosecond"),
+        ScalarValue::TimestampNanosecond(_) => err("TimestampNanosecond"),
+        ScalarValue::IntervalYearMonth(_) => err("IntervalYearMonth"),
+        ScalarValue::IntervalDayTime(_) => err("IntervalDayTime"),
+    }
+}
+
+fn cast_scalar_to_date32(scalar: &ScalarValue) -> Result<Option<i32>> {
+    fn err(from_type: &str) -> Result<Option<i32>> {
+        Err(DataFusionError::Internal(format!(
+            "cannot cast {} to Date32",
+            from_type
+        )))
+    }
+    fn string_to_date32(o: &Option<String>) -> Result<Option<i32>> {
+        if let Some(x) = o {
+            // Consistent with cast() in update_batch being configured with the "safe" option true, so we return None (null value) if there is a cast error.
+            Ok(x.parse::<chrono::NaiveDate>()
+                .map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
+                .ok())
+        } else {
+            Ok(None)
+        }
+    }
+
+    // Number of days between 0001-01-01 and 1970-01-01
+    const EPOCH_DAYS_FROM_CE: i32 = 719_163;
+
+    const SECONDS_IN_DAY: i64 = 86_400;
+    const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * 1_000;
+
+    match scalar {
+        ScalarValue::Boolean(_) => err("Boolean"),
+        ScalarValue::Float32(_) => err("Float32"),
+        ScalarValue::Float64(_) => err("Float64"),
+        ScalarValue::Int8(_) => err("Int8"),
+        ScalarValue::Int16(_) => err("Int16"),
+        ScalarValue::Int32(o) => Ok(*o),
+        ScalarValue::Int64(o) => Ok(o.and_then(|x| num::NumCast::from(x))),
+        ScalarValue::Int96(_) => err("Int96"),
+        ScalarValue::Int64Decimal(_, _scale) => err("Int64Decimal"),
+        ScalarValue::Int96Decimal(_, _scale) => err("Int96Decimal"),
+        ScalarValue::UInt8(_) => err("UInt8"),
+        ScalarValue::UInt16(_) => err("UInt16"),
+        ScalarValue::UInt32(_) => err("UInt32"),
+        ScalarValue::UInt64(_) => err("UInt64"),
+        ScalarValue::Utf8(o) => string_to_date32(o),
+        ScalarValue::LargeUtf8(o) => string_to_date32(o),
+        ScalarValue::Binary(_) => err("Binary"),
+        ScalarValue::LargeBinary(_) => err("LargeBinary"),
+        ScalarValue::List(_, _dt) => err("List"),
+        ScalarValue::Date32(o) => Ok(*o),
+        ScalarValue::Date64(o) => Ok(o.map(|x| (x / MILLISECONDS_IN_DAY) as i32)),
+        ScalarValue::TimestampSecond(o) => Ok(o.map(|x| (x / SECONDS_IN_DAY) as i32)),
+        ScalarValue::TimestampMillisecond(o) => Ok(o.map(|x| (x / MILLISECONDS_IN_DAY) as i32)),
+        ScalarValue::TimestampMicrosecond(o) => {
+            Ok(o.map(|x| (x / (1_000_000 * SECONDS_IN_DAY)) as i32))
+        }
+        ScalarValue::TimestampNanosecond(o) => {
+            Ok(o.map(|x| (x / (1_000_000_000 * SECONDS_IN_DAY)) as i32))
+        }
+        ScalarValue::IntervalYearMonth(_) => err("IntervalYearMonth"),
+        ScalarValue::IntervalDayTime(_) => err("IntervalDayTime"),
+    }
+}
+
 impl Accumulator for XirrAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
+    fn reset(&mut self) {
+        self.pairs.clear();
+        self.initial_guess = ValueState::Unset;
+        self.on_error = ValueState::Unset;
+    }
+
+    fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+        let payment = cast_scalar_to_float64(&values[0])?;
+        let date = cast_scalar_to_date32(&values[1])?;
+        self.add_pair(payment, date)?;
+        let values_len = values.len();
+        if values_len < 3 {
+            return Ok(());
+        }
+        let ScalarValue::Float64(initial_guess) = values[2] else {
+            return Err(DataFusionError::Internal(format!(
+                "XIRR initial guess should be a Float64 but it was of type {}",
+                values[2].get_datatype()
+            )));
+        };
+        self.set_initial_guess(initial_guess)?;
+        if values_len < 4 {
+            return Ok(());
+        }
+        let on_error = cast_scalar_to_float64(&values[3])?;
+        self.set_on_error(on_error)?;
+        Ok(())
+    }
+
+    fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+        if states.len() != 4 {
+            return Err(DataFusionError::Internal(format!(
+                "Merging XIRR states list with {} columns instead of 4",
+                states.len()
+            )));
+        }
+        // payments and dates
+        {
+            let ScalarValue::List(payments, payments_datatype) = &states[0] else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR payments state must be a List but it was of type {}",
+                    states[0].get_datatype()
+                )));
+            };
+            if payments_datatype.as_ref() != &DataType::Float64 {
+                return Err(DataFusionError::Internal(format!("XIRR payments state must be a List of Float64 but it was a List with element type {}", payments_datatype)));
+            }
+            let ScalarValue::List(dates, dates_datatype) = &states[1] else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR dates state must be a List but it was of type {}",
+                    states[1].get_datatype()
+                )));
+            };
+            if dates_datatype.as_ref() != &DataType::Int32 {
+                return Err(DataFusionError::Internal(format!("XIRR dates state must be a List of Int32 but it was a List with element type {}", dates_datatype)));
+            }
+            let Some(payments) = payments else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR payments state is null in merge"
+                )));
+            };
+            let Some(dates) = dates else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR dates state is null, payments not null in merge"
+                )));
+            };
+
+            for (payment, date) in payments.iter().zip(dates.iter()) {
+                let ScalarValue::Float64(payment) = payment else {
+                    return Err(DataFusionError::Internal(format!(
+                        "XIRR payment in List is not a Float64"
+                    )));
+                };
+                let ScalarValue::Int32(date) = date else {
+                    // Date32
+                    return Err(DataFusionError::Internal(format!(
+                        "XIRR date in List is not an Int32"
+                    )));
+                };
+                self.add_pair(*payment, *date)?;
+            }
+        }
+        // initial_guess
+        {
+            let ScalarValue::List(initial_guess_list, initial_guess_dt) = &states[2] else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR initial guess state is not a List in merge"
+                )));
+            };
+            if initial_guess_dt.as_ref() != &DataType::Float64 {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR initial guess state is not a List of Float64 in merge"
+                )));
+            }
+            let Some(initial_guess_list) = initial_guess_list else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR initial guess state is a null list in merge"
+                )));
+            };
+            // To be clear this list has 0 or 1 elements which may be null.
+            for initial_guess in initial_guess_list.iter() {
+                let ScalarValue::Float64(guess) = initial_guess else {
+                    return Err(DataFusionError::Internal(format!(
+                        "XIRR initial guess in List is not a Float64"
+                    )));
+                };
+                self.set_initial_guess(*guess)?;
+            }
+        }
+        // on_error
+        {
+            let ScalarValue::List(on_error_list, on_error_dt) = &states[3] else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR on_error state is not a List in merge"
+                )));
+            };
+            if on_error_dt.as_ref() != &DataType::Float64 {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR on_error state is not a List of Float64 in merge"
+                )));
+            }
+
+            let Some(on_error_list) = on_error_list else {
+                return Err(DataFusionError::Internal(format!(
+                    "XIRR on_error state is a null list in merge"
+                )));
+            };
+            // To be clear this list has 0 or 1 elements which may be null.
+            for on_error in on_error_list.iter() {
+                let ScalarValue::Float64(on_error) = on_error else {
+                    return Err(DataFusionError::Internal(format!(
+                        "XIRR on_error in List is not a Float64"
+                    )));
+                };
+                self.set_on_error(*on_error)?;
+            }
+        }
+
+        Ok(())
+    }
+
+    fn state(&self) -> Result<SmallVec<[ScalarValue; 2]>> {
         let (payments, dates): (Vec<_>, Vec<_>) = self
             .pairs
             .iter()
             .map(|(payment, date)| {
                 let payment = ScalarValue::Float64(Some(*payment));
-                let date = ScalarValue::Date32(Some(*date));
+                let date = ScalarValue::Int32(Some(*date)); // Date32
                 (payment, date)
             })
             .unzip();
@@ -188,9 +434,9 @@
             ValueState::Unset => vec![],
             ValueState::Set(on_error) => vec![ScalarValue::Float64(on_error)],
         };
-        Ok(vec![
+        Ok(smallvec::smallvec![
             ScalarValue::List(Some(Box::new(payments)), Box::new(DataType::Float64)),
-            ScalarValue::List(Some(Box::new(dates)), Box::new(DataType::Date32)),
+            ScalarValue::List(Some(Box::new(dates)), Box::new(DataType::Int32)), // Date32
             ScalarValue::List(Some(Box::new(initial_guess)), Box::new(DataType::Float64)),
             ScalarValue::List(Some(Box::new(on_error)), Box::new(DataType::Float64)),
         ])
@@ -224,6 +470,12 @@
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        if states.len() != 4 {
+            return Err(DataFusionError::Internal(format!(
+                "Merging XIRR states list with {} columns instead of 4",
+                states.len()
+            )));
+        }
         let payments = states[0]
             .as_any()
             .downcast_ref::<ListArray>()
@@ -235,14 +487,11 @@
             .downcast_ref::<ListArray>()
             .unwrap()
             .values();
-        let dates = dates.as_any().downcast_ref::<Date32Array>().unwrap();
+        let dates = dates.as_any().downcast_ref::<Int32Array>().unwrap(); // Date32Array
         for (payment, date) in payments.into_iter().zip(dates) {
             self.add_pair(payment, date)?;
         }
-        let states_len = states.len();
-        if states_len < 3 {
-            return Ok(());
-        }
+
         let initial_guesses = states[2]
             .as_any()
             .downcast_ref::<ListArray>()
@@ -255,9 +504,7 @@
         for initial_guess in initial_guesses {
             self.set_initial_guess(initial_guess)?;
         }
-        if states_len < 4 {
-            return Ok(());
-        }
+
         let on_errors = states[3]
             .as_any()
             .downcast_ref::<ListArray>()

@srh srh merged commit 785142d into master May 6, 2025
74 of 77 checks passed
@srh srh deleted the cubestore/xirr-udaf branch May 6, 2025 14:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants