77
88from django .conf import settings
99from django .core .management .base import BaseCommand , CommandError
10- from psycopg2 import ProgrammingError , connect , sql
10+ from psycopg2 import OperationalError , ProgrammingError , connect , sql
1111
1212DEFAULT_DATABASE = settings .DATABASES ["default" ]
1313DB_HOST = DEFAULT_DATABASE .get ("HOST" , "localhost" )
@@ -52,7 +52,14 @@ def handle(self, *args, **options):
5252
5353 temp_db = f"temp_{ DB_NAME } "
5454 try :
55- self ._execute_sql ("postgres" , [f"CREATE DATABASE { temp_db } TEMPLATE { DB_NAME } ;" ])
55+ self ._execute_sql (
56+ "postgres" ,
57+ [
58+ sql .SQL ("CREATE DATABASE {temp_db} TEMPLATE {DB_NAME};" ).format (
59+ temp_db = sql .Identifier (temp_db ), DB_NAME = sql .Identifier (DB_NAME )
60+ )
61+ ],
62+ )
5663
5764 self .stdout .write (self .style .SUCCESS (f"Created temporary DB: { temp_db } " ))
5865
@@ -86,20 +93,27 @@ def handle(self, *args, **options):
8693 raise CommandError (message ) from e
8794 finally :
8895 try :
89- self ._execute_sql ("postgres" , [f"DROP DATABASE IF EXISTS { temp_db } ;" ])
90- except CalledProcessError :
96+ self ._execute_sql (
97+ "postgres" ,
98+ [
99+ sql .SQL ("DROP DATABASE IF EXISTS {temp_db};" ).format (
100+ temp_db = sql .Identifier (temp_db )
101+ )
102+ ],
103+ )
104+ except (ProgrammingError , OperationalError ):
91105 self .stderr .write (
92106 self .style .WARNING (f"Failed to drop temp DB { temp_db } (ignored)." )
93107 )
94108
95- def _table_list_query (self ) -> str :
96- return """
109+ def _table_list_query (self ) -> sql . Composable :
110+ return sql . SQL ( """
97111 SELECT table_name
98112 FROM information_schema.columns
99113 WHERE table_schema = 'public' AND column_name = 'email';
100- """
114+ """ )
101115
102- def _remove_emails (self , tables : list [str ]) -> list [str ]:
116+ def _remove_emails (self , tables : list [str ]) -> list [sql . Composable ]:
103117 return [
104118 sql .SQL ("UPDATE {table} SET email = '';" ).format (table = sql .Identifier (table ))
105119 for table in tables
@@ -108,7 +122,7 @@ def _remove_emails(self, tables: list[str]) -> list[str]:
108122 def _execute_sql (
109123 self ,
110124 dbname : str ,
111- sql_queries : list [str ],
125+ sql_queries : list [sql . Composable ],
112126 ):
113127 connection = connect (
114128 dbname = dbname ,
0 commit comments