Skip to content

Commit 4dd8b47

Browse files
authored
fix: Expression pl.concat was incorrectly marked as elementwise (#22019)
1 parent 8dd5382 commit 4dd8b47

File tree

3 files changed

+65
-1
lines changed

3 files changed

+65
-1
lines changed

Diff for: crates/polars-plan/src/dsl/functions/concat.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ pub fn concat_expr<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
9797
input: s,
9898
function: FunctionExpr::ConcatExpr(rechunk),
9999
options: FunctionOptions {
100-
collect_groups: ApplyOptions::ElementWise,
100+
collect_groups: ApplyOptions::GroupWise,
101101
flags: FunctionFlags::default() | FunctionFlags::INPUT_WILDCARD_EXPANSION,
102102
cast_options: Some(CastingRules::cast_to_supertypes()),
103103
..Default::default()

Diff for: crates/polars-stream/src/physical_plan/lower_expr.rs

+29
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,35 @@ fn lower_exprs_with_ctx(
539539
transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));
540540
},
541541

542+
AExpr::Function {
543+
input: ref inner_exprs,
544+
function: FunctionExpr::ConcatExpr(_rechunk),
545+
options: _,
546+
} => {
547+
// We have to lower each expression separately as they might have different lengths.
548+
let mut concat_streams = Vec::new();
549+
let out_name = unique_column_name();
550+
for inner_expr in inner_exprs {
551+
let (trans_input, trans_expr) =
552+
lower_exprs_with_ctx(input, &[inner_expr.node()], ctx)?;
553+
let select_expr =
554+
ExprIR::new(trans_expr[0], OutputName::Alias(out_name.clone()));
555+
concat_streams.push(build_select_stream_with_ctx(
556+
trans_input,
557+
&[select_expr],
558+
ctx,
559+
)?);
560+
}
561+
562+
let output_schema = ctx.phys_sm[concat_streams[0].node].output_schema.clone();
563+
let node_kind = PhysNodeKind::OrderedUnion {
564+
inputs: concat_streams,
565+
};
566+
let node_key = ctx.phys_sm.insert(PhysNode::new(output_schema, node_kind));
567+
input_streams.insert(PhysStream::first(node_key));
568+
transformed_exprs.push(ctx.expr_arena.add(AExpr::Column(out_name)));
569+
},
570+
542571
AExpr::Function {
543572
input: ref inner_exprs,
544573
function: FunctionExpr::Boolean(BooleanFunction::IsIn { nulls_equal }),

Diff for: py-polars/tests/unit/functions/test_concat.py

+35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import polars as pl
4+
from polars.testing import assert_frame_equal
45

56

67
@pytest.mark.slow
@@ -59,3 +60,37 @@ def test_concat_vertically_relaxed() -> None:
5960
"a": [1.0, 0.2, 1.0, 2.0],
6061
"b": [None, 0.1, 2.0, 1.0],
6162
}
63+
64+
65+
def test_concat_group_by() -> None:
66+
df = pl.DataFrame(
67+
{
68+
"g": [0, 0, 0, 0, 1, 1, 1, 1],
69+
"a": [0, 1, 2, 3, 4, 5, 6, 7],
70+
"b": [8, 9, 10, 11, 12, 13, 14, 15],
71+
}
72+
)
73+
out = df.group_by("g").agg(pl.concat([pl.col.a, pl.col.b]))
74+
75+
assert_frame_equal(
76+
out,
77+
pl.DataFrame(
78+
{
79+
"g": [0, 1],
80+
"a": [[0, 1, 2, 3, 8, 9, 10, 11], [4, 5, 6, 7, 12, 13, 14, 15]],
81+
}
82+
),
83+
check_row_order=False,
84+
)
85+
86+
87+
def test_concat_19877() -> None:
88+
df = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
89+
out = df.select(pl.concat([pl.col("a"), pl.col("b")]))
90+
assert_frame_equal(out, pl.DataFrame({"a": [1, 2, 3, 4]}))
91+
92+
93+
def test_concat_zip_series_21980() -> None:
94+
df = pl.DataFrame({"x": 1, "y": 2})
95+
out = df.select(pl.concat([pl.col.x, pl.col.y]), pl.Series([3, 4]))
96+
assert_frame_equal(out, pl.DataFrame({"x": [1, 2], "": [3, 4]}))

0 commit comments

Comments
 (0)