diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b96fc6e41..cf247db64 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,11 @@ Changelog 1.1.8 ----- +Added +^^^^^ +- ``QuerySet.union()`` — SQL UNION query support for combining results from multiple QuerySets, including support for union across different models, ``union(all=True)`` for duplicates, ``order_by()``, ``limit()``, and ``count()``. +- Added comprehensive EXPLAIN support for MySQL and PostgreSQL. + Fixed ^^^^^ - ``MigrationRecorder`` now uses parameterized queries; fixes MariaDB/MySQL rejecting ISO-8601 ``applied_at`` values. (#2132) diff --git a/tests/backends/test_explain.py b/tests/backends/test_explain.py index e16d91c91..3d343e10f 100644 --- a/tests/backends/test_explain.py +++ b/tests/backends/test_explain.py @@ -2,7 +2,8 @@ from tests.testmodels import Tournament from tortoise.contrib.test import requireCapability -from tortoise.contrib.test.condition import NotEQ +from tortoise.contrib.test.condition import NotEQ, NotIn +from tortoise.exceptions import UnSupportedError @requireCapability(dialect=NotEQ("mssql")) @@ -18,3 +19,19 @@ async def test_explain(db): plan = await Tournament.all().explain() # This should have returned *some* information. assert len(str(plan)) > 20 + + +@requireCapability(dialect=NotIn("postgres", "mysql", "mssql")) +@pytest.mark.asyncio +async def test_explain_unsupported_output_fmt(db): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError, match="does not support different explain formats"): + await Tournament.all().explain(output_fmt="json") + + +@requireCapability(dialect=NotIn("postgres", "mysql", "mssql")) +@pytest.mark.asyncio +async def test_explain_unsupported_options(db): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError, match="does not support explain options"): + await Tournament.all().explain(analyze=True) diff --git a/tests/backends/test_mysql.py b/tests/backends/test_mysql.py index 37f7c1cf4..36bede415 100644 --- a/tests/backends/test_mysql.py +++ b/tests/backends/test_mysql.py @@ -3,13 +3,17 @@ """ import copy +import json import os import ssl import pytest +from tests.testmodels import Tournament from tortoise.backends.base.config_generator import generate_config from tortoise.context import TortoiseContext +from tortoise.contrib.test import requireCapability +from tortoise.exceptions import UnSupportedError def _get_db_config(): @@ -87,3 +91,65 @@ async def test_ssl_custom(): await ctx.init(db_config, _create_db=True) except ConnectionError: pass + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain() + data = json.loads(result[0]["EXPLAIN"]) + assert "query_plan" in data or "query_block" in data + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_format_traditional(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(output_fmt="traditional") + assert "table" in result[0] + assert result[0]["table"] == "tournament" + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_format_tree(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(output_fmt="tree") + assert isinstance(result[0]["EXPLAIN"], str) + assert "->" in result[0]["EXPLAIN"] + assert "tournament" in result[0]["EXPLAIN"] + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_analyze(db_simple): + await Tournament.create(name="Test") + # Older MySQL version don't support ANALYZE with JSON format, that's why we use TREE + result = await Tournament.all().explain(output_fmt="tree", analyze=True) + assert "actual" in result[0]["EXPLAIN"] + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_analyze_false(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(analyze=False) + assert "query_plan" in result[0]["EXPLAIN"] or "query_block" in result[0]["EXPLAIN"] + assert "actual" not in result[0]["EXPLAIN"] + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_unsupported_format(db_simple): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError, match="Unsupported explain format"): + await Tournament.all().explain(output_fmt="invalid") + + +@requireCapability(dialect="mysql") +@pytest.mark.asyncio +async def test_explain_unsupported_option(db_simple): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError, match="Unsupported options"): + await Tournament.all().explain(unsupported_option=True) diff --git a/tests/backends/test_postgres.py b/tests/backends/test_postgres.py index 694469cd1..ca1b040c4 100644 --- a/tests/backends/test_postgres.py +++ b/tests/backends/test_postgres.py @@ -2,15 +2,19 @@ Test some PostgreSQL-specific features """ +import json import os import ssl +import xml.etree.ElementTree as ET import pytest +import yaml from tests.testmodels import Tournament from tortoise import Tortoise, connections from tortoise.backends.base.config_generator import generate_config -from tortoise.exceptions import OperationalError +from tortoise.contrib.test import requireCapability +from tortoise.exceptions import OperationalError, UnSupportedError def _get_db_config(): @@ -28,11 +32,10 @@ def _get_db_config(): return db_config, is_asyncpg, is_psycopg +@requireCapability(dialect="postgres") @pytest.mark.asyncio -async def test_schema(db_simple): - db_config, is_asyncpg, is_psycopg = _get_db_config() - if not is_asyncpg and not is_psycopg: - pytest.skip("PostgreSQL only") +async def test_schema(db_isolated): + db_config, is_asyncpg, _ = _get_db_config() if is_asyncpg: from asyncpg.exceptions import InvalidSchemaNameError @@ -75,11 +78,10 @@ async def test_schema(db_simple): await Tortoise._drop_databases() +@requireCapability(dialect="postgres") @pytest.mark.asyncio -async def test_ssl_true(): - db_config, is_asyncpg, is_psycopg = _get_db_config() - if not is_asyncpg and not is_psycopg: - pytest.skip("PostgreSQL only") +async def test_ssl_true(db_isolated): + db_config, _, _ = _get_db_config() db_config["connections"]["models"]["credentials"]["ssl"] = True ssl_failed = False @@ -95,11 +97,10 @@ async def test_ssl_true(): await Tortoise._drop_databases() +@requireCapability(dialect="postgres") @pytest.mark.asyncio -async def test_ssl_custom(): - db_config, is_asyncpg, is_psycopg = _get_db_config() - if not is_asyncpg and not is_psycopg: - pytest.skip("PostgreSQL only") +async def test_ssl_custom(db_isolated): + db_config, _, _ = _get_db_config() # Expect connectionerror or pass ssl_ctx = ssl.create_default_context() @@ -118,11 +119,10 @@ async def test_ssl_custom(): await Tortoise._drop_databases() +@requireCapability(dialect="postgres") @pytest.mark.asyncio -async def test_application_name(): +async def test_application_name(db_isolated): db_config, is_asyncpg, is_psycopg = _get_db_config() - if not is_asyncpg and not is_psycopg: - pytest.skip("PostgreSQL only") db_config["connections"]["models"]["credentials"]["application_name"] = "mytest_application" try: @@ -138,3 +138,162 @@ async def test_application_name(): finally: if Tortoise._inited: await Tortoise._drop_databases() + + +def _get_query_plan(result: list): + query_plan = result[0]["QUERY PLAN"] + if isinstance(query_plan, str): + query_plan = json.loads(query_plan) + return query_plan[0] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain() + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_format_text(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(output_fmt="text") + assert isinstance(result[0]["QUERY PLAN"], str) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_format_yaml(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(output_fmt="yaml") + yaml.safe_dump(result[0]["QUERY PLAN"]) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_format_xml(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(output_fmt="xml") + ET.fromstring(result[0]["QUERY PLAN"]) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_unsupported_format(db_simple): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError) as exc_info: + await Tournament.all().explain(output_fmt="invalid") + assert "Unsupported explain format" in str(exc_info.value) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_analyze(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(analyze=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Actual Loops" in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_costs(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(costs=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Total Cost" in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_buffers(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(buffers=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Shared Hit Blocks" in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_timing(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(analyze=True, timing=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Actual Total Time" in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_memory(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(memory=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Memory" in query_plan or "Memory" in str(query_plan) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_settings(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(settings=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_summary(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(summary=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Planning Time" in query_plan + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_multiple_options(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(analyze=True, costs=True, buffers=True) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Actual Loops" in query_plan["Plan"] + assert "Total Cost" in query_plan["Plan"] + assert "Shared Hit Blocks" in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_unsupported_option(db_simple): + await Tournament.create(name="Test") + with pytest.raises(UnSupportedError) as exc_info: + await Tournament.all().explain(unsupported_option=True) + assert "UNSUPPORTED_OPTION" in str(exc_info.value) + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_option_false(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain(analyze=False) + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Actual Loops" not in query_plan["Plan"] + + +@requireCapability(dialect="postgres") +@pytest.mark.asyncio +async def test_explain_default_verbose(db_simple): + await Tournament.create(name="Test") + result = await Tournament.all().explain() + query_plan = _get_query_plan(result) + assert "Plan" in query_plan + assert "Output" in query_plan["Plan"] diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index 0738816ab..9865b78fc 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -10,7 +10,7 @@ from pypika_tortoise import JoinType, Parameter, Table from pypika_tortoise.queries import QueryBuilder -from tortoise.exceptions import OperationalError +from tortoise.exceptions import OperationalError, UnSupportedError from tortoise.expressions import Expression, ResolveContext from tortoise.fields.base import DatabaseDefault from tortoise.fields.relational import ( @@ -96,7 +96,15 @@ def __init__( self.update_cache, ) = EXECUTOR_CACHE[key] - async def execute_explain(self, sql: str) -> Any: + async def execute_explain( + self, sql: str, output_fmt: str | None = None, **options: bool + ) -> Any: + if output_fmt: + raise UnSupportedError("This database does not support different explain formats") + + if options: + raise UnSupportedError("This database does not support explain options") + sql = " ".join((self.EXPLAIN_PREFIX, sql)) return (await self.db.execute_query(sql))[1] diff --git a/tortoise/backends/base_postgres/executor.py b/tortoise/backends/base_postgres/executor.py index 7286c5544..4788e6510 100644 --- a/tortoise/backends/base_postgres/executor.py +++ b/tortoise/backends/base_postgres/executor.py @@ -3,7 +3,7 @@ import uuid from collections.abc import Callable, Sequence from functools import partial -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from pypika_tortoise.dialects import PostgreSQLQueryBuilder from pypika_tortoise.queries import QueryBuilder @@ -28,6 +28,7 @@ postgres_posix_regex, ) from tortoise.contrib.postgres.search import SearchCriterion +from tortoise.exceptions import UnSupportedError from tortoise.filters import ( array_contained_by, array_contains, @@ -53,7 +54,23 @@ def postgres_search( class BasePostgresExecutor(BaseExecutor): - EXPLAIN_PREFIX = "EXPLAIN (FORMAT JSON, VERBOSE)" + EXPLAIN_PREFIX = "EXPLAIN ({})" + EXPLAIN_SUPPORTED_FORMATS = ["TEXT", "JSON", "XML", "YAML"] + EXPLAIN_SUPPORTED_OPTIONS = frozenset( + [ + "ANALYZE", + "BUFFERS", + "COSTS", + "GENERIC_PLAN", + "MEMORY", + "SETTINGS", + "SERIALIZE", + "SUMMARY", + "TIMING", + "VERBOSE", + "WAL", + ] + ) DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID} FILTER_FUNC_OVERRIDE = { array_contains: postgres_array_contains, @@ -106,3 +123,22 @@ async def _process_insert_result(self, instance: Model, results: dict | None) -> model_field = db_projection[key] field_object = self.model._meta.fields_map[model_field] setattr(instance, model_field, field_object.to_python_value(val)) + + async def execute_explain( + self, sql: str, output_fmt: str | None = None, **options: bool + ) -> Any: + output_fmt = output_fmt or "JSON" + if output_fmt.upper() not in self.EXPLAIN_SUPPORTED_FORMATS: + raise UnSupportedError(f"Unsupported explain format: {output_fmt}") + + options = options or {"verbose": True} + + required_options = set(option.upper() for option, required in options.items() if required) + if unsupported_options := (required_options - self.EXPLAIN_SUPPORTED_OPTIONS): + raise UnSupportedError(f"Unsupported options: {unsupported_options}") + + required_options.add("FORMAT " + output_fmt.upper()) + postgres_options = ", ".join(required_options) + explain_statement = self.EXPLAIN_PREFIX.format(postgres_options) + sql = " ".join((explain_statement, sql)) + return (await self.db.execute_query(sql))[1] diff --git a/tortoise/backends/mssql/executor.py b/tortoise/backends/mssql/executor.py index db17684c6..8816b71e4 100644 --- a/tortoise/backends/mssql/executor.py +++ b/tortoise/backends/mssql/executor.py @@ -7,5 +7,7 @@ class MSSQLExecutor(ODBCExecutor): - async def execute_explain(self, sql: str) -> Any: + async def execute_explain( + self, sql: str, output_fmt: str | None = None, **options: bool + ) -> Any: raise UnSupportedError("MSSQL does not support explain") diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 8bb911a6c..be5b4d6c7 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -1,4 +1,5 @@ import enum +from typing import Any from pypika_tortoise import SqlContext, functions from pypika_tortoise.enums import SqlTypes @@ -14,6 +15,7 @@ mysql_json_filter, ) from tortoise.contrib.mysql.search import SearchCriterion +from tortoise.exceptions import UnSupportedError from tortoise.fields import BigIntField, IntField, SmallIntField from tortoise.filters import ( Like, @@ -124,7 +126,26 @@ class MySQLExecutor(BaseExecutor): json_filter: mysql_json_filter, posix_regex: mysql_posix_regex, } - EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON" + EXPLAIN_SUPPORTED_FORMATS = ["JSON", "TRADITIONAL", "TREE"] + + async def execute_explain( + self, sql: str, output_fmt: str | None = None, **options: bool + ) -> Any: + output_fmt = output_fmt or "JSON" + if output_fmt.upper() not in self.EXPLAIN_SUPPORTED_FORMATS: + raise UnSupportedError(f"Unsupported explain format: {output_fmt}") + + if options and not all(k == "analyze" for k in options): + unsupported = [k for k in options if k != "analyze"] + raise UnSupportedError(f"Unsupported options: {set(unsupported)}") + + explain_parts = [] + if options.get("analyze"): + explain_parts.append("ANALYZE") + explain_parts.append(f"FORMAT={output_fmt.upper()}") + + explain_statement = "EXPLAIN " + " ".join(explain_parts) + return (await self.db.execute_query(f"{explain_statement} {sql}"))[1] async def _process_insert_result(self, instance: Model, results: int) -> None: pk_field_object = self.model._meta.pk diff --git a/tortoise/queryset.py b/tortoise/queryset.py index ffe0a1f22..23f6e0540 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1056,25 +1056,32 @@ def prefetch_related(self, *args: str | Prefetch) -> QuerySet[MODEL]: queryset._prefetch_map[first_level_field].add(forwarded_prefetch) return queryset - async def explain(self) -> Any: + async def explain(self, output_fmt: str | None = None, **options: bool) -> Any: """Fetch and return information about the query execution plan. This is done by executing an ``EXPLAIN`` query whose exact prefix depends on the database backend, as documented below. - - PostgreSQL: ``EXPLAIN (FORMAT JSON, VERBOSE) ...`` - - SQLite: ``EXPLAIN QUERY PLAN ...`` - - MySQL: ``EXPLAIN FORMAT=JSON ...`` + :param output_fmt: The output format for the EXPLAIN result. + - PostgreSQL: ``text``, ``json``, ``xml``, ``yaml`` (default: ``json``) + - MySQL: ``json``, ``traditional``, ``tree`` (default: ``json``) + - SQLite, MSSQL, Oracle: Not supported (raises UnSupportedError) + :param options: Additional options for EXPLAIN (database-specific). + - PostgreSQL: ``analyze``, ``buffers``, ``costs``, ``memory``, ``settings``, ``summary``, ``timing``, ``verbose``, ``wal``, ``generic_plan``, ``serialize`` (if not provided default is ``verbose``) + - MySQL: ``analyze`` + - SQLite, MSSQL, Oracle: Not supported (raises UnSupportedError) .. note:: This is only meant to be used in an interactive environment for debugging and query optimization. **The output format may (and will) vary greatly depending on the database backend.** + + :raises UnSupportedError: If the database does not support the requested format or options. """ self._choose_db_if_not_chosen() self._make_query() return await self._db.executor_class(model=self.model, db=self._db).execute_explain( - self.query.get_sql() + self.query.get_sql(), output_fmt, **options ) def using_db(self, _db: BaseDBAsyncClient | None) -> QuerySet[MODEL]: