Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions aws_advanced_python_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,16 @@

from logging import DEBUG, getLogger

from aws_advanced_python_wrapper.pep249 import (DatabaseError, DataError,
Error, IntegrityError,
InterfaceError, InternalError,
NotSupportedError,
OperationalError,
ProgrammingError)
from .cleanup import release_resources
from .driver_info import DriverInfo
from .utils.utils import LogUtils
from .wrapper import AwsWrapperConnection
from aws_advanced_python_wrapper.pep249 import (
Error,
InterfaceError,
DatabaseError,
DataError,
OperationalError,
IntegrityError,
InternalError,
ProgrammingError,
NotSupportedError
)

# PEP249 compliance
connect = AwsWrapperConnection.connect
Expand Down Expand Up @@ -58,5 +53,6 @@

__version__ = DriverInfo.DRIVER_VERSION


def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None):
LogUtils.setup_logger(getLogger(name), level, format_string)
7 changes: 2 additions & 5 deletions aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
# limitations under the License.

# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py
from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector
import re

from aws_advanced_python_wrapper import AwsWrapperConnection
from sqlalchemy.dialects.mysql.mysqlconnector import \
MySQLDialect_mysqlconnector


class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector):
Expand All @@ -29,4 +27,3 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector):

name = 'mysql'
driver = 'aws_wrapper_mysqlconnector'

3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py
from psycopg import Connection
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
import re

from aws_advanced_python_wrapper import AwsWrapperConnection

Expand Down
89 changes: 45 additions & 44 deletions tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@

from datetime import date, datetime, time, timezone
from decimal import Decimal
from typing import Any
from typing import Any, List, Optional

import pytest
from sqlalchemy import (JSON, BigInteger, Boolean, Date, DateTime, Float,
ForeignKey, Numeric, SmallInteger, String, Text, Time,
and_, create_engine, or_, text)
from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, joinedload,
mapped_column, relationship, sessionmaker,
subqueryload)
from sqlalchemy.sql import func
from sqlalchemy.orm import (
declarative_base, sessionmaker, relationship, Session, joinedload,
subqueryload
)
from sqlalchemy import (
create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger,
Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_,
and_, text
)

from tests.integration.container.utils.rds_test_utility import RdsTestUtility
from ..utils.conditions import (disable_on_features, enable_on_deployments,
Expand All @@ -40,74 +37,76 @@
from ..utils.test_environment import TestEnvironment
from ..utils.test_environment_features import TestEnvironmentFeatures

class Base:
__allow_unmapped__ = True

Base = declarative_base(cls=Base)
class Base(DeclarativeBase):
pass


class TestModel(Base):
"""Basic test model for SQLAlchemy ORM functionality"""
__tablename__ = 'sqlalchemy_test_model'

id = Column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
email: Mapped[str] = mapped_column(String(254), unique=True)
age: Mapped[int] = mapped_column()
is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True)
created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc))

name = Column(String(100), nullable=False)
email = Column(String(254), nullable=False, unique=True)
age = Column(Integer, nullable=False)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.now(timezone.utc))

class DataTypeModel(Base):
"""Model for testing various data types"""
__tablename__ = 'sqlalchemy_data_type_model'

id = Column(Integer, primary_key=True)
id: Mapped[int] = mapped_column(primary_key=True)

# String fields
string_field = Column(String(255))
text_field = Column(Text)
string_field: Mapped[Optional[str]] = mapped_column(String(255))
text_field: Mapped[Optional[str]] = mapped_column(Text)

# Numeric fields
integer_field = Column(Integer)
small_integer_field = Column(SmallInteger)
big_integer_field = Column(BigInteger)
numeric_field = Column(Numeric(10, 2))
float_field = Column(Float)
integer_field: Mapped[Optional[int]] = mapped_column()
small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger)
big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger)
numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2))
float_field: Mapped[Optional[float]] = mapped_column(Float)

# Boolean field
boolean_field = Column(Boolean, default=False)
boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False)

# Date/Time fields
date_field = Column(Date)
time_field = Column(Time)
datetime_field = Column(DateTime)
date_field: Mapped[Optional[date]] = mapped_column(Date)
time_field: Mapped[Optional[time]] = mapped_column(Time)
datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime)

# JSON field (MySQL 5.7+)
json_field = Column(JSON)
json_field: Mapped[Optional[Any]] = mapped_column(JSON)


class Author(Base):
"""Author model for relationship testing"""
__tablename__ = 'sqlalchemy_author'

id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
email = Column(String(254), nullable=False)
birth_date = Column(Date)
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
email: Mapped[str] = mapped_column(String(254))
birth_date: Mapped[Optional[date]] = mapped_column(Date)

books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan')

books = relationship('Book', back_populates='author', cascade='all, delete-orphan')

class Book(Base):
"""Book model for relationship testing"""
__tablename__ = 'sqlalchemy_book'

id = Column(Integer, primary_key=True)
title = Column(String(200), nullable=False)
author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False)
publication_date = Column(Date, nullable=False)
pages = Column(Integer, nullable=False)
price = Column(Numeric(8, 2), nullable=False)
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id"))
publication_date: Mapped[date] = mapped_column(Date)
pages: Mapped[int] = mapped_column()
price: Mapped[Decimal] = mapped_column(Numeric(8, 2))

author = relationship('Author', back_populates='books')
author: Mapped[Author] = relationship(back_populates='books')

@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented
@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER])
Expand Down Expand Up @@ -680,6 +679,7 @@ def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, se
def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session):
"""Test SQLAlchemy load_only() and defer() for query optimization"""
from sqlalchemy.orm import defer, load_only

# Ensure clean slate
session.query(TestModel).delete()
session.commit()
Expand Down Expand Up @@ -747,6 +747,7 @@ def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, ses
def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session):
"""Test SQLAlchemy case() conditional expressions"""
from sqlalchemy import String, case

# Ensure clean slate
session.query(TestModel).delete()
session.commit()
Expand Down
Loading