Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 61 additions & 40 deletions crates/polars-plan/src/plans/python/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,48 @@ fn sanitize(name: &str) -> Option<&str> {
}
}

fn series_to_pyarrow_list(s: &polars_core::prelude::Series) -> Option<String> {
if s.is_empty() || s.len() > 100 {
return None;
}
let mut list_repr = String::with_capacity(s.len() * 5);
list_repr.push('[');
for av in s.iter() {
match av {
AnyValue::Boolean(v) => {
let s = if v { "True" } else { "False" };
write!(list_repr, "{s},").unwrap();
},
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(v, tu, tz) => {
let dtm = to_py_datetime(v, &tu, tz);
write!(list_repr, "{dtm},").unwrap();
},
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => {
write!(list_repr, "to_py_date({v}),").unwrap();
},
AnyValue::String(s) => {
let _ = sanitize(s)?;
write!(list_repr, "{av},").unwrap();
},
// Hard to sanitize
AnyValue::Binary(_) | AnyValue::List(_) => return None,
#[cfg(feature = "dtype-array")]
AnyValue::Array(_, _) => return None,
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(_, _, _) => return None,
_ => {
write!(list_repr, "{av},").unwrap();
},
}
}
// pop last comma
list_repr.pop();
list_repr.push(']');
Some(list_repr)
}

// convert to a pyarrow expression that can be evaluated with pythons eval
pub fn predicate_to_pa(
predicate: Node,
Expand All @@ -55,45 +97,10 @@ pub fn predicate_to_pa(
Some(format!("pa.compute.field('{name}')"))
},
AExpr::Literal(LiteralValue::Series(s)) => {
if !args.allow_literal_series || s.is_empty() || s.len() > 100 {
if !args.allow_literal_series {
None
} else {
let mut list_repr = String::with_capacity(s.len() * 5);
list_repr.push('[');
for av in s.iter() {
match av {
AnyValue::Boolean(v) => {
let s = if v { "True" } else { "False" };
write!(list_repr, "{s},").unwrap();
},
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(v, tu, tz) => {
let dtm = to_py_datetime(v, &tu, tz);
write!(list_repr, "{dtm},").unwrap();
},
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => {
write!(list_repr, "to_py_date({v}),").unwrap();
},
AnyValue::String(s) => {
let _ = sanitize(s)?;
write!(list_repr, "{av},").unwrap();
},
// Hard to sanitize
AnyValue::Binary(_) | AnyValue::List(_) => return None,
#[cfg(feature = "dtype-array")]
AnyValue::Array(_, _) => return None,
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(_, _, _) => return None,
_ => {
write!(list_repr, "{av},").unwrap();
},
}
}
// pop last comma
list_repr.pop();
list_repr.push(']');
Some(list_repr)
series_to_pyarrow_list(s)
}
},
AExpr::Literal(lv) => {
Expand Down Expand Up @@ -163,9 +170,23 @@ pub fn predicate_to_pa(
..
} => {
let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;
let mut args = args;
args.allow_literal_series = true;
let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;
let rhs_node = input.get(1)?.node();
let mut is_in_args = args;
is_in_args.allow_literal_series = true;
let values = predicate_to_pa(rhs_node, expr_arena, is_in_args)
.or_else(|| {
// Handle AnyValue::List directly for is_in RHS only
match expr_arena.get(rhs_node) {
AExpr::Literal(lv) => {
let av = lv.to_any_value()?;
match av.as_borrowed() {
AnyValue::List(s) => series_to_pyarrow_list(&s),
_ => None,
}
},
_ => None,
}
})?;

Some(format!("({col}).isin({values})"))
},
Expand Down
48 changes: 48 additions & 0 deletions py-polars/tests/unit/io/test_pyarrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,54 @@ def test_pyarrow_dataset_partial_predicate_pushdown(
assert_frame_equal(result, expected)


def test_pyarrow_dataset_is_in_predicate_pushdown(
tmp_path: Path,
plmonkeypatch: PlMonkeyPatch,
capfd: pytest.CaptureFixture[str],
) -> None:
plmonkeypatch.setenv("POLARS_VERBOSE_SENSITIVE", "1")

df = pl.DataFrame({"id": [1, 2, 3, 4, 5], "val": [10, 20, 30, 40, 50]})
path = tmp_path / "test.parquet"
df.write_parquet(path)
dset = ds.dataset(path, format="parquet")

q = pl.scan_pyarrow_dataset(dset).filter(pl.col("id").is_in([1, 3]))

capfd.readouterr()
result = q.collect()
capture = capfd.readouterr().err

# Verify: predicate fully pushed to pyarrow
assert "(pa.compute.field('id')).isin([1,3])" in capture
# Verify: no residual predicate
assert "residual predicate: None" in capture

# Verify: correctness
expected = df.lazy().filter(pl.col("id").is_in([1, 3])).collect()
assert_frame_equal(result, expected)


def test_pyarrow_dataset_list_literal_not_pushed(
tmp_path: Path,
) -> None:
"""List literals outside of is_in() should not be pushed down to pyarrow."""
df = pl.DataFrame({"id": [1, 2, 3, 4, 5], "val": [10, 20, 30, 40, 50]})
path = tmp_path / "test.parquet"
df.write_parquet(path)
dset = ds.dataset(path, format="parquet")

# is_in predicate should be fully pushed (no FILTER in plan)
q_is_in = pl.scan_pyarrow_dataset(dset).filter(pl.col("id").is_in([1, 3]))
assert "FILTER" not in q_is_in.explain()

# A bare series literal should NOT be pushed (FILTER remains in plan)
q_series = pl.scan_pyarrow_dataset(dset).filter(
pl.col("id") == pl.lit(pl.Series([1, 3]))
)
assert "FILTER" in q_series.explain()


def test_pyarrow_dataset_comm_subplan_elim(tmp_path: Path) -> None:
df0 = pl.DataFrame({"a": [1, 2, 3]})

Expand Down