From 28ca634b02cb4fe0c7164c0094f19a159beb0cde Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Sun, 19 Apr 2026 21:38:12 -0700 Subject: [PATCH 1/3] Changes to fix Github build --- .gitignore | 1 + tests/agents/test_3001_async_tools.py | 125 ++++++++++++----- tests/agents/test_3101_async_tasks.py | 54 +++++--- tests/agents/test_3201_async_agents.py | 40 +++--- tests/agents/test_3301_async_teams.py | 41 +++--- tests/agents/test_3800_agente2e.py | 122 ++++++++--------- tests/agents/test_3800_async_agente2e.py | 97 ++++---------- tests/agents/test_3900_async_sql_team.py | 92 +++++-------- tests/agents/test_3900_sql_team.py | 98 ++++++-------- tests/conftest.py | 163 +++++++++++++++++++---- 10 files changed, 465 insertions(+), 368 deletions(-) diff --git a/.gitignore b/.gitignore index d2ff10f..1bf21f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ my_examples logs/ +log/ .idea .env .venv diff --git a/tests/agents/test_3001_async_tools.py b/tests/agents/test_3001_async_tools.py index 5157c2d..a5d4668 100644 --- a/tests/agents/test_3001_async_tools.py +++ b/tests/agents/test_3001_async_tools.py @@ -22,7 +22,9 @@ pytestmark = pytest.mark.anyio # Path -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3001_async_tools.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -87,15 +89,6 @@ def log_test_name(request): logger.info("--- Finished test: %s ---", request.function.__name__) -@pytest.fixture(scope="module", autouse=True) -async def async_connect(test_env): - logger.info("Opening async database connection") - await select_ai.async_connect(**test_env.connect_params()) - yield - logger.info("Closing async database connection") - await select_ai.async_disconnect() - - async def get_tool_status(tool_name): logger.info("Fetching tool status for: %s", tool_name) async with select_ai.async_cursor() as cur: @@ -134,9 +127,11 @@ def log_tool_details(context: str, tool) -> None: "instruction": getattr(attrs, "instruction", None) if attrs else None, "function": getattr(attrs, "function", None) if attrs else None, "tool_inputs": getattr(attrs, "tool_inputs", None) if attrs else None, - "tool_params": tool_params.dict(exclude_null=False) - if tool_params is not None - else None, + "tool_params": ( + tool_params.dict(exclude_null=False) + if tool_params is not None + else None + ), } logger.info("TOOL_DETAILS: %s", details) @@ -261,7 +256,9 @@ async def email_credential(): logger.info("EMAIL credential did not exist or could not be dropped") pass - await select_ai.async_create_credential(credential=credential, replace=True) + await select_ai.async_create_credential( + credential=credential, replace=True + ) logger.info("Created EMAIL credential: %s", EMAIL_CRED_NAME) yield EMAIL_CRED_NAME @@ -288,7 +285,9 @@ async def slack_credential(): logger.info("SLACK credential did not exist or could not be dropped") pass - await select_ai.async_create_credential(credential=credential, replace=True) + await select_ai.async_create_credential( + credential=credential, replace=True + ) logger.info("Created SLACK credential: %s", SLACK_CRED_NAME) yield SLACK_CRED_NAME @@ -345,27 +344,35 @@ async def slack_tool(slack_credential): @pytest.fixture(scope="module") async def neg_sql_tool(): - logger.info("Creating SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME) + logger.info( + "Creating SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME + ) tool = await AsyncTool.create_sql_tool( tool_name=NEG_SQL_TOOL_NAME, profile_name="NON_EXISTENT_PROFILE", replace=True, ) yield tool - logger.info("Deleting SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME) + logger.info( + "Deleting SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME + ) await tool.delete(force=True) @pytest.fixture(scope="module") async def neg_rag_tool(): - logger.info("Creating RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME) + logger.info( + "Creating RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME + ) tool = await AsyncTool.create_rag_tool( tool_name=NEG_RAG_TOOL_NAME, profile_name="NON_EXISTENT_RAG_PROFILE", replace=True, ) yield tool - logger.info("Deleting RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME) + logger.info( + "Deleting RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME + ) await tool.delete(force=True) @@ -470,21 +477,34 @@ async def test_3006_enable_disable_sql_tool(sql_tool): async def test_3007_web_search_tool_created(web_search_tool): - logger.info("Validating Web Search tool creation: %s", WEB_SEARCH_TOOL_NAME) + logger.info( + "Validating Web Search tool creation: %s", WEB_SEARCH_TOOL_NAME + ) log_tool_details("test_3007_web_search_tool_created", web_search_tool) assert web_search_tool.tool_name == WEB_SEARCH_TOOL_NAME - assert web_search_tool.attributes.tool_type == select_ai.agent.ToolType.WEBSEARCH - assert web_search_tool.attributes.tool_params.credential_name == "OPENAI_CRED" + assert ( + web_search_tool.attributes.tool_type + == select_ai.agent.ToolType.WEBSEARCH + ) + assert ( + web_search_tool.attributes.tool_params.credential_name == "OPENAI_CRED" + ) async def test_3008_email_tool_created(email_tool): logger.info("Validating EMAIL tool creation: %s", EMAIL_TOOL_NAME) log_tool_details("test_3008_email_tool_created", email_tool) assert email_tool.tool_name == EMAIL_TOOL_NAME - assert str(email_tool.attributes.tool_type).upper() in ("EMAIL", "NOTIFICATION") + assert str(email_tool.attributes.tool_type).upper() in ( + "EMAIL", + "NOTIFICATION", + ) assert email_tool.attributes.tool_params.credential_name == EMAIL_CRED_NAME assert email_tool.attributes.tool_params.smtp_host is not None - assert str(email_tool.attributes.tool_params.notification_type).lower() == "email" + assert ( + str(email_tool.attributes.tool_params.notification_type).lower() + == "email" + ) async def test_3009_slack_tool_created(slack_tool): @@ -492,11 +512,22 @@ async def test_3009_slack_tool_created(slack_tool): if slack_tool is not None: log_tool_details("test_3009_slack_tool_created", slack_tool) assert slack_tool.tool_name == SLACK_TOOL_NAME - assert str(slack_tool.attributes.tool_type).upper() in ("SLACK", "NOTIFICATION") - assert slack_tool.attributes.tool_params.credential_name == SLACK_CRED_NAME - assert str(slack_tool.attributes.tool_params.notification_type).lower() == "slack" + assert str(slack_tool.attributes.tool_type).upper() in ( + "SLACK", + "NOTIFICATION", + ) + assert ( + slack_tool.attributes.tool_params.credential_name + == SLACK_CRED_NAME + ) + assert ( + str(slack_tool.attributes.tool_params.notification_type).lower() + == "slack" + ) else: - logger.info("SLACK tool not created due to expected backend-side error") + logger.info( + "SLACK tool not created due to expected backend-side error" + ) async def test_3010_custom_tool_attributes_roundtrip(): @@ -536,7 +567,10 @@ async def test_3010_custom_tool_attributes_roundtrip(): ) assert isinstance(fetched.attributes.tool_inputs, list) assert fetched.attributes.tool_inputs[0]["name"] == "p_birth_date" - assert "birth date" in fetched.attributes.tool_inputs[0]["description"].lower() + assert ( + "birth date" + in fetched.attributes.tool_inputs[0]["description"].lower() + ) finally: await tool.delete(force=True) @@ -565,12 +599,16 @@ async def test_3011_custom_tool_without_tool_type(): assert fetched.tool_name == CUSTOM_NO_TYPE_TOOL_NAME assert fetched.attributes.tool_type is None assert fetched.attributes.function == PLSQL_FUNCTION_NAME - assert fetched.attributes.instruction == "Calculate age from birth date" + assert ( + fetched.attributes.instruction == "Calculate age from birth date" + ) finally: await tool.delete(force=True) -async def test_3012_custom_tool_with_tool_type_without_instruction(sql_profile): +async def test_3012_custom_tool_with_tool_type_without_instruction( + sql_profile, +): logger.info( "Validating custom tool creation with tool_type set and instruction unset" ) @@ -631,15 +669,22 @@ async def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile): async def test_3014_sql_tool_with_invalid_profile_created(neg_sql_tool): logger.info("Validating SQL tool with invalid profile") - log_tool_details("test_3010_sql_tool_with_invalid_profile_created", neg_sql_tool) + log_tool_details( + "test_3010_sql_tool_with_invalid_profile_created", neg_sql_tool + ) assert neg_sql_tool.tool_name == NEG_SQL_TOOL_NAME assert neg_sql_tool.attributes.tool_type == select_ai.agent.ToolType.SQL - assert neg_sql_tool.attributes.tool_params.profile_name == "NON_EXISTENT_PROFILE" + assert ( + neg_sql_tool.attributes.tool_params.profile_name + == "NON_EXISTENT_PROFILE" + ) async def test_3015_rag_tool_with_invalid_profile_created(neg_rag_tool): logger.info("Validating RAG tool with invalid profile") - log_tool_details("test_3011_rag_tool_with_invalid_profile_created", neg_rag_tool) + log_tool_details( + "test_3011_rag_tool_with_invalid_profile_created", neg_rag_tool + ) assert neg_rag_tool.tool_name == NEG_RAG_TOOL_NAME assert neg_rag_tool.attributes.tool_type == select_ai.agent.ToolType.RAG assert ( @@ -691,12 +736,16 @@ async def test_3020_create_tool_default_status_enabled(sql_profile): tool = await AsyncTool.create_built_in_tool( tool_name=DEFAULT_STATUS_TOOL_NAME, tool_type=select_ai.agent.ToolType.SQL, - tool_params=select_ai.agent.SQLToolParams(profile_name=SQL_PROFILE_NAME), + tool_params=select_ai.agent.SQLToolParams( + profile_name=SQL_PROFILE_NAME + ), ) try: await assert_tool_status(DEFAULT_STATUS_TOOL_NAME, "ENABLED") fetched = await AsyncTool.fetch(DEFAULT_STATUS_TOOL_NAME) - log_tool_details("test_3016_create_tool_default_status_enabled", fetched) + log_tool_details( + "test_3016_create_tool_default_status_enabled", fetched + ) logger.info( "Fetched created tool | name=%s | type=%s | profile=%s", fetched.tool_name, @@ -771,7 +820,9 @@ async def test_3024_http_tool_created(email_credential): fetched = await AsyncTool.fetch(HTTP_TOOL_NAME) assert fetched.tool_name == HTTP_TOOL_NAME assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP - assert fetched.attributes.tool_params.credential_name == email_credential + assert ( + fetched.attributes.tool_params.credential_name == email_credential + ) assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT finally: logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME) diff --git a/tests/agents/test_3101_async_tasks.py b/tests/agents/test_3101_async_tasks.py index 370158c..049cd08 100644 --- a/tests/agents/test_3101_async_tasks.py +++ b/tests/agents/test_3101_async_tasks.py @@ -9,9 +9,9 @@ 3101 - Module for testing select_ai agent async tasks """ -import uuid import logging import os +import uuid import oracledb import pytest @@ -20,7 +20,9 @@ pytestmark = pytest.mark.anyio -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3101_async_tasks.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -35,7 +37,9 @@ PYSAI_3100_TASK_NAME = f"PYSAI_3100_{uuid.uuid4().hex.upper()}" PYSAI_3100_SQL_TASK_DESCRIPTION = "PYSAI_3100_SQL_TASK_DESCRIPTION" -PYSAI_3100_DISABLED_TASK_NAME = f"PYSAI_3100_DISABLED_{uuid.uuid4().hex.upper()}" +PYSAI_3100_DISABLED_TASK_NAME = ( + f"PYSAI_3100_DISABLED_{uuid.uuid4().hex.upper()}" +) PYSAI_3100_DEFAULT_STATUS_TASK_NAME = ( f"PYSAI_3100_DEFAULT_STATUS_{uuid.uuid4().hex.upper()}" ) @@ -54,15 +58,6 @@ def log_test_name(request): logger.info("--- Finished test: %s ---", request.function.__name__) -@pytest.fixture(scope="module", autouse=True) -async def async_connect(test_env): - logger.info("Opening async database connection") - await select_ai.async_connect(**test_env.connect_params()) - yield - logger.info("Closing async database connection") - await select_ai.async_disconnect() - - async def get_task_status(task_name): logger.info("Fetching task status for: %s", task_name) async with select_ai.async_cursor() as cur: @@ -145,7 +140,10 @@ async def test_3100(task, task_attributes): async def test_3101(task_name_pattern): """task list""" if task_name_pattern: - tasks = [task async for task in select_ai.agent.AsyncTask.list(task_name_pattern)] + tasks = [ + task + async for task in select_ai.agent.AsyncTask.list(task_name_pattern) + ] else: tasks = [task async for task in select_ai.agent.AsyncTask.list()] for task in tasks: @@ -182,7 +180,9 @@ async def test_3103_create_task_default_status_enabled(): ) await task.create(replace=True) try: - await assert_task_status(PYSAI_3100_DEFAULT_STATUS_TASK_NAME, "ENABLED") + await assert_task_status( + PYSAI_3100_DEFAULT_STATUS_TASK_NAME, "ENABLED" + ) fetched = await AsyncTask.fetch(PYSAI_3100_DEFAULT_STATUS_TASK_NAME) log_task_details("test_3103", fetched) assert fetched.description == "Default status should be enabled" @@ -208,7 +208,9 @@ async def test_3104_create_task_with_enabled_false_sets_disabled(): log_task_details("test_3104", fetched) assert fetched.description == "Task created disabled" - logger.info("Enabling task created with enabled=False: %s", task.task_name) + logger.info( + "Enabling task created with enabled=False: %s", task.task_name + ) await task.enable() await assert_task_status(PYSAI_3100_DISABLED_TASK_NAME, "ENABLED") finally: @@ -226,7 +228,9 @@ async def test_3105_disable_enable_task(task): async def test_3105b_set_single_attribute_invalid(task): - logger.info("Setting invalid single attribute for async task: %s", task.task_name) + logger.info( + "Setting invalid single attribute for async task: %s", task.task_name + ) with pytest.raises(oracledb.DatabaseError) as exc: await task.set_attribute("description", "New Desc") logger.info("Received expected Oracle error: %s", exc.value) @@ -234,7 +238,9 @@ async def test_3105b_set_single_attribute_invalid(task): async def test_3105c_duplicate_task_creation_fails(task): - logger.info("Creating duplicate async task without replace: %s", task.task_name) + logger.info( + "Creating duplicate async task without replace: %s", task.task_name + ) dup = AsyncTask( task_name=task.task_name, description="Duplicate task", @@ -256,7 +262,10 @@ async def test_3105d_invalid_regex_pattern(): async def test_3106_drop_task_force_true_non_existent(): - logger.info("Dropping missing task with force=True: %s", PYSAI_3100_MISSING_TASK_NAME) + logger.info( + "Dropping missing task with force=True: %s", + PYSAI_3100_MISSING_TASK_NAME, + ) task = AsyncTask(task_name=PYSAI_3100_MISSING_TASK_NAME) await task.delete(force=True) status = await get_task_status(PYSAI_3100_MISSING_TASK_NAME) @@ -265,7 +274,10 @@ async def test_3106_drop_task_force_true_non_existent(): async def test_3107_drop_task_force_false_non_existent_raises(): - logger.info("Dropping missing task with force=False: %s", PYSAI_3100_MISSING_TASK_NAME) + logger.info( + "Dropping missing task with force=False: %s", + PYSAI_3100_MISSING_TASK_NAME, + ) task = AsyncTask(task_name=PYSAI_3100_MISSING_TASK_NAME) with pytest.raises(oracledb.Error) as exc: await task.delete(force=False) @@ -376,5 +388,7 @@ async def test_3112_enable_deleted_task_object_raises(): logger.info("Attempting to enable deleted task using same object") with pytest.raises(oracledb.DatabaseError) as exc: await task.enable() - logger.info("Received expected error when enabling deleted task: %s", exc.value) + logger.info( + "Received expected error when enabling deleted task: %s", exc.value + ) assert "ORA-20051" in str(exc.value) diff --git a/tests/agents/test_3201_async_agents.py b/tests/agents/test_3201_async_agents.py index 6c0d021..5f92677 100644 --- a/tests/agents/test_3201_async_agents.py +++ b/tests/agents/test_3201_async_agents.py @@ -21,7 +21,9 @@ pytestmark = pytest.mark.anyio -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3201_async_agents.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -52,22 +54,15 @@ def log_test_name(request): logger.info("--- Finished test: %s ---", request.function.__name__) -@pytest.fixture(scope="module", autouse=True) -async def async_connect(test_env): - logger.info("Opening async database connection") - await select_ai.async_connect(**test_env.connect_params()) - yield - logger.info("Closing async database connection") - await select_ai.async_disconnect() - - def log_agent_details(context: str, agent) -> None: attrs = getattr(agent, "attributes", None) details = { "context": context, "agent_name": getattr(agent, "agent_name", None), "description": getattr(agent, "description", None), - "profile_name": getattr(attrs, "profile_name", None) if attrs else None, + "profile_name": ( + getattr(attrs, "profile_name", None) if attrs else None + ), "role": getattr(attrs, "role", None) if attrs else None, "enable_human_tool": ( getattr(attrs, "enable_human_tool", None) if attrs else None @@ -164,12 +159,15 @@ async def test_3200_identity(agent, agent_attributes): assert agent.attributes.enable_human_tool is False -@pytest.mark.parametrize("agent_name_pattern", [None, ".*", "^PYSAI_3200_AGENT_"]) +@pytest.mark.parametrize( + "agent_name_pattern", [None, ".*", "^PYSAI_3200_AGENT_"] +) async def test_3201_list(agent_name_pattern): logger.info("Listing agents with pattern=%s", agent_name_pattern) if agent_name_pattern: agents = [ - a async for a in select_ai.agent.AsyncAgent.list(agent_name_pattern) + a + async for a in select_ai.agent.AsyncAgent.list(agent_name_pattern) ] else: agents = [a async for a in select_ai.agent.AsyncAgent.list()] @@ -212,13 +210,17 @@ async def test_3204_create_agent_default_status_enabled(agent_attributes): try: await assert_agent_status(name, "ENABLED") fetched = await AsyncAgent.fetch(name) - log_agent_details("test_3204_create_agent_default_status_enabled", fetched) + log_agent_details( + "test_3204_create_agent_default_status_enabled", fetched + ) assert fetched.description == "Default enabled status" finally: await a.delete(force=True) -async def test_3205_create_agent_with_enabled_false_sets_disabled(agent_attributes): +async def test_3205_create_agent_with_enabled_false_sets_disabled( + agent_attributes, +): a = AsyncAgent( agent_name=PYSAI_3200_DISABLED_AGENT_NAME, description="Initially disabled", @@ -310,12 +312,16 @@ async def test_3213_disable_enable(agent): async def test_3214_set_attribute_none(agent): logger.info("Setting role=None on async agent: %s", agent.agent_name) - await expect_async_error("ORA-20050", lambda: agent.set_attribute("role", None)) + await expect_async_error( + "ORA-20050", lambda: agent.set_attribute("role", None) + ) async def test_3215_set_attribute_empty(agent): logger.info("Setting role='' on async agent: %s", agent.agent_name) - await expect_async_error("ORA-20050", lambda: agent.set_attribute("role", "")) + await expect_async_error( + "ORA-20050", lambda: agent.set_attribute("role", "") + ) async def test_3216_create_existing_without_replace(agent_attributes): diff --git a/tests/agents/test_3301_async_teams.py b/tests/agents/test_3301_async_teams.py index 3a54a3f..62b1384 100644 --- a/tests/agents/test_3301_async_teams.py +++ b/tests/agents/test_3301_async_teams.py @@ -32,7 +32,9 @@ # Logging # ----------------------------------------------------------------------------- -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_DIR = os.path.join(PROJECT_ROOT, "log") os.makedirs(LOG_DIR, exist_ok=True) LOG_FILE = os.path.join(LOG_DIR, "tkex_test_3301_async_teams.log") @@ -54,6 +56,7 @@ # Per-test logging + async connection # ----------------------------------------------------------------------------- + @pytest.fixture(autouse=True) def log_test_name(request): logger.info("--- Starting test: %s ---", request.function.__name__) @@ -61,19 +64,11 @@ def log_test_name(request): logger.info("--- Finished test: %s ---", request.function.__name__) -@pytest.fixture(scope="module", autouse=True) -async def async_connect(test_env): - logger.info("Opening async database connection") - await select_ai.async_connect(**test_env.connect_params()) - yield - logger.info("Closing async database connection") - await select_ai.async_disconnect() - - # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- + async def expect_async_error(expected_code, coro_fn): """ expected_code: @@ -152,6 +147,7 @@ async def assert_team_status(team_name: str, expected_status: str) -> None: # Fixtures # ----------------------------------------------------------------------------- + @pytest.fixture(scope="module") async def python_gen_ai_profile(profile_attributes): logger.info("Creating profile: %s", PYSAI_TEAM_PROFILE_NAME) @@ -238,6 +234,7 @@ async def team(team_attributes): # Tests # ----------------------------------------------------------------------------- + async def test_3300_create_and_identity(team, team_attributes): log_team_details("test_3300_create_and_identity", team) assert team.team_name == PYSAI_TEAM_NAME @@ -248,7 +245,11 @@ async def test_3300_create_and_identity(team, team_attributes): @pytest.mark.parametrize("pattern", [None, ".*", "^PYSAI_TEAM_"]) async def test_3301_list(pattern): logger.info("Listing teams using pattern: %s", pattern) - teams = [t async for t in AsyncTeam.list(pattern)] if pattern else [t async for t in AsyncTeam.list()] + teams = ( + [t async for t in AsyncTeam.list(pattern)] + if pattern + else [t async for t in AsyncTeam.list()] + ) for t in teams: if t.team_name == PYSAI_TEAM_NAME: log_team_details("test_3301_list", t) @@ -316,15 +317,21 @@ async def test_3308_fetch_non_existing(): async def test_3311_set_attribute_invalid_key(team): - await expect_async_error("ORA-20053", lambda: team.set_attribute("no_such_attr", "x")) + await expect_async_error( + "ORA-20053", lambda: team.set_attribute("no_such_attr", "x") + ) async def test_3312_set_attribute_none(team): - await expect_async_error("ORA-20053", lambda: team.set_attribute("process", None)) + await expect_async_error( + "ORA-20053", lambda: team.set_attribute("process", None) + ) async def test_3313_set_attribute_empty(team): - await expect_async_error("ORA-20053", lambda: team.set_attribute("process", "")) + await expect_async_error( + "ORA-20053", lambda: team.set_attribute("process", "") + ) async def test_3314_set_attribute_invalid_value(team): @@ -355,7 +362,9 @@ async def test_3317_set_attribute_after_delete(team_attributes): t = AsyncTeam(name, team_attributes, "TMP") await t.create() await t.delete(force=True) - await expect_async_error("ORA-20053", lambda: t.set_attribute("process", "sequential")) + await expect_async_error( + "ORA-20053", lambda: t.set_attribute("process", "sequential") + ) async def test_3318_double_delete(team_attributes): @@ -384,10 +393,10 @@ async def test_3320_fetch_after_delete(team_attributes): await t.delete(force=True) await expect_async_error("NOT_FOUND", lambda: AsyncTeam.fetch(name)) + async def test_3321_double_delete(team_attributes): name = f"TMP_{uuid.uuid4().hex.upper()}" t = AsyncTeam(name, team_attributes, "TMP") await t.create() await t.delete(force=True) await expect_async_error("ORA-20053", lambda: t.delete(force=False)) - diff --git a/tests/agents/test_3800_agente2e.py b/tests/agents/test_3800_agente2e.py index a3b4a30..4f16d85 100644 --- a/tests/agents/test_3800_agente2e.py +++ b/tests/agents/test_3800_agente2e.py @@ -5,13 +5,13 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -import uuid -import time -import os import logging -import pytest +import os +import time +import uuid from contextlib import contextmanager +import pytest import select_ai from select_ai.agent import ( Agent, @@ -21,8 +21,8 @@ Team, TeamAttributes, Tool, - ToolParams, ToolAttributes, + ToolParams, ) # ---------------------------------------------------------------------- @@ -31,7 +31,9 @@ # Path -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3800_agente2e.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -76,7 +78,8 @@ def log_object_details(context: str, object_type: str, obj) -> None: "description": getattr(obj, "description", None), "provider_type": ( type(getattr(attributes, "provider", None)).__name__ - if attributes is not None and getattr(attributes, "provider", None) + if attributes is not None + and getattr(attributes, "provider", None) else None ), "object_count": ( @@ -210,7 +213,8 @@ def setup_test_user(test_env): msg = str(exc) if ( "ORA-01749" not in msg - and "Cannot GRANT or REVOKE privileges to or from yourself" not in msg + and "Cannot GRANT or REVOKE privileges to or from yourself" + not in msg ): raise @@ -220,11 +224,10 @@ def setup_test_user(test_env): ) finally: select_ai.disconnect() - select_ai.connect(**test_env.connect_params()) -@pytest.fixture(scope="session") -def openai_cred(): +@pytest.fixture(scope="module") +def openai_cred(connect): api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") assert api_key, "PYSAI_TEST_OPENAI_API_KEY not set" cred_name = "OPENAI_CRED" @@ -244,8 +247,8 @@ def openai_cred(): return cred_name -@pytest.fixture(scope="session") -def email_cred(): +@pytest.fixture(scope="module") +def email_cred(connect): smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") @@ -268,56 +271,22 @@ def email_cred(): return cred_name -@pytest.fixture(scope="session") -def allow_network_acl(): - with select_ai.cursor() as cur: - cur.execute("SELECT USER FROM dual") - db_user = cur.fetchone()[0] - - def append_ace(host, privileges): - try: - cur.execute( - f""" - BEGIN - DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( - host => '{host}', - ace => xs$ace_type( - privilege_list => xs$name_list({','.join([f"'{p}'" for p in privileges])}), - principal_name => '{db_user}', - principal_type => xs_acl.ptype_db - ) - ); - END; - """ - ) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" in msg - or "ORA-46313" in msg - or "already exists" in msg - ): - return - raise - - append_ace(EMAIL_SMTP_HOST, ["connect", "smtp"]) - - for host in ["api.openai.com", "a.co","amazon.in"]: - append_ace(host, ["connect", "http"]) - - yield - - # ---------------------------------------------------------------------- # MAIN TEST # ---------------------------------------------------------------------- + def test_3800_agent_end_to_end( - profile_attributes, setup_test_user, openai_cred, email_cred, allow_network_acl + connect, + profile_attributes, + setup_test_user, + openai_cred, + email_cred, + allow_network_acl, ): """ End-to-end Select AI Agent integration test. - + """ # ------------------------------- @@ -381,9 +350,7 @@ def test_3800_agent_end_to_end( attributes=ToolAttributes( tool_type="WEBSEARCH", instruction="Use this tool to find the current price of a product from a URL.", - tool_params=ToolParams( - credential_name=openai_cred - ), + tool_params=ToolParams(credential_name=openai_cred), ), ) websearch_tool.create(replace=True) @@ -395,7 +362,10 @@ def test_3800_agent_end_to_end( fetched_websearch_tool.tool_name, fetched_websearch_tool.attributes.tool_params.credential_name, ) - assert fetched_websearch_tool.attributes.tool_params.credential_name == openai_cred + assert ( + fetched_websearch_tool.attributes.tool_params.credential_name + == openai_cred + ) # Email notification tool email_tool = Tool( @@ -420,11 +390,19 @@ def test_3800_agent_end_to_end( fetched_email_tool.tool_name, fetched_email_tool.attributes.tool_params.credential_name, ) - assert fetched_email_tool.attributes.tool_params.credential_name == email_cred + assert ( + fetched_email_tool.attributes.tool_params.credential_name + == email_cred + ) assert Tool("Email") is not None - assert websearch_tool.attributes.tool_params.credential_name == openai_cred - assert email_tool.attributes.tool_params.credential_name == email_cred + assert ( + websearch_tool.attributes.tool_params.credential_name + == openai_cred + ) + assert ( + email_tool.attributes.tool_params.credential_name == email_cred + ) # ------------------------------- # TASK @@ -458,7 +436,7 @@ def test_3800_agent_end_to_end( assert task.task_name == "Return_And_Price_Match" assert set(task.attributes.tools) == {"Websearch", "Email"} - + # ------------------------------- # TEAM # ------------------------------- @@ -466,10 +444,12 @@ def test_3800_agent_end_to_end( team = Team( team_name="ReturnAgency", attributes=TeamAttributes( - agents=[{ - "name": "CustomerAgent", - "task": "Return_And_Price_Match", - }], + agents=[ + { + "name": "CustomerAgent", + "task": "Return_And_Price_Match", + } + ], process="sequential", ), ) @@ -519,7 +499,9 @@ def test_3800_agent_end_to_end( decoded_tool_history = _decode_history_rows(tool_history) - logger.info("Tool history rows fetched: %d", len(decoded_tool_history)) + logger.info( + "Tool history rows fetched: %d", len(decoded_tool_history) + ) assert decoded_tool_history finally: @@ -541,5 +523,7 @@ def test_3800_agent_end_to_end( created["agent"].delete(force=True) if created["profile"] is not None: - logger.info("Deleting profile: %s", created["profile"].profile_name) + logger.info( + "Deleting profile: %s", created["profile"].profile_name + ) created["profile"].delete(force=True) diff --git a/tests/agents/test_3800_async_agente2e.py b/tests/agents/test_3800_async_agente2e.py index af60c6b..c9b8199 100644 --- a/tests/agents/test_3800_async_agente2e.py +++ b/tests/agents/test_3800_async_agente2e.py @@ -36,8 +36,12 @@ # LOGGING # ---------------------------------------------------------------------- -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3800_async_agente2e.log") +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) +LOG_FILE = os.path.join( + PROJECT_ROOT, "log", "tkex_test_3800_async_agente2e.log" +) os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) root = logging.getLogger() @@ -79,7 +83,8 @@ def log_object_details(context: str, object_type: str, obj) -> None: "description": getattr(obj, "description", None), "provider_type": ( type(getattr(attributes, "provider", None)).__name__ - if attributes is not None and getattr(attributes, "provider", None) + if attributes is not None + and getattr(attributes, "provider", None) else None ), "object_count": ( @@ -216,7 +221,8 @@ def setup_test_user(test_env): msg = str(exc) if ( "ORA-01749" not in msg - and "Cannot GRANT or REVOKE privileges to or from yourself" not in msg + and "Cannot GRANT or REVOKE privileges to or from yourself" + not in msg ): raise @@ -226,11 +232,10 @@ def setup_test_user(test_env): ) finally: select_ai.disconnect() - select_ai.connect(**test_env.connect_params()) -@pytest.fixture(scope="session") -def openai_cred(): +@pytest.fixture(scope="module") +def openai_cred(connect): api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") assert api_key, "PYSAI_TEST_OPENAI_API_KEY not set" cred_name = "OPENAI_CRED" @@ -250,8 +255,8 @@ def openai_cred(): return cred_name -@pytest.fixture(scope="session") -def email_cred(): +@pytest.fixture(scope="module") +def email_cred(connect): smtp_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") smtp_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") @@ -274,63 +279,12 @@ def email_cred(): return cred_name -@pytest.fixture(scope="session") -def allow_network_acl(): - with select_ai.cursor() as cur: - cur.execute("SELECT USER FROM dual") - db_user = cur.fetchone()[0] - - def append_ace(host, privileges): - try: - cur.execute( - f""" - BEGIN - DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( - host => '{host}', - ace => xs$ace_type( - privilege_list => xs$name_list({','.join([f"'{p}'" for p in privileges])}), - principal_name => '{db_user}', - principal_type => xs_acl.ptype_db - ) - ); - END; - """ - ) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" in msg - or "ORA-46313" in msg - or "already exists" in msg - ): - return - raise - - append_ace(EMAIL_SMTP_HOST, ["connect", "smtp"]) - - for host in ["api.openai.com", "a.co", "amazon.in"]: - append_ace(host, ["connect", "http"]) - - yield - - -@pytest.fixture(scope="module", autouse=True) -async def async_connect( - test_env, setup_test_user, openai_cred, email_cred, allow_network_acl -): - logger.info( - "Opening async database connection | user=%s | dsn=%s", - test_env.test_user, - test_env.connect_string, - ) - await select_ai.async_connect(**test_env.connect_params()) - yield - logger.info("Closing async database connection") - await select_ai.async_disconnect() - - async def test_3800_agent_end_to_end_async( - profile_attributes, openai_cred, email_cred + async_connect, + profile_attributes, + openai_cred, + email_cred, + allow_network_acl, ): """End-to-end Select AI Agent integration test (async).""" @@ -456,8 +410,13 @@ async def test_3800_agent_end_to_end_async( [t.tool_name for t in created["tools"]], ) assert len(created["tools"]) == 2 - assert websearch_tool.attributes.tool_params.credential_name == openai_cred - assert email_tool.attributes.tool_params.credential_name == email_cred + assert ( + websearch_tool.attributes.tool_params.credential_name + == openai_cred + ) + assert ( + email_tool.attributes.tool_params.credential_name == email_cred + ) with log_step("Create task"): task = AsyncTask( @@ -566,5 +525,7 @@ async def test_3800_agent_end_to_end_async( await created["agent"].delete(force=True) if created["profile"] is not None: - logger.info("Deleting profile: %s", created["profile"].profile_name) + logger.info( + "Deleting profile: %s", created["profile"].profile_name + ) await created["profile"].delete(force=True) diff --git a/tests/agents/test_3900_async_sql_team.py b/tests/agents/test_3900_async_sql_team.py index 3a4b012..3bd07ea 100644 --- a/tests/agents/test_3900_async_sql_team.py +++ b/tests/agents/test_3900_async_sql_team.py @@ -1,5 +1,5 @@ -import os import logging +import os import uuid from contextlib import contextmanager @@ -11,7 +11,9 @@ pytestmark = pytest.mark.anyio # Configure file-based logging for this script run. -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "test_3900_async_sql_team.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -67,7 +69,8 @@ def log_object_details(context: str, object_type: str, obj) -> None: "description": getattr(obj, "description", None), "provider_type": ( type(getattr(attributes, "provider", None)).__name__ - if attributes is not None and getattr(attributes, "provider", None) + if attributes is not None + and getattr(attributes, "provider", None) else None ), "object_count": ( @@ -172,18 +175,6 @@ async def verify_credential_exists(credential_name, expected_username=None): assert actual_username == expected_username -async def connect_to_db(): - # Connect the Python client to the test database. - user = os.getenv("PYSAI_TEST_USER") - password = os.getenv("PYSAI_TEST_USER_PASSWORD") - dsn = os.getenv("PYSAI_TEST_CONNECT_STRING") - assert user, "PYSAI_TEST_USER not set" - assert password, "PYSAI_TEST_USER_PASSWORD not set" - assert dsn, "PYSAI_TEST_CONNECT_STRING not set" - logger.info("Connecting to database using configured test credentials") - await select_ai.async_connect(user=user, password=password, dsn=dsn) - - async def _cleanup_async_sql_team_objects(created) -> None: with log_step("Cleanup async SQL team objects"): if created["team"] is not None: @@ -191,14 +182,18 @@ async def _cleanup_async_sql_team_objects(created) -> None: logger.info("Deleting team: %s", created["team"].team_name) await created["team"].delete(force=True) except Exception: - logger.exception("Failed to delete team: %s", created["team"].team_name) + logger.exception( + "Failed to delete team: %s", created["team"].team_name + ) if created["task"] is not None: try: logger.info("Deleting task: %s", created["task"].task_name) await created["task"].delete(force=True) except Exception: - logger.exception("Failed to delete task: %s", created["task"].task_name) + logger.exception( + "Failed to delete task: %s", created["task"].task_name + ) for tool in reversed(created["tools"]): try: @@ -218,51 +213,29 @@ async def _cleanup_async_sql_team_objects(created) -> None: if created["profile"] is not None: try: - logger.info("Deleting profile: %s", created["profile"].profile_name) + logger.info( + "Deleting profile: %s", created["profile"].profile_name + ) await created["profile"].delete(force=True) except Exception: logger.exception( - "Failed to delete profile: %s", created["profile"].profile_name + "Failed to delete profile: %s", + created["profile"].profile_name, ) for credential_name in reversed(created["credentials"]): try: logger.info("Deleting credential: %s", credential_name) - await select_ai.async_delete_credential(credential_name, force=True) + await select_ai.async_delete_credential( + credential_name, force=True + ) except Exception: - logger.exception("Failed to delete credential: %s", credential_name) - - -async def allow_network_acl(): - # Grant the database user SMTP access required by the email notification tool. - async with select_ai.async_cursor() as cur: - try: - await cur.execute( - """ - BEGIN - DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( - host => :host, - ace => xs$ace_type( - privilege_list => xs$name_list('connect', 'smtp'), - principal_name => SYS_CONTEXT('USERENV', 'CURRENT_USER'), - principal_type => xs_acl.ptype_db - ) - ); - END; - """, - host=EMAIL_SMTP_HOST, - ) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" not in msg - and "ORA-46313" not in msg - and "already exists" not in msg - ): - raise + logger.exception( + "Failed to delete credential: %s", credential_name + ) -async def create_async_sql_team(): +async def create_async_sql_team(test_env): created = { "team": None, "task": None, @@ -275,8 +248,9 @@ async def create_async_sql_team(): # Initialize database access required by the team and tools. try: with log_step("Initialize database and network access"): - await connect_to_db() - await allow_network_acl() + logger.info( + "Using global async database pool from tests/conftest.py" + ) # Load OCI model and credential settings from the environment. oci_user_ocid = os.getenv("PYSAI_TEST_OCI_USER_OCID") @@ -350,11 +324,15 @@ async def create_async_sql_team(): log_object_details("create_sql_tool", "tool", sql_tool) fetched_sql_tool = await AsyncTool.fetch(SQL_TOOL_NAME) assert fetched_sql_tool.tool_name == SQL_TOOL_NAME - assert fetched_sql_tool.attributes.tool_params.profile_name == SQL_PROFILE_NAME + assert ( + fetched_sql_tool.attributes.tool_params.profile_name + == SQL_PROFILE_NAME + ) # Load SMTP settings for the email notification tool. email_credential_name = ( - os.getenv("PYSAI_TEST_EMAIL_CREDENTIAL_NAME") or f"EMAIL_CRED_{RUN_ID}" + os.getenv("PYSAI_TEST_EMAIL_CREDENTIAL_NAME") + or f"EMAIL_CRED_{RUN_ID}" ) email_username = os.getenv("PYSAI_TEST_EMAIL_CRED_USERNAME") email_password = os.getenv("PYSAI_TEST_EMAIL_CRED_PASSWORD") @@ -466,8 +444,8 @@ async def create_async_sql_team(): @pytest.fixture(scope="module") -async def async_sql_team(): - async for team in create_async_sql_team(): +async def async_sql_team(async_connect, test_env, allow_network_acl): + async for team in create_async_sql_team(test_env): yield team diff --git a/tests/agents/test_3900_sql_team.py b/tests/agents/test_3900_sql_team.py index b12e7e2..956f81a 100644 --- a/tests/agents/test_3900_sql_team.py +++ b/tests/agents/test_3900_sql_team.py @@ -1,14 +1,16 @@ -import os import logging +import os import uuid from contextlib import contextmanager -import select_ai import pytest +import select_ai import select_ai.agent # Configure file-based logging for this script run. -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../..") +) LOG_FILE = os.path.join(PROJECT_ROOT, "log", "test_3900_sql_team.log") os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) @@ -56,7 +58,8 @@ def log_object_details(context: str, object_type: str, obj) -> None: "description": getattr(obj, "description", None), "provider_type": ( type(getattr(attributes, "provider", None)).__name__ - if attributes is not None and getattr(attributes, "provider", None) + if attributes is not None + and getattr(attributes, "provider", None) else None ), "object_count": ( @@ -161,18 +164,6 @@ def verify_credential_exists(credential_name, expected_username=None): assert actual_username == expected_username -def connect_to_db(): - # Connect the Python client to the test database. - user = os.getenv("PYSAI_TEST_USER") - password = os.getenv("PYSAI_TEST_USER_PASSWORD") - dsn = os.getenv("PYSAI_TEST_CONNECT_STRING") - assert user, "PYSAI_TEST_USER not set" - assert password, "PYSAI_TEST_USER_PASSWORD not set" - assert dsn, "PYSAI_TEST_CONNECT_STRING not set" - logger.info("Connecting to database using configured test credentials") - select_ai.connect(user=user, password=password, dsn=dsn) - - def _cleanup_sql_team_objects(created) -> None: with log_step("Cleanup SQL team objects"): if created["team"] is not None: @@ -180,14 +171,18 @@ def _cleanup_sql_team_objects(created) -> None: logger.info("Deleting team: %s", created["team"].team_name) created["team"].delete(force=True) except Exception: - logger.exception("Failed to delete team: %s", created["team"].team_name) + logger.exception( + "Failed to delete team: %s", created["team"].team_name + ) if created["task"] is not None: try: logger.info("Deleting task: %s", created["task"].task_name) created["task"].delete(force=True) except Exception: - logger.exception("Failed to delete task: %s", created["task"].task_name) + logger.exception( + "Failed to delete task: %s", created["task"].task_name + ) for tool in reversed(created["tools"]): try: @@ -207,11 +202,14 @@ def _cleanup_sql_team_objects(created) -> None: if created["profile"] is not None: try: - logger.info("Deleting profile: %s", created["profile"].profile_name) + logger.info( + "Deleting profile: %s", created["profile"].profile_name + ) created["profile"].delete(force=True) except Exception: logger.exception( - "Failed to delete profile: %s", created["profile"].profile_name + "Failed to delete profile: %s", + created["profile"].profile_name, ) for credential_name in reversed(created["credentials"]): @@ -219,39 +217,12 @@ def _cleanup_sql_team_objects(created) -> None: logger.info("Deleting credential: %s", credential_name) select_ai.delete_credential(credential_name, force=True) except Exception: - logger.exception("Failed to delete credential: %s", credential_name) - - -def allow_network_acl(): - # Grant the database user SMTP access required by the email notification tool. - with select_ai.cursor() as cur: - try: - cur.execute( - """ - BEGIN - DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( - host => :host, - ace => xs$ace_type( - privilege_list => xs$name_list('connect', 'smtp'), - principal_name => SYS_CONTEXT('USERENV', 'CURRENT_USER'), - principal_type => xs_acl.ptype_db - ) - ); - END; - """, - host=EMAIL_SMTP_HOST, - ) - except Exception as exc: - msg = str(exc) - if ( - "ORA-46212" not in msg - and "ORA-46313" not in msg - and "already exists" not in msg - ): - raise + logger.exception( + "Failed to delete credential: %s", credential_name + ) -def create_sql_team(): +def create_sql_team(test_env): created = { "team": None, "task": None, @@ -264,8 +235,7 @@ def create_sql_team(): # Initialize database access required by the team and tools. try: with log_step("Initialize database and network access"): - connect_to_db() - allow_network_acl() + logger.info("Using global database pool from tests/conftest.py") # Load OCI model and credential settings from the environment. oci_credential_name = "SQL_TEAM_OCI_CRED" @@ -339,7 +309,10 @@ def create_sql_team(): log_object_details("create_sql_tool", "tool", sql_tool) fetched_sql_tool = select_ai.agent.Tool.fetch("SQL_QUERY_TOOL") assert fetched_sql_tool.tool_name == "SQL_QUERY_TOOL" - assert fetched_sql_tool.attributes.tool_params.profile_name == "SQL_PROFILE" + assert ( + fetched_sql_tool.attributes.tool_params.profile_name + == "SQL_PROFILE" + ) # Load SMTP settings for the email notification tool. email_credential_name = ( @@ -380,7 +353,9 @@ def create_sql_team(): ) created["tools"].append(email_tool) log_object_details("create_email_tool", "tool", email_tool) - fetched_email_tool = select_ai.agent.Tool.fetch("EMAIL_NOTIFICATION_TOOL") + fetched_email_tool = select_ai.agent.Tool.fetch( + "EMAIL_NOTIFICATION_TOOL" + ) assert fetched_email_tool.tool_name == "EMAIL_NOTIFICATION_TOOL" assert ( fetched_email_tool.attributes.tool_params.credential_name @@ -434,9 +409,14 @@ def create_sql_team(): team = select_ai.agent.Team( team_name="SQL_DATA_TEAM", attributes=select_ai.agent.TeamAttributes( - agents=[{"name": "SQL_ANALYST_AGENT", "task": "SQL_ANALYSIS_TASK"}], + agents=[ + { + "name": "SQL_ANALYST_AGENT", + "task": "SQL_ANALYSIS_TASK", + } + ], process="sequential", - ) + ), ) team.create(replace=True, enabled=True) created["team"] = team @@ -453,8 +433,8 @@ def create_sql_team(): @pytest.fixture(scope="module") -def sql_team(): - yield from create_sql_team() +def sql_team(connect, test_env, allow_network_acl): + yield from create_sql_team(test_env) def test_sql_team_runs(sql_team): diff --git a/tests/conftest.py b/tests/conftest.py index 1ede4fd..4673339 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,9 @@ # PYSAI_TEST_CONNECT_STRING: connect string for test suite # PYSAI_TEST_WALLET_LOCATION: location of wallet file (thin mode, mTLS) # PYSAI_TEST_WALLET_PASSWORD: password for wallet file (thin mode, mTLS) +# PYSAI_TEST_MIN_POOL_SIZE: Minimum number of connections in the pool +# PYSAI_TEST_MAX_POOL_SIZE: Maximum number of connections in the pool +# PYSAI_TEST_POOL_INCREMENT # # OCI Gen AI # PYSAI_TEST_OCI_USER_OCID @@ -28,11 +31,61 @@ import os import uuid +import oracledb import pytest import select_ai PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "UNLIMITED TABLESPACE", +) + + +def _ensure_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + with select_ai.db.get_connection() as conn: + conn.commit() + + +def _grant_basic_schema_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + with select_ai.db.get_connection() as conn: + conn.commit() + + +def _append_host_ace(cur, host: str, privileges, username: str): + privilege_list = ",".join([f"'{p}'" for p in privileges]) + cur.execute( + f""" + BEGIN + DBMS_NETWORK_ACL_ADMIN.APPEND_HOST_ACE( + host => '{host}', + ace => xs$ace_type( + privilege_list => xs$name_list({privilege_list}), + principal_name => '{username}', + principal_type => xs_acl.ptype_db + ) + ); + END; + """ + ) def get_env_value(name, default_value=None, required=False): @@ -61,8 +114,17 @@ def __init__(self): self.admin_password = get_env_value("ADMIN_PASSWORD") self.wallet_location = get_env_value("WALLET_LOCATION") self.wallet_password = get_env_value("WALLET_PASSWORD") - - def connect_params(self, admin: bool = False): + self.min_pool_size = int( + get_env_value("MIN_POOL_SIZE", default_value=2) + ) + self.max_pool_size = int( + get_env_value("MAX_POOL_SIZE", default_value=4) + ) + self.pool_increment = int( + get_env_value("POOL_INCREMENT", default_value=1) + ) + + def connect_params(self, admin: bool = False, use_pool: bool = False): """ Returns connect params """ @@ -76,6 +138,10 @@ def connect_params(self, admin: bool = False): "wallet_password": self.wallet_password, "config_dir": self.wallet_location, } + if use_pool: + connect_params["min_size"] = self.min_pool_size + connect_params["max_size"] = self.max_pool_size + connect_params["increment"] = self.pool_increment return connect_params @@ -90,39 +156,46 @@ def test_env(pytestconfig): return env -# @pytest.fixture(autouse=True, scope="session") -# def setup_test_user(test_env): -# select_ai.connect(**test_env.connect_params(admin=True)) -# select_ai.grant_privileges(users=[test_env.test_user]) -# select_ai.grant_http_access( -# users=[test_env.test_user], -# provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, -# ) -# select_ai.disconnect() +@pytest.fixture(autouse=True, scope="session") +def setup_test_user(test_env): + select_ai.connect(**test_env.connect_params(admin=True)) + _ensure_test_user_exists( + username=test_env.test_user, + password=test_env.test_user_password, + ) + _grant_basic_schema_privileges(username=test_env.test_user) + select_ai.grant_privileges(users=[test_env.test_user]) + select_ai.grant_http_access( + users=[test_env.test_user], + provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, + ) + select_ai.disconnect() -@pytest.fixture(autouse=True, scope="session") -def connect(test_env): - select_ai.connect(**test_env.connect_params()) +@pytest.fixture(autouse=True, scope="module") +def connect(setup_test_user, test_env): + select_ai.create_pool(**test_env.connect_params(use_pool=True)) yield select_ai.disconnect() -# @pytest.fixture(autouse=True, scope="session") -# async def async_connect(test_env, anyio_backend): -# await select_ai.async_connect(**test_env.connect_params()) -# yield -# await select_ai.async_disconnect() +@pytest.fixture(autouse=True, scope="module") +async def async_connect(setup_test_user, test_env, anyio_backend): + select_ai.create_pool_async(**test_env.connect_params(use_pool=True)) + yield + await select_ai.async_disconnect() -@pytest.fixture +@pytest.fixture(scope="module") def connection(): - return select_ai.db.get_connection() + with select_ai.db.get_connection() as conn: + yield conn @pytest.fixture -def async_connection(): - return select_ai.db.async_get_connection() +async def async_connection(): + async with select_ai.db.async_get_connection() as conn: + yield conn @pytest.fixture(scope="module") @@ -131,13 +204,13 @@ def cursor(): yield cr -@pytest.fixture +@pytest.fixture(scope="module") async def async_cursor(): async with select_ai.async_cursor() as cr: yield cr -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(autouse=True, scope="module") def oci_credential(connect, test_env): credential = { "credential_name": PYSAI_OCI_CREDENTIAL_NAME, @@ -154,3 +227,43 @@ def oci_credential(connect, test_env): @pytest.fixture(scope="module") def oci_compartment_id(test_env): return get_env_value("OCI_COMPARTMENT_ID", required=True) + + +@pytest.fixture(scope="module") +def allow_network_acl(test_env): + username = test_env.test_user.upper() + email_smtp_host = get_env_value("EMAIL_SMTPHOST") + http_hosts = ["api.openai.com", "a.co", "amazon.in"] + + with oracledb.connect(**test_env.connect_params(admin=True)) as conn: + cur = conn.cursor() + try: + if email_smtp_host: + try: + _append_host_ace( + cur, email_smtp_host, ["connect", "smtp"], username + ) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + + for host in http_hosts: + try: + _append_host_ace(cur, host, ["connect", "http"], username) + except Exception as exc: + msg = str(exc) + if ( + "ORA-46212" not in msg + and "ORA-46313" not in msg + and "already exists" not in msg + ): + raise + finally: + cur.close() + + yield From 4214e2c5ac4cc6e1f0004bb737d27785779269bf Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Sun, 19 Apr 2026 22:15:28 -0700 Subject: [PATCH 2/3] Remove duplicate fixtures --- .github/workflows/test.yaml | 1 + tests/agents/test_3800_agente2e.py | 29 -------- tests/agents/test_3800_async_agente2e.py | 28 -------- tests/conftest.py | 89 +++++++++++++++--------- 4 files changed, 58 insertions(+), 89 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e0ac095..859dcdb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -54,3 +54,4 @@ jobs: PYSAI_TEST_EMAIL_RECIPIENT: ${{secrets.PYSAI_TEST_EMAIL_RECIPIENT}} PYSAI_TEST_EMAIL_SENDER: ${{secrets.PYSAI_TEST_EMAIL_SENDER}} PYSAI_TEST_EMAIL_SMTPHOST: ${{secrets.PYSAI_TEST_EMAIL_SMTPHOST}} + PYSAI_TEST_OPENAI_API_KEY: ${{secrets.PYSAI_TEST_OPENAI_API_KEY}} diff --git a/tests/agents/test_3800_agente2e.py b/tests/agents/test_3800_agente2e.py index 4f16d85..eaa0793 100644 --- a/tests/agents/test_3800_agente2e.py +++ b/tests/agents/test_3800_agente2e.py @@ -198,34 +198,6 @@ def _decode_history_rows(rows): return decoded_rows -@pytest.fixture(scope="session") -def setup_test_user(test_env): - try: - select_ai.disconnect() - except Exception: - pass - - select_ai.connect(**test_env.connect_params(admin=True)) - try: - try: - select_ai.grant_privileges(users=[test_env.test_user]) - except Exception as exc: - msg = str(exc) - if ( - "ORA-01749" not in msg - and "Cannot GRANT or REVOKE privileges to or from yourself" - not in msg - ): - raise - - select_ai.grant_http_access( - users=[test_env.test_user], - provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, - ) - finally: - select_ai.disconnect() - - @pytest.fixture(scope="module") def openai_cred(connect): api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") @@ -279,7 +251,6 @@ def email_cred(connect): def test_3800_agent_end_to_end( connect, profile_attributes, - setup_test_user, openai_cred, email_cred, allow_network_acl, diff --git a/tests/agents/test_3800_async_agente2e.py b/tests/agents/test_3800_async_agente2e.py index c9b8199..26080ce 100644 --- a/tests/agents/test_3800_async_agente2e.py +++ b/tests/agents/test_3800_async_agente2e.py @@ -206,34 +206,6 @@ async def _decode_history_rows(rows): return decoded_rows -@pytest.fixture(scope="session") -def setup_test_user(test_env): - try: - select_ai.disconnect() - except Exception: - pass - - select_ai.connect(**test_env.connect_params(admin=True)) - try: - try: - select_ai.grant_privileges(users=[test_env.test_user]) - except Exception as exc: - msg = str(exc) - if ( - "ORA-01749" not in msg - and "Cannot GRANT or REVOKE privileges to or from yourself" - not in msg - ): - raise - - select_ai.grant_http_access( - users=[test_env.test_user], - provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, - ) - finally: - select_ai.disconnect() - - @pytest.fixture(scope="module") def openai_cred(connect): api_key = os.getenv("PYSAI_TEST_OPENAI_API_KEY") diff --git a/tests/conftest.py b/tests/conftest.py index 4673339..bd5b3f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,10 @@ import oracledb import pytest import select_ai +from select_ai.sql import ( + ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + GRANT_PRIVILEGES_TO_USER, +) PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" @@ -44,30 +48,45 @@ ) -def _ensure_test_user_exists(username: str, password: str): +def _ensure_test_user_exists(cur, username: str, password: str): username_upper = username.upper() - with select_ai.cursor() as cr: - cr.execute( - "SELECT 1 FROM dba_users WHERE username = :username", - username=username_upper, - ) - if cr.fetchone(): - return - escaped_password = password.replace('"', '""') - cr.execute( - f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' - ) - with select_ai.db.get_connection() as conn: - conn.commit() + cur.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cur.fetchone(): + return + escaped_password = password.replace('"', '""') + cur.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) -def _grant_basic_schema_privileges(username: str): +def _grant_basic_schema_privileges(cur, username: str): username_upper = username.upper() - with select_ai.cursor() as cr: - for privilege in _BASIC_SCHEMA_PRIVILEGES: - cr.execute(f"GRANT {privilege} TO {username_upper}") - with select_ai.db.get_connection() as conn: - conn.commit() + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cur.execute(f"GRANT {privilege} TO {username_upper}") + + +def _grant_select_ai_privileges(cur, username: str): + try: + cur.execute(GRANT_PRIVILEGES_TO_USER.format(username.strip())) + except Exception as exc: + msg = str(exc) + if ( + "ORA-01749" not in msg + and "Cannot GRANT or REVOKE privileges to or from yourself" + not in msg + ): + raise + + +def _grant_http_access(cur, username: str, provider_endpoint: str): + cur.execute( + ENABLE_AI_PROFILE_DOMAIN_FOR_USER, + user=username, + host=provider_endpoint, + ) def _append_host_ace(cur, host: str, privileges, username: str): @@ -158,18 +177,24 @@ def test_env(pytestconfig): @pytest.fixture(autouse=True, scope="session") def setup_test_user(test_env): - select_ai.connect(**test_env.connect_params(admin=True)) - _ensure_test_user_exists( - username=test_env.test_user, - password=test_env.test_user_password, - ) - _grant_basic_schema_privileges(username=test_env.test_user) - select_ai.grant_privileges(users=[test_env.test_user]) - select_ai.grant_http_access( - users=[test_env.test_user], - provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, - ) - select_ai.disconnect() + with oracledb.connect(**test_env.connect_params(admin=True)) as conn: + cur = conn.cursor() + try: + _ensure_test_user_exists( + cur, + username=test_env.test_user, + password=test_env.test_user_password, + ) + _grant_basic_schema_privileges(cur, username=test_env.test_user) + _grant_select_ai_privileges(cur, username=test_env.test_user) + _grant_http_access( + cur, + username=test_env.test_user, + provider_endpoint=select_ai.OpenAIProvider.provider_endpoint, + ) + conn.commit() + finally: + cur.close() @pytest.fixture(autouse=True, scope="module") From 1d485eef2866840a1f654b18a46bd22cc47bcb78 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Mon, 20 Apr 2026 08:39:31 -0700 Subject: [PATCH 3/3] Added create proc privilege to test user --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index bd5b3f4..dc8689c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,7 @@ _BASIC_SCHEMA_PRIVILEGES = ( "CREATE SESSION", "CREATE TABLE", + "CREATE PROCEDURE", "UNLIMITED TABLESPACE", )