Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ split-on-trailing-comma = false
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
"sqlspec/extensions/adk/converters.py" = ["S403"]
"sqlspec/migrations/utils.py" = ["S404"]
"sqlspec/adapters/spanner/config.py" = ["PLC2801"]
"tests/**/*.*" = [
"A",
"ARG",
Expand Down
4 changes: 3 additions & 1 deletion sqlspec/adapters/spanner/_type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import base64
from datetime import datetime, timezone
from datetime import date, datetime, timezone
from typing import TYPE_CHECKING, Any
from uuid import UUID

Expand Down Expand Up @@ -167,6 +167,8 @@ def infer_spanner_param_types(params: "dict[str, Any] | None") -> "dict[str, Any
types[key] = param_types.BYTES
elif isinstance(value, datetime):
types[key] = param_types.TIMESTAMP
elif isinstance(value, date):
types[key] = param_types.DATE
elif isinstance(value, dict) and hasattr(param_types, "JSON"):
types[key] = param_types.JSON
elif isinstance(value, list):
Expand Down
35 changes: 26 additions & 9 deletions sqlspec/adapters/spanner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,37 @@ def _close_pool(self) -> None:
def provide_connection(
self, *args: Any, transaction: "bool" = False, **kwargs: Any
) -> Generator[SpannerConnection, None, None]:
"""Yield a Snapshot (default) or Batch context from the configured pool.
"""Yield a Snapshot (default) or Transaction context from the configured pool.

Note: Spanner does not support database.transaction() as a context manager.
For write operations requiring conditional logic, use database.run_in_transaction()
directly. The `transaction=True` option here uses database.batch() which is
suitable for simple insert/update/delete mutations.
Args:
*args: Additional positional arguments (unused, for interface compatibility).
transaction: If True, yields a Transaction context that supports
execute_update() for DML statements. If False (default), yields
a read-only Snapshot context for SELECT queries.
**kwargs: Additional keyword arguments (unused, for interface compatibility).

Note: For complex transactional logic with retries, use database.run_in_transaction()
directly. The Transaction context here auto-commits on successful exit.
"""
database = self.get_database()
if transaction:
with cast("Any", database).batch() as batch:
yield cast("SpannerConnection", batch)
session = cast("Any", database).session()
session.create()
try:
txn = session.transaction()
txn.__enter__()
try:
yield cast("SpannerConnection", txn)
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
txn.commit()
except Exception:
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
txn.rollback()
raise
finally:
session.delete()
else:
with cast("Any", database).snapshot() as snapshot:
with cast("Any", database).snapshot(multi_use=True) as snapshot:
yield cast("SpannerConnection", snapshot)

@contextmanager
Expand All @@ -209,7 +227,6 @@ def provide_session(
def provide_write_session(
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
) -> Generator[SpannerSyncDriver, None, None]:
"""Convenience wrapper that always yields a write-capable transaction session."""
with self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) as driver:
yield driver

Expand Down
37 changes: 13 additions & 24 deletions sqlspec/adapters/spanner/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,32 +204,21 @@ def _execute_many(self, cursor: "SpannerConnection", statement: "SQL") -> Execut
raise SQLConversionError(msg)
conn = cast("Any", cursor)

parameter_sets = statement.parameters if isinstance(statement.parameters, list) else []
if not parameter_sets:
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)

if not prepared_parameters or not isinstance(prepared_parameters, list):
msg = "execute_many requires at least one parameter set"
raise SQLConversionError(msg)

base_params = parameter_sets[0]
base_statement = self.prepare_statement(
statement.raw_sql, *[base_params], statement_config=statement.statement_config
)
compiled_sql, _ = self._get_compiled_sql(base_statement, self.statement_config)

batch_inputs: list[dict[str, Any]] = []
for params in parameter_sets:
per_statement = self.prepare_statement(
statement.raw_sql, *[params], statement_config=statement.statement_config
)
_, processed_params = self._get_compiled_sql(per_statement, self.statement_config)
coerced_params = self._coerce_params(processed_params)
batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = []
for params in prepared_parameters:
coerced_params = self._coerce_params(params)
if coerced_params is None:
coerced_params = {}
batch_inputs.append(coerced_params)

batch_args = [(compiled_sql, p, self._infer_param_types(p)) for p in batch_inputs]
batch_args.append((sql, coerced_params, self._infer_param_types(coerced_params)))

row_counts = conn.batch_update(batch_args)
total_rows = int(sum(int(count) for count in row_counts))
_status, row_counts = conn.batch_update(batch_args)
total_rows = sum(row_counts) if row_counts else 0

return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True)

Expand Down Expand Up @@ -350,28 +339,28 @@ def _truncate_table_sync(self, table: str) -> None:


def _build_spanner_profile() -> DriverParameterProfile:
type_coercions: dict[type, Any] = {dict: to_json, list: to_json, tuple: to_json}
type_coercions: dict[type, Any] = {dict: to_json}
return DriverParameterProfile(
name="Spanner",
default_style=ParameterStyle.NAMED_AT,
supported_styles={ParameterStyle.NAMED_AT},
default_execution_style=ParameterStyle.NAMED_AT,
supported_execution_styles={ParameterStyle.NAMED_AT},
has_native_list_expansion=True,
json_serializer_strategy="helper",
json_serializer_strategy="none",
default_dialect="spanner",
preserve_parameter_format=True,
needs_static_script_compilation=False,
allow_mixed_parameter_styles=False,
preserve_original_params_for_many=True,
custom_type_coercions=type_coercions,
extras={"type_coercion_overrides": type_coercions},
extras={},
)


_SPANNER_PROFILE = _build_spanner_profile()
register_driver_profile("spanner", _SPANNER_PROFILE)

spanner_statement_config = build_statement_config_from_profile(
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}, json_serializer=to_json
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}
)
21 changes: 13 additions & 8 deletions sqlspec/extensions/litestar/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,16 @@
CORRELATION_STATE_KEY = "sqlspec_correlation_id"

__all__ = (
"CORRELATION_STATE_KEY",
"DEFAULT_COMMIT_MODE",
"DEFAULT_CONNECTION_KEY",
"DEFAULT_CORRELATION_HEADER",
"DEFAULT_POOL_KEY",
"DEFAULT_SESSION_KEY",
"TRACE_CONTEXT_FALLBACK_HEADERS",
"CommitMode",
"CorrelationMiddleware",
"PluginConfigState",
"SQLSpecPlugin",
)

Expand Down Expand Up @@ -117,7 +122,7 @@ def _build_correlation_headers(*, primary: str, configured: list[str], auto_trac
return tuple(_dedupe_headers(header_order))


class _CorrelationMiddleware:
class CorrelationMiddleware:
__slots__ = ("_app", "_headers")

def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None:
Expand Down Expand Up @@ -153,7 +158,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No


@dataclass
class _PluginConfigState:
class PluginConfigState:
"""Internal state for each database configuration."""

config: "DatabaseConfigProtocol[Any, Any, Any]"
Expand Down Expand Up @@ -219,7 +224,7 @@ def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -
"""
self._sqlspec = sqlspec

self._plugin_configs: list[_PluginConfigState] = []
self._plugin_configs: list[PluginConfigState] = []
for cfg in self._sqlspec.configs.values():
config_union = cast(
"SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
Expand Down Expand Up @@ -276,9 +281,9 @@ def _create_config_state(
self,
config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
settings: "dict[str, Any]",
) -> _PluginConfigState:
) -> PluginConfigState:
"""Create plugin state with handlers for the given configuration."""
state = _PluginConfigState(
state = PluginConfigState(
config=config,
connection_key=settings["connection_key"],
pool_key=settings["pool_key"],
Expand All @@ -296,7 +301,7 @@ def _create_config_state(
self._setup_handlers(state)
return state

def _setup_handlers(self, state: _PluginConfigState) -> None:
def _setup_handlers(self, state: PluginConfigState) -> None:
"""Setup handlers for the plugin state."""
connection_key = state.connection_key
pool_key = state.pool_key
Expand Down Expand Up @@ -403,7 +408,7 @@ def store_sqlspec_in_state() -> None:
app_config.type_decoders = decoders_list

if self._correlation_headers:
middleware = DefineMiddleware(_CorrelationMiddleware, headers=self._correlation_headers)
middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)
existing_middleware = list(app_config.middleware or [])
existing_middleware.append(middleware)
app_config.middleware = existing_middleware
Expand Down Expand Up @@ -579,7 +584,7 @@ def provide_request_connection(

def _get_plugin_state(
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]"
) -> _PluginConfigState:
) -> PluginConfigState:
"""Get plugin state for a configuration by key."""
if isinstance(key, str):
for state in self._plugin_configs:
Expand Down
21 changes: 20 additions & 1 deletion tests/integration/test_adapters/test_spanner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def spanner_database(


@pytest.fixture
def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.Client) -> SpannerSyncConfig:
def spanner_config(
spanner_service: SpannerService, spanner_connection: spanner.Client, spanner_database: "Database"
) -> SpannerSyncConfig:
"""Create SpannerSyncConfig after ensuring database exists."""
_ = spanner_database # Ensure database is created before config
api_endpoint = f"{spanner_service.host}:{spanner_service.port}"

return SpannerSyncConfig(
Expand All @@ -53,12 +57,27 @@ def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.

@pytest.fixture
def spanner_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
"""Read-only session for SELECT operations."""
sql = SQLSpec()
c = sql.add_config(spanner_config)
with sql.provide_session(c) as session:
yield session


@pytest.fixture
def spanner_write_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
"""Write-capable session for DML operations (INSERT/UPDATE/DELETE)."""
with spanner_config.provide_write_session() as session:
yield session


@pytest.fixture
def spanner_read_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
"""Read-only session for SELECT operations."""
with spanner_config.provide_session() as session:
yield session


def run_ddl(database: "Database", statements: "list[str]", timeout: int = 300) -> None:
"""Execute DDL statements on Spanner database."""
operation = database.update_ddl(statements) # type: ignore[no-untyped-call]
Expand Down
Loading