Skip to content

Commit 4059d6e

Browse files
authored
fix: migration incorrectly remove constraints for unique together deletions with postgresql (#450)
* feat: support --no-input for migrate * fix: postgresql failed to remove unique constraint * docs: update changelog * chore: upgrade deps
1 parent e7d8ab5 commit 4059d6e

17 files changed

+727
-338
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
## 0.9
44
### [0.9.0]**(Unreleased)**
55

6+
#### Added
7+
- Support `--no-input` for aerich migrate. ([#450])
8+
69
### Changed
710
- Drop support for Python3.8. ([#446])
811

912
#### Fixed
1013
- fix: m2m migrate raises TypeError. ([#448])
1114
- fix: `aerich init-db` process is suspended. ([#435])
15+
- fix: migration will incorrectly remove constraints with index deletions. ([#450])
1216

17+
[#450]: https://github.com/tortoise/aerich/pull/450
1318
[#448]: https://github.com/tortoise/aerich/pull/448
1419
[#446]: https://github.com/tortoise/aerich/pull/446
1520
[#435]: https://github.com/tortoise/aerich/pull/435

aerich/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,10 @@ async def inspectdb(self, tables: list[str] | None = None) -> str:
250250
inspect = cls(connection, tables)
251251
return await inspect.inspect()
252252

253-
async def migrate(self, name: str = "update", empty: bool = False) -> str:
254-
return await Migrate.migrate(name, empty)
253+
async def migrate(
254+
self, name: str = "update", empty: bool = False, no_input: bool = False
255+
) -> str:
256+
return await Migrate.migrate(name, empty, no_input)
255257

256258
async def init_db(self, safe: bool) -> None:
257259
location = self.location

aerich/cli.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ async def cli(ctx: Context, config, app) -> None:
8989
@cli.command(help="Generate a migration file for the current state of the models.")
9090
@click.option("--name", default="update", show_default=True, help="Migration name.")
9191
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
92+
@click.option("--no-input", default=False, is_flag=True, help="Do not ask for prompt.")
9293
@click.pass_context
93-
async def migrate(ctx: Context, name, empty) -> None:
94+
async def migrate(ctx: Context, name, empty, no_input) -> None:
9495
command = ctx.obj["command"]
95-
ret = await command.migrate(name, empty)
96+
ret = await command.migrate(name, empty, no_input)
9697
if not ret:
9798
return click.secho("No changes detected", fg=Color.yellow)
9899
click.secho(f"Success creating migration file {ret}", fg=Color.green)

aerich/ddl/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class BaseDDL:
2525
)
2626
_ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}'
2727
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"'
28+
_DROP_CONSTRAINT_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{name}"'
2829
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
2930
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
3031
_M2M_TABLE_TEMPLATE = (
@@ -223,6 +224,12 @@ def drop_index(
223224
def drop_index_by_name(self, model: type[Model], index_name: str) -> str:
224225
return self.drop_index(model, [], name=index_name)
225226

227+
def drop_unique_constraint(self, model: type[Model], name: str) -> str:
228+
return self._DROP_CONSTRAINT_TEMPLATE.format(
229+
table_name=model._meta.db_table,
230+
name=name,
231+
)
232+
226233
def _generate_fk_name(
227234
self, db_table: str, field_describe: dict, reference_table_describe: dict
228235
) -> str:

aerich/migrate.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import importlib
55
import os
6+
import re
67
from collections.abc import Iterable
78
from datetime import datetime
89
from pathlib import Path
@@ -23,7 +24,9 @@
2324
get_app_connection,
2425
get_dict_diff_by_key,
2526
get_models_describe,
27+
import_py_file,
2628
is_default_function,
29+
run_async,
2730
)
2831

2932
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
@@ -163,7 +166,7 @@ def _exclude_extra_field_types(cls, diffs) -> list[tuple]:
163166
]
164167

165168
@classmethod
166-
async def migrate(cls, name: str, empty: bool) -> str:
169+
async def migrate(cls, name: str, empty: bool, no_input: bool = False) -> str:
167170
"""
168171
diff old models and new models to generate diff content
169172
:param name: str name for migration
@@ -174,8 +177,8 @@ async def migrate(cls, name: str, empty: bool) -> str:
174177
return await cls._generate_diff_py(name)
175178
new_version_content = get_models_describe(cls.app)
176179
last_version = cast(dict, cls._last_version_content)
177-
cls.diff_models(last_version, new_version_content)
178-
cls.diff_models(new_version_content, last_version, False)
180+
cls.diff_models(last_version, new_version_content, no_input=no_input)
181+
cls.diff_models(new_version_content, last_version, False, no_input=no_input)
179182

180183
cls._merge_operators()
181184

@@ -393,9 +396,31 @@ def _handle_o2o_fields(
393396
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
394397
)
395398

399+
@classmethod
400+
def _is_unique_constraint(cls, model: type[Model], index_name: str) -> bool:
401+
if cls.dialect != "postgres":
402+
return False
403+
# For postgresql, if a unique_together was created when generating the table, it is
404+
# a constraint. And if it was created after table generated, it will be a unique index.
405+
migrate_files = cls.get_all_version_files()
406+
if len(migrate_files) < 2:
407+
return True
408+
pattern = re.compile(rf' "?{index_name}"? ')
409+
for filename in reversed(migrate_files[1:]):
410+
module = import_py_file(Path(cls.migrate_location, filename))
411+
upgrade_sql = run_async(module.upgrade, None)
412+
if pattern.search(upgrade_sql):
413+
line = [i for i in upgrade_sql.splitlines() if pattern.search(i)][0]
414+
prefix_words = pattern.split(line)[0].lower().split()
415+
if "drop" in prefix_words:
416+
# The migrate file may be generated by `aerich migrate` without applied by `aerich upgrade`
417+
continue
418+
return "constraint" in prefix_words
419+
return True
420+
396421
@classmethod
397422
def diff_models(
398-
cls, old_models: dict[str, dict], new_models: dict[str, dict], upgrade=True
423+
cls, old_models: dict[str, dict], new_models: dict[str, dict], upgrade=True, no_input=False
399424
) -> None:
400425
"""
401426
diff models and add operators
@@ -467,7 +492,15 @@ def diff_models(
467492
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
468493
# remove unique_together
469494
for index in old_unique_together.difference(new_unique_together):
470-
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
495+
index_name = cls._unique_index_name(model, index)
496+
if upgrade and cls._is_unique_constraint(model, index_name):
497+
cls._add_operator(
498+
cls.ddl.drop_unique_constraint(model, index_name), upgrade, True
499+
)
500+
else:
501+
cls._add_operator(
502+
cls.ddl.drop_index_by_name(model, index_name), upgrade, True
503+
)
471504
# add indexes
472505
for idx in new_indexes.difference(old_indexes):
473506
cls._add_operator(cls._add_index(model, idx), upgrade, fk_m2m_index=True)
@@ -536,7 +569,7 @@ def diff_models(
536569
# print a empty line to warn that is another model
537570
prefix = "\n" + prefix
538571
models_with_rename_field.add(new_model_str)
539-
is_rename = click.prompt(
572+
is_rename = no_input or click.prompt(
540573
f"{prefix}Rename {old_data_field_name} to {new_data_field_name}?",
541574
default=True,
542575
type=bool,
@@ -757,6 +790,11 @@ def _resolve_fk_fields_name(cls, model: type[Model], fields_name: Iterable[str])
757790
ret.append(field_name)
758791
return ret
759792

793+
@classmethod
794+
def _unique_index_name(cls, model: type[Model], fields_name: Iterable[str]) -> str:
795+
field_names = cls._resolve_fk_fields_name(model, fields_name)
796+
return cls.ddl._index_name(True, model, field_names)
797+
760798
@classmethod
761799
def _drop_index(
762800
cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False

aerich/utils.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,25 @@
44
import os
55
import re
66
import sys
7-
from collections.abc import Generator
7+
from collections.abc import Awaitable, Callable, Generator
88
from pathlib import Path
99
from types import ModuleType
10+
from typing import TypeVar
1011

12+
from anyio import from_thread
1113
from asyncclick import BadOptionUsage, ClickException, Context
1214
from dictdiffer import diff
1315
from tortoise import BaseDBAsyncClient, Tortoise
1416

17+
if sys.version_info >= (3, 11):
18+
from typing import ParamSpec, TypeVarTuple, Unpack
19+
else:
20+
from typing_extensions import ParamSpec, TypeVarTuple, Unpack
21+
22+
T_Retval = TypeVar("T_Retval")
23+
PosArgsT = TypeVarTuple("PosArgsT")
24+
P = ParamSpec("P")
25+
1526

1627
def add_src_path(path: str) -> str:
1728
"""
@@ -141,3 +152,22 @@ def get_dict_diff_by_key(
141152
if additions:
142153
for index in sorted(additions):
143154
yield from diff([], [new_fields[index]]) # add
155+
156+
157+
def run_async(
158+
async_func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
159+
*args: Unpack[PosArgsT],
160+
) -> T_Retval:
161+
"""Run async function in worker thread and get the result of it"""
162+
# `asyncio.run(async_func())` can get the result of async function,
163+
# but it will close the running loop.
164+
result: list[T_Retval] = []
165+
166+
async def runner() -> None:
167+
res = await async_func(*args)
168+
result.append(res)
169+
170+
with from_thread.start_blocking_portal() as portal:
171+
portal.call(runner)
172+
173+
return result[0]

conftest.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
from collections.abc import Generator
7+
from contextlib import contextmanager
78
from pathlib import Path
89

910
import pytest
@@ -31,6 +32,7 @@
3132
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
3233
},
3334
}
35+
TEST_DIR = Path(__file__).parent / "tests"
3436

3537

3638
@pytest.fixture(scope="function", autouse=True)
@@ -70,14 +72,11 @@ async def initialize_tests(event_loop, request) -> None:
7072
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
7173

7274

73-
@pytest.fixture
74-
def new_aerich_project(tmp_path: Path):
75-
test_dir = Path(__file__).parent / "tests"
76-
asset_dir = test_dir / "assets" / "fake"
75+
@contextmanager
76+
def _new_aerich_project(tmp_path: Path, asset_dir: Path, models_py: Path, test_dir=TEST_DIR):
7777
settings_py = asset_dir / "settings.py"
7878
_tests_py = asset_dir / "_tests.py"
7979
db_py = asset_dir / "db.py"
80-
models_py = test_dir / "models.py"
8180
models_second_py = test_dir / "models_second.py"
8281
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
8382
dst_dir = tmp_path / "tests"
@@ -95,3 +94,21 @@ def new_aerich_project(tmp_path: Path):
9594
run_shell("python db.py drop", capture_output=False)
9695
if should_remove:
9796
sys.path.remove(str(tmp_path))
97+
98+
99+
@pytest.fixture
100+
def new_aerich_project(tmp_path: Path):
101+
# Create a tortoise project in tmp_path that managed by aerich using assets from tests/assets/fake/
102+
asset_dir = TEST_DIR / "assets" / "fake"
103+
models_py = TEST_DIR / "models.py"
104+
with _new_aerich_project(tmp_path, asset_dir, models_py):
105+
yield
106+
107+
108+
@pytest.fixture
109+
def tmp_aerich_project(tmp_path: Path):
110+
# Create a tortoise project in tmp_path that managed by aerich using assets from tests/assets/remove_constraint/
111+
asset_dir = TEST_DIR / "assets" / "remove_constraint"
112+
models_py = asset_dir / "models.py"
113+
with _new_aerich_project(tmp_path, asset_dir, models_py):
114+
yield

0 commit comments

Comments
 (0)