Skip to content

Commit 1e1ace1

Browse files
authored
Fix: Create a transaction and a session when migrating a snapshot (#4794)
1 parent eeb18fb commit 1e1ace1

File tree

2 files changed

+47
-37
lines changed

2 files changed

+47
-37
lines changed

sqlmesh/core/snapshot/evaluator.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -911,39 +911,40 @@ def _migrate_snapshot(
911911
):
912912
return
913913

914+
deployability_index = DeployabilityIndex.all_deployable()
915+
render_kwargs: t.Dict[str, t.Any] = dict(
916+
engine_adapter=adapter,
917+
snapshots=parent_snapshots_by_name(snapshot, snapshots),
918+
runtime_stage=RuntimeStage.CREATING,
919+
deployability_index=deployability_index,
920+
)
914921
target_table_name = snapshot.table_name()
915-
if adapter.table_exists(target_table_name):
916-
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
917-
tmp_table_name = snapshot.table_name(is_deployable=False)
918-
logger.info(
919-
"Migrating table schema from '%s' to '%s'",
920-
tmp_table_name,
921-
target_table_name,
922-
)
923-
evaluation_strategy.migrate(
924-
target_table_name=target_table_name,
925-
source_table_name=tmp_table_name,
926-
snapshot=snapshot,
927-
snapshots=parent_snapshots_by_name(snapshot, snapshots),
928-
allow_destructive_snapshots=allow_destructive_snapshots,
929-
)
930-
else:
931-
logger.info(
932-
"Creating table '%s' for the snapshot of the forward-only model %s",
933-
target_table_name,
934-
snapshot.snapshot_id,
935-
)
936-
deployability_index = DeployabilityIndex.all_deployable()
937-
render_kwargs: t.Dict[str, t.Any] = dict(
938-
engine_adapter=adapter,
939-
snapshots=parent_snapshots_by_name(snapshot, snapshots),
940-
runtime_stage=RuntimeStage.CREATING,
941-
deployability_index=deployability_index,
942-
)
943-
with (
944-
adapter.transaction(),
945-
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
946-
):
922+
923+
with (
924+
adapter.transaction(),
925+
adapter.session(snapshot.model.render_session_properties(**render_kwargs)),
926+
):
927+
if adapter.table_exists(target_table_name):
928+
evaluation_strategy = _evaluation_strategy(snapshot, adapter)
929+
tmp_table_name = snapshot.table_name(is_deployable=False)
930+
logger.info(
931+
"Migrating table schema from '%s' to '%s'",
932+
tmp_table_name,
933+
target_table_name,
934+
)
935+
evaluation_strategy.migrate(
936+
target_table_name=target_table_name,
937+
source_table_name=tmp_table_name,
938+
snapshot=snapshot,
939+
snapshots=parent_snapshots_by_name(snapshot, snapshots),
940+
allow_destructive_snapshots=allow_destructive_snapshots,
941+
)
942+
else:
943+
logger.info(
944+
"Creating table '%s' for the snapshot of the forward-only model %s",
945+
target_table_name,
946+
snapshot.snapshot_id,
947+
)
947948
self._execute_create(
948949
snapshot=snapshot,
949950
table_name=target_table_name,

tests/core/test_snapshot_evaluator.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
SnapshotTableCleanupTask,
5555
)
5656
from sqlmesh.core.snapshot.definition import to_view_mapping
57-
from sqlmesh.core.snapshot.evaluator import CustomMaterialization
57+
from sqlmesh.core.snapshot.evaluator import CustomMaterialization, SnapshotCreationFailedError
5858
from sqlmesh.utils.concurrency import NodeExecutionFailedError
5959
from sqlmesh.utils.date import to_timestamp
6060
from sqlmesh.utils.errors import ConfigError, SQLMeshError, DestructiveChangeError
@@ -92,13 +92,16 @@ def date_kwargs() -> t.Dict[str, str]:
9292

9393
@pytest.fixture
9494
def adapter_mock(mocker: MockerFixture):
95+
def mock_exit(self, exc_type, exc_value, traceback):
96+
pass
97+
9598
transaction_mock = mocker.Mock()
9699
transaction_mock.__enter__ = mocker.Mock()
97-
transaction_mock.__exit__ = mocker.Mock()
100+
transaction_mock.__exit__ = mock_exit
98101

99102
session_mock = mocker.Mock()
100103
session_mock.__enter__ = mocker.Mock()
101-
session_mock.__exit__ = mocker.Mock()
104+
session_mock.__exit__ = mock_exit
102105

103106
adapter_mock = mocker.Mock()
104107
adapter_mock.transaction.return_value = transaction_mock
@@ -1160,6 +1163,7 @@ def test_migrate(mocker: MockerFixture, make_snapshot):
11601163
cursor_mock = mocker.Mock()
11611164
connection_mock.cursor.return_value = cursor_mock
11621165
adapter = EngineAdapter(lambda: connection_mock, "")
1166+
session_spy = mocker.spy(adapter, "session")
11631167

11641168
current_table = "sqlmesh__test_schema.test_schema__test_model__1"
11651169

@@ -1201,6 +1205,8 @@ def columns(table_name):
12011205
]
12021206
)
12031207

1208+
session_spy.assert_called_once()
1209+
12041210

12051211
def test_migrate_missing_table(mocker: MockerFixture, make_snapshot):
12061212
connection_mock = mocker.NonCallableMock()
@@ -1596,7 +1602,8 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m
15961602
),
15971603
]
15981604

1599-
evaluator.create([snapshot], {})
1605+
with pytest.raises(SnapshotCreationFailedError):
1606+
evaluator.create([snapshot], {})
16001607

16011608
adapter_mock.clone_table.assert_called_once_with(
16021609
f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev",
@@ -2537,7 +2544,9 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot
25372544
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
25382545

25392546
evaluator = SnapshotEvaluator(adapter_mock)
2540-
evaluator.create([snapshot], {})
2547+
2548+
with pytest.raises(SnapshotCreationFailedError):
2549+
evaluator.create([snapshot], {})
25412550

25422551
adapter_mock.replace_query.assert_called_once_with(
25432552
f"sqlmesh__db.db__seed__{snapshot.version}",

0 commit comments

Comments
 (0)