Skip to content

Commit eb2a7b3

Browse files
wtnclaude
andcommitted
fix: Support varying quantile values per group in group_by aggregation
Co-authored-by: Claude <noreply@anthropic.com>
1 parent 99835ea commit eb2a7b3

File tree

5 files changed

+437
-20
lines changed

5 files changed

+437
-20
lines changed

crates/polars-core/src/frame/column/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,25 @@ impl Column {
828828
.into()
829829
}
830830

831+
/// # Safety
832+
///
833+
/// Does no bounds checks, groups must be correct.
834+
#[cfg(feature = "algorithm_group_by")]
835+
pub unsafe fn agg_varying_quantile(
836+
&self,
837+
groups: &GroupsType,
838+
quantiles: &[f64],
839+
method: QuantileMethod,
840+
) -> Self {
841+
// @scalar-opt
842+
843+
unsafe {
844+
self.as_materialized_series()
845+
.agg_varying_quantile(groups, quantiles, method)
846+
}
847+
.into()
848+
}
849+
831850
/// # Safety
832851
///
833852
/// Does no bounds checks, groups must be correct.

crates/polars-core/src/frame/group_by/aggregations/dispatch.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,73 @@ impl Series {
370370
}
371371
}
372372

373+
#[doc(hidden)]
374+
pub unsafe fn agg_varying_quantile(
375+
&self,
376+
groups: &GroupsType,
377+
quantiles: &[f64],
378+
method: QuantileMethod,
379+
) -> Series {
380+
// Prevent a rechunk for every individual group.
381+
let s = if groups.len() > 1 {
382+
self.rechunk()
383+
} else {
384+
self.clone()
385+
};
386+
387+
use DataType::*;
388+
match s.dtype() {
389+
Float32 => s
390+
.f32()
391+
.unwrap()
392+
.agg_varying_quantile(groups, quantiles, method),
393+
Float64 => s
394+
.f64()
395+
.unwrap()
396+
.agg_varying_quantile(groups, quantiles, method),
397+
#[cfg(feature = "dtype-decimal")]
398+
Decimal(_, _) => s
399+
.cast(&DataType::Float64)
400+
.unwrap()
401+
.agg_varying_quantile(groups, quantiles, method),
402+
#[cfg(feature = "dtype-datetime")]
403+
Datetime(tu, tz) => self
404+
.to_physical_repr()
405+
.agg_varying_quantile(groups, quantiles, method)
406+
.cast(&Int64)
407+
.unwrap()
408+
.into_datetime(*tu, tz.clone()),
409+
#[cfg(feature = "dtype-duration")]
410+
Duration(tu) => self
411+
.to_physical_repr()
412+
.agg_varying_quantile(groups, quantiles, method)
413+
.cast(&Int64)
414+
.unwrap()
415+
.into_duration(*tu),
416+
#[cfg(feature = "dtype-time")]
417+
Time => self
418+
.to_physical_repr()
419+
.agg_varying_quantile(groups, quantiles, method)
420+
.cast(&Int64)
421+
.unwrap()
422+
.into_time(),
423+
#[cfg(feature = "dtype-date")]
424+
Date => (self
425+
.to_physical_repr()
426+
.agg_varying_quantile(groups, quantiles, method)
427+
.cast(&Float64)
428+
.unwrap()
429+
* (US_IN_DAY as f64))
430+
.cast(&DataType::Int64)
431+
.unwrap()
432+
.into_datetime(TimeUnit::Microseconds, None),
433+
dt if dt.is_primitive_numeric() => {
434+
apply_method_physical_integer!(s, agg_varying_quantile, groups, quantiles, method)
435+
},
436+
_ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
437+
}
438+
}
439+
373440
#[doc(hidden)]
374441
pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
375442
// Prevent a rechunk for every individual group.

crates/polars-core/src/frame/group_by/aggregations/mod.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,17 @@ where
179179
ca.into_series()
180180
}
181181

182+
/// Same as `agg_helper_idx_on_all` but also passes the group index to the closure.
183+
fn agg_helper_idx_on_all_with_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
184+
where
185+
F: Fn((usize, &IdxVec)) -> Option<T::Native> + Send + Sync,
186+
T: PolarsNumericType,
187+
{
188+
let ca: ChunkedArray<T> =
189+
POOL.install(|| groups.all().into_par_iter().enumerate().map(f).collect());
190+
ca.into_series()
191+
}
192+
182193
pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
183194
where
184195
F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
@@ -188,6 +199,22 @@ where
188199
ca.into_series()
189200
}
190201

202+
/// Same as `_agg_helper_slice` but also passes the group index to the closure.
203+
fn _agg_helper_slice_with_idx<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
204+
where
205+
F: Fn(usize, [IdxSize; 2]) -> Option<T::Native> + Send + Sync,
206+
T: PolarsNumericType,
207+
{
208+
let ca: ChunkedArray<T> = POOL.install(|| {
209+
groups
210+
.par_iter()
211+
.enumerate()
212+
.map(|(idx, &g)| f(idx, g))
213+
.collect()
214+
});
215+
ca.into_series()
216+
}
217+
191218
pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
192219
where
193220
F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
@@ -334,6 +361,64 @@ where
334361
}
335362
}
336363

364+
/// Compute quantile aggregation where each group can have a different quantile value.
365+
unsafe fn agg_varying_quantile_generic<T, K>(
366+
ca: &ChunkedArray<T>,
367+
groups: &GroupsType,
368+
quantiles: &[f64],
369+
method: QuantileMethod,
370+
) -> Series
371+
where
372+
T: PolarsNumericType,
373+
ChunkedArray<T>: QuantileDispatcher<K::Native>,
374+
K: PolarsNumericType,
375+
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
376+
{
377+
match groups {
378+
GroupsType::Idx(groups) => {
379+
let ca = ca.rechunk();
380+
agg_helper_idx_on_all_with_idx::<K, _>(groups, |(group_idx, idx)| {
381+
debug_assert!(idx.len() <= ca.len());
382+
let quantile = quantiles[group_idx];
383+
if !(0.0..=1.0).contains(&quantile) {
384+
return None;
385+
}
386+
match idx.len() {
387+
0 => None,
388+
1 => {
389+
let idx = idx[0] as usize;
390+
ca.get(idx).map(|v| NumCast::from(v).unwrap())
391+
},
392+
_ => {
393+
let take = { ca.take_unchecked(idx) };
394+
take._quantile(quantile, method).unwrap()
395+
},
396+
}
397+
})
398+
},
399+
GroupsType::Slice { groups, .. } => {
400+
_agg_helper_slice_with_idx::<K, _>(groups, |group_idx, [first, len]| {
401+
debug_assert!(first + len <= ca.len() as IdxSize);
402+
let quantile = quantiles[group_idx];
403+
if !(0.0..=1.0).contains(&quantile) {
404+
return None;
405+
}
406+
match len {
407+
0 => None,
408+
1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
409+
_ => {
410+
let arr_group = _slice_from_offsets(ca, first, len);
411+
arr_group
412+
._quantile(quantile, method)
413+
.unwrap_unchecked()
414+
.map(|flt| NumCast::from(flt).unwrap_unchecked())
415+
},
416+
}
417+
})
418+
},
419+
}
420+
}
421+
337422
unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
338423
where
339424
T: PolarsNumericType,
@@ -903,6 +988,14 @@ impl Float32Chunked {
903988
) -> Series {
904989
agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
905990
}
991+
pub(crate) unsafe fn agg_varying_quantile(
992+
&self,
993+
groups: &GroupsType,
994+
quantiles: &[f64],
995+
method: QuantileMethod,
996+
) -> Series {
997+
agg_varying_quantile_generic::<_, Float32Type>(self, groups, quantiles, method)
998+
}
906999
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
9071000
agg_median_generic::<_, Float32Type>(self, groups)
9081001
}
@@ -916,6 +1009,14 @@ impl Float64Chunked {
9161009
) -> Series {
9171010
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
9181011
}
1012+
pub(crate) unsafe fn agg_varying_quantile(
1013+
&self,
1014+
groups: &GroupsType,
1015+
quantiles: &[f64],
1016+
method: QuantileMethod,
1017+
) -> Series {
1018+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1019+
}
9191020
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
9201021
agg_median_generic::<_, Float64Type>(self, groups)
9211022
}
@@ -1108,6 +1209,14 @@ where
11081209
) -> Series {
11091210
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
11101211
}
1212+
pub(crate) unsafe fn agg_varying_quantile(
1213+
&self,
1214+
groups: &GroupsType,
1215+
quantiles: &[f64],
1216+
method: QuantileMethod,
1217+
) -> Series {
1218+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1219+
}
11111220
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
11121221
agg_median_generic::<_, Float64Type>(self, groups)
11131222
}

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,22 @@ impl AggQuantileExpr {
483483
);
484484
quantile.get(0).unwrap().try_extract()
485485
}
486+
487+
fn get_quantile_from_scalar(&self, quantile_ac: &AggregationContext) -> PolarsResult<f64> {
488+
let quantile_col = quantile_ac.get_values();
489+
polars_ensure!(quantile_col.len() <= 1, ComputeError:
490+
"polars only supports computing a single quantile; \
491+
make sure the 'quantile' expression input produces a single quantile"
492+
);
493+
quantile_col.get(0).unwrap().try_extract()
494+
}
495+
496+
fn get_quantiles_per_group(&self, quantile_ac: &AggregationContext) -> PolarsResult<Vec<f64>> {
497+
let quantile_col = quantile_ac.get_values();
498+
let quantile_col = quantile_col.cast(&DataType::Float64)?;
499+
let quantile_ca = quantile_col.f64()?;
500+
Ok(quantile_ca.iter().map(|v| v.unwrap_or(f64::NAN)).collect())
501+
}
486502
}
487503

488504
impl PhysicalExpr for AggQuantileExpr {
@@ -505,6 +521,7 @@ impl PhysicalExpr for AggQuantileExpr {
505521
state: &ExecutionState,
506522
) -> PolarsResult<AggregationContext<'a>> {
507523
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
524+
let quantile_ac = self.quantile.evaluate_on_groups(df, groups, state)?;
508525

509526
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
510527
// reliably call `agg_quantile` functions with the groups.
@@ -513,27 +530,52 @@ impl PhysicalExpr for AggQuantileExpr {
513530
// don't change names by aggregations as is done in polars-core
514531
let keep_name = ac.get_values().name().clone();
515532

516-
let quantile = self.get_quantile(df, state)?;
533+
// Check if quantile is a literal/scalar (same for all groups) or varies per group
534+
let is_uniform_quantile = quantile_ac.is_literal()
535+
|| matches!(quantile_ac.agg_state(), AggState::LiteralScalar(_));
517536

518-
if let AggState::LiteralScalar(c) = &mut ac.state {
519-
*c = c
520-
.quantile_reduce(quantile, self.method)?
521-
.into_column(keep_name);
522-
return Ok(ac);
523-
}
537+
if is_uniform_quantile {
538+
// Fast path: single quantile value for all groups
539+
let quantile = self.get_quantile_from_scalar(&quantile_ac)?;
524540

525-
// SAFETY:
526-
// groups are in bounds
527-
let mut agg = unsafe {
528-
ac.flat_naive()
529-
.into_owned()
530-
.agg_quantile(ac.groups(), quantile, self.method)
531-
};
532-
agg.rename(keep_name);
533-
Ok(AggregationContext::from_agg_state(
534-
AggregatedScalar(agg),
535-
Cow::Borrowed(groups),
536-
))
541+
if let AggState::LiteralScalar(c) = &mut ac.state {
542+
*c = c
543+
.quantile_reduce(quantile, self.method)?
544+
.into_column(keep_name);
545+
return Ok(ac);
546+
}
547+
548+
// SAFETY:
549+
// groups are in bounds
550+
let mut agg = unsafe {
551+
ac.flat_naive()
552+
.into_owned()
553+
.agg_quantile(ac.groups(), quantile, self.method)
554+
};
555+
agg.rename(keep_name);
556+
Ok(AggregationContext::from_agg_state(
557+
AggregatedScalar(agg),
558+
Cow::Borrowed(groups),
559+
))
560+
} else {
561+
// Different quantile value per group
562+
let quantiles = self.get_quantiles_per_group(&quantile_ac)?;
563+
564+
// SAFETY:
565+
// groups are in bounds
566+
let mut agg = unsafe {
567+
ac.flat_naive().into_owned().agg_varying_quantile(
568+
ac.groups(),
569+
&quantiles,
570+
self.method,
571+
)
572+
};
573+
agg.rename(keep_name);
574+
Ok(AggregationContext::from_agg_state(
575+
AggregatedScalar(agg),
576+
Cow::Borrowed(groups),
577+
))
578+
}
537579
}
538580

539581
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {

0 commit comments

Comments
 (0)