Skip to content

feat: aerich.Command support async with syntax #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,14 @@ from tortoise import Model, fields


class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
date = fields.DateField(null=True)
datetime = fields.DatetimeField(auto_now=True)
decimal = fields.DecimalField(max_digits=10, decimal_places=2)
float = fields.FloatField(null=True)
id = fields.IntField(primary_key=True)
string = fields.CharField(max_length=200, null=True)
time = fields.TimeField(null=True)
tinyint = fields.BooleanField(null=True)
```

Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
Expand All @@ -243,8 +243,8 @@ Note that this command is limited and can't infer some fields, such as `IntEnumF
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
"default": "postgres://postgres_user:[email protected]:5432/db1",
"second": "postgres://postgres_user:[email protected]:5432/db2",
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
Expand All @@ -253,7 +253,7 @@ tortoise_orm = {
}
```

You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`.

## Restore `aerich` workflow

Expand All @@ -273,9 +273,9 @@ You can use `aerich` out of cli by use `Command` class.
```python
from aerich import Command

command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
async with Command(tortoise_config=config, app='models') as command:
await command.migrate('test')
await command.upgrade()
```

## Upgrade/Downgrade with `--fake` option
Expand Down
28 changes: 19 additions & 9 deletions aerich/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import os
import platform
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import TYPE_CHECKING

import tortoise
from tortoise import Tortoise, generate_schema_for_client
from tortoise import Tortoise, connections, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
Expand Down Expand Up @@ -59,10 +60,9 @@ def _init_tortoise_0_24_1_patch():
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re

def _get_m2m_tables(
self, model: type[Model], table_name: str, safe: bool, models_tables: list[str]
) -> list[str]:
self, model: type[Model], db_table: str, safe: bool, models_tables: list[str]
) -> list[str]: # Copied from tortoise-orm
m2m_tables_for_create = []
db_table = table_name
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
Expand All @@ -88,15 +88,15 @@ def _get_m2m_tables(
else:
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
table_name = field_object.through
through_table_name = field_object.through
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=table_name, comment=desc)
comment = self._table_comment_generator(table=through_table_name, comment=desc)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=table_name,
table_name=through_table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
Expand All @@ -116,7 +116,7 @@ def _get_m2m_tables(
m2m_create_string += self._post_table_hook()
if field_object.create_unique_index:
unique_index_create_sql = self._get_unique_index_sql(
exists, table_name, [backward_key, forward_key]
exists, through_table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
Expand All @@ -136,7 +136,7 @@ def _get_m2m_tables(
_init_tortoise_0_24_1_patch()


class Command:
class Command(AbstractAsyncContextManager):
def __init__(
self,
tortoise_config: dict,
Expand All @@ -151,6 +151,16 @@ def __init__(
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)

async def __aenter__(self) -> Command:
await self.init()
return self

async def close(self) -> None:
await connections.close_all()

async def __aexit__(self, *args, **kw) -> None:
await self.close()

async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from aerich import Command
from conftest import tortoise_orm


async def test_command(mocker):
mocker.patch("os.listdir", return_value=[])
async with Command(tortoise_orm) as command:
history = await command.history()
heads = await command.heads()
assert history == []
assert heads == []