Skip to content

Commit e4c5b7b

Browse files
wtnclaude
andcommitted
fix: Support varying quantile values per group in group_by aggregation
Co-authored-by: Claude <noreply@anthropic.com>
1 parent c241260 commit e4c5b7b

File tree

5 files changed

+429
-20
lines changed

5 files changed

+429
-20
lines changed

crates/polars-core/src/frame/column/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,25 @@ impl Column {
828828
.into()
829829
}
830830

831+
/// # Safety
832+
///
833+
/// Does no bounds checks, groups must be correct.
834+
#[cfg(feature = "algorithm_group_by")]
835+
pub unsafe fn agg_varying_quantile(
836+
&self,
837+
groups: &GroupsType,
838+
quantiles: &[f64],
839+
method: QuantileMethod,
840+
) -> Self {
841+
// @scalar-opt
842+
843+
unsafe {
844+
self.as_materialized_series()
845+
.agg_varying_quantile(groups, quantiles, method)
846+
}
847+
.into()
848+
}
849+
831850
/// # Safety
832851
///
833852
/// Does no bounds checks, groups must be correct.

crates/polars-core/src/frame/group_by/aggregations/dispatch.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,73 @@ impl Series {
370370
}
371371
}
372372

373+
#[doc(hidden)]
374+
pub unsafe fn agg_varying_quantile(
375+
&self,
376+
groups: &GroupsType,
377+
quantiles: &[f64],
378+
method: QuantileMethod,
379+
) -> Series {
380+
// Prevent a rechunk for every individual group.
381+
let s = if groups.len() > 1 {
382+
self.rechunk()
383+
} else {
384+
self.clone()
385+
};
386+
387+
use DataType::*;
388+
match s.dtype() {
389+
Float32 => s
390+
.f32()
391+
.unwrap()
392+
.agg_varying_quantile(groups, quantiles, method),
393+
Float64 => s
394+
.f64()
395+
.unwrap()
396+
.agg_varying_quantile(groups, quantiles, method),
397+
#[cfg(feature = "dtype-decimal")]
398+
Decimal(_, _) => s
399+
.cast(&DataType::Float64)
400+
.unwrap()
401+
.agg_varying_quantile(groups, quantiles, method),
402+
#[cfg(feature = "dtype-datetime")]
403+
Datetime(tu, tz) => self
404+
.to_physical_repr()
405+
.agg_varying_quantile(groups, quantiles, method)
406+
.cast(&Int64)
407+
.unwrap()
408+
.into_datetime(*tu, tz.clone()),
409+
#[cfg(feature = "dtype-duration")]
410+
Duration(tu) => self
411+
.to_physical_repr()
412+
.agg_varying_quantile(groups, quantiles, method)
413+
.cast(&Int64)
414+
.unwrap()
415+
.into_duration(*tu),
416+
#[cfg(feature = "dtype-time")]
417+
Time => self
418+
.to_physical_repr()
419+
.agg_varying_quantile(groups, quantiles, method)
420+
.cast(&Int64)
421+
.unwrap()
422+
.into_time(),
423+
#[cfg(feature = "dtype-date")]
424+
Date => (self
425+
.to_physical_repr()
426+
.agg_varying_quantile(groups, quantiles, method)
427+
.cast(&Float64)
428+
.unwrap()
429+
* (US_IN_DAY as f64))
430+
.cast(&DataType::Int64)
431+
.unwrap()
432+
.into_datetime(TimeUnit::Microseconds, None),
433+
dt if dt.is_primitive_numeric() => {
434+
apply_method_physical_integer!(s, agg_varying_quantile, groups, quantiles, method)
435+
},
436+
_ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
437+
}
438+
}
439+
373440
#[doc(hidden)]
374441
pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
375442
// Prevent a rechunk for every individual group.

crates/polars-core/src/frame/group_by/aggregations/mod.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ where
181181
ca.into_series()
182182
}
183183

184+
/// Same as `agg_helper_idx_on_all` but also passes the group index to the closure.
185+
fn agg_helper_idx_on_all_with_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
186+
where
187+
F: Fn((usize, &IdxVec)) -> Option<T::Native> + Send + Sync,
188+
T: PolarsNumericType,
189+
{
190+
let ca: ChunkedArray<T> =
191+
POOL.install(|| groups.all().into_par_iter().enumerate().map(f).collect());
192+
ca.into_series()
193+
}
194+
184195
pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
185196
where
186197
F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
@@ -190,6 +201,22 @@ where
190201
ca.into_series()
191202
}
192203

204+
/// Same as `_agg_helper_slice` but also passes the group index to the closure.
205+
fn _agg_helper_slice_with_idx<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
206+
where
207+
F: Fn(usize, [IdxSize; 2]) -> Option<T::Native> + Send + Sync,
208+
T: PolarsNumericType,
209+
{
210+
let ca: ChunkedArray<T> = POOL.install(|| {
211+
groups
212+
.par_iter()
213+
.enumerate()
214+
.map(|(idx, &g)| f(idx, g))
215+
.collect()
216+
});
217+
ca.into_series()
218+
}
219+
193220
pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
194221
where
195222
F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
@@ -336,6 +363,64 @@ where
336363
}
337364
}
338365

366+
/// Compute quantile aggregation where each group can have a different quantile value.
367+
unsafe fn agg_varying_quantile_generic<T, K>(
368+
ca: &ChunkedArray<T>,
369+
groups: &GroupsType,
370+
quantiles: &[f64],
371+
method: QuantileMethod,
372+
) -> Series
373+
where
374+
T: PolarsNumericType,
375+
ChunkedArray<T>: QuantileDispatcher<K::Native>,
376+
K: PolarsNumericType,
377+
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
378+
{
379+
match groups {
380+
GroupsType::Idx(groups) => {
381+
let ca = ca.rechunk();
382+
agg_helper_idx_on_all_with_idx::<K, _>(groups, |(group_idx, idx)| {
383+
debug_assert!(idx.len() <= ca.len());
384+
let quantile = quantiles[group_idx];
385+
if !(0.0..=1.0).contains(&quantile) {
386+
return None;
387+
}
388+
match idx.len() {
389+
0 => None,
390+
1 => {
391+
let idx = idx[0] as usize;
392+
ca.get(idx).map(|v| NumCast::from(v).unwrap())
393+
},
394+
_ => {
395+
let take = { ca.take_unchecked(idx) };
396+
take._quantile(quantile, method).unwrap()
397+
},
398+
}
399+
})
400+
},
401+
GroupsType::Slice { groups, .. } => {
402+
_agg_helper_slice_with_idx::<K, _>(groups, |group_idx, [first, len]| {
403+
debug_assert!(first + len <= ca.len() as IdxSize);
404+
let quantile = quantiles[group_idx];
405+
if !(0.0..=1.0).contains(&quantile) {
406+
return None;
407+
}
408+
match len {
409+
0 => None,
410+
1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
411+
_ => {
412+
let arr_group = _slice_from_offsets(ca, first, len);
413+
arr_group
414+
._quantile(quantile, method)
415+
.unwrap_unchecked()
416+
.map(|flt| NumCast::from(flt).unwrap_unchecked())
417+
},
418+
}
419+
})
420+
},
421+
}
422+
}
423+
339424
unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
340425
where
341426
T: PolarsNumericType,
@@ -905,6 +990,14 @@ impl Float32Chunked {
905990
) -> Series {
906991
agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
907992
}
993+
pub(crate) unsafe fn agg_varying_quantile(
994+
&self,
995+
groups: &GroupsType,
996+
quantiles: &[f64],
997+
method: QuantileMethod,
998+
) -> Series {
999+
agg_varying_quantile_generic::<_, Float32Type>(self, groups, quantiles, method)
1000+
}
9081001
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
9091002
agg_median_generic::<_, Float32Type>(self, groups)
9101003
}
@@ -918,6 +1011,14 @@ impl Float64Chunked {
9181011
) -> Series {
9191012
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
9201013
}
1014+
pub(crate) unsafe fn agg_varying_quantile(
1015+
&self,
1016+
groups: &GroupsType,
1017+
quantiles: &[f64],
1018+
method: QuantileMethod,
1019+
) -> Series {
1020+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1021+
}
9211022
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
9221023
agg_median_generic::<_, Float64Type>(self, groups)
9231024
}
@@ -1110,6 +1211,14 @@ where
11101211
) -> Series {
11111212
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
11121213
}
1214+
pub(crate) unsafe fn agg_varying_quantile(
1215+
&self,
1216+
groups: &GroupsType,
1217+
quantiles: &[f64],
1218+
method: QuantileMethod,
1219+
) -> Series {
1220+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1221+
}
11131222
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
11141223
agg_median_generic::<_, Float64Type>(self, groups)
11151224
}

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,22 @@ impl AggQuantileExpr {
483483
);
484484
quantile.get(0).unwrap().try_extract()
485485
}
486+
487+
fn get_quantile_from_scalar(&self, quantile_ac: &AggregationContext) -> PolarsResult<f64> {
488+
let quantile_col = quantile_ac.get_values();
489+
polars_ensure!(quantile_col.len() <= 1, ComputeError:
490+
"polars only supports computing a single quantile; \
491+
make sure the 'quantile' expression input produces a single quantile"
492+
);
493+
quantile_col.get(0).unwrap().try_extract()
494+
}
495+
496+
fn get_quantiles_per_group(&self, quantile_ac: &AggregationContext) -> PolarsResult<Vec<f64>> {
497+
let quantile_col = quantile_ac.get_values();
498+
let quantile_col = quantile_col.cast(&DataType::Float64)?;
499+
let quantile_ca = quantile_col.f64()?;
500+
Ok(quantile_ca.iter().map(|v| v.unwrap_or(f64::NAN)).collect())
501+
}
486502
}
487503

488504
impl PhysicalExpr for AggQuantileExpr {
@@ -505,6 +521,7 @@ impl PhysicalExpr for AggQuantileExpr {
505521
state: &ExecutionState,
506522
) -> PolarsResult<AggregationContext<'a>> {
507523
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
524+
let quantile_ac = self.quantile.evaluate_on_groups(df, groups, state)?;
508525

509526
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
510527
// reliably call `agg_quantile` functions with the groups.
@@ -513,27 +530,52 @@ impl PhysicalExpr for AggQuantileExpr {
513530
// don't change names by aggregations as is done in polars-core
514531
let keep_name = ac.get_values().name().clone();
515532

516-
let quantile = self.get_quantile(df, state)?;
533+
// Check if quantile is a literal/scalar (same for all groups) or varies per group
534+
let is_uniform_quantile = quantile_ac.is_literal()
535+
|| matches!(quantile_ac.agg_state(), AggState::LiteralScalar(_));
517536

518-
if let AggState::LiteralScalar(c) = &mut ac.state {
519-
*c = c
520-
.quantile_reduce(quantile, self.method)?
521-
.into_column(keep_name);
522-
return Ok(ac);
523-
}
537+
if is_uniform_quantile {
538+
// Fast path: single quantile value for all groups
539+
let quantile = self.get_quantile_from_scalar(&quantile_ac)?;
524540

525-
// SAFETY:
526-
// groups are in bounds
527-
let mut agg = unsafe {
528-
ac.flat_naive()
529-
.into_owned()
530-
.agg_quantile(ac.groups(), quantile, self.method)
531-
};
532-
agg.rename(keep_name);
533-
Ok(AggregationContext::from_agg_state(
534-
AggregatedScalar(agg),
535-
Cow::Borrowed(groups),
536-
))
541+
if let AggState::LiteralScalar(c) = &mut ac.state {
542+
*c = c
543+
.quantile_reduce(quantile, self.method)?
544+
.into_column(keep_name);
545+
return Ok(ac);
546+
}
547+
548+
// SAFETY:
549+
// groups are in bounds
550+
let mut agg = unsafe {
551+
ac.flat_naive()
552+
.into_owned()
553+
.agg_quantile(ac.groups(), quantile, self.method)
554+
};
555+
agg.rename(keep_name);
556+
Ok(AggregationContext::from_agg_state(
557+
AggregatedScalar(agg),
558+
Cow::Borrowed(groups),
559+
))
560+
} else {
561+
// Different quantile value per group
562+
let quantiles = self.get_quantiles_per_group(&quantile_ac)?;
563+
564+
// SAFETY:
565+
// groups are in bounds
566+
let mut agg = unsafe {
567+
ac.flat_naive().into_owned().agg_varying_quantile(
568+
ac.groups(),
569+
&quantiles,
570+
self.method,
571+
)
572+
};
573+
agg.rename(keep_name);
574+
Ok(AggregationContext::from_agg_state(
575+
AggregatedScalar(agg),
576+
Cow::Borrowed(groups),
577+
))
578+
}
537579
}
538580

539581
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {

0 commit comments

Comments
 (0)