Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ jobs:
PYSAI_TEST_EMAIL_RECIPIENT: ${{secrets.PYSAI_TEST_EMAIL_RECIPIENT}}
PYSAI_TEST_EMAIL_SENDER: ${{secrets.PYSAI_TEST_EMAIL_SENDER}}
PYSAI_TEST_EMAIL_SMTPHOST: ${{secrets.PYSAI_TEST_EMAIL_SMTPHOST}}
PYSAI_TEST_OPENAI_API_KEY: ${{secrets.PYSAI_TEST_OPENAI_API_KEY}}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
my_examples
logs/
log/
.idea
.env
.venv
Expand Down
125 changes: 88 additions & 37 deletions tests/agents/test_3001_async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
pytestmark = pytest.mark.anyio

# Path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
PROJECT_ROOT = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../..")
)
LOG_FILE = os.path.join(PROJECT_ROOT, "log", "tkex_test_3001_async_tools.log")
os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True)

Expand Down Expand Up @@ -87,15 +89,6 @@ def log_test_name(request):
logger.info("--- Finished test: %s ---", request.function.__name__)


@pytest.fixture(scope="module", autouse=True)
async def async_connect(test_env):
logger.info("Opening async database connection")
await select_ai.async_connect(**test_env.connect_params())
yield
logger.info("Closing async database connection")
await select_ai.async_disconnect()


async def get_tool_status(tool_name):
logger.info("Fetching tool status for: %s", tool_name)
async with select_ai.async_cursor() as cur:
Expand Down Expand Up @@ -134,9 +127,11 @@ def log_tool_details(context: str, tool) -> None:
"instruction": getattr(attrs, "instruction", None) if attrs else None,
"function": getattr(attrs, "function", None) if attrs else None,
"tool_inputs": getattr(attrs, "tool_inputs", None) if attrs else None,
"tool_params": tool_params.dict(exclude_null=False)
if tool_params is not None
else None,
"tool_params": (
tool_params.dict(exclude_null=False)
if tool_params is not None
else None
),
}

logger.info("TOOL_DETAILS: %s", details)
Expand Down Expand Up @@ -261,7 +256,9 @@ async def email_credential():
logger.info("EMAIL credential did not exist or could not be dropped")
pass

await select_ai.async_create_credential(credential=credential, replace=True)
await select_ai.async_create_credential(
credential=credential, replace=True
)
logger.info("Created EMAIL credential: %s", EMAIL_CRED_NAME)
yield EMAIL_CRED_NAME

Expand All @@ -288,7 +285,9 @@ async def slack_credential():
logger.info("SLACK credential did not exist or could not be dropped")
pass

await select_ai.async_create_credential(credential=credential, replace=True)
await select_ai.async_create_credential(
credential=credential, replace=True
)
logger.info("Created SLACK credential: %s", SLACK_CRED_NAME)
yield SLACK_CRED_NAME

Expand Down Expand Up @@ -345,27 +344,35 @@ async def slack_tool(slack_credential):

@pytest.fixture(scope="module")
async def neg_sql_tool():
logger.info("Creating SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME)
logger.info(
"Creating SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME
)
tool = await AsyncTool.create_sql_tool(
tool_name=NEG_SQL_TOOL_NAME,
profile_name="NON_EXISTENT_PROFILE",
replace=True,
)
yield tool
logger.info("Deleting SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME)
logger.info(
"Deleting SQL tool with invalid profile: %s", NEG_SQL_TOOL_NAME
)
await tool.delete(force=True)


@pytest.fixture(scope="module")
async def neg_rag_tool():
logger.info("Creating RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME)
logger.info(
"Creating RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME
)
tool = await AsyncTool.create_rag_tool(
tool_name=NEG_RAG_TOOL_NAME,
profile_name="NON_EXISTENT_RAG_PROFILE",
replace=True,
)
yield tool
logger.info("Deleting RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME)
logger.info(
"Deleting RAG tool with invalid profile: %s", NEG_RAG_TOOL_NAME
)
await tool.delete(force=True)


Expand Down Expand Up @@ -470,33 +477,57 @@ async def test_3006_enable_disable_sql_tool(sql_tool):


async def test_3007_web_search_tool_created(web_search_tool):
logger.info("Validating Web Search tool creation: %s", WEB_SEARCH_TOOL_NAME)
logger.info(
"Validating Web Search tool creation: %s", WEB_SEARCH_TOOL_NAME
)
log_tool_details("test_3007_web_search_tool_created", web_search_tool)
assert web_search_tool.tool_name == WEB_SEARCH_TOOL_NAME
assert web_search_tool.attributes.tool_type == select_ai.agent.ToolType.WEBSEARCH
assert web_search_tool.attributes.tool_params.credential_name == "OPENAI_CRED"
assert (
web_search_tool.attributes.tool_type
== select_ai.agent.ToolType.WEBSEARCH
)
assert (
web_search_tool.attributes.tool_params.credential_name == "OPENAI_CRED"
)


async def test_3008_email_tool_created(email_tool):
logger.info("Validating EMAIL tool creation: %s", EMAIL_TOOL_NAME)
log_tool_details("test_3008_email_tool_created", email_tool)
assert email_tool.tool_name == EMAIL_TOOL_NAME
assert str(email_tool.attributes.tool_type).upper() in ("EMAIL", "NOTIFICATION")
assert str(email_tool.attributes.tool_type).upper() in (
"EMAIL",
"NOTIFICATION",
)
assert email_tool.attributes.tool_params.credential_name == EMAIL_CRED_NAME
assert email_tool.attributes.tool_params.smtp_host is not None
assert str(email_tool.attributes.tool_params.notification_type).lower() == "email"
assert (
str(email_tool.attributes.tool_params.notification_type).lower()
== "email"
)


async def test_3009_slack_tool_created(slack_tool):
logger.info("Validating SLACK tool creation: %s", SLACK_TOOL_NAME)
if slack_tool is not None:
log_tool_details("test_3009_slack_tool_created", slack_tool)
assert slack_tool.tool_name == SLACK_TOOL_NAME
assert str(slack_tool.attributes.tool_type).upper() in ("SLACK", "NOTIFICATION")
assert slack_tool.attributes.tool_params.credential_name == SLACK_CRED_NAME
assert str(slack_tool.attributes.tool_params.notification_type).lower() == "slack"
assert str(slack_tool.attributes.tool_type).upper() in (
"SLACK",
"NOTIFICATION",
)
assert (
slack_tool.attributes.tool_params.credential_name
== SLACK_CRED_NAME
)
assert (
str(slack_tool.attributes.tool_params.notification_type).lower()
== "slack"
)
else:
logger.info("SLACK tool not created due to expected backend-side error")
logger.info(
"SLACK tool not created due to expected backend-side error"
)


async def test_3010_custom_tool_attributes_roundtrip():
Expand Down Expand Up @@ -536,7 +567,10 @@ async def test_3010_custom_tool_attributes_roundtrip():
)
assert isinstance(fetched.attributes.tool_inputs, list)
assert fetched.attributes.tool_inputs[0]["name"] == "p_birth_date"
assert "birth date" in fetched.attributes.tool_inputs[0]["description"].lower()
assert (
"birth date"
in fetched.attributes.tool_inputs[0]["description"].lower()
)
finally:
await tool.delete(force=True)

Expand Down Expand Up @@ -565,12 +599,16 @@ async def test_3011_custom_tool_without_tool_type():
assert fetched.tool_name == CUSTOM_NO_TYPE_TOOL_NAME
assert fetched.attributes.tool_type is None
assert fetched.attributes.function == PLSQL_FUNCTION_NAME
assert fetched.attributes.instruction == "Calculate age from birth date"
assert (
fetched.attributes.instruction == "Calculate age from birth date"
)
finally:
await tool.delete(force=True)


async def test_3012_custom_tool_with_tool_type_without_instruction(sql_profile):
async def test_3012_custom_tool_with_tool_type_without_instruction(
sql_profile,
):
logger.info(
"Validating custom tool creation with tool_type set and instruction unset"
)
Expand Down Expand Up @@ -631,15 +669,22 @@ async def test_3013_custom_tool_with_tool_type_and_instruction(sql_profile):

async def test_3014_sql_tool_with_invalid_profile_created(neg_sql_tool):
logger.info("Validating SQL tool with invalid profile")
log_tool_details("test_3010_sql_tool_with_invalid_profile_created", neg_sql_tool)
log_tool_details(
"test_3010_sql_tool_with_invalid_profile_created", neg_sql_tool
)
assert neg_sql_tool.tool_name == NEG_SQL_TOOL_NAME
assert neg_sql_tool.attributes.tool_type == select_ai.agent.ToolType.SQL
assert neg_sql_tool.attributes.tool_params.profile_name == "NON_EXISTENT_PROFILE"
assert (
neg_sql_tool.attributes.tool_params.profile_name
== "NON_EXISTENT_PROFILE"
)


async def test_3015_rag_tool_with_invalid_profile_created(neg_rag_tool):
logger.info("Validating RAG tool with invalid profile")
log_tool_details("test_3011_rag_tool_with_invalid_profile_created", neg_rag_tool)
log_tool_details(
"test_3011_rag_tool_with_invalid_profile_created", neg_rag_tool
)
assert neg_rag_tool.tool_name == NEG_RAG_TOOL_NAME
assert neg_rag_tool.attributes.tool_type == select_ai.agent.ToolType.RAG
assert (
Expand Down Expand Up @@ -691,12 +736,16 @@ async def test_3020_create_tool_default_status_enabled(sql_profile):
tool = await AsyncTool.create_built_in_tool(
tool_name=DEFAULT_STATUS_TOOL_NAME,
tool_type=select_ai.agent.ToolType.SQL,
tool_params=select_ai.agent.SQLToolParams(profile_name=SQL_PROFILE_NAME),
tool_params=select_ai.agent.SQLToolParams(
profile_name=SQL_PROFILE_NAME
),
)
try:
await assert_tool_status(DEFAULT_STATUS_TOOL_NAME, "ENABLED")
fetched = await AsyncTool.fetch(DEFAULT_STATUS_TOOL_NAME)
log_tool_details("test_3016_create_tool_default_status_enabled", fetched)
log_tool_details(
"test_3016_create_tool_default_status_enabled", fetched
)
logger.info(
"Fetched created tool | name=%s | type=%s | profile=%s",
fetched.tool_name,
Expand Down Expand Up @@ -771,7 +820,9 @@ async def test_3024_http_tool_created(email_credential):
fetched = await AsyncTool.fetch(HTTP_TOOL_NAME)
assert fetched.tool_name == HTTP_TOOL_NAME
assert fetched.attributes.tool_type == select_ai.agent.ToolType.HTTP
assert fetched.attributes.tool_params.credential_name == email_credential
assert (
fetched.attributes.tool_params.credential_name == email_credential
)
assert fetched.attributes.tool_params.endpoint == HTTP_ENDPOINT
finally:
logger.info("Deleting HTTP tool: %s", HTTP_TOOL_NAME)
Expand Down
Loading