@@ -483,6 +483,57 @@ impl AggQuantileExpr {
483483 ) ;
484484 quantile. get ( 0 ) . unwrap ( ) . try_extract ( )
485485 }
486+
487+ fn get_quantile_from_scalar ( & self , quantile_ac : & AggregationContext ) -> PolarsResult < f64 > {
488+ let quantile_col = quantile_ac. get_values ( ) ;
489+ polars_ensure ! ( quantile_col. len( ) <= 1 , ComputeError :
490+ "polars only supports computing a single quantile; \
491+ make sure the 'quantile' expression input produces a single quantile"
492+ ) ;
493+ quantile_col. get ( 0 ) . unwrap ( ) . try_extract ( )
494+ }
495+
496+ /// Compute quantile per group when quantile values vary by group.
497+ fn agg_quantile_per_group (
498+ & self ,
499+ ac : & mut AggregationContext ,
500+ quantile_ac : & mut AggregationContext ,
501+ keep_name : PlSmallStr ,
502+ ) -> PolarsResult < Column > {
503+ let ac_list = ac. aggregated_as_list ( ) ;
504+ let quantile_values = quantile_ac. flat_naive ( ) ;
505+
506+ // Get quantile values as f64
507+ let quantile_values = quantile_values. cast ( & DataType :: Float64 ) ?;
508+ let quantile_arr = quantile_values. f64 ( ) ?;
509+
510+ let method = self . method ;
511+ let result = ac_list
512+ . amortized_iter ( )
513+ . zip ( quantile_arr. iter ( ) )
514+ . map ( |( opt_s, opt_q) | match ( opt_s, opt_q) {
515+ ( Some ( s) , Some ( q) ) => {
516+ let s = s. as_ref ( ) ;
517+ if !( 0.0 ..=1.0 ) . contains ( & q) {
518+ Ok ( None )
519+ } else {
520+ let scalar = s. quantile_reduce ( q, method) ?;
521+ // Extract f64 from the scalar, returning None if null
522+ if scalar. is_null ( ) {
523+ Ok ( None )
524+ } else {
525+ scalar. value ( ) . try_extract :: < f64 > ( ) . map ( Some )
526+ }
527+ }
528+ } ,
529+ _ => Ok ( None ) ,
530+ } )
531+ . collect :: < PolarsResult < Float64Chunked > > ( ) ?
532+ . with_name ( keep_name)
533+ . into_column ( ) ;
534+
535+ Ok ( result)
536+ }
486537}
487538
488539impl PhysicalExpr for AggQuantileExpr {
@@ -505,6 +556,7 @@ impl PhysicalExpr for AggQuantileExpr {
505556 state : & ExecutionState ,
506557 ) -> PolarsResult < AggregationContext < ' a > > {
507558 let mut ac = self . input . evaluate_on_groups ( df, groups, state) ?;
559+ let mut quantile_ac = self . quantile . evaluate_on_groups ( df, groups, state) ?;
508560
509561 // AggregatedScalar has no defined group structure. We fix it up here, so that we can
510562 // reliably call `agg_quantile` functions with the groups.
@@ -513,27 +565,41 @@ impl PhysicalExpr for AggQuantileExpr {
513565 // don't change names by aggregations as is done in polars-core
514566 let keep_name = ac. get_values ( ) . name ( ) . clone ( ) ;
515567
516- let quantile = self . get_quantile ( df, state) ?;
568+ // Check if quantile is a literal/scalar (same for all groups) or varies per group
569+ let is_uniform_quantile = quantile_ac. is_literal ( )
570+ || matches ! ( quantile_ac. agg_state( ) , AggState :: LiteralScalar ( _) ) ;
517571
518- if let AggState :: LiteralScalar ( c) = & mut ac. state {
519- * c = c
520- . quantile_reduce ( quantile, self . method ) ?
521- . into_column ( keep_name) ;
522- return Ok ( ac) ;
523- }
572+ if is_uniform_quantile {
573+ // Fast path: single quantile value for all groups
574+ let quantile = self . get_quantile_from_scalar ( & quantile_ac) ?;
524575
525- // SAFETY:
526- // groups are in bounds
527- let mut agg = unsafe {
528- ac. flat_naive ( )
529- . into_owned ( )
530- . agg_quantile ( ac. groups ( ) , quantile, self . method )
531- } ;
532- agg. rename ( keep_name) ;
533- Ok ( AggregationContext :: from_agg_state (
534- AggregatedScalar ( agg) ,
535- Cow :: Borrowed ( groups) ,
536- ) )
576+ if let AggState :: LiteralScalar ( c) = & mut ac. state {
577+ * c = c
578+ . quantile_reduce ( quantile, self . method ) ?
579+ . into_column ( keep_name) ;
580+ return Ok ( ac) ;
581+ }
582+
583+ // SAFETY:
584+ // groups are in bounds
585+ let mut agg = unsafe {
586+ ac. flat_naive ( )
587+ . into_owned ( )
588+ . agg_quantile ( ac. groups ( ) , quantile, self . method )
589+ } ;
590+ agg. rename ( keep_name) ;
591+ Ok ( AggregationContext :: from_agg_state (
592+ AggregatedScalar ( agg) ,
593+ Cow :: Borrowed ( groups) ,
594+ ) )
595+ } else {
596+ // Slow path: different quantile value per group
597+ let agg = self . agg_quantile_per_group ( & mut ac, & mut quantile_ac, keep_name) ?;
598+ Ok ( AggregationContext :: from_agg_state (
599+ AggregatedScalar ( agg) ,
600+ Cow :: Borrowed ( groups) ,
601+ ) )
602+ }
537603 }
538604
539605 fn to_field ( & self , input_schema : & Schema ) -> PolarsResult < Field > {
0 commit comments