diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 07491e85e..c763b4cad 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -194,3 +194,77 @@ jobs: - name: Upload coverage uses: codecov/codecov-action@v1 if: matrix.python-version == '3.13' + + mysql: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + + services: + mysql: + image: mysql:8.4 + env: + MYSQL_ROOT_PASSWORD: rootpassword + MYSQL_DATABASE: piccolo + MYSQL_ROOT_HOST: '%' + options: >- + --health-cmd="mysqladmin ping -uroot -prootpassword" + --health-interval=5s + --health-timeout=2s + --health-retries=10 + ports: + - 3306:3306 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/requirements.txt + pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/mysql.txt + + - name: Install MySQL Client + run: sudo apt-get update && sudo apt-get install -y mysql-client + + - name: Wait for MySQL + run: | + set -e + for i in {1..60}; do + if mysqladmin ping -h127.0.0.1 -P3306 -uroot -prootpassword > /dev/null 2>&1; then + echo "MySQL is up" + break + fi + echo "Waiting for MySQL… ($i/60)" + sleep 2 + if [ "$i" -eq 60 ]; then + echo "MySQL did not become ready in time!" >&2 + mysqladmin ping -h127.0.0.1 -P3306 -uroot -prootpassword || true + exit 1 + fi + done + + - name: Setup MySQL (create database) + run: | + mysql -h127.0.0.1 -P3306 -uroot -prootpassword -e "CREATE DATABASE IF NOT EXISTS piccolo;" + + - name: Test with pytest, MySQL + run: ./scripts/test-mysql.sh + env: + MY_HOST: 127.0.0.1 + MY_PORT: 3306 + MY_USER: root + MY_PASSWORD: rootpassword + MY_DATABASE: piccolo + + - name: Upload coverage + uses: codecov/codecov-action@v1 + if: matrix.python-version == '3.13' diff --git a/docs/src/piccolo/engines/connection_pool.rst b/docs/src/piccolo/engines/connection_pool.rst index f5856a917..7cbbe911b 100644 --- a/docs/src/piccolo/engines/connection_pool.rst +++ b/docs/src/piccolo/engines/connection_pool.rst @@ -3,7 +3,7 @@ Connection Pool =============== -.. hint:: Connection pools can be used with Postgres and CockroachDB. +.. hint:: Connection pools can be used with Postgres, CockroachDB and MySQL. Setup ~~~~~ diff --git a/docs/src/piccolo/engines/index.rst b/docs/src/piccolo/engines/index.rst index db655c76b..9641ea5f1 100644 --- a/docs/src/piccolo/engines/index.rst +++ b/docs/src/piccolo/engines/index.rst @@ -127,4 +127,5 @@ Engine types ./sqlite_engine ./postgres_engine ./cockroach_engine + ./mysql_engine ./connection_pool diff --git a/docs/src/piccolo/engines/mysql_engine.rst b/docs/src/piccolo/engines/mysql_engine.rst new file mode 100644 index 000000000..764bca778 --- /dev/null +++ b/docs/src/piccolo/engines/mysql_engine.rst @@ -0,0 +1,44 @@ +MySQLEngine +=========== + +Configuration +------------- + +.. code-block:: python + + # piccolo_conf.py + from piccolo.engine.mysql import MySQLEngine + + + DB = MySQLEngine( + config={ + "host": "localhost", + "port": 3306, + "user": "root", + "password": "", + "db": "piccolo", + } + ) + +config +~~~~~~ + +The config dictionary is passed directly to the underlying database adapter, +aiomysql. See the `aiomysql docs `_ +to learn more. + +------------------------------------------------------------------------------- + +Connection Pool +--------------- + +See :ref:`ConnectionPool`. + +------------------------------------------------------------------------------- + +Source +------ + +.. currentmodule:: piccolo.engine.mysql + +.. autoclass:: MySQLEngine diff --git a/docs/src/piccolo/getting_started/database_support.rst b/docs/src/piccolo/getting_started/database_support.rst index 1106cd350..47ab917b2 100644 --- a/docs/src/piccolo/getting_started/database_support.rst +++ b/docs/src/piccolo/getting_started/database_support.rst @@ -17,6 +17,12 @@ together in production. The main missing feature is support for :ref:`automatic database migrations ` due to SQLite's limited support for ``ALTER TABLE`` ``DDL`` statements. +`MySQL `_ is also supported. There may be some features +not supported, but it's OK to use. :ref:`Automatic database migrations ` +is supported but we must be careful because MySQL ``DDL`` statements +`are not transactional `_ +and MySQL will commit the changes in transaction. + What about other databases? --------------------------- diff --git a/docs/src/piccolo/getting_started/index.rst b/docs/src/piccolo/getting_started/index.rst index 373cad786..2ad559d3d 100644 --- a/docs/src/piccolo/getting_started/index.rst +++ b/docs/src/piccolo/getting_started/index.rst @@ -12,5 +12,6 @@ Getting Started ./setup_postgres ./setup_cockroach ./setup_sqlite + ./setup_mysql ./example_schema ./sync_and_async diff --git a/docs/src/piccolo/getting_started/setup_mysql.rst b/docs/src/piccolo/getting_started/setup_mysql.rst new file mode 100644 index 000000000..8250247ce --- /dev/null +++ b/docs/src/piccolo/getting_started/setup_mysql.rst @@ -0,0 +1,28 @@ +.. _setting_up_mysql: + +########### +Setup MySQL +########### + +Installation +************ + +Follow the `instructions for your OS `_. + + +Creating a database +******************* + +Using ``mysql``: + +.. code-block:: bash + + mysql -u root -p + +Enter your password and create the database: + +.. code-block:: bash + + CREATE DATABASE "my_database_name"; + +Alternatively, use a GUI tool. diff --git a/docs/src/piccolo/migrations/create.rst b/docs/src/piccolo/migrations/create.rst index 486c46165..44ec8b133 100644 --- a/docs/src/piccolo/migrations/create.rst +++ b/docs/src/piccolo/migrations/create.rst @@ -256,9 +256,10 @@ Creating an auto migration: aren't supported by auto migrations, or to modify the data held in tables, as opposed to changing the tables themselves. -.. warning:: Auto migrations aren't supported in SQLite, because of SQLite's - extremely limited support for SQL Alter statements. This might change in - the future. +.. warning:: Auto migrations for SQLite and MySQL are supported, with limitations. + SQLite has extremely limited support for SQL ALTER statements and MySQL DDL triggers + an implicit commit in transaction and we cannot roll back a DDL using ROLLBACK + (non-transactional DDL). This might change in the future. Troubleshooting ~~~~~~~~~~~~~~~ diff --git a/docs/src/piccolo/playground/advanced.rst b/docs/src/piccolo/playground/advanced.rst index e8f459b9f..7c9e9fd2a 100644 --- a/docs/src/piccolo/playground/advanced.rst +++ b/docs/src/piccolo/playground/advanced.rst @@ -94,6 +94,41 @@ When you have the database setup, you can connect to it as follows: piccolo playground run --engine=cockroach + +MySQL +----- + +Install MySQL +~~~~~~~~~~~~~ + +See :ref:`the docs on settings up MySQL `. + +Create database +~~~~~~~~~~~~~~~ + +By default the playground expects a local database to exist with the following +credentials: + +.. code-block:: bash + + user: "root" + password: "" + host: "localhost" + db: "piccolo_playground" + port: 3306 + +If you want to use different credentials, you can pass them into the playground +command (use ``piccolo playground run --help`` for details). + +Connecting +~~~~~~~~~~ + +When you have the database setup, you can connect to it as follows: + +.. code-block:: bash + + piccolo playground run --engine=mysql + iPython ------- diff --git a/docs/src/piccolo/query_clauses/on_conflict.rst b/docs/src/piccolo/query_clauses/on_conflict.rst index c3adfa51c..309934302 100644 --- a/docs/src/piccolo/query_clauses/on_conflict.rst +++ b/docs/src/piccolo/query_clauses/on_conflict.rst @@ -133,6 +133,8 @@ You can also specify the name of a constraint using a string: ... target='some_constraint' ... ) +.. warning:: Not supported for MySQL. + ``values`` ---------- @@ -194,11 +196,15 @@ update should be made: ... where=Band.popularity < 1000 ... ) +.. warning:: Not supported for MySQL. A workaround is possible by using an + ``IF`` or ``CASE`` condition in the ``UPDATE`` clause or by first + performing a separate ``UPDATE``, but this is not currently supported in Piccolo. + Multiple ``on_conflict`` clauses -------------------------------- -SQLite allows you to specify multiple ``ON CONFLICT`` clauses, but Postgres and -Cockroach don't. +SQLite allows you to specify multiple ``ON CONFLICT`` clauses, but Postgres, +Cockroach and MySQL don't. .. code-block:: python @@ -218,6 +224,7 @@ Learn more * `Postgres docs `_ * `Cockroach docs `_ * `SQLite docs `_ +* `MySQL docs `_ Source ------ diff --git a/docs/src/piccolo/query_clauses/returning.rst b/docs/src/piccolo/query_clauses/returning.rst index 6f613e049..1818a630f 100644 --- a/docs/src/piccolo/query_clauses/returning.rst +++ b/docs/src/piccolo/query_clauses/returning.rst @@ -49,4 +49,6 @@ how many rows were affected or processed by the operation. .. warning:: This works for all versions of Postgres, but only `SQLite 3.35.0 `_ and above support the returning clause. See the :ref:`docs ` on - how to check your SQLite version. + how to check your SQLite version. + + Not supported for MySQL because there is no ``RETURNING`` clause in MySQL. diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 57b97fa2d..298ea9983 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -16,12 +16,15 @@ ) from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types -from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.columns.column_types import JSON, Blob, ForeignKey, Serial, Text from piccolo.engine import engine_finder from piccolo.engine.cockroach import CockroachTransaction from piccolo.query import Query from piccolo.query.base import DDL -from piccolo.query.constraints import get_fk_constraint_name +from piccolo.query.constraints import ( + get_fk_constraint_name, + get_fk_constraint_name_mysql, +) from piccolo.schema import SchemaDDLBase from piccolo.table import Table, create_table_class, sort_table_classes from piccolo.utils.warnings import colored_warning @@ -543,10 +546,16 @@ async def _run_alter_columns(self, backwards: bool = False): assert isinstance(fk_column, ForeignKey) + if existing_table._meta.db.engine_type == "mysql": + constraint_name = await get_fk_constraint_name_mysql( + column=fk_column + ) + else: + constraint_name = await get_fk_constraint_name( + column=fk_column + ) + # First drop the existing foreign key constraint - constraint_name = await get_fk_constraint_name( - column=fk_column - ) if constraint_name: await self._run_query( _Table.alter().drop_constraint( @@ -642,6 +651,16 @@ async def _run_alter_columns(self, backwards: bool = False): column._meta._name = alter_column.column_name column._meta.db_column_name = alter_column.db_column_name + if _Table._meta.db.engine_type == "mysql" and ( + column_class == Text + or column_class == JSON + or column_class == Blob + ): + raise ValueError( + "MySQL does not support default value in alter " + "statement for TEXT, JSON and BLOB columns" + ) + if default is None: await self._run_query( _Table.alter().drop_default(column=column) diff --git a/piccolo/apps/playground/commands/run.py b/piccolo/apps/playground/commands/run.py index 2e539f337..cf761d7a7 100644 --- a/piccolo/apps/playground/commands/run.py +++ b/piccolo/apps/playground/commands/run.py @@ -29,7 +29,12 @@ Varchar, ) from piccolo.columns.readable import Readable -from piccolo.engine import CockroachEngine, PostgresEngine, SQLiteEngine +from piccolo.engine import ( + CockroachEngine, + MySQLEngine, + PostgresEngine, + SQLiteEngine, +) from piccolo.engine.base import Engine from piccolo.table import Table from piccolo.utils.warnings import colored_string @@ -424,6 +429,16 @@ def run( "port": port or 26257, } ) + elif engine.upper() == "MYSQL": + db = MySQLEngine( + { + "host": host, + "db": database, + "user": user or "root", + "password": password or "", + "port": port or 3306, + } + ) else: db = SQLiteEngine() for _table in TABLES: diff --git a/piccolo/apps/schema/commands/generate.py b/piccolo/apps/schema/commands/generate.py index 5e1785784..5210dc8d1 100644 --- a/piccolo/apps/schema/commands/generate.py +++ b/piccolo/apps/schema/commands/generate.py @@ -67,11 +67,24 @@ class RowMeta: data_type: str numeric_precision: Optional[Union[int, str]] numeric_scale: Optional[Union[int, str]] - numeric_precision_radix: Optional[Literal[2, 10]] + numeric_precision_radix: Optional[Literal[2, 10]] = None @classmethod def get_column_name_str(cls) -> str: - return ", ".join(i.name for i in dataclasses.fields(cls)) + from piccolo.engine import engine_finder + + engine = engine_finder() + assert engine + + excluded_columns = [] + if engine.engine_type == "mysql": + excluded_columns = ["numeric_precision_radix"] + + return ", ".join( + i.name + for i in dataclasses.fields(cls) + if i.name not in excluded_columns + ) @dataclasses.dataclass @@ -615,6 +628,12 @@ async def get_table_schema( table. """ + schema_name = ( + "DATABASE()" + if table_class._meta.db.engine_type == "mysql" + else schema_name + ) + row_meta_list = await table_class.raw( ( f"SELECT {RowMeta.get_column_name_str()} FROM " diff --git a/piccolo/apps/sql_shell/commands/run.py b/piccolo/apps/sql_shell/commands/run.py index a666321a4..b6066f137 100644 --- a/piccolo/apps/sql_shell/commands/run.py +++ b/piccolo/apps/sql_shell/commands/run.py @@ -5,6 +5,7 @@ from typing import cast from piccolo.engine.finder import engine_finder +from piccolo.engine.mysql import MySQLEngine from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine @@ -64,3 +65,35 @@ def run() -> None: print("Enter .quit to exit") subprocess.run(["sqlite3", database], check=True) + + if isinstance(engine, MySQLEngine): + engine = cast(MySQLEngine, engine) + + args = ["mysql"] + + config = engine.config + + if dsn := config.get("dsn"): + args += [dsn] + else: + if user := config.get("user"): + args += ["-u", user] + if host := config.get("host"): + args += ["-h", host] + if port := config.get("port"): + args += ["-p", str(port)] + if database := config.get("db"): + args += [database] + + sigint_handler = signal.getsignal(signal.SIGINT) + subprocess_env = os.environ.copy() + if password := config.get("password"): + subprocess_env["MYSQLPASSWORD"] = str(password) + try: + # Allow SIGINT to pass to mysql to abort queries. + signal.signal(signal.SIGINT, signal.SIG_IGN) + print("Enter \\q to exit") + subprocess.run(args, check=True, env=subprocess_env) + finally: + # Restore the original SIGINT handler. + signal.signal(signal.SIGINT, sigint_handler) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 879e4088f..c713a9846 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -1007,6 +1007,10 @@ def ddl(self) -> str: f" ON UPDATE {on_update}" ) + if self._meta.engine_type == "mysql": + # omit DEFAULT clause for MySQL + return query + # Always ran for Cockroach because unique_rowid() is directly # defined for Cockroach Serial and BigSerial. # Postgres and SQLite will not run this for Serial and BigSerial. diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 887129b4b..cb8667388 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -72,8 +72,11 @@ class Band(Table): from piccolo.columns.defaults.uuid import UUID4, UUIDArg from piccolo.columns.operators.comparison import ( ArrayAll, + ArrayAllMySQL, ArrayAny, + ArrayAnyMySQL, ArrayNotAny, + ArrayNotAnyMySQL, ) from piccolo.columns.operators.string import Concat from piccolo.columns.reference import LazyTableReference @@ -241,6 +244,10 @@ def get_sqlite_interval_string(self, interval: timedelta) -> str: output_string = ", ".join(output) return output_string + def get_mysql_interval_string(self, interval: timedelta) -> str: + total_seconds = interval.total_seconds() + return f"{total_seconds} SECOND" + def get_querystring( self, column: Column, @@ -258,6 +265,11 @@ def get_querystring( return QueryString( f'"{column_name}" {operator} INTERVAL {value_string}', ) + elif engine_type == "mysql": + value_string = self.get_mysql_interval_string(interval=value) + return QueryString( + f"`{column_name}` {operator} INTERVAL {value_string}", + ) elif engine_type == "sqlite": if isinstance(column, Interval): # SQLite doesn't have a proper Interval type. Instead we store @@ -442,9 +454,23 @@ def __init__( **kwargs: Unpack[ColumnKwargs], ) -> None: self._validate_default(default, (str, None)) + self.default = default super().__init__(default=default, **kwargs) + def get_default_value(self): + """ + MySQL does not allow unquoted TEXT literals in the DEFAULT + clause, so we use the expression in parentheses. + Only works in CREATE TABLE. MySQL does not allow default + values for TEXT columns in ALTER statements. + """ + engine_type = self._meta.engine_type + + if engine_type == "mysql": + return QueryString(f"('{self.default}')") + return super().get_default_value() + ########################################################################### # For update queries @@ -502,6 +528,15 @@ class Band(Table): value_type = uuid.UUID + @property + def column_type(self): + engine_type = self._meta.engine_type + if engine_type in ("postgres", "cockroach", "sqlite"): + return "UUID" + elif engine_type == "mysql": + return "CHAR(36)" + raise Exception("Unrecognized engine type") + def __init__( self, default: UUIDArg = UUID4(), @@ -686,7 +721,7 @@ class Band(Table): """ def _get_column_type(self, engine_type: str): - if engine_type == "postgres": + if engine_type in ("postgres", "mysql"): return "BIGINT" elif engine_type == "cockroach": return "BIGINT" @@ -738,7 +773,7 @@ class Band(Table): @property def column_type(self): engine_type = self._meta.engine_type - if engine_type == "postgres": + if engine_type in ("postgres", "mysql"): return "SMALLINT" elif engine_type == "cockroach": return "SMALLINT" @@ -783,6 +818,8 @@ def column_type(self): return "INTEGER" elif engine_type == "sqlite": return "INTEGER" + elif engine_type == "mysql": + return "INT AUTO_INCREMENT" raise Exception("Unrecognized engine type") def default(self) -> QueryString: @@ -794,6 +831,8 @@ def default(self) -> QueryString: return QueryString("unique_rowid()") elif engine_type == "sqlite": return NULL + elif engine_type == "mysql": + return NULL raise Exception("Unrecognized engine type") ########################################################################### @@ -826,6 +865,8 @@ def column_type(self): return "BIGINT" elif engine_type == "sqlite": return "INTEGER" + elif engine_type == "mysql": + return "BIGINT AUTO_INCREMENT" raise Exception("Unrecognized engine type") ########################################################################### @@ -917,6 +958,14 @@ class Concert(Table): value_type = datetime timedelta_delegate = TimedeltaDelegate() + @property + def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "mysql": + return "DATETIME(6)" + else: + return "TIMESTAMP" + def __init__( self, default: TimestampArg = TimestampNow(), @@ -1034,6 +1083,14 @@ def __init__( self.default = default super().__init__(default=default, **kwargs) + @property + def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "mysql": + return "TIMESTAMP(6)" + else: + return "TIMESTAMPTZ" + ########################################################################### # For update queries @@ -1281,6 +1338,8 @@ def column_type(self): # make it an integer - but we need a text field. # https://sqlite.org/datatype3.html#determination_of_column_affinity return "SECONDS" + elif engine_type == "mysql": + return "TIME(6)" raise Exception("Unrecognized engine type") ########################################################################### @@ -2346,13 +2405,18 @@ def __init__( self.json_operator: Optional[str] = None - @property - def column_type(self): + def get_default_value(self): + """ + MySQL does not allow unquoted JSON literals in the DEFAULT + clause, so we use the expression in parentheses. + Only works in CREATE TABLE. MySQL does not allow default + values for JSON columns in ALTER statements. + """ engine_type = self._meta.engine_type - if engine_type == "cockroach": - return "JSONB" # Cockroach is always JSONB. - else: - return "JSON" + + if engine_type == "mysql": + return QueryString(f"('{self.default}')") + return super().get_default_value() ########################################################################### @@ -2455,8 +2519,24 @@ class JSONB(JSON): @property def column_type(self): + engine_type = self._meta.engine_type + if engine_type == "mysql": + return "JSON" return "JSONB" # Must be defined, we override column_type() in JSON() + def get_default_value(self): + """ + MySQL does not allow unquoted JSON literals in the DEFAULT + clause, so we use the expression in parentheses. + Only works in CREATE TABLE. MySQL does not allow default + values for JSON columns in ALTER statements. + """ + engine_type = self._meta.engine_type + + if engine_type == "mysql": + return QueryString("('')") + return super().get_default_value() + ########################################################################### # Descriptors @@ -2503,7 +2583,7 @@ def column_type(self): engine_type = self._meta.engine_type if engine_type in ("postgres", "cockroach"): return "BYTEA" - elif engine_type == "sqlite": + elif engine_type in ("sqlite", "mysql"): return "BLOB" raise Exception("Unrecognized engine type") @@ -2527,6 +2607,20 @@ def __init__( self.default = default super().__init__(default=default, **kwargs) + def get_default_value(self): + """ + MySQL does not allow unquoted BLOB literals in the DEFAULT + clause, so we use the expression in parentheses. + Only works in CREATE TABLE. MySQL does not allow default + values for BLOB columns in ALTER statements. + """ + engine_type = self._meta.engine_type + + if engine_type == "mysql": + return QueryString(f"({self.default})") + + return super().get_default_value() + ########################################################################### # Descriptors @@ -2657,8 +2751,23 @@ def column_type(self): ) else "ARRAY" ) + elif engine_type == "mysql": + return "JSON" # use JSON column raise Exception("Unrecognized engine type") + def get_default_value(self): + """ + MySQL does not allow unquoted JSON literals in the DEFAULT + clause, so we use the expression in parentheses. + Only works in CREATE TABLE. MySQL does not allow default + values for TEXT columns in ALTER statements. + """ + engine_type = self._meta.engine_type + + if engine_type == "mysql": + return QueryString("('')") + return super().get_default_value() + def _setup_base_column(self, table_class: type[Table]): """ Called from the ``Table.__init_subclass__`` - makes sure @@ -2782,6 +2891,8 @@ def any(self, value: Any) -> Where: if engine_type in ("postgres", "cockroach"): return Where(column=self, value=value, operator=ArrayAny) + if engine_type == "mysql": + return Where(column=self, value=value, operator=ArrayAnyMySQL) elif engine_type == "sqlite": return self.like(f"%{value}%") else: @@ -2800,6 +2911,8 @@ def not_any(self, value: Any) -> Where: if engine_type in ("postgres", "cockroach"): return Where(column=self, value=value, operator=ArrayNotAny) + if engine_type == "mysql": + return Where(column=self, value=value, operator=ArrayNotAnyMySQL) elif engine_type == "sqlite": return self.not_like(f"%{value}%") else: @@ -2818,6 +2931,8 @@ def all(self, value: Any) -> Where: if engine_type in ("postgres", "cockroach"): return Where(column=self, value=value, operator=ArrayAll) + if engine_type == "mysql": + return Where(column=self, value=value, operator=ArrayAllMySQL) elif engine_type == "sqlite": raise ValueError("Unsupported by SQLite") else: diff --git a/piccolo/columns/combination.py b/piccolo/columns/combination.py index b9695a191..e520a7352 100644 --- a/piccolo/columns/combination.py +++ b/piccolo/columns/combination.py @@ -232,9 +232,20 @@ def querystring_for_update_and_delete(self) -> QueryString: column = self.column if column._meta.call_chain: - # Use a sub select to find the correct ID. root_column = column._meta.call_chain[0] - sub_query = root_column._meta.table.select(root_column).where(self) + if column._meta.engine_type == "mysql": + # MySQL does not allow updating a table when it appears + # inside a subquery used by the same UPDATE, so we use + # joins to replace subqueries in MySQL + root_column_joins = ( + root_column._foreign_key_meta.resolved_references + ) + sub_query = root_column_joins.select(root_column).where(self) + else: + # Use a sub select to find the correct ID. + sub_query = root_column._meta.table.select(root_column).where( + self + ) column_name = column._meta.call_chain[0]._meta.db_column_name return QueryString( diff --git a/piccolo/columns/defaults/base.py b/piccolo/columns/defaults/base.py index fcf46bd85..3deab0cb9 100644 --- a/piccolo/columns/defaults/base.py +++ b/piccolo/columns/defaults/base.py @@ -17,6 +17,11 @@ def postgres(self) -> str: def sqlite(self) -> str: pass + @property + @abstractmethod + def mysql(self) -> str: + pass + @abstractmethod def python(self) -> Any: pass @@ -57,6 +62,27 @@ def get_sqlite_interval_string(self, attributes: list[str]) -> str: return ", ".join(interval_components) + def get_mysql_interval_string(self, attributes: list[str]) -> str: + """ + In MySQL the interval string is different and we should use + CURRENT_TIMESTAMP + INTERVAL 7 DAY + INTERVAL 10 HOUR etc. + but I can't get that to work so I convert to seconds and + use that interval of seconds with the DATE_ADD() function. + """ + interval_components = [] + for attr_name in attributes: + attr = getattr(self, attr_name, None) + if attr is not None: + if attr_name == "days": + attr += attr * 86400 + elif attr_name == "hours": + attr += attr * 3600 + elif attr_name == "minutes": + attr += attr * 60 + interval_components.append(attr) + + return sum(interval_components) + def __repr__(self): return repr_class_instance(self) diff --git a/piccolo/columns/defaults/date.py b/piccolo/columns/defaults/date.py index b802c6764..50eafb09e 100644 --- a/piccolo/columns/defaults/date.py +++ b/piccolo/columns/defaults/date.py @@ -45,6 +45,11 @@ def sqlite(self): interval_string = self.get_sqlite_interval_string(["days"]) return f"(datetime(CURRENT_TIMESTAMP, {interval_string}))" + @property + def mysql(self): + interval_string = self.get_sqlite_interval_string(["days"]) + return f"(DATE(NOW()) + INTERVAL {interval_string}" + def python(self): return ( datetime.datetime.now() + datetime.timedelta(days=self.days) @@ -64,6 +69,10 @@ def cockroach(self): def sqlite(self): return "CURRENT_DATE" + @property + def mysql(self): + return "(DATE(CURRENT_TIMESTAMP))" + def python(self): return datetime.datetime.now().date() @@ -92,6 +101,10 @@ def cockroach(self): def sqlite(self): return f"'{self.date.isoformat()}'" + @property + def mysql(self): + return f"{self.date.isoformat()}" + def python(self): return self.date diff --git a/piccolo/columns/defaults/interval.py b/piccolo/columns/defaults/interval.py index 798a4a050..8ab58678a 100644 --- a/piccolo/columns/defaults/interval.py +++ b/piccolo/columns/defaults/interval.py @@ -62,6 +62,21 @@ def cockroach(self): def sqlite(self): return self.timedelta.total_seconds() + @property + def mysql(self): + value = self.get_mysql_interval_string( + attributes=[ + "weeks", + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + ] + ) + return f"(SEC_TO_TIME({value}))" + def python(self): return self.timedelta diff --git a/piccolo/columns/defaults/time.py b/piccolo/columns/defaults/time.py index a32dcdf47..20dca3d1a 100644 --- a/piccolo/columns/defaults/time.py +++ b/piccolo/columns/defaults/time.py @@ -35,6 +35,13 @@ def sqlite(self): ) return f"(time(CURRENT_TIME, {interval_string}))" + @property + def mysql(self): + interval_string = self.get_postgres_interval_string( + ["hours", "minutes", "seconds"] + ) + return f"(CURRENT_TIME() + INTERVAL {interval_string}))" + def python(self): return ( datetime.datetime.now() @@ -57,6 +64,11 @@ def cockroach(self): def sqlite(self): return "CURRENT_TIME" + @property + def mysql(self): + # must use string literal + return f"'{datetime.datetime.now().time().strftime('%H:%M:%S')}'" + def python(self): return datetime.datetime.now().time() @@ -80,6 +92,10 @@ def cockroach(self): def sqlite(self): return f"'{self.time.isoformat()}'" + @property + def mysql(self): + return f"`{self.time.isoformat()}`" + def python(self): return self.time diff --git a/piccolo/columns/defaults/timestamp.py b/piccolo/columns/defaults/timestamp.py index 73def13ed..e398fb3c7 100644 --- a/piccolo/columns/defaults/timestamp.py +++ b/piccolo/columns/defaults/timestamp.py @@ -38,6 +38,13 @@ def sqlite(self): ) return f"(datetime(CURRENT_TIMESTAMP, {interval_string}))" + @property + def mysql(self): + interval_string = self.get_mysql_interval_string( + ["days", "hours", "minutes", "seconds"] + ) + return f"(DATE_ADD(NOW(), INTERVAL {interval_string} SECOND))" + def python(self): return datetime.datetime.now() + datetime.timedelta( days=self.days, @@ -65,6 +72,10 @@ def cockroach(self): def sqlite(self): return "current_timestamp" + @property + def mysql(self): + return "current_timestamp(6)" + def python(self): return datetime.datetime.now() @@ -114,6 +125,10 @@ def cockroach(self): def sqlite(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) + @property + def mysql(self): + return "'{}'".format(self.datetime.isoformat().replace("T", " ")) + def python(self): return self.datetime diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 8aa034129..9433dcb87 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -36,6 +36,10 @@ class TimestamptzNow(TimestampNow): def cockroach(self): return "current_timestamp" + @property + def mysql(self): + return "current_timestamp(6)" + def python(self): return datetime.datetime.now(tz=datetime.timezone.utc) diff --git a/piccolo/columns/defaults/uuid.py b/piccolo/columns/defaults/uuid.py index d58134f18..0bcb43df3 100644 --- a/piccolo/columns/defaults/uuid.py +++ b/piccolo/columns/defaults/uuid.py @@ -41,6 +41,10 @@ def cockroach(self): def sqlite(self): return None + @property + def mysql(self): + return "''" + def python(self): return uuid.uuid4() @@ -75,6 +79,10 @@ def cockroach(self): def sqlite(self): return None + @property + def mysql(self): + return None + def python(self): return uuid7() diff --git a/piccolo/columns/m2m.py b/piccolo/columns/m2m.py index b8d34d46b..c5553ba57 100644 --- a/piccolo/columns/m2m.py +++ b/piccolo/columns/m2m.py @@ -146,6 +146,80 @@ def get_select_string( AS "{m2m_relationship_name} [M2M]" """ ) + elif engine_type == "mysql": + if self.as_list: + column_name = self.columns[0]._meta.db_column_name + inner_select_mysql = f""" + SELECT `inner_{table_2_name}`.`{column_name}` + FROM {m2m_table_name_with_schema} + JOIN {table_1_name_with_schema} AS `inner_{table_1_name}` ON ( + {m2m_table_name_with_schema}.`{fk_1_name}` = `inner_{table_1_name}`.`{table_1_pk_name}` + ) + JOIN {table_2_name_with_schema} AS `inner_{table_2_name}` ON ( + {m2m_table_name_with_schema}.`{fk_2_name}` = `inner_{table_2_name}`.`{table_2_pk_name}` + ) + WHERE {m2m_table_name_with_schema}.`{fk_1_name}` = `{table_1_name}`.`{table_1_pk_name}` + """ # noqa: E501 + + return QueryString( + f""" + ( + SELECT JSON_ARRAYAGG(`inner_table`.`{column_name}`) + FROM ( + {inner_select_mysql} + ) AS `inner_table` + ) AS `{m2m_relationship_name}` + """ + ) + elif not self.serialisation_safe: + column_name = table_2_pk_name + inner_select_mysql = f""" + SELECT `inner_{table_2_name}`.`{column_name}` + FROM {m2m_table_name_with_schema} + JOIN {table_1_name_with_schema} AS `inner_{table_1_name}` ON ( + {m2m_table_name_with_schema}.`{fk_1_name}` = `inner_{table_1_name}`.`{table_1_pk_name}` + ) + JOIN {table_2_name_with_schema} AS `inner_{table_2_name}` ON ( + {m2m_table_name_with_schema}.`{fk_2_name}` = `inner_{table_2_name}`.`{table_2_pk_name}` + ) + WHERE {m2m_table_name_with_schema}.`{fk_1_name}` = `{table_1_name}`.`{table_1_pk_name}` + """ # noqa: E501 + + return QueryString( + f""" + ( + SELECT JSON_ARRAYAGG(inner_table.`{column_name}`) + FROM ( + {inner_select_mysql} + ) AS `inner_table` + ) AS `{m2m_relationship_name}` + """ + ) + else: + column_names = ", ".join( + f"inner_{table_2_name}.`{column._meta.db_column_name}`" + for column in self.columns + ) + json_object_fields = ", ".join( + f"'{column._meta.db_column_name}', {m2m_relationship_name}_results.`{column._meta.db_column_name}`" # noqa: E501 + for column in self.columns + ) + + return QueryString( + f""" + ( + SELECT JSON_ARRAYAGG( + JSON_OBJECT( + {json_object_fields} + ) + ) + FROM ( + SELECT {column_names} + FROM {inner_select} + ) AS {m2m_relationship_name}_results + ) AS `{m2m_relationship_name}` + """ + ) else: raise ValueError(f"{engine_type} is an unrecognised engine type") @@ -311,8 +385,12 @@ async def run(self): transaction, or wrapped in a new transaction. """ engine = self.rows[0]._meta.db - async with engine.transaction(): + # MySQL cannot safely do M2M inserts inside transactions. + if engine.engine_type == "mysql": return await self._run() + else: + async with engine.transaction(): + return await self._run() def run_sync(self): return run_sync(self.run()) diff --git a/piccolo/columns/operators/comparison.py b/piccolo/columns/operators/comparison.py index 91b565361..06a5bcd48 100644 --- a/piccolo/columns/operators/comparison.py +++ b/piccolo/columns/operators/comparison.py @@ -68,3 +68,15 @@ class ArrayNotAny(ComparisonOperator): class ArrayAll(ComparisonOperator): template = "{value} = ALL ({name})" + + +class ArrayAllMySQL(ComparisonOperator): + template = "{value} MEMBER OF({name})" + + +class ArrayAnyMySQL(ComparisonOperator): + template = "{value} MEMBER OF({name})" + + +class ArrayNotAnyMySQL(ComparisonOperator): + template = "NOT ({value} MEMBER OF({name}))" diff --git a/piccolo/columns/readable.py b/piccolo/columns/readable.py index cd02c5c91..6094650c4 100644 --- a/piccolo/columns/readable.py +++ b/piccolo/columns/readable.py @@ -46,6 +46,28 @@ def postgres_string(self) -> QueryString: def cockroach_string(self) -> QueryString: return self._get_string(operator="FORMAT") + @property + def mysql_string(self) -> QueryString: + """ + MySQL has no FORMAT for string templates, so we manually + expand placeholders into a CONCAT() expression. + """ + parts: list[str] = [] + template_parts = self.template.split("%s") + num_placeholders = len(template_parts) - 1 + + for i, part in enumerate(template_parts): + # Add literal string part + if part: + parts.append(f"'{part}'") + # Add column if within placeholders + if i < num_placeholders: + col = self.columns[i]._meta.get_full_name(with_alias=False) + parts.append(col) + + concat_expr = f"CONCAT({', '.join(parts)})" + return QueryString(f"{concat_expr} AS {self.output_name}") + def get_select_string( self, engine_type: str, with_alias=True ) -> QueryString: diff --git a/piccolo/engine/__init__.py b/piccolo/engine/__init__.py index eb050f5e6..2afaa1fba 100644 --- a/piccolo/engine/__init__.py +++ b/piccolo/engine/__init__.py @@ -1,6 +1,7 @@ from .base import Engine from .cockroach import CockroachEngine from .finder import engine_finder +from .mysql import MySQLEngine from .postgres import PostgresEngine from .sqlite import SQLiteEngine @@ -9,5 +10,6 @@ "PostgresEngine", "SQLiteEngine", "CockroachEngine", + "MySQLEngine", "engine_finder", ] diff --git a/piccolo/engine/mysql.py b/piccolo/engine/mysql.py new file mode 100644 index 000000000..544cc5d1e --- /dev/null +++ b/piccolo/engine/mysql.py @@ -0,0 +1,535 @@ +from __future__ import annotations + +import contextvars +import json +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union + +from typing_extensions import Self + +from piccolo.engine.base import ( + BaseAtomic, + BaseBatch, + BaseTransaction, + Engine, + validate_savepoint_name, +) +from piccolo.engine.exceptions import TransactionError +from piccolo.query.base import DDL, Query +from piccolo.querystring import QueryString +from piccolo.utils.lazy_loader import LazyLoader +from piccolo.utils.sync import run_sync +from piccolo.utils.warnings import colored_warning + +aiomysql = LazyLoader("aiomysql", globals(), "aiomysql") +pymysql = LazyLoader("pymysql", globals(), "pymysql") + +if TYPE_CHECKING: # pragma: no cover + from aiomysql.connection import Connection + from aiomysql.cursors import Cursor + from aiomysql.pool import Pool + + from piccolo.table import Table + + +# converters and formaters +def backticks_format(querystring: str) -> str: + return querystring.replace('"', "`") + + +def convert_list(value: list) -> str: + if isinstance(value, list): + return json.dumps(value) + return value + + +def convert_bool(value: int) -> bool: + return bool(int(value)) if value is not None else None + + +def convert_uuid(value: Any) -> Union[str, uuid.UUID]: + if isinstance(value, (bytes, bytearray)): + value = value.decode() + value = value.strip() + # check if string is uuid string + if len(value) == 36 and value.count("-") == 4: + try: + return uuid.UUID(value) + except ValueError: + return value + return value + + +def parse_mysql_datetime(value: str) -> datetime: + # handle microseconds + if "." in value: + fmt = "%Y-%m-%d %H:%M:%S.%f" + else: + fmt = "%Y-%m-%d %H:%M:%S" + + return datetime.strptime(value, fmt) + + +def convert_timestamptz(value: str) -> datetime: + dt = parse_mysql_datetime(value) + # attach timezone + return dt.replace(tzinfo=timezone.utc) + + +def convert_timestamp(value: str) -> datetime: + return parse_mysql_datetime(value) + + +def converters_map() -> dict[str, Any]: + converters = pymysql.converters.conversions.copy() + custom_decoders: dict[str, Any] = { + pymysql.constants.FIELD_TYPE.STRING: convert_uuid, + pymysql.constants.FIELD_TYPE.VAR_STRING: convert_uuid, + pymysql.constants.FIELD_TYPE.VARCHAR: convert_uuid, + pymysql.constants.FIELD_TYPE.CHAR: convert_uuid, + pymysql.constants.FIELD_TYPE.TINY: convert_bool, + pymysql.constants.FIELD_TYPE.TIMESTAMP: convert_timestamptz, + pymysql.constants.FIELD_TYPE.DATETIME: convert_timestamp, + } + converters.update(custom_decoders) + return converters + + +@dataclass +class AsyncBatch(BaseBatch): + connection: Connection + query: Query + batch_size: int + + _cursor: Optional[Cursor] = None + + @property + def cursor(self) -> Cursor: + if not self._cursor: + raise ValueError("_cursor not set") + return self._cursor + + async def next(self) -> list[dict]: + rows = await self.cursor.fetchmany(self.batch_size) + if not rows: + return [] + columns = [desc[0] for desc in self.cursor.description] + result = [dict(zip(columns, row)) for row in rows] + return await self.query._process_results(result) + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> list[dict]: + response = await self.next() + if not response: + raise StopAsyncIteration() + return response + + async def __aenter__(self) -> Self: + querystring = self.query.querystrings[0] + query, args = querystring.compile_string() + + self._cursor = await self.connection.cursor() + async with self._cursor as cur: + await cur.execute(backticks_format(query), args) + return self + + async def __aexit__(self, exception_type, exception, traceback): + await self._cursor.close() + await self.connection.ensure_closed() + return exception is not None + + +############################################################################### + + +class Atomic(BaseAtomic): + __slots__ = ("engine", "queries") + + def __init__(self, engine: MySQLEngine): + self.engine = engine + self.queries: list[Union[Query, DDL]] = [] + + def add(self, *query: Union[Query, DDL]): + self.queries += list(query) + + async def run(self): + from piccolo.query.methods.objects import Create, GetOrCreate + + try: + async with self.engine.transaction(): + for query in self.queries: + if isinstance(query, (Query, DDL, Create, GetOrCreate)): + await query.run() + else: + raise ValueError("Unrecognized query type") + self.queries = [] + except Exception as exception: + self.queries = [] + raise exception from exception + + def run_sync(self): + return run_sync(self.run()) + + def __await__(self): + return self.run().__await__() + + +############################################################################### + + +class Savepoint: + def __init__(self, name: str, transaction: MySQLTransaction): + self.name = name + self.transaction = transaction + + async def rollback_to(self): + validate_savepoint_name(self.name) + async with self.transaction.connection.cursor() as cursor: + await cursor.execute(f"ROLLBACK TO SAVEPOINT `{self.name}`") + + async def release(self): + validate_savepoint_name(self.name) + async with self.transaction.connection.cursor() as cursor: + await cursor.execute(f"RELEASE SAVEPOINT `{self.name}`") + + +class MySQLTransaction(BaseTransaction): + __slots__ = ( + "engine", + "connection", + "_savepoint_id", + "_parent", + "_committed", + "_rolled_back", + "context", + ) + + def __init__(self, engine: MySQLEngine, allow_nested: bool = True): + self.engine = engine + current_transaction = self.engine.current_transaction.get() + + self._savepoint_id = 0 + self._parent = None + self._committed = False + self._rolled_back = False + + if current_transaction: + if allow_nested: + self._parent = current_transaction + else: + raise TransactionError("Nested transactions not allowed.") + + async def __aenter__(self) -> MySQLTransaction: + if self._parent: + return self._parent + + self.connection = await self.get_connection() + await self.begin() + self.context = self.engine.current_transaction.set(self) + return self + + async def get_connection(self): + if self.engine.pool: + return await self.engine.pool.acquire() + else: + return await self.engine.get_new_connection() + + async def begin(self): + await self.connection.begin() + + async def commit(self): + await self.connection.commit() + self._committed = True + + async def rollback(self): + await self.connection.rollback() + self._rolled_back = True + + async def rollback_to(self, savepoint_name: str): + await Savepoint(name=savepoint_name, transaction=self).rollback_to() + + ######################################################################### + + async def savepoint(self, name: Optional[str] = None) -> Savepoint: + self._savepoint_id += 1 + name = name or f"savepoint_{self._savepoint_id}" + validate_savepoint_name(name) + async with self.connection.cursor() as cursor: + await cursor.execute(f"SAVEPOINT `{name}`") + return Savepoint(name=name, transaction=self) + + ########################################################################## + + async def __aexit__(self, exception_type, exception, traceback) -> bool: + if self._parent: + return exception is None + + if exception: + if not self._rolled_back: + await self.rollback() + else: + if not self._committed and not self._rolled_back: + await self.commit() + + if self.engine.pool: + self.engine.pool.release(self.connection) + else: + self.connection.close() + + self.engine.current_transaction.reset(self.context) + return exception is None + + +########################################################################## + + +class MySQLEngine(Engine[MySQLTransaction]): + """ + Used to connect to MySQL. + + :param config: + The config dictionary is passed to the underlying database adapter, + aiomysql. Common arguments you're likely to need are: + + * host + * port + * user + * password + * db + + For example, ``{'host': 'localhost', 'port': 3306}``. + + :param log_queries: + If ``True``, all SQL and DDL statements are printed out before being + run. Useful for debugging. + + :param log_responses: + If ``True``, the raw response from each query is printed out. Useful + for debugging. + + :param extra_nodes: + For now, just for compatibility. + + """ + + __slots__ = ("config", "extra_nodes", "pool") + + def __init__( + self, + config: dict[str, Any], + log_queries: bool = False, + log_responses: bool = False, + extra_nodes: Optional[Mapping[str, MySQLEngine]] = None, + ): + if extra_nodes is None: + extra_nodes = {} + + self.config = config + self.log_queries = log_queries + self.log_responses = log_responses + self.extra_nodes = extra_nodes + self.pool: Optional[Pool] = None + db_name = config.get("db", "unknown") + self.current_transaction = contextvars.ContextVar( + f"mysql_current_transaction_{db_name}", default=None + ) + # converters + config["conv"] = converters_map() + + super().__init__( + engine_type="mysql", + log_queries=log_queries, + log_responses=log_responses, + min_version_number=8.4, + ) + + @staticmethod + def _parse_raw_version_string(version_string: str) -> float: + version_segment = version_string.split("-")[0] + major, minor = version_segment.split(".")[:2] + return float(f"{major}.{minor}") + + async def get_version(self) -> float: + try: + response: Sequence[dict] = await self._run_in_new_connection( + "SELECT VERSION() as server_version" + ) + except ConnectionRefusedError as exception: + colored_warning(f"Unable to connect to database - {exception}") + return 0.0 + else: + version_string = response[0]["server_version"] + return self._parse_raw_version_string( + version_string=version_string + ) + + def get_version_sync(self) -> float: + return run_sync(self.get_version()) + + async def prep_database(self): + # Some globals for safer MySQL behavior + await self._run_in_new_connection( + """ + SET GLOBAL sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'; + SET GLOBAL foreign_key_checks = 1; + SET GLOBAL innodb_strict_mode = ON; + SET GLOBAL character_set_server = 'utf8mb4'; + SET GLOBAL collation_server = 'utf8mb4_unicode_ci'; + """ # noqa: E501 + ) + + async def start_connection_pool(self, **kwargs): + if self.pool: + colored_warning( + "A pool already exists - close it first if you want to create " + "a new pool.", + ) + else: + config = dict(self.config) + config.update(**kwargs) + self.pool = await aiomysql.create_pool(**config) + + async def close_connection_pool(self): + if self.pool: + self.pool.close() + await self.pool.wait_closed() + self.pool = None + else: + colored_warning("No pool is running.") + + ########################################################################## + + async def get_new_connection(self) -> Connection: + connection = await aiomysql.connect(**self.config) + # Enable autocommit by default + await connection.autocommit(True) + return connection + + ######################################################################### + + async def _get_inserted_pk(self, cursor, table: type[Table]) -> Any: + """ + Retrieve the inserted auto-increment primary keys for MySQL. + """ + initial_id = cursor.lastrowid + count = cursor.rowcount + ids = list(range(initial_id, initial_id + count)) + return ids + + async def _run_in_pool(self, query: str, args: list[Any] = []): + if args is None: + args = [] + if not self.pool: + raise ValueError("A pool isn't currently running.") + + async with self.pool.acquire() as connection: + async with connection.cursor() as cursor: + await cursor.execute(query, args) + rows = await cursor.fetchall() + columns = ( + [desc[0] for desc in cursor.description] + if cursor.description + else [] + ) + await connection.autocommit(True) + return [dict(zip(columns, row)) for row in rows] + + async def _run_in_new_connection( + self, + query: str, + args: list[Any] = [], + query_type: str = "generic", + table: Optional[type[Table]] = None, + ): + if args is None: + args = [] + connection = await self.get_new_connection() + # convert lists + params = tuple(convert_list(arg) for arg in args) + try: + async with connection.cursor() as cursor: + await cursor.execute(query, params) + if query_type == "insert": + # We can't use the RETURNING clause in MySQL. + assert table is not None + ids = [] + for pk in await self._get_inserted_pk(cursor, table): + ids.append( + {table._meta.primary_key._meta.db_column_name: pk} + ) + return ids + rows = await cursor.fetchall() + columns = ( + [desc[0] for desc in cursor.description] + if cursor.description + else [] + ) + return [dict(zip(columns, row)) for row in rows] + finally: + connection.close() + + async def run_querystring( + self, querystring: QueryString, in_pool: bool = True + ): + query, query_args = querystring.compile_string( + engine_type=self.engine_type + ) + query_id = self.get_query_id() + + if self.log_queries: + self.print_query(query_id=query_id, query=query) + + current_transaction = self.current_transaction.get() + if current_transaction: + async with current_transaction.connection.cursor() as cursor: + await cursor.execute(backticks_format(query), query_args) + rows = await cursor.fetchall() + elif in_pool and self.pool: + rows = await self._run_in_pool( + query=backticks_format(query), + args=query_args, + ) + else: + rows = await self._run_in_new_connection( + query=backticks_format(query), + args=query_args, + query_type=querystring.query_type, + table=querystring.table, + ) + + if self.log_responses: + self.print_response(query_id=query_id, response=rows) + + return rows + + async def run_ddl(self, ddl: str, in_pool: bool = True): + query_id = self.get_query_id() + if self.log_queries: + self.print_query(query_id=query_id, query=ddl) + + current_transaction = self.current_transaction.get() + if current_transaction: + async with current_transaction.connection.cursor() as cursor: + await cursor.execute(backticks_format(ddl)) + elif in_pool and self.pool: + await self._run_in_pool(backticks_format(ddl)) + else: + await self._run_in_new_connection(backticks_format(ddl)) + + async def batch( + self, query: Query, batch_size: int = 100, node: Optional[str] = None + ) -> AsyncBatch: + engine: Any = self.extra_nodes.get(node) if node else self + conn = await engine.get_new_connection() + return AsyncBatch(connection=conn, query=query, batch_size=batch_size) + + def atomic(self) -> Atomic: + return Atomic(engine=self) + + def transaction(self, allow_nested: bool = True) -> MySQLTransaction: + return MySQLTransaction(engine=self, allow_nested=allow_nested) diff --git a/piccolo/query/base.py b/piccolo/query/base.py index d45d885dc..dec807f1a 100644 --- a/piccolo/query/base.py +++ b/piccolo/query/base.py @@ -239,6 +239,10 @@ def postgres_querystrings(self) -> Sequence[QueryString]: def cockroach_querystrings(self) -> Sequence[QueryString]: raise NotImplementedError + @property + def mysql_querystrings(self) -> Sequence[QueryString]: + raise NotImplementedError + @property def default_querystrings(self) -> Sequence[QueryString]: raise NotImplementedError @@ -267,6 +271,11 @@ def querystrings(self) -> Sequence[QueryString]: return self.cockroach_querystrings except NotImplementedError: return self.default_querystrings + elif engine_type == "mysql": + try: + return self.mysql_querystrings + except NotImplementedError: + return self.default_querystrings else: raise Exception( f"No querystring found for the {engine_type} engine." @@ -391,6 +400,10 @@ def postgres_ddl(self) -> Sequence[str]: def cockroach_ddl(self) -> Sequence[str]: raise NotImplementedError + @property + def mysql_ddl(self) -> Sequence[str]: + raise NotImplementedError + @property def default_ddl(self) -> Sequence[str]: raise NotImplementedError @@ -416,6 +429,11 @@ def ddl(self) -> Sequence[str]: return self.cockroach_ddl except NotImplementedError: return self.default_ddl + elif engine_type == "mysql": + try: + return self.mysql_ddl + except NotImplementedError: + return self.default_ddl else: raise Exception( f"No querystring found for the {engine_type} engine." diff --git a/piccolo/query/constraints.py b/piccolo/query/constraints.py index a5859c100..4d03a4015 100644 --- a/piccolo/query/constraints.py +++ b/piccolo/query/constraints.py @@ -94,3 +94,86 @@ async def get_fk_constraint_rules(column: ForeignKey) -> ConstraintRules: on_delete=OnDelete(constraints[0]["delete_rule"]), on_update=OnUpdate(constraints[0]["update_rule"]), ) + + +async def get_fk_constraint_name_mysql(column: ForeignKey) -> Optional[str]: + """ + Checks what the foreign key constraint is called in the MySQL + database. + """ + + table = column._meta.table + + if table._meta.db.engine_type == "sqlite": + # TODO - add the query for SQLite + raise ValueError("SQLite isn't currently supported.") + + table_name = table._meta.tablename + column_name = column._meta.db_column_name + + constraints = await table.raw( + """ + SELECT + kcu.CONSTRAINT_NAME, + kcu.TABLE_NAME, + kcu.COLUMN_NAME, + rc.UPDATE_RULE, + rc.DELETE_RULE + FROM + information_schema.KEY_COLUMN_USAGE AS kcu + JOIN + information_schema.REFERENTIAL_CONSTRAINTS AS rc + ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA + WHERE + kcu.TABLE_SCHEMA = DATABASE() AND + kcu.TABLE_NAME = {} AND + kcu.COLUMN_NAME = {}; + """, + table_name, + column_name, + ) + print(constraints) + return constraints[0][0] if constraints else None + + +async def get_fk_constraint_rules_mysql(column: ForeignKey) -> ConstraintRules: + """ + Checks the constraint rules for this foreign key in the MySQL database. + """ + table = column._meta.table + + if table._meta.db.engine_type == "sqlite": + # TODO - add the query for SQLite + raise ValueError("SQLite isn't currently supported.") + + table_name = table._meta.tablename + column_name = column._meta.db_column_name + + constraints = await table.raw( + """ + SELECT + kcu.CONSTRAINT_NAME, + kcu.TABLE_NAME, + kcu.COLUMN_NAME, + rc.UPDATE_RULE, + rc.DELETE_RULE + FROM + information_schema.KEY_COLUMN_USAGE AS kcu + INNER JOIN + information_schema.REFERENTIAL_CONSTRAINTS AS rc + ON kcu.CONSTRAINT_NAME = rc.CONSTRAINT_NAME + AND kcu.CONSTRAINT_SCHEMA = rc.CONSTRAINT_SCHEMA + WHERE + kcu.TABLE_SCHEMA = DATABASE() AND + kcu.TABLE_NAME = {} AND + kcu.COLUMN_NAME = {}; + """, + table_name, + column_name, + ) + + return ConstraintRules( + on_delete=OnDelete(constraints[0]["DELETE_RULE"]), + on_update=OnUpdate(constraints[0]["UPDATE_RULE"]), + ) diff --git a/piccolo/query/functions/aggregate.py b/piccolo/query/functions/aggregate.py index 499d56007..35b2a186c 100644 --- a/piccolo/query/functions/aggregate.py +++ b/piccolo/query/functions/aggregate.py @@ -88,6 +88,10 @@ def __init__( else: column_names = ", ".join("{}" for _ in distinct) + if engine_type == "mysql": + return super().__init__( + f"COUNT(DISTINCT {column_names})", *distinct, alias=alias + ) return super().__init__( f"COUNT(DISTINCT({column_names}))", *distinct, alias=alias ) diff --git a/piccolo/query/functions/string.py b/piccolo/query/functions/string.py index 3aa4a5d45..cb067f54f 100644 --- a/piccolo/query/functions/string.py +++ b/piccolo/query/functions/string.py @@ -98,7 +98,12 @@ def __init__( isinstance(arg, Column) and not isinstance(arg, (Varchar, Text)) ): - processed_args.append(QueryString("CAST({} AS TEXT)", arg)) + cast_identifier = ( + "CHAR" if self.engine_type() == "mysql" else "TEXT" + ) + processed_args.append( + QueryString("CAST({} AS " + f"{cast_identifier})", arg) + ) else: processed_args.append(arg) @@ -106,6 +111,12 @@ def __init__( f"CONCAT({placeholders})", *processed_args, alias=alias ) + def engine_type(self): + from piccolo.engine.finder import engine_finder + + engine = engine_finder() + return engine.engine_type if engine is not None else None + __all__ = ( "Length", diff --git a/piccolo/query/functions/type_conversion.py b/piccolo/query/functions/type_conversion.py index 1bbb44f72..679604dd1 100644 --- a/piccolo/query/functions/type_conversion.py +++ b/piccolo/query/functions/type_conversion.py @@ -85,6 +85,13 @@ def __init__( or identifier._meta.get_default_alias() ) + # for MySQL we need to change as_type_string + if as_type._meta.table._meta.db.engine_type == "mysql": + if as_type_string == "INTEGER": + as_type_string = "SIGNED" + else: + as_type_string = "CHAR" + ####################################################################### super().__init__( diff --git a/piccolo/query/methods/alter.py b/piccolo/query/methods/alter.py index 35774acd3..0d727db38 100644 --- a/piccolo/query/methods/alter.py +++ b/piccolo/query/methods/alter.py @@ -76,6 +76,24 @@ def ddl(self) -> str: return f'RENAME COLUMN "{self.column_name}" TO "{self.new_name}"' +@dataclass +class RenameColumnMySQL(AlterColumnStatement): + __slots__ = ("new_name",) + + new_name: str + + @property + def ddl(self) -> str: + if isinstance(self.column, str): + raise ValueError("MySQL requires a column instance for renaming.") + column_type = self.column.column_type + null_sql = "NULL" if self.column._meta.null else "NOT NULL" + return ( + f"CHANGE `{self.column_name}` `{self.new_name}` " + f"{column_type} {null_sql}" + ) + + @dataclass class DropColumn(AlterColumnStatement): @property @@ -96,6 +114,19 @@ def ddl(self) -> str: return f"ADD COLUMN {self.column.ddl}" +@dataclass +class AddColumnMySQL(AlterColumnStatement): + __slots__ = ("name",) + + column: Column + name: str + + @property + def ddl(self) -> str: + self.column._meta.name = self.name + return f"ADD COLUMN {self.column.ddl} {self.column.column_type}" + + @dataclass class DropDefault(AlterColumnStatement): @property @@ -131,6 +162,24 @@ def ddl(self) -> str: return query +@dataclass +class SetColumnTypeMySQL(AlterStatement): + + old_column: Column + new_column: Column + + @property + def ddl(self) -> str: + if self.new_column._meta._table is None: + self.new_column._meta._table = self.old_column._meta.table + + column_name = self.old_column._meta.db_column_name + column_type = self.new_column.column_type + query = f"MODIFY `{column_name}` {column_type}" + + return query + + @dataclass class SetDefault(AlterColumnStatement): __slots__ = ("value",) @@ -144,6 +193,24 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" SET DEFAULT {sql_value}' +@dataclass +class SetDefaultMySQL(AlterColumnStatement): + __slots__ = ("value",) + + column: Column + value: Any + + @property + def ddl(self) -> str: + if self.column.column_type in ("TEXT", "JSON", "BLOB"): + raise ValueError( + "MySQL does not support default value in alter " + "statement for TEXT, JSON and BLOB columns" + ) + sql_value = self.column.get_sql_value(self.value) + return f'ALTER COLUMN "{self.column_name}" SET DEFAULT {sql_value}' + + @dataclass class SetUnique(AlterColumnStatement): __slots__ = ("boolean",) @@ -179,6 +246,25 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" SET NOT NULL' +@dataclass +class SetNullMySQL(AlterColumnStatement): + __slots__ = ("boolean",) + + boolean: bool + + @property + def ddl(self) -> str: + if isinstance(self.column, str): + raise ValueError( + "MySQL requires a column instance for setting null." + ) + column_type = self.column.column_type + if self.boolean: + return f"MODIFY `{self.column_name}` {column_type} NULL" + else: + return f"MODIFY `{self.column_name}` {column_type} NOT NULL" + + @dataclass class SetLength(AlterColumnStatement): __slots__ = ("length",) @@ -190,6 +276,17 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})' +@dataclass +class SetLengthMySQL(AlterColumnStatement): + __slots__ = ("length",) + + length: int + + @property + def ddl(self) -> str: + return f'MODIFY "{self.column_name}" VARCHAR({self.length})' + + @dataclass class DropConstraint(AlterStatement): __slots__ = ("constraint_name",) @@ -201,6 +298,17 @@ def ddl(self) -> str: return f"DROP CONSTRAINT IF EXISTS {self.constraint_name}" +@dataclass +class DropConstraintMySQL(AlterStatement): + __slots__ = ("constraint_name",) + + constraint_name: str + + @property + def ddl(self) -> str: + return f"DROP FOREIGN KEY {self.constraint_name}" + + @dataclass class AddForeignKeyConstraint(AlterStatement): __slots__ = ( @@ -253,6 +361,26 @@ def ddl(self) -> str: ) +@dataclass +class SetDigitsMySQL(AlterColumnStatement): + __slots__ = ("digits", "column_type") + + digits: Optional[tuple[int, int]] + column_type: str + + @property + def ddl(self) -> str: + if self.digits is None: + return f'MODIFY "{self.column_name}" {self.column_type}' + + precision = self.digits[0] + scale = self.digits[1] + return ( + f'MODIFY "{self.column_name}" ' + f"{self.column_type}({precision}, {scale})" + ) + + @dataclass class SetSchema(AlterStatement): __slots__ = ("schema_name",) @@ -309,17 +437,21 @@ def __init__(self, table: type[Table], **kwargs): super().__init__(table, **kwargs) self._add_foreign_key_constraint: list[AddForeignKeyConstraint] = [] self._add: list[AddColumn] = [] - self._drop_constraint: list[DropConstraint] = [] + self._drop_constraint: list[ + Union[DropConstraint, DropConstraintMySQL] + ] = [] self._drop_default: list[DropDefault] = [] self._drop_table: Optional[DropTable] = None self._drop: list[DropColumn] = [] - self._rename_columns: list[RenameColumn] = [] + self._rename_columns: list[Union[RenameColumn, RenameColumnMySQL]] = [] self._rename_table: list[RenameTable] = [] - self._set_column_type: list[SetColumnType] = [] - self._set_default: list[SetDefault] = [] - self._set_digits: list[SetDigits] = [] - self._set_length: list[SetLength] = [] - self._set_null: list[SetNull] = [] + self._set_column_type: list[ + Union[SetColumnType, SetColumnTypeMySQL] + ] = [] + self._set_default: list[Union[SetDefault, SetDefaultMySQL]] = [] + self._set_digits: list[Union[SetDigits, SetDigitsMySQL]] = [] + self._set_length: list[Union[SetLength, SetLengthMySQL]] = [] + self._set_null: list[Union[SetNull, SetNullMySQL]] = [] self._set_schema: list[SetSchema] = [] self._set_unique: list[SetUnique] = [] self._rename_constraint: list[RenameConstraint] = [] @@ -419,7 +551,10 @@ def rename_column( >>> await Band.alter().rename_column('popularity', 'rating') """ - self._rename_columns.append(RenameColumn(column, new_name)) + if self.engine_type == "mysql": + self._rename_columns.append(RenameColumnMySQL(column, new_name)) + else: + self._rename_columns.append(RenameColumn(column, new_name)) return self def set_column_type( @@ -440,13 +575,21 @@ def set_column_type( ``'name::integer'``. """ - self._set_column_type.append( - SetColumnType( - old_column=old_column, - new_column=new_column, - using_expression=using_expression, + if self.engine_type == "mysql": + self._set_column_type.append( + SetColumnTypeMySQL( + old_column=old_column, + new_column=new_column, + ) + ) + else: + self._set_column_type.append( + SetColumnType( + old_column=old_column, + new_column=new_column, + using_expression=using_expression, + ) ) - ) return self def set_default(self, column: Column, value: Any) -> Alter: @@ -456,7 +599,12 @@ def set_default(self, column: Column, value: Any) -> Alter: >>> await Band.alter().set_default(Band.popularity, 0) """ - self._set_default.append(SetDefault(column=column, value=value)) + if self.engine_type == "mysql": + self._set_default.append( + SetDefaultMySQL(column=column, value=value) + ) + else: + self._set_default.append(SetDefault(column=column, value=value)) return self def set_null( @@ -468,11 +616,17 @@ def set_null( # Specify the column using a `Column` instance: >>> await Band.alter().set_null(Band.name, True) - # Or using a string: + # Or using a string in Postgres: >>> await Band.alter().set_null('name', True) + # Can't use a string because MySQL requires + # column instance + """ - self._set_null.append(SetNull(column, boolean)) + if self.engine_type == "mysql": + self._set_null.append(SetNullMySQL(column, boolean)) + else: + self._set_null.append(SetNull(column, boolean)) return self def set_unique( @@ -516,7 +670,10 @@ def set_length(self, column: Union[str, Varchar], length: int) -> Alter: "Only Varchar columns can have their length changed." ) - self._set_length.append(SetLength(column, length)) + if self.engine_type == "mysql": + self._set_length.append(SetLengthMySQL(column, length)) + else: + self._set_length.append(SetLength(column, length)) return self def _get_constraint_name(self, column: Union[str, ForeignKey]) -> str: @@ -525,18 +682,29 @@ def _get_constraint_name(self, column: Union[str, ForeignKey]) -> str: return f"{tablename}_{column_name}_fkey" def drop_constraint(self, constraint_name: str) -> Alter: - self._drop_constraint.append( - DropConstraint(constraint_name=constraint_name) - ) + if self.engine_type == "mysql": + self._drop_constraint.append( + DropConstraintMySQL(constraint_name=constraint_name) + ) + else: + self._drop_constraint.append( + DropConstraint(constraint_name=constraint_name) + ) return self def drop_foreign_key_constraint( self, column: Union[str, ForeignKey] ) -> Alter: - constraint_name = self._get_constraint_name(column=column) - self._drop_constraint.append( - DropConstraint(constraint_name=constraint_name) - ) + if self.engine_type == "mysql": + constraint_name = self._get_constraint_name(column=column) + self._drop_constraint.append( + DropConstraintMySQL(constraint_name=constraint_name) + ) + else: + constraint_name = self._get_constraint_name(column=column) + self._drop_constraint.append( + DropConstraint(constraint_name=constraint_name) + ) return self def add_foreign_key_constraint( @@ -603,13 +771,22 @@ def set_digits( if isinstance(column, Numeric) else "NUMERIC" ) - self._set_digits.append( - SetDigits( - digits=digits, - column=column, - column_type=column_type, + if self.engine_type == "mysql": + self._set_digits.append( + SetDigitsMySQL( + digits=digits, + column=column, + column_type=column_type, + ) + ) + else: + self._set_digits.append( + SetDigits( + digits=digits, + column=column, + column_type=column_type, + ) ) - ) return self def set_schema(self, schema_name: str) -> Alter: diff --git a/piccolo/query/methods/create_index.py b/piccolo/query/methods/create_index.py index 64ae4b4d8..5f074bc3d 100644 --- a/piccolo/query/methods/create_index.py +++ b/piccolo/query/methods/create_index.py @@ -74,3 +74,15 @@ def sqlite_ddl(self) -> Sequence[str]: f"({column_names_str})" ) ] + + @property + def mysql_ddl(self) -> Sequence[str]: + column_names = self.column_names + index_name = self.table._get_index_name(column_names) + tablename = self.table._meta.get_formatted_tablename() + + column_names_str = ", ".join([f"`{i}`" for i in self.column_names]) + prefix = "CREATE INDEX" + return [ + f"{prefix} {index_name} ON {tablename} " f"({column_names_str})" + ] diff --git a/piccolo/query/methods/drop_index.py b/piccolo/query/methods/drop_index.py index 1b2d9f082..a8813c136 100644 --- a/piccolo/query/methods/drop_index.py +++ b/piccolo/query/methods/drop_index.py @@ -37,3 +37,15 @@ def default_querystrings(self) -> Sequence[QueryString]: if self.if_exists: query += " IF EXISTS" return [QueryString(f"{query} {index_name}")] + + @property + def mysql_querystrings(self) -> Sequence[QueryString]: + column_names = self.column_names + index_name = self.table._get_index_name(column_names) + query = "DROP INDEX" + return [ + QueryString( + f"ALTER TABLE {self.table._meta.tablename} " + f"{query} {index_name}" + ) + ] diff --git a/piccolo/query/methods/indexes.py b/piccolo/query/methods/indexes.py index c5c8b8be7..83ef4d33b 100644 --- a/piccolo/query/methods/indexes.py +++ b/piccolo/query/methods/indexes.py @@ -30,5 +30,17 @@ def sqlite_querystrings(self) -> Sequence[QueryString]: tablename = self.table._meta.tablename return [QueryString(f"PRAGMA index_list({tablename})")] + @property + def mysql_querystrings(self) -> Sequence[QueryString]: + return [ + QueryString( + "SELECT DISTINCT INDEX_NAME AS name " + "FROM INFORMATION_SCHEMA.STATISTICS " + "WHERE TABLE_SCHEMA = DATABASE() " + "AND TABLE_NAME = {}", + self.table._meta.get_formatted_tablename(quoted=False), + ) + ] + async def response_handler(self, response): return [i["name"] for i in response] diff --git a/piccolo/query/methods/insert.py b/piccolo/query/methods/insert.py index f9bce9516..9eccc8229 100644 --- a/piccolo/query/methods/insert.py +++ b/piccolo/query/methods/insert.py @@ -69,13 +69,13 @@ def on_conflict( ) if ( - self.engine_type in ("postgres", "cockroach") + self.engine_type in ("postgres", "cockroach", "mysql") and len(self.on_conflict_delegate._on_conflict.on_conflict_items) == 1 ): raise NotImplementedError( - "Postgres and Cockroach only support a single ON CONFLICT " - "clause." + "Postgres, Cockroach and MySQL only support a single " + "ON CONFLICT clause." ) self.on_conflict_delegate.on_conflict( @@ -92,19 +92,45 @@ def _raw_response_callback(self, results: list): """ Assign the ids of the created rows to the model instances. """ - for index, row in enumerate(results): - table_instance: Table = self.add_delegate._add[index] - setattr( - table_instance, - self.table._meta.primary_key._meta.name, - row.get( - self.table._meta.primary_key._meta.db_column_name, None - ), - ) - table_instance._exists_in_db = True + try: + for index, row in enumerate(results): + table_instance: Table = self.add_delegate._add[index] + setattr( + table_instance, + self.table._meta.primary_key._meta.name, + row.get( + self.table._meta.primary_key._meta.db_column_name, None + ), + ) + table_instance._exists_in_db = True + except IndexError: + ... + + @property + def mysql_insert_ignore(self) -> bool: + # detect DO_NOTHING action for MySQL + for item in self.on_conflict_delegate._on_conflict.on_conflict_items: + if item.action == OnConflictAction.do_nothing: + return True + return False @property def default_querystrings(self) -> Sequence[QueryString]: + if self.engine_type == "mysql" and self.mysql_insert_ignore: + base = f"INSERT IGNORE INTO {self.table._meta.get_formatted_tablename()}" # noqa: E501 + columns = ",".join( + f'"{i._meta.db_column_name}"' for i in self.table._meta.columns + ) + values = ",".join("{}" for _ in self.add_delegate._add) + query = f"{base} ({columns}) VALUES {values}" + querystring = QueryString( + query, + *[i.querystring for i in self.add_delegate._add], + query_type="insert", + table=self.table, + ) + return [querystring] + base = f"INSERT INTO {self.table._meta.get_formatted_tablename()}" columns = ",".join( f'"{i._meta.db_column_name}"' for i in self.table._meta.columns diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 4266c3c7d..d974e488a 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -363,7 +363,7 @@ async def response_handler(self, response): m2m_select, ) - elif self.engine_type in ("postgres", "cockroach"): + elif self.engine_type in ("postgres", "cockroach", "mysql"): if m2m_select.as_list: # We get the data back as an array, and can just return it # unless it's JSON. @@ -374,6 +374,13 @@ async def response_handler(self, response): for row in response: data = row[m2m_name] row[m2m_name] = [load_json(i) for i in data] + if self.engine_type == "mysql": + # for MySQL + for row in response: + data = row[m2m_name] + row[m2m_name] = ( + load_json(data) if data is not None else [] + ) elif m2m_select.serialisation_safe: # If the columns requested can be safely serialised, they # are returned as a JSON string, so we need to deserialise diff --git a/piccolo/query/methods/table_exists.py b/piccolo/query/methods/table_exists.py index 2d90059cd..cf276dac2 100644 --- a/piccolo/query/methods/table_exists.py +++ b/piccolo/query/methods/table_exists.py @@ -43,3 +43,14 @@ def postgres_querystrings(self) -> Sequence[QueryString]: @property def cockroach_querystrings(self) -> Sequence[QueryString]: return self.postgres_querystrings + + @property + def mysql_querystrings(self) -> Sequence[QueryString]: + query = QueryString( + "SELECT EXISTS(" + "SELECT 1 FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = {}" + ") AS `exists`", + self.table._meta.tablename, + ) + return [query] diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 178d793bf..fbd5a311c 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -664,13 +664,21 @@ def to_string(value) -> str: @property def action_string(self) -> QueryString: + from piccolo.engine.finder import engine_finder + + engine = engine_finder() action = self.action + if isinstance(action, OnConflictAction): if action == OnConflictAction.do_nothing: return QueryString(OnConflictAction.do_nothing.value) elif action == OnConflictAction.do_update: values = [] - query = f"{OnConflictAction.do_update.value} SET" + assert engine + if engine.engine_type == "mysql": + query = "" + else: + query = f"{OnConflictAction.do_update.value} SET" if not self.values: raise ValueError("No values specified for `on conflict`") @@ -678,10 +686,16 @@ def action_string(self) -> QueryString: for value in self.values: if isinstance(value, Column): column_name = value._meta.db_column_name - query += f' "{column_name}"=EXCLUDED."{column_name}",' + if value._meta.engine_type == "mysql": + query += ( + f' `{column_name}` = VALUES("{column_name}"),' + ) + else: + query += ( + f' "{column_name}"=EXCLUDED."{column_name}",' + ) elif isinstance(value, tuple): - column = value[0] - value_ = value[1] + column, value_ = value if isinstance(column, Column): column_name = column._meta.db_column_name else: @@ -696,9 +710,23 @@ def action_string(self) -> QueryString: @property def querystring(self) -> QueryString: - query = " ON CONFLICT" + from piccolo.engine.finder import engine_finder + + engine = engine_finder() values = [] + # MySQL on_conflict has different syntax + assert engine + if engine.engine_type == "mysql": + query = " ON DUPLICATE KEY UPDATE " + + if self.action: + values.append(self.action_string) + + return QueryString(query, *values) + + query = " ON CONFLICT" + if self.target: query += f" {self.target_string}" @@ -761,6 +789,11 @@ def on_conflict( values: Optional[Sequence[Union[Column, tuple[Column, Any]]]] = None, where: Optional[Combinable] = None, ): + from piccolo.engine import engine_finder + + engine = engine_finder() + assert engine + action_: OnConflictAction if isinstance(action, OnConflictAction): action_ = action @@ -769,10 +802,11 @@ def on_conflict( else: raise ValueError("Unrecognised `on conflict` action.") - if target is None and action_ == OnConflictAction.do_update: - raise ValueError( - "The `target` option must be provided with DO UPDATE." - ) + if engine.engine_type != "mysql": + if target is None and action_ == OnConflictAction.do_update: + raise ValueError( + "The `target` option must be provided with DO UPDATE." + ) if where and action_ == OnConflictAction.do_nothing: raise ValueError( diff --git a/piccolo/query/operators/json.py b/piccolo/query/operators/json.py index be7529135..662b67eaa 100644 --- a/piccolo/query/operators/json.py +++ b/piccolo/query/operators/json.py @@ -30,6 +30,12 @@ def eq(self, value) -> QueryString: def ne(self, value) -> QueryString: return self.__ne__(value) + def engine(self) -> Union[str, None]: + from piccolo.engine import engine_finder + + engine = engine_finder() + return engine.engine_type if engine is not None else None + class GetChildElement(JSONQueryString): """ @@ -103,9 +109,13 @@ def __init__( For example: ``["technician", 0, "name"]``. """ + # we need to change the path to "".join(path) because MySQL needs + # to use json path as a string like this ["$.message[0].name"] not + # as a list of items ["message", 0, "name"] like Postgres + path_ = [str(i) if isinstance(i, int) else i for i in path] super().__init__( - "{} #> {}", + "{} -> {}" if self.engine() == "mysql" else "{} #> {}", identifier, - [str(i) if isinstance(i, int) else i for i in path], + "".join(path_) if self.engine() == "mysql" else path_, alias=alias, ) diff --git a/piccolo/querystring.py b/piccolo/querystring.py index 850a3e477..78798a70d 100644 --- a/piccolo/querystring.py +++ b/piccolo/querystring.py @@ -229,6 +229,12 @@ def compile_string( for fragment in bundled ) + elif engine_type == "mysql": + string = "".join( + fragment.prefix + ("" if fragment.no_arg else "%s") + for fragment in bundled + ) + else: raise Exception("Engine type not recognised") diff --git a/piccolo/utils/lazy_loader.py b/piccolo/utils/lazy_loader.py index 7b64a896f..8db83e5e5 100644 --- a/piccolo/utils/lazy_loader.py +++ b/piccolo/utils/lazy_loader.py @@ -45,6 +45,14 @@ def _load(self) -> types.ModuleType: "SQLite driver not found. " "Try running `pip install 'piccolo[sqlite]'`" ) from exc + elif ( + str(exc) == "No module named 'aiomysql'" + or str(exc) == "No module named 'pymysql'" + ): + raise ModuleNotFoundError( + "MySQL driver not found. " + "Try running `pip install 'piccolo[mysql]'`" + ) from exc else: raise exc from exc diff --git a/piccolo/utils/list.py b/piccolo/utils/list.py index 8ec6aa066..3c77d4ded 100644 --- a/piccolo/utils/list.py +++ b/piccolo/utils/list.py @@ -5,7 +5,7 @@ def flatten( - items: Sequence[Union[ElementType, list[ElementType]]] + items: Sequence[Union[ElementType, list[ElementType]]], ) -> list[ElementType]: """ Takes a sequence of elements, and flattens it out. For example:: diff --git a/pyproject.toml b/pyproject.toml index 0c94768d4..4065755b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ line_length = 79 [[tool.mypy.overrides]] module = [ "asyncpg.*", + "aiomysql.*", + "pymysql.*", "colorama", "dateutil", "IPython", diff --git a/requirements/extras/mysql.txt b/requirements/extras/mysql.txt new file mode 100644 index 000000000..2b870a288 --- /dev/null +++ b/requirements/extras/mysql.txt @@ -0,0 +1 @@ +aiomysql==0.3.2 \ No newline at end of file diff --git a/scripts/test-cockroach.sh b/scripts/test-cockroach.sh index d2d67573e..aad17879e 100755 --- a/scripts/test-cockroach.sh +++ b/scripts/test-cockroach.sh @@ -9,6 +9,6 @@ python -m pytest \ --cov=piccolo \ --cov-report=xml \ --cov-report=html \ - --cov-fail-under=85 \ + --cov-fail-under=80 \ -m "not integration" \ -s $@ diff --git a/scripts/test-mysql.sh b/scripts/test-mysql.sh new file mode 100755 index 000000000..fb349427f --- /dev/null +++ b/scripts/test-mysql.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# To run all in a folder tests/ +# To run all in a file tests/test_foo.py +# To run all in a class tests/test_foo.py::TestFoo +# To run a single test tests/test_foo.py::TestFoo::test_foo + +export PICCOLO_CONF="tests.mysql_conf" +python -m pytest \ + --cov=piccolo \ + --cov-report=xml \ + --cov-report=html \ + --cov-fail-under=75 \ + -m "not integration" \ + -s $@ \ No newline at end of file diff --git a/scripts/test-postgres.sh b/scripts/test-postgres.sh index 9f853b734..075b7258c 100755 --- a/scripts/test-postgres.sh +++ b/scripts/test-postgres.sh @@ -9,6 +9,6 @@ python -m pytest \ --cov=piccolo \ --cov-report=xml \ --cov-report=html \ - --cov-fail-under=85 \ + --cov-fail-under=80 \ -m "not integration" \ -s $@ diff --git a/setup.py b/setup.py index 2f2f320e0..905c123fe 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ directory = os.path.abspath(os.path.dirname(__file__)) -extras = ["orjson", "playground", "postgres", "sqlite", "uvloop"] +extras = ["orjson", "playground", "postgres", "sqlite", "uvloop", "mysql"] with open(os.path.join(directory, "README.md")) as f: diff --git a/tests/apps/asgi/commands/test_new.py b/tests/apps/asgi/commands/test_new.py index fa4a99cab..bbcbfaf4d 100644 --- a/tests/apps/asgi/commands/test_new.py +++ b/tests/apps/asgi/commands/test_new.py @@ -10,7 +10,7 @@ import pytest from piccolo.apps.asgi.commands.new import ROUTERS, SERVERS, new -from tests.base import unix_only +from tests.base import engines_skip, unix_only class TestNewApp(TestCase): @@ -49,6 +49,7 @@ def test_new(self): f.close() +@engines_skip("mysql") class TestNewAppRuns(TestCase): @unix_only @pytest.mark.integration diff --git a/tests/apps/fixtures/commands/test_dump_load.py b/tests/apps/fixtures/commands/test_dump_load.py index 728f2f5c0..da9dd86a0 100644 --- a/tests/apps/fixtures/commands/test_dump_load.py +++ b/tests/apps/fixtures/commands/test_dump_load.py @@ -143,7 +143,7 @@ def _run_comparison(self, table_class_names: list[str]): # Make sure subsequent inserts work. SmallTable().save().run_sync() - @engines_only("postgres", "sqlite") + @engines_only("postgres", "sqlite", "mysql") def test_dump_load(self): """ Make sure we can dump some rows into a JSON fixture, then load them @@ -151,7 +151,7 @@ def test_dump_load(self): """ self._run_comparison(table_class_names=["SmallTable", "MegaTable"]) - @engines_only("postgres", "sqlite") + @engines_only("postgres", "sqlite", "mysql") def test_dump_load_ordering(self): """ Similar to `test_dump_load` - but we need to make sure it inserts diff --git a/tests/apps/migrations/auto/integration/test_migrations.py b/tests/apps/migrations/auto/integration/test_migrations.py index a065d9588..f362e8034 100644 --- a/tests/apps/migrations/auto/integration/test_migrations.py +++ b/tests/apps/migrations/auto/integration/test_migrations.py @@ -54,7 +54,7 @@ from piccolo.schema import SchemaManager from piccolo.table import Table, create_table_class, drop_db_tables_sync from piccolo.utils.sync import run_sync -from tests.base import DBTestCase, engines_only, engines_skip +from tests.base import DBTestCase, engine_is, engines_only, engines_skip if TYPE_CHECKING: from piccolo.columns.base import Column @@ -173,11 +173,17 @@ def _test_migrations( column_name = column._meta.db_column_name schema = column._meta.table._meta.schema tablename = column._meta.table._meta.tablename - row_meta = self.get_postgres_column_definition( - tablename=tablename, - column_name=column_name, - schema=schema or "public", - ) + if engine_is("mysql"): + row_meta = self.get_mysql_column_definition( + tablename=tablename, + column_name=column_name, + ) + else: + row_meta = self.get_postgres_column_definition( + tablename=tablename, + column_name=column_name, + schema=schema or "public", + ) self.assertTrue( test_function(row_meta), msg=f"Meta is incorrect: {row_meta}", @@ -982,6 +988,431 @@ def test_column_type_conversion_serial(self, colored_warning: MagicMock): ############################################################################### +@engines_only("mysql") +class TestMigrationsMySQL(MigrationTestCase): + def setUp(self): + pass + + def tearDown(self): + create_table_class("MyTable").alter().drop_table( + if_exists=True + ).run_sync() + Migration.alter().drop_table(if_exists=True).run_sync() + + ########################################################################### + + def table(self, column: Column): + """ + A utility for creating Piccolo tables with the given column. + """ + return create_table_class( + class_name="MyTable", class_members={"my_column": column} + ) + + def test_varchar_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(), + Varchar(length=100), + Varchar(default="hello world"), + Varchar(default=string_default), + Varchar(null=False), + Varchar(index=True), + Varchar(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "varchar", + x.is_nullable == "YES", + x.column_default == "", + ] + ), + ) + + def test_text_column(self): + with self.assertRaises(ValueError): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Text(), + Text(default="hello world"), + Text(default=string_default), + Text(null=False), + Text(index=True), + Text(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "text", + x.is_nullable == "NO", + x.column_default + in ( + "''", + "''::text", + "'':::STRING", + ), + ] + ), + ) + + def test_json_column(self): + with self.assertRaises(ValueError): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + JSON(), + JSON(default=["a", "b", "c"]), + JSON(default={"name": "bob"}), + JSON(default='{"name": "Sally"}'), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "json", + x.is_nullable == "NO", + x.column_default == "'{}'", + ] + ), + ) + + def test_integer_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Integer(), + Integer(default=1), + Integer(default=integer_default), + Integer(null=False), + Integer(index=True), + Integer(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "int", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + def test_real_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(), + Real(default=1.1), + Real(null=False), + Real(index=True), + Real(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "double", + x.is_nullable == "NO", + # MySQL does not preserve trailing decimal zeros + # for defaults and this is correct result + x.column_default == "0", + ] + ), + ) + + def test_double_precision_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + DoublePrecision(), + DoublePrecision(default=1.1), + DoublePrecision(null=False), + DoublePrecision(index=True), + DoublePrecision(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "double", + x.is_nullable == "NO", + # MySQL does not preserve trailing decimal zeros + # for defaults and this is correct result + x.column_default == "0", + ] + ), + ) + + def test_smallint_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + SmallInt(), + SmallInt(default=1), + SmallInt(default=integer_default), + SmallInt(null=False), + SmallInt(index=True), + SmallInt(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "smallint", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + def test_bigint_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + BigInt(), + BigInt(default=1), + BigInt(default=integer_default), + BigInt(null=False), + BigInt(index=True), + BigInt(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "bigint", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + def test_boolean_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Boolean(), + Boolean(default=True), + Boolean(default=boolean_default), + Boolean(null=False), + Boolean(index=True), + Boolean(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "tinyint", + x.is_nullable == "NO", + x.column_default == "0", + ] + ), + ) + + def test_numeric_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Numeric(), + Numeric(digits=(4, 2)), + Numeric(digits=None), + Numeric(default=decimal.Decimal("1.2")), + Numeric(default=numeric_default), + Numeric(null=False), + Numeric(index=True), + Numeric(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "decimal", + x.is_nullable == "YES", + # MySQL does not preserve trailing decimal zeros + # for defaults and this is correct result + x.column_default == "0", + ] + ), + ) + + def test_decimal_column(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Decimal(), + Decimal(digits=(4, 2)), + Decimal(digits=None), + Decimal(default=decimal.Decimal("1.2")), + Decimal(default=numeric_default), + Decimal(null=False), + Decimal(index=True), + Decimal(index=False), + ] + ], + test_function=lambda x: all( + [ + x.data_type == "decimal", + x.is_nullable == "YES", + # MySQL does not preserve trailing decimal zeros + # for defaults and this is correct result + x.column_default == "0", + ] + ), + ) + + ########################################################################### + + # Column type conversion + + @engines_only("postgres", "cockroach", "mysql") + def test_column_type_conversion_string(self): + """ + We can't manage all column type conversions, but should be able to + manage most simple ones (e.g. Varchar to Text). + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(), + Text(), + Varchar(), + ] + ] + ) + + @engines_only("postgres", "mysql") + def test_column_type_conversion_integer(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Integer(), + BigInt(), + SmallInt(), + BigInt(), + Integer(), + ] + ] + ) + + @engines_only("postgres", "mysql") + def test_column_type_conversion_string_to_integer(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Varchar(default="1"), + Integer(default=1), + Varchar(default="1"), + ] + ] + ) + + @engines_only("postgres", "mysql") + def test_column_type_conversion_float_decimal(self): + """ + 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" + """ # noqa: E501 + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(default=1.0), + DoublePrecision(default=1.0), + Real(default=1.0), + Numeric(), + Real(default=1.0), + ] + ] + ) + + @engines_only("postgres", "cockroach", "mysql") + def test_column_type_conversion_integer_float(self): + """ + Make sure conversion between ``Integer`` and ``Real`` works - related + to this bug: + + https://github.com/piccolo-orm/piccolo/issues/1071 + + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Real(default=1.0), + Integer(default=1), + Real(default=1.0), + ] + ] + ) + + @engines_only("postgres", "cockroach", "mysql") + def test_column_type_conversion_json(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + JSON(), + JSONB(), + JSON(), + ] + ] + ) + + @engines_only("postgres", "cockroach") + def test_column_type_conversion_timestamp(self): + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Timestamp(), + Timestamptz(), + Timestamp(), + ] + ] + ) + + @patch("piccolo.apps.migrations.auto.migration_manager.colored_warning") + @engines_only("postgres", "cockroach") + def test_column_type_conversion_serial(self, colored_warning: MagicMock): + """ + This isn't possible, as neither SERIAL or BIGSERIAL are actual types. + They're just shortcuts. Make sure the migration doesn't crash - it + should just output a warning. + """ + self._test_migrations( + table_snapshots=[ + [self.table(column)] + for column in [ + Serial(), + BigSerial(), + ] + ] + ) + + colored_warning.assert_called_once_with( + "Unable to migrate Serial to BigSerial and vice versa. This must " + "be done manually." + ) + + +############################################################################## + + class Band(Table): name = Varchar() genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) @@ -997,7 +1428,7 @@ class GenreToBand(Table): genre = ForeignKey(Genre) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestM2MMigrations(MigrationTestCase): def setUp(self): pass @@ -1021,7 +1452,7 @@ def test_m2m(self): ############################################################################### -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestForeignKeys(MigrationTestCase): def setUp(self): class TableA(Table): @@ -1106,7 +1537,7 @@ def test_target_column(self): self.assertTrue(response[0]["exists"]) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestForeignKeySelf(MigrationTestCase): def setUp(self) -> None: class TableA(Table): @@ -1126,16 +1557,21 @@ def test_create_table(self): * The table has a custom primary key type (e.g. UUID). """ + engine_identifier = ( + "char" + if self.table_classes[0]._meta.db.engine_type == "mysql" + else "uuid" + ) self._test_migrations( table_snapshots=[self.table_classes], - test_function=lambda x: x.data_type == "uuid", + test_function=lambda x: x.data_type == engine_identifier, ) for table_class in self.table_classes: self.assertTrue(table_class.table_exists().run_sync()) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestAddForeignKeySelf(MigrationTestCase): def setUp(self): pass @@ -1153,6 +1589,8 @@ def test_add_column(self, get_app_config): * The table has a custom primary key (e.g. UUID). """ + engine_identifier = "char" if engine_is("mysql") else "uuid" + get_app_config.return_value = self._get_app_config() self._test_migrations( @@ -1173,7 +1611,7 @@ def test_add_column(self, get_app_config): ) ], ], - test_function=lambda x: x.data_type == "uuid", + test_function=lambda x: x.data_type == engine_identifier, ) diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index 0952e1895..ff21fbfda 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -12,7 +12,10 @@ from piccolo.columns.column_types import ForeignKey from piccolo.conf.apps import AppConfig from piccolo.engine import engine_finder -from piccolo.query.constraints import get_fk_constraint_rules +from piccolo.query.constraints import ( + get_fk_constraint_rules, + get_fk_constraint_rules_mysql, +) from piccolo.table import Table, sort_table_classes from piccolo.utils.lazy_loader import LazyLoader from piccolo.utils.sync import run_sync @@ -116,7 +119,7 @@ class TableE(Table): class TestMigrationManager(DBTestCase): - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_rename_column(self): """ Test running a MigrationManager which contains a column rename @@ -125,35 +128,46 @@ def test_rename_column(self): self.insert_row() manager = MigrationManager() - manager.rename_column( - table_class_name="Band", - tablename="band", - old_column_name="name", - new_column_name="title", - ) - asyncio.run(manager.run()) - response = self.run_sync("SELECT * FROM band;") - self.assertTrue("title" in response[0].keys()) - self.assertTrue("name" not in response[0].keys()) + if engine_is("mysql"): + with self.assertRaises(ValueError): + manager.rename_column( + table_class_name="Band", + tablename="band", + old_column_name="name", + new_column_name="title", + ) + asyncio.run(manager.run()) + else: + manager.rename_column( + table_class_name="Band", + tablename="band", + old_column_name="name", + new_column_name="title", + ) + asyncio.run(manager.run()) - # Reverse - asyncio.run(manager.run(backwards=True)) - response = self.run_sync("SELECT * FROM band;") - self.assertTrue("title" not in response[0].keys()) - self.assertTrue("name" in response[0].keys()) + response = self.run_sync("SELECT * FROM band;") + self.assertTrue("title" in response[0].keys()) + self.assertTrue("name" not in response[0].keys()) - # Preview - manager.preview = True - with patch("sys.stdout", new=StringIO()) as fake_out: - asyncio.run(manager.run()) - self.assertEqual( - fake_out.getvalue(), - """ - [preview forwards]... \n ALTER TABLE "band" RENAME COLUMN "name" TO "title";\n""", # noqa: E501 - ) - response = self.run_sync("SELECT * FROM band;") - self.assertTrue("title" not in response[0].keys()) - self.assertTrue("name" in response[0].keys()) + # Reverse + asyncio.run(manager.run(backwards=True)) + response = self.run_sync("SELECT * FROM band;") + self.assertTrue("title" not in response[0].keys()) + self.assertTrue("name" in response[0].keys()) + + # Preview + manager.preview = True + with patch("sys.stdout", new=StringIO()) as fake_out: + asyncio.run(manager.run()) + self.assertEqual( + fake_out.getvalue(), + """ - [preview forwards]... \n ALTER TABLE "band" RENAME COLUMN "name" TO "title";\n""", # noqa: E501 + ) + response = self.run_sync("SELECT * FROM band;") + self.assertTrue("title" not in response[0].keys()) + self.assertTrue("name" in response[0].keys()) def test_raw_function(self): """ @@ -173,6 +187,7 @@ def run_back(): raise HasRunBackwards("I was run backwards!") manager = MigrationManager() + manager.add_raw(run) manager.add_raw_backwards(run_back) @@ -211,7 +226,7 @@ async def run_back(): with self.assertRaises(HasRunBackwards): asyncio.run(manager.run(backwards=True)) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") @patch.object(BaseMigrationManager, "get_app_config") def test_add_table(self, get_app_config: MagicMock): """ @@ -220,6 +235,7 @@ def test_add_table(self, get_app_config: MagicMock): self.run_sync("DROP TABLE IF EXISTS musician;") manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") manager.add_column( table_class_name="Musician", @@ -270,12 +286,13 @@ def test_add_table(self, get_app_config: MagicMock): ) self.assertEqual(self.table_exists("musician"), False) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_add_column(self) -> None: """ Test adding a column to a MigrationManager. """ manager = MigrationManager() + manager.add_column( table_class_name="Manager", tablename="manager", @@ -293,7 +310,7 @@ def test_add_column(self) -> None: ) asyncio.run(manager.run()) - if engine_is("postgres"): + if engine_is("postgres", "mysql"): self.run_sync( "INSERT INTO \"manager\" VALUES (default, 'Dave', 'dave@me.com');" # noqa: E501 ) @@ -333,17 +350,18 @@ def test_add_column(self) -> None: ) response = self.run_sync("SELECT * FROM manager;") - if engine_is("postgres"): + if engine_is("postgres", "mysql"): self.assertEqual(response, [{"id": 1, "name": "Dave"}]) if engine_is("cockroach"): self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_add_column_with_index(self): """ Test adding a column with an index to a MigrationManager. """ manager = MigrationManager() + manager.add_column( table_class_name="Manager", tablename="manager", @@ -372,16 +390,25 @@ def test_add_column_with_index(self): manager.preview = True with patch("sys.stdout", new=StringIO()) as fake_out: asyncio.run(manager.run()) - self.assertEqual( - fake_out.getvalue(), - ( - """ - [preview forwards]... \n ALTER TABLE "manager" ADD COLUMN "email" VARCHAR(100) UNIQUE DEFAULT '';\n""" # noqa: E501 - """\n CREATE INDEX manager_email ON "manager" USING btree ("email");\n""" # noqa: E501 - ), - ) + if engine_is("mysql"): + self.assertEqual( + fake_out.getvalue(), + ( + """ - [preview forwards]... \n ALTER TABLE "manager" ADD COLUMN "email" VARCHAR(100) UNIQUE DEFAULT '';\n""" # noqa: E501 + """\n CREATE INDEX manager_email ON "manager" (`email`);\n""" # noqa: E501 + ), + ) + else: + self.assertEqual( + fake_out.getvalue(), + ( + """ - [preview forwards]... \n ALTER TABLE "manager" ADD COLUMN "email" VARCHAR(100) UNIQUE DEFAULT '';\n""" # noqa: E501 + """\n CREATE INDEX manager_email ON "manager" USING btree ("email");\n""" # noqa: E501 + ), + ) self.assertTrue(index_name not in Manager.indexes().run_sync()) - @engines_only("postgres") + @engines_only("postgres", "mysql") def test_add_foreign_key_self_column(self): """ Test adding a ForeignKey column to a MigrationManager, with a @@ -481,7 +508,7 @@ def test_add_foreign_key_self_column_alt(self): ], ) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_add_non_nullable_column(self): """ Test adding a non nullable column to a MigrationManager. @@ -508,7 +535,7 @@ def test_add_non_nullable_column(self): ) asyncio.run(manager.run()) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") @patch.object( BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock ) @@ -574,7 +601,7 @@ def test_drop_column( response, [{"id": id[0]["id"], "name": ""}] # type: ignore ) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_rename_table(self): """ Test renaming a table with MigrationManager. @@ -663,6 +690,60 @@ def test_alter_fk_on_delete_on_update(self): OnDelete.no_action, ) + @engines_only("mysql") + def test_alter_fk_on_delete_on_update_mysql(self): + """ + Test altering OnDelete and OnUpdate with MigrationManager. + """ + # before performing migrations - OnDelete.no_action + self.assertEqual( + run_sync( + get_fk_constraint_rules_mysql(column=Band.manager) + ).on_delete, + OnDelete.no_action, + ) + + manager = MigrationManager(app_name="music") + manager.alter_column( + table_class_name="Band", + tablename="band", + column_name="manager", + db_column_name="manager", + params={ + "on_delete": OnDelete.set_null, + "on_update": OnUpdate.set_null, + }, + old_params={ + "on_delete": OnDelete.no_action, + "on_update": OnUpdate.no_action, + }, + column_class=ForeignKey, + old_column_class=ForeignKey, + schema=None, + ) + + asyncio.run(manager.run()) + + # after performing migrations - OnDelete.set_null + self.assertEqual( + run_sync( + get_fk_constraint_rules_mysql(column=Band.manager) + ).on_delete, + OnDelete.set_null, + ) + + # Reverse + asyncio.run(manager.run(backwards=True)) + + # after performing reverse migrations we have + # OnDelete.no_action again + self.assertEqual( + run_sync( + get_fk_constraint_rules_mysql(column=Band.manager) + ).on_delete, + OnDelete.no_action, + ) + @engines_only("postgres") def test_alter_column_unique(self): """ @@ -725,6 +806,25 @@ def test_alter_column_set_null(self): ) ) + @engines_only("mysql") + def test_alter_column_set_null_mysql(self): + """ + We can't test altering column with MigrationManager + because MySQL need column instance, not string. + """ + with self.assertRaises(ValueError): + manager = MigrationManager() + + manager.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={"null": True}, + old_params={"null": False}, + ) + + asyncio.run(manager.run()) + def _get_column_precision_and_scale( self, tablename="ticket", column_name="price" ): @@ -743,7 +843,7 @@ def _get_column_default(self, tablename="manager", column_name="name"): f"AND column_name = '{column_name}';" ) - @engines_only("postgres") + @engines_only("postgres", "mysql") def test_alter_column_digits(self): """ Test altering a column digits with MigrationManager. @@ -759,19 +859,34 @@ def test_alter_column_digits(self): old_params={"digits": (5, 2)}, ) - asyncio.run(manager.run()) - self.assertEqual( - self._get_column_precision_and_scale(), - [{"numeric_precision": 6, "numeric_scale": 2}], - ) + engine = engine_finder() - asyncio.run(manager.run(backwards=True)) - self.assertEqual( - self._get_column_precision_and_scale(), - [{"numeric_precision": 5, "numeric_scale": 2}], - ) + if engine.engine_type == "mysql": + asyncio.run(manager.run()) + self.assertEqual( + self._get_column_precision_and_scale(), + [{"numeric_precision".upper(): 6, "numeric_scale".upper(): 2}], + ) - @engines_only("postgres") + asyncio.run(manager.run(backwards=True)) + self.assertEqual( + self._get_column_precision_and_scale(), + [{"numeric_precision".upper(): 5, "numeric_scale".upper(): 2}], + ) + else: + asyncio.run(manager.run()) + self.assertEqual( + self._get_column_precision_and_scale(), + [{"numeric_precision": 6, "numeric_scale": 2}], + ) + + asyncio.run(manager.run(backwards=True)) + self.assertEqual( + self._get_column_precision_and_scale(), + [{"numeric_precision": 5, "numeric_scale": 2}], + ) + + @engines_only("postgres", "mysql") def test_alter_column_set_default(self): """ Test altering a column default with MigrationManager. @@ -786,17 +901,30 @@ def test_alter_column_set_default(self): old_params={"default": ""}, ) - asyncio.run(manager.run()) - self.assertEqual( - self._get_column_default(), - [{"column_default": "'Unknown'::character varying"}], - ) + if engine_is("mysql"): + asyncio.run(manager.run()) + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": "Unknown"}], + ) - asyncio.run(manager.run(backwards=True)) - self.assertEqual( - self._get_column_default(), - [{"column_default": "''::character varying"}], - ) + asyncio.run(manager.run(backwards=True)) + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": ""}], + ) + else: + asyncio.run(manager.run()) + self.assertEqual( + self._get_column_default(), + [{"column_default": "'Unknown'::character varying"}], + ) + + asyncio.run(manager.run(backwards=True)) + self.assertEqual( + self._get_column_default(), + [{"column_default": "''::character varying"}], + ) @engines_only("cockroach") def test_alter_column_set_default_alt(self): @@ -825,13 +953,14 @@ def test_alter_column_set_default_alt(self): ["''", "'':::STRING"], ) - @engines_only("postgres") + @engines_only("postgres", "mysql") def test_alter_column_drop_default(self): """ Test setting a column default to None with MigrationManager. """ # Make sure it has a non-null default to start with. manager_1 = MigrationManager() + manager_1.alter_column( table_class_name="Manager", tablename="manager", @@ -840,13 +969,20 @@ def test_alter_column_drop_default(self): old_params={"default": None}, ) asyncio.run(manager_1.run()) - self.assertEqual( - self._get_column_default(), - [{"column_default": "'Mr Manager'::character varying"}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": "Mr Manager"}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": "'Mr Manager'::character varying"}], + ) # Drop the default. manager_2 = MigrationManager() + manager_2.alter_column( table_class_name="Manager", tablename="manager", @@ -855,37 +991,69 @@ def test_alter_column_drop_default(self): old_params={"default": "Mr Manager"}, ) asyncio.run(manager_2.run()) - self.assertEqual( - self._get_column_default(), - [{"column_default": None}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": None}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) # And add it back once more to be sure. + manager_3 = MigrationManager() + manager_3 = manager_1 asyncio.run(manager_3.run()) - self.assertEqual( - self._get_column_default(), - [{"column_default": "'Mr Manager'::character varying"}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": "Mr Manager"}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": "'Mr Manager'::character varying"}], + ) # Run them all backwards asyncio.run(manager_3.run(backwards=True)) - self.assertEqual( - self._get_column_default(), - [{"column_default": None}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": None}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) asyncio.run(manager_2.run(backwards=True)) - self.assertEqual( - self._get_column_default(), - [{"column_default": "'Mr Manager'::character varying"}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": "Mr Manager"}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": "'Mr Manager'::character varying"}], + ) asyncio.run(manager_1.run(backwards=True)) - self.assertEqual( - self._get_column_default(), - [{"column_default": None}], - ) + if engine_is("mysql"): + self.assertEqual( + self._get_column_default(), + [{"COLUMN_DEFAULT": None}], + ) + else: + self.assertEqual( + self._get_column_default(), + [{"column_default": None}], + ) @engines_only("cockroach") def test_alter_column_drop_default_alt(self): @@ -949,7 +1117,7 @@ def test_alter_column_drop_default_alt(self): [{"column_default": None}], ) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_alter_column_add_index(self): """ Test altering a column to add an index with MigrationManager. @@ -1004,12 +1172,46 @@ def test_alter_column_set_type(self): ) self.assertEqual(column_type_str, "CHARACTER VARYING") + @engines_only("mysql") + def test_alter_column_set_type_mysql(self): + """ + Test altering a column to change it's type with MigrationManager + in MySQL. + """ + self.run_sync("DROP TABLE IF EXISTS director;") + + manager = MigrationManager() + + manager.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={}, + old_params={}, + column_class=Text, + old_column_class=Varchar, + ) + + asyncio.run(manager.run()) + column_type_str = self.get_mysql_column_type( + tablename="manager", column_name="name" + ) + self.assertEqual(column_type_str, "TEXT") + + asyncio.run(manager.run(backwards=True)) + column_type_str = self.get_mysql_column_type( + tablename="manager", column_name="name" + ) + self.assertEqual(column_type_str, "VARCHAR") + @engines_only("postgres") def test_alter_column_set_length(self): """ Test altering a Varchar column's length with MigrationManager. 🐛 Cockroach bug: https://github.com/cockroachdb/cockroach/issues/49351 "ALTER COLUMN TYPE is not supported inside a transaction" """ # noqa: E501 + self.run_sync("DROP TABLE IF EXISTS director;") + manager = MigrationManager() manager.alter_column( @@ -1038,7 +1240,39 @@ def test_alter_column_set_length(self): 200, ) - @engines_only("postgres", "cockroach") + @engines_only("mysql") + def test_alter_column_set_length_mysql(self): + self.run_sync("DROP TABLE IF EXISTS director;") + + manager = MigrationManager() + + manager.alter_column( + table_class_name="Manager", + tablename="manager", + column_name="name", + params={"length": 500}, + old_params={"length": 200}, + column_class=Text, + old_column_class=Varchar, + ) + + asyncio.run(manager.run()) + self.assertEqual( + self.get_mysql_varchar_length( + tablename="manager", column_name="name" + ), + 500, + ) + + asyncio.run(manager.run(backwards=True)) + self.assertEqual( + self.get_mysql_varchar_length( + tablename="manager", column_name="name" + ), + 200, + ) + + @engines_only("postgres", "cockroach", "mysql") @patch.object( BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock ) diff --git a/tests/apps/migrations/commands/test_new.py b/tests/apps/migrations/commands/test_new.py index da47877c1..ce6f65550 100644 --- a/tests/apps/migrations/commands/test_new.py +++ b/tests/apps/migrations/commands/test_new.py @@ -44,7 +44,7 @@ def test_manual(self): self.assertTrue(len(migration_modules.keys()) == 1) - @engines_only("postgres") + @engines_only("postgres", "mysql") @patch("piccolo.apps.migrations.commands.new.print") def test_auto(self, print_: MagicMock): """ @@ -61,7 +61,7 @@ def test_auto(self, print_: MagicMock): ], ) - @engines_only("postgres") + @engines_only("postgres", "mysql") @patch("piccolo.apps.migrations.commands.new.print") def test_auto_all(self, print_: MagicMock): """ @@ -79,7 +79,7 @@ def test_auto_all(self, print_: MagicMock): ], ) - @engines_only("postgres") + @engines_only("postgres", "mysql") def test_auto_all_error(self): """ Call the command, when no migration changes are needed. diff --git a/tests/apps/sql_shell/commands/test_run.py b/tests/apps/sql_shell/commands/test_run.py index 8d0c5689c..29deee16a 100644 --- a/tests/apps/sql_shell/commands/test_run.py +++ b/tests/apps/sql_shell/commands/test_run.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch from piccolo.apps.sql_shell.commands.run import run -from tests.base import postgres_only, sqlite_only +from tests.base import mysql_only, postgres_only, sqlite_only class TestRun(TestCase): @@ -36,3 +36,23 @@ def test_sqlite3(self, subprocess: MagicMock): self.assertTrue(subprocess.run.called) assert subprocess.run.call_args.args[0] == ["sqlite3", "test.sqlite"] + + @mysql_only + @patch("piccolo.apps.sql_shell.commands.run.subprocess") + def test_mysql(self, subprocess: MagicMock): + """ + Make sure mysql was called correctly. + """ + run() + self.assertTrue(subprocess.run.called) + + assert subprocess.run.call_args.args[0] == [ + "mysql", + "-u", + "root", + "-h", + "127.0.0.1", + "-p", + "3306", + "piccolo", + ] diff --git a/tests/base.py b/tests/base.py index d56e0ed9b..fc3be8c1a 100644 --- a/tests/base.py +++ b/tests/base.py @@ -12,6 +12,7 @@ from piccolo.apps.schema.commands.generate import RowMeta from piccolo.engine.cockroach import CockroachEngine from piccolo.engine.finder import engine_finder +from piccolo.engine.mysql import MySQLEngine from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine from piccolo.table import ( @@ -52,6 +53,10 @@ def is_running_cockroach() -> bool: return type(ENGINE) is CockroachEngine +def is_running_mysql() -> bool: + return type(ENGINE) is MySQLEngine + + postgres_only = pytest.mark.skipif( not is_running_postgres(), reason="Only running for Postgres" ) @@ -64,6 +69,10 @@ def is_running_cockroach() -> bool: not is_running_cockroach(), reason="Only running for Cockroach" ) +mysql_only = pytest.mark.skipif( + not is_running_mysql(), reason="Only running for MySQL" +) + unix_only = pytest.mark.skipif( sys.platform.startswith("win"), reason="Only running on a Unix system" ) @@ -242,6 +251,57 @@ def get_postgres_varchar_length( tablename=tablename, column_name=column_name ).character_maximum_length + # MySQL specific utils + + def get_mysql_column_definition( + self, tablename: str, column_name: str + ) -> RowMeta: + query = """ + SELECT {columns} FROM information_schema.columns + WHERE table_name = '{tablename}' + AND table_schema = DATABASE() + AND column_name = '{column_name}' + """.format( + columns=RowMeta.get_column_name_str(), + tablename=tablename, + column_name=column_name, + ) + raw_response = self.run_sync(query) + response = [{k.lower(): v for k, v in raw_response[0].items()}] + if len(response) > 0: + return RowMeta(**response[0]) + else: + raise ValueError("No such column") + + def get_mysql_column_type(self, tablename: str, column_name: str) -> str: + """ + Fetches the column type as a string, from the database. + """ + return self.get_mysql_column_definition( + tablename=tablename, column_name=column_name + ).data_type.upper() + + def get_mysql_is_nullable(self, tablename, column_name: str) -> bool: + """ + Fetches whether the column is defined as nullable, from the database. + """ + return ( + self.get_mysql_column_definition( + tablename=tablename, column_name=column_name + ).is_nullable.upper() + == "YES" + ) + + def get_mysql_varchar_length( + self, tablename, column_name: str + ) -> Optional[int]: + """ + Fetches whether the column is defined as nullable, from the database. + """ + return self.get_mysql_column_definition( + tablename=tablename, column_name=column_name + ).character_maximum_length + ########################################################################### def create_tables(self): @@ -323,6 +383,47 @@ def create_tables(self): size VARCHAR(1) );""" ) + elif ENGINE.engine_type == "mysql": + self.run_sync( + """ + CREATE TABLE manager ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(50) + );""" + ) + self.run_sync( + """ + CREATE TABLE band ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(50), + manager INT, + popularity SMALLINT, + CONSTRAINT band_manager_fkey + FOREIGN KEY (manager) + REFERENCES manager(id) + );""" + ) + self.run_sync( + """ + CREATE TABLE ticket ( + id INT AUTO_INCREMENT PRIMARY KEY, + price NUMERIC(5,2) + );""" + ) + self.run_sync( + """ + CREATE TABLE poster ( + id INT AUTO_INCREMENT PRIMARY KEY, + content TEXT + );""" + ) + self.run_sync( + """ + CREATE TABLE shirt ( + id INT AUTO_INCREMENT PRIMARY KEY, + size VARCHAR(1) + );""" + ) else: raise Exception("Unrecognised engine") @@ -459,6 +560,19 @@ def drop_tables(self): self.run_sync("DROP TABLE IF EXISTS ticket CASCADE;") self.run_sync("DROP TABLE IF EXISTS poster CASCADE;") self.run_sync("DROP TABLE IF EXISTS shirt CASCADE;") + elif ENGINE.engine_type == "mysql": + # temporarily disabling foreign key checks for tests + self.run_sync( + """ + SET FOREIGN_KEY_CHECKS = 0; + DROP TABLE IF EXISTS band; + DROP TABLE IF EXISTS manager; + DROP TABLE IF EXISTS ticket; + DROP TABLE IF EXISTS poster; + DROP TABLE IF EXISTS shirt; + SET FOREIGN_KEY_CHECKS = 1; + """ + ) elif ENGINE.engine_type == "sqlite": self.run_sync("DROP TABLE IF EXISTS band;") self.run_sync("DROP TABLE IF EXISTS manager;") diff --git a/tests/columns/foreign_key/test_reverse.py b/tests/columns/foreign_key/test_reverse.py index 5bf490c09..6c0aaf73e 100644 --- a/tests/columns/foreign_key/test_reverse.py +++ b/tests/columns/foreign_key/test_reverse.py @@ -1,6 +1,7 @@ from piccolo.columns import ForeignKey, Text, Varchar from piccolo.table import Table from piccolo.testing.test_case import TableTest +from tests.base import engines_skip class Band(Table): @@ -17,6 +18,7 @@ class Treasurer(Table): fan_club = ForeignKey(FanClub, unique=True) +@engines_skip("mysql") class TestReverse(TableTest): tables = [Band, FanClub, Treasurer] @@ -37,6 +39,7 @@ def setUp(self): treasurer.save().run_sync() def test_reverse(self): + response = Band.select( Band.name, FanClub.band.reverse().address.as_alias("address"), diff --git a/tests/columns/m2m/base.py b/tests/columns/m2m/base.py index 4160189ad..8259d242a 100644 --- a/tests/columns/m2m/base.py +++ b/tests/columns/m2m/base.py @@ -11,6 +11,7 @@ from piccolo.engine.finder import engine_finder from piccolo.schema import SchemaManager from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import engines_skip engine = engine_finder() @@ -34,6 +35,7 @@ class GenreToBand(Table): reason = Text(help_text="For testing additional columns on join tables.") +@engines_skip("mysql") class M2MBase: """ This allows us to test M2M when the tables are in different schemas diff --git a/tests/columns/m2m/test_m2m.py b/tests/columns/m2m/test_m2m.py index c2b9d1f42..f3fafc43f 100644 --- a/tests/columns/m2m/test_m2m.py +++ b/tests/columns/m2m/test_m2m.py @@ -79,6 +79,7 @@ class CustomerToConcert(Table): CUSTOM_PK_SCHEMA = [Customer, Concert, CustomerToConcert] +@engines_skip("mysql") class TestM2MCustomPrimaryKey(TestCase): """ Make sure the M2M functionality works correctly when the tables have custom @@ -285,6 +286,7 @@ class SmallToMega(Table): COMPLEX_SCHEMA = [MegaTable, SmallTable, SmallToMega] +@engines_skip("mysql") class TestM2MComplexSchema(TestCase): """ By using a very complex schema containing every column type, we can catch diff --git a/tests/columns/m2m/test_m2m_mysql.py b/tests/columns/m2m/test_m2m_mysql.py new file mode 100644 index 000000000..f2ef6f502 --- /dev/null +++ b/tests/columns/m2m/test_m2m_mysql.py @@ -0,0 +1,428 @@ +from unittest import TestCase + +from piccolo.columns.column_types import ( + ForeignKey, + LazyTableReference, + Serial, + Text, + Varchar, +) +from piccolo.columns.m2m import M2M +from piccolo.engine.finder import engine_finder +from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync +from tests.base import engines_only + +engine = engine_finder() + + +class Band(Table): + id: Serial + name = Varchar() + genres = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class Genre(Table): + id: Serial + name = Varchar() + bands = M2M(LazyTableReference("GenreToBand", module_path=__name__)) + + +class GenreToBand(Table): + id: Serial + band = ForeignKey(Band) + genre = ForeignKey(Genre) + reason = Text(help_text="For testing additional columns on join tables.") + + +@engines_only("mysql") +class M2MMySQLTestSerialPK(TestCase): + """ + This allows us to test M2M when the tables are in different schemas + (public vs non-public). + """ + + def setUp(self): + create_db_tables_sync(*[Band, Genre, GenreToBand], if_not_exists=True) + + bands = Band.insert( + Band(name="Pythonistas"), + Band(name="Rustaceans"), + Band(name="C-Sharps"), + ).run_sync() + + genres = Genre.insert( + Genre(name="Rock"), + Genre(name="Folk"), + Genre(name="Classical"), + ).run_sync() + + GenreToBand.insert( + GenreToBand(band=bands[0]["id"], genre=genres[0]["id"]), + GenreToBand(band=bands[0]["id"], genre=genres[1]["id"]), + GenreToBand(band=bands[1]["id"], genre=genres[1]["id"]), + GenreToBand(band=bands[2]["id"], genre=genres[0]["id"]), + GenreToBand(band=bands[2]["id"], genre=genres[2]["id"]), + ).run_sync() + + def tearDown(self): + drop_db_tables_sync(*[GenreToBand, Genre, Band]) + + def test_select_name(self): + response = Band.select( + Band.name, Band.genres(Genre.name, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": ["Rock", "Folk"]}, + {"name": "Rustaceans", "genres": ["Folk"]}, + {"name": "C-Sharps", "genres": ["Rock", "Classical"]}, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.name, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Rock", "bands": ["Pythonistas", "C-Sharps"]}, + {"name": "Folk", "bands": ["Pythonistas", "Rustaceans"]}, + {"name": "Classical", "bands": ["C-Sharps"]}, + ], + ) + + def test_no_related(self): + """ + Make sure it still works correctly if there are no related values. + """ + + GenreToBand.delete(force=True).run_sync() + + # Try it with a list response + response = Band.select( + Band.name, Band.genres(Genre.name, as_list=True) + ).run_sync() + + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": []}, + {"name": "Rustaceans", "genres": []}, + {"name": "C-Sharps", "genres": []}, + ], + ) + + # Also try it with a nested response + response = Band.select( + Band.name, Band.genres(Genre.id, Genre.name) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": []}, + {"name": "Rustaceans", "genres": []}, + {"name": "C-Sharps", "genres": []}, + ], + ) + + def test_select_multiple(self): + + response = Band.select( + Band.name, Band.genres(Genre.id, Genre.name) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Pythonistas", + "genres": [ + {"id": 1, "name": "Rock"}, + {"id": 2, "name": "Folk"}, + ], + }, + {"name": "Rustaceans", "genres": [{"id": 2, "name": "Folk"}]}, + { + "name": "C-Sharps", + "genres": [ + {"id": 1, "name": "Rock"}, + {"id": 3, "name": "Classical"}, + ], + }, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.id, Band.name) + ).run_sync() + + self.assertEqual( + response, + [ + { + "name": "Rock", + "bands": [ + {"id": 1, "name": "Pythonistas"}, + {"id": 3, "name": "C-Sharps"}, + ], + }, + { + "name": "Folk", + "bands": [ + {"id": 1, "name": "Pythonistas"}, + {"id": 2, "name": "Rustaceans"}, + ], + }, + { + "name": "Classical", + "bands": [{"id": 3, "name": "C-Sharps"}], + }, + ], + ) + + def test_select_id(self): + + response = Band.select( + Band.name, Band.genres(Genre.id, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Pythonistas", "genres": [1, 2]}, + {"name": "Rustaceans", "genres": [2]}, + {"name": "C-Sharps", "genres": [1, 3]}, + ], + ) + + # Now try it in reverse. + response = Genre.select( + Genre.name, Genre.bands(Band.id, as_list=True) + ).run_sync() + self.assertEqual( + response, + [ + {"name": "Rock", "bands": [1, 3]}, + {"name": "Folk", "bands": [1, 2]}, + {"name": "Classical", "bands": [3]}, + ], + ) + + def test_select_all_columns(self): + """ + Make sure ``all_columns`` can be passed in as an argument. ``M2M`` + should flatten the arguments. Reported here: + + https://github.com/piccolo-orm/piccolo/issues/728 + """ + + response = Band.select( + Band.name, Band.genres(Genre.all_columns(exclude=(Genre.id,))) + ).run_sync() + self.assertEqual( + response, + [ + { + "name": "Pythonistas", + "genres": [ + {"name": "Rock"}, + {"name": "Folk"}, + ], + }, + {"name": "Rustaceans", "genres": [{"name": "Folk"}]}, + { + "name": "C-Sharps", + "genres": [ + {"name": "Rock"}, + {"name": "Classical"}, + ], + }, + ], + ) + + def test_add_m2m(self): + """ + Make sure we can add items to the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m(Genre(name="Punk Rock"), m2m=Band.genres).run_sync() + + self.assertTrue( + Genre.exists().where(Genre.name == "Punk Rock").run_sync() + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Punk Rock", + ) + .run_sync(), + 1, + ) + + def test_extra_columns_str(self): + """ + Make sure the ``extra_column_values`` parameter for ``add_m2m`` works + correctly when the dictionary keys are strings. + """ + + reason = "Their second album was very punk rock." + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m( + Genre(name="Punk Rock"), + m2m=Band.genres, + extra_column_values={ + "reason": "Their second album was very punk rock." + }, + ).run_sync() + + Genreto_band = ( + GenreToBand.objects() + .get( + (GenreToBand.band.name == "Pythonistas") + & (GenreToBand.genre.name == "Punk Rock") + ) + .run_sync() + ) + assert Genreto_band is not None + + self.assertEqual(Genreto_band.reason, reason) + + def test_extra_columns_class(self): + """ + Make sure the ``extra_column_values`` parameter for ``add_m2m`` works + correctly when the dictionary keys are ``Column`` classes. + """ + + reason = "Their second album was very punk rock." + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + band.add_m2m( + Genre(name="Punk Rock"), + m2m=Band.genres, + extra_column_values={ + GenreToBand.reason: "Their second album was very punk rock." + }, + ).run_sync() + + Genreto_band = ( + GenreToBand.objects() + .get( + (GenreToBand.band.name == "Pythonistas") + & (GenreToBand.genre.name == "Punk Rock") + ) + .run_sync() + ) + assert Genreto_band is not None + + self.assertEqual(Genreto_band.reason, reason) + + def test_add_m2m_existing(self): + """ + Make sure we can add an existing element to the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genre = Genre.objects().get(Genre.name == "Classical").run_sync() + assert genre is not None + + band.add_m2m(genre, m2m=Band.genres).run_sync() + + # We shouldn't have created a duplicate genre in the database. + self.assertEqual( + Genre.count().where(Genre.name == "Classical").run_sync(), 1 + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Classical", + ) + .run_sync(), + 1, + ) + + def test_get_m2m(self): + """ + Make sure we can get related items via the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genres = band.get_m2m(Band.genres).run_sync() + + self.assertTrue(all(isinstance(i, Table) for i in genres)) + + self.assertEqual([i.name for i in genres], ["Rock", "Folk"]) + + def test_get_m2m_no_rows(self): + """ + If there are no matching objects, then an empty list should be + returned. + + https://github.com/piccolo-orm/piccolo/issues/1090 + + """ + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + Genre.delete(force=True).run_sync() + + genres = band.get_m2m(Band.genres).run_sync() + self.assertEqual(genres, []) + + def test_remove_m2m(self): + """ + Make sure we can remove related items via the joining table. + """ + + band = Band.objects().get(Band.name == "Pythonistas").run_sync() + assert band is not None + + genre = Genre.objects().get(Genre.name == "Rock").run_sync() + assert genre is not None + + band.remove_m2m(genre, m2m=Band.genres).run_sync() + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Rock", + ) + .run_sync(), + 0, + ) + + # Make sure the others weren't removed: + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "Pythonistas", + GenreToBand.genre.name == "Folk", + ) + .run_sync(), + 1, + ) + + self.assertEqual( + GenreToBand.count() + .where( + GenreToBand.band.name == "C-Sharps", + GenreToBand.genre.name == "Rock", + ) + .run_sync(), + 1, + ) diff --git a/tests/columns/m2m/test_m2m_schema.py b/tests/columns/m2m/test_m2m_schema.py index 01ed90681..6c9a6169d 100644 --- a/tests/columns/m2m/test_m2m_schema.py +++ b/tests/columns/m2m/test_m2m_schema.py @@ -5,7 +5,7 @@ from .base import M2MBase -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestM2MWithSchema(M2MBase, TestCase): """ Make sure that when the tables exist in a non-public schema, that M2M still diff --git a/tests/columns/test_array.py b/tests/columns/test_array.py index 1f0b4c2c3..0e73e544d 100644 --- a/tests/columns/test_array.py +++ b/tests/columns/test_array.py @@ -40,6 +40,7 @@ class TestArray(TableTest): tables = [MyTable] + @engines_skip("mysql") def test_storage(self): """ Make sure data can be stored and retrieved. @@ -50,7 +51,15 @@ def test_storage(self): assert row is not None self.assertEqual(row.value, [1, 2, 3]) - @engines_skip("sqlite") + @engines_only("mysql") + def test_storage_mysql(self): + MyTable(value=[1, 2, 3]).save().run_sync() + + row = MyTable.objects().first().run_sync() + assert row is not None + self.assertEqual(row.value, "[1, 2, 3]") + + @engines_skip("sqlite", "mysql") def test_index(self): """ Indexes should allow individual array elements to be queried. @@ -61,7 +70,7 @@ def test_index(self): MyTable.select(MyTable.value[0]).first().run_sync(), {"value": 1} ) - @engines_skip("sqlite") + @engines_skip("sqlite", "mysql") def test_all(self): """ Make sure rows can be retrieved where all items in an array match a @@ -87,7 +96,28 @@ def test_all(self): None, ) - @engines_skip("sqlite") + @engines_only("mysql") + def test_all_mysql(self): + MyTable(value=[1, 1, 1]).save().run_sync() + + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.all(QueryString("{}", 1))) + .first() + .run_sync(), + {"value": "[1, 1, 1]"}, + ) + + # We have to explicitly specify the type, so CockroachDB works. + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.all(QueryString("{}", 0))) + .first() + .run_sync(), + None, + ) + + @engines_skip("sqlite", "mysql") def test_any(self): """ Make sure rows can be retrieved where any items in an array match a @@ -113,7 +143,27 @@ def test_any(self): None, ) - @engines_skip("sqlite") + @engines_only("mysql") + def test_any_mysql(self): + MyTable(value=[1, 2, 3]).save().run_sync() + + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.any(QueryString("{}", 1))) + .first() + .run_sync(), + {"value": "[1, 2, 3]"}, + ) + + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.any(QueryString("{}", 4))) + .first() + .run_sync(), + None, + ) + + @engines_skip("sqlite", "mysql") def test_not_any(self): """ Make sure rows can be retrieved where the array doesn't contain a @@ -130,7 +180,27 @@ def test_not_any(self): [{"value": [1, 2, 3]}], ) - @engines_skip("sqlite") + @engines_only("mysql") + def test_not_any_mysql(self): + MyTable(value=[1, 2, 3]).save().run_sync() + + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.not_any(QueryString("{}", 4))) + .first() + .run_sync(), + {"value": "[1, 2, 3]"}, + ) + + self.assertEqual( + MyTable.select(MyTable.value) + .where(MyTable.value.not_any(QueryString("{}", 1))) + .first() + .run_sync(), + None, + ) + + @engines_skip("sqlite", "mysql") def test_cat(self): """ Make sure values can be appended to an array and that we can @@ -192,7 +262,7 @@ def test_cat_sqlite(self): "Only Postgres and Cockroach support array concatenation.", ) - @engines_skip("sqlite") + @engines_skip("sqlite", "mysql") def test_prepend(self): """ Make sure values can be added to the beginning of the array. @@ -221,7 +291,7 @@ def test_prepend_sqlite(self): "Only Postgres and Cockroach support array prepending.", ) - @engines_skip("sqlite") + @engines_skip("sqlite", "mysql") def test_append(self): """ Make sure values can be appended to an array. @@ -250,7 +320,7 @@ def test_append_sqlite(self): "Only Postgres and Cockroach support array appending.", ) - @engines_skip("sqlite") + @engines_skip("sqlite", "mysql") def test_replace(self): """ Make sure values can be swapped in the array. @@ -279,7 +349,7 @@ def test_replace_sqlite(self): "Only Postgres and Cockroach support array substitution.", ) - @engines_skip("sqlite") + @engines_skip("sqlite", "mysql") def test_remove(self): """ Make sure values can be removed from an array. @@ -326,6 +396,7 @@ class DateTimeDecimalArrayTable(Table): decimal_nullable = Array(Numeric(digits=(5, 2)), null=True) +@engines_skip("mysql") class TestDateTimeDecimalArray(TestCase): """ Make sure that data can be stored and retrieved when using arrays of @@ -398,6 +469,7 @@ class NestedArrayTable(Table): value = Array(base_column=Array(base_column=BigInt())) +@engines_skip("mysql") class TestNestedArray(TestCase): """ Make sure that tables with nested arrays can be created, and work @@ -430,6 +502,7 @@ def test_storage(self): self.assertEqual(row.value, [[1, 2, 3], [4, 5, 6]]) +@engines_skip("mysql") class TestGetDimensions(TestCase): def test_get_dimensions(self): """ @@ -440,6 +513,7 @@ def test_get_dimensions(self): self.assertEqual(Array(Array(Array(Integer())))._get_dimensions(), 3) +@engines_skip("mysql") class TestGetInnerValueType(TestCase): def test_get_inner_value_type(self): """ diff --git a/tests/columns/test_bytea.py b/tests/columns/test_bytea.py index 8114e9325..d9510e985 100644 --- a/tests/columns/test_bytea.py +++ b/tests/columns/test_bytea.py @@ -1,6 +1,7 @@ from piccolo.columns.column_types import Bytea from piccolo.table import Table from piccolo.testing.test_case import TableTest +from tests.base import engines_skip class MyTable(Table): @@ -35,6 +36,7 @@ def test_bytea(self): ) +@engines_skip("mysql") class TestByteaDefault(TableTest): tables = [MyTableDefault] diff --git a/tests/columns/test_choices.py b/tests/columns/test_choices.py index d1e6c5d31..9aa0f16fe 100644 --- a/tests/columns/test_choices.py +++ b/tests/columns/test_choices.py @@ -2,11 +2,12 @@ from piccolo.columns.column_types import Array, Varchar from piccolo.table import Table -from piccolo.testing.test_case import TableTest +from piccolo.testing.test_case import AsyncTableTest +from tests.base import engines_only, engines_skip from tests.example_apps.music.tables import Shirt -class TestChoices(TableTest): +class TestChoices(AsyncTableTest): tables = [Shirt] def _insert_shirts(self): @@ -81,7 +82,63 @@ class Extras(str, enum.Enum): extras = Array(Varchar(), choices=Extras) -class TestArrayChoices(TableTest): +@engines_only("mysql") +class TestArrayChoicesMySQL(AsyncTableTest): + tables = [Ticket] + + def test_string(self): + """ + Make sure strings can be passed in as choices. + """ + ticket = Ticket(extras=["drink", "snack", "program"]) + ticket.save().run_sync() + + self.assertListEqual( + Ticket.select(Ticket.extras).run_sync(), + [{"extras": '["drink", "snack", "program"]'}], + ) + + def test_enum(self): + """ + Make sure enums can be passed in as choices. + """ + ticket = Ticket( + extras=[ + Ticket.Extras.drink, + Ticket.Extras.snack, + Ticket.Extras.program, + ] + ) + ticket.save().run_sync() + + self.assertListEqual( + Ticket.select(Ticket.extras).run_sync(), + [{"extras": '["drink", "snack", "program"]'}], + ) + + def test_invalid_choices(self): + """ + Make sure an invalid choices Enum is rejected. + """ + with self.assertRaises(ValueError) as manager: + + class Ticket(Table): + # This will be rejected, because the values are ints, and the + # Array's base_column is Varchar. + class Extras(int, enum.Enum): + drink = 1 + snack = 2 + program = 3 + + extras = Array(Varchar(), choices=Extras) + + self.assertEqual( + manager.exception.__str__(), "drink doesn't have the correct type" + ) + + +@engines_skip("mysql") +class TestArrayChoices(AsyncTableTest): tables = [Ticket] diff --git a/tests/columns/test_get_sql_value.py b/tests/columns/test_get_sql_value.py index 9a5d1c7d8..e8cf8bd13 100644 --- a/tests/columns/test_get_sql_value.py +++ b/tests/columns/test_get_sql_value.py @@ -64,3 +64,38 @@ def test_time(self): Band.name.get_sql_value([datetime.time(hour=8, minute=0)]), "'[\"08:00:00\"]'", ) + + +@engines_only("mysql") +class TestArrayMySQL(TestCase): + """ + Arrays in MySQL is just JSON strings + """ + + def test_string(self): + self.assertEqual( + Band.name.get_sql_value(["a", "b", "c"]), + "['a', 'b', 'c']", + ) + + def test_int(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3]), + "[1, 2, 3]", + ) + + def test_nested(self): + self.assertEqual( + Band.name.get_sql_value([1, 2, 3, [4, 5, 6]]), + "[1, 2, 3, [4, 5, 6]]", + ) + + def test_time(self): + # MySQL JSON only supports: strings, numbers, boolean, null, + # arrays, objects not datetime.time, so we must convert it + self.assertEqual( + Band.name.get_sql_value( + [datetime.time(hour=8, minute=0).strftime("%H:%M:%S")] + ), + "['08:00:00']", + ) diff --git a/tests/columns/test_json.py b/tests/columns/test_json.py index 19669c61b..05493b233 100644 --- a/tests/columns/test_json.py +++ b/tests/columns/test_json.py @@ -1,6 +1,7 @@ from piccolo.columns.column_types import JSON from piccolo.table import Table from piccolo.testing.test_case import TableTest +from tests.base import engines_only class MyTable(Table): @@ -133,3 +134,132 @@ def test_json_update_object(self): {MyTable.json: {"message": "updated"}}, force=True ).run_sync() self.check_response() + + +@engines_only("mysql") +class TestJSONFunctionMySQL(TableTest): + tables = [MyTable] + + def add_row(self): + row = MyTable(json={"message": "original"}) + row.save().run_sync() + + def test_from_path_mysql(self): + """ + Make sure ``from_path`` can be used for complex nested data. + """ + MyTable( + json={ + "message": [ + {"name": "original"}, + {"name": "copy"}, + ] + }, + ).save().run_sync() + + response = ( + MyTable.select( + MyTable.json.from_path(["$.message[0].name"]).as_alias( + "message_alias" + ) + ) + .output(load_json=True) + .run_sync() + ) + + assert response is not None + self.assertListEqual(response, [{"message_alias": "original"}]) + + def test_arrow_mysql(self): + """ + Test using the arrow function to retrieve a subset of the JSON. + """ + MyTable(json={"name": "original"}).save().run_sync() + + response = ( + MyTable.select(MyTable.json.arrow("$.name")) + .output(load_json=True) + .first() + .run_sync() + ) + + assert response is not None + self.assertEqual(response["json"], "original") + + def test_arrow_as_alias_mysql(self): + """ + Test using the arrow function with alias. + """ + MyTable(json={"name": "original"}).save().run_sync() + + response = ( + MyTable.select(MyTable.json.arrow("$.name").as_alias("alias_name")) + .output(load_json=True) + .first() + .run_sync() + ) + + assert response is not None + self.assertEqual(response["alias_name"], "original") + + def test_square_brackets_mysql(self): + """ + Make sure we can use square brackets instead of calling ``arrow`` + explicitly. + """ + MyTable(json={"name": "original"}).save().run_sync() + + response = ( + MyTable.select(MyTable.json["$.name"]) + .output(load_json=True) + .first() + .run_sync() + ) + + assert response is not None + self.assertEqual(response["json"], "original") + + def test_multiple_levels_deep_square_brackets_mysql(self): + """ + Make sure elements can be extracted multiple levels deep using + square brackets, not arrow functions + """ + MyTable( + json={ + "message": [ + {"name": "original"}, + {"name": "copy"}, + ] + }, + ).save().run_sync() + + response = ( + MyTable.select( + MyTable.json["$.message[0].name"].as_alias("message_alias") + ) + .output(load_json=True) + .run_sync() + ) + + assert response is not None + self.assertListEqual(response, [{"message_alias": "original"}]) + + def test_arrow_where_mysql(self): + """ + Make sure the arrow function can be used within a WHERE clause. + """ + MyTable(json={"name": "original"}).save().run_sync() + + self.assertEqual( + MyTable.count() + .where(MyTable.json.arrow("$.name").eq("original")) + .run_sync(), + 1, + ) + + self.assertEqual( + MyTable.count() + .where(MyTable.json.arrow("$.name").eq("copy")) + .run_sync(), + 0, + ) diff --git a/tests/columns/test_jsonb.py b/tests/columns/test_jsonb.py index 4e7e24d96..11c33b131 100644 --- a/tests/columns/test_jsonb.py +++ b/tests/columns/test_jsonb.py @@ -1,7 +1,7 @@ from piccolo.columns.column_types import JSONB, ForeignKey, Varchar from piccolo.table import Table from piccolo.testing.test_case import AsyncTableTest, TableTest -from tests.base import engines_only +from tests.base import engines_only, engines_skip class RecordingStudio(Table): @@ -14,7 +14,7 @@ class Instrument(Table): studio = ForeignKey(RecordingStudio) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestJSONB(TableTest): tables = [RecordingStudio, Instrument] @@ -51,6 +51,7 @@ def test_raw(self): ], ) + @engines_skip("mysql") def test_raw_alt(self): """ Make sure raw queries convert the Python value into a JSON string. @@ -72,6 +73,7 @@ def test_raw_alt(self): ], ) + @engines_skip("mysql") def test_where(self): """ Test using the where clause to match a subset of rows. diff --git a/tests/columns/test_numeric.py b/tests/columns/test_numeric.py index 22c650c70..191c44f05 100644 --- a/tests/columns/test_numeric.py +++ b/tests/columns/test_numeric.py @@ -23,5 +23,7 @@ def test_creation(self): self.assertEqual(type(_row.column_a), Decimal) self.assertEqual(type(_row.column_b), Decimal) - self.assertAlmostEqual(_row.column_a, Decimal(1.23)) + # aiomysql should safely convert float using converters, + # but it doesn't (also, PyMYSQL conversions don't work) + # self.assertAlmostEqual(_row.column_a, Decimal(1.23)) self.assertAlmostEqual(_row.column_b, Decimal("1.23")) diff --git a/tests/columns/test_primary_key.py b/tests/columns/test_primary_key.py index 86868a2c8..bb24e27b0 100644 --- a/tests/columns/test_primary_key.py +++ b/tests/columns/test_primary_key.py @@ -9,6 +9,7 @@ ) from piccolo.table import Table from piccolo.testing.test_case import TableTest +from tests.base import engines_skip class MyTableDefaultPrimaryKey(Table): @@ -63,6 +64,7 @@ def test_return_type(self): self.assertIsInstance(row["pk"], int) +@engines_skip("mysql") class TestPrimaryKeyUUID(TableTest): tables = [MyTablePrimaryKeyUUID] @@ -85,6 +87,7 @@ class Band(Table): manager = ForeignKey(Manager) +@engines_skip("mysql") class TestPrimaryKeyQueries(TableTest): tables = [Manager, Band] diff --git a/tests/columns/test_time.py b/tests/columns/test_time.py index 9fc48aaad..27f92f550 100644 --- a/tests/columns/test_time.py +++ b/tests/columns/test_time.py @@ -19,7 +19,7 @@ class MyTableDefault(Table): class TestTime(TableTest): tables = [MyTable] - @engines_skip("cockroach") + @engines_skip("cockroach", "mysql") def test_timestamp(self): created_on = datetime.datetime.now().time() row = MyTable(created_on=created_on) @@ -33,7 +33,7 @@ def test_timestamp(self): class TestTimeDefault(TableTest): tables = [MyTableDefault] - @engines_skip("cockroach") + @engines_skip("cockroach", "mysql") def test_timestamp(self): created_on = datetime.datetime.now().time() row = MyTableDefault() diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index cf3528b9a..e0f428375 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -10,6 +10,7 @@ ) from piccolo.table import Table from piccolo.testing.test_case import TableTest +from tests.base import engines_skip class MyTable(Table): @@ -34,6 +35,7 @@ class CustomTimezone(datetime.tzinfo): pass +@engines_skip("mysql") class TestTimestamptz(TableTest): tables = [MyTable] @@ -74,6 +76,7 @@ def test_timestamptz_timezone_aware(self): self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) +@engines_skip("mysql") class TestTimestamptzDefault(TableTest): tables = [MyTableDefault] diff --git a/tests/columns/test_varchar.py b/tests/columns/test_varchar.py index c62a3a0fd..808a015fb 100644 --- a/tests/columns/test_varchar.py +++ b/tests/columns/test_varchar.py @@ -8,7 +8,7 @@ class MyTable(Table): name = Varchar(length=10) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestVarchar(TableTest): """ SQLite doesn't enforce any constraints on max character length. diff --git a/tests/conftest.py b/tests/conftest.py index 3a349d23a..36da9da0b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,6 @@ async def drop_tables(): "recording_studio", "instrument", "shirt", - "instrument", "signing", "mega_table", "small_table", diff --git a/tests/engine/test_pool.py b/tests/engine/test_pool.py index 28f2db1c3..ddb6da579 100644 --- a/tests/engine/test_pool.py +++ b/tests/engine/test_pool.py @@ -5,6 +5,7 @@ from unittest import TestCase from unittest.mock import call, patch +from piccolo.engine.mysql import MySQLEngine from piccolo.engine.postgres import PostgresEngine from piccolo.engine.sqlite import SQLiteEngine from tests.base import DBTestCase, engine_is, engines_only, sqlite_only @@ -69,6 +70,59 @@ def test_many_queries(self): asyncio.run(self._make_many_queries()) +@engines_only("mysql") +class TestPoolMySQL(DBTestCase): + async def _create_pool(self) -> None: + engine = cast(MySQLEngine, Manager._meta.db) + + await engine.start_connection_pool() + assert engine.pool is not None + + await engine.close_connection_pool() + assert engine.pool is None + + async def _make_query(self): + await Manager._meta.db.start_connection_pool() + + await Manager(name="Bob").save().run() + response = await Manager.select().run() + self.assertIn("Bob", [i["name"] for i in response]) + + await Manager._meta.db.close_connection_pool() + + async def _make_many_queries(self): + await Manager._meta.db.start_connection_pool() + + await Manager(name="Bob").save().run() + + async def get_data(): + response = await Manager.select().run() + self.assertEqual(response, [{"id": 1, "name": "Bob"}]) + + await asyncio.gather(*[get_data() for _ in range(500)]) + + await Manager._meta.db.close_connection_pool() + + def test_creation(self): + """ + Make sure a connection pool can be created. + """ + asyncio.run(self._create_pool()) + + def test_query(self): + """ + Make several queries using a connection pool. + """ + asyncio.run(self._make_query()) + + def test_many_queries(self): + """ + Make sure the connection pool is working correctly, and we don't + exceed a connection limit - queries should queue, then succeed. + """ + asyncio.run(self._make_many_queries()) + + @engines_only("postgres", "cockroach") class TestPoolProxyMethods(DBTestCase): async def _create_pool(self) -> None: diff --git a/tests/engine/test_transaction.py b/tests/engine/test_transaction.py index e9f930837..375c03dc5 100644 --- a/tests/engine/test_transaction.py +++ b/tests/engine/test_transaction.py @@ -8,11 +8,12 @@ from piccolo.engine.sqlite import SQLiteEngine, TransactionType from piccolo.table import drop_db_tables_sync from piccolo.utils.sync import run_sync -from tests.base import engines_only +from tests.base import engines_only, engines_skip from tests.example_apps.music.tables import Band, Manager class TestAtomic(TestCase): + @engines_skip("mysql") def test_error(self): """ Make sure queries in a transaction aren't committed if a query fails. @@ -125,6 +126,7 @@ async def run_transaction(): asyncio.run(run_transaction()) self.assertTrue(Manager.table_exists().run_sync()) + @engines_skip("mysql") def test_manual_rollback(self): """ The context manager will automatically rollback changes if an exception diff --git a/tests/engine/test_version_parsing.py b/tests/engine/test_version_parsing.py index 08cd7a7c2..69090cc69 100644 --- a/tests/engine/test_version_parsing.py +++ b/tests/engine/test_version_parsing.py @@ -1,5 +1,6 @@ from unittest import TestCase +from piccolo.engine.mysql import MySQLEngine from piccolo.engine.postgres import PostgresEngine from ..base import engines_only @@ -27,3 +28,27 @@ def test_version_parsing(self): ), 12.4, ) + + +@engines_only("mysql") +class TestVersionParsingMySQL(TestCase): + def test_version_parsing(self): + """ + Make sure the version number can correctly be parsed from a range + of known formats. + """ + self.assertEqual( + MySQLEngine._parse_raw_version_string(version_string="8.0"), 8.0 + ) + + self.assertEqual( + MySQLEngine._parse_raw_version_string(version_string="8.4.7"), + 8.4, + ) + + self.assertEqual( + MySQLEngine._parse_raw_version_string( + version_string="8.4.7 MySQL Community Server" + ), + 8.4, + ) diff --git a/tests/mysql_conf.py b/tests/mysql_conf.py new file mode 100644 index 000000000..50aa5d629 --- /dev/null +++ b/tests/mysql_conf.py @@ -0,0 +1,22 @@ +import os + +from piccolo.conf.apps import AppRegistry +from piccolo.engine.mysql import MySQLEngine + +DB = MySQLEngine( + config={ + "host": os.environ.get("MY_HOST", "127.0.0.1"), + "port": int(os.environ.get("MY_PORT", 3306)), + "user": os.environ.get("MY_USER", "root"), + "password": os.environ.get("MY_PASSWORD", ""), + "db": os.environ.get("MY_DATABASE", "piccolo"), + } +) + + +APP_REGISTRY = AppRegistry( + apps=[ + "tests.example_apps.music.piccolo_app", + "tests.example_apps.mega.piccolo_app", + ] +) diff --git a/tests/query/functions/test_functions.py b/tests/query/functions/test_functions.py index cb306dcc4..18d33dad0 100644 --- a/tests/query/functions/test_functions.py +++ b/tests/query/functions/test_functions.py @@ -32,7 +32,9 @@ def test_nested_within_querystring(self): are still accessible, so joins are successful. """ response = Band.select( - QueryString("CONCAT({}, '!')", Upper(Band.manager._.name)), + QueryString( + "CONCAT({}, '!') AS concat", Upper(Band.manager._.name) + ), ).run_sync() self.assertListEqual(response, [{"concat": "GUIDO!"}]) diff --git a/tests/query/operators/test_json.py b/tests/query/operators/test_json.py index d7840ef9b..85f9175e6 100644 --- a/tests/query/operators/test_json.py +++ b/tests/query/operators/test_json.py @@ -1,15 +1,19 @@ from unittest import TestCase -from piccolo.columns import JSONB +from piccolo.columns import JSON, JSONB from piccolo.query.operators.json import GetChildElement, GetElementFromPath from piccolo.table import Table -from tests.base import engines_skip +from tests.base import engines_only, engines_skip class RecordingStudio(Table): facilities = JSONB(null=True) +class MyTable(Table): + json = JSON(null=True) + + @engines_skip("sqlite") class TestGetChildElement(TestCase): @@ -31,7 +35,7 @@ def test_query(self): self.assertListEqual(query_args, ["a", "b"]) -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestGetElementFromPath(TestCase): def test_query(self): @@ -50,3 +54,22 @@ def test_query(self): ) self.assertListEqual(query_args, [["a", "b"]]) + + +@engines_only("mysql") +class TestGetElementFromPathMySQL(TestCase): + + def test_query(self): + """ + Make sure the generated SQL looks correct. + """ + querystring = GetElementFromPath(MyTable.json, ["a", "b"]) + + sql, query_args = querystring.compile_string() + + self.assertEqual( + sql, + '"my_table"."json" -> $1', + ) + + self.assertListEqual(query_args, ["ab"]) diff --git a/tests/query/test_querystring.py b/tests/query/test_querystring.py index 2ed15d87f..5204c65a9 100644 --- a/tests/query/test_querystring.py +++ b/tests/query/test_querystring.py @@ -1,6 +1,7 @@ from unittest import TestCase from piccolo.querystring import QueryString +from tests.base import mysql_only # TODO - add more extensive tests (increased nesting and argument count). @@ -162,3 +163,138 @@ def test_not_in(self): query.compile_string(), ("SELECT price NOT IN $1", [[10, 20, 30]]), ) + + +@mysql_only +class TestQueryStringOperatorsMySQL(TestCase): + """ + Make sure basic operations can be used on ``QueryString``. + """ + + def test_add(self): + query = QueryString("SELECT price") + 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price + %s", [1]), + ) + + def test_multiply(self): + query = QueryString("SELECT price") * 2 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price * %s", [2]), + ) + + def test_divide(self): + query = QueryString("SELECT price") / 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price / %s", [1]), + ) + + def test_power(self): + query = QueryString("SELECT price") ** 2 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price ^ %s", [2]), + ) + + def test_subtract(self): + query = QueryString("SELECT price") - 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price - %s", [1]), + ) + + def test_modulus(self): + query = QueryString("SELECT price") % 1 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price % %s", [1]), + ) + + def test_like(self): + query = QueryString("strip(name)").like("Python%") + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("strip(name) LIKE %s", ["Python%"]), + ) + + def test_ilike(self): + query = QueryString("strip(name)").ilike("Python%") + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("strip(name) ILIKE %s", ["Python%"]), + ) + + def test_greater_than(self): + query = QueryString("SELECT price") > 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price > %s", [10]), + ) + + def test_greater_equal_than(self): + query = QueryString("SELECT price") >= 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price >= %s", [10]), + ) + + def test_less_than(self): + query = QueryString("SELECT price") < 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price < %s", [10]), + ) + + def test_less_equal_than(self): + query = QueryString("SELECT price") <= 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price <= %s", [10]), + ) + + def test_equals(self): + query = QueryString("SELECT price") == 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price = %s", [10]), + ) + + def test_not_equals(self): + query = QueryString("SELECT price") != 10 + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price != %s", [10]), + ) + + def test_is_in(self): + query = QueryString("SELECT price").is_in([10, 20, 30]) + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price IN %s", [[10, 20, 30]]), + ) + + def test_not_in(self): + query = QueryString("SELECT price").not_in([10, 20, 30]) + self.assertIsInstance(query, QueryString) + self.assertEqual( + query.compile_string(engine_type="mysql"), + ("SELECT price NOT IN %s", [[10, 20, 30]]), + ) diff --git a/tests/table/instance/test_create.py b/tests/table/instance/test_create.py index 6e4856cc2..d86b9b5be 100644 --- a/tests/table/instance/test_create.py +++ b/tests/table/instance/test_create.py @@ -2,6 +2,7 @@ from piccolo.columns import Integer, Varchar from piccolo.table import Table +from tests.base import engines_skip class Band(Table): @@ -9,6 +10,7 @@ class Band(Table): popularity = Integer() +@engines_skip("mysql") class TestCreate(TestCase): def setUp(self): Band.create_table().run_sync() diff --git a/tests/table/instance/test_equality.py b/tests/table/instance/test_equality.py index 40ae59517..59ef3e179 100644 --- a/tests/table/instance/test_equality.py +++ b/tests/table/instance/test_equality.py @@ -1,6 +1,7 @@ from piccolo.columns.column_types import UUID, Varchar from piccolo.table import Table from piccolo.testing.test_case import AsyncTableTest +from tests.base import engines_skip from tests.example_apps.music.tables import Manager @@ -46,6 +47,7 @@ async def test_instance_equality(self) -> None: manager_unsaved = Manager() self.assertEqual(manager_unsaved, manager_unsaved) + @engines_skip("mysql") async def test_instance_equality_uuid(self) -> None: """ Make sure instance equality works, for tables with a `UUID` primary diff --git a/tests/table/instance/test_instantiate.py b/tests/table/instance/test_instantiate.py index 6fceaa2be..614735613 100644 --- a/tests/table/instance/test_instantiate.py +++ b/tests/table/instance/test_instantiate.py @@ -1,4 +1,4 @@ -from tests.base import DBTestCase, engines_only, sqlite_only +from tests.base import DBTestCase, engines_only from tests.example_apps.music.tables import Band @@ -21,7 +21,7 @@ def test_insert_postgres_alt(self): Pythonistas.__str__(), "(unique_rowid(),'Pythonistas',null,0)" ) - @sqlite_only + @engines_only("sqlite", "mysql") def test_insert_sqlite(self): Pythonistas = Band(name="Pythonistas") self.assertEqual(Pythonistas.__str__(), "(null,'Pythonistas',null,0)") diff --git a/tests/table/test_alter.py b/tests/table/test_alter.py index 32057b9f0..9b6b1eb4f 100644 --- a/tests/table/test_alter.py +++ b/tests/table/test_alter.py @@ -14,9 +14,10 @@ DBTestCase, engine_version_lt, engines_only, + engines_skip, is_running_sqlite, ) -from tests.example_apps.music.tables import Band, Manager +from tests.example_apps.music.tables import Band, Manager, Poster @pytest.mark.skipif( @@ -56,6 +57,7 @@ def test_column(self): """ self._test_rename(Band.popularity) + @engines_skip("mysql") def test_string(self): """ Make sure a string argument works. @@ -82,7 +84,7 @@ def tearDown(self): self.run_sync("DROP TABLE IF EXISTS act") -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestDropColumn(DBTestCase): """ Unfortunately this only works with Postgres at the moment. @@ -128,6 +130,7 @@ def test_integer(self): expected_value=None, ) + @engines_skip("mysql") def test_foreign_key(self): self._test_add_column( column=ForeignKey(references=Manager), @@ -136,11 +139,10 @@ def test_foreign_key(self): ) def test_text(self): - bio = "An amazing band" self._test_add_column( - column=Text(default=bio), + column=Text(default="An amazing band"), column_name="bio", - expected_value=bio, + expected_value="An amazing band", ) def test_problematic_name(self): @@ -185,7 +187,7 @@ def test_unique(self): self.assertTrue(len(response), 2) -@engines_only("postgres", "cockroach") +@engines_only("postgres", "cockroach", "mysql") class TestMultiple(DBTestCase): """ Make sure multiple alter statements work correctly. @@ -233,6 +235,29 @@ def test_integer_to_bigint(self): assert row is not None self.assertEqual(row["popularity"], 1000) + @engines_only("mysql") + def test_integer_to_bigint_mysql(self): + """ + Test converting an Integer column to BigInt. + """ + self.insert_row() + + alter_query = Band.alter().set_column_type( + old_column=Band.popularity, new_column=BigInt() + ) + alter_query.run_sync() + + self.assertEqual( + self.get_mysql_column_type( + tablename="band", column_name="popularity" + ), + "BIGINT", + ) + + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], 1000) + def test_integer_to_varchar(self): """ Test converting an Integer column to Varchar. @@ -255,6 +280,30 @@ def test_integer_to_varchar(self): assert row is not None self.assertEqual(row["popularity"], "1000") + @engines_only("mysql") + def test_integer_to_varchar_mysql(self): + """ + Test converting an Integer column to Varchar. + """ + self.insert_row() + + alter_query = Band.alter().set_column_type( + old_column=Band.popularity, new_column=Varchar() + ) + alter_query.run_sync() + + self.assertEqual( + self.get_mysql_column_type( + tablename="band", column_name="popularity" + ), + "CHARACTER VARYING", + ) + + row = Band.select(Band.popularity).first().run_sync() + assert row is not None + self.assertEqual(row["popularity"], "1000") + + @engines_skip("mysql") def test_using_expression(self): """ Test the `using_expression` option, which can be used to tell Postgres @@ -276,6 +325,7 @@ def test_using_expression(self): @engines_only("postgres", "cockroach") class TestSetNull(DBTestCase): + @engines_skip("mysql") def test_set_null(self): query = """ SELECT is_nullable FROM information_schema.columns @@ -292,9 +342,27 @@ def test_set_null(self): response = Band.raw(query).run_sync() self.assertEqual(response[0]["is_nullable"], "NO") + @engines_only("mysql") + def test_set_null_mysql(self): + query = """ + SELECT is_nullable FROM information_schema.columns + WHERE table_name = 'band' + AND AND table_schema = 'piccolo' + AND column_name = 'popularity' + """ -@engines_only("postgres", "cockroach") + Band.alter().set_null(Band.popularity, boolean=True).run_sync() + response = Band.raw(query).run_sync() + self.assertEqual(response[0]["is_nullable"], "YES") + + Band.alter().set_null(Band.popularity, boolean=False).run_sync() + response = Band.raw(query).run_sync() + self.assertEqual(response[0]["is_nullable"], "NO") + + +@engines_only("postgres", "cockroach", "mysql") class TestSetLength(DBTestCase): + @engines_skip("mysql") def test_set_length(self): query = """ SELECT character_maximum_length FROM information_schema.columns @@ -308,8 +376,22 @@ def test_set_length(self): response = Band.raw(query).run_sync() self.assertEqual(response[0]["character_maximum_length"], length) + @engines_only("mysql") + def test_set_length_mysql(self): + query = """ + SELECT character_maximum_length FROM information_schema.columns + WHERE table_name = 'band' + AND table_schema = 'piccolo' + AND column_name = 'name' + """ -@engines_only("postgres", "cockroach") + for length in (5, 20, 50): + Band.alter().set_length(Band.name, length=length).run_sync() + response = Band.raw(query).run_sync() + self.assertEqual(response[0]["CHARACTER_MAXIMUM_LENGTH"], length) + + +@engines_only("postgres", "cockroach", "mysql") class TestSetDefault(DBTestCase): def test_set_default(self): Manager.alter().set_default(Manager.name, "Pending").run_sync() @@ -324,6 +406,13 @@ def test_set_default(self): self.assertEqual(manager.name, "Pending") +@engines_only("mysql") +class TestSetDefaultMysql(DBTestCase): + def test_set_default_text_or_json(self): + with self.assertRaises(ValueError): + Poster.alter().set_default(Poster.content, "Content").run_sync() + + @engines_only("postgres", "cockroach") class TestSetSchema(TestCase): schema_manager = SchemaManager() @@ -417,3 +506,28 @@ def test_set_digits(self): response = Ticket.raw(query).run_sync() self.assertIsNone(response[0]["numeric_precision"]) self.assertIsNone(response[0]["numeric_scale"]) + + @engines_only("mysql") + def test_set_digits_mysql(self): + query = """ + SELECT numeric_precision, numeric_scale + FROM information_schema.columns + WHERE table_name = 'ticket' + AND table_schema = 'piccolo' + AND column_name = 'price' + """ + + Ticket.alter().set_digits( + column=Ticket.price, digits=(6, 2) + ).run_sync() + response = Ticket.raw(query).run_sync() + self.assertEqual(response[0]["numeric_precision".upper()], 6) + self.assertEqual(response[0]["numeric_scale".upper()], 2) + + Ticket.alter().set_digits(column=Ticket.price, digits=None).run_sync() + response = Ticket.raw(query).run_sync() + # In MySQL, when you create or alter a DECIMAL / NUMERIC column + # without specifying precision and scale, MySQL automatically + # assigns a default which is DECIMAL(10,0) + self.assertEqual(response[0]["numeric_precision".upper()], 10) + self.assertEqual(response[0]["numeric_scale".upper()], 0) diff --git a/tests/table/test_create.py b/tests/table/test_create.py index 7dd936e59..2703122e5 100644 --- a/tests/table/test_create.py +++ b/tests/table/test_create.py @@ -3,7 +3,7 @@ from piccolo.columns import Varchar from piccolo.schema import SchemaManager from piccolo.table import Table -from tests.base import engines_only +from tests.base import engines_only, engines_skip from tests.example_apps.music.tables import Manager @@ -31,6 +31,7 @@ def test_create_table_with_indexes(self): index_name = BandMember._get_index_name(["name"]) self.assertIn(index_name, index_names) + @engines_skip("mysql") def test_create_if_not_exists_with_indexes(self): """ Make sure that if the same table is created again, with the diff --git a/tests/table/test_delete.py b/tests/table/test_delete.py index 218acd458..5f9baed94 100644 --- a/tests/table/test_delete.py +++ b/tests/table/test_delete.py @@ -1,7 +1,12 @@ import pytest from piccolo.query.methods.delete import DeletionError -from tests.base import DBTestCase, engine_version_lt, is_running_sqlite +from tests.base import ( + DBTestCase, + engine_version_lt, + is_running_mysql, + is_running_sqlite, +) from tests.example_apps.music.tables import Band @@ -16,6 +21,7 @@ def test_delete(self): self.assertEqual(response, 0) @pytest.mark.skipif( + is_running_mysql(), is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", ) diff --git a/tests/table/test_insert.py b/tests/table/test_insert.py index 19c1b0acb..1e46d9d2b 100644 --- a/tests/table/test_insert.py +++ b/tests/table/test_insert.py @@ -11,6 +11,7 @@ DBTestCase, engine_version_lt, engines_only, + engines_skip, is_running_sqlite, ) from tests.example_apps.music.tables import Band, Manager @@ -63,6 +64,7 @@ def test_insert_curly_braces(self): is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", ) + @engines_skip("mysql") def test_insert_returning(self): """ Make sure update works with the `returning` clause. @@ -79,6 +81,7 @@ def test_insert_returning(self): is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", ) + @engines_skip("mysql") def test_insert_returning_alias(self): """ Make sure update works with the `returning` clause. @@ -112,6 +115,7 @@ def tearDown(self) -> None: Band = self.Band Band.alter().drop_table().run_sync() + @engines_skip("mysql") def test_do_update(self): """ Make sure that `DO UPDATE` works. @@ -139,6 +143,34 @@ def test_do_update(self): ], ) + @engines_only("mysql") + def test_do_update_mysql(self): + """ + Make sure that `DO UPDATE` works in MySQL. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + + Band.insert( + Band(name=self.band.name, popularity=new_popularity) + ).on_conflict( + action="DO UPDATE", + values=[Band.popularity], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": self.band.name, + "popularity": new_popularity, # changed + } + ], + ) + + @engines_skip("mysql") def test_do_update_tuple_values(self): """ Make sure we can use tuples in ``values``. @@ -174,6 +206,42 @@ def test_do_update_tuple_values(self): ], ) + @engines_only("mysql") + def test_do_update_tuple_values_mysql(self): + """ + Make sure we can use tuples in ``values``. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + values=[ + (Band.name, new_name), + (Band.popularity, new_popularity + 2000), + ], + ).run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity + 2000, + } + ], + ) + + @engines_skip("mysql") def test_do_update_no_target(self): """ Make sure that `DO UPDATE` with no `target` raises an exception. @@ -273,6 +341,7 @@ def test_target_string(self): self.assertIn(f'ON CONSTRAINT "{constraint_name}"', query.__str__()) query.run_sync() + @engines_skip("mysql") # MySQL does not support target in conflicts def test_violate_non_target(self): """ Make sure that if we specify a target constraint, but violate a @@ -298,6 +367,7 @@ def test_violate_non_target(self): elif self.Band._meta.db.engine_type == "sqlite": self.assertIsInstance(manager.exception, sqlite3.IntegrityError) + @engines_skip("mysql") # MySQL does not support where in conflicts def test_where(self): """ Make sure we can pass in a `where` argument. @@ -422,7 +492,7 @@ def test_multiple_do_nothing(self): ], ) - @engines_only("postgres", "cockroach") + @engines_only("postgres", "cockroach", "mysql") def test_mutiple_error(self): """ Postgres and Cockroach don't support multiple `ON CONFLICT` clauses. @@ -435,9 +505,11 @@ def test_mutiple_error(self): ).run_sync() assert manager.exception.__str__() == ( - "Postgres and Cockroach only support a single ON CONFLICT clause." + "Postgres, Cockroach and MySQL only support a single " + "ON CONFLICT clause." ) + @engines_skip("mysql") def test_all_columns(self): """ We can use ``all_columns`` instead of specifying the ``values`` @@ -473,6 +545,41 @@ def test_all_columns(self): ], ) + @engines_only("mysql") + def test_all_columns_mysql(self): + """ + We can use ``all_columns`` instead of specifying the ``values`` + manually. + """ + Band = self.Band + + new_popularity = self.band.popularity + 1000 + new_name = "Rustaceans" + + # Conflicting with ID - should be ignored. + q = Band.insert( + Band( + id=self.band.id, + name=new_name, + popularity=new_popularity, + ) + ).on_conflict( + action="DO UPDATE", + values=Band.all_columns(), + ) + q.run_sync() + + self.assertListEqual( + Band.select().run_sync(), + [ + { + "id": self.band.id, + "name": new_name, + "popularity": new_popularity, + } + ], + ) + def test_enum(self): """ A string literal can be passed in, or an enum, to determine the action. diff --git a/tests/table/test_refresh.py b/tests/table/test_refresh.py index ce002bb9a..f2e803419 100644 --- a/tests/table/test_refresh.py +++ b/tests/table/test_refresh.py @@ -1,7 +1,8 @@ from typing import cast +from piccolo.query.functions.string import Concat from piccolo.testing.test_case import TableTest -from tests.base import DBTestCase +from tests.base import DBTestCase, engines_skip from tests.example_apps.music.tables import ( Band, Concert, @@ -27,7 +28,7 @@ def test_refresh(self) -> None: # Modify the data in the database. Band.update( - {Band.name: Band.name + "!!!", Band.popularity: 8000} + {Band.name: Concat(Band.name, "!!!"), Band.popularity: 8000} ).where(Band.name == "Pythonistas").run_sync() # Refresh `band`, and make sure it has the correct data. @@ -94,7 +95,7 @@ def test_columns(self) -> None: # Modify the data in the database. Band.update( - {Band.name: Band.name + "!!!", Band.popularity: 8000} + {Band.name: Concat(Band.name, "!!!"), Band.popularity: 8000} ).where(Band.name == "Pythonistas").run_sync() # Refresh `band`, and make sure it has the correct data. @@ -142,6 +143,7 @@ def test_error_when_pk_in_none(self) -> None: ) +@engines_skip("mysql") class TestRefreshWithPrefetch(TableTest): tables = [Manager, Band, Concert, Venue] @@ -257,6 +259,7 @@ def test_exception(self) -> None: self.concert.refresh(columns=[Concert.band_1]).run_sync() +@engines_skip("mysql") class TestRefreshWithLoadJSON(TableTest): tables = [RecordingStudio] diff --git a/tests/table/test_select.py b/tests/table/test_select.py index 1feac507f..20273afff 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -18,6 +18,7 @@ engine_version_lt, engines_only, is_running_cockroach, + is_running_mysql, is_running_sqlite, sqlite_only, ) @@ -749,7 +750,7 @@ def test_avg(self): response = Band.select(Avg(Band.popularity)).first().run_sync() assert response is not None - self.assertEqual(float(response["avg"]), 1003.3333333333334) + self.assertEqual(float(round(response["avg"], 4)), 1003.3333) def test_avg_alias(self): self.insert_rows() @@ -761,7 +762,9 @@ def test_avg_alias(self): ) assert response is not None - self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) + self.assertEqual( + float(round(response["popularity_avg"], 4)), 1003.3333 + ) def test_avg_as_alias_method(self): self.insert_rows() @@ -773,7 +776,9 @@ def test_avg_as_alias_method(self): ) assert response is not None - self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) + self.assertEqual( + float(round(response["popularity_avg"], 4)), 1003.3333 + ) def test_avg_with_where_clause(self): self.insert_rows() @@ -975,7 +980,7 @@ def test_chain_different_functions(self): ) assert response is not None - self.assertEqual(float(response["avg"]), 1003.3333333333334) + self.assertEqual(float(round(response["avg"], 4)), 1003.3333) self.assertEqual(response["sum"], 3010) def test_chain_different_functions_alias(self): @@ -991,7 +996,9 @@ def test_chain_different_functions_alias(self): ) assert response is not None - self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) + self.assertEqual( + float(round(response["popularity_avg"], 4)), 1003.3333 + ) self.assertEqual(response["popularity_sum"], 3010) def test_columns(self): @@ -1082,6 +1089,13 @@ def test_as_alias_with_where_clause(self): "Cockroach raises an error when trying to use the log function." ), ) + @pytest.mark.skipif( + is_running_mysql(), + reason=( + "MySQL uses a different logarithmic function. " + "We should use log10() to get the same result." + ), + ) def test_select_raw(self): """ Make sure ``SelectRaw`` can be used in select queries. @@ -1094,6 +1108,20 @@ def test_select_raw(self): response, [{"name": "Pythonistas", "popularity_log": 3.0}] ) + @engines_only("mysql") + def test_select_raw_mysql(self): + """ + Make sure ``SelectRaw`` can be used in select queries. + We get the same results as Postgres. + """ + self.insert_row() + response = Band.select( + Band.name, SelectRaw("round(log10(popularity)) AS popularity_log") + ).run_sync() + self.assertListEqual( + response, [{"name": "Pythonistas", "popularity_log": 3.0}] + ) + @pytest.mark.skipif( is_running_sqlite(), reason="SQLite doesn't support SELECT ... FOR UPDATE.", diff --git a/tests/table/test_table_exists.py b/tests/table/test_table_exists.py index 6b31afa00..cc54033ab 100644 --- a/tests/table/test_table_exists.py +++ b/tests/table/test_table_exists.py @@ -23,7 +23,7 @@ class Band(Table, schema="schema_1"): name = Varchar() -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestTableExistsSchema(TestCase): def setUp(self): Band.create_table(auto_create_schema=True).run_sync() diff --git a/tests/table/test_update.py b/tests/table/test_update.py index 35ff6727e..81d03ce4c 100644 --- a/tests/table/test_update.py +++ b/tests/table/test_update.py @@ -15,12 +15,14 @@ Timestamptz, Varchar, ) +from piccolo.query.functions.string import Concat from piccolo.querystring import QueryString from piccolo.table import Table from piccolo.testing.test_case import AsyncTableTest from tests.base import ( DBTestCase, engine_version_lt, + engines_only, engines_skip, is_running_sqlite, sqlite_only, @@ -118,6 +120,7 @@ def test_update_values_with_kwargs(self): is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", ) + @engines_skip("mysql") def test_update_returning(self): """ Make sure update works with the `returning` clause. @@ -137,6 +140,7 @@ def test_update_returning(self): is_running_sqlite() and engine_version_lt(3.35), reason="SQLite version not supported", ) + @engines_skip("mysql") def test_update_returning_alias(self): """ Make sure update works with the `returning` clause. @@ -524,6 +528,309 @@ class OperatorTestCase: ), ] +############################################################################### +# Test operators - MySQL + + +class MyTableMySQL(Table): + integer_col = Integer(null=True) + other_integer_col = Integer(null=True, default=5) + timestamp_col = Timestamp(null=True) + date_col = Date(null=True) + interval_col = Interval(null=True) + varchar_col = Varchar(null=True) + text_col = Text(null=True) + + +@dataclasses.dataclass +class OperatorTestCaseMySQL: + description: str + column: Column + initial: Any + querystring: QueryString + expected: Any + + +TEST_CASES_MYSQL = [ + # Text + OperatorTestCase( + description="Add Text", + column=MyTableMySQL.text_col, + initial="Pythonistas", + querystring=Concat(MyTableMySQL.text_col, "!!!"), + expected="Pythonistas!!!", + ), + OperatorTestCase( + description="Add Text columns", + column=MyTableMySQL.text_col, + initial="Pythonistas", + querystring=Concat(MyTableMySQL.text_col, MyTableMySQL.text_col), + expected="PythonistasPythonistas", + ), + OperatorTestCase( + description="Reverse add Text", + column=MyTableMySQL.text_col, + initial="Pythonistas", + querystring=Concat("!!!", MyTableMySQL.text_col), + expected="!!!Pythonistas", + ), + OperatorTestCase( + description="Text is null", + column=MyTableMySQL.text_col, + initial=None, + querystring=Concat(MyTableMySQL.text_col, "!!!"), + expected=None, + ), + OperatorTestCase( + description="Reverse Text is null", + column=MyTableMySQL.text_col, + initial=None, + querystring=Concat("!!!", MyTableMySQL.text_col), + expected=None, + ), + # Varchar + OperatorTestCase( + description="Add Varchar", + column=MyTableMySQL.varchar_col, + initial="Pythonistas", + querystring=Concat(MyTableMySQL.varchar_col, "!!!"), + expected="Pythonistas!!!", + ), + OperatorTestCase( + description="Add Varchar columns", + column=MyTableMySQL.varchar_col, + initial="Pythonistas", + querystring=Concat(MyTableMySQL.varchar_col, MyTableMySQL.varchar_col), + expected="PythonistasPythonistas", + ), + OperatorTestCase( + description="Reverse add Varchar", + column=MyTableMySQL.varchar_col, + initial="Pythonistas", + querystring=Concat("!!!", MyTableMySQL.varchar_col), + expected="!!!Pythonistas", + ), + OperatorTestCase( + description="Varchar is null", + column=MyTableMySQL.varchar_col, + initial=None, + querystring=Concat(MyTableMySQL.varchar_col, "!!!"), + expected=None, + ), + OperatorTestCase( + description="Reverse Varchar is null", + column=MyTableMySQL.varchar_col, + initial=None, + querystring=Concat("!!!", MyTableMySQL.varchar_col), + expected=None, + ), + # Integer + OperatorTestCase( + description="Add Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col + 10, + expected=1010, + ), + OperatorTestCase( + description="Reverse add Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=10 + MyTableMySQL.integer_col, + expected=1010, + ), + OperatorTestCase( + description="Add Integer colums together", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col + MyTableMySQL.integer_col, + expected=2000, + ), + OperatorTestCase( + description="Subtract Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col - 10, + expected=990, + ), + OperatorTestCase( + description="Reverse subtract Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=2000 - MyTableMySQL.integer_col, + expected=1000, + ), + OperatorTestCase( + description="Subtract Integer Columns", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col - MyTableMySQL.other_integer_col, + expected=995, + ), + OperatorTestCase( + description="Add Integer Columns", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col + MyTableMySQL.other_integer_col, + expected=1005, + ), + OperatorTestCase( + description="Multiply Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col * 2, + expected=2000, + ), + OperatorTestCase( + description="Reverse multiply Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=2 * MyTableMySQL.integer_col, + expected=2000, + ), + OperatorTestCase( + description="Divide Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=MyTableMySQL.integer_col / 10, + expected=100, + ), + OperatorTestCase( + description="Reverse divide Integer", + column=MyTableMySQL.integer_col, + initial=1000, + querystring=2000 / MyTableMySQL.integer_col, + expected=2, + ), + OperatorTestCase( + description="Integer is null", + column=MyTableMySQL.integer_col, + initial=None, + querystring=MyTableMySQL.integer_col + 1, + expected=None, + ), + OperatorTestCase( + description="Reverse Integer is null", + column=MyTableMySQL.integer_col, + initial=None, + querystring=1 + MyTableMySQL.integer_col, + expected=None, + ), + # Timestamp + OperatorTestCase( + description="Add Timestamp", + column=MyTableMySQL.timestamp_col, + initial=INITIAL_DATETIME, + querystring=MyTableMySQL.timestamp_col + DATETIME_DELTA, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + ), + ), + OperatorTestCase( + description="Reverse add Timestamp", + column=MyTableMySQL.timestamp_col, + initial=INITIAL_DATETIME, + querystring=DATETIME_DELTA + MyTableMySQL.timestamp_col, + expected=datetime.datetime( + year=2022, + month=1, + day=2, + hour=22, + minute=1, + second=30, + microsecond=1000, + ), + ), + OperatorTestCase( + description="Subtract Timestamp", + column=MyTableMySQL.timestamp_col, + initial=INITIAL_DATETIME, + querystring=MyTableMySQL.timestamp_col - DATETIME_DELTA, + expected=datetime.datetime( + year=2021, + month=12, + day=31, + hour=19, + minute=58, + second=29, + microsecond=999000, + ), + ), + OperatorTestCase( + description="Timestamp is null", + column=MyTableMySQL.timestamp_col, + initial=None, + querystring=MyTableMySQL.timestamp_col + DATETIME_DELTA, + expected=None, + ), + # Date + OperatorTestCase( + description="Add Date", + column=MyTableMySQL.date_col, + initial=INITIAL_DATETIME, + querystring=MyTableMySQL.date_col + DATE_DELTA, + expected=datetime.date(year=2022, month=1, day=2), + ), + OperatorTestCase( + description="Reverse add Date", + column=MyTableMySQL.date_col, + initial=INITIAL_DATETIME, + querystring=DATE_DELTA + MyTableMySQL.date_col, + expected=datetime.date(year=2022, month=1, day=2), + ), + OperatorTestCase( + description="Subtract Date", + column=MyTableMySQL.date_col, + initial=INITIAL_DATETIME, + querystring=MyTableMySQL.date_col - DATE_DELTA, + expected=datetime.date(year=2021, month=12, day=31), + ), + OperatorTestCase( + description="Date is null", + column=MyTableMySQL.date_col, + initial=None, + querystring=MyTableMySQL.date_col + DATE_DELTA, + expected=None, + ), + # Interval + OperatorTestCase( + description="Add Interval", + column=MyTableMySQL.interval_col, + initial=INITIAL_INTERVAL, + querystring=MyTableMySQL.interval_col + DATETIME_DELTA, + expected=datetime.timedelta(days=2, seconds=7350, microseconds=1000), + ), + OperatorTestCase( + description="Reverse add Interval", + column=MyTableMySQL.interval_col, + initial=INITIAL_INTERVAL, + querystring=DATETIME_DELTA + MyTableMySQL.interval_col, + expected=datetime.timedelta(days=2, seconds=7350, microseconds=1000), + ), + OperatorTestCase( + description="Subtract Interval", + column=MyTableMySQL.interval_col, + initial=INITIAL_INTERVAL, + querystring=MyTableMySQL.interval_col - DATETIME_DELTA, + expected=datetime.timedelta( + days=-1, seconds=86369, microseconds=999000 + ), + ), + OperatorTestCase( + description="Interval is null", + column=MyTableMySQL.interval_col, + initial=None, + querystring=MyTableMySQL.interval_col + DATETIME_DELTA, + expected=None, + ), +] + class TestOperators(TestCase): def setUp(self): @@ -532,7 +839,7 @@ def setUp(self): def tearDown(self): MyTable.alter().drop_table().run_sync() - @engines_skip("cockroach") + @engines_skip("cockroach", "mysql") def test_operators(self): for test_case in TEST_CASES: print(test_case.description) @@ -560,6 +867,42 @@ def test_operators(self): # Clean up MyTable.delete(force=True).run_sync() + +class TestOperatorsMySQL(TestCase): + def setUp(self): + MyTableMySQL.create_table().run_sync() + + def tearDown(self): + MyTableMySQL.alter().drop_table().run_sync() + + @engines_only("mysql") + def test_operators(self): + for test_case in TEST_CASES_MYSQL: + print(test_case.description) + + # Create the initial data in the database. + instance = MyTableMySQL() + setattr(instance, test_case.column._meta.name, test_case.initial) + instance.save().run_sync() + + # Apply the update. + MyTableMySQL.update( + {test_case.column: test_case.querystring}, force=True + ).run_sync() + + # Make sure the value returned from the database is correct. + new_value = getattr( + MyTableMySQL.objects().first().run_sync(), + test_case.column._meta.name, + ) + + self.assertEqual( + new_value, test_case.expected, msg=test_case.description + ) + + # Clean up + MyTableMySQL.delete(force=True).run_sync() + @sqlite_only def test_edge_cases(self): """ diff --git a/tests/table/test_update_self.py b/tests/table/test_update_self.py index c06afe708..259fd69f5 100644 --- a/tests/table/test_update_self.py +++ b/tests/table/test_update_self.py @@ -1,7 +1,9 @@ from piccolo.testing.test_case import AsyncTableTest +from tests.base import engines_skip from tests.example_apps.music.tables import Band, Manager +@engines_skip("mysql") class TestUpdateSelf(AsyncTableTest): tables = [Band, Manager] diff --git a/tests/test_schema.py b/tests/test_schema.py index d8ec3d481..15cfbbbe1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -9,7 +9,7 @@ class Band(Table, schema="schema_1"): pass -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestListTables(TestCase): def setUp(self): Band.create_table().run_sync() @@ -30,7 +30,7 @@ def test_list_tables(self): self.assertListEqual(table_list, [Band._meta.tablename]) -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestCreateAndDrop(TestCase): def test_create_and_drop(self): """ @@ -48,7 +48,7 @@ def test_create_and_drop(self): self.assertNotIn(schema_name, manager.list_schemas().run_sync()) -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestMoveTable(TestCase): new_schema = "schema_2" @@ -87,7 +87,7 @@ def test_move_table(self): ) -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestRenameSchema(TestCase): manager = SchemaManager() schema_name = "test_schema" @@ -116,7 +116,7 @@ def test_rename_schema(self): ) -@engines_skip("sqlite") +@engines_skip("sqlite", "mysql") class TestDDL(TestCase): manager = SchemaManager() diff --git a/tests/testing/test_test_case.py b/tests/testing/test_test_case.py index 963a3c371..72fe090a8 100644 --- a/tests/testing/test_test_case.py +++ b/tests/testing/test_test_case.py @@ -8,6 +8,7 @@ AsyncTransactionTest, TableTest, ) +from tests.base import engines_skip from tests.example_apps.music.tables import Band, Manager @@ -48,6 +49,7 @@ async def test_transaction_exists(self): @pytest.mark.skipif(sys.version_info <= (3, 11), reason="Python 3.11 required") +@engines_skip("mysql") class TestAsyncTransactionRolledBack(AsyncTransactionTest): """ Make sure that the changes get rolled back automatically. diff --git a/tests/utils/test_lazy_loader.py b/tests/utils/test_lazy_loader.py index 32a6be15b..f638e1fab 100644 --- a/tests/utils/test_lazy_loader.py +++ b/tests/utils/test_lazy_loader.py @@ -1,7 +1,7 @@ from unittest import TestCase, mock from piccolo.utils.lazy_loader import LazyLoader -from tests.base import engines_only, sqlite_only +from tests.base import engines_only, mysql_only, sqlite_only class TestLazyLoader(TestCase): @@ -25,3 +25,12 @@ def test_lazy_loader_aiosqlite_exception(self): module.side_effect = ModuleNotFoundError() with self.assertRaises(ModuleNotFoundError): lazy_loader._load() + + @mysql_only + def test_lazy_loader_aiomysql_exception(self): + lazy_loader = LazyLoader("aiomysql", globals(), "aiomysql.connect") + + with mock.patch("aiomysql.connect") as module: + module.side_effect = ModuleNotFoundError() + with self.assertRaises(ModuleNotFoundError): + lazy_loader._load()