From 27eff84449a7b3f05a207948eca8d7d39b9461ff Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 11 Mar 2026 09:59:04 +0530 Subject: [PATCH] fix(stats): widen sum_value integer arithmetic to SUM-compatible types --- datafusion/common/src/stats.rs | 94 ++++++++++++++++++- datafusion/datasource/src/statistics.rs | 81 +++++++++++++++- datafusion/physical-expr/src/projection.rs | 42 ++++++++- .../physical-plan/src/joins/cross_join.rs | 85 +++++++++++++---- datafusion/physical-plan/src/union.rs | 2 +- 5 files changed, 272 insertions(+), 32 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 3d4d9b6c6c4ae..fb506f0b9d273 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -180,6 +180,14 @@ impl Precision { } impl Precision { + fn sum_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 => DataType::Int64, + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => DataType::UInt64, + _ => data_type.clone(), + } + } + /// Calculates the sum of two (possibly inexact) [`ScalarValue`] values, /// conservatively propagating exactness information. If one of the input /// values is [`Precision::Absent`], the result is `Absent` too. @@ -198,6 +206,46 @@ impl Precision { } } + /// Casts integer values to the wider SQL `SUM` return type. + /// + /// This narrows overflow risk when `sum_value` statistics are merged: + /// `Int8/Int16/Int32 -> Int64` and `UInt8/UInt16/UInt32 -> UInt64`. + pub fn cast_to_sum_type(&self) -> Precision { + match self { + Precision::Exact(value) => { + let source_type = value.data_type(); + let target_type = Self::sum_data_type(&source_type); + if source_type == target_type { + Precision::Exact(value.clone()) + } else { + value + .cast_to(&target_type) + .map(Precision::Exact) + .unwrap_or(Precision::Absent) + } + } + Precision::Inexact(value) => { + let source_type = value.data_type(); + let target_type = Self::sum_data_type(&source_type); + if source_type == target_type { + Precision::Inexact(value.clone()) + } else { + value + .cast_to(&target_type) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent) + } + } + Precision::Absent => Precision::Absent, + } + } + + /// SUM-style addition with integer widening to match SQL `SUM` return + /// types for smaller integral inputs. + pub fn add_for_sum(&self, other: &Precision) -> Precision { + self.cast_to_sum_type().add(&other.cast_to_sum_type()) + } + /// Calculates the difference of two (possibly inexact) [`ScalarValue`] values, /// conservatively propagating exactness information. If one of the input /// values is [`Precision::Absent`], the result is `Absent` too. @@ -636,7 +684,8 @@ impl Statistics { col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count); col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); - col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); + col_stats.sum_value = + col_stats.sum_value.add_for_sum(&item_col_stats.sum_value); col_stats.distinct_count = Precision::Absent; col_stats.byte_size = col_stats.byte_size.add(&item_col_stats.byte_size); } @@ -948,6 +997,45 @@ mod tests { assert_eq!(precision.add(&Precision::Absent), Precision::Absent); } + #[test] + fn test_add_for_sum_scalar_integer_widening() { + let precision = Precision::Exact(ScalarValue::Int32(Some(42))); + + assert_eq!( + precision.add_for_sum(&Precision::Exact(ScalarValue::Int32(Some(23)))), + Precision::Exact(ScalarValue::Int64(Some(65))), + ); + assert_eq!( + precision.add_for_sum(&Precision::Inexact(ScalarValue::Int32(Some(23)))), + Precision::Inexact(ScalarValue::Int64(Some(65))), + ); + } + + #[test] + fn test_add_for_sum_prevents_int32_overflow() { + let lhs = Precision::Exact(ScalarValue::Int32(Some(i32::MAX))); + let rhs = Precision::Exact(ScalarValue::Int32(Some(1))); + + assert_eq!( + lhs.add_for_sum(&rhs), + Precision::Exact(ScalarValue::Int64(Some(i64::from(i32::MAX) + 1))), + ); + } + + #[test] + fn test_add_for_sum_scalar_unsigned_integer_widening() { + let precision = Precision::Exact(ScalarValue::UInt32(Some(42))); + + assert_eq!( + precision.add_for_sum(&Precision::Exact(ScalarValue::UInt32(Some(23)))), + Precision::Exact(ScalarValue::UInt64(Some(65))), + ); + assert_eq!( + precision.add_for_sum(&Precision::Inexact(ScalarValue::UInt32(Some(23)))), + Precision::Inexact(ScalarValue::UInt64(Some(65))), + ); + } + #[test] fn test_sub() { let precision1 = Precision::Exact(42); @@ -1193,7 +1281,7 @@ mod tests { ); assert_eq!( col1_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(1100))) + Precision::Exact(ScalarValue::Int64(Some(1100))) ); // 500 + 600 let col2_stats = &summary_stats.column_statistics[1]; @@ -1208,7 +1296,7 @@ mod tests { ); assert_eq!( col2_stats.sum_value, - Precision::Exact(ScalarValue::Int32(Some(2200))) + Precision::Exact(ScalarValue::Int64(Some(2200))) ); // 1000 + 1200 } diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index b1a56e096c222..e5a1e4613b3d4 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -293,7 +293,7 @@ fn sort_columns_from_physical_sort_exprs( since = "47.0.0", note = "Please use `get_files_with_limit` and `compute_all_files_statistics` instead" )] -#[expect(unused)] +#[cfg_attr(not(test), expect(unused))] pub async fn get_statistics_with_limit( all_files: impl Stream)>>, file_schema: SchemaRef, @@ -329,7 +329,7 @@ pub async fn get_statistics_with_limit( col_stats_set[index].null_count = file_column.null_count; col_stats_set[index].max_value = file_column.max_value; col_stats_set[index].min_value = file_column.min_value; - col_stats_set[index].sum_value = file_column.sum_value; + col_stats_set[index].sum_value = file_column.sum_value.cast_to_sum_type(); } // If the number of rows exceeds the limit, we can stop processing @@ -374,7 +374,7 @@ pub async fn get_statistics_with_limit( col_stats.null_count = col_stats.null_count.add(file_nc); col_stats.max_value = col_stats.max_value.max(file_max); col_stats.min_value = col_stats.min_value.min(file_min); - col_stats.sum_value = col_stats.sum_value.add(file_sum); + col_stats.sum_value = col_stats.sum_value.add_for_sum(file_sum); col_stats.byte_size = col_stats.byte_size.add(file_sbs); } @@ -497,3 +497,78 @@ pub fn add_row_stats( ) -> Precision { file_num_rows.add(&num_rows) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::PartitionedFile; + use arrow::datatypes::{DataType, Field, Schema}; + use futures::stream; + + fn file_stats(sum: u32) -> Statistics { + Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(4), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + distinct_count: Precision::Exact(1), + byte_size: Precision::Exact(4), + }], + } + } + + #[tokio::test] + #[expect(deprecated)] + async fn test_get_statistics_with_limit_casts_first_file_sum_to_sum_type() + -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)])); + + let files = stream::iter(vec![Ok(( + PartitionedFile::new("f1.parquet", 1), + Arc::new(file_stats(100)), + ))]); + + let (_group, stats) = + get_statistics_with_limit(files, schema, None, false).await?; + + assert_eq!( + stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(100))) + ); + + Ok(()) + } + + #[tokio::test] + #[expect(deprecated)] + async fn test_get_statistics_with_limit_merges_sum_with_unsigned_widening() + -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)])); + + let files = stream::iter(vec![ + Ok(( + PartitionedFile::new("f1.parquet", 1), + Arc::new(file_stats(100)), + )), + Ok(( + PartitionedFile::new("f2.parquet", 1), + Arc::new(file_stats(200)), + )), + ]); + + let (_group, stats) = + get_statistics_with_limit(files, schema, None, true).await?; + + assert_eq!( + stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(300))) + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index dbbd289415277..de6377bddc012 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -693,12 +693,15 @@ impl ProjectionExprs { Precision::Absent }; - let sum_value = Precision::::from(stats.num_rows) - .cast_to(&value.data_type()) - .ok() - .map(|row_count| { - Precision::Exact(value.clone()).multiply(&row_count) + let widened_sum = Precision::Exact(value.clone()).cast_to_sum_type(); + let sum_value = widened_sum + .get_value() + .and_then(|sum| { + Precision::::from(stats.num_rows) + .cast_to(&sum.data_type()) + .ok() }) + .map(|row_count| widened_sum.multiply(&row_count)) .unwrap_or(Precision::Absent); ColumnStatistics { @@ -2866,6 +2869,35 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_project_statistics_with_i32_literal_sum_widens_to_i64() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + alias: "constant".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(Some(50))) + ); + + Ok(()) + } + // Test statistics calculation for NULL literal (constant NULL column) #[test] fn test_project_statistics_with_null_literal() -> Result<()> { diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 342cb7e70a78b..44dd194945ea2 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -447,32 +447,34 @@ fn stats_cartesian_product( // Min, max and distinct_count on the other hand are invariants. let cross_join_stats = left_col_stats .into_iter() - .map(|s| ColumnStatistics { - null_count: s.null_count.multiply(&right_row_count), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - sum_value: s - .sum_value - .get_value() - // Cast the row count into the same type as any existing sum value - .and_then(|v| { - Precision::::from(right_row_count) - .cast_to(&v.data_type()) - .ok() - }) - .map(|row_count| s.sum_value.multiply(&row_count)) - .unwrap_or(Precision::Absent), - byte_size: Precision::Absent, + .map(|s| { + let widened_sum = s.sum_value.cast_to_sum_type(); + ColumnStatistics { + null_count: s.null_count.multiply(&right_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + sum_value: widened_sum + .get_value() + // Cast the row count into the same type as any existing sum value + .and_then(|v| { + Precision::::from(right_row_count) + .cast_to(&v.data_type()) + .ok() + }) + .map(|row_count| widened_sum.multiply(&row_count)) + .unwrap_or(Precision::Absent), + byte_size: Precision::Absent, + } }) .chain(right_col_stats.into_iter().map(|s| { + let widened_sum = s.sum_value.cast_to_sum_type(); ColumnStatistics { null_count: s.null_count.multiply(&left_row_count), distinct_count: s.distinct_count, min_value: s.min_value, max_value: s.max_value, - sum_value: s - .sum_value + sum_value: widened_sum .get_value() // Cast the row count into the same type as any existing sum value .and_then(|v| { @@ -480,7 +482,7 @@ fn stats_cartesian_product( .cast_to(&v.data_type()) .ok() }) - .map(|row_count| s.sum_value.multiply(&row_count)) + .map(|row_count| widened_sum.multiply(&row_count)) .unwrap_or(Precision::Absent), byte_size: Precision::Absent, } @@ -864,6 +866,49 @@ mod tests { assert_eq!(result, expected); } + #[tokio::test] + async fn test_stats_cartesian_product_unsigned_sum_widens_to_u64() { + let left_row_count = 2; + let right_row_count = 3; + + let left = Statistics { + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::UInt32(Some(10))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(1))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(7))), + null_count: Precision::Exact(0), + byte_size: Precision::Absent, + }], + }; + + let right = Statistics { + num_rows: Precision::Exact(right_row_count), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::UInt32(Some(12))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(0))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(11))), + null_count: Precision::Exact(0), + byte_size: Precision::Absent, + }], + }; + + let result = stats_cartesian_product(left, right); + + assert_eq!( + result.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(21))) + ); + assert_eq!( + result.column_statistics[1].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(22))) + ); + } + #[tokio::test] async fn test_join() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9fc02e730d022..343c38a774cc4 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -825,7 +825,7 @@ fn col_stats_union( left.distinct_count = Precision::Absent; left.min_value = left.min_value.min(&right.min_value); left.max_value = left.max_value.max(&right.max_value); - left.sum_value = left.sum_value.add(&right.sum_value); + left.sum_value = left.sum_value.add_for_sum(&right.sum_value); left.null_count = left.null_count.add(&right.null_count); left