diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index 7d4dbf38..459a67fa 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -14,21 +14,16 @@ from logging import DEBUG, getLogger +from aws_advanced_python_wrapper.pep249 import (DatabaseError, DataError, + Error, IntegrityError, + InterfaceError, InternalError, + NotSupportedError, + OperationalError, + ProgrammingError) from .cleanup import release_resources from .driver_info import DriverInfo from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection -from aws_advanced_python_wrapper.pep249 import ( - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError -) # PEP249 compliance connect = AwsWrapperConnection.connect @@ -58,5 +53,6 @@ __version__ = DriverInfo.DRIVER_VERSION + def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index fde407e5..dc8da639 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -13,10 +13,8 @@ # limitations under the License. # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py -from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector -import re - -from aws_advanced_python_wrapper import AwsWrapperConnection +from sqlalchemy.dialects.mysql.mysqlconnector import \ + MySQLDialect_mysqlconnector class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): @@ -29,4 +27,3 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' - diff --git a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index d792ce50..c066520b 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py from psycopg import Connection from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg -import re from aws_advanced_python_wrapper import AwsWrapperConnection diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 46b2e863..82427f95 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -18,19 +18,16 @@ from datetime import date, datetime, time, timezone from decimal import Decimal -from typing import Any +from typing import Any, List, Optional import pytest +from sqlalchemy import (JSON, BigInteger, Boolean, Date, DateTime, Float, + ForeignKey, Numeric, SmallInteger, String, Text, Time, + and_, create_engine, or_, text) +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, joinedload, + mapped_column, relationship, sessionmaker, + subqueryload) from sqlalchemy.sql import func -from sqlalchemy.orm import ( - declarative_base, sessionmaker, relationship, Session, joinedload, - subqueryload -) -from sqlalchemy import ( - create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, - Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_, - and_, text -) from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -40,74 +37,76 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures -class Base: - __allow_unmapped__ = True -Base = declarative_base(cls=Base) +class Base(DeclarativeBase): + pass + class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" __tablename__ = 'sqlalchemy_test_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False, unique=True) - age = Column(Integer, nullable=False) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.now(timezone.utc)) class DataTypeModel(Base): """Model for testing various data types""" __tablename__ = 'sqlalchemy_data_type_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) # String fields - string_field = Column(String(255)) - text_field = Column(Text) + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) # Numeric fields - integer_field = Column(Integer) - small_integer_field = Column(SmallInteger) - big_integer_field = Column(BigInteger) - numeric_field = Column(Numeric(10, 2)) - float_field = Column(Float) + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) # Boolean field - boolean_field = Column(Boolean, default=False) + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) # Date/Time fields - date_field = Column(Date) - time_field = Column(Time) - datetime_field = Column(DateTime) + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) # JSON field (MySQL 5.7+) - json_field = Column(JSON) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + class Author(Base): """Author model for relationship testing""" __tablename__ = 'sqlalchemy_author' - id = Column(Integer, primary_key=True) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False) - birth_date = Column(Date) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') - books = relationship('Book', back_populates='author', cascade='all, delete-orphan') class Book(Base): """Book model for relationship testing""" __tablename__ = 'sqlalchemy_book' - id = Column(Integer, primary_key=True) - title = Column(String(200), nullable=False) - author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False) - publication_date = Column(Date, nullable=False) - pages = Column(Integer, nullable=False) - price = Column(Numeric(8, 2), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) - author = relationship('Author', back_populates='books') + author: Mapped[Author] = relationship(back_populates='books') @enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @@ -680,6 +679,7 @@ def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, se def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session): """Test SQLAlchemy load_only() and defer() for query optimization""" from sqlalchemy.orm import defer, load_only + # Ensure clean slate session.query(TestModel).delete() session.commit() @@ -747,6 +747,7 @@ def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, ses def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session): """Test SQLAlchemy case() conditional expressions""" from sqlalchemy import String, case + # Ensure clean slate session.query(TestModel).delete() session.commit()