Skip to content

Commit f4c3952

Browse files
authored
Merge pull request #35 from UW-Macrostrat/database-updates
Added a function to filter statements
2 parents 3c22804 + a4a8f4b commit f4c3952

File tree

5 files changed

+65
-7
lines changed

5 files changed

+65
-7
lines changed

database/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## [3.5.1] - 2024-12-21
4+
5+
- Add a `statement_filter` parameter to the `run_sql` function to allow for
6+
filtering of statements in a SQL file.
7+
- Improved the consistency of the `Database.run_sql` function with the `run_sql`
8+
utility function.
9+
310
## [3.5.0] - 2024-11-25
411

512
- Add database transfer utilities for asynchronous `pg_load` and `pg_dump`

database/macrostrat/database/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def run_sql(self, fn, params=None, **kwargs):
128128
Returns: Iterator of results from the query.
129129
"""
130130
params = self._setup_params(params, kwargs)
131-
return iter(run_sql(self.session, fn, params, **kwargs))
131+
return run_sql(self.session, fn, params, **kwargs)
132132

133133
def run_query(self, sql, params=None, **kwargs):
134134
"""Run a single query on the database object, returning the result.

database/macrostrat/database/utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from typing import IO, Union
66
from warnings import warn
77

8+
import psycopg2.errors
89
from click import echo, secho
910
from psycopg2.extensions import set_wait_callback
1011
from psycopg2.extras import wait_select
1112
from psycopg2.sql import SQL, Composable, Composed
12-
import psycopg2.errors
1313
from rich.console import Console
1414
from sqlalchemy import MetaData, create_engine, text
1515
from sqlalchemy.engine import Connection, Engine
@@ -18,7 +18,7 @@
1818
InternalError,
1919
InvalidRequestError,
2020
ProgrammingError,
21-
OperationalError
21+
OperationalError,
2222
)
2323
from sqlalchemy.orm import sessionmaker
2424
from sqlalchemy.schema import Table
@@ -232,6 +232,9 @@ def infer_has_server_binds(sql):
232232
return "%s" in sql or search(r"%\(\w+\)s", sql)
233233

234234

235+
_default_statement_filter = lambda sql_text, params: True
236+
237+
235238
def _run_sql(connectable, sql, params=None, **kwargs):
236239
"""
237240
Internal function for running a query on a SQLAlchemy connectable,
@@ -247,6 +250,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
247250
raise_errors = kwargs.pop("raise_errors", False)
248251
has_server_binds = kwargs.pop("has_server_binds", None)
249252
ensure_single_query = kwargs.pop("ensure_single_query", False)
253+
statement_filter = kwargs.pop("statement_filter", _default_statement_filter)
250254

251255
if stop_on_error:
252256
raise_errors = True
@@ -288,6 +292,11 @@ def _run_sql(connectable, sql, params=None, **kwargs):
288292
if has_server_binds is None:
289293
has_server_binds = infer_has_server_binds(sql_text)
290294

295+
should_run = statement_filter(sql_text, params)
296+
if not should_run:
297+
pretty_print(sql_text, dim=True, strikethrough=True)
298+
continue
299+
291300
# This only does something for postgresql, but it's harmless to run it for other engines
292301
set_wait_callback(wait_select)
293302

@@ -325,7 +334,9 @@ def _run_sql(connectable, sql, params=None, **kwargs):
325334

326335
def _should_raise_query_error(err):
327336
"""Determine if an error should be raised for a query or not."""
328-
if not isinstance(err, (ProgrammingError, IntegrityError, InternalError, OperationalError)):
337+
if not isinstance(
338+
err, (ProgrammingError, IntegrityError, InternalError, OperationalError)
339+
):
329340
return True
330341

331342
orig_err = getattr(err, "orig", None)
@@ -336,7 +347,10 @@ def _should_raise_query_error(err):
336347
# We might want to change this behavior in the future, or support more graceful handling of errors from other
337348
# database backends.
338349
# Ideally we could handle operational errors more gracefully
339-
if isinstance(orig_err, psycopg2.errors.QueryCanceled) or getattr(orig_err, "pgcode", None) == "57014":
350+
if (
351+
isinstance(orig_err, psycopg2.errors.QueryCanceled)
352+
or getattr(orig_err, "pgcode", None) == "57014"
353+
):
340354
return True
341355

342356
return False
@@ -444,6 +458,9 @@ def run_sql(*args, **kwargs):
444458
returning a list after completion.
445459
ensure_single_query : bool
446460
If True, raise an error if multiple queries are passed when only one is expected.
461+
statement_filter : Callable
462+
A function that takes a SQL statement and parameters and returns True if the statement
463+
should be run, and False if it should be skipped.
447464
"""
448465
res = _run_sql(*args, **kwargs)
449466
if kwargs.pop("yield_results", False):

database/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ authors = ["Daven Quinn <[email protected]>"]
33
description = "A SQLAlchemy-based database toolkit."
44
name = "macrostrat.database"
55
packages = [{ include = "macrostrat" }]
6-
version = "3.5.0"
6+
version = "3.5.1"
77

88
[tool.poetry.dependencies]
99
GeoAlchemy2 = "^0.15.2"

database/tests/test_database.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from psycopg2.errors import SyntaxError
1212
from psycopg2.extensions import AsIs
1313
from psycopg2.sql import SQL, Identifier, Literal, Placeholder
14-
from pytest import fixture, mark, raises, warns
14+
from pytest import fixture, raises, warns
1515
from sqlalchemy.exc import ProgrammingError
1616
from sqlalchemy.sql import text
1717

@@ -106,6 +106,40 @@ def test_sql_text_inference_6():
106106
assert infer_is_sql_text(insert_sample_query)
107107

108108

109+
def test_sql_statement_filtering(db):
110+
sql = """
111+
INSERT INTO sample (name) VALUES (:name);
112+
113+
DELETE FROM sample WHERE name = :name;
114+
"""
115+
116+
assert infer_is_sql_text(sql)
117+
118+
with db.transaction(rollback="always"):
119+
# Make sure there are no samples
120+
assert _get_sample_count(db) == 0
121+
122+
# Run the SQL, filtering out the DELETE statement
123+
124+
def filter_func(statement, params):
125+
return not statement.startswith("DELETE")
126+
127+
res = db.run_sql(
128+
sql,
129+
params=dict(name="Test"),
130+
raise_errors=True,
131+
statement_filter=filter_func,
132+
yield_results=False,
133+
)
134+
135+
assert len(res) == 1
136+
assert _get_sample_count(db) == 1
137+
138+
139+
def _get_sample_count(db):
140+
return db.run_query("SELECT count(*) FROM sample").scalar()
141+
142+
109143
def test_sql_interpolation_psycopg(db):
110144
db.run_sql(insert_sample_query, params=dict(name="Test"), raise_errors=True)
111145
db.session.commit()

0 commit comments

Comments
 (0)