diff --git a/docs/sphinx_autodoc_typehints.py b/docs/sphinx_autodoc_typehints.py index d0e5aed1d..67d1f1216 100644 --- a/docs/sphinx_autodoc_typehints.py +++ b/docs/sphinx_autodoc_typehints.py @@ -15,25 +15,33 @@ Protocol = None logger = logging.getLogger(__name__) -pydata_annotations = {'Any', 'AnyStr', 'Callable', 'ClassVar', 'NoReturn', 'Optional', 'Tuple', - 'Union'} +pydata_annotations = { + "Any", + "AnyStr", + "Callable", + "ClassVar", + "NoReturn", + "Optional", + "Tuple", + "Union", +} def format_annotation(annotation, fully_qualified=False): - if inspect.isclass(annotation) and annotation.__module__ == 'builtins': - if annotation.__qualname__ == 'NoneType': - return '``None``' + if inspect.isclass(annotation) and annotation.__module__ == "builtins": + if annotation.__qualname__ == "NoneType": + return "``None``" else: - return ':py:class:`{}`'.format(annotation.__qualname__) + return ":py:class:`{}`".format(annotation.__qualname__) annotation_cls = annotation if inspect.isclass(annotation) else type(annotation) - if annotation_cls.__module__ == 'typing': - class_name = str(annotation).split('[')[0].split('.')[-1] + if annotation_cls.__module__ == "typing": + class_name = str(annotation).split("[")[0].split(".")[-1] params = None - module = 'typing' - extra = '' + module = "typing" + extra = "" - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if inspect.isclass(origin): annotation_cls = annotation.__origin__ try: @@ -44,13 +52,13 @@ def format_annotation(annotation, fully_qualified=False): pass # annotation_cls was either the "type" object or typing.Type if annotation is Any: - return ':py:data:`{}typing.Any`'.format("" if fully_qualified else "~") + return ":py:data:`{}typing.Any`".format("" if fully_qualified else "~") elif annotation is AnyStr: - return ':py:data:`{}typing.AnyStr`'.format("" if fully_qualified else "~") + return ":py:data:`{}typing.AnyStr`".format("" if fully_qualified else "~") elif isinstance(annotation, TypeVar): bound = annotation.__bound__ if bound: - if 'ForwardRef(' in str(bound): + if "ForwardRef(" in str(bound): try: bound = bound._evaluate(sys.modules[annotation.__module__].__dict__, None) except: @@ -59,28 +67,34 @@ def format_annotation(annotation, fully_qualified=False): except: bound = bound.__forward_arg__ return format_annotation(bound, fully_qualified) - return '\\%r' % annotation - elif (annotation is Union or getattr(annotation, '__origin__', None) is Union or - hasattr(annotation, '__union_params__')): - if hasattr(annotation, '__union_params__'): + return "\\%r" % annotation + elif ( + annotation is Union + or getattr(annotation, "__origin__", None) is Union + or hasattr(annotation, "__union_params__") + ): + if hasattr(annotation, "__union_params__"): params = annotation.__union_params__ - elif hasattr(annotation, '__args__'): + elif hasattr(annotation, "__args__"): params = annotation.__args__ - if params and len(params) == 2 and (hasattr(params[1], '__qualname__') and - params[1].__qualname__ == 'NoneType'): - class_name = 'Optional' + if ( + params + and len(params) == 2 + and (hasattr(params[1], "__qualname__") and params[1].__qualname__ == "NoneType") + ): + class_name = "Optional" params = (params[0],) - elif annotation_cls.__qualname__ == 'Tuple' and hasattr(annotation, '__tuple_params__'): + elif annotation_cls.__qualname__ == "Tuple" and hasattr(annotation, "__tuple_params__"): params = annotation.__tuple_params__ if annotation.__tuple_use_ellipsis__: params += (Ellipsis,) - elif annotation_cls.__qualname__ == 'Callable': + elif annotation_cls.__qualname__ == "Callable": arg_annotations = result_annotation = None - if hasattr(annotation, '__result__'): + if hasattr(annotation, "__result__"): arg_annotations = annotation.__args__ result_annotation = annotation.__result__ - elif getattr(annotation, '__args__', None): + elif getattr(annotation, "__args__", None): arg_annotations = annotation.__args__[:-1] result_annotation = annotation.__args__[-1] @@ -88,66 +102,74 @@ def format_annotation(annotation, fully_qualified=False): params = [Ellipsis, result_annotation] elif arg_annotations is not None: params = [ - '\\[{}]'.format( - ', '.join( - format_annotation(param, fully_qualified) - for param in arg_annotations)), - result_annotation + "\\[{}]".format( + ", ".join( + format_annotation(param, fully_qualified) for param in arg_annotations + ) + ), + result_annotation, ] - elif str(annotation).startswith('typing.ClassVar[') and hasattr(annotation, '__type__'): + elif str(annotation).startswith("typing.ClassVar[") and hasattr(annotation, "__type__"): # < py3.7 params = (annotation.__type__,) - elif hasattr(annotation, 'type_var'): + elif hasattr(annotation, "type_var"): # Type alias class_name = annotation.name params = (annotation.type_var,) - elif getattr(annotation, '__args__', None) is not None: + elif getattr(annotation, "__args__", None) is not None: params = annotation.__args__ - elif hasattr(annotation, '__parameters__'): + elif hasattr(annotation, "__parameters__"): params = annotation.__parameters__ if params: - extra = '\\[{}]'.format(', '.join( - format_annotation(param, fully_qualified) for param in params)) + extra = "\\[{}]".format( + ", ".join(format_annotation(param, fully_qualified) for param in params) + ) - return '{prefix}`{qualify}{module}.{name}`{extra}'.format( - prefix=':py:data:' if class_name in pydata_annotations else ':py:class:', + return "{prefix}`{qualify}{module}.{name}`{extra}".format( + prefix=":py:data:" if class_name in pydata_annotations else ":py:class:", qualify="" if fully_qualified else "~", module=module, name=class_name, - extra=extra + extra=extra, ) elif annotation is Ellipsis: - return '...' - elif (inspect.isfunction(annotation) and annotation.__module__ == 'typing' and - hasattr(annotation, '__name__') and hasattr(annotation, '__supertype__')): - return ':py:func:`{qualify}typing.NewType`\\(:py:data:`~{name}`, {extra})'.format( + return "..." + elif ( + inspect.isfunction(annotation) + and annotation.__module__ == "typing" + and hasattr(annotation, "__name__") + and hasattr(annotation, "__supertype__") + ): + return ":py:func:`{qualify}typing.NewType`\\(:py:data:`~{name}`, {extra})".format( qualify="" if fully_qualified else "~", name=annotation.__name__, extra=format_annotation(annotation.__supertype__, fully_qualified), ) - elif inspect.isclass(annotation) or inspect.isclass(getattr(annotation, '__origin__', None)): + elif inspect.isclass(annotation) or inspect.isclass(getattr(annotation, "__origin__", None)): if not inspect.isclass(annotation): annotation_cls = annotation.__origin__ - extra = '' + extra = "" try: mro = annotation_cls.mro() except TypeError: pass else: if Generic in mro or (Protocol and Protocol in mro): - params = (getattr(annotation, '__parameters__', None) or - getattr(annotation, '__args__', None)) + params = getattr(annotation, "__parameters__", None) or getattr( + annotation, "__args__", None + ) if params: - extra = '\\[{}]'.format(', '.join( - format_annotation(param, fully_qualified) for param in params)) + extra = "\\[{}]".format( + ", ".join(format_annotation(param, fully_qualified) for param in params) + ) - return ':py:class:`{qualify}{module}.{name}`{extra}'.format( + return ":py:class:`{qualify}{module}.{name}`{extra}".format( qualify="" if fully_qualified else "~", module=annotation.__module__, name=annotation_cls.__qualname__, - extra=extra + extra=extra, ) return str(annotation) @@ -157,31 +179,31 @@ def process_signature(app, what: str, name: str, obj, options, signature, return if not callable(obj): return - if what in ('class', 'exception'): - obj = getattr(obj, '__init__', getattr(obj, '__new__', None)) + if what in ("class", "exception"): + obj = getattr(obj, "__init__", getattr(obj, "__new__", None)) - if not getattr(obj, '__annotations__', None): + if not getattr(obj, "__annotations__", None): return obj = inspect.unwrap(obj) signature = Signature(obj) parameters = [ - param.replace(annotation=inspect.Parameter.empty) - for param in signature.parameters.values() + param.replace(annotation=inspect.Parameter.empty) for param in signature.parameters.values() ] - if '' in obj.__qualname__: + if "" in obj.__qualname__: logger.warning( 'Cannot treat a function defined as a local function: "%s" (use @functools.wraps)', - name) + name, + ) return if parameters: - if what in ('class', 'exception'): + if what in ("class", "exception"): del parameters[0] - elif what == 'method': + elif what == "method": outer = inspect.getmodule(obj) - for clsname in obj.__qualname__.split('.')[:-1]: + for clsname in obj.__qualname__.split(".")[:-1]: outer = getattr(outer, clsname) method_name = obj.__name__ @@ -189,18 +211,16 @@ def process_signature(app, what: str, name: str, obj, options, signature, return # If the method starts with double underscore (dunder) # Python applies mangling so we need to prepend the class name. # This doesn't happen if it always ends with double underscore. - class_name = obj.__qualname__.split('.')[-2] + class_name = obj.__qualname__.split(".")[-2] method_name = "_{c}{m}".format(c=class_name, m=method_name) method_object = outer.__dict__[method_name] if outer else obj if not isinstance(method_object, (classmethod, staticmethod)): del parameters[0] - signature = signature.replace( - parameters=parameters, - return_annotation=inspect.Signature.empty) + signature = signature.replace(parameters=parameters, return_annotation=inspect.Signature.empty) - return stringify_signature(signature).replace('\\', '\\\\'), None + return stringify_signature(signature).replace("\\", "\\\\"), None def get_all_type_hints(obj, name): @@ -216,8 +236,9 @@ def get_all_type_hints(obj, name): try: rv = get_type_hints(obj, localns=type_globals.__dict__) except Exception as exc: - logger.warning('Cannot resolve forward reference in type annotations of "%s": %s', - name, exc) + logger.warning( + 'Cannot resolve forward reference in type annotations of "%s": %s', name, exc + ) rv = obj.__annotations__ if rv: @@ -238,8 +259,9 @@ def get_all_type_hints(obj, name): try: rv = get_type_hints(obj, localns=type_globals.__dict__) except: - logger.warning('Cannot resolve forward reference in type annotations of "%s": %s', - name, exc) + logger.warning( + 'Cannot resolve forward reference in type annotations of "%s": %s', name, exc + ) rv = obj.__annotations__ return rv @@ -254,14 +276,16 @@ def backfill_type_hints(obj, name): return {} else: import ast - parse_kwargs = {'type_comments': True} + + parse_kwargs = {"type_comments": True} def _one_child(module): children = module.body # use the body to ignore type comments if len(children) != 1: logger.warning( - 'Did not get exactly one node from AST for "%s", got %s', name, len(children)) + 'Did not get exactly one node from AST for "%s", got %s', name, len(children) + ) return return children[0] @@ -284,14 +308,14 @@ def _one_child(module): return {} try: - comment_args_str, comment_returns = type_comment.split(' -> ') + comment_args_str, comment_returns = type_comment.split(" -> ") except ValueError: logger.warning('Unparseable type hint comment for "%s": Expected to contain ` -> `', name) return {} rv = {} if comment_returns: - rv['return'] = comment_returns + rv["return"] = comment_returns args = load_args(obj_ast) comment_args = split_type_comment_args(comment_args_str) @@ -323,7 +347,7 @@ def _one_child(module): def load_args(obj_ast): func_args = obj_ast.args args = [] - pos_only = getattr(func_args, 'posonlyargs', None) + pos_only = getattr(func_args, "posonlyargs", None) if pos_only: args.extend(pos_only) @@ -357,7 +381,7 @@ def add(val): add(comment[start_arg_at:at]) start_arg_at = at + 1 - add(comment[start_arg_at: at + 1]) + add(comment[start_arg_at : at + 1]) return result @@ -366,22 +390,23 @@ def process_docstring(app, what, name, obj, options, lines): obj = obj.fget if callable(obj): - if what in ('class', 'exception'): - obj = getattr(obj, '__init__') + if what in ("class", "exception"): + obj = getattr(obj, "__init__") obj = inspect.unwrap(obj) type_hints = get_all_type_hints(obj, name) for argname, annotation in type_hints.items(): - if argname == 'return': + if argname == "return": continue # this is handled separately later - if argname.endswith('_'): - argname = '{}\\_'.format(argname[:-1]) + if argname.endswith("_"): + argname = "{}\\_".format(argname[:-1]) formatted_annotation = format_annotation( - annotation, fully_qualified=app.config.typehints_fully_qualified) + annotation, fully_qualified=app.config.typehints_fully_qualified + ) - searchfor = ':param {}:'.format(argname) + searchfor = ":param {}:".format(argname) insert_index = None for i, line in enumerate(lines): @@ -394,31 +419,29 @@ def process_docstring(app, what, name, obj, options, lines): insert_index = len(lines) if insert_index is not None: - lines.insert( - insert_index, - ':type {}: {}'.format(argname, formatted_annotation) - ) + lines.insert(insert_index, ":type {}: {}".format(argname, formatted_annotation)) - if 'return' in type_hints and what not in ('class', 'exception'): + if "return" in type_hints and what not in ("class", "exception"): formatted_annotation = format_annotation( - type_hints['return'], fully_qualified=app.config.typehints_fully_qualified) + type_hints["return"], fully_qualified=app.config.typehints_fully_qualified + ) insert_index = len(lines) for i, line in enumerate(lines): - if line.startswith(':rtype:'): + if line.startswith(":rtype:"): insert_index = None break - elif line.startswith(':return:') or line.startswith(':returns:'): + elif line.startswith(":return:") or line.startswith(":returns:"): insert_index = i if insert_index is not None: if insert_index == len(lines): # Ensure that :rtype: doesn't get joined with a paragraph of text, which # prevents it being interpreted. - lines.append('') + lines.append("") insert_index += 1 - lines.insert(insert_index, ':rtype: {}'.format(formatted_annotation)) + lines.insert(insert_index, ":rtype: {}".format(formatted_annotation)) def builder_ready(app): @@ -427,10 +450,10 @@ def builder_ready(app): def setup(app): - app.add_config_value('set_type_checking_flag', False, 'html') - app.add_config_value('always_document_param_types', False, 'html') - app.add_config_value('typehints_fully_qualified', False, 'env') - app.connect('builder-inited', builder_ready) - app.connect('autodoc-process-signature', process_signature) - app.connect('autodoc-process-docstring', process_docstring) + app.add_config_value("set_type_checking_flag", False, "html") + app.add_config_value("always_document_param_types", False, "html") + app.add_config_value("typehints_fully_qualified", False, "env") + app.connect("builder-inited", builder_ready) + app.connect("autodoc-process-signature", process_signature) + app.connect("autodoc-process-docstring", process_docstring) return dict(parallel_read_safe=True) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index ac07d117f..a08149aac 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -1,5 +1,5 @@ from hashlib import sha256 -from typing import TYPE_CHECKING, Any, List, Set, Type, cast +from typing import TYPE_CHECKING, Any, List, Set, Type, cast, Dict from tortoise.exceptions import ConfigurationError from tortoise.fields import JSONField, TextField, UUIDField @@ -425,24 +425,56 @@ def get_create_schema_sql(self, safe: bool = True) -> str: created_tables: Set[dict] = set() ordered_tables_for_create: List[str] = [] m2m_tables_to_create: List[str] = [] - while True: - if len(created_tables) == tables_to_create_count: - break - try: - next_table_for_create = next( - t - for t in tables_to_create - if t["references"].issubset(created_tables | {t["table"]}) + + while len(created_tables) != tables_to_create_count: + if not tables_to_create: + # This means an exception will be raised! The following is forensics. + + discovered_tables: Dict[str, Type[Model]] = {} + for model in models_to_create: + table_name = str(model._meta.basetable).replace('"', "") + if table_name in discovered_tables: + other_cyclic_model = discovered_tables[table_name] + msg = ( + f"Model {model._meta.full_name} overlaps with model {other_cyclic_model._meta.full_name}. " + f"Make sure to use typing.TYPE_CHECKING if models are in multiple Python modules." + ) + raise ConfigurationError(msg) + discovered_tables[table_name] = model + + raise ConfigurationError(_FORENSIC_FAIL_MSG) + + for table in tables_to_create: + if table["references"].issubset(created_tables | {table["table"]}): + next_table_to_create = table + break + else: # if no break + try: + t = tables_to_create[0] + except IndexError: + raise ConfigurationError( + f"Forensic of error regarding foreign key (FK) references failed:\n{_FORENSIC_FAIL_MSG}" + ) + + table = t["table"] + refs = [i for i in t["references"] if i != table and i not in created_tables] + raise ConfigurationError( + f"Failed to create schema(`{table}`) due to cyclic foreign key (FK) references({refs})" ) - except StopIteration: - raise ConfigurationError("Can't create schema due to cyclic fk references") - tables_to_create.remove(next_table_for_create) - created_tables.add(next_table_for_create["table"]) - ordered_tables_for_create.append(next_table_for_create["table_creation_string"]) - m2m_tables_to_create += next_table_for_create["m2m_tables"] + + tables_to_create.remove(next_table_to_create) + created_tables.add(next_table_to_create["table"]) + ordered_tables_for_create.append(next_table_to_create["table_creation_string"]) + m2m_tables_to_create += next_table_to_create["m2m_tables"] schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create) return schema_creation_string async def generate_from_string(self, creation_string: str) -> None: await self.client.execute_script(creation_string) + + +_FORENSIC_FAIL_MSG = ( + "Something to do with your model structure, raise an issue on GitHub, and perhaps " + "reference the previous PR: https://github.com/tortoise/tortoise-orm/pull/1236" +)