Skip to content

Commit 1a2a088

Browse files
coastalwhiteorlp
andauthored
perf: Native streaming int_range with len or count (#24280)
Co-authored-by: Orson Peters <orsonpeters@gmail.com>
1 parent 3be90aa commit 1a2a088

File tree

3 files changed

+145
-29
lines changed

3 files changed

+145
-29
lines changed

crates/polars-plan/src/plans/aexpr/builder.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,22 @@ impl AExprBuilder {
381381
nc.gt(idx_zero, arena)
382382
}
383383

384+
pub fn drop_nulls(self, arena: &mut Arena<AExpr>) -> Self {
385+
Self::function(
386+
vec![self.expr_ir_retain_name(arena)],
387+
IRFunctionExpr::DropNulls,
388+
arena,
389+
)
390+
}
391+
392+
pub fn drop_nans(self, arena: &mut Arena<AExpr>) -> Self {
393+
Self::function(
394+
vec![self.expr_ir_retain_name(arena)],
395+
IRFunctionExpr::DropNans,
396+
arena,
397+
)
398+
}
399+
384400
pub fn eq(self, other: impl IntoAExprBuilder, arena: &mut Arena<AExpr>) -> Self {
385401
self.binary_op(other, Operator::Eq, arena)
386402
}
@@ -465,6 +481,10 @@ impl AExprBuilder {
465481
ExprIR::new(self.node(), OutputName::Alias(name.into()))
466482
}
467483

484+
pub fn expr_ir_retain_name(self, arena: &Arena<AExpr>) -> ExprIR {
485+
ExprIR::from_node(self.node(), arena)
486+
}
487+
468488
pub fn expr_ir_unnamed(self) -> ExprIR {
469489
self.expr_ir(PlSmallStr::EMPTY)
470490
}

crates/polars-stream/src/physical_plan/lower_expr.rs

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,92 @@ fn lower_exprs_with_ctx(
11961196
transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(value_key.clone())));
11971197
},
11981198

1199+
// pl.row_index() maps to this.
1200+
#[cfg(feature = "range")]
1201+
AExpr::Function {
1202+
input: ref inner_exprs,
1203+
function: IRFunctionExpr::Range(IRRangeFunction::IntRange { step: 1, dtype }),
1204+
options: _,
1205+
} if {
1206+
let start_is_zero = match ctx.expr_arena.get(inner_exprs[0].node()) {
1207+
AExpr::Literal(lit) => lit.extract_usize().ok() == Some(0),
1208+
_ => false,
1209+
};
1210+
let stop_is_len = matches!(ctx.expr_arena.get(inner_exprs[1].node()), AExpr::Len);
1211+
1212+
dtype == DataType::IDX_DTYPE && start_is_zero && stop_is_len
1213+
} =>
1214+
{
1215+
let out_name = unique_column_name();
1216+
let row_idx_col_aexpr = ctx.expr_arena.add(AExpr::Column(out_name.clone()));
1217+
let row_idx_col_expr_ir =
1218+
ExprIR::new(row_idx_col_aexpr, OutputName::ColumnLhs(out_name.clone()));
1219+
let row_idx_stream = build_select_stream_with_ctx(
1220+
build_row_idx_stream(input, out_name, None, ctx.phys_sm),
1221+
&[row_idx_col_expr_ir],
1222+
ctx,
1223+
)?;
1224+
input_streams.insert(row_idx_stream);
1225+
transformed_exprs.push(row_idx_col_aexpr);
1226+
},
1227+
1228+
#[cfg(feature = "range")]
1229+
AExpr::Function {
1230+
input: ref inner_exprs,
1231+
function: IRFunctionExpr::Range(IRRangeFunction::IntRange { step: 1, dtype }),
1232+
options: _,
1233+
} if {
1234+
let start_is_zero = match ctx.expr_arena.get(inner_exprs[0].node()) {
1235+
AExpr::Literal(lit) => lit.extract_usize().ok() == Some(0),
1236+
_ => false,
1237+
};
1238+
let stop_is_count = matches!(
1239+
ctx.expr_arena.get(inner_exprs[1].node()),
1240+
AExpr::Agg(IRAggExpr::Count { .. })
1241+
);
1242+
1243+
start_is_zero && stop_is_count
1244+
} =>
1245+
{
1246+
let AExpr::Agg(IRAggExpr::Count {
1247+
input: input_expr,
1248+
include_nulls,
1249+
}) = ctx.expr_arena.get(inner_exprs[1].node())
1250+
else {
1251+
unreachable!();
1252+
};
1253+
let (input_expr, include_nulls) = (*input_expr, *include_nulls);
1254+
1255+
let out_name = unique_column_name();
1256+
let mut row_idx_col_aexpr = ctx.expr_arena.add(AExpr::Column(out_name.clone()));
1257+
if dtype != IDX_DTYPE {
1258+
row_idx_col_aexpr = AExprBuilder::new_from_node(row_idx_col_aexpr)
1259+
.cast(dtype, ctx.expr_arena)
1260+
.node();
1261+
}
1262+
let row_idx_col_expr_ir =
1263+
ExprIR::new(row_idx_col_aexpr, OutputName::ColumnLhs(out_name.clone()));
1264+
1265+
let mut input_expr = AExprBuilder::new_from_node(input_expr);
1266+
if !include_nulls {
1267+
input_expr = input_expr.drop_nulls(ctx.expr_arena);
1268+
}
1269+
let input_expr = input_expr.expr_ir_retain_name(ctx.expr_arena);
1270+
1271+
let row_idx_stream = build_select_stream_with_ctx(
1272+
build_row_idx_stream(
1273+
build_select_stream_with_ctx(input, &[input_expr], ctx)?,
1274+
out_name,
1275+
None,
1276+
ctx.phys_sm,
1277+
),
1278+
&[row_idx_col_expr_ir],
1279+
ctx,
1280+
)?;
1281+
input_streams.insert(row_idx_stream);
1282+
transformed_exprs.push(row_idx_col_aexpr);
1283+
},
1284+
11991285
// Lower arbitrary elementwise functions.
12001286
ref node @ AExpr::Function {
12011287
input: ref inner_exprs,
@@ -1605,35 +1691,6 @@ fn lower_exprs_with_ctx(
16051691
transformed_exprs.push(AExprBuilder::col(out_name.clone(), ctx.expr_arena).node());
16061692
},
16071693

1608-
// pl.row_index() maps to this.
1609-
#[cfg(feature = "range")]
1610-
AExpr::Function {
1611-
input: ref inner_exprs,
1612-
function: IRFunctionExpr::Range(IRRangeFunction::IntRange { step: 1, dtype }),
1613-
options: _,
1614-
} if {
1615-
let start_is_zero = match ctx.expr_arena.get(inner_exprs[0].node()) {
1616-
AExpr::Literal(lit) => lit.extract_usize().ok() == Some(0),
1617-
_ => false,
1618-
};
1619-
let stop_is_len = matches!(ctx.expr_arena.get(inner_exprs[1].node()), AExpr::Len);
1620-
1621-
dtype == DataType::IDX_DTYPE && start_is_zero && stop_is_len
1622-
} =>
1623-
{
1624-
let out_name = unique_column_name();
1625-
let row_idx_col_aexpr = ctx.expr_arena.add(AExpr::Column(out_name.clone()));
1626-
let row_idx_col_expr_ir =
1627-
ExprIR::new(row_idx_col_aexpr, OutputName::ColumnLhs(out_name.clone()));
1628-
let row_idx_stream = build_select_stream_with_ctx(
1629-
build_row_idx_stream(input, out_name, None, ctx.phys_sm),
1630-
&[row_idx_col_expr_ir],
1631-
ctx,
1632-
)?;
1633-
input_streams.insert(row_idx_stream);
1634-
transformed_exprs.push(row_idx_col_aexpr);
1635-
},
1636-
16371694
AExpr::Slice {
16381695
input: inner,
16391696
offset,

py-polars/tests/unit/functions/range/test_int_range.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,42 @@ def test_int_ranges_non_numeric_input_should_error() -> None:
286286
_ = df.select(pl.int_ranges("start", "end"))
287287

288288
assert "conversion from `str` to `i64` failed" in str(excinfo.value)
289+
290+
291+
def test_int_range_len_count() -> None:
292+
values = [1, 2, None, 4, 5, 6]
293+
294+
lf = pl.Series("a", values).to_frame().lazy()
295+
296+
def irange(e: pl.Expr) -> pl.LazyFrame:
297+
return lf.select(r=pl.int_range(0, e, dtype=pl.get_index_type()))
298+
299+
q = irange(pl.len())
300+
assert_series_equal(
301+
q.collect().to_series(),
302+
pl.Series("r", [0, 1, 2, 3, 4, 5], pl.get_index_type()),
303+
)
304+
305+
q = irange(pl.col.a.len())
306+
assert_series_equal(
307+
q.collect().to_series(),
308+
pl.Series("r", [0, 1, 2, 3, 4, 5], pl.get_index_type()),
309+
)
310+
311+
q = irange(pl.col.a.filter(pl.col.a.ne_missing(4)).len())
312+
assert_series_equal(
313+
q.collect().to_series(),
314+
pl.Series("r", [0, 1, 2, 3, 4], pl.get_index_type()),
315+
)
316+
317+
q = irange(pl.col.a.count())
318+
assert_series_equal(
319+
q.collect().to_series(),
320+
pl.Series("r", [0, 1, 2, 3, 4], pl.get_index_type()),
321+
)
322+
323+
q = irange(pl.col.a.filter(pl.col.a.ne_missing(4)).count())
324+
assert_series_equal(
325+
q.collect().to_series(),
326+
pl.Series("r", [0, 1, 2, 3], pl.get_index_type()),
327+
)

0 commit comments

Comments
 (0)