99
1010import click
1111import structlog
12+ from alembic .config import Config
13+ from alembic .runtime .migration import MigrationContext
14+ from alembic .script import ScriptDirectory
1215from safir .database import (
1316 create_database_engine ,
1417 is_database_current ,
1518 stamp_database ,
1619)
1720from safir .logging import configure_logging
21+ from sqlalchemy import Connection
22+ from sqlalchemy .ext .asyncio import AsyncEngine
1823
1924from .config import config
2025from .database import init_database
@@ -85,12 +90,39 @@ def init(*, alembic_config_path: Path, reset: bool) -> None:
8590)
8691def update_db_schema (* , alembic_config_path : Path ) -> None :
8792 """Update the SQL database schema."""
93+ logger = structlog .get_logger ("docverse" )
94+
95+ engine = create_database_engine (
96+ config .database_url , config .database_password
97+ )
98+
99+ # Get current and target revisions
100+ alembic_config = Config (str (alembic_config_path ))
101+ alembic_scripts = ScriptDirectory .from_config (alembic_config )
102+ head_rev = alembic_scripts .get_current_head ()
103+ current_rev = asyncio .run (_get_current_revision (engine ))
104+
105+ logger .info (
106+ "Starting database schema update" ,
107+ current_revision = current_rev ,
108+ target_revision = head_rev ,
109+ )
110+
88111 subprocess .run (
89112 ["alembic" , "upgrade" , "head" ],
90113 check = True ,
91114 cwd = str (alembic_config_path .parent ),
92115 )
93116
117+ new_rev = asyncio .run (_get_current_revision (engine ))
118+ asyncio .run (engine .dispose ())
119+
120+ logger .info (
121+ "Database schema update complete" ,
122+ previous_revision = current_rev ,
123+ current_revision = new_rev ,
124+ )
125+
94126
95127@main .command ()
96128@click .option (
@@ -110,3 +142,18 @@ def validate_db_schema(*, alembic_config_path: Path) -> None:
110142 ):
111143 msg = "Database schema is not current"
112144 raise click .ClickException (msg )
145+
146+
147+ async def _get_current_revision (engine : AsyncEngine ) -> str | None :
148+ """Get the current Alembic revision from the database."""
149+
150+ def _get_heads (connection : Connection ) -> set [str ]:
151+ context = MigrationContext .configure (connection )
152+ return set (context .get_current_heads ())
153+
154+ async with engine .begin () as connection :
155+ heads = await connection .run_sync (_get_heads )
156+ # Return single revision or comma-joined if multiple heads
157+ if not heads :
158+ return None
159+ return "," .join (sorted (heads ))
0 commit comments