Skip to content

Commit 7929246

Browse files
authored
feat(rust, python): prefer streaming groupby if partitionable (#5580)
1 parent 4e35d9c commit 7929246

File tree

23 files changed

+186
-73
lines changed

23 files changed

+186
-73
lines changed

polars/polars-core/src/frame/from.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,13 @@ impl TryFrom<StructArray> for DataFrame {
2424
DataFrame::new(columns)
2525
}
2626
}
27+
28+
impl From<&Schema> for DataFrame {
29+
fn from(schema: &Schema) -> Self {
30+
let cols = schema
31+
.iter()
32+
.map(|(name, dtype)| Series::new_empty(name, dtype))
33+
.collect();
34+
DataFrame::new_no_checks(cols)
35+
}
36+
}

polars/polars-core/src/utils/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult<Vec<DataFrame>>
199199
#[doc(hidden)]
200200
/// Split a [`DataFrame`] into `n` parts. We take a `&mut` to be able to repartition/align chunks.
201201
pub fn split_df(df: &mut DataFrame, n: usize) -> PolarsResult<Vec<DataFrame>> {
202-
if n == 0 {
202+
if n == 0 || df.height() == 0 {
203203
return Ok(vec![df.clone()]);
204204
}
205205
// make sure that chunks are aligned.

polars/polars-lazy/Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ temporal = ["dtype-datetime", "dtype-date", "dtype-time", "dtype-duration", "pol
3838
fmt = ["polars-core/fmt", "polars-plan/fmt"]
3939
strings = ["polars-plan/strings"]
4040
future = []
41-
dtype-u8 = ["polars-plan/dtype-u8"]
42-
dtype-u16 = ["polars-plan/dtype-u16"]
43-
dtype-i8 = ["polars-plan/dtype-i8"]
44-
dtype-i16 = ["polars-plan/dtype-i16"]
41+
dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe/dtype-u8"]
42+
dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe/dtype-u16"]
43+
dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe/dtype-i8"]
44+
dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe/dtype-i16"]
4545
dtype-date = ["polars-plan/dtype-date", "polars-time/dtype-date", "temporal"]
4646
dtype-datetime = ["polars-plan/dtype-datetime", "polars-time/dtype-datetime", "temporal"]
4747
dtype-duration = ["polars-plan/dtype-duration", "polars-time/dtype-duration", "temporal"]
4848
dtype-time = ["polars-core/dtype-time", "temporal"]
49-
dtype-categorical = ["polars-plan/dtype-categorical"]
49+
dtype-categorical = ["polars-plan/dtype-categorical", "polars-pipe/dtype-categorical"]
5050
dtype-struct = ["polars-plan/dtype-struct"]
5151
dtype-binary = ["polars-plan/dtype-binary"]
5252
object = ["polars-plan/object"]

polars/polars-lazy/polars-pipe/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ csv-file = ["polars-plan/csv-file", "polars-io/csv-file"]
2626
parquet = ["polars-plan/parquet", "polars-io/parquet"]
2727
nightly = ["polars-core/nightly", "polars-utils/nightly", "hashbrown/nightly"]
2828
cross_join = ["polars-core/cross_join"]
29+
dtype-u8 = ["polars-core/dtype-u8"]
30+
dtype-u16 = ["polars-core/dtype-u16"]
31+
dtype-i8 = ["polars-core/dtype-i8"]
32+
dtype-i16 = ["polars-core/dtype-i16"]
33+
dtype-categorical = ["polars-core/dtype-categorical"]

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/convert.rs

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::executors::sinks::groupby::aggregates::count::CountAgg;
1414
use crate::executors::sinks::groupby::aggregates::first::FirstAgg;
1515
use crate::executors::sinks::groupby::aggregates::last::LastAgg;
1616
use crate::executors::sinks::groupby::aggregates::mean::MeanAgg;
17+
use crate::executors::sinks::groupby::aggregates::null::NullAgg;
1718
use crate::executors::sinks::groupby::aggregates::{AggregateFunction, SumAgg};
1819
use crate::expressions::PhysicalPipedExpr;
1920
use crate::operators::DataChunk;
@@ -89,14 +90,24 @@ where
8990
{
9091
match expr_arena.get(node) {
9192
AExpr::Alias(input, _) => convert_to_hash_agg(*input, expr_arena, schema, to_physical),
92-
AExpr::Count | AExpr::Agg(AAggExpr::Count(_)) => (
93+
AExpr::Count => (
9394
Arc::new(Count {}),
9495
AggregateFunction::Count(CountAgg::new()),
9596
),
9697
AExpr::Agg(agg) => match agg {
9798
AAggExpr::Sum(input) => {
9899
let phys_expr = to_physical(*input, expr_arena).unwrap();
99-
let agg_fn = match phys_expr.field(schema).unwrap().dtype.to_physical() {
100+
let logical_dtype = phys_expr.field(schema).unwrap().dtype;
101+
102+
#[cfg(feature = "dtype-categorical")]
103+
if matches!(logical_dtype, DataType::Categorical(_)) {
104+
return (
105+
phys_expr,
106+
AggregateFunction::Null(NullAgg::new(logical_dtype)),
107+
);
108+
}
109+
110+
let agg_fn = match logical_dtype.to_physical() {
100111
// Boolean is aggregated as the IDX type.
101112
DataType::Boolean => {
102113
if std::mem::size_of::<IdxSize>() == 4 {
@@ -117,19 +128,28 @@ where
117128
DataType::Int64 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
118129
DataType::Float32 => AggregateFunction::SumF32(SumAgg::<f32>::new()),
119130
DataType::Float64 => AggregateFunction::SumF64(SumAgg::<f64>::new()),
120-
_ => unreachable!(),
131+
dt => AggregateFunction::Null(NullAgg::new(dt)),
121132
};
122133
(phys_expr, agg_fn)
123134
}
124135
AAggExpr::Mean(input) => {
125136
let phys_expr = to_physical(*input, expr_arena).unwrap();
126-
let agg_fn = match phys_expr.field(schema).unwrap().dtype.to_physical() {
137+
138+
let logical_dtype = phys_expr.field(schema).unwrap().dtype;
139+
#[cfg(feature = "dtype-categorical")]
140+
if matches!(logical_dtype, DataType::Categorical(_)) {
141+
return (
142+
phys_expr,
143+
AggregateFunction::Null(NullAgg::new(logical_dtype)),
144+
);
145+
}
146+
let agg_fn = match logical_dtype.to_physical() {
127147
dt if dt.is_integer() => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
128148
// Boolean is aggregated as the IDX type.
129149
DataType::Boolean => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
130150
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
131151
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
132-
_ => unreachable!(),
152+
dt => AggregateFunction::Null(NullAgg::new(dt)),
133153
};
134154
(phys_expr, agg_fn)
135155
}
@@ -143,6 +163,10 @@ where
143163
let dtype = phys_expr.field(schema).unwrap().dtype;
144164
(phys_expr, AggregateFunction::Last(LastAgg::new(dtype)))
145165
}
166+
AAggExpr::Count(input) => {
167+
let phys_expr = to_physical(*input, expr_arena).unwrap();
168+
(phys_expr, AggregateFunction::Count(CountAgg::new()))
169+
}
146170
agg => panic!("{:?} not yet implemented.", agg),
147171
},
148172
_ => todo!(),

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/count.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ impl AggregateFn for CountAgg {
3838
self.count += other.count;
3939
}
4040

41-
fn split(&self) -> Box<dyn AggregateFn> {
42-
Box::new(Self::new())
43-
}
44-
4541
fn finalize(&mut self) -> AnyValue<'static> {
4642
AnyValue::from(self.count)
4743
}

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/first.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ impl AggregateFn for FirstAgg {
4444
};
4545
}
4646

47-
fn split(&self) -> Box<dyn AggregateFn> {
48-
Box::new(Self::new(self.dtype.clone()))
49-
}
50-
5147
fn finalize(&mut self) -> AnyValue<'static> {
5248
std::mem::take(&mut self.first).unwrap_or(AnyValue::Null)
5349
}

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/interface.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::executors::sinks::groupby::aggregates::count::CountAgg;
88
use crate::executors::sinks::groupby::aggregates::first::FirstAgg;
99
use crate::executors::sinks::groupby::aggregates::last::LastAgg;
1010
use crate::executors::sinks::groupby::aggregates::mean::MeanAgg;
11+
use crate::executors::sinks::groupby::aggregates::null::NullAgg;
1112
use crate::executors::sinks::groupby::aggregates::SumAgg;
1213
use crate::operators::IdxSize;
1314

@@ -54,8 +55,6 @@ pub trait AggregateFn: Send + Sync {
5455

5556
fn combine(&mut self, other: &dyn Any);
5657

57-
fn split(&self) -> Box<dyn AggregateFn>;
58-
5958
fn finalize(&mut self) -> AnyValue<'static>;
6059

6160
fn as_any(&self) -> &dyn Any;
@@ -76,10 +75,10 @@ pub enum AggregateFunction {
7675
SumI64(SumAgg<i64>),
7776
MeanF32(MeanAgg<f32>),
7877
MeanF64(MeanAgg<f64>),
79-
// place holder for any aggregate function
80-
// this is not preferred because of the extra
81-
// indirection
82-
// Other(Box<dyn AggregateFn>)
78+
Null(NullAgg), // place holder for any aggregate function
79+
// this is not preferred because of the extra
80+
// indirection
81+
// Other(Box<dyn AggregateFn>)
8382
}
8483

8584
impl AggregateFunction {
@@ -97,6 +96,7 @@ impl AggregateFunction {
9796
MeanF32(_) => MeanF32(MeanAgg::new()),
9897
MeanF64(_) => MeanF64(MeanAgg::new()),
9998
Count(_) => Count(CountAgg::new()),
99+
Null(a) => Null(a.clone()),
100100
}
101101
}
102102
}

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/last.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,6 @@ impl AggregateFn for LastAgg {
4242
};
4343
}
4444

45-
fn split(&self) -> Box<dyn AggregateFn> {
46-
Box::new(Self::new(self.dtype.clone()))
47-
}
48-
4945
fn finalize(&mut self) -> AnyValue<'static> {
5046
std::mem::take(&mut self.last)
5147
}

polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/aggregates/mean.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,6 @@ impl<K: NumericNative + Add<Output = K> + NumCast> AggregateFn for MeanAgg<K> {
100100
};
101101
}
102102

103-
fn split(&self) -> Box<dyn AggregateFn> {
104-
Box::new(Self::new())
105-
}
106-
107103
fn finalize(&mut self) -> AnyValue<'static> {
108104
if let Some(val) = self.sum {
109105
unsafe {

0 commit comments

Comments
 (0)