Skip to content

Commit 9ab292c

Browse files
wtnclaude
andcommitted
fix: Support varying quantile values per group in group_by aggregation
Co-authored-by: Claude <noreply@anthropic.com>
1 parent 04e8c78 commit 9ab292c

File tree

5 files changed

+293
-19
lines changed

5 files changed

+293
-19
lines changed

crates/polars-core/src/frame/column/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,25 @@ impl Column {
831831
.into()
832832
}
833833

834+
/// # Safety
835+
///
836+
/// Does no bounds checks, groups must be correct.
837+
#[cfg(feature = "algorithm_group_by")]
838+
pub unsafe fn agg_varying_quantile(
839+
&self,
840+
groups: &GroupsType,
841+
quantiles: &[f64],
842+
method: QuantileMethod,
843+
) -> Self {
844+
// @scalar-opt
845+
846+
unsafe {
847+
self.as_materialized_series()
848+
.agg_varying_quantile(groups, quantiles, method)
849+
}
850+
.into()
851+
}
852+
834853
/// # Safety
835854
///
836855
/// Does no bounds checks, groups must be correct.

crates/polars-core/src/frame/group_by/aggregations/dispatch.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,73 @@ impl Series {
369369
}
370370
}
371371

372+
#[doc(hidden)]
373+
pub unsafe fn agg_varying_quantile(
374+
&self,
375+
groups: &GroupsType,
376+
quantiles: &[f64],
377+
method: QuantileMethod,
378+
) -> Series {
379+
// Prevent a rechunk for every individual group.
380+
let s = if groups.len() > 1 {
381+
self.rechunk()
382+
} else {
383+
self.clone()
384+
};
385+
386+
use DataType::*;
387+
match s.dtype() {
388+
Float32 => s
389+
.f32()
390+
.unwrap()
391+
.agg_varying_quantile(groups, quantiles, method),
392+
Float64 => s
393+
.f64()
394+
.unwrap()
395+
.agg_varying_quantile(groups, quantiles, method),
396+
#[cfg(feature = "dtype-decimal")]
397+
Decimal(_, _) => s
398+
.cast(&DataType::Float64)
399+
.unwrap()
400+
.agg_varying_quantile(groups, quantiles, method),
401+
#[cfg(feature = "dtype-datetime")]
402+
Datetime(tu, tz) => self
403+
.to_physical_repr()
404+
.agg_varying_quantile(groups, quantiles, method)
405+
.cast(&Int64)
406+
.unwrap()
407+
.into_datetime(*tu, tz.clone()),
408+
#[cfg(feature = "dtype-duration")]
409+
Duration(tu) => self
410+
.to_physical_repr()
411+
.agg_varying_quantile(groups, quantiles, method)
412+
.cast(&Int64)
413+
.unwrap()
414+
.into_duration(*tu),
415+
#[cfg(feature = "dtype-time")]
416+
Time => self
417+
.to_physical_repr()
418+
.agg_varying_quantile(groups, quantiles, method)
419+
.cast(&Int64)
420+
.unwrap()
421+
.into_time(),
422+
#[cfg(feature = "dtype-date")]
423+
Date => (self
424+
.to_physical_repr()
425+
.agg_varying_quantile(groups, quantiles, method)
426+
.cast(&Float64)
427+
.unwrap()
428+
* (US_IN_DAY as f64))
429+
.cast(&DataType::Int64)
430+
.unwrap()
431+
.into_datetime(TimeUnit::Microseconds, None),
432+
dt if dt.is_primitive_numeric() => {
433+
apply_method_physical_integer!(s, agg_varying_quantile, groups, quantiles, method)
434+
},
435+
_ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
436+
}
437+
}
438+
372439
#[doc(hidden)]
373440
pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
374441
// Prevent a rechunk for every individual group.

crates/polars-core/src/frame/group_by/aggregations/mod.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,23 @@ where
193193
ca.into_series()
194194
}
195195

196+
/// Same as `agg_helper_idx_on_all` but also passes the group index to the closure.
197+
fn agg_helper_idx_on_all_with_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
198+
where
199+
F: Fn(usize, &IdxVec) -> Option<T::Native> + Send + Sync,
200+
T: PolarsNumericType,
201+
{
202+
let ca: ChunkedArray<T> = POOL.install(|| {
203+
groups
204+
.all()
205+
.into_par_iter()
206+
.enumerate()
207+
.map(|(idx, g)| f(idx, g))
208+
.collect()
209+
});
210+
ca.into_series()
211+
}
212+
196213
pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
197214
where
198215
F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
@@ -202,6 +219,22 @@ where
202219
ca.into_series()
203220
}
204221

222+
/// Same as `_agg_helper_slice` but also passes the group index to the closure.
223+
fn _agg_helper_slice_with_idx<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
224+
where
225+
F: Fn(usize, [IdxSize; 2]) -> Option<T::Native> + Send + Sync,
226+
T: PolarsNumericType,
227+
{
228+
let ca: ChunkedArray<T> = POOL.install(|| {
229+
groups
230+
.par_iter()
231+
.enumerate()
232+
.map(|(idx, &g)| f(idx, g))
233+
.collect()
234+
});
235+
ca.into_series()
236+
}
237+
205238
pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
206239
where
207240
F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
@@ -344,6 +377,58 @@ where
344377
}
345378
}
346379

380+
/// Compute quantile aggregation where each group can have a different quantile value.
381+
unsafe fn agg_varying_quantile_generic<T, K>(
382+
ca: &ChunkedArray<T>,
383+
groups: &GroupsType,
384+
quantiles: &[f64],
385+
method: QuantileMethod,
386+
) -> Series
387+
where
388+
T: PolarsNumericType,
389+
ChunkedArray<T>: QuantileDispatcher<K::Native>,
390+
K: PolarsNumericType,
391+
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
392+
{
393+
match groups {
394+
GroupsType::Idx(groups) => {
395+
let ca = ca.rechunk();
396+
agg_helper_idx_on_all_with_idx::<K, _>(groups, |group_idx, idx| {
397+
debug_assert!(idx.len() <= ca.len());
398+
if idx.is_empty() {
399+
return None;
400+
}
401+
let quantile = quantiles[group_idx];
402+
if !(0.0..=1.0).contains(&quantile) {
403+
return None;
404+
}
405+
let take = { ca.take_unchecked(idx) };
406+
take._quantile(quantile, method).unwrap_unchecked()
407+
})
408+
},
409+
GroupsType::Slice { groups, .. } => {
410+
_agg_helper_slice_with_idx::<K, _>(groups, |group_idx, [first, len]| {
411+
debug_assert!(first + len <= ca.len() as IdxSize);
412+
let quantile = quantiles[group_idx];
413+
if !(0.0..=1.0).contains(&quantile) {
414+
return None;
415+
}
416+
match len {
417+
0 => None,
418+
1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
419+
_ => {
420+
let arr_group = _slice_from_offsets(ca, first, len);
421+
arr_group
422+
._quantile(quantile, method)
423+
.unwrap_unchecked()
424+
.map(|flt| NumCast::from(flt).unwrap_unchecked())
425+
},
426+
}
427+
})
428+
},
429+
}
430+
}
431+
347432
unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
348433
where
349434
T: PolarsNumericType,
@@ -895,6 +980,14 @@ impl Float32Chunked {
895980
) -> Series {
896981
agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
897982
}
983+
pub(crate) unsafe fn agg_varying_quantile(
984+
&self,
985+
groups: &GroupsType,
986+
quantiles: &[f64],
987+
method: QuantileMethod,
988+
) -> Series {
989+
agg_varying_quantile_generic::<_, Float32Type>(self, groups, quantiles, method)
990+
}
898991
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
899992
agg_median_generic::<_, Float32Type>(self, groups)
900993
}
@@ -908,6 +1001,14 @@ impl Float64Chunked {
9081001
) -> Series {
9091002
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
9101003
}
1004+
pub(crate) unsafe fn agg_varying_quantile(
1005+
&self,
1006+
groups: &GroupsType,
1007+
quantiles: &[f64],
1008+
method: QuantileMethod,
1009+
) -> Series {
1010+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1011+
}
9111012
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
9121013
agg_median_generic::<_, Float64Type>(self, groups)
9131014
}
@@ -1097,6 +1198,14 @@ where
10971198
) -> Series {
10981199
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
10991200
}
1201+
pub(crate) unsafe fn agg_varying_quantile(
1202+
&self,
1203+
groups: &GroupsType,
1204+
quantiles: &[f64],
1205+
method: QuantileMethod,
1206+
) -> Series {
1207+
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
1208+
}
11001209
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
11011210
agg_median_generic::<_, Float64Type>(self, groups)
11021211
}

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,22 @@ impl AggQuantileExpr {
483483
);
484484
quantile.get(0).unwrap().try_extract()
485485
}
486+
487+
fn get_quantile_from_scalar(&self, quantile_ac: &AggregationContext) -> PolarsResult<f64> {
488+
let quantile_col = quantile_ac.get_values();
489+
polars_ensure!(quantile_col.len() <= 1, ComputeError:
490+
"polars only supports computing a single quantile; \
491+
make sure the 'quantile' expression input produces a single quantile"
492+
);
493+
quantile_col.get(0).unwrap().try_extract()
494+
}
495+
496+
fn get_quantiles_per_group(&self, quantile_ac: &AggregationContext) -> PolarsResult<Vec<f64>> {
497+
let quantile_col = quantile_ac.get_values();
498+
let quantile_col = quantile_col.cast(&DataType::Float64)?;
499+
let quantile_ca = quantile_col.f64()?;
500+
Ok(quantile_ca.iter().map(|v| v.unwrap_or(f64::NAN)).collect())
501+
}
486502
}
487503

488504
impl PhysicalExpr for AggQuantileExpr {
@@ -505,6 +521,7 @@ impl PhysicalExpr for AggQuantileExpr {
505521
state: &ExecutionState,
506522
) -> PolarsResult<AggregationContext<'a>> {
507523
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
524+
let quantile_ac = self.quantile.evaluate_on_groups(df, groups, state)?;
508525

509526
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
510527
// reliably call `agg_quantile` functions with the groups.
@@ -513,27 +530,50 @@ impl PhysicalExpr for AggQuantileExpr {
513530
// don't change names by aggregations as is done in polars-core
514531
let keep_name = ac.get_values().name().clone();
515532

516-
let quantile = self.get_quantile(df, state)?;
533+
// Check if quantile is a literal/scalar (same for all groups) or varies per group
534+
let is_uniform_quantile = quantile_ac.is_literal()
535+
|| matches!(quantile_ac.agg_state(), AggState::LiteralScalar(_));
517536

518-
if let AggState::LiteralScalar(c) = &mut ac.state {
519-
*c = c
520-
.quantile_reduce(quantile, self.method)?
521-
.into_column(keep_name);
522-
return Ok(ac);
523-
}
537+
if is_uniform_quantile {
538+
// Fast path: single quantile value for all groups
539+
let quantile = self.get_quantile_from_scalar(&quantile_ac)?;
524540

525-
// SAFETY:
526-
// groups are in bounds
527-
let mut agg = unsafe {
528-
ac.flat_naive()
529-
.into_owned()
530-
.agg_quantile(ac.groups(), quantile, self.method)
531-
};
532-
agg.rename(keep_name);
533-
Ok(AggregationContext::from_agg_state(
534-
AggregatedScalar(agg),
535-
Cow::Borrowed(groups),
536-
))
541+
if let AggState::LiteralScalar(c) = &mut ac.state {
542+
*c = c
543+
.quantile_reduce(quantile, self.method)?
544+
.into_column(keep_name);
545+
return Ok(ac);
546+
}
547+
548+
// SAFETY:
549+
// groups are in bounds
550+
let mut agg = unsafe {
551+
ac.flat_naive()
552+
.into_owned()
553+
.agg_quantile(ac.groups(), quantile, self.method)
554+
};
555+
agg.rename(keep_name);
556+
Ok(AggregationContext::from_agg_state(
557+
AggregatedScalar(agg),
558+
Cow::Borrowed(groups),
559+
))
560+
} else {
561+
// Different quantile value per group
562+
let quantiles = self.get_quantiles_per_group(&quantile_ac)?;
563+
564+
// SAFETY:
565+
// groups are in bounds
566+
let mut agg = unsafe {
567+
ac.flat_naive()
568+
.into_owned()
569+
.agg_varying_quantile(ac.groups(), &quantiles, self.method)
570+
};
571+
agg.rename(keep_name);
572+
Ok(AggregationContext::from_agg_state(
573+
AggregatedScalar(agg),
574+
Cow::Borrowed(groups),
575+
))
576+
}
537577
}
538578

539579
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {

py-polars/tests/unit/operations/aggregation/test_aggregations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,45 @@ def test_quantile_expr_input() -> None:
2727
)
2828

2929

30+
def test_quantile_varying_by_group_20951() -> None:
31+
df = pl.DataFrame(
32+
{
33+
"value": [1, 2, 1, 2],
34+
"quantile": [0, 0, 1, 1],
35+
}
36+
)
37+
result = df.group_by(pl.col.quantile).agg(
38+
pl.col.value.quantile(pl.col.quantile.first())
39+
).sort("quantile")
40+
41+
expected = pl.DataFrame(
42+
{
43+
"quantile": [0, 1],
44+
"value": [1.0, 2.0],
45+
}
46+
)
47+
assert_frame_equal(result, expected)
48+
49+
df = pl.DataFrame(
50+
{
51+
"group": ["a", "a", "a", "b", "b", "b"],
52+
"value": [1, 2, 3, 10, 20, 30],
53+
"q": [0.0, 0.0, 0.0, 0.5, 0.5, 0.5],
54+
}
55+
)
56+
result = df.group_by("group").agg(
57+
pl.col.value.quantile(pl.col.q.first())
58+
).sort("group")
59+
60+
expected = pl.DataFrame(
61+
{
62+
"group": ["a", "b"],
63+
"value": [1.0, 20.0],
64+
}
65+
)
66+
assert_frame_equal(result, expected)
67+
68+
3069
def test_boolean_aggs() -> None:
3170
df = pl.DataFrame({"bool": [True, False, None, True]})
3271

0 commit comments

Comments
 (0)