Skip to content

Commit 687811d

Browse files
authored
feat: Improve read_database typing (#19444)
1 parent dc47e92 commit 687811d

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

py-polars/polars/io/database/functions.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
except ImportError:
2626
Selectable: TypeAlias = Any # type: ignore[no-redef]
2727

28+
from sqlalchemy.sql.elements import TextClause
29+
2830

2931
@overload
3032
def read_database(
31-
query: str | Selectable,
33+
query: str | TextClause | Selectable,
3234
connection: ConnectionOrCursor | str,
3335
*,
3436
iter_batches: Literal[False] = ...,
@@ -41,7 +43,7 @@ def read_database(
4143

4244
@overload
4345
def read_database(
44-
query: str | Selectable,
46+
query: str | TextClause | Selectable,
4547
connection: ConnectionOrCursor | str,
4648
*,
4749
iter_batches: Literal[True],
@@ -54,7 +56,7 @@ def read_database(
5456

5557
@overload
5658
def read_database(
57-
query: str | Selectable,
59+
query: str | TextClause | Selectable,
5860
connection: ConnectionOrCursor | str,
5961
*,
6062
iter_batches: bool,
@@ -66,7 +68,7 @@ def read_database(
6668

6769

6870
def read_database(
69-
query: str | Selectable,
71+
query: str | TextClause | Selectable,
7072
connection: ConnectionOrCursor | str,
7173
*,
7274
iter_batches: bool = False,

py-polars/tests/unit/io/database/test_read.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pyarrow as pa
1313
import pytest
1414
import sqlalchemy
15-
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
15+
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text
1616
from sqlalchemy.orm import sessionmaker
1717
from sqlalchemy.sql.expression import cast as alchemy_cast
1818

@@ -383,6 +383,39 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
383383
assert_frame_equal(batches[0], expected)
384384

385385

386+
def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:
387+
# various flavours of alchemy connection
388+
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
389+
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
390+
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
391+
392+
# establish sqlalchemy "textclause" and validate usage
393+
textclause_query = text("""
394+
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
395+
FROM test_data
396+
WHERE value < 0
397+
""")
398+
399+
expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})
400+
401+
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
402+
assert_frame_equal(
403+
pl.read_database(textclause_query, connection=conn),
404+
expected,
405+
)
406+
407+
batches = list(
408+
pl.read_database(
409+
textclause_query,
410+
connection=conn,
411+
iter_batches=True,
412+
batch_size=1,
413+
)
414+
)
415+
assert len(batches) == 1
416+
assert_frame_equal(batches[0], expected)
417+
418+
386419
def test_read_database_parameterised(tmp_sqlite_db: Path) -> None:
387420
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
388421
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")

0 commit comments

Comments
 (0)