3
3
import contextlib
4
4
import importlib
5
5
import os
6
+ import re
6
7
from collections .abc import Iterable
7
8
from datetime import datetime
8
9
from pathlib import Path
23
24
get_app_connection ,
24
25
get_dict_diff_by_key ,
25
26
get_models_describe ,
27
+ import_py_file ,
26
28
is_default_function ,
29
+ run_async ,
27
30
)
28
31
29
32
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
@@ -163,7 +166,7 @@ def _exclude_extra_field_types(cls, diffs) -> list[tuple]:
163
166
]
164
167
165
168
@classmethod
166
- async def migrate (cls , name : str , empty : bool ) -> str :
169
+ async def migrate (cls , name : str , empty : bool , no_input : bool = False ) -> str :
167
170
"""
168
171
diff old models and new models to generate diff content
169
172
:param name: str name for migration
@@ -174,8 +177,8 @@ async def migrate(cls, name: str, empty: bool) -> str:
174
177
return await cls ._generate_diff_py (name )
175
178
new_version_content = get_models_describe (cls .app )
176
179
last_version = cast (dict , cls ._last_version_content )
177
- cls .diff_models (last_version , new_version_content )
178
- cls .diff_models (new_version_content , last_version , False )
180
+ cls .diff_models (last_version , new_version_content , no_input = no_input )
181
+ cls .diff_models (new_version_content , last_version , False , no_input = no_input )
179
182
180
183
cls ._merge_operators ()
181
184
@@ -393,9 +396,31 @@ def _handle_o2o_fields(
393
396
key , old_model_describe , new_model_describe , model , old_models , new_models , upgrade
394
397
)
395
398
399
+ @classmethod
400
+ def _is_unique_constraint (cls , model : type [Model ], index_name : str ) -> bool :
401
+ if cls .dialect != "postgres" :
402
+ return False
403
+ # For postgresql, if a unique_together was created when generating the table, it is
404
+ # a constraint. And if it was created after table generated, it will be a unique index.
405
+ migrate_files = cls .get_all_version_files ()
406
+ if len (migrate_files ) < 2 :
407
+ return True
408
+ pattern = re .compile (rf' "?{ index_name } "? ' )
409
+ for filename in reversed (migrate_files [1 :]):
410
+ module = import_py_file (Path (cls .migrate_location , filename ))
411
+ upgrade_sql = run_async (module .upgrade , None )
412
+ if pattern .search (upgrade_sql ):
413
+ line = [i for i in upgrade_sql .splitlines () if pattern .search (i )][0 ]
414
+ prefix_words = pattern .split (line )[0 ].lower ().split ()
415
+ if "drop" in prefix_words :
416
+ # The migrate file may be generated by `aerich migrate` without applied by `aerich upgrade`
417
+ continue
418
+ return "constraint" in prefix_words
419
+ return True
420
+
396
421
@classmethod
397
422
def diff_models (
398
- cls , old_models : dict [str , dict ], new_models : dict [str , dict ], upgrade = True
423
+ cls , old_models : dict [str , dict ], new_models : dict [str , dict ], upgrade = True , no_input = False
399
424
) -> None :
400
425
"""
401
426
diff models and add operators
@@ -467,7 +492,15 @@ def diff_models(
467
492
cls ._add_operator (cls ._add_index (model , index , True ), upgrade , True )
468
493
# remove unique_together
469
494
for index in old_unique_together .difference (new_unique_together ):
470
- cls ._add_operator (cls ._drop_index (model , index , True ), upgrade , True )
495
+ index_name = cls ._unique_index_name (model , index )
496
+ if upgrade and cls ._is_unique_constraint (model , index_name ):
497
+ cls ._add_operator (
498
+ cls .ddl .drop_unique_constraint (model , index_name ), upgrade , True
499
+ )
500
+ else :
501
+ cls ._add_operator (
502
+ cls .ddl .drop_index_by_name (model , index_name ), upgrade , True
503
+ )
471
504
# add indexes
472
505
for idx in new_indexes .difference (old_indexes ):
473
506
cls ._add_operator (cls ._add_index (model , idx ), upgrade , fk_m2m_index = True )
@@ -536,7 +569,7 @@ def diff_models(
536
569
# print a empty line to warn that is another model
537
570
prefix = "\n " + prefix
538
571
models_with_rename_field .add (new_model_str )
539
- is_rename = click .prompt (
572
+ is_rename = no_input or click .prompt (
540
573
f"{ prefix } Rename { old_data_field_name } to { new_data_field_name } ?" ,
541
574
default = True ,
542
575
type = bool ,
@@ -757,6 +790,11 @@ def _resolve_fk_fields_name(cls, model: type[Model], fields_name: Iterable[str])
757
790
ret .append (field_name )
758
791
return ret
759
792
793
+ @classmethod
794
+ def _unique_index_name (cls , model : type [Model ], fields_name : Iterable [str ]) -> str :
795
+ field_names = cls ._resolve_fk_fields_name (model , fields_name )
796
+ return cls .ddl ._index_name (True , model , field_names )
797
+
760
798
@classmethod
761
799
def _drop_index (
762
800
cls , model : type [Model ], fields_name : Iterable [str ] | Index , unique = False
0 commit comments