Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,25 @@ impl Column {
.into()
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_varying_quantile(
&self,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Self {
// @scalar-opt

unsafe {
self.as_materialized_series()
.agg_varying_quantile(groups, quantiles, method)
}
.into()
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
Expand Down
67 changes: 67 additions & 0 deletions crates/polars-core/src/frame/group_by/aggregations/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,73 @@ impl Series {
}
}

#[doc(hidden)]
pub unsafe fn agg_varying_quantile(
&self,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

use DataType::*;
match s.dtype() {
Float32 => s
.f32()
.unwrap()
.agg_varying_quantile(groups, quantiles, method),
Float64 => s
.f64()
.unwrap()
.agg_varying_quantile(groups, quantiles, method),
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => s
.cast(&DataType::Float64)
.unwrap()
.agg_varying_quantile(groups, quantiles, method),
#[cfg(feature = "dtype-datetime")]
Datetime(tu, tz) => self
.to_physical_repr()
.agg_varying_quantile(groups, quantiles, method)
.cast(&Int64)
.unwrap()
.into_datetime(*tu, tz.clone()),
#[cfg(feature = "dtype-duration")]
Duration(tu) => self
.to_physical_repr()
.agg_varying_quantile(groups, quantiles, method)
.cast(&Int64)
.unwrap()
.into_duration(*tu),
#[cfg(feature = "dtype-time")]
Time => self
.to_physical_repr()
.agg_varying_quantile(groups, quantiles, method)
.cast(&Int64)
.unwrap()
.into_time(),
#[cfg(feature = "dtype-date")]
Date => (self
.to_physical_repr()
.agg_varying_quantile(groups, quantiles, method)
.cast(&Float64)
.unwrap()
* (US_IN_DAY as f64))
.cast(&DataType::Int64)
.unwrap()
.into_datetime(TimeUnit::Microseconds, None),
dt if dt.is_primitive_numeric() => {
apply_method_physical_integer!(s, agg_varying_quantile, groups, quantiles, method)
},
_ => Series::full_null(PlSmallStr::EMPTY, groups.len(), s.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_last(&self, groups: &GroupsType) -> Series {
// Prevent a rechunk for every individual group.
Expand Down
109 changes: 109 additions & 0 deletions crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ where
ca.into_series()
}

/// Same as `agg_helper_idx_on_all` but also passes the group index to the closure.
fn agg_helper_idx_on_all_with_idx<T, F>(groups: &GroupsIdx, f: F) -> Series
where
F: Fn((usize, &IdxVec)) -> Option<T::Native> + Send + Sync,
T: PolarsNumericType,
{
let ca: ChunkedArray<T> =
POOL.install(|| groups.all().into_par_iter().enumerate().map(f).collect());
ca.into_series()
}

pub fn _agg_helper_slice<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
where
F: Fn([IdxSize; 2]) -> Option<T::Native> + Send + Sync,
Expand All @@ -191,6 +202,22 @@ where
ca.into_series()
}

/// Same as `_agg_helper_slice` but also passes the group index to the closure.
fn _agg_helper_slice_with_idx<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
where
F: Fn(usize, [IdxSize; 2]) -> Option<T::Native> + Send + Sync,
T: PolarsNumericType,
{
let ca: ChunkedArray<T> = POOL.install(|| {
groups
.par_iter()
.enumerate()
.map(|(idx, &g)| f(idx, g))
.collect()
});
ca.into_series()
}

pub fn _agg_helper_slice_no_null<T, F>(groups: &[[IdxSize; 2]], f: F) -> Series
where
F: Fn([IdxSize; 2]) -> T::Native + Send + Sync,
Expand Down Expand Up @@ -337,6 +364,64 @@ where
}
}

/// Compute quantile aggregation where each group can have a different quantile value.
unsafe fn agg_varying_quantile_generic<T, K>(
ca: &ChunkedArray<T>,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Series
where
T: PolarsNumericType,
ChunkedArray<T>: QuantileDispatcher<K::Native>,
K: PolarsNumericType,
<K as datatypes::PolarsNumericType>::Native: num_traits::Float + quantile_filter::SealedRolling,
{
match groups {
GroupsType::Idx(groups) => {
let ca = ca.rechunk();
agg_helper_idx_on_all_with_idx::<K, _>(groups, |(group_idx, idx)| {
debug_assert!(idx.len() <= ca.len());
let quantile = quantiles[group_idx];
if !(0.0..=1.0).contains(&quantile) {
return None;
}
match idx.len() {
0 => None,
1 => {
let idx = idx[0] as usize;
ca.get(idx).map(|v| NumCast::from(v).unwrap())
},
_ => {
let take = { ca.take_unchecked(idx) };
take._quantile(quantile, method).unwrap()
},
}
})
},
GroupsType::Slice { groups, .. } => {
_agg_helper_slice_with_idx::<K, _>(groups, |group_idx, [first, len]| {
debug_assert!(first + len <= ca.len() as IdxSize);
let quantile = quantiles[group_idx];
if !(0.0..=1.0).contains(&quantile) {
return None;
}
match len {
0 => None,
1 => ca.get(first as usize).map(|v| NumCast::from(v).unwrap()),
_ => {
let arr_group = _slice_from_offsets(ca, first, len);
arr_group
._quantile(quantile, method)
.unwrap_unchecked()
.map(|flt| NumCast::from(flt).unwrap_unchecked())
},
}
})
},
}
}

unsafe fn agg_median_generic<T, K>(ca: &ChunkedArray<T>, groups: &GroupsType) -> Series
where
T: PolarsNumericType,
Expand Down Expand Up @@ -1137,6 +1222,14 @@ impl Float32Chunked {
) -> Series {
agg_quantile_generic::<_, Float32Type>(self, groups, quantile, method)
}
pub(crate) unsafe fn agg_varying_quantile(
&self,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Series {
agg_varying_quantile_generic::<_, Float32Type>(self, groups, quantiles, method)
}
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
agg_median_generic::<_, Float32Type>(self, groups)
}
Expand All @@ -1150,6 +1243,14 @@ impl Float64Chunked {
) -> Series {
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
}
pub(crate) unsafe fn agg_varying_quantile(
&self,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Series {
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
}
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
agg_median_generic::<_, Float64Type>(self, groups)
}
Expand Down Expand Up @@ -1342,6 +1443,14 @@ where
) -> Series {
agg_quantile_generic::<_, Float64Type>(self, groups, quantile, method)
}
pub(crate) unsafe fn agg_varying_quantile(
&self,
groups: &GroupsType,
quantiles: &[f64],
method: QuantileMethod,
) -> Series {
agg_varying_quantile_generic::<_, Float64Type>(self, groups, quantiles, method)
}
pub(crate) unsafe fn agg_median(&self, groups: &GroupsType) -> Series {
agg_median_generic::<_, Float64Type>(self, groups)
}
Expand Down
80 changes: 61 additions & 19 deletions crates/polars-expr/src/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,22 @@ impl AggQuantileExpr {
);
quantile.get(0).unwrap().try_extract()
}

fn get_quantile_from_scalar(&self, quantile_ac: &AggregationContext) -> PolarsResult<f64> {
let quantile_col = quantile_ac.get_values();
polars_ensure!(quantile_col.len() <= 1, ComputeError:
"polars only supports computing a single quantile; \
make sure the 'quantile' expression input produces a single quantile"
);
quantile_col.get(0).unwrap().try_extract()
}

fn get_quantiles_per_group(&self, quantile_ac: &AggregationContext) -> PolarsResult<Vec<f64>> {
let quantile_col = quantile_ac.get_values();
let quantile_col = quantile_col.cast(&DataType::Float64)?;
let quantile_ca = quantile_col.f64()?;
Ok(quantile_ca.iter().map(|v| v.unwrap_or(f64::NAN)).collect())
}
}

impl PhysicalExpr for AggQuantileExpr {
Expand All @@ -542,6 +558,7 @@ impl PhysicalExpr for AggQuantileExpr {
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
let mut ac = self.input.evaluate_on_groups(df, groups, state)?;
let quantile_ac = self.quantile.evaluate_on_groups(df, groups, state)?;

// AggregatedScalar has no defined group structure. We fix it up here, so that we can
// reliably call `agg_quantile` functions with the groups.
Expand All @@ -550,27 +567,52 @@ impl PhysicalExpr for AggQuantileExpr {
// don't change names by aggregations as is done in polars-core
let keep_name = ac.get_values().name().clone();

let quantile = self.get_quantile(df, state)?;
// Check if quantile is a literal/scalar (same for all groups) or varies per group
let is_uniform_quantile = quantile_ac.is_literal()
|| matches!(quantile_ac.agg_state(), AggState::LiteralScalar(_));

if let AggState::LiteralScalar(c) = &mut ac.state {
*c = c
.quantile_reduce(quantile, self.method)?
.into_column(keep_name);
return Ok(ac);
}
if is_uniform_quantile {
// Fast path: single quantile value for all groups
let quantile = self.get_quantile_from_scalar(&quantile_ac)?;

// SAFETY:
// groups are in bounds
let mut agg = unsafe {
ac.flat_naive()
.into_owned()
.agg_quantile(ac.groups(), quantile, self.method)
};
agg.rename(keep_name);
Ok(AggregationContext::from_agg_state(
AggregatedScalar(agg),
Cow::Borrowed(groups),
))
if let AggState::LiteralScalar(c) = &mut ac.state {
*c = c
.quantile_reduce(quantile, self.method)?
.into_column(keep_name);
return Ok(ac);
}

// SAFETY:
// groups are in bounds
let mut agg = unsafe {
ac.flat_naive()
.into_owned()
.agg_quantile(ac.groups(), quantile, self.method)
};
agg.rename(keep_name);
Ok(AggregationContext::from_agg_state(
AggregatedScalar(agg),
Cow::Borrowed(groups),
))
} else {
// Different quantile value per group
let quantiles = self.get_quantiles_per_group(&quantile_ac)?;

// SAFETY:
// groups are in bounds
let mut agg = unsafe {
ac.flat_naive().into_owned().agg_varying_quantile(
ac.groups(),
&quantiles,
self.method,
)
};
agg.rename(keep_name);
Ok(AggregationContext::from_agg_state(
AggregatedScalar(agg),
Cow::Borrowed(groups),
))
}
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
Expand Down
Loading
Loading