Skip to content

Commit 320af19

Browse files
wtnclaude
andcommitted
fix: Support varying quantile values per group in group_by aggregation
Co-authored-by: Claude <noreply@anthropic.com>
1 parent 04e8c78 commit 320af19

File tree

2 files changed

+124
-19
lines changed

2 files changed

+124
-19
lines changed

crates/polars-expr/src/expressions/aggregation.rs

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,57 @@ impl AggQuantileExpr {
483483
);
484484
quantile.get(0).unwrap().try_extract()
485485
}
486+
487+
fn get_quantile_from_scalar(&self, quantile_ac: &AggregationContext) -> PolarsResult<f64> {
488+
let quantile_col = quantile_ac.get_values();
489+
polars_ensure!(quantile_col.len() <= 1, ComputeError:
490+
"polars only supports computing a single quantile; \
491+
make sure the 'quantile' expression input produces a single quantile"
492+
);
493+
quantile_col.get(0).unwrap().try_extract()
494+
}
495+
496+
/// Compute quantile per group when quantile values vary by group.
497+
fn agg_quantile_per_group(
498+
&self,
499+
ac: &mut AggregationContext,
500+
quantile_ac: &mut AggregationContext,
501+
keep_name: PlSmallStr,
502+
) -> PolarsResult<Column> {
503+
let ac_list = ac.aggregated_as_list();
504+
let quantile_values = quantile_ac.flat_naive();
505+
506+
// Get quantile values as f64
507+
let quantile_values = quantile_values.cast(&DataType::Float64)?;
508+
let quantile_arr = quantile_values.f64()?;
509+
510+
let method = self.method;
511+
let result = ac_list
512+
.amortized_iter()
513+
.zip(quantile_arr.iter())
514+
.map(|(opt_s, opt_q)| match (opt_s, opt_q) {
515+
(Some(s), Some(q)) => {
516+
let s = s.as_ref();
517+
if !(0.0..=1.0).contains(&q) {
518+
Ok(None)
519+
} else {
520+
let scalar = s.quantile_reduce(q, method)?;
521+
// Extract f64 from the scalar, returning None if null
522+
if scalar.is_null() {
523+
Ok(None)
524+
} else {
525+
scalar.value().try_extract::<f64>().map(Some)
526+
}
527+
}
528+
},
529+
_ => Ok(None),
530+
})
531+
.collect::<PolarsResult<Float64Chunked>>()?
532+
.with_name(keep_name)
533+
.into_column();
534+
535+
Ok(result)
536+
}
486537
}
487538

488539
impl PhysicalExpr for AggQuantileExpr {
@@ -505,6 +556,7 @@ impl PhysicalExpr for AggQuantileExpr {
505556
state: &ExecutionState,
506557
) -> PolarsResult<AggregationContext<'a>> {
507558
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
559+
let mut quantile_ac = self.quantile.evaluate_on_groups(df, groups, state)?;
508560

509561
// AggregatedScalar has no defined group structure. We fix it up here, so that we can
510562
// reliably call `agg_quantile` functions with the groups.
@@ -513,27 +565,41 @@ impl PhysicalExpr for AggQuantileExpr {
513565
// don't change names by aggregations as is done in polars-core
514566
let keep_name = ac.get_values().name().clone();
515567

516-
let quantile = self.get_quantile(df, state)?;
568+
// Check if quantile is a literal/scalar (same for all groups) or varies per group
569+
let is_uniform_quantile = quantile_ac.is_literal()
570+
|| matches!(quantile_ac.agg_state(), AggState::LiteralScalar(_));
517571

518-
if let AggState::LiteralScalar(c) = &mut ac.state {
519-
*c = c
520-
.quantile_reduce(quantile, self.method)?
521-
.into_column(keep_name);
522-
return Ok(ac);
523-
}
572+
if is_uniform_quantile {
573+
// Fast path: single quantile value for all groups
574+
let quantile = self.get_quantile_from_scalar(&quantile_ac)?;
524575

525-
// SAFETY:
526-
// groups are in bounds
527-
let mut agg = unsafe {
528-
ac.flat_naive()
529-
.into_owned()
530-
.agg_quantile(ac.groups(), quantile, self.method)
531-
};
532-
agg.rename(keep_name);
533-
Ok(AggregationContext::from_agg_state(
534-
AggregatedScalar(agg),
535-
Cow::Borrowed(groups),
536-
))
576+
if let AggState::LiteralScalar(c) = &mut ac.state {
577+
*c = c
578+
.quantile_reduce(quantile, self.method)?
579+
.into_column(keep_name);
580+
return Ok(ac);
581+
}
582+
583+
// SAFETY:
584+
// groups are in bounds
585+
let mut agg = unsafe {
586+
ac.flat_naive()
587+
.into_owned()
588+
.agg_quantile(ac.groups(), quantile, self.method)
589+
};
590+
agg.rename(keep_name);
591+
Ok(AggregationContext::from_agg_state(
592+
AggregatedScalar(agg),
593+
Cow::Borrowed(groups),
594+
))
595+
} else {
596+
// Slow path: different quantile value per group
597+
let agg = self.agg_quantile_per_group(&mut ac, &mut quantile_ac, keep_name)?;
598+
Ok(AggregationContext::from_agg_state(
599+
AggregatedScalar(agg),
600+
Cow::Borrowed(groups),
601+
))
602+
}
537603
}
538604

539605
fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {

py-polars/tests/unit/operations/aggregation/test_aggregations.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,45 @@ def test_quantile_expr_input() -> None:
2727
)
2828

2929

30+
def test_quantile_varying_by_group_20951() -> None:
31+
df = pl.DataFrame(
32+
{
33+
"value": [1, 2, 1, 2],
34+
"quantile": [0, 0, 1, 1],
35+
}
36+
)
37+
result = df.group_by(pl.col.quantile).agg(
38+
pl.col.value.quantile(pl.col.quantile.first())
39+
).sort("quantile")
40+
41+
expected = pl.DataFrame(
42+
{
43+
"quantile": [0, 1],
44+
"value": [1.0, 2.0],
45+
}
46+
)
47+
assert_frame_equal(result, expected)
48+
49+
df = pl.DataFrame(
50+
{
51+
"group": ["a", "a", "a", "b", "b", "b"],
52+
"value": [1, 2, 3, 10, 20, 30],
53+
"q": [0.0, 0.0, 0.0, 0.5, 0.5, 0.5],
54+
}
55+
)
56+
result = df.group_by("group").agg(
57+
pl.col.value.quantile(pl.col.q.first())
58+
).sort("group")
59+
60+
expected = pl.DataFrame(
61+
{
62+
"group": ["a", "b"],
63+
"value": [1.0, 20.0],
64+
}
65+
)
66+
assert_frame_equal(result, expected)
67+
68+
3069
def test_boolean_aggs() -> None:
3170
df = pl.DataFrame({"bool": [True, False, None, True]})
3271

0 commit comments

Comments
 (0)