Skip to content

fix: migration incorrectly remove constraints for unique together deletions with postgresql #450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
## 0.9
### [0.9.0]**(Unreleased)**

#### Added
- Support `--no-input` for aerich migrate. ([#450])

### Changed
- Drop support for Python3.8. ([#446])

#### Fixed
- fix: m2m migrate raises TypeError. ([#448])
- fix: `aerich init-db` process is suspended. ([#435])
- fix: migration will incorrectly remove constraints with index deletions. ([#450])

[#450]: https://github.com/tortoise/aerich/pull/450
[#448]: https://github.com/tortoise/aerich/pull/448
[#446]: https://github.com/tortoise/aerich/pull/446
[#435]: https://github.com/tortoise/aerich/pull/435
Expand Down
6 changes: 4 additions & 2 deletions aerich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,10 @@ async def inspectdb(self, tables: list[str] | None = None) -> str:
inspect = cls(connection, tables)
return await inspect.inspect()

async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)
async def migrate(
self, name: str = "update", empty: bool = False, no_input: bool = False
) -> str:
return await Migrate.migrate(name, empty, no_input)

async def init_db(self, safe: bool) -> None:
location = self.location
Expand Down
5 changes: 3 additions & 2 deletions aerich/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ async def cli(ctx: Context, config, app) -> None:
@cli.command(help="Generate a migration file for the current state of the models.")
@click.option("--name", default="update", show_default=True, help="Migration name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.option("--no-input", default=False, is_flag=True, help="Do not ask for prompt.")
@click.pass_context
async def migrate(ctx: Context, name, empty) -> None:
async def migrate(ctx: Context, name, empty, no_input) -> None:
command = ctx.obj["command"]
ret = await command.migrate(name, empty)
ret = await command.migrate(name, empty, no_input)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success creating migration file {ret}", fg=Color.green)
Expand Down
7 changes: 7 additions & 0 deletions aerich/ddl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class BaseDDL:
)
_ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}'
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"'
_DROP_CONSTRAINT_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = (
Expand Down Expand Up @@ -223,6 +224,12 @@ def drop_index(
def drop_index_by_name(self, model: type[Model], index_name: str) -> str:
return self.drop_index(model, [], name=index_name)

def drop_unique_constraint(self, model: type[Model], name: str) -> str:
return self._DROP_CONSTRAINT_TEMPLATE.format(
table_name=model._meta.db_table,
name=name,
)

def _generate_fk_name(
self, db_table: str, field_describe: dict, reference_table_describe: dict
) -> str:
Expand Down
50 changes: 44 additions & 6 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
import importlib
import os
import re
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
Expand All @@ -23,7 +24,9 @@
get_app_connection,
get_dict_diff_by_key,
get_models_describe,
import_py_file,
is_default_function,
run_async,
)

MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
Expand Down Expand Up @@ -163,7 +166,7 @@ def _exclude_extra_field_types(cls, diffs) -> list[tuple]:
]

@classmethod
async def migrate(cls, name: str, empty: bool) -> str:
async def migrate(cls, name: str, empty: bool, no_input: bool = False) -> str:
"""
diff old models and new models to generate diff content
:param name: str name for migration
Expand All @@ -174,8 +177,8 @@ async def migrate(cls, name: str, empty: bool) -> str:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app)
last_version = cast(dict, cls._last_version_content)
cls.diff_models(last_version, new_version_content)
cls.diff_models(new_version_content, last_version, False)
cls.diff_models(last_version, new_version_content, no_input=no_input)
cls.diff_models(new_version_content, last_version, False, no_input=no_input)

cls._merge_operators()

Expand Down Expand Up @@ -393,9 +396,31 @@ def _handle_o2o_fields(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)

@classmethod
def _is_unique_constraint(cls, model: type[Model], index_name: str) -> bool:
if cls.dialect != "postgres":
return False
# For postgresql, if a unique_together was created when generating the table, it is
# a constraint. And if it was created after table generated, it will be a unique index.
migrate_files = cls.get_all_version_files()
if len(migrate_files) < 2:
return True
pattern = re.compile(rf' "?{index_name}"? ')
for filename in reversed(migrate_files[1:]):
module = import_py_file(Path(cls.migrate_location, filename))
upgrade_sql = run_async(module.upgrade, None)
if pattern.search(upgrade_sql):
line = [i for i in upgrade_sql.splitlines() if pattern.search(i)][0]
prefix_words = pattern.split(line)[0].lower().split()
if "drop" in prefix_words:
# The migrate file may be generated by `aerich migrate` without applied by `aerich upgrade`
continue
return "constraint" in prefix_words
return True

@classmethod
def diff_models(
cls, old_models: dict[str, dict], new_models: dict[str, dict], upgrade=True
cls, old_models: dict[str, dict], new_models: dict[str, dict], upgrade=True, no_input=False
) -> None:
"""
diff models and add operators
Expand Down Expand Up @@ -467,7 +492,15 @@ def diff_models(
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
# remove unique_together
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
index_name = cls._unique_index_name(model, index)
if upgrade and cls._is_unique_constraint(model, index_name):
cls._add_operator(
cls.ddl.drop_unique_constraint(model, index_name), upgrade, True
)
else:
cls._add_operator(
cls.ddl.drop_index_by_name(model, index_name), upgrade, True
)
# add indexes
for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx), upgrade, fk_m2m_index=True)
Expand Down Expand Up @@ -536,7 +569,7 @@ def diff_models(
# print a empty line to warn that is another model
prefix = "\n" + prefix
models_with_rename_field.add(new_model_str)
is_rename = click.prompt(
is_rename = no_input or click.prompt(
f"{prefix}Rename {old_data_field_name} to {new_data_field_name}?",
default=True,
type=bool,
Expand Down Expand Up @@ -757,6 +790,11 @@ def _resolve_fk_fields_name(cls, model: type[Model], fields_name: Iterable[str])
ret.append(field_name)
return ret

@classmethod
def _unique_index_name(cls, model: type[Model], fields_name: Iterable[str]) -> str:
field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl._index_name(True, model, field_names)

@classmethod
def _drop_index(
cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False
Expand Down
32 changes: 31 additions & 1 deletion aerich/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
import os
import re
import sys
from collections.abc import Generator
from collections.abc import Awaitable, Callable, Generator
from pathlib import Path
from types import ModuleType
from typing import TypeVar

from anyio import from_thread
from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise

if sys.version_info >= (3, 11):
from typing import ParamSpec, TypeVarTuple, Unpack
else:
from typing_extensions import ParamSpec, TypeVarTuple, Unpack

T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
P = ParamSpec("P")


def add_src_path(path: str) -> str:
"""
Expand Down Expand Up @@ -141,3 +152,22 @@ def get_dict_diff_by_key(
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add


def run_async(
async_func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
) -> T_Retval:
"""Run async function in worker thread and get the result of it"""
# `asyncio.run(async_func())` can get the result of async function,
# but it will close the running loop.
result: list[T_Retval] = []

async def runner() -> None:
res = await async_func(*args)
result.append(res)

with from_thread.start_blocking_portal() as portal:
portal.call(runner)

return result[0]
27 changes: 22 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path

import pytest
Expand Down Expand Up @@ -31,6 +32,7 @@
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
},
}
TEST_DIR = Path(__file__).parent / "tests"


@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -70,14 +72,11 @@ async def initialize_tests(event_loop, request) -> None:
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))


@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent / "tests"
asset_dir = test_dir / "assets" / "fake"
@contextmanager
def _new_aerich_project(tmp_path: Path, asset_dir: Path, models_py: Path, test_dir=TEST_DIR):
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
Expand All @@ -95,3 +94,21 @@ def new_aerich_project(tmp_path: Path):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))


@pytest.fixture
def new_aerich_project(tmp_path: Path):
# Create a tortoise project in tmp_path that managed by aerich using assets from tests/assets/fake/
asset_dir = TEST_DIR / "assets" / "fake"
models_py = TEST_DIR / "models.py"
with _new_aerich_project(tmp_path, asset_dir, models_py):
yield


@pytest.fixture
def tmp_aerich_project(tmp_path: Path):
# Create a tortoise project in tmp_path that managed by aerich using assets from tests/assets/remove_constraint/
asset_dir = TEST_DIR / "assets" / "remove_constraint"
models_py = asset_dir / "models.py"
with _new_aerich_project(tmp_path, asset_dir, models_py):
yield
Loading