Skip to content

Commit 40936a5

Browse files
alexander-beediedongchao-1
authored andcommitted
feat(python): Support running Polars SQL queries against any objects implementing the PyCapsule interface (pola-rs#22235)
1 parent e0089d4 commit 40936a5

File tree

2 files changed

+40
-48
lines changed

2 files changed

+40
-48
lines changed

py-polars/polars/sql/context.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from polars._typing import FrameType
1414
from polars._utils.deprecation import deprecate_renamed_parameter
15+
from polars._utils.pycapsule import is_pycapsule
1516
from polars._utils.unstable import issue_unstable_warning
1617
from polars._utils.various import _get_stack_locals
1718
from polars._utils.wrap import wrap_ldf
@@ -52,14 +53,14 @@
5253
pa.RecordBatch,
5354
]
5455

55-
5656
__all__ = ["SQLContext"]
5757

5858

5959
def _compatible_frame(obj: Any) -> bool:
6060
"""Check if the object can be converted to DataFrame."""
6161
return (
62-
isinstance(obj, (DataFrame, LazyFrame, Series))
62+
is_pycapsule(obj)
63+
or isinstance(obj, LazyFrame)
6364
or (_check_for_pandas(obj) and isinstance(obj, (pd.DataFrame, pd.Series)))
6465
or (_check_for_pyarrow(obj) and isinstance(obj, (pa.Table, pa.RecordBatch)))
6566
)
@@ -68,14 +69,16 @@ def _compatible_frame(obj: Any) -> bool:
6869
def _ensure_lazyframe(obj: Any) -> LazyFrame:
6970
"""Return LazyFrame from compatible input."""
7071
if isinstance(obj, (DataFrame, LazyFrame)):
71-
return obj if isinstance(obj, LazyFrame) else obj.lazy()
72+
return obj.lazy()
7273
elif isinstance(obj, Series):
7374
return obj.to_frame().lazy()
7475
elif _check_for_pandas(obj) and isinstance(obj, (pd.DataFrame, pd.Series)):
7576
if isinstance(frame := from_pandas(obj), Series):
7677
frame = frame.to_frame()
7778
return frame.lazy()
78-
elif _check_for_pyarrow(obj) and isinstance(obj, (pa.Table, pa.RecordBatch)):
79+
elif is_pycapsule(obj) or (
80+
_check_for_pyarrow(obj) and isinstance(obj, (pa.Table, pa.RecordBatch))
81+
):
7982
return from_arrow(obj).lazy() # type: ignore[union-attr]
8083
else:
8184
msg = f"Unrecognised frame type: {type(obj)}"

py-polars/tests/unit/sql/test_miscellaneous.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import polars as pl
1010
from polars.exceptions import ColumnNotFoundError, SQLInterfaceError, SQLSyntaxError
1111
from polars.testing import assert_frame_equal
12+
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder
1213

1314
if TYPE_CHECKING:
1415
from polars.datatypes import DataType
@@ -243,6 +244,7 @@ def test_sql_on_compatible_frame_types() -> None:
243244
dfp = df.to_pandas()
244245
dfa = df.to_arrow()
245246
dfb = dfa.to_batches()[0] # noqa: F841
247+
dfo = PyCapsuleStreamHolder(df) # noqa: F841
246248

247249
# run polars sql query against all frame types
248250
for dfs in ( # noqa: B007
@@ -256,14 +258,15 @@ def test_sql_on_compatible_frame_types() -> None:
256258
UNION ALL SELECT * FROM dfp -- pandas frame
257259
UNION ALL SELECT * FROM dfa -- pyarrow table
258260
UNION ALL SELECT * FROM dfb -- pyarrow record batch
261+
UNION ALL SELECT * FROM dfo -- arbitrary pycapsule object
259262
) tbl
260263
INNER JOIN dfs ON dfs.c == tbl.b -- join on pandas/polars series
261264
GROUP BY "a", "b"
262265
ORDER BY "a", "b"
263266
"""
264267
).collect()
265268

266-
expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [16, 24]})
269+
expected = pl.DataFrame({"a": [1, 3], "b": [4, 6], "cc": [20, 30]})
267270
assert_frame_equal(left=expected, right=res)
268271

269272
# register and operate on non-polars frames
@@ -400,60 +403,60 @@ def test_select_output_heights_20058_21084(filter_expr: str, order_expr: str) ->
400403
)
401404

402405
assert_frame_equal(
403-
df.sql(f"""\
404-
SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}
405-
""").cast(pl.Int64),
406+
df.sql(f"SELECT 1 + 1 as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
407+
pl.Int64
408+
),
406409
pl.DataFrame({"a": [2, 2, 2], "b": [1, 1, 1]}),
407410
)
408411

409412
# Queries that aggregate to unit height
410413

411414
assert_frame_equal(
412-
df.sql(f"""\
413-
SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}
414-
""").cast(pl.Int64),
415+
df.sql(f"SELECT COUNT(*) as a FROM self {filter_expr} {order_expr}").cast(
416+
pl.Int64
417+
),
415418
pl.DataFrame({"a": 3}),
416419
)
417420

418421
assert_frame_equal(
419-
df.sql(f"""\
420-
SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}
421-
""").cast(pl.Int64),
422+
df.sql(
423+
f"SELECT COUNT(*) as a, 1 as b FROM self {filter_expr} {order_expr}"
424+
).cast(pl.Int64),
422425
pl.DataFrame({"a": 3, "b": 1}),
423426
)
424427

425428
assert_frame_equal(
426-
df.sql(f"""\
427-
SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}
428-
""").cast(pl.Int64),
429+
df.sql(
430+
f"SELECT FIRST(a) as a, 1 as b FROM self {filter_expr} {order_expr}"
431+
).cast(pl.Int64),
429432
pl.DataFrame({"a": 1, "b": 1}),
430433
)
431434

432435
assert_frame_equal(
433-
df.sql(f"""\
434-
SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}
435-
""").cast(pl.Int64),
436+
df.sql(f"SELECT SUM(a) as a, 1 as b FROM self {filter_expr} {order_expr}").cast(
437+
pl.Int64
438+
),
436439
pl.DataFrame({"a": 6, "b": 1}),
437440
)
438441

439442
assert_frame_equal(
440-
df.sql(f"""\
441-
SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}
442-
""").cast(pl.Int64),
443+
df.sql(
444+
f"SELECT FIRST(1) as a, 1 as b FROM self {filter_expr} {order_expr}"
445+
).cast(pl.Int64),
443446
pl.DataFrame({"a": 1, "b": 1}),
444447
)
445448

446449
assert_frame_equal(
447-
df.sql(f"""\
448-
SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}
449-
""").cast(pl.Int64),
450+
df.sql(
451+
f"SELECT FIRST(1) + 1 as a, 1 as b FROM self {filter_expr} {order_expr}"
452+
).cast(pl.Int64),
450453
pl.DataFrame({"a": 2, "b": 1}),
451454
)
452455

453456
assert_frame_equal(
454-
df.sql(f"""\
455-
SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}
456-
""").cast(pl.Int64),
457+
df.sql(
458+
f"SELECT FIRST(1 + 1) as a, 1 as b FROM self {filter_expr} {order_expr}"
459+
).cast(pl.Int64),
457460
pl.DataFrame({"a": 2, "b": 1}),
458461
)
459462

@@ -473,48 +476,34 @@ def test_select_explode_height_filter_order_by() -> None:
473476
# extended with NULLs.
474477

475478
assert_frame_equal(
476-
df.sql(
477-
"""\
478-
SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key
479-
"""
480-
),
479+
df.sql("SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key"),
481480
pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),
482481
)
483482

484483
assert_frame_equal(
485484
df.sql(
486-
"""\
487-
SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST
488-
"""
485+
"SELECT UNNEST(list_long) as list FROM self ORDER BY sort_key NULLS FIRST"
489486
),
490487
pl.Series("list", [3, 4, 5, 6, 2, 1]).to_frame(),
491488
)
492489

493490
# Literals are broadcasted to output height of UNNEST:
494491
assert_frame_equal(
495-
df.sql(
496-
"""\
497-
SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key
498-
"""
499-
),
492+
df.sql("SELECT UNNEST(list_long) as list, 1 as x FROM self ORDER BY sort_key"),
500493
pl.select(pl.Series("list", [2, 1, 3, 4, 5, 6]), x=1),
501494
)
502495

503496
# Note: Filter applies before projections in SQL
504497
assert_frame_equal(
505498
df.sql(
506-
"""\
507-
SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key
508-
"""
499+
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask ORDER BY sort_key"
509500
),
510501
pl.Series("list", [4, 5, 6]).to_frame(),
511502
)
512503

513504
assert_frame_equal(
514505
df.sql(
515-
"""\
516-
SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key
517-
"""
506+
"SELECT UNNEST(list_long) as list FROM self WHERE filter_mask_all_true ORDER BY sort_key"
518507
),
519508
pl.Series("list", [2, 1, 3, 4, 5, 6]).to_frame(),
520509
)

0 commit comments

Comments
 (0)