From c7ab03fcb9e8a9d2b6fffa74c008ffeb02b1aa40 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 6 Feb 2026 14:41:36 -0800 Subject: [PATCH 01/15] Simple PG workflow working --- aws_advanced_python_wrapper/__init__.py | 24 ++- .../sqlalchemy/orm_dialect.py | 139 ++++++++++++++++++ pyproject.toml | 2 + tests/unit/test_sqlalchemy_orm.py | 66 +++++++++ 4 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py create mode 100644 tests/unit/test_sqlalchemy_orm.py diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index fbac66233..7d4dbf38a 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -15,8 +15,20 @@ from logging import DEBUG, getLogger from .cleanup import release_resources +from .driver_info import DriverInfo from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.pep249 import ( + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError +) # PEP249 compliance connect = AwsWrapperConnection.connect @@ -32,9 +44,19 @@ 'set_logger', 'apilevel', 'threadsafety', - 'paramstyle' + 'paramstyle', + 'Error', + 'InterfaceError', + 'DatabaseError', + 'DataError', + 'OperationalError', + 'IntegrityError', + 'InternalError', + 'ProgrammingError', + 'NotSupportedError' ] +__version__ = DriverInfo.DRIVER_VERSION def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py new file mode 100644 index 000000000..498e7adce --- /dev/null +++ b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py @@ -0,0 +1,139 @@ +# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py +from psycopg import Connection +from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg +import re + +class SqlAlchemyOrmPgDialect(PGDialect_psycopg): + """ + SQLAlchemy dialect for AWS Advanced Python Wrapper. + Extends PostgreSQL psycopg dialect with Aurora-aware connection handling. + """ + + name = 'postgresql' + driver = 'aws_wrapper' + + def __init__(self, **kwargs): + # Skip parent's version check since we're a wrapper, not psycopg itself + super(PGDialect_psycopg, self).__init__(**kwargs) + + # Dynamically detect the actual psycopg version we're wrapping to ensure + # SQLAlchemy uses the correct feature set and SQL generation + try: + import psycopg + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + else: + self.psycopg_version = (3, 0, 2) # Minimum supported + except (ImportError, AttributeError): + self.psycopg_version = (3, 0, 2) + + @classmethod + def import_dbapi(cls): + """ + Return the DB-API 2.0 module. + SQLAlchemy calls this to get the driver module. + """ + import aws_advanced_python_wrapper + return aws_advanced_python_wrapper + + def create_connect_args(self, url): + """ + Transform SQLAlchemy URL into connection arguments. + Must include 'target' parameter for the wrapper. + """ + # Extract standard connection parameters + opts = url.translate_connect_args(username='user') + + # Add query string parameters + opts.update(url.query) + + # Add the required 'target' parameter for your wrapper + if 'target' not in opts: + opts['target'] = Connection.connect + + # Return empty args list and kwargs dict + return ([], opts) + + def on_connect(self): + """ + Return a callable that will be executed on new connections. This can be used if we need to set any session-level + parameters. + """ + + def set_session_params(conn): + # Set any Aurora-specific session parameters + cursor = conn.cursor() + try: + # Example: Set statement timeout + cursor.execute("SET statement_timeout = '60s'") + finally: + cursor.close() + + return set_session_params + + def get_isolation_level(self, dbapi_connection): + """Get the current isolation level""" + cursor = dbapi_connection.cursor() + try: + cursor.execute("SHOW transaction_isolation") + val = cursor.fetchone() + if val: + # Extract first element from tuple and format + return val.upper().replace(' ', '_') + return 'READ_COMMITTED' # PostgreSQL's default + finally: + cursor.close() + + def initialize(self, connection): + """ + Override initialization to handle type introspection. + The parent class tries to use TypeInfo.fetch() which requires + a native psycopg connection, not our wrapper. + """ + # Find the AwsWrapperConnection at whatever nesting level + wrapper_conn = self._get_wrapper_connection(connection) + + if wrapper_conn and hasattr(wrapper_conn, 'connection'): + # Get the underlying psycopg connection + underlying_conn = wrapper_conn.connection + + # Temporarily swap the entire connection chain + original_dbapi_conn = connection.connection + connection.connection = underlying_conn + + try: + # Call parent initialization with native psycopg connection + super().initialize(connection) + finally: + # Restore original connection chain + connection.connection = original_dbapi_conn + else: + # If we can't find wrapper or it doesn't expose underlying connection, + # skip type introspection (custom types won't be auto-configured) + pass + + def _get_wrapper_connection(self, connection): + """ + Traverse the connection chain to find AwsWrapperConnection. + Handles variable nesting depths depending on pool configuration. + """ + from aws_advanced_python_wrapper import AwsWrapperConnection + + # Start with the DBAPI connection + current = connection.connection + + # Traverse up to 5 levels deep (reasonable limit) + for _ in range(5): + if isinstance(current, AwsWrapperConnection): + return current + + # Try to go deeper if there's a .connection attribute + if hasattr(current, 'connection'): + current = current.connection + else: + break + + return None diff --git a/pyproject.toml b/pyproject.toml index ffd73d2f4..80d787048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,3 +84,5 @@ filterwarnings = [ 'ignore:Exception during reset or similar:pytest.PytestUnhandledThreadExceptionWarning' ] +[tool.poetry.plugins."sqlalchemy.dialects"] +"postgresql.aws_wrapper" = "aws_advanced_python_wrapper.sqlalchemy.orm_dialect:SqlAlchemyOrmPgDialect" diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py new file mode 100644 index 000000000..7ec7bec4b --- /dev/null +++ b/tests/unit/test_sqlalchemy_orm.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine, Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +class TestSqlAlchemyORM: + def test_basic_workflow(self): + # Step 1: Create engine (connection to database) + engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') + + # Step 2: Define base class for declarative models + Base = declarative_base() + + # Step 3: Define model class (separate from database operations) + class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + + # Step 4: Create tables + Base.metadata.create_all(engine) + + # Step 5: Create session factory + Session = sessionmaker(bind=engine) + + # Step 6: Use session for database operations + session = Session() + + # INSERT - Create new object and add to session + new_user = User(name='John Doe', email='john@example.com') + session.add(new_user) + session.commit() # Explicit commit required + + # SELECT - Query using session + users = session.query(User).filter(User.name == 'John Doe').all() + for user in users: + print(f"{user.name}: {user.email}") + + + # UPDATE - Modify object and commit + user = session.query(User).filter(User.name == "John Doe").first() + user.email = 'newemail@example.com' + session.commit() # Changes tracked by session + + # DELETE - Remove object from session + user_to_delete = session.query(User).filter(User.name == "John Doe").first() + session.delete(user_to_delete) + session.commit() + + # Always close session when done + session.close() From 6e16513087b60f270aba0d3dd9e3c0a8df5d5377 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 6 Feb 2026 15:23:33 -0800 Subject: [PATCH 02/15] Cleanup --- .../sqlalchemy/orm_dialect.py | 84 +++++++++++-------- tests/unit/test_sqlalchemy_orm.py | 46 +++++----- 2 files changed, 68 insertions(+), 62 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py index 498e7adce..c71e6f22a 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py @@ -3,21 +3,30 @@ from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg import re +from aws_advanced_python_wrapper import AwsWrapperConnection + + class SqlAlchemyOrmPgDialect(PGDialect_psycopg): """ - SQLAlchemy dialect for AWS Advanced Python Wrapper. - Extends PostgreSQL psycopg dialect with Aurora-aware connection handling. + SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect. + This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used + directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection + string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]"). """ name = 'postgresql' driver = 'aws_wrapper' def __init__(self, **kwargs): - # Skip parent's version check since we're a wrapper, not psycopg itself + # PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than + # 3.0.2. If we call it, the exception is raised because it mistakenly interprets our driver version as its own. + # As a workaround we call the grandparent __init__ instead of the parent's __init__. + # TODO: since we are calling the grandparent's __init__ instead of the parent's __init__, we should investigate + # whether any important code in the parent's __init__ needs to be executed. super(PGDialect_psycopg, self).__init__(**kwargs) - # Dynamically detect the actual psycopg version we're wrapping to ensure - # SQLAlchemy uses the correct feature set and SQL generation + # Dynamically detect the actual psycopg version installed and set it as self.psycopg_version. Note that setting + # this field before calling super().__init__ does not avoid the issue noted above. try: import psycopg m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__) @@ -26,8 +35,10 @@ def __init__(self, **kwargs): int(x) for x in m.group(1, 2, 3) if x is not None ) else: - self.psycopg_version = (3, 0, 2) # Minimum supported + # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version. + self.psycopg_version = (3, 0, 2) except (ImportError, AttributeError): + # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version. self.psycopg_version = (3, 0, 2) @classmethod @@ -42,7 +53,7 @@ def import_dbapi(cls): def create_connect_args(self, url): """ Transform SQLAlchemy URL into connection arguments. - Must include 'target' parameter for the wrapper. + Must include the 'target' parameter for our wrapper driver. """ # Extract standard connection parameters opts = url.translate_connect_args(username='user') @@ -50,7 +61,7 @@ def create_connect_args(self, url): # Add query string parameters opts.update(url.query) - # Add the required 'target' parameter for your wrapper + # Add the required 'target' parameter for our wrapper if 'target' not in opts: opts['target'] = Connection.connect @@ -62,7 +73,6 @@ def on_connect(self): Return a callable that will be executed on new connections. This can be used if we need to set any session-level parameters. """ - def set_session_params(conn): # Set any Aurora-specific session parameters cursor = conn.cursor() @@ -75,15 +85,13 @@ def set_session_params(conn): return set_session_params def get_isolation_level(self, dbapi_connection): - """Get the current isolation level""" cursor = dbapi_connection.cursor() try: cursor.execute("SHOW transaction_isolation") val = cursor.fetchone() if val: - # Extract first element from tuple and format return val.upper().replace(' ', '_') - return 'READ_COMMITTED' # PostgreSQL's default + return 'READ_COMMITTED' # return Postgres' default isolation level. finally: cursor.close() @@ -91,48 +99,50 @@ def initialize(self, connection): """ Override initialization to handle type introspection. The parent class tries to use TypeInfo.fetch() which requires - a native psycopg connection, not our wrapper. + a native psycopg connection, not AwsWrapperConnection. """ - # Find the AwsWrapperConnection at whatever nesting level - wrapper_conn = self._get_wrapper_connection(connection) - - if wrapper_conn and hasattr(wrapper_conn, 'connection'): - # Get the underlying psycopg connection - underlying_conn = wrapper_conn.connection + # Unwrap SQLAlchemy's connection object + wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) - # Temporarily swap the entire connection chain - original_dbapi_conn = connection.connection - connection.connection = underlying_conn + # Check if wrapper_conn and wrapper_parent expose their underlying connections + if wrapper_conn and hasattr(wrapper_conn, 'connection') and wrapper_parent and hasattr(wrapper_parent.connection, 'connection'): + # Temporarily remove the AwsWrapperConnection from the connection chain + psycopg_conn = wrapper_conn.connection + wrapper_parent.connection = psycopg_conn try: - # Call parent initialization with native psycopg connection super().initialize(connection) finally: - # Restore original connection chain - connection.connection = original_dbapi_conn + # Restore wrapper connection in the connection chain. + wrapper_parent.connection = wrapper_conn else: - # If we can't find wrapper or it doesn't expose underlying connection, - # skip type introspection (custom types won't be auto-configured) + # If unable to swap underlying pscyopg connection, skip type introspection. + # This means custom types (hstore, json, etc.) won't be auto-configured. pass - def _get_wrapper_connection(self, connection): + def _get_wrapper_connection_and_parent(self, connection): """ - Traverse the connection chain to find AwsWrapperConnection. - Handles variable nesting depths depending on pool configuration. - """ - from aws_advanced_python_wrapper import AwsWrapperConnection + Traverse the connection chain to find AwsWrapperConnection and its parent connection. + + Args: + connection: SQLAlchemy Connection object + Returns: + AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None + """ # Start with the DBAPI connection - current = connection.connection + parent = connection + child = connection.connection # Traverse up to 5 levels deep (reasonable limit) for _ in range(5): - if isinstance(current, AwsWrapperConnection): - return current + if isinstance(child, AwsWrapperConnection): + return child, parent # Try to go deeper if there's a .connection attribute - if hasattr(current, 'connection'): - current = current.connection + if hasattr(child, 'connection'): + parent = child + child = child.connection else: break diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index 7ec7bec4b..ffbbcf35a 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -39,28 +39,24 @@ class User(Base): Session = sessionmaker(bind=engine) # Step 6: Use session for database operations - session = Session() - - # INSERT - Create new object and add to session - new_user = User(name='John Doe', email='john@example.com') - session.add(new_user) - session.commit() # Explicit commit required - - # SELECT - Query using session - users = session.query(User).filter(User.name == 'John Doe').all() - for user in users: - print(f"{user.name}: {user.email}") - - - # UPDATE - Modify object and commit - user = session.query(User).filter(User.name == "John Doe").first() - user.email = 'newemail@example.com' - session.commit() # Changes tracked by session - - # DELETE - Remove object from session - user_to_delete = session.query(User).filter(User.name == "John Doe").first() - session.delete(user_to_delete) - session.commit() - - # Always close session when done - session.close() + with Session() as session: + # INSERT - Create new object and add to session + new_user = User(name='John Doe', email='john@example.com') + session.add(new_user) + session.commit() # Explicit commit required + + # SELECT - Query using session + users = session.query(User).filter(User.name == 'John Doe').all() + for user in users: + print(f"{user.name}: {user.email}") + + + # UPDATE - Modify object and commit + user = session.query(User).filter(User.name == "John Doe").first() + user.email = 'newemail@example.com' + session.commit() + + # DELETE - Remove object from session + user_to_delete = session.query(User).filter(User.name == "John Doe").first() + session.delete(user_to_delete) + session.commit() From af9e19f4350f3fc0e5ae0a31d5d331144f303635 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 9 Feb 2026 10:20:28 -0800 Subject: [PATCH 03/15] Fix failover2 wrong writer host --- .../cluster_topology_monitor.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 173cb0014..660b48bae 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -21,7 +21,6 @@ from time import perf_counter_ns from typing import TYPE_CHECKING, Dict, Optional -from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, Topology from aws_advanced_python_wrapper.utils import services_container @@ -345,8 +344,8 @@ def _open_any_connection_and_update_topology(self) -> Topology: self._cluster_id, self._initial_host_info.host) try: - writer_id = self._topology_utils.get_writer_id_if_connected( - conn, self._plugin_service.driver_dialect) + writer_id = self._topology_utils.get_writer_host_if_connected( + conn, self._plugin_service.driver_dialect) if writer_id: self._is_verified_writer_connection = True writer_verified_by_this_thread = True @@ -355,10 +354,9 @@ def _open_any_connection_and_update_topology(self) -> Topology: writer_host_info = self._initial_host_info self._writer_host_info.set(writer_host_info) else: - instance_template = self._get_instance_template(writer_id, conn) - writer_host = instance_template.host.replace("?", writer_id) - port = instance_template.port \ - if instance_template.is_port_specified() \ + writer_host = self._instance_template.host.replace("?", writer_id) + port = self._instance_template.port \ + if self._instance_template.is_port_specified() \ else self._initial_host_info.port writer_host_info = HostInfo( writer_host, From 800b3a746f2c7b8720379053002f5985cfd4a1cb Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 24 Mar 2026 17:24:04 -0700 Subject: [PATCH 04/15] Add mysql-connector SQLAlchemy ORM --- .../sqlalchemy/mysql_orm_dialect.py | 19 + .../{orm_dialect.py => pg_orm_dialect.py} | 4 +- .../sqlalchemy/test_sqlalchemy_basic.py | 993 ++++++++++++++++++ tests/unit/test_sqlalchemy_orm.py | 2 +- 4 files changed, 1015 insertions(+), 3 deletions(-) create mode 100644 aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py rename aws_advanced_python_wrapper/sqlalchemy/{orm_dialect.py => pg_orm_dialect.py} (97%) create mode 100644 tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py new file mode 100644 index 000000000..6d7ff34db --- /dev/null +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -0,0 +1,19 @@ +# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py +from psycopg import Connection +from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector +import re + +from aws_advanced_python_wrapper import AwsWrapperConnection + + +class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): + """ + SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. + This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used + directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection + string passed to create_engine with "mysql+aws_wrapper_mysqlconnector://" ("[name]+[driver]"). + """ + + name = 'mysql' + driver = 'aws_wrapper_mysqlconnector' + diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py similarity index 97% rename from aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py rename to aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index c71e6f22a..c2780b861 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -11,11 +11,11 @@ class SqlAlchemyOrmPgDialect(PGDialect_psycopg): SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect. This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection - string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]"). + string passed to create_engine with "postgresql+aws_wrapper_psycopg://" ("[name]+[driver]"). """ name = 'postgresql' - driver = 'aws_wrapper' + driver = 'aws_wrapper_psycopg' def __init__(self, **kwargs): # PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py new file mode 100644 index 000000000..bea23f156 --- /dev/null +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -0,0 +1,993 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 + +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any + +import pytest +from sqlalchemy.orm import declarative_base, Mapped, mapped_column +from sqlalchemy import Column, Integer, String + +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemy: + TestModel: Any + DataTypeModel: Any + Author: Any + Book: Any + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def sqlalchemy_models(self, sqlalchemy_setup): + """Create SQLAlchemy models after SQLAlchemy is set up""" + + Base = declarative_base() + + class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' + + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String, primary_key=True) + age: Mapped[int] = mapped_column(Integer) + is_active: Mapped[bool] = mapped_column(Bool, server_default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now()) + +''' + class DataTypeModel(models.Model): + """Model for testing various data types""" + # String fields + char_field = models.CharField(max_length=255, null=True, blank=True) + text_field = models.TextField(null=True, blank=True) + + # Numeric fields + integer_field = models.IntegerField(null=True, blank=True) + big_integer_field = models.BigIntegerField(null=True, blank=True) + decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) + float_field = models.FloatField(null=True, blank=True) + + # Boolean field + boolean_field = models.BooleanField(default=False) + + # Date/Time fields + date_field = models.DateField(null=True, blank=True) + time_field = models.TimeField(null=True, blank=True) + datetime_field = models.DateTimeField(null=True, blank=True) + + # JSON field (MySQL 5.7+) + json_field = models.JSONField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_data_type_model' + + class Author(models.Model): + """Author model for relationship testing""" + name = models.CharField(max_length=100) + email = models.EmailField() + birth_date = models.DateField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_author' + + # Store Author first so it's available for Book's ForeignKey + TestDjango.Author = Author + + class Book(models.Model): + """Book model for relationship testing""" + title = models.CharField(max_length=200) + author = models.ForeignKey(TestDjango.Author, on_delete=models.CASCADE, related_name='books') + publication_date = models.DateField() + pages = models.IntegerField() + price = models.DecimalField(max_digits=8, decimal_places=2) + + class Meta: + app_label = 'test_app' + db_table = 'django_book' + + # Store models as class attributes for easy access + TestDjango.TestModel = TestModel + TestDjango.DataTypeModel = DataTypeModel + TestDjango.Book = Book + + # Create tables for our test models + with connection.schema_editor() as schema_editor: + schema_editor.create_model(TestModel) + schema_editor.create_model(DataTypeModel) + schema_editor.create_model(Author) + schema_editor.create_model(Book) + + yield + + # Clean up tables + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(Book) + schema_editor.delete_model(Author) + schema_editor.delete_model(DataTypeModel) + schema_editor.delete_model(TestModel) + + @pytest.fixture(scope='class') + def django_setup(self, conn_utils): + """Setup Django configuration for testing""" + # Configure Django settings + if not settings.configured: + db_config = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': 'failover_v2,aurora_connection_tracker', + 'connect_timeout': 10, + 'autocommit': True, + }, + } + + settings.configure( + DEBUG=True, + DATABASES={'default': db_config}, + INSTALLED_APPS=[ + 'django.contrib.contenttypes', + 'django.contrib.auth', + ], + SECRET_KEY='test-secret-key-for-django-tests', + USE_TZ=True, + ) + + django.setup() + setup_test_environment() + + yield + connections.close_all() + + teardown_test_environment() + + def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): + """Test Django backend configuration with empty plugins""" + # Verify that the connection is using the AWS wrapper + assert hasattr(connection, 'connection') + + # Test basic connection functionality + assert self.TestModel.objects.count() == 0 + + def test_django_basic_model_operations(self, test_environment: TestEnvironment, django_models): + """Test basic Django ORM operations (CRUD)""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create + test_obj = TestModel.objects.create( + name="John Doe", + email="john@example.com", + age=30, + is_active=True + ) + assert test_obj.id is not None + assert test_obj.name == "John Doe" + + # Read + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.name == "John Doe" + assert retrieved_obj.email == "john@example.com" + assert retrieved_obj.age == 30 + assert retrieved_obj.is_active is True + + # Update + retrieved_obj.name = "Jane Doe" + retrieved_obj.age = 25 + retrieved_obj.save() + + updated_obj = TestModel.objects.get(id=test_obj.id) + assert updated_obj.name == "Jane Doe" + assert updated_obj.age == 25 + + # Delete + updated_obj.delete() + assert TestModel.objects.filter(id=test_obj.id).count() == 0 + + def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): + """Test Django QuerySet operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test filtering + active_users = TestModel.objects.filter(is_active=True) + assert active_users.count() == 2 + + # Test ordering + ordered_users = TestModel.objects.order_by('age') + ages = [user.age for user in ordered_users] + assert ages == [25, 30, 35] + + # Test complex queries + young_active_users = TestModel.objects.filter(age__lt=30, is_active=True) + assert young_active_users.count() == 1 + assert young_active_users.first().name == "Alice" + + # Test exclude + non_bob_users = TestModel.objects.exclude(name="Bob") + assert non_bob_users.count() == 2 + + # Test exists + assert TestModel.objects.filter(name="Alice").exists() + assert not TestModel.objects.filter(name="David").exists() + + # Clean up + TestModel.objects.all().delete() + + def test_django_data_types(self, test_environment: TestEnvironment, django_models): + """Test Django ORM with various data types""" + DataTypeModel = self.DataTypeModel + + # Ensure clean slate + DataTypeModel.objects.all().delete() + + # Create test data with various data types + test_datetime = datetime(2023, 12, 25, 14, 30, 0) + test_datetime_aware = timezone.make_aware(test_datetime) + + test_data = DataTypeModel.objects.create( + char_field="Test String", + text_field="This is a longer text field content", + integer_field=42, + big_integer_field=9223372036854775807, + decimal_field=Decimal('123.45'), + float_field=3.14159, + boolean_field=True, + date_field=date(2023, 12, 25), + time_field=time(14, 30, 0), + datetime_field=test_datetime_aware, # Use timezone-aware datetime + json_field={"key": "value", "number": 123, "array": [1, 2, 3]} + ) + + # Retrieve and verify data + retrieved = DataTypeModel.objects.get(id=test_data.id) + + assert retrieved.char_field == "Test String" + assert retrieved.text_field == "This is a longer text field content" + assert retrieved.integer_field == 42 + assert retrieved.big_integer_field == 9223372036854775807 + assert retrieved.decimal_field == Decimal('123.45') + assert abs(retrieved.float_field - 3.14159) < 0.001 + assert retrieved.boolean_field is True + assert retrieved.date_field == date(2023, 12, 25) + assert retrieved.time_field == time(14, 30, 0) + # Compare timezone-aware datetimes + assert retrieved.datetime_field == test_datetime_aware + assert retrieved.json_field == {"key": "value", "number": 123, "array": [1, 2, 3]} + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_null_values(self, test_environment: TestEnvironment, django_models): + """Test Django ORM handling of NULL values""" + DataTypeModel = self.DataTypeModel + + # First, ensure we start with a clean slate + DataTypeModel.objects.all().delete() + + # Create object with NULL values + test_obj = DataTypeModel.objects.create( + char_field=None, + integer_field=None, + date_field=None, + boolean_field=False # This field has default=False, so it won't be NULL + ) + + # Retrieve and verify NULL values + retrieved = DataTypeModel.objects.get(id=test_obj.id) + assert retrieved.char_field is None + assert retrieved.integer_field is None + assert retrieved.date_field is None + assert retrieved.boolean_field is False + + # Test filtering with NULL values + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 0 + + # Create an object with non-NULL values to test the opposite + DataTypeModel.objects.create( + char_field="Not NULL", + integer_field=42, + date_field=date(2023, 1, 1) + ) + + # Now test filtering again + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 # Still one NULL object + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 1 # Now one non-NULL object + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_relationships(self, test_environment: TestEnvironment, django_models): + """Test Django ORM relationships (ForeignKey)""" + Author = self.Author + Book = self.Book + + # Create author + author = Author.objects.create( + name="J.K. Rowling", + email="jk@example.com", + birth_date=date(1965, 7, 31) + ) + + # Create books + book1 = Book.objects.create( + title="Harry Potter and the Philosopher's Stone", + author=author, + publication_date=date(1997, 6, 26), + pages=223, + price=Decimal('12.99') + ) + + book2 = Book.objects.create( + title="Harry Potter and the Chamber of Secrets", + author=author, + publication_date=date(1998, 7, 2), + pages=251, + price=Decimal('13.99') + ) + + # Test forward relationship + assert book1.author.name == "J.K. Rowling" + assert book2.author.email == "jk@example.com" + + # Test reverse relationship + author_books = author.books.all() + assert author_books.count() == 2 + book_titles = [book.title for book in author_books.order_by('publication_date')] + assert "Harry Potter and the Philosopher's Stone" in book_titles + assert "Harry Potter and the Chamber of Secrets" in book_titles + + # Test related queries + books_by_author = Book.objects.filter(author__name="J.K. Rowling") + assert books_by_author.count() == 2 + + # Test select_related for optimization + book_with_author = Book.objects.select_related('author').get(id=book1.id) + assert book_with_author.author.name == "J.K. Rowling" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_aggregations(self, test_environment: TestEnvironment, django_models): + """Test Django ORM aggregations""" + Author = self.Author + Book = self.Book + + # Create test data + author = Author.objects.create(name="Test Author", email="test@example.com") + + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test aggregations + stats = Book.objects.aggregate( + total_books=Count('id'), + total_pages=Sum('pages'), + avg_price=Avg('price'), + max_pages=Max('pages'), + min_price=Min('price') + ) + + assert stats['total_books'] == 3 + assert stats['total_pages'] == 600 + assert abs(float(stats['avg_price']) - 20.0) < 0.01 + assert stats['max_pages'] == 300 + assert stats['min_price'] == Decimal('10.00') + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_transactions(self, test_environment: TestEnvironment, django_models): + """Test Django transaction handling""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + initial_count = TestModel.objects.count() + + # Test successful transaction + with transaction.atomic(): + TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + + assert TestModel.objects.count() == initial_count + 2 + + # Test rollback transaction + try: + with transaction.atomic(): + TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + TestModel.objects.create(name="User 4", email="user4@example.com", age=40) + # Force an error to trigger rollback + raise Exception("Force rollback") + except Exception: + pass # Expected exception + + # Should still have only 2 additional records (rollback occurred) + assert TestModel.objects.count() == initial_count + 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_bulk_operations(self, test_environment: TestEnvironment, django_models): + """Test Django bulk operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test bulk_create + test_objects = [ + TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) + for i in range(10) + ] + + created_objects = TestModel.objects.bulk_create(test_objects) + assert len(created_objects) == 10 + assert TestModel.objects.count() == 10 + + # Test bulk_update - need to get the objects first and modify them + objects_to_update = list(TestModel.objects.all()) + for obj in objects_to_update: + obj.age += 5 + + TestModel.objects.bulk_update(objects_to_update, ['age']) + + # Verify updates - get fresh objects from database + ages = list(TestModel.objects.values_list('age', flat=True).order_by('name')) + expected_ages = [25 + i for i in range(10)] # 20+i+5 for i in range(10) + assert ages == expected_ages + + # Clean up + TestModel.objects.all().delete() + + def test_django_complex_queries(self, test_environment: TestEnvironment, django_models): + """Test complex Django queries with Q objects and F expressions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=28, is_active=True) + + # Test Q objects for complex conditions + complex_query = TestModel.objects.filter( + Q(age__gte=30) | Q(name__startswith='A') + ) + assert complex_query.count() == 3 # Bob (30), Charlie (35), Alice (starts with A) + + # Test F expressions + TestModel.objects.filter(age__lt=30).update(age=F('age') + 5) + + # Verify F expression update + alice = TestModel.objects.get(name="Alice") + david = TestModel.objects.get(name="David") + assert alice.age == 30 # 25 + 5 + assert david.age == 33 # 28 + 5 + + # Clean up, might get a failover error from this connection + TestModel.objects.all().delete() + + def test_django_raw_sql_queries(self, test_environment: TestEnvironment, django_models): + """Test Django raw SQL query execution""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test raw() method + raw_results = TestModel.objects.raw( + f'SELECT * FROM {TestModel._meta.db_table} WHERE age >= %s ORDER BY age', + [30] + ) + raw_list = list(raw_results) + assert len(raw_list) == 2 + assert raw_list[0].name == "Bob" + assert raw_list[1].name == "Charlie" + + # Test connection.cursor() for custom SQL + with connection.cursor() as cursor: + cursor.execute( + f'SELECT name, age FROM {TestModel._meta.db_table} WHERE is_active = %s ORDER BY age', + [True] + ) + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0][0] == "Alice" # name + assert rows[0][1] == 25 # age + assert rows[1][0] == "Charlie" + assert rows[1][1] == 35 + + # Test raw SQL with connection for aggregate + with connection.cursor() as cursor: + cursor.execute(f'SELECT COUNT(*), AVG(age) FROM {TestModel._meta.db_table}') + count, avg_age = cursor.fetchone() + assert count == 3 + assert abs(float(avg_age) - 30.0) < 0.01 + + # Clean up + TestModel.objects.all().delete() + + def test_django_get_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django get_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Test User', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Test User" + assert obj1.age == 25 + + # Test get case (object already exists) + obj2, created2 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Different Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Test User" # Should keep original values + assert obj2.age == 25 + + # Verify only one object exists + assert TestModel.objects.filter(email="test@example.com").count() == 1 + + # Clean up + TestModel.objects.all().delete() + + def test_django_update_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django update_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Initial Name', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Initial Name" + assert obj1.age == 25 + + # Test update case (object already exists) + obj2, created2 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Updated Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Updated Name" # Should be updated + assert obj2.age == 30 + assert obj2.is_active is False + + # Verify only one object exists + assert TestModel.objects.filter(email="update@example.com").count() == 1 + + # Verify the update persisted + retrieved = TestModel.objects.get(email="update@example.com") + assert retrieved.name == "Updated Name" + assert retrieved.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_prefetch_related(self, test_environment: TestEnvironment, django_models): + """Test Django prefetch_related for optimizing queries""" + Author = self.Author + Book = self.Book + + # Create test data + author1 = Author.objects.create(name="Author 1", email="author1@example.com") + author2 = Author.objects.create(name="Author 2", email="author2@example.com") + + Book.objects.create(title="Book 1A", author=author1, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 1B", author=author1, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 2A", author=author2, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test prefetch_related + authors = Author.objects.prefetch_related('books').all() + + # Access related books (should not trigger additional queries due to prefetch) + for author in authors: + books = list(author.books.all()) + if author.name == "Author 1": + assert len(books) == 2 + book_titles = [book.title for book in books] + assert "Book 1A" in book_titles + assert "Book 1B" in book_titles + elif author.name == "Author 2": + assert len(books) == 1 + assert books[0].title == "Book 2A" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_database_functions(self, test_environment: TestEnvironment, django_models): + """Test Django database functions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="alice", email="alice@example.com", age=25) + TestModel.objects.create(name="BOB", email="bob@example.com", age=30) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35) + + # Test Upper function + upper_names = TestModel.objects.annotate(upper_name=Upper('name')).values_list('upper_name', flat=True) + upper_list = list(upper_names) + assert "ALICE" in upper_list + assert "BOB" in upper_list + assert "CHARLIE" in upper_list + + # Test Lower function + lower_names = TestModel.objects.annotate(lower_name=Lower('name')).values_list('lower_name', flat=True) + lower_list = list(lower_names) + assert "alice" in lower_list + assert "bob" in lower_list + assert "charlie" in lower_list + + # Test Length function + name_lengths = TestModel.objects.annotate(name_length=Length('name')).filter(name_length__gte=5) + assert name_lengths.count() == 2 # "alice" (5) and "Charlie" (7) + + # Test Concat function + full_info = TestModel.objects.annotate( + full_info=Concat('name', Value(' - '), 'email', output_field=CharField()) + ).first() + assert ' - ' in full_info.full_info + assert '@example.com' in full_info.full_info + + # Clean up + TestModel.objects.all().delete() + + def test_django_annotations(self, test_environment: TestEnvironment, django_models): + """Test Django annotations with expressions""" + TestModel = self.TestModel + Book = self.Book + Author = self.Author + + # Create test data for TestModel + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test annotate with F expression for calculations + test_with_age_plus_ten = TestModel.objects.annotate( + age_plus_ten=F('age') + 10 + ).order_by('age') + + # Verify calculation + first_obj = test_with_age_plus_ten.first() + assert first_obj.age_plus_ten == first_obj.age + 10 + assert first_obj.age_plus_ten == 35 # 25 + 10 + + # Create books for F expression testing + author = Author.objects.create(name="Test Author", email="test@example.com") + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test annotate with F expression for price per page + books_with_price_per_page = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).order_by('price_per_page') + + # Verify calculation + first_book = books_with_price_per_page.first() + expected_price_per_page = float(first_book.price) / first_book.pages + assert abs(float(first_book.price_per_page) - expected_price_per_page) < 0.001 + + # Test filtering on annotated field - use a lower threshold to avoid precision issues + cheap_books = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).filter(price_per_page__lte=0.15) + assert cheap_books.count() == 3 # All books have price_per_page = 0.10 + + # Clean up + TestModel.objects.all().delete() + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_values_and_values_list(self, test_environment: TestEnvironment, django_models): + """Test Django values() and values_list() methods""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test values() - returns list of dictionaries + values_result = TestModel.objects.values('name', 'age').order_by('age') + values_list = list(values_result) + assert len(values_list) == 3 + assert values_list[0] == {'name': 'Alice', 'age': 25} + assert values_list[1] == {'name': 'Bob', 'age': 30} + assert values_list[2] == {'name': 'Charlie', 'age': 35} + + # Test values_list() - returns list of tuples + values_list_result = TestModel.objects.values_list('name', 'age').order_by('age') + tuples_list = list(values_list_result) + assert len(tuples_list) == 3 + assert tuples_list[0] == ('Alice', 25) + assert tuples_list[1] == ('Bob', 30) + assert tuples_list[2] == ('Charlie', 35) + + # Test values_list() with flat=True - returns flat list + names = TestModel.objects.values_list('name', flat=True).order_by('name') + names_list = list(names) + assert names_list == ['Alice', 'Bob', 'Charlie'] + + # Test values() with filtering + active_users = TestModel.objects.filter(is_active=True).values('name', 'email') + active_list = list(active_users) + assert len(active_list) == 2 + active_names = [user['name'] for user in active_list] + assert 'Alice' in active_names + assert 'Charlie' in active_names + assert 'Bob' not in active_names + + # Clean up + TestModel.objects.all().delete() + + def test_django_distinct_queries(self, test_environment: TestEnvironment, django_models): + """Test Django distinct() functionality""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data with duplicate ages + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=25, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=30, is_active=True) + + # Test distinct ages + distinct_ages = TestModel.objects.values_list('age', flat=True).distinct().order_by('age') + ages_list = list(distinct_ages) + assert ages_list == [25, 30] + + # Test distinct with multiple fields + distinct_age_status = TestModel.objects.values('age', 'is_active').distinct().order_by('age', 'is_active') + distinct_list = list(distinct_age_status) + assert len(distinct_list) == 3 # (25, True), (30, False), (30, True) + + # Test count with distinct + total_count = TestModel.objects.count() + distinct_age_count = TestModel.objects.values('age').distinct().count() + assert total_count == 4 + assert distinct_age_count == 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_only_and_defer(self, test_environment: TestEnvironment, django_models): + """Test Django only() and defer() for query optimization""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj = TestModel.objects.create( + name="Test User", + email="test@example.com", + age=30, + is_active=True + ) + + # Test only() - load only specific fields + obj_only = TestModel.objects.only('name', 'email').get(id=obj.id) + assert obj_only.name == "Test User" + assert obj_only.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_only.age == 30 + + # Test defer() - exclude specific fields from loading + obj_defer = TestModel.objects.defer('age', 'is_active').get(id=obj.id) + assert obj_defer.name == "Test User" + assert obj_defer.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_defer.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_in_bulk(self, test_environment: TestEnvironment, django_models): + """Test Django in_bulk() for batch retrieval""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj1 = TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + obj2 = TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + obj3 = TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + + # Test in_bulk with IDs (default behavior) + bulk_result = TestModel.objects.in_bulk([obj1.id, obj2.id, obj3.id]) + assert len(bulk_result) == 3 + assert bulk_result[obj1.id].name == "User 1" + assert bulk_result[obj2.id].name == "User 2" + assert bulk_result[obj3.id].name == "User 3" + + # Test in_bulk with all IDs (no list provided) + bulk_all = TestModel.objects.in_bulk() + assert len(bulk_all) == 3 + assert obj1.id in bulk_all + assert obj2.id in bulk_all + assert obj3.id in bulk_all + + # Test in_bulk with email field (unique field) + bulk_by_email = TestModel.objects.in_bulk( + ["user1@example.com", "user3@example.com"], + field_name='email' + ) + assert len(bulk_by_email) == 2 + assert bulk_by_email["user1@example.com"].name == "User 1" + assert bulk_by_email["user3@example.com"].name == "User 3" + + # Clean up + TestModel.objects.all().delete() + + def test_django_conditional_expressions(self, test_environment: TestEnvironment, django_models): + """Test Django Case/When conditional expressions""" + from django.db.models import Case, IntegerField, Value, When + + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test Case/When for conditional logic + results = TestModel.objects.annotate( + age_category=Case( + When(age__lt=30, then=Value('young')), + When(age__gte=30, age__lt=40, then=Value('middle')), + default=Value('senior'), + output_field=CharField() + ) + ).order_by('age') + + results_list = list(results) + assert results_list[0].age_category == 'young' # Alice, 25 + assert results_list[1].age_category == 'middle' # Bob, 30 + assert results_list[2].age_category == 'middle' # Charlie, 35 + + # Test Case/When with integer output + priority_results = TestModel.objects.annotate( + priority=Case( + When(is_active=True, age__lt=30, then=Value(1)), + When(is_active=True, then=Value(2)), + When(is_active=False, then=Value(3)), + default=Value(4), + output_field=IntegerField() + ) + ).order_by('priority', 'name') + + priority_list = list(priority_results) + assert priority_list[0].name == 'Alice' # priority 1: active and young + assert priority_list[1].name == 'Charlie' # priority 2: active but not young + assert priority_list[2].name == 'Bob' # priority 3: not active + + # Clean up + TestModel.objects.all().delete() + + def test_django_iterator(self, test_environment: TestEnvironment, django_models): + """Test Django iterator() for memory-efficient queries""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + for i in range(20): + TestModel.objects.create( + name=f"User {i}", + email=f"user{i}@example.com", + age=20 + i + ) + + # Test iterator() - processes results without caching + count = 0 + for obj in TestModel.objects.iterator(): + assert obj.name.startswith("User") + count += 1 + assert count == 20 + + # Test iterator with chunk_size + count = 0 + for obj in TestModel.objects.iterator(chunk_size=5): + assert obj.email.endswith("@example.com") + count += 1 + assert count == 20 + + # Clean up + TestModel.objects.all().delete() + +''' + diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index ffbbcf35a..bf8468acd 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -19,7 +19,7 @@ class TestSqlAlchemyORM: def test_basic_workflow(self): # Step 1: Create engine (connection to database) - engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') + engine = create_engine('mysql+aws_wrapper_mysqlconnector://mysqlmaster:mysqlpassword@database-mysql-ulojonat.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com:3306/mysqldb') # Step 2: Define base class for declarative models Base = declarative_base() From 0033f903ae74b986b8383725037a12c901277a7c Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 25 Mar 2026 11:44:28 -0700 Subject: [PATCH 05/15] Revert connection string in sqlalchemy orm unit test --- tests/unit/test_sqlalchemy_orm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index bf8468acd..70acb6f58 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -19,8 +19,7 @@ class TestSqlAlchemyORM: def test_basic_workflow(self): # Step 1: Create engine (connection to database) - engine = create_engine('mysql+aws_wrapper_mysqlconnector://mysqlmaster:mysqlpassword@database-mysql-ulojonat.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com:3306/mysqldb') - + engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') # Step 2: Define base class for declarative models Base = declarative_base() From 65a7b380ec8380c17d9f7194280c71484320614b Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 26 Mar 2026 15:27:08 -0700 Subject: [PATCH 06/15] Add __init__.py for sqlalchemy integration tests --- .../container/sqlalchemy/__init__.py | 13 ++++++++++ .../sqlalchemy/test_sqlalchemy_basic.py | 24 ++++++++++--------- 2 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 tests/integration/container/sqlalchemy/__init__.py diff --git a/tests/integration/container/sqlalchemy/__init__.py b/tests/integration/container/sqlalchemy/__init__.py new file mode 100644 index 000000000..bd4acb2bf --- /dev/null +++ b/tests/integration/container/sqlalchemy/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index bea23f156..8766dbaeb 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,8 +21,8 @@ from typing import Any import pytest -from sqlalchemy.orm import declarative_base, Mapped, mapped_column -from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy import Column, Integer, String, Boolean, DateTime from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -52,18 +52,20 @@ def rds_utils(self): @pytest.fixture(scope='class') def sqlalchemy_models(self, sqlalchemy_setup): """Create SQLAlchemy models after SQLAlchemy is set up""" + #Base = declarative_base() - Base = declarative_base() + class Base(DeclarativeBase): + pass - class TestModel(Base): - """Basic test model for SQLAlchemy ORM functionality""" - __tablename__ = 'sqlalchemy_test_model' + class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' - name: Mapped[str] = mapped_column(String(100)) - email: Mapped[str] = mapped_column(String, primary_key=True) - age: Mapped[int] = mapped_column(Integer) - is_active: Mapped[bool] = mapped_column(Bool, server_default=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now()) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String, primary_key=True) + age: Mapped[int] = mapped_column(Integer) + is_active: Mapped[bool] = mapped_column(Boolean) + created_at: Mapped[datetime] = mapped_column(DateTime) ''' class DataTypeModel(models.Model): From 6e055b303d96820f64a043811fe99bd6a506b133 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 26 Mar 2026 15:59:06 -0700 Subject: [PATCH 07/15] Fix RdsUtils not being found --- .../cluster_topology_monitor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 660b48bae..173cb0014 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -21,6 +21,7 @@ from time import perf_counter_ns from typing import TYPE_CHECKING, Dict, Optional +from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, Topology from aws_advanced_python_wrapper.utils import services_container @@ -344,8 +345,8 @@ def _open_any_connection_and_update_topology(self) -> Topology: self._cluster_id, self._initial_host_info.host) try: - writer_id = self._topology_utils.get_writer_host_if_connected( - conn, self._plugin_service.driver_dialect) + writer_id = self._topology_utils.get_writer_id_if_connected( + conn, self._plugin_service.driver_dialect) if writer_id: self._is_verified_writer_connection = True writer_verified_by_this_thread = True @@ -354,9 +355,10 @@ def _open_any_connection_and_update_topology(self) -> Topology: writer_host_info = self._initial_host_info self._writer_host_info.set(writer_host_info) else: - writer_host = self._instance_template.host.replace("?", writer_id) - port = self._instance_template.port \ - if self._instance_template.is_port_specified() \ + instance_template = self._get_instance_template(writer_id, conn) + writer_host = instance_template.host.replace("?", writer_id) + port = instance_template.port \ + if instance_template.is_port_specified() \ else self._initial_host_info.port writer_host_info = HostInfo( writer_host, From 49bbeea09e730f2eab4e59ce4901a11f97e5bb12 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 1 Apr 2026 18:21:14 -0700 Subject: [PATCH 08/15] Translate basic django test to sqlalchemy --- pyproject.toml | 3 +- .../sqlalchemy/test_sqlalchemy_basic.py | 235 +++++++----------- .../container/utils/test_database_info.py | 2 +- .../utils/test_environment_request.py | 2 +- 4 files changed, 96 insertions(+), 146 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80d787048..7cf7f9c61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,4 +85,5 @@ filterwarnings = [ ] [tool.poetry.plugins."sqlalchemy.dialects"] -"postgresql.aws_wrapper" = "aws_advanced_python_wrapper.sqlalchemy.orm_dialect:SqlAlchemyOrmPgDialect" +"postgresql.aws_wrapper_psycopg" = "aws_advanced_python_wrapper.sqlalchemy.pg_orm_dialect:SqlAlchemyOrmPgDialect" +"mysql.aws_wrapper_mysqlconnector" = "aws_advanced_python_wrapper.sqlalchemy.mysql_orm_dialect:SqlAlchemyOrmMysqlDialect" diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 8766dbaeb..ad080230a 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,8 +21,8 @@ from typing import Any import pytest -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import Column, Integer, String, Boolean, DateTime +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy import create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -39,146 +39,94 @@ TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) class TestSqlAlchemy: - TestModel: Any - DataTypeModel: Any - Author: Any - Book: Any - @pytest.fixture(scope='class') def rds_utils(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) - @pytest.fixture(scope='class') - def sqlalchemy_models(self, sqlalchemy_setup): - """Create SQLAlchemy models after SQLAlchemy is set up""" - #Base = declarative_base() - - class Base(DeclarativeBase): - pass + Base = declarative_base() class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" __tablename__ = 'sqlalchemy_test_model' - name: Mapped[str] = mapped_column(String(100)) - email: Mapped[str] = mapped_column(String, primary_key=True) - age: Mapped[int] = mapped_column(Integer) - is_active: Mapped[bool] = mapped_column(Boolean) - created_at: Mapped[datetime] = mapped_column(DateTime) - -''' - class DataTypeModel(models.Model): - """Model for testing various data types""" - # String fields - char_field = models.CharField(max_length=255, null=True, blank=True) - text_field = models.TextField(null=True, blank=True) - - # Numeric fields - integer_field = models.IntegerField(null=True, blank=True) - big_integer_field = models.BigIntegerField(null=True, blank=True) - decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) - float_field = models.FloatField(null=True, blank=True) - - # Boolean field - boolean_field = models.BooleanField(default=False) - - # Date/Time fields - date_field = models.DateField(null=True, blank=True) - time_field = models.TimeField(null=True, blank=True) - datetime_field = models.DateTimeField(null=True, blank=True) - - # JSON field (MySQL 5.7+) - json_field = models.JSONField(null=True, blank=True) - - class Meta: - app_label = 'test_app' - db_table = 'django_data_type_model' - - class Author(models.Model): - """Author model for relationship testing""" - name = models.CharField(max_length=100) - email = models.EmailField() - birth_date = models.DateField(null=True, blank=True) - - class Meta: - app_label = 'test_app' - db_table = 'django_author' - - # Store Author first so it's available for Book's ForeignKey - TestDjango.Author = Author - - class Book(models.Model): - """Book model for relationship testing""" - title = models.CharField(max_length=200) - author = models.ForeignKey(TestDjango.Author, on_delete=models.CASCADE, related_name='books') - publication_date = models.DateField() - pages = models.IntegerField() - price = models.DecimalField(max_digits=8, decimal_places=2) - - class Meta: - app_label = 'test_app' - db_table = 'django_book' - - # Store models as class attributes for easy access - TestDjango.TestModel = TestModel - TestDjango.DataTypeModel = DataTypeModel - TestDjango.Book = Book - - # Create tables for our test models - with connection.schema_editor() as schema_editor: - schema_editor.create_model(TestModel) - schema_editor.create_model(DataTypeModel) - schema_editor.create_model(Author) - schema_editor.create_model(Book) - - yield - - # Clean up tables - with connection.schema_editor() as schema_editor: - schema_editor.delete_model(Book) - schema_editor.delete_model(Author) - schema_editor.delete_model(DataTypeModel) - schema_editor.delete_model(TestModel) - - @pytest.fixture(scope='class') - def django_setup(self, conn_utils): - """Setup Django configuration for testing""" - # Configure Django settings - if not settings.configured: - db_config = { - 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', - 'NAME': conn_utils.dbname, - 'USER': conn_utils.user, - 'PASSWORD': conn_utils.password, - 'HOST': conn_utils.writer_cluster_host, - 'PORT': conn_utils.port, - 'OPTIONS': { - 'plugins': 'failover_v2,aurora_connection_tracker', - 'connect_timeout': 10, - 'autocommit': True, - }, - } - - settings.configure( - DEBUG=True, - DATABASES={'default': db_config}, - INSTALLED_APPS=[ - 'django.contrib.contenttypes', - 'django.contrib.auth', - ], - SECRET_KEY='test-secret-key-for-django-tests', - USE_TZ=True, - ) - - django.setup() - setup_test_environment() - - yield - connections.close_all() - - teardown_test_environment() - + id = Column(Integer, primary_key=True) + + name = Column(String(100)) + email = Column(String, primary_key=True) + age = Column(Integer) + is_active = Column(Boolean) + created_at = Column(DateTime) + + class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' + + id = Column(Integer, primary_key=True) + + # String fields + string_field = Column(String(255)) + text_field = Column(Text) + + # Numeric fields + integer_field = Column(Integer) + small_integer_field = Column(SmallInteger) + big_integer_field = Column(BigInteger) + numeric_field = Column(Numeric) + float_field = Column(Float) + + # Boolean field + boolean_field = Column(Boolean) + + # Date/Time fields + date_field = Column(Date) + time_field = Column(Time) + datetime_field = Column(DateTime) + + # JSON field (MySQL 5.7+) + json_field = Column(JSON) + + class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' + + id = Column(Integer, primary_key=True) + name = Column(String(100)) + email = Column(String) + birth_date = Column(Date) + + class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' + + id = Column(Integer, primary_key=True) + title = Column(String(200)) + author = Column(String, ForeignKey("Author.id")) + publication_date = Column(Date) + pages = Column(Integer) + price = Column(Numeric) + + @pytest.fixture(scope="class") + def engine(self, conn_utils): + conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}' + engine = create_engine(conn_str) + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + + @pytest.fixture(scope="class") + def Session(self, engine): + Session = sessionmaker(bind=engine) + yield Session + + @pytest.fixture(scope="class") + def session(self, Session): + session = Session() + yield session + session.rollback() + session.close() + + ''' def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): """Test Django backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper @@ -186,26 +134,25 @@ def test_django_backend_configuration(self, test_environment: TestEnvironment, d # Test basic connection functionality assert self.TestModel.objects.count() == 0 + ''' - def test_django_basic_model_operations(self, test_environment: TestEnvironment, django_models): + def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): """Test basic Django ORM operations (CRUD)""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() # Create - test_obj = TestModel.objects.create( + test_obj = TestModel( name="John Doe", email="john@example.com", age=30, is_active=True ) + session.add(test_obj) + session.commit() assert test_obj.id is not None assert test_obj.name == "John Doe" # Read - retrieved_obj = TestModel.objects.get(id=test_obj.id) + retrieved_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() assert retrieved_obj.name == "John Doe" assert retrieved_obj.email == "john@example.com" assert retrieved_obj.age == 30 @@ -214,16 +161,18 @@ def test_django_basic_model_operations(self, test_environment: TestEnvironment, # Update retrieved_obj.name = "Jane Doe" retrieved_obj.age = 25 - retrieved_obj.save() + session.commit() - updated_obj = TestModel.objects.get(id=test_obj.id) + updated_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() assert updated_obj.name == "Jane Doe" assert updated_obj.age == 25 # Delete - updated_obj.delete() - assert TestModel.objects.filter(id=test_obj.id).count() == 0 + session.delete(updated_obj) + session.commit() + assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 +''' def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): """Test Django QuerySet operations""" TestModel = self.TestModel diff --git a/tests/integration/container/utils/test_database_info.py b/tests/integration/container/utils/test_database_info.py index a1b3a0944..edc49faf0 100644 --- a/tests/integration/container/utils/test_database_info.py +++ b/tests/integration/container/utils/test_database_info.py @@ -42,7 +42,7 @@ def __init__(self, database_info: Dict[str, Any]) -> None: self._username = typing.cast('str', database_info.get("username")) self._password = typing.cast('str', database_info.get("password")) - self._default_db_name = typing.cast('str', database_info.get("defaultDbName")) + self._default_db_name = "mysqldb" self._cluster_endpoint = typing.cast('str', database_info.get("clusterEndpoint")) self._cluster_endpoint_port = typing.cast('int', database_info.get("clusterEndpointPort")) self._cluster_read_only_endpoint = typing.cast('str', database_info.get("clusterReadOnlyEndpoint")) diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index db1bbeef2..def700293 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -63,7 +63,7 @@ def get_features(self) -> Set[TestEnvironmentFeatures]: return self._features def get_num_of_instances(self) -> int: - return self._num_of_instances + return 3 def get_display_name(self) -> str: return "Test environment [{0}, {1}, {2}, {3}, {4}, {5}]".format( From 92659f372c67a1676ba460136c814b2a846b8533 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Sat, 4 Apr 2026 16:15:12 -0700 Subject: [PATCH 09/15] Add basic CRUD test for sqlalchemy ORM mysql tests --- .../sqlalchemy/test_sqlalchemy_basic.py | 136 +++++++++--------- 1 file changed, 71 insertions(+), 65 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index ad080230a..49ebd56cd 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -16,13 +16,16 @@ from __future__ import annotations -from datetime import date, datetime, time +from datetime import date, datetime, time, timezone from decimal import Decimal from typing import Any import pytest -from sqlalchemy.orm import declarative_base, sessionmaker -from sqlalchemy import create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from sqlalchemy import ( + create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, + Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON +) from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -32,79 +35,83 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures +Base = declarative_base() -@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented -@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, - TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) -class TestSqlAlchemy: - @pytest.fixture(scope='class') - def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() - return RdsTestUtility(region) +class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' - Base = declarative_base() + id = Column(Integer, primary_key=True) - class TestModel(Base): - """Basic test model for SQLAlchemy ORM functionality""" - __tablename__ = 'sqlalchemy_test_model' + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False, unique=True) + age = Column(Integer, nullable=False) + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.now(timezone.utc)) - id = Column(Integer, primary_key=True) +class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' - name = Column(String(100)) - email = Column(String, primary_key=True) - age = Column(Integer) - is_active = Column(Boolean) - created_at = Column(DateTime) + id = Column(Integer, primary_key=True) - class DataTypeModel(Base): - """Model for testing various data types""" - __tablename__ = 'sqlalchemy_data_type_model' + # String fields + string_field = Column(String(255)) + text_field = Column(Text) - id = Column(Integer, primary_key=True) + # Numeric fields + integer_field = Column(Integer) + small_integer_field = Column(SmallInteger) + big_integer_field = Column(BigInteger) + numeric_field = Column(Numeric(10, 2)) + float_field = Column(Float) - # String fields - string_field = Column(String(255)) - text_field = Column(Text) + # Boolean field + boolean_field = Column(Boolean, default=False) - # Numeric fields - integer_field = Column(Integer) - small_integer_field = Column(SmallInteger) - big_integer_field = Column(BigInteger) - numeric_field = Column(Numeric) - float_field = Column(Float) + # Date/Time fields + date_field = Column(Date) + time_field = Column(Time) + datetime_field = Column(DateTime) - # Boolean field - boolean_field = Column(Boolean) + # JSON field (MySQL 5.7+) + json_field = Column(JSON) - # Date/Time fields - date_field = Column(Date) - time_field = Column(Time) - datetime_field = Column(DateTime) +class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' - # JSON field (MySQL 5.7+) - json_field = Column(JSON) + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False) + birth_date = Column(Date) - class Author(Base): - """Author model for relationship testing""" - __tablename__ = 'sqlalchemy_author' + books = relationship('Book', back_populates='author', cascade='all, delete-orphan') - id = Column(Integer, primary_key=True) - name = Column(String(100)) - email = Column(String) - birth_date = Column(Date) +class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' - class Book(Base): - """Book model for relationship testing""" - __tablename__ = 'sqlalchemy_book' + id = Column(Integer, primary_key=True) + title = Column(String(200), nullable=False) + author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False) + publication_date = Column(Date, nullable=False) + pages = Column(Integer, nullable=False) + price = Column(Numeric(8, 2), nullable=False) + + author = relationship('Author', back_populates='books') + +@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemy: + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) - id = Column(Integer, primary_key=True) - title = Column(String(200)) - author = Column(String, ForeignKey("Author.id")) - publication_date = Column(Date) - pages = Column(Integer) - price = Column(Numeric) @pytest.fixture(scope="class") def engine(self, conn_utils): @@ -152,7 +159,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test assert test_obj.name == "John Doe" # Read - retrieved_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() + retrieved_obj = session.query(TestModel).filter_by(id = test_obj.id).one() assert retrieved_obj.name == "John Doe" assert retrieved_obj.email == "john@example.com" assert retrieved_obj.age == 30 @@ -163,7 +170,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test retrieved_obj.age = 25 session.commit() - updated_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() + updated_obj = session.query(TestModel).filter_by(id = test_obj.id).one() assert updated_obj.name == "Jane Doe" assert updated_obj.age == 25 @@ -172,7 +179,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test session.commit() assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 -''' + ''' def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): """Test Django QuerySet operations""" TestModel = self.TestModel @@ -939,6 +946,5 @@ def test_django_iterator(self, test_environment: TestEnvironment, django_models) # Clean up TestModel.objects.all().delete() - -''' + ''' From 80fb9d053b8aacf08a3a8d401f99e5df10c123f7 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 16:41:43 -0700 Subject: [PATCH 10/15] Add remaining basic MySQL SQLAlchemy ORM tests --- .../sqlalchemy/test_sqlalchemy_basic.py | 1081 +++++++---------- 1 file changed, 470 insertions(+), 611 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 49ebd56cd..822b92ac8 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,10 +21,15 @@ from typing import Any import pytest -from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from sqlalchemy.sql import func +from sqlalchemy.orm import ( + declarative_base, sessionmaker, relationship, Session, joinedload, + subqueryload +) from sqlalchemy import ( create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, - Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON + Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_, + and_, text ) from tests.integration.container.utils.rds_test_utility import RdsTestUtility @@ -133,18 +138,18 @@ def session(self, Session): session.rollback() session.close() - ''' - def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): - """Test Django backend configuration with empty plugins""" + def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine): + """Test SQLAlchemy backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper - assert hasattr(connection, 'connection') + with engine.connect() as connection: + assert connection.connection is not None # Test basic connection functionality - assert self.TestModel.objects.count() == 0 - ''' + with Session(engine) as session: + assert session.query(TestModel).count() == 0 def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): - """Test basic Django ORM operations (CRUD)""" + """Test basic SQLAlchemy ORM operations (CRUD)""" # Create test_obj = TestModel( @@ -179,772 +184,626 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test session.commit() assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 - ''' - def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): - """Test Django QuerySet operations""" - TestModel = self.TestModel - + def test_sqlalchemy_query_operations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy query operations""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() # Test filtering - active_users = TestModel.objects.filter(is_active=True) - assert active_users.count() == 2 - + active_users = session.query(TestModel).filter(TestModel.is_active == True).all() + assert len(active_users) == 2 # Test ordering - ordered_users = TestModel.objects.order_by('age') + ordered_users = session.query(TestModel).order_by(TestModel.age).all() ages = [user.age for user in ordered_users] assert ages == [25, 30, 35] - # Test complex queries - young_active_users = TestModel.objects.filter(age__lt=30, is_active=True) - assert young_active_users.count() == 1 - assert young_active_users.first().name == "Alice" - - # Test exclude - non_bob_users = TestModel.objects.exclude(name="Bob") - assert non_bob_users.count() == 2 - + young_active_users = session.query(TestModel).filter( + TestModel.age < 30, TestModel.is_active == True + ).all() + assert len(young_active_users) == 1 + assert young_active_users[0].name == "Alice" + # Test exclude (using NOT) + non_bob_users = session.query(TestModel).filter(TestModel.name != "Bob").all() + assert len(non_bob_users) == 2 # Test exists - assert TestModel.objects.filter(name="Alice").exists() - assert not TestModel.objects.filter(name="David").exists() - + assert session.query(TestModel).filter(TestModel.name == "Alice").first() is not None + assert session.query(TestModel).filter(TestModel.name == "David").first() is None # Clean up - TestModel.objects.all().delete() - - def test_django_data_types(self, test_environment: TestEnvironment, django_models): - """Test Django ORM with various data types""" - DataTypeModel = self.DataTypeModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_data_types(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy with various data types""" # Ensure clean slate - DataTypeModel.objects.all().delete() - + session.query(DataTypeModel).delete() + session.commit() # Create test data with various data types test_datetime = datetime(2023, 12, 25, 14, 30, 0) - test_datetime_aware = timezone.make_aware(test_datetime) - - test_data = DataTypeModel.objects.create( - char_field="Test String", + test_data = DataTypeModel( + string_field="Test String", text_field="This is a longer text field content", integer_field=42, + small_integer_field=5, big_integer_field=9223372036854775807, - decimal_field=Decimal('123.45'), + numeric_field=Decimal('123.45'), float_field=3.14159, boolean_field=True, date_field=date(2023, 12, 25), time_field=time(14, 30, 0), - datetime_field=test_datetime_aware, # Use timezone-aware datetime - json_field={"key": "value", "number": 123, "array": [1, 2, 3]} + datetime_field=test_datetime, + json_field={"key": "value", "number": 123, "array": [1, 2, 3]}, ) - + session.add(test_data) + session.commit() # Retrieve and verify data - retrieved = DataTypeModel.objects.get(id=test_data.id) - - assert retrieved.char_field == "Test String" + retrieved = session.query(DataTypeModel).get(test_data.id) + assert retrieved.string_field == "Test String" assert retrieved.text_field == "This is a longer text field content" assert retrieved.integer_field == 42 + assert retrieved.small_integer_field == 5 assert retrieved.big_integer_field == 9223372036854775807 - assert retrieved.decimal_field == Decimal('123.45') + assert retrieved.numeric_field == Decimal('123.45') assert abs(retrieved.float_field - 3.14159) < 0.001 assert retrieved.boolean_field is True assert retrieved.date_field == date(2023, 12, 25) assert retrieved.time_field == time(14, 30, 0) - # Compare timezone-aware datetimes - assert retrieved.datetime_field == test_datetime_aware + assert retrieved.datetime_field == test_datetime assert retrieved.json_field == {"key": "value", "number": 123, "array": [1, 2, 3]} - # Clean up - DataTypeModel.objects.all().delete() - - def test_django_null_values(self, test_environment: TestEnvironment, django_models): - """Test Django ORM handling of NULL values""" - DataTypeModel = self.DataTypeModel - - # First, ensure we start with a clean slate - DataTypeModel.objects.all().delete() + session.query(DataTypeModel).delete() + session.commit() + def test_sqlalchemy_null_values(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy handling of NULL values""" + # Ensure clean slate + session.query(DataTypeModel).delete() + session.commit() # Create object with NULL values - test_obj = DataTypeModel.objects.create( - char_field=None, + test_obj = DataTypeModel( + string_field=None, integer_field=None, date_field=None, - boolean_field=False # This field has default=False, so it won't be NULL + boolean_field=False, ) - + session.add(test_obj) + session.commit() # Retrieve and verify NULL values - retrieved = DataTypeModel.objects.get(id=test_obj.id) - assert retrieved.char_field is None + retrieved = session.query(DataTypeModel).get(test_obj.id) + assert retrieved.string_field is None assert retrieved.integer_field is None assert retrieved.date_field is None assert retrieved.boolean_field is False - # Test filtering with NULL values - null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) - assert null_char_objects.count() == 1 - - not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) - assert not_null_char_objects.count() == 0 - + null_char_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.is_(None)).all() + assert len(null_char_objects) == 1 + not_null_char_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.isnot(None)).all() + assert len(not_null_char_objects) == 0 # Create an object with non-NULL values to test the opposite - DataTypeModel.objects.create( - char_field="Not NULL", + session.add(DataTypeModel( + string_field="Not NULL", integer_field=42, - date_field=date(2023, 1, 1) - ) - + date_field=date(2023, 1, 1), + )) + session.commit() # Now test filtering again - null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) - assert null_char_objects.count() == 1 # Still one NULL object - - not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) - assert not_null_char_objects.count() == 1 # Now one non-NULL object - + null_string_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.is_(None)).all() + # Still one NULL object + assert len(null_string_objects) == 1 + not_null_string_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.isnot(None)).all() + # Now one non-NULL object + assert len(not_null_string_objects) == 1 # Clean up - DataTypeModel.objects.all().delete() - - def test_django_relationships(self, test_environment: TestEnvironment, django_models): - """Test Django ORM relationships (ForeignKey)""" - Author = self.Author - Book = self.Book + session.query(DataTypeModel).delete() + session.commit() + def test_sqlalchemy_relationships(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy relationships (ForeignKey)""" # Create author - author = Author.objects.create( + author = Author( name="J.K. Rowling", email="jk@example.com", - birth_date=date(1965, 7, 31) + birth_date=date(1965, 7, 31), ) - + session.add(author) + session.commit() # Create books - book1 = Book.objects.create( + book1 = Book( title="Harry Potter and the Philosopher's Stone", - author=author, + author_id=author.id, publication_date=date(1997, 6, 26), pages=223, - price=Decimal('12.99') + price=Decimal('12.99'), ) - - book2 = Book.objects.create( + book2 = Book( title="Harry Potter and the Chamber of Secrets", - author=author, + author_id=author.id, publication_date=date(1998, 7, 2), pages=251, - price=Decimal('13.99') + price=Decimal('13.99'), ) - + session.add_all([book1, book2]) + session.commit() # Test forward relationship assert book1.author.name == "J.K. Rowling" assert book2.author.email == "jk@example.com" - # Test reverse relationship - author_books = author.books.all() - assert author_books.count() == 2 - book_titles = [book.title for book in author_books.order_by('publication_date')] + assert len(author.books) == 2 + book_titles = [book.title for book in sorted(author.books, key=lambda b: b.publication_date)] assert "Harry Potter and the Philosopher's Stone" in book_titles assert "Harry Potter and the Chamber of Secrets" in book_titles - # Test related queries - books_by_author = Book.objects.filter(author__name="J.K. Rowling") - assert books_by_author.count() == 2 - - # Test select_related for optimization - book_with_author = Book.objects.select_related('author').get(id=book1.id) + books_by_author = session.query(Book).join(Author).filter(Author.name == "J.K. Rowling").all() + assert len(books_by_author) == 2 + # Test joinedload for optimization + book_with_author = session.query(Book).options( + joinedload(Book.author) + ).filter(Book.id == book1.id).one() assert book_with_author.author.name == "J.K. Rowling" - - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_aggregations(self, test_environment: TestEnvironment, django_models): - """Test Django ORM aggregations""" - Author = self.Author - Book = self.Book - - # Create test data - author = Author.objects.create(name="Test Author", email="test@example.com") - - Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test aggregations - stats = Book.objects.aggregate( - total_books=Count('id'), - total_pages=Sum('pages'), - avg_price=Avg('price'), - max_pages=Max('pages'), - min_price=Min('price') - ) - - assert stats['total_books'] == 3 - assert stats['total_pages'] == 600 - assert abs(float(stats['avg_price']) - 20.0) < 0.01 - assert stats['max_pages'] == 300 - assert stats['min_price'] == Decimal('10.00') - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_transactions(self, test_environment: TestEnvironment, django_models): - """Test Django transaction handling""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + session.query(Book).delete() + session.query(Author).delete() + session.commit() - initial_count = TestModel.objects.count() + def test_sqlalchemy_aggregations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy aggregations""" + author = Author(name="Test Author", email="test@example.com") + session.add(author) + session.flush() + books = [ + Book(title="Book 1", author_id=author.id, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')), + Book(title="Book 2", author_id=author.id, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')), + Book(title="Book 3", author_id=author.id, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')), + ] + session.add_all(books) + session.flush() + stats = session.query( + func.count(Book.id).label('total_books'), + func.sum(Book.pages).label('total_pages'), + func.avg(Book.price).label('avg_price'), + func.max(Book.pages).label('max_pages'), + func.min(Book.price).label('min_price'), + ).one() + assert stats.total_books == 3 + assert stats.total_pages == 600 + assert abs(float(stats.avg_price) - 20.0) < 0.01 + assert stats.max_pages == 300 + assert stats.min_price == Decimal('10.00') + session.rollback() + def test_sqlalchemy_transactions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy transaction handling""" + session.query(TestModel).delete() + session.commit() + initial_count = session.query(TestModel).count() # Test successful transaction - with transaction.atomic(): - TestModel.objects.create(name="User 1", email="user1@example.com", age=25) - TestModel.objects.create(name="User 2", email="user2@example.com", age=30) - - assert TestModel.objects.count() == initial_count + 2 - + session.add(TestModel(name="User 1", email="user1@example.com", age=25)) + session.add(TestModel(name="User 2", email="user2@example.com", age=30)) + session.commit() + assert session.query(TestModel).count() == initial_count + 2 # Test rollback transaction try: - with transaction.atomic(): - TestModel.objects.create(name="User 3", email="user3@example.com", age=35) - TestModel.objects.create(name="User 4", email="user4@example.com", age=40) - # Force an error to trigger rollback - raise Exception("Force rollback") + session.add(TestModel(name="User 3", email="user3@example.com", age=35)) + session.add(TestModel(name="User 4", email="user4@example.com", age=40)) + session.flush() + raise Exception("Force rollback") except Exception: - pass # Expected exception - - # Should still have only 2 additional records (rollback occurred) - assert TestModel.objects.count() == initial_count + 2 - - # Clean up - TestModel.objects.all().delete() - - def test_django_bulk_operations(self, test_environment: TestEnvironment, django_models): - """Test Django bulk operations""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + session.rollback() + assert session.query(TestModel).count() == initial_count + 2 + session.query(TestModel).delete() + session.commit() - # Test bulk_create - test_objects = [ + def test_sqlalchemy_bulk_operations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy bulk operations""" + session.query(TestModel).delete() + session.commit() + # Test bulk insert + session.bulk_save_objects([ TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) for i in range(10) - ] - - created_objects = TestModel.objects.bulk_create(test_objects) - assert len(created_objects) == 10 - assert TestModel.objects.count() == 10 - - # Test bulk_update - need to get the objects first and modify them - objects_to_update = list(TestModel.objects.all()) - for obj in objects_to_update: - obj.age += 5 - - TestModel.objects.bulk_update(objects_to_update, ['age']) - - # Verify updates - get fresh objects from database - ages = list(TestModel.objects.values_list('age', flat=True).order_by('name')) - expected_ages = [25 + i for i in range(10)] # 20+i+5 for i in range(10) + ]) + session.commit() + assert session.query(TestModel).count() == 10 + # Test bulk update + session.query(TestModel).update({TestModel.age: TestModel.age + 5}) + session.commit() + ages = [r.age for r in session.query(TestModel).order_by(TestModel.name).all()] + expected_ages = [25 + i for i in range(10)] assert ages == expected_ages + session.query(TestModel).delete() + session.commit() - # Clean up - TestModel.objects.all().delete() - - def test_django_complex_queries(self, test_environment: TestEnvironment, django_models): - """Test complex Django queries with Q objects and F expressions""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - - # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - TestModel.objects.create(name="David", email="david@example.com", age=28, is_active=True) - - # Test Q objects for complex conditions - complex_query = TestModel.objects.filter( - Q(age__gte=30) | Q(name__startswith='A') - ) - assert complex_query.count() == 3 # Bob (30), Charlie (35), Alice (starts with A) - - # Test F expressions - TestModel.objects.filter(age__lt=30).update(age=F('age') + 5) - - # Verify F expression update - alice = TestModel.objects.get(name="Alice") - david = TestModel.objects.get(name="David") - assert alice.age == 30 # 25 + 5 - assert david.age == 33 # 28 + 5 - - # Clean up, might get a failover error from this connection - TestModel.objects.all().delete() - - def test_django_raw_sql_queries(self, test_environment: TestEnvironment, django_models): - """Test Django raw SQL query execution""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - - # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test raw() method - raw_results = TestModel.objects.raw( - f'SELECT * FROM {TestModel._meta.db_table} WHERE age >= %s ORDER BY age', - [30] + def test_sqlalchemy_complex_queries(self, test_environment: TestEnvironment, session): + """Test complex SQLAlchemy queries with or_/and_ and column expressions""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + TestModel(name="David", email="david@example.com", age=28, is_active=True), + ]) + session.commit() + # Test or_ for complex conditions + results = session.query(TestModel).filter( + or_(TestModel.age >= 30, TestModel.name.like('A%')) + ).all() + assert len(results) == 3 + # Test column expression update (equivalent to Django's F expressions) + session.query(TestModel).filter(TestModel.age < 30).update( + {TestModel.age: TestModel.age + 5}, synchronize_session='fetch' ) - raw_list = list(raw_results) - assert len(raw_list) == 2 - assert raw_list[0].name == "Bob" - assert raw_list[1].name == "Charlie" - - # Test connection.cursor() for custom SQL - with connection.cursor() as cursor: - cursor.execute( - f'SELECT name, age FROM {TestModel._meta.db_table} WHERE is_active = %s ORDER BY age', - [True] - ) - rows = cursor.fetchall() - assert len(rows) == 2 - assert rows[0][0] == "Alice" # name - assert rows[0][1] == 25 # age - assert rows[1][0] == "Charlie" - assert rows[1][1] == 35 - - # Test raw SQL with connection for aggregate - with connection.cursor() as cursor: - cursor.execute(f'SELECT COUNT(*), AVG(age) FROM {TestModel._meta.db_table}') - count, avg_age = cursor.fetchone() - assert count == 3 - assert abs(float(avg_age) - 30.0) < 0.01 - - # Clean up - TestModel.objects.all().delete() - - def test_django_get_or_create(self, test_environment: TestEnvironment, django_models): - """Test Django get_or_create pattern""" - TestModel = self.TestModel + session.commit() + alice = session.query(TestModel).filter_by(name="Alice").one() + david = session.query(TestModel).filter_by(name="David").one() + assert alice.age == 30 + assert david.age == 33 + session.query(TestModel).delete() + session.commit() - # Ensure clean slate - TestModel.objects.all().delete() + def test_sqlalchemy_raw_sql_queries(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy raw SQL query execution""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + table = TestModel.__tablename__ + # Test raw SQL with text() + rows = session.execute( + text(f'SELECT * FROM {table} WHERE age >= :age ORDER BY age'), + {'age': 30} + ).fetchall() + assert len(rows) == 2 + # Test raw SQL for specific columns + rows = session.execute( + text(f'SELECT name, age FROM {table} WHERE is_active = :active ORDER BY age'), + {'active': True} + ).fetchall() + assert len(rows) == 2 + assert rows[0][0] == "Alice" + assert rows[0][1] == 25 + assert rows[1][0] == "Charlie" + assert rows[1][1] == 35 + # Test raw SQL aggregate + result = session.execute( + text(f'SELECT COUNT(*), AVG(age) FROM {table}') + ).fetchone() + assert result[0] == 3 + assert abs(float(result[1]) - 30.0) < 0.01 + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_get_or_create(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy get-or-create pattern""" + session.query(TestModel).delete() + session.commit() # Test create case - obj1, created1 = TestModel.objects.get_or_create( - email="test@example.com", - defaults={'name': 'Test User', 'age': 25, 'is_active': True} - ) + obj1 = session.query(TestModel).filter_by(email="test@example.com").first() + created1 = obj1 is None + if created1: + obj1 = TestModel(name="Test User", email="test@example.com", age=25, is_active=True) + session.add(obj1) + session.commit() assert created1 is True assert obj1.name == "Test User" assert obj1.age == 25 - - # Test get case (object already exists) - obj2, created2 = TestModel.objects.get_or_create( - email="test@example.com", - defaults={'name': 'Different Name', 'age': 30, 'is_active': False} - ) + # Test get case + obj2 = session.query(TestModel).filter_by(email="test@example.com").first() + created2 = obj2 is None + if created2: + obj2 = TestModel(name="Different Name", email="test@example.com", age=30, is_active=False) + session.add(obj2) + session.commit() assert created2 is False assert obj2.id == obj1.id - assert obj2.name == "Test User" # Should keep original values + assert obj2.name == "Test User" assert obj2.age == 25 + assert session.query(TestModel).filter_by(email="test@example.com").count() == 1 + session.query(TestModel).delete() + session.commit() - # Verify only one object exists - assert TestModel.objects.filter(email="test@example.com").count() == 1 - - # Clean up - TestModel.objects.all().delete() - - def test_django_update_or_create(self, test_environment: TestEnvironment, django_models): - """Test Django update_or_create pattern""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - + def test_sqlalchemy_update_or_create(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy update-or-create pattern""" + session.query(TestModel).delete() + session.commit() # Test create case - obj1, created1 = TestModel.objects.update_or_create( - email="update@example.com", - defaults={'name': 'Initial Name', 'age': 25, 'is_active': True} - ) + obj1 = session.query(TestModel).filter_by(email="update@example.com").first() + created1 = obj1 is None + if created1: + obj1 = TestModel(name="Initial Name", email="update@example.com", age=25, is_active=True) + session.add(obj1) + session.commit() assert created1 is True assert obj1.name == "Initial Name" assert obj1.age == 25 - - # Test update case (object already exists) - obj2, created2 = TestModel.objects.update_or_create( - email="update@example.com", - defaults={'name': 'Updated Name', 'age': 30, 'is_active': False} - ) + # Test update case + obj2 = session.query(TestModel).filter_by(email="update@example.com").first() + created2 = obj2 is None + if created2: + obj2 = TestModel(name="Updated Name", email="update@example.com", age=30, is_active=False) + session.add(obj2) + else: + obj2.name = "Updated Name" + obj2.age = 30 + obj2.is_active = False + session.commit() assert created2 is False assert obj2.id == obj1.id - assert obj2.name == "Updated Name" # Should be updated + assert obj2.name == "Updated Name" assert obj2.age == 30 assert obj2.is_active is False - - # Verify only one object exists - assert TestModel.objects.filter(email="update@example.com").count() == 1 - - # Verify the update persisted - retrieved = TestModel.objects.get(email="update@example.com") + assert session.query(TestModel).filter_by(email="update@example.com").count() == 1 + retrieved = session.query(TestModel).filter_by(email="update@example.com").one() assert retrieved.name == "Updated Name" assert retrieved.age == 30 + session.query(TestModel).delete() + session.commit() - # Clean up - TestModel.objects.all().delete() - - def test_django_prefetch_related(self, test_environment: TestEnvironment, django_models): - """Test Django prefetch_related for optimizing queries""" - Author = self.Author - Book = self.Book - - # Create test data - author1 = Author.objects.create(name="Author 1", email="author1@example.com") - author2 = Author.objects.create(name="Author 2", email="author2@example.com") - - Book.objects.create(title="Book 1A", author=author1, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 1B", author=author1, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 2A", author=author2, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test prefetch_related - authors = Author.objects.prefetch_related('books').all() - - # Access related books (should not trigger additional queries due to prefetch) + def test_sqlalchemy_eager_loading(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy eager loading for optimizing queries""" + author1 = Author(name="Author 1", email="author1@example.com") + author2 = Author(name="Author 2", email="author2@example.com") + session.add_all([author1, author2]) + session.flush() + session.add_all([ + Book(title="Book 1A", author_id=author1.id, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')), + Book(title="Book 1B", author_id=author1.id, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')), + Book(title="Book 2A", author_id=author2.id, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')), + ]) + session.commit() + # Test subqueryload (equivalent to Django's prefetch_related) + authors = session.query(Author).options(subqueryload(Author.books)).all() for author in authors: - books = list(author.books.all()) if author.name == "Author 1": - assert len(books) == 2 - book_titles = [book.title for book in books] - assert "Book 1A" in book_titles - assert "Book 1B" in book_titles + assert len(author.books) == 2 + titles = [b.title for b in author.books] + assert "Book 1A" in titles + assert "Book 1B" in titles elif author.name == "Author 2": - assert len(books) == 1 - assert books[0].title == "Book 2A" - - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_database_functions(self, test_environment: TestEnvironment, django_models): - """Test Django database functions""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + assert len(author.books) == 1 + assert author.books[0].title == "Book 2A" + session.rollback() - # Create test data - TestModel.objects.create(name="alice", email="alice@example.com", age=25) - TestModel.objects.create(name="BOB", email="bob@example.com", age=30) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35) - - # Test Upper function - upper_names = TestModel.objects.annotate(upper_name=Upper('name')).values_list('upper_name', flat=True) - upper_list = list(upper_names) - assert "ALICE" in upper_list - assert "BOB" in upper_list - assert "CHARLIE" in upper_list - - # Test Lower function - lower_names = TestModel.objects.annotate(lower_name=Lower('name')).values_list('lower_name', flat=True) - lower_list = list(lower_names) - assert "alice" in lower_list - assert "bob" in lower_list - assert "charlie" in lower_list - - # Test Length function - name_lengths = TestModel.objects.annotate(name_length=Length('name')).filter(name_length__gte=5) - assert name_lengths.count() == 2 # "alice" (5) and "Charlie" (7) - - # Test Concat function - full_info = TestModel.objects.annotate( - full_info=Concat('name', Value(' - '), 'email', output_field=CharField()) + def test_sqlalchemy_database_functions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy database functions""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="alice", email="alice@example.com", age=25), + TestModel(name="BOB", email="bob@example.com", age=30), + TestModel(name="Charlie", email="charlie@example.com", age=35), + ]) + session.commit() + # Test upper + upper_names = [r[0] for r in session.query(func.upper(TestModel.name)).all()] + assert "ALICE" in upper_names + assert "BOB" in upper_names + assert "CHARLIE" in upper_names + # Test lower + lower_names = [r[0] for r in session.query(func.lower(TestModel.name)).all()] + assert "alice" in lower_names + assert "bob" in lower_names + assert "charlie" in lower_names + # Test length + results = session.query(TestModel).filter(func.length(TestModel.name) >= 5).all() + assert len(results) == 2 # "alice" (5) and "Charlie" (7) + # Test concat + result = session.query( + func.concat(TestModel.name, ' - ', TestModel.email) ).first() - assert ' - ' in full_info.full_info - assert '@example.com' in full_info.full_info - - # Clean up - TestModel.objects.all().delete() - - def test_django_annotations(self, test_environment: TestEnvironment, django_models): - """Test Django annotations with expressions""" - TestModel = self.TestModel - Book = self.Book - Author = self.Author - - # Create test data for TestModel - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test annotate with F expression for calculations - test_with_age_plus_ten = TestModel.objects.annotate( - age_plus_ten=F('age') + 10 - ).order_by('age') - - # Verify calculation - first_obj = test_with_age_plus_ten.first() - assert first_obj.age_plus_ten == first_obj.age + 10 - assert first_obj.age_plus_ten == 35 # 25 + 10 - - # Create books for F expression testing - author = Author.objects.create(name="Test Author", email="test@example.com") - Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test annotate with F expression for price per page - books_with_price_per_page = Book.objects.annotate( - price_per_page=F('price') / F('pages') - ).order_by('price_per_page') - - # Verify calculation - first_book = books_with_price_per_page.first() - expected_price_per_page = float(first_book.price) / first_book.pages - assert abs(float(first_book.price_per_page) - expected_price_per_page) < 0.001 - - # Test filtering on annotated field - use a lower threshold to avoid precision issues - cheap_books = Book.objects.annotate( - price_per_page=F('price') / F('pages') - ).filter(price_per_page__lte=0.15) - assert cheap_books.count() == 3 # All books have price_per_page = 0.10 - - # Clean up - TestModel.objects.all().delete() - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_values_and_values_list(self, test_environment: TestEnvironment, django_models): - """Test Django values() and values_list() methods""" - TestModel = self.TestModel + assert ' - ' in result[0] + assert '@example.com' in result[0] + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_values_and_values_list(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy equivalents of Django's values() and values_list() functions""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test values() - returns list of dictionaries - values_result = TestModel.objects.values('name', 'age').order_by('age') - values_list = list(values_result) - assert len(values_list) == 3 - assert values_list[0] == {'name': 'Alice', 'age': 25} - assert values_list[1] == {'name': 'Bob', 'age': 30} - assert values_list[2] == {'name': 'Charlie', 'age': 35} - - # Test values_list() - returns list of tuples - values_list_result = TestModel.objects.values_list('name', 'age').order_by('age') - tuples_list = list(values_list_result) - assert len(tuples_list) == 3 - assert tuples_list[0] == ('Alice', 25) - assert tuples_list[1] == ('Bob', 30) - assert tuples_list[2] == ('Charlie', 35) - - # Test values_list() with flat=True - returns flat list - names = TestModel.objects.values_list('name', flat=True).order_by('name') - names_list = list(names) - assert names_list == ['Alice', 'Bob', 'Charlie'] - - # Test values() with filtering - active_users = TestModel.objects.filter(is_active=True).values('name', 'email') - active_list = list(active_users) - assert len(active_list) == 2 - active_names = [user['name'] for user in active_list] + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + # Convert values to dicts (equivalent to Django's values()) + values_result = session.query(TestModel.name, TestModel.age).order_by(TestModel.age).all() + assert len(values_result) == 3 + assert values_result[0] == ('Alice', 25) + assert values_result[1] == ('Bob', 30) + assert values_result[2] == ('Charlie', 35) + values_dicts = [{'name': r.name, 'age': r.age} for r in values_result] + assert values_dicts[0] == {'name': 'Alice', 'age': 25} + assert values_dicts[1] == {'name': 'Bob', 'age': 30} + assert values_dicts[2] == {'name': 'Charlie', 'age': 35} + # Test flat list (equivalent to Django's values_list with flat=True) + names = [r[0] for r in session.query(TestModel.name).order_by(TestModel.name).all()] + assert names == ['Alice', 'Bob', 'Charlie'] + # Test with filtering + active_users = session.query(TestModel.name, TestModel.email).filter( + TestModel.is_active == True + ).all() + assert len(active_users) == 2 + active_names = [r.name for r in active_users] assert 'Alice' in active_names assert 'Charlie' in active_names assert 'Bob' not in active_names - # Clean up - TestModel.objects.all().delete() - - def test_django_distinct_queries(self, test_environment: TestEnvironment, django_models): - """Test Django distinct() functionality""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy distinct() functionality""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data with duplicate ages - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=25, is_active=True) - TestModel.objects.create(name="David", email="david@example.com", age=30, is_active=True) - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=25, is_active=True), + TestModel(name="David", email="david@example.com", age=30, is_active=True), + ]) + session.commit() # Test distinct ages - distinct_ages = TestModel.objects.values_list('age', flat=True).distinct().order_by('age') - ages_list = list(distinct_ages) + ages_list = [r[0] for r in session.query(TestModel.age).distinct().order_by(TestModel.age).all()] assert ages_list == [25, 30] - # Test distinct with multiple fields - distinct_age_status = TestModel.objects.values('age', 'is_active').distinct().order_by('age', 'is_active') - distinct_list = list(distinct_age_status) + distinct_list = session.query(TestModel.age, TestModel.is_active).distinct().order_by( + TestModel.age, TestModel.is_active + ).all() assert len(distinct_list) == 3 # (25, True), (30, False), (30, True) - # Test count with distinct - total_count = TestModel.objects.count() - distinct_age_count = TestModel.objects.values('age').distinct().count() + total_count = session.query(TestModel).count() + distinct_age_count = session.query(TestModel.age).distinct().count() assert total_count == 4 assert distinct_age_count == 2 - # Clean up - TestModel.objects.all().delete() - - def test_django_only_and_defer(self, test_environment: TestEnvironment, django_models): - """Test Django only() and defer() for query optimization""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy load_only() and defer() for query optimization""" + from sqlalchemy.orm import defer, load_only # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - obj = TestModel.objects.create( - name="Test User", - email="test@example.com", - age=30, - is_active=True - ) - - # Test only() - load only specific fields - obj_only = TestModel.objects.only('name', 'email').get(id=obj.id) + obj = TestModel(name="Test User", email="test@example.com", age=30, is_active=True) + session.add(obj) + session.commit() + obj_id = obj.id + session.expire_all() + # Test load_only() - load only specific fields + obj_only = session.query(TestModel).options( + load_only(TestModel.name, TestModel.email) + ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" - # Accessing deferred fields will trigger additional query, but should still work assert obj_only.age == 30 - + session.expire_all() # Test defer() - exclude specific fields from loading - obj_defer = TestModel.objects.defer('age', 'is_active').get(id=obj.id) + obj_defer = session.query(TestModel).options( + defer(TestModel.age), defer(TestModel.is_active) + ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" - # Accessing deferred fields will trigger additional query, but should still work assert obj_defer.age == 30 - # Clean up - TestModel.objects.all().delete() - - def test_django_in_bulk(self, test_environment: TestEnvironment, django_models): - """Test Django in_bulk() for batch retrieval""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy batch retrieval (equivalent to Django's in_bulk)""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - obj1 = TestModel.objects.create(name="User 1", email="user1@example.com", age=25) - obj2 = TestModel.objects.create(name="User 2", email="user2@example.com", age=30) - obj3 = TestModel.objects.create(name="User 3", email="user3@example.com", age=35) - - # Test in_bulk with IDs (default behavior) - bulk_result = TestModel.objects.in_bulk([obj1.id, obj2.id, obj3.id]) + obj1 = TestModel(name="User 1", email="user1@example.com", age=25) + obj2 = TestModel(name="User 2", email="user2@example.com", age=30) + obj3 = TestModel(name="User 3", email="user3@example.com", age=35) + session.add_all([obj1, obj2, obj3]) + session.commit() + # Test bulk retrieval by IDs + ids = [obj1.id, obj2.id, obj3.id] + bulk_result = {o.id: o for o in session.query(TestModel).filter(TestModel.id.in_(ids)).all()} assert len(bulk_result) == 3 assert bulk_result[obj1.id].name == "User 1" assert bulk_result[obj2.id].name == "User 2" assert bulk_result[obj3.id].name == "User 3" - - # Test in_bulk with all IDs (no list provided) - bulk_all = TestModel.objects.in_bulk() + # Test bulk retrieval of all + bulk_all = {o.id: o for o in session.query(TestModel).all()} assert len(bulk_all) == 3 assert obj1.id in bulk_all assert obj2.id in bulk_all assert obj3.id in bulk_all - - # Test in_bulk with email field (unique field) - bulk_by_email = TestModel.objects.in_bulk( - ["user1@example.com", "user3@example.com"], - field_name='email' - ) + # Test bulk retrieval by email field + emails = ["user1@example.com", "user3@example.com"] + bulk_by_email = { + o.email: o for o in session.query(TestModel).filter(TestModel.email.in_(emails)).all() + } assert len(bulk_by_email) == 2 assert bulk_by_email["user1@example.com"].name == "User 1" assert bulk_by_email["user3@example.com"].name == "User 3" - # Clean up - TestModel.objects.all().delete() - - def test_django_conditional_expressions(self, test_environment: TestEnvironment, django_models): - """Test Django Case/When conditional expressions""" - from django.db.models import Case, IntegerField, Value, When - - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy case() conditional expressions""" + from sqlalchemy import String, case # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test Case/When for conditional logic - results = TestModel.objects.annotate( - age_category=Case( - When(age__lt=30, then=Value('young')), - When(age__gte=30, age__lt=40, then=Value('middle')), - default=Value('senior'), - output_field=CharField() - ) - ).order_by('age') - - results_list = list(results) - assert results_list[0].age_category == 'young' # Alice, 25 - assert results_list[1].age_category == 'middle' # Bob, 30 - assert results_list[2].age_category == 'middle' # Charlie, 35 - - # Test Case/When with integer output - priority_results = TestModel.objects.annotate( - priority=Case( - When(is_active=True, age__lt=30, then=Value(1)), - When(is_active=True, then=Value(2)), - When(is_active=False, then=Value(3)), - default=Value(4), - output_field=IntegerField() - ) - ).order_by('priority', 'name') - - priority_list = list(priority_results) - assert priority_list[0].name == 'Alice' # priority 1: active and young - assert priority_list[1].name == 'Charlie' # priority 2: active but not young - assert priority_list[2].name == 'Bob' # priority 3: not active - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + # Test case() for conditional logic + age_category = case( + (TestModel.age < 30, 'young'), + (TestModel.age.between(30, 39), 'middle'), + else_='senior' + ).label('age_category') + results = session.query(TestModel, age_category).order_by(TestModel.age).all() + assert results[0].age_category == 'young' # Alice, 25 + assert results[1].age_category == 'middle' # Bob, 30 + assert results[2].age_category == 'middle' # Charlie, 35 + # Test case() with integer output + from sqlalchemy import Integer + priority = case( + (and_(TestModel.is_active == True, TestModel.age < 30), 1), + (TestModel.is_active == True, 2), + (TestModel.is_active == False, 3), + else_=4 + ).label('priority') + results = session.query(TestModel, priority).order_by('priority', TestModel.name).all() + assert results[0].TestModel.name == 'Alice' # priority 1 + assert results[1].TestModel.name == 'Charlie' # priority 2 + assert results[2].TestModel.name == 'Bob' # priority 3 # Clean up - TestModel.objects.all().delete() - - def test_django_iterator(self, test_environment: TestEnvironment, django_models): - """Test Django iterator() for memory-efficient queries""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_yield_per(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy yield_per() for memory-efficient queries""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - for i in range(20): - TestModel.objects.create( - name=f"User {i}", - email=f"user{i}@example.com", - age=20 + i - ) - - # Test iterator() - processes results without caching + session.add_all([ + TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) + for i in range(20) + ]) + session.commit() + # Test yield_per() - processes results without caching all at once count = 0 - for obj in TestModel.objects.iterator(): + for obj in session.query(TestModel).yield_per(100): assert obj.name.startswith("User") count += 1 assert count == 20 - - # Test iterator with chunk_size + # Test yield_per with smaller chunk size count = 0 - for obj in TestModel.objects.iterator(chunk_size=5): + for obj in session.query(TestModel).yield_per(5): assert obj.email.endswith("@example.com") count += 1 assert count == 20 - # Clean up - TestModel.objects.all().delete() - ''' + session.query(TestModel).delete() + session.commit() From 904790797c83b9cc66d29f9b748e720f35001a9a Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 16:58:33 -0700 Subject: [PATCH 11/15] Remove temporary changes to get tests to run locally --- .../container/utils/test_database_info.py | 2 +- .../utils/test_environment_request.py | 2 +- tests/unit/test_sqlalchemy_orm.py | 61 ------------------- 3 files changed, 2 insertions(+), 63 deletions(-) delete mode 100644 tests/unit/test_sqlalchemy_orm.py diff --git a/tests/integration/container/utils/test_database_info.py b/tests/integration/container/utils/test_database_info.py index edc49faf0..a1b3a0944 100644 --- a/tests/integration/container/utils/test_database_info.py +++ b/tests/integration/container/utils/test_database_info.py @@ -42,7 +42,7 @@ def __init__(self, database_info: Dict[str, Any]) -> None: self._username = typing.cast('str', database_info.get("username")) self._password = typing.cast('str', database_info.get("password")) - self._default_db_name = "mysqldb" + self._default_db_name = typing.cast('str', database_info.get("defaultDbName")) self._cluster_endpoint = typing.cast('str', database_info.get("clusterEndpoint")) self._cluster_endpoint_port = typing.cast('int', database_info.get("clusterEndpointPort")) self._cluster_read_only_endpoint = typing.cast('str', database_info.get("clusterReadOnlyEndpoint")) diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index def700293..db1bbeef2 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -63,7 +63,7 @@ def get_features(self) -> Set[TestEnvironmentFeatures]: return self._features def get_num_of_instances(self) -> int: - return 3 + return self._num_of_instances def get_display_name(self) -> str: return "Test environment [{0}, {1}, {2}, {3}, {4}, {5}]".format( diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py deleted file mode 100644 index 70acb6f58..000000000 --- a/tests/unit/test_sqlalchemy_orm.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from sqlalchemy import create_engine, Column, Integer, String -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -class TestSqlAlchemyORM: - def test_basic_workflow(self): - # Step 1: Create engine (connection to database) - engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') - # Step 2: Define base class for declarative models - Base = declarative_base() - - # Step 3: Define model class (separate from database operations) - class User(Base): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - name = Column(String(50)) - email = Column(String(100)) - - # Step 4: Create tables - Base.metadata.create_all(engine) - - # Step 5: Create session factory - Session = sessionmaker(bind=engine) - - # Step 6: Use session for database operations - with Session() as session: - # INSERT - Create new object and add to session - new_user = User(name='John Doe', email='john@example.com') - session.add(new_user) - session.commit() # Explicit commit required - - # SELECT - Query using session - users = session.query(User).filter(User.name == 'John Doe').all() - for user in users: - print(f"{user.name}: {user.email}") - - - # UPDATE - Modify object and commit - user = session.query(User).filter(User.name == "John Doe").first() - user.email = 'newemail@example.com' - session.commit() - - # DELETE - Remove object from session - user_to_delete = session.query(User).filter(User.name == "John Doe").first() - session.delete(user_to_delete) - session.commit() From 0e089385782853c0e2b4c9e787b3ae03bc5e8544 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 17:01:28 -0700 Subject: [PATCH 12/15] Add license headers and remove unused import --- .../sqlalchemy/mysql_orm_dialect.py | 15 ++++++++++++++- .../sqlalchemy/pg_orm_dialect.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 6d7ff34db..fde407e5c 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -1,5 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py -from psycopg import Connection from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector import re diff --git a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index c2780b861..d792ce501 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -1,3 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py from psycopg import Connection from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg From d4eb5bdc2599a126b68f501849353b65fc2c32bd Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 9 Apr 2026 11:14:13 -0700 Subject: [PATCH 13/15] Try fixing mypy errors in tests --- .../integration/container/sqlalchemy/test_sqlalchemy_basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 822b92ac8..9f099febd 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -688,7 +688,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test load_only() - load only specific fields obj_only = session.query(TestModel).options( - load_only(TestModel.name, TestModel.email) + load_only('TestModel.name', 'TestModel.email') ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" @@ -696,7 +696,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test defer() - exclude specific fields from loading obj_defer = session.query(TestModel).options( - defer(TestModel.age), defer(TestModel.is_active) + defer('TestModel.age'), defer('TestModel.is_active') ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" From b344ef5fcacd2cb9730bc0d7962fdff47ab925ca Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 13 Apr 2026 13:22:35 -0700 Subject: [PATCH 14/15] Try to fix mypy Base class error --- .../container/sqlalchemy/test_sqlalchemy_basic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 9f099febd..46b2e8639 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -40,7 +40,10 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures -Base = declarative_base() +class Base: + __allow_unmapped__ = True + +Base = declarative_base(cls=Base) class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" @@ -688,7 +691,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test load_only() - load only specific fields obj_only = session.query(TestModel).options( - load_only('TestModel.name', 'TestModel.email') + load_only(TestModel.name, TestModel.email) ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" @@ -696,7 +699,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test defer() - exclude specific fields from loading obj_defer = session.query(TestModel).options( - defer('TestModel.age'), defer('TestModel.is_active') + defer(TestModel.age), defer(TestModel.is_active) ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" From 919284aad56e3e4a61b305f96d45266f274008f8 Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:02:35 -0700 Subject: [PATCH 15/15] fix: test failures due to using legacy sqlalchemy api (#1226) --- aws_advanced_python_wrapper/__init__.py | 18 ++-- .../sqlalchemy/mysql_orm_dialect.py | 7 +- .../sqlalchemy/pg_orm_dialect.py | 3 +- .../sqlalchemy/test_sqlalchemy_basic.py | 89 ++++++++++--------- 4 files changed, 56 insertions(+), 61 deletions(-) diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index 7d4dbf38a..459a67fab 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -14,21 +14,16 @@ from logging import DEBUG, getLogger +from aws_advanced_python_wrapper.pep249 import (DatabaseError, DataError, + Error, IntegrityError, + InterfaceError, InternalError, + NotSupportedError, + OperationalError, + ProgrammingError) from .cleanup import release_resources from .driver_info import DriverInfo from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection -from aws_advanced_python_wrapper.pep249 import ( - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError -) # PEP249 compliance connect = AwsWrapperConnection.connect @@ -58,5 +53,6 @@ __version__ = DriverInfo.DRIVER_VERSION + def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index fde407e5c..dc8da6394 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -13,10 +13,8 @@ # limitations under the License. # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py -from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector -import re - -from aws_advanced_python_wrapper import AwsWrapperConnection +from sqlalchemy.dialects.mysql.mysqlconnector import \ + MySQLDialect_mysqlconnector class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): @@ -29,4 +27,3 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' - diff --git a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index d792ce501..c066520bc 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py from psycopg import Connection from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg -import re from aws_advanced_python_wrapper import AwsWrapperConnection diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 46b2e8639..82427f95b 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -18,19 +18,16 @@ from datetime import date, datetime, time, timezone from decimal import Decimal -from typing import Any +from typing import Any, List, Optional import pytest +from sqlalchemy import (JSON, BigInteger, Boolean, Date, DateTime, Float, + ForeignKey, Numeric, SmallInteger, String, Text, Time, + and_, create_engine, or_, text) +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, joinedload, + mapped_column, relationship, sessionmaker, + subqueryload) from sqlalchemy.sql import func -from sqlalchemy.orm import ( - declarative_base, sessionmaker, relationship, Session, joinedload, - subqueryload -) -from sqlalchemy import ( - create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, - Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_, - and_, text -) from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -40,74 +37,76 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures -class Base: - __allow_unmapped__ = True -Base = declarative_base(cls=Base) +class Base(DeclarativeBase): + pass + class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" __tablename__ = 'sqlalchemy_test_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False, unique=True) - age = Column(Integer, nullable=False) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.now(timezone.utc)) class DataTypeModel(Base): """Model for testing various data types""" __tablename__ = 'sqlalchemy_data_type_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) # String fields - string_field = Column(String(255)) - text_field = Column(Text) + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) # Numeric fields - integer_field = Column(Integer) - small_integer_field = Column(SmallInteger) - big_integer_field = Column(BigInteger) - numeric_field = Column(Numeric(10, 2)) - float_field = Column(Float) + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) # Boolean field - boolean_field = Column(Boolean, default=False) + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) # Date/Time fields - date_field = Column(Date) - time_field = Column(Time) - datetime_field = Column(DateTime) + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) # JSON field (MySQL 5.7+) - json_field = Column(JSON) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + class Author(Base): """Author model for relationship testing""" __tablename__ = 'sqlalchemy_author' - id = Column(Integer, primary_key=True) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False) - birth_date = Column(Date) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') - books = relationship('Book', back_populates='author', cascade='all, delete-orphan') class Book(Base): """Book model for relationship testing""" __tablename__ = 'sqlalchemy_book' - id = Column(Integer, primary_key=True) - title = Column(String(200), nullable=False) - author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False) - publication_date = Column(Date, nullable=False) - pages = Column(Integer, nullable=False) - price = Column(Numeric(8, 2), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) - author = relationship('Author', back_populates='books') + author: Mapped[Author] = relationship(back_populates='books') @enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @@ -680,6 +679,7 @@ def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, se def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session): """Test SQLAlchemy load_only() and defer() for query optimization""" from sqlalchemy.orm import defer, load_only + # Ensure clean slate session.query(TestModel).delete() session.commit() @@ -747,6 +747,7 @@ def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, ses def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session): """Test SQLAlchemy case() conditional expressions""" from sqlalchemy import String, case + # Ensure clean slate session.query(TestModel).delete() session.commit()