From 4936204ba12dca80bf37282899edea0a0274ac30 Mon Sep 17 00:00:00 2001 From: "pushpendra.garg@oracle.com" Date: Mon, 9 Feb 2026 09:54:17 +0000 Subject: [PATCH] Selet AI Tests for Vector Index. --- tests/credential/conftest.py | 74 ++ .../credential/test_2200_async_create_cred.py | 395 ++++++++ tests/credential/test_2200_create_cred.py | 341 +++++++ tests/credential/test_2300_async_drop_cred.py | 254 +++++ tests/credential/test_2300_drop_cred.py | 198 ++++ tests/provider/conftest.py | 97 ++ tests/provider/test_2400_async_enable.py | 277 +++++ tests/provider/test_2400_enable.py | 272 +++++ tests/provider/test_2500_async_disable.py | 275 +++++ tests/provider/test_2500_disable.py | 284 ++++++ tests/test_1010_async_connection.py | 378 +++++++ tests/test_1010_connection.py | 359 +++++++ tests/vector_index/conftest.py | 104 ++ .../test_5000_async_create_index.py | 453 +++++++++ tests/vector_index/test_5000_create_index.py | 434 ++++++++ .../test_5100_async_drop_index.py | 592 +++++++++++ tests/vector_index/test_5100_drop_index.py | 540 ++++++++++ .../test_5200_async_setindex_attributes.py | 959 ++++++++++++++++++ .../test_5200_setindex_attributes.py | 801 +++++++++++++++ .../test_5300_async_getindex_attributes.py | 442 ++++++++ .../test_5300_getindex_attributes.py | 429 ++++++++ .../test_5400_async_list_index.py | 406 ++++++++ tests/vector_index/test_5400_list_index.py | 350 +++++++ .../test_5500_async_enable_disable_index.py | 395 ++++++++ .../test_5500_enable_disable_index.py | 360 +++++++ 25 files changed, 9469 insertions(+) create mode 100644 tests/credential/conftest.py create mode 100644 tests/credential/test_2200_async_create_cred.py create mode 100644 tests/credential/test_2200_create_cred.py create mode 100644 tests/credential/test_2300_async_drop_cred.py create mode 100644 tests/credential/test_2300_drop_cred.py create mode 100644 tests/provider/conftest.py create mode 100644 tests/provider/test_2400_async_enable.py create mode 100644 tests/provider/test_2400_enable.py create mode 100644 tests/provider/test_2500_async_disable.py create mode 100644 tests/provider/test_2500_disable.py create mode 100644 tests/test_1010_async_connection.py create mode 100644 tests/test_1010_connection.py create mode 100644 tests/vector_index/conftest.py create mode 100644 tests/vector_index/test_5000_async_create_index.py create mode 100644 tests/vector_index/test_5000_create_index.py create mode 100644 tests/vector_index/test_5100_async_drop_index.py create mode 100644 tests/vector_index/test_5100_drop_index.py create mode 100644 tests/vector_index/test_5200_async_setindex_attributes.py create mode 100644 tests/vector_index/test_5200_setindex_attributes.py create mode 100644 tests/vector_index/test_5300_async_getindex_attributes.py create mode 100644 tests/vector_index/test_5300_getindex_attributes.py create mode 100644 tests/vector_index/test_5400_async_list_index.py create mode 100644 tests/vector_index/test_5400_list_index.py create mode 100644 tests/vector_index/test_5500_async_enable_disable_index.py create mode 100644 tests/vector_index/test_5500_enable_disable_index.py diff --git a/tests/credential/conftest.py b/tests/credential/conftest.py new file mode 100644 index 0000000..92810c8 --- /dev/null +++ b/tests/credential/conftest.py @@ -0,0 +1,74 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import os + +import pytest +import select_ai + + +def get_credential_env_value(name, default_value=None): + return os.environ.get(f"PYSAI_TEST_{name}", default_value) + + +@pytest.fixture(scope="session") +def oci_credential(): + """ + Override the root autouse OCI credential fixture for credential-only tests. + These suites should not require OCI environment variables unless a test + explicitly opts into them. + """ + return None + + +@pytest.fixture(scope="session") +def credential_test_params(test_env): + return { + "user": test_env.test_user, + "password": test_env.test_user_password, + "dsn": test_env.connect_string, + "user_ocid": get_credential_env_value( + "OCI_USER_OCID", default_value="user ocid" + ), + "tenancy_ocid": get_credential_env_value( + "OCI_TENANCY_OCID", default_value="tenancy ocid" + ), + "private_key": get_credential_env_value( + "OCI_PRIVATE_KEY", default_value="private key" + ), + "fingerprint": get_credential_env_value( + "OCI_FINGERPRINT", default_value="fingerprint" + ), + "cred_username": get_credential_env_value( + "CRED_USERNAME", default_value="OCI credential username" + ), + "cred_password": get_credential_env_value( + "CRED_PASSWORD", default_value="OCI credential password" + ), + } + + +@pytest.fixture(scope="session") +def credential_connect_as(test_env): + def _connect_as(admin=False, **overrides): + select_ai.disconnect() + connect_params = test_env.connect_params(admin=admin) + connect_params.update(overrides) + select_ai.connect(**connect_params) + + return _connect_as + + +@pytest.fixture(scope="session") +def credential_async_connect_as(test_env): + async def _connect_as(admin=False, **overrides): + await select_ai.async_disconnect() + connect_params = test_env.connect_params(admin=admin) + connect_params.update(overrides) + await select_ai.async_connect(**connect_params) + + return _connect_as diff --git a/tests/credential/test_2200_async_create_cred.py b/tests/credential/test_2200_async_create_cred.py new file mode 100644 index 0000000..ba0d656 --- /dev/null +++ b/tests/credential/test_2200_async_create_cred.py @@ -0,0 +1,395 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestAsyncCreateCredential") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def credential_params(request, credential_test_params): + request.cls.credential_params = credential_test_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown_cred( + request, + async_connect, + credential_params, + credential_async_connect_as, +): + logger.info("=== Setting up TestAsyncCreateCredential class ===") + assert await select_ai.async_is_connected(), "Connection to DB failed" + request.cls.credential_async_connect_as = staticmethod( + credential_async_connect_as + ) + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestAsyncCreateCredential class ===") + logger.info("Connection cleanup is owned by root session fixtures.") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("credential_params", "setup_and_teardown_cred") +class TestAsyncCreateCredential: + @staticmethod + def get_native_cred_param(params, cred_name=None): + return dict( + credential_name=cred_name, + user_ocid=params["user_ocid"], + tenancy_ocid=params["tenancy_ocid"], + private_key=params["private_key"], + fingerprint=params["fingerprint"], + ) + + @staticmethod + def get_cred_param(params, cred_name=None): + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"], + ) + + @staticmethod + async def drop_credential(cred_name="GENAI_CRED"): + logger.info("Dropping credential: %s", cred_name) + await select_ai.async_delete_credential(cred_name, force=True) + logger.info("Dropped credential: %s", cred_name) + + async def test_2201(self): + """Testing basic credential creation.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + logger.info("Creating credential: %s", credential) + try: + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Credential created successfully.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2202(self): + """Testing creating credential twice without replace.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + try: + await select_ai.async_create_credential(credential=credential) + logger.info("First credential creation successful.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + logger.info("Attempting to create credential again (expected to fail)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_create_credential(credential=credential) + logger.info("Caught expected DatabaseError: %s", exc_info.value) + assert "ORA-20022" in str(exc_info.value) + await self.drop_credential() + + async def test_2203(self): + """Testing repeated credential creation with replace=True.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + for i in range(5): + logger.info("Creating credential iteration %s...", i + 1) + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("Repeated creation succeeded.") + await self.drop_credential() + + async def test_2204(self): + """Testing credential creation with replace=True.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + try: + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("Credential created successfully with replace=True.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2205(self): + """Testing credential creation twice with replace=True.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + try: + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info( + "Credential created successfully twice with replace=True." + ) + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2206(self): + """Testing replace=True then replace=False behavior.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + try: + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("First creation succeeded.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + logger.info("Second creation without replace (expected to fail)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_create_credential(credential=credential) + logger.info("Caught expected error: %s", exc_info.value) + assert "ORA-20022" in str(exc_info.value) + await self.drop_credential() + + async def test_2207(self): + """Testing replace=False followed by replace=True.""" + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + try: + await select_ai.async_create_credential(credential=credential) + logger.info("Credential created (replace=False).") + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("Credential replaced successfully (replace=True).") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2208(self): + """Testing native credential creation.""" + credential = self.get_native_cred_param( + self.credential_params, + "GENAI_CRED", + ) + try: + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Native credential created successfully.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2209(self): + """Testing native credential creation twice.""" + credential = self.get_native_cred_param( + self.credential_params, + "GENAI_CRED", + ) + try: + await select_ai.async_create_credential(credential=credential) + logger.info("First native credential created.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_create_credential(credential=credential) + logger.info("Expected error caught: %s", exc_info.value) + assert "ORA-20022" in str(exc_info.value) + await self.drop_credential() + + async def test_2210(self): + """Testing native credential creation with replace=True.""" + credential = self.get_native_cred_param( + self.credential_params, + "GENAI_CRED", + ) + try: + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("Native credential created successfully.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + await self.drop_credential() + + async def test_2211(self): + """Testing native credential creation with replace=True twice.""" + credential = self.get_native_cred_param( + self.credential_params, + "GENAI_CRED", + ) + for i in range(2): + logger.info( + "Creating native credential iteration %s (replace=True)...", + i + 1, + ) + await select_ai.async_create_credential( + credential=credential, + replace=True, + ) + logger.info("Native credential replaced successfully twice.") + await self.drop_credential() + + async def test_2212(self): + """Testing creation with empty credential name.""" + credential = self.get_cred_param(self.credential_params) + with pytest.raises(Exception) as exc_info: + await select_ai.async_create_credential(credential=credential) + logger.info("Expected exception caught: %s", exc_info.value) + assert "ORA-20010: Missing credential name" in str(exc_info.value) + + async def test_2213(self): + """Testing credential creation with empty dictionary.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_create_credential(credential={}) + logger.info("Expected exception caught: %s", exc_info.value) + assert ( + "PLS-00306: wrong number or types of arguments in call to " + "'CREATE_CREDENTIAL'" in str(exc_info.value) + ) + + async def test_2214(self): + """Testing credential creation with invalid username.""" + credential = dict( + credential_name="GENAI_CRED", + username="invalid_username", + password=self.credential_params["cred_password"], + ) + await select_ai.async_create_credential(credential=credential, replace=True) + logger.info("Credential with invalid username created successfully.") + await self.drop_credential() + + async def test_2215(self): + """Testing credential creation with invalid password.""" + credential = dict( + credential_name="GENAI_CRED", + username=self.credential_params["cred_username"], + password="invalid_pwd", + ) + await select_ai.async_create_credential(credential=credential, replace=True) + logger.info("Credential with invalid password created successfully.") + await self.drop_credential() + + async def test_2216(self): + """Testing credential creation when DB is disconnected.""" + await select_ai.async_disconnect() + credential = self.get_cred_param(self.credential_params, "GENAI_CRED") + with pytest.raises(DatabaseNotConnectedError): + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Expected DatabaseNotConnectedError raised.") + await self.credential_async_connect_as() + + async def test_2217(self): + """Test Credential creation for a local test user.""" + test_username = "TEST_USER1" + test_password = self.credential_params["password"] + escaped_password = test_password.replace('"', '""') + + logger.info("Connecting as admin user...") + await self.credential_async_connect_as(admin=True) + logger.info("Admin connection established.") + + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Existing user '%s' dropped.", test_username) + except oracledb.DatabaseError: + logger.info( + "User '%s' did not exist, continuing...", + test_username, + ) + await admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{escaped_password}"' + ) + await admin_cursor.execute( + "grant create session, create table, unlimited tablespace " + f"to {test_username}" + ) + await admin_cursor.execute( + f"grant execute on dbms_cloud to {test_username}" + ) + logger.info( + "User '%s' created and granted privileges.", + test_username, + ) + + logger.info("Connecting as test user '%s'...", test_username) + await self.credential_async_connect_as( + user=test_username, + password=test_password, + ) + logger.info("Test user connection established.") + + credential = self.get_cred_param( + self.credential_params, + "GENAI_CRED_USER1", + ) + try: + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Credential created successfully.") + except Exception as exc: + pytest.fail(f"async_create_credential() raised {exc} unexpectedly.") + + await self.drop_credential("GENAI_CRED_USER1") + + logger.info("Reconnecting as admin to drop test user '%s'...", test_username) + await self.credential_async_connect_as(admin=True) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Test user '%s' dropped successfully.", test_username) + await self.credential_async_connect_as() + + async def test_2218(self): + """Testing credential name with special characters.""" + credential = self.get_cred_param( + self.credential_params, + "GENAI_CRED!@#", + ) + with pytest.raises( + oracledb.DatabaseError, + match="ORA-20010: Invalid credential name", + ): + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Invalid name test passed.") + + async def test_2219(self): + """Testing credential name exceeding 128 characters.""" + long_name = "GENAI_CRED" + "_" + "a" * (128 - len("GENAI_CRED")) + credential = self.get_cred_param(self.credential_params, long_name) + with pytest.raises( + oracledb.DatabaseError, + match=( + r"ORA-20008: Credential name length \(129\) exceeds " + r"maximum length \(128\)" + ), + ): + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Long credential name test passed.") diff --git a/tests/credential/test_2200_create_cred.py b/tests/credential/test_2200_create_cred.py new file mode 100644 index 0000000..3886f1f --- /dev/null +++ b/tests/credential/test_2200_create_cred.py @@ -0,0 +1,341 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestCreateCredential") + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + +@pytest.fixture(scope="class") +def credential_params(request, credential_test_params): + request.cls.credential_params = credential_test_params + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown_cred( + request, + connect, + credential_params, + credential_connect_as, +): + logger.info("=== Setting up TestCreateCredential class ===") + assert select_ai.is_connected(), "Connection to DB failed" + request.cls.credential_connect_as = staticmethod(credential_connect_as) + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestCreateCredential class ===") + logger.info("Connection cleanup is owned by root session fixtures.") + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + +@pytest.mark.usefixtures("credential_params", "setup_and_teardown_cred") +class TestCreateCredential: + logger = logger + + @staticmethod + def get_native_cred_param(params, cred_name=None): + return dict( + credential_name = cred_name, + user_ocid = params["user_ocid"], + tenancy_ocid = params["tenancy_ocid"], + private_key = params["private_key"], + fingerprint = params["fingerprint"] + ) + @staticmethod + def get_cred_param(params, cred_name=None): + return dict( + credential_name = cred_name, + username = params["cred_username"], + password = params["cred_password"] + ) + @staticmethod + def drop_credential_cursor(cursor, cred_name='GENAI_CRED'): + logger.info(f"Dropping credential: {cred_name}") + cursor.callproc( + "DBMS_CLOUD.DROP_CREDENTIAL", + keyword_parameters={ + "credential_name": cred_name + }, + ) + logger.info(f"Dropped credential: {cred_name}") + + def test_2201(self): + """Testing basic credential creation""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + self.logger.info(f"Creating credential: {credential}") + try: + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Credential created successfully.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2202(self): + """Testing creating credential twice without replace""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential) + self.logger.info("First credential creation successful.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + self.logger.info("Attempting to create credential again (expected to fail)...") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + self.logger.info(f"Caught expected DatabaseError: {cm.value}") + assert "ORA-20022" in str(cm.value) + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2203(self): + """Testing repeated credential creation with replace=True""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + for i in range(5): + self.logger.info(f"Creating credential iteration {i+1}...") + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Repeated creation succeeded.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2204(self): + """Testing credential creation with replace=True""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential created successfully with replace=True.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2205(self): + """Testing credential creation twice with replace=True""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential created successfully with replace=True.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential created successfully with replace=True.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + assert True, "Credential creation and replacement passed without exception." + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2206(self): + """Testing replace=True then replace=False behavior""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("First creation succeeded.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + self.logger.info("Second creation without replace (expected to fail)...") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + self.logger.info(f"Caught expected error: {cm.value}") + assert "ORA-20022" in str(cm.value) + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2207(self): + """Testing replace=False followed by replace=True""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential) + self.logger.info("Credential created (replace=False).") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential replaced successfully (replace=True).") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + assert True, "Credential creation and replacement passed without exception." + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2208(self): + """Testing native credential creation""" + credential = self.get_native_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Native credential created successfully.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2209(self): + """Testing native credential creation twice""" + credential = self.get_native_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential) + self.logger.info("First native credential created.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + self.logger.info(f"Expected error caught: {cm.value}") + assert "ORA-20022" in str(cm.value) + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2210(self): + """Testing native credential creation with replace=True""" + credential = self.get_native_cred_param(self.credential_params, 'GENAI_CRED') + try: + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Native credential created successfully.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2211(self): + """Testing native credential creation with replace=True twice""" + credential = self.get_native_cred_param(self.credential_params, 'GENAI_CRED') + for i in range(2): + self.logger.info(f"Creating native credential iteration {i+1} (replace=True)...") + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Native credential replaced successfully twice.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2212(self): + """Testing creation with empty credential name""" + credential = self.get_cred_param(self.credential_params) + with pytest.raises(Exception) as cm: + select_ai.create_credential(credential=credential) + self.logger.info(f"Expected exception caught: {cm.value}") + assert "ORA-20010: Missing credential name" in str(cm.value) + + def test_2213(self): + """Testing credential creation with empty dictionary""" + credential = dict() + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.create_credential(credential=credential) + self.logger.info(f"Expected exception caught: {cm.value}") + assert ( + "PLS-00306: wrong number or types of arguments in call to 'CREATE_CREDENTIAL'" in str(cm.value) + ) + + def test_2214(self): + """Testing credential creation with invalid username""" + credential = dict( + credential_name = 'GENAI_CRED', + username = 'invalid_username', + password = self.credential_params["cred_password"] + ) + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential with invalid username created successfully.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2215(self): + """Testing credential creation with invalid password""" + credential = dict( + credential_name = 'GENAI_CRED', + username = self.credential_params["cred_username"], + password = 'invalid_pwd' + ) + select_ai.create_credential(credential=credential, replace=True) + self.logger.info("Credential with invalid password created successfully.") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor) + + def test_2216(self): + """Testing credential creation when DB is disconnected""" + select_ai.disconnect() + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED') + with pytest.raises(DatabaseNotConnectedError): + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Expected DatabaseNotConnectedError raised.") + self.credential_connect_as() + + def test_2217(self): + """Test Credential creation for a local test user""" + self.logger.info("Connecting as admin user...") + self.credential_connect_as(admin=True) + self.logger.info("Admin connection established.") + test_username = "TEST_USER1" + test_password = self.credential_params["password"] + escaped_password = test_password.replace('"', '""') + self.logger.info(f"Ensuring test user '{test_username}' does not exist...") + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + self.logger.info(f"Existing user '{test_username}' dropped.") + except oracledb.DatabaseError: + self.logger.info(f"User '{test_username}' did not exist, continuing...") + self.logger.info(f"Creating test user '{test_username}'...") + admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{escaped_password}"' + ) + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + self.logger.info(f"User '{test_username}' created and granted privileges.") + self.logger.info(f"Connecting as test user '{test_username}'...") + self.credential_connect_as( + user=test_username, + password=test_password, + ) + self.logger.info("Test user connection established.") + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED_USER1') + self.logger.info(f"Creating credential '{credential['credential_name']}' for test user...") + try: + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Credential created successfully.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + self.logger.info(f"Dropping credential '{credential['credential_name']}'...") + with select_ai.cursor() as cursor: + self.drop_credential_cursor(cursor, 'GENAI_CRED_USER1') + self.logger.info("Credential dropped.") + self.logger.info("Disconnecting test user...") + select_ai.disconnect() + self.logger.info("Disconnected test user.") + self.logger.info(f"Reconnecting as admin to drop test user '{test_username}'...") + self.credential_connect_as(admin=True) + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + self.logger.info(f"Test user '{test_username}' dropped successfully.") + self.credential_connect_as() + + def test_2218(self): + """Testing credential name with special characters""" + credential = self.get_cred_param(self.credential_params, 'GENAI_CRED!@#') + with pytest.raises(oracledb.DatabaseError, match="ORA-20010: Invalid credential name"): + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Invalid name test passed.") + + def test_2219(self): + """Testing credential name exceeding 128 characters""" + long_name = "GENAI_CRED" + "_" + "a" * (128 - len('GENAI_CRED')) + credential = self.get_cred_param(self.credential_params, long_name) + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-20008: Credential name length \(129\) exceeds maximum length \(128\)" + ): + select_ai.create_credential(credential=credential, replace=False) + self.logger.info("Long credential name test passed.") diff --git a/tests/credential/test_2300_async_drop_cred.py b/tests/credential/test_2300_async_drop_cred.py new file mode 100644 index 0000000..a3fd119 --- /dev/null +++ b/tests/credential/test_2300_async_drop_cred.py @@ -0,0 +1,254 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai + +logger = logging.getLogger("TestAsyncDropCredential") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def drop_params(request, credential_test_params): + request.cls.drop_params = credential_test_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown( + request, + async_connect, + drop_params, + credential_async_connect_as, +): + logger.info("=== Setting up TestAsyncDropCredential class ===") + assert await select_ai.async_is_connected(), "Connection to DB failed" + request.cls.credential_async_connect_as = staticmethod( + credential_async_connect_as + ) + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestAsyncDropCredential class ===") + logger.info("Connection cleanup is owned by root session fixtures.") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("drop_params", "setup_and_teardown") +class TestAsyncDropCredential: + @staticmethod + def get_cred_param(params, cred_name=None): + logger.info("Preparing credential params for: %s", cred_name) + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"], + ) + + @classmethod + async def create_test_credential(cls, cred_name="GENAI_CRED"): + logger.info("Creating test credential: %s", cred_name) + credential = cls.get_cred_param(cls.drop_params, cred_name) + try: + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + logger.info("Credential '%s' created successfully.", cred_name) + except Exception as exc: + pytest.fail( + f"async_create_credential() raised {exc} unexpectedly." + ) + + @classmethod + async def create_local_user(cls, test_username="TEST_USER1"): + logger.info("Creating local user: %s", test_username) + test_password = cls.drop_params["password"] + escaped_password = test_password.replace('"', '""') + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass + await admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{escaped_password}"' + ) + await admin_cursor.execute( + "grant create session, create table, unlimited tablespace " + f"to {test_username}" + ) + await admin_cursor.execute( + f"grant execute on dbms_cloud to {test_username}" + ) + logger.info("Local user '%s' ready.", test_username) + + async def test_2301(self): + """Deleting existing credential (force=True).""" + await self.create_test_credential() + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Credential deleted successfully.") + except Exception as exc: + pytest.fail(f"async_delete_credential() raised {exc} unexpectedly.") + + async def test_2302(self): + """Deleting same credential twice (force=True).""" + await self.create_test_credential() + await select_ai.async_delete_credential("GENAI_CRED", force=True) + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Double deletion succeeded (force=True).") + + async def test_2303(self): + """Deleting same credential twice (force=False).""" + await self.create_test_credential() + await select_ai.async_delete_credential("GENAI_CRED", force=False) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_delete_credential("GENAI_CRED", force=False) + logger.info( + "Expected DatabaseError for second delete (force=False): %s", + exc_info.value, + ) + + async def test_2304(self): + """Deleting nonexistent credential (default force=False).""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_delete_credential("nonexistent_cred") + logger.info( + "Expected DatabaseError for nonexistent credential: %s", + exc_info.value, + ) + + async def test_2305(self): + """Deleting nonexistent credential (force=False).""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_delete_credential( + "nonexistent_cred", + force=False, + ) + logger.info( + "Expected DatabaseError for nonexistent credential: %s", + exc_info.value, + ) + + async def test_2306(self): + """Deleting nonexistent credential (force=True).""" + try: + await select_ai.async_delete_credential( + "nonexistent_cred", + force=True, + ) + logger.info("No error raised (expected behavior).") + except Exception as exc: + pytest.fail( + f"async_delete_credential(force=True) raised {exc} unexpectedly." + ) + + async def test_2307(self): + """Deleting credential as local user.""" + test_username = "TEST_USER1" + await self.credential_async_connect_as(admin=True) + await self.create_local_user(test_username) + + await self.credential_async_connect_as( + user=test_username, + password=self.drop_params["password"], + ) + + credential = self.get_cred_param(self.drop_params, "GENAI_CRED_USER1") + await select_ai.async_create_credential( + credential=credential, + replace=False, + ) + + try: + await select_ai.async_delete_credential( + "GENAI_CRED_USER1", + force=True, + ) + logger.info("Local user credential deleted successfully.") + except Exception as exc: + pytest.fail(f"async_delete_credential() raised {exc} unexpectedly.") + finally: + await self.credential_async_connect_as(admin=True) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Local user cleanup complete.") + await self.credential_async_connect_as() + + async def test_2308(self): + """Deleting credential with invalid name.""" + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-20010: Invalid credential name", + ): + await select_ai.async_delete_credential("invalid!@#", force=True) + logger.info("Caught expected ORA-20010 for invalid name.") + + async def test_2309(self): + """Deleting credential without active connection.""" + await select_ai.async_disconnect() + with pytest.raises(select_ai.errors.DatabaseNotConnectedError) as exc_info: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info( + "Expected DatabaseNotConnectedError raised: %s", + exc_info.value, + ) + await self.credential_async_connect_as() + + async def test_2310(self): + """Deleting credential with name exceeding max length.""" + long_name = "GENAI_CRED_" + "a" * 120 + with pytest.raises( + oracledb.DatabaseError, + match=( + r"ORA-20008: Credential name length .* exceeds maximum length" + ), + ): + await select_ai.async_delete_credential(long_name, force=True) + logger.info("Caught expected ORA-20008 for long credential name.") + + async def test_2311(self): + """Deleting credential with lowercase name.""" + await self.create_test_credential("GENAI_CRED") + try: + await select_ai.async_delete_credential(credential_name="genai_cred") + logger.info( + "Credential deleted successfully (case-insensitive)." + ) + except Exception as exc: + pytest.fail( + "async_delete_credential raised " + f"{exc} unexpectedly for lowercase name" + ) + + async def test_2312(self): + """Deleting credential with empty or None name.""" + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-20010: Missing credential name", + ): + await select_ai.async_delete_credential( + credential_name="", + force=True, + ) + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-20010: Missing credential name", + ): + await select_ai.async_delete_credential( + credential_name=None, + force=True, + ) + logger.info("Caught expected ORA-20010 for missing credential name.") diff --git a/tests/credential/test_2300_drop_cred.py b/tests/credential/test_2300_drop_cred.py new file mode 100644 index 0000000..0553266 --- /dev/null +++ b/tests/credential/test_2300_drop_cred.py @@ -0,0 +1,198 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb + +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestDropCredential") + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + +@pytest.fixture(scope="class") +def drop_params(request, credential_test_params): + request.cls.drop_params = credential_test_params + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, drop_params, credential_connect_as): + logger.info("=== Setting up TestDropCredential class ===") + assert select_ai.is_connected(), "Connection to DB failed" + request.cls.credential_connect_as = staticmethod(credential_connect_as) + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestDropCredential class ===") + logger.info("Connection cleanup is owned by root session fixtures.") + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + +@pytest.mark.usefixtures("drop_params", "setup_and_teardown") +class TestDropCredential: + @staticmethod + def get_cred_param(params, cred_name=None): + logger.info(f"Preparing credential params for: {cred_name}") + return dict( + credential_name = cred_name, + username = params["cred_username"], + password = params["cred_password"] + ) + @classmethod + def create_test_credential(cls, cred_name="GENAI_CRED"): + logger.info(f"Creating test credential: {cred_name}") + credential = cls.get_cred_param(cls.drop_params, cred_name) + try: + select_ai.create_credential(credential=credential, replace=False) + logger.info(f"Credential '{cred_name}' created successfully.") + except Exception as e: + pytest.fail(f"create_credential() raised {e} unexpectedly.") + @classmethod + def create_local_user(cls, test_username="TEST_USER1"): + logger.info(f"Creating local user: {test_username}") + test_password = cls.drop_params["password"] + escaped_password = test_password.replace('"', '""') + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{escaped_password}"' + ) + admin_cursor.execute(f"grant create session, create table, unlimited tablespace to {test_username}") + admin_cursor.execute(f"grant execute on dbms_cloud to {test_username}") + logger.info(f"Local user '{test_username}' ready.") + + def test_2301(self): + """Deleting existing credential (force=True)""" + logger.info("Deleting existing credential (force=True)") + self.create_test_credential() + try: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Credential deleted successfully.") + except Exception as e: + pytest.fail(f"delete_credential() raised {e} unexpectedly.") + + def test_2302(self): + """Deleting same credential twice (force=True)""" + logger.info("Deleting same credential twice (force=True)") + self.create_test_credential() + select_ai.delete_credential("GENAI_CRED", force=True) + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Double deletion succeeded (force=True).") + + def test_2303(self): + """Deleting same credential twice (force=False)""" + logger.info("Deleting same credential twice (force=False)") + self.create_test_credential() + select_ai.delete_credential("GENAI_CRED", force=False) + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.delete_credential("GENAI_CRED", force=False) + logger.info(f"Expected DatabaseError for second delete (force=False): {cm.value}") + + def test_2304(self): + """Deleting nonexistent credential (default force=False)""" + logger.info("Deleting nonexistent credential (default force=False)") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.delete_credential("nonexistent_cred") + logger.info(f"Expected DatabaseError for nonexistent credential: {cm.value}") + + def test_2305(self): + """Deleting nonexistent credential (force=False)""" + logger.info("Deleting nonexistent credential (force=False)") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.delete_credential("nonexistent_cred", force=False) + logger.info(f"Expected DatabaseError for nonexistent credential: {cm.value}") + + def test_2306(self): + """Deleting nonexistent credential (force=True)""" + logger.info("Deleting nonexistent credential (force=True)") + try: + select_ai.delete_credential("nonexistent_cred", force=True) + logger.info("No error raised (expected behavior).") + except Exception as e: + pytest.fail(f"delete_credential(force=True) raised {e} unexpectedly.") + + def test_2307(self): + """Deleting credential as local user""" + logger.info("Deleting credential as local user") + test_username = "TEST_USER1" + self.credential_connect_as(admin=True) + self.create_local_user(test_username) + self.credential_connect_as( + user=test_username, + password=self.drop_params["password"], + ) + credential = self.get_cred_param(self.drop_params, "GENAI_CRED_USER1") + try: + select_ai.delete_credential("GENAI_CRED_USER1", force=True) + logger.info("Local user credential deleted successfully.") + except Exception as e: + pytest.fail(f"delete_credential() raised {e} unexpectedly.") + finally: + select_ai.disconnect() + self.credential_connect_as(admin=True) + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Local user cleanup complete.") + self.credential_connect_as() + + def test_2308(self): + """Deleting credential with invalid name""" + logger.info("Deleting credential with invalid name") + with pytest.raises(oracledb.DatabaseError, match=r"ORA-20010: Invalid credential name"): + select_ai.delete_credential("invalid!@#", force=True) + logger.info("Caught expected ORA-20010 for invalid name.") + + def test_2309(self): + """Deleting credential without active connection""" + logger.info("Deleting credential without active connection") + select_ai.disconnect() + with pytest.raises(select_ai.errors.DatabaseNotConnectedError) as cm: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info(f"Expected DatabaseNotConnectedError raised: {cm.value}") + self.credential_connect_as() + + def test_2310(self): + """Deleting credential with name exceeding max length""" + logger.info("Deleting credential with name exceeding max length") + long_name = "GENAI_CRED_" + "a" * 120 + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-20008: Credential name length .* exceeds maximum length" + ): + select_ai.delete_credential(long_name, force=True) + logger.info("Caught expected ORA-20008 for long credential name.") + + def test_2311(self): + """Deleting credential with lowercase name""" + logger.info("Deleting credential with lowercase name") + self.create_test_credential("GENAI_CRED") + try: + select_ai.delete_credential(credential_name="genai_cred") + logger.info("Credential deleted successfully (case-insensitive).") + except Exception as e: + pytest.fail(f"async_delete_credential raised {e} unexpectedly for lowercase name") + + def test_2312(self): + """Deleting credential with empty or None name""" + logger.info("Deleting credential with empty or None name") + with pytest.raises(oracledb.DatabaseError, match=r"ORA-20010: Missing credential name"): + select_ai.delete_credential(credential_name="", force=True) + with pytest.raises(oracledb.DatabaseError, match=r"ORA-20010: Missing credential name"): + select_ai.delete_credential(credential_name=None, force=True) + logger.info("Caught expected ORA-20010 for missing credential name.") diff --git a/tests/provider/conftest.py b/tests/provider/conftest.py new file mode 100644 index 0000000..391073d --- /dev/null +++ b/tests/provider/conftest.py @@ -0,0 +1,97 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest +import select_ai + +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "UNLIMITED TABLESPACE", +) + + +def get_supported_provider_endpoints(): + """ + Returns provider endpoints that can be derived directly from the public + provider classes and are therefore suitable for ACL enable/disable tests. + OCI is intentionally not included here because its endpoint is deployment + specific and not exposed as one canonical default by OCIGenAIProvider. + """ + return { + "openai": select_ai.OpenAIProvider.provider_endpoint, + "cohere": select_ai.CohereProvider.provider_endpoint, + "google": select_ai.GoogleProvider.provider_endpoint, + "huggingface": select_ai.HuggingFaceProvider.provider_endpoint, + "anthropic": select_ai.AnthropicProvider.provider_endpoint, + "azure": select_ai.AzureProvider( + azure_resource_name="python-select-ai-test" + ).provider_endpoint, + "aws": select_ai.AWSProvider(region="us-east-1").provider_endpoint, + } + + +@pytest.fixture(scope="session") +def oci_credential(): + """ + Provider tests do not need the shared OCI credential fixture and they + temporarily switch connections to admin. Override it locally to avoid + unrelated teardown failures in the session fixture stack. + """ + return None + + +def ensure_provider_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + with select_ai.db.get_connection() as conn: + conn.commit() + + +def grant_provider_test_user_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + with select_ai.db.get_connection() as conn: + conn.commit() + + +async def async_ensure_provider_test_user_exists(username: str, password: str): + username_upper = username.upper() + async with select_ai.async_cursor() as cr: + await cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if await cr.fetchone(): + return + escaped_password = password.replace('"', '""') + await cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + async with select_ai.db.async_get_connection() as async_connection: + await async_connection.commit() + + +async def async_grant_provider_test_user_privileges(username: str): + username_upper = username.upper() + async with select_ai.async_cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + await cr.execute(f"GRANT {privilege} TO {username_upper}") + async with select_ai.db.async_get_connection() as async_connection: + await async_connection.commit() diff --git a/tests/provider/test_2400_async_enable.py b/tests/provider/test_2400_async_enable.py new file mode 100644 index 0000000..338ed78 --- /dev/null +++ b/tests/provider/test_2400_async_enable.py @@ -0,0 +1,277 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from provider.conftest import ( + async_ensure_provider_test_user_exists, + async_grant_provider_test_user_privileges, + get_supported_provider_endpoints, +) + +logger = logging.getLogger("TestAsyncEnableProvider") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO, + ) + + +@pytest.fixture(scope="class") +def provider_params(request, test_env): + request.cls.provider_params = { + "user": test_env.admin_user, + "password": test_env.admin_password, + "dsn": test_env.connect_string, + } + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, provider_params, test_env): + logger.info("=== Setting up TestAsyncEnableProvider class ===") + await select_ai.async_disconnect() + await select_ai.async_connect(**test_env.connect_params(admin=True)) + assert await select_ai.async_is_connected(), "Connection to DB failed" + + cls = request.cls + cls.user = cls.provider_params["user"] + cls.password = cls.provider_params["password"] + cls.dsn = cls.provider_params["dsn"] + cls.db_users = [] + + try: + for i in range(1, 6): + user = f"DB_USER{i}" + await cls.create_local_user(user) + cls.db_users.append(user) + except Exception: + await select_ai.async_disconnect() + await select_ai.async_connect(**test_env.connect_params()) + raise + + yield + + logger.info("=== Tearing down TestAsyncEnableProvider class ===") + async with select_ai.async_cursor() as admin_cursor: + for user in cls.db_users: + try: + await admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError: + pass + try: + await select_ai.async_disconnect() + except Exception as exc: + logger.warning("Warning: disconnect failed (%s)", exc) + await select_ai.async_connect(**test_env.connect_params()) + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("provider_params", "setup_and_teardown") +class TestAsyncEnableProvider: + @classmethod + async def create_local_user(cls, test_username="TEST_USER1"): + test_password = cls.password + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass + await async_ensure_provider_test_user_exists(test_username, test_password) + await async_grant_provider_test_user_privileges(test_username) + await select_ai.async_grant_privileges(users=[test_username]) + + def setup_method(self): + self.provider_endpoint = "*.openai.azure.com" + self.db_users = self.__class__.db_users + + async def test_2401(self): + """Test enabling provider with valid users and endpoint.""" + try: + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully for all test users.") + except Exception as exc: + pytest.fail( + f"async_grant_http_access() raised {exc} unexpectedly." + ) + + async def test_2402(self): + """Test enabling provider with a non-existent username.""" + db_users = ["DB_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError, match=r"ORA-46238") as exc_info: + await select_ai.async_grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Expected DatabaseError caught: %s", exc_info.value) + assert "TEST_USER2" in str(exc_info.value) + + async def test_2403(self): + """Test enabling provider with all non-existent usernames.""" + db_users = ["TEST_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError, match=r"ORA-46238") as exc_info: + await select_ai.async_grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Expected DatabaseError caught: %s", exc_info.value) + assert "TEST_USER1" in str(exc_info.value) + + async def test_2404(self): + """Test enabling provider with empty users list.""" + try: + await select_ai.async_grant_http_access( + users=[], + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully with empty users.") + except Exception as exc: + pytest.fail( + "async_grant_http_access() raised " + f"{exc} unexpectedly with empty users." + ) + + async def test_2405(self): + """Test enabling provider with users as string instead of list.""" + try: + await select_ai.async_grant_http_access( + users="DB_USER1", + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully with string user input.") + except Exception as exc: + pytest.fail( + f"async_grant_http_access() raised unexpected exception: {exc}" + ) + + async def test_2406(self): + """Test enabling provider with users as int - expect TypeError.""" + with pytest.raises(TypeError) as exc_info: + await select_ai.async_grant_http_access( + users=2, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Expected TypeError caught: %s", exc_info.value) + + async def test_2407(self): + """Test enabling provider with missing provider_endpoint.""" + with pytest.raises(oracledb.DatabaseError, match=r"ORA-29261") as exc_info: + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint=None, + ) + logger.info("Expected DatabaseError caught: %s", exc_info.value) + + async def test_2408(self): + """Test enabling provider with a syntactically valid custom host.""" + try: + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint="invalid.endpoint", + ) + logger.info("Provider enabled successfully for custom host name.") + except Exception as exc: + pytest.fail( + "async_grant_http_access() raised " + f"{exc} unexpectedly with custom host name." + ) + + async def test_2409(self): + """Test enabling provider with duplicate usernames.""" + try: + await select_ai.async_grant_http_access( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully with duplicate users.") + except Exception as exc: + pytest.fail( + "async_grant_http_access() raised " + f"{exc} unexpectedly with duplicate users." + ) + + async def test_2410(self): + """Test enabling provider with lowercase username.""" + try: + await select_ai.async_grant_http_access( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully for lowercase username.") + except Exception as exc: + pytest.fail( + "async_grant_http_access() raised " + f"{exc} unexpectedly with lowercase username." + ) + + async def test_2411(self): + """Test enabling provider with username containing whitespace.""" + db_users = [f" {self.db_users[0]} "] + try: + await select_ai.async_grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider enabled successfully with whitespace user.") + except Exception as exc: + pytest.fail( + "async_grant_http_access() raised " + f"{exc} unexpectedly with whitespace user." + ) + + async def test_2412(self): + """Test enabling provider with large user list.""" + db_users = [f"DB_USER_{i}" for i in range(1000)] + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Expected DatabaseError caught: %s", exc_info.value) + + async def test_2413(self): + """Test enabling provider with a valid custom endpoint.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com", + ) + logger.info("Expected DatabaseError caught: %s", exc_info.value) + assert ( + "ORA-24244: invalid host or port for access control list " + "(ACL) assignment" in str(exc_info.value) + ) + + async def test_2414(self): + """Test enabling provider ACLs for all supported provider endpoints.""" + provider_endpoints = get_supported_provider_endpoints() + for provider_name, provider_endpoint in provider_endpoints.items(): + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + logger.info( + "Provider enabled successfully for %s endpoint %s.", + provider_name, + provider_endpoint, + ) diff --git a/tests/provider/test_2400_enable.py b/tests/provider/test_2400_enable.py new file mode 100644 index 0000000..c7541ce --- /dev/null +++ b/tests/provider/test_2400_enable.py @@ -0,0 +1,272 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import oracledb +import pytest +import select_ai +from provider.conftest import ( + ensure_provider_test_user_exists, + get_supported_provider_endpoints, + grant_provider_test_user_privileges, +) + +logger = logging.getLogger("TestEnableProvider") + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + +@pytest.fixture(scope="class") +def provider_params(request, test_env): + params = { + "user": test_env.admin_user, + "password": test_env.admin_password, + "dsn": test_env.connect_string, + } + request.cls.provider_params = params + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, provider_params, test_env): + logger.info("=== Setting up TestEnableProvider class ===") + select_ai.disconnect() + select_ai.connect(**test_env.connect_params(admin=True)) + assert select_ai.is_connected(), "Connection to DB failed" + cls = request.cls + cls.user = cls.provider_params["user"] + cls.password = cls.provider_params["password"] + cls.dsn = cls.provider_params["dsn"] + cls.db_users = [] + try: + # Create multiple DB users (DB_USER1 ... DB_USER5) + for i in range(1, 6): + user = f"DB_USER{i}" + cls.create_local_user(user) + cls.db_users.append(user) + except Exception: + select_ai.disconnect() + select_ai.connect(**test_env.connect_params()) + raise + yield + logger.info("=== Tearing down TestEnableProvider class ===") + # Drop DB users + with select_ai.cursor() as admin_cursor: + for user in cls.db_users: + try: + admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if already dropped + try: + select_ai.disconnect() + except Exception as e: + logger.warning(f"Warning: disconnect failed ({e})") + select_ai.connect(**test_env.connect_params()) + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + +@pytest.mark.usefixtures("provider_params", "setup_and_teardown") +class TestEnableProvider: + @classmethod + def create_local_user(cls, test_username="TEST_USER1"): + test_password = cls.password + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass + ensure_provider_test_user_exists(test_username, test_password) + grant_provider_test_user_privileges(test_username) + select_ai.grant_privileges(users=[test_username]) + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + self.provider_endpoint = "*.openai.azure.com" + self.db_users = self.__class__.db_users + + # ---- TESTS ---- + + def test_2401(self): + "Test enabling provider with valid users and endpoint" + logger.info("Testing grant_http_access() with valid users and valid endpoint") + try: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider enabled successfully for all test users.") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly.") + + def test_2402(self): + "Test enabling provider with a non-existent username" + logger.info("Testing grant_http_access() with a mix of existing and non-existent usernames") + db_users = ["DB_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError, match=r"ORA-46238") as cm: + select_ai.grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info(f"Expected DatabaseError caught: {cm.value}") + assert "TEST_USER2" in str(cm.value) + logger.info("Test for non-existent username completed.") + + def test_2403(self): + "Test enabling provider with all non-existent usernames" + logger.info("Testing grant_http_access() with all non-existent usernames") + db_users = ["TEST_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError, match=r"ORA-46238") as cm: + select_ai.grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info(f"Expected DatabaseError caught: {cm.value}") + assert "TEST_USER1" in str(cm.value) + logger.info("Test for all non-existent usernames completed.") + + def test_2404(self): + "Test enabling provider with empty users list" + logger.info("Testing grant_http_access() with empty users list") + try: + select_ai.grant_http_access( + users=[], + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider enabled successfully with empty users (expected and allowed).") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly with empty users.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly with empty users.") + + def test_2405(self): + "Test enabling provider with users as string instead of list" + logger.info("Testing grant_http_access() with users as a string (should be list; verify no TypeError is raised)") + try: + select_ai.grant_http_access( + users="DB_USER1", # not a list + provider_endpoint=self.provider_endpoint + ) + logger.info("No TypeError raised. Library may accept strings for 'users'.") + except Exception as e: + logger.warning(f"Unexpected exception caught: {e}") + pytest.fail(f"grant_http_access() raised unexpected exception: {e}") + + def test_2406(self): + "Test enabling provider with users as int - expect TypeError" + logger.info("Testing grant_http_access() with users as an integer (type error expected)") + with pytest.raises(TypeError) as cm: + select_ai.grant_http_access( + users=2, # not a list + provider_endpoint=self.provider_endpoint + ) + logger.info(f"Expected TypeError caught: {cm.value}") + + def test_2407(self): + "Test enabling provider with missing provider_endpoint" + logger.info("Testing grant_http_access() with None as provider_endpoint") + with pytest.raises(oracledb.DatabaseError, match=r"ORA-29261") as cm: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint=None + ) + logger.info(f"Expected DatabaseError caught: {cm.value}") + + def test_2408(self): + "Test enabling provider with a syntactically valid custom host" + logger.info("Testing grant_http_access() with a custom host name") + try: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint="invalid.endpoint" + ) + logger.info("Provider enabled successfully for custom host name.") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly with custom host name.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly with custom host name.") + + def test_2409(self): + "Test enabling provider with duplicate usernames" + logger.info("Testing grant_http_access() with duplicate usernames") + try: + select_ai.grant_http_access( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider enabled successfully with duplicate users (expected and allowed).") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly with duplicate users.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly with duplicate users.") + + def test_2410(self): + "Test enabling provider with lowercase username (case-insensitive)" + logger.info("Testing grant_http_access() with username in lowercase (should succeed on case-insensitive DB)") + try: + select_ai.grant_http_access( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider enabled successfully for lowercase username.") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly with lowercase username.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly with lowercase username.") + + def test_2411(self): + "Test enabling provider with username containing whitespace" + logger.info("Testing grant_http_access() with username containing leading/trailing whitespace") + db_users = [f" {self.db_users[0]} "] + try: + select_ai.grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider enabled successfully with username containing whitespace.") + except Exception as e: + logger.error(f"grant_http_access() raised {e} unexpectedly with whitespace in username.") + pytest.fail(f"grant_http_access() raised {e} unexpectedly with whitespace in username.") + + def test_2412(self): + "Test enabling provider with large user list" + logger.info("Testing grant_http_access() with a very large list of users (DatabaseError expected)") + db_users = [f"DB_USER_{i}" for i in range(1000)] + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.grant_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info(f"Expected DatabaseError caught: {cm.value}") + + def test_2413(self): + "Test enabling provider with a valid custom endpoint (ORA-24244 expected)" + logger.info("Testing grant_http_access() with a custom endpoint (ORA-24244 expected)") + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com" + ) + logger.info(f"Expected DatabaseError caught: {cm.value}") + assert "ORA-24244: invalid host or port for access control list (ACL) assignment" in str(cm.value) + + def test_2414(self): + "Test enabling provider ACLs for all supported provider endpoints" + logger.info("Testing grant_http_access() across supported provider endpoints") + provider_endpoints = get_supported_provider_endpoints() + for provider_name, provider_endpoint in provider_endpoints.items(): + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + logger.info( + "Provider enabled successfully for %s endpoint %s.", + provider_name, + provider_endpoint, + ) diff --git a/tests/provider/test_2500_async_disable.py b/tests/provider/test_2500_async_disable.py new file mode 100644 index 0000000..5a664d9 --- /dev/null +++ b/tests/provider/test_2500_async_disable.py @@ -0,0 +1,275 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from provider.conftest import ( + async_ensure_provider_test_user_exists, + async_grant_provider_test_user_privileges, + get_supported_provider_endpoints, +) + +logger = logging.getLogger("TestAsyncDisableProvider") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO, + ) + + +@pytest.fixture(scope="class") +def disable_params(request, test_env): + request.cls.disable_params = { + "user": test_env.admin_user, + "password": test_env.admin_password, + "dsn": test_env.connect_string, + } + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, disable_params, test_env): + logger.info("=== Setting up TestAsyncDisableProvider class ===") + await select_ai.async_disconnect() + await select_ai.async_connect(**test_env.connect_params(admin=True)) + assert await select_ai.async_is_connected(), "Connection to DB failed" + + db_users = [] + try: + for i in range(1, 6): + user = f"DB_USER{i}" + await request.cls.create_local_user(user) + db_users.append(user) + request.cls.db_users = db_users + await request.cls.create_local_user("DB_USER6") + except Exception: + await select_ai.async_disconnect() + await select_ai.async_connect(**test_env.connect_params()) + raise + + yield + + logger.info("=== Tearing down TestAsyncDisableProvider class ===") + db_users.append("DB_USER6") + async with select_ai.async_cursor() as admin_cursor: + for user in db_users: + try: + await admin_cursor.execute(f"DROP USER {user} CASCADE") + except oracledb.DatabaseError as exc: + logger.warning("Drop user failed for %s: %s", user, exc) + try: + await select_ai.async_disconnect() + except Exception as exc: + logger.warning("Warning: disconnect failed (%s)", exc) + await select_ai.async_connect(**test_env.connect_params()) + + +@pytest.fixture(autouse=True) +async def provider_enabled_state(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + provider_endpoint = "*.openai.azure.com" + await select_ai.async_grant_http_access( + users=request.cls.db_users, + provider_endpoint=provider_endpoint, + ) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("disable_params", "setup_and_teardown") +class TestAsyncDisableProvider: + @classmethod + async def create_local_user(cls, test_username="TEST_USER1"): + test_password = cls.disable_params["password"] + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass + await async_ensure_provider_test_user_exists(test_username, test_password) + await async_grant_provider_test_user_privileges(test_username) + await select_ai.async_grant_privileges(users=[test_username]) + + def setup_method(self): + self.provider_endpoint = "*.openai.azure.com" + + async def test_2501(self): + """Test disabling provider with all valid users and endpoint.""" + try: + await select_ai.async_revoke_http_access( + users=self.db_users, + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider disabled successfully for all valid users.") + except Exception as exc: + pytest.fail( + f"async_revoke_http_access() raised {exc} unexpectedly." + ) + + async def test_2502(self): + """Test disabling provider with a mix of existing and non-existent usernames.""" + db_users = ["DB_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError): + await select_ai.async_revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + + async def test_2503(self): + """Test disabling provider with all invalid usernames.""" + with pytest.raises(oracledb.DatabaseError): + await select_ai.async_revoke_http_access( + users=["INVALID_USER1", "INVALID_USER2"], + provider_endpoint=self.provider_endpoint, + ) + + async def test_2504(self): + """Test disabling provider with users as integer.""" + with pytest.raises((TypeError, ValueError)): + await select_ai.async_revoke_http_access( + users=123, + provider_endpoint=self.provider_endpoint, + ) + + async def test_2505(self): + """Test disabling provider with users as string.""" + try: + await select_ai.async_revoke_http_access( + users="DB_USER1", + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider disabled successfully for string user input.") + except Exception as exc: + pytest.fail( + f"async_revoke_http_access() raised {exc} unexpectedly." + ) + + async def test_2506(self): + """Test disabling provider with users as None.""" + with pytest.raises((TypeError, ValueError)): + await select_ai.async_revoke_http_access( + users=None, + provider_endpoint=self.provider_endpoint, + ) + + async def test_2507(self): + """Test disabling provider with missing provider_endpoint.""" + with pytest.raises(oracledb.DatabaseError, match=r"ORA-29261: bad argument"): + await select_ai.async_revoke_http_access( + users=self.db_users, + provider_endpoint=None, + ) + + async def test_2508(self): + """Test disabling provider with invalid endpoint.""" + with pytest.raises(oracledb.DatabaseError): + await select_ai.async_revoke_http_access( + users=self.db_users, + provider_endpoint="invalid.endpoint", + ) + + async def test_2509(self): + """Test disabling provider with empty users list.""" + try: + await select_ai.async_revoke_http_access( + users=[], + provider_endpoint=self.provider_endpoint, + ) + logger.info("async_revoke_http_access() succeeded with empty users.") + except Exception as exc: + pytest.fail( + "async_revoke_http_access() raised " + f"{exc} unexpectedly with empty users." + ) + + async def test_2510(self): + """Test disabling provider with duplicate usernames.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_revoke_http_access( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint, + ) + assert "ORA-01927" in str(exc_info.value) + + async def test_2511(self): + """Test disabling provider with lowercase username.""" + try: + await select_ai.async_revoke_http_access( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint, + ) + logger.info("Provider disabled successfully for lowercase username.") + except Exception as exc: + pytest.fail( + "async_revoke_http_access() raised " + f"{exc} unexpectedly with lowercase username." + ) + + async def test_2512(self): + """Test disabling provider with username containing whitespace.""" + db_users = [f" {self.db_users[0]} "] + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + assert "ORA-01927" in str(exc_info.value) + + async def test_2513(self): + """Test disabling provider with valid custom endpoint.""" + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-24244: invalid host or port for access control list \(ACL\) assignment", + ): + await select_ai.async_revoke_http_access( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com", + ) + + async def test_2514(self): + """Test disabling provider with non-granted user.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_revoke_http_access( + users=["DB_USER6"], + provider_endpoint=self.provider_endpoint, + ) + assert "ORA-01927" in str(exc_info.value) + + async def test_2515(self): + """Test disabling provider with a large user list.""" + db_users = [f"DB_USER_{i}" for i in range(1000)] + with pytest.raises(oracledb.DatabaseError): + await select_ai.async_revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint, + ) + + async def test_2516(self): + """Test disabling provider ACLs for all supported provider endpoints.""" + provider_endpoints = get_supported_provider_endpoints() + for provider_name, provider_endpoint in provider_endpoints.items(): + if provider_endpoint != self.provider_endpoint: + await select_ai.async_grant_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + await select_ai.async_revoke_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + logger.info( + "Provider disabled successfully for %s endpoint %s.", + provider_name, + provider_endpoint, + ) diff --git a/tests/provider/test_2500_disable.py b/tests/provider/test_2500_disable.py new file mode 100644 index 0000000..d4fdcc1 --- /dev/null +++ b/tests/provider/test_2500_disable.py @@ -0,0 +1,284 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import oracledb +import pytest +import select_ai +from provider.conftest import ( + ensure_provider_test_user_exists, + get_supported_provider_endpoints, + grant_provider_test_user_privileges, +) + +logger = logging.getLogger("TestDisableProvider") + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + +@pytest.fixture(scope="class") +def disable_params(request, test_env): + params = { + "user": test_env.admin_user, + "password": test_env.admin_password, + "dsn": test_env.connect_string, + } + request.cls.disable_params = params + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, disable_params, test_env): + logger.info("\n=== Setting up TestDisableProvider class ===") + select_ai.disconnect() + select_ai.connect(**test_env.connect_params(admin=True)) + assert select_ai.is_connected(), "Connection to DB failed" + db_users = [] + try: + for i in range(1, 6): + user = f"DB_USER{i}" + request.cls.create_local_user(user) + db_users.append(user) + request.cls.db_users = db_users + # Create Additional user + request.cls.create_local_user("DB_USER6") + except Exception: + select_ai.disconnect() + select_ai.connect(**test_env.connect_params()) + raise + logger.info("Setup complete.\n") + yield + + logger.info("\n=== Tearing down TestDisableProvider class ===") + db_users.append("DB_USER6") + with select_ai.cursor() as admin_cursor: + for user in db_users: + try: + admin_cursor.execute(f"DROP USER {user} CASCADE") + logger.info(f"Dropped user {user}") + except oracledb.DatabaseError as e: + logger.warning(f"Disconnect failed: {e}") + try: + select_ai.disconnect() + except Exception as e: + logger.warning(f"Warning: disconnect failed ({e})") + select_ai.connect(**test_env.connect_params()) + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + +@pytest.mark.usefixtures("disable_params", "setup_and_teardown") +class TestDisableProvider: + + @classmethod + def create_local_user(cls, test_username="TEST_USER1"): + logger.info(f"Creating local user: {test_username}") + test_password = cls.disable_params["password"] + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + pass # Ignore if user doesn't exist + ensure_provider_test_user_exists(test_username, test_password) + grant_provider_test_user_privileges(test_username) + select_ai.grant_privileges(users=[test_username]) + logger.info(f"User {test_username} created successfully.") + + def setup_method(self, method): + logger.info(f"\n--- Starting test: {method.__name__} ---") + self.provider_endpoint = "*.openai.azure.com" + try: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info(f"Provider enabled for {len(self.db_users)} users.") + except Exception as e: + pytest.fail(f"grant_http_access() raised {e} unexpectedly.") + + def teardown_method(self, method): + logger.info(f"--- Finished test: {method.__name__} ---") + + # === TEST CASES === + + def test_2501(self): + "Test disabling provider with all valid users and endpoint" + try: + select_ai.revoke_http_access( + users=self.db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider disabled successfully for all valid users.") + except Exception as e: + pytest.fail(f"revoke_http_access() raised {e} unexpectedly.") + + def test_2502(self): + "Test disabling provider with a mix of existing and non-existent usernames" + db_users = ["DB_USER1", "TEST_USER2"] + with pytest.raises(oracledb.DatabaseError): + select_ai.revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info("Caught expected DatabaseError for nonexistent user.") + + def test_2503(self): + "Test disabling provider with all invalid usernames" + with pytest.raises(oracledb.DatabaseError): + select_ai.revoke_http_access( + users=["INVALID_USER1", "INVALID_USER2"], + provider_endpoint=self.provider_endpoint + ) + logger.info("Caught expected DatabaseError for invalid users input.") + + def test_2504(self): + "Test disabling provider with users as integer (TypeError/ValueError expected)" + with pytest.raises((TypeError, ValueError)): + select_ai.revoke_http_access( + users=123, + provider_endpoint=self.provider_endpoint + ) + logger.info("Caught expected TypeError/ValueError for int users input.") + + def test_2505(self): + "Test disabling provider with users as string" + try: + select_ai.revoke_http_access( + users="DB_USER1", + provider_endpoint=self.provider_endpoint + ) + logger.info("Provider disabled successfully for string user input.") + except Exception as e: + pytest.fail(f"revoke_http_access() raised {e} unexpectedly.") + + def test_2506(self): + "Test disabling provider with users as None (TypeError/ValueError expected)" + with pytest.raises((TypeError, ValueError)): + select_ai.revoke_http_access( + users=None, + provider_endpoint=self.provider_endpoint + ) + logger.info("Caught expected TypeError/ValueError for none users input.") + + def test_2507(self): + "Test disabling provider with missing provider_endpoint (ORA-29261 expected)" + with pytest.raises(oracledb.DatabaseError, match=r"ORA-29261: bad argument"): + select_ai.revoke_http_access( + users=self.db_users, + provider_endpoint=None + ) + logger.info("Caught expected ORA-29261 for missing endpoint.") + + def test_2508(self): + "Test disabling provider with invalid endpoint (DatabaseError expected)" + with pytest.raises(oracledb.DatabaseError): + select_ai.revoke_http_access( + users=self.db_users, + provider_endpoint="invalid.endpoint" + ) + logger.info("Caught expected DatabaseError for invalid endpoint.") + + def test_2509(self): + "Test disabling provider with empty users list" + try: + select_ai.revoke_http_access( + users=[], + provider_endpoint=self.provider_endpoint + ) + logger.info("revoke_http_access() succeeded with empty users list.") + except Exception as e: + pytest.fail(f"revoke_http_access() raised {e} unexpectedly with empty users list.") + + def test_2510(self): + "Test disabling provider with duplicate usernames (ORA-01927 expected)" + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.revoke_http_access( + users=[self.db_users[0], self.db_users[0]], + provider_endpoint=self.provider_endpoint + ) + assert "ORA-01927" in str(cm.value) + logger.info("Caught expected ORA-01927 for duplicate users.") + + def test_2511(self): + "Test disabling provider with lowercase username" + try: + select_ai.revoke_http_access( + users=[self.db_users[0].lower()], + provider_endpoint=self.provider_endpoint + ) + logger.info("revoke_http_access() succeeded with lowercase username.") + except Exception as e: + pytest.fail(f"revoke_http_access() raised {e} unexpectedly with lowercase username.") + + def test_2512(self): + "Test disabling provider with username containing whitespace" + db_users = [f" {self.db_users[0]} "] + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + assert "ORA-01927" in str(cm.value) + logger.info("Caught expected ORA-01927 for whitespace username.") + + def test_2513(self): + "Test disabling provider with valid custom endpoint (ORA-24244 expected)" + with pytest.raises( + oracledb.DatabaseError, + match=r"ORA-24244: invalid host or port for access control list \(ACL\) assignment" + ): + select_ai.revoke_http_access( + users=self.db_users, + provider_endpoint="https://custom.openai.azure.com" + ) + logger.info("Caught expected ORA-24244 for custom endpoint.") + + def test_2514(self): + "Test disabling provider with non-granted user (ORA-01927 expected)" + non_granted_user = "DB_USER6" + with pytest.raises(oracledb.DatabaseError) as cm: + select_ai.revoke_http_access( + users=[non_granted_user], + provider_endpoint=self.provider_endpoint + ) + assert "ORA-01927" in str(cm.value) + logger.info("Caught expected ORA-01927 for non-granted user.") + + def test_2515(self): + "Test disabling provider with a large user list (DatabaseError expected)" + db_users = [f"DB_USER_{i}" for i in range(1000)] + with pytest.raises(oracledb.DatabaseError): + select_ai.revoke_http_access( + users=db_users, + provider_endpoint=self.provider_endpoint + ) + logger.info("Caught expected DatabaseError for large user list.") + + def test_2516(self): + "Test disabling provider ACLs for all supported provider endpoints" + logger.info("Testing revoke_http_access() across supported provider endpoints") + provider_endpoints = get_supported_provider_endpoints() + for provider_name, provider_endpoint in provider_endpoints.items(): + if provider_endpoint != self.provider_endpoint: + select_ai.grant_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + select_ai.revoke_http_access( + users=self.db_users, + provider_endpoint=provider_endpoint, + ) + logger.info( + "Provider disabled successfully for %s endpoint %s.", + provider_name, + provider_endpoint, + ) diff --git a/tests/test_1010_async_connection.py b/tests/test_1010_async_connection.py new file mode 100644 index 0000000..ce3f2fd --- /dev/null +++ b/tests/test_1010_async_connection.py @@ -0,0 +1,378 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestAsyncConnection") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="module") +def oci_credential(): + """ + These connection tests do not use the shared OCI credential fixture. + Override it locally so the module only depends on DB connectivity. + """ + return None + + +@pytest.fixture(scope="session", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO, + ) + + +@pytest.fixture(scope="class") +def connection_params(request, test_env): + params = { + "user": test_env.test_user, + "password": test_env.test_user_password, + "dsn": test_env.connect_string, + "wallet_location": test_env.wallet_location, + "wallet_password": test_env.wallet_password, + } + request.cls.connection_params = params + + +@pytest.fixture(scope="class") +def async_connect_as(test_env): + async def _connect_as(*, admin=False, **overrides): + await select_ai.async_disconnect() + if admin: + connect_kwargs = dict(test_env.connect_params(admin=True)) + else: + connect_kwargs = dict(test_env.connect_params()) + connect_kwargs.update(overrides) + await select_ai.async_connect(**connect_kwargs) + + return _connect_as + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown( + request, async_connect, connection_params, async_connect_as +): + logger.info("=== Setting up TestAsyncConnection class ===") + request.cls.async_connect_as = staticmethod(async_connect_as) + await async_connect_as() + assert await select_ai.async_is_connected(), "Connection to DB failed" + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestAsyncConnection class ===") + try: + await async_connect_as() + logger.info("Restored default DB connection") + except Exception as exc: + logger.warning("Warning: disconnect failed (%s)", exc) + + +@pytest.fixture(autouse=True) +async def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + if request.instance is not None: + await request.instance.async_connect_as() + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("connection_params", "setup_and_teardown") +class TestAsyncConnection: + async def test_1011(self): + """Testing connection success with wallet.""" + logger.info("Testing connection success with wallet...") + await self.async_connect_as() + assert await select_ai.async_is_connected(), "Connection to DB failed" + logger.info("Connection successful") + await select_ai.async_disconnect() + assert not await select_ai.async_is_connected() + logger.info("Disconnected after test_1011") + + async def test_1012(self): + """Testing connection without wallet.""" + logger.info("Testing connection without wallet...") + if self.connection_params["wallet_location"]: + with pytest.raises(oracledb.DatabaseError, match=r"DPY-4027"): + await select_ai.async_disconnect() + await select_ai.async_connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn=self.connection_params["dsn"], + ) + logger.info( + "Wallet-less connect correctly failed in wallet-based setup." + ) + else: + await select_ai.async_disconnect() + await select_ai.async_connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn=self.connection_params["dsn"], + ) + assert await select_ai.async_is_connected(), ( + "Connection to DB failed without wallet" + ) + logger.info("Connection successful without wallet") + await select_ai.async_disconnect() + assert not await select_ai.async_is_connected() + logger.info("Disconnected after test_1012") + + async def test_1013(self): + """Testing is_connected returns bool.""" + logger.info("Testing is_connected returns bool...") + await self.async_connect_as() + assert isinstance(await select_ai.async_is_connected(), bool) + await select_ai.async_disconnect() + logger.info("is_connected check complete and disconnected") + + async def test_1014(self): + """Testing failure with wrong password.""" + logger.info("Testing failure with wrong password...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError): + await select_ai.async_connect( + user=self.connection_params["user"], + password="wrong_pass", + dsn=self.connection_params["dsn"], + **connect_kwargs, + ) + logger.info("Correctly raised DatabaseError for wrong password") + + async def test_1015(self): + """Testing connection with bad string.""" + logger.info("Testing connection with bad string...") + with pytest.raises(TypeError) as exc_info: + await select_ai.async_connect("not a valid connect string!!") + assert "missing 2 required positional arguments" in str(exc_info.value) + logger.info("Correctly raised TypeError for bad string") + + async def test_1016(self): + """Testing connection with bad DSN.""" + logger.info("Testing connection with bad DSN...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn="invalid_dsn", + **connect_kwargs, + ) + msg = str(exc_info.value) + logger.info("Database exception message was: %s", msg) + assert ("DPY-4000" in msg) or ("DPY-4026" in msg) or ("DPY-4027" in msg) + logger.info("Correctly raised DatabaseError for bad DSN") + + async def test_1017(self): + """Testing connection with bad password.""" + logger.info("Testing connection with bad password...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.async_connect( + user=self.connection_params["user"], + password=self.connection_params["password"] + "X", + dsn=self.connection_params["dsn"], + **connect_kwargs, + ) + assert "ORA-01017" in str(exc_info.value) + logger.info("Correctly raised DatabaseError for wrong password") + + async def test_1018(self): + """Testing simple query execution.""" + logger.info("Testing simple query execution...") + await self.async_connect_as() + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT 1 FROM DUAL") + result = await cr.fetchone() + assert result[0] == 1 + logger.info("Query executed successfully, result: %s", result[0]) + + async def test_1019(self): + """Testing query with parameters.""" + logger.info("Testing query with parameters...") + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT :val FROM dual", val=42) + result = await cr.fetchone() + assert result[0] == 42 + logger.info("Query with parameters successful, result: %s", result[0]) + + async def test_1020(self): + """Testing fetchall.""" + logger.info("Testing fetchall...") + async with select_ai.async_cursor() as cursor: + await cursor.execute("SELECT level FROM dual CONNECT BY level <= 5") + results = await cursor.fetchall() + assert len(results) == 5 + logger.info("Fetched rows: %s", len(results)) + + async def test_1021(self): + """Testing invalid query.""" + logger.info("Testing invalid query...") + async with select_ai.async_cursor() as cursor: + with pytest.raises(oracledb.DatabaseError): + await cursor.execute("SELECT * FROM non_existent_table") + logger.info("Correctly raised DatabaseError for invalid query") + + async def test_1022(self): + """Testing commit and rollback.""" + logger.info("Testing commit and rollback...") + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + begin + execute immediate 'create table test_cr_tab_async (id int)'; + exception + when others then + if sqlcode != -955 then + raise; + end if; + end; + """ + ) + async with select_ai.db.async_get_connection() as async_connection: + await async_connection.commit() + async with select_ai.async_cursor() as cursor: + await cursor.execute("truncate table test_cr_tab_async") + await cursor.execute("insert into test_cr_tab_async values (1)") + async with select_ai.db.async_get_connection() as async_connection: + await async_connection.rollback() + async with select_ai.async_cursor() as cursor: + await cursor.execute("select count(*) from test_cr_tab_async") + (count,) = await cursor.fetchone() + assert count == 0 + logger.info("Rollback verified successfully") + + async def test_1023(self): + """Testing connection close error.""" + logger.info("Testing connection close error...") + await select_ai.async_disconnect() + with pytest.raises(DatabaseNotConnectedError): + async with select_ai.async_cursor() as cr: + await cr.execute("SELECT 1 FROM DUAL") + logger.info( + "DatabaseNotConnectedError correctly raised on disconnected cursor" + ) + + async def test_1024(self): + """Testing repeated disconnect.""" + logger.info("Testing repeated disconnect...") + await self.async_connect_as() + await select_ai.async_disconnect() + await select_ai.async_disconnect() + assert not await select_ai.async_is_connected() + logger.info("Repeated disconnect handled successfully") + + async def test_1025(self): + """Testing DBMS_OUTPUT package.""" + logger.info("Testing DBMS_OUTPUT package...") + await self.async_connect_as() + test_string = "Testing DBMS_OUTPUT package" + async with select_ai.async_cursor() as cursor: + await cursor.callproc("dbms_output.enable") + await cursor.callproc("dbms_output.put_line", [test_string]) + string_var = cursor.var(str) + number_var = cursor.var(int) + await cursor.callproc("dbms_output.get_line", (string_var, number_var)) + assert string_var.getvalue() == test_string + logger.info("DBMS_OUTPUT verified: %s", string_var.getvalue()) + + async def test_1026(self): + """Testing instance name retrieval.""" + logger.info("Testing instance name retrieval...") + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "select upper(sys_context('userenv', 'instance_name')) from dual" + ) + (instance_name,) = await cursor.fetchone() + assert isinstance(instance_name, str) + logger.info("Instance name: %s", instance_name) + + async def test_1027(self): + """Testing max open cursors.""" + logger.info("Testing max open cursors...") + await self.async_connect_as(admin=True) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "select value from V$PARAMETER where name='open_cursors'" + ) + (max_open_cursors,) = await cursor.fetchone() + assert int(max_open_cursors) == 1000 + logger.info("Max open cursors: %s", max_open_cursors) + + async def test_1028(self): + """Testing service name retrieval.""" + logger.info("Testing service name retrieval...") + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "select sys_context('userenv', 'service_name') from dual" + ) + (service_name,) = await cursor.fetchone() + assert isinstance(service_name, str) + logger.info("Service name: %s", service_name) + + async def test_1029(self): + """Testing user and table creation.""" + logger.info("Testing user and table creation...") + test_username = "TEST_USER2" + test_password = self.connection_params["password"] + await self.async_connect_as(admin=True) + async with select_ai.async_cursor() as admin_cursor: + try: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + logger.info("User %s did not exist before test", test_username) + await admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{test_password}"' + ) + await admin_cursor.execute( + f"grant create session, create table, unlimited tablespace to {test_username}" + ) + logger.info("Created test user: %s", test_username) + + await self.async_connect_as( + user=test_username, + password=test_password, + dsn=self.connection_params["dsn"], + ) + async with select_ai.async_cursor() as test_cursor: + await test_cursor.execute("CREATE TABLE test_table_async (id INT)") + await test_cursor.execute( + "INSERT INTO test_table_async (id) VALUES (100)" + ) + await test_cursor.execute("SELECT id FROM test_table_async") + result = await test_cursor.fetchone() + assert result[0] == 100 + logger.info("Test table created and verified successfully") + + await self.async_connect_as(admin=True) + async with select_ai.async_cursor() as admin_cursor: + await admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Dropped test user: %s", test_username) diff --git a/tests/test_1010_connection.py b/tests/test_1010_connection.py new file mode 100644 index 0000000..88209ff --- /dev/null +++ b/tests/test_1010_connection.py @@ -0,0 +1,359 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestConnection") + + +@pytest.fixture(scope="module") +def oci_credential(): + """ + These connection tests do not use the shared OCI credential fixture. + Override it locally so the module only depends on DB connectivity. + """ + return None + + +@pytest.fixture(scope="session", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO, + ) + + +@pytest.fixture(scope="class") +def connection_params(request, test_env): + params = { + "user": test_env.test_user, + "password": test_env.test_user_password, + "dsn": test_env.connect_string, + "wallet_location": test_env.wallet_location, + "wallet_password": test_env.wallet_password, + "admin_user": test_env.admin_user, + "admin_password": test_env.admin_password, + } + request.cls.connection_params = params + + +@pytest.fixture(scope="class") +def connect_as(test_env): + def _connect_as(*, admin=False, **overrides): + select_ai.disconnect() + if admin: + connect_kwargs = dict(test_env.connect_params(admin=True)) + else: + connect_kwargs = dict(test_env.connect_params()) + connect_kwargs.update(overrides) + select_ai.connect(**connect_kwargs) + + return _connect_as + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connection_params, connect_as): + logger.info("=== Setting up TestConnection class ===") + request.cls.connect_as = staticmethod(connect_as) + connect_as() + assert select_ai.is_connected(), "Connection to DB failed" + logger.info("Initial connection successful") + yield + logger.info("=== Tearing down TestConnection class ===") + try: + connect_as() + logger.info("Restored default DB connection") + except Exception as exc: + logger.warning("Warning: disconnect failed (%s)", exc) + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + if request.instance is not None: + request.instance.connect_as() + logger.info("--- Finished test: %s ---", request.function.__name__) + +@pytest.mark.usefixtures("connection_params", "setup_and_teardown") +class TestConnection: + def test_1011(self): + """Testing connection success with wallet.""" + logger.info("Testing connection success with wallet...") + self.connect_as() + assert select_ai.is_connected(), "Connection to DB failed" + logger.info("Connection successful") + select_ai.disconnect() + assert not select_ai.is_connected() + logger.info("Disconnected after test_1011") + + def test_1012(self): + """Testing connection without wallet.""" + logger.info("Testing connection without wallet...") + if self.connection_params["wallet_location"]: + with pytest.raises(oracledb.DatabaseError, match=r"DPY-4027"): + select_ai.disconnect() + select_ai.connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn=self.connection_params["dsn"], + ) + logger.info("Wallet-less connect correctly failed in wallet-based setup.") + else: + select_ai.disconnect() + select_ai.connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn=self.connection_params["dsn"], + ) + assert select_ai.is_connected(), "Connection to DB failed without wallet" + logger.info("Connection successful without wallet") + select_ai.disconnect() + assert not select_ai.is_connected() + logger.info("Disconnected after test_1012") + + def test_1013(self): + """Testing is_connected returns bool.""" + logger.info("Testing is_connected returns bool...") + self.connect_as() + assert isinstance(select_ai.is_connected(), bool) + select_ai.disconnect() + logger.info("is_connected check complete and disconnected") + + def test_1014(self): + """Testing failure with wrong password.""" + logger.info("Testing failure with wrong password...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError): + select_ai.connect( + user=self.connection_params["user"], + password="wrong_pass", + dsn=self.connection_params["dsn"], + **connect_kwargs + ) + logger.info("Correctly raised DatabaseError for wrong password") + + def test_1015(self): + """Testing connection with bad string.""" + logger.info("Testing connection with bad string...") + with pytest.raises(TypeError) as e: + select_ai.connect("not a valid connect string!!") + assert "missing 2 required positional arguments" in str(e.value) + logger.info("Correctly raised TypeError for bad string") + + def test_1016(self): + """Testing connection with bad DSN.""" + logger.info("Testing connection with bad DSN...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError) as excinfo: + select_ai.connect( + user=self.connection_params["user"], + password=self.connection_params["password"], + dsn="invalid_dsn", + **connect_kwargs + ) + msg = str(excinfo.value) + logger.info("Database exception message was: %s", msg) + assert ("DPY-4000" in msg) or ("DPY-4026" in msg) or ("DPY-4027" in msg) + logger.info("Correctly raised DatabaseError for bad DSN") + + def test_1017(self): + """Testing connection with bad password.""" + logger.info("Testing connection with bad password...") + connect_kwargs = {} + if self.connection_params["wallet_location"]: + connect_kwargs = { + "config_dir": self.connection_params["wallet_location"], + "wallet_location": self.connection_params["wallet_location"], + "wallet_password": self.connection_params["wallet_password"], + } + with pytest.raises(oracledb.DatabaseError) as excinfo: + select_ai.connect( + user=self.connection_params["user"], + password=self.connection_params["password"] + "X", + dsn=self.connection_params["dsn"], + **connect_kwargs + ) + assert "ORA-01017" in str(excinfo.value) + logger.info("Correctly raised DatabaseError for wrong password") + + def test_1018(self): + """Testing simple query execution.""" + logger.info("Testing simple query execution...") + self.connect_as() + with select_ai.cursor() as cr: + cr.execute("SELECT 1 FROM DUAL") + result = cr.fetchone() + assert result[0] == 1 + logger.info("Query executed successfully, result: %s", result[0]) + + def test_1019(self): + """Testing query with parameters.""" + logger.info("Testing query with parameters...") + with select_ai.cursor() as cr: + cr.execute("SELECT :val FROM dual", val=42) + result = cr.fetchone() + assert result[0] == 42 + logger.info("Query with parameters successful, result: %s", result[0]) + + def test_1020(self): + """Testing fetchall.""" + logger.info("Testing fetchall...") + with select_ai.cursor() as cursor: + cursor.execute("SELECT level FROM dual CONNECT BY level <= 5") + results = cursor.fetchall() + assert len(results) == 5 + logger.info("Fetched rows: %s", len(results)) + + def test_1021(self): + """Testing invalid query.""" + logger.info("Testing invalid query...") + with select_ai.cursor() as cursor: + with pytest.raises(oracledb.DatabaseError): + cursor.execute("SELECT * FROM non_existent_table") + logger.info("Correctly raised DatabaseError for invalid query") + + def test_1022(self): + """Testing commit and rollback.""" + logger.info("Testing commit and rollback...") + with select_ai.cursor() as cursor: + cursor.execute(""" + begin + execute immediate 'create table test_cr_tab (id int)'; + exception + when others then + if sqlcode != -955 then + raise; + end if; + end; + """) + cursor.execute("commit") + cursor.execute("truncate table test_cr_tab") + cursor.execute("insert into test_cr_tab values (1)") + cursor.execute("rollback") + cursor.execute("select count(*) from test_cr_tab") + (count,) = cursor.fetchone() + assert count == 0 + logger.info("Rollback verified successfully") + + def test_1023(self): + """Testing connection close error.""" + logger.info("Testing connection close error...") + select_ai.disconnect() + with pytest.raises(DatabaseNotConnectedError): + with select_ai.cursor() as cr: + cr.execute("SELECT 1 FROM DUAL") + logger.info("DatabaseNotConnectedError correctly raised on disconnected cursor") + + def test_1024(self): + """Testing repeated disconnect.""" + logger.info("Testing repeated disconnect...") + self.connect_as() + select_ai.disconnect() + select_ai.disconnect() + assert not select_ai.is_connected() + logger.info("Repeated disconnect handled successfully") + + def test_1025(self): + """Testing DBMS_OUTPUT package.""" + logger.info("Testing DBMS_OUTPUT package...") + self.connect_as() + test_string = "Testing DBMS_OUTPUT package" + with select_ai.cursor() as cursor: + cursor.callproc("dbms_output.enable") + cursor.callproc("dbms_output.put_line", [test_string]) + string_var = cursor.var(str) + number_var = cursor.var(int) + cursor.callproc("dbms_output.get_line", (string_var, number_var)) + assert string_var.getvalue() == test_string + logger.info("DBMS_OUTPUT verified: %s", string_var.getvalue()) + + def test_1026(self): + """Testing instance name retrieval.""" + logger.info("Testing instance name retrieval...") + with select_ai.cursor() as cursor: + cursor.execute( + "select upper(sys_context('userenv', 'instance_name')) from dual" + ) + (instance_name,) = cursor.fetchone() + assert isinstance(instance_name, str) + logger.info("Instance name: %s", instance_name) + + def test_1027(self): + """Testing max open cursors.""" + logger.info("Testing max open cursors...") + self.connect_as(admin=True) + with select_ai.cursor() as cursor: + cursor.execute( + "select value from V$PARAMETER where name='open_cursors'" + ) + (max_open_cursors,) = cursor.fetchone() + assert int(max_open_cursors) == 1000 + logger.info("Max open cursors: %s", max_open_cursors) + + def test_1028(self): + """Testing service name retrieval.""" + logger.info("Testing service name retrieval...") + with select_ai.cursor() as cursor: + cursor.execute( + "select sys_context('userenv', 'service_name') from dual" + ) + (service_name,) = cursor.fetchone() + assert isinstance(service_name, str) + logger.info("Service name: %s", service_name) + + def test_1029(self): + """Testing user and table creation.""" + logger.info("Testing user and table creation...") + test_username = "TEST_USER1" + test_password = self.connection_params["password"] + self.connect_as(admin=True) + with select_ai.cursor() as admin_cursor: + try: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + except oracledb.DatabaseError: + logger.info("User %s did not exist before test", test_username) + admin_cursor.execute( + f'CREATE USER {test_username} IDENTIFIED BY "{test_password}"' + ) + admin_cursor.execute( + f"grant create session, create table, unlimited tablespace to {test_username}" + ) + logger.info("Created test user: %s", test_username) + self.connect_as( + user=test_username, + password=test_password, + dsn=self.connection_params["dsn"], + ) + with select_ai.cursor() as test_cursor: + test_cursor.execute("CREATE TABLE test_table (id INT)") + test_cursor.execute("INSERT INTO test_table (id) VALUES (100)") + test_cursor.execute("SELECT id FROM test_table") + result = test_cursor.fetchone() + assert result[0] == 100 + logger.info("Test table created and verified successfully") + self.connect_as(admin=True) + with select_ai.cursor() as admin_cursor: + admin_cursor.execute(f"DROP USER {test_username} CASCADE") + logger.info("Dropped test user: %s", test_username) diff --git a/tests/vector_index/conftest.py b/tests/vector_index/conftest.py new file mode 100644 index 0000000..7195374 --- /dev/null +++ b/tests/vector_index/conftest.py @@ -0,0 +1,104 @@ +import logging +import os +from pathlib import Path + +import pytest + +LOG_FORMAT = "%(levelname)s: [%(name)s] %(message)s" + + +def get_vcidx_env_value(name, default_value=None, required=False): + """ + Reads vector-index-specific environment variables. + """ + env_name = f"PYSAI_TEST_{name}" + value = os.environ.get(env_name) + if value is None: + if required: + pytest.exit(f"missing value for environment variable {env_name}", 1) + return default_value + return value + + +def _configure_logger(logger: logging.Logger, module_file: str) -> None: + logger.setLevel(logging.DEBUG) + log_dir = Path(__file__).resolve().parents[2] / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"tkex_{Path(module_file).stem}.log" + formatter = logging.Formatter(fmt=LOG_FORMAT) + file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.WARNING) + console_handler.setFormatter(formatter) + logger.handlers.clear() + logger.propagate = False + logger.addHandler(file_handler) + logger.addHandler(console_handler) + logger.info("Configured logging for module") + + +@pytest.fixture(scope="module", autouse=True) +def configure_module_logging(request): + module = request.module + logger = logging.getLogger(module.__name__) + _configure_logger(logger, module.__file__) + yield + for handler in logger.handlers: + handler.close() + logger.handlers.clear() + + +@pytest.fixture(scope="session") +def embedding_location(): + value = get_vcidx_env_value("EMBEDDING_LOCATION", required=True) + + # Fail fast with a clear message if the wrong value is being used + if "inference.generativeai" in value or "/actions/embedText" in value: + pytest.exit( + "PYSAI_TEST_EMBEDDING_LOCATION is set to a GenAI inference endpoint. " + "It must be an Object Storage URI/URL (objectstorage.../n//b//o/). " + f"Got: {value}", + 1, + ) + + if "objectstorage." not in value: + pytest.exit( + "PYSAI_TEST_EMBEDDING_LOCATION does not look like an Object Storage URL/URI. " + f"Got: {value}", + 1, + ) + + return value + + +@pytest.fixture(scope="session") +def vcidx_object_store_credentials(): + return { + "cred_username": get_vcidx_env_value("CRED_USERNAME"), + "cred_password": get_vcidx_env_value("CRED_PASSWORD"), + } + + +@pytest.fixture(scope="class") +def vcidx_params( + test_env, + oci_credential, + oci_compartment_id, + embedding_location, + vcidx_object_store_credentials, +): + return { + "user": test_env.admin_user, + "password": test_env.admin_password, + "dsn": test_env.connect_string, + "user_ocid": oci_credential["user_ocid"], + "tenancy_ocid": oci_credential["tenancy_ocid"], + "private_key": oci_credential["private_key"], + "fingerprint": oci_credential["fingerprint"], + "oci_compartment_id": oci_compartment_id, + "embedding_location": embedding_location, + "cred_username": vcidx_object_store_credentials["cred_username"], + "cred_password": vcidx_object_store_credentials["cred_password"], + } diff --git a/tests/vector_index/test_5000_async_create_index.py b/tests/vector_index/test_5000_async_create_index.py new file mode 100644 index 0000000..1636173 --- /dev/null +++ b/tests/vector_index/test_5000_async_create_index.py @@ -0,0 +1,453 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai import OracleVectorIndexAttributes + +logger = logging.getLogger("TestAsyncCreateVectorIndex") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def vector_index_params(request, vcidx_params): + request.cls.vector_index_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, vector_index_params): + """ + The shared async connection fixture from tests/conftest.py owns lifecycle. + """ + logger.info("\n=== Setting up TestAsyncCreateVectorIndex class ===") + assert await select_ai.async_is_connected(), "Connection to DB failed" + logger.info("Fetching credential secrets and OCI configuration...") + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Setup complete.\n") + yield + logger.info("\n=== Tearing down TestAsyncCreateVectorIndex class ===") + await request.cls.delete_profile(request.cls.profile) + await request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +async def vector_index_test_state(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + request.cls.objstore_cred = "OBJSTORE_CRED" + params = request.cls.vector_index_params + request.cls.vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name=request.cls.objstore_cred, + ) + request.cls.async_vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=request.cls.vector_index_attributes, + description="Test vector index", + profile=request.cls.profile, + ) + yield + try: + await select_ai.AsyncVectorIndex( + index_name="test_vector_index" + ).delete(force=True) + logger.info("Vector index deleted successfully.") + except Exception as exc: + logger.warning("Warning: vector index cleanup failed: %s", exc) + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("vector_index_params", "setup_and_teardown") +class TestAsyncCreateVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing native credential params for: %s", cred_name) + params = cls.vector_index_params + return dict( + credential_name=cred_name, + user_ocid=params["user_ocid"], + tenancy_ocid=params["tenancy_ocid"], + private_key=params["private_key"], + fingerprint=params["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing basic credential params for: %s", cred_name) + params = cls.vector_index_params + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"], + ) + + @classmethod + async def create_credential( + cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED" + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info("Creating GenAI credential: %s", genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + logger.info("GenAI credential created.") + except Exception as exc: + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + + params = cls.vector_index_params + if params.get("cred_username") and params.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info( + "Creating ObjectStore credential: %s", objstore_cred + ) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("ObjectStore credential created.") + except Exception as exc: + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + logger.info("Creating Profile: %s", profile_name) + params = cls.vector_index_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=params["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider, + ) + profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True, + ) + logger.info("Profile '%s' created successfully.", profile_name) + return profile + + @classmethod + async def delete_profile(cls, profile): + logger.info("Deleting profile...") + try: + await profile.delete() + logger.info( + "Profile '%s' deleted successfully.", profile.profile_name + ) + except Exception as exc: + raise AssertionError( + f"profile.delete() raised {exc} unexpectedly." + ) + + @classmethod + async def delete_credential(cls): + logger.info("Deleting credentials...") + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as exc: + logger.warning( + "delete_credential() raised %s unexpectedly.", exc + ) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as exc: + logger.warning( + "delete_credential() raised %s unexpectedly.", exc + ) + + async def test_5001(self): + """Test successful vector index creation.""" + try: + await self.async_vector_index.create(replace=True) + logger.info("Vector index created successfully.") + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + logger.info("Verifying created vector index...") + vector_index = select_ai.AsyncVectorIndex() + indexes = [index.index_name async for index in vector_index.list()] + logger.info("Indexes found: %s", indexes) + assert "TEST_VECTOR_INDEX" in indexes + logger.info("Verified vector index creation successfully.") + + async def test_5002(self): + """Test vector index creation with replace=False.""" + try: + await self.async_vector_index.create(replace=False) + logger.info("Vector index created successfully.") + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + logger.info("Verifying created vector index...") + vector_index = select_ai.AsyncVectorIndex() + indexes = [index.index_name async for index in vector_index.list()] + logger.info("Indexes found: %s", indexes) + assert "TEST_VECTOR_INDEX" in indexes + logger.info("Verified vector index presence.") + + async def test_5003(self): + """Test vector index creation with empty description.""" + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="", + profile=self.profile, + ) + try: + await vector_index.create(replace=True) + logger.info( + "Vector index created successfully with empty description." + ) + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + logger.info("Verifying created vector index...") + index_list = select_ai.AsyncVectorIndex() + indexes = [index.index_name async for index in index_list.list()] + logger.info("Indexes found: %s", indexes) + assert "TEST_VECTOR_INDEX" in indexes + logger.info( + "Verified vector index creation with empty description." + ) + + async def test_5004(self): + """Test vector index recreation with replace=True.""" + try: + await self.async_vector_index.create(replace=True) + logger.info("First creation successful.") + await self.async_vector_index.create(replace=True) + logger.info("Second creation successful with replace=True.") + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + + async def test_5005(self): + """Test vector index recreation with replace=False (expect failure).""" + try: + await self.async_vector_index.create(replace=False) + logger.info("First creation successful.") + except Exception as exc: + pytest.fail( + f"Create vector index failed unexpectedly: {exc}" + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.create(replace=False) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + assert "ORA-20048" in str(exc_info.value) + assert "already exists" in str(exc_info.value) + logger.info("Verified error on duplicate creation with replace=False.") + + async def test_5006(self): + """Test minimal attribute vector index creation.""" + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + try: + await vector_index.create(replace=True) + logger.info( + "Vector index created successfully with minimal attributes." + ) + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + + async def test_5007(self): + """Test vector index recreation after delete.""" + try: + await self.async_vector_index.create(replace=True) + logger.info("Vector index created successfully.") + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + logger.info("Deleting vector index...") + await select_ai.AsyncVectorIndex( + index_name="test_vector_index" + ).delete(force=True) + logger.info("Vector index deleted successfully.") + logger.info("Recreating vector index...") + try: + await self.async_vector_index.create(replace=True) + logger.info("Vector index recreated successfully.") + except Exception as exc: + pytest.fail( + f"VectorIndex.create raised an unexpected exception: {exc}" + ) + + async def test_5008(self): + """Test vector index creation with invalid credential.""" + params = self.vector_index_params + vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name="invalidObjStore_cred", + ) + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await vector_index.create(replace=True) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + + async def test_5009(self): + """Test vector index creation with invalid location.""" + vector_index_attributes = OracleVectorIndexAttributes( + location="invalid_location", + object_storage_credential_name=self.objstore_cred, + ) + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await vector_index.create(replace=True) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + + async def test_5010(self): + """Test vector index creation with missing attributes.""" + with pytest.raises(AttributeError): + await select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=None, + profile=self.profile, + ).create() + logger.info("Expected AttributeError raised for missing attributes.") + + async def test_5011(self): + """Test vector index creation with invalid attributes type.""" + with pytest.raises(TypeError): + await select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes="invalid_attributes", + profile=self.profile, + ).create() + logger.info("Expected TypeError raised for invalid attribute type.") + + async def test_5012(self): + """Test vector index creation with invalid name type.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.AsyncVectorIndex( + index_name=12345, + attributes=self.vector_index_attributes, + profile=self.profile, + ).create() + assert "ORA-20048" in str(exc_info.value) + assert "Invalid vector index name" in str(exc_info.value) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + + async def test_5013(self): + """Test vector index creation with empty name.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await select_ai.AsyncVectorIndex( + index_name="", + attributes=self.vector_index_attributes, + profile=self.profile, + ).create() + assert "ORA-20048" in str(exc_info.value) + assert "Missing vector index name" in str(exc_info.value) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + + async def test_5014(self): + """Test vector index creation with invalid profile.""" + with pytest.raises(TypeError) as exc_info: + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile="invalid_profile", + ) + await vector_index.create() + logger.info( + "Expected TypeError raised for invalid profile: %s", + exc_info.value, + ) + + async def test_5015(self): + """Test vector index creation with None attributes.""" + with pytest.raises(TypeError) as exc_info: + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=None, + description="invalid_profile", + profile="invalid_profile", + ) + await vector_index.create() + logger.info( + "Expected TypeError raised for None attributes: %s", + exc_info.value, + ) + + async def test_5016(self): + """Test vector index creation with long name (>128 chars).""" + long_name = "X" * 150 + vector_index = select_ai.AsyncVectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await vector_index.create() + logger.info( + "Expected DatabaseError raised for long name: %s", + exc_info.value, + ) + + async def test_5017(self): + """Test vector index creation with long description.""" + long_desc = "D" * 5000 + vector_index = select_ai.AsyncVectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description=long_desc, + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await vector_index.create(replace=True) + assert "ORA-20045" in str(exc_info.value) + assert "description is too long" in str(exc_info.value) + logger.info("Expected DatabaseError raised: %s", exc_info.value) + + async def test_5018(self): + """Test multiple recreations of vector index (10x).""" + for _ in range(10): + await self.async_vector_index.create(replace=True) + logger.info("Successfully recreated vector index multiple times.") diff --git a/tests/vector_index/test_5000_create_index.py b/tests/vector_index/test_5000_create_index.py new file mode 100644 index 0000000..c892cdb --- /dev/null +++ b/tests/vector_index/test_5000_create_index.py @@ -0,0 +1,434 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb +from select_ai import OracleVectorIndexAttributes + +logger = logging.getLogger("TestCreateVectorIndex") + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + +@pytest.fixture(scope="class") +def vector_index_params( + request, + vcidx_params, +): + request.cls.vector_index_params = vcidx_params + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, vector_index_params): + """ + 'connect' fixture from base tests/conftest.py ensures DB connection exists. + Do NOT disconnect here; let the session fixture own lifecycle. + """ + logger.info("\n=== Setting up TestCreateVectorIndex class ===") + assert select_ai.is_connected(), "Connection to DB failed" + logger.info("Fetching credential secrets and OCI configuration...") + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Setup complete.\n") + yield + logger.info("\n=== Tearing down TestCreateVectorIndex class ===") + request.cls.delete_profile(request.cls.profile) + request.cls.delete_credential() + logger.info("Teardown complete.\n") + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + +@pytest.mark.usefixtures("vector_index_params", "setup_and_teardown") +class TestCreateVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing native credential params for: {cred_name}") + params = cls.vector_index_params + return dict( + credential_name=cred_name, + user_ocid=params["user_ocid"], + tenancy_ocid=params["tenancy_ocid"], + private_key=params["private_key"], + fingerprint=params["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing basic credential params for: {cred_name}") + params = cls.vector_index_params + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"], + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info(f"Creating credentials: {genai_cred}, {objstore_cred}") + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info(f"Creating GenAI credential: {genai_cred}") + select_ai.create_credential(credential=genai_credential, replace=True) + logger.info("GenAI credential created.") + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Only create OBJSTORE_CRED if creds are provided in env + params = cls.vector_index_params + if params.get("cred_username") and params.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info(f"Creating ObjectStore credential: {objstore_cred}") + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("ObjectStore credential created.") + except Exception as e: + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + logger.info(f"Creating Profile: {profile_name}") + params = cls.vector_index_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=params["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0" + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider, + ) + profile = select_ai.Profile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True, + ) + logger.info(f"Profile '{profile_name}' created successfully.") + return profile + + @classmethod + def delete_profile(cls, profile): + logger.info("Deleting profile...") + try: + profile.delete() + logger.info(f"Profile '{profile.profile_name}' deleted successfully.") + except Exception as e: + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + @classmethod + def delete_credential(cls): + logger.info("Deleting credentials...") + try: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + def setup_method(self, method): + logger.info(f"\n--- Starting test: {method.__name__} ---") + self.objstore_cred = "OBJSTORE_CRED" + params = self.vector_index_params + self.vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name=self.objstore_cred, + ) + self.profile = self.profile + self.vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + + def teardown_method(self, method): + logger.info(f"--- Finished test: {method.__name__} ---") + try: + vector_index = select_ai.VectorIndex(index_name="test_vector_index") + vector_index.delete(force=True) + logger.info("Vector index deleted successfully.") + except Exception as e: + logger.warning(f"Warning: vector index cleanup failed: {e}") + + def test_5001(self): + """Test successful vector index creation.""" + try: + self.vector_index.create(replace=True) + logger.info("Vector index created successfully.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + logger.info("Verifying created vector index...") + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + logger.info(f"Indexes found: {indexes}") + assert "TEST_VECTOR_INDEX" in indexes + logger.info("Verified vector index creation successfully.") + + def test_5002(self): + """Test vector index creation with replace=False.""" + try: + self.vector_index.create(replace=False) + logger.info("Vector index created successfully.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + logger.info("Verifying created vector index...") + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + logger.info(f"Indexes found: {indexes}") + assert "TEST_VECTOR_INDEX" in indexes + logger.info("Verified vector index presence.") + + def test_5003(self): + """Test vector index creation with empty description.""" + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="", + profile=self.profile, + ) + try: + vector_index.create(replace=True) + logger.info("Vector index created successfully with empty description.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + logger.info("Verifying created vector index...") + vector_index = select_ai.VectorIndex() + indexes = [i.index_name for i in vector_index.list()] + logger.info(f"Indexes found: {indexes}") + assert "TEST_VECTOR_INDEX" in indexes + logger.info("Verified vector index creation with empty description.") + + def test_5004(self): + """Test vector index recreation with replace=True.""" + try: + self.vector_index.create(replace=True) + logger.info("First creation successful.") + self.vector_index.create(replace=True) + logger.info("Second creation successful with replace=True.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + def test_5005(self): + """Test vector index recreation with replace=False (expect failure).""" + try: + self.vector_index.create(replace=False) + logger.info("First creation successful.") + except Exception as e: + pytest.fail(f"Create vector index failed unexpectedly: {e}") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.create(replace=False) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + assert "ORA-20048" in str(exc_info.value) + assert "already exists" in str(exc_info.value) + logger.info("Verified error on duplicate creation with replace=False.") + + def test_5006(self): + """Test minimal attribute vector index creation.""" + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + try: + vector_index.create(replace=True) + logger.info("Vector index created successfully with minimal attributes.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + def test_5007(self): + """Test vector index recreation after delete.""" + try: + self.vector_index.create(replace=True) + logger.info("Vector index created successfully.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + logger.info("Deleting vector index...") + vector_index = select_ai.VectorIndex(index_name="test_vector_index") + vector_index.delete(force=True) + logger.info("Vector index deleted successfully.") + logger.info("Recreating vector index...") + try: + self.vector_index.create(replace=True) + logger.info("Vector index recreated successfully.") + except Exception as e: + pytest.fail(f"VectorIndex.create raised an unexpected exception: {e}") + + def test_5008(self): + """Test vector index creation with invalid credential.""" + params = self.vector_index_params + vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name="invalidObjStore_cred", + ) + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + vector_index.create(replace=True) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + + def test_5009(self): + """Test vector index creation with invalid location.""" + vector_index_attributes = OracleVectorIndexAttributes( + location="invalid_location", + object_storage_credential_name=self.objstore_cred, + ) + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + vector_index.create(replace=True) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + + def test_5010(self): + """Test vector index creation with missing attributes.""" + with pytest.raises(AttributeError): + select_ai.VectorIndex( + index_name="test_vector_index", + attributes=None, + profile=self.profile, + ).create() + logger.info("Expected AttributeError raised for missing attributes.") + + def test_5011(self): + """Test vector index creation with invalid attributes type.""" + with pytest.raises(TypeError): + select_ai.VectorIndex( + index_name="test_vector_index", + attributes="invalid_attributes", + profile=self.profile, + ).create() + logger.info("Expected TypeError raised for invalid attribute type.") + + def test_5012(self): + """Test vector index creation with invalid name type.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + select_ai.VectorIndex( + index_name=12345, + attributes=self.vector_index_attributes, + profile=self.profile, + ).create() + assert "ORA-20048" in str(exc_info.value) + assert "Invalid vector index name" in str(exc_info.value) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + + def test_5013(self): + """Test vector index creation with empty name.""" + with pytest.raises(oracledb.DatabaseError) as exc_info: + select_ai.VectorIndex( + index_name="", + attributes=self.vector_index_attributes, + profile=self.profile, + ).create() + assert "ORA-20048" in str(exc_info.value) + assert "Missing vector index name" in str(exc_info.value) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + + def test_5014(self): + """Test vector index creation with invalid profile.""" + with pytest.raises(TypeError) as exc_info: + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description="Test vector index", + profile="invalid_profile", + ) + vector_index.create() + logger.info( + "Expected TypeError raised for invalid profile: %s", + exc_info.value, + ) + + def test_5015(self): + """Test vector index creation with None attributes.""" + with pytest.raises(TypeError) as exc_info: + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=None, + description="invalid_profile", + profile="invalid_profile", + ) + vector_index.create() + logger.info( + "Expected TypeError raised for None attributes: %s", + exc_info.value, + ) + + def test_5016(self): + """Test vector index creation with long name (>128 chars).""" + long_name = "X" * 150 + vector_index = select_ai.VectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + vector_index.create() + logger.info( + "Expected DatabaseError raised for long name: %s", + exc_info.value, + ) + + def test_5017(self): + """Test vector index creation with long description.""" + long_desc = "D" * 5000 + vector_index = select_ai.VectorIndex( + index_name="test_vector_index", + attributes=self.vector_index_attributes, + description=long_desc, + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + vector_index.create(replace=True) + assert "ORA-20045" in str(exc_info.value) + assert "description is too long" in str(exc_info.value) + logger.info( + "Expected DatabaseError raised: %s", + exc_info.value, + ) + + def test_5018(self): + """Test multiple recreations of vector index (10x).""" + for _ in range(10): + self.vector_index.create(replace=True) + logger.info("Successfully recreated vector index multiple times.") diff --git a/tests/vector_index/test_5100_async_drop_index.py b/tests/vector_index/test_5100_async_drop_index.py new file mode 100644 index 0000000..e59c1c3 --- /dev/null +++ b/tests/vector_index/test_5100_async_drop_index.py @@ -0,0 +1,592 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import asyncio +import logging + +import oracledb +import pytest +import select_ai +from select_ai import OracleVectorIndexAttributes + +logger = logging.getLogger("TestAsyncDeleteVectorIndex") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def delete_vec_params(request, vcidx_params): + request.cls.delete_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, delete_vec_params): + logger.info("=== Setting up TestAsyncDeleteVectorIndex class ===") + assert await select_ai.async_is_connected(), "Connection to DB failed" + logger.info("Fetching credential secrets and OCI configuration...") + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Setup complete.") + yield + logger.info("=== Tearing down TestAsyncDeleteVectorIndex class ===") + await request.cls.delete_profile(request.cls.profile) + await request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +async def vector_index_test_state(request): + logger.info("SetUp for %s", request.function.__name__) + request.cls.objstore_cred = "OBJSTORE_CRED" + request.cls.vecidx = select_ai.AsyncVectorIndex() + params = request.cls.delete_vec_params + request.cls.vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name=request.cls.objstore_cred, + ) + + await request.cls.delete_and_wait() + + request.cls.index_name = "test_vector_index" + request.cls.async_vector_index = select_ai.AsyncVectorIndex( + index_name=request.cls.index_name, + attributes=request.cls.vector_index_attributes, + description="Test vector index", + profile=request.cls.profile, + ) + await request.cls.async_vector_index.create(replace=True) + logger.info( + "Vector index '%s' created for test.", request.cls.index_name + ) + yield + logger.info("TearDown for %s", request.function.__name__) + try: + await request.cls.async_vector_index.delete(force=True) + logger.info( + "Vector index '%s' deleted successfully.", request.cls.index_name + ) + except Exception as exc: + logger.warning("Warning: vector index cleanup failed: %s", exc) + + +@pytest.mark.usefixtures("delete_vec_params", "setup_and_teardown") +class TestAsyncDeleteVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing native credential params for: %s", cred_name) + params = cls.delete_vec_params + return dict( + credential_name=cred_name, + user_ocid=params["user_ocid"], + tenancy_ocid=params["tenancy_ocid"], + private_key=params["private_key"], + fingerprint=params["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing basic credential params for: %s", cred_name) + params = cls.delete_vec_params + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"], + ) + + @classmethod + async def create_credential( + cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED" + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info("Creating GenAI credential: %s", genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + logger.info("GenAI credential created.") + except Exception as exc: + logger.error( + "create_credential() raised %s unexpectedly.", exc + ) + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + + params = cls.delete_vec_params + if params.get("cred_username") and params.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info( + "Creating ObjectStore credential: %s", objstore_cred + ) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("ObjectStore credential created.") + except Exception as exc: + logger.error( + "create_credential() raised %s unexpectedly.", exc + ) + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + logger.info("Creating Profile: %s", profile_name) + params = cls.delete_vec_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=params["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider, + ) + profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True, + ) + logger.info("Profile '%s' created successfully.", profile_name) + return profile + + @classmethod + async def delete_profile(cls, profile): + logger.info("Deleting profile...") + try: + await profile.delete() + logger.info( + "Profile '%s' deleted successfully.", profile.profile_name + ) + except Exception as exc: + logger.error("profile.delete() raised %s unexpectedly.", exc) + raise AssertionError( + f"profile.delete() raised {exc} unexpectedly." + ) + + @classmethod + async def delete_credential(cls): + logger.info("Deleting credentials...") + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as exc: + logger.warning( + "delete_credential() raised %s unexpectedly.", exc + ) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as exc: + logger.warning( + "delete_credential() raised %s unexpectedly.", exc + ) + + @classmethod + async def delete_and_wait(cls, force=True, pattern=".*", wait_seconds=1): + logger.info("Deleting indexes matching pattern.") + all_indexes = [ + index async for index in cls.vecidx.list(index_name_pattern=pattern) + ] + if not all_indexes: + logger.info("No indexes found to delete.") + return + for index in all_indexes: + try: + await index.delete(force=force) + logger.info("Deleted index: %s", index.index_name) + await asyncio.sleep(wait_seconds) + except Exception as exc: + logger.warning( + "Warning: failed to delete index %s: %s", + index.index_name, + exc, + ) + remaining = [ + index async for index in cls.vecidx.list(index_name_pattern=pattern) + ] + logger.info( + "Remaining indexes after delete: %s", + [index.index_name for index in remaining], + ) + + async def assert_index_count(self, pattern, expected): + actual = [ + index async for index in self.vecidx.list(index_name_pattern=pattern) + ] + logger.info( + "Indexes matching '%s': %s", + pattern, + [index.index_name for index in actual], + ) + assert len(actual) == expected, ( + f"Expected {expected} indexes, got {len(actual)}" + ) + + async def verify_and_cleanup_vectab(self, vector_index_name: str): + table_name = f"{vector_index_name}$vectab".upper() + logger.info("Verifying and cleaning up vector table: %s", table_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + SELECT column_name + FROM user_tab_columns + WHERE table_name = :table_name + ORDER BY column_id + """, + {"table_name": table_name}, + ) + cols = [row[0] for row in await cursor.fetchall()] + logger.info("Columns found in %s: %s", table_name, cols) + expected_cols = ["CONTENT", "ATTRIBUTES", "EMBEDDING"] + assert cols == expected_cols, ( + f"Unexpected columns for {table_name}: {cols}" + ) + await cursor.execute(f"DROP TABLE {table_name} PURGE") + logger.info("Table %s dropped successfully.", table_name) + + async def test_5101(self): + """Test single vector index deletion removes the index.""" + logger.info("Deleting vector index (single delete)") + await self.assert_index_count("^test_vector_index", 1) + await self.async_vector_index.delete(force=True) + logger.info("Delete called on vector index.") + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info("Single-delete verified: index removed") + + async def test_5103(self): + """Test deleting the same vector index twice.""" + logger.info("Deleting vector index first time") + await self.async_vector_index.delete(force=True) + logger.info("Deleting vector index second time (no-op expected)") + await asyncio.sleep(1) + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info("Double-delete verified: index removed") + + async def test_5104(self): + """Test delete with include_data=True also removes table.""" + logger.info("Deleting index with include_data=True (metadata + table)") + await self.async_vector_index.delete(include_data=True, force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + table_name = "TEST_VECTOR_INDEX$VECTAB" + async with select_ai.async_cursor() as cursor: + await cursor.execute( + """ + SELECT COUNT(*) + FROM user_tables + WHERE table_name = :table_name + """, + {"table_name": table_name}, + ) + (count,) = await cursor.fetchone() + logger.info( + "Verified vector table '%s' removed: %s", table_name, count == 0 + ) + assert count == 0 + + async def test_5105(self): + """Test delete with include_data=False doesn't remove table.""" + logger.info("Deleting index with include_data=False (metadata only)") + await self.async_vector_index.delete(include_data=False, force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info( + "Attempting to recreate index (should fail due to leftover table)" + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.create(replace=True) + logger.info("Expected DatabaseError on recreate: %s", exc_info.value) + assert "ORA-00955" in str(exc_info.value) + await self.verify_and_cleanup_vectab("test_vector_index") + logger.info("Vector table cleaned up after failed recreate") + + async def test_5106(self): + """Test delete twice with include_data=False then cleanup.""" + logger.info("Deleting index metadata only first time") + await self.async_vector_index.delete(include_data=False, force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting to recreate index (should fail)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.create(replace=True) + assert "ORA-00955" in str(exc_info.value) + await self.verify_and_cleanup_vectab("test_vector_index") + logger.info("Vector table cleaned up") + logger.info("Deleting index metadata only second time (no-op)") + await self.async_vector_index.delete(include_data=False, force=True) + await self.assert_index_count("^test_vector_index", 0) + + async def test_5107(self): + """Test delete twice with include_data=False and cleanup after failed recreate.""" + logger.info("Deleting metadata only first time") + await self.async_vector_index.delete(include_data=False, force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting recreate (expected failure)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.create(replace=True) + logger.info("Recreate failed (expected): %s", exc_info.value) + assert "ORA-00955" in str(exc_info.value) + logger.info("Deleting metadata second time (no-op)") + await self.async_vector_index.delete(include_data=False, force=True) + await self.verify_and_cleanup_vectab("test_vector_index") + await self.assert_index_count("^test_vector_index", 0) + logger.info("Cleanup complete") + + async def test_5108(self): + """Test delete and then recreate a vector index.""" + logger.info("Deleting index before recreation") + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + logger.info("Recreating vector index") + await self.async_vector_index.create(replace=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 1) + logger.info("Recreate verified: index exists") + + async def test_5109(self): + """Test delete of a nonexistent index (should not error).""" + idx = select_ai.AsyncVectorIndex( + index_name="nonexistent_index", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + logger.info("Attempting to delete nonexistent index") + await idx.delete(force=True) + await asyncio.sleep(1) + await self.assert_index_count("^nonexistent_index", 0) + logger.info("Nonexistent delete verified (no error)") + + async def test_5110(self): + """Test delete after set_attribute was called.""" + logger.info("Setting temporary attributes before delete") + await self.async_vector_index.create(replace=True) + try: + await self.async_vector_index.set_attribute( + attribute_name="match_limit", + attribute_value=10, + ) + except Exception as exc: + logger.error("set_attribute() raised %s unexpectedly.", exc) + pytest.fail(f"set_attribute() raised {exc} unexpectedly.") + logger.info("Deleting index after setting attributes") + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + actual_indexes = [ + index + async for index in self.async_vector_index.list( + index_name_pattern="^test_vector_index$" + ) + ] + logger.info("Indexes remaining after delete: %s", actual_indexes) + assert len(actual_indexes) == 0 + logger.info("Delete after attributes verified") + + async def test_5111(self): + """Test case-sensitive name for create and delete.""" + idx = select_ai.AsyncVectorIndex( + index_name="CaseSensitiveIndex", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + logger.info("Creating case-sensitive index") + await idx.create(replace=True) + logger.info("Deleting case-sensitive index") + await idx.delete(force=True) + await asyncio.sleep(1) + await self.assert_index_count("^CaseSensitiveIndex", 0) + logger.info("Case-sensitive index delete verified") + + async def test_5112(self): + """Test creation and deletion with long index name.""" + long_name = "index_" + "x" * 40 + idx = select_ai.AsyncVectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.profile, + ) + logger.info("Creating long-name index: %s", long_name) + await idx.create(replace=True) + logger.info("Deleting long-name index: %s", long_name) + await idx.delete(force=True) + await asyncio.sleep(1) + await self.assert_index_count(f"^{long_name}$", 0) + logger.info("Long-name index delete verified") + + async def test_5113(self): + """Test creation and bulk deletion of indexes.""" + names = [f"bulk_idx_{i}" for i in range(3)] + logger.info("Creating bulk indexes") + for name in names: + await select_ai.AsyncVectorIndex( + index_name=name, + attributes=self.vector_index_attributes, + profile=self.profile, + ).create(replace=True) + logger.info("Created %s", name) + logger.info("Deleting bulk indexes") + for name in names: + await select_ai.AsyncVectorIndex( + index_name=name, + attributes=self.vector_index_attributes, + profile=self.profile, + ).delete(force=True) + await asyncio.sleep(1) + logger.info("Deleted %s", name) + await self.assert_index_count("^bulk_idx_", 0) + logger.info("Bulk delete verified") + + async def test_5114(self): + """Test that list returns empty after index is deleted.""" + logger.info("Deleting index and verifying list is empty") + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + actual_indexes = [ + index + async for index in self.async_vector_index.list( + index_name_pattern=".*" + ) + ] + index_names = [index.index_name for index in actual_indexes] + logger.info("Actual indexes after delete: %s", index_names) + actual = [ + index + async for index in self.async_vector_index.list( + index_name_pattern="^test_vector_index" + ) + ] + assert actual == [] + logger.info("List verification successful: no remaining indexes") + + async def test_5115(self): + """Test delete then recreate index with same name.""" + logger.info("Deleting index before recreate with same name") + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + logger.info("Recreating index with same name") + await self.async_vector_index.create(replace=False) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 1) + logger.info("Recreate same-name index verified") + + async def test_5116(self): + """Test delete of one out of multiple indexes.""" + idx1 = select_ai.AsyncVectorIndex( + index_name="IDX_1", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + idx2 = select_ai.AsyncVectorIndex( + index_name="IDX_2", + attributes=self.vector_index_attributes, + profile=self.profile, + ) + logger.info("Creating two indexes IDX_1 and IDX_2") + await idx1.create(replace=True) + await idx2.create(replace=True) + logger.info("Deleting IDX_1 only") + await self.delete_and_wait(force=True, pattern="^IDX_1$") + remaining_idx2 = [ + index + async for index in self.async_vector_index.list( + index_name_pattern="^IDX_2$" + ) + ] + logger.info( + "IDX_2 entries after IDX_1 delete: %s", remaining_idx2 + ) + assert len(remaining_idx2) == 1 + logger.info("IDX_2 remains after IDX_1 delete") + + async def test_5117(self): + """Test deletion by pattern.""" + logger.info("Deleting index with pattern '^test_vector_index$'") + await self.async_vector_index.create(replace=True) + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + actual = [ + index + async for index in self.async_vector_index.list( + index_name_pattern="^test_vector_index$" + ) + ] + logger.info("List entries after pattern delete: %s", actual) + assert len(actual) == 0 + logger.info("Pattern delete verified") + + async def test_5118(self): + """Test delete with force=True option.""" + logger.info("Deleting index with force=True") + await self.async_vector_index.create(replace=True) + await self.async_vector_index.delete(force=True) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index$", 0) + logger.info("Force delete verified") + + async def test_5119(self): + """Test delete with force=False option.""" + logger.info("Creating index before delete (force=False)") + await self.async_vector_index.create(replace=True) + logger.info("Deleting index with force=False") + await self.async_vector_index.delete(force=False) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index$", 0) + logger.info("Delete verified successfully with force=False") + + async def test_5120(self): + """Test delete with force=False called twice in a row.""" + logger.info("Deleting index first time (force=False)") + await self.async_vector_index.delete(force=False) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index$", 0) + logger.info("First delete succeeded") + logger.info("Attempting second delete (expected to fail)") + with pytest.raises(Exception) as exc_info: + await self.async_vector_index.delete(force=False) + assert "does not exist" in str(exc_info.value) + logger.info("Expected failure confirmed: %s", exc_info.value) + await self.assert_index_count("^test_vector_index$", 0) + logger.info("Index still absent after failed second delete") + + async def test_5121(self): + """Test delete include_data=False and force=False (leftover vectab).""" + logger.info("Deleting index with include_data=False and force=False") + await self.async_vector_index.delete( + include_data=False, + force=False, + ) + await asyncio.sleep(1) + await self.assert_index_count("^test_vector_index", 0) + logger.info( + "Attempting to recreate index (expected to fail - leftover data)" + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.create(replace=False) + assert "ORA-00955" in str(exc_info.value) + logger.info( + "Expected recreate failure confirmed: %s", exc_info.value + ) + logger.info("Cleaning up leftover vector table") + await self.verify_and_cleanup_vectab("test_vector_index") + logger.info( + "Cleanup complete after include_data=False, force=False delete" + ) diff --git a/tests/vector_index/test_5100_drop_index.py b/tests/vector_index/test_5100_drop_index.py new file mode 100644 index 0000000..9c0ba3c --- /dev/null +++ b/tests/vector_index/test_5100_drop_index.py @@ -0,0 +1,540 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb +import time +from select_ai import OracleVectorIndexAttributes + +# Set up global logger (one per module) +logger = logging.getLogger("TestDeleteVectorIndex") + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + + +@pytest.fixture(scope="class") +def delete_vec_params( + request, + vcidx_params, +): + request.cls.delete_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, delete_vec_params): + logger.info("=== Setting up TestDeleteVectorIndex class ===") + + # 'connect' fixture from base tests/conftest.py ensures DB connection exists. + # Do NOT disconnect here; let the session fixture own lifecycle. + assert select_ai.is_connected(), "Connection to DB failed" + + logger.info("Fetching credential secrets and OCI configuration...") + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Setup complete.") + + yield + + logger.info("=== Tearing down TestDeleteVectorIndex class ===") + request.cls.delete_profile(request.cls.profile) + request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +@pytest.mark.usefixtures("delete_vec_params", "setup_and_teardown") +class TestDeleteVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing native credential params for: {cred_name}") + params = cls.delete_vec_params + return dict( + credential_name=cred_name, + user_ocid=params["user_ocid"], + tenancy_ocid=params["tenancy_ocid"], + private_key=params["private_key"], + fingerprint=params["fingerprint"] + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing basic credential params for: {cred_name}") + params = cls.delete_vec_params + return dict( + credential_name=cred_name, + username=params["cred_username"], + password=params["cred_password"] + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info(f"Creating credentials: {genai_cred}, {objstore_cred}") + + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info(f"Creating GenAI credential: {genai_cred}") + select_ai.create_credential(credential=genai_credential, replace=True) + logger.info("GenAI credential created.") + except Exception as e: + logger.error(f"create_credential() raised {e} unexpectedly.") + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Only create OBJSTORE_CRED if creds are provided in env + params = cls.delete_vec_params + if params.get("cred_username") and params.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info(f"Creating ObjectStore credential: {objstore_cred}") + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("ObjectStore credential created.") + except Exception as e: + logger.error(f"create_credential() raised {e} unexpectedly.") + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + logger.info(f"Creating Profile: {profile_name}") + params = cls.delete_vec_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=params["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + profile = select_ai.Profile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + logger.info(f"Profile '{profile_name}' created successfully.") + return profile + + @classmethod + def delete_profile(cls, profile): + logger.info("Deleting profile...") + try: + profile.delete() + logger.info(f"Profile '{profile.profile_name}' deleted successfully.") + except Exception as e: + logger.error(f"profile.delete() raised {e} unexpectedly.") + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + @classmethod + def delete_credential(cls): + logger.info("Deleting credentials...") + try: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + def delete_and_wait(self, force=True, pattern=".*", wait_seconds=1): + logger.info("Deleting indexes matching pattern.") + all_indexes = list(self.vecidx.list(index_name_pattern=pattern)) + if not all_indexes: + logger.info("No indexes found to delete.") + return + for idx in all_indexes: + try: + idx.delete(force=force) + logger.info(f"Deleted index: {idx.index_name}") + time.sleep(wait_seconds) + except Exception as e: + logger.warning(f"Warning: failed to delete index {idx.index_name}: {e}") + remaining = list(self.vecidx.list(index_name_pattern=pattern)) + logger.info(f"Remaining indexes after delete: {[i.index_name for i in remaining]}") + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + self.objstore_cred = "OBJSTORE_CRED" + self.vecidx = select_ai.VectorIndex() + params = self.delete_vec_params + + self.vector_index_attributes = OracleVectorIndexAttributes( + location=params["embedding_location"], + object_storage_credential_name=self.objstore_cred + ) + + self.delete_and_wait() + + self.index_name = "test_vector_index" + self.vector_index = select_ai.VectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + self.vector_index.create(replace=True) + logger.info(f"Vector index '{self.index_name}' created for test.") + + def teardown_method(self, method): + logger.info(f"TearDown for {method.__name__}") + try: + self.vector_index.delete(force=True) + logger.info(f"Vector index '{self.index_name}' deleted successfully.") + except Exception as e: + logger.warning(f"Warning: vector index cleanup failed: {e}") + + def assert_index_count(self, pattern, expected): + actual = list(self.vecidx.list(index_name_pattern=pattern)) + logger.info(f"Indexes matching '{pattern}': {[i.index_name for i in actual]}") + assert len(actual) == expected, f"Expected {expected} indexes, got {len(actual)}" + + def verify_and_cleanup_vectab(self, vector_index_name: str): + table_name = f"{vector_index_name}$vectab".upper() + logger.info(f"Verifying and cleaning up vector table: {table_name}") + with select_ai.cursor() as cursor: + cursor.execute(""" + SELECT column_name + FROM user_tab_columns + WHERE table_name = :table_name + ORDER BY column_id + """, {"table_name": table_name}) + cols = [c[0] for c in cursor.fetchall()] + logger.info(f"Columns found in {table_name}: {cols}") + expected_cols = ["CONTENT", "ATTRIBUTES", "EMBEDDING"] + assert cols == expected_cols, f"Unexpected columns for {table_name}: {cols}" + cursor.execute(f"DROP TABLE {table_name} PURGE") + logger.info(f"Table {table_name} dropped successfully.") + + def test_5101(self): + """Test single vector index deletion removes the index.""" + logger.info("Deleting vector index (single delete)") + self.assert_index_count("^test_vector_index", 1) + self.vector_index.delete(force=True) + logger.info("Delete called on vector index.") + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Single-delete verified: index removed") + + # def test_5102(self): + # """Test multiple creates then bulk delete.""" + # logger.info("Creating multiple vector indexes") + # # Create multiple indexes + # for i in range(5): + # idx = select_ai.VectorIndex( + # index_name=f"TEST_VECTOR_INDEX_{i}", + # attributes=self.vector_index_attributes, + # profile=self.profile + # ) + # idx.create(replace=True) + # logger.info(f"Created index TEST_VECTOR_INDEX_{i}") + # logger.info("Deleting all created indexes") + # # Delete all indexes one by one + # self.delete_and_wait(force=True, pattern=f"^TEST_VECTOR_INDEX_{i}$") + # # Ensure all are gone + # actual_indexes = list(self.vector_index.list(index_name_pattern="^TEST_VECTOR_INDEX_")) + # logger.info(f"Indexes found for bulk delete test: {actual_indexes}") + # assert len(actual_indexes) == 0 + # logger.info("All TEST_VECTOR_INDEX_* deleted successfully.") + + def test_5103(self): + """Test deleting the same vector index twice.""" + logger.info("Deleting vector index first time") + self.vector_index.delete(force=True) + logger.info("Deleting vector index second time (no-op expected)") + time.sleep(1) + self.vector_index.delete(force=True) # no-op + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Double-delete verified: index removed") + + def test_5104(self): + """Test delete with include_data=True also removes table.""" + logger.info("Deleting index with include_data=True (metadata + table)") + self.vector_index.delete(include_data=True, force=True) + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + table_name = "TEST_VECTOR_INDEX$VECTAB" + with select_ai.cursor() as cursor: + cursor.execute(""" + SELECT COUNT(*) + FROM user_tables + WHERE table_name = :table_name + """, {"table_name": table_name}) + (count,) = cursor.fetchone() + logger.info(f"Verified vector table '{table_name}' removed: {count==0}") + assert count == 0 + + def test_5105(self): + """Test delete with include_data=False doesn't remove table.""" + logger.info("Deleting index with include_data=False (metadata only)") + self.vector_index.delete(include_data=False, force=True) + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting to recreate index (should fail due to leftover table)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.create(replace=True) + logger.info( + "Expected DatabaseError on recreate: %s", + exc_info.value, + ) + assert "ORA-00955" in str(exc_info.value) + self.verify_and_cleanup_vectab("test_vector_index") + logger.info("Vector table cleaned up after failed recreate") + + def test_5106(self): + """Test delete twice with include_data=False then cleanup.""" + logger.info("Deleting index metadata only first time") + self.vector_index.delete(include_data=False, force=True) + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting to recreate index (should fail)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.create(replace=True) + assert "ORA-00955" in str(exc_info.value) + self.verify_and_cleanup_vectab("test_vector_index") + logger.info("Vector table cleaned up") + logger.info("Deleting index metadata only second time (no-op)") + self.vector_index.delete(include_data=False, force=True) + self.assert_index_count("^test_vector_index", 0) + + def test_5107(self): + """Test delete twice with include_data=False and cleanup after failed recreate.""" + logger.info("Deleting metadata only first time") + self.vector_index.delete(include_data=False, force=True) + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting recreate (expected failure)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.create(replace=True) + logger.info( + "Recreate failed (expected): %s", + exc_info.value, + ) + assert "ORA-00955" in str(exc_info.value) + logger.info("Deleting metadata second time (no-op)") + self.vector_index.delete(include_data=False, force=True) + self.verify_and_cleanup_vectab("test_vector_index") + self.assert_index_count("^test_vector_index", 0) + logger.info("Cleanup complete") + + def test_5108(self): + """Test delete and then recreate a vector index.""" + logger.info("Deleting index before recreation") + self.vector_index.delete(force=True) + time.sleep(1) + logger.info("Recreating vector index") + self.vector_index.create(replace=True) + time.sleep(1) + self.assert_index_count("^test_vector_index", 1) + logger.info("Recreate verified: index exists") + + def test_5109(self): + """Test delete of a nonexistent index (should not error).""" + idx = select_ai.VectorIndex( + index_name="nonexistent_index", + attributes=self.vector_index_attributes, + profile=self.profile + ) + logger.info("Attempting to delete nonexistent index") + idx.delete(force=True) + time.sleep(1) + self.assert_index_count("^nonexistent_index", 0) + logger.info("Nonexistent delete verified (no error)") + + def test_5110(self): + """Test delete after set_attribute was called.""" + logger.info("Setting temporary attributes before delete") + self.vector_index.create(replace=True) + try: + self.vector_index.set_attribute( + attribute_name="match_limit", + attribute_value=10 + ) + except Exception as e: + logger.error(f"set_attribute() raised {e} unexpectedly.") + pytest.fail(f"set_attribute() raised {e} unexpectedly.") + logger.info("Deleting index after setting attributes") + self.vector_index.delete(force=True) + time.sleep(1) + actual_indexes = list(self.vector_index.list(index_name_pattern="^test_vector_index$")) + logger.info(f"Indexes remaining after delete: {actual_indexes}") + assert len(actual_indexes) == 0 + logger.info("Delete after attributes verified") + + def test_5111(self): + """Test case-sensitive name for create and delete.""" + idx = select_ai.VectorIndex( + index_name="CaseSensitiveIndex", + attributes=self.vector_index_attributes, + profile=self.profile + ) + logger.info("Creating case-sensitive index") + idx.create(replace=True) + logger.info("Deleting case-sensitive index") + idx.delete(force=True) + time.sleep(1) + self.assert_index_count("^CaseSensitiveIndex", 0) + logger.info("Case-sensitive index delete verified") + + def test_5112(self): + """Test creation and deletion with long index name.""" + long_name = "index_" + "x" * 40 + idx = select_ai.VectorIndex( + index_name=long_name, + attributes=self.vector_index_attributes, + profile=self.profile + ) + logger.info(f"Creating long-name index: {long_name}") + idx.create(replace=True) + logger.info(f"Deleting long-name index: {long_name}") + idx.delete(force=True) + time.sleep(1) + self.assert_index_count(f"^{long_name}$", 0) + logger.info("Long-name index delete verified") + + def test_5113(self): + """Test creation and bulk deletion of indexes.""" + names = [f"bulk_idx_{i}" for i in range(3)] + logger.info("Creating bulk indexes") + for n in names: + select_ai.VectorIndex( + index_name=n, + attributes=self.vector_index_attributes, + profile=self.profile + ).create(replace=True) + logger.info(f"Created {n}") + logger.info("Deleting bulk indexes") + for n in names: + select_ai.VectorIndex( + index_name=n, + attributes=self.vector_index_attributes, + profile=self.profile + ).delete(force=True) + time.sleep(1) + logger.info(f"Deleted {n}") + self.assert_index_count("^bulk_idx_", 0) + logger.info("Bulk delete verified") + + def test_5114(self): + """Test that list returns empty after index is deleted.""" + logger.info("Deleting index and verifying list is empty") + self.vector_index.delete(force=True) + time.sleep(1) + actual_indexes = list(self.vector_index.list(index_name_pattern=".*")) + index_names = [idx.index_name for idx in actual_indexes] + logger.info(f"Actual indexes after delete: {index_names}") + actual = list(self.vector_index.list(index_name_pattern="^test_vector_index")) + assert actual == [] + logger.info("List verification successful: no remaining indexes") + + def test_5115(self): + """Test delete then recreate index with same name.""" + logger.info("Deleting index before recreate with same name") + self.vector_index.delete(force=True) + time.sleep(1) + logger.info("Recreating index with same name") + self.vector_index.create(replace=False) + time.sleep(1) + self.assert_index_count("^test_vector_index", 1) + logger.info("Recreate same-name index verified") + + def test_5116(self): + """Test delete of one out of multiple indexes.""" + idx1 = select_ai.VectorIndex(index_name="IDX_1", attributes=self.vector_index_attributes, profile=self.profile) + idx2 = select_ai.VectorIndex(index_name="IDX_2", attributes=self.vector_index_attributes, profile=self.profile) + logger.info("Creating two indexes IDX_1 and IDX_2") + idx1.create(replace=True) + idx2.create(replace=True) + logger.info("Deleting IDX_1 only") + self.delete_and_wait(force=True, pattern="^IDX_1$") + remaining_idx2 = list(self.vector_index.list(index_name_pattern="^IDX_2$")) + logger.info(f"IDX_2 entries after IDX_1 delete: {remaining_idx2}") + assert len(remaining_idx2) == 1 + logger.info("IDX_2 remains after IDX_1 delete") + + def test_5117(self): + """Test deletion by pattern.""" + logger.info("Deleting index with pattern '^test_vector_index$'") + self.vector_index.create(replace=True) + self.vector_index.delete(force=True) + time.sleep(1) + actual = list(self.vector_index.list(index_name_pattern="^test_vector_index$")) + logger.info(f"List entries after pattern delete: {actual}") + assert len(actual) == 0 + logger.info("Pattern delete verified") + + def test_5118(self): + """Test delete with force=True option.""" + logger.info("Deleting index with force=True") + self.vector_index.create(replace=True) + self.vector_index.delete(force=True) + time.sleep(1) + self.assert_index_count("^test_vector_index$", 0) + logger.info("Force delete verified") + + def test_5119(self): + """Test delete with force=False option.""" + logger.info("Creating index before delete (force=False)") + self.vector_index.create(replace=True) + logger.info("Deleting index with force=False") + self.vector_index.delete(force=False) + time.sleep(1) + self.assert_index_count("^test_vector_index$", 0) + logger.info("Delete verified successfully with force=False") + + def test_5120(self): + """Test delete with force=False called twice in a row.""" + logger.info("Deleting index first time (force=False)") + self.vector_index.delete(force=False) + time.sleep(1) + self.assert_index_count("^test_vector_index$", 0) + logger.info("First delete succeeded") + logger.info("Attempting second delete (expected to fail)") + with pytest.raises(Exception) as exc_info: + self.vector_index.delete(force=False) + assert "does not exist" in str(exc_info.value) + logger.info( + "Expected failure confirmed: %s", + exc_info.value, + ) + self.assert_index_count("^test_vector_index$", 0) + logger.info("Index still absent after failed second delete") + + def test_5121(self): + """Test delete include_data=False and force=False (leftover vectab)""" + logger.info("Deleting index with include_data=False and force=False") + self.vector_index.delete(include_data=False, force=False) + time.sleep(1) + self.assert_index_count("^test_vector_index", 0) + logger.info("Attempting to recreate index (expected to fail - leftover data)") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.create(replace=False) + assert "ORA-00955" in str(exc_info.value) + logger.info( + "Expected recreate failure confirmed: %s", + exc_info.value, + ) + logger.info("Cleaning up leftover vector table") + self.verify_and_cleanup_vectab("test_vector_index") + logger.info("Cleanup complete after include_data=False, force=False delete") diff --git a/tests/vector_index/test_5200_async_setindex_attributes.py b/tests/vector_index/test_5200_async_setindex_attributes.py new file mode 100644 index 0000000..4c0a41a --- /dev/null +++ b/tests/vector_index/test_5200_async_setindex_attributes.py @@ -0,0 +1,959 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai import AsyncVectorIndex, OracleVectorIndexAttributes +from select_ai import VectorIndexAttributes +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestAsyncSetVectorIndexAttributes") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def set_vec_params(request, vcidx_params): + request.cls.set_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, set_vec_params, test_env): + logger.info("=== Setting up TestAsyncSetVectorIndexAttributes class ===") + p = request.cls.set_vec_params + + assert await select_ai.async_is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.objstore_cred = "OBJSTORE_CRED" + request.cls.reconnect_params = test_env.connect_params() + + logger.info("Fetching credential secrets and OCI configuration...") + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + request.cls.index_name = "test_vector_index_attr" + vi_attrs = OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED", + ) + request.cls.vector_index_attributes = vi_attrs + + vi = AsyncVectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile, + ) + await vi.create(replace=True) + + created_indexes = [ + idx.index_name async for idx in AsyncVectorIndex.list() + ] + assert request.cls.index_name.upper() in created_indexes, ( + f"VectorIndex {request.cls.index_name} was not created" + ) + + yield + + logger.info("=== Tearing down TestAsyncSetVectorIndexAttributes class ===") + try: + vector_index = AsyncVectorIndex(index_name=request.cls.index_name) + await vector_index.delete(force=True) + except Exception as exc: + logger.warning("Warning: drop vector index failed: %s", exc) + + try: + await request.cls.profile.delete() + except Exception as exc: + logger.warning("profile.delete() raised %s unexpectedly.", exc) + + await request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +async def vector_index_state(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + indexes = [ + idx async for idx in AsyncVectorIndex.list( + index_name_pattern=request.cls.index_name + ) + ] + if not indexes: + logger.info( + "Vector index %s missing; recreating baseline test state.", + request.cls.index_name, + ) + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile( + profile_name=request.cls.profile.profile_name + ) + await AsyncVectorIndex( + index_name=request.cls.index_name, + attributes=request.cls.vector_index_attributes, + description="Test vector index", + profile=request.cls.profile, + ).create(replace=True) + indexes = [ + idx async for idx in AsyncVectorIndex.list( + index_name_pattern=request.cls.index_name + ) + ] + request.cls.async_vector_index = indexes[0] + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("set_vec_params", "setup_and_teardown") +class TestAsyncSetVectorIndexAttributes: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing native credential params for: %s", cred_name) + p = cls.set_vec_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing basic credential params for: %s", cred_name) + p = cls.set_vec_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"], + ) + + @classmethod + async def create_credential( + cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED" + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + + genai_credential = cls.get_native_cred_param(genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + + p = cls.set_vec_params + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("Credentials created.") + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + p = cls.set_vec_params + return await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ), + ), + description="OCI GENAI Profile", + replace=True, + ) + + @classmethod + async def delete_profile(cls, profile): + return await profile.delete() + + @classmethod + async def delete_credential(cls): + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + + async def test_5201(self): + """Update 'match_limit' attribute.""" + logger.info("Testing update of 'match_limit' attribute...") + await self.async_vector_index.set_attribute("match_limit", 10) + attrs = await self.async_vector_index.get_attributes() + logger.info("Updated match_limit: %s", attrs.match_limit) + assert attrs.match_limit == 10 + logger.info("Match limit update verified successfully.") + + async def test_5202(self): + """Update 'similarity_threshold' attribute.""" + logger.info("Testing update of 'similarity_threshold' attribute...") + await self.async_vector_index.set_attribute( + "similarity_threshold", 0.8 + ) + attrs = await self.async_vector_index.get_attributes() + logger.info( + "Updated similarity_threshold: %s", attrs.similarity_threshold + ) + assert attrs.similarity_threshold == 0.8 + logger.info("Similarity threshold update verified successfully.") + + async def test_5203(self): + """Update multiple attributes with VectorIndexAttributes object.""" + logger.info( + "Testing update of multiple attributes via " + "VectorIndexAttributes object..." + ) + update_attrs = VectorIndexAttributes( + match_limit=5, + similarity_threshold=0.5, + location=self.embedding_location, + refresh_rate=40, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for restricted attribute update: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + logger.info("Restricted multi-attribute update rejected as expected.") + + async def test_5204(self): + """Repeated update of the same attribute 'similarity_threshold'.""" + logger.info( + "Testing repeated update of the same attribute " + "'similarity_threshold'..." + ) + await self.async_vector_index.set_attribute( + "similarity_threshold", 0.8 + ) + await self.async_vector_index.set_attribute( + "similarity_threshold", 0.5 + ) + attrs = await self.async_vector_index.get_attributes() + logger.info( + "Final similarity_threshold value: %s", + attrs.similarity_threshold, + ) + assert attrs.similarity_threshold == 0.5 + logger.info("Repeated attribute update verified successfully.") + + async def test_5205(self): + """Update 'match_limit' with maximum allowed value.""" + logger.info( + "Testing update of 'match_limit' with maximum allowed value..." + ) + max_limit = 8192 + await self.async_vector_index.set_attribute("match_limit", max_limit) + attrs = await self.async_vector_index.get_attributes() + logger.info("Set match_limit to: %s", attrs.match_limit) + assert attrs.match_limit == max_limit + logger.info("Max value for match_limit verified successfully.") + + async def test_5206(self): + """Update match_limit with minimum value.""" + logger.info("Testing update of match_limit with minimum value...") + min_limit = 1 + await self.async_vector_index.set_attribute("match_limit", min_limit) + logger.info( + "Set match_limit to %s, fetching attributes for verification...", + min_limit, + ) + attrs = await self.async_vector_index.get_attributes() + assert attrs.match_limit == min_limit + logger.info("match_limit minimum value update verified successfully.") + + async def test_5207(self): + """Update match_limit with zero value.""" + logger.info("Testing update of match_limit with zero value...") + min_limit = 0 + await self.async_vector_index.set_attribute("match_limit", min_limit) + logger.info("Fetching attributes to verify zero value update...") + attrs = await self.async_vector_index.get_attributes() + assert attrs.match_limit == min_limit + logger.info("match_limit zero value update verified successfully.") + + async def test_5208(self): + """Update profile_name with temporary profile.""" + temp_profile_name = "vector_ai_profile_temp" + temp_profile = await self.create_profile(profile_name=temp_profile_name) + logger.info("Temporary profile created: %s", temp_profile_name) + await self.async_vector_index.set_attribute( + "profile_name", temp_profile_name + ) + logger.info( + "Set profile_name to %s, fetching attributes...", + temp_profile_name, + ) + attrs = await self.async_vector_index.get_attributes() + logger.info( + "VectorIndex attributes after profile update: %s", + attrs.__dict__, + ) + assert attrs.profile_name == temp_profile_name + vec_index = await AsyncVectorIndex.fetch(self.index_name) + logger.info( + "Persisted VectorIndex after profile update: %s", + vec_index.__dict__, + ) + assert attrs.profile_name == vec_index.profile.profile_name + logger.info("Persisted VectorIndex reflects updated profile correctly.") + await self.delete_profile(temp_profile) + logger.info("Temporary profile deleted: %s", temp_profile_name) + await self.async_vector_index.set_attribute( + "profile_name", self.profile.profile_name + ) + attrs = await self.async_vector_index.get_attributes() + logger.info("VectorIndex reset to default profile, verifying...") + assert attrs.profile_name == self.profile.profile_name + logger.info("VectorIndex profile reset verified successfully.") + + async def test_5209(self): + """Update profile_name and then delete profile scenario.""" + logger.info( + "Testing update of profile_name followed by delete scenario..." + ) + temp_profile_name = "vector_ai_profile_temp" + temp_profile = await self.create_profile(profile_name=temp_profile_name) + logger.info("Temporary profile created: %s", temp_profile_name) + await self.async_vector_index.set_attribute( + "profile_name", temp_profile_name + ) + logger.info( + "Set profile_name to %s, verifying update...", + temp_profile_name, + ) + attrs = await self.async_vector_index.get_attributes() + assert attrs.profile_name == temp_profile_name + vec_index = await AsyncVectorIndex.fetch(self.index_name) + logger.info( + "Persisted VectorIndex after profile update: %s", + vec_index.__dict__, + ) + assert attrs.profile_name == vec_index.profile.profile_name + await self.delete_profile(temp_profile) + logger.info("Temporary profile deleted: %s", temp_profile_name) + logger.info( + "Verifying VectorIndex retains deleted profile name reference..." + ) + vec_index = await AsyncVectorIndex.fetch(self.index_name) + attrs = await vec_index.get_attributes() + assert attrs.profile_name == temp_profile_name + logger.info( + "VectorIndex still references deleted profile name as expected." + ) + await self.async_vector_index.set_attribute( + "profile_name", self.profile.profile_name + ) + attrs = await self.async_vector_index.get_attributes() + logger.info( + "Reset VectorIndex profile to default: %s", attrs.__dict__ + ) + assert attrs.profile_name == self.profile.profile_name + logger.info("Profile reset after delete verified successfully.") + + async def test_5210(self): + """Deleted profile leaves a stale profile_name on VectorIndex.""" + logger.info("Testing deleted profile behavior via AsyncVectorIndex fetch...") + temp_profile_name = "vector_ai_profile_temp" + temp_profile = await self.create_profile(profile_name=temp_profile_name) + await self.async_vector_index.set_attribute( + "profile_name", temp_profile_name + ) + await self.delete_profile(temp_profile) + vec_index = await AsyncVectorIndex.fetch(self.index_name) + attrs = await vec_index.get_attributes() + logger.info( + "Fetched AsyncVectorIndex after profile delete: profile=%s attrs=%s", + vec_index.profile, + attrs.__dict__, + ) + assert vec_index.profile is None + assert attrs.profile_name == temp_profile_name + await self.async_vector_index.set_attribute( + "profile_name", self.profile.profile_name + ) + logger.info("Deleted profile behavior verified successfully.") + + async def test_5211(self): + """Update refresh_rate attribute.""" + logger.info("Testing update of refresh_rate attribute...") + await self.async_vector_index.set_attribute("refresh_rate", 30) + attrs = await self.async_vector_index.get_attributes() + assert attrs.refresh_rate == 30 + + async def test_5212(self): + """Update object_storage_credential_name, handle pipeline.""" + logger.info( + "Testing update of object_storage_credential_name " + "with pipeline handling..." + ) + attrs = await self.async_vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info("Retrieved pipeline name: %s", pipeline_name) + logger.info("Stopping pipeline: %s", pipeline_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.stop_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Pipeline '%s' stopped successfully.", pipeline_name) + objstore_credential = self.get_cred_param("TEMP_OBJSTORE_CRED") + logger.info( + "Creating temporary Object Store credential: TEMP_OBJSTORE_CRED" + ) + try: + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("TEMP_OBJSTORE_CRED created successfully.") + except Exception as exc: + raise AssertionError( + "create_credential() raised an unexpected exception: " + f"{exc}" + ) + logger.info("Updating vector index with TEMP_OBJSTORE_CRED...") + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + "TEMP_OBJSTORE_CRED", + ) + attrs = await self.async_vector_index.get_attributes() + logger.info("Updated credential: %s", attrs.object_storage_credential_name) + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info("Deleting temporary credential: TEMP_OBJSTORE_CRED") + try: + await select_ai.async_delete_credential( + "TEMP_OBJSTORE_CRED", force=True + ) + logger.info("TEMP_OBJSTORE_CRED deleted successfully.") + except Exception as exc: + pytest.fail( + f"delete_credential() raised unexpected exception: {exc}" + ) + logger.info("Restoring original Object Store credential: OBJSTORE_CRED") + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + "OBJSTORE_CRED", + ) + logger.info("Restarting pipeline: %s", pipeline_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.start_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Pipeline '%s' restarted successfully.", pipeline_name) + attrs = await self.async_vector_index.get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + logger.info("Object Store credential restored successfully.") + + async def test_5213(self): + """Update object_storage_credential_name with delete handling.""" + logger.info( + "Testing update of object_storage_credential_name " + "with delete handling..." + ) + attrs = await self.async_vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info("Retrieved pipeline name: %s", pipeline_name) + logger.info("Stopping pipeline: %s", pipeline_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.stop_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Pipeline '%s' stopped successfully.", pipeline_name) + objstore_credential = self.get_cred_param("TEMP_OBJSTORE_CRED") + logger.info( + "Creating temporary Object Store credential: TEMP_OBJSTORE_CRED" + ) + try: + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("TEMP_OBJSTORE_CRED created successfully.") + except Exception as exc: + raise AssertionError( + "create_credential() raised an unexpected exception: " + f"{exc}" + ) + logger.info("Updating vector index with TEMP_OBJSTORE_CRED...") + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + "TEMP_OBJSTORE_CRED", + ) + attrs = await self.async_vector_index.get_attributes() + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info( + "Credential updated to: %s", attrs.object_storage_credential_name + ) + logger.info("Deleting temporary credential: TEMP_OBJSTORE_CRED") + try: + await select_ai.async_delete_credential( + "TEMP_OBJSTORE_CRED", force=True + ) + logger.info("TEMP_OBJSTORE_CRED deleted successfully.") + except Exception as exc: + pytest.fail( + f"delete_credential() raised unexpected exception: {exc}" + ) + logger.info( + "Verifying that VectorIndex retains deleted credential reference..." + ) + vec_index = await AsyncVectorIndex.fetch(self.index_name) + attrs = await vec_index.get_attributes() + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info( + "VectorIndex still references deleted credential name as expected." + ) + logger.info("Restoring original Object Store credential: OBJSTORE_CRED") + await select_ai.async_create_credential( + credential=self.get_cred_param("OBJSTORE_CRED"), + replace=True, + ) + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + "OBJSTORE_CRED", + ) + logger.info("Restarting pipeline: %s", pipeline_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.start_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Pipeline '%s' restarted successfully.", pipeline_name) + attrs = await self.async_vector_index.get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + logger.info( + "Object Store credential restoration after delete " + "verified successfully." + ) + + async def test_5214(self): + """Deleted credential leaves VectorIndex unusable until restored.""" + logger.info("Testing missing credential behavior via AsyncVectorIndex create...") + temp_credential_name = "TEMP_OBJSTORE_CRED" + attrs = await self.async_vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info("Stopping pipeline: %s", pipeline_name) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.stop_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + await select_ai.async_create_credential( + credential=self.get_cred_param(temp_credential_name), + replace=True, + ) + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + temp_credential_name, + ) + await select_ai.async_delete_credential(temp_credential_name, force=True) + failing_index = AsyncVectorIndex( + index_name="test_vector_index_attr_missing_cred", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name=temp_credential_name, + ), + description="Missing credential test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await failing_index.create(replace=True) + logger.info( + "Expected DatabaseError raised for deleted credential reference: %s", + exc_info.value, + ) + assert "ORA-20004" in str(exc_info.value) + assert temp_credential_name in str(exc_info.value) + await select_ai.async_create_credential( + credential=self.get_cred_param("OBJSTORE_CRED"), + replace=True, + ) + await self.async_vector_index.set_attribute( + "object_storage_credential_name", + "OBJSTORE_CRED", + ) + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "BEGIN dbms_cloud_pipeline.start_pipeline(" + "pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Deleted credential behavior verified successfully.") + + async def test_5215(self): + """Update multiple attributes together.""" + logger.info("Testing update of multiple attributes together...") + updates = { + "refresh_rate": 50, + "similarity_threshold": 0.8, + "match_limit": 10, + } + for field, value in updates.items(): + logger.info("Updating %s to %s...", field, value) + await self.async_vector_index.set_attribute(field, value) + attrs = await self.async_vector_index.get_attributes() + logger.info("Fetched attributes after updates: %s", attrs.__dict__) + assert attrs.refresh_rate == updates["refresh_rate"] + assert attrs.similarity_threshold == updates["similarity_threshold"] + assert attrs.match_limit == updates["match_limit"] + logger.info("All multiple attribute updates verified successfully.") + + async def test_5216(self): + """Update description (should raise DatabaseError).""" + logger.info( + "Testing update of description attribute " + "(should raise DatabaseError)..." + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute( + "description", "updated description" + ) + assert "ORA-20048" in str(exc_info.value) + logger.info( + "DatabaseError correctly raised for invalid description update." + ) + + async def test_5217(self): + """Update pipeline_name (should raise DatabaseError).""" + logger.info( + "Testing update of pipeline_name (expected DatabaseError)..." + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute( + "pipeline_name", "test_pipeline" + ) + assert "ORA-20048" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + assert attrs.pipeline_name == "TEST_VECTOR_INDEX_ATTR$VECPIPELINE" + logger.info( + "Pipeline update correctly raised error and original value retained." + ) + + async def test_5218(self): + """Update chunk_size (should fail).""" + logger.info( + "Testing update of chunk_size (should fail with ORA-20047)..." + ) + attrs = await self.async_vector_index.get_attributes() + original_chunk_size = attrs.chunk_size + logger.info("Current attributes: %s", attrs.__dict__) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute("chunk_size", 2048) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + assert attrs.chunk_size == original_chunk_size + logger.info( + "chunk_size update prevented successfully; original value " + "verified." + ) + + async def test_5219(self): + """Update chunk_overlap (should fail).""" + logger.info( + "Testing update of chunk_overlap (should fail with ORA-20047)..." + ) + original_chunk_overlap = ( + await self.async_vector_index.get_attributes() + ).chunk_overlap + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute("chunk_overlap", 256) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + assert attrs.chunk_overlap == original_chunk_overlap + logger.info( + "chunk_overlap update prevented successfully; original value " + "verified." + ) + + async def test_5220(self): + """Update vector_distance_metric (should fail).""" + logger.info( + "Testing update of vector_distance_metric " + "(should fail with ORA-20047)..." + ) + original_distance_metric = ( + await self.async_vector_index.get_attributes() + ).vector_distance_metric + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute( + "vector_distance_metric", "EUCLIDEAN" + ) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + assert attrs.vector_distance_metric == original_distance_metric + logger.info("vector_distance_metric update prevented successfully.") + + async def test_5221(self): + """Partial update with VectorIndexAttributes object.""" + logger.info( + "Testing partial update with VectorIndexAttributes object..." + ) + update_attrs = VectorIndexAttributes(match_limit=20, chunk_size=2048) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for partial restricted update: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + logger.info("Attributes after update attempt: %s", attrs.__dict__) + logger.info("Partial restricted update rejected as expected.") + + async def test_5222(self): + """Update with invalid attribute combinations.""" + logger.info("Testing update with invalid attribute combinations...") + update_attrs = VectorIndexAttributes( + chunk_size=2048, + chunk_overlap=256, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for invalid attribute combination: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + logger.info("Attributes after invalid update: %s", attrs.__dict__) + logger.info("Invalid update combination rejected as expected.") + + async def test_5223(self): + """Update location (should raise ORA-20047).""" + logger.info("Testing update of location (expected ORA-20047)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute( + "location", self.embedding_location + ) + assert "ORA-20047" in str(exc_info.value) + attrs = await self.async_vector_index.get_attributes() + assert attrs.location == self.embedding_location + logger.info("Location update prevented successfully.") + + async def test_5224(self): + """Update using profile object directly.""" + logger.info( + "Testing update of vector index using profile object directly..." + ) + temp_profile_name = "vector_ai_profile_temp" + temp_profile = await self.create_profile(profile_name=temp_profile_name) + logger.info("Created temporary profile: %s", temp_profile_name) + try: + await self.async_vector_index.set_attribute("profile", temp_profile) + except oracledb.NotSupportedError as exc: + logger.info("Expected NotSupportedError caught: %s", exc) + except Exception as exc: + raise AssertionError(f"Unexpected exception: {exc}") + attrs = await self.async_vector_index.get_attributes() + assert attrs.profile_name in [ + self.profile.profile_name, + temp_profile_name, + ] + logger.info( + "Attributes after attempted profile object update: %s", + attrs.__dict__, + ) + try: + await self.delete_profile(temp_profile) + logger.info( + "Temporary profile '%s' deleted successfully.", + temp_profile_name, + ) + except Exception as exc: + logger.warning("Profile cleanup failed: %s", exc) + + async def test_5225(self): + """Update with invalid attribute name.""" + logger.info("Testing update with invalid attribute name...") + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.set_attribute("invalid_attr", "value") + logger.info("Invalid attribute name correctly raised DatabaseError.") + + async def test_5226(self): + """Update with invalid type for integer field.""" + logger.info("Testing update with invalid type for integer field...") + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.set_attribute( + "chunk_size", "not_an_int" + ) + logger.info("Invalid integer type correctly raised DatabaseError.") + + async def test_5227(self): + """Update with invalid type for float field.""" + logger.info("Testing update with invalid type for float field...") + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.set_attribute( + "similarity_threshold", "NaN" + ) + logger.info("Invalid float type correctly raised DatabaseError.") + + async def test_5228(self): + """Update with invalid enum value for vector_distance_metric.""" + logger.info( + "Testing update with invalid enum value " + "for vector_distance_metric..." + ) + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.set_attribute( + "vector_distance_metric", "INVALID" + ) + logger.info("Invalid enum value correctly raised DatabaseError.") + + async def test_5229(self): + """Update on nonexistent vector index.""" + logger.info("Testing update on nonexistent vector index...") + temp_index = AsyncVectorIndex(index_name="does_not_exist") + with pytest.raises(AttributeError): + await temp_index.set_attribute("chunk_size", 512) + logger.info( + "Nonexistent index update correctly raised AttributeError." + ) + + async def test_5230(self): + """Update with None as attribute name (should fail).""" + logger.info("Testing update with None as attribute name...") + with pytest.raises(TypeError): + await self.async_vector_index.set_attribute(None, 128) + logger.info("None attribute name correctly raised TypeError.") + + async def test_5231(self): + """Update with None as attribute name for second time.""" + logger.info( + "Testing update with None as attribute name for second time..." + ) + with pytest.raises(TypeError): + await self.async_vector_index.set_attribute(None, 128) + logger.info("None attribute name correctly raised TypeError.") + + async def test_5232(self): + """Update with invalid attributes object (non-object input).""" + logger.info( + "Testing update with invalid attributes object " + "(non-object input)..." + ) + with pytest.raises(AttributeError): + await self.async_vector_index.set_attributes( + attributes="not_an_object" + ) + logger.info( + "Invalid attributes object correctly raised AttributeError." + ) + + async def test_5233(self): + """Update after disconnecting from the database.""" + logger.info("Testing update after disconnecting from the database...") + await select_ai.async_disconnect() + with pytest.raises(DatabaseNotConnectedError): + await self.async_vector_index.set_attribute("chunk_size", 256) + logger.info( + "DatabaseNotConnectedError correctly raised after disconnect." + ) + logger.info("Reconnecting for further tests...") + await select_ai.async_connect(**self.reconnect_params) + assert await select_ai.async_is_connected(), ( + "Connection to DB failed" + ) + logger.info("Reconnection successful.") + + async def test_5234(self): + """Update with None as attribute value (should fail).""" + logger.info("Testing update with None as attribute value...") + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.set_attribute("chunk_size", None) + logger.info("None value correctly raised DatabaseError.") + + async def test_5235(self): + """Concurrent updates on the same vector index.""" + logger.info("Testing concurrent updates on the same vector index...") + index1 = await AsyncVectorIndex.fetch(self.index_name) + index2 = await AsyncVectorIndex.fetch(self.index_name) + await index1.set_attribute("match_limit", 15) + await index2.set_attribute("match_limit", 20) + attrs = await self.async_vector_index.get_attributes() + logger.info( + "Final match_limit value after concurrent updates: %s", + attrs.match_limit, + ) + assert attrs.match_limit in [15, 20] + logger.info( + "Concurrent update behavior verified (last writer wins)." + ) + + async def test_5236(self): + """Update with excessively large attribute value.""" + logger.info("Testing update with excessively large attribute value...") + long_name = "X" * 500 + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.async_vector_index.set_attribute("profile_name", long_name) + assert "ORA-20048" in str(exc_info.value) + logger.info("Large attribute value correctly raised DatabaseError.") + + async def test_5237(self): + """Repeated updates to match_limit (last writer wins).""" + logger.info("Testing repeated updates to match_limit...") + for i in range(5): + await self.async_vector_index.set_attribute("match_limit", i * 10) + logger.info("Set match_limit to %s", i * 10) + attrs = await self.async_vector_index.get_attributes() + assert attrs.match_limit == 40 + logger.info("Repeated update test passed; last value retained.") + + async def test_5238(self): + """Update attribute after delete and recreate of vector index.""" + logger.info( + "Testing attribute update after deleting and recreating the " + "vector index..." + ) + await self.async_vector_index.delete(force=True) + logger.info("Vector index deleted.") + self.async_vector_index = AsyncVectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + await self.async_vector_index.create(replace=True) + logger.info("Vector index recreated.") + await self.async_vector_index.set_attribute("match_limit", 10) + attrs = await self.async_vector_index.get_attributes() + assert attrs.match_limit == 10 + logger.info("Update after recreation verified successfully.") diff --git a/tests/vector_index/test_5200_setindex_attributes.py b/tests/vector_index/test_5200_setindex_attributes.py new file mode 100644 index 0000000..e87ddac --- /dev/null +++ b/tests/vector_index/test_5200_setindex_attributes.py @@ -0,0 +1,801 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + + +import logging +import os +import oracledb +import pytest +import select_ai +from select_ai import VectorIndex, VectorIndexAttributes, OracleVectorIndexAttributes +from select_ai.errors import DatabaseNotConnectedError + +logger = logging.getLogger("TestSetVectorIndexAttributes") + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + + +@pytest.fixture(scope="class") +def set_vec_params( + request, + vcidx_params, +): + request.cls.set_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, set_vec_params, oci_credential): + logger.info("=== Setting up TestSetVectorIndexAttributes class ===") + p = request.cls.set_vec_params + + # 'connect' fixture from base tests/conftest.py ensures DB connection exists. + # Do NOT disconnect here; let the session fixture own lifecycle. + assert select_ai.is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.objstore_cred = "OBJSTORE_CRED" + request.cls.shared_oci_credential = oci_credential + + logger.info("Fetching credential secrets and OCI configuration...") + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + request.cls.index_name = "test_vector_index_attr" + vi_attrs = OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED" + ) + request.cls.vector_index_attributes = vi_attrs + + vi = VectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile + ) + vi.create(replace=True) + + # Keep original validation intent (handle either classmethod or instance list()) + try: + created_indexes = [idx.index_name for idx in VectorIndex.list()] + except Exception: + created_indexes = [idx.index_name for idx in VectorIndex().list()] + assert request.cls.index_name.upper() in created_indexes, f"VectorIndex {request.cls.index_name} was not created" + + yield + + logger.info("=== Tearing down TestSetVectorIndexAttributes class ===") + try: + vector_index = VectorIndex(index_name=request.cls.index_name) + vector_index.delete(force=True) + except Exception as e: + logger.warning(f"Warning: drop vector index failed: {e}") + + try: + request.cls.profile.delete() + except Exception as e: + logger.warning(f"profile.delete() raised {e} unexpectedly.") + + request.cls.delete_credential() + request.cls.ensure_session_oci_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +@pytest.mark.usefixtures("set_vec_params", "setup_and_teardown") +class TestSetVectorIndexAttributes: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing native credential params for: {cred_name}") + p = cls.set_vec_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"] + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing basic credential params for: {cred_name}") + p = cls.set_vec_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"] + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info(f"Creating credentials: {genai_cred}, {objstore_cred}") + + genai_credential = cls.get_native_cred_param(genai_cred) + select_ai.create_credential(credential=genai_credential, replace=True) + + # Only create OBJSTORE_CRED if creds are provided in env + p = cls.set_vec_params + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("Credentials created.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + p = cls.set_vec_params + return select_ai.Profile( + profile_name=profile_name, + attributes=select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + ), + description="OCI GENAI Profile", + replace=True + ) + + @classmethod + def delete_profile(cls, profile): + return profile.delete() + + @classmethod + def delete_credential(cls): + try: + select_ai.delete_credential("GENAI_CRED", force=True) + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + @classmethod + def ensure_session_oci_credential(cls): + credential_name = cls.shared_oci_credential["credential_name"] + select_ai.create_credential( + credential={ + "credential_name": credential_name, + "user_ocid": cls.user_ocid, + "tenancy_ocid": cls.tenancy_ocid, + "private_key": cls.private_key, + "fingerprint": cls.fingerprint, + }, + replace=True, + ) + logger.info( + "Recreated shared OCI credential for session teardown: %s", + credential_name, + ) + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + vecidx = VectorIndex() + indexes = list(vecidx.list(index_name_pattern=self.index_name)) + if not indexes: + logger.info( + "Vector index %s missing; recreating baseline test state.", + self.index_name, + ) + self.create_credential() + self.profile = self.create_profile(profile_name=self.profile.profile_name) + VectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile, + ).create(replace=True) + indexes = list(vecidx.list(index_name_pattern=self.index_name)) + self.vector_index = indexes[0] + + def teardown_method(self, method): + logger.info(f"TearDown for {method.__name__}") + + def test_5201(self): + """Update 'match_limit' attribute.""" + logger.info("Testing update of 'match_limit' attribute...") + self.vector_index.set_attribute("match_limit", 10) + attrs = self.vector_index.get_attributes() + logger.info(f"Updated match_limit: {attrs.match_limit}") + assert attrs.match_limit == 10 + logger.info("Match limit update verified successfully.") + + def test_5202(self): + """Update 'similarity_threshold' attribute.""" + logger.info("Testing update of 'similarity_threshold' attribute...") + self.vector_index.set_attribute("similarity_threshold", 0.8) + attrs = self.vector_index.get_attributes() + logger.info(f"Updated similarity_threshold: {attrs.similarity_threshold}") + assert attrs.similarity_threshold == 0.8 + logger.info("Similarity threshold update verified successfully.") + + def test_5203(self): + """Update multiple attributes with VectorIndexAttributes object.""" + logger.info( + "Testing update of multiple attributes via " + "VectorIndexAttributes object..." + ) + update_attrs = VectorIndexAttributes( + match_limit=5, + similarity_threshold=0.5, + location=self.embedding_location, + refresh_rate=40, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for restricted attribute update: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + logger.info("Restricted multi-attribute update rejected as expected.") + + def test_5204(self): + """Repeated update of the same attribute 'similarity_threshold'.""" + logger.info("Testing repeated update of the same attribute 'similarity_threshold'...") + self.vector_index.set_attribute("similarity_threshold", 0.8) + self.vector_index.set_attribute("similarity_threshold", 0.5) + attrs = self.vector_index.get_attributes() + logger.info(f"Final similarity_threshold value: {attrs.similarity_threshold}") + assert attrs.similarity_threshold == 0.5 + logger.info("Repeated attribute update verified successfully.") + + def test_5205(self): + """Update 'match_limit' with maximum allowed value.""" + logger.info("Testing update of 'match_limit' with maximum allowed value...") + max_limit = 8192 + self.vector_index.set_attribute("match_limit", max_limit) + attrs = self.vector_index.get_attributes() + logger.info(f"Set match_limit to: {attrs.match_limit}") + assert attrs.match_limit == max_limit + logger.info("Max value for match_limit verified successfully.") + + def test_5206(self): + """Update match_limit with minimum value.""" + logger.info("Testing update of match_limit with minimum value...") + min_limit = 1 + self.vector_index.set_attribute("match_limit", min_limit) + logger.info(f"Set match_limit to {min_limit}, fetching attributes for verification...") + attrs = self.vector_index.get_attributes() + assert attrs.match_limit == min_limit + logger.info("match_limit minimum value update verified successfully.") + + def test_5207(self): + """Update match_limit with zero value.""" + logger.info("Testing update of match_limit with zero value...") + min_limit = 0 + self.vector_index.set_attribute("match_limit", min_limit) + logger.info("Fetching attributes to verify zero value update...") + attrs = self.vector_index.get_attributes() + assert attrs.match_limit == min_limit + logger.info("match_limit zero value update verified successfully.") + + def test_5208(self): + """Update profile_name with temporary profile.""" + temp_profile_name = "vector_ai_profile_temp" + temp_profile = self.create_profile(profile_name=temp_profile_name) + logger.info(f"Temporary profile created: {temp_profile_name}") + self.vector_index.set_attribute("profile_name", temp_profile_name) + logger.info(f"Set profile_name to {temp_profile_name}, fetching attributes...") + attrs = self.vector_index.get_attributes() + logger.info(f"VectorIndex attributes after profile update: {attrs.__dict__}") + assert attrs.profile_name == temp_profile_name + vecidx = VectorIndex() + vec_index = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + logger.info(f"Persisted VectorIndex after profile update: {vec_index.__dict__}") + assert attrs.profile_name == vec_index.profile.profile_name + logger.info("Persisted VectorIndex reflects updated profile correctly.") + self.delete_profile(temp_profile) + logger.info(f"Temporary profile deleted: {temp_profile_name}") + self.vector_index.set_attribute("profile_name", self.profile.profile_name) + attrs = self.vector_index.get_attributes() + logger.info("VectorIndex reset to default profile, verifying...") + assert attrs.profile_name == self.profile.profile_name + logger.info("VectorIndex profile reset verified successfully.") + + def test_5209(self): + """Update profile_name and then delete profile scenario.""" + logger.info( + "Testing update of profile_name followed by delete scenario..." + ) + temp_profile_name = "vector_ai_profile_temp" + temp_profile = self.create_profile(profile_name=temp_profile_name) + logger.info(f"Temporary profile created: {temp_profile_name}") + self.vector_index.set_attribute("profile_name", temp_profile_name) + logger.info(f"Set profile_name to {temp_profile_name}, verifying update...") + attrs = self.vector_index.get_attributes() + assert attrs.profile_name == temp_profile_name + vecidx = VectorIndex() + vec_index = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + logger.info(f"Persisted VectorIndex after profile update: {vec_index.__dict__}") + assert attrs.profile_name == vec_index.profile.profile_name + self.delete_profile(temp_profile) + logger.info(f"Temporary profile deleted: {temp_profile_name}") + logger.info( + "Verifying VectorIndex retains deleted profile name reference..." + ) + vecidx = VectorIndex() + vec_index = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + attrs = vec_index.get_attributes() + assert attrs.profile_name == temp_profile_name + logger.info( + "VectorIndex still references deleted profile name as expected." + ) + self.vector_index.set_attribute("profile_name", self.profile.profile_name) + attrs = self.vector_index.get_attributes() + logger.info(f"Reset VectorIndex profile to default: {attrs.__dict__}") + assert attrs.profile_name == self.profile.profile_name + logger.info("Profile reset after delete verified successfully.") + + def test_5210(self): + """Deleted profile leaves a stale profile_name on VectorIndex.""" + logger.info("Testing deleted profile behavior via VectorIndex fetch...") + temp_profile_name = "vector_ai_profile_temp" + temp_profile = self.create_profile(profile_name=temp_profile_name) + self.vector_index.set_attribute("profile_name", temp_profile_name) + self.delete_profile(temp_profile) + vec_index = VectorIndex.fetch(self.index_name) + attrs = vec_index.get_attributes() + logger.info( + "Fetched VectorIndex after profile delete: profile=%s attrs=%s", + vec_index.profile, + attrs.__dict__, + ) + assert vec_index.profile is None + assert attrs.profile_name == temp_profile_name + self.vector_index.set_attribute("profile_name", self.profile.profile_name) + logger.info("Deleted profile behavior verified successfully.") + + def test_5211(self): + """Update refresh_rate attribute.""" + logger.info("Testing update of refresh_rate attribute...") + self.vector_index.set_attribute("refresh_rate", 30) + attrs = self.vector_index.get_attributes() + assert attrs.refresh_rate == 30 + + def test_5212(self): + """Update object_storage_credential_name, handle pipeline.""" + logger.info("Testing update of object_storage_credential_name with pipeline handling...") + attrs = self.vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info(f"Retrieved pipeline name: {pipeline_name}") + logger.info(f"Stopping pipeline: {pipeline_name}") + with select_ai.cursor() as cursor: + cursor.execute("BEGIN dbms_cloud_pipeline.stop_pipeline(pipeline_name => :1); END;", [pipeline_name]) + logger.info(f"Pipeline '{pipeline_name}' stopped successfully.") + objstore_credential = self.get_cred_param("TEMP_OBJSTORE_CRED") + logger.info("Creating temporary Object Store credential: TEMP_OBJSTORE_CRED") + try: + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("TEMP_OBJSTORE_CRED created successfully.") + except Exception as e: + raise AssertionError(f"create_credential() raised an unexpected exception: {e}") + logger.info("Updating vector index with TEMP_OBJSTORE_CRED...") + self.vector_index.set_attribute("object_storage_credential_name", "TEMP_OBJSTORE_CRED") + attrs = self.vector_index.get_attributes() + logger.info(f"Updated credential: {attrs.object_storage_credential_name}") + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info("Deleting temporary credential: TEMP_OBJSTORE_CRED") + try: + select_ai.delete_credential("TEMP_OBJSTORE_CRED", force=True) + logger.info("TEMP_OBJSTORE_CRED deleted successfully.") + except Exception as e: + pytest.fail(f"delete_credential() raised unexpected exception: {e}") + logger.info("Restoring original Object Store credential: OBJSTORE_CRED") + self.vector_index.set_attribute("object_storage_credential_name", "OBJSTORE_CRED") + logger.info(f"Restarting pipeline: {pipeline_name}") + with select_ai.cursor() as cursor: + cursor.execute("BEGIN dbms_cloud_pipeline.start_pipeline(pipeline_name => :1); END;", [pipeline_name]) + logger.info(f"Pipeline '{pipeline_name}' restarted successfully.") + attrs = self.vector_index.get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + logger.info("Object Store credential restored successfully.") + + def test_5213(self): + """Update object_storage_credential_name with delete handling.""" + logger.info("Testing update of object_storage_credential_name with delete handling...") + attrs = self.vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info(f"Retrieved pipeline name: {pipeline_name}") + logger.info(f"Stopping pipeline: {pipeline_name}") + with select_ai.cursor() as cursor: + cursor.execute("BEGIN dbms_cloud_pipeline.stop_pipeline(pipeline_name => :1); END;", [pipeline_name]) + logger.info(f"Pipeline '{pipeline_name}' stopped successfully.") + objstore_credential = self.get_cred_param("TEMP_OBJSTORE_CRED") + logger.info("Creating temporary Object Store credential: TEMP_OBJSTORE_CRED") + try: + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("TEMP_OBJSTORE_CRED created successfully.") + except Exception as e: + raise AssertionError(f"create_credential() raised an unexpected exception: {e}") + logger.info("Updating vector index with TEMP_OBJSTORE_CRED...") + self.vector_index.set_attribute("object_storage_credential_name", "TEMP_OBJSTORE_CRED") + attrs = self.vector_index.get_attributes() + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info(f"Credential updated to: {attrs.object_storage_credential_name}") + logger.info("Deleting temporary credential: TEMP_OBJSTORE_CRED") + try: + select_ai.delete_credential("TEMP_OBJSTORE_CRED", force=True) + logger.info("TEMP_OBJSTORE_CRED deleted successfully.") + except Exception as e: + pytest.fail(f"delete_credential() raised unexpected exception: {e}") + logger.info( + "Verifying that VectorIndex retains deleted credential reference..." + ) + vecidx = VectorIndex() + vec_index = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + attrs = vec_index.get_attributes() + assert attrs.object_storage_credential_name == "TEMP_OBJSTORE_CRED" + logger.info( + "VectorIndex still references deleted credential name as expected." + ) + logger.info("Restoring original Object Store credential: OBJSTORE_CRED") + self.get_cred_param("OBJSTORE_CRED") + select_ai.create_credential( + credential=self.get_cred_param("OBJSTORE_CRED"), + replace=True, + ) + self.vector_index.set_attribute("object_storage_credential_name", "OBJSTORE_CRED") + logger.info(f"Restarting pipeline: {pipeline_name}") + with select_ai.cursor() as cursor: + cursor.execute("BEGIN dbms_cloud_pipeline.start_pipeline(pipeline_name => :1); END;", [pipeline_name]) + logger.info(f"Pipeline '{pipeline_name}' restarted successfully.") + attrs = self.vector_index.get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + logger.info("Object Store credential restoration after delete verified successfully.") + + def test_5214(self): + """Deleted credential leaves VectorIndex unusable until restored.""" + logger.info("Testing missing credential behavior via VectorIndex create...") + temp_credential_name = "TEMP_OBJSTORE_CRED" + attrs = self.vector_index.get_attributes() + pipeline_name = attrs.pipeline_name + logger.info("Stopping pipeline: %s", pipeline_name) + with select_ai.cursor() as cursor: + cursor.execute( + "BEGIN dbms_cloud_pipeline.stop_pipeline(pipeline_name => :1); END;", + [pipeline_name], + ) + select_ai.create_credential( + credential=self.get_cred_param(temp_credential_name), + replace=True, + ) + self.vector_index.set_attribute( + "object_storage_credential_name", + temp_credential_name, + ) + select_ai.delete_credential(temp_credential_name, force=True) + failing_index = VectorIndex( + index_name="test_vector_index_attr_missing_cred", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name=temp_credential_name, + ), + description="Missing credential test vector index", + profile=self.profile, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + failing_index.create(replace=True) + logger.info( + "Expected DatabaseError raised for deleted credential reference: %s", + exc_info.value, + ) + assert "ORA-20004" in str(exc_info.value) + assert temp_credential_name in str(exc_info.value) + select_ai.create_credential( + credential=self.get_cred_param("OBJSTORE_CRED"), + replace=True, + ) + self.vector_index.set_attribute( + "object_storage_credential_name", + "OBJSTORE_CRED", + ) + with select_ai.cursor() as cursor: + cursor.execute( + "BEGIN dbms_cloud_pipeline.start_pipeline(pipeline_name => :1); END;", + [pipeline_name], + ) + logger.info("Deleted credential behavior verified successfully.") + + def test_5215(self): + """Update multiple attributes together.""" + logger.info("Testing update of multiple attributes together...") + updates = { + "refresh_rate": 50, + "similarity_threshold": 0.8, + "match_limit": 10 + } + for field, value in updates.items(): + logger.info(f"Updating {field} to {value}...") + self.vector_index.set_attribute(field, value) + attrs = self.vector_index.get_attributes() + logger.info(f"Fetched attributes after updates: {attrs.__dict__}") + assert attrs.refresh_rate == updates["refresh_rate"] + assert attrs.similarity_threshold == updates["similarity_threshold"] + assert attrs.match_limit == updates["match_limit"] + logger.info("All multiple attribute updates verified successfully.") + + def test_5216(self): + """Update description (should raise DatabaseError).""" + logger.info("Testing update of description attribute (should raise DatabaseError)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("description", "updated description") + assert "ORA-20048" in str(exc_info.value) + logger.info("DatabaseError correctly raised for invalid description update.") + + def test_5217(self): + """Update pipeline_name (should raise DatabaseError).""" + logger.info("Testing update of pipeline_name (expected DatabaseError)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("pipeline_name", "test_pipeline") + assert "ORA-20048" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + assert attrs.pipeline_name == "TEST_VECTOR_INDEX_ATTR$VECPIPELINE" + logger.info("Pipeline update correctly raised error and original value retained.") + + def test_5218(self): + """Update chunk_size (should fail).""" + logger.info("Testing update of chunk_size (should fail with ORA-20047)...") + attrs = self.vector_index.get_attributes() + original_chunk_size = attrs.chunk_size + logger.info(f"Current attributes: {attrs.__dict__}") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("chunk_size", 2048) + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + assert attrs.chunk_size == original_chunk_size + logger.info("chunk_size update prevented successfully; original value verified.") + + def test_5219(self): + """Update chunk_overlap (should fail).""" + logger.info("Testing update of chunk_overlap (should fail with ORA-20047)...") + original_chunk_overlap = self.vector_index.get_attributes().chunk_overlap + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("chunk_overlap", 256) + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + assert attrs.chunk_overlap == original_chunk_overlap + logger.info("chunk_overlap update prevented successfully; original value verified.") + + def test_5220(self): + """Update vector_distance_metric (should fail).""" + logger.info("Testing update of vector_distance_metric (should fail with ORA-20047)...") + original_distance_metric = ( + self.vector_index.get_attributes().vector_distance_metric + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("vector_distance_metric", "EUCLIDEAN") + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + assert attrs.vector_distance_metric == original_distance_metric + logger.info("vector_distance_metric update prevented successfully.") + + def test_5221(self): + """Partial update with VectorIndexAttributes object.""" + logger.info("Testing partial update with VectorIndexAttributes object...") + update_attrs = VectorIndexAttributes(match_limit=20, chunk_size=2048) + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for partial restricted update: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + logger.info(f"Attributes after update attempt: {attrs.__dict__}") + logger.info("Partial restricted update rejected as expected.") + + def test_5222(self): + """Update with invalid attribute combinations.""" + logger.info("Testing update with invalid attribute combinations...") + update_attrs = VectorIndexAttributes(chunk_size=2048, chunk_overlap=256) + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attributes(attributes=update_attrs) + logger.info( + "Expected DatabaseError raised for invalid attribute combination: %s", + exc_info.value, + ) + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + logger.info(f"Attributes after invalid update: {attrs.__dict__}") + logger.info("Invalid update combination rejected as expected.") + + def test_5223(self): + """Update location (should raise ORA-20047).""" + logger.info("Testing update of location (expected ORA-20047)...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("location", self.embedding_location) + assert "ORA-20047" in str(exc_info.value) + attrs = self.vector_index.get_attributes() + assert attrs.location == self.embedding_location + logger.info("Location update prevented successfully.") + + def test_5224(self): + """Update using profile object directly.""" + logger.info("Testing update of vector index using profile object directly...") + temp_profile_name = "vector_ai_profile_temp" + temp_profile = self.create_profile(profile_name=temp_profile_name) + logger.info(f"Created temporary profile: {temp_profile_name}") + try: + self.vector_index.set_attribute("profile", temp_profile) + except oracledb.NotSupportedError as e: + logger.info(f"Expected NotSupportedError caught: {e}") + except Exception as e: + raise AssertionError(f"Unexpected exception: {e}") + attrs = self.vector_index.get_attributes() + assert attrs.profile_name in [self.profile.profile_name, temp_profile_name] + logger.info(f"Attributes after attempted profile object update: {attrs.__dict__}") + try: + self.delete_profile(temp_profile) + logger.info(f"Temporary profile '{temp_profile_name}' deleted successfully.") + except Exception as e: + logger.warning(f"Profile cleanup failed: {e}") + + def test_5225(self): + """Update with invalid attribute name.""" + logger.info("Testing update with invalid attribute name...") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.set_attribute("invalid_attr", "value") + logger.info("Invalid attribute name correctly raised DatabaseError.") + + def test_5226(self): + """Update with invalid type for integer field.""" + logger.info("Testing update with invalid type for integer field...") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.set_attribute("chunk_size", "not_an_int") + logger.info("Invalid integer type correctly raised DatabaseError.") + + def test_5227(self): + """Update with invalid type for float field.""" + logger.info("Testing update with invalid type for float field...") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.set_attribute("similarity_threshold", "NaN") + logger.info("Invalid float type correctly raised DatabaseError.") + + def test_5228(self): + """Update with invalid enum value for vector_distance_metric.""" + logger.info("Testing update with invalid enum value for vector_distance_metric...") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.set_attribute("vector_distance_metric", "INVALID") + logger.info("Invalid enum value correctly raised DatabaseError.") + + def test_5229(self): + """Update on nonexistent vector index.""" + logger.info("Testing update on nonexistent vector index...") + temp_index = VectorIndex(index_name="does_not_exist") + with pytest.raises(AttributeError): + temp_index.set_attribute("chunk_size", 512) + logger.info("Nonexistent index update correctly raised AttributeError.") + + def test_5230(self): + """Update with None as attribute name (should fail).""" + logger.info("Testing update with None as attribute name...") + with pytest.raises(TypeError): + self.vector_index.set_attribute(None, 128) + logger.info("None attribute name correctly raised TypeError.") + + def test_5231(self): + """Update with None as attribute name for second time.""" + logger.info("Testing update with None as attribute name for second time...") + with pytest.raises(TypeError): + self.vector_index.set_attribute(None, 128) + logger.info("None attribute name correctly raised TypeError.") + + def test_5232(self): + """Update with invalid attributes object (non-object input).""" + logger.info("Testing update with invalid attributes object (non-object input)...") + with pytest.raises(AttributeError): + self.vector_index.set_attributes(attributes="not_an_object") + logger.info("Invalid attributes object correctly raised AttributeError.") + + def test_5233(self): + """Update after disconnecting from the database.""" + logger.info("Testing update after disconnecting from the database...") + select_ai.disconnect() + with pytest.raises(DatabaseNotConnectedError): + self.vector_index.set_attribute("chunk_size", 256) + logger.info("DatabaseNotConnectedError correctly raised after disconnect.") + logger.info("Reconnecting for further tests...") + connect_kwargs = { + "user": self.user, + "password": self.password, + "dsn": self.dsn, + } + wallet_location = os.environ.get("PYSAI_TEST_WALLET_LOCATION") + wallet_password = os.environ.get("PYSAI_TEST_WALLET_PASSWORD") + if wallet_location: + connect_kwargs["config_dir"] = wallet_location + connect_kwargs["wallet_location"] = wallet_location + if wallet_password: + connect_kwargs["wallet_password"] = wallet_password + select_ai.connect(**connect_kwargs) + assert select_ai.is_connected(), "Connection to DB failed" + logger.info("Reconnection successful.") + + def test_5234(self): + """Update with None as attribute value (should fail).""" + logger.info("Testing update with None as attribute value...") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.set_attribute("chunk_size", None) + logger.info("None value correctly raised DatabaseError.") + + def test_5235(self): + """Concurrent updates on the same vector index.""" + logger.info("Testing concurrent updates on the same vector index...") + vecidx = VectorIndex() + index1 = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + index2 = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + index1.set_attribute("match_limit", 15) + index2.set_attribute("match_limit", 20) + attrs = self.vector_index.get_attributes() + logger.info(f"Final match_limit value after concurrent updates: {attrs.match_limit}") + assert attrs.match_limit in [15, 20] + logger.info("Concurrent update behavior verified (last writer wins).") + + def test_5236(self): + """Update with excessively large attribute value.""" + logger.info("Testing update with excessively large attribute value...") + long_name = "X" * 500 + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.vector_index.set_attribute("profile_name", long_name) + assert "ORA-20048" in str(exc_info.value) + logger.info("Large attribute value correctly raised DatabaseError.") + + def test_5237(self): + """Repeated updates to match_limit (last writer wins).""" + logger.info("Testing repeated updates to match_limit...") + for i in range(5): + self.vector_index.set_attribute("match_limit", i * 10) + logger.info(f"Set match_limit to {i * 10}") + attrs = self.vector_index.get_attributes() + assert attrs.match_limit == 40 + logger.info("Repeated update test passed; last value retained.") + + def test_5238(self): + """Update attribute after delete and recreate of vector index.""" + logger.info("Testing attribute update after deleting and recreating the vector index...") + self.vector_index.delete(force=True) + logger.info("Vector index deleted.") + self.vector_index = VectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + self.vector_index.create(replace=True) + logger.info("Vector index recreated.") + self.vector_index.set_attribute("match_limit", 10) + attrs = self.vector_index.get_attributes() + assert attrs.match_limit == 10 + logger.info("Update after recreation verified successfully.") diff --git a/tests/vector_index/test_5300_async_getindex_attributes.py b/tests/vector_index/test_5300_async_getindex_attributes.py new file mode 100644 index 0000000..91c74f7 --- /dev/null +++ b/tests/vector_index/test_5300_async_getindex_attributes.py @@ -0,0 +1,442 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import pytest +import select_ai +from select_ai import AsyncVectorIndex, OracleVectorIndexAttributes +from select_ai.errors import VectorIndexNotFoundError + +logger = logging.getLogger("TestAsyncGetVectorIndexAttributes") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def vector_attr_params(request, vcidx_params): + request.cls.vector_attr_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown( + request, + async_connect, + vector_attr_params, + test_env, +): + logger.info("=== Setting up TestAsyncGetVectorIndexAttributes class ===") + p = request.cls.vector_attr_params + + assert await select_ai.async_is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.reconnect_params = test_env.connect_params() + + logger.info("Fetching credential secrets and OCI configuration...") + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + request.cls.index_name = "test_vector_index_attr" + vi_attrs = OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED", + ) + request.cls.vector_index_attributes = vi_attrs + + vi = AsyncVectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile, + ) + await vi.create(replace=True) + + created_indexes = [ + idx.index_name async for idx in AsyncVectorIndex.list() + ] + assert request.cls.index_name.upper() in created_indexes, ( + f"VectorIndex {request.cls.index_name} was not created" + ) + + yield + + logger.info("=== Tearing down TestAsyncGetVectorIndexAttributes class ===") + + try: + vector_index = AsyncVectorIndex(index_name=request.cls.index_name) + await vector_index.delete(force=True) + except Exception as exc: + logger.warning("drop vector index failed: %s", exc) + + try: + await request.cls.profile.delete() + except Exception as exc: + logger.warning("profile.delete() raised %s unexpectedly.", exc) + + await request.cls.delete_credential() + + +@pytest.fixture(autouse=True) +async def vector_index_state(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + request.cls.async_vector_index = await AsyncVectorIndex.fetch( + request.cls.index_name + ) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("vector_attr_params", "setup_and_teardown") +class TestAsyncGetVectorIndexAttributes: + @classmethod + def get_native_cred_param(cls, cred_name=None): + logger.info("Preparing native credential params for: %s", cred_name) + p = cls.vector_attr_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None): + logger.info("Preparing basic credential params for: %s", cred_name) + p = cls.vector_attr_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"], + ) + + @classmethod + async def create_credential( + cls, + genai_cred="GENAI_CRED", + objstore_cred="OBJSTORE_CRED", + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + genai_credential = cls.get_native_cred_param(genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + + p = cls.vector_attr_params + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("Credentials created.") + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + p = cls.vector_attr_params + return await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ), + ), + description="OCI GENAI Profile", + replace=True, + ) + + @classmethod + async def delete_credential(cls): + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + + async def test_5301(self): + """Get vector index attributes and verify type.""" + logger.info("Getting vector index attributes and verifying type...") + attrs = await self.async_vector_index.get_attributes() + assert isinstance(attrs, OracleVectorIndexAttributes) + logger.info("Attributes type verified successfully.") + + async def test_5302(self): + """Verify core values of vector index attributes.""" + logger.info("Getting vector index attributes and verifying core values...") + attrs = await self.async_vector_index.get_attributes() + assert attrs.location == self.embedding_location + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.profile_name == "vector_ai_profile" + assert attrs.pipeline_name == f"{self.index_name.upper()}$VECPIPELINE" + logger.info("Core attribute values verified successfully.") + + async def test_5303(self): + """Additional sanity checks on vector index attributes.""" + logger.info( + "Performing additional sanity checks on vector index attributes..." + ) + attrs = await self.async_vector_index.get_attributes() + assert attrs.chunk_size is None + assert attrs.chunk_overlap is None + assert attrs.match_limit is None + assert attrs.refresh_rate == 1440 + assert attrs.vector_distance_metric is None + assert attrs.vector_db_provider.value == "oracle" + logger.info("Additional sanity checks passed successfully.") + + async def test_5304(self): + """Verify required fields in attributes object.""" + logger.info("Verifying attributes object contains required fields...") + attrs = await self.async_vector_index.get_attributes() + logger.info("Attributes dict: %s", attrs.__dict__) + assert hasattr(attrs, "location") + assert hasattr(attrs, "object_storage_credential_name") + logger.info("Attributes contain all expected fields.") + + async def test_5305(self): + """Repeatability: fetch attributes twice and compare.""" + logger.info("Fetching attributes twice to confirm repeatability...") + attrs1 = await self.async_vector_index.get_attributes() + attrs2 = await self.async_vector_index.get_attributes() + assert attrs1.location == attrs2.location + logger.info("Attribute values are repeatable across calls.") + + async def test_5306(self): + """Test case-insensitive index name handling.""" + logger.info("Testing case-insensitive index name handling...") + vector_index = await AsyncVectorIndex.fetch(self.index_name.lower()) + attrs = await vector_index.get_attributes() + assert attrs.location == self.embedding_location + logger.info("Case-insensitive index name test passed.") + + async def test_5307(self): + """Type check on key vector index attributes.""" + logger.info("Performing type check on key vector index attributes...") + attrs = await self.async_vector_index.get_attributes() + logger.info("%s", attrs) + assert isinstance(attrs.location, str) + assert isinstance(attrs.profile_name, str) + assert isinstance(attrs.object_storage_credential_name, str) + logger.info("All attribute fields are of correct type.") + + async def test_5308(self): + """Calling get_attributes on nonexistent index raises error.""" + logger.info("Testing get_attributes() with a nonexistent index...") + with pytest.raises(VectorIndexNotFoundError): + await AsyncVectorIndex(index_name="does_not_exist").get_attributes() + logger.info( + "Nonexistent index correctly raised VectorIndexNotFoundError." + ) + + async def test_5309(self): + """Verify error after deleting a temporary vector index.""" + logger.info("Testing error after deleting a temporary vector index...") + temp_index = AsyncVectorIndex( + index_name="temp_index_for_delete", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED", + ), + description="Test vector index", + profile=self.profile, + ) + await temp_index.create(replace=True) + logger.info("Temporary vector index created.") + await temp_index.delete(force=True) + logger.info( + "Temporary vector index deleted. Attempting to fetch attributes..." + ) + with pytest.raises(VectorIndexNotFoundError): + await AsyncVectorIndex( + index_name="temp_index_for_delete" + ).get_attributes() + logger.info("Expected error raised after deleting index.") + + async def test_5310(self): + """Access attributes after deleting the vector index.""" + logger.info( + "Testing access to attributes object after deleting " + "the vector index..." + ) + temp_index = AsyncVectorIndex( + index_name="temp_index_for_delete", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED", + ), + description="Test vector index", + profile=self.profile, + ) + await temp_index.create(replace=True) + logger.info("Fetching attributes before deletion...") + attrs = await temp_index.get_attributes() + assert isinstance(attrs, OracleVectorIndexAttributes) + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.location == self.embedding_location + logger.info("Deleting temporary index...") + await temp_index.delete(force=True) + logger.info("Accessing cached attributes after deletion...") + logger.info("After delete: %s", attrs) + assert isinstance(attrs, OracleVectorIndexAttributes) + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.location == self.embedding_location + logger.info("Attributes object remains valid after deletion.") + + async def test_5311(self): + """get_attributes with empty index name raises error.""" + logger.info("Testing get_attributes() with empty index name...") + with pytest.raises(AttributeError): + await AsyncVectorIndex(index_name="").get_attributes() + logger.info("Empty name correctly raised AttributeError.") + + async def test_5312(self): + """get_attributes with None as index name raises error.""" + logger.info("Testing get_attributes() with None as index name...") + with pytest.raises(AttributeError): + await AsyncVectorIndex(index_name=None).get_attributes() + logger.info("None name correctly raised AttributeError.") + + async def test_5313(self): + """get_attributes with special characters in index name.""" + logger.info( + "Testing get_attributes() with special characters in index name..." + ) + with pytest.raises(VectorIndexNotFoundError): + await AsyncVectorIndex(index_name="@@invalid!!").get_attributes() + logger.info( + "Special character name correctly raised VectorIndexNotFoundError." + ) + + async def test_5314(self): + """get_attributes with Unicode index name raises error.""" + logger.info("Testing get_attributes() with Unicode index name...") + with pytest.raises(VectorIndexNotFoundError): + await AsyncVectorIndex(index_name="テスト").get_attributes() + logger.info("Unicode name correctly raised VectorIndexNotFoundError.") + + async def test_5315(self): + """Multiple indices: check their attribute differences.""" + logger.info( + "Creating multiple vector indices to compare their attributes..." + ) + index_a = AsyncVectorIndex( + index_name="index_a", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED", + ), + description="Test vector index", + profile=self.profile, + ) + index_b = AsyncVectorIndex( + index_name="index_b", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED", + ), + description="Test vector index", + profile=self.profile, + ) + try: + logger.info("Creating index_a...") + await index_a.create(replace=True) + logger.info("Creating index_b...") + await index_b.create(replace=True) + logger.info("Fetching attributes for both indices...") + attrs_a = await AsyncVectorIndex(index_name="index_a").get_attributes() + attrs_b = await AsyncVectorIndex(index_name="index_b").get_attributes() + logger.info("Attrs_a: %s", attrs_a) + assert attrs_a.pipeline_name != attrs_b.pipeline_name + logger.info("Indices have distinct pipeline names as expected.") + finally: + logger.info("Deleting both indices...") + for index in (index_a, index_b): + try: + await index.delete(force=True) + except Exception as exc: + logger.warning( + "Cleanup failed for %s: %s", + index.index_name, + exc, + ) + + async def test_5316(self): + """Attributes remain consistent after index delete and recreate.""" + logger.info("Testing attributes consistency after delete and recreate...") + temp_index = AsyncVectorIndex( + index_name="temp_recreate", + attributes=OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED", + ), + description="Test vector index", + profile=self.profile, + ) + try: + logger.info("Creating temporary vector index for recreate test...") + await temp_index.create(replace=True) + logger.info("Deleting temporary index...") + await temp_index.delete(force=True) + logger.info("Recreating temporary index...") + await temp_index.create(replace=True) + logger.info("Fetching attributes after recreation...") + attrs = await AsyncVectorIndex( + index_name="temp_recreate" + ).get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + logger.info("Recreate test completed successfully.") + finally: + try: + await temp_index.delete(force=True) + except Exception as exc: + logger.warning("Cleanup failed for temp_recreate: %s", exc) + + async def test_5317(self): + """get_attributes with very long index name raises error.""" + logger.info("Testing get_attributes() with very long index name...") + long_name = "X" * 100 + with pytest.raises(VectorIndexNotFoundError): + await AsyncVectorIndex(index_name=long_name).get_attributes() + logger.info("Long name correctly raised VectorIndexNotFoundError.") + + async def test_5318(self): + """get_attributes after disconnecting from database raises error.""" + logger.info("Testing get_attributes() after disconnecting from database...") + await select_ai.async_disconnect() + with pytest.raises(Exception): + await AsyncVectorIndex(index_name=self.index_name).get_attributes() + logger.info("Expected error raised after disconnect.") + logger.info("Reconnecting for remaining tests...") + await select_ai.async_connect(**self.reconnect_params) + assert await select_ai.async_is_connected(), "Connection to DB failed" + logger.info("Reconnection successful.") diff --git a/tests/vector_index/test_5300_getindex_attributes.py b/tests/vector_index/test_5300_getindex_attributes.py new file mode 100644 index 0000000..f83d35e --- /dev/null +++ b/tests/vector_index/test_5300_getindex_attributes.py @@ -0,0 +1,429 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +from select_ai import VectorIndex, OracleVectorIndexAttributes +from select_ai.errors import VectorIndexNotFoundError + +logger = logging.getLogger("TestGetVectorIndexAttributes") + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + + +@pytest.fixture(scope="class") +def vector_attr_params( + request, + vcidx_params, +): + request.cls.vector_attr_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, vector_attr_params, test_env): + logger.info("=== Setting up TestGetVectorIndexAttributes class ===") + p = request.cls.vector_attr_params + + # 'connect' fixture from base tests/conftest.py ensures DB connection exists. + # Do NOT disconnect here; let the session fixture own lifecycle. + assert select_ai.is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.reconnect_params = test_env.connect_params() + + logger.info("Fetching credential secrets and OCI configuration...") + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + # create vector index + request.cls.index_name = "test_vector_index_attr" + vi_attrs = OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED" + ) + request.cls.vector_index_attributes = vi_attrs + + vi = VectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile + ) + vi.create(replace=True) + + # Keep original validation intent (handle either classmethod or instance list()) + try: + created_indexes = [idx.index_name for idx in VectorIndex.list()] + except Exception: + created_indexes = [idx.index_name for idx in VectorIndex().list()] + assert request.cls.index_name.upper() in created_indexes, f"VectorIndex {request.cls.index_name} was not created" + + yield + + logger.info("=== Tearing down TestGetVectorIndexAttributes class ===") + + # Delete Vector Index + try: + vector_index = VectorIndex(index_name=request.cls.index_name) + vector_index.delete(force=True) + except Exception as e: + logger.warning(f"drop vector index failed: {e}") + + # Delete Profile + try: + request.cls.profile.delete() + except Exception as e: + logger.warning(f"profile.delete() raised {e} unexpectedly.") + + # Delete Credential + request.cls.delete_credential() + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +@pytest.mark.usefixtures("vector_attr_params", "setup_and_teardown") +class TestGetVectorIndexAttributes: + @classmethod + def get_native_cred_param(cls, cred_name=None): + logger.info(f"Preparing native credential params for: {cred_name}") + p = cls.vector_attr_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"] + ) + + @classmethod + def get_cred_param(cls, cred_name=None): + logger.info(f"Preparing basic credential params for: {cred_name}") + p = cls.vector_attr_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"] + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info("Creating credentials: GENAI_CRED, OBJSTORE_CRED") + p = cls.vector_attr_params + + genai_credential = cls.get_native_cred_param(genai_cred) + select_ai.create_credential(credential=genai_credential, replace=True) + + # Only create OBJSTORE_CRED if creds are provided in env + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("Credentials created.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls): + p = cls.vector_attr_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + return select_ai.Profile( + profile_name="vector_ai_profile", + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + @classmethod + def delete_credential(cls): + try: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + vecidx = VectorIndex() + self.vector_index = (list(vecidx.list(index_name_pattern=self.index_name)))[0] + + def teardown_method(self, method): + logger.info(f"TearDown for {method.__name__}") + + # ---------------- + # Positive tests + # ---------------- + def test_5301(self): + """Get vector index attributes and verify type.""" + logger.info("Getting vector index attributes and verifying type...") + attrs = self.vector_index.get_attributes() + assert isinstance(attrs, OracleVectorIndexAttributes) + logger.info("Attributes type verified successfully.") + + def test_5302(self): + """Verify core values of vector index attributes.""" + logger.info("Getting vector index attributes and verifying core values...") + attrs = self.vector_index.get_attributes() + assert attrs.location == self.embedding_location + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.profile_name == "vector_ai_profile" + assert attrs.pipeline_name == f"{self.index_name.upper()}$VECPIPELINE" + logger.info("Core attribute values verified successfully.") + + def test_5303(self): + """Additional sanity checks on vector index attributes.""" + logger.info("Performing additional sanity checks on vector index attributes...") + attrs = self.vector_index.get_attributes() + assert attrs.chunk_size is None + assert attrs.chunk_overlap is None + assert attrs.match_limit is None + assert attrs.refresh_rate == 1440 + assert attrs.vector_distance_metric is None + assert attrs.vector_db_provider.value == "oracle" + logger.info("Additional sanity checks passed successfully.") + + def test_5304(self): + """Verify required fields in attributes object.""" + logger.info("Verifying attributes object contains required fields...") + attrs = self.vector_index.get_attributes() + logger.info(f"Attributes dict: {attrs.__dict__}") + assert hasattr(attrs, "location") + assert hasattr(attrs, "object_storage_credential_name") + logger.info("Attributes contain all expected fields.") + + def test_5305(self): + """Repeatability: fetch attributes twice and compare.""" + logger.info("Fetching attributes twice to confirm repeatability...") + attrs1 = self.vector_index.get_attributes() + attrs2 = self.vector_index.get_attributes() + assert attrs1.location == attrs2.location + logger.info("Attribute values are repeatable across calls.") + + def test_5306(self): + """Test case-insensitive index name handling.""" + logger.info("Testing case-insensitive index name handling...") + vecidx = VectorIndex() + vector_index = (list(vecidx.list(index_name_pattern=self.index_name.lower())))[0] + attrs = vector_index.get_attributes() + assert attrs.location == self.embedding_location + logger.info("Case-insensitive index name test passed.") + + def test_5307(self): + """Type check on key vector index attributes.""" + logger.info("Performing type check on key vector index attributes...") + attrs = self.vector_index.get_attributes() + logger.info(f"{attrs}") + assert isinstance(attrs.location, str) + assert isinstance(attrs.profile_name, str) + assert isinstance(attrs.object_storage_credential_name, str) + logger.info("All attribute fields are of correct type.") + + # ---------------- + # Negative tests + # ---------------- + def test_5308(self): + """Calling get_attributes on nonexistent index raises error.""" + logger.info("Testing get_attributes() with a nonexistent index...") + with pytest.raises(VectorIndexNotFoundError): + VectorIndex(index_name="does_not_exist").get_attributes() + logger.info("Nonexistent index correctly raised VectorIndexNotFoundError.") + + def test_5309(self): + """verify error after deleting a temporary vector index.""" + logger.info("Testing error after deleting a temporary vector index...") + vector_index_attributes = OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED" + ) + logger.info("Creating temporary vector index...") + temp_index = VectorIndex( + index_name="temp_index_for_delete", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + temp_index.create(replace=True) + logger.info("Temporary vector index created.") + temp_index.delete(force=True) + logger.info("Temporary vector index deleted. Attempting to fetch attributes...") + with pytest.raises(VectorIndexNotFoundError): + VectorIndex(index_name="temp_index_for_delete").get_attributes() + logger.info("Expected error raised after deleting index.") + + def test_5310(self): + """Access attributes after deleting the vector index (should use cache).""" + logger.info("Testing access to attributes object after deleting the vector index...") + vector_index_attributes = OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED" + ) + logger.info("Creating temporary vector index for deletion test...") + temp_index = VectorIndex( + index_name="temp_index_for_delete", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + temp_index.create(replace=True) + logger.info("Fetching attributes before deletion...") + attrs = temp_index.get_attributes() + assert isinstance(attrs, OracleVectorIndexAttributes) + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.location == self.embedding_location + logger.info("Deleting temporary index...") + temp_index.delete(force=True) + logger.info("Accessing cached attributes after deletion...") + logger.info(f"After delete: {attrs}") + assert isinstance(attrs, OracleVectorIndexAttributes) + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + assert attrs.location == self.embedding_location + logger.info("Attributes object remains valid after deletion.") + + def test_5311(self): + """get_attributes with empty index name raises error.""" + logger.info("Testing get_attributes() with empty index name...") + with pytest.raises(AttributeError): + VectorIndex(index_name="").get_attributes() + logger.info("Empty name correctly raised AttributeError.") + + def test_5312(self): + """get_attributes with None as index name raises error.""" + logger.info("Testing get_attributes() with None as index name...") + with pytest.raises(AttributeError): + VectorIndex(index_name=None).get_attributes() + logger.info("None name correctly raised AttributeError.") + + def test_5313(self): + """get_attributes with special characters in index name.""" + logger.info("Testing get_attributes() with special characters in index name...") + with pytest.raises(VectorIndexNotFoundError): + VectorIndex(index_name='@@invalid!!').get_attributes() + logger.info("Special character name correctly raised VectorIndexNotFoundError.") + + def test_5314(self): + """get_attributes with Unicode index name raises error.""" + logger.info("Testing get_attributes() with Unicode index name...") + with pytest.raises(VectorIndexNotFoundError): + VectorIndex(index_name='テスト').get_attributes() + logger.info("Unicode name correctly raised VectorIndexNotFoundError.") + + # ---------------- + # Stress / Edge cases + # ---------------- + def test_5315(self): + """Multiple indices: check their attribute differences.""" + logger.info("Creating multiple vector indices to compare their attributes...") + vector_index_attributes = OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED" + ) + logger.info("Creating index_a...") + index_a = VectorIndex( + index_name="index_a", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + index_a.create(replace=True) + logger.info("Creating index_b...") + index_b = VectorIndex( + index_name="index_b", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + index_b.create(replace=True) + logger.info("Fetching attributes for both indices...") + attrs_a = VectorIndex(index_name="index_a").get_attributes() + logger.info(f"Attrs_a: {attrs_a}") + attrs_b = VectorIndex(index_name="index_b").get_attributes() + assert attrs_a.pipeline_name != attrs_b.pipeline_name + logger.info("Indices have distinct pipeline names as expected.") + logger.info("Deleting both indices...") + index_a.delete(force=True) + index_b.delete(force=True) + logger.info("Both indices deleted successfully.") + + def test_5316(self): + """Attributes remain consistent after index delete and recreate.""" + logger.info("Testing attributes consistency after delete and recreate...") + vector_index_attributes = OracleVectorIndexAttributes( + location=self.embedding_location, + object_storage_credential_name="OBJSTORE_CRED" + ) + logger.info("Creating temporary vector index for recreate test...") + temp_index = VectorIndex( + index_name="temp_recreate", + attributes=vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + temp_index.create(replace=True) + logger.info("Deleting temporary index...") + temp_index.delete(force=True) + logger.info("Recreating temporary index...") + temp_index.create(replace=True) + logger.info("Fetching attributes after recreation...") + attrs = VectorIndex(index_name="temp_recreate").get_attributes() + assert attrs.object_storage_credential_name == "OBJSTORE_CRED" + temp_index.delete(force=True) + logger.info("Recreate test completed successfully.") + + def test_5317(self): + """get_attributes with very long index name raises error.""" + logger.info("Testing get_attributes() with very long index name...") + long_name = "X" * 100 + with pytest.raises(VectorIndexNotFoundError): + VectorIndex(index_name=long_name).get_attributes() + logger.info("Long name correctly raised VectorIndexNotFoundError.") + + def test_5318(self): + """get_attributes after disconnecting from database raises error.""" + logger.info("Testing get_attributes() after disconnecting from database...") + select_ai.disconnect() + with pytest.raises(Exception): + VectorIndex(index_name=self.index_name).get_attributes() + logger.info("Expected error raised after disconnect.") + logger.info("Reconnecting for remaining tests...") + + select_ai.connect(**self.reconnect_params) + + logger.info("Reconnection successful.") diff --git a/tests/vector_index/test_5400_async_list_index.py b/tests/vector_index/test_5400_async_list_index.py new file mode 100644 index 0000000..08d5409 --- /dev/null +++ b/tests/vector_index/test_5400_async_list_index.py @@ -0,0 +1,406 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging + +import oracledb +import pytest +import select_ai +from select_ai import AsyncVectorIndex, OracleVectorIndexAttributes + +logger = logging.getLogger("TestAsyncListVectorIndex") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def list_vec_params(request, vcidx_params): + request.cls.list_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, list_vec_params): + logger.info("=== Setting up TestAsyncListVectorIndex class ===") + p = request.cls.list_vec_params + + assert await select_ai.async_is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.objstore_cred = "OBJSTORE_CRED" + + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + request.cls.indexes = [ + f"test_vector_index{i}" for i in range(1, 6) + ] + [ + f"test_vecidx{i}" for i in range(1, 3) + ] + + for idx in request.cls.indexes: + try: + await request.cls.create_vector_index(index_name=idx) + except Exception as exc: + logger.warning( + "Index creation failed or already exists for %s: %s", + idx, + exc, + ) + + yield + + logger.info("=== Tearing down TestAsyncListVectorIndex class ===") + for idx in request.cls.indexes: + try: + await AsyncVectorIndex(index_name=idx).delete(force=True) + except Exception as exc: + logger.warning("Warning: drop vector index failed: %s", exc) + + try: + await request.cls.profile.delete() + except Exception as exc: + logger.warning("profile.delete() raised %s unexpectedly.", exc) + + await request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +async def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + request.cls.vector_index = AsyncVectorIndex() + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("list_vec_params", "setup_and_teardown") +class TestAsyncListVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing native credential params for: %s", cred_name) + p = cls.list_vec_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing basic credential params for: %s", cred_name) + p = cls.list_vec_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"], + ) + + @classmethod + async def create_credential( + cls, + genai_cred="GENAI_CRED", + objstore_cred="OBJSTORE_CRED", + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + p = cls.list_vec_params + + genai_credential = cls.get_native_cred_param(genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("Credentials created.") + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + p = cls.list_vec_params + return await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ), + ), + description="OCI GENAI Profile", + replace=True, + ) + + @classmethod + async def create_vector_index(cls, index_name): + logger.info("Creating vector index: %s", index_name) + vector_index_attributes = OracleVectorIndexAttributes( + location=cls.list_vec_params["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED", + ) + vector_index = AsyncVectorIndex( + index_name=index_name, + attributes=vector_index_attributes, + description="Test vector index", + profile=cls.profile, + ) + await vector_index.create(replace=True) + logger.info("Vector index '%s' created successfully.", index_name) + + @classmethod + async def delete_credential(cls): + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + + async def collect_indexes(self, pattern=".*"): + return [ + idx async for idx in self.vector_index.list( + index_name_pattern=pattern + ) + ] + + def expected_index_names(self): + return [ + f"TEST_VECTOR_INDEX{i}" for i in range(1, 6) + ] + [ + f"TEST_VECIDX{i}" for i in range(1, 3) + ] + + async def fetch_expected_indexes(self): + return [ + await AsyncVectorIndex.fetch(index_name) + for index_name in self.expected_index_names() + ] + + async def test_5401(self): + """Verify list of vector indexes with matching names.""" + logger.info("Verifying list of vector indexes with matching names...") + expected_index_names = [ + f"TEST_VECTOR_INDEX{i}" for i in range(1, 6) + ] + [ + f"TEST_VECIDX{i}" for i in range(1, 3) + ] + actual_indexes = await self.collect_indexes(".*") + logger.info( + "Found %s indexes, verifying names match expectations...", + len(actual_indexes), + ) + actual_index_names = [idx.index_name for idx in actual_indexes] + matched_names = [ + name for name in actual_index_names if name in expected_index_names + ] + assert sorted(matched_names) == sorted(expected_index_names), ( + f"Expected names {sorted(expected_index_names)}, " + f"got {sorted(actual_index_names)}" + ) + logger.info("All expected index names matched as expected.") + + async def test_5402(self): + """Verify each index has correct profile name.""" + logger.info("Verifying each index has correct profile name...") + expected_profile = "vector_ai_profile" + for index in await self.fetch_expected_indexes(): + assert index.profile.profile_name == expected_profile, ( + f"Profile mismatch for {index.index_name}: " + f"expected {expected_profile}, got {index.profile.profile_name}" + ) + logger.info("All indexes have correct profile name.") + + async def test_5403(self): + """Verify each index has correct object store credential name.""" + logger.info( + "Verifying each index has correct object store credential name..." + ) + expected_credential = "OBJSTORE_CRED" + for index in await self.fetch_expected_indexes(): + assert ( + index.attributes.object_storage_credential_name + == expected_credential + ), ( + f"Credential mismatch for {index.index_name}: " + f"expected {expected_credential}, " + f"got {index.attributes.object_storage_credential_name}" + ) + logger.info("All indexes have correct object store credential name.") + + async def test_5404(self): + """Verify descriptions for all indexes.""" + logger.info("Verifying descriptions for all indexes...") + expected_description = "Test vector index" + for index in await self.fetch_expected_indexes(): + assert index.description == expected_description, ( + f"Description mismatch for {index.index_name}: " + f"expected {expected_description}, got {index.description}" + ) + logger.info("All indexes have correct descriptions.") + + async def test_5405(self): + """Test exact match listing for index name.""" + logger.info( + "Testing exact match listing for index name " + "'test_vector_index1'..." + ) + indexes = await self.collect_indexes("^test_vector_index1$") + assert indexes[0].index_name == "TEST_VECTOR_INDEX1" + logger.info("Exact match returned correct index.") + + async def test_5406(self): + """Verify multiple matches for pattern.""" + logger.info( + "Verifying multiple matches for pattern '^test_vector_index'..." + ) + actual_indexes = await self.collect_indexes("^test_vector_index") + actual_index_names = [index.index_name for index in actual_indexes] + expected_index_names = [ + f"TEST_VECTOR_INDEX{i}" for i in range(1, 6) + ] + matched_names = [ + name for name in actual_index_names if name in expected_index_names + ] + assert sorted(matched_names) == sorted(expected_index_names), ( + f"Expected names {sorted(expected_index_names)}, " + f"got {sorted(actual_index_names)}" + ) + logger.info("Multiple index names verified successfully.") + + async def test_5407(self): + """Test case-sensitive regex pattern for listing indexes.""" + logger.info("Testing case-sensitive regex pattern for listing indexes...") + indexes = await self.collect_indexes("^TEST_VECTOR_INDEX?") + assert any(idx.index_name == "TEST_VECTOR_INDEX2" for idx in indexes) + logger.info("Case-sensitive pattern matched correctly.") + + async def test_5408(self): + """Test case-insensitive regex pattern for listing indexes.""" + logger.info("Testing case-insensitive regex pattern for listing indexes...") + indexes = await self.collect_indexes("^TEST") + assert any( + idx.index_name.upper() == "TEST_VECTOR_INDEX1" for idx in indexes + ) + logger.info("Case-insensitive pattern matched correctly.") + + async def test_5409(self): + """Test complex regex pattern with OR operator.""" + logger.info("Testing complex regex pattern with OR operator...") + indexes = await self.collect_indexes("^(test_vector_index|test_vecidx)") + names = [idx.index_name for idx in indexes] + assert "TEST_VECTOR_INDEX1" in names + assert "TEST_VECIDX1" in names + assert "INVALID_VECIDX1" not in names + logger.info("Complex regex OR pattern matched correctly.") + + async def test_5410(self): + """Test non-matching regex pattern returns nothing.""" + logger.info("Testing non-matching regex pattern...") + indexes = await self.collect_indexes("^xyz") + assert len(indexes) == 0 + logger.info("Non-matching pattern returned no results as expected.") + + async def test_5411(self): + """Test invalid regex pattern expecting ORA-12726 error.""" + logger.info("Testing invalid regex pattern expecting ORA-12726 error...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.collect_indexes("[unclosed") + assert "ORA-12726" in str(exc_info.value) + logger.info("Invalid regex correctly raised ORA-12726 error.") + + async def test_5412(self): + """Test invalid type pattern (numeric instead of string).""" + logger.info("Testing invalid type pattern (numeric instead of string)...") + indexes = await self.collect_indexes(123) + assert len(indexes) == 0 + logger.info("Invalid type pattern handled correctly with empty result.") + + async def test_5413(self): + """Test None as pattern input.""" + logger.info("Testing None as pattern input...") + indexes = await self.collect_indexes(None) + assert len(indexes) != len(self.indexes) + logger.info("None pattern handled correctly.") + + async def test_5414(self): + """Test empty string as pattern input.""" + logger.info("Testing empty string as pattern input...") + indexes = await self.collect_indexes("") + assert len(indexes) != len(self.indexes) + logger.info("Empty string pattern handled correctly.") + + async def test_5415(self): + """Test whitespace-only pattern.""" + logger.info("Testing whitespace-only pattern...") + indexes = await self.collect_indexes(" ") + assert len(indexes) == 0 + logger.info("Whitespace pattern correctly returned no matches.") + + async def test_5416(self): + """Test numeric string pattern yields no matches.""" + logger.info("Testing numeric pattern that should yield no matches...") + indexes = await self.collect_indexes("test123") + assert len(indexes) == 0 + logger.info("Numeric pattern correctly returned no matches.") + + async def test_5417(self): + """Test pattern with special characters '$'.""" + logger.info("Testing pattern with special characters '$'...") + indexes = await self.collect_indexes("test_vector_index1$") + assert len(indexes) == 1 + logger.info("Special character pattern matched correctly.") + + async def test_5418(self): + """Test extremely long regex pattern expecting ORA-12733 error.""" + logger.info( + "Testing extremely long regex pattern expecting ORA-12733 error..." + ) + pattern = "^" + "a" * 1000 + "$" + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.collect_indexes(pattern) + assert "ORA-12733" in str(exc_info.value) + logger.info("Long regex correctly raised ORA-12733 error.") + + async def test_5419(self): + """Test case-insensitive match for prefix.""" + logger.info("Testing case-insensitive match for prefix '^TEST'...") + indexes = await self.collect_indexes("^TEST") + expected_index_names = [ + f"TEST_VECTOR_INDEX{i}" for i in range(1, 6) + ] + [ + f"TEST_VECIDX{i}" for i in range(1, 3) + ] + actual_index_names = [idx.index_name for idx in indexes] + matched_names = [ + name for name in actual_index_names if name in expected_index_names + ] + assert sorted(matched_names) == sorted(expected_index_names) + logger.info("Case-insensitive match returned all expected indexes.") diff --git a/tests/vector_index/test_5400_list_index.py b/tests/vector_index/test_5400_list_index.py new file mode 100644 index 0000000..ccb0fa8 --- /dev/null +++ b/tests/vector_index/test_5400_list_index.py @@ -0,0 +1,350 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb + +logger = logging.getLogger("TestListVectorIndex") + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + + +@pytest.fixture(scope="class") +def list_vec_params( + request, + vcidx_params, +): + request.cls.list_vec_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, list_vec_params): + logger.info("=== Setting up TestListVectorIndex class ===") + p = request.cls.list_vec_params + + # 'connect' fixture from base tests/conftest.py ensures DB connection exists. + # Do NOT disconnect here; let the session fixture own lifecycle. + assert select_ai.is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + request.cls.objstore_cred = "OBJSTORE_CRED" + + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Profile 'vector_ai_profile' created successfully.") + + def create_vector_index(index_name): + logger.info(f"Creating vector index: {index_name}") + vector_index_attributes = select_ai.OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED" + ) + vector_index = select_ai.VectorIndex( + index_name=index_name, + attributes=vector_index_attributes, + description="Test vector index", + profile=request.cls.profile + ) + vector_index.create(replace=True) + logger.info(f"Vector index '{index_name}' created successfully.") + + request.cls.indexes = [f"test_vector_index{i}" for i in range(1, 6)] + \ + [f"test_vecidx{i}" for i in range(1, 3)] + + for idx in request.cls.indexes: + try: + create_vector_index(index_name=idx) + except Exception as exc: + logger.warning(f"Index creation failed or already exists for {idx}: {exc}") + + yield + + logger.info("=== Tearing down TestListVectorIndex class ===") + for idx in request.cls.indexes: + try: + vector_index = select_ai.VectorIndex(index_name=idx) + vector_index.delete(force=True) + except Exception as e: + logger.warning(f"Warning: drop vector index failed: {e}") + + try: + request.cls.profile.delete() + except Exception as e: + logger.warning(f"profile.delete() raised {e} unexpectedly.") + + request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +@pytest.mark.usefixtures("list_vec_params", "setup_and_teardown") +class TestListVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing native credential params for: {cred_name}") + p = cls.list_vec_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"] + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing basic credential params for: {cred_name}") + p = cls.list_vec_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"] + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info("Creating credentials: GENAI_CRED, OBJSTORE_CRED") + p = cls.list_vec_params + + genai_credential = cls.get_native_cred_param(genai_cred) + select_ai.create_credential(credential=genai_credential, replace=True) + + # Only create OBJSTORE_CRED if creds are provided in env + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("Credentials created.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls): + p = cls.list_vec_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + return select_ai.Profile( + profile_name="vector_ai_profile", + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + + @classmethod + def delete_credential(cls): + try: + select_ai.delete_credential("GENAI_CRED", force=True) + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + self.vector_index = select_ai.VectorIndex() + + def teardown_method(self, method): + logger.info(f"TearDown for {method.__name__}") + + # ---------------- + # Positive tests + # ---------------- + def test_5401(self): + """Verify list of vector indexes with matching names.""" + logger.info("Verifying list of vector indexes with matching names...") + expected_index_names = [f"TEST_VECTOR_INDEX{i}" for i in range(1, 6)] + \ + [f"TEST_VECIDX{i}" for i in range(1, 3)] + actual_indexes = list(self.vector_index.list(index_name_pattern=".*")) + logger.info(f"Found {len(actual_indexes)} indexes, verifying names match expectations...") + assert len(actual_indexes) == len(expected_index_names), \ + f"Expected {len(expected_index_names)} indexes, got {len(actual_indexes)}" + actual_index_names = [idx.index_name for idx in actual_indexes] + assert sorted(actual_index_names) == sorted(expected_index_names), \ + f"Expected names {sorted(expected_index_names)}, got {sorted(actual_index_names)}" + logger.info("All index names matched as expected.") + + def test_5402(self): + """Verify each index has correct profile name.""" + logger.info("Verifying each index has correct profile name...") + expected_profile = "vector_ai_profile" + for index in self.vector_index.list(index_name_pattern=".*"): + assert index.profile.profile_name == expected_profile, \ + f"Profile mismatch for {index.index_name}: expected {expected_profile}, got {index.profile.profile_name}" + logger.info("All indexes have correct profile name.") + + def test_5403(self): + """Verify each index has correct object store credential name.""" + logger.info("Verifying each index has correct object store credential name...") + expected_credential = "OBJSTORE_CRED" + for index in self.vector_index.list(index_name_pattern=".*"): + assert index.attributes.object_storage_credential_name == expected_credential, \ + f"Credential mismatch for {index.index_name}: expected {expected_credential}, got {index.attributes.object_storage_credential_name}" + logger.info("All indexes have correct object store credential name.") + + def test_5404(self): + """Verify descriptions for all indexes.""" + logger.info("Verifying descriptions for all indexes...") + expected_description = "Test vector index" + for index in self.vector_index.list(index_name_pattern=".*"): + assert index.description == expected_description, \ + f"Description mismatch for {index.index_name}: expected {expected_description}, got {index.description}" + logger.info("All indexes have correct descriptions.") + + def test_5405(self): + """Test exact match listing for index name.""" + logger.info("Testing exact match listing for index name 'test_vector_index1'...") + indexes = self.vector_index.list(index_name_pattern="^test_vector_index1$") + assert list(indexes)[0].index_name == "TEST_VECTOR_INDEX1" + logger.info("Exact match returned correct index.") + + def test_5406(self): + """Verify multiple matches for pattern.""" + logger.info("Verifying multiple matches for pattern '^test_vector_index'...") + actual_indexes = list(self.vector_index.list(index_name_pattern="^test_vector_index")) + expected_count = 5 + assert len(actual_indexes) == expected_count, \ + f"Expected {expected_count} indexes, got {len(actual_indexes)}" + actual_index_names = [index.index_name for index in actual_indexes] + expected_index_names = [f"TEST_VECTOR_INDEX{i}" for i in range(1, 6)] + assert sorted(actual_index_names) == sorted(expected_index_names), \ + f"Expected names {sorted(expected_index_names)}, got {sorted(actual_index_names)}" + logger.info("Multiple index names verified successfully.") + + def test_5407(self): + """Test case-sensitive regex pattern for listing indexes.""" + logger.info("Testing case-sensitive regex pattern for listing indexes...") + indexes = self.vector_index.list("^TEST_VECTOR_INDEX?") + assert any(idx.index_name == "TEST_VECTOR_INDEX2" for idx in indexes) + logger.info("Case-sensitive pattern matched correctly.") + + def test_5408(self): + """Test case-insensitive regex pattern for listing indexes.""" + logger.info("Testing case-insensitive regex pattern for listing indexes...") + indexes = self.vector_index.list("^TEST") + assert any(idx.index_name.upper() == "TEST_VECTOR_INDEX1" for idx in indexes) + logger.info("Case-insensitive pattern matched correctly.") + + def test_5409(self): + """Test complex regex pattern with OR operator.""" + logger.info("Testing complex regex pattern with OR operator...") + indexes = self.vector_index.list("^(test_vector_index|test_vecidx)") + names = [idx.index_name for idx in indexes] + assert "TEST_VECTOR_INDEX1" in names + assert "TEST_VECIDX1" in names + assert "INVALID_VECIDX1" not in names + logger.info("Complex regex OR pattern matched correctly.") + + # ---------------- + # Negative tests + # ---------------- + def test_5410(self): + """Test non-matching regex pattern returns nothing.""" + logger.info("Testing non-matching regex pattern...") + indexes = self.vector_index.list(index_name_pattern="^xyz") + assert len(list(indexes)) == 0 + logger.info("Non-matching pattern returned no results as expected.") + + def test_5411(self): + """Test invalid regex pattern expecting ORA-12726 error.""" + logger.info("Testing invalid regex pattern expecting ORA-12726 error...") + with pytest.raises(oracledb.DatabaseError) as exc_info: + list(self.vector_index.list("[unclosed")) + assert "ORA-12726" in str(exc_info.value) + logger.info("Invalid regex correctly raised ORA-12726 error.") + + def test_5412(self): + """Test invalid type pattern (numeric instead of string).""" + logger.info("Testing invalid type pattern (numeric instead of string)...") + indexes = list(self.vector_index.list(123)) + assert len(indexes) == 0 + logger.info("Invalid type pattern handled correctly with empty result.") + + # ---------------- + # Stress / Edge cases + # ---------------- + def test_5413(self): + """Test None as pattern input.""" + logger.info("Testing None as pattern input...") + indexes = self.vector_index.list(None) + assert len(list(indexes)) != len(self.indexes) + logger.info("None pattern handled correctly.") + + def test_5414(self): + """Test empty string as pattern input.""" + logger.info("Testing empty string as pattern input...") + indexes = self.vector_index.list("") + assert len(list(indexes)) != len(self.indexes) + logger.info("Empty string pattern handled correctly.") + + def test_5415(self): + """Test whitespace-only pattern.""" + logger.info("Testing whitespace-only pattern...") + indexes = self.vector_index.list(" ") + assert len(list(indexes)) == 0 + logger.info("Whitespace pattern correctly returned no matches.") + + def test_5416(self): + """Test numeric string pattern yields no matches.""" + logger.info("Testing numeric pattern that should yield no matches...") + indexes = list(self.vector_index.list("test123")) + assert len(indexes) == 0 + logger.info("Numeric pattern correctly returned no matches.") + + def test_5417(self): + """Test pattern with special characters '$'.""" + logger.info("Testing pattern with special characters '$'...") + indexes = self.vector_index.list("test_vector_index1$") + assert len(list(indexes)) == 1 + logger.info("Special character pattern matched correctly.") + + def test_5418(self): + """Test extremely long regex pattern expecting ORA-12733 error.""" + logger.info("Testing extremely long regex pattern expecting ORA-12733 error...") + pattern = "^" + "a" * 1000 + "$" + with pytest.raises(oracledb.DatabaseError) as exc_info: + list(self.vector_index.list(pattern)) + assert "ORA-12733" in str(exc_info.value) + logger.info("Long regex correctly raised ORA-12733 error.") + + def test_5419(self): + """Test case-insensitive match for prefix.""" + logger.info("Testing case-insensitive match for prefix '^TEST'...") + indexes = self.vector_index.list("^TEST") + assert len(list(indexes)) == 7 + logger.info("Case-insensitive match returned expected count.") diff --git a/tests/vector_index/test_5500_async_enable_disable_index.py b/tests/vector_index/test_5500_async_enable_disable_index.py new file mode 100644 index 0000000..056c18a --- /dev/null +++ b/tests/vector_index/test_5500_async_enable_disable_index.py @@ -0,0 +1,395 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import asyncio +import logging + +import oracledb +import pytest +import select_ai +from select_ai import AsyncVectorIndex, OracleVectorIndexAttributes + +logger = logging.getLogger("TestAsyncEnableDisableVectorIndex") + +pytestmark = pytest.mark.anyio + + +@pytest.fixture(scope="class") +def enabledisable_params(request, vcidx_params): + request.cls.enabledisable_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +async def setup_and_teardown(request, async_connect, enabledisable_params): + logger.info("=== Setting up TestAsyncEnableDisableVectorIndex class ===") + p = request.cls.enabledisable_params + + assert await select_ai.async_is_connected(), "Connection to DB failed" + + request.cls.user = p["user"] + request.cls.password = p["password"] + request.cls.dsn = p["dsn"] + request.cls.user_ocid = p["user_ocid"] + request.cls.tenancy_ocid = p["tenancy_ocid"] + request.cls.private_key = p["private_key"] + request.cls.fingerprint = p["fingerprint"] + request.cls.cred_username = p["cred_username"] + request.cls.cred_password = p["cred_password"] + request.cls.oci_compartment_id = p["oci_compartment_id"] + request.cls.embedding_location = p["embedding_location"] + + logger.info("Fetching credential secrets and OCI configuration...") + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "begin execute immediate 'drop table test_items purge'; " + "exception when others then null; end;" + ) + await cursor.execute( + "create table test_items (id number primary key, " + "name varchar2(50))" + ) + await cursor.execute("insert into test_items values (1, 'Alpha')") + await cursor.execute("insert into test_items values (2, 'Beta')") + await cursor.execute("commit") + + await request.cls.create_credential() + request.cls.profile = await request.cls.create_profile() + logger.info("Setup complete.") + + vi_attrs = OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED", + ) + request.cls.vector_index_attributes = vi_attrs + request.cls.index_name = "test_vector_index" + + vector_index = AsyncVectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile, + ) + await vector_index.create(replace=True) + + created_indexes = [ + idx.index_name async for idx in AsyncVectorIndex.list() + ] + assert request.cls.index_name.upper() in created_indexes, ( + f"VectorIndex {request.cls.index_name} was not created" + ) + + yield + + logger.info("=== Tearing down TestAsyncEnableDisableVectorIndex class ===") + try: + await AsyncVectorIndex(index_name=request.cls.index_name).delete( + force=True + ) + except Exception as exc: + logger.info("Warning: drop vector index failed: %s", exc) + + try: + await request.cls.profile.delete() + except Exception as exc: + logger.warning("profile.delete() raised %s unexpectedly.", exc) + + await request.cls.delete_credential() + + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "begin execute immediate 'drop table test_items purge'; " + "exception when others then null; end;" + ) + + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +async def vector_index_state(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + request.cls.objstore_cred = "OBJSTORE_CRED" + request.cls.vecidx = AsyncVectorIndex() + request.cls.async_vector_index = await AsyncVectorIndex.fetch( + request.cls.index_name + ) + logger.info(request.cls.async_vector_index.index_name) + await request.cls.async_vector_index.enable() + await asyncio.sleep(1) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.mark.usefixtures("enabledisable_params", "setup_and_teardown") +class TestAsyncEnableDisableVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing native credential params for: %s", cred_name) + p = cls.enabledisable_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"], + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info("Preparing basic credential params for: %s", cred_name) + p = cls.enabledisable_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"], + ) + + @classmethod + async def create_credential( + cls, + genai_cred="GENAI_CRED", + objstore_cred="OBJSTORE_CRED", + ): + logger.info("Creating credentials: %s, %s", genai_cred, objstore_cred) + + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info("Creating GenAI credential: %s", genai_cred) + await select_ai.async_create_credential( + credential=genai_credential, + replace=True, + ) + logger.info("GenAI credential created.") + except Exception as exc: + logger.error("create_credential() raised %s unexpectedly.", exc) + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + + p = cls.enabledisable_params + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info( + "Creating ObjectStore credential: %s", objstore_cred + ) + await select_ai.async_create_credential( + credential=objstore_credential, + replace=True, + ) + logger.info("ObjectStore credential created.") + except Exception as exc: + logger.error("create_credential() raised %s unexpectedly.", exc) + raise AssertionError( + f"create_credential() raised {exc} unexpectedly." + ) + else: + logger.info( + "Skipping ObjectStore credential creation " + "(CRED_USERNAME/CRED_PASSWORD not set)." + ) + + @classmethod + async def create_profile(cls, profile_name="vector_ai_profile"): + logger.info("Creating Profile: %s", profile_name) + p = cls.enabledisable_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider, + ) + profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True, + ) + logger.info("Profile '%s' created successfully.", profile_name) + return profile + + @classmethod + async def delete_credential(cls): + logger.info("Deleting credentials...") + try: + await select_ai.async_delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + try: + await select_ai.async_delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as exc: + logger.warning("delete_credential() raised %s unexpectedly.", exc) + + async def wait_for_status_table( + self, status_table, retries=5, delay=2 + ): + for _ in range(retries): + try: + async with select_ai.async_cursor() as cursor: + await cursor.execute( + f"SELECT COUNT(*) FROM {status_table}" + ) + return await cursor.fetchone() + except oracledb.DatabaseError as exc: + if "ORA-00942" in str(exc): + await asyncio.sleep(delay) + continue + raise + return None + + async def test_5501(self): + """Disabling and enabling the vector index.""" + logger.info("Disabling vector index: %s", self.index_name) + await self.async_vector_index.disable() + logger.info("Enabling vector index: %s", self.index_name) + await self.async_vector_index.enable() + logger.info("Vector index enabled successfully") + + async def test_5502(self): + """Disable same vector index twice (should be harmless).""" + logger.info("First disable of vector index: %s", self.index_name) + await self.async_vector_index.disable() + logger.info( + "Attempting second disable of vector index: %s", + self.index_name, + ) + await self.async_vector_index.disable() + + async def test_5503(self): + """Enable same vector index twice (should be harmless).""" + logger.info("Enabling vector index: %s", self.index_name) + await self.async_vector_index.enable() + await self.async_vector_index.enable() + + async def test_5504(self): + """Ensure queries work after enabling the vector index.""" + logger.info("Querying test_items table after enabling vector index") + async with select_ai.async_cursor() as cursor: + await cursor.execute("select count(*) from test_items") + row_count, = await cursor.fetchone() + logger.info("Number of rows in test_items: %s", row_count) + assert row_count == 2 + df = await self.profile.run_sql(prompt="How many rows in test_items") + logger.info("run_sql returned: %s", df) + assert df is not None, "run_sql should return a DataFrame object" + + async def test_5505(self): + """Ensure queries fail after disabling the vector index.""" + logger.info( + "Disabling vector index: %s to test query blocking", + self.index_name, + ) + await self.async_vector_index.disable() + logger.info("Running query should raise DatabaseError") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await self.profile.run_sql(prompt="Show all rows from test_items") + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + async def test_5506(self): + """Disabling a nonexistent index raises error.""" + logger.info("Disabling nonexistent index to test error handling") + invalid_index = AsyncVectorIndex(index_name="does_not_exist") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await invalid_index.disable() + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + async def test_5507(self): + """Enabling a nonexistent index raises error.""" + logger.info("Enabling nonexistent index to test error handling") + invalid_index = AsyncVectorIndex(index_name="does_not_exist") + with pytest.raises(oracledb.DatabaseError) as exc_info: + await invalid_index.enable() + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + async def test_5508(self): + """Disabling after delete raises error; vector index recreated.""" + logger.info("Deleting vector index: %s", self.index_name) + await self.async_vector_index.delete(force=True) + logger.info("Attempting to disable deleted index") + with pytest.raises(oracledb.DatabaseError): + await self.async_vector_index.disable() + logger.info("Recreating vector index for subsequent tests") + vector_index = AsyncVectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile, + ) + await vector_index.create(replace=True) + logger.info("Vector index recreated successfully") + self.async_vector_index = await AsyncVectorIndex.fetch(self.index_name) + + async def test_5509(self): + """Pipeline inactive after disabling the vector index.""" + logger.info( + "Disabling vector index: %s to check pipeline inactivity", + self.index_name, + ) + await self.async_vector_index.disable() + pipeline_name = f"{self.index_name.upper()}$VECPIPELINE" + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "SELECT status_table FROM user_cloud_pipelines " + "WHERE pipeline_name = :1", + [pipeline_name], + ) + row = await cursor.fetchone() + if row is None: + logger.info( + "Pipeline is inactive (no entry in user_cloud_pipelines)" + ) + assert True + return + status_table = row[0] + logger.info( + "Status table found: %s, querying should fail", + status_table, + ) + with pytest.raises(oracledb.DatabaseError): + await cursor.execute( + f"SELECT * FROM {status_table} FETCH FIRST 1 ROWS ONLY" + ) + + async def test_5510(self): + """Pipeline metadata is available after enabling the vector index.""" + pipeline_name = f"{self.index_name.upper()}$VECPIPELINE" + logger.info("Checking pipeline activity after enabling vector index") + async with select_ai.async_cursor() as cursor: + await cursor.execute( + "SELECT status_table FROM user_cloud_pipelines " + "WHERE pipeline_name = :pipeline_name", + {"pipeline_name": pipeline_name}, + ) + status_row = await cursor.fetchone() + assert status_row is not None, ( + f"No pipeline entry found for {pipeline_name}" + ) + status_table = status_row[0] + logger.info("Status table found: %s", status_table) + if status_table is not None: + count_row = await self.wait_for_status_table(status_table) + assert count_row is not None, ( + f"No result returned from status_table {status_table}" + ) + assert count_row[0] >= 0, ( + "Pipeline table should be accessible when enabled" + ) + logger.info("Pipeline metadata is available after enable.") diff --git a/tests/vector_index/test_5500_enable_disable_index.py b/tests/vector_index/test_5500_enable_disable_index.py new file mode 100644 index 0000000..5603aae --- /dev/null +++ b/tests/vector_index/test_5500_enable_disable_index.py @@ -0,0 +1,360 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +import pytest +import select_ai +import oracledb +import time +from select_ai import VectorIndex + +# Set up global logger (one per module) +logger = logging.getLogger("TestEnableDisableVectorIndex") + + +@pytest.fixture(scope="class", autouse=True) +def setup_logging(): + logging.basicConfig( + format="%(asctime)s %(levelname)s %(name)s %(message)s", + level=logging.INFO + ) + + +@pytest.fixture(scope="class") +def enabledisable_params( + request, + vcidx_params, +): + request.cls.enabledisable_params = vcidx_params + + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(request, connect, enabledisable_params): + logger.info("=== Setting up TestEnableDisableVectorIndex class ===") + p = request.cls.enabledisable_params + + # 'connect' fixture from base tests/conftest.py ensures DB connection exists. + # Do NOT disconnect here; let the session fixture own lifecycle. + assert select_ai.is_connected(), "Connection to DB failed" + + logger.info("Fetching credential secrets and OCI configuration...") + + # table setup + with select_ai.cursor() as cursor: + cursor.execute("begin execute immediate 'drop table test_items purge'; exception when others then null; end;") + cursor.execute("create table test_items (id number primary key, name varchar2(50))") + cursor.execute("insert into test_items values (1, 'Alpha')") + cursor.execute("insert into test_items values (2, 'Beta')") + cursor.execute("commit") + + # test resources + request.cls.create_credential() + request.cls.profile = request.cls.create_profile() + logger.info("Setup complete.") + + # Start with clean vector index + vi_attrs = select_ai.OracleVectorIndexAttributes( + location=p["embedding_location"], + object_storage_credential_name="OBJSTORE_CRED" + ) + request.cls.vector_index_attributes = vi_attrs + request.cls.index_name = "test_vector_index" + vector_index = select_ai.VectorIndex( + index_name=request.cls.index_name, + attributes=vi_attrs, + description="Test vector index", + profile=request.cls.profile + ) + vector_index.create(replace=True) + + try: + created_indexes = [idx.index_name for idx in VectorIndex.list()] + except Exception: + created_indexes = [idx.index_name for idx in VectorIndex().list()] + assert request.cls.index_name.upper() in created_indexes, f"VectorIndex {request.cls.index_name} was not created" + + yield + + logger.info("=== Tearing down TestEnableDisableVectorIndex class ===") + try: + vector_index = VectorIndex(index_name=request.cls.index_name) + vector_index.delete(force=True) + except Exception as e: + logger.info(f"Warning: drop vector index failed: {e}") + + request.cls.delete_profile(request.cls.profile) + request.cls.delete_credential() + logger.info("Teardown complete.\n") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +@pytest.mark.usefixtures("enabledisable_params", "setup_and_teardown") +class TestEnableDisableVectorIndex: + @classmethod + def get_native_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing native credential params for: {cred_name}") + p = cls.enabledisable_params + return dict( + credential_name=cred_name, + user_ocid=p["user_ocid"], + tenancy_ocid=p["tenancy_ocid"], + private_key=p["private_key"], + fingerprint=p["fingerprint"] + ) + + @classmethod + def get_cred_param(cls, cred_name=None) -> dict: + logger.info(f"Preparing basic credential params for: {cred_name}") + p = cls.enabledisable_params + return dict( + credential_name=cred_name, + username=p["cred_username"], + password=p["cred_password"] + ) + + @classmethod + def create_credential(cls, genai_cred="GENAI_CRED", objstore_cred="OBJSTORE_CRED"): + logger.info(f"Creating credentials: {genai_cred}, {objstore_cred}") + + genai_credential = cls.get_native_cred_param(genai_cred) + try: + logger.info(f"Creating GenAI credential: {genai_cred}") + select_ai.create_credential(credential=genai_credential, replace=True) + logger.info("GenAI credential created.") + except Exception as e: + logger.error(f"create_credential() raised {e} unexpectedly.") + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + + # Only create OBJSTORE_CRED if creds are provided in env + p = cls.enabledisable_params + if p.get("cred_username") and p.get("cred_password"): + objstore_credential = cls.get_cred_param(objstore_cred) + try: + logger.info(f"Creating ObjectStore credential: {objstore_cred}") + select_ai.create_credential(credential=objstore_credential, replace=True) + logger.info("ObjectStore credential created.") + except Exception as e: + logger.error(f"create_credential() raised {e} unexpectedly.") + raise AssertionError(f"create_credential() raised {e} unexpectedly.") + else: + logger.info("Skipping ObjectStore credential creation (CRED_USERNAME/CRED_PASSWORD not set).") + + @classmethod + def create_profile(cls, profile_name="vector_ai_profile"): + logger.info(f"Creating Profile: {profile_name}") + p = cls.enabledisable_params + provider = select_ai.OCIGenAIProvider( + oci_compartment_id=p["oci_compartment_id"], + oci_apiformat="GENERIC", + embedding_model="cohere.embed-english-v3.0", + ) + profile_attributes = select_ai.ProfileAttributes( + credential_name="GENAI_CRED", + provider=provider + ) + profile = select_ai.Profile( + profile_name=profile_name, + attributes=profile_attributes, + description="OCI GENAI Profile", + replace=True + ) + logger.info(f"Profile '{profile_name}' created successfully.") + return profile + + @classmethod + def delete_profile(cls, profile): + logger.info("Deleting profile...") + try: + profile.delete() + logger.info(f"Profile '{profile.profile_name}' deleted successfully.") + except Exception as e: + logger.error(f"profile.delete() raised {e} unexpectedly.") + raise AssertionError(f"profile.delete() raised {e} unexpectedly.") + + @classmethod + def delete_credential(cls): + logger.info("Deleting credentials...") + try: + select_ai.delete_credential("GENAI_CRED", force=True) + logger.info("Deleted credential 'GENAI_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + try: + select_ai.delete_credential("OBJSTORE_CRED", force=True) + logger.info("Deleted credential 'OBJSTORE_CRED'") + except Exception as e: + logger.warning(f"delete_credential() raised {e} unexpectedly.") + + def setup_method(self, method): + logger.info(f"SetUp for {method.__name__}") + self.objstore_cred = "OBJSTORE_CRED" + self.vecidx = select_ai.VectorIndex() + self.vector_index = list(self.vecidx.list(index_name_pattern=".*"))[0] + logger.info(self.vector_index.index_name) + try: + self.vector_index.enable() + time.sleep(1) + except oracledb.DatabaseError as e: + if "ORA-20000" not in str(e): + raise + + def teardown_method(self, method): + logger.info(f"TearDown for {method.__name__}") + + def wait_for_status_table(self, cursor, status_table, retries=5, delay=2): + for _ in range(retries): + try: + cursor.execute(f"SELECT COUNT(*) FROM {status_table}") + return cursor.fetchone() + except oracledb.DatabaseError as e: + if "ORA-00942" in str(e): + time.sleep(delay) + continue + raise + return None + + def wait_for_pipeline_entry(self, cursor, pipeline_name, retries=5, delay=2): + for _ in range(retries): + cursor.execute( + "SELECT status_table FROM user_cloud_pipelines WHERE pipeline_name = :1", + [pipeline_name] + ) + row = cursor.fetchone() + if row and row[0]: + return row[0] + time.sleep(delay) + return None + + def test_5501(self): + """Disabling and enabling the vector index.""" + logger.info(f"Disabling vector index: {self.index_name}") + self.vector_index.disable() + logger.info(f"Enabling vector index: {self.index_name}") + self.vector_index.enable() + logger.info(f"Vector index enabled successfully") + + def test_5502(self): + """Disable same vector index twice (should be harmless).""" + logger.info(f"First disable of vector index: {self.index_name}") + self.vector_index.disable() + logger.info(f"Attempting second disable of vector index: {self.index_name}") + self.vector_index.disable() + + def test_5503(self): + """Enable same vector index twice (should be harmless).""" + logger.info(f"Enabling vector index: {self.index_name}") + self.vector_index.enable() + self.vector_index.enable() + + def test_5504(self): + """Ensure queries work after enabling the vector index.""" + logger.info("Querying test_items table after enabling vector index") + with select_ai.cursor() as cursor: + cursor.execute("select count(*) from test_items") + row_count, = cursor.fetchone() + logger.info(f"Number of rows in test_items: {row_count}") + assert row_count == 2 + df = self.profile.run_sql(prompt="How many rows in test_items") + logger.info(f"run_sql returned: {df}") + assert df is not None, "run_sql should return a DataFrame object" + + def test_5505(self): + """Ensure queries fail after disabling the vector index.""" + logger.info(f"Disabling vector index: {self.index_name} to test query blocking") + self.vector_index.disable() + logger.info(f"Running query should raise DatabaseError") + with pytest.raises(oracledb.DatabaseError) as exc_info: + self.profile.run_sql(prompt="Show all rows from test_items") + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + def test_5506(self): + """Disabling a nonexistent index raises error.""" + logger.info("Disabling nonexistent index to test error handling") + invalid_index = VectorIndex(index_name="does_not_exist") + with pytest.raises(oracledb.DatabaseError) as exc_info: + invalid_index.disable() + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + def test_5507(self): + """Enabling a nonexistent index raises error.""" + logger.info("Enabling nonexistent index to test error handling") + invalid_index = VectorIndex(index_name="does_not_exist") + with pytest.raises(oracledb.DatabaseError) as exc_info: + invalid_index.enable() + logger.info( + "Expected database error confirmed: %s", + exc_info.value, + ) + + def test_5508(self): + """Disabling after delete raises error; vector index recreated.""" + logger.info(f"Deleting vector index: {self.index_name}") + self.vector_index.delete(force=True) + logger.info(f"Attempting to disable deleted index") + with pytest.raises(oracledb.DatabaseError): + self.vector_index.disable() + logger.info(f"Recreating vector index for subsequent tests") + vector_index = select_ai.VectorIndex( + index_name=self.index_name, + attributes=self.vector_index_attributes, + description="Test vector index", + profile=self.profile + ) + vector_index.create(replace=True) + logger.info(f"Vector index recreated successfully") + + def test_5509(self): + """Pipeline inactive after disabling the vector index.""" + logger.info(f"Disabling vector index: {self.index_name} to check pipeline inactivity") + self.vector_index.disable() + pipeline_name = f"{self.index_name.upper()}$VECPIPELINE" + with select_ai.cursor() as cursor: + cursor.execute( + "SELECT status_table FROM user_cloud_pipelines WHERE pipeline_name = :1", + [pipeline_name] + ) + row = cursor.fetchone() + if row is None: + logger.info(f"Pipeline is inactive (no entry in user_cloud_pipelines)") + assert True + return + status_table = row[0] + logger.info(f"Status table found: {status_table}, querying should fail") + with pytest.raises(oracledb.DatabaseError): + cursor.execute(f"SELECT * FROM {status_table} FETCH FIRST 1 ROWS ONLY") + + def test_5510(self): + """Pipeline metadata is available after enabling the vector index.""" + pipeline_name = f"{self.index_name.upper()}$VECPIPELINE" + logger.info(f"Checking pipeline activity after enabling vector index") + with select_ai.cursor() as cursor: + cursor.execute( + "SELECT status_table FROM user_cloud_pipelines WHERE pipeline_name = :pipeline_name", + {"pipeline_name": pipeline_name} + ) + (status_table,) = cursor.fetchone() + logger.info(f"Status table found: {status_table}") + if status_table is not None: + count_row = self.wait_for_status_table(cursor, status_table) + else: + count_row = None + if status_table is not None: + assert count_row is not None, f"No result returned from status_table {status_table}" + assert count_row[0] >= 0, "Pipeline table should be accessible when enabled" + logger.info("Pipeline metadata is available after enable.")