diff --git a/tortoise/__init__.py b/tortoise/__init__.py index cd3ca65f2..d84aaef36 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import importlib -import json import logging import os import warnings @@ -271,23 +270,6 @@ def _init_apps( validate_connections=validate_connections, ) - @classmethod - def _get_config_from_config_file(cls, config_file: str) -> dict: - _, extension = os.path.splitext(config_file) - if extension in (".yml", ".yaml"): - import yaml # pylint: disable=C0415 - - with open(config_file) as f: - config = yaml.safe_load(f) - elif extension == ".json": - with open(config_file) as f: - config = json.load(f) - else: - raise ConfigurationError( - f"Unknown config extension {extension}, only .yml and .json are supported" - ) - return config - @classmethod def _build_initial_querysets(cls) -> None: if cls.apps: @@ -408,7 +390,7 @@ async def init( # Normalize config: handle config_file case normalized_config: dict[str, Any] | TortoiseConfig | None = config if config_file: - normalized_config = cls._get_config_from_config_file(config_file) + normalized_config = TortoiseConfig.from_config_file(config_file) # Debug logging if logger.isEnabledFor(logging.DEBUG) and normalized_config is not None: diff --git a/tortoise/cli/cli.py b/tortoise/cli/cli.py index b3269de49..5d1f69c59 100644 --- a/tortoise/cli/cli.py +++ b/tortoise/cli/cli.py @@ -174,13 +174,9 @@ def _load_config(ctx: CLIContext) -> TortoiseConfig: Returns: TortoiseConfig: Validated configuration object """ - config_value = ctx.config - config_file = ctx.config_file - if config_file: - config_dict = Tortoise._get_config_from_config_file(config_file) - return TortoiseConfig.from_dict(config_dict) - if not config_value: - config_value = utils.tortoise_orm_config() + if config_file := ctx.config_file: + return TortoiseConfig.from_config_file(config_file) + config_value = ctx.config or utils.tortoise_orm_config() if not config_value: raise utils.CLIUsageError( "You must specify TORTOISE_ORM in option or env, or pyproject.toml [tool.tortoise]", diff --git a/tortoise/config.py b/tortoise/config.py index 8c2be9d16..f7527093b 100644 --- a/tortoise/config.py +++ b/tortoise/config.py @@ -1,11 +1,24 @@ from __future__ import annotations +import json +import os from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any +from tortoise.backends.base.config_generator import generate_config from tortoise.exceptions import ConfigurationError +if TYPE_CHECKING: + import sys + from collections.abc import Iterable + from types import ModuleType + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + @dataclass(frozen=True) class DBUrlConfig: @@ -46,7 +59,7 @@ def to_config(self) -> str | dict[str, Any]: return {"engine": self.engine, "credentials": self.credentials} @classmethod - def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig: + def from_dict(cls, data: Mapping[str, Any]) -> Self: if not isinstance(data, Mapping): raise ConfigurationError("ConnectionConfig must be created from a mapping") credentials = data.get("credentials", {}) @@ -85,7 +98,7 @@ def to_dict(self) -> dict[str, Any]: return data @classmethod - def from_dict(cls, data: Mapping[str, Any]) -> AppConfig: + def from_dict(cls, data: Mapping[str, Any]) -> Self: if not isinstance(data, Mapping): raise ConfigurationError("AppConfig must be created from a mapping") if "models" not in data: @@ -159,7 +172,7 @@ def to_dict(self) -> dict[str, Any]: return config @classmethod - def from_dict(cls, data: Mapping[str, Any]) -> TortoiseConfig: + def from_dict(cls, data: Mapping[str, Any]) -> Self: if not isinstance(data, Mapping): raise ConfigurationError("TortoiseConfig must be created from a mapping") @@ -202,3 +215,102 @@ def from_dict(cls, data: Mapping[str, Any]) -> TortoiseConfig: use_tz=data.get("use_tz"), timezone=data.get("timezone"), ) + + @classmethod + def from_config_file(cls, config_file: str) -> Self: + """ + Load configuration from a YAML or JSON file. + + Args: + config_file (str): Path to the configuration file. Supported extensions: .yml, .yaml, .json. + + Returns: + Self: The constructed TortoiseConfig. + + Raises: + ConfigurationError: If the file is missing, unsupported, or contents are invalid. + """ + _, extension = os.path.splitext(config_file) + if extension in (".yml", ".yaml"): + import yaml # pylint: disable=C0415 + + with open(config_file) as f: + config = yaml.safe_load(f) + elif extension == ".json": + with open(config_file) as f: + config = json.load(f) + else: + raise ConfigurationError( + f"Unknown config extension {extension}, only .yml and .json are supported" + ) + return cls.from_dict(config) + + @classmethod + def from_db_url_and_modules( + cls, db_url: str, modules: dict[str, Iterable[str | ModuleType]] + ) -> Self: + """ + Create a TortoiseConfig instance using a database URL and app modules mapping. + + This factory method builds a configuration dictionary using the provided database URL and modules, + and returns a TortoiseConfig instance based on that configuration. + + Args: + db_url: Database connection URL as a string. + modules: + A mapping where keys are app names, and values are iterables of Python module names + (as strings or Python module types) containing ORM models. + + Returns: + Self: The constructed TortoiseConfig instance. + + Raises: + ConfigurationError: If the generated config is invalid. + """ + config_dict = generate_config(db_url, app_modules=modules) + return cls.from_dict(config_dict) + + @classmethod + def resolve_args( + cls, + config: dict[str, Any] | Self | None = None, + config_file: str | None = None, + db_url: str | None = None, + modules: dict[str, Iterable[str | ModuleType]] | None = None, + ) -> Self: + """ + Parse and resolve multiple configuration argument sources into a unified TortoiseConfig instance. + + Accepts (in order of priority): + - `config` dict or TortoiseConfig instance, + - `config_file` path, + - or both `db_url` and `modules`. + + Args: + config (dict[str, Any] | TortoiseConfig | None): + config_file (str | None): Path to a config YAML or JSON file. + db_url (str | None): Database URL for config generation. + modules (dict[str, Iterable[str | ModuleType]] | None): App modules for config generation. + Args: + config: A configuration dict or TortoiseConfig instance. + config_file: Path to config file. + db_url: Database URL for config generation. + modules: App modules for config generation. + + Returns: + TortoiseConfig instance with resolved configuration. + + Raises: + ConfigurationError: If arguments are invalid or conflicting. + """ + if config is not None: + if config_file is not None: + raise ConfigurationError("Cannot specify both 'config' and 'config_file'") + return cls.from_dict(config) if isinstance(config, dict) else config + elif config_file is not None: + return cls.from_config_file(config_file) + elif db_url is not None and modules is not None: + return cls.from_db_url_and_modules(db_url, modules) + raise ConfigurationError( + "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" + ) diff --git a/tortoise/context.py b/tortoise/context.py index bc274b215..9b9a14014 100644 --- a/tortoise/context.py +++ b/tortoise/context.py @@ -236,26 +236,6 @@ def routers(self) -> list[type]: """ return self._routers - def _get_config_from_config_file(self, config_file: str) -> dict: - """Load configuration from a JSON or YAML file.""" - import json - import os - - _, extension = os.path.splitext(config_file) - if extension in (".yml", ".yaml"): - import yaml # pylint: disable=C0415 - - with open(config_file) as f: - config = yaml.safe_load(f) - elif extension == ".json": - with open(config_file) as f: - config = json.load(f) - else: - raise ConfigurationError( - f"Unknown config extension {extension}, only .yml and .json are supported" - ) - return config - async def init( self, config: dict[str, Any] | TortoiseConfig | None = None, @@ -303,26 +283,7 @@ async def init( """ from tortoise.apps import Apps - # Handle config_file: load it as config dict - if config_file is not None: - if config is not None: - raise ConfigurationError("Cannot specify both 'config' and 'config_file'") - config = self._get_config_from_config_file(config_file) - - # Convert input to TortoiseConfig for typed access - typed_config: TortoiseConfig - if config is None: - if db_url is None or modules is None: - raise ConfigurationError( - "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'" - ) - config_dict = generate_config(db_url, app_modules=modules) - typed_config = TortoiseConfig.from_dict(config_dict) - elif isinstance(config, TortoiseConfig): - typed_config = config - else: - typed_config = TortoiseConfig.from_dict(config) - + typed_config = TortoiseConfig.resolve_args(config, config_file, db_url, modules) config_dict = typed_config.to_dict() connections_config = config_dict["connections"] apps_config = config_dict["apps"] diff --git a/tortoise/migrations/api/migrate.py b/tortoise/migrations/api/migrate.py index 508f8cea3..7576e880a 100644 --- a/tortoise/migrations/api/migrate.py +++ b/tortoise/migrations/api/migrate.py @@ -23,10 +23,7 @@ async def migrate( progress: Callable[[str, str, str], object] | None = None, ) -> None: """Run migrations for configured apps.""" - if isinstance(config, TortoiseConfig): - config = config.to_dict() - if config_file: - config = Tortoise._get_config_from_config_file(config_file) + config = TortoiseConfig.resolve_args(config, config_file).to_dict() if not config: raise ValueError("migrate requires a config or config_file") diff --git a/tortoise/migrations/api/plan.py b/tortoise/migrations/api/plan.py index 2dce67f56..957e4d0cb 100644 --- a/tortoise/migrations/api/plan.py +++ b/tortoise/migrations/api/plan.py @@ -19,10 +19,7 @@ async def plan( """ Print an ordered migration plan and return the formatted lines. """ - if isinstance(config, TortoiseConfig): - config = config.to_dict() - if config_file: - config = Tortoise._get_config_from_config_file(config_file) + config = TortoiseConfig.resolve_args(config, config_file).to_dict() if not config: raise ValueError("plan requires a config or config_file") diff --git a/tortoise/migrations/api/sqlmigrate.py b/tortoise/migrations/api/sqlmigrate.py index 37ecaeb0f..2be4a2bcc 100644 --- a/tortoise/migrations/api/sqlmigrate.py +++ b/tortoise/migrations/api/sqlmigrate.py @@ -43,10 +43,7 @@ async def sqlmigrate( Returns: A list of SQL strings (including descriptive comment annotations). """ - if isinstance(config, TortoiseConfig): - config = config.to_dict() - if config_file: - config = Tortoise._get_config_from_config_file(config_file) + config = TortoiseConfig.resolve_args(config, config_file).to_dict() if not config: raise ValueError("sqlmigrate requires a config or config_file")