55from typing import IO , Union
66from warnings import warn
77
8+ import psycopg2 .errors
89from click import echo , secho
910from psycopg2 .extensions import set_wait_callback
1011from psycopg2 .extras import wait_select
1112from psycopg2 .sql import SQL , Composable , Composed
12- import psycopg2 .errors
1313from rich .console import Console
1414from sqlalchemy import MetaData , create_engine , text
1515from sqlalchemy .engine import Connection , Engine
1818 InternalError ,
1919 InvalidRequestError ,
2020 ProgrammingError ,
21- OperationalError
21+ OperationalError ,
2222)
2323from sqlalchemy .orm import sessionmaker
2424from sqlalchemy .schema import Table
@@ -232,6 +232,9 @@ def infer_has_server_binds(sql):
232232 return "%s" in sql or search (r"%\(\w+\)s" , sql )
233233
234234
235+ _default_statement_filter = lambda sql_text , params : True
236+
237+
235238def _run_sql (connectable , sql , params = None , ** kwargs ):
236239 """
237240 Internal function for running a query on a SQLAlchemy connectable,
@@ -247,6 +250,7 @@ def _run_sql(connectable, sql, params=None, **kwargs):
247250 raise_errors = kwargs .pop ("raise_errors" , False )
248251 has_server_binds = kwargs .pop ("has_server_binds" , None )
249252 ensure_single_query = kwargs .pop ("ensure_single_query" , False )
253+ statement_filter = kwargs .pop ("statement_filter" , _default_statement_filter )
250254
251255 if stop_on_error :
252256 raise_errors = True
@@ -288,6 +292,11 @@ def _run_sql(connectable, sql, params=None, **kwargs):
288292 if has_server_binds is None :
289293 has_server_binds = infer_has_server_binds (sql_text )
290294
295+ should_run = statement_filter (sql_text , params )
296+ if not should_run :
297+ pretty_print (sql_text , dim = True , strikethrough = True )
298+ continue
299+
291300 # This only does something for postgresql, but it's harmless to run it for other engines
292301 set_wait_callback (wait_select )
293302
@@ -325,7 +334,9 @@ def _run_sql(connectable, sql, params=None, **kwargs):
325334
326335def _should_raise_query_error (err ):
327336 """Determine if an error should be raised for a query or not."""
328- if not isinstance (err , (ProgrammingError , IntegrityError , InternalError , OperationalError )):
337+ if not isinstance (
338+ err , (ProgrammingError , IntegrityError , InternalError , OperationalError )
339+ ):
329340 return True
330341
331342 orig_err = getattr (err , "orig" , None )
@@ -336,7 +347,10 @@ def _should_raise_query_error(err):
336347 # We might want to change this behavior in the future, or support more graceful handling of errors from other
337348 # database backends.
338349 # Ideally we could handle operational errors more gracefully
339- if isinstance (orig_err , psycopg2 .errors .QueryCanceled ) or getattr (orig_err , "pgcode" , None ) == "57014" :
350+ if (
351+ isinstance (orig_err , psycopg2 .errors .QueryCanceled )
352+ or getattr (orig_err , "pgcode" , None ) == "57014"
353+ ):
340354 return True
341355
342356 return False
@@ -444,6 +458,9 @@ def run_sql(*args, **kwargs):
444458 returning a list after completion.
445459 ensure_single_query : bool
446460 If True, raise an error if multiple queries are passed when only one is expected.
461+ statement_filter : Callable
462+ A function that takes a SQL statement and parameters and returns True if the statement
463+ should be run, and False if it should be skipped.
447464 """
448465 res = _run_sql (* args , ** kwargs )
449466 if kwargs .pop ("yield_results" , False ):
0 commit comments