From 5fe3fad184f01bf6afb1ce8faba74cf4de721fe8 Mon Sep 17 00:00:00 2001 From: Prateek Saxena Date: Wed, 1 Apr 2026 14:54:28 +0000 Subject: [PATCH 1/6] Added provider and unit test coverage --- tests/providers/conftest.py | 229 +++++++++++++++ tests/providers/test_5000_openai_provider.py | 25 ++ tests/providers/test_5100_azure_provider.py | 25 ++ tests/providers/test_5200_oci_provider.py | 25 ++ tests/providers/test_5300_cohere_provider.py | 25 ++ tests/providers/test_5400_google_provider.py | 25 ++ .../providers/test_5500_anthropic_provider.py | 25 ++ .../test_5600_huggingface_provider.py | 25 ++ tests/providers/test_5700_aws_provider.py | 24 ++ tests/unit/conftest.py | 267 ++++++++++++++++++ tests/unit/test_4000_abc.py | 81 ++++++ tests/unit/test_4100_provider.py | 88 ++++++ tests/unit/test_4200_db.py | 180 ++++++++++++ tests/unit/test_4300_credential.py | 121 ++++++++ tests/unit/test_4400_privilege.py | 86 ++++++ tests/unit/test_4500_base_profile.py | 199 +++++++++++++ tests/unit/test_4600_validations.py | 62 ++++ tests/unit/test_4700_errors.py | 82 ++++++ 18 files changed, 1594 insertions(+) create mode 100644 tests/providers/conftest.py create mode 100644 tests/providers/test_5000_openai_provider.py create mode 100644 tests/providers/test_5100_azure_provider.py create mode 100644 tests/providers/test_5200_oci_provider.py create mode 100644 tests/providers/test_5300_cohere_provider.py create mode 100644 tests/providers/test_5400_google_provider.py create mode 100644 tests/providers/test_5500_anthropic_provider.py create mode 100644 tests/providers/test_5600_huggingface_provider.py create mode 100644 tests/providers/test_5700_aws_provider.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/test_4000_abc.py create mode 100644 tests/unit/test_4100_provider.py create mode 100644 tests/unit/test_4200_db.py create mode 100644 tests/unit/test_4300_credential.py create mode 100644 tests/unit/test_4400_privilege.py create mode 100644 tests/unit/test_4500_base_profile.py create mode 100644 tests/unit/test_4600_validations.py create mode 100644 tests/unit/test_4700_errors.py diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py new file mode 100644 index 0000000..5c7836c --- /dev/null +++ b/tests/providers/conftest.py @@ -0,0 +1,229 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import contextlib +import json +import os +import uuid + +import pytest +import select_ai + +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "UNLIMITED TABLESPACE", +) + +_PROVIDER_CLASSES = { + "openai": select_ai.OpenAIProvider, + "azure": select_ai.AzureProvider, + "oci": select_ai.OCIGenAIProvider, + "cohere": select_ai.CohereProvider, + "google": select_ai.GoogleProvider, + "anthropic": select_ai.AnthropicProvider, + "huggingface": select_ai.HuggingFaceProvider, + "aws": select_ai.AWSProvider, +} + + +@pytest.fixture(autouse=True, scope="module") +def connect(): + yield + + +@pytest.fixture(autouse=True, scope="module") +def async_connect(): + yield + + +@pytest.fixture(autouse=True, scope="module") +def oci_credential(): + yield {} + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +def _env(name, default=None): + return os.environ.get(name, default) + + +def _required_env(name, fallback=None): + value = _env(name, fallback) + if not value: + pytest.skip(f"missing environment variable {name}") + return value + + +def _ensure_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + with select_ai.db.get_connection() as conn: + conn.commit() + + +def _grant_basic_schema_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + with select_ai.db.get_connection() as conn: + conn.commit() + + +class ProviderTestEnv: + + def __init__(self): + self.test_user = _required_env( + "PYSAI_PROVIDER_TEST_USER", _env("PYSAI_TEST_USER") + ) + self.test_user_password = _required_env( + "PYSAI_PROVIDER_TEST_USER_PASSWORD", + _env("PYSAI_TEST_USER_PASSWORD"), + ) + self.admin_user = _required_env( + "PYSAI_PROVIDER_TEST_ADMIN_USER", _env("PYSAI_TEST_ADMIN_USER") + ) + self.admin_password = _required_env( + "PYSAI_PROVIDER_TEST_ADMIN_PASSWORD", + _env("PYSAI_TEST_ADMIN_PASSWORD"), + ) + self.connect_string = _required_env( + "PYSAI_PROVIDER_TEST_CONNECT_STRING", + _env("PYSAI_TEST_CONNECT_STRING"), + ) + self.wallet_location = _env( + "PYSAI_PROVIDER_TEST_WALLET_LOCATION", + _env("PYSAI_TEST_WALLET_LOCATION"), + ) + self.wallet_password = _env( + "PYSAI_PROVIDER_TEST_WALLET_PASSWORD", + _env("PYSAI_TEST_WALLET_PASSWORD"), + ) + + def connect_params(self, admin=False): + user = self.admin_user if admin else self.test_user + password = self.admin_password if admin else self.test_user_password + params = { + "user": user, + "password": password, + "dsn": self.connect_string, + } + if self.wallet_location: + params["wallet_location"] = self.wallet_location + params["wallet_password"] = self.wallet_password + params["config_dir"] = self.wallet_location + return params + + +def _provider_json_env(provider_name, suffix): + env_name = f"PYSAI_PROVIDER_{provider_name.upper()}_{suffix}" + raw = os.environ.get(env_name) + if not raw: + pytest.skip(f"missing environment variable {env_name}") + return json.loads(raw) + + +def _provider_prompt(provider_name): + return os.environ.get( + f"PYSAI_PROVIDER_{provider_name.upper()}_PROMPT", + "What is a database?", + ) + + +def _build_provider(provider_name, provider_kwargs): + provider_cls = _PROVIDER_CLASSES[provider_name] + return provider_cls(**provider_kwargs) + + +@pytest.fixture(scope="session") +def provider_test_env(): + return ProviderTestEnv() + + +@pytest.fixture(autouse=True, scope="session") +def setup_test_user(provider_test_env): + select_ai.connect(**provider_test_env.connect_params(admin=True)) + _ensure_test_user_exists( + username=provider_test_env.test_user, + password=provider_test_env.test_user_password, + ) + _grant_basic_schema_privileges(username=provider_test_env.test_user) + select_ai.grant_privileges(users=[provider_test_env.test_user]) + select_ai.disconnect() + yield + + +@pytest.fixture(autouse=True, scope="module") +def provider_connection(setup_test_user, provider_test_env): + select_ai.connect(**provider_test_env.connect_params()) + yield + with contextlib.suppress(Exception): + select_ai.disconnect() + + +@pytest.fixture +def provider_profile_factory(provider_test_env): + created = [] + + def _factory(provider_name): + credential = _provider_json_env(provider_name, "CREDENTIAL_JSON") + provider_kwargs = _provider_json_env(provider_name, "PROFILE_JSON") + provider = _build_provider(provider_name, provider_kwargs) + credential_name = ( + f"PYSAI_{provider_name.upper()}_CRED_{uuid.uuid4().hex.upper()}" + ) + profile_name = ( + f"PYSAI_{provider_name.upper()}_PROFILE_{uuid.uuid4().hex.upper()}" + ) + credential["credential_name"] = credential_name + select_ai.create_credential(credential=credential, replace=True) + provider_endpoint = ( + provider.provider_endpoint + or getattr(provider.__class__, "provider_endpoint", None) + ) + if provider_endpoint: + select_ai.grant_http_access( + users=[provider_test_env.test_user], + provider_endpoint=provider_endpoint, + ) + profile = select_ai.Profile( + profile_name=profile_name, + attributes=select_ai.ProfileAttributes( + credential_name=credential_name, + provider=provider, + ), + ) + created.append((provider_endpoint, credential_name, profile)) + return profile, _provider_prompt(provider_name) + + yield _factory + + for provider_endpoint, credential_name, profile in reversed(created): + with contextlib.suppress(Exception): + profile.delete(force=True) + with contextlib.suppress(Exception): + select_ai.delete_credential(credential_name, force=True) + if provider_endpoint: + with contextlib.suppress(Exception): + select_ai.revoke_http_access( + users=[provider_test_env.test_user], + provider_endpoint=provider_endpoint, + ) diff --git a/tests/providers/test_5000_openai_provider.py b/tests/providers/test_5000_openai_provider.py new file mode 100644 index 0000000..a2a9ecb --- /dev/null +++ b/tests/providers/test_5000_openai_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "openai" + + +@pytest.fixture +def openai_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5000_openai_provider(openai_profile): + profile, prompt = openai_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5100_azure_provider.py b/tests/providers/test_5100_azure_provider.py new file mode 100644 index 0000000..5a52612 --- /dev/null +++ b/tests/providers/test_5100_azure_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "azure" + + +@pytest.fixture +def azure_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5100_azure_provider(azure_profile): + profile, prompt = azure_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5200_oci_provider.py b/tests/providers/test_5200_oci_provider.py new file mode 100644 index 0000000..3cfbc86 --- /dev/null +++ b/tests/providers/test_5200_oci_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "oci" + + +@pytest.fixture +def oci_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5200_oci_provider(oci_profile): + profile, prompt = oci_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5300_cohere_provider.py b/tests/providers/test_5300_cohere_provider.py new file mode 100644 index 0000000..a878765 --- /dev/null +++ b/tests/providers/test_5300_cohere_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "cohere" + + +@pytest.fixture +def cohere_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5300_cohere_provider(cohere_profile): + profile, prompt = cohere_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5400_google_provider.py b/tests/providers/test_5400_google_provider.py new file mode 100644 index 0000000..97fcae5 --- /dev/null +++ b/tests/providers/test_5400_google_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "google" + + +@pytest.fixture +def google_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5400_google_provider(google_profile): + profile, prompt = google_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5500_anthropic_provider.py b/tests/providers/test_5500_anthropic_provider.py new file mode 100644 index 0000000..199755a --- /dev/null +++ b/tests/providers/test_5500_anthropic_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "anthropic" + + +@pytest.fixture +def anthropic_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5500_anthropic_provider(anthropic_profile): + profile, prompt = anthropic_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5600_huggingface_provider.py b/tests/providers/test_5600_huggingface_provider.py new file mode 100644 index 0000000..7d5b5a8 --- /dev/null +++ b/tests/providers/test_5600_huggingface_provider.py @@ -0,0 +1,25 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "huggingface" + + +@pytest.fixture +def huggingface_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5600_huggingface_provider(huggingface_profile): + profile, prompt = huggingface_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() + diff --git a/tests/providers/test_5700_aws_provider.py b/tests/providers/test_5700_aws_provider.py new file mode 100644 index 0000000..23d8239 --- /dev/null +++ b/tests/providers/test_5700_aws_provider.py @@ -0,0 +1,24 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +pytestmark = pytest.mark.provider + +PROVIDER_NAME = "aws" + + +@pytest.fixture +def aws_profile(provider_profile_factory): + return provider_profile_factory(PROVIDER_NAME) + + +def test_5700_aws_provider(aws_profile): + profile, prompt = aws_profile + response = profile.narrate(prompt) + assert isinstance(response, str) + assert response.strip() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..7131e0a --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,267 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import contextlib +from types import SimpleNamespace + +import pytest + + +@pytest.fixture(autouse=True, scope="session") +def setup_test_user(): + yield + + +@pytest.fixture(autouse=True, scope="module") +def connect(): + yield + + +@pytest.fixture(autouse=True, scope="module") +def async_connect(): + yield + + +@pytest.fixture(autouse=True, scope="module") +def oci_credential(): + yield {} + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +class FakeCursor: + + def __init__(self): + self.execute_calls = [] + self.callproc_calls = [] + self.callfunc_calls = [] + self.fetchone_result = None + self.fetchall_result = [] + self.closed = False + + def execute(self, statement, *args, **kwargs): + self.execute_calls.append((statement, args, kwargs)) + return None + + def callproc(self, name, *args, **kwargs): + self.callproc_calls.append((name, args, kwargs)) + return None + + def callfunc(self, name, return_type, *args, **kwargs): + self.callfunc_calls.append((name, return_type, args, kwargs)) + return None + + def fetchone(self): + return self.fetchone_result + + def fetchall(self): + return self.fetchall_result + + def close(self): + self.closed = True + + +class FakeAsyncCursor: + + def __init__(self): + self.execute_calls = [] + self.callproc_calls = [] + self.callfunc_calls = [] + self.fetchone_result = None + self.fetchall_result = [] + self.closed = False + + async def execute(self, statement, *args, **kwargs): + self.execute_calls.append((statement, args, kwargs)) + return None + + async def callproc(self, name, *args, **kwargs): + self.callproc_calls.append((name, args, kwargs)) + return None + + async def callfunc(self, name, return_type, *args, **kwargs): + self.callfunc_calls.append((name, return_type, args, kwargs)) + return None + + async def fetchone(self): + return self.fetchone_result + + async def fetchall(self): + return self.fetchall_result + + def close(self): + self.closed = True + + +class FakeConnection: + + def __init__(self, cursor_factory=None, ping_error=None): + self.cursor_factory = cursor_factory or FakeCursor + self.ping_error = ping_error + self.closed = False + self.ping_count = 0 + + def ping(self): + self.ping_count += 1 + if self.ping_error is not None: + raise self.ping_error + + def cursor(self): + return self.cursor_factory() + + def close(self): + self.closed = True + + +class FakeAsyncConnection: + + def __init__(self, cursor_factory=None, ping_error=None): + self.cursor_factory = cursor_factory or FakeAsyncCursor + self.ping_error = ping_error + self.closed = False + self.ping_count = 0 + + async def ping(self): + self.ping_count += 1 + if self.ping_error is not None: + raise self.ping_error + + def cursor(self): + return self.cursor_factory() + + async def close(self): + self.closed = True + + +class FakePool: + + def __init__(self, acquired_connection=None, acquire_error=None): + self.acquired_connection = acquired_connection or FakeConnection() + self.acquire_error = acquire_error + self.released = [] + self.closed = False + self.close_force = None + + def acquire(self): + if self.acquire_error is not None: + raise self.acquire_error + return self.acquired_connection + + def release(self, connection): + self.released.append(connection) + + def close(self, force=False): + self.closed = True + self.close_force = force + + +class FakeAsyncPool: + + def __init__(self, acquired_connection=None, acquire_error=None): + self.acquired_connection = acquired_connection or FakeAsyncConnection() + self.acquire_error = acquire_error + self.released = [] + self.closed = False + self.close_force = None + + async def acquire(self): + if self.acquire_error is not None: + raise self.acquire_error + return self.acquired_connection + + async def release(self, connection): + self.released.append(connection) + + async def close(self, force=False): + self.closed = True + self.close_force = force + + +class FakeLOB: + + def __init__(self, value): + self.value = value + + def read(self): + return self.value + + +class FakeAsyncLOB: + + def __init__(self, value): + self.value = value + + async def read(self): + return self.value + + +class FakeDatabaseError(Exception): + + def __init__(self, code): + super().__init__(code) + self.args = (SimpleNamespace(code=code),) + + +@contextlib.contextmanager +def sync_cursor_manager(cursor): + yield cursor + + +@contextlib.asynccontextmanager +async def async_cursor_manager(cursor): + yield cursor + + +@pytest.fixture +def error_factory(): + def _factory(code): + return FakeDatabaseError(code) + + return _factory + + +@pytest.fixture +def fake_cursor(): + return FakeCursor() + + +@pytest.fixture +def fake_async_cursor(): + return FakeAsyncCursor() + + +@pytest.fixture +def fake_connection(): + return FakeConnection() + + +@pytest.fixture +def fake_async_connection(): + return FakeAsyncConnection() + + +@pytest.fixture +def fake_pool(): + return FakePool() + + +@pytest.fixture +def fake_async_pool(): + return FakeAsyncPool() + + +@pytest.fixture +def fake_lob(): + return FakeLOB("lob-value") + + +@pytest.fixture +def fake_async_lob(): + return FakeAsyncLOB("async-lob-value") diff --git a/tests/unit/test_4000_abc.py b/tests/unit/test_4000_abc.py new file mode 100644 index 0000000..d7fdf1c --- /dev/null +++ b/tests/unit/test_4000_abc.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import json +from dataclasses import dataclass +from typing import List, Mapping, Optional + +import pytest + +from select_ai._abc import SelectAIDataClass, _bool + +pytestmark = pytest.mark.unit + + +@dataclass +class ScalarData(SelectAIDataClass): + count: Optional[int] = None + name: Optional[str] = None + enabled: Optional[bool] = None + ratio: Optional[float] = None + + +@dataclass +class JsonData(SelectAIDataClass): + payload: Optional[List[Mapping]] = None + + +def test_4000_bool_accepts_supported_values(): + assert _bool(True) is True + assert _bool(0) is False + assert _bool("yes") is True + assert _bool("0") is False + + +def test_4001_bool_rejects_invalid_value(): + with pytest.raises(ValueError): + _bool("maybe") + + +def test_4002_dict_excludes_null_values_by_default(): + data = ScalarData(count=1, name=None) + assert data.dict() == {"count": 1} + + +def test_4003_dict_can_include_null_values(): + data = ScalarData(count=1, name=None) + assert data.dict(exclude_null=False) == { + "count": 1, + "name": None, + "enabled": None, + "ratio": None, + } + + +def test_4004_json_serializes_dictionary_payload(): + data = ScalarData(count=1, name="value") + assert json.loads(data.json()) == {"count": 1, "name": "value"} + + +def test_4005_item_access_reads_and_writes_attributes(): + data = ScalarData(count=1) + assert data["count"] == 1 + data["name"] = "updated" + assert data.name == "updated" + + +def test_4006_post_init_coerces_scalar_types(): + data = ScalarData(count="7", name=10, enabled="true", ratio="2.5") + assert data.count == 7 + assert data.name == "10" + assert data.enabled is True + assert data.ratio == 2.5 + + +def test_4007_post_init_decodes_json_fields(): + data = JsonData(payload='[{"owner": "SH", "name": "EMP"}]') + assert data.payload == [{"owner": "SH", "name": "EMP"}] diff --git a/tests/unit/test_4100_provider.py b/tests/unit/test_4100_provider.py new file mode 100644 index 0000000..bf77b71 --- /dev/null +++ b/tests/unit/test_4100_provider.py @@ -0,0 +1,88 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +from select_ai.provider import ( + ANTHROPIC, + AWS, + AZURE, + COHERE, + GOOGLE, + HUGGINGFACE, + OCI, + OPENAI, + AWSProvider, + AnthropicProvider, + AzureProvider, + CohereProvider, + GoogleProvider, + HuggingFaceProvider, + OCIGenAIProvider, + OpenAIProvider, + Provider, +) + +pytestmark = pytest.mark.unit + + +@pytest.mark.parametrize( + ("provider_name", "expected_type"), + [ + (OPENAI, OpenAIProvider), + (AZURE, AzureProvider), + (OCI, OCIGenAIProvider), + (COHERE, CohereProvider), + (GOOGLE, GoogleProvider), + (HUGGINGFACE, HuggingFaceProvider), + (AWS, AWSProvider), + (ANTHROPIC, AnthropicProvider), + ], +) +def test_4100_create_returns_expected_provider_subclass( + provider_name, expected_type +): + provider = Provider.create(provider_name=provider_name) + assert isinstance(provider, expected_type) + + +def test_4101_create_falls_back_to_base_provider_for_unknown_name(): + provider = Provider.create(provider_name="custom") + assert type(provider) is Provider + + +def test_4102_key_alias_maps_provider_fields(): + assert Provider.key_alias("provider") == "provider_name" + assert Provider.key_alias("provider_name") == "provider" + assert Provider.key_alias("model") == "model" + + +def test_4103_keys_contains_provider_specific_fields(): + keys = Provider.keys() + assert "provider" in keys + assert "provider_endpoint" in keys + assert "azure_resource_name" in keys + assert "oci_compartment_id" in keys + assert "aws_apiformat" in keys + + +def test_4104_azure_provider_sets_endpoint_from_resource_name(): + provider = AzureProvider(azure_resource_name="demo-resource") + assert provider.provider_endpoint == "demo-resource.openai.azure.com" + + +def test_4105_aws_provider_sets_endpoint_from_region(): + provider = AWSProvider(region="us-phoenix-1") + assert provider.provider_endpoint == "bedrock-runtime.us-phoenix-1.amazonaws.com" + + +def test_4106_default_provider_endpoints_are_exposed(): + assert OpenAIProvider().provider_endpoint == "api.openai.com" + assert CohereProvider.provider_endpoint == "api.cohere.ai" + assert GoogleProvider.provider_endpoint == "generativelanguage.googleapis.com" + assert HuggingFaceProvider.provider_endpoint == "api-inference.huggingface.co" + assert AnthropicProvider.provider_endpoint == "api.anthropic.com" diff --git a/tests/unit/test_4200_db.py b/tests/unit/test_4200_db.py new file mode 100644 index 0000000..b9d898a --- /dev/null +++ b/tests/unit/test_4200_db.py @@ -0,0 +1,180 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import os +from threading import get_ident + +import pytest + +import select_ai.db as db +from select_ai.errors import DatabaseNotConnectedError + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def reset_db_state(): + db.__conn__.clear() + db.__async_conn__.clear() + db.__pool__.clear() + db.__async_pool__.clear() + yield + db.__conn__.clear() + db.__async_conn__.clear() + db.__pool__.clear() + db.__async_pool__.clear() + + +def test_4200_connect_stores_thread_local_connection(monkeypatch, fake_connection): + captured = {} + + def fake_connect(**kwargs): + captured.update(kwargs) + return fake_connection + + monkeypatch.setattr(db.oracledb, "connect", fake_connect) + db.connect(user="user", password="password", dsn="dsn") + assert captured["connection_id_prefix"] == "python-select-ai" + assert db.__conn__[(os.getpid(), get_ident())] is fake_connection + + +@pytest.mark.anyio +async def test_4201_async_connect_stores_thread_local_connection( + monkeypatch, fake_async_connection +): + captured = {} + + async def fake_connect_async(**kwargs): + captured.update(kwargs) + return fake_async_connection + + monkeypatch.setattr(db.oracledb, "connect_async", fake_connect_async) + await db.async_connect(user="user", password="password", dsn="dsn") + assert captured["connection_id_prefix"] == "async-python-select-ai" + assert db.__async_conn__[(os.getpid(), get_ident())] is fake_async_connection + + +def test_4202_create_pool_stores_process_pool(monkeypatch, fake_pool): + captured = {} + + def fake_create_pool(**kwargs): + captured.update(kwargs) + return fake_pool + + monkeypatch.setattr(db.oracledb, "create_pool", fake_create_pool) + db.create_pool(user="user", password="password", dsn="dsn") + assert captured["connection_id_prefix"] == "python-select-ai" + assert db.__pool__[os.getpid()] is fake_pool + + +def test_4203_create_pool_async_stores_process_pool(monkeypatch, fake_async_pool): + captured = {} + + def fake_create_pool_async(**kwargs): + captured.update(kwargs) + return fake_async_pool + + monkeypatch.setattr(db.oracledb, "create_pool_async", fake_create_pool_async) + db.create_pool_async(user="user", password="password", dsn="dsn") + assert captured["connection_id_prefix"] == "async-python-select-ai" + assert db.__async_pool__[os.getpid()] is fake_async_pool + + +def test_4204_connection_manager_rejects_standalone_and_pool_together( + fake_connection, fake_pool +): + db.__conn__[(os.getpid(), get_ident())] = fake_connection + db.__pool__[os.getpid()] = fake_pool + with pytest.raises(ValueError): + db.ConnectionManager() + + +def test_4205_connection_manager_yields_standalone_connection(fake_connection): + db.__conn__[(os.getpid(), get_ident())] = fake_connection + with db.ConnectionManager().get_connection() as connection: + assert connection is fake_connection + assert fake_connection.ping_count == 1 + + +def test_4206_connection_manager_yields_pool_connection(fake_pool): + db.__pool__[os.getpid()] = fake_pool + with db.ConnectionManager().get_connection() as connection: + assert connection is fake_pool.acquired_connection + assert fake_pool.released == [fake_pool.acquired_connection] + + +def test_4207_connection_manager_raises_when_not_connected(): + with pytest.raises(DatabaseNotConnectedError): + with db.ConnectionManager().get_connection(): + pass + + +def test_4208_is_connected_returns_true_for_healthy_connection(fake_connection): + db.__conn__[(os.getpid(), get_ident())] = fake_connection + assert db.is_connected() is True + + +def test_4209_is_connected_returns_false_for_unhealthy_connection( + monkeypatch, fake_connection +): + class FakeDbError(Exception): + pass + + monkeypatch.setattr(db.oracledb, "DatabaseError", FakeDbError) + monkeypatch.setattr(db.oracledb, "InterfaceError", FakeDbError) + fake_connection.ping_error = FakeDbError() + db.__conn__[(os.getpid(), get_ident())] = fake_connection + assert db.is_connected() is False + + +def test_4210_cursor_closes_cursor_after_use(fake_cursor, fake_connection): + fake_connection.cursor_factory = lambda: fake_cursor + db.__conn__[(os.getpid(), get_ident())] = fake_connection + with db.cursor() as cursor: + assert cursor is fake_cursor + assert fake_cursor.closed is True + + +def test_4211_disconnect_closes_standalone_connection(fake_connection): + db.__conn__[(os.getpid(), get_ident())] = fake_connection + db.disconnect() + assert fake_connection.closed is True + assert db.__conn__ == {} + + +def test_4212_disconnect_closes_pool(fake_pool): + db.__pool__[os.getpid()] = fake_pool + db.disconnect() + assert fake_pool.closed is True + assert fake_pool.close_force is True + assert db.__pool__ == {} + + +@pytest.mark.anyio +async def test_4213_async_cursor_closes_cursor_after_use( + fake_async_cursor, fake_async_connection +): + fake_async_connection.cursor_factory = lambda: fake_async_cursor + db.__async_conn__[(os.getpid(), get_ident())] = fake_async_connection + async with db.async_cursor() as cursor: + assert cursor is fake_async_cursor + assert fake_async_cursor.closed is True + + +@pytest.mark.anyio +async def test_4214_async_is_connected_returns_true(fake_async_connection): + db.__async_conn__[(os.getpid(), get_ident())] = fake_async_connection + assert await db.async_is_connected() is True + + +@pytest.mark.anyio +async def test_4215_async_disconnect_closes_pool(fake_async_pool): + db.__async_pool__[os.getpid()] = fake_async_pool + await db.async_disconnect() + assert fake_async_pool.closed is True + assert db.__async_pool__ == {} + diff --git a/tests/unit/test_4300_credential.py b/tests/unit/test_4300_credential.py new file mode 100644 index 0000000..5fedd56 --- /dev/null +++ b/tests/unit/test_4300_credential.py @@ -0,0 +1,121 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import contextlib +from unittest.mock import AsyncMock, Mock + +import pytest + +import select_ai.credential as credential + +pytestmark = pytest.mark.unit + + +@contextlib.contextmanager +def _cursor_manager(cursor): + yield cursor + + +@contextlib.asynccontextmanager +async def _async_cursor_manager(cursor): + yield cursor + + +def test_4300_validate_credential_accepts_supported_keys(): + credential._validate_credential( + { + "credential_name": "cred", + "username": "user", + "password": "secret", + "comments": "demo", + } + ) + + +def test_4301_validate_credential_rejects_invalid_keys(): + with pytest.raises(ValueError): + credential._validate_credential({"credential_name": "cred", "token": "x"}) + + +def test_4302_create_credential_calls_dbms_cloud_create(monkeypatch): + callproc = Mock() + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "cursor", lambda: _cursor_manager(fake_cursor) + ) + credential.create_credential({"credential_name": "cred", "username": "openai"}) + callproc.assert_called_once_with( + "DBMS_CLOUD.CREATE_CREDENTIAL", + keyword_parameters={"credential_name": "cred", "username": "openai"}, + ) + + +def test_4303_create_credential_replaces_on_duplicate(monkeypatch, error_factory): + callproc = Mock(side_effect=[error_factory(20022), None, None]) + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "cursor", lambda: _cursor_manager(fake_cursor) + ) + monkeypatch.setattr(credential.oracledb, "DatabaseError", type(error_factory(1))) + payload = {"credential_name": "cred", "username": "openai"} + credential.create_credential(payload, replace=True) + assert callproc.call_count == 3 + + +def test_4304_create_credential_reraises_unknown_errors(monkeypatch, error_factory): + fake_error = type(error_factory(1)) + callproc = Mock(side_effect=error_factory(20999)) + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "cursor", lambda: _cursor_manager(fake_cursor) + ) + monkeypatch.setattr(credential.oracledb, "DatabaseError", fake_error) + with pytest.raises(fake_error): + credential.create_credential({"credential_name": "cred"}) + + +def test_4305_delete_credential_ignores_missing_when_forced( + monkeypatch, error_factory +): + fake_error = type(error_factory(1)) + callproc = Mock(side_effect=error_factory(20004)) + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "cursor", lambda: _cursor_manager(fake_cursor) + ) + monkeypatch.setattr(credential.oracledb, "DatabaseError", fake_error) + credential.delete_credential("cred", force=True) + + +@pytest.mark.anyio +async def test_4306_async_create_credential_replaces_on_duplicate( + monkeypatch, error_factory +): + fake_error = type(error_factory(1)) + callproc = AsyncMock(side_effect=[error_factory(20022), None, None]) + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "async_cursor", lambda: _async_cursor_manager(fake_cursor) + ) + monkeypatch.setattr(credential.oracledb, "DatabaseError", fake_error) + payload = {"credential_name": "cred", "username": "openai"} + await credential.async_create_credential(payload, replace=True) + assert callproc.await_count == 3 + + +@pytest.mark.anyio +async def test_4307_async_delete_credential_ignores_missing_when_forced( + monkeypatch, error_factory +): + fake_error = type(error_factory(1)) + callproc = AsyncMock(side_effect=error_factory(20004)) + fake_cursor = Mock(callproc=callproc) + monkeypatch.setattr( + credential, "async_cursor", lambda: _async_cursor_manager(fake_cursor) + ) + monkeypatch.setattr(credential.oracledb, "DatabaseError", fake_error) + await credential.async_delete_credential("cred", force=True) diff --git a/tests/unit/test_4400_privilege.py b/tests/unit/test_4400_privilege.py new file mode 100644 index 0000000..aacd078 --- /dev/null +++ b/tests/unit/test_4400_privilege.py @@ -0,0 +1,86 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import contextlib +from unittest.mock import AsyncMock, Mock + +import pytest + +import select_ai.privilege as privilege + +pytestmark = pytest.mark.unit + + +@contextlib.contextmanager +def _cursor_manager(cursor): + yield cursor + + +@contextlib.asynccontextmanager +async def _async_cursor_manager(cursor): + yield cursor + + +def test_4400_grant_privileges_normalizes_single_user(monkeypatch): + fake_cursor = Mock(execute=Mock()) + monkeypatch.setattr(privilege, "cursor", lambda: _cursor_manager(fake_cursor)) + privilege.grant_privileges(" DEMO_USER ") + fake_cursor.execute.assert_called_once() + statement = fake_cursor.execute.call_args.args[0] + assert "DEMO_USER" in statement + + +def test_4401_revoke_privileges_accepts_multiple_users(monkeypatch): + fake_cursor = Mock(execute=Mock()) + monkeypatch.setattr(privilege, "cursor", lambda: _cursor_manager(fake_cursor)) + privilege.revoke_privileges(["USER_ONE", "USER_TWO"]) + assert fake_cursor.execute.call_count == 2 + + +def test_4402_grant_http_access_passes_host_parameter(monkeypatch): + fake_cursor = Mock(execute=Mock()) + monkeypatch.setattr(privilege, "cursor", lambda: _cursor_manager(fake_cursor)) + privilege.grant_http_access(["USER_ONE", "USER_TWO"], "api.openai.com") + assert fake_cursor.execute.call_count == 2 + _, kwargs = fake_cursor.execute.call_args + assert kwargs == {"user": "USER_TWO", "host": "api.openai.com"} + + +def test_4403_revoke_http_access_passes_host_parameter(monkeypatch): + fake_cursor = Mock(execute=Mock()) + monkeypatch.setattr(privilege, "cursor", lambda: _cursor_manager(fake_cursor)) + privilege.revoke_http_access("USER_ONE", "api.openai.com") + fake_cursor.execute.assert_called_once_with( + privilege.DISABLE_AI_PROFILE_DOMAIN_FOR_USER, + user="USER_ONE", + host="api.openai.com", + ) + + +@pytest.mark.anyio +async def test_4404_async_grant_privileges_normalizes_single_user(monkeypatch): + fake_cursor = Mock(execute=AsyncMock()) + monkeypatch.setattr( + privilege, "async_cursor", lambda: _async_cursor_manager(fake_cursor) + ) + await privilege.async_grant_privileges(" DEMO_USER ") + statement = fake_cursor.execute.call_args.args[0] + assert "DEMO_USER" in statement + + +@pytest.mark.anyio +async def test_4405_async_http_access_helpers_use_expected_parameters(monkeypatch): + fake_cursor = Mock(execute=AsyncMock()) + monkeypatch.setattr( + privilege, "async_cursor", lambda: _async_cursor_manager(fake_cursor) + ) + await privilege.async_grant_http_access("USER_ONE", "api.openai.com") + fake_cursor.execute.assert_awaited_once_with( + privilege.ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + user="USER_ONE", + host="api.openai.com", + ) diff --git a/tests/unit/test_4500_base_profile.py b/tests/unit/test_4500_base_profile.py new file mode 100644 index 0000000..ecab03b --- /dev/null +++ b/tests/unit/test_4500_base_profile.py @@ -0,0 +1,199 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import datetime +import json + +import pytest + +import select_ai.base_profile as base_profile +from select_ai.action import Action +from select_ai.base_profile import ( + BaseProfile, + ProfileAttributes, + convert_json_rows_to_df, + no_data_for_prompt, + validate_params_for_feedback, + validate_params_for_summary, +) +from select_ai.feedback import FeedbackOperation, FeedbackType +from select_ai.provider import OpenAIProvider +from select_ai.summary import SummaryParams + +pytestmark = pytest.mark.unit + + +def test_4500_profile_attributes_reject_invalid_provider_type(): + with pytest.raises(ValueError): + ProfileAttributes(provider="invalid") + + +def test_4501_profile_attributes_json_flattens_provider_fields(): + attributes = ProfileAttributes( + credential_name="cred", + provider=OpenAIProvider(model="gpt-4o-mini"), + ) + payload = json.loads(attributes.json()) + assert payload["credential_name"] == "cred" + assert payload["provider"] == "openai" + assert payload["model"] == "gpt-4o-mini" + + +def test_4502_profile_attributes_create_builds_provider_subclass(): + attributes = ProfileAttributes.create( + provider="openai", + model="gpt-4o-mini", + credential_name="cred", + ) + assert isinstance(attributes.provider, OpenAIProvider) + assert attributes.provider.model == "gpt-4o-mini" + + +@pytest.mark.anyio +async def test_4503_profile_attributes_async_create_reads_async_lob( + monkeypatch, fake_async_lob +): + monkeypatch.setattr(base_profile.oracledb, "AsyncLOB", type(fake_async_lob)) + attributes = await ProfileAttributes.async_create( + provider="openai", + model=fake_async_lob, + credential_name="cred", + ) + assert attributes.provider.model == fake_async_lob.value + + +def test_4504_profile_attributes_create_reads_lob(monkeypatch, fake_lob): + monkeypatch.setattr(base_profile.oracledb, "LOB", type(fake_lob)) + attributes = ProfileAttributes.create( + provider="openai", + model=fake_lob, + credential_name="cred", + ) + assert attributes.provider.model == fake_lob.value + + +def test_4505_set_attribute_updates_provider_and_profile_fields(): + attributes = ProfileAttributes( + provider=OpenAIProvider(model="gpt-4o-mini"), + temperature=0.1, + ) + attributes.set_attribute("model", "gpt-4.1-mini") + attributes.set_attribute("temperature", 0.4) + assert attributes.provider.model == "gpt-4.1-mini" + assert attributes.temperature == 0.4 + + +def test_4506_raise_error_if_profile_exists_requires_replace_or_merge(): + profile = BaseProfile( + profile_name="demo", + attributes=ProfileAttributes(provider=OpenAIProvider()), + ) + with pytest.raises(base_profile.ProfileExistsError): + profile._raise_error_if_profile_exists() + + +def test_4507_merge_attributes_prefers_saved_values_when_missing(): + profile = BaseProfile(profile_name="demo") + saved_attributes = ProfileAttributes( + provider=OpenAIProvider(model="gpt-4o-mini"), + temperature=0.2, + ) + profile._merge_attributes(saved_attributes, "saved description") + assert profile.attributes == saved_attributes + assert profile.description == "saved description" + + +def test_4508_merge_attributes_merges_non_null_values_when_merge_enabled(): + profile = BaseProfile( + profile_name="demo", + attributes=ProfileAttributes(temperature=0.9), + merge=True, + ) + saved_attributes = ProfileAttributes( + provider=OpenAIProvider(model="gpt-4o-mini"), + temperature=0.2, + ) + profile._merge_attributes(saved_attributes, "saved description") + assert profile.replace is True + assert profile.attributes.temperature == 0.9 + assert profile.attributes.provider.model == "gpt-4o-mini" + + +def test_4509_no_data_for_prompt_handles_empty_responses(): + assert no_data_for_prompt(None) is True + assert no_data_for_prompt("No data found for the prompt.") is True + assert no_data_for_prompt("result") is False + + +def test_4510_validate_params_for_feedback_builds_prompt_payload(): + params = validate_params_for_feedback( + feedback_type=FeedbackType.NEGATIVE, + feedback_content="bad ordering", + prompt_spec=("show all people", Action.SHOWSQL), + response="SELECT * FROM people", + ) + assert params["feedback_type"] == "negative" + assert params["operation"] == "add" + assert params["sql_text"] == "select ai showsql show all people" + + +def test_4511_validate_params_for_feedback_requires_response_for_negative_add(): + with pytest.raises(AttributeError): + validate_params_for_feedback( + feedback_type=FeedbackType.NEGATIVE, + feedback_content="bad ordering", + prompt_spec=("show all people", Action.SHOWSQL), + ) + + +def test_4512_validate_params_for_feedback_rejects_invalid_action(): + with pytest.raises(AttributeError): + validate_params_for_feedback( + feedback_type=FeedbackType.POSITIVE, + feedback_content="good", + prompt_spec=("show all people", Action.CHAT), + ) + + +def test_4513_validate_params_for_feedback_accepts_sql_id_only(): + params = validate_params_for_feedback( + feedback_type=FeedbackType.POSITIVE, + feedback_content="great", + sql_id="abc123", + operation=FeedbackOperation.DELETE, + ) + assert params == { + "operation": "delete", + "feedback_content": "great", + "feedback_type": "positive", + "sql_id": "abc123", + } + + +def test_4514_validate_params_for_summary_accepts_content_or_location(): + params = validate_params_for_summary( + prompt="Summarize", + content="demo content", + params=SummaryParams(min_words=10), + ) + assert params["content"] == "demo content" + assert json.loads(params["parameters"]) == {"min_words": 10} + + +def test_4515_validate_params_for_summary_rejects_invalid_source_combinations(): + with pytest.raises(AttributeError): + validate_params_for_summary() + with pytest.raises(AttributeError): + validate_params_for_summary(content="x", location_uri="y") + + +def test_4516_convert_json_rows_to_df_handles_valid_and_invalid_payloads(): + frame = convert_json_rows_to_df('[{"name": "Alice"}]') + assert list(frame.columns) == ["name"] + with pytest.raises(base_profile.InvalidSQLError): + convert_json_rows_to_df("not-json") + diff --git a/tests/unit/test_4600_validations.py b/tests/unit/test_4600_validations.py new file mode 100644 index 0000000..ce72d86 --- /dev/null +++ b/tests/unit/test_4600_validations.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +from typing import Any, Dict, Iterable, Literal, Optional, Sequence, Set, Tuple + +import pytest + +from select_ai._validations import _match, enforce_types + +pytestmark = pytest.mark.unit + + +def test_4600_match_handles_any_optional_and_literal(): + assert _match("value", Any) is True + assert _match(None, Optional[int]) is True + assert _match("small", Literal["small", "large"]) is True + assert _match("medium", Literal["small", "large"]) is False + + +def test_4601_match_handles_fixed_and_variadic_tuples(): + assert _match((1, "x"), Tuple[int, str]) is True + assert _match((1, 2, 3), Tuple[int, ...]) is True + assert _match((1, "x"), Tuple[int, int]) is False + + +def test_4602_match_handles_mappings_sequences_and_sets(): + assert _match({"a": 1}, Dict[str, int]) is True + assert _match([1, 2, 3], Sequence[int]) is True + assert _match({1, 2, 3}, Set[int]) is True + assert _match("abc", Sequence[int]) is False + + +def test_4603_match_handles_plain_and_bare_container_types(): + assert _match(5, int) is True + assert _match([1, 2], Iterable[int]) is True + assert _match("five", int) is False + + +def test_4604_enforce_types_validates_sync_functions(): + @enforce_types + def fn(name: str, count: int = 1): + return f"{name}:{count}" + + assert fn("demo", 2) == "demo:2" + with pytest.raises(TypeError): + fn("demo", "two") + + +@pytest.mark.anyio +async def test_4605_enforce_types_validates_async_functions(): + @enforce_types + async def fn(items: Sequence[int]): + return sum(items) + + assert await fn([1, 2, 3]) == 6 + with pytest.raises(TypeError): + await fn(["1"]) + diff --git a/tests/unit/test_4700_errors.py b/tests/unit/test_4700_errors.py new file mode 100644 index 0000000..4a8e776 --- /dev/null +++ b/tests/unit/test_4700_errors.py @@ -0,0 +1,82 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2026, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import pytest + +from select_ai.errors import ( + AgentAttributesEmptyError, + AgentNotFoundError, + AgentTaskAttributesEmptyError, + AgentTaskNotFoundError, + AgentTeamAttributesEmptyError, + AgentTeamNotFoundError, + AgentToolAttributesEmptyError, + AgentToolNotFoundError, + ConversationNotFoundError, + DatabaseNotConnectedError, + InvalidSQLError, + ProfileAttributesEmptyError, + ProfileExistsError, + ProfileNotFoundError, + VectorIndexNotFoundError, +) + +pytestmark = pytest.mark.unit + + +@pytest.mark.parametrize( + ("error", "expected"), + [ + ( + DatabaseNotConnectedError(), + "Not connected to the Database. Use select_ai.connect() or " + "select_ai.async_connect() to establish connection", + ), + ( + ConversationNotFoundError("conv-1"), + "Conversation with id conv-1 not found", + ), + (ProfileNotFoundError("demo"), "Profile demo not found"), + ( + ProfileExistsError("demo"), + "Profile demo already exists. Use either replace=True or merge=True", + ), + ( + ProfileAttributesEmptyError("demo"), + "Profile demo attributes empty in the database. ", + ), + (VectorIndexNotFoundError("idx"), "VectorIndex idx not found"), + ( + VectorIndexNotFoundError("idx", "demo"), + "VectorIndex idx not found for profile demo", + ), + (AgentNotFoundError("agent"), "Agent agent not found"), + ( + AgentAttributesEmptyError("agent"), + "Agent agent attributes empty in the database.", + ), + (AgentTaskNotFoundError("task"), "Agent Task task not found"), + ( + AgentTaskAttributesEmptyError("task"), + "Agent Task task attributes empty in the database.", + ), + (AgentToolNotFoundError("tool"), "Agent Tool tool not found"), + ( + AgentToolAttributesEmptyError("tool"), + "Agent tool tool attributes empty in the database.", + ), + (AgentTeamNotFoundError("team"), "Agent Team team not found"), + ( + AgentTeamAttributesEmptyError("team"), + "Agent team team attributes empty in the database.", + ), + (InvalidSQLError("bad sql"), "bad sql"), + ], +) +def test_4700_error_messages_are_stable(error, expected): + assert str(error) == expected + From c4eea7646b51597357164992c29638484c943f97 Mon Sep 17 00:00:00 2001 From: Kondra Nagabhavani Date: Fri, 27 Mar 2026 08:08:58 +0530 Subject: [PATCH 2/6] Added sync and async api tests for Agent feature (#26) * Added sync and async api tests for Agent feature * Addressed review comments and enhanced tests --- tests/agents/test_3001_async_tools.py | 744 +++++++++++++++++++++++ tests/agents/test_3001_tools.py | 658 ++++++++++++++++++++ tests/agents/test_3101_async_tasks.py | 380 ++++++++++++ tests/agents/test_3101_tasks.py | 323 ++++++++++ tests/agents/test_3201_agents.py | 491 +++++++++++++++ tests/agents/test_3201_async_agents.py | 416 +++++++++++++ tests/agents/test_3301_async_teams.py | 393 ++++++++++++ tests/agents/test_3301_teams.py | 415 +++++++++++++ tests/agents/test_3800_agente2e.py | 436 +++++++++++++ tests/agents/test_3800_async_agente2e.py | 345 +++++++++++ 10 files changed, 4601 insertions(+) create mode 100644 tests/agents/test_3001_async_tools.py create mode 100644 tests/agents/test_3001_tools.py create mode 100644 tests/agents/test_3101_async_tasks.py create mode 100644 tests/agents/test_3101_tasks.py create mode 100644 tests/agents/test_3201_agents.py create mode 100644 tests/agents/test_3201_async_agents.py create mode 100644 tests/agents/test_3301_async_teams.py create mode 100644 tests/agents/test_3301_teams.py create mode 100644 tests/agents/test_3800_agente2e.py create mode 100644 tests/agents/test_3800_async_agente2e.py diff --git a/tests/agents/test_3001_async_tools.py b/tests/agents/test_3001_async_tools.py new file mode 100644 index 0000000..ed760a3 --- /dev/null +++ b/tests/agents/test_3001_async_tools.py @@ -0,0 +1,744 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3001 - Async API coverage for select_ai.agent AsyncTool APIs +""" + +import logging +import os +import uuid + +import oracledb +import pytest +import select_ai +from select_ai.agent import AsyncTool +from select_ai.errors import AgentToolNotFoundError + +pytestmark = pytest.mark.anyio + +# Path +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3001_async_tools.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# Force logging to file (pytest-proof) +root = logging.getLogger() +root.setLevel(logging.INFO) +for handler in root.handlers[:]: + root.removeHandler(handler) +file_handler = logging.FileHandler(LOG_FILE, mode="w") +file_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(file_handler) +logger = logging.getLogger() + +UUID = uuid.uuid4().hex.upper() + +SQL_PROFILE_NAME = f"PYSAI_3001_SQL_PROFILE_{UUID}" +RAG_PROFILE_NAME = f"PYSAI_3001_RAG_PROFILE_{UUID}" + +SQL_TOOL_NAME = f"PYSAI_3001_SQL_TOOL_{UUID}" +RAG_TOOL_NAME = f"PYSAI_3001_RAG_TOOL_{UUID}" +PLSQL_TOOL_NAME = f"PYSAI_3001_PLSQL_TOOL_{UUID}" +WEB_SEARCH_TOOL_NAME = f"PYSAI_3001_WEB_TOOL_{UUID}" +PLSQL_FUNCTION_NAME = f"PYSAI_3001_CALC_AGE_{UUID}" +CUSTOM_ATTR_TOOL_NAME = f"PYSAI_3001_CUSTOM_ATTR_TOOL_{UUID}" +CUSTOM_ATTR_TOOL_DESCRIPTION = "Custom attr tool for async testing" +CUSTOM_NO_TYPE_TOOL_NAME = f"PYSAI_3001_CUSTOM_NO_TYPE_TOOL_{UUID}" +CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME = ( + f"PYSAI_3001_CUSTOM_WITH_TYPE_NO_INSTR_TOOL_{UUID}" +) +CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME = ( + f"PYSAI_3001_CUSTOM_WITH_TYPE_AND_INSTR_TOOL_{UUID}" +) +DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" +DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" +DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" + +EMAIL_TOOL_NAME = f"PYSAI_3001_EMAIL_TOOL_{UUID}" +SLACK_TOOL_NAME = f"PYSAI_3001_SLACK_TOOL_{UUID}" + +NEG_SQL_TOOL_NAME = f"PYSAI_3001_NEG_SQL_TOOL_{UUID}" +NEG_RAG_TOOL_NAME = f"PYSAI_3001_NEG_RAG_TOOL_{UUID}" +NEG_PLSQL_TOOL_NAME = f"PYSAI_3001_NEG_PLSQL_TOOL_{UUID}" + +EMAIL_CRED_NAME = f"PYSAI_3001_EMAIL_CRED_{UUID}" +SLACK_CRED_NAME = f"PYSAI_3001_SLACK_CRED_{UUID}" + +SMTP_USERNAME = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") +SMTP_PASSWORD = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") +SLACK_USERNAME = os.getenv("PYSAI_TEST_SLACK_USERNAME") +SLACK_PASSWORD = os.getenv("PYSAI_TEST_SLACK_PASSWORD") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.fixture(scope="module", autouse=True) +async def async_connect(test_env): + logger.info("Opening async database connection") + await select_ai.async_connect(**test_env.connect_params()) + yield + logger.info("Closing async database connection") + await select_ai.async_disconnect() + + +async def get_tool_status(tool_name): + logger.info("Fetching tool status for: %s", tool_name) + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT status + FROM USER_AI_AGENT_TOOLS + WHERE tool_name = :tool_name + """, + {"tool_name": tool_name}, + ) + row = await cur.fetchone() + return row[0] if row else None + + +async def assert_tool_status(tool_name: str, expected_status: str) -> None: + status = await get_tool_status(tool_name) + logger.info( + "Verifying tool status | tool=%s | expected=%s | actual=%s", + tool_name, + expected_status, + status, + ) + assert status == expected_status + + +def log_tool_details(context: str, tool) -> None: + attrs = getattr(tool, "attributes", None) + tool_params = getattr(attrs, "tool_params", None) if attrs else None + + details = { + "context": context, + "tool_name": getattr(tool, "tool_name", None), + "description": getattr(tool, "description", None), + "tool_type": str(getattr(attrs, "tool_type", None)) if attrs else None, + "instruction": getattr(attrs, "instruction", None) if attrs else None, + "function": getattr(attrs, "function", None) if attrs else None, + "tool_inputs": getattr(attrs, "tool_inputs", None) if attrs else None, + "tool_params": tool_params.dict(exclude_null=False) + if tool_params is not None + else None, + } + + logger.info("TOOL_DETAILS: %s", details) + print("TOOL_DETAILS:", details) + + +@pytest.fixture(scope="module") +async def sql_profile(profile_attributes): + logger.info("Creating SQL profile: %s", SQL_PROFILE_NAME) + profile = await select_ai.AsyncProfile( + profile_name=SQL_PROFILE_NAME, + description="SQL Profile", + attributes=profile_attributes, + ) + yield profile + logger.info("Deleting SQL profile: %s", SQL_PROFILE_NAME) + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +async def rag_profile(rag_profile_attributes): + logger.info("Creating RAG profile: %s", RAG_PROFILE_NAME) + profile = await select_ai.AsyncProfile( + profile_name=RAG_PROFILE_NAME, + description="RAG Profile", + attributes=rag_profile_attributes, + ) + yield profile + logger.info("Deleting RAG profile: %s", RAG_PROFILE_NAME) + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +async def sql_tool(sql_profile): + logger.info("Creating SQL tool: %s", SQL_TOOL_NAME) + tool = await AsyncTool.create_sql_tool( + tool_name=SQL_TOOL_NAME, + profile_name=SQL_PROFILE_NAME, + description="SQL Tool", + replace=True, + ) + yield tool + logger.info("Deleting SQL tool: %s", SQL_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def rag_tool(rag_profile): + logger.info("Creating RAG tool: %s", RAG_TOOL_NAME) + tool = await AsyncTool.create_rag_tool( + tool_name=RAG_TOOL_NAME, + profile_name=RAG_PROFILE_NAME, + description="RAG Tool", + replace=True, + ) + yield tool + logger.info("Deleting RAG tool: %s", RAG_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def plsql_function(): + logger.info("Creating PL/SQL function: %s", PLSQL_FUNCTION_NAME) + ddl = f""" + CREATE OR REPLACE FUNCTION {PLSQL_FUNCTION_NAME}(p_birth_date DATE) + RETURN NUMBER IS + BEGIN + RETURN TRUNC(MONTHS_BETWEEN(SYSDATE, p_birth_date) / 12); + END; + """ + + async with select_ai.async_cursor() as cur: + await cur.execute(ddl) + + yield + + logger.info("Dropping PL/SQL function: %s", PLSQL_FUNCTION_NAME) + async with select_ai.async_cursor() as cur: + await cur.execute(f"DROP FUNCTION {PLSQL_FUNCTION_NAME}") + + +@pytest.fixture(scope="module") +async def plsql_tool(plsql_function): + logger.info("Creating PL/SQL tool: %s", PLSQL_TOOL_NAME) + tool = await AsyncTool.create_pl_sql_tool( + tool_name=PLSQL_TOOL_NAME, + function=PLSQL_FUNCTION_NAME, + description="PL/SQL Tool", + replace=True, + ) + yield tool + logger.info("Deleting PL/SQL tool: %s", PLSQL_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def web_search_tool(): + logger.info("Creating Web Search tool: %s", WEB_SEARCH_TOOL_NAME) + tool = await AsyncTool.create_websearch_tool( + tool_name=WEB_SEARCH_TOOL_NAME, + description="Web Search Tool for testing", + credential_name="OPENAI_CRED", + replace=True, + ) + yield tool + logger.info("Deleting Web Search tool: %s", WEB_SEARCH_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def email_credential(): + logger.info("Ensuring EMAIL credential is clean: %s", EMAIL_CRED_NAME) + credential = { + "credential_name": EMAIL_CRED_NAME, + "username": SMTP_USERNAME, + "password": SMTP_PASSWORD, + } + + try: + await select_ai.async_delete_credential(EMAIL_CRED_NAME, force=True) + except Exception: + logger.info("EMAIL credential did not exist or could not be dropped") + pass + + await select_ai.async_create_credential(credential=credential, replace=True) + logger.info("Created EMAIL credential: %s", EMAIL_CRED_NAME) + yield EMAIL_CRED_NAME + + logger.info("Deleting EMAIL credential: %s", EMAIL_CRED_NAME) + try: + await select_ai.async_delete_credential(EMAIL_CRED_NAME, force=True) + except Exception: + logger.warning("Failed to delete EMAIL credential during teardown") + pass + + +@pytest.fixture(scope="module") +async def slack_credential(): + logger.info("Ensuring SLACK credential is clean: %s", SLACK_CRED_NAME) + credential = { + "credential_name": SLACK_CRED_NAME, + "username": SLACK_USERNAME, + "password": SLACK_PASSWORD, + } + + try: + await select_ai.async_delete_credential(SLACK_CRED_NAME, force=True) + except Exception: + logger.info("SLACK credential did not exist or could not be dropped") + pass + + await select_ai.async_create_credential(credential=credential, replace=True) + logger.info("Created SLACK credential: %s", SLACK_CRED_NAME) + yield SLACK_CRED_NAME + + logger.info("Deleting SLACK credential: %s", SLACK_CRED_NAME) + try: + await select_ai.async_delete_credential(SLACK_CRED_NAME, force=True) + except Exception: + logger.warning("Failed to delete SLACK credential during teardown") + pass + + +@pytest.fixture(scope="module") +async def email_tool(email_credential): + logger.info("Creating EMAIL tool: %s", EMAIL_TOOL_NAME) + tool = await AsyncTool.create_email_notification_tool( + tool_name=EMAIL_TOOL_NAME, + credential_name=EMAIL_CRED_NAME, + recipient="kondra.nagabhavani@oracle.com", + sender="bharadwaj.vulugundam@oracle.com", + smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + description="Send email", + replace=True, + ) + yield tool + logger.info("Deleting EMAIL tool: %s", EMAIL_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def slack_tool(slack_credential): + logger.info("Creating SLACK tool: %s", SLACK_TOOL_NAME) + tool = None + try: + tool = await AsyncTool.create_slack_notification_tool( + tool_name=SLACK_TOOL_NAME, + credential_name=SLACK_CRED_NAME, + slack_channel="#general", + description="slack notification", + replace=True, + ) + logger.info("SLACK tool created successfully: %s", SLACK_TOOL_NAME) + yield tool + except oracledb.DatabaseError as e: + if "ORA-20052" in str(e): + logger.info("Expected ORA-20052 during SLACK tool creation: %s", e) + yield None + else: + raise + finally: + if tool is not None: + logger.info("Deleting SLACK tool: %s", SLACK_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def neg_sql_tool(): + logger.info("Creating SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME) + tool = await AsyncTool.create_sql_tool( + tool_name=NEG_SQL_TOOL_NAME, + profile_name="NON_EXISTENT_PROFILE", + replace=True, + ) + yield tool + logger.info("Deleting SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def neg_rag_tool(): + logger.info("Creating RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME) + tool = await AsyncTool.create_rag_tool( + tool_name=NEG_RAG_TOOL_NAME, + profile_name="NON_EXISTENT_RAG_PROFILE", + replace=True, + ) + yield tool + logger.info("Deleting RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME) + await tool.delete(force=True) + + +@pytest.fixture(scope="module") +async def neg_plsql_tool(): + logger.info( + "Creating PL/SQL tool with invalid function: %s", NEG_PLSQL_TOOL_NAME + ) + tool = await AsyncTool.create_pl_sql_tool( + tool_name=NEG_PLSQL_TOOL_NAME, + function="NON_EXISTENT_FUNCTION", + replace=True, + ) + yield tool + logger.info( + "Deleting PL/SQL tool with invalid function: %s", NEG_PLSQL_TOOL_NAME + ) + await tool.delete(force=True) + + +async def test_3000_sql_tool_created(sql_tool): + logger.info("Validating SQL tool creation: %s", SQL_TOOL_NAME) + log_tool_details("test_3000_sql_tool_created", sql_tool) + assert sql_tool.tool_name == SQL_TOOL_NAME + assert sql_tool.description == "SQL Tool" + assert sql_tool.attributes.tool_type == select_ai.agent.ToolType.SQL + assert sql_tool.attributes.tool_params is not None + assert sql_tool.attributes.tool_params.profile_name == SQL_PROFILE_NAME + + +async def test_3001_rag_tool_created(rag_tool): + logger.info("Validating RAG tool creation: %s", RAG_TOOL_NAME) + log_tool_details("test_3001_rag_tool_created", rag_tool) + assert rag_tool.tool_name == RAG_TOOL_NAME + assert rag_tool.description == "RAG Tool" + assert rag_tool.attributes.tool_type == select_ai.agent.ToolType.RAG + assert rag_tool.attributes.tool_params is not None + assert rag_tool.attributes.tool_params.profile_name == RAG_PROFILE_NAME + + +async def test_3002_plsql_tool_created(plsql_tool): + logger.info("Validating PL/SQL tool creation: %s", PLSQL_TOOL_NAME) + log_tool_details("test_3002_plsql_tool_created", plsql_tool) + assert plsql_tool.tool_name == PLSQL_TOOL_NAME + assert plsql_tool.description == "PL/SQL Tool" + assert plsql_tool.attributes.tool_type is None + assert plsql_tool.attributes.function == PLSQL_FUNCTION_NAME + + +async def test_3003_list_tools(): + logger.info("Listing all tools") + tools = [tool async for tool in AsyncTool.list()] + for tool in tools: + if tool.tool_name in {SQL_TOOL_NAME, RAG_TOOL_NAME, PLSQL_TOOL_NAME}: + log_tool_details("test_3003_list_tools", tool) + tool_names = {tool.tool_name for tool in tools} + logger.info("Tools present: %s", tool_names) + assert len(tools) >= 3 + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +async def test_3004_list_tools_regex(): + logger.info("Listing tools with regex: ^PYSAI_3001_") + tools = [ + tool async for tool in AsyncTool.list(tool_name_pattern="^PYSAI_3001_") + ] + for tool in tools: + log_tool_details("test_3004_list_tools_regex", tool) + tool_names = {tool.tool_name for tool in tools} + logger.info("Matched tools: %s", tool_names) + assert len(tools) >= 3 + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +async def test_3005_fetch_tool(): + logger.info("Fetching SQL tool: %s", SQL_TOOL_NAME) + tool = await AsyncTool.fetch(SQL_TOOL_NAME) + logger.info( + "Fetched SQL tool | name=%s | type=%s | profile=%s", + tool.tool_name, + tool.attributes.tool_type, + tool.attributes.tool_params.profile_name, + ) + log_tool_details("test_3005_fetch_tool", tool) + assert tool.tool_name == SQL_TOOL_NAME + assert tool.attributes.tool_type == select_ai.agent.ToolType.SQL + assert tool.attributes.tool_params.profile_name == SQL_PROFILE_NAME + + +async def test_3006_enable_disable_sql_tool(sql_tool): + logger.info("Disabling SQL tool: %s", sql_tool.tool_name) + await sql_tool.disable() + await assert_tool_status(sql_tool.tool_name, "DISABLED") + + logger.info("Enabling SQL tool: %s", sql_tool.tool_name) + await sql_tool.enable() + await assert_tool_status(sql_tool.tool_name, "ENABLED") + + +async def test_3007_web_search_tool_created(web_search_tool): + logger.info("Validating Web Search tool creation: %s", WEB_SEARCH_TOOL_NAME) + log_tool_details("test_3007_web_search_tool_created", web_search_tool) + assert web_search_tool.tool_name == WEB_SEARCH_TOOL_NAME + assert web_search_tool.attributes.tool_type == select_ai.agent.ToolType.WEBSEARCH + assert web_search_tool.attributes.tool_params.credential_name == "OPENAI_CRED" + + +async def test_3008_email_tool_created(email_tool): + logger.info("Validating EMAIL tool creation: %s", EMAIL_TOOL_NAME) + log_tool_details("test_3008_email_tool_created", email_tool) + assert email_tool.tool_name == EMAIL_TOOL_NAME + assert str(email_tool.attributes.tool_type).upper() in ("EMAIL", "NOTIFICATION") + assert email_tool.attributes.tool_params.credential_name == EMAIL_CRED_NAME + assert email_tool.attributes.tool_params.smtp_host is not None + assert str(email_tool.attributes.tool_params.notification_type).lower() == "email" + + +async def test_3009_slack_tool_created(slack_tool): + logger.info("Validating SLACK tool creation: %s", SLACK_TOOL_NAME) + if slack_tool is not None: + log_tool_details("test_3009_slack_tool_created", slack_tool) + assert slack_tool.tool_name == SLACK_TOOL_NAME + assert str(slack_tool.attributes.tool_type).upper() in ("SLACK", "NOTIFICATION") + assert slack_tool.attributes.tool_params.credential_name == SLACK_CRED_NAME + assert str(slack_tool.attributes.tool_params.notification_type).lower() == "slack" + else: + logger.info("SLACK tool not created due to expected backend-side error") + + +async def test_3010_custom_tool_attributes_roundtrip(): + logger.info( + "Validating custom tool attribute roundtrip: instruction/tool_inputs/description" + ) + tool = AsyncTool( + tool_name=CUSTOM_ATTR_TOOL_NAME, + description=CUSTOM_ATTR_TOOL_DESCRIPTION, + attributes=select_ai.agent.ToolAttributes( + function=PLSQL_FUNCTION_NAME, + instruction="Return age in years for a birth date input", + tool_inputs=[ + { + "name": "p_birth_date", + "description": "Input birth date in DATE format", + } + ], + ), + ) + await tool.create(replace=True) + try: + fetched = await AsyncTool.fetch(CUSTOM_ATTR_TOOL_NAME) + log_tool_details("test_3009_custom_tool_attributes_roundtrip", fetched) + logger.info( + "Fetched custom tool | name=%s | description=%s | instruction=%s", + fetched.tool_name, + fetched.description, + fetched.attributes.instruction, + ) + assert fetched.tool_name == CUSTOM_ATTR_TOOL_NAME + assert fetched.description == CUSTOM_ATTR_TOOL_DESCRIPTION + assert fetched.attributes.function == PLSQL_FUNCTION_NAME + assert ( + fetched.attributes.instruction + == "Return age in years for a birth date input" + ) + assert isinstance(fetched.attributes.tool_inputs, list) + assert fetched.attributes.tool_inputs[0]["name"] == "p_birth_date" + assert "birth date" in fetched.attributes.tool_inputs[0]["description"].lower() + finally: + await tool.delete(force=True) + + +async def test_3011_custom_tool_without_tool_type(): + logger.info("Validating custom tool creation with tool_type unset") + tool = AsyncTool( + tool_name=CUSTOM_NO_TYPE_TOOL_NAME, + description="Custom tool without tool_type", + attributes=select_ai.agent.ToolAttributes( + function=PLSQL_FUNCTION_NAME, + instruction="Calculate age from birth date", + ), + ) + await tool.create(replace=True) + try: + fetched = await AsyncTool.fetch(CUSTOM_NO_TYPE_TOOL_NAME) + logger.info( + "Fetched custom tool | name=%s | type=%s | function=%s | instruction=%s", + fetched.tool_name, + fetched.attributes.tool_type, + fetched.attributes.function, + fetched.attributes.instruction, + ) + log_tool_details("test_3009_custom_tool_without_tool_type", fetched) + assert fetched.tool_name == CUSTOM_NO_TYPE_TOOL_NAME + assert fetched.attributes.tool_type is None + assert fetched.attributes.function == PLSQL_FUNCTION_NAME + assert fetched.attributes.instruction == "Calculate age from birth date" + finally: + await tool.delete(force=True) + + +async def test_3012_custom_tool_with_tool_type_without_instruction(sql_profile): + logger.info( + "Validating custom tool creation with tool_type set and instruction unset" + ) + tool = AsyncTool( + tool_name=CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME, + description="Custom tool with tool_type and no instruction", + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + ), + ) + await tool.create(replace=True) + try: + fetched = await AsyncTool.fetch(CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME) + log_tool_details( + "test_3009_custom_tool_with_tool_type_without_instruction", fetched + ) + assert fetched.tool_name == CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.instruction is not None + assert "sql" in fetched.attributes.instruction.lower() + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + await tool.delete(force=True) + + +async def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile): + logger.info( + "Validating custom tool creation with tool_type and instruction set" + ) + tool = AsyncTool( + tool_name=CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME, + description="Custom tool with tool_type and instruction", + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + instruction="Use SQL profile to answer query from relational data", + ), + ) + await tool.create(replace=True) + try: + fetched = await AsyncTool.fetch(CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME) + log_tool_details( + "test_3009_custom_tool_with_tool_type_and_instruction", fetched + ) + assert fetched.tool_name == CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.instruction is not None + assert "sql" in fetched.attributes.instruction.lower() + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + await tool.delete(force=True) + + +async def test_3014_sql_tool_with_invalid_profile_created(neg_sql_tool): + logger.info("Validating SQL tool with invalid profile") + log_tool_details("test_3010_sql_tool_with_invalid_profile_created", neg_sql_tool) + assert neg_sql_tool.tool_name == NEG_SQL_TOOL_NAME + assert neg_sql_tool.attributes.tool_type == select_ai.agent.ToolType.SQL + assert neg_sql_tool.attributes.tool_params.profile_name == "NON_EXISTENT_PROFILE" + + +async def test_3015_rag_tool_with_invalid_profile_created(neg_rag_tool): + logger.info("Validating RAG tool with invalid profile") + log_tool_details("test_3011_rag_tool_with_invalid_profile_created", neg_rag_tool) + assert neg_rag_tool.tool_name == NEG_RAG_TOOL_NAME + assert neg_rag_tool.attributes.tool_type == select_ai.agent.ToolType.RAG + assert ( + neg_rag_tool.attributes.tool_params.profile_name + == "NON_EXISTENT_RAG_PROFILE" + ) + + +async def test_3016_plsql_tool_with_invalid_function_created(neg_plsql_tool): + logger.info("Validating PL/SQL tool with invalid function") + log_tool_details( + "test_3012_plsql_tool_with_invalid_function_created", neg_plsql_tool + ) + assert neg_plsql_tool.tool_name == NEG_PLSQL_TOOL_NAME + assert neg_plsql_tool.attributes.function == "NON_EXISTENT_FUNCTION" + + +async def test_3017_fetch_non_existent_tool(): + logger.info("Fetching non-existent tool") + with pytest.raises(AgentToolNotFoundError) as exc: + await AsyncTool.fetch("TOOL_DOES_NOT_EXIST") + logger.info("Received expected error: %s", exc.value) + + +async def test_3018_list_invalid_regex(): + logger.info("Listing tools with invalid regex") + with pytest.raises(Exception) as exc: + async for _ in AsyncTool.list(tool_name_pattern="*["): + pass + logger.info("Received expected regex error: %s", exc.value) + + +async def test_3019_list_tools(): + logger.info("Listing all tools") + tools = [tool async for tool in AsyncTool.list()] + for tool in tools: + if tool.tool_name in {SQL_TOOL_NAME, RAG_TOOL_NAME, PLSQL_TOOL_NAME}: + log_tool_details("test_3015_list_tools", tool) + tool_names = {tool.tool_name for tool in tools} + logger.info("Tools present: %s", tool_names) + assert len(tools) >= 3 + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +async def test_3020_create_tool_default_status_enabled(sql_profile): + logger.info("Creating tool to validate default ENABLED status") + tool = await AsyncTool.create_built_in_tool( + tool_name=DEFAULT_STATUS_TOOL_NAME, + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams(profile_name=SQL_PROFILE_NAME), + ) + try: + await assert_tool_status(DEFAULT_STATUS_TOOL_NAME, "ENABLED") + fetched = await AsyncTool.fetch(DEFAULT_STATUS_TOOL_NAME) + log_tool_details("test_3016_create_tool_default_status_enabled", fetched) + logger.info( + "Fetched created tool | name=%s | type=%s | profile=%s", + fetched.tool_name, + fetched.attributes.tool_type, + fetched.attributes.tool_params.profile_name, + ) + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + await tool.delete(force=True) + + +async def test_3021_create_tool_with_enabled_false_sets_disabled(sql_profile): + logger.info("Creating tool with enabled=False to validate DISABLED status") + tool = AsyncTool( + tool_name=DISABLED_TOOL_NAME, + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + ), + ) + await tool.create(enabled=False, replace=True) + try: + await assert_tool_status(DISABLED_TOOL_NAME, "DISABLED") + fetched = await AsyncTool.fetch(DISABLED_TOOL_NAME) + log_tool_details( + "test_3017_create_tool_with_enabled_false_sets_disabled", fetched + ) + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + finally: + await tool.delete(force=True) + + +async def test_3022_drop_tool_force_true_non_existent(): + logger.info("Validating DROP_TOOL force=True for missing tool") + tool = AsyncTool(tool_name=DROP_FORCE_MISSING_TOOL) + await tool.delete(force=True) + status = await get_tool_status(DROP_FORCE_MISSING_TOOL) + logger.info("Status after force delete on missing tool: %s", status) + assert status is None + + +async def test_3023_drop_tool_force_false_non_existent_raises(): + logger.info("Validating DROP_TOOL force=False for missing tool raises") + tool = AsyncTool(tool_name=DROP_FORCE_MISSING_TOOL) + with pytest.raises(oracledb.Error) as exc: + await tool.delete(force=False) + logger.info("Received expected drop error: %s", exc.value) diff --git a/tests/agents/test_3001_tools.py b/tests/agents/test_3001_tools.py new file mode 100644 index 0000000..106ee64 --- /dev/null +++ b/tests/agents/test_3001_tools.py @@ -0,0 +1,658 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3001 - Complete and backend-aligned test coverage for select_ai.agent Tool APIs +(with logging for behavior visibility) +""" + +import uuid +import logging +import pytest +import os +import select_ai +import oracledb +from select_ai.agent import Tool +from select_ai.errors import AgentToolNotFoundError + +# Path +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3001_tools.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# Force logging to file (pytest-proof) +root = logging.getLogger() +root.setLevel(logging.INFO) + +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +logger = logging.getLogger() + +# ----------------------------------------------------------------------------- +# Per-test logging +# ----------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +def get_tool_status(tool_name): + with select_ai.cursor() as cur: + cur.execute(""" + SELECT status + FROM USER_AI_AGENT_TOOLS + WHERE tool_name = :tool_name + """, {"tool_name": tool_name}) + row = cur.fetchone() + return row[0] if row else None + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +UUID = uuid.uuid4().hex.upper() + +SQL_PROFILE_NAME = f"PYSAI_SQL_PROFILE_{UUID}" +RAG_PROFILE_NAME = f"PYSAI_RAG_PROFILE_{UUID}" + +SQL_TOOL_NAME = f"PYSAI_SQL_TOOL_{UUID}" +RAG_TOOL_NAME = f"PYSAI_RAG_TOOL_{UUID}" +PLSQL_TOOL_NAME = f"PYSAI_PLSQL_TOOL_{UUID}" +WEB_SEARCH_TOOL_NAME = f"PYSAI_WEB_TOOL_{UUID}" +PLSQL_FUNCTION_NAME = f"PYSAI_CALC_AGE_{UUID}" +CUSTOM_ATTR_TOOL_NAME = f"PYSAI_3001_CUSTOM_ATTR_TOOL_{UUID}" +CUSTOM_ATTR_TOOL_DESCRIPTION = "Custom attr tool for sync testing" +CUSTOM_NO_TYPE_TOOL_NAME = f"PYSAI_3001_CUSTOM_NO_TYPE_TOOL_{UUID}" +CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME = ( + f"PYSAI_3001_CUSTOM_WITH_TYPE_NO_INSTR_TOOL_{UUID}" +) +CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME = ( + f"PYSAI_3001_CUSTOM_WITH_TYPE_AND_INSTR_TOOL_{UUID}" +) +DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" +DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" +DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" +smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") +smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") +slack_username = os.getenv("PYSAI_TEST_SLACK_USERNAME") +slack_password = os.getenv("PYSAI_TEST_SLACK_PASSWORD") + +@pytest.fixture(scope="module") +def email_credential(): + cred_name = "EMAIL_CRED" + logger.info("Ensuring EMAIL credential is clean: %s", cred_name) + + # Drop if exists (best-effort) + try: + select_ai.delete_credential(cred_name) + logger.info("Dropped existing EMAIL credential: %s", cred_name) + except Exception as e: + logger.info("EMAIL credential did not exist or could not be dropped: %s", e) + + # Create fresh credential + credential = { + "credential_name": cred_name, + "username": smtp_username, + "password": smtp_password, + } + + select_ai.create_credential( + credential=credential, + replace=True + ) + logger.info("Created EMAIL credential: %s", cred_name) + + yield cred_name + + logger.info("Deleting EMAIL credential at teardown: %s", cred_name) + try: + select_ai.delete_credential(cred_name) + except Exception as e: + logger.warning("Failed to delete EMAIL credential during teardown: %s", e) + +@pytest.fixture(scope="module") +def slack_credential(): + cred_name = "SLACK_CRED" + logger.info("Ensuring SLACK credential is clean: %s", cred_name) + + # Drop if exists (best-effort) + try: + select_ai.delete_credential(cred_name) + logger.info("Dropped existing SLACK credential: %s", cred_name) + except Exception as e: + logger.info("SLACK credential did not exist or could not be dropped: %s", e) + + # Create fresh SLACK credential (backend-required fields) + credential = { + "credential_name": cred_name, + "username": slack_username, + "password": slack_password, + } + + select_ai.create_credential( + credential=credential, + replace=True + ) + logger.info("Created SLACK credential: %s", cred_name) + + yield cred_name + + logger.info("Deleting SLACK credential at teardown: %s", cred_name) + try: + select_ai.delete_credential(cred_name) + except Exception as e: + logger.warning("Failed to delete SLACK credential during teardown: %s", e) + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def sql_profile(profile_attributes): + logger.info("Creating SQL profile: %s", SQL_PROFILE_NAME) + profile = select_ai.Profile( + profile_name=SQL_PROFILE_NAME, + description="SQL Profile", + attributes=profile_attributes, + ) + yield profile + logger.info("Deleting SQL profile") + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def rag_profile(rag_profile_attributes): + logger.info("Creating RAG profile: %s", RAG_PROFILE_NAME) + profile = select_ai.Profile( + profile_name=RAG_PROFILE_NAME, + description="RAG Profile", + attributes=rag_profile_attributes, + ) + yield profile + logger.info("Deleting RAG profile") + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def sql_tool(sql_profile): + logger.info("Creating SQL tool: %s", SQL_TOOL_NAME) + tool = select_ai.agent.Tool.create_sql_tool( + tool_name=SQL_TOOL_NAME, + profile_name=SQL_PROFILE_NAME, + description="SQL Tool", + replace=True, + ) + yield tool + logger.info("Deleting SQL tool") + tool.delete(force=True) + + +@pytest.fixture(scope="module") +def rag_tool(rag_profile): + logger.info("Creating RAG tool: %s", RAG_TOOL_NAME) + tool = select_ai.agent.Tool.create_rag_tool( + tool_name=RAG_TOOL_NAME, + profile_name=RAG_PROFILE_NAME, + description="RAG Tool", + replace=True, + ) + yield tool + logger.info("Deleting RAG tool") + tool.delete(force=True) + + +@pytest.fixture(scope="module") +def plsql_function(): + logger.info("Creating PL/SQL function: %s", PLSQL_FUNCTION_NAME) + ddl = f""" + CREATE OR REPLACE FUNCTION {PLSQL_FUNCTION_NAME}(p_birth_date DATE) + RETURN NUMBER IS + BEGIN + RETURN TRUNC(MONTHS_BETWEEN(SYSDATE, p_birth_date) / 12); + END; + """ + with select_ai.cursor() as cur: + cur.execute(ddl) + yield + logger.info("Dropping PL/SQL function") + with select_ai.cursor() as cur: + cur.execute(f"DROP FUNCTION {PLSQL_FUNCTION_NAME}") + + +@pytest.fixture(scope="module") +def plsql_tool(plsql_function): + logger.info("Creating PL/SQL tool: %s", PLSQL_TOOL_NAME) + tool = select_ai.agent.Tool.create_pl_sql_tool( + tool_name=PLSQL_TOOL_NAME, + function=PLSQL_FUNCTION_NAME, + description="PL/SQL Tool", + replace=True, + ) + yield tool + logger.info("Deleting PL/SQL tool") + tool.delete(force=True) + +@pytest.fixture(scope="module") +def web_search_tool(): + """Fixture for Web Search Tool positive case.""" + logger.info("Creating Web Search tool: %s", WEB_SEARCH_TOOL_NAME) + tool = select_ai.agent.Tool.create_websearch_tool( + tool_name=WEB_SEARCH_TOOL_NAME, + description="Web Search Tool for testing", + credential_name="OPENAI_CRED", + replace=True, + ) + logger.info("WEBSEARCH Tool created successfully: %s", WEB_SEARCH_TOOL_NAME) + yield tool + logger.info("Deleting Web Search tool") + tool.delete(force=True) + +@pytest.fixture(scope="module") +def email_tool(email_credential): + logger.info("Creating EMAIL tool: EMAIL_TOOL") + tool = select_ai.agent.Tool.create_email_notification_tool( + tool_name="EMAIL_TOOL", + credential_name="EMAIL_CRED", + recipient="kondra.nagabhavani@oracle.com", + sender="bharadwaj.vulugundam@oracle.com", + smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + description="Send email", + replace=True, + ) + logger.info("EMAIL_TOOL created successfully") + yield tool + logger.info("Deleting EMAIL tool") + tool.delete(force=True) + +@pytest.fixture(scope="module") +def slack_tool(slack_credential): + logger.info("Creating SLACK tool: SLACK_TOOL") + try: + tool = select_ai.agent.Tool.create_slack_notification_tool( + tool_name="SLACK_TOOL", + credential_name="SLACK_CRED", + slack_channel="#general", + description="slack notification", + replace=True, + ) + logger.info("SLACK_TOOL is created successfully") + yield tool + except oracledb.DatabaseError as e: + if "ORA-20052" in str(e): + logger.info(f"Expected error during tool creation: {e}") + yield None # Return None, indicating the tool creation failed but is expected + else: + raise e + finally: + if 'tool' in locals(): + logger.info("Deleting SLACK tool") + tool.delete(force=True) + +@pytest.fixture(scope="module") +def neg_sql_tool(): + logger.info("Creating SQL tool with INVALID profile: NEG_SQL_TOOL") + tool = select_ai.agent.Tool.create_sql_tool( + tool_name="NEG_SQL_TOOL", + profile_name="NON_EXISTENT_PROFILE", + replace=True, + ) + logger.info("NEG_SQL_TOOL is created successfully.") + yield tool + logger.info("Deleting NEG_SQL_TOOL") + tool.delete(force=True) + +@pytest.fixture(scope="module") +def neg_rag_tool(): + logger.info("Creating RAG tool with INVALID profile: NEG_RAG_TOOL") + tool = select_ai.agent.Tool.create_rag_tool( + tool_name="NEG_RAG_TOOL", + profile_name="NON_EXISTENT_RAG_PROFILE", + replace=True, + ) + logger.info("NEG_RAG_TOOL is created successfully") + yield tool + logger.info("Deleting NEG_RAG_TOOL") + tool.delete(force=True) + + +@pytest.fixture(scope="module") +def neg_plsql_tool(): + logger.info("Creating PL/SQL tool with INVALID function: NEG_PLSQL_TOOL") + tool = select_ai.agent.Tool.create_pl_sql_tool( + tool_name="NEG_PLSQL_TOOL", + function="NON_EXISTENT_FUNCTION", + replace=True, + ) + logger.info("NEG_PLSQL_TOOL is created successfully") + yield tool + logger.info("Deleting NEG_PLSQL_TOOL") + tool.delete(force=True) + +# ----------------------------------------------------------------------------- +# POSITIVE TESTS +# ----------------------------------------------------------------------------- + +def test_3000_sql_tool_created(sql_tool): + logger.info("Validating SQL tool creation") + logger.info("SQL Tool created successfully: %s", SQL_TOOL_NAME) + logger.info("SQL Profile created successfully: %s", SQL_PROFILE_NAME) + assert sql_tool.tool_name == SQL_TOOL_NAME + assert sql_tool.attributes.tool_params.profile_name == SQL_PROFILE_NAME + + +def test_3001_rag_tool_created(rag_tool): + logger.info("Validating RAG tool creation") + logger.info("RAG Tool created successfully: %s", RAG_TOOL_NAME) + logger.info("RAG Profile created successfully: %s", RAG_PROFILE_NAME) + assert rag_tool.tool_name == RAG_TOOL_NAME + assert rag_tool.attributes.tool_params.profile_name == RAG_PROFILE_NAME + + +def test_3002_plsql_tool_created(plsql_tool): + logger.info("Validating PL/SQL tool creation") + logger.info("PL/SQL Tool created successfully: %s", PLSQL_TOOL_NAME) + logger.info("PL/SQL function created successfully: %s", PLSQL_FUNCTION_NAME) + assert plsql_tool.tool_name == PLSQL_TOOL_NAME + assert plsql_tool.attributes.function == PLSQL_FUNCTION_NAME + + +def test_3003_list_tools(): + logger.info("Listing all tools") + tool_names = {t.tool_name for t in select_ai.agent.Tool.list()} + logger.info("Tools present: %s", tool_names) + + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +def test_3004_list_tools_regex(): + logger.info("Listing tools using regex ^PYSAI_") + tool_names = {t.tool_name for t in select_ai.agent.Tool.list("^PYSAI_")} + logger.info("Matched tools: %s", tool_names) + + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +def test_3005_fetch_tool(): + logger.info("Fetching SQL tool") + tool = select_ai.agent.Tool.fetch(SQL_TOOL_NAME) + assert tool.tool_name == SQL_TOOL_NAME + + +def test_3006_enable_disable_sql_tool(sql_tool): + logger.info("Disabling SQL tool: %s", sql_tool.tool_name) + sql_tool.disable() + + status = get_tool_status(sql_tool.tool_name) + logger.info( + "Tool status after disable | tool=%s | status=%s", + sql_tool.tool_name, + status, + ) + assert status == "DISABLED" + + logger.info("Enabling SQL tool: %s", sql_tool.tool_name) + sql_tool.enable() + + status = get_tool_status(sql_tool.tool_name) + logger.info( + "Tool status after enable | tool=%s | status=%s", + sql_tool.tool_name, + status, + ) + assert status == "ENABLED" + + +def test_3007_web_search_tool_created(web_search_tool): + logger.info("Validating Web Search tool creation") + assert web_search_tool.tool_name == WEB_SEARCH_TOOL_NAME + + +def test_3008_email_tool_created(email_tool): + logger.info("Validating EMAIL tool creation") + assert email_tool.tool_name == "EMAIL_TOOL" + + +def test_3009_slack_tool_created(slack_tool): + logger.info("Validating SLACK tool creation") + + # If the tool is None (because of expected ORA-20052 error), skip the assertion + if slack_tool is None: + logger.info("SLACK tool creation failed with expected error ORA-20052, but continuing test.") + else: + assert slack_tool.tool_name == "SLACK_TOOL" + +def test_3010_custom_tool_attributes_roundtrip(): + logger.info( + "Validating custom tool attribute roundtrip: instruction/tool_inputs/description" + ) + tool = Tool( + tool_name=CUSTOM_ATTR_TOOL_NAME, + description=CUSTOM_ATTR_TOOL_DESCRIPTION, + attributes=select_ai.agent.ToolAttributes( + function=PLSQL_FUNCTION_NAME, + instruction="Return age in years for a birth date input", + tool_inputs=[ + { + "name": "p_birth_date", + "description": "Input birth date in DATE format", + } + ], + ), + ) + tool.create(replace=True) + try: + fetched = select_ai.agent.Tool.fetch(CUSTOM_ATTR_TOOL_NAME) + logger.info( + "Fetched custom tool | name=%s | description=%s | instruction=%s", + fetched.tool_name, + fetched.description, + fetched.attributes.instruction, + ) + assert fetched.tool_name == CUSTOM_ATTR_TOOL_NAME + assert fetched.description == CUSTOM_ATTR_TOOL_DESCRIPTION + assert fetched.attributes.function == PLSQL_FUNCTION_NAME + assert ( + fetched.attributes.instruction + == "Return age in years for a birth date input" + ) + assert isinstance(fetched.attributes.tool_inputs, list) + assert fetched.attributes.tool_inputs[0]["name"] == "p_birth_date" + assert "birth date" in fetched.attributes.tool_inputs[0]["description"].lower() + finally: + tool.delete(force=True) + + +def test_3011_custom_tool_without_tool_type(): + logger.info("Validating custom tool creation with tool_type unset") + tool = Tool( + tool_name=CUSTOM_NO_TYPE_TOOL_NAME, + description="Custom tool without explicit tool_type", + attributes=select_ai.agent.ToolAttributes( + function=PLSQL_FUNCTION_NAME, + ), + ) + tool.create(replace=True) + try: + fetched = select_ai.agent.Tool.fetch(CUSTOM_NO_TYPE_TOOL_NAME) + assert fetched.tool_name == CUSTOM_NO_TYPE_TOOL_NAME + assert fetched.attributes.function == PLSQL_FUNCTION_NAME + assert fetched.attributes.tool_type is None + assert fetched.description == "Custom tool without explicit tool_type" + finally: + tool.delete(force=True) + + +def test_3012_custom_tool_with_tool_type_without_instruction(sql_profile): + logger.info("Validating custom tool with tool_type and no instruction") + tool = Tool( + tool_name=CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME, + description="Custom tool with tool_type and no instruction", + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + ), + ) + tool.create(replace=True) + try: + fetched = select_ai.agent.Tool.fetch(CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME) + logger.info( + "Fetched custom tool | name=%s | type=%s | instruction=%s | profile=%s", + fetched.tool_name, + fetched.attributes.tool_type, + fetched.attributes.instruction, + fetched.attributes.tool_params.profile_name, + ) + assert fetched.tool_name == CUSTOM_WITH_TYPE_NO_INSTR_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.instruction is not None + assert "sql" in fetched.attributes.instruction.lower() + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + tool.delete(force=True) + + +def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile): + logger.info("Validating custom tool with tool_type and instruction") + tool = Tool( + tool_name=CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME, + description="Custom tool with tool_type and instruction", + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + instruction="Use SQL profile to answer query from relational data", + ), + ) + tool.create(replace=True) + try: + fetched = select_ai.agent.Tool.fetch(CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME) + assert fetched.tool_name == CUSTOM_WITH_TYPE_AND_INSTR_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.instruction is not None + assert "sql" in fetched.attributes.instruction.lower() + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + tool.delete(force=True) + + +def test_3014_sql_tool_with_invalid_profile_created(neg_sql_tool): + logger.info("Validating SQL tool with invalid profile is stored") + assert neg_sql_tool.tool_name == "NEG_SQL_TOOL" + assert neg_sql_tool.attributes.tool_params.profile_name == "NON_EXISTENT_PROFILE" + + +def test_3015_rag_tool_with_invalid_profile_created(neg_rag_tool): + logger.info("Validating RAG tool with invalid profile is stored") + assert neg_rag_tool.tool_name == "NEG_RAG_TOOL" + assert neg_rag_tool.attributes.tool_params.profile_name == "NON_EXISTENT_RAG_PROFILE" + + +def test_3016_plsql_tool_with_invalid_function_created(neg_plsql_tool): + logger.info("Validating PL/SQL tool with invalid function is stored") + assert neg_plsql_tool.tool_name == "NEG_PLSQL_TOOL" + assert neg_plsql_tool.attributes.function == "NON_EXISTENT_FUNCTION" + + +def test_3017_fetch_non_existent_tool(): + logger.info("Fetching non-existent tool") + with pytest.raises(AgentToolNotFoundError) as exc: + select_ai.agent.Tool.fetch("TOOL_DOES_NOT_EXIST") + logger.error("%s", exc.value) + + +def test_3018_list_invalid_regex(): + logger.info("Listing tools with invalid regex") + with pytest.raises(Exception) as exc: + list(select_ai.agent.Tool.list(tool_name_pattern="*[")) + logger.error("%s", exc.value) + + +def test_3019_list_tools(): + logger.info("Listing all tools") + tool_names = {t.tool_name for t in select_ai.agent.Tool.list()} + logger.info("Tools present: %s", tool_names) + + assert SQL_TOOL_NAME in tool_names + assert RAG_TOOL_NAME in tool_names + assert PLSQL_TOOL_NAME in tool_names + + +def test_3020_create_tool_default_status_enabled(sql_profile): + logger.info("Creating tool to validate default ENABLED status") + tool = select_ai.agent.Tool.create_built_in_tool( + tool_name=DEFAULT_STATUS_TOOL_NAME, + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams(profile_name=SQL_PROFILE_NAME), + ) + try: + status = get_tool_status(DEFAULT_STATUS_TOOL_NAME) + logger.info("Tool status after create: %s", status) + assert status == "ENABLED" + fetched = select_ai.agent.Tool.fetch(DEFAULT_STATUS_TOOL_NAME) + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + assert fetched.attributes.tool_params.profile_name == SQL_PROFILE_NAME + finally: + tool.delete(force=True) + + +def test_3021_create_tool_with_enabled_false_sets_disabled(sql_profile): + logger.info("Creating tool with enabled=False to validate DISABLED status") + tool = Tool( + tool_name=DISABLED_TOOL_NAME, + attributes=select_ai.agent.ToolAttributes( + tool_type=select_ai.agent.ToolType.SQL, + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), + ), + ) + tool.create(enabled=False, replace=True) + try: + status = get_tool_status(DISABLED_TOOL_NAME) + logger.info("Tool status after create(enabled=False): %s", status) + assert status == "DISABLED" + fetched = select_ai.agent.Tool.fetch(DISABLED_TOOL_NAME) + assert fetched.attributes.tool_type == select_ai.agent.ToolType.SQL + finally: + tool.delete(force=True) + + +def test_3022_drop_tool_force_true_non_existent(): + logger.info("Validating DROP_TOOL force=True for missing tool") + tool = Tool(tool_name=DROP_FORCE_MISSING_TOOL) + tool.delete(force=True) + status = get_tool_status(DROP_FORCE_MISSING_TOOL) + logger.info("Status after force delete on missing tool: %s", status) + assert status is None + + +def test_3023_drop_tool_force_false_non_existent_raises(): + logger.info("Validating DROP_TOOL force=False for missing tool raises") + tool = Tool(tool_name=DROP_FORCE_MISSING_TOOL) + with pytest.raises(oracledb.Error) as exc: + tool.delete(force=False) + logger.info("Received expected drop error: %s", exc.value) diff --git a/tests/agents/test_3101_async_tasks.py b/tests/agents/test_3101_async_tasks.py new file mode 100644 index 0000000..091e272 --- /dev/null +++ b/tests/agents/test_3101_async_tasks.py @@ -0,0 +1,380 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3101 - Module for testing select_ai agent async tasks +""" + +import uuid +import logging +import os + +import oracledb +import pytest +import select_ai +from select_ai.agent import AsyncTask, TaskAttributes + +pytestmark = pytest.mark.anyio + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3100_async_tasks.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +root = logging.getLogger() +root.setLevel(logging.INFO) +for handler in root.handlers[:]: + root.removeHandler(handler) +file_handler = logging.FileHandler(LOG_FILE, mode="w") +file_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(file_handler) +logger = logging.getLogger() + +PYSAI_3100_TASK_NAME = f"PYSAI_3100_{uuid.uuid4().hex.upper()}" +PYSAI_3100_SQL_TASK_DESCRIPTION = "PYSAI_3100_SQL_TASK_DESCRIPTION" +PYSAI_3100_DISABLED_TASK_NAME = f"PYSAI_3100_DISABLED_{uuid.uuid4().hex.upper()}" +PYSAI_3100_DEFAULT_STATUS_TASK_NAME = ( + f"PYSAI_3100_DEFAULT_STATUS_{uuid.uuid4().hex.upper()}" +) +PYSAI_3100_PARENT_TASK_NAME = f"PYSAI_3100_PARENT_{uuid.uuid4().hex.upper()}" +PYSAI_3100_CHILD_TASK_NAME = f"PYSAI_3100_CHILD_{uuid.uuid4().hex.upper()}" +PYSAI_3100_DEFAULT_HUMAN_TASK_NAME = ( + f"PYSAI_3100_DEFAULT_HUMAN_{uuid.uuid4().hex.upper()}" +) +PYSAI_3100_MISSING_TASK_NAME = f"PYSAI_3100_MISSING_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.fixture(scope="module", autouse=True) +async def async_connect(test_env): + logger.info("Opening async database connection") + await select_ai.async_connect(**test_env.connect_params()) + yield + logger.info("Closing async database connection") + await select_ai.async_disconnect() + + +async def get_task_status(task_name): + logger.info("Fetching task status for: %s", task_name) + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT status + FROM USER_AI_AGENT_TASKS + WHERE task_name = :task_name + """, + {"task_name": task_name}, + ) + row = await cur.fetchone() + return row[0] if row else None + + +async def assert_task_status(task_name: str, expected_status: str) -> None: + status = await get_task_status(task_name) + logger.info( + "Verifying task status | task=%s | expected=%s | actual=%s", + task_name, + expected_status, + status, + ) + assert status == expected_status + + +def log_task_details(context: str, task) -> None: + attrs = getattr(task, "attributes", None) + details = { + "context": context, + "task_name": getattr(task, "task_name", None), + "description": getattr(task, "description", None), + "instruction": getattr(attrs, "instruction", None) if attrs else None, + "tools": getattr(attrs, "tools", None) if attrs else None, + "input": getattr(attrs, "input", None) if attrs else None, + "enable_human_tool": ( + getattr(attrs, "enable_human_tool", None) if attrs else None + ), + } + logger.info("TASK_DETAILS: %s", details) + print("TASK_DETAILS:", details) + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user with their request about movies. " + "User question: {query}. " + "You can use SQL tool to search the data from database", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +async def task(task_attributes): + task = AsyncTask( + task_name=PYSAI_3100_TASK_NAME, + description=PYSAI_3100_SQL_TASK_DESCRIPTION, + attributes=task_attributes, + ) + await task.create() + yield task + await task.delete(force=True) + + +async def test_3100(task, task_attributes): + """simple task creation""" + log_task_details("test_3100", task) + assert task.task_name == PYSAI_3100_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3100_SQL_TASK_DESCRIPTION + assert task.attributes.instruction is not None + assert "{query}" in task.attributes.instruction + assert task.attributes.tools == ["MOVIE_SQL_TOOL"] + assert task.attributes.enable_human_tool is False + + +@pytest.mark.parametrize("task_name_pattern", [None, "^PYSAI_3100_"]) +async def test_3101(task_name_pattern): + """task list""" + if task_name_pattern: + tasks = [task async for task in select_ai.agent.AsyncTask.list(task_name_pattern)] + else: + tasks = [task async for task in select_ai.agent.AsyncTask.list()] + for task in tasks: + if task.task_name == PYSAI_3100_TASK_NAME: + log_task_details("test_3101", task) + task_names = set(task.task_name for task in tasks) + task_descriptions = set(task.description for task in tasks) + assert len(tasks) >= 1 + assert PYSAI_3100_TASK_NAME in task_names + assert PYSAI_3100_SQL_TASK_DESCRIPTION in task_descriptions + + +async def test_3102(task_attributes): + """task fetch""" + task = await select_ai.agent.AsyncTask.fetch(PYSAI_3100_TASK_NAME) + log_task_details("test_3102", task) + assert task.task_name == PYSAI_3100_TASK_NAME + assert task.attributes == task_attributes + assert task.description == PYSAI_3100_SQL_TASK_DESCRIPTION + assert task.attributes.tools == ["MOVIE_SQL_TOOL"] + assert task.attributes.input is None + assert task.attributes.enable_human_tool is False + + +async def test_3103_create_task_default_status_enabled(): + task = AsyncTask( + task_name=PYSAI_3100_DEFAULT_STATUS_TASK_NAME, + description="Default status should be enabled", + attributes=TaskAttributes( + instruction="Summarize user request: {query}", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ), + ) + await task.create(replace=True) + try: + await assert_task_status(PYSAI_3100_DEFAULT_STATUS_TASK_NAME, "ENABLED") + fetched = await AsyncTask.fetch(PYSAI_3100_DEFAULT_STATUS_TASK_NAME) + log_task_details("test_3103", fetched) + assert fetched.description == "Default status should be enabled" + assert fetched.attributes.enable_human_tool is False + finally: + await task.delete(force=True) + + +async def test_3104_create_task_with_enabled_false_sets_disabled(): + task = AsyncTask( + task_name=PYSAI_3100_DISABLED_TASK_NAME, + description="Task created disabled", + attributes=TaskAttributes( + instruction="Handle disabled task validation", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ), + ) + await task.create(enabled=False, replace=True) + try: + await assert_task_status(PYSAI_3100_DISABLED_TASK_NAME, "DISABLED") + fetched = await AsyncTask.fetch(PYSAI_3100_DISABLED_TASK_NAME) + log_task_details("test_3104", fetched) + assert fetched.description == "Task created disabled" + + logger.info("Enabling task created with enabled=False: %s", task.task_name) + await task.enable() + await assert_task_status(PYSAI_3100_DISABLED_TASK_NAME, "ENABLED") + finally: + await task.delete(force=True) + + +async def test_3105_disable_enable_task(task): + logger.info("Disabling task: %s", task.task_name) + await task.disable() + await assert_task_status(PYSAI_3100_TASK_NAME, "DISABLED") + + logger.info("Enabling task: %s", task.task_name) + await task.enable() + await assert_task_status(PYSAI_3100_TASK_NAME, "ENABLED") + + +async def test_3105b_set_single_attribute_invalid(task): + logger.info("Setting invalid single attribute for async task: %s", task.task_name) + with pytest.raises(oracledb.DatabaseError) as exc: + await task.set_attribute("description", "New Desc") + logger.info("Received expected Oracle error: %s", exc.value) + assert "ORA-20051" in str(exc.value) + + +async def test_3105c_duplicate_task_creation_fails(task): + logger.info("Creating duplicate async task without replace: %s", task.task_name) + dup = AsyncTask( + task_name=task.task_name, + description="Duplicate task", + attributes=task.attributes, + ) + with pytest.raises(oracledb.Error) as exc: + await dup.create(replace=False) + logger.info("Received expected duplicate create error: %s", exc.value) + assert "ORA-20051" in str(exc.value) + + +async def test_3105d_invalid_regex_pattern(): + logger.info("Listing async tasks with invalid regex") + with pytest.raises(oracledb.Error) as exc: + async for _ in AsyncTask.list("[INVALID_REGEX"): + pass + logger.info("Received expected invalid regex error: %s", exc.value) + assert "ORA-12726" in str(exc.value) + + +async def test_3106_drop_task_force_true_non_existent(): + logger.info("Dropping missing task with force=True: %s", PYSAI_3100_MISSING_TASK_NAME) + task = AsyncTask(task_name=PYSAI_3100_MISSING_TASK_NAME) + await task.delete(force=True) + status = await get_task_status(PYSAI_3100_MISSING_TASK_NAME) + logger.info("Status after force delete on missing task: %s", status) + assert status is None + + +async def test_3107_drop_task_force_false_non_existent_raises(): + logger.info("Dropping missing task with force=False: %s", PYSAI_3100_MISSING_TASK_NAME) + task = AsyncTask(task_name=PYSAI_3100_MISSING_TASK_NAME) + with pytest.raises(oracledb.Error) as exc: + await task.delete(force=False) + logger.info("Received expected drop error: %s", exc.value) + + +async def test_3108_create_task_with_input_attribute(): + logger.info("Creating parent/child tasks for input chaining validation") + parent_task = AsyncTask( + task_name=PYSAI_3100_PARENT_TASK_NAME, + description="Parent task", + attributes=TaskAttributes( + instruction="Generate an intermediate summary for: {query}", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ), + ) + child_task = AsyncTask( + task_name=PYSAI_3100_CHILD_TASK_NAME, + description="Child task with input dependency", + attributes=TaskAttributes( + instruction="Use upstream context and produce final answer", + tools=["MOVIE_SQL_TOOL"], + input=PYSAI_3100_PARENT_TASK_NAME, + enable_human_tool=False, + ), + ) + await parent_task.create(replace=True) + await child_task.create(replace=True) + try: + fetched = await AsyncTask.fetch(PYSAI_3100_CHILD_TASK_NAME) + log_task_details("test_3108_child", fetched) + assert fetched.attributes.input == PYSAI_3100_PARENT_TASK_NAME + assert fetched.attributes.tools == ["MOVIE_SQL_TOOL"] + assert fetched.description == "Child task with input dependency" + assert fetched.attributes.enable_human_tool is False + finally: + await child_task.delete(force=True) + await parent_task.delete(force=True) + + +async def test_3109_enable_human_tool_default_true(): + logger.info("Creating task to validate enable_human_tool default behavior") + task = AsyncTask( + task_name=PYSAI_3100_DEFAULT_HUMAN_TASK_NAME, + description="Default enable_human_tool check", + attributes=TaskAttributes( + instruction="Collect more details from user for: {query}", + tools=["MOVIE_SQL_TOOL"], + ), + ) + await task.create(replace=True) + try: + fetched = await AsyncTask.fetch(PYSAI_3100_DEFAULT_HUMAN_TASK_NAME) + log_task_details("test_3109", fetched) + assert fetched.attributes.enable_human_tool is True + finally: + await task.delete(force=True) + + +async def test_3110_create_requires_task_name(): + logger.info("Validating create() requires task_name") + with pytest.raises(AttributeError) as exc: + await AsyncTask( + attributes=TaskAttributes( + instruction="Missing task_name validation", tools=[] + ) + ).create() + logger.info("Received expected error: %s", exc.value) + + +async def test_3111_create_requires_attributes(): + logger.info("Validating create() requires attributes") + with pytest.raises(AttributeError) as exc: + await AsyncTask( + task_name=f"PYSAI_3100_NO_ATTR_{uuid.uuid4().hex.upper()}" + ).create() + logger.info("Received expected error: %s", exc.value) + + +async def test_3112_enable_deleted_task_object_raises(): + logger.info("Creating task to validate object behavior after delete") + task_name = f"PYSAI_3100_DELETED_{uuid.uuid4().hex.upper()}" + attrs = TaskAttributes( + instruction="Validate task object after delete for: {query}", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ) + task = AsyncTask( + task_name=task_name, + description="Task deleted before reuse", + attributes=attrs, + ) + + await task.create(replace=True) + await assert_task_status(task_name, "ENABLED") + + await task.delete(force=True) + status = await get_task_status(task_name) + logger.info("Task status after delete: %s", status) + assert status is None + + logger.info("Verifying in-memory task object is still populated") + assert task.task_name == task_name + assert task.description == "Task deleted before reuse" + assert task.attributes == attrs + + logger.info("Attempting to enable deleted task using same object") + with pytest.raises(oracledb.DatabaseError) as exc: + await task.enable() + logger.info("Received expected error when enabling deleted task: %s", exc.value) + assert "ORA-20051" in str(exc.value) diff --git a/tests/agents/test_3101_tasks.py b/tests/agents/test_3101_tasks.py new file mode 100644 index 0000000..886a2d6 --- /dev/null +++ b/tests/agents/test_3101_tasks.py @@ -0,0 +1,323 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +# 3101 - Comprehensive tests for select_ai.agent.Task with error code asserts +""" + +import uuid +import logging +import pytest +import os +import select_ai +from select_ai.agent import Task, TaskAttributes +from select_ai.errors import AgentTaskNotFoundError +import oracledb +# Path +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3101_tasks.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# Logging +root = logging.getLogger() +root.setLevel(logging.INFO) +for h in root.handlers[:]: + root.removeHandler(h) +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) +logger = logging.getLogger() + +# ----------------------------------------------------------------------------- +# Per-test logging +# ----------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +def get_task_status(task_name): + with select_ai.cursor() as cur: + cur.execute(""" + SELECT status + FROM USER_AI_AGENT_TASKS + WHERE task_name = :task_name + """, {"task_name": task_name}) + row = cur.fetchone() + return row[0] if row else None + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- + +BASE = f"PYSAI_3101_{uuid.uuid4().hex.upper()}" +TASK_A_NAME = f"{BASE}_TASK_A" +TASK_B_NAME = f"{BASE}_TASK_B" + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + +def expect_oracle_error(expected_code, fn): + """ + Run fn and assert that expected Oracle/Agent error occurs. + expected_code: "ORA-xxxxx" or "NOT_FOUND" + """ + try: + fn() + except AgentTaskNotFoundError as e: + logger.info("Expected failure (NOT_FOUND): %s", e) + assert expected_code == "NOT_FOUND" + except oracledb.DatabaseError as e: + msg = str(e) + logger.info("Expected Oracle failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got {msg}" + else: + pytest.fail(f"Expected error {expected_code} did not occur") + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def task_a(): + logger.info("Creating TASK_A: %s", TASK_A_NAME) + attrs = TaskAttributes( + instruction="Analyze movie data for user query: {query}", + tools=["MOVIE_SQL_TOOL"], + enable_human_tool=False, + ) + task = Task(task_name=TASK_A_NAME, description="Primary analysis task", attributes=attrs) + task.create() + logger.info("TASK_A created successfully") + yield task + logger.info("Deleting TASK_A: %s", TASK_A_NAME) + task.delete(force=True) + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(TASK_A_NAME)) + logger.info("TASK_A deleted successfully") + +@pytest.fixture(scope="module") +def task_b(task_a): + logger.info("Creating TASK_B: %s", TASK_B_NAME) + attrs = TaskAttributes( + instruction="Summarize insights from previous analysis", + input=TASK_A_NAME, + tools=None, + enable_human_tool=True, + ) + task = Task(task_name=TASK_B_NAME, description="Chained summarization task", attributes=attrs) + task.create() + logger.info("TASK_B created successfully") + yield task + logger.info("Deleting TASK_B: %s", TASK_B_NAME) + task.delete(force=True) + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(TASK_B_NAME)) + logger.info("TASK_B deleted successfully") + +# ----------------------------------------------------------------------------- +# Positive Tests +# ----------------------------------------------------------------------------- + +def test_3100_task_creation(task_a): + logger.info("Verifying TASK_A creation") + logger.info("Task Name : %s", task_a.task_name) + logger.info("Task Description: %s", task_a.description) + logger.info("Task Attributes:") + logger.info(" enable_human_tool = %s", task_a.attributes.enable_human_tool) + logger.info(" tools = %s", task_a.attributes.tools) + assert task_a.task_name == TASK_A_NAME + assert task_a.description == "Primary analysis task" + assert task_a.attributes.enable_human_tool is False + assert task_a.attributes.tools == ["MOVIE_SQL_TOOL"] + +def test_3101_task_chaining(task_b): + logger.info("Verifying TASK_B chaining") + logger.info("TASK_B attributes:") + logger.info(" input = %s", task_b.attributes.input) + logger.info(" enable_human_tool = %s", task_b.attributes.enable_human_tool) + assert task_b.attributes.input == TASK_A_NAME + assert task_b.attributes.enable_human_tool is True + +def test_3102_fetch_task(task_a): + logger.info("Fetching TASK_A") + fetched = Task.fetch(TASK_A_NAME) + logger.info("Fetched task details:") + logger.info(" task_name = %s", fetched.task_name) + logger.info(" attributes = %s", fetched.attributes) + logger.info("Original task attributes:") + logger.info(" attributes = %s", task_a.attributes) + assert fetched.task_name == TASK_A_NAME + assert fetched.attributes == task_a.attributes + +def test_3103_list_tasks_with_regex(): + logger.info("Listing tasks with regex") + tasks = list(Task.list(f"{BASE}.*")) + names = sorted(t.task_name for t in tasks) + logger.info("Tasks returned (sorted):") + for name in names: + logger.info(" - %s", name) + assert TASK_A_NAME in names + assert TASK_B_NAME in names + + +def test_3104_disable_enable_task(task_b): + logger.info("Disabling TASK_B: %s", task_b.task_name) + task_b.disable() + + status = get_task_status(task_b.task_name) + logger.info("DB status after disable: %s", status) + assert status == "DISABLED" + + logger.info("Enabling TASK_B: %s", task_b.task_name) + task_b.enable() + + status = get_task_status(task_b.task_name) + logger.info("DB status after enable: %s", status) + assert status == "ENABLED" + +# ----------------------------------------------------------------------------- +# Negative / Edge Case Tests with Error Code Asserts +# ----------------------------------------------------------------------------- + +def test_3105_set_single_attribute_invalid(task_b): + logger.info("Setting invalid single attribute for TASK_B") + expect_oracle_error("ORA-20051", lambda: task_b.set_attribute("description", "New Desc")) + +def test_3110_fetch_non_existent_task(): + name = f"{BASE}_NO_SUCH_TASK" + logger.info("Fetching non-existent task: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(name)) + +def test_3111_duplicate_task_creation_fails(task_a): + logger.info("Creating duplicate TASK_A without replace") + logger.info(" task_name = %s", task_a.task_name) + dup = Task( + task_name=task_a.task_name, + description="Duplicate task", + attributes=task_a.attributes, + ) + expect_oracle_error("ORA-20051", lambda: dup.create(replace=False)) + +def test_3113_set_invalid_attribute(task_a): + logger.info("Setting invalid attribute for TASK_A") + logger.info(" attribute = unknown_attribute") + expect_oracle_error("ORA-20051", lambda: task_a.set_attribute("unknown_attribute", "value")) + +def test_3114_invalid_regex_pattern(): + logger.info("Listing tasks with invalid regex") + expect_oracle_error("ORA-12726", lambda: list(Task.list("[INVALID_REGEX"))) + +def test_3115_delete_disabled_task_without_force(): + task_name = f"{BASE}_TEMP_DELETE" + logger.info("Creating and deleting disabled task: %s", task_name) + attrs = TaskAttributes(instruction="Temporary task", tools=None) + task = Task(task_name=task_name, description="Temp task", attributes=attrs) + task.create() + task.disable() + # DB verification: task is DISABLED + status = get_task_status(task_name) + logger.info("Task status before delete: %s", status) + assert status == "DISABLED" + task.delete(force=False) + # DB verification: task removed + status = get_task_status(task_name) + logger.info("Task status after delete: %s", status) + assert status is None + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(task_name)) + + +def test_3116_missing_instruction(): + task_name = f"{BASE}_NO_INSTRUCTION" + logger.info("Creating task with missing instruction: %s", task_name) + attrs = TaskAttributes(instruction="", tools=None) + task = Task(task_name=task_name, attributes=attrs) + expect_oracle_error("ORA-20051", lambda: task.create()) + +def test_3117_delete_enabled_task_without_force_succeeds(): + task_name = f"{BASE}_FORCE_DELETE_TEST" + logger.info("Creating and deleting enabled task: %s", task_name) + attrs = TaskAttributes(instruction="Delete force test", tools=None) + task = Task(task_name=task_name, attributes=attrs) + task.create(enabled=True) + # DB verification: task is ENABLED + status = get_task_status(task_name) + logger.info("Task status before delete: %s", status) + assert status == "ENABLED" + task.delete(force=False) + # DB verification: task removed + status = get_task_status(task_name) + logger.info("Task status after delete: %s", status) + assert status is None + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(task_name)) + + +def test_3118_delete_disabled_task_with_force_succeeds(): + task_name = f"{BASE}_DISABLED_CREATE" + logger.info("Deleting initially disabled task: %s", task_name) + attrs = TaskAttributes(instruction="Initially disabled task", tools=None) + task = Task(task_name=task_name, attributes=attrs) + task.create(enabled=False) + # DB verification: task is DISABLED + status = get_task_status(task_name) + logger.info("Task status before delete: %s", status) + assert status == "DISABLED" + task.delete(force=True) + # DB verification: task removed + status = get_task_status(task_name) + logger.info("Task status after delete: %s", status) + assert status is None + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(task_name)) + + logger.info("Attempting operational use of deleted task object: %s", task_name) + expect_oracle_error("ORA-20051", lambda: task.enable()) + + +def test_3119_double_delete_force_true_succeeds(): + task_name = f"{BASE}_DOUBLE_DELETE_FORCE_TRUE" + logger.info("Creating task for double delete with force=True: %s", task_name) + attrs = TaskAttributes(instruction="Double delete force true", tools=None) + task = Task(task_name=task_name, attributes=attrs) + task.create(enabled=True) + + task.delete(force=True) + status = get_task_status(task_name) + logger.info("Task status after first delete: %s", status) + assert status is None + + logger.info("Deleting already deleted task with force=True: %s", task_name) + task.delete(force=True) + status = get_task_status(task_name) + logger.info("Task status after second delete with force=True: %s", status) + assert status is None + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(task_name)) + + +def test_3120_double_delete_force_false_raises(): + task_name = f"{BASE}_DOUBLE_DELETE_FORCE_FALSE" + logger.info("Creating task for double delete with force=False: %s", task_name) + attrs = TaskAttributes(instruction="Double delete force false", tools=None) + task = Task(task_name=task_name, attributes=attrs) + task.create(enabled=True) + + task.delete(force=False) + status = get_task_status(task_name) + logger.info("Task status after first delete: %s", status) + assert status is None + + logger.info("Deleting already deleted task with force=False: %s", task_name) + with pytest.raises(oracledb.DatabaseError) as exc: + task.delete(force=False) + logger.info("Received expected Oracle error on second delete: %s", exc.value) + expect_oracle_error("NOT_FOUND", lambda: Task.fetch(task_name)) diff --git a/tests/agents/test_3201_agents.py b/tests/agents/test_3201_agents.py new file mode 100644 index 0000000..4573889 --- /dev/null +++ b/tests/agents/test_3201_agents.py @@ -0,0 +1,491 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3200 - Module for testing select_ai agents +""" + +import uuid +import logging +import pytest +import select_ai +import os +from select_ai.agent import Agent, AgentAttributes +from select_ai.errors import AgentNotFoundError +import oracledb + +# Path +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3201_agents.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# Force logging to file (pytest-proof) +root = logging.getLogger() +root.setLevel(logging.INFO) + +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +logger = logging.getLogger() + + +# ----------------------------------------------------------------------------- +# Per-test logging +# ----------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +# ----------------------------------------------------------------------------- +# Helper Functions +# ----------------------------------------------------------------------------- + +def get_agent_status(agent_name): + with select_ai.cursor() as cur: + cur.execute(""" + SELECT status + FROM USER_AI_AGENTS + WHERE agent_name = :agent_name + """, {"agent_name": agent_name}) + row = cur.fetchone() + return row[0] if row else None + +# ----------------------------------------------------------------------------- +# Test constants +# ----------------------------------------------------------------------------- + +PYSAI_AGENT_NAME = f"PYSAI_3200_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_AGENT_DESC = "PYSAI_3200_AGENT_DESCRIPTION" +PYSAI_PROFILE_NAME = f"PYSAI_3200_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_DISABLED_AGENT_NAME = f"PYSAI_3200_DISABLED_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_MISSING_AGENT_NAME = f"PYSAI_3200_MISSING_AGENT_{uuid.uuid4().hex.upper()}" + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def python_gen_ai_profile(profile_attributes): + logger.info("Creating profile: %s", PYSAI_PROFILE_NAME) + profile = select_ai.Profile( + profile_name=PYSAI_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + profile.create(replace=True) + yield profile + logger.info("Deleting profile: %s", PYSAI_PROFILE_NAME) + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def agent_attributes(): + return AgentAttributes( + profile_name=PYSAI_PROFILE_NAME, + role="You are an AI Movie Analyst. You analyze movies.", + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +def agent(python_gen_ai_profile, agent_attributes): + logger.info("Creating agent: %s", PYSAI_AGENT_NAME) + agent = Agent( + agent_name=PYSAI_AGENT_NAME, + description=PYSAI_AGENT_DESC, + attributes=agent_attributes, + ) + agent.create(enabled=True, replace=True) + yield agent + logger.info("Deleting agent: %s", PYSAI_AGENT_NAME) + agent.delete(force=True) + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def expect_oracle_error(expected_code, fn): + """ + Run fn and assert that expected Oracle/Agent error occurs. + expected_code: "ORA-xxxxx" or "NOT_FOUND" + """ + try: + fn() + except AgentNotFoundError as e: + logger.info("Expected failure (NOT_FOUND): %s", e) + assert expected_code == "NOT_FOUND" + except oracledb.DatabaseError as e: + msg = str(e) + logger.info("Expected Oracle failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got {msg}" + else: + pytest.fail(f"Expected error {expected_code} did not occur") + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + +def test_3200_identity(agent, agent_attributes): + logger.info("Verifying agent identity") + logger.info("Agent name : %s", agent.agent_name) + logger.info("Agent description: %s", agent.description) + logger.info("Agent attributes : %s", agent.attributes) + assert agent.agent_name == PYSAI_AGENT_NAME + assert agent.description == PYSAI_AGENT_DESC + assert agent.attributes == agent_attributes + + +@pytest.mark.parametrize("pattern", [None, ".*", "^PYSAI_3200_AGENT_"]) +def test_3201_list(pattern): + logger.info("Listing agents with pattern: %s", pattern) + agents = list(Agent.list() if pattern is None else Agent.list(pattern)) + names = sorted(a.agent_name for a in agents) + logger.info("Agents found (sorted):") + for name in names: + logger.info(" - %s", name) + + assert PYSAI_AGENT_NAME in names + + +def test_3202_fetch(agent_attributes): + logger.info("Fetching agent: %s", PYSAI_AGENT_NAME) + a = Agent.fetch(PYSAI_AGENT_NAME) + logger.info("Fetched agent name : %s", a.agent_name) + logger.info("Fetched agent description: %s", a.description) + logger.info("Fetched agent attributes : %s", a.attributes) + assert a.agent_name == PYSAI_AGENT_NAME + assert a.attributes == agent_attributes + assert a.description == PYSAI_AGENT_DESC + + +def test_3203_fetch_non_existing(): + name = f"PYSAI_NO_SUCH_AGENT_{uuid.uuid4().hex}" + logger.info("Fetching non-existing agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + + +def test_3204_create_agent_default_status_enabled(agent_attributes): + name = f"PYSAI_3200_STATUS_ENABLED_{uuid.uuid4().hex.upper()}" + logger.info("Creating agent with default enabled status: %s", name) + a = Agent( + agent_name=name, + description="Default enabled status", + attributes=agent_attributes, + ) + a.create(replace=True) + try: + status = get_agent_status(name) + logger.info("Agent status after create: %s", status) + assert status == "ENABLED" + + fetched = Agent.fetch(name) + logger.info("Fetched created agent: %s", fetched.agent_name) + assert fetched.description == "Default enabled status" + finally: + a.delete(force=True) + + +def test_3205_create_agent_with_enabled_false_sets_disabled(agent_attributes): + logger.info("Creating disabled agent: %s", PYSAI_DISABLED_AGENT_NAME) + a = Agent( + agent_name=PYSAI_DISABLED_AGENT_NAME, + description="Initially disabled", + attributes=agent_attributes, + ) + a.create(enabled=False, replace=True) + try: + status = get_agent_status(PYSAI_DISABLED_AGENT_NAME) + logger.info("Agent status after create(enabled=False): %s", status) + assert status == "DISABLED" + + fetched = Agent.fetch(PYSAI_DISABLED_AGENT_NAME) + logger.info("Fetched disabled agent: %s", fetched.agent_name) + assert fetched.description == "Initially disabled" + finally: + a.delete(force=True) + + +def test_3206_drop_agent_force_true_non_existent(): + logger.info("Dropping missing agent with force=True: %s", PYSAI_MISSING_AGENT_NAME) + a = Agent(agent_name=PYSAI_MISSING_AGENT_NAME) + a.delete(force=True) + status = get_agent_status(PYSAI_MISSING_AGENT_NAME) + logger.info("Status after force delete on missing agent: %s", status) + assert status is None + + +def test_3207_drop_agent_force_false_non_existent_raises(): + logger.info("Dropping missing agent with force=False: %s", PYSAI_MISSING_AGENT_NAME) + a = Agent(agent_name=PYSAI_MISSING_AGENT_NAME) + expect_oracle_error("ORA-20050", lambda: a.delete(force=False)) + + +def test_3208_create_requires_agent_name(agent_attributes): + logger.info("Validating create() requires agent_name") + with pytest.raises(AttributeError) as exc: + Agent(attributes=agent_attributes).create() + logger.info("Received expected error: %s", exc.value) + + +def test_3209_create_requires_attributes(): + logger.info("Validating create() requires attributes") + with pytest.raises(AttributeError) as exc: + Agent(agent_name=f"PYSAI_3200_NO_ATTR_{uuid.uuid4().hex.upper()}").create() + logger.info("Received expected error: %s", exc.value) + + +def test_3210_disable_enable(agent): + logger.info("Disabling agent: %s", agent.agent_name) + agent.disable() + + status = get_agent_status(agent.agent_name) + logger.info("Agent status after disable: %s", status) + assert status == "DISABLED" + + logger.info("Enabling agent: %s", agent.agent_name) + agent.enable() + + status = get_agent_status(agent.agent_name) + logger.info("Agent status after enable: %s", status) + assert status == "ENABLED" + + +def test_3211_set_attribute(agent): + logger.info("Setting role attribute on agent: %s", agent.agent_name) + agent.set_attribute("role", "You are a DB assistant") + + a = Agent.fetch(PYSAI_AGENT_NAME) + logger.info("Updated role attribute: %s", a.attributes.role) + + assert "DB assistant" in a.attributes.role + + +def test_3212_set_attributes(agent): + logger.info("Replacing agent attributes") + + new_attrs = AgentAttributes( + profile_name=PYSAI_PROFILE_NAME, + role="You are a cloud architect", + enable_human_tool=True, + ) + + logger.info("New attributes: %s", new_attrs) + agent.set_attributes(new_attrs) + + a = Agent.fetch(PYSAI_AGENT_NAME) + logger.info("Fetched attributes after replace: %s", a.attributes) + + assert a.attributes == new_attrs + + +def test_3213_set_attribute_invalid_key(agent): + logger.info("Setting invalid attribute key on agent: %s", agent.agent_name) + expect_oracle_error("ORA-20050", lambda: agent.set_attribute("no_such_key", 123)) + +def test_3214_set_attribute_none(agent): + logger.info("Setting attribute 'role' to None on agent: %s", agent.agent_name) + expect_oracle_error("ORA-20050", lambda: agent.set_attribute("role", None)) + +def test_3215_set_attribute_empty(agent): + logger.info("Setting attribute 'role' to empty string on agent: %s", agent.agent_name) + expect_oracle_error("ORA-20050", lambda: agent.set_attribute("role", "")) + +def test_3216_create_existing_without_replace(agent_attributes): + logger.info("Create existing agent without replace should fail") + a = Agent( + agent_name=PYSAI_AGENT_NAME, + description="X", + attributes=agent_attributes, + ) + expect_oracle_error("ORA-20050", lambda: a.create(replace=False)) + +def test_3217_delete_and_recreate(agent_attributes): + name = f"PYSAI_RECREATE_{uuid.uuid4().hex}" + logger.info("Create agent: %s", name) + #Create agent + a = Agent(name, attributes=agent_attributes) + a.create() + # Verify created + fetched = Agent.fetch(name) + logger.info("Agent created successfully: %s", fetched.agent_name) + assert fetched.agent_name == name + #Delete agent + logger.info("Delete agent: %s", name) + a.delete(force=True) + # Verify deleted + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Agent deleted successfully: %s", name) + #Recreate agent + logger.info("Recreate agent: %s", name) + a.create(replace=False) + # Verify recreated + fetched_recreated = Agent.fetch(name) + logger.info("Agent recreated successfully: %s", fetched_recreated.agent_name) + assert fetched_recreated.agent_name == name + #Final cleanup + logger.info("Cleanup agent: %s", name) + a.delete(force=True) + # Verify cleanup + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Final cleanup successful for agent: %s", name) + + +def test_3218_disable_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent: %s", name) + a.delete(force=True) + logger.info("Agent deleted, verifying deletion: %s", name) + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + logger.info("Attempting to disable deleted agent: %s", name) + expect_oracle_error("ORA-20050", lambda: a.disable()) + logger.info("Disable after delete confirmed error for agent: %s", name) + + +def test_3219_enable_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent: %s", name) + a.delete(force=True) + logger.info("Agent deleted, verifying deletion: %s", name) + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + logger.info("Attempting to enable deleted agent: %s", name) + expect_oracle_error("ORA-20050", lambda: a.enable()) + logger.info("Enable after delete confirmed error for agent: %s", name) + + +def test_3220_set_attribute_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent: %s", name) + a.delete(force=True) + logger.info("Agent deleted, verifying deletion: %s", name) + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + logger.info("Attempting to set attribute on deleted agent: %s", name) + expect_oracle_error("ORA-20050", lambda: a.set_attribute("role", "X")) + logger.info("Set attribute after delete confirmed error for agent: %s", name) + + +def test_3221_double_delete_force_true(agent_attributes): + name = f"PYSAI_TMP_DOUBLE_DEL_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent first time: %s", name) + a.delete(force=True) + logger.info("First delete done, verifying deletion: %s", name) + logger.info("Attempting fetch after first delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + logger.info("Deleting agent second time (should not fail): %s", name) + a.delete(force=True) + logger.info("Second delete completed, verifying still deleted: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent still does not exist after double delete: %s", name) + + +def test_3222_double_delete_force_false_raises(agent_attributes): + name = f"PYSAI_TMP_DOUBLE_DEL_FALSE_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent first time with force=False: %s", name) + a.delete(force=False) + logger.info("First delete done, verifying deletion: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + logger.info("Deleting agent second time with force=False: %s", name) + expect_oracle_error("ORA-20050", lambda: a.delete(force=False)) + logger.info("Confirmed second delete with force=False raises error: %s", name) + + +def test_3223_fetch_after_delete(agent_attributes): + name = f"PYSAI_TMP_FETCH_DEL_{uuid.uuid4().hex}" + logger.info("Creating agent: %s", name) + a = Agent(name, attributes=agent_attributes) + a.create() + logger.info("Agent created successfully: %s", name) + + logger.info("Fetching agent to verify creation: %s", name) + fetched = Agent.fetch(name) + logger.info("Fetched agent: %s", fetched.agent_name) + + logger.info("Deleting agent: %s", name) + a.delete(force=True) + logger.info("Agent deleted, verifying deletion: %s", name) + logger.info("Attempting fetch after delete for agent: %s", name) + expect_oracle_error("NOT_FOUND", lambda: Agent.fetch(name)) + logger.info("Confirmed agent no longer exists: %s", name) + + +def test_3224_list_all_non_empty(): + logger.info("Listing all agents") + agents = list(Agent.list()) + names = sorted(a.agent_name for a in agents) + logger.info("Total agents found: %d", len(names)) + logger.info("Agent names:") + for name in names: + logger.info(" - %s", name) + assert len(names) > 0 diff --git a/tests/agents/test_3201_async_agents.py b/tests/agents/test_3201_async_agents.py new file mode 100644 index 0000000..58f102c --- /dev/null +++ b/tests/agents/test_3201_async_agents.py @@ -0,0 +1,416 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3200 - Module for testing select_ai async agents +""" + +import logging +import os +import uuid + +import oracledb +import pytest +import select_ai +from select_ai.agent import AgentAttributes, AsyncAgent +from select_ai.errors import AgentNotFoundError + +pytestmark = pytest.mark.anyio + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3200_async_agents.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +root = logging.getLogger() +root.setLevel(logging.INFO) +for handler in root.handlers[:]: + root.removeHandler(handler) +file_handler = logging.FileHandler(LOG_FILE, mode="w") +file_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(file_handler) +logger = logging.getLogger() + +PYSAI_3200_AGENT_NAME = f"PYSAI_3200_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_3200_AGENT_DESCRIPTION = "PYSAI_3200_AGENT_DESCRIPTION" +PYSAI_3200_PROFILE_NAME = f"PYSAI_3200_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_3200_DISABLED_AGENT_NAME = ( + f"PYSAI_3200_DISABLED_AGENT_{uuid.uuid4().hex.upper()}" +) +PYSAI_3200_MISSING_AGENT_NAME = ( + f"PYSAI_3200_MISSING_AGENT_{uuid.uuid4().hex.upper()}" +) + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.fixture(scope="module", autouse=True) +async def async_connect(test_env): + logger.info("Opening async database connection") + await select_ai.async_connect(**test_env.connect_params()) + yield + logger.info("Closing async database connection") + await select_ai.async_disconnect() + + +def log_agent_details(context: str, agent) -> None: + attrs = getattr(agent, "attributes", None) + details = { + "context": context, + "agent_name": getattr(agent, "agent_name", None), + "description": getattr(agent, "description", None), + "profile_name": getattr(attrs, "profile_name", None) if attrs else None, + "role": getattr(attrs, "role", None) if attrs else None, + "enable_human_tool": ( + getattr(attrs, "enable_human_tool", None) if attrs else None + ), + } + logger.info("AGENT_DETAILS: %s", details) + print("AGENT_DETAILS:", details) + + +async def expect_async_error(expected_code, coro_fn): + try: + await coro_fn() + except AgentNotFoundError as exc: + logger.info("Expected failure (NOT_FOUND): %s", exc) + assert expected_code == "NOT_FOUND" + except oracledb.DatabaseError as exc: + msg = str(exc) + logger.info("Expected Oracle failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got: {msg}" + else: + pytest.fail(f"Expected error {expected_code} did not occur") + + +async def get_agent_status(agent_name): + logger.info("Fetching agent status for: %s", agent_name) + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT status + FROM USER_AI_AGENTS + WHERE agent_name = :agent_name + """, + {"agent_name": agent_name}, + ) + row = await cur.fetchone() + return row[0] if row else None + + +async def assert_agent_status(agent_name: str, expected_status: str) -> None: + status = await get_agent_status(agent_name) + logger.info( + "Verifying agent status | agent=%s | expected=%s | actual=%s", + agent_name, + expected_status, + status, + ) + assert status == expected_status + + +@pytest.fixture(scope="module") +async def async_python_gen_ai_profile(profile_attributes): + logger.info("Creating profile: %s", PYSAI_3200_PROFILE_NAME) + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_3200_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + yield profile + logger.info("Deleting profile: %s", PYSAI_3200_PROFILE_NAME) + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +def agent_attributes(): + return AgentAttributes( + profile_name=PYSAI_3200_PROFILE_NAME, + role=( + "You are an AI Movie Analyst. " + "You can help answer movie-related questions." + ), + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +async def agent(async_python_gen_ai_profile, agent_attributes): + logger.info("Creating async agent: %s", PYSAI_3200_AGENT_NAME) + agent_obj = AsyncAgent( + agent_name=PYSAI_3200_AGENT_NAME, + attributes=agent_attributes, + description=PYSAI_3200_AGENT_DESCRIPTION, + ) + await agent_obj.create(enabled=True, replace=True) + yield agent_obj + logger.info("Deleting async agent: %s", PYSAI_3200_AGENT_NAME) + await agent_obj.delete(force=True) + + +async def test_3200_identity(agent, agent_attributes): + log_agent_details("test_3200_identity", agent) + assert agent.agent_name == PYSAI_3200_AGENT_NAME + assert agent.attributes == agent_attributes + assert agent.description == PYSAI_3200_AGENT_DESCRIPTION + assert agent.attributes.enable_human_tool is False + + +@pytest.mark.parametrize("agent_name_pattern", [None, ".*", "^PYSAI_3200_AGENT_"]) +async def test_3201_list(agent_name_pattern): + logger.info("Listing agents with pattern=%s", agent_name_pattern) + if agent_name_pattern: + agents = [ + a async for a in select_ai.agent.AsyncAgent.list(agent_name_pattern) + ] + else: + agents = [a async for a in select_ai.agent.AsyncAgent.list()] + + for a in agents: + if a.agent_name == PYSAI_3200_AGENT_NAME: + log_agent_details("test_3201_list", a) + + agent_names = set(a.agent_name for a in agents) + agent_descriptions = set(a.description for a in agents) + assert len(agents) >= 1 + assert PYSAI_3200_AGENT_NAME in agent_names + assert PYSAI_3200_AGENT_DESCRIPTION in agent_descriptions + + +async def test_3202_fetch(agent_attributes): + a = await AsyncAgent.fetch(agent_name=PYSAI_3200_AGENT_NAME) + log_agent_details("test_3202_fetch", a) + assert a.agent_name == PYSAI_3200_AGENT_NAME + assert a.attributes == agent_attributes + assert a.description == PYSAI_3200_AGENT_DESCRIPTION + + +async def test_3203_fetch_non_existing(): + name = f"PYSAI_NO_SUCH_AGENT_{uuid.uuid4().hex.upper()}" + logger.info("Fetching non-existing async agent: %s", name) + with pytest.raises(AgentNotFoundError) as exc: + await AsyncAgent.fetch(name) + logger.info("Received expected error: %s", exc.value) + + +async def test_3204_create_agent_default_status_enabled(agent_attributes): + name = f"PYSAI_3200_STATUS_ENABLED_{uuid.uuid4().hex.upper()}" + a = AsyncAgent( + agent_name=name, + description="Default enabled status", + attributes=agent_attributes, + ) + await a.create(replace=True) + try: + await assert_agent_status(name, "ENABLED") + fetched = await AsyncAgent.fetch(name) + log_agent_details("test_3204_create_agent_default_status_enabled", fetched) + assert fetched.description == "Default enabled status" + finally: + await a.delete(force=True) + + +async def test_3205_create_agent_with_enabled_false_sets_disabled(agent_attributes): + a = AsyncAgent( + agent_name=PYSAI_3200_DISABLED_AGENT_NAME, + description="Initially disabled", + attributes=agent_attributes, + ) + await a.create(enabled=False, replace=True) + try: + await assert_agent_status(PYSAI_3200_DISABLED_AGENT_NAME, "DISABLED") + fetched = await AsyncAgent.fetch(PYSAI_3200_DISABLED_AGENT_NAME) + log_agent_details( + "test_3205_create_agent_with_enabled_false_sets_disabled", fetched + ) + assert fetched.description == "Initially disabled" + finally: + await a.delete(force=True) + + +async def test_3206_set_attribute(agent): + logger.info("Setting role attribute on async agent: %s", agent.agent_name) + await agent.set_attribute("role", "You are a DB assistant") + updated = await AsyncAgent.fetch(PYSAI_3200_AGENT_NAME) + log_agent_details("test_3206_set_attribute", updated) + assert "DB assistant" in updated.attributes.role + + +async def test_3207_set_attributes(agent): + logger.info("Replacing async agent attributes") + new_attrs = AgentAttributes( + profile_name=PYSAI_3200_PROFILE_NAME, + role="You are a cloud architect", + enable_human_tool=True, + ) + await agent.set_attributes(new_attrs) + updated = await AsyncAgent.fetch(PYSAI_3200_AGENT_NAME) + log_agent_details("test_3207_set_attributes", updated) + assert updated.attributes == new_attrs + + +async def test_3208_set_attribute_invalid_key(agent): + logger.info("Setting invalid attribute key on async agent") + with pytest.raises(oracledb.DatabaseError) as exc: + await agent.set_attribute("no_such_key", 123) + logger.info("Received expected Oracle error: %s", exc.value) + assert "ORA-20050" in str(exc.value) + + +async def test_3209_drop_agent_force_true_non_existent(): + logger.info("Dropping missing agent with force=True") + a = AsyncAgent(agent_name=PYSAI_3200_MISSING_AGENT_NAME) + await a.delete(force=True) + status = await get_agent_status(PYSAI_3200_MISSING_AGENT_NAME) + logger.info("Status after force delete on missing agent: %s", status) + assert status is None + + +async def test_3210_drop_agent_force_false_non_existent_raises(): + logger.info("Dropping missing agent with force=False") + a = AsyncAgent(agent_name=PYSAI_3200_MISSING_AGENT_NAME) + with pytest.raises(oracledb.DatabaseError) as exc: + await a.delete(force=False) + logger.info("Received expected Oracle error: %s", exc.value) + + +async def test_3211_create_requires_agent_name(agent_attributes): + logger.info("Validating async create() requires agent_name") + with pytest.raises(AttributeError) as exc: + await AsyncAgent(attributes=agent_attributes).create() + logger.info("Received expected error: %s", exc.value) + + +async def test_3212_create_requires_attributes(): + logger.info("Validating async create() requires attributes") + with pytest.raises(AttributeError) as exc: + await AsyncAgent( + agent_name=f"PYSAI_3200_NO_ATTR_{uuid.uuid4().hex.upper()}" + ).create() + logger.info("Received expected error: %s", exc.value) + + +async def test_3213_disable_enable(agent): + logger.info("Disabling async agent: %s", agent.agent_name) + await agent.disable() + await assert_agent_status(agent.agent_name, "DISABLED") + + logger.info("Enabling async agent: %s", agent.agent_name) + await agent.enable() + await assert_agent_status(agent.agent_name, "ENABLED") + + +async def test_3214_set_attribute_none(agent): + logger.info("Setting role=None on async agent: %s", agent.agent_name) + await expect_async_error("ORA-20050", lambda: agent.set_attribute("role", None)) + + +async def test_3215_set_attribute_empty(agent): + logger.info("Setting role='' on async agent: %s", agent.agent_name) + await expect_async_error("ORA-20050", lambda: agent.set_attribute("role", "")) + + +async def test_3216_create_existing_without_replace(agent_attributes): + logger.info("Creating duplicate async agent without replace") + dup = AsyncAgent( + agent_name=PYSAI_3200_AGENT_NAME, + description="Duplicate async agent", + attributes=agent_attributes, + ) + await expect_async_error("ORA-20050", lambda: dup.create(replace=False)) + + +async def test_3217_delete_and_recreate(agent_attributes): + name = f"PYSAI_RECREATE_{uuid.uuid4().hex.upper()}" + logger.info("Creating async agent: %s", name) + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + + fetched = await AsyncAgent.fetch(name) + log_agent_details("test_3217_created", fetched) + assert fetched.agent_name == name + + logger.info("Deleting async agent: %s", name) + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + + logger.info("Recreating async agent: %s", name) + await a.create(replace=False) + recreated = await AsyncAgent.fetch(name) + log_agent_details("test_3217_recreated", recreated) + assert recreated.agent_name == name + + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + + +async def test_3218_disable_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + await expect_async_error("ORA-20050", lambda: a.disable()) + + +async def test_3219_enable_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + await expect_async_error("ORA-20050", lambda: a.enable()) + + +async def test_3220_set_attribute_after_delete(agent_attributes): + name = f"PYSAI_TMP_DEL_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + await expect_async_error("ORA-20050", lambda: a.set_attribute("role", "X")) + + +async def test_3221_double_delete_force_true(agent_attributes): + name = f"PYSAI_TMP_DOUBLE_DEL_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + + +async def test_3222_double_delete_force_false_raises(agent_attributes): + name = f"PYSAI_TMP_DOUBLE_DEL_FALSE_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=False) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + await expect_async_error("ORA-20050", lambda: a.delete(force=False)) + + +async def test_3223_fetch_after_delete(agent_attributes): + name = f"PYSAI_TMP_FETCH_DEL_{uuid.uuid4().hex.upper()}" + a = AsyncAgent(name, attributes=agent_attributes) + await a.create() + await a.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncAgent.fetch(name)) + + +async def test_3224_list_all_non_empty(): + logger.info("Listing all async agents") + agents = [a async for a in AsyncAgent.list()] + names = sorted(a.agent_name for a in agents) + logger.info("Total async agents found: %d", len(names)) + for name in names: + logger.info(" - %s", name) + assert len(names) > 0 diff --git a/tests/agents/test_3301_async_teams.py b/tests/agents/test_3301_async_teams.py new file mode 100644 index 0000000..3a54a3f --- /dev/null +++ b/tests/agents/test_3301_async_teams.py @@ -0,0 +1,393 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3301 - Async contract, regression and corner-case tests for select_ai.agent.AsyncTeam +""" + +import logging +import os +import uuid + +import oracledb +import pytest +import select_ai +from select_ai.agent import ( + AgentAttributes, + AsyncAgent, + AsyncTask, + AsyncTeam, + TaskAttributes, + TeamAttributes, +) +from select_ai.errors import AgentTeamNotFoundError + +pytestmark = pytest.mark.anyio + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_DIR = os.path.join(PROJECT_ROOT, "log") +os.makedirs(LOG_DIR, exist_ok=True) +LOG_FILE = os.path.join(LOG_DIR, "tkex_test_3301_async_teams.log") + +root = logging.getLogger() +root.setLevel(logging.INFO) +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# ----------------------------------------------------------------------------- +# Per-test logging + async connection +# ----------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@pytest.fixture(scope="module", autouse=True) +async def async_connect(test_env): + logger.info("Opening async database connection") + await select_ai.async_connect(**test_env.connect_params()) + yield + logger.info("Closing async database connection") + await select_ai.async_disconnect() + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + +async def expect_async_error(expected_code, coro_fn): + """ + expected_code: + - "NOT_FOUND" + - "ORA-20053" + - "ORA-xxxxx" + """ + try: + await coro_fn() + except AgentTeamNotFoundError as exc: + logger.info("Expected failure (NOT_FOUND): %s", exc) + assert expected_code == "NOT_FOUND" + except oracledb.DatabaseError as exc: + msg = str(exc) + logger.info("Expected Oracle failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got: {msg}" + except Exception as exc: + msg = str(exc) + logger.info("Expected generic failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got: {msg}" + else: + pytest.fail(f"Expected error {expected_code} did not occur") + + +def log_team_details(context: str, team) -> None: + attrs = getattr(team, "attributes", None) + details = { + "context": context, + "team_name": getattr(team, "team_name", None), + "description": getattr(team, "description", None), + "process": getattr(attrs, "process", None) if attrs else None, + "agents": getattr(attrs, "agents", None) if attrs else None, + } + logger.info("TEAM_DETAILS: %s", details) + print("TEAM_DETAILS:", details) + + +async def get_team_status(team_name: str): + logger.info("Fetching team status for: %s", team_name) + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT status + FROM USER_AI_AGENT_TEAMS + WHERE agent_team_name = :team_name + """, + {"team_name": team_name}, + ) + row = await cur.fetchone() + return row[0] if row else None + + +async def assert_team_status(team_name: str, expected_status: str) -> None: + status = await get_team_status(team_name) + logger.info( + "Verifying team status | team=%s | expected=%s | actual=%s", + team_name, + expected_status, + status, + ) + assert status == expected_status + + +# ----------------------------------------------------------------------------- +# Test constants +# ----------------------------------------------------------------------------- + +PYSAI_TEAM_AGENT_NAME = f"PYSAI_TEAM_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_PROFILE_NAME = f"PYSAI_TEAM_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_TASK_NAME = f"PYSAI_TEAM_TASK_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_NAME = f"PYSAI_TEAM_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_DESC = "PYSAI ASYNC TEAM FINAL CONTRACT TEST" + + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +async def python_gen_ai_profile(profile_attributes): + logger.info("Creating profile: %s", PYSAI_TEAM_PROFILE_NAME) + + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + if oci_compartment_id: + profile_attributes.oci_compartment_id = oci_compartment_id + + profile = await select_ai.AsyncProfile( + profile_name=PYSAI_TEAM_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, + ) + + yield profile + + logger.info("Deleting profile: %s", PYSAI_TEAM_PROFILE_NAME) + await profile.delete(force=True) + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user. Question: {query}", + enable_human_tool=False, + ) + + +@pytest.fixture(scope="module") +async def task(task_attributes): + logger.info("Creating task: %s", PYSAI_TEAM_TASK_NAME) + task_obj = AsyncTask( + task_name=PYSAI_TEAM_TASK_NAME, + description="Test Task", + attributes=task_attributes, + ) + await task_obj.create(replace=True) + yield task_obj + logger.info("Deleting task: %s", PYSAI_TEAM_TASK_NAME) + await task_obj.delete(force=True) + + +@pytest.fixture(scope="module") +async def agent(python_gen_ai_profile): + logger.info("Creating agent: %s", PYSAI_TEAM_AGENT_NAME) + agent_obj = AsyncAgent( + agent_name=PYSAI_TEAM_AGENT_NAME, + description="Test Agent", + attributes=AgentAttributes( + profile_name=PYSAI_TEAM_PROFILE_NAME, + role="You are a helpful AI assistant", + enable_human_tool=False, + ), + ) + await agent_obj.create(enabled=True, replace=True) + yield agent_obj + logger.info("Deleting agent: %s", PYSAI_TEAM_AGENT_NAME) + await agent_obj.delete(force=True) + + +@pytest.fixture(scope="module") +def team_attributes(agent, task): + return TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + + +@pytest.fixture(scope="module") +async def team(team_attributes): + logger.info("Creating team: %s", PYSAI_TEAM_NAME) + team_obj = AsyncTeam( + team_name=PYSAI_TEAM_NAME, + attributes=team_attributes, + description=PYSAI_TEAM_DESC, + ) + await team_obj.create(enabled=True, replace=True) + yield team_obj + logger.info("Deleting team: %s", PYSAI_TEAM_NAME) + await team_obj.delete(force=True) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + +async def test_3300_create_and_identity(team, team_attributes): + log_team_details("test_3300_create_and_identity", team) + assert team.team_name == PYSAI_TEAM_NAME + assert team.description == PYSAI_TEAM_DESC + assert team.attributes == team_attributes + + +@pytest.mark.parametrize("pattern", [None, ".*", "^PYSAI_TEAM_"]) +async def test_3301_list(pattern): + logger.info("Listing teams using pattern: %s", pattern) + teams = [t async for t in AsyncTeam.list(pattern)] if pattern else [t async for t in AsyncTeam.list()] + for t in teams: + if t.team_name == PYSAI_TEAM_NAME: + log_team_details("test_3301_list", t) + names = [t.team_name for t in teams] + assert PYSAI_TEAM_NAME in names + + +async def test_3302_fetch(team_attributes): + t = await AsyncTeam.fetch(PYSAI_TEAM_NAME) + log_team_details("test_3302_fetch", t) + assert t.attributes == team_attributes + + +async def test_3303_run(team): + response = await team.run( + prompt="What is 2+2?", + params={"conversation_id": str(uuid.uuid4())}, + ) + logger.info("Team run response: %s", response) + assert isinstance(response, str) + assert len(response) > 0 + + +async def test_3304_disable_enable_contract(team): + logger.info("Disabling team: %s", team.team_name) + await team.disable() + await assert_team_status(team.team_name, "DISABLED") + await expect_async_error("ORA-20053", lambda: team.disable()) + + logger.info("Enabling team: %s", team.team_name) + await team.enable() + await assert_team_status(team.team_name, "ENABLED") + await expect_async_error("ORA-20053", lambda: team.enable()) + + +async def test_3305_set_attribute_process(team): + await team.set_attribute("process", "sequential") + fetched = await AsyncTeam.fetch(PYSAI_TEAM_NAME) + log_team_details("test_3305_set_attribute_process", fetched) + assert fetched.attributes.process == "sequential" + + +async def test_3306_set_attributes(team, agent, task): + new_attrs = TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + await team.set_attributes(new_attrs) + fetched = await AsyncTeam.fetch(PYSAI_TEAM_NAME) + log_team_details("test_3306_set_attributes", fetched) + assert fetched.attributes == new_attrs + + +async def test_3307_replace_create(team_attributes): + team2 = AsyncTeam(PYSAI_TEAM_NAME, team_attributes, "REPLACED DESC") + await team2.create(enabled=True, replace=True) + fetched = await AsyncTeam.fetch(PYSAI_TEAM_NAME) + log_team_details("test_3307_replace_create", fetched) + assert fetched.description == "REPLACED DESC" + + +async def test_3308_fetch_non_existing(): + name = f"NO_SUCH_{uuid.uuid4().hex.upper()}" + await expect_async_error("NOT_FOUND", lambda: AsyncTeam.fetch(name)) + + +async def test_3311_set_attribute_invalid_key(team): + await expect_async_error("ORA-20053", lambda: team.set_attribute("no_such_attr", "x")) + + +async def test_3312_set_attribute_none(team): + await expect_async_error("ORA-20053", lambda: team.set_attribute("process", None)) + + +async def test_3313_set_attribute_empty(team): + await expect_async_error("ORA-20053", lambda: team.set_attribute("process", "")) + + +async def test_3314_set_attribute_invalid_value(team): + await expect_async_error( + "ORA-20053", + lambda: team.set_attribute("process", "not_a_real_process"), + ) + + +async def test_3315_disable_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("ORA-20053", lambda: t.disable()) + + +async def test_3316_enable_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("ORA-20053", lambda: t.enable()) + + +async def test_3317_set_attribute_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("ORA-20053", lambda: t.set_attribute("process", "sequential")) + + +async def test_3318_double_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("ORA-20053", lambda: t.delete(force=False)) + + +async def test_3319_create_existing_without_replace(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t1 = AsyncTeam(name, team_attributes, "TMP1") + await t1.create(replace=False) + await expect_async_error( + "ORA-20053", + lambda: AsyncTeam(name, team_attributes, "TMP2").create(replace=False), + ) + await t1.delete(force=True) + + +async def test_3320_fetch_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("NOT_FOUND", lambda: AsyncTeam.fetch(name)) + +async def test_3321_double_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex.upper()}" + t = AsyncTeam(name, team_attributes, "TMP") + await t.create() + await t.delete(force=True) + await expect_async_error("ORA-20053", lambda: t.delete(force=False)) + diff --git a/tests/agents/test_3301_teams.py b/tests/agents/test_3301_teams.py new file mode 100644 index 0000000..66398f9 --- /dev/null +++ b/tests/agents/test_3301_teams.py @@ -0,0 +1,415 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3301 - Final contract, regression and corner-case tests for select_ai.agent.Team +""" + +import uuid +import logging +import os +import pytest +import select_ai +import oracledb + +from select_ai.agent import ( + Agent, + AgentAttributes, + Task, + TaskAttributes, + Team, + TeamAttributes, +) + +from select_ai.errors import AgentTeamNotFoundError + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_DIR = os.path.join(PROJECT_ROOT, "log") +os.makedirs(LOG_DIR, exist_ok=True) + +LOG_FILE = os.path.join(LOG_DIR, "tkex_test_3301_teams.log") + +root = logging.getLogger() +root.setLevel(logging.INFO) + +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +LOGGER = logging.getLogger(__name__) +LOGGER.setLevel(logging.INFO) + +def log_step(msg): + LOGGER.info("%s", msg) + +def log_ok(msg): + LOGGER.info("%s", msg) +logger = LOGGER + + +# ----------------------------------------------------------------------------- +# Per-test logging +# ----------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info(f"--- Starting test: {request.function.__name__} ---") + yield + logger.info(f"--- Finished test: {request.function.__name__} ---") + + +# ----------------------------------------------------------------------------- +# Strict error checker (LIKE 3101 / 3201) +# ----------------------------------------------------------------------------- + +def expect_error(expected_code, fn): + """ + expected_code: + - "NOT_FOUND" + - "ORA-20051" + - "ORA-xxxxx" + """ + try: + fn() + except AgentTeamNotFoundError as e: + LOGGER.info("Expected failure (NOT_FOUND): %s", e) + assert expected_code == "NOT_FOUND" + except oracledb.DatabaseError as e: + msg = str(e) + LOGGER.info("Expected Oracle failure: %s", msg) + assert expected_code in msg, f"Expected {expected_code}, got: {msg}" + except Exception as e: + LOGGER.info("Expected generic failure: %s", e) + assert expected_code in str(e), f"Expected {expected_code}, got: {e}" + else: + pytest.fail(f"Expected error {expected_code} did not occur") + +# ----------------------------------------------------------------------------- +# Test constants +# ----------------------------------------------------------------------------- + +PYSAI_TEAM_AGENT_NAME = f"PYSAI_TEAM_AGENT_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_PROFILE_NAME = f"PYSAI_TEAM_PROFILE_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_TASK_NAME = f"PYSAI_TEAM_TASK_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_NAME = f"PYSAI_TEAM_{uuid.uuid4().hex.upper()}" +PYSAI_TEAM_DESC = "PYSAI TEAM FINAL CONTRACT TEST" + +# ----------------------------------------------------------------------------- +# Fixtures +# ----------------------------------------------------------------------------- + +# @pytest.fixture(scope="module", autouse=True) +# def _connect(): +# select_ai.connect() +# yield +# select_ai.disconnect() + +# @pytest.fixture(scope="module") +# def profile_attributes(): +# return { +# "provider": "oci_genai", +# "model": "cohere.command-r-plus" +# } + +@pytest.fixture(scope="module") +def python_gen_ai_profile(profile_attributes): + log_step(f"Creating profile: {PYSAI_TEAM_PROFILE_NAME}") + + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + if not oci_compartment_id: + raise RuntimeError("PYSAI_TEST_OCI_COMPARTMENT_ID not set") + + # ---- EXTEND existing ProfileAttributes object ---- + profile_attributes.oci_compartment_id = oci_compartment_id + + # ---- STRICT TYPE CHECK ---- + assert isinstance( + profile_attributes, + select_ai.ProfileAttributes + ), "profile_attributes must be ProfileAttributes object" + + profile = select_ai.Profile( + profile_name=PYSAI_TEAM_PROFILE_NAME, + description="OCI GENAI Profile", + attributes=profile_attributes, # <-- pass object, NOT dict + ) + + profile.create(replace=True) + + yield profile + + log_step(f"Deleting profile: {PYSAI_TEAM_PROFILE_NAME}") + profile.delete(force=True) + + +@pytest.fixture(scope="module") +def task_attributes(): + return TaskAttributes( + instruction="Help the user. Question: {query}", + enable_human_tool=False, + ) + +@pytest.fixture(scope="module") +def task(task_attributes): + log_step(f"Creating task: {PYSAI_TEAM_TASK_NAME}") + task = Task( + task_name=PYSAI_TEAM_TASK_NAME, + description="Test Task", + attributes=task_attributes, + ) + task.create(replace=True) + yield task + log_step(f"Deleting task: {PYSAI_TEAM_TASK_NAME}") + task.delete(force=True) + +@pytest.fixture(scope="module") +def agent(python_gen_ai_profile): + log_step(f"Creating agent: {PYSAI_TEAM_AGENT_NAME}") + agent = Agent( + agent_name=PYSAI_TEAM_AGENT_NAME, + description="Test Agent", + attributes=AgentAttributes( + profile_name=PYSAI_TEAM_PROFILE_NAME, + role="You are a helpful AI assistant", + enable_human_tool=False, + ), + ) + agent.create(enabled=True, replace=True) + yield agent + log_step(f"Deleting agent: {PYSAI_TEAM_AGENT_NAME}") + agent.delete(force=True) + +@pytest.fixture(scope="module") +def team_attributes(agent, task): + return TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + +@pytest.fixture(scope="module") +def team(team_attributes): + log_step(f"Creating team: {PYSAI_TEAM_NAME}") + team = Team( + team_name=PYSAI_TEAM_NAME, + attributes=team_attributes, + description=PYSAI_TEAM_DESC, + ) + team.create(enabled=True, replace=True) + yield team + log_step(f"Deleting team: {PYSAI_TEAM_NAME}") + team.delete(force=True) + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- +# Logging-enhanced Team tests +# ----------------------------------------------------------------------------- + +def test_3300_create_and_identity(team, team_attributes): + log_step("Validating team identity and attributes") + log_step(f"Team name: {team.team_name}") + log_step(f"Team description: {team.description}") + log_step(f"Team attributes: {team.attributes}") + assert team.team_name == PYSAI_TEAM_NAME + assert team.description == PYSAI_TEAM_DESC + assert team.attributes == team_attributes + log_ok("Team identity and attributes OK") + + +@pytest.mark.parametrize("pattern", [None, ".*", "^PYSAI_TEAM_"]) +def test_3301_list(pattern): + log_step(f"Listing teams using pattern: {pattern}") + teams = list(Team.list(pattern)) if pattern else list(Team.list()) + names = [t.team_name for t in teams] + log_step(f"Teams found: {names}") + assert PYSAI_TEAM_NAME in names + log_ok("Team found in list") + + +def test_3302_fetch(team_attributes): + log_step(f"Fetching team: {PYSAI_TEAM_NAME}") + t = Team.fetch(PYSAI_TEAM_NAME) + log_step(f"Fetched team attributes: {t.attributes}") + assert t.attributes == team_attributes + log_ok("Fetch OK") + + +def test_3303_run(team): + log_step(f"Running team: {team.team_name}") + response = team.run("What is 2+2?", {"conversation_id": str(uuid.uuid4())}) + log_step(f"Team run response: {response}") + assert isinstance(response, str) + assert len(response) > 0 + log_ok("Run OK") + + +def test_3304_disable_enable_contract(team): + log_step(f"Disabling team: {team.team_name}") + team.disable() + log_step("Team disabled successfully") + expect_error("ORA-20053", lambda: team.disable()) + log_step(f"Enabling team: {team.team_name}") + team.enable() + log_step("Team enabled successfully") + expect_error("ORA-20053", lambda: team.enable()) + + +def test_3305_set_attribute_process(team): + log_step(f"Setting team attribute 'process' to 'sequential': {team.team_name}") + team.set_attribute("process", "sequential") + fetched = Team.fetch(PYSAI_TEAM_NAME) + log_step(f"Fetched attribute process: {fetched.attributes.process}") + assert fetched.attributes.process == "sequential" + log_ok("Set attribute OK") + + +def test_3306_set_attributes(team, agent, task): + new_attrs = TeamAttributes( + agents=[{"name": agent.agent_name, "task": task.task_name}], + process="sequential", + ) + log_step(f"Replacing team attributes: {team.team_name}") + log_step(f"New attributes: {new_attrs}") + team.set_attributes(new_attrs) + fetched = Team.fetch(PYSAI_TEAM_NAME) + log_step(f"Fetched attributes after replace: {fetched.attributes}") + assert fetched.attributes == new_attrs + log_ok("Set attributes OK") + + +def test_3307_replace_create(team_attributes): + log_step(f"Replacing existing team: {PYSAI_TEAM_NAME}") + team2 = Team(PYSAI_TEAM_NAME, team_attributes, "REPLACED DESC") + team2.create(enabled=True, replace=True) + fetched = Team.fetch(PYSAI_TEAM_NAME) + log_step(f"Fetched team description after replace: {fetched.description}") + assert fetched.description == "REPLACED DESC" + log_ok("Replace OK") + + +def test_3308_fetch_non_existing(): + name = f"NO_SUCH_{uuid.uuid4().hex}" + log_step(f"Fetching non-existing team: {name}") + expect_error("NOT_FOUND", lambda: Team.fetch(name)) + log_ok("Fetch non-existing confirmed error") + + +def test_3311_set_attribute_invalid_key(team): + log_step(f"Setting invalid attribute key on team: {team.team_name}") + expect_error("ORA-20053", lambda: team.set_attribute("no_such_attr", "x")) + log_ok("Set invalid attribute confirmed error") + + +def test_3312_set_attribute_none(team): + log_step(f"Setting team attribute 'process' to None: {team.team_name}") + expect_error("ORA-20053", lambda: team.set_attribute("process", None)) + log_ok("Set attribute None confirmed error") + + +def test_3313_set_attribute_empty(team): + log_step(f"Setting team attribute 'process' to empty string: {team.team_name}") + expect_error("ORA-20053", lambda: team.set_attribute("process", "")) + log_ok("Set attribute empty confirmed error") + + +def test_3314_set_attribute_invalid_value(team): + log_step(f"Setting team attribute 'process' to invalid value: {team.team_name}") + expect_error("ORA-20053", lambda: team.set_attribute("process", "not_a_real_process")) + log_ok("Set attribute invalid value confirmed error") + + +def test_3315_disable_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting temporary team: {name}") + t.delete(force=True) + log_step(f"Attempting to disable deleted team: {name}") + expect_error("ORA-20053", lambda: t.disable()) + log_ok("Disable after delete confirmed error") + + +def test_3316_enable_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting temporary team: {name}") + t.delete(force=True) + log_step(f"Attempting to enable deleted team: {name}") + expect_error("ORA-20053", lambda: t.enable()) + log_ok("Enable after delete confirmed error") + + +def test_3317_set_attribute_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting temporary team: {name}") + t.delete(force=True) + log_step(f"Attempting to set attribute on deleted team: {name}") + expect_error("ORA-20053", lambda: t.set_attribute("process", "sequential")) + log_ok("Set attribute after delete confirmed error") + + +def test_3318_double_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting team first time: {name}") + t.delete(force=True) + log_step(f"Deleting team second time: {name}") + expect_error("ORA-20053", lambda: t.delete(force=False)) + log_ok("Double delete confirmed error") + + +def test_3319_create_existing_without_replace(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating team: {name}") + t1 = Team(name, team_attributes, "TMP1") + t1.create(replace=False) + log_step(f"Attempting to create existing team without replace: {name}") + expect_error("ORA-20053", lambda: Team(name, team_attributes, "TMP2").create(replace=False)) + t1.delete(force=True) + log_ok("Create existing without replace confirmed error") + + +def test_3320_fetch_after_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting temporary team: {name}") + t.delete(force=True) + log_step(f"Fetching deleted team: {name}") + expect_error("NOT_FOUND", lambda: Team.fetch(name)) + log_ok("Fetch after delete confirmed error") + + +def test_3321_double_delete(team_attributes): + name = f"TMP_{uuid.uuid4().hex}" + log_step(f"Creating temporary team: {name}") + t = Team(name, team_attributes, "TMP") + t.create() + log_step(f"Deleting team first time: {name}") + t.delete(force=True) + log_step(f"Deleting team second time: {name}") + # Second delete without force to actually raise the error + expect_error("ORA-20053", lambda: t.delete(force=False)) + log_ok("Double delete confirmed error") diff --git a/tests/agents/test_3800_agente2e.py b/tests/agents/test_3800_agente2e.py new file mode 100644 index 0000000..634d21e --- /dev/null +++ b/tests/agents/test_3800_agente2e.py @@ -0,0 +1,436 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import uuid +import time +import os +import logging +import pytest +from contextlib import contextmanager + +import select_ai +from select_ai.agent import ( + Agent, + AgentAttributes, + Task, + TaskAttributes, + Team, + TeamAttributes, + Tool, + ToolParams, + ToolAttributes, +) + +# ---------------------------------------------------------------------- +# LOGGING +# ---------------------------------------------------------------------- + + +# Path +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3800_agente2e.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +# Force logging to file (pytest-proof) +root = logging.getLogger() +root.setLevel(logging.INFO) + +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +logger = logging.getLogger() + + +@contextmanager +def log_step(step): + logger.info("START: %s", step) + start = time.time() + try: + yield + logger.info("END: %s (%.2fs)", step, time.time() - start) + except Exception: + logger.exception("FAILED: %s", step) + raise + + +def _safe_dict(obj): + if obj is None: + return None + if hasattr(obj, "dict"): + try: + return obj.dict(exclude_null=False) + except TypeError: + return obj.dict() + return str(obj) + + +def log_object_details(context: str, object_type: str, obj) -> None: + details = {"context": context, "object_type": object_type} + + if object_type == "profile": + details.update( + { + "profile_name": getattr(obj, "profile_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "agent": + details.update( + { + "agent_name": getattr(obj, "agent_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "tool": + details.update( + { + "tool_name": getattr(obj, "tool_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "task": + details.update( + { + "task_name": getattr(obj, "task_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "team": + details.update( + { + "team_name": getattr(obj, "team_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + else: + details["repr"] = str(obj) + + logger.info("OBJECT_DETAILS: %s", details) + print("OBJECT_DETAILS:", details) + + +@pytest.fixture(scope="session") +def setup_test_user(test_env): + try: + select_ai.disconnect() + except Exception: + pass + + select_ai.connect(**test_env.connect_params(admin=True)) + try: + try: + select_ai.grant_privileges(users=[test_env.test_user]) + except Exception as exc: + msg = str(exc) + if ( + "ORA-01749" not in msg + and "Cannot GRANT or REVOKE privileges to or from yourself" not in msg + ): + raise + + select_ai.grant_http_access( + users=[test_env.test_user], + provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, + ) + finally: + select_ai.disconnect() + select_ai.connect(**test_env.connect_params()) + + +@pytest.fixture(scope="session") +def openai_cred(): + api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") + assert api_key, "PYSAI_TEST_OPENAI_API_KEY not set" + + select_ai.create_credential( + credential={ + "credential_name": "OPENAI_CRED", + "username": "openai", + "password": api_key, + }, + replace=True, + ) + + return "OPENAI_CRED" + + +@pytest.fixture(scope="session") +def email_cred(): + smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") + smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") + + assert smtp_username, "PYSAI_TEST_EMAIL_CRED_USERNAME not set" + assert smtp_password, "PYSAI_TEST_EMAIL_CRED_PASSWORD not set" + + select_ai.create_credential( + credential={ + "credential_name": "EMAIL_CRED", + "username": smtp_username, + "password": smtp_password, + }, + replace=True, + ) + + return "EMAIL_CRED" + + +@pytest.fixture(scope="session") +def allow_network_acl(): + with select_ai.cursor() as cur: + cur.execute("SELECT USER FROM dual") + db_user = cur.fetchone()[0] + + def append_ace(host, privileges): + try: + cur.execute( + f""" + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => '{host}', + ace => xs$ace_type( + privilege_list => xs$name_list({','.join([f"'{p}'" for p in privileges])}), + principal_name => '{db_user}', + principal_type => xs_acl.ptype_db + ) + ); + END; + """ + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" in msg + or "ORA-46313" in msg + or "already exists" in msg + ): + return + raise + + append_ace( + "smtp.email.us-ashburn-1.oci.oraclecloud.com", + ["connect", "smtp"], + ) + + for host in ["api.openai.com", "a.co","amazon.in"]: + append_ace(host, ["connect", "http"]) + + yield + + +# ---------------------------------------------------------------------- +# MAIN TEST +# ---------------------------------------------------------------------- + +def test_3800_agent_end_to_end( + profile_attributes, setup_test_user, openai_cred, email_cred, allow_network_acl +): + """ + End-to-end Select AI Agent integration test. + + """ + + # ------------------------------- + # PROFILE + # ------------------------------- + logger.info("Starting End-to-End Agent Test") + + # ---------------- PROFILE ---------------- + + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + assert oci_compartment_id, "PYSAI_TEST_OCI_COMPARTMENT_ID not set" + + profile_attributes.provider.oci_compartment_id = oci_compartment_id + + profile = select_ai.Profile( + profile_name="GEN1_PROFILE", + attributes=profile_attributes, + replace=True, + ) + log_object_details("create_profile", "profile", profile) + + # ------------------------------- + # AGENT + # ------------------------------- + with log_step("Create agent"): + agent = Agent( + agent_name="CustomerAgent", + attributes=AgentAttributes( + profile_name="GEN1_PROFILE", + role="You are an experienced customer agent handling returns.", + enable_human_tool=True, + ), + ) + agent.create(replace=True) + log_object_details("create_agent", "agent", agent) + + assert agent.agent_name == "CustomerAgent" + + # ------------------------------- + # TOOLS + # ------------------------------- + with log_step("Create tools"): + + # Human tool + Tool.create_built_in_tool( + tool_name="Human", + description="Human intervention tool", + tool_type="HUMAN", + tool_params=ToolParams(), + replace=True, + ) + + websearch_tool = Tool( + tool_name="Websearch", + attributes=ToolAttributes( + tool_type="WEBSEARCH", + instruction="Use this tool to find the current price of a product from a URL.", + tool_params=ToolParams( + credential_name="OPENAI_CRED" + ), + ), + ) + websearch_tool.create(replace=True) + log_object_details("create_websearch_tool", "tool", websearch_tool) + + # Email notification tool + email_recipient = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") + email_sender = os.getenv("PYSAI_TEST_EMAIL_SENDER") + assert email_recipient, "PYSAI_TEST_EMAIL_RECIPIENT not set" + assert email_sender, "PYSAI_TEST_EMAIL_SENDER not set" + email_tool = Tool( + tool_name="Email", + attributes=ToolAttributes( + tool_type="NOTIFICATION", + tool_params=ToolParams( + credential_name="EMAIL_CRED", + notification_type="EMAIL", + recipient=email_recipient, + sender=email_sender, + smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + ), + ), + ) + email_tool.create(replace=True) + log_object_details("create_email_tool", "tool", email_tool) + + assert Tool("Human") is not None + assert Tool("Email") is not None + + # ------------------------------- + # TASK + # ------------------------------- + with log_step("Create task"): + task = Task( + task_name="Return_And_Price_Match", + attributes=TaskAttributes( + instruction=( + "Process a product return request from a customer. " + "1. Ask customer the reason for return (price match or defective). " + "2. If price match: " + " a. Request customer to provide a price match link. " + " b. Use websearch tool to get the price for that price match link" + " c. Ask customer if they want a refund and specify how much refund. " + " d. Send email notification only if customer accepts the refund. " + "3. If defective: " + " a. Process the defective return." + ), + tools=["Human", "Websearch", "Email"], + ), + ) + task.create(replace=True) + log_object_details("create_task", "task", task) + + assert task.task_name == "Return_And_Price_Match" + assert set(task.attributes.tools) == {"Human", "Websearch", "Email"} + + assert task.task_name == "Return_And_Price_Match" + # Corrected assert to match the 3 tools + assert set(task.attributes.tools) == {"Human", "Websearch", "Email"} + + # ------------------------------- + # TEAM + # ------------------------------- + with log_step("Create team"): + team = Team( + team_name="ReturnAgency", + attributes=TeamAttributes( + agents=[{ + "name": "CustomerAgent", + "task": "Return_And_Price_Match", + }], + process="sequential", + ), + ) + team.create(enabled=True, replace=True) + log_object_details("create_team", "team", team) + + assert team.team_name == "ReturnAgency" + + # ------------------------------- + # RUN CONVERSATION + # ------------------------------- + with log_step("Run agent conversation"): + conversation_id = str(uuid.uuid4()) + + prompts = [ + "I want to return an office chair", + "The price when I bought it is 100. But I found a cheaper price", + "Here is the price match link 'https://www.ikea.com/us/en/p/stefan-chair-brown-black-00211088/'", + "Yes, I would like to proceed with a refund", + ] + + for idx, prompt in enumerate(prompts, start=1): + logger.info("USER %d: %s", idx, prompt) + + response = team.run( + prompt=prompt, + params={"conversation_id": conversation_id}, + ) + + # ---- PRINT + LOG RESPONSE ---- + print(f"\nAGENT RESPONSE {idx}:\n{response}\n") + logger.info("AGENT RESPONSE %d: %s", idx, response) + + assert response is not None + assert isinstance(response, (str, dict)) + + if isinstance(response, dict): + assert response + + with select_ai.cursor() as cur: + cur.execute( + """ + SELECT * FROM user_ai_agent_tool_history + """ + ) + tool_history = cur.fetchall() + + decoded_tool_history = [] + for row in tool_history: + decoded_row = [] + for value in row: + if hasattr(value, "read"): + decoded_row.append(value.read()) + else: + decoded_row.append(value) + decoded_tool_history.append(tuple(decoded_row)) + + print(decoded_tool_history) + logger.info("Tool history rows fetched: %d", len(decoded_tool_history)) + for row in decoded_tool_history: + logger.info("TOOL_HISTORY_ROW: %s", row) + + assert decoded_tool_history diff --git a/tests/agents/test_3800_async_agente2e.py b/tests/agents/test_3800_async_agente2e.py new file mode 100644 index 0000000..191f777 --- /dev/null +++ b/tests/agents/test_3800_async_agente2e.py @@ -0,0 +1,345 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +3800 - Async end-to-end Select AI Agent integration test +""" + +import logging +import os +import time +import uuid +from contextlib import contextmanager + +import pytest +import select_ai +from select_ai.agent import ( + AgentAttributes, + AsyncAgent, + AsyncTask, + AsyncTeam, + AsyncTool, + TaskAttributes, + TeamAttributes, + ToolAttributes, + ToolParams, +) + +pytestmark = pytest.mark.anyio + +# ---------------------------------------------------------------------- +# LOGGING +# ---------------------------------------------------------------------- + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3800_async_agente2e.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +root = logging.getLogger() +root.setLevel(logging.INFO) +for h in root.handlers[:]: + root.removeHandler(h) + +fh = logging.FileHandler(LOG_FILE, mode="w") +fh.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root.addHandler(fh) + +logger = logging.getLogger(__name__) + + +@contextmanager +def log_step(step): + logger.info("START: %s", step) + start = time.time() + try: + yield + logger.info("END: %s (%.2fs)", step, time.time() - start) + except Exception: + logger.exception("FAILED: %s", step) + raise + + +def _safe_dict(obj): + if obj is None: + return None + if hasattr(obj, "dict"): + try: + return obj.dict(exclude_null=False) + except TypeError: + return obj.dict() + return str(obj) + + +def log_object_details(context: str, object_type: str, obj) -> None: + details = {"context": context, "object_type": object_type} + + if object_type == "profile": + details.update( + { + "profile_name": getattr(obj, "profile_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "agent": + details.update( + { + "agent_name": getattr(obj, "agent_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "tool": + details.update( + { + "tool_name": getattr(obj, "tool_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "task": + details.update( + { + "task_name": getattr(obj, "task_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + elif object_type == "team": + details.update( + { + "team_name": getattr(obj, "team_name", None), + "description": getattr(obj, "description", None), + "attributes": _safe_dict(getattr(obj, "attributes", None)), + } + ) + else: + details["repr"] = str(obj) + + logger.info("OBJECT_DETAILS: %s", details) + print("OBJECT_DETAILS:", details) + + +@pytest.fixture(scope="module", autouse=True) +async def async_connect(test_env): + logger.info( + "Opening async admin database connection | user=%s | dsn=%s", + test_env.admin_user, + test_env.connect_string, + ) + await select_ai.async_connect(**test_env.connect_params(admin=True)) + yield + logger.info("Closing async admin database connection") + await select_ai.async_disconnect() + + +async def test_3800_agent_end_to_end_async(profile_attributes): + """End-to-end Select AI Agent integration test (async).""" + + run_id = uuid.uuid4().hex.upper() + + profile_name = f"GEN1_PROFILE_{run_id}" + agent_name = f"CustomerAgent_{run_id}" + human_tool_name = f"Human_{run_id}" + websearch_tool_name = f"Websearch_{run_id}" + email_tool_name = f"Email_{run_id}" + task_name = f"Return_And_Price_Match_{run_id}" + team_name = f"ReturnAgency_{run_id}" + + created = { + "team": None, + "task": None, + "tools": [], + "agent": None, + "profile": None, + } + + logger.info("Starting async End-to-End Agent Test") + logger.info( + "Run identifiers | profile=%s agent=%s task=%s team=%s", + profile_name, + agent_name, + task_name, + team_name, + ) + + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + assert oci_compartment_id, "PYSAI_TEST_OCI_COMPARTMENT_ID not set" + + profile_attributes.provider.oci_compartment_id = oci_compartment_id + + try: + with log_step("Create profile"): + profile = await select_ai.AsyncProfile( + profile_name=profile_name, + attributes=profile_attributes, + replace=True, + ) + created["profile"] = profile + logger.info("Created profile: %s", profile.profile_name) + log_object_details("create_profile", "profile", profile) + + with log_step("Create agent"): + agent = AsyncAgent( + agent_name=agent_name, + attributes=AgentAttributes( + profile_name=profile_name, + role=( + "You are an experienced customer agent handling returns." + ), + enable_human_tool=True, + ), + ) + await agent.create(replace=True) + created["agent"] = agent + logger.info("Created agent: %s", agent.agent_name) + log_object_details("create_agent", "agent", agent) + assert agent.agent_name == agent_name + + with log_step("Create tools"): + human_tool = await AsyncTool.create_built_in_tool( + tool_name=human_tool_name, + description="Human intervention tool", + tool_type=select_ai.agent.ToolType.HUMAN, + tool_params=ToolParams(), + replace=True, + ) + created["tools"].append(human_tool) + + websearch_tool = AsyncTool( + tool_name=websearch_tool_name, + attributes=ToolAttributes( + tool_type=select_ai.agent.ToolType.WEBSEARCH, + instruction=( + "Use this tool to find current product price from a URL." + ), + tool_params=ToolParams(credential_name="OPENAI_CRED"), + ), + ) + await websearch_tool.create(replace=True) + created["tools"].append(websearch_tool) + log_object_details("create_websearch_tool", "tool", websearch_tool) + + email_recipient = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") + email_sender = os.getenv("PYSAI_TEST_EMAIL_SENDER") + assert email_recipient, "PYSAI_TEST_EMAIL_RECIPIENT not set" + assert email_sender, "PYSAI_TEST_EMAIL_SENDER not set" + email_tool = AsyncTool( + tool_name=email_tool_name, + attributes=ToolAttributes( + tool_type=select_ai.agent.ToolType.NOTIFICATION, + tool_params=ToolParams( + credential_name="EMAIL_CRED", + notification_type="EMAIL", + recipient=email_recipient, + sender=email_sender, + smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + ), + ), + ) + await email_tool.create(replace=True) + created["tools"].append(email_tool) + log_object_details("create_email_tool", "tool", email_tool) + + logger.info( + "Created tools: %s", + [t.tool_name for t in created["tools"]], + ) + log_object_details("create_human_tool", "tool", human_tool) + assert len(created["tools"]) == 3 + + with log_step("Create task"): + task = AsyncTask( + task_name=task_name, + attributes=TaskAttributes( + instruction=( + "Process a product return request from a customer. " + "1. Ask customer reason for return (price match or defective). " + "2. If price match: request link, use websearch, ask refund amount, " + "send email only if accepted. " + "3. If defective: process defective return." + ), + tools=[human_tool_name, websearch_tool_name, email_tool_name], + ), + ) + await task.create(replace=True) + created["task"] = task + + logger.info("Created task: %s", task.task_name) + log_object_details("create_task", "task", task) + assert task.task_name == task_name + assert set(task.attributes.tools) == { + human_tool_name, + websearch_tool_name, + email_tool_name, + } + + with log_step("Create team"): + team = AsyncTeam( + team_name=team_name, + attributes=TeamAttributes( + agents=[ + { + "name": agent_name, + "task": task_name, + } + ], + process="sequential", + ), + ) + await team.create(enabled=True, replace=True) + created["team"] = team + + logger.info("Created team: %s", team.team_name) + log_object_details("create_team", "team", team) + assert team.team_name == team_name + + with log_step("Run async agent conversation"): + conversation_id = str(uuid.uuid4()) + prompts = [ + "I want to return an office chair", + "The price when I bought it is 100. I found a cheaper price", + "Price match link https://www.ikea.com/us/en/p/stefan-chair-brown-black-00211088/", + "Yes, I would like to proceed with a refund", + "If you have not started the refund, please do", + ] + + for idx, prompt in enumerate(prompts, start=1): + logger.info("USER %d: %s", idx, prompt) + response = await team.run( + prompt=prompt, + params={"conversation_id": conversation_id}, + ) + + print(f"\nASYNC AGENT RESPONSE {idx}:\n{response}\n") + logger.info("ASYNC AGENT RESPONSE %d: %s", idx, response) + + assert response is not None + assert isinstance(response, str) + assert len(response.strip()) > 0 + + finally: + with log_step("Cleanup async e2e objects"): + if created["team"] is not None: + logger.info("Deleting team: %s", created["team"].team_name) + await created["team"].delete(force=True) + + if created["task"] is not None: + logger.info("Deleting task: %s", created["task"].task_name) + await created["task"].delete(force=True) + + for tool in reversed(created["tools"]): + logger.info("Deleting tool: %s", tool.tool_name) + await tool.delete(force=True) + + if created["agent"] is not None: + logger.info("Deleting agent: %s", created["agent"].agent_name) + await created["agent"].delete(force=True) + + if created["profile"] is not None: + logger.info("Deleting profile: %s", created["profile"].profile_name) + await created["profile"].delete(force=True) From ebfa1e292e76c62ec286297b24953e18db2f8555 Mon Sep 17 00:00:00 2001 From: Vishwas Mor Date: Wed, 25 Feb 2026 10:01:22 +0000 Subject: [PATCH 3/6] Feedback tests for async and sync profiles Feedback tests for async and sync profiles Adding feedback logging address comments reviewed comments comments correction --- tests/feedback/conftest.py | 55 ++ tests/feedback/test_4000_sync_profile.py | 859 +++++++++++++++++++++ tests/feedback/test_4100_async_profile.py | 876 ++++++++++++++++++++++ 3 files changed, 1790 insertions(+) create mode 100644 tests/feedback/conftest.py create mode 100644 tests/feedback/test_4000_sync_profile.py create mode 100644 tests/feedback/test_4100_async_profile.py diff --git a/tests/feedback/conftest.py b/tests/feedback/conftest.py new file mode 100644 index 0000000..7fba8f9 --- /dev/null +++ b/tests/feedback/conftest.py @@ -0,0 +1,55 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import logging +from pathlib import Path + +import pytest + +LOG_FORMAT = "%(levelname)s: [%(name)s] %(message)s" + + +def _configure_logger(logger: logging.Logger, module_file: str) -> None: + logger.setLevel(logging.INFO) + log_dir = Path(__file__).resolve().parents[2] / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + log_file = log_dir / f"{Path(module_file).stem}.log" + + formatter = logging.Formatter(fmt=LOG_FORMAT) + + file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8") + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + + logger.handlers.clear() + logger.propagate = False + logger.addHandler(file_handler) + logger.addHandler(console_handler) + logger.info("Configured logging for module") + + +@pytest.fixture(scope="module", autouse=True) +def configure_module_logging(request): + module = request.module + logger = logging.getLogger(module.__name__) + _configure_logger(logger, module.__file__) + yield + for handler in logger.handlers: + handler.close() + logger.handlers.clear() + + +@pytest.fixture(autouse=True) +def log_test_case(request, configure_module_logging): + logger = logging.getLogger(request.module.__name__) + logger.info("Starting test %s", request.node.name) + yield + logger.info("Finished test %s", request.node.name) diff --git a/tests/feedback/test_4000_sync_profile.py b/tests/feedback/test_4000_sync_profile.py new file mode 100644 index 0000000..3be9cad --- /dev/null +++ b/tests/feedback/test_4000_sync_profile.py @@ -0,0 +1,859 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +4000 - Sync Profile feedback API tests +""" +import logging +import uuid +import oracledb +import pytest +import select_ai +from select_ai.action import Action + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +_ACTIVE_CURSOR = None + +PROFILE_PREFIX = "PYSAI_TEST_FEEDBACK_SYNC_4000" +PROFILE_NAME = f"{PROFILE_PREFIX}_{uuid.uuid4().hex.upper()}" +PROFILE_DESCRIPTION = "OCI Gen AI Test Profile" +PROMPT = "Total points of each gymnasts" +SHOWSQL_SQL_ID = "ahgttusrvh9x5" +RUNSQL_SQL_ID = "6s20ukn8j3p5j" +EXPLAINSQL_SQL_ID = "2a617cynwfm36" +PROFILE_OBJECT_NAMES = ("gymnast", "people") +WARMUP_STATEMENTS = ( + f"select ai showsql {PROMPT}", + f"select ai runsql {PROMPT}", + f"select ai explainsql {PROMPT}", +) + + +def _assert_db_error(exc_info, expected_code): + assert isinstance(exc_info.value, oracledb.DatabaseError) + (error,) = exc_info.value.args + assert error.code == expected_code + return error + + +def _set_profile_and_warm_up(profile_name, cursor): + cursor.execute( + """ + BEGIN + dbms_cloud_ai.set_profile(:profile_name); + END; + """, + profile_name=profile_name, + ) + for statement in WARMUP_STATEMENTS: + cursor.execute(statement) + + +def _log_feedback_vecindex_rows(profile): + if _ACTIVE_CURSOR is None: + raise RuntimeError("cursor fixture is not available") + table_name = f"{profile.profile_name.upper()}_FEEDBACK_VECINDEX$VECTAB" + _ACTIVE_CURSOR.execute(f"select CONTENT, ATTRIBUTES from {table_name}") + rows = _ACTIVE_CURSOR.fetchall() + logger.info("Feedback vecindex rows: %s", rows) + + +@pytest.fixture(autouse=True) +def active_cursor(cursor): + global _ACTIVE_CURSOR + _ACTIVE_CURSOR = cursor + yield + _ACTIVE_CURSOR = None + + +@pytest.fixture(scope="module") +def profile(oci_credential, oci_compartment_id, test_env, cursor): + object_list = [ + {"owner": test_env.test_user, "name": object_name} + for object_name in PROFILE_OBJECT_NAMES + ] + profile = select_ai.Profile( + profile_name=PROFILE_NAME, + description=PROFILE_DESCRIPTION, + replace=True, + attributes=select_ai.ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=object_list, + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ), + ), + ) + _set_profile_and_warm_up(profile.profile_name, cursor) + + yield profile + profile.delete(force=True) + +############################################### NEGATIVE FEEDBACK TESTS +def test_4001(profile, cursor): + """Add negative feedback using SHOWSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.SHOWSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4002(profile, cursor): + """Add negative feedback using RUNSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.RUNSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4003(profile, cursor): + """Add negative feedback using EXPLAINSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.EXPLAINSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4004(profile, cursor): + """Add negative feedback using SHOWSQL sql_id, response, and feedback_content.""" + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4005(profile, cursor): + """Add negative feedback using RUNSQL sql_id, response, and feedback_content.""" + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4006(profile, cursor): + """Add negative feedback using EXPLAINSQL sql_id, response, and feedback_content.""" + sql_id = EXPLAINSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4007(profile, cursor): + """Attempt negative feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p1.name, g1.total_points FROM people p1 JOIN gymnast g1 " + "ON p1.id = g1.id ORDER BY g1.total_points DESC" + ) + feedback_content = "print in descending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=%s response=%s feedback_content=%s", + prompt, + action, + sql_id, + response, + feedback_content, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + profile.add_negative_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +def test_4008(profile, cursor): + """Attempt negative feedback without a response.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=None feedback_content=%s", + prompt, + action, + feedback_content, + ) + with pytest.raises(AttributeError) as exc_info: + profile.add_negative_feedback( + prompt_spec=(prompt, action), + feedback_content=feedback_content, + ) + assert isinstance(exc_info.value, AttributeError) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +def test_4009(profile, cursor): + """Add negative feedback with sql_id and response but without feedback_content.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p6.name, g6.total_points FROM people p6 JOIN gymnast g6 " + "ON p6.id = g6.id ORDER BY g6.total_points DESC" + ) + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=None", + sql_id, + response, + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4010(profile, cursor): + """Add negative feedback with prompt_spec=None and a valid sql_id.""" + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p6.name, g6.total_points FROM people p6 JOIN gymnast g6 " + "ON p6.id = g6.id ORDER BY g6.total_points ASC, p6.name ASC" + ) + feedback_content = "print in ascending order of total_points and name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=None, + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4011(profile, cursor): + """Add negative feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + feedback_content = "print in ascending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + sql_id=None, + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +def test_4012(profile, cursor): + """Attempt negative feedback with response=None.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + feedback_content = "print in ascending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=None feedback_content=%s", + prompt, + action, + feedback_content, + ) + with pytest.raises(AttributeError) as exc_info: + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=None, + feedback_content=feedback_content, + ) + assert isinstance(exc_info.value, AttributeError) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +def test_4013(profile, cursor): + """Add negative feedback with feedback_content=None using sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=None", + sql_id, + response, + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=None, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +def test_4014(profile, cursor): + """Add negative feedback for a non-existent SHOWSQL prompt.""" + prompt = "Adding negative feedback with non existent prompt" + action = Action.SHOWSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +############################################################## POSITIVE FEEDBACK TESTS +def test_4015(profile, cursor): + """Add positive feedback using SHOWSQL prompt_spec only.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + profile.add_positive_feedback(prompt_spec=(prompt, action)) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4016(profile, cursor): + """Add positive feedback using RUNSQL prompt_spec only.""" + prompt = PROMPT + action = Action.RUNSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + profile.add_positive_feedback(prompt_spec=(prompt, action)) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4017(profile, cursor): + """Add positive feedback using EXPLAINSQL prompt_spec only.""" + prompt = PROMPT + action = Action.EXPLAINSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + profile.add_positive_feedback(prompt_spec=(prompt, action)) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4018(profile, cursor): + """Attempt positive feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=%s", + prompt, + action, + sql_id, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + profile.add_positive_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +def test_4019(profile, cursor): + """Add positive feedback using SHOWSQL sql_id only.""" + sql_id = SHOWSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + profile.add_positive_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +def test_4020(profile, cursor): + """Add positive feedback using RUNSQL sql_id only.""" + sql_id = RUNSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + profile.add_positive_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4021(profile, cursor): + """Add positive feedback using EXPLAINSQL sql_id only.""" + sql_id = EXPLAINSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + profile.add_positive_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4022(profile, cursor): + """Add positive feedback with prompt_spec=None and a valid sql_id.""" + sql_id = SHOWSQL_SQL_ID + logger.info("Adding positive feedback with prompt_spec=None sql_id=%s", sql_id) + profile.add_positive_feedback( + prompt_spec=None, + sql_id=sql_id, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +def test_4023(profile, cursor): + """Add positive feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + profile.add_positive_feedback( + prompt_spec=(prompt, action), + sql_id=None, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +def test_4024(profile, cursor): + """Attempt positive feedback for a non-existent SHOWSQL prompt.""" + prompt = "Adding positive feedback with non existent prompt" + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + profile.add_positive_feedback(prompt_spec=(prompt, action)) + _assert_db_error(exc_info, 20000) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +############################################################## DELETE FEEDBACK TESTS +def test_4025(profile, cursor): + """Delete feedback by prompt_spec after adding positive SHOWSQL feedback.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s", + prompt, + action, + ) + profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + _log_feedback_vecindex_rows(profile) + + logger.info("Deleting feedback for prompt=%s action=%s", prompt, action) + profile.delete_feedback( + prompt_spec=(prompt, action), + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + +def test_4026(profile, cursor): + """Delete feedback by RUNSQL sql_id after adding negative feedback with sql_id.""" + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + "Feedback prior to delete", + ) + profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content="Feedback prior to delete", + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + profile.delete_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + +def test_4027(profile, cursor): + """Delete feedback by prompt_spec after adding positive EXPLAINSQL feedback.""" + prompt = PROMPT + action = Action.EXPLAINSQL + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s", + prompt, + action, + ) + profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback for prompt=%s action=%s", prompt, action) + profile.delete_feedback( + prompt_spec=(prompt, action), + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + +def test_4028(profile, cursor): + """Delete SHOWSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + profile.delete_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +def test_4029(profile, cursor): + """Delete RUNSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.RUNSQL + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + profile.delete_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + +def test_4030(profile, cursor): + """Delete EXPLAINSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.EXPLAINSQL + sql_id = EXPLAINSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + profile.delete_feedback(sql_id=sql_id) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +def test_4031(profile, cursor): + """Delete feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + _log_feedback_vecindex_rows(profile) + logger.info( + "Deleting feedback for prompt=%s action=%s with sql_id=None", + prompt, + action, + ) + profile.delete_feedback( + prompt_spec=(prompt, action), + sql_id=None, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +def test_4032(profile, cursor): + """Delete feedback with prompt_spec=None and a valid sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback before delete without prompt_spec using sql_id=%s", + sql_id, + ) + profile.add_positive_feedback( + sql_id=sql_id, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback with prompt_spec=None using sql_id=%s", sql_id) + profile.delete_feedback( + prompt_spec=None, + sql_id=sql_id, + ) + _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + +def test_4033(profile, cursor): + """Attempt delete_feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback before conflicting delete without prompt_spec using sql_id=%s", + sql_id, + ) + profile.add_positive_feedback( + sql_id=sql_id, + ) + _log_feedback_vecindex_rows(profile) + logger.info( + "Deleting feedback for prompt=%s action=%s sql_id=%s", + prompt, + action, + sql_id, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + profile.delete_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) diff --git a/tests/feedback/test_4100_async_profile.py b/tests/feedback/test_4100_async_profile.py new file mode 100644 index 0000000..5b3bae7 --- /dev/null +++ b/tests/feedback/test_4100_async_profile.py @@ -0,0 +1,876 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +4100 - Async profile feedback API tests +""" +import logging +import uuid +import oracledb +import pytest +import select_ai +from select_ai.action import Action + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +_ACTIVE_ASYNC_CURSOR = None + +PROFILE_PREFIX = "PYSAI_TEST_FEEDBACK_ASYNC_4100" +PROFILE_NAME = f"{PROFILE_PREFIX}_{uuid.uuid4().hex.upper()}" +PROFILE_DESCRIPTION = "OCI Gen AI Test Profile" +PROMPT = "Total points of each gymnasts" +SHOWSQL_SQL_ID = "ahgttusrvh9x5" +RUNSQL_SQL_ID = "6s20ukn8j3p5j" +EXPLAINSQL_SQL_ID = "2a617cynwfm36" +PROFILE_OBJECT_NAMES = ("gymnast", "people") +WARMUP_STATEMENTS = ( + f"select ai showsql {PROMPT}", + f"select ai runsql {PROMPT}", + f"select ai explainsql {PROMPT}", +) + + +def _assert_db_error(exc_info, expected_code): + assert isinstance(exc_info.value, oracledb.DatabaseError) + (error,) = exc_info.value.args + assert error.code == expected_code + return error + + +async def _set_profile_and_warm_up(profile_name, async_cursor): + await async_cursor.execute( + """ + BEGIN + dbms_cloud_ai.set_profile(:profile_name); + END; + """, + profile_name=profile_name, + ) + for statement in WARMUP_STATEMENTS: + await async_cursor.execute(statement) + + +async def _log_feedback_vecindex_rows(profile): + if _ACTIVE_ASYNC_CURSOR is None: + raise RuntimeError("async_cursor fixture is not available") + table_name = f"{profile.profile_name.upper()}_FEEDBACK_VECINDEX$VECTAB" + await _ACTIVE_ASYNC_CURSOR.execute( + f"select CONTENT, ATTRIBUTES from {table_name}" + ) + rows = await _ACTIVE_ASYNC_CURSOR.fetchall() + logger.info("Feedback vecindex rows: %s", rows) + + +@pytest.fixture(scope="module") +async def profile(oci_credential, oci_compartment_id, test_env, async_cursor): + object_list = [ + {"owner": test_env.test_user, "name": object_name} + for object_name in PROFILE_OBJECT_NAMES + ] + profile = await select_ai.AsyncProfile( + profile_name=PROFILE_NAME, + description=PROFILE_DESCRIPTION, + replace=True, + attributes=select_ai.ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=object_list, + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ), + ), + ) + await _set_profile_and_warm_up(profile.profile_name, async_cursor) + + yield profile + await profile.delete(force=True) + + +@pytest.fixture(autouse=True) +async def active_async_cursor(async_cursor): + global _ACTIVE_ASYNC_CURSOR + _ACTIVE_ASYNC_CURSOR = async_cursor + yield + _ACTIVE_ASYNC_CURSOR = None + +############################################### NEGATIVE FEEDBACK TESTS +async def test_4101(profile, async_cursor): + """Add negative feedback using SHOWSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.SHOWSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4102(profile, async_cursor): + """Add negative feedback using RUNSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.RUNSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4103(profile, async_cursor): + """Add negative feedback using EXPLAINSQL prompt_spec, response, and feedback_content.""" + prompt = PROMPT + action = Action.EXPLAINSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4104(profile, async_cursor): + """Add negative feedback using SHOWSQL sql_id, response, and feedback_content.""" + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4105(profile, async_cursor): + """Add negative feedback using RUNSQL sql_id, response, and feedback_content.""" + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4106(profile, async_cursor): + """Add negative feedback using EXPLAINSQL sql_id, response, and feedback_content.""" + sql_id = EXPLAINSQL_SQL_ID + response = ( + "SELECT p4.name, g4.total_points FROM people p4 JOIN gymnast g4 " + "ON p4.id = g4.id ORDER BY p4.name DESC" + ) + feedback_content = "print in descending order of name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4107(profile, async_cursor): + """Attempt negative feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p1.name, g1.total_points FROM people p1 JOIN gymnast g1 " + "ON p1.id = g1.id ORDER BY g1.total_points DESC" + ) + feedback_content = "print in descending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=%s response=%s feedback_content=%s", + prompt, + action, + sql_id, + response, + feedback_content, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) + + +async def test_4108(profile, async_cursor): + """Attempt negative feedback without a response.""" + prompt = PROMPT + action = Action.SHOWSQL + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=None feedback_content=%s", + prompt, + action, + feedback_content, + ) + with pytest.raises(AttributeError) as exc_info: + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + feedback_content=feedback_content, + ) + assert isinstance(exc_info.value, AttributeError) + logger.error("%s", str(exc_info.value).splitlines()[0]) + + +async def test_4109(profile, async_cursor): + """Add negative feedback with sql_id and response but without feedback_content.""" + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p6.name, g6.total_points FROM people p6 JOIN gymnast g6 " + "ON p6.id = g6.id ORDER BY g6.total_points DESC" + ) + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=None", + sql_id, + response, + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4110(profile, async_cursor): + """Add negative feedback with prompt_spec=None and a valid sql_id.""" + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p6.name, g6.total_points FROM people p6 JOIN gymnast g6 " + "ON p6.id = g6.id ORDER BY g6.total_points ASC, p6.name ASC" + ) + feedback_content = "print in ascending order of total_points and name" + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=None, + sql_id=sql_id, + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4111(profile, async_cursor): + """Add negative feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + feedback_content = "print in ascending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + sql_id=None, + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +async def test_4112(profile, async_cursor): + """Attempt negative feedback with response=None.""" + prompt = PROMPT + action = Action.SHOWSQL + feedback_content = "print in ascending order of total_points" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=None feedback_content=%s", + prompt, + action, + feedback_content, + ) + with pytest.raises(AttributeError) as exc_info: + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=None, + feedback_content=feedback_content, + ) + assert isinstance(exc_info.value, AttributeError) + logger.error("%s", str(exc_info.value).splitlines()[0]) + + +async def test_4113(profile, async_cursor): + """Add negative feedback with feedback_content=None using sql_id.""" + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback with prompt_spec=None sql_id=%s response=%s feedback_content=None", + sql_id, + response, + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content=None, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + +async def test_4114(profile, async_cursor): + """Add negative feedback for a non-existent SHOWSQL prompt.""" + prompt = "Adding negative feedback with non existent prompt" + action = Action.SHOWSQL + response = ( + "SELECT p5.name, g5.total_points FROM people p5 JOIN gymnast g5 " + "ON p5.id = g5.id ORDER BY p4.name ASC" + ) + feedback_content = "print in ascending order of name" + logger.info( + "Adding negative feedback for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + feedback_content, + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content=feedback_content, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert response in show_prompt + + +############################################################## POSITIVE FEEDBACK TESTS +async def test_4115(profile, async_cursor): + """Add positive feedback using SHOWSQL prompt_spec only.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + await profile.add_positive_feedback(prompt_spec=(prompt, action)) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4116(profile, async_cursor): + """Add positive feedback using RUNSQL prompt_spec only.""" + prompt = PROMPT + action = Action.RUNSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + await profile.add_positive_feedback(prompt_spec=(prompt, action)) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4117(profile, async_cursor): + """Add positive feedback using EXPLAINSQL prompt_spec only.""" + prompt = PROMPT + action = Action.EXPLAINSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + await profile.add_positive_feedback(prompt_spec=(prompt, action)) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4118(profile, async_cursor): + """Attempt positive feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=%s", + prompt, + action, + sql_id, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await profile.add_positive_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) + + +async def test_4119(profile, async_cursor): + """Add positive feedback using SHOWSQL sql_id only.""" + sql_id = SHOWSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + await profile.add_positive_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4120(profile, async_cursor): + """Add positive feedback using RUNSQL sql_id only.""" + sql_id = RUNSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + await profile.add_positive_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4121(profile, async_cursor): + """Add positive feedback using EXPLAINSQL sql_id only.""" + sql_id = EXPLAINSQL_SQL_ID + logger.info("Adding positive feedback without prompt_spec using sql_id=%s", sql_id) + await profile.add_positive_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4122(profile, async_cursor): + """Add positive feedback with prompt_spec=None and a valid sql_id.""" + sql_id = SHOWSQL_SQL_ID + logger.info("Adding positive feedback with prompt_spec=None sql_id=%s", sql_id) + await profile.add_positive_feedback( + prompt_spec=None, + sql_id=sql_id, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + + +async def test_4123(profile, async_cursor): + """Add positive feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + await profile.add_positive_feedback( + prompt_spec=(prompt, action), + sql_id=None, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking if show_prompt contains feedback metadata") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert "sql_query" in show_prompt + assert "user_prompt" in show_prompt + +async def test_4124(profile, async_cursor): + """Attempt positive feedback for a non-existent SHOWSQL prompt.""" + prompt = "Adding positive feedback with non existent prompt" + action = Action.SHOWSQL + logger.info( + "Adding positive feedback for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await profile.add_positive_feedback(prompt_spec=(prompt, action)) + _assert_db_error(exc_info, 20000) + logger.error("%s", str(exc_info.value).splitlines()[0]) + +############################################################## DELETE FEEDBACK TESTS +async def test_4125(profile, async_cursor): + """Delete feedback by prompt_spec after adding positive SHOWSQL feedback.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s", + prompt, + action, + ) + await profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + await _log_feedback_vecindex_rows(profile) + + logger.info("Deleting feedback for prompt=%s action=%s", prompt, action) + await profile.delete_feedback( + prompt_spec=(prompt, action), + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4126(profile, async_cursor): + """Delete feedback by RUNSQL sql_id after adding negative feedback with sql_id.""" + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete with prompt_spec=None sql_id=%s response=%s feedback_content=%s", + sql_id, + response, + "Feedback prior to delete", + ) + await profile.add_negative_feedback( + sql_id=sql_id, + response=response, + feedback_content="Feedback prior to delete", + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + await profile.delete_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4127(profile, async_cursor): + """Delete feedback by prompt_spec after adding positive EXPLAINSQL feedback.""" + prompt = PROMPT + action = Action.EXPLAINSQL + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s", + prompt, + action, + ) + await profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback for prompt=%s action=%s", prompt, action) + await profile.delete_feedback( + prompt_spec=(prompt, action), + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4128(profile, async_cursor): + """Delete SHOWSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + await profile.delete_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4129(profile, async_cursor): + """Delete RUNSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.RUNSQL + sql_id = RUNSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + await profile.delete_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4130(profile, async_cursor): + """Delete EXPLAINSQL feedback by sql_id after adding negative prompt-based feedback.""" + prompt = PROMPT + action = Action.EXPLAINSQL + sql_id = EXPLAINSQL_SQL_ID + response = ( + "SELECT p.name, g.total_points FROM people p JOIN gymnast g " + "ON p.id = g.id ORDER BY g.total_points DESC" + ) + logger.info( + "Adding negative feedback before delete for prompt=%s action=%s sql_id=None response=%s feedback_content=%s", + prompt, + action, + response, + "Feedback prior to delete", + ) + await profile.add_negative_feedback( + prompt_spec=(prompt, action), + response=response, + feedback_content="Feedback prior to delete", + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback using sql_id=%s", sql_id) + await profile.delete_feedback(sql_id=sql_id) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4131(profile, async_cursor): + """Delete feedback with sql_id=None and a valid prompt_spec.""" + prompt = PROMPT + action = Action.SHOWSQL + logger.info( + "Adding positive feedback before delete for prompt=%s action=%s sql_id=None", + prompt, + action, + ) + await profile.add_positive_feedback( + prompt_spec=(prompt, action), + ) + await _log_feedback_vecindex_rows(profile) + logger.info( + "Deleting feedback for prompt=%s action=%s with sql_id=None", + prompt, + action, + ) + await profile.delete_feedback( + prompt_spec=(prompt, action), + sql_id=None, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4132(profile, async_cursor): + """Delete feedback with prompt_spec=None and a valid sql_id.""" + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback before delete without prompt_spec using sql_id=%s", + sql_id, + ) + await profile.add_positive_feedback( + sql_id=sql_id, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Deleting feedback with prompt_spec=None using sql_id=%s", sql_id) + await profile.delete_feedback( + prompt_spec=None, + sql_id=sql_id, + ) + await _log_feedback_vecindex_rows(profile) + logger.info("Checking absence of feedback in show_prompt") + show_prompt = await profile.show_prompt(PROMPT) + logger.info("show_prompt response: %s", show_prompt) + assert show_prompt.count(PROMPT) == 1 + + +async def test_4133(profile, async_cursor): + """Attempt delete_feedback with both prompt_spec and sql_id.""" + prompt = PROMPT + action = Action.SHOWSQL + sql_id = SHOWSQL_SQL_ID + logger.info( + "Adding positive feedback before conflicting delete without prompt_spec using sql_id=%s", + sql_id, + ) + await profile.add_positive_feedback( + sql_id=sql_id, + ) + await _log_feedback_vecindex_rows(profile) + logger.info( + "Deleting feedback for prompt=%s action=%s sql_id=%s", + prompt, + action, + sql_id, + ) + with pytest.raises(oracledb.DatabaseError) as exc_info: + await profile.delete_feedback( + prompt_spec=(prompt, action), + sql_id=sql_id, + ) + _assert_db_error(exc_info, 6550) + logger.error("%s", str(exc_info.value).splitlines()[0]) From f9eaf64497a7f14c24279aeb152859af3ec3a5d4 Mon Sep 17 00:00:00 2001 From: Kondra Nagabhavani Date: Thu, 16 Apr 2026 10:59:30 +0530 Subject: [PATCH 4/6] Agent tests (#30) * Added sync and async api tests for Agent feature * Addressed review comments and enhanced tests * Modified agente2e tests and added one more end to end agent testcase * Added environment variables * Resolve conflict by keeping remote conftest.py --- .github/workflows/test.yaml | 7 + tests/agents/test_3001_async_tools.py | 42 +- tests/agents/test_3001_tools.py | 42 +- tests/agents/test_3101_async_tasks.py | 2 +- tests/agents/test_3201_async_agents.py | 2 +- tests/agents/test_3800_agente2e.py | 491 +++++++++++++--------- tests/agents/test_3800_async_agente2e.py | 319 ++++++++++++--- tests/agents/test_3900_async_sql_team.py | 496 +++++++++++++++++++++++ tests/agents/test_3900_sql_team.py | 482 ++++++++++++++++++++++ tests/conftest.py | 100 ++--- 10 files changed, 1658 insertions(+), 325 deletions(-) create mode 100644 tests/agents/test_3900_async_sql_team.py create mode 100644 tests/agents/test_3900_sql_team.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b404a64..e0ac095 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -47,3 +47,10 @@ jobs: PYSAI_TEST_OCI_PRIVATE_KEY: ${{ secrets.PYSAI_TEST_OCI_PRIVATE_KEY }} PYSAI_TEST_OCI_FINGERPRINT: ${{ secrets.PYSAI_TEST_OCI_FINGERPRINT }} PYSAI_TEST_OCI_COMPARTMENT_ID: ${{ secrets.PYSAI_TEST_OCI_COMPARTMENT_ID }} + PYSAI_TEST_EMAIL_CRED_USERNAME: ${{ secrets.PYSAI_TEST_EMAIL_CRED_USERNAME }} + PYSAI_TEST_EMAIL_CRED_PASSWORD: ${{ secrets.PYSAI_TEST_EMAIL_CRED_PASSWORD }} + PYSAI_TEST_SLACK_USERNAME: ${{ secrets.PYSAI_TEST_SLACK_USERNAME }} + PYSAI_TEST_SLACK_PASSWORD: ${{ secrets.PYSAI_TEST_SLACK_PASSWORD }} + PYSAI_TEST_EMAIL_RECIPIENT: ${{secrets.PYSAI_TEST_EMAIL_RECIPIENT}} + PYSAI_TEST_EMAIL_SENDER: ${{secrets.PYSAI_TEST_EMAIL_SENDER}} + PYSAI_TEST_EMAIL_SMTPHOST: ${{secrets.PYSAI_TEST_EMAIL_SMTPHOST}} diff --git a/tests/agents/test_3001_async_tools.py b/tests/agents/test_3001_async_tools.py index ed760a3..5157c2d 100644 --- a/tests/agents/test_3001_async_tools.py +++ b/tests/agents/test_3001_async_tools.py @@ -58,6 +58,8 @@ DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" +HTTP_TOOL_NAME = f"PYSAI_3001_HTTP_TOOL_{UUID}" +HTTP_ENDPOINT = "https://example.com/api/tool" EMAIL_TOOL_NAME = f"PYSAI_3001_EMAIL_TOOL_{UUID}" SLACK_TOOL_NAME = f"PYSAI_3001_SLACK_TOOL_{UUID}" @@ -69,6 +71,9 @@ EMAIL_CRED_NAME = f"PYSAI_3001_EMAIL_CRED_{UUID}" SLACK_CRED_NAME = f"PYSAI_3001_SLACK_CRED_{UUID}" +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") SMTP_USERNAME = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") SMTP_PASSWORD = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") SLACK_USERNAME = os.getenv("PYSAI_TEST_SLACK_USERNAME") @@ -301,9 +306,9 @@ async def email_tool(email_credential): tool = await AsyncTool.create_email_notification_tool( tool_name=EMAIL_TOOL_NAME, credential_name=EMAIL_CRED_NAME, - recipient="kondra.nagabhavani@oracle.com", - sender="bharadwaj.vulugundam@oracle.com", - smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, description="Send email", replace=True, ) @@ -320,7 +325,7 @@ async def slack_tool(slack_credential): tool = await AsyncTool.create_slack_notification_tool( tool_name=SLACK_TOOL_NAME, credential_name=SLACK_CRED_NAME, - slack_channel="#general", + channel="#general", description="slack notification", replace=True, ) @@ -742,3 +747,32 @@ async def test_3023_drop_tool_force_false_non_existent_raises(): with pytest.raises(oracledb.Error) as exc: await tool.delete(force=False) logger.info("Received expected drop error: %s", exc.value) + + +async def test_3024_http_tool_created(email_credential): + logger.info("Creating HTTP tool: %s", HTTP_TOOL_NAME) + try: + tool = await AsyncTool.create_http_tool( + tool_name=HTTP_TOOL_NAME, + credential_name=email_credential, + endpoint=HTTP_ENDPOINT, + description="HTTP Tool", + replace=True, + ) + except oracledb.DatabaseError as e: + if "ORA-20052" in str(e): + logger.info( + "HTTP tool creation failed with expected backend-side error: %s", + e, + ) + return + raise + try: + fetched = await AsyncTool.fetch(HTTP_TOOL_NAME) + assert fetched.tool_name == HTTP_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP + assert fetched.attributes.tool_params.credential_name == email_credential + assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT + finally: + logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME) + await tool.delete(force=True) diff --git a/tests/agents/test_3001_tools.py b/tests/agents/test_3001_tools.py index 106ee64..f72eb30 100644 --- a/tests/agents/test_3001_tools.py +++ b/tests/agents/test_3001_tools.py @@ -87,6 +87,11 @@ def get_tool_status(tool_name): DISABLED_TOOL_NAME = f"PYSAI_3001_DISABLED_TOOL_{UUID}" DEFAULT_STATUS_TOOL_NAME = f"PYSAI_3001_DEFAULT_STATUS_TOOL_{UUID}" DROP_FORCE_MISSING_TOOL = f"PYSAI_3001_DROP_MISSING_{UUID}" +HTTP_TOOL_NAME = f"PYSAI_3001_HTTP_TOOL_{UUID}" +HTTP_ENDPOINT = "https://example.com/api/tool" +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") slack_username = os.getenv("PYSAI_TEST_SLACK_USERNAME") @@ -270,9 +275,9 @@ def email_tool(email_credential): tool = select_ai.agent.Tool.create_email_notification_tool( tool_name="EMAIL_TOOL", credential_name="EMAIL_CRED", - recipient="kondra.nagabhavani@oracle.com", - sender="bharadwaj.vulugundam@oracle.com", - smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, description="Send email", replace=True, ) @@ -288,7 +293,7 @@ def slack_tool(slack_credential): tool = select_ai.agent.Tool.create_slack_notification_tool( tool_name="SLACK_TOOL", credential_name="SLACK_CRED", - slack_channel="#general", + channel="#general", description="slack notification", replace=True, ) @@ -656,3 +661,32 @@ def test_3023_drop_tool_force_false_non_existent_raises(): with pytest.raises(oracledb.Error) as exc: tool.delete(force=False) logger.info("Received expected drop error: %s", exc.value) + + +def test_3024_http_tool_created(email_credential): + logger.info("Creating HTTP tool: %s", HTTP_TOOL_NAME) + try: + tool = select_ai.agent.Tool.create_http_tool( + tool_name=HTTP_TOOL_NAME, + credential_name=email_credential, + endpoint=HTTP_ENDPOINT, + description="HTTP Tool", + replace=True, + ) + except oracledb.DatabaseError as e: + if "ORA-20052" in str(e): + logger.info( + "HTTP tool creation failed with expected backend-side error: %s", + e, + ) + return + raise + try: + fetched = select_ai.agent.Tool.fetch(HTTP_TOOL_NAME) + assert fetched.tool_name == HTTP_TOOL_NAME + assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP + assert fetched.attributes.tool_params.credential_name == email_credential + assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT + finally: + logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME) + tool.delete(force=True) diff --git a/tests/agents/test_3101_async_tasks.py b/tests/agents/test_3101_async_tasks.py index 091e272..370158c 100644 --- a/tests/agents/test_3101_async_tasks.py +++ b/tests/agents/test_3101_async_tasks.py @@ -21,7 +21,7 @@ pytestmark = pytest.mark.anyio PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3100_async_tasks.log") +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3101_async_tasks.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) root = logging.getLogger() diff --git a/tests/agents/test_3201_async_agents.py b/tests/agents/test_3201_async_agents.py index 58f102c..6c0d021 100644 --- a/tests/agents/test_3201_async_agents.py +++ b/tests/agents/test_3201_async_agents.py @@ -22,7 +22,7 @@ pytestmark = pytest.mark.anyio PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3200_async_agents.log") +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3201_async_agents.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) root = logging.getLogger() diff --git a/tests/agents/test_3800_agente2e.py b/tests/agents/test_3800_agente2e.py index 634d21e..a3b4a30 100644 --- a/tests/agents/test_3800_agente2e.py +++ b/tests/agents/test_3800_agente2e.py @@ -48,6 +48,10 @@ logger = logging.getLogger() +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") + @contextmanager def log_step(step): @@ -61,26 +65,25 @@ def log_step(step): raise -def _safe_dict(obj): - if obj is None: - return None - if hasattr(obj, "dict"): - try: - return obj.dict(exclude_null=False) - except TypeError: - return obj.dict() - return str(obj) - - def log_object_details(context: str, object_type: str, obj) -> None: details = {"context": context, "object_type": object_type} + attributes = getattr(obj, "attributes", None) if object_type == "profile": details.update( { "profile_name": getattr(obj, "profile_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "provider_type": ( + type(getattr(attributes, "provider", None)).__name__ + if attributes is not None and getattr(attributes, "provider", None) + else None + ), + "object_count": ( + len(getattr(attributes, "object_list", []) or []) + if attributes is not None + else None + ), } ) elif object_type == "agent": @@ -88,7 +91,16 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "agent_name": getattr(obj, "agent_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "profile_name": ( + getattr(attributes, "profile_name", None) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), } ) elif object_type == "tool": @@ -96,7 +108,11 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "tool_name": getattr(obj, "tool_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "tool_type": ( + getattr(attributes, "tool_type", None) + if attributes is not None + else None + ), } ) elif object_type == "task": @@ -104,7 +120,16 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "task_name": getattr(obj, "task_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "tool_count": ( + len(getattr(attributes, "tools", []) or []) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), } ) elif object_type == "team": @@ -112,14 +137,62 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "team_name": getattr(obj, "team_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "process": ( + getattr(attributes, "process", None) + if attributes is not None + else None + ), + "agent_count": ( + len(getattr(attributes, "agents", []) or []) + if attributes is not None + else None + ), } ) else: details["repr"] = str(obj) logger.info("OBJECT_DETAILS: %s", details) - print("OBJECT_DETAILS:", details) + + +def log_credential_setup(credential_name): + logger.info("Preparing credential | name=%s", credential_name) + + +def verify_credential_exists(credential_name, expected_username=None): + logger.info("Verifying credential exists in DB: %s", credential_name) + + with select_ai.cursor() as cur: + cur.execute( + """ + SELECT credential_name, username + FROM user_credentials + WHERE UPPER(credential_name) = UPPER(:credential_name) + """, + credential_name=credential_name, + ) + row = cur.fetchone() + + assert row is not None, f"Credential {credential_name} was not created" + + actual_name, actual_username = row + logger.info("Verified credential | name=%s", actual_name) + assert actual_name.upper() == credential_name.upper() + if expected_username is not None: + assert actual_username == expected_username + + +def _decode_history_rows(rows): + decoded_rows = [] + for row in rows: + decoded_row = [] + for value in row: + if hasattr(value, "read"): + decoded_row.append(value.read()) + else: + decoded_row.append(value) + decoded_rows.append(tuple(decoded_row)) + return decoded_rows @pytest.fixture(scope="session") @@ -154,17 +227,21 @@ def setup_test_user(test_env): def openai_cred(): api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") assert api_key, "PYSAI_TEST_OPENAI_API_KEY not set" + cred_name = "OPENAI_CRED" + + log_credential_setup(cred_name) select_ai.create_credential( credential={ - "credential_name": "OPENAI_CRED", + "credential_name": cred_name, "username": "openai", "password": api_key, }, replace=True, ) - return "OPENAI_CRED" + verify_credential_exists(cred_name, expected_username="openai") + return cred_name @pytest.fixture(scope="session") @@ -174,17 +251,21 @@ def email_cred(): assert smtp_username, "PYSAI_TEST_EMAIL_CRED_USERNAME not set" assert smtp_password, "PYSAI_TEST_EMAIL_CRED_PASSWORD not set" + cred_name = "EMAIL_CRED" + + log_credential_setup(cred_name) select_ai.create_credential( credential={ - "credential_name": "EMAIL_CRED", + "credential_name": cred_name, "username": smtp_username, "password": smtp_password, }, replace=True, ) - return "EMAIL_CRED" + verify_credential_exists(cred_name, expected_username=smtp_username) + return cred_name @pytest.fixture(scope="session") @@ -219,10 +300,7 @@ def append_ace(host, privileges): return raise - append_ace( - "smtp.email.us-ashburn-1.oci.oraclecloud.com", - ["connect", "smtp"], - ) + append_ace(EMAIL_SMTP_HOST, ["connect", "smtp"]) for host in ["api.openai.com", "a.co","amazon.in"]: append_ace(host, ["connect", "http"]) @@ -246,6 +324,18 @@ def test_3800_agent_end_to_end( # PROFILE # ------------------------------- logger.info("Starting End-to-End Agent Test") + logger.info( + "Resolved credential fixtures | openai=%s | email=%s", + openai_cred, + email_cred, + ) + created = { + "team": None, + "task": None, + "tools": [], + "agent": None, + "profile": None, + } # ---------------- PROFILE ---------------- @@ -254,183 +344,202 @@ def test_3800_agent_end_to_end( profile_attributes.provider.oci_compartment_id = oci_compartment_id - profile = select_ai.Profile( - profile_name="GEN1_PROFILE", - attributes=profile_attributes, - replace=True, - ) - log_object_details("create_profile", "profile", profile) - - # ------------------------------- - # AGENT - # ------------------------------- - with log_step("Create agent"): - agent = Agent( - agent_name="CustomerAgent", - attributes=AgentAttributes( - profile_name="GEN1_PROFILE", - role="You are an experienced customer agent handling returns.", - enable_human_tool=True, - ), - ) - agent.create(replace=True) - log_object_details("create_agent", "agent", agent) - - assert agent.agent_name == "CustomerAgent" - - # ------------------------------- - # TOOLS - # ------------------------------- - with log_step("Create tools"): - - # Human tool - Tool.create_built_in_tool( - tool_name="Human", - description="Human intervention tool", - tool_type="HUMAN", - tool_params=ToolParams(), + try: + profile = select_ai.Profile( + profile_name="GEN1_PROFILE", + attributes=profile_attributes, replace=True, ) - - websearch_tool = Tool( - tool_name="Websearch", - attributes=ToolAttributes( - tool_type="WEBSEARCH", - instruction="Use this tool to find the current price of a product from a URL.", - tool_params=ToolParams( - credential_name="OPENAI_CRED" + created["profile"] = profile + log_object_details("create_profile", "profile", profile) + + # ------------------------------- + # AGENT + # ------------------------------- + with log_step("Create agent"): + agent = Agent( + agent_name="CustomerAgent", + attributes=AgentAttributes( + profile_name="GEN1_PROFILE", + role="You are an experienced customer agent handling returns.", + enable_human_tool=False, ), - ), - ) - websearch_tool.create(replace=True) - log_object_details("create_websearch_tool", "tool", websearch_tool) - - # Email notification tool - email_recipient = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") - email_sender = os.getenv("PYSAI_TEST_EMAIL_SENDER") - assert email_recipient, "PYSAI_TEST_EMAIL_RECIPIENT not set" - assert email_sender, "PYSAI_TEST_EMAIL_SENDER not set" - email_tool = Tool( - tool_name="Email", - attributes=ToolAttributes( - tool_type="NOTIFICATION", - tool_params=ToolParams( - credential_name="EMAIL_CRED", - notification_type="EMAIL", - recipient=email_recipient, - sender=email_sender, - smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + ) + agent.create(replace=True) + created["agent"] = agent + log_object_details("create_agent", "agent", agent) + + assert agent.agent_name == "CustomerAgent" + assert agent.attributes.enable_human_tool is False + + # ------------------------------- + # TOOLS + # ------------------------------- + with log_step("Create tools"): + websearch_tool = Tool( + tool_name="Websearch", + attributes=ToolAttributes( + tool_type="WEBSEARCH", + instruction="Use this tool to find the current price of a product from a URL.", + tool_params=ToolParams( + credential_name=openai_cred + ), ), - ), - ) - email_tool.create(replace=True) - log_object_details("create_email_tool", "tool", email_tool) - - assert Tool("Human") is not None - assert Tool("Email") is not None - - # ------------------------------- - # TASK - # ------------------------------- - with log_step("Create task"): - task = Task( - task_name="Return_And_Price_Match", - attributes=TaskAttributes( - instruction=( - "Process a product return request from a customer. " - "1. Ask customer the reason for return (price match or defective). " - "2. If price match: " - " a. Request customer to provide a price match link. " - " b. Use websearch tool to get the price for that price match link" - " c. Ask customer if they want a refund and specify how much refund. " - " d. Send email notification only if customer accepts the refund. " - "3. If defective: " - " a. Process the defective return." + ) + websearch_tool.create(replace=True) + created["tools"].append(websearch_tool) + log_object_details("create_websearch_tool", "tool", websearch_tool) + fetched_websearch_tool = Tool.fetch("Websearch") + logger.info( + "Verified fetched websearch tool credential | tool=%s | credential=%s", + fetched_websearch_tool.tool_name, + fetched_websearch_tool.attributes.tool_params.credential_name, + ) + assert fetched_websearch_tool.attributes.tool_params.credential_name == openai_cred + + # Email notification tool + email_tool = Tool( + tool_name="Email", + attributes=ToolAttributes( + tool_type="NOTIFICATION", + tool_params=ToolParams( + credential_name=email_cred, + notification_type="EMAIL", + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, + ), ), - tools=["Human", "Websearch", "Email"], - ), - ) - task.create(replace=True) - log_object_details("create_task", "task", task) + ) + email_tool.create(replace=True) + created["tools"].append(email_tool) + log_object_details("create_email_tool", "tool", email_tool) + fetched_email_tool = Tool.fetch("Email") + logger.info( + "Verified fetched email tool credential | tool=%s | credential=%s", + fetched_email_tool.tool_name, + fetched_email_tool.attributes.tool_params.credential_name, + ) + assert fetched_email_tool.attributes.tool_params.credential_name == email_cred + + assert Tool("Email") is not None + assert websearch_tool.attributes.tool_params.credential_name == openai_cred + assert email_tool.attributes.tool_params.credential_name == email_cred + + # ------------------------------- + # TASK + # ------------------------------- + with log_step("Create task"): + task = Task( + task_name="Return_And_Price_Match", + attributes=TaskAttributes( + instruction=( + "Process a product return request from a customer. " + "1. Ask customer the reason for return (price match or defective). " + "2. If price match: " + " a. Request customer to provide a price match link. " + " b. Use websearch tool to get the price for that price match link" + " c. Ask customer if they want a refund and specify how much refund. " + " d. Send email notification only if customer accepts the refund. " + "3. If defective: " + " a. Process the defective return." + ), + tools=["Websearch", "Email"], + enable_human_tool=False, + ), + ) + task.create(replace=True) + created["task"] = task + log_object_details("create_task", "task", task) + + assert task.task_name == "Return_And_Price_Match" + assert set(task.attributes.tools) == {"Websearch", "Email"} + assert task.attributes.enable_human_tool is False assert task.task_name == "Return_And_Price_Match" - assert set(task.attributes.tools) == {"Human", "Websearch", "Email"} + assert set(task.attributes.tools) == {"Websearch", "Email"} + + # ------------------------------- + # TEAM + # ------------------------------- + with log_step("Create team"): + team = Team( + team_name="ReturnAgency", + attributes=TeamAttributes( + agents=[{ + "name": "CustomerAgent", + "task": "Return_And_Price_Match", + }], + process="sequential", + ), + ) + team.create(enabled=True, replace=True) + created["team"] = team + log_object_details("create_team", "team", team) + + assert team.team_name == "ReturnAgency" + + # ------------------------------- + # RUN CONVERSATION + # ------------------------------- + with log_step("Run agent conversation"): + conversation_id = str(uuid.uuid4()) + + prompts = [ + "I want to return an office chair", + "The price when I bought it is 100. But I found a cheaper price", + "Here is the price match link 'https://www.ikea.com/us/en/p/stefan-chair-brown-black-00211088/'", + "Yes, I would like to proceed with a refund", + ] + + for idx, prompt in enumerate(prompts, start=1): + logger.info("USER %d: %s", idx, prompt) + + response = team.run( + prompt=prompt, + params={"conversation_id": conversation_id}, + ) - assert task.task_name == "Return_And_Price_Match" - # Corrected assert to match the 3 tools - assert set(task.attributes.tools) == {"Human", "Websearch", "Email"} - - # ------------------------------- - # TEAM - # ------------------------------- - with log_step("Create team"): - team = Team( - team_name="ReturnAgency", - attributes=TeamAttributes( - agents=[{ - "name": "CustomerAgent", - "task": "Return_And_Price_Match", - }], - process="sequential", - ), - ) - team.create(enabled=True, replace=True) - log_object_details("create_team", "team", team) + print(f"\nAGENT RESPONSE {idx}:\n{response}\n") + logger.info("AGENT RESPONSE %d: %s", idx, response) - assert team.team_name == "ReturnAgency" + assert response is not None + assert isinstance(response, (str, dict)) - # ------------------------------- - # RUN CONVERSATION - # ------------------------------- - with log_step("Run agent conversation"): - conversation_id = str(uuid.uuid4()) - - prompts = [ - "I want to return an office chair", - "The price when I bought it is 100. But I found a cheaper price", - "Here is the price match link 'https://www.ikea.com/us/en/p/stefan-chair-brown-black-00211088/'", - "Yes, I would like to proceed with a refund", - ] - - for idx, prompt in enumerate(prompts, start=1): - logger.info("USER %d: %s", idx, prompt) - - response = team.run( - prompt=prompt, - params={"conversation_id": conversation_id}, - ) + if isinstance(response, dict): + assert response - # ---- PRINT + LOG RESPONSE ---- - print(f"\nAGENT RESPONSE {idx}:\n{response}\n") - logger.info("AGENT RESPONSE %d: %s", idx, response) + with select_ai.cursor() as cur: + cur.execute( + """ + SELECT * FROM user_ai_agent_tool_history + """ + ) + tool_history = cur.fetchall() - assert response is not None - assert isinstance(response, (str, dict)) + decoded_tool_history = _decode_history_rows(tool_history) - if isinstance(response, dict): - assert response + logger.info("Tool history rows fetched: %d", len(decoded_tool_history)) - with select_ai.cursor() as cur: - cur.execute( - """ - SELECT * FROM user_ai_agent_tool_history - """ - ) - tool_history = cur.fetchall() - - decoded_tool_history = [] - for row in tool_history: - decoded_row = [] - for value in row: - if hasattr(value, "read"): - decoded_row.append(value.read()) - else: - decoded_row.append(value) - decoded_tool_history.append(tuple(decoded_row)) - - print(decoded_tool_history) - logger.info("Tool history rows fetched: %d", len(decoded_tool_history)) - for row in decoded_tool_history: - logger.info("TOOL_HISTORY_ROW: %s", row) - - assert decoded_tool_history + assert decoded_tool_history + finally: + with log_step("Cleanup sync e2e objects"): + if created["team"] is not None: + logger.info("Deleting team: %s", created["team"].team_name) + created["team"].delete(force=True) + + if created["task"] is not None: + logger.info("Deleting task: %s", created["task"].task_name) + created["task"].delete(force=True) + + for tool in reversed(created["tools"]): + logger.info("Deleting tool: %s", tool.tool_name) + tool.delete(force=True) + + if created["agent"] is not None: + logger.info("Deleting agent: %s", created["agent"].agent_name) + created["agent"].delete(force=True) + + if created["profile"] is not None: + logger.info("Deleting profile: %s", created["profile"].profile_name) + created["profile"].delete(force=True) diff --git a/tests/agents/test_3800_async_agente2e.py b/tests/agents/test_3800_async_agente2e.py index 191f777..af60c6b 100644 --- a/tests/agents/test_3800_async_agente2e.py +++ b/tests/agents/test_3800_async_agente2e.py @@ -9,6 +9,7 @@ 3800 - Async end-to-end Select AI Agent integration test """ +import inspect import logging import os import time @@ -50,6 +51,10 @@ logger = logging.getLogger(__name__) +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") + @contextmanager def log_step(step): @@ -63,26 +68,25 @@ def log_step(step): raise -def _safe_dict(obj): - if obj is None: - return None - if hasattr(obj, "dict"): - try: - return obj.dict(exclude_null=False) - except TypeError: - return obj.dict() - return str(obj) - - def log_object_details(context: str, object_type: str, obj) -> None: details = {"context": context, "object_type": object_type} + attributes = getattr(obj, "attributes", None) if object_type == "profile": details.update( { "profile_name": getattr(obj, "profile_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "provider_type": ( + type(getattr(attributes, "provider", None)).__name__ + if attributes is not None and getattr(attributes, "provider", None) + else None + ), + "object_count": ( + len(getattr(attributes, "object_list", []) or []) + if attributes is not None + else None + ), } ) elif object_type == "agent": @@ -90,7 +94,16 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "agent_name": getattr(obj, "agent_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "profile_name": ( + getattr(attributes, "profile_name", None) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), } ) elif object_type == "tool": @@ -98,7 +111,11 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "tool_name": getattr(obj, "tool_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "tool_type": ( + getattr(attributes, "tool_type", None) + if attributes is not None + else None + ), } ) elif object_type == "task": @@ -106,7 +123,16 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "task_name": getattr(obj, "task_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "tool_count": ( + len(getattr(attributes, "tools", []) or []) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), } ) elif object_type == "team": @@ -114,37 +140,204 @@ def log_object_details(context: str, object_type: str, obj) -> None: { "team_name": getattr(obj, "team_name", None), "description": getattr(obj, "description", None), - "attributes": _safe_dict(getattr(obj, "attributes", None)), + "process": ( + getattr(attributes, "process", None) + if attributes is not None + else None + ), + "agent_count": ( + len(getattr(attributes, "agents", []) or []) + if attributes is not None + else None + ), } ) else: details["repr"] = str(obj) logger.info("OBJECT_DETAILS: %s", details) - print("OBJECT_DETAILS:", details) + + +def log_credential_setup(credential_name): + logger.info("Preparing credential | name=%s", credential_name) + + +def verify_credential_exists(credential_name, expected_username=None): + logger.info("Verifying credential exists in DB: %s", credential_name) + + with select_ai.cursor() as cur: + cur.execute( + """ + SELECT credential_name, username + FROM user_credentials + WHERE UPPER(credential_name) = UPPER(:credential_name) + """, + credential_name=credential_name, + ) + row = cur.fetchone() + + assert row is not None, f"Credential {credential_name} was not created" + + actual_name, actual_username = row + logger.info("Verified credential | name=%s", actual_name) + assert actual_name.upper() == credential_name.upper() + if expected_username is not None: + assert actual_username == expected_username + + +async def _decode_history_rows(rows): + decoded_rows = [] + for row in rows: + decoded_row = [] + for value in row: + if hasattr(value, "read"): + lob_value = value.read() + if inspect.isawaitable(lob_value): + lob_value = await lob_value + decoded_row.append(lob_value) + else: + decoded_row.append(value) + decoded_rows.append(tuple(decoded_row)) + return decoded_rows + + +@pytest.fixture(scope="session") +def setup_test_user(test_env): + try: + select_ai.disconnect() + except Exception: + pass + + select_ai.connect(**test_env.connect_params(admin=True)) + try: + try: + select_ai.grant_privileges(users=[test_env.test_user]) + except Exception as exc: + msg = str(exc) + if ( + "ORA-01749" not in msg + and "Cannot GRANT or REVOKE privileges to or from yourself" not in msg + ): + raise + + select_ai.grant_http_access( + users=[test_env.test_user], + provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, + ) + finally: + select_ai.disconnect() + select_ai.connect(**test_env.connect_params()) + + +@pytest.fixture(scope="session") +def openai_cred(): + api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") + assert api_key, "PYSAI_TEST_OPENAI_API_KEY not set" + cred_name = "OPENAI_CRED" + + log_credential_setup(cred_name) + + select_ai.create_credential( + credential={ + "credential_name": cred_name, + "username": "openai", + "password": api_key, + }, + replace=True, + ) + + verify_credential_exists(cred_name, expected_username="openai") + return cred_name + + +@pytest.fixture(scope="session") +def email_cred(): + smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") + smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") + + assert smtp_username, "PYSAI_TEST_EMAIL_CRED_USERNAME not set" + assert smtp_password, "PYSAI_TEST_EMAIL_CRED_PASSWORD not set" + cred_name = "EMAIL_CRED" + + log_credential_setup(cred_name) + + select_ai.create_credential( + credential={ + "credential_name": cred_name, + "username": smtp_username, + "password": smtp_password, + }, + replace=True, + ) + + verify_credential_exists(cred_name, expected_username=smtp_username) + return cred_name + + +@pytest.fixture(scope="session") +def allow_network_acl(): + with select_ai.cursor() as cur: + cur.execute("SELECT USER FROM dual") + db_user = cur.fetchone()[0] + + def append_ace(host, privileges): + try: + cur.execute( + f""" + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => '{host}', + ace => xs$ace_type( + privilege_list => xs$name_list({','.join([f"'{p}'" for p in privileges])}), + principal_name => '{db_user}', + principal_type => xs_acl.ptype_db + ) + ); + END; + """ + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" in msg + or "ORA-46313" in msg + or "already exists" in msg + ): + return + raise + + append_ace(EMAIL_SMTP_HOST, ["connect", "smtp"]) + + for host in ["api.openai.com", "a.co", "amazon.in"]: + append_ace(host, ["connect", "http"]) + + yield @pytest.fixture(scope="module", autouse=True) -async def async_connect(test_env): +async def async_connect( + test_env, setup_test_user, openai_cred, email_cred, allow_network_acl +): logger.info( - "Opening async admin database connection | user=%s | dsn=%s", - test_env.admin_user, + "Opening async database connection | user=%s | dsn=%s", + test_env.test_user, test_env.connect_string, ) - await select_ai.async_connect(**test_env.connect_params(admin=True)) + await select_ai.async_connect(**test_env.connect_params()) yield - logger.info("Closing async admin database connection") + logger.info("Closing async database connection") await select_ai.async_disconnect() -async def test_3800_agent_end_to_end_async(profile_attributes): +async def test_3800_agent_end_to_end_async( + profile_attributes, openai_cred, email_cred +): """End-to-end Select AI Agent integration test (async).""" run_id = uuid.uuid4().hex.upper() profile_name = f"GEN1_PROFILE_{run_id}" agent_name = f"CustomerAgent_{run_id}" - human_tool_name = f"Human_{run_id}" websearch_tool_name = f"Websearch_{run_id}" email_tool_name = f"Email_{run_id}" task_name = f"Return_And_Price_Match_{run_id}" @@ -166,6 +359,11 @@ async def test_3800_agent_end_to_end_async(profile_attributes): task_name, team_name, ) + logger.info( + "Resolved credential fixtures | openai=%s | email=%s", + openai_cred, + email_cred, + ) oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") assert oci_compartment_id, "PYSAI_TEST_OCI_COMPARTMENT_ID not set" @@ -191,7 +389,7 @@ async def test_3800_agent_end_to_end_async(profile_attributes): role=( "You are an experienced customer agent handling returns." ), - enable_human_tool=True, + enable_human_tool=False, ), ) await agent.create(replace=True) @@ -199,17 +397,9 @@ async def test_3800_agent_end_to_end_async(profile_attributes): logger.info("Created agent: %s", agent.agent_name) log_object_details("create_agent", "agent", agent) assert agent.agent_name == agent_name + assert agent.attributes.enable_human_tool is False with log_step("Create tools"): - human_tool = await AsyncTool.create_built_in_tool( - tool_name=human_tool_name, - description="Human intervention tool", - tool_type=select_ai.agent.ToolType.HUMAN, - tool_params=ToolParams(), - replace=True, - ) - created["tools"].append(human_tool) - websearch_tool = AsyncTool( tool_name=websearch_tool_name, attributes=ToolAttributes( @@ -217,40 +407,57 @@ async def test_3800_agent_end_to_end_async(profile_attributes): instruction=( "Use this tool to find current product price from a URL." ), - tool_params=ToolParams(credential_name="OPENAI_CRED"), + tool_params=ToolParams(credential_name=openai_cred), ), ) await websearch_tool.create(replace=True) created["tools"].append(websearch_tool) log_object_details("create_websearch_tool", "tool", websearch_tool) + fetched_websearch_tool = await AsyncTool.fetch(websearch_tool_name) + logger.info( + "Verified fetched websearch tool credential | tool=%s | credential=%s", + fetched_websearch_tool.tool_name, + fetched_websearch_tool.attributes.tool_params.credential_name, + ) + assert ( + fetched_websearch_tool.attributes.tool_params.credential_name + == openai_cred + ) - email_recipient = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") - email_sender = os.getenv("PYSAI_TEST_EMAIL_SENDER") - assert email_recipient, "PYSAI_TEST_EMAIL_RECIPIENT not set" - assert email_sender, "PYSAI_TEST_EMAIL_SENDER not set" email_tool = AsyncTool( tool_name=email_tool_name, attributes=ToolAttributes( tool_type=select_ai.agent.ToolType.NOTIFICATION, tool_params=ToolParams( - credential_name="EMAIL_CRED", + credential_name=email_cred, notification_type="EMAIL", - recipient=email_recipient, - sender=email_sender, - smtp_host="smtp.email.us-ashburn-1.oci.oraclecloud.com", + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, ), ), ) await email_tool.create(replace=True) created["tools"].append(email_tool) log_object_details("create_email_tool", "tool", email_tool) + fetched_email_tool = await AsyncTool.fetch(email_tool_name) + logger.info( + "Verified fetched email tool credential | tool=%s | credential=%s", + fetched_email_tool.tool_name, + fetched_email_tool.attributes.tool_params.credential_name, + ) + assert ( + fetched_email_tool.attributes.tool_params.credential_name + == email_cred + ) logger.info( "Created tools: %s", [t.tool_name for t in created["tools"]], ) - log_object_details("create_human_tool", "tool", human_tool) - assert len(created["tools"]) == 3 + assert len(created["tools"]) == 2 + assert websearch_tool.attributes.tool_params.credential_name == openai_cred + assert email_tool.attributes.tool_params.credential_name == email_cred with log_step("Create task"): task = AsyncTask( @@ -263,7 +470,8 @@ async def test_3800_agent_end_to_end_async(profile_attributes): "send email only if accepted. " "3. If defective: process defective return." ), - tools=[human_tool_name, websearch_tool_name, email_tool_name], + tools=[websearch_tool_name, email_tool_name], + enable_human_tool=False, ), ) await task.create(replace=True) @@ -273,10 +481,10 @@ async def test_3800_agent_end_to_end_async(profile_attributes): log_object_details("create_task", "task", task) assert task.task_name == task_name assert set(task.attributes.tools) == { - human_tool_name, websearch_tool_name, email_tool_name, } + assert task.attributes.enable_human_tool is False with log_step("Create team"): team = AsyncTeam( @@ -322,6 +530,23 @@ async def test_3800_agent_end_to_end_async(profile_attributes): assert isinstance(response, str) assert len(response.strip()) > 0 + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT * FROM user_ai_agent_tool_history + """ + ) + tool_history = await cur.fetchall() + + decoded_tool_history = await _decode_history_rows(tool_history) + + logger.info( + "Async tool history rows fetched: %d", + len(decoded_tool_history), + ) + + assert decoded_tool_history + finally: with log_step("Cleanup async e2e objects"): if created["team"] is not None: diff --git a/tests/agents/test_3900_async_sql_team.py b/tests/agents/test_3900_async_sql_team.py new file mode 100644 index 0000000..3a4b012 --- /dev/null +++ b/tests/agents/test_3900_async_sql_team.py @@ -0,0 +1,496 @@ +import os +import logging +import uuid +from contextlib import contextmanager + +import pytest +import select_ai +import select_ai.agent +from select_ai.agent import AsyncAgent, AsyncTask, AsyncTeam, AsyncTool + +pytestmark = pytest.mark.anyio + +# Configure file-based logging for this script run. +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "test_3900_async_sql_team.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) +for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) +file_handler = logging.FileHandler(LOG_FILE, mode="w") +file_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root_logger.addHandler(file_handler) + +logger = logging.getLogger(__name__) + +RUN_ID = uuid.uuid4().hex.upper() +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") +SQL_PROFILE_NAME = f"ASYNC_SQL_PROFILE_{RUN_ID}" +SQL_TOOL_NAME = f"ASYNC_SQL_QUERY_TOOL_{RUN_ID}" +EMAIL_TOOL_NAME = f"ASYNC_EMAIL_NOTIFICATION_TOOL_{RUN_ID}" +SQL_TASK_NAME = f"ASYNC_SQL_ANALYSIS_TASK_{RUN_ID}" +SQL_AGENT_NAME = f"ASYNC_SQL_ANALYST_AGENT_{RUN_ID}" +SQL_TEAM_NAME = f"ASYNC_SQL_DATA_TEAM_{RUN_ID}" +OCI_CREDENTIAL_NAME = f"ASYNC_SQL_TEAM_OCI_CRED_{RUN_ID}" + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@contextmanager +def log_step(step): + logger.info("START: %s", step) + try: + yield + logger.info("END: %s", step) + except Exception: + logger.exception("FAILED: %s", step) + raise + + +def log_object_details(context: str, object_type: str, obj) -> None: + details = {"context": context, "object_type": object_type} + attributes = getattr(obj, "attributes", None) + + if object_type == "profile": + details.update( + { + "profile_name": getattr(obj, "profile_name", None), + "description": getattr(obj, "description", None), + "provider_type": ( + type(getattr(attributes, "provider", None)).__name__ + if attributes is not None and getattr(attributes, "provider", None) + else None + ), + "object_count": ( + len(getattr(attributes, "object_list", []) or []) + if attributes is not None + else None + ), + } + ) + elif object_type == "agent": + details.update( + { + "agent_name": getattr(obj, "agent_name", None), + "description": getattr(obj, "description", None), + "profile_name": ( + getattr(attributes, "profile_name", None) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "tool": + details.update( + { + "tool_name": getattr(obj, "tool_name", None), + "description": getattr(obj, "description", None), + "tool_type": ( + getattr(attributes, "tool_type", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "task": + details.update( + { + "task_name": getattr(obj, "task_name", None), + "description": getattr(obj, "description", None), + "tool_count": ( + len(getattr(attributes, "tools", []) or []) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "team": + details.update( + { + "team_name": getattr(obj, "team_name", None), + "description": getattr(obj, "description", None), + "process": ( + getattr(attributes, "process", None) + if attributes is not None + else None + ), + "agent_count": ( + len(getattr(attributes, "agents", []) or []) + if attributes is not None + else None + ), + } + ) + else: + details["repr"] = str(obj) + + logger.info("OBJECT_DETAILS: %s", details) + + +def log_credential_setup(credential_name): + logger.info("Preparing credential | name=%s", credential_name) + + +async def verify_credential_exists(credential_name, expected_username=None): + logger.info("Verifying credential exists in DB: %s", credential_name) + async with select_ai.async_cursor() as cur: + await cur.execute( + """ + SELECT credential_name, username + FROM user_credentials + WHERE UPPER(credential_name) = UPPER(:credential_name) + """, + credential_name=credential_name, + ) + row = await cur.fetchone() + + assert row is not None, f"Credential {credential_name} was not created" + + actual_name, actual_username = row + logger.info("Verified credential | name=%s", actual_name) + assert actual_name.upper() == credential_name.upper() + if expected_username is not None: + assert actual_username == expected_username + + +async def connect_to_db(): + # Connect the Python client to the test database. + user = os.getenv("PYSAI_TEST_USER") + password = os.getenv("PYSAI_TEST_USER_PASSWORD") + dsn = os.getenv("PYSAI_TEST_CONNECT_STRING") + assert user, "PYSAI_TEST_USER not set" + assert password, "PYSAI_TEST_USER_PASSWORD not set" + assert dsn, "PYSAI_TEST_CONNECT_STRING not set" + logger.info("Connecting to database using configured test credentials") + await select_ai.async_connect(user=user, password=password, dsn=dsn) + + +async def _cleanup_async_sql_team_objects(created) -> None: + with log_step("Cleanup async SQL team objects"): + if created["team"] is not None: + try: + logger.info("Deleting team: %s", created["team"].team_name) + await created["team"].delete(force=True) + except Exception: + logger.exception("Failed to delete team: %s", created["team"].team_name) + + if created["task"] is not None: + try: + logger.info("Deleting task: %s", created["task"].task_name) + await created["task"].delete(force=True) + except Exception: + logger.exception("Failed to delete task: %s", created["task"].task_name) + + for tool in reversed(created["tools"]): + try: + logger.info("Deleting tool: %s", tool.tool_name) + await tool.delete(force=True) + except Exception: + logger.exception("Failed to delete tool: %s", tool.tool_name) + + if created["agent"] is not None: + try: + logger.info("Deleting agent: %s", created["agent"].agent_name) + await created["agent"].delete(force=True) + except Exception: + logger.exception( + "Failed to delete agent: %s", created["agent"].agent_name + ) + + if created["profile"] is not None: + try: + logger.info("Deleting profile: %s", created["profile"].profile_name) + await created["profile"].delete(force=True) + except Exception: + logger.exception( + "Failed to delete profile: %s", created["profile"].profile_name + ) + + for credential_name in reversed(created["credentials"]): + try: + logger.info("Deleting credential: %s", credential_name) + await select_ai.async_delete_credential(credential_name, force=True) + except Exception: + logger.exception("Failed to delete credential: %s", credential_name) + + +async def allow_network_acl(): + # Grant the database user SMTP access required by the email notification tool. + async with select_ai.async_cursor() as cur: + try: + await cur.execute( + """ + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => :host, + ace => xs$ace_type( + privilege_list => xs$name_list('connect', 'smtp'), + principal_name => SYS_CONTEXT('USERENV', 'CURRENT_USER'), + principal_type => xs_acl.ptype_db + ) + ); + END; + """, + host=EMAIL_SMTP_HOST, + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + + +async def create_async_sql_team(): + created = { + "team": None, + "task": None, + "tools": [], + "agent": None, + "profile": None, + "credentials": [], + } + + # Initialize database access required by the team and tools. + try: + with log_step("Initialize database and network access"): + await connect_to_db() + await allow_network_acl() + + # Load OCI model and credential settings from the environment. + oci_user_ocid = os.getenv("PYSAI_TEST_OCI_USER_OCID") + oci_tenancy_ocid = os.getenv("PYSAI_TEST_OCI_TENANCY_OCID") + oci_private_key = os.getenv("PYSAI_TEST_OCI_PRIVATE_KEY") + oci_fingerprint = os.getenv("PYSAI_TEST_OCI_FINGERPRINT") + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + oci_region = "us-chicago-1" + oci_apiformat = "GENERIC" + oci_model = "meta.llama-4-maverick-17b-128e-instruct-fp8" + + assert oci_user_ocid, "PYSAI_TEST_OCI_USER_OCID not set" + assert oci_tenancy_ocid, "PYSAI_TEST_OCI_TENANCY_OCID not set" + assert oci_private_key, "PYSAI_TEST_OCI_PRIVATE_KEY not set" + assert oci_fingerprint, "PYSAI_TEST_OCI_FINGERPRINT not set" + assert oci_compartment_id, "PYSAI_TEST_OCI_COMPARTMENT_ID not set" + logger.info( + "Resolved OCI configuration | credential=%s | region=%s | model=%s", + OCI_CREDENTIAL_NAME, + oci_region, + oci_model, + ) + + # Create the OCI credential used by the Select AI profile. + with log_step("Create OCI credential"): + await select_ai.async_create_credential( + credential={ + "credential_name": OCI_CREDENTIAL_NAME, + "user_ocid": oci_user_ocid, + "tenancy_ocid": oci_tenancy_ocid, + "private_key": oci_private_key, + "fingerprint": oci_fingerprint, + }, + replace=True, + ) + created["credentials"].append(OCI_CREDENTIAL_NAME) + await verify_credential_exists(OCI_CREDENTIAL_NAME) + + # Create the profile that backs the SQL agent's model access. + with log_step("Create SQL profile"): + profile = await select_ai.AsyncProfile( + profile_name=SQL_PROFILE_NAME, + attributes=select_ai.ProfileAttributes( + credential_name=OCI_CREDENTIAL_NAME, + provider=select_ai.OCIGenAIProvider( + region=oci_region, + oci_apiformat=oci_apiformat, + model=oci_model, + oci_compartment_id=oci_compartment_id, + ), + object_list=[{"owner": "SH"}], + ), + description="Profile for async SQL Agent using OCI GenAI provider.", + replace=True, + ) + created["profile"] = profile + log_object_details("create_profile", "profile", profile) + assert profile.profile_name == SQL_PROFILE_NAME + + # Create the SQL tool the task will use to query database objects. + with log_step("Create SQL query tool"): + sql_tool = await AsyncTool.create_sql_tool( + tool_name=SQL_TOOL_NAME, + profile_name=SQL_PROFILE_NAME, + description=( + "Use this tool to query database tables for sales and customer info." + ), + replace=True, + ) + created["tools"].append(sql_tool) + log_object_details("create_sql_tool", "tool", sql_tool) + fetched_sql_tool = await AsyncTool.fetch(SQL_TOOL_NAME) + assert fetched_sql_tool.tool_name == SQL_TOOL_NAME + assert fetched_sql_tool.attributes.tool_params.profile_name == SQL_PROFILE_NAME + + # Load SMTP settings for the email notification tool. + email_credential_name = ( + os.getenv("PYSAI_TEST_EMAIL_CREDENTIAL_NAME") or f"EMAIL_CRED_{RUN_ID}" + ) + email_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") + email_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") + + assert email_username, "PYSAI_TEST_EMAIL_CRED_USERNAME not set" + assert email_password, "PYSAI_TEST_EMAIL_CRED_PASSWORD not set" + log_credential_setup(email_credential_name) + + # Create the SMTP credential used by the notification tool. + with log_step("Create email credential"): + await select_ai.async_create_credential( + credential={ + "credential_name": email_credential_name, + "username": email_username, + "password": email_password, + }, + replace=True, + ) + created["credentials"].append(email_credential_name) + await verify_credential_exists( + email_credential_name, expected_username=email_username + ) + + # Create the built-in email notification tool. + with log_step("Create email notification tool"): + email_tool = await AsyncTool.create_email_notification_tool( + tool_name=EMAIL_TOOL_NAME, + credential_name=email_credential_name, + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, + description="Send notification emails for SQL insights", + replace=True, + ) + created["tools"].append(email_tool) + log_object_details("create_email_tool", "tool", email_tool) + fetched_email_tool = await AsyncTool.fetch(EMAIL_TOOL_NAME) + assert fetched_email_tool.tool_name == EMAIL_TOOL_NAME + assert ( + fetched_email_tool.attributes.tool_params.credential_name + == email_credential_name + ) + + # Create the task that combines SQL analysis with email delivery. + with log_step("Create SQL analysis task"): + task = AsyncTask( + task_name=SQL_TASK_NAME, + attributes=select_ai.agent.TaskAttributes( + instruction=( + "Answer the user query by querying the database: {query}. " + "After you produce the answer, send a concise summary of the findings " + "to the analytics stakeholders using the " + f"{EMAIL_TOOL_NAME}. " + "Include the SQL results and any key metrics in the email body." + ), + tools=[SQL_TOOL_NAME, EMAIL_TOOL_NAME], + ), + ) + await task.create(enabled=True, replace=True) + created["task"] = task + log_object_details("create_task", "task", task) + fetched_task = await AsyncTask.fetch(SQL_TASK_NAME) + assert fetched_task.task_name == SQL_TASK_NAME + assert set(fetched_task.attributes.tools) == { + SQL_TOOL_NAME, + EMAIL_TOOL_NAME, + } + + # Create the agent that will execute the SQL analysis task. + with log_step("Create SQL analyst agent"): + agent = AsyncAgent( + agent_name=SQL_AGENT_NAME, + attributes=select_ai.agent.AgentAttributes( + profile_name=SQL_PROFILE_NAME, + role="You are a data analyst that translates natural language to SQL.", + enable_human_tool=False, + ), + ) + await agent.create(enabled=True, replace=True) + created["agent"] = agent + log_object_details("create_agent", "agent", agent) + fetched_agent = await AsyncAgent.fetch(SQL_AGENT_NAME) + assert fetched_agent.agent_name == SQL_AGENT_NAME + assert fetched_agent.attributes.profile_name == SQL_PROFILE_NAME + assert fetched_agent.attributes.enable_human_tool is False + + # Create the team that wires the agent to the task. + with log_step("Create SQL data team"): + team = AsyncTeam( + team_name=SQL_TEAM_NAME, + attributes=select_ai.agent.TeamAttributes( + agents=[{"name": SQL_AGENT_NAME, "task": SQL_TASK_NAME}], + process="sequential", + ), + ) + await team.create(replace=True, enabled=True) + created["team"] = team + log_object_details("create_team", "team", team) + fetched_team = await AsyncTeam.fetch(SQL_TEAM_NAME) + assert fetched_team.team_name == SQL_TEAM_NAME + assert fetched_team.attributes.process == "sequential" + assert fetched_team.attributes.agents == [ + {"name": SQL_AGENT_NAME, "task": SQL_TASK_NAME} + ] + + yield team + finally: + await _cleanup_async_sql_team_objects(created) + + +@pytest.fixture(scope="module") +async def async_sql_team(): + async for team in create_async_sql_team(): + yield team + + +async def test_async_sql_team_runs(async_sql_team): + # Run the team with a sample prompt and verify a response is returned. + with log_step("Run async SQL team"): + conversation_id = str(uuid.uuid4()) + prompt = "List tables in the SH schema?" + logger.info( + "Running team | team=%s | conversation_id=%s | prompt=%s", + async_sql_team.team_name, + conversation_id, + prompt, + ) + response = await async_sql_team.run( + prompt=prompt, + params={"conversation_id": conversation_id}, + ) + logger.info("Agent Response: %s", response) + assert response is not None + assert isinstance(response, str) + assert len(response.strip()) > 0 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/tests/agents/test_3900_sql_team.py b/tests/agents/test_3900_sql_team.py new file mode 100644 index 0000000..b12e7e2 --- /dev/null +++ b/tests/agents/test_3900_sql_team.py @@ -0,0 +1,482 @@ +import os +import logging +import uuid +from contextlib import contextmanager + +import select_ai +import pytest +import select_ai.agent + +# Configure file-based logging for this script run. +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +LOG_FILE = os.path.join(PROJECT_ROOT, "log", "test_3900_sql_team.log") +os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) +for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) +file_handler = logging.FileHandler(LOG_FILE, mode="w") +file_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) +root_logger.addHandler(file_handler) + +logger = logging.getLogger(__name__) + +EMAIL_RECIPIENT = os.getenv("PYSAI_TEST_EMAIL_RECIPIENT") +EMAIL_SENDER = os.getenv("PYSAI_TEST_EMAIL_SENDER") +EMAIL_SMTP_HOST = os.getenv("PYSAI_TEST_EMAIL_SMTPHOST") + + +@pytest.fixture(autouse=True) +def log_test_name(request): + logger.info("--- Starting test: %s ---", request.function.__name__) + yield + logger.info("--- Finished test: %s ---", request.function.__name__) + + +@contextmanager +def log_step(step): + logger.info("START: %s", step) + try: + yield + logger.info("END: %s", step) + except Exception: + logger.exception("FAILED: %s", step) + raise + + +def log_object_details(context: str, object_type: str, obj) -> None: + details = {"context": context, "object_type": object_type} + attributes = getattr(obj, "attributes", None) + + if object_type == "profile": + details.update( + { + "profile_name": getattr(obj, "profile_name", None), + "description": getattr(obj, "description", None), + "provider_type": ( + type(getattr(attributes, "provider", None)).__name__ + if attributes is not None and getattr(attributes, "provider", None) + else None + ), + "object_count": ( + len(getattr(attributes, "object_list", []) or []) + if attributes is not None + else None + ), + } + ) + elif object_type == "agent": + details.update( + { + "agent_name": getattr(obj, "agent_name", None), + "description": getattr(obj, "description", None), + "profile_name": ( + getattr(attributes, "profile_name", None) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "tool": + details.update( + { + "tool_name": getattr(obj, "tool_name", None), + "description": getattr(obj, "description", None), + "tool_type": ( + getattr(attributes, "tool_type", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "task": + details.update( + { + "task_name": getattr(obj, "task_name", None), + "description": getattr(obj, "description", None), + "tool_count": ( + len(getattr(attributes, "tools", []) or []) + if attributes is not None + else None + ), + "enable_human_tool": ( + getattr(attributes, "enable_human_tool", None) + if attributes is not None + else None + ), + } + ) + elif object_type == "team": + details.update( + { + "team_name": getattr(obj, "team_name", None), + "description": getattr(obj, "description", None), + "process": ( + getattr(attributes, "process", None) + if attributes is not None + else None + ), + "agent_count": ( + len(getattr(attributes, "agents", []) or []) + if attributes is not None + else None + ), + } + ) + else: + details["repr"] = str(obj) + + logger.info("OBJECT_DETAILS: %s", details) + + +def log_credential_setup(credential_name): + logger.info("Preparing credential | name=%s", credential_name) + + +def verify_credential_exists(credential_name, expected_username=None): + logger.info("Verifying credential exists in DB: %s", credential_name) + with select_ai.cursor() as cur: + cur.execute( + """ + SELECT credential_name, username + FROM user_credentials + WHERE UPPER(credential_name) = UPPER(:credential_name) + """, + credential_name=credential_name, + ) + row = cur.fetchone() + + assert row is not None, f"Credential {credential_name} was not created" + + actual_name, actual_username = row + logger.info("Verified credential | name=%s", actual_name) + assert actual_name.upper() == credential_name.upper() + if expected_username is not None: + assert actual_username == expected_username + + +def connect_to_db(): + # Connect the Python client to the test database. + user = os.getenv("PYSAI_TEST_USER") + password = os.getenv("PYSAI_TEST_USER_PASSWORD") + dsn = os.getenv("PYSAI_TEST_CONNECT_STRING") + assert user, "PYSAI_TEST_USER not set" + assert password, "PYSAI_TEST_USER_PASSWORD not set" + assert dsn, "PYSAI_TEST_CONNECT_STRING not set" + logger.info("Connecting to database using configured test credentials") + select_ai.connect(user=user, password=password, dsn=dsn) + + +def _cleanup_sql_team_objects(created) -> None: + with log_step("Cleanup SQL team objects"): + if created["team"] is not None: + try: + logger.info("Deleting team: %s", created["team"].team_name) + created["team"].delete(force=True) + except Exception: + logger.exception("Failed to delete team: %s", created["team"].team_name) + + if created["task"] is not None: + try: + logger.info("Deleting task: %s", created["task"].task_name) + created["task"].delete(force=True) + except Exception: + logger.exception("Failed to delete task: %s", created["task"].task_name) + + for tool in reversed(created["tools"]): + try: + logger.info("Deleting tool: %s", tool.tool_name) + tool.delete(force=True) + except Exception: + logger.exception("Failed to delete tool: %s", tool.tool_name) + + if created["agent"] is not None: + try: + logger.info("Deleting agent: %s", created["agent"].agent_name) + created["agent"].delete(force=True) + except Exception: + logger.exception( + "Failed to delete agent: %s", created["agent"].agent_name + ) + + if created["profile"] is not None: + try: + logger.info("Deleting profile: %s", created["profile"].profile_name) + created["profile"].delete(force=True) + except Exception: + logger.exception( + "Failed to delete profile: %s", created["profile"].profile_name + ) + + for credential_name in reversed(created["credentials"]): + try: + logger.info("Deleting credential: %s", credential_name) + select_ai.delete_credential(credential_name, force=True) + except Exception: + logger.exception("Failed to delete credential: %s", credential_name) + + +def allow_network_acl(): + # Grant the database user SMTP access required by the email notification tool. + with select_ai.cursor() as cur: + try: + cur.execute( + """ + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => :host, + ace => xs$ace_type( + privilege_list => xs$name_list('connect', 'smtp'), + principal_name => SYS_CONTEXT('USERENV', 'CURRENT_USER'), + principal_type => xs_acl.ptype_db + ) + ); + END; + """, + host=EMAIL_SMTP_HOST, + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + + +def create_sql_team(): + created = { + "team": None, + "task": None, + "tools": [], + "agent": None, + "profile": None, + "credentials": [], + } + + # Initialize database access required by the team and tools. + try: + with log_step("Initialize database and network access"): + connect_to_db() + allow_network_acl() + + # Load OCI model and credential settings from the environment. + oci_credential_name = "SQL_TEAM_OCI_CRED" + oci_user_ocid = os.getenv("PYSAI_TEST_OCI_USER_OCID") + oci_tenancy_ocid = os.getenv("PYSAI_TEST_OCI_TENANCY_OCID") + oci_private_key = os.getenv("PYSAI_TEST_OCI_PRIVATE_KEY") + oci_fingerprint = os.getenv("PYSAI_TEST_OCI_FINGERPRINT") + oci_compartment_id = os.getenv("PYSAI_TEST_OCI_COMPARTMENT_ID") + oci_region = "us-chicago-1" + oci_apiformat = "GENERIC" + oci_model = "meta.llama-4-maverick-17b-128e-instruct-fp8" + + assert oci_user_ocid, "PYSAI_TEST_OCI_USER_OCID not set" + assert oci_tenancy_ocid, "PYSAI_TEST_OCI_TENANCY_OCID not set" + assert oci_private_key, "PYSAI_TEST_OCI_PRIVATE_KEY not set" + assert oci_fingerprint, "PYSAI_TEST_OCI_FINGERPRINT not set" + assert oci_compartment_id, "PYSAI_TEST_OCI_COMPARTMENT_ID not set" + logger.info( + "Resolved OCI configuration | credential=%s | region=%s | model=%s", + oci_credential_name, + oci_region, + oci_model, + ) + + # Create the OCI credential used by the Select AI profile. + with log_step("Create OCI credential"): + select_ai.create_credential( + credential={ + "credential_name": oci_credential_name, + "user_ocid": oci_user_ocid, + "tenancy_ocid": oci_tenancy_ocid, + "private_key": oci_private_key, + "fingerprint": oci_fingerprint, + }, + replace=True, + ) + created["credentials"].append(oci_credential_name) + verify_credential_exists(oci_credential_name) + + # Create the profile that backs the SQL agent's model access. + with log_step("Create SQL profile"): + profile = select_ai.Profile( + profile_name="SQL_PROFILE", + attributes=select_ai.ProfileAttributes( + credential_name=oci_credential_name, + provider=select_ai.OCIGenAIProvider( + region=oci_region, + oci_apiformat=oci_apiformat, + model=oci_model, + oci_compartment_id=oci_compartment_id, + ), + object_list=[{"owner": "SH"}], + ), + description="Profile for SQL Agent using OCI GenAI provider.", + replace=True, + ) + created["profile"] = profile + log_object_details("create_profile", "profile", profile) + assert profile.profile_name == "SQL_PROFILE" + + # Create the SQL tool the task will use to query database objects. + with log_step("Create SQL query tool"): + sql_tool = select_ai.agent.Tool.create_sql_tool( + tool_name="SQL_QUERY_TOOL", + profile_name="SQL_PROFILE", + description="Use this tool to query database tables for sales and customer info.", + instruction="Use this tool to execute SQL queries against the database. Only query the 'sales' and 'customers' tables. Always return results in a structured format.", + replace=True, + ) + created["tools"].append(sql_tool) + log_object_details("create_sql_tool", "tool", sql_tool) + fetched_sql_tool = select_ai.agent.Tool.fetch("SQL_QUERY_TOOL") + assert fetched_sql_tool.tool_name == "SQL_QUERY_TOOL" + assert fetched_sql_tool.attributes.tool_params.profile_name == "SQL_PROFILE" + + # Load SMTP settings for the email notification tool. + email_credential_name = ( + os.getenv("PYSAI_TEST_EMAIL_CREDENTIAL_NAME") or "EMAIL_CRED" + ) + email_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") + email_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") + assert email_username, "PYSAI_TEST_EMAIL_CRED_USERNAME not set" + assert email_password, "PYSAI_TEST_EMAIL_CRED_PASSWORD not set" + log_credential_setup(email_credential_name) + + # Create the SMTP credential used by the notification tool. + with log_step("Create email credential"): + select_ai.create_credential( + credential={ + "credential_name": email_credential_name, + "username": email_username, + "password": email_password, + }, + replace=True, + ) + created["credentials"].append(email_credential_name) + verify_credential_exists( + email_credential_name, expected_username=email_username + ) + + # Create the built-in email notification tool. + with log_step("Create email notification tool"): + email_tool = select_ai.agent.Tool.create_email_notification_tool( + tool_name="EMAIL_NOTIFICATION_TOOL", + credential_name=email_credential_name, + subject="SQL Analysis Results", + recipient=EMAIL_RECIPIENT, + sender=EMAIL_SENDER, + smtp_host=EMAIL_SMTP_HOST, + description="Send notification emails for SQL insights", + replace=True, + ) + created["tools"].append(email_tool) + log_object_details("create_email_tool", "tool", email_tool) + fetched_email_tool = select_ai.agent.Tool.fetch("EMAIL_NOTIFICATION_TOOL") + assert fetched_email_tool.tool_name == "EMAIL_NOTIFICATION_TOOL" + assert ( + fetched_email_tool.attributes.tool_params.credential_name + == email_credential_name + ) + + # Create the task that combines SQL analysis with email delivery. + with log_step("Create SQL analysis task"): + task = select_ai.agent.Task( + task_name="SQL_ANALYSIS_TASK", + attributes=select_ai.agent.TaskAttributes( + instruction=( + "Answer the user query by querying the database: {query}. " + "After you produce the answer, send a concise summary of the findings " + "to the analytics stakeholders using the EMAIL_NOTIFICATION_TOOL. " + "Include the SQL results and any key metrics in the email body." + ), + tools=["SQL_QUERY_TOOL", "EMAIL_NOTIFICATION_TOOL"], + ), + ) + task.create(replace=True) + created["task"] = task + log_object_details("create_task", "task", task) + fetched_task = select_ai.agent.Task.fetch("SQL_ANALYSIS_TASK") + assert fetched_task.task_name == "SQL_ANALYSIS_TASK" + assert set(fetched_task.attributes.tools) == { + "SQL_QUERY_TOOL", + "EMAIL_NOTIFICATION_TOOL", + } + + # Create the agent that will execute the SQL analysis task. + with log_step("Create SQL analyst agent"): + agent = select_ai.agent.Agent( + agent_name="SQL_ANALYST_AGENT", + attributes=select_ai.agent.AgentAttributes( + profile_name="SQL_PROFILE", + role="You are a data analyst that translates natural language to SQL.", + enable_human_tool=False, + ), + ) + agent.create(enabled=True, replace=True) + created["agent"] = agent + log_object_details("create_agent", "agent", agent) + fetched_agent = select_ai.agent.Agent.fetch("SQL_ANALYST_AGENT") + assert fetched_agent.agent_name == "SQL_ANALYST_AGENT" + assert fetched_agent.attributes.profile_name == "SQL_PROFILE" + assert fetched_agent.attributes.enable_human_tool is False + + # Create the team that wires the agent to the task. + with log_step("Create SQL data team"): + team = select_ai.agent.Team( + team_name="SQL_DATA_TEAM", + attributes=select_ai.agent.TeamAttributes( + agents=[{"name": "SQL_ANALYST_AGENT", "task": "SQL_ANALYSIS_TASK"}], + process="sequential", + ) + ) + team.create(replace=True, enabled=True) + created["team"] = team + log_object_details("create_team", "team", team) + fetched_team = select_ai.agent.Team.fetch("SQL_DATA_TEAM") + assert fetched_team.team_name == "SQL_DATA_TEAM" + assert fetched_team.attributes.process == "sequential" + assert fetched_team.attributes.agents == [ + {"name": "SQL_ANALYST_AGENT", "task": "SQL_ANALYSIS_TASK"} + ] + yield team + finally: + _cleanup_sql_team_objects(created) + + +@pytest.fixture(scope="module") +def sql_team(): + yield from create_sql_team() + + +def test_sql_team_runs(sql_team): + # Run the team with a sample prompt and verify a response is returned. + with log_step("Run SQL team"): + conversation_id = str(uuid.uuid4()) + prompt = "List tables in the SH schema?" + logger.info( + "Running team | team=%s | conversation_id=%s | prompt=%s", + sql_team.team_name, + conversation_id, + prompt, + ) + response = sql_team.run( + prompt=prompt, + params={"conversation_id": conversation_id}, + ) + logger.info("Agent Response: %s", response) + assert response is not None + assert isinstance(response, str) + assert len(response.strip()) > 0 + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/tests/conftest.py b/tests/conftest.py index bb6cc1b..1ede4fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,6 @@ # PYSAI_TEST_CONNECT_STRING: connect string for test suite # PYSAI_TEST_WALLET_LOCATION: location of wallet file (thin mode, mTLS) # PYSAI_TEST_WALLET_PASSWORD: password for wallet file (thin mode, mTLS) -# PYSAI_TEST_MIN_POOL_SIZE: Minimum number of connections in the pool -# PYSAI_TEST_MAX_POOL_SIZE: Maximum number of connections in the pool -# PYSAI_TEST_POOL_INCREMENT # # OCI Gen AI # PYSAI_TEST_OCI_USER_OCID @@ -36,37 +33,6 @@ PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" -_BASIC_SCHEMA_PRIVILEGES = ( - "CREATE SESSION", - "CREATE TABLE", - "UNLIMITED TABLESPACE", -) - - -def _ensure_test_user_exists(username: str, password: str): - username_upper = username.upper() - with select_ai.cursor() as cr: - cr.execute( - "SELECT 1 FROM dba_users WHERE username = :username", - username=username_upper, - ) - if cr.fetchone(): - return - escaped_password = password.replace('"', '""') - cr.execute( - f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' - ) - with select_ai.db.get_connection() as conn: - conn.commit() - - -def _grant_basic_schema_privileges(username: str): - username_upper = username.upper() - with select_ai.cursor() as cr: - for privilege in _BASIC_SCHEMA_PRIVILEGES: - cr.execute(f"GRANT {privilege} TO {username_upper}") - with select_ai.db.get_connection() as conn: - conn.commit() def get_env_value(name, default_value=None, required=False): @@ -95,17 +61,8 @@ def __init__(self): self.admin_password = get_env_value("ADMIN_PASSWORD") self.wallet_location = get_env_value("WALLET_LOCATION") self.wallet_password = get_env_value("WALLET_PASSWORD") - self.min_pool_size = int( - get_env_value("MIN_POOL_SIZE", default_value=2) - ) - self.max_pool_size = int( - get_env_value("MAX_POOL_SIZE", default_value=4) - ) - self.pool_increment = int( - get_env_value("POOL_INCREMENT", default_value=1) - ) - - def connect_params(self, admin: bool = False, use_pool: bool = False): + + def connect_params(self, admin: bool = False): """ Returns connect params """ @@ -119,10 +76,6 @@ def connect_params(self, admin: bool = False, use_pool: bool = False): "wallet_password": self.wallet_password, "config_dir": self.wallet_location, } - if use_pool: - connect_params["min_size"] = self.min_pool_size - connect_params["max_size"] = self.max_pool_size - connect_params["increment"] = self.pool_increment return connect_params @@ -137,46 +90,39 @@ def test_env(pytestconfig): return env -@pytest.fixture(autouse=True, scope="session") -def setup_test_user(test_env): - select_ai.connect(**test_env.connect_params(admin=True)) - _ensure_test_user_exists( - username=test_env.test_user, - password=test_env.test_user_password, - ) - _grant_basic_schema_privileges(username=test_env.test_user) - select_ai.grant_privileges(users=[test_env.test_user]) - select_ai.grant_http_access( - users=[test_env.test_user], - provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, - ) - select_ai.disconnect() +# @pytest.fixture(autouse=True, scope="session") +# def setup_test_user(test_env): +# select_ai.connect(**test_env.connect_params(admin=True)) +# select_ai.grant_privileges(users=[test_env.test_user]) +# select_ai.grant_http_access( +# users=[test_env.test_user], +# provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, +# ) +# select_ai.disconnect() -@pytest.fixture(autouse=True, scope="module") -def connect(setup_test_user, test_env): - select_ai.create_pool(**test_env.connect_params(use_pool=True)) +@pytest.fixture(autouse=True, scope="session") +def connect(test_env): + select_ai.connect(**test_env.connect_params()) yield select_ai.disconnect() -@pytest.fixture(autouse=True, scope="module") -async def async_connect(setup_test_user, test_env, anyio_backend): - select_ai.create_pool_async(**test_env.connect_params(use_pool=True)) - yield - await select_ai.async_disconnect() +# @pytest.fixture(autouse=True, scope="session") +# async def async_connect(test_env, anyio_backend): +# await select_ai.async_connect(**test_env.connect_params()) +# yield +# await select_ai.async_disconnect() @pytest.fixture def connection(): - with select_ai.db.get_connection() as conn: - yield conn + return select_ai.db.get_connection() @pytest.fixture def async_connection(): - with select_ai.db.async_get_connection() as conn: - yield conn + return select_ai.db.async_get_connection() @pytest.fixture(scope="module") @@ -185,13 +131,13 @@ def cursor(): yield cr -@pytest.fixture(scope="module") +@pytest.fixture async def async_cursor(): async with select_ai.async_cursor() as cr: yield cr -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="session") def oci_credential(connect, test_env): credential = { "credential_name": PYSAI_OCI_CREDENTIAL_NAME, From 11a4d1ec0fd023738874c90eacf4f89429ffcfb2 Mon Sep 17 00:00:00 2001 From: Prateek Saxena Date: Thu, 16 Apr 2026 09:15:47 +0000 Subject: [PATCH 5/6] Stabilize provider coverage and chat session tests --- pyproject.toml | 4 +++ tests/agents/conftest.py | 3 +- tests/conftest.py | 32 +++++++++++++++++++ tests/profiles/conftest.py | 10 ++++-- tests/profiles/test_1300_profile_async.py | 3 +- tests/profiles/test_1800_chat_session.py | 14 ++++++-- .../profiles/test_1900_chat_session_async.py | 14 ++++++-- 7 files changed, 69 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7758495..034538d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,3 +77,7 @@ minversion = "8.3.0" testpaths = [ "tests" ] +markers = [ + "unit: fast offline unit tests", + "provider: manual live provider smoke tests", +] diff --git a/tests/agents/conftest.py b/tests/agents/conftest.py index 386ab8f..dcfc341 100644 --- a/tests/agents/conftest.py +++ b/tests/agents/conftest.py @@ -10,11 +10,12 @@ @pytest.fixture(scope="module") -def provider(): +def provider(oci_compartment_id): return select_ai.OCIGenAIProvider( region="us-chicago-1", oci_apiformat="GENERIC", model="meta.llama-4-maverick-17b-128e-instruct-fp8", + oci_compartment_id=oci_compartment_id, ) diff --git a/tests/conftest.py b/tests/conftest.py index 1ede4fd..e75f34a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,38 @@ PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "CREATE PROCEDURE", + "UNLIMITED TABLESPACE", +) + + +def _ensure_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + with select_ai.db.get_connection() as conn: + conn.commit() + + +def _grant_basic_schema_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + with select_ai.db.get_connection() as conn: + conn.commit() def get_env_value(name, default_value=None, required=False): diff --git a/tests/profiles/conftest.py b/tests/profiles/conftest.py index 601cc2f..53ffa99 100644 --- a/tests/profiles/conftest.py +++ b/tests/profiles/conftest.py @@ -57,9 +57,11 @@ def log_test_case(request, configure_module_logging): @pytest.fixture(scope="module") -def provider(): +def provider(oci_compartment_id): return select_ai.OCIGenAIProvider( - region="us-phoenix-1", oci_apiformat="GENERIC" + region="us-phoenix-1", + oci_apiformat="GENERIC", + oci_compartment_id=oci_compartment_id, ) @@ -76,5 +78,7 @@ def profile_attributes(provider, oci_credential): def min_profile_attributes(provider, oci_credential): return select_ai.ProfileAttributes( credential_name=oci_credential["credential_name"], - provider=select_ai.OCIGenAIProvider(), + provider=select_ai.OCIGenAIProvider( + oci_compartment_id=provider.oci_compartment_id + ), ) diff --git a/tests/profiles/test_1300_profile_async.py b/tests/profiles/test_1300_profile_async.py index 3f0d8c6..c9cd5a6 100644 --- a/tests/profiles/test_1300_profile_async.py +++ b/tests/profiles/test_1300_profile_async.py @@ -277,7 +277,7 @@ async def test_1307(): assert profile.attributes.provider.model == "meta.llama-3.1-70b-instruct" -async def test_1308(oci_credential): +async def test_1308(oci_credential, oci_compartment_id): """Set multiple attributes for a Profile""" logger.info( "Setting multiple attributes for async profile %s", @@ -289,6 +289,7 @@ async def test_1308(oci_credential): provider=select_ai.OCIGenAIProvider( model="meta.llama-4-maverick-17b-128e-instruct-fp8", region="us-chicago-1", + oci_compartment_id=oci_compartment_id, oci_apiformat="GENERIC", ), object_list=[{"owner": "ADMIN", "name": "gymnasts"}], diff --git a/tests/profiles/test_1800_chat_session.py b/tests/profiles/test_1800_chat_session.py index 4fcbadd..2e2ac5d 100644 --- a/tests/profiles/test_1800_chat_session.py +++ b/tests/profiles/test_1800_chat_session.py @@ -56,7 +56,10 @@ ], "general": [ ("What is the capital of Japan?", "tokyo"), - ("Tell me a fun fact about space.", "space"), + ( + "Tell me a fun fact about space.", + ("space", "jupiter", "planet", "solar system"), + ), ("Who invented the telephone?", "telephone"), ("What is the fastest land animal?", "cheetah"), ("Explain why the sky looks blue.", "sky"), @@ -120,10 +123,15 @@ def _create(**kwargs): def _assert_keywords(session, prompts): - for prompt, keyword in prompts: + for prompt, expected_keywords in prompts: response = session.chat(prompt=prompt) logger.debug("Received response for prompt '%s': %s", prompt, response) - assert keyword.lower() in response.lower() + keywords = ( + expected_keywords + if isinstance(expected_keywords, (tuple, list, set)) + else (expected_keywords,) + ) + assert any(item.lower() in response.lower() for item in keywords) def test_1800_database_chat_session( diff --git a/tests/profiles/test_1900_chat_session_async.py b/tests/profiles/test_1900_chat_session_async.py index 1450eac..24ece5e 100644 --- a/tests/profiles/test_1900_chat_session_async.py +++ b/tests/profiles/test_1900_chat_session_async.py @@ -56,7 +56,10 @@ ], "general": [ ("What is the capital of Japan?", "tokyo"), - ("Tell me a fun fact about space.", "space"), + ( + "Tell me a fun fact about space.", + ("space", "jupiter", "planet", "solar system"), + ), ("Who invented the telephone?", "telephone"), ("What is the fastest land animal?", "cheetah"), ("Explain why the sky looks blue.", "sky"), @@ -124,10 +127,15 @@ async def _create(**kwargs): async def _assert_keywords(session, prompts): - for prompt, keyword in prompts: + for prompt, expected_keywords in prompts: response = await session.chat(prompt=prompt) logger.debug("Async response for prompt '%s': %s", prompt, response) - assert keyword.lower() in response.lower() + keywords = ( + expected_keywords + if isinstance(expected_keywords, (tuple, list, set)) + else (expected_keywords,) + ) + assert any(item.lower() in response.lower() for item in keywords) @pytest.mark.anyio From 0740d3e083dd5e1e92c9d99c76b0cb4843dc7900 Mon Sep 17 00:00:00 2001 From: Prateek Saxena Date: Thu, 16 Apr 2026 18:27:06 +0000 Subject: [PATCH 6/6] Fix connection fixtures for pytest schema setup --- tests/conftest.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e75f34a..ef39fd2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -149,12 +149,14 @@ def connect(test_env): @pytest.fixture def connection(): - return select_ai.db.get_connection() + with select_ai.db.get_connection() as conn: + yield conn @pytest.fixture -def async_connection(): - return select_ai.db.async_get_connection() +async def async_connection(): + async with select_ai.db.async_get_connection() as conn: + yield conn @pytest.fixture(scope="module")