Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,15 @@ test = [
"pyarrow>=23.0.1",
"adbc-driver-flightsql>=1.11.0",
"adbc-driver-manager>=1.11.0",
"polars>=1.0.0",
]
params = [
"adbc-driver-flightsql>=1.11.0",
"adbc-driver-manager>=1.11.0",
]
polars = [
"polars>=1.0.0",
]

# ============== Tool Configuration ==============

Expand Down Expand Up @@ -93,6 +97,8 @@ module = [
"adbc_driver_manager.*",
"certifi",
"pandas",
"polars",
"polars.*",
"_pytest.*",
"pytest.*",
]
Expand Down Expand Up @@ -185,6 +191,14 @@ ignore = [
"PLR2004", # Magic values ok in tests
"ARG", # Unused arguments ok in tests (fixtures)
"T20", # Print statements ok in tests
"N812", # `functions as F` is the DataFrame convention
"S608", # asserting SQL fragments is the point of these tests
]
"spicepy/_client.py" = [
"S608", # SQL is constructed from escaped identifiers/literals; lint can't tell
]
"spicepy/_dataframe.py" = [
"S608", # The DataFrame layer's entire job is composing SQL; identifiers and literals are escaped
]

[tool.ruff.lint.isort]
Expand Down Expand Up @@ -246,7 +260,10 @@ directory = "htmlcov"

[tool.bandit]
exclude_dirs = ["tests", ".venv"]
skips = ["B101"] # assert_used
skips = [
"B101", # assert_used
"B608", # hardcoded_sql_expressions: composing SQL with escaped identifiers/literals is this package's job
]

[dependency-groups]
dev = [
Expand Down
14 changes: 14 additions & 0 deletions spicepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,19 @@
"""

# flake8: noqa
from . import functions
from ._client import Client
from ._dataframe import SpiceDataFrame
from ._expr import Expr, case, col, lit
from ._http import RefreshOpts

__all__ = [
"Client",
"Expr",
"RefreshOpts",
"SpiceDataFrame",
"case",
"col",
"functions",
"lit",
]
282 changes: 281 additions & 1 deletion spicepy/_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from collections.abc import Iterator
import json
import os
from pathlib import Path
import platform
import threading
from typing import Any
from typing import TYPE_CHECKING, Any, cast

import certifi
import pyarrow as pa

if TYPE_CHECKING:
import pandas as pd
import polars as pl

from ._dataframe import SpiceDataFrame

# pylint: disable=E0611
from pyarrow._flight import (
FlightCallOptions,
Expand Down Expand Up @@ -409,6 +416,279 @@ def query_with_params(
adbc = self._ensure_adbc_client()
return adbc.query_with_params(sql, params)

def _read_table(
self,
sql: str,
params: list[Any] | None,
timeout: int | None,
) -> pa.Table:
if params is not None:
return self.query_with_params(sql, params).read_all()
kwargs: dict[str, Any] = {}
if timeout is not None:
kwargs["timeout"] = timeout
return self.query(sql, **kwargs).read_all()

def query_arrow(
self,
sql: str,
*,
params: list[Any] | None = None,
timeout: int | None = None,
) -> pa.Table:
"""Execute a SQL query and return results as a PyArrow Table.

Args:
sql: SQL query string. Use $1, $2, ... placeholders if passing params.
params: Optional list of parameter values. When provided, the query is
executed via ADBC FlightSQL with prepared statements. See
:meth:`query_with_params` for parameter format.
timeout: Optional query timeout in seconds (ignored when params is set).

Returns:
Arrow Table with all query results materialized in memory.
"""
return self._read_table(sql, params, timeout)

def query_pandas(
self,
sql: str,
*,
params: list[Any] | None = None,
timeout: int | None = None,
) -> "pd.DataFrame":
"""Execute a SQL query and return results as a pandas DataFrame.

See :meth:`query_arrow` for argument semantics.
"""
return cast("pd.DataFrame", self._read_table(sql, params, timeout).to_pandas())

def query_polars(
self,
sql: str,
*,
params: list[Any] | None = None,
timeout: int | None = None,
) -> "pl.DataFrame":
"""Execute a SQL query and return results as a polars DataFrame.

Requires the optional ``polars`` dependency:
``pip install spicepy[polars]``.

See :meth:`query_arrow` for argument semantics.
"""
try:
import polars as pl
except ImportError as exc:
raise ImportError(
"polars is not installed. Install it with: pip install spicepy[polars]"
) from exc
return cast(
"pl.DataFrame", pl.from_arrow(self._read_table(sql, params, timeout))
)

def query_pylist(
self,
sql: str,
*,
params: list[Any] | None = None,
timeout: int | None = None,
) -> list[dict[str, Any]]:
"""Execute a SQL query and return results as a list of row dicts.

See :meth:`query_arrow` for argument semantics.
"""
return cast(
"list[dict[str, Any]]",
self._read_table(sql, params, timeout).to_pylist(),
)

def query_pydict(
self,
sql: str,
*,
params: list[Any] | None = None,
timeout: int | None = None,
) -> dict[str, list[Any]]:
"""Execute a SQL query and return results as a column-oriented dict."""
return cast(
"dict[str, list[Any]]",
self._read_table(sql, params, timeout).to_pydict(),
)

def query_batches(
self,
sql: str,
*,
timeout: int | None = None,
) -> Iterator[pa.RecordBatch]:
"""Execute a SQL query and yield Arrow RecordBatches as they stream in.

Unlike :meth:`query_arrow`, this does not materialize the full result
in memory before returning.
"""
kwargs: dict[str, Any] = {}
if timeout is not None:
kwargs["timeout"] = timeout
reader = self.query(sql, **kwargs)
for chunk in reader:
# FlightStreamReader yields FlightStreamChunk; pa.RecordBatchReader yields RecordBatch.
batch = getattr(chunk, "data", chunk)
if batch is not None:
yield batch

# ------------------------------------------------------------------
# Catalog introspection
# ------------------------------------------------------------------

def catalogs(self) -> list[str]:
"""List catalog names visible to this connection."""
rows = self.query_pylist(
"SELECT DISTINCT catalog_name FROM information_schema.schemata "
"ORDER BY catalog_name"
)
return [r["catalog_name"] for r in rows]

def schemas(self, catalog: str | None = None) -> list[str]:
"""List schema names, optionally restricted to a catalog."""
if catalog is None:
rows = self.query_pylist(
"SELECT schema_name FROM information_schema.schemata "
"ORDER BY schema_name"
)
else:
from ._sql import quote_literal

rows = self.query_pylist(
"SELECT schema_name FROM information_schema.schemata "
f"WHERE catalog_name = {quote_literal(catalog)} "
"ORDER BY schema_name"
)
return [r["schema_name"] for r in rows]

def tables(self, schema: str | None = None) -> list[str]:
"""List table names, optionally restricted to a schema."""
if schema is None:
rows = self.query_pylist(
"SELECT table_name FROM information_schema.tables "
"ORDER BY table_schema, table_name"
)
else:
from ._sql import quote_literal

rows = self.query_pylist(
"SELECT table_name FROM information_schema.tables "
f"WHERE table_schema = {quote_literal(schema)} "
"ORDER BY table_name"
)
return [r["table_name"] for r in rows]

def describe(self, table: str) -> "pd.DataFrame":
"""Return column metadata (name, type, nullability) for a table."""
from ._sql import quote_ident

return self.query_pandas(f"DESCRIBE {quote_ident(table)}")

def get_schema(self, sql: str) -> pa.Schema:
"""Return the Arrow schema of a query without materializing rows."""
return self.query(f"SELECT * FROM ({sql}) LIMIT 0").read_all().schema

def explain(
self,
sql: str,
*,
analyze: bool = False,
verbose: bool = False,
) -> str:
"""Run EXPLAIN against ``sql`` and return the textual plan."""
prefix = "EXPLAIN"
if analyze:
prefix += " ANALYZE"
if verbose:
prefix += " VERBOSE"
rows = self.query_pylist(f"{prefix} {sql}")
return "\n".join((r.get("plan") or r.get("plan_type") or str(r)) for r in rows)

def show(self, sql: str, n: int = 20) -> None:
"""Pretty-print the first ``n`` rows of ``sql``."""
df = self.query_pandas(f"SELECT * FROM ({sql}) LIMIT {int(n)}")
print(df.to_string(index=False)) # noqa: T201

# ------------------------------------------------------------------
# Writers
# ------------------------------------------------------------------

def write_parquet(self, sql: str, path: str, **kwargs: Any) -> None:
"""Stream the result of ``sql`` to a Parquet file at ``path``."""
import pyarrow.parquet as pq

reader = self.query(sql)
with pq.ParquetWriter(path, reader.schema, **kwargs) as writer:
for chunk in reader:
batch = getattr(chunk, "data", chunk)
if batch is not None:
writer.write_batch(batch)

def write_csv(self, sql: str, path: str, **kwargs: Any) -> None:
"""Stream the result of ``sql`` to a CSV file at ``path``."""
import pyarrow.csv as pa_csv

reader = self.query(sql)
with pa_csv.CSVWriter(path, reader.schema, **kwargs) as writer:
for chunk in reader:
batch = getattr(chunk, "data", chunk)
if batch is not None:
writer.write_batch(batch)

def write_json(self, sql: str, path: str) -> None:
"""Write the result of ``sql`` to ``path`` as newline-delimited JSON."""
import json as _json

table = self.query_arrow(sql)
with open(path, "w", encoding="utf-8") as fh:
for row in table.to_pylist():
fh.write(_json.dumps(row, default=str))
fh.write("\n")

# ------------------------------------------------------------------
# DataFrame entry points
# ------------------------------------------------------------------

def table(self, name: str) -> "SpiceDataFrame":
"""Return a lazy DataFrame referencing ``name`` as a table."""
from ._dataframe import SpiceDataFrame
from ._sql import quote_ident

return SpiceDataFrame(self, f"SELECT * FROM {quote_ident(name)}")

def sql(self, query: str) -> "SpiceDataFrame":
"""Return a lazy DataFrame wrapping an arbitrary SQL query."""
from ._dataframe import SpiceDataFrame

return SpiceDataFrame(self, query)

def from_arrow(self, table: pa.Table) -> "SpiceDataFrame":
"""Return a lazy DataFrame backed by an inline VALUES clause.

Intended for small literal tables; large data should be loaded
server-side as a dataset.
"""
from ._dataframe import values_dataframe

return values_dataframe(self, table.to_pylist())

def from_pandas(self, df: "pd.DataFrame") -> "SpiceDataFrame":
"""Return a lazy DataFrame from a pandas DataFrame via inline VALUES."""
import pyarrow as pa

return self.from_arrow(pa.Table.from_pandas(df))

def from_pydict(self, data: dict[str, list[Any]]) -> "SpiceDataFrame":
"""Return a lazy DataFrame from a column-oriented dict via inline VALUES."""
import pyarrow as pa

return self.from_arrow(pa.Table.from_pydict(data))

def refresh_dataset(
self, dataset: str, refresh_opts: RefreshOpts | None = None
) -> Any:
Expand Down
Loading
Loading