Skip to content

Commit 890edce

Browse files
authored
fix(spanner): improve write handling in Spanner (#273)
* fix: improve write handling in Spanner * refactor: rename internal classes and update related tests for clarity
1 parent 1c6a2ad commit 890edce

File tree

16 files changed

+1654
-422
lines changed

16 files changed

+1654
-422
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ split-on-trailing-comma = false
501501
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
502502
"sqlspec/extensions/adk/converters.py" = ["S403"]
503503
"sqlspec/migrations/utils.py" = ["S404"]
504+
"sqlspec/adapters/spanner/config.py" = ["PLC2801"]
504505
"tests/**/*.*" = [
505506
"A",
506507
"ARG",

sqlspec/adapters/spanner/_type_handlers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
import base64
11-
from datetime import datetime, timezone
11+
from datetime import date, datetime, timezone
1212
from typing import TYPE_CHECKING, Any
1313
from uuid import UUID
1414

@@ -167,6 +167,8 @@ def infer_spanner_param_types(params: "dict[str, Any] | None") -> "dict[str, Any
167167
types[key] = param_types.BYTES
168168
elif isinstance(value, datetime):
169169
types[key] = param_types.TIMESTAMP
170+
elif isinstance(value, date):
171+
types[key] = param_types.DATE
170172
elif isinstance(value, dict) and hasattr(param_types, "JSON"):
171173
types[key] = param_types.JSON
172174
elif isinstance(value, list):

sqlspec/adapters/spanner/config.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,37 @@ def _close_pool(self) -> None:
178178
def provide_connection(
179179
self, *args: Any, transaction: "bool" = False, **kwargs: Any
180180
) -> Generator[SpannerConnection, None, None]:
181-
"""Yield a Snapshot (default) or Batch context from the configured pool.
181+
"""Yield a Snapshot (default) or Transaction context from the configured pool.
182182
183-
Note: Spanner does not support database.transaction() as a context manager.
184-
For write operations requiring conditional logic, use database.run_in_transaction()
185-
directly. The `transaction=True` option here uses database.batch() which is
186-
suitable for simple insert/update/delete mutations.
183+
Args:
184+
*args: Additional positional arguments (unused, for interface compatibility).
185+
transaction: If True, yields a Transaction context that supports
186+
execute_update() for DML statements. If False (default), yields
187+
a read-only Snapshot context for SELECT queries.
188+
**kwargs: Additional keyword arguments (unused, for interface compatibility).
189+
190+
Note: For complex transactional logic with retries, use database.run_in_transaction()
191+
directly. The Transaction context here auto-commits on successful exit.
187192
"""
188193
database = self.get_database()
189194
if transaction:
190-
with cast("Any", database).batch() as batch:
191-
yield cast("SpannerConnection", batch)
195+
session = cast("Any", database).session()
196+
session.create()
197+
try:
198+
txn = session.transaction()
199+
txn.__enter__()
200+
try:
201+
yield cast("SpannerConnection", txn)
202+
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
203+
txn.commit()
204+
except Exception:
205+
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
206+
txn.rollback()
207+
raise
208+
finally:
209+
session.delete()
192210
else:
193-
with cast("Any", database).snapshot() as snapshot:
211+
with cast("Any", database).snapshot(multi_use=True) as snapshot:
194212
yield cast("SpannerConnection", snapshot)
195213

196214
@contextmanager
@@ -209,7 +227,6 @@ def provide_session(
209227
def provide_write_session(
210228
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
211229
) -> Generator[SpannerSyncDriver, None, None]:
212-
"""Convenience wrapper that always yields a write-capable transaction session."""
213230
with self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) as driver:
214231
yield driver
215232

sqlspec/adapters/spanner/driver.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -204,32 +204,21 @@ def _execute_many(self, cursor: "SpannerConnection", statement: "SQL") -> Execut
204204
raise SQLConversionError(msg)
205205
conn = cast("Any", cursor)
206206

207-
parameter_sets = statement.parameters if isinstance(statement.parameters, list) else []
208-
if not parameter_sets:
207+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
208+
209+
if not prepared_parameters or not isinstance(prepared_parameters, list):
209210
msg = "execute_many requires at least one parameter set"
210211
raise SQLConversionError(msg)
211212

212-
base_params = parameter_sets[0]
213-
base_statement = self.prepare_statement(
214-
statement.raw_sql, *[base_params], statement_config=statement.statement_config
215-
)
216-
compiled_sql, _ = self._get_compiled_sql(base_statement, self.statement_config)
217-
218-
batch_inputs: list[dict[str, Any]] = []
219-
for params in parameter_sets:
220-
per_statement = self.prepare_statement(
221-
statement.raw_sql, *[params], statement_config=statement.statement_config
222-
)
223-
_, processed_params = self._get_compiled_sql(per_statement, self.statement_config)
224-
coerced_params = self._coerce_params(processed_params)
213+
batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = []
214+
for params in prepared_parameters:
215+
coerced_params = self._coerce_params(params)
225216
if coerced_params is None:
226217
coerced_params = {}
227-
batch_inputs.append(coerced_params)
228-
229-
batch_args = [(compiled_sql, p, self._infer_param_types(p)) for p in batch_inputs]
218+
batch_args.append((sql, coerced_params, self._infer_param_types(coerced_params)))
230219

231-
row_counts = conn.batch_update(batch_args)
232-
total_rows = int(sum(int(count) for count in row_counts))
220+
_status, row_counts = conn.batch_update(batch_args)
221+
total_rows = sum(row_counts) if row_counts else 0
233222

234223
return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True)
235224

@@ -350,28 +339,28 @@ def _truncate_table_sync(self, table: str) -> None:
350339

351340

352341
def _build_spanner_profile() -> DriverParameterProfile:
353-
type_coercions: dict[type, Any] = {dict: to_json, list: to_json, tuple: to_json}
342+
type_coercions: dict[type, Any] = {dict: to_json}
354343
return DriverParameterProfile(
355344
name="Spanner",
356345
default_style=ParameterStyle.NAMED_AT,
357346
supported_styles={ParameterStyle.NAMED_AT},
358347
default_execution_style=ParameterStyle.NAMED_AT,
359348
supported_execution_styles={ParameterStyle.NAMED_AT},
360349
has_native_list_expansion=True,
361-
json_serializer_strategy="helper",
350+
json_serializer_strategy="none",
362351
default_dialect="spanner",
363352
preserve_parameter_format=True,
364353
needs_static_script_compilation=False,
365354
allow_mixed_parameter_styles=False,
366355
preserve_original_params_for_many=True,
367356
custom_type_coercions=type_coercions,
368-
extras={"type_coercion_overrides": type_coercions},
357+
extras={},
369358
)
370359

371360

372361
_SPANNER_PROFILE = _build_spanner_profile()
373362
register_driver_profile("spanner", _SPANNER_PROFILE)
374363

375364
spanner_statement_config = build_statement_config_from_profile(
376-
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}, json_serializer=to_json
365+
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}
377366
)

sqlspec/extensions/litestar/plugin.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,16 @@
7171
CORRELATION_STATE_KEY = "sqlspec_correlation_id"
7272

7373
__all__ = (
74+
"CORRELATION_STATE_KEY",
7475
"DEFAULT_COMMIT_MODE",
7576
"DEFAULT_CONNECTION_KEY",
77+
"DEFAULT_CORRELATION_HEADER",
7678
"DEFAULT_POOL_KEY",
7779
"DEFAULT_SESSION_KEY",
80+
"TRACE_CONTEXT_FALLBACK_HEADERS",
7881
"CommitMode",
82+
"CorrelationMiddleware",
83+
"PluginConfigState",
7984
"SQLSpecPlugin",
8085
)
8186

@@ -117,7 +122,7 @@ def _build_correlation_headers(*, primary: str, configured: list[str], auto_trac
117122
return tuple(_dedupe_headers(header_order))
118123

119124

120-
class _CorrelationMiddleware:
125+
class CorrelationMiddleware:
121126
__slots__ = ("_app", "_headers")
122127

123128
def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None:
@@ -153,7 +158,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No
153158

154159

155160
@dataclass
156-
class _PluginConfigState:
161+
class PluginConfigState:
157162
"""Internal state for each database configuration."""
158163

159164
config: "DatabaseConfigProtocol[Any, Any, Any]"
@@ -219,7 +224,7 @@ def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -
219224
"""
220225
self._sqlspec = sqlspec
221226

222-
self._plugin_configs: list[_PluginConfigState] = []
227+
self._plugin_configs: list[PluginConfigState] = []
223228
for cfg in self._sqlspec.configs.values():
224229
config_union = cast(
225230
"SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
@@ -276,9 +281,9 @@ def _create_config_state(
276281
self,
277282
config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
278283
settings: "dict[str, Any]",
279-
) -> _PluginConfigState:
284+
) -> PluginConfigState:
280285
"""Create plugin state with handlers for the given configuration."""
281-
state = _PluginConfigState(
286+
state = PluginConfigState(
282287
config=config,
283288
connection_key=settings["connection_key"],
284289
pool_key=settings["pool_key"],
@@ -296,7 +301,7 @@ def _create_config_state(
296301
self._setup_handlers(state)
297302
return state
298303

299-
def _setup_handlers(self, state: _PluginConfigState) -> None:
304+
def _setup_handlers(self, state: PluginConfigState) -> None:
300305
"""Setup handlers for the plugin state."""
301306
connection_key = state.connection_key
302307
pool_key = state.pool_key
@@ -403,7 +408,7 @@ def store_sqlspec_in_state() -> None:
403408
app_config.type_decoders = decoders_list
404409

405410
if self._correlation_headers:
406-
middleware = DefineMiddleware(_CorrelationMiddleware, headers=self._correlation_headers)
411+
middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)
407412
existing_middleware = list(app_config.middleware or [])
408413
existing_middleware.append(middleware)
409414
app_config.middleware = existing_middleware
@@ -579,7 +584,7 @@ def provide_request_connection(
579584

580585
def _get_plugin_state(
581586
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]"
582-
) -> _PluginConfigState:
587+
) -> PluginConfigState:
583588
"""Get plugin state for a configuration by key."""
584589
if isinstance(key, str):
585590
for state in self._plugin_configs:

tests/integration/test_adapters/test_spanner/conftest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def spanner_database(
3535

3636

3737
@pytest.fixture
38-
def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.Client) -> SpannerSyncConfig:
38+
def spanner_config(
39+
spanner_service: SpannerService, spanner_connection: spanner.Client, spanner_database: "Database"
40+
) -> SpannerSyncConfig:
41+
"""Create SpannerSyncConfig after ensuring database exists."""
42+
_ = spanner_database # Ensure database is created before config
3943
api_endpoint = f"{spanner_service.host}:{spanner_service.port}"
4044

4145
return SpannerSyncConfig(
@@ -53,12 +57,27 @@ def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.
5357

5458
@pytest.fixture
5559
def spanner_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
60+
"""Read-only session for SELECT operations."""
5661
sql = SQLSpec()
5762
c = sql.add_config(spanner_config)
5863
with sql.provide_session(c) as session:
5964
yield session
6065

6166

67+
@pytest.fixture
68+
def spanner_write_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
69+
"""Write-capable session for DML operations (INSERT/UPDATE/DELETE)."""
70+
with spanner_config.provide_write_session() as session:
71+
yield session
72+
73+
74+
@pytest.fixture
75+
def spanner_read_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
76+
"""Read-only session for SELECT operations."""
77+
with spanner_config.provide_session() as session:
78+
yield session
79+
80+
6281
def run_ddl(database: "Database", statements: "list[str]", timeout: int = 300) -> None:
6382
"""Execute DDL statements on Spanner database."""
6483
operation = database.update_ddl(statements) # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)