Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import importlib
import json
import logging
import os
import warnings
Expand Down Expand Up @@ -271,23 +270,6 @@ def _init_apps(
validate_connections=validate_connections,
)

@classmethod
def _get_config_from_config_file(cls, config_file: str) -> dict:
_, extension = os.path.splitext(config_file)
if extension in (".yml", ".yaml"):
import yaml # pylint: disable=C0415

with open(config_file) as f:
config = yaml.safe_load(f)
elif extension == ".json":
with open(config_file) as f:
config = json.load(f)
else:
raise ConfigurationError(
f"Unknown config extension {extension}, only .yml and .json are supported"
)
return config

@classmethod
def _build_initial_querysets(cls) -> None:
if cls.apps:
Expand Down Expand Up @@ -408,7 +390,7 @@ async def init(
# Normalize config: handle config_file case
normalized_config: dict[str, Any] | TortoiseConfig | None = config
if config_file:
normalized_config = cls._get_config_from_config_file(config_file)
normalized_config = TortoiseConfig._get_config_from_config_file(config_file)

# Debug logging
if logger.isEnabledFor(logging.DEBUG) and normalized_config is not None:
Expand Down
11 changes: 5 additions & 6 deletions tortoise/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,13 @@ def _load_config(ctx: CLIContext) -> TortoiseConfig:
config_value = ctx.config
config_file = ctx.config_file
if config_file:
config_dict = Tortoise._get_config_from_config_file(config_file)
return TortoiseConfig.from_dict(config_dict)
return TortoiseConfig._get_config_from_config_file(config_file)
if not config_value:
config_value = utils.tortoise_orm_config()
if not config_value:
raise utils.CLIUsageError(
"You must specify TORTOISE_ORM in option or env, or pyproject.toml [tool.tortoise]",
)
if not config_value:
raise utils.CLIUsageError(
"You must specify TORTOISE_ORM in option or env, or pyproject.toml [tool.tortoise]",
)
Comment thread
waketzheng marked this conversation as resolved.
Outdated
return utils.get_tortoise_config(config_value)


Expand Down
49 changes: 48 additions & 1 deletion tortoise/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import json
import os
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Any
from typing import TYPE_CHECKING, Any

from tortoise.exceptions import ConfigurationError

if TYPE_CHECKING:
from collections.abc import Iterable
from types import ModuleType


@dataclass(frozen=True)
class DBUrlConfig:
Expand Down Expand Up @@ -202,3 +208,44 @@ def from_dict(cls, data: Mapping[str, Any]) -> TortoiseConfig:
use_tz=data.get("use_tz"),
timezone=data.get("timezone"),
)

@classmethod
def merge_args(
Comment thread
waketzheng marked this conversation as resolved.
Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the method name is clear, merge_args seems to create a merged configuration from config, config_file, db_url and modules, but this is actually not the case.

Maybe generate_config could be a better name.

Moreover, maybe we need to normalize the classmethods in this class:

@classmethod
def generate_config(
    cls,
    config: dict[str, Any] | Self | None = None,
    config_file: str | None = None,
    db_url: str | None = None,
    modules: dict[str, Iterable[str | ModuleType]] | None = None,
) -> Self:
    ...

@classmethod
def generate_config_from_db_url_and_modules(cls, db_url: str, modules: dict[str, Iterable[str | ModuleType]]) -> Self:
    ...

@classmethod
def generate_config_from_config_file(cls, config_file: str) -> Self:
    ...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions. I’ve updated the naming:

  1. To keep consistent with from_dict, I’ll use from_db_url_and_modules / from_config_file.

  2. Since there’s already a parameter called config, I used resolve_args instead of generate_config to avoid confusion.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the from_* convention and it's consistent with the existing code. Maybe we can rename resolve_args to from_args? 🤔

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can rename resolve_args to from_args?

That is one solution. However, I strongly recommend using resolve_* instead of from_* for this function, because:

  1. resolve_* checks whether arguments conflict, whereas from_* does not.
  2. resolve indicates that this function will parse the arguments, not just load configuration from them.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from_* is clearer, but we can let @abondar decide...

Comment thread
waketzheng marked this conversation as resolved.
Outdated
cls,
config: dict[str, Any] | TortoiseConfig | None = None,
config_file: str | None = None,
db_url: str | None = None,
modules: dict[str, Iterable[str | ModuleType]] | None = None,
) -> TortoiseConfig:
if config is not None:
if config_file is not None:
raise ConfigurationError("Cannot specify both 'config' and 'config_file'")
return cls.from_dict(config) if isinstance(config, dict) else config
elif config_file is not None:
return cls._get_config_from_config_file(config_file)
elif db_url is None or modules is None:
raise ConfigurationError(
"Must provide either 'config', 'config_file', or both 'db_url' and 'modules'"
)
else:
from tortoise.backends.base.config_generator import generate_config

config_dict = generate_config(db_url, app_modules=modules)
return cls.from_dict(config_dict)

@classmethod
def _get_config_from_config_file(cls, config_file: str) -> TortoiseConfig:
Comment thread
waketzheng marked this conversation as resolved.
Outdated
_, extension = os.path.splitext(config_file)
if extension in (".yml", ".yaml"):
import yaml # pylint: disable=C0415

with open(config_file) as f:
config = yaml.safe_load(f)
elif extension == ".json":
with open(config_file) as f:
config = json.load(f)
else:
raise ConfigurationError(
f"Unknown config extension {extension}, only .yml and .json are supported"
)
return cls.from_dict(config)
41 changes: 1 addition & 40 deletions tortoise/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,26 +236,6 @@ def routers(self) -> list[type]:
"""
return self._routers

def _get_config_from_config_file(self, config_file: str) -> dict:
"""Load configuration from a JSON or YAML file."""
import json
import os

_, extension = os.path.splitext(config_file)
if extension in (".yml", ".yaml"):
import yaml # pylint: disable=C0415

with open(config_file) as f:
config = yaml.safe_load(f)
elif extension == ".json":
with open(config_file) as f:
config = json.load(f)
else:
raise ConfigurationError(
f"Unknown config extension {extension}, only .yml and .json are supported"
)
return config

async def init(
self,
config: dict[str, Any] | TortoiseConfig | None = None,
Expand Down Expand Up @@ -303,26 +283,7 @@ async def init(
"""
from tortoise.apps import Apps

# Handle config_file: load it as config dict
if config_file is not None:
if config is not None:
raise ConfigurationError("Cannot specify both 'config' and 'config_file'")
config = self._get_config_from_config_file(config_file)

# Convert input to TortoiseConfig for typed access
typed_config: TortoiseConfig
if config is None:
if db_url is None or modules is None:
raise ConfigurationError(
"Must provide either 'config', 'config_file', or both 'db_url' and 'modules'"
)
config_dict = generate_config(db_url, app_modules=modules)
typed_config = TortoiseConfig.from_dict(config_dict)
elif isinstance(config, TortoiseConfig):
typed_config = config
else:
typed_config = TortoiseConfig.from_dict(config)

typed_config = TortoiseConfig.merge_args(config, config_file, db_url, modules)
config_dict = typed_config.to_dict()
connections_config = config_dict["connections"]
apps_config = config_dict["apps"]
Expand Down
4 changes: 2 additions & 2 deletions tortoise/migrations/api/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ async def migrate(
progress: Callable[[str, str, str], object] | None = None,
) -> None:
"""Run migrations for configured apps."""
if config_file:
config = TortoiseConfig._get_config_from_config_file(config_file)
if isinstance(config, TortoiseConfig):
config = config.to_dict()
if config_file:
config = Tortoise._get_config_from_config_file(config_file)
if not config:
raise ValueError("migrate requires a config or config_file")

Expand Down
4 changes: 2 additions & 2 deletions tortoise/migrations/api/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ async def plan(
"""
Print an ordered migration plan and return the formatted lines.
"""
if config_file:
config = TortoiseConfig._get_config_from_config_file(config_file)
if isinstance(config, TortoiseConfig):
config = config.to_dict()
if config_file:
config = Tortoise._get_config_from_config_file(config_file)
if not config:
raise ValueError("plan requires a config or config_file")

Expand Down
4 changes: 2 additions & 2 deletions tortoise/migrations/api/sqlmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ async def sqlmigrate(
Returns:
A list of SQL strings (including descriptive comment annotations).
"""
if config_file:
config = TortoiseConfig._get_config_from_config_file(config_file)
if isinstance(config, TortoiseConfig):
config = config.to_dict()
if config_file:
config = Tortoise._get_config_from_config_file(config_file)
if not config:
raise ValueError("sqlmigrate requires a config or config_file")

Expand Down
Loading