Skip to content

Commit b2f13bf

Browse files
committed
feat: add query_arrow/pandas/polars/pylist helpers on Client
Adds one-shot output ergonomics so callers don't have to reach for .read_all().to_pandas() on the underlying Flight or ADBC reader. Each helper accepts an optional `params` list that routes through the existing ADBC parameterized path, otherwise it goes through Flight as before. polars is gated behind an optional extra (`spicepy[polars]`) and raises a clear ImportError if missing.
1 parent 8d0250c commit b2f13bf

3 files changed

Lines changed: 292 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,15 @@ test = [
5656
"pyarrow>=23.0.1",
5757
"adbc-driver-flightsql>=1.11.0",
5858
"adbc-driver-manager>=1.11.0",
59+
"polars>=1.0.0",
5960
]
6061
params = [
6162
"adbc-driver-flightsql>=1.11.0",
6263
"adbc-driver-manager>=1.11.0",
6364
]
65+
polars = [
66+
"polars>=1.0.0",
67+
]
6468

6569
# ============== Tool Configuration ==============
6670

@@ -93,6 +97,8 @@ module = [
9397
"adbc_driver_manager.*",
9498
"certifi",
9599
"pandas",
100+
"polars",
101+
"polars.*",
96102
"_pytest.*",
97103
"pytest.*",
98104
]

spicepy/_client.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from pathlib import Path
44
import platform
55
import threading
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any, cast
77

88
import certifi
99
import pyarrow as pa
1010

11+
if TYPE_CHECKING:
12+
import pandas as pd
13+
import polars as pl
14+
1115
# pylint: disable=E0611
1216
from pyarrow._flight import (
1317
FlightCallOptions,
@@ -409,6 +413,91 @@ def query_with_params(
409413
adbc = self._ensure_adbc_client()
410414
return adbc.query_with_params(sql, params)
411415

416+
def _read_table(
417+
self,
418+
sql: str,
419+
params: list[Any] | None,
420+
timeout: int | None,
421+
) -> pa.Table:
422+
if params is not None:
423+
return self.query_with_params(sql, params).read_all()
424+
kwargs: dict[str, Any] = {}
425+
if timeout is not None:
426+
kwargs["timeout"] = timeout
427+
return self.query(sql, **kwargs).read_all()
428+
429+
def query_arrow(
430+
self,
431+
sql: str,
432+
*,
433+
params: list[Any] | None = None,
434+
timeout: int | None = None,
435+
) -> pa.Table:
436+
"""Execute a SQL query and return results as a PyArrow Table.
437+
438+
Args:
439+
sql: SQL query string. Use $1, $2, ... placeholders if passing params.
440+
params: Optional list of parameter values. When provided, the query is
441+
executed via ADBC FlightSQL with prepared statements. See
442+
:meth:`query_with_params` for parameter format.
443+
timeout: Optional query timeout in seconds (ignored when params is set).
444+
445+
Returns:
446+
Arrow Table with all query results materialized in memory.
447+
"""
448+
return self._read_table(sql, params, timeout)
449+
450+
def query_pandas(
451+
self,
452+
sql: str,
453+
*,
454+
params: list[Any] | None = None,
455+
timeout: int | None = None,
456+
) -> "pd.DataFrame":
457+
"""Execute a SQL query and return results as a pandas DataFrame.
458+
459+
See :meth:`query_arrow` for argument semantics.
460+
"""
461+
return cast("pd.DataFrame", self._read_table(sql, params, timeout).to_pandas())
462+
463+
def query_polars(
464+
self,
465+
sql: str,
466+
*,
467+
params: list[Any] | None = None,
468+
timeout: int | None = None,
469+
) -> "pl.DataFrame":
470+
"""Execute a SQL query and return results as a polars DataFrame.
471+
472+
Requires the optional ``polars`` dependency:
473+
``pip install spicepy[polars]``.
474+
475+
See :meth:`query_arrow` for argument semantics.
476+
"""
477+
try:
478+
import polars as pl
479+
except ImportError as exc:
480+
raise ImportError(
481+
"polars is not installed. Install it with: pip install spicepy[polars]"
482+
) from exc
483+
return cast("pl.DataFrame", pl.from_arrow(self._read_table(sql, params, timeout)))
484+
485+
def query_pylist(
486+
self,
487+
sql: str,
488+
*,
489+
params: list[Any] | None = None,
490+
timeout: int | None = None,
491+
) -> list[dict[str, Any]]:
492+
"""Execute a SQL query and return results as a list of row dicts.
493+
494+
See :meth:`query_arrow` for argument semantics.
495+
"""
496+
return cast(
497+
"list[dict[str, Any]]",
498+
self._read_table(sql, params, timeout).to_pylist(),
499+
)
500+
412501
def refresh_dataset(
413502
self, dataset: str, refresh_opts: RefreshOpts | None = None
414503
) -> Any:

tests/test_client.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,202 @@ def test_thread_join_raises_exception(self) -> None:
10171017
thread.join()
10181018

10191019

1020+
class TestClientQueryHelpers:
1021+
"""Test Client query_arrow / query_pandas / query_polars / query_pylist."""
1022+
1023+
@staticmethod
1024+
def _sample_table() -> pa.Table:
1025+
return pa.table({"id": [1, 2, 3], "name": ["a", "b", "c"]})
1026+
1027+
@patch("spicepy._client._SpiceFlight")
1028+
@patch("spicepy._client._Cert")
1029+
def test_query_arrow_returns_table(
1030+
self,
1031+
mock_cert_class: MagicMock,
1032+
mock_flight_class: MagicMock,
1033+
) -> None:
1034+
mock_cert = MagicMock()
1035+
mock_cert.tls_root_certs = b"cert"
1036+
mock_cert_class.return_value = mock_cert
1037+
1038+
table = self._sample_table()
1039+
mock_reader = MagicMock()
1040+
mock_reader.read_all.return_value = table
1041+
mock_flight = MagicMock()
1042+
mock_flight.query.return_value = mock_reader
1043+
mock_flight_class.return_value = mock_flight
1044+
1045+
client = Client()
1046+
result = client.query_arrow("SELECT * FROM t")
1047+
1048+
assert result is table
1049+
mock_flight.query.assert_called_once_with("SELECT * FROM t")
1050+
1051+
@patch("spicepy._client._SpiceFlight")
1052+
@patch("spicepy._client._Cert")
1053+
def test_query_arrow_passes_timeout(
1054+
self,
1055+
mock_cert_class: MagicMock,
1056+
mock_flight_class: MagicMock,
1057+
) -> None:
1058+
mock_cert = MagicMock()
1059+
mock_cert.tls_root_certs = b"cert"
1060+
mock_cert_class.return_value = mock_cert
1061+
1062+
mock_reader = MagicMock()
1063+
mock_reader.read_all.return_value = self._sample_table()
1064+
mock_flight = MagicMock()
1065+
mock_flight.query.return_value = mock_reader
1066+
mock_flight_class.return_value = mock_flight
1067+
1068+
client = Client()
1069+
client.query_arrow("SELECT 1", timeout=30)
1070+
1071+
mock_flight.query.assert_called_once_with("SELECT 1", timeout=30)
1072+
1073+
@patch("spicepy._client._SpiceFlight")
1074+
@patch("spicepy._client._Cert")
1075+
@patch("spicepy._client._ADBCClient")
1076+
def test_query_arrow_routes_params_to_adbc(
1077+
self,
1078+
mock_adbc_class: MagicMock,
1079+
mock_cert_class: MagicMock,
1080+
mock_flight_class: MagicMock,
1081+
) -> None:
1082+
mock_cert = MagicMock()
1083+
mock_cert.tls_root_certs = b"cert"
1084+
mock_cert_class.return_value = mock_cert
1085+
1086+
table = self._sample_table()
1087+
adbc_reader = MagicMock()
1088+
adbc_reader.read_all.return_value = table
1089+
mock_adbc = MagicMock()
1090+
mock_adbc.query_with_params.return_value = adbc_reader
1091+
mock_adbc_class.return_value = mock_adbc
1092+
1093+
flight_instance = MagicMock()
1094+
mock_flight_class.return_value = flight_instance
1095+
1096+
client = Client()
1097+
result = client.query_arrow("SELECT * FROM t WHERE id = $1", params=[1])
1098+
1099+
assert result is table
1100+
mock_adbc.query_with_params.assert_called_once_with(
1101+
"SELECT * FROM t WHERE id = $1", [1]
1102+
)
1103+
flight_instance.query.assert_not_called()
1104+
1105+
@patch("spicepy._client._SpiceFlight")
1106+
@patch("spicepy._client._Cert")
1107+
def test_query_pandas_returns_dataframe(
1108+
self,
1109+
mock_cert_class: MagicMock,
1110+
mock_flight_class: MagicMock,
1111+
) -> None:
1112+
import pandas as pd
1113+
1114+
mock_cert = MagicMock()
1115+
mock_cert.tls_root_certs = b"cert"
1116+
mock_cert_class.return_value = mock_cert
1117+
1118+
mock_reader = MagicMock()
1119+
mock_reader.read_all.return_value = self._sample_table()
1120+
mock_flight = MagicMock()
1121+
mock_flight.query.return_value = mock_reader
1122+
mock_flight_class.return_value = mock_flight
1123+
1124+
client = Client()
1125+
df = client.query_pandas("SELECT * FROM t")
1126+
1127+
assert isinstance(df, pd.DataFrame)
1128+
assert list(df.columns) == ["id", "name"]
1129+
assert len(df) == 3
1130+
1131+
@patch("spicepy._client._SpiceFlight")
1132+
@patch("spicepy._client._Cert")
1133+
def test_query_pylist_returns_row_dicts(
1134+
self,
1135+
mock_cert_class: MagicMock,
1136+
mock_flight_class: MagicMock,
1137+
) -> None:
1138+
mock_cert = MagicMock()
1139+
mock_cert.tls_root_certs = b"cert"
1140+
mock_cert_class.return_value = mock_cert
1141+
1142+
mock_reader = MagicMock()
1143+
mock_reader.read_all.return_value = self._sample_table()
1144+
mock_flight = MagicMock()
1145+
mock_flight.query.return_value = mock_reader
1146+
mock_flight_class.return_value = mock_flight
1147+
1148+
client = Client()
1149+
rows = client.query_pylist("SELECT * FROM t")
1150+
1151+
assert rows == [
1152+
{"id": 1, "name": "a"},
1153+
{"id": 2, "name": "b"},
1154+
{"id": 3, "name": "c"},
1155+
]
1156+
1157+
@patch("spicepy._client._SpiceFlight")
1158+
@patch("spicepy._client._Cert")
1159+
def test_query_polars_returns_dataframe(
1160+
self,
1161+
mock_cert_class: MagicMock,
1162+
mock_flight_class: MagicMock,
1163+
) -> None:
1164+
pl = pytest.importorskip("polars")
1165+
1166+
mock_cert = MagicMock()
1167+
mock_cert.tls_root_certs = b"cert"
1168+
mock_cert_class.return_value = mock_cert
1169+
1170+
mock_reader = MagicMock()
1171+
mock_reader.read_all.return_value = self._sample_table()
1172+
mock_flight = MagicMock()
1173+
mock_flight.query.return_value = mock_reader
1174+
mock_flight_class.return_value = mock_flight
1175+
1176+
client = Client()
1177+
df = client.query_polars("SELECT * FROM t")
1178+
1179+
assert isinstance(df, pl.DataFrame)
1180+
assert df.columns == ["id", "name"]
1181+
assert df.height == 3
1182+
1183+
@patch("spicepy._client._SpiceFlight")
1184+
@patch("spicepy._client._Cert")
1185+
def test_query_polars_missing_raises_importerror(
1186+
self,
1187+
mock_cert_class: MagicMock,
1188+
mock_flight_class: MagicMock,
1189+
) -> None:
1190+
mock_cert = MagicMock()
1191+
mock_cert.tls_root_certs = b"cert"
1192+
mock_cert_class.return_value = mock_cert
1193+
1194+
mock_reader = MagicMock()
1195+
mock_reader.read_all.return_value = self._sample_table()
1196+
mock_flight = MagicMock()
1197+
mock_flight.query.return_value = mock_reader
1198+
mock_flight_class.return_value = mock_flight
1199+
1200+
client = Client()
1201+
1202+
import builtins
1203+
1204+
real_import = builtins.__import__
1205+
1206+
def fake_import(name, *args, **kwargs):
1207+
if name == "polars":
1208+
raise ImportError("No module named 'polars'")
1209+
return real_import(name, *args, **kwargs)
1210+
1211+
with patch("builtins.__import__", side_effect=fake_import):
1212+
with pytest.raises(ImportError, match="polars is not installed"):
1213+
client.query_polars("SELECT 1")
1214+
1215+
10201216
class TestIsMacosArm64:
10211217
"""Test is_macos_arm64 function."""
10221218

0 commit comments

Comments
 (0)