From f8435418705f06e8da1a7ab7608507b614f24f05 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 24 Feb 2026 21:42:28 +0100 Subject: [PATCH 1/2] Migrate agentserver adapter to agent-framework rc1 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ai/agentserver/agentframework/__init__.py | 45 +- .../agentframework/_agent_framework.py | 49 +- .../agentframework/_ai_agent_adapter.py | 26 +- .../agentframework/_foundry_tools.py | 32 +- .../agentframework/_workflow_agent_adapter.py | 37 +- .../agent_framework_input_converters.py | 285 +++--- ...ramework_output_non_streaming_converter.py | 188 ++-- ...nt_framework_output_streaming_converter.py | 871 ++++++++++-------- .../models/agent_id_generator.py | 46 +- .../agentframework/models/constants.py | 18 +- .../models/conversation_converters.py | 93 +- .../models/human_in_the_loop_helper.py | 200 ++-- .../agentframework/models/utils/async_iter.py | 4 +- .../agentframework/persistence/__init__.py | 8 +- .../_foundry_checkpoint_repository.py | 6 +- .../_foundry_checkpoint_storage.py | 38 +- .../_foundry_conversation_message_store.py | 85 +- ..._foundry_conversation_thread_repository.py | 44 +- .../persistence/agent_thread_repository.py | 65 +- .../persistence/checkpoint_repository.py | 5 +- .../pyproject.toml | 6 +- .../samples/basic_simple/minimal_example.py | 4 +- .../chat_client_with_foundry_tool/README.md | 4 +- .../chat_client_with_foundry_tool.py | 26 +- .../human_in_the_loop_ai_function/README.md | 2 +- .../human_in_the_loop_ai_function/main.py | 55 +- .../human_in_the_loop_workflow_agent/main.py | 49 +- .../workflow_as_agent_reflection_pattern.py | 64 +- .../samples/mcp_apikey/mcp_apikey.py | 6 +- .../samples/mcp_simple/mcp_simple.py | 4 +- .../simple_async/minimal_async_example.py | 2 +- .../workflow_agent_simple.py | 312 ++++++- .../workflow_with_foundry_checkpoints/main.py | 43 +- .../tests/unit_tests/conftest.py | 3 +- .../test_agent_framework_input_converter.py | 65 +- .../test_conversation_id_optional.py | 4 +- .../test_conversation_item_converter.py | 29 +- .../test_foundry_checkpoint_storage.py | 46 +- .../tests/unit_tests/test_foundry_tools.py | 46 +- .../test_from_agent_framework_managed.py | 47 +- .../test_human_in_the_loop_helper.py | 106 +-- 41 files changed, 1630 insertions(+), 1438 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index b256a564ae75..1b9e6b009a40 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -4,34 +4,39 @@ # pylint: disable=docstring-should-be-keyword __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from typing import Callable, Optional, Union, overload +from typing import TYPE_CHECKING, Callable, Optional, Union, overload -from agent_framework import AgentProtocol, BaseAgent, Workflow, WorkflowBuilder -from azure.core.credentials_async import AsyncTokenCredential -from azure.core.credentials import TokenCredential +from agent_framework import BaseAgent, SupportsAgentRun, Workflow, WorkflowBuilder -from azure.ai.agentserver.core.application import PackageMetadata, set_current_app # pylint: disable=import-error,no-name-in-module +from azure.ai.agentserver.core.application import ( # pylint: disable=import-error,no-name-in-module + PackageMetadata, + set_current_app, +) +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential -from ._version import VERSION -from ._agent_framework import AgentFrameworkAgent -from ._ai_agent_adapter import AgentFrameworkAIAgentAdapter -from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter from ._foundry_tools import FoundryToolsChatMiddleware +from ._version import VERSION from .persistence import AgentThreadRepository, CheckpointRepository +if TYPE_CHECKING: + from ._agent_framework import AgentFrameworkAgent + from ._ai_agent_adapter import AgentFrameworkAIAgentAdapter + from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter + @overload def from_agent_framework( - agent: Union[BaseAgent, AgentProtocol], + agent: Union[BaseAgent, SupportsAgentRun], /, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository]=None ) -> "AgentFrameworkAIAgentAdapter": """ - Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. + Create an Agent Framework AI Agent Adapter from a SupportsAgentRun or BaseAgent. :param agent: The agent to adapt. - :type agent: Union[BaseAgent, AgentProtocol] + :type agent: Union[BaseAgent, SupportsAgentRun] :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] :param thread_repository: Optional thread repository for agent thread management. @@ -75,19 +80,19 @@ def from_agent_framework( ... def from_agent_framework( - agent_or_workflow: Union[BaseAgent, AgentProtocol, WorkflowBuilder, Callable[[], Workflow]], + agent_or_workflow: Union[BaseAgent, SupportsAgentRun, WorkflowBuilder, Callable[[], Workflow]], /, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, ) -> "AgentFrameworkAgent": """ - Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a + Create an Agent Framework Adapter from either a SupportsAgentRun/BaseAgent or a WorkflowAgent. One of agent or workflow must be provided. :param agent_or_workflow: The agent to adapt. - :type agent_or_workflow: Optional[Union[BaseAgent, AgentProtocol]] + :type agent_or_workflow: Optional[Union[BaseAgent, SupportsAgentRun]] :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] :param thread_repository: Optional thread repository for agent thread management. @@ -103,6 +108,8 @@ def from_agent_framework( """ if isinstance(agent_or_workflow, WorkflowBuilder): + from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter + return AgentFrameworkWorkflowAdapter( workflow_factory=agent_or_workflow.build, credentials=credentials, @@ -110,6 +117,8 @@ def from_agent_framework( checkpoint_repository=checkpoint_repository, ) if isinstance(agent_or_workflow, Callable): # type: ignore + from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter + return AgentFrameworkWorkflowAdapter( workflow_factory=agent_or_workflow, credentials=credentials, @@ -118,12 +127,14 @@ def from_agent_framework( ) # raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") - if isinstance(agent_or_workflow, (AgentProtocol, BaseAgent)): + if isinstance(agent_or_workflow, (SupportsAgentRun, BaseAgent)): + from ._ai_agent_adapter import AgentFrameworkAIAgentAdapter + return AgentFrameworkAIAgentAdapter(agent_or_workflow, credentials=credentials, thread_repository=thread_repository) raise TypeError("You must provide one of the instances of type " - "[AgentProtocol, BaseAgent, WorkflowBuilder or callable returning a Workflow]") + "[SupportsAgentRun, BaseAgent, WorkflowBuilder or callable returning a Workflow]") __all__ = [ diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index 5f507ea95061..f62b89e0d092 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -5,10 +5,9 @@ from __future__ import annotations import os -from typing import Any, AsyncGenerator, Optional, TYPE_CHECKING, Union, Callable +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, Union -from agent_framework import AgentProtocol, AgentThread, WorkflowAgent -from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module +from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent from opentelemetry import trace from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent @@ -23,8 +22,8 @@ from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .models.human_in_the_loop_helper import HumanInTheLoopHelper -from .persistence._foundry_conversation_thread_repository import FoundryConversationThreadRepository from .persistence import AgentThreadRepository +from .persistence._foundry_conversation_thread_repository import FoundryConversationThreadRepository if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -36,12 +35,12 @@ class AgentFrameworkAgent(FoundryCBAgent): """ Adapter class for integrating Agent Framework agents with the FoundryCB agent interface. - This class wraps an Agent Framework `AgentProtocol` instance and provides a unified interface + This class wraps an Agent Framework `SupportsAgentRun` instance and provides a unified interface for running agents in both streaming and non-streaming modes. It handles input and output conversion between the Agent Framework and the expected formats for FoundryCB agents. Parameters: - agent (AgentProtocol): An instance of an Agent Framework agent to be adapted. + agent (SupportsAgentRun): An instance of an Agent Framework agent to be adapted. Usage: - Instantiate with an Agent Framework agent. @@ -54,7 +53,7 @@ def __init__(self, thread_repository: Optional[AgentThreadRepository] = None, project_endpoint: Optional[str] = None, **kwargs) -> None: - """Initialize the AgentFrameworkAgent with an AgentProtocol. + """Initialize the AgentFrameworkAgent with a SupportsAgentRun-compatible agent adapter. :param credentials: Azure credentials for authentication. :type credentials: Optional[AsyncTokenCredential] @@ -149,6 +148,8 @@ def _try_import_configure_otel_providers(self): def _setup_tracing_with_azure_ai_client(self, project_endpoint: str): async def setup_async(): + from agent_framework.azure import AzureAIClient # pylint: disable=import-outside-toplevel,no-name-in-module + async with AzureAIClient( project_endpoint=project_endpoint, async_credential=self.credentials, @@ -180,17 +181,17 @@ async def agent_run( # pylint: disable=too-many-statements async def _load_agent_thread( self, context: AgentRunContext, - agent: Union[AgentProtocol, WorkflowAgent], - ) -> Optional[AgentThread]: - """Load the agent thread for a given conversation ID. + agent: Union[SupportsAgentRun, WorkflowAgent], + ) -> Optional[AgentSession]: + """Load the agent session for a given conversation ID. :param context: The agent run context. :type context: AgentRunContext :param agent: The agent instance. - :type agent: AgentProtocol | WorkflowAgent + :type agent: SupportsAgentRun | WorkflowAgent - :return: The loaded AgentThread if available, None otherwise. - :rtype: Optional[AgentThread] + :return: The loaded AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] """ if self._thread_repository and context.conversation_id: conversation_id = context.conversation_id @@ -198,16 +199,16 @@ async def _load_agent_thread( if agent_thread: logger.info(f"Loaded agent thread for conversation: {conversation_id}") return agent_thread - return agent.get_new_thread() + return agent.create_session() return None - async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentThread) -> None: - """Save the agent thread for a given conversation ID. + async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentSession) -> None: + """Save the agent session for a given conversation ID. :param context: The agent run context. :type context: AgentRunContext - :param agent_thread: The agent thread to save. - :type agent_thread: AgentThread + :param agent_thread: The agent session to save. + :type agent_thread: AgentSession :return: None :rtype: None @@ -219,18 +220,18 @@ async def _save_agent_thread(self, context: AgentRunContext, agent_thread: Agent def _run_streaming_updates( self, context: AgentRunContext, - run_stream: Callable[[], AsyncGenerator[Any, None]], - agent_thread: Optional[AgentThread] = None, + stream_runner: Callable[[], AsyncGenerator[Any, None]], + agent_thread: Optional[AgentSession] = None, ) -> AsyncGenerator[ResponseStreamEvent, Any]: """ Execute a streaming run with shared OAuth/error handling. :param context: The agent run context. :type context: AgentRunContext - :param run_stream: A callable that invokes the agent in stream mode - :type run_stream: Callable[[], AsyncGenerator[Any, None]] + :param stream_runner: A callable that invokes the agent in stream mode + :type stream_runner: Callable[[], AsyncGenerator[Any, None]] :param agent_thread: The agent thread to use during streaming updates. - :type agent_thread: Optional[AgentThread] + :type agent_thread: Optional[AgentSession] :return: An async generator yielding streaming events. :rtype: AsyncGenerator[ResponseStreamEvent, Any] @@ -245,7 +246,7 @@ async def stream_updates(): try: update_count = 0 try: - updates = run_stream() + updates = stream_runner() async for event in streaming_converter.convert(updates): update_count += 1 yield event diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index 0bb85bea5e35..4fc722d89bdd 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -6,29 +6,29 @@ from typing import Any, AsyncGenerator, Optional, Union -from agent_framework import AgentProtocol -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential +from agent_framework import SupportsAgentRun from azure.ai.agentserver.core import AgentRunContext -from azure.ai.agentserver.core.tools import OAuthConsentRequiredError from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( Response as OpenAIResponse, ResponseStreamEvent, ) +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential -from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from ._agent_framework import AgentFrameworkAgent +from .models.agent_framework_input_converters import transform_input from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from ._agent_framework import AgentFrameworkAgent from .persistence import AgentThreadRepository logger = get_logger() class AgentFrameworkAIAgentAdapter(AgentFrameworkAgent): - def __init__(self, agent: AgentProtocol, + def __init__(self, agent: SupportsAgentRun, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository] = None, *, @@ -49,10 +49,7 @@ async def agent_run( # pylint: disable=too-many-statements agent_thread = await self._load_agent_thread(context, self._agent) - input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) - message = await input_converter.transform_input( - request_input, - agent_thread=agent_thread) + message = transform_input(request_input) logger.debug("Transformed input message type: %s", type(message)) # Attach per-request context to the agent instance so tools can access it @@ -68,9 +65,10 @@ async def agent_run( # pylint: disable=too-many-statements if context.stream: return self._run_streaming_updates( context=context, - run_stream=lambda: self._agent.run_stream( + stream_runner=lambda: self._agent.run( message, - thread=agent_thread, + session=agent_thread, + stream=True, ), agent_thread=agent_thread, ) @@ -79,7 +77,7 @@ async def agent_run( # pylint: disable=too-many-statements logger.info("Running agent in non-streaming mode") result = await self._agent.run( message, - thread=agent_thread) + session=agent_thread) logger.debug("Agent run completed, result type: %s", type(result)) await self._save_agent_thread(context, agent_thread) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py index f936d32e4ec1..7818c81f538e 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -8,7 +8,7 @@ import inspect from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence -from agent_framework import AIFunction, ChatContext, ChatOptions, ChatMiddleware +from agent_framework import ChatContext, ChatMiddleware, ChatOptions, FunctionTool from pydantic import Field, create_model from azure.ai.agentserver.core import AgentServerContext @@ -47,19 +47,19 @@ def __init__( ) -> None: self._allowed_tools: List[FoundryToolLike] = [ensure_foundry_tool(tool) for tool in tools] - async def list_tools(self) -> List[AIFunction]: + async def list_tools(self) -> List[FunctionTool]: server_context = AgentServerContext.get() foundry_tool_catalog = server_context.tools.catalog resolved_tools = await foundry_tool_catalog.list(self._allowed_tools) return [self._to_aifunction(tool) for tool in resolved_tools] - def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> AIFunction: - """Convert an FoundryTool to an Agent Framework AI Function + def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> FunctionTool: + """Convert an FoundryTool to an Agent Framework Function Tool :param foundry_tool: The FoundryTool to convert. :type foundry_tool: ~azure.ai.agentserver.core.client.tools.aio.FoundryTool - :return: An AI Function Tool. - :rtype: AIFunction + :return: A Function Tool. + :rtype: FunctionTool """ # Get the input schema from the tool descriptor input_schema = foundry_tool.input_schema or {} @@ -103,8 +103,8 @@ async def tool_func(**kwargs: Any) -> Any: return await server_context.tools.invoke(foundry_tool, kwargs) _attach_signature_from_pydantic_model(tool_func, input_model) - # Create and return the AIFunction - return AIFunction( + # Create and return the FunctionTool + return FunctionTool( name=foundry_tool.name, description=foundry_tool.description or "No description available", func=tool_func, @@ -123,16 +123,20 @@ def __init__( async def process( self, context: ChatContext, - next: Callable[[ChatContext], Awaitable[None]], + call_next: Callable[[], Awaitable[None]], ) -> None: tools = await self._foundry_tool_client.list_tools() base_chat_options = context.chat_options if not base_chat_options: logger.debug("No existing ChatOptions found, creating new one with Foundry tools.") - base_chat_options = ChatOptions(tools=tools) - context.chat_options = base_chat_options + context.chat_options = ChatOptions(tools=tools) else: logger.debug("Adding Foundry tools to existing ChatOptions.") - base_tools = base_chat_options.tools or [] - context.chat_options.tools = base_tools + tools - await next(context) + if isinstance(base_chat_options, dict): + base_tools = base_chat_options.get("tools") or [] + base_chat_options["tools"] = base_tools + tools + context.chat_options = base_chat_options + else: + base_tools = base_chat_options.tools or [] + context.chat_options.tools = base_tools + tools + await call_next() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index f91d4c0849ca..1eccdaf8859f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -10,10 +10,7 @@ Union, ) -from agent_framework import Workflow, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint -from agent_framework._workflows import get_checkpoint_summary -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential +from agent_framework import CheckpointStorage, Workflow, WorkflowAgent, WorkflowCheckpoint from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger @@ -22,9 +19,11 @@ ResponseStreamEvent, ) from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from ._agent_framework import AgentFrameworkAgent -from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from .models.agent_framework_input_converters import transform_input from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) @@ -69,8 +68,8 @@ async def agent_run( # pylint: disable=too-many-statements if checkpoint_storage: selected_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) if selected_checkpoint: - summary = get_checkpoint_summary(selected_checkpoint) - if summary.status == "completed": + checkpoint_status = self._checkpoint_status(selected_checkpoint) + if checkpoint_status == "completed": logger.warning( "Selected checkpoint %s is completed. Will not resume from it.", selected_checkpoint.checkpoint_id, @@ -80,21 +79,18 @@ async def agent_run( # pylint: disable=too-many-statements await self._load_checkpoint(agent, selected_checkpoint, checkpoint_storage) logger.info("Loaded checkpoint with ID: %s", selected_checkpoint.checkpoint_id) - input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) - message = await input_converter.transform_input( - request_input, - agent_thread=agent_thread, - checkpoint=selected_checkpoint) + message = transform_input(request_input) logger.debug("Transformed input message type: %s", type(message)) # Use split converters if context.stream: return self._run_streaming_updates( context=context, - run_stream=lambda: agent.run_stream( + stream_runner=lambda: agent.run( message, - thread=agent_thread, + session=agent_thread, checkpoint_storage=checkpoint_storage, + stream=True, ), agent_thread=agent_thread, ) @@ -103,7 +99,7 @@ async def agent_run( # pylint: disable=too-many-statements logger.info("Running WorkflowAgent in non-streaming mode") result = await agent.run( message, - thread=agent_thread, + session=agent_thread, checkpoint_storage=checkpoint_storage) logger.debug("WorkflowAgent run completed, result type: %s", type(result)) @@ -129,6 +125,17 @@ async def oauth_consent_stream(error=e): def _build_agent(self) -> WorkflowAgent: return self._workflow_factory().as_agent() + + def _checkpoint_status(self, checkpoint: WorkflowCheckpoint) -> Optional[str]: + status = getattr(checkpoint, "status", None) + if status: + return status + metadata = getattr(checkpoint, "metadata", None) + if isinstance(metadata, dict): + value = metadata.get("status") + return str(value) if value is not None else None + return None + async def _get_latest_checkpoint(self, checkpoint_storage: CheckpointStorage) -> Optional[Any]: """Load the latest checkpoint from the given storage. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py index a21e5b9c44c7..3d53657b5f73 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py @@ -1,182 +1,139 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=too-many-nested-blocks,too-many-return-statements,too-many-branches # mypy: disable-error-code="no-redef" from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List -from agent_framework import ( - AgentThread, - ChatMessage, - RequestInfoEvent, - Role as ChatRole, - WorkflowCheckpoint, -) -from agent_framework._types import TextContent +from agent_framework import Content, Message from azure.ai.agentserver.core.logger import get_logger logger = get_logger() -class AgentFrameworkInputConverter: +def transform_input( # pylint: disable=too-many-return-statements + input_item: str | List[Dict] | None, +) -> str | Message | list[str | Message] | None: """Normalize inputs for agent.run. Accepts: str | List | None - Returns: None | str | ChatMessage | list[str] | list[ChatMessage] - """ - def __init__(self, *, hitl_helper=None) -> None: - self._hitl_helper = hitl_helper - - async def transform_input( - self, - input: str | List[Dict] | None, - agent_thread: Optional[AgentThread] = None, - checkpoint: Optional[WorkflowCheckpoint] = None, - ) -> str | ChatMessage | list[str] | list[ChatMessage] | None: - logger.debug("Transforming input of type: %s", type(input)) - - if input is None: - return None + Returns: None | str | Message | list[str] | list[Message] - if isinstance(input, str): - return input - - if self._hitl_helper: - # load pending requests from checkpoint and thread messages if available - thread_messages = [] - if agent_thread and agent_thread.message_store: - thread_messages = await agent_thread.message_store.list_messages() - pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint) - if pending_hitl_requests: - logger.info("Pending HitL requests: %s", list(pending_hitl_requests.keys())) - hitl_response = self._hitl_helper.validate_and_convert_hitl_response( - input, - pending_requests=pending_hitl_requests) - if hitl_response: - return hitl_response - - return self._transform_input_internal(input) - - def _transform_input_internal( - self, - input: str | List[Dict] | None, - ) -> str | ChatMessage | list[str] | list[ChatMessage] | None: - try: - if isinstance(input, list): - messages: list[str | ChatMessage] = [] - - for item in input: - # Case 1: ImplicitUserMessage with content as str or list of ItemContentInputText - if self._is_implicit_user_message(item): - content = item.get("content", None) - if isinstance(content, str): - messages.append(content) - elif isinstance(content, list): - text_parts: list[str] = [] - for content_item in content: - text_content = self._extract_input_text(content_item) - if text_content: - text_parts.append(text_content) - if text_parts: - messages.append(" ".join(text_parts)) - - # Case 2: message params (user/assistant/system) - elif ( - item.get("type") in ("", None, "message") - and item.get("role") is not None - and item.get("content") is not None - ): - role_map = { - "user": ChatRole.USER, - "assistant": ChatRole.ASSISTANT, - "system": ChatRole.SYSTEM, - } - role = role_map.get(item.get("role", "user"), ChatRole.USER) - - content_text = "" - item_content = item.get("content", None) - if item_content and isinstance(item_content, list): - text_parts: list[str] = [] - for content_item in item_content: - item_text = self._extract_input_text(content_item) - if item_text: - text_parts.append(item_text) - content_text = " ".join(text_parts) if text_parts else "" - elif item_content and isinstance(item_content, str): - content_text = str(item_content) - - if content_text: - messages.append(ChatMessage(role=role, text=content_text)) - - # Determine the most natural return type - if not messages: - return None - if len(messages) == 1: - return messages[0] - if all(isinstance(m, str) for m in messages): - return [m for m in messages if isinstance(m, str)] - if all(isinstance(m, ChatMessage) for m in messages): - return [m for m in messages if isinstance(m, ChatMessage)] - - # Mixed content: coerce ChatMessage to str by extracting TextContent parts - result: list[str] = [] - for msg in messages: - if isinstance(msg, ChatMessage): - text_parts: list[str] = [] - for c in getattr(msg, "contents", []) or []: - if isinstance(c, TextContent): - text_parts.append(c.text) - result.append(" ".join(text_parts) if text_parts else str(msg)) - else: - result.append(str(msg)) - return result - - raise TypeError(f"Unsupported input type: {type(input)}") - except Exception as e: - logger.error("Error processing messages: %s", e, exc_info=True) - raise Exception(f"Error processing messages: {e}") from e # pylint: disable=broad-exception-raised - - def _is_implicit_user_message(self, item: Dict) -> bool: - return "content" in item and "role" not in item and "type" not in item - - def _extract_input_text(self, content_item: Dict) -> str: - if content_item.get("type") == "input_text" and "text" in content_item: - text_content = content_item.get("text") - if isinstance(text_content, str): - return text_content - return None # type: ignore - - def _validate_and_convert_hitl_response( - self, - pending_request: Dict, - input: List[Dict], - ) -> Optional[List[ChatMessage]]: - if not self._hitl_helper: - logger.warning("HitL helper not provided; cannot validate HitL response.") - return None - if isinstance(input, str): - logger.warning("Expected list input for HitL response validation, got str.") - return None - if not isinstance(input, list) or len(input) != 1: - logger.warning("Expected single-item list input for HitL response validation.") - return None - - item = input[0] - if item.get("type") != "function_call_output": - logger.warning("Expected function_call_output type for HitL response validation.") - return None - call_id = item.get("call_id", None) - if not call_id or call_id not in pending_request: - logger.warning("Function call output missing valid call_id for HitL response validation.") - return None - request_info = pending_request[call_id] - if isinstance(request_info, dict): - request_info = RequestInfoEvent.from_dict(request_info) - if not isinstance(request_info, RequestInfoEvent): - logger.warning("No valid pending request info found for call_id: %s", call_id) + :param input_item: The raw input to normalize. + :type input_item: str or List[Dict] or None + """ + logger.debug("Transforming input of type: %s", type(input_item)) + + if input_item is None: + return None + + if isinstance(input_item, str): + return input_item + + try: + if isinstance(input_item, list): + messages: list[str | Message] = [] + + for item in input_item: + match item: + # Case 1: ImplicitUserMessage — no "role" or "type" key + case {"content": content} if "role" not in item and "type" not in item: + messages.extend(_parse_implicit_user_content(content)) + + # Case 2: Explicit message with role + case {"type": "message", "role": role, "content": content}: + _parse_explicit_message(role, content, messages) + + # Determine the most natural return type + if not messages: + return None + if len(messages) == 1: + return messages[0] + if all(isinstance(m, str) for m in messages): + return [m for m in messages if isinstance(m, str)] + if all(isinstance(m, Message) for m in messages): + return [m for m in messages if isinstance(m, Message)] + + # Mixed content: coerce Message to str by extracting text content parts + return _coerce_to_strings(messages) + + raise TypeError(f"Unsupported input type: {type(input_item)}") + except Exception as e: + logger.debug("Error processing messages: %s", e, exc_info=True) + raise Exception(f"Error processing messages: {e}") from e # pylint: disable=broad-exception-raised + + +def _parse_implicit_user_content(content: str | list | None) -> list[str]: + """Extract text from an implicit user message (no role/type keys). + + :param content: The content to parse. + :type content: str or list or None + :return: A list of extracted text strings. + :rtype: list[str] + """ + match content: + case str(): + return [content] + case list(): + text_parts = [_extract_input_text(item) for item in content] + joined = " ".join(t for t in text_parts if t) + return [joined] if joined else [] + case _: + return [] + + +def _parse_explicit_message(role: str, content: str | list | None, sink: list[str | Message]) -> None: + """Parse an explicit message dict and append to sink. + + :param role: The role of the message sender. + :type role: str + :param content: The message content. + :type content: str or list or None + :param sink: The list to append parsed messages to. + :type sink: list[str | Message] + """ + match role: + case "user" | "assistant" | "system" | "tool": + pass + case _: + raise ValueError(f"Unsupported message role: {role!r}") + + content_text = "" + match content: + case str(): + content_text = content + case list(): + text_parts = [_extract_input_text(item) for item in content] + content_text = " ".join(t for t in text_parts if t) + + if content_text: + sink.append(Message(role=role, contents=[Content.from_text(content_text)])) + + +def _coerce_to_strings(messages: list[str | Message]) -> list[str | Message]: + """Coerce a mixed list of str/Message into all strings. + + :param messages: The mixed list of strings and Messages. + :type messages: list[str | Message] + :return: A list with Messages coerced to strings. + :rtype: list[str | Message] + """ + result: list[str | Message] = [] + for msg in messages: + match msg: + case Message(): + text_parts = [c.text for c in (getattr(msg, "contents", None) or []) if c.type == "text"] + result.append(" ".join(text_parts) if text_parts else str(msg)) + case str(): + result.append(msg) + return result + + +def _extract_input_text(content_item: Dict) -> str | None: + match content_item: + case {"type": "input_text", "text": str() as text}: + return text + case _: return None - - return self._hitl_helper.convert_response(request_info, item) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 4984b2fc0423..a8849dd684b4 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -5,37 +5,31 @@ import datetime import json -from typing import Any, List +from typing import Any, Dict, List -from agent_framework import ( - AgentRunResponse, - FunctionCallContent, - FunctionResultContent, - ErrorContent, - TextContent, -) -from agent_framework._types import UserInputRequestContents +from agent_framework import AgentResponse, Content from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import Response as OpenAIResponse -from azure.ai.agentserver.core.models.projects import ItemContentOutputText +from azure.ai.agentserver.core.models.projects import ( + ItemContentOutputText, + ResponsesAssistantMessageItemResource, +) -from .agent_id_generator import AgentIdGenerator -from .constants import Constants -from .human_in_the_loop_helper import HumanInTheLoopHelper +from . import constants +from .agent_id_generator import generate_agent_id logger = get_logger() class AgentFrameworkOutputNonStreamingConverter: # pylint: disable=name-too-long - """Non-streaming converter: AgentRunResponse -> OpenAIResponse.""" + """Non-streaming converter: AgentResponse -> OpenAIResponse.""" - def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper=None): + def __init__(self, context: AgentRunContext): self._context = context self._response_id = None self._response_created_at = None - self._hitl_helper = hitl_helper def _ensure_response_started(self) -> None: if not self._response_id: @@ -46,31 +40,23 @@ def _ensure_response_started(self) -> None: def _build_item_content_output_text(self, text: str) -> ItemContentOutputText: return ItemContentOutputText(text=text, annotations=[]) - def _build_created_by(self, author_name: str) -> dict: - self._ensure_response_started() - - agent_dict = { - "type": "agent_id", - "name": author_name or "", - "version": "", # Default to empty string - } - - return { - "agent": agent_dict, - "response_id": self._response_id, - } + def _new_assistant_message_item(self, message_text: str) -> ResponsesAssistantMessageItemResource: + item_content = self._build_item_content_output_text(message_text) + return ResponsesAssistantMessageItemResource( + id=self._context.id_generator.generate_message_id(), status="completed", content=[item_content] + ) - def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIResponse: + def transform_output_for_response(self, response: AgentResponse) -> OpenAIResponse: """Build an OpenAIResponse capturing all supported content types. Previously this method only emitted text message items. We now also capture: - - FunctionCallContent -> function_call output item - - FunctionResultContent -> function_call_output item + - function_call content -> function_call output item + - function_result content -> function_call_output item to stay aligned with the streaming converter so no output is lost. - :param response: The AgentRunResponse from the agent framework. - :type response: AgentRunResponse + :param response: The AgentResponse from the agent framework. + :type response: AgentResponse :return: The constructed OpenAIResponse. :rtype: OpenAIResponse @@ -85,11 +71,9 @@ def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIRes contents = getattr(message, "contents", None) if not contents: continue - # Extract author_name from this message - msg_author_name = getattr(message, "author_name", None) or "" for j, content in enumerate(contents): logger.debug(" content index=%d in message=%d type=%s", j, i, type(content).__name__) - self._append_content_item(content, completed_items, msg_author_name) + self._append_content_item(content, completed_items) response_data = self._construct_response_data(completed_items) openai_response = OpenAIResponse(response_data) @@ -102,36 +86,27 @@ def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIRes # ------------------------- helper append methods ------------------------- - def _append_content_item(self, content: Any, sink: List[dict], author_name: str) -> None: + def _append_content_item(self, content: Any, sink: List[dict]) -> None: """Dispatch a content object to the appropriate append helper. - Adding this indirection keeps the main transform method compact and makes it - simpler to extend with new content types later. - - :param content: The content object to append. + :param content: The content object to dispatch. :type content: Any - :param sink: The list to append the converted content dict to. + :param sink: The list to append items to. :type sink: List[dict] - :param author_name: The author name for the created_by field. - :type author_name: str - - :return: None - :rtype: None """ - if isinstance(content, TextContent): - self._append_text_content(content, sink, author_name) - elif isinstance(content, FunctionCallContent): - self._append_function_call_content(content, sink, author_name) - elif isinstance(content, FunctionResultContent): - self._append_function_result_content(content, sink, author_name) - elif isinstance(content, UserInputRequestContents): - self._append_user_input_request_contents(content, sink, author_name) - elif isinstance(content, ErrorContent): - raise ValueError(f"ErrorContent received: code={content.error_code}, message={content.message}") - else: - logger.debug("unsupported content type skipped: %s", type(content).__name__) - - def _append_text_content(self, content: TextContent, sink: List[dict], author_name: str) -> None: + match content.type: + case "text" | "text_reasoning": + self._append_text_content(content, sink) + case "function_call": + self._append_function_call_content(content, sink) + case "function_result": + self._append_function_result_content(content, sink) + case "usage": + logger.debug("Skipping usage content (input/output token counts)") + case _: + logger.warning("Unhandled content type in non-streaming: %s", content.type) + + def _append_text_content(self, content: Content, sink: List[dict]) -> None: text_value = getattr(content, "text", None) if not text_value: return @@ -150,12 +125,11 @@ def _append_text_content(self, content: TextContent, sink: List[dict], author_na "logprobs": [], } ], - "created_by": self._build_created_by(author_name), } ) logger.debug(" added message item id=%s text_len=%d", item_id, len(text_value)) - def _append_function_call_content(self, content: FunctionCallContent, sink: List[dict], author_name: str) -> None: + def _append_function_call_content(self, content: Content, sink: List[dict]) -> None: name = getattr(content, "name", "") or "" arguments = getattr(content, "arguments", "") if not isinstance(arguments, str): @@ -173,7 +147,6 @@ def _append_function_call_content(self, content: FunctionCallContent, sink: List "call_id": call_id, "name": name, "arguments": arguments or "", - "created_by": self._build_created_by(author_name), } ) logger.debug( @@ -184,20 +157,15 @@ def _append_function_call_content(self, content: FunctionCallContent, sink: List len(arguments or ""), ) - def _append_function_result_content( - self, - content: FunctionResultContent, - sink: List[dict], - author_name: str, - ) -> None: + def _append_function_result_content(self, content: Content, sink: List[dict]) -> None: # Coerce the function result into a simple display string. - result = [] raw = getattr(content, "result", None) - if isinstance(raw, str): - result = [raw] - elif isinstance(raw, list): - for item in raw: - result.append(self._coerce_result_text(item)) # type: ignore + result: list[str | dict[str, Any]] = [] + match raw: + case str(): + result = [raw] + case list(): + result = [self._coerce_result_text(item) for item in raw] call_id = getattr(content, "call_id", None) or "" func_out_id = self._context.id_generator.generate_function_output_id() sink.append( @@ -207,76 +175,46 @@ def _append_function_result_content( "status": "completed", "call_id": call_id, "output": json.dumps(result) if len(result) > 0 else "", - "created_by": self._build_created_by(author_name), } ) logger.debug( - "added function_call_output item id=%s call_id=%s " - "output_len=%d", + "added function_call_output item id=%s call_id=%s output_len=%d", func_out_id, call_id, len(result), ) - def _append_user_input_request_contents( - self, - content: UserInputRequestContents, - sink: List[dict], - author_name: str, - ) -> None: - item_id = self._context.id_generator.generate_function_call_id() - content = self._hitl_helper.convert_user_input_request_content(content) - sink.append( - { - "id": item_id, - "type": "function_call", - "status": "in_progress", - "call_id": content["call_id"], - "name": content["name"], - "arguments": content["arguments"], - "created_by": self._build_created_by(author_name), - } - ) - logger.debug( - " added user_input_request item id=%s call_id=%s", - item_id, - content["call_id"], - ) - # ------------- simple normalization helper ------------------------- def _coerce_result_text(self, value: Any) -> str | dict: - """ - Return a string if value is already str or a TextContent-like object; else str(value). + """Return a string or dict representation of a result value. - :param value: The value to coerce. + :param value: The result value to coerce. :type value: Any - - :return: The coerced string or dict. - :rtype: str | dict + :return: A string or dict representation. + :rtype: str or dict """ - if value is None: - return "" - if isinstance(value, str): - return value - # Direct TextContent instance - if isinstance(value, TextContent): - content_payload = {"type": "text", "text": getattr(value, "text", "")} - return content_payload - - return "" + match value: + case None: + return "" + case str(): + return value + case _ if hasattr(value, 'type') and value.type == "text": + return {"type": "text", "text": getattr(value, "text", "")} + case _: + return "" def _construct_response_data(self, output_items: List[dict]) -> dict: - agent_id = AgentIdGenerator.generate(self._context) + agent_id = generate_agent_id(self._context) - response_data = { + response_data: Dict[str, Any] = { "object": "response", "metadata": {}, "agent": agent_id, "conversation": self._context.get_conversation_object(), "type": "message", "role": "assistant", - "temperature": Constants.DEFAULT_TEMPERATURE, - "top_p": Constants.DEFAULT_TOP_P, + "temperature": constants.DEFAULT_TEMPERATURE, + "top_p": constants.DEFAULT_TOP_P, "user": "", "id": self._context.response_id, "created_at": self._response_created_at, diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index 6ba9bd05ea1b..266f4d570ea0 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -1,27 +1,19 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=attribute-defined-outside-init,protected-access,unnecessary-lambda-assignment -# mypy: disable-error-code="call-overload,assignment,arg-type,override" +# pylint: disable=attribute-defined-outside-init,protected-access +# mypy: disable-error-code="call-overload,assignment,arg-type" from __future__ import annotations import datetime import json -from typing import Any, AsyncIterable, List, Union +import uuid +from typing import Any, List, Optional, cast -from agent_framework import ( - AgentRunResponseUpdate, - BaseContent, - FunctionResultContent, -) -from agent_framework._types import ( - ErrorContent, - FunctionCallContent, - TextContent, - UserInputRequestContents, -) +from agent_framework import AgentResponseUpdate, Content from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( Response as OpenAIResponse, ResponseStreamEvent, @@ -30,11 +22,11 @@ FunctionToolCallItemResource, FunctionToolCallOutputItemResource, ItemContentOutputText, - ItemResource, ResponseCompletedEvent, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent, + ResponseErrorEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseInProgressEvent, @@ -45,423 +37,560 @@ ResponseTextDoneEvent, ) -from .agent_id_generator import AgentIdGenerator -from .human_in_the_loop_helper import HumanInTheLoopHelper -from .utils.async_iter import chunk_on_change, peek +from .agent_id_generator import generate_agent_id + +logger = get_logger() class _BaseStreamingState: """Base interface for streaming state handlers.""" - async def convert_contents( - self, - contents: AsyncIterable[BaseContent], - author_name: str, - ) -> AsyncIterable[ResponseStreamEvent]: - # pylint: disable=unused-argument + def prework(self, ctx: Any) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument + return [] + + def convert_content(self, ctx: Any, content) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument raise NotImplementedError + def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument + return [] + class _TextContentStreamingState(_BaseStreamingState): """State handler for text and reasoning-text content during streaming.""" - def __init__(self, parent: AgentFrameworkOutputStreamingConverter): - self._parent = parent - - async def convert_contents( - self, - contents: AsyncIterable[TextContent], - author_name: str, - ) -> AsyncIterable[ResponseStreamEvent]: - item_id = self._parent.context.id_generator.generate_message_id() - output_index = self._parent.next_output_index() - - yield ResponseOutputItemAddedEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=ResponsesAssistantMessageItemResource( - id=item_id, - status="in_progress", - content=[], - created_by=self._parent._build_created_by(author_name), - ), + def __init__(self, context: AgentRunContext) -> None: + self.context = context + self.item_id = None + self.output_index = None + self.text_buffer = "" + self.text_part_started = False + + def prework(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if self.item_id is not None: + return events + + # Start a new assistant message item (in_progress) + self.item_id = self.context.id_generator.generate_message_id() + self.output_index = ctx._next_output_index # pylint: disable=protected-access + ctx._next_output_index += 1 + + message_item = ResponsesAssistantMessageItemResource( + id=self.item_id, + status="in_progress", + content=[], ) - yield ResponseContentPartAddedEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - content_index=0, - part=ItemContentOutputText(text="", annotations=[], logprobs=[]), + events.append( + ResponseOutputItemAddedEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, + item=message_item, + ) ) - text = "" - async for content in contents: - delta = content.text - text += delta - - yield ResponseTextDeltaEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, + if not self.text_part_started: + empty_part = ItemContentOutputText(text="", annotations=[], logprobs=[]) + events.append( + ResponseContentPartAddedEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, + content_index=0, + part=empty_part, + ) + ) + self.text_part_started = True + return events + + def convert_content(self, ctx: Any, content: Content) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + match content.type: + case "text": + delta = content.text or "" + case "text_reasoning": + delta = getattr(content, "reasoning", "") or "" + case _: + delta = getattr(content, "text", None) or getattr(content, "reasoning", "") or "" + + # buffer accumulated text + self.text_buffer += delta + + # emit delta event for text + assert self.item_id is not None, "Text state not initialized: missing item_id" + assert self.output_index is not None, "Text state not initialized: missing output_index" + events.append( + ResponseTextDeltaEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, content_index=0, delta=delta, ) - - yield ResponseTextDoneEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - content_index=0, - text=text, ) - - content_part = ItemContentOutputText(text=text, annotations=[], logprobs=[]) - yield ResponseContentPartDoneEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - content_index=0, - part=content_part, + return events + + def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if not self.item_id: + return events + + full_text = self.text_buffer + assert self.item_id is not None and self.output_index is not None + events.append( + ResponseTextDoneEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, + content_index=0, + text=full_text, + ) ) - - item = ResponsesAssistantMessageItemResource( - id=item_id, - status="completed", - content=[content_part], - created_by=self._parent._build_created_by(author_name), + final_part = ItemContentOutputText(text=full_text, annotations=[], logprobs=[]) + events.append( + ResponseContentPartDoneEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, + content_index=0, + part=final_part, + ) ) - yield ResponseOutputItemDoneEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=item, + completed_item = ResponsesAssistantMessageItemResource( + id=self.item_id, status="completed", content=[final_part] ) - - self._parent.add_completed_output_item(item) # pylint: disable=protected-access + events.append( + ResponseOutputItemDoneEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, + item=completed_item, + ) + ) + ctx._last_completed_text = full_text # pylint: disable=protected-access + # store for final response + ctx._completed_output_items.append( + { + "id": self.item_id, + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "text": full_text, + "annotations": [], + "logprobs": [], + } + ], + "role": "assistant", + } + ) + # reset state + self.item_id = None + self.output_index = None + self.text_buffer = "" + self.text_part_started = False + return events class _FunctionCallStreamingState(_BaseStreamingState): """State handler for function_call content during streaming.""" - def __init__(self, - parent: AgentFrameworkOutputStreamingConverter, - hitl_helper: HumanInTheLoopHelper): - self._parent = parent - self._hitl_helper = hitl_helper - - async def convert_contents( - self, contents: AsyncIterable[Union[FunctionCallContent, UserInputRequestContents]], author_name: str - ) -> AsyncIterable[ResponseStreamEvent]: - content_by_call_id = {} - ids_by_call_id = {} - hitl_contents = [] - - async for content in contents: - if isinstance(content, FunctionCallContent): - if content.call_id not in content_by_call_id: - item_id = self._parent.context.id_generator.generate_function_call_id() - output_index = self._parent.next_output_index() - - content_by_call_id[content.call_id] = content - ids_by_call_id[content.call_id] = (item_id, output_index) - - yield ResponseOutputItemAddedEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=FunctionToolCallItemResource( - id=item_id, - status="in_progress", - call_id=content.call_id, - name=content.name, - arguments="", - created_by=self._parent._build_created_by(author_name), - ), - ) - else: - content_by_call_id[content.call_id] = content_by_call_id[content.call_id] + content - item_id, output_index = ids_by_call_id[content.call_id] - - args_delta = content.arguments if isinstance(content.arguments, str) else "" - yield ResponseFunctionCallArgumentsDeltaEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - delta=args_delta, - ) - - elif isinstance(content, UserInputRequestContents): - converted_hitl = self._hitl_helper.convert_user_input_request_content(content) - if converted_hitl: - hitl_contents.append(converted_hitl) - - for call_id, content in content_by_call_id.items(): - item_id, output_index = ids_by_call_id[call_id] - args = self._serialize_arguments(content.arguments) - yield ResponseFunctionCallArgumentsDoneEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - arguments=args, - ) - - item = FunctionToolCallItemResource( - id=item_id, - status="completed", - call_id=call_id, - name=content.name, - arguments=args, - created_by=self._parent._build_created_by(author_name), - ) - yield ResponseOutputItemDoneEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=item, - ) - - self._parent.add_completed_output_item(item) # pylint: disable=protected-access - - # process HITL contents after function calls - for content in hitl_contents: - item_id = self._parent.context.id_generator.generate_function_call_id() - output_index = self._parent.next_output_index() - - yield ResponseOutputItemAddedEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=FunctionToolCallItemResource( - id=item_id, - status="in_progress", - call_id=content["call_id"], - name=content["name"], - arguments="", - created_by=self._parent._build_created_by(author_name), - ), + def __init__(self, context: AgentRunContext) -> None: + self.context = context + self.item_id = None + self.output_index = None + self.call_id = None + self.name = None + self.args_buffer = "" + self.requires_approval = False + self.approval_request_id: str | None = None + + def prework(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if self.item_id is not None: + return events + # initialize function-call item + self.item_id = self.context.id_generator.generate_function_call_id() + self.output_index = ctx._next_output_index + ctx._next_output_index += 1 + + self.call_id = self.call_id or str(uuid.uuid4()) + function_item = FunctionToolCallItemResource( + id=self.item_id, + status="in_progress", + call_id=self.call_id, + name=self.name or "", + arguments="", + ) + events.append( + ResponseOutputItemAddedEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, + item=function_item, ) - yield ResponseFunctionCallArgumentsDeltaEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - delta=content["arguments"], + ) + return events + + def convert_content(self, ctx: Any, content: Content) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + # record identifiers (once available) + self.name = getattr(content, "name", None) or self.name or "" + self.call_id = getattr(content, "call_id", None) or self.call_id or str(uuid.uuid4()) + + args_delta = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) + args_delta = args_delta or "" + self.args_buffer += args_delta + assert self.item_id is not None and self.output_index is not None + for ch in args_delta: + events.append( + ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, + delta=ch, + ) ) - yield ResponseFunctionCallArgumentsDoneEvent( - sequence_number=self._parent.next_sequence(), - item_id=item_id, - output_index=output_index, - arguments=content["arguments"], - ) - item = FunctionToolCallItemResource( - id=item_id, - status="in_progress", - call_id=content["call_id"], - name=content["name"], - arguments=content["arguments"], - created_by=self._parent._build_created_by(author_name), + # finalize if arguments are detected to be complete + is_done = bool( + getattr(content, "is_final", False) + or getattr(content, "final", False) + or getattr(content, "done", False) + or getattr(content, "arguments_final", False) + or getattr(content, "arguments_done", False) + or getattr(content, "finish", False) + ) + if not is_done and self.args_buffer: + try: + json.loads(self.args_buffer) + is_done = True + except Exception: # pylint: disable=broad-exception-caught + pass + + if is_done: + events.append( + ResponseFunctionCallArgumentsDoneEvent( + sequence_number=ctx.next_sequence(), + item_id=self.item_id, + output_index=self.output_index, + arguments=self.args_buffer, + ) ) - yield ResponseOutputItemDoneEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=item, + events.extend(self.afterwork(ctx)) + return events + + def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if not self.item_id: + return events + assert self.call_id is not None + done_item = FunctionToolCallItemResource( + id=self.item_id, + status="completed", + call_id=self.call_id, + name=self.name or "", + arguments=self.args_buffer, + ) + assert self.output_index is not None + events.append( + ResponseOutputItemDoneEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, + item=done_item, ) - self._parent.add_completed_output_item(item) - - def _serialize_arguments(self, arguments: Any) -> str: - if isinstance(arguments, str): - return arguments - try: - return json.dumps(arguments) - except Exception: # pylint: disable=broad-exception-caught - return str(arguments) + ) + # store for final response + ctx._completed_output_items.append( + { + "id": self.item_id, + "type": "function_call", + "call_id": self.call_id, + "name": self.name or "", + "arguments": self.args_buffer, + "status": "requires_approval" if self.requires_approval else "completed", + "requires_approval": self.requires_approval, + "approval_request_id": self.approval_request_id, + } + ) + # reset + self.item_id = None + self.output_index = None + self.args_buffer = "" + self.call_id = None + self.name = None + self.requires_approval = False + self.approval_request_id = None + return events class _FunctionCallOutputStreamingState(_BaseStreamingState): """Handles function_call_output items streaming (non-chunked simple output).""" - def __init__(self, parent: AgentFrameworkOutputStreamingConverter): - self._parent = parent - - async def convert_contents( - self, contents: AsyncIterable[FunctionResultContent], author_name: str - ) -> AsyncIterable[ResponseStreamEvent]: - async for content in contents: - item_id = self._parent.context.id_generator.generate_function_output_id() - output_index = self._parent.next_output_index() - - output = (f"{type(content.exception)}({str(content.exception)})" - if content.exception - else self._to_output(content.result)) - - item = FunctionToolCallOutputItemResource( - id=item_id, - status="completed", - call_id=content.call_id, - output=output, - created_by=self._parent._build_created_by(author_name), - ) - - yield ResponseOutputItemAddedEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, + def __init__( + self, + context: AgentRunContext, + call_id: Optional[str] = None, + output: Optional[list[str]] = None, + ) -> None: + # Avoid mutable default argument (Ruff B006) + self.context = context + self.item_id = None + self.output_index = None + self.call_id = call_id + self.output = output if output is not None else [] + + def prework(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if self.item_id is not None: + return events + self.item_id = self.context.id_generator.generate_function_output_id() + self.output_index = ctx._next_output_index + ctx._next_output_index += 1 + + self.call_id = self.call_id or str(uuid.uuid4()) + item = FunctionToolCallOutputItemResource( + id=self.item_id, + status="in_progress", + call_id=self.call_id, + output="", + ) + events.append( + ResponseOutputItemAddedEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, item=item, ) - - yield ResponseOutputItemDoneEvent( - sequence_number=self._parent.next_sequence(), - output_index=output_index, - item=item, + ) + return events + + def convert_content(self, ctx: Any, content: Any) -> List[ResponseStreamEvent]: # no delta events for now + events: List[ResponseStreamEvent] = [] + # treat entire output as final + raw = getattr(content, "result", None) + result: list[str | dict[str, Any]] = [] + match raw: + case str(): + result = [raw] if raw else [str(self.output)] + case list(): + result = [self._coerce_result_text(item) for item in raw] + self.output = json.dumps(result) if len(result) > 0 else "" + + events.extend(self.afterwork(ctx)) + return events + + def _coerce_result_text(self, value: Any) -> str | dict: + """Return a string or dict representation of a result value. + + :param value: The result value to coerce. + :type value: Any + :return: A string or dict representation. + :rtype: str or dict + """ + match value: + case None: + return "" + case str(): + return value + case _ if hasattr(value, 'type') and value.type == "text": + return {"type": "text", "text": getattr(value, "text", "")} + case _: + return "" + + def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if not self.item_id: + return events + # Ensure types conform: call_id must be str (guarantee non-None) and output is a single string + str_call_id = self.call_id or "" + single_output: str = cast(str, self.output[0]) if self.output else "" + done_item = FunctionToolCallOutputItemResource( + id=self.item_id, + status="completed", + call_id=str_call_id, + output=single_output, + ) + assert self.output_index is not None + events.append( + ResponseOutputItemDoneEvent( + sequence_number=ctx.next_sequence(), + output_index=self.output_index, + item=done_item, ) - - self._parent.add_completed_output_item(item) # pylint: disable=protected-access - - @classmethod - def _to_output(cls, result: Any) -> str: - if isinstance(result, str): - return result - if isinstance(result, list): - text = [] - for item in result: - if isinstance(item, BaseContent): - text.append(item.to_dict()) - else: - text.append(str(item)) - return json.dumps(text) - return "" + ) + ctx._completed_output_items.append( + { + "id": self.item_id, + "type": "function_call_output", + "status": "completed", + "call_id": self.call_id, + "output": self.output, + } + ) + self.item_id = None + self.output_index = None + return events class AgentFrameworkOutputStreamingConverter: """Streaming converter using content-type-specific state handlers.""" - def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper=None) -> None: + def __init__(self, context: AgentRunContext) -> None: self._context = context # sequence numbers must start at 0 for first emitted event - self._sequence = -1 - self._next_output_index = -1 - self._response_id = self._context.response_id + self._sequence = 0 + self._response_id = None self._response_created_at = None - self._completed_output_items: List[ItemResource] = [] - self._hitl_helper = hitl_helper + self._next_output_index = 0 + self._last_completed_text = "" + self._active_state: Optional[_BaseStreamingState] = None + self._active_kind = None # "text" | "function_call" | "error" + # accumulate completed output items for final response + self._completed_output_items: List[dict] = [] + + def _ensure_response_started(self) -> None: + if not self._response_id: + self._response_id = self._context.response_id + if not self._response_created_at: + self._response_created_at = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) def next_sequence(self) -> int: self._sequence += 1 return self._sequence - def next_output_index(self) -> int: - self._next_output_index += 1 - return self._next_output_index - - def add_completed_output_item(self, item: ItemResource) -> None: - self._completed_output_items.append(item) - - @property - def context(self) -> AgentRunContext: - return self._context - - async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> AsyncIterable[ResponseStreamEvent]: - self._ensure_response_started() - - created_response = self._build_response(status="in_progress") - yield ResponseCreatedEvent( - sequence_number=self.next_sequence(), - response=created_response, - ) - - yield ResponseInProgressEvent( - sequence_number=self.next_sequence(), - response=created_response, - ) - - is_changed = ( - lambda a, b: a is not None \ - and b is not None \ - and a.message_id != b.message_id - ) - - async for group in chunk_on_change(updates, is_changed): - has_value, first_tuple, contents_with_author = await peek(self._read_updates(group)) - if not has_value or first_tuple is None: - continue - - first, author_name = first_tuple # Extract content and author_name from tuple - - state = None - if isinstance(first, TextContent): - state = _TextContentStreamingState(self) - elif isinstance(first, (FunctionCallContent, UserInputRequestContents)): - state = _FunctionCallStreamingState(self, self._hitl_helper) - elif isinstance(first, FunctionResultContent): - state = _FunctionCallOutputStreamingState(self) - elif isinstance(first, ErrorContent): - error_msg = ( - f"ErrorContent received: code={first.error_code}, " - f"message={first.message}" - ) - raise ValueError(error_msg) - if not state: - continue - - # Extract just the content from (content, author_name) tuples using async generator - async def extract_contents(): - async for content, _ in contents_with_author: # pylint: disable=cell-var-from-loop - yield content - - async for content in state.convert_contents(extract_contents(), author_name): - yield content - - yield ResponseCompletedEvent( - sequence_number=self.next_sequence(), - response=self._build_response(status="completed"), + def _switch_state(self, kind: str) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if self._active_state and self._active_kind != kind: + events.extend(self._active_state.afterwork(self)) + self._active_state = None + self._active_kind = None + + if self._active_state is None: + match kind: + case "text": + self._active_state = _TextContentStreamingState(self._context) + case "function_call": + self._active_state = _FunctionCallStreamingState(self._context) + case "function_call_output": + self._active_state = _FunctionCallOutputStreamingState(self._context) + case _: + self._active_state = None + self._active_kind = kind + if self._active_state: + events.extend(self._active_state.prework(self)) + return events + + def transform_output_for_streaming(self, update: AgentResponseUpdate) -> List[ResponseStreamEvent]: + logger.debug( + "Transforming streaming update with %d contents", + len(update.contents) if getattr(update, "contents", None) else 0, ) - - def _build_created_by(self, author_name: str) -> dict: - agent_dict = { - "type": "agent_id", - "name": author_name or "", - "version": "", - } - - return { - "agent": agent_dict, - "response_id": self._response_id, - } - - async def _read_updates( - self, - updates: AsyncIterable[AgentRunResponseUpdate], - ) -> AsyncIterable[tuple[BaseContent, str]]: - async for update in updates: - if not update.contents: - continue - - # Extract author_name from each update - author_name = getattr(update, "author_name", "") or "" - - accepted_types = (TextContent, - FunctionCallContent, - UserInputRequestContents, - FunctionResultContent, - ErrorContent) - for content in update.contents: - if isinstance(content, accepted_types): - yield (content, author_name) - - def _ensure_response_started(self) -> None: - if not self._response_created_at: - self._response_created_at = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) - - def _build_response(self, status: str) -> OpenAIResponse: self._ensure_response_started() - agent_id = AgentIdGenerator.generate(self._context) + events: List[ResponseStreamEvent] = [] + + if getattr(update, "contents", None): + for i, content in enumerate(update.contents): + logger.debug("Processing content %d: %s", i, content.type) + match content.type: + case "text" | "text_reasoning": + events.extend(self._switch_state("text")) + if isinstance(self._active_state, _TextContentStreamingState): + events.extend(self._active_state.convert_content(self, content)) + case "function_call": + events.extend(self._switch_state("function_call")) + if isinstance(self._active_state, _FunctionCallStreamingState): + events.extend(self._active_state.convert_content(self, content)) + case "function_result": + events.extend(self._switch_state("function_call_output")) + if isinstance(self._active_state, _FunctionCallOutputStreamingState): + call_id = getattr(content, "call_id", None) + if call_id: + self._active_state.call_id = call_id + events.extend(self._active_state.convert_content(self, content)) + case "function_approval_request": + events.extend(self._switch_state("function_call")) + if isinstance(self._active_state, _FunctionCallStreamingState): + self._active_state.requires_approval = True + self._active_state.approval_request_id = getattr(content, "id", None) + events.extend(self._active_state.convert_content(self, content.function_call)) + case "error": + events.extend(self._switch_state("error")) + events.append( + ResponseErrorEvent( + sequence_number=self.next_sequence(), + code=getattr(content, "error_code", None) or "server_error", + message=getattr(content, "message", None) or "An error occurred", + param="", + ) + ) + case "usage": + # Usage metadata — not emitted as a stream event + logger.debug("Skipping usage content (input/output token counts)") + case _: + logger.warning("Unhandled content type in streaming: %s", content.type) + return events + + def finalize_last_content(self) -> List[ResponseStreamEvent]: + events: List[ResponseStreamEvent] = [] + if self._active_state: + events.extend(self._active_state.afterwork(self)) + self._active_state = None + self._active_kind = None + return events + + def build_response(self, status: str) -> OpenAIResponse: + self._ensure_response_started() + agent_id = generate_agent_id(self._context) response_data = { "object": "response", "agent_id": agent_id, "id": self._response_id, "status": status, "created_at": self._response_created_at, - "conversation": self._context.get_conversation_object(), - "output": [] # ensure output is always set } - - # set output even if _completed_output_items is empty, never leave the output as null - if status == "completed": + if status == "completed" and self._completed_output_items: response_data["output"] = self._completed_output_items return OpenAIResponse(response_data) + + # High-level helpers to emit lifecycle events for streaming + def initial_events(self) -> List[ResponseStreamEvent]: + """ + Emit ResponseCreatedEvent and an initial ResponseInProgressEvent. + + :return: List of initial response stream events. + :rtype: List[ResponseStreamEvent] + """ + self._ensure_response_started() + events: List[ResponseStreamEvent] = [] + created_response = self.build_response(status="in_progress") + events.append( + ResponseCreatedEvent( + sequence_number=self.next_sequence(), + response=created_response, + ) + ) + events.append( + ResponseInProgressEvent( + sequence_number=self.next_sequence(), + response=self.build_response(status="in_progress"), + ) + ) + return events + + def completion_events(self) -> List[ResponseStreamEvent]: + """ + Finalize any active content and emit a single ResponseCompletedEvent. + + :return: List of completion response stream events. + :rtype: List[ResponseStreamEvent] + """ + self._ensure_response_started() + events: List[ResponseStreamEvent] = [] + events.extend(self.finalize_last_content()) + completed_response = self.build_response(status="completed") + events.append( + ResponseCompletedEvent( + sequence_number=self.next_sequence(), + response=completed_response, + ) + ) + return events diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_id_generator.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_id_generator.py index da4045898a5e..abd2dd2c02ef 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_id_generator.py @@ -1,13 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -"""Helper utilities for constructing AgentId model instances. - -Centralizes logic for safely building a `models.AgentId` from a request agent -object. We intentionally do not allow overriding the generated model's fixed -`type` literal ("agent_id"). If the provided object lacks a name, `None` is -returned so callers can decide how to handle absence. -""" +"""Helper for constructing AgentId model instances from request context.""" from __future__ import annotations @@ -17,28 +11,20 @@ from azure.ai.agentserver.core.models import projects -class AgentIdGenerator: - @staticmethod - def generate(context: AgentRunContext) -> Optional[projects.AgentId]: - """ - Builds an AgentId model from the request agent object in the provided context. - - :param context: The AgentRunContext containing the request. - :type context: AgentRunContext - - :return: The constructed AgentId model, or None if the request lacks an agent name. - :rtype: Optional[projects.AgentId] - """ - agent = context.request.get("agent") - if not agent: - return None +def generate_agent_id(context: AgentRunContext) -> Optional[projects.AgentId]: + """Build an AgentId model from the request agent object in the provided context. - agent_id = projects.AgentId( - { - "type": agent.type, - "name": agent.name, - "version": agent.version, - } - ) + :param context: The agent run context containing the request. + :type context: AgentRunContext + """ + agent = context.request.get("agent") + if not agent: + return None - return agent_id + return projects.AgentId( + { + "type": agent.type, + "name": agent.name, + "version": agent.version, + } + ) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/constants.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/constants.py index 859e115e425e..990eb7d83388 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/constants.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/constants.py @@ -1,13 +1,13 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -class Constants: - # streaming configuration - # Environment variable name to control idle timeout for streaming updates (seconds) - AGENTS_ADAPTER_STREAM_TIMEOUT_S = "AGENTS_ADAPTER_STREAM_TIMEOUT_S" - # Default idle timeout (seconds) when env var or request override not provided - DEFAULT_STREAM_TIMEOUT_S = 300.0 - # model defaults - DEFAULT_TEMPERATURE = 1.0 - DEFAULT_TOP_P = 1.0 +# streaming configuration +# Environment variable name to control idle timeout for streaming updates (seconds) +AGENTS_ADAPTER_STREAM_TIMEOUT_S = "AGENTS_ADAPTER_STREAM_TIMEOUT_S" +# Default idle timeout (seconds) when env var or request override not provided +DEFAULT_STREAM_TIMEOUT_S = 300.0 + +# model defaults +DEFAULT_TEMPERATURE = 1.0 +DEFAULT_TOP_P = 1.0 diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py index 245cc54fda2d..b830df5b7565 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py @@ -6,22 +6,15 @@ from typing import Any, Mapping, Optional from uuid import uuid4 -from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent -from agent_framework._types import ( - DataContent, - HostedFileContent, - TextContent, - TextReasoningContent, - UriContent, -) +from agent_framework import Content, Message from openai.types.conversations import ConversationItem from azure.ai.agentserver.core.logger import get_logger logger = get_logger() -class ConversationItemConverter: +class ConversationItemConverter: _ROLE_MAP = { "assistant": "assistant", "system": "system", @@ -63,15 +56,8 @@ class ConversationItemConverter: "error", ) - def to_chat_message(self, item: ConversationItem) -> Optional[ChatMessage]: - """ - Convert a ConversationItem from the Conversations API to a ChatMessage. - - :param item: The ConversationItem from the Conversations API. - :type item: ConversationItem - :return: The converted ChatMessage, or None if conversion is not possible. - :rtype: Optional[ChatMessage] - """ + def to_chat_message(self, item: ConversationItem) -> Optional[Message]: + """Convert a ConversationItem from the Conversations API to an AF Message.""" if item is None: return None @@ -89,7 +75,7 @@ def to_chat_message(self, item: ConversationItem) -> Optional[ChatMessage]: logger.debug("Unsupported conversation item type: %s", item_type) return None - def _convert_message_item(self, item: Any) -> Optional[ChatMessage]: + def _convert_message_item(self, item: Any) -> Optional[Message]: role_value = self._ROLE_MAP.get(str(getattr(item, "role", "user")).lower(), "user") raw_contents = getattr(item, "content", None) or [] @@ -103,9 +89,9 @@ def _convert_message_item(self, item: Any) -> Optional[ChatMessage]: if not converted_contents: return None - return ChatMessage(role=role_value, contents=converted_contents) + return Message(role=role_value, contents=converted_contents) - def _convert_tool_call_item(self, item: Any) -> Optional[ChatMessage]: + def _convert_tool_call_item(self, item: Any) -> Optional[Message]: data = self._model_dump(item) if not data: return None @@ -116,14 +102,14 @@ def _convert_tool_call_item(self, item: Any) -> Optional[ChatMessage]: arguments = self._extract_call_arguments(data) normalized_arguments = self._normalize_arguments(arguments) - content = FunctionCallContent( + content = self._content_from_function_call( call_id=call_id, name=name, arguments=normalized_arguments, ) - return ChatMessage(role="assistant", contents=[content]) + return Message(role="assistant", contents=[content]) - def _convert_tool_result_item(self, item: Any) -> Optional[ChatMessage]: + def _convert_tool_result_item(self, item: Any) -> Optional[Message]: data = self._model_dump(item) if not data: return None @@ -133,19 +119,19 @@ def _convert_tool_result_item(self, item: Any) -> Optional[ChatMessage]: if result_payload is None: return None - content = FunctionResultContent(call_id=call_id, result=result_payload) - return ChatMessage(role="tool", contents=[content]) + content = self._content_from_function_result(call_id=call_id, result=result_payload) + return Message(role="tool", contents=[content]) - def _convert_reasoning_item(self, item: Any) -> Optional[ChatMessage]: + def _convert_reasoning_item(self, item: Any) -> Optional[Message]: data = self._model_dump(item) summaries = data.get("summary", []) or [] content_items = data.get("content", []) or [] - reasoning_contents: list[TextReasoningContent] = [] + reasoning_contents: list[Any] = [] for content in content_items: text = content.get("text") if isinstance(content, Mapping) else None if text: - reasoning_contents.append(TextReasoningContent(text=text)) + reasoning_contents.append(self._content_from_text_reasoning(text)) summary_text = " \n".join( summary.get("text") @@ -161,7 +147,7 @@ def _convert_reasoning_item(self, item: Any) -> Optional[ChatMessage]: kwargs["text"] = summary_text if reasoning_contents: kwargs["contents"] = reasoning_contents - return ChatMessage(role="assistant", **kwargs) + return Message(role="assistant", **kwargs) def _convert_message_content(self, content: Any) -> Optional[Any]: content_type = str(getattr(content, "type", "")).lower() @@ -169,37 +155,44 @@ def _convert_message_content(self, content: Any) -> Optional[Any]: if content_type in {"input_text", "output_text", "text", "summary_text"}: text_value = getattr(content, "text", None) if text_value: - return TextContent(text=text_value) + return self._content_from_text(text_value) if content_type == "reasoning_text": text_value = getattr(content, "text", None) if text_value: - return TextReasoningContent(text=text_value) + return self._content_from_text_reasoning(text_value) if content_type == "refusal": refusal_text = getattr(content, "refusal", None) if refusal_text: - return TextContent(text=refusal_text) - - if content_type in {"input_image", "computer_screenshot"}: - file_id = getattr(content, "file_id", None) - image_url = getattr(content, "image_url", None) - if file_id: - return HostedFileContent(file_id=file_id) - if image_url: - return UriContent(uri=image_url, media_type="image/*") - - if content_type == "input_file": - file_id = getattr(content, "file_id", None) - if file_id: - return HostedFileContent(file_id=file_id) - file_url = getattr(content, "file_url", None) - file_data = getattr(content, "file_data", None) - if file_url or file_data: - return DataContent(uri=file_url, data=file_data, media_type="application/octet-stream") + return self._content_from_text(refusal_text) return None + def _content_from_text(self, text: str) -> Any: + factory = getattr(Content, "from_text", None) + if callable(factory): + return factory(text=text) + return Content(type="text", text=text) + + def _content_from_text_reasoning(self, text: str) -> Any: + factory = getattr(Content, "from_text_reasoning", None) + if callable(factory): + return factory(text=text) + return Content(type="text_reasoning", text=text) + + def _content_from_function_call(self, call_id: str, name: str, arguments: Any) -> Any: + factory = getattr(Content, "from_function_call", None) + if callable(factory): + return factory(call_id=call_id, name=name, arguments=arguments) + return Content(type="function_call", call_id=call_id, name=name, arguments=arguments) + + def _content_from_function_result(self, call_id: str, result: Any) -> Any: + factory = getattr(Content, "from_function_result", None) + if callable(factory): + return factory(call_id=call_id, result=result) + return Content(type="function_result", call_id=call_id, result=result) + def _extract_call_arguments(self, data: Mapping[str, Any]) -> Any: if data.get("arguments") not in (None, ""): return data.get("arguments") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py index 17710b2e537d..a3ee7b3e0d15 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py @@ -1,68 +1,71 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Any, List, Dict, Optional, Union import json +from typing import Any, Dict, List, Optional, Union -from agent_framework import ( - ChatMessage, - FunctionResultContent, - FunctionApprovalResponseContent, - RequestInfoEvent, - WorkflowCheckpoint, -) -from agent_framework._types import UserInputRequestContents +from agent_framework import Content, Message, WorkflowCheckpoint, WorkflowEvent from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.server.common.constants import HUMAN_IN_THE_LOOP_FUNCTION_NAME logger = get_logger() -class HumanInTheLoopHelper: - def get_pending_hitl_request(self, - thread_messages: List[ChatMessage] = None, - checkpoint: Optional[WorkflowCheckpoint] = None, - ) -> dict[str, Union[RequestInfoEvent, Any]]: - res = {} - # if has checkpoint (WorkflowAgent), find pending request info from checkpoint - if checkpoint and checkpoint.pending_request_info_events: - for call_id, request in checkpoint.pending_request_info_events.items(): - # find if the request is already responded in the thread messages - if isinstance(request, dict): - request_obj = RequestInfoEvent.from_dict(request) - res[call_id] = request_obj - return res +class HumanInTheLoopHelper: + def get_pending_hitl_request( + self, + thread_messages: List[Message] = None, + checkpoint: Optional[WorkflowCheckpoint] = None, + ) -> dict[str, Union[WorkflowEvent, Any]]: + res: dict[str, Union[WorkflowEvent, Any]] = {} + + if checkpoint: + pending_events = getattr(checkpoint, "pending_request_info_events", None) or getattr( + checkpoint, "pending_events", None + ) + if pending_events: + for call_id, request in pending_events.items(): + request_obj = self._coerce_workflow_event(request) + if request_obj is not None: + res[call_id] = request_obj + return res if not thread_messages: return res - # if no checkpoint (Agent), find user input request and pair the feedbacks for message in thread_messages: for content in message.contents: - if isinstance(content, UserInputRequestContents): - # is a human input request - function_call = content.function_call + if ( + getattr(content, "type", None) == "function_approval_request" + or getattr(content, "user_input_request", False) + ): + function_call = getattr(content, "function_call", None) + if function_call is None: + continue call_id = getattr(function_call, "call_id", "") if call_id: - res[call_id] = RequestInfoEvent( - source_executor_id="agent", - request_id=call_id, - response_type=None, - request_data=function_call, + res[call_id] = WorkflowEvent( + "request_info", + data={ + "source_executor_id": "agent", + "request_id": call_id, + "response_type": None, + "request_data": function_call, + }, ) - elif isinstance(content, FunctionResultContent): - if content.call_id and content.call_id in res: - # remove requests that already got feedback - res.pop(content.call_id) - elif isinstance(content, FunctionApprovalResponseContent): - function_call = content.function_call - call_id = getattr(function_call, "call_id", "") + elif content.type == "function_result": + call_id = getattr(content, "call_id", None) + if call_id and call_id in res: + res.pop(call_id) + elif content.type == "function_approval_response": + function_call = getattr(content, "function_call", None) + call_id = getattr(function_call, "call_id", "") if function_call else "" if call_id and call_id in res: res.pop(call_id) return res - def convert_user_input_request_content(self, content: UserInputRequestContents) -> dict: + def convert_user_input_request_content(self, content: Any) -> dict: function_call = content.function_call call_id = getattr(function_call, "call_id", "") arguments = self.convert_request_arguments(getattr(function_call, "arguments", "")) @@ -73,31 +76,30 @@ def convert_user_input_request_content(self, content: UserInputRequestContents) } def convert_request_arguments(self, arguments: Any) -> str: - # convert data to payload if possible if isinstance(arguments, dict): data = arguments.get("data") if data and hasattr(data, "convert_to_payload"): return data.convert_to_payload() if not isinstance(arguments, str): - if hasattr(arguments, "to_dict"): # agentframework models have to_dict method + if hasattr(arguments, "to_dict"): arguments = arguments.to_dict() try: arguments = json.dumps(arguments) - except Exception: # pragma: no cover - fallback # pylint: disable=broad-exception-caught + except Exception: # pragma: no cover # pylint: disable=broad-exception-caught arguments = str(arguments) return arguments - def validate_and_convert_hitl_response(self, - input: Union[str, List[Dict], None], - pending_requests: Dict[str, RequestInfoEvent], - ) -> Optional[List[ChatMessage]]: - + def validate_and_convert_hitl_response( + self, + input: Union[str, List[Dict], None], + pending_requests: Dict[str, WorkflowEvent], + ) -> Optional[List[Message]]: if input is None or isinstance(input, str): logger.warning("Expected list input for HitL response validation, got str.") return None - res = [] + res: list[Message] = [] for item in input: if item.get("type") != "function_call_output": logger.warning("Expected function_call_output type for HitL response validation.") @@ -107,47 +109,37 @@ def validate_and_convert_hitl_response(self, res.append(self.convert_response(pending_requests[call_id], item)) return res - def convert_response(self, hitl_request: RequestInfoEvent, input: Dict) -> ChatMessage: - response_type = hitl_request.response_type + def convert_response(self, hitl_request: WorkflowEvent, input: Dict) -> Message: + response_type = self._event_response_type(hitl_request) + request_id = self._event_request_id(hitl_request) response_result = input.get("output", "") - logger.info(f"response_type {type(response_type)}: %s", response_type) + logger.info("response_type %s", response_type) if response_type and hasattr(response_type, "convert_from_payload"): response_result = response_type.convert_from_payload(input.get("output", "")) - logger.info(f"response_result {type(response_result)}: %s", response_result) - response_content = FunctionResultContent( - call_id=hitl_request.request_id, + logger.info("response_result %s", response_result) + + response_content = self._content_from_function_result( + call_id=request_id, result=response_result, ) - return ChatMessage(role="tool", contents=[response_content]) - - def remove_hitl_content_from_thread(self, thread_messages: List[ChatMessage]) -> List[ChatMessage]: - """Remove HITL function call contents and related results from a conversation thread. - - HITL requests become ``function_call`` entries named ``HUMAN_IN_THE_LOOP_FUNCTION_NAME`` when converted - by the adapter. To avoid feeding those synthetic requests back into the agent, this method strips the - HITL function_calls and their placeholder outputs while preserving real tool invocations. + return Message(role="tool", contents=[response_content]) - :param thread_messages: The messages converted from the conversation API. - :type thread_messages: List[ChatMessage] - :return: Messages without HITL-specific artifacts. - :rtype: List[ChatMessage] - """ - filtered_messages = [] + def remove_hitl_content_from_thread(self, thread_messages: List[Message]) -> List[Message]: + """Remove HITL function call contents and related results from a conversation thread.""" + filtered_messages: list[Message] = [] prev_function_call = None prev_hitl_request = None prev_function_output = None - pending_tool_message = None + pending_tool_message: Optional[Message] = None for message in thread_messages: - filtered_contents = [] + filtered_contents: list[Any] = [] for content in message.contents: if content.type == "function_result": result_call_id = getattr(content, "call_id", "") if not prev_function_call: - if prev_hitl_request and prev_hitl_request.call_id == result_call_id: - # this is a hitl function result without the function call content, we can - # just skip it and wait for the next function call or result. + if prev_hitl_request and getattr(prev_hitl_request, "call_id", "") == result_call_id: prev_hitl_request = None else: logger.warning( @@ -155,14 +147,9 @@ def remove_hitl_content_from_thread(self, thread_messages: List[ChatMessage]) -> result_call_id, ) elif result_call_id == getattr(prev_function_call, "call_id", ""): - # prev_function_call is not None and call_id matches if prev_hitl_request: - # A HITL request may followed by one or two function result contents. - # if there are two, the last one is the real function result, the first - # one is just for recording the human feedback, we should remove it. prev_function_output = content else: - # function call without hitl result of the function call content filtered_contents.append(content) prev_function_call = None else: @@ -175,9 +162,6 @@ def remove_hitl_content_from_thread(self, thread_messages: List[ChatMessage]) -> prev_hitl_request = None else: if pending_tool_message: - # for the case mentioned above, we should append the real function result content - # when we see the next content, which means the function call and output cycle is ended. - # attach the real function result to the thread filtered_messages.append(pending_tool_message) pending_tool_message = None prev_function_call = None @@ -185,28 +169,62 @@ def remove_hitl_content_from_thread(self, thread_messages: List[ChatMessage]) -> prev_function_output = None if content.type == "function_call": - if content.name != HUMAN_IN_THE_LOOP_FUNCTION_NAME: + if getattr(content, "name", "") != HUMAN_IN_THE_LOOP_FUNCTION_NAME: filtered_contents.append(content) prev_function_call = content else: - # hitl request converted by adapter, skip this message. prev_hitl_request = content else: filtered_contents.append(content) if filtered_contents: - filtered_message = ChatMessage( - role=message.role, - contents=filtered_contents, - message_id=message.message_id, - additional_properties=message.additional_properties, + filtered_messages.append( + Message( + role=message.role, + contents=filtered_contents, + message_id=message.message_id, + additional_properties=message.additional_properties, + ) ) - filtered_messages.append(filtered_message) if prev_function_output: - pending_tool_message = ChatMessage( + pending_tool_message = Message( role="tool", contents=[prev_function_output], message_id=message.message_id, additional_properties=message.additional_properties, ) return filtered_messages + + def _coerce_workflow_event(self, request: Any) -> Optional[WorkflowEvent]: + if isinstance(request, WorkflowEvent): + return request + if isinstance(request, dict): + event_type = request.get("type") + event_data = request.get("data") + if event_type: + return WorkflowEvent(event_type, data=event_data) + return None + + def _event_data(self, event: WorkflowEvent) -> Any: + data = getattr(event, "data", None) + if data is None and isinstance(event, dict): + data = event.get("data") + return data + + def _event_request_id(self, event: WorkflowEvent) -> str: + data = self._event_data(event) + if isinstance(data, dict): + return data.get("request_id", "") + return getattr(data, "request_id", "") if data is not None else "" + + def _event_response_type(self, event: WorkflowEvent) -> Any: + data = self._event_data(event) + if isinstance(data, dict): + return data.get("response_type", None) + return getattr(data, "response_type", None) if data is not None else None + + def _content_from_function_result(self, call_id: str, result: Any) -> Any: + factory = getattr(Content, "from_function_result", None) + if callable(factory): + return factory(call_id=call_id, result=result) + return Content(type="function_result", call_id=call_id, result=result) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py index fdf3b2fbb2a3..d89f4f4f769f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py @@ -4,7 +4,7 @@ from __future__ import annotations from collections.abc import AsyncIterable, AsyncIterator, Callable -from typing import TypeVar, Optional, Tuple +from typing import Optional, Tuple, TypeVar TSource = TypeVar("TSource") TKey = TypeVar("TKey") @@ -76,7 +76,7 @@ def key_equal(a: TKey, b: TKey) -> bool: # type: ignore[no-redef] while has_pending: current_key = pending_key - async def inner() -> AsyncIterator[TSource]: + async def inner(current_key=current_key) -> AsyncIterator[TSource]: nonlocal pending, pending_key, has_pending # First element of the group diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py index d59d1154650c..a90ae8acd4db 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py @@ -1,16 +1,16 @@ +from ._foundry_checkpoint_repository import FoundryCheckpointRepository +from ._foundry_checkpoint_storage import FoundryCheckpointStorage from .agent_thread_repository import ( AgentThreadRepository, InMemoryAgentThreadRepository, - SerializedAgentThreadRepository, JsonLocalFileAgentThreadRepository, + SerializedAgentThreadRepository, ) from .checkpoint_repository import ( CheckpointRepository, - InMemoryCheckpointRepository, FileCheckpointRepository, + InMemoryCheckpointRepository, ) -from ._foundry_checkpoint_storage import FoundryCheckpointStorage -from ._foundry_checkpoint_repository import FoundryCheckpointRepository __all__ = [ "AgentThreadRepository", diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py index 96b80615445e..bd08a845b429 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py @@ -7,16 +7,16 @@ from typing import Dict, Optional, Union from agent_framework import CheckpointStorage -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential from azure.ai.agentserver.core.checkpoints.client import ( CheckpointSession, FoundryCheckpointClient, ) +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential -from .checkpoint_repository import CheckpointRepository from ._foundry_checkpoint_storage import FoundryCheckpointStorage +from .checkpoint_repository import CheckpointRepository logger = logging.getLogger(__name__) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py index 65e5b2f7b3c4..833c3647149a 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py @@ -90,54 +90,66 @@ async def load_checkpoint(self, checkpoint_id: str) -> Optional[WorkflowCheckpoi return checkpoint async def list_checkpoint_ids( - self, workflow_id: Optional[str] = None + self, workflow_name: Optional[str] = None, **kwargs ) -> List[str]: """List checkpoint IDs. - If workflow_id is provided, filter by that workflow. + If workflow_name is provided, filter by that workflow. - :param workflow_id: Optional workflow identifier for filtering. - :type workflow_id: Optional[str] + :param workflow_name: Optional workflow name for filtering. + :type workflow_name: Optional[str] :return: List of checkpoint identifiers. :rtype: List[str] """ item_ids = await self._client.list_item_ids(self._session_id) ids = [item_id.item_id for item_id in item_ids] - # Filter by workflow_id if provided - if workflow_id is not None: + # Backward-compat alias + if workflow_name is None: + workflow_name = kwargs.get("workflow_id") + + if workflow_name is not None: filtered_ids = [] for checkpoint_id in ids: checkpoint = await self.load_checkpoint(checkpoint_id) - if checkpoint and checkpoint.workflow_id == workflow_id: + if checkpoint and self._checkpoint_workflow_name(checkpoint) == workflow_name: filtered_ids.append(checkpoint_id) return filtered_ids return ids async def list_checkpoints( - self, workflow_id: Optional[str] = None + self, workflow_name: Optional[str] = None, **kwargs ) -> List[WorkflowCheckpoint]: """List checkpoint objects. - If workflow_id is provided, filter by that workflow. + If workflow_name is provided, filter by that workflow. - :param workflow_id: Optional workflow identifier for filtering. - :type workflow_id: Optional[str] + :param workflow_name: Optional workflow name for filtering. + :type workflow_name: Optional[str] :return: List of workflow checkpoints. :rtype: List[WorkflowCheckpoint] """ - ids = await self.list_checkpoint_ids(workflow_id=None) + if workflow_name is None: + workflow_name = kwargs.get("workflow_id") + + ids = await self.list_checkpoint_ids(workflow_name=None) checkpoints: List[WorkflowCheckpoint] = [] for checkpoint_id in ids: checkpoint = await self.load_checkpoint(checkpoint_id) if checkpoint is not None: - if workflow_id is None or checkpoint.workflow_id == workflow_id: + if workflow_name is None or self._checkpoint_workflow_name(checkpoint) == workflow_name: checkpoints.append(checkpoint) return checkpoints + def _checkpoint_workflow_name(self, checkpoint: WorkflowCheckpoint) -> Optional[str]: + workflow_name = getattr(checkpoint, "workflow_name", None) + if workflow_name: + return workflow_name + return getattr(checkpoint, "workflow_id", None) + async def delete_checkpoint(self, checkpoint_id: str) -> bool: """Delete a checkpoint by ID. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py index bd04df539faf..b796b7cd0804 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py @@ -4,22 +4,23 @@ from collections.abc import MutableMapping, Sequence from typing import Any, List, Optional -from agent_framework import ChatMessage -from azure.ai.projects import AIProjectClient +from agent_framework import Message from azure.ai.agentserver.core.logger import get_logger +from azure.ai.projects import AIProjectClient from ..models.conversation_converters import ConversationItemConverter from ..models.human_in_the_loop_helper import HumanInTheLoopHelper logger = get_logger() + class FoundryConversationMessageStore: - """A ChatMessageStoreProtocol implementation that reads messages from Azure AI Foundry Conversations API. + """Message store that reads messages from Azure AI Foundry Conversations API. This message store fetches messages from the Foundry Conversations API and converts them - to ChatMessage format. Messages added via add_messages() are cached locally but not - persisted back to the API. + to Agent Framework ``Message`` objects. Messages added via add_messages() are cached + locally but not persisted back to the API. :param conversation_id: The conversation ID to fetch messages from. :type conversation_id: str @@ -30,73 +31,39 @@ class FoundryConversationMessageStore: def __init__( self, conversation_id: str, - project_client: AIProjectClient + project_client: AIProjectClient, ) -> None: - """Initialize the FoundryConversationMessageStore. - - :param conversation_id: The conversation ID to fetch messages from. - :type conversation_id: str - :param project_client: The Azure AI Project client used to retrieve conversation history. - :type project_client: AIProjectClient - """ self._conversation_id = conversation_id self._project_client = project_client - self._retrieved_messages: list[ChatMessage] = [] - self._cached_messages: list[ChatMessage] = [] - - - async def list_messages(self) -> list[ChatMessage]: - """Get all messages from the conversation, including cached messages. - - Fetches messages from the Foundry Conversations API, converts them to ChatMessage format, - and combines them with any locally cached messages. + self._retrieved_messages: list[Message] = [] + self._cached_messages: list[Message] = [] - :return: List of ChatMessage objects, ordered from oldest to newest. - :rtype: list[ChatMessage] - """ + async def list_messages(self) -> list[Message]: + """Get all messages from the conversation, including cached messages.""" return self._retrieved_messages + self._cached_messages - async def add_messages(self, messages: Sequence[ChatMessage]) -> None: - """Add messages to the local cache. - - Messages are cached locally but not persisted to the API. - - :param messages: The sequence of ChatMessage objects to add. - :type messages: Sequence[ChatMessage] - """ + async def add_messages(self, messages: Sequence[Message]) -> None: + """Add messages to the local cache.""" self._cached_messages.extend(messages) @classmethod - async def deserialize( # pylint: disable=unused-argument + async def deserialize( # pylint: disable=unused-argument cls, serialized_store_state: MutableMapping[str, Any], *, project_client: Optional[AIProjectClient] = None, **kwargs: Any, ) -> "FoundryConversationMessageStore": - """Create a new FoundryConversationMessageStore instance from serialized state. - - :param serialized_store_state: The serialized state data. - :type serialized_store_state: MutableMapping[str, Any] - :keyword project_client: The AIProjectClient instance to use for API interactions. - :paramtype project_client: Optional[AIProjectClient] - :return: A new FoundryConversationMessageStore instance. - :rtype: FoundryConversationMessageStore - :raises ValueError: If required parameters are missing. - """ - conversation_id = serialized_store_state.get("conversation_id") if not conversation_id: raise ValueError("conversation_id is required in serialized state") store = cls( conversation_id=conversation_id, - project_client=project_client + project_client=project_client, ) - # Restore cached messages await store.update_from_state(serialized_store_state) - return store async def update_from_state( # pylint: disable=unused-argument @@ -104,42 +71,30 @@ async def update_from_state( # pylint: disable=unused-argument serialized_store_state: MutableMapping[str, Any], **kwargs: Any, ) -> None: - """Update the current store instance from serialized state data. - - :param serialized_store_state: The serialized state data. - :type serialized_store_state: MutableMapping[str, Any] - """ if not serialized_store_state: return - # Update cached messages cached_messages_data = serialized_store_state.get("messages", []) self._cached_messages = [] for msg_data in cached_messages_data: if isinstance(msg_data, dict): - self._cached_messages.append(ChatMessage.from_dict(msg_data)) - elif isinstance(msg_data, ChatMessage): + self._cached_messages.append(Message.from_dict(msg_data)) + elif isinstance(msg_data, Message): self._cached_messages.append(msg_data) await self.retrieve_messages() async def serialize(self, **kwargs: Any) -> dict[str, Any]: # pylint: disable=unused-argument - """Serialize the current store state. - - :return: The serialized state data containing conversation_id and cached messages. - :rtype: dict[str, Any] - """ return { "conversation_id": self._conversation_id, "messages": [msg.to_dict() for msg in self._cached_messages], } - async def retrieve_messages(self): + async def retrieve_messages(self) -> None: history_messages = await self._get_conversation_history() filtered_messages = HumanInTheLoopHelper().remove_hitl_content_from_thread(history_messages or []) self._retrieved_messages = filtered_messages - async def _get_conversation_history(self) -> List[ChatMessage]: - # Retrieve conversation history from Foundry. + async def _get_conversation_history(self) -> List[Message]: if not self._project_client: logger.error("AIProjectClient is not configured; cannot load conversation history.") return [] @@ -148,7 +103,7 @@ async def _get_conversation_history(self) -> List[ChatMessage]: converter = ConversationItemConverter() async with self._project_client.get_openai_client() as openai_client: raw_items = await openai_client.conversations.items.list(self._conversation_id) - retrieved_messages: list[ChatMessage] = [] + retrieved_messages: list[Message] = [] if raw_items is None: self._retrieved_messages = [] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py index ffaf66d92252..9d57bfb6d0b7 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py @@ -3,24 +3,28 @@ # --------------------------------------------------------- from typing import Optional, Union -from agent_framework import AgentThread, AgentProtocol, WorkflowAgent -from azure.core.credentials_async import AsyncTokenCredential -from azure.core.credentials import TokenCredential -from azure.ai.projects.aio import AIProjectClient +from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent from azure.ai.agentserver.core.logger import get_logger +from azure.ai.projects.aio import AIProjectClient +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential -from .agent_thread_repository import AgentThreadRepository from ._foundry_conversation_message_store import FoundryConversationMessageStore +from .agent_thread_repository import AgentThreadRepository logger = get_logger() + class FoundryConversationThreadRepository(AgentThreadRepository): """A Foundry Conversation implementation of AgentThreadRepository.""" - def __init__(self, - agent: Optional[Union[AgentProtocol, WorkflowAgent]], - project_endpoint: str, - credential: Union[TokenCredential, AsyncTokenCredential]) -> None: + + def __init__( + self, + agent: Optional[Union[SupportsAgentRun, WorkflowAgent]], + project_endpoint: str, + credential: Union[TokenCredential, AsyncTokenCredential], + ) -> None: self._agent = agent if not project_endpoint or not credential: raise ValueError( @@ -28,21 +32,21 @@ def __init__(self, "FoundryConversationThreadRepository." ) self._client = AIProjectClient(project_endpoint, credential) - self._inventory: dict[str, AgentThread] = {} + self._inventory: dict[str, AgentSession] = {} async def get( self, conversation_id: Optional[str], - agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, - ) -> Optional[AgentThread]: + agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, + ) -> Optional[AgentSession]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param agent: The agent instance. It will be used for in-memory repository for interface consistency. - :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] - :return: The saved AgentThread if available, None otherwise. - :rtype: Optional[AgentThread] + :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] + :return: The saved AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] """ if not conversation_id: return None @@ -54,22 +58,20 @@ async def get( self._inventory[conversation_id] = FoundryConversationThread(message_store=message_store) return self._inventory[conversation_id] - async def set(self, - conversation_id: Optional[str], - thread: AgentThread) -> None: + async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: """Save the thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param thread: The thread to save. - :type thread: AgentThread + :type thread: AgentSession """ if not conversation_id: - raise ValueError("conversation_id is required to save an AgentThread.") + raise ValueError("conversation_id is required to save an AgentSession.") self._inventory[conversation_id] = thread -class FoundryConversationThread(AgentThread): +class FoundryConversationThread(AgentSession): @property def service_thread_id(self) -> str | None: return self._service_thread_id diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py index 2a43dbb5aee8..7fa702d59c7c 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -1,18 +1,18 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from abc import ABC, abstractmethod import json import os +from abc import ABC, abstractmethod from typing import Any, Optional, Union -from agent_framework import AgentThread, AgentProtocol, WorkflowAgent +from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent class AgentThreadRepository(ABC): """ - AgentThread repository to manage saved thread messages of agent threads and workflows. - + Repository to manage persisted agent session state for conversations and workflows. + :meta private: """ @@ -20,48 +20,49 @@ class AgentThreadRepository(ABC): async def get( self, conversation_id: Optional[str], - agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, - ) -> Optional[AgentThread]: + agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, + ) -> Optional[AgentSession]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param agent: The agent instance. If provided, it can be used to deserialize the thread. - :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] + :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] - :return: The saved AgentThread if available, None otherwise. - :rtype: Optional[AgentThread] + :return: The saved AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] """ @abstractmethod - async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: """Save the thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param thread: The thread to save. - :type thread: AgentThread + :type thread: AgentSession """ class InMemoryAgentThreadRepository(AgentThreadRepository): """In-memory implementation of AgentThreadRepository.""" + def __init__(self) -> None: - self._inventory: dict[str, AgentThread] = {} + self._inventory: dict[str, AgentSession] = {} async def get( self, conversation_id: Optional[str], - agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, - ) -> Optional[AgentThread]: + agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, + ) -> Optional[AgentSession]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param agent: The agent instance. It will be used for in-memory repository for interface consistency. - :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] - :return: The saved AgentThread if available, None otherwise. - :rtype: Optional[AgentThread] + :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] + :return: The saved AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] """ if not conversation_id: return None @@ -69,13 +70,13 @@ async def get( return self._inventory[conversation_id] return None - async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: """Save the thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param thread: The thread to save. - :type thread: AgentThread + :type thread: AgentSession """ if not conversation_id or not thread: return @@ -83,48 +84,49 @@ async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None class SerializedAgentThreadRepository(AgentThreadRepository): - """Implementation of AgentThreadRepository with AgentThread serialization.""" - def __init__(self, agent: AgentProtocol) -> None: + """Implementation of AgentThreadRepository with AgentSession serialization.""" + + def __init__(self, agent: SupportsAgentRun) -> None: """ Initialize the repository with the given agent. :param agent: The agent instance. - :type agent: AgentProtocol + :type agent: SupportsAgentRun """ self._agent = agent async def get( self, conversation_id: Optional[str], - agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, - ) -> Optional[AgentThread]: + agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, + ) -> Optional[AgentSession]: """Retrieve the saved thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param agent: The agent instance. If provided, it can be used to deserialize the thread. Otherwise, the repository's agent will be used. - :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] + :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] - :return: The saved AgentThread if available, None otherwise. - :rtype: Optional[AgentThread] + :return: The saved AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] """ if not conversation_id: return None serialized_thread = await self.read_from_storage(conversation_id) if serialized_thread: agent_to_use = agent or self._agent - thread = await agent_to_use.deserialize_thread(serialized_thread) + thread = await agent_to_use.deserialize_session(serialized_thread) return thread return None - async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: """Save the thread for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] :param thread: The thread to save. - :type thread: AgentThread + :type thread: AgentSession """ if not conversation_id: return @@ -157,7 +159,8 @@ async def write_to_storage(self, conversation_id: Optional[str], serialized_thre class JsonLocalFileAgentThreadRepository(SerializedAgentThreadRepository): """Json based implementation of AgentThreadRepository using local file storage.""" - def __init__(self, agent: AgentProtocol, storage_path: str) -> None: + + def __init__(self, agent: SupportsAgentRun, storage_path: str) -> None: super().__init__(agent) self._storage_path = storage_path os.makedirs(self._storage_path, exist_ok=True) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py index 9848d01f6b10..7fa0c6d63670 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py @@ -1,16 +1,17 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from abc import ABC, abstractmethod import os +from abc import ABC, abstractmethod from typing import Optional from agent_framework import ( CheckpointStorage, - InMemoryCheckpointStorage, FileCheckpointStorage, + InMemoryCheckpointStorage, ) + class CheckpointRepository(ABC): """ Repository interface for storing and retrieving checkpoints. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml index 88b4b4503e14..48870bf423c5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml @@ -22,10 +22,10 @@ keywords = ["azure", "azure sdk"] dependencies = [ "azure-ai-agentserver-core==1.0.0b14", - "agent-framework-azure-ai>=1.0.0b251112,<=1.0.0b260107", - "agent-framework-core>=1.0.0b251112,<=1.0.0b260107", + "agent-framework-azure-ai==1.0.0rc1", + "agent-framework-core==1.0.0rc1", "opentelemetry-exporter-otlp-proto-grpc>=1.36.0,<=1.39.0", - "opentelemetry-semantic-conventions-ai==0.4.13", # imported by af, will remove after af fixed breaking change + "opentelemetry-semantic-conventions-ai==0.4.13", "azure-ai-projects==2.0.0b3" ] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py index 2ea0f19dd32a..ca37d1db20fc 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py @@ -6,10 +6,10 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import DefaultAzureCredential from dotenv import load_dotenv -load_dotenv() from azure.ai.agentserver.agentframework import from_agent_framework +load_dotenv() def get_weather( @@ -21,7 +21,7 @@ def get_weather( def main() -> None: - agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).create_agent( + agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).as_agent( instructions="You are a helpful weather agent.", tools=get_weather, ) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md index d9fe177e850f..af70891ff19b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md @@ -2,7 +2,7 @@ This sample demonstrates how to attach `FoundryToolsChatMiddleware` to an Agent Framework chat client so that: -- Foundry tools configured in your Azure AI Project are converted into Agent Framework `AIFunction` tools. +- Foundry tools configured in your Azure AI Project are converted into Agent Framework `FunctionTool` tools. - The tools are injected automatically for each agent run. ## What this sample does @@ -61,7 +61,7 @@ agent = AzureOpenAIChatClient( middleware=FoundryToolsChatMiddleware( tools=[{"type": "web_search_preview"}, {"type": "mcp", "project_connection_id": tool_connection_id}], ), -).create_agent( +).as_agent( name="FoundryToolAgent", instructions="You are a helpful assistant with access to various tools.", ) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py index d8c75259d29b..e31c28d64f8b 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py @@ -2,33 +2,37 @@ """Example showing how to use an agent factory function with ToolClient. This sample demonstrates how to pass a factory function to from_agent_framework -that receives a ToolClient and returns an AgentProtocol. This pattern allows -the agent to be created dynamically with access to tools from Azure AI Tool -Client at runtime. +that receives a ToolClient and returns a SupportsAgentRun-compatible instance. +This pattern allows the agent to be created dynamically with access to tools from +Azure AI Tool Client at runtime. """ import os -from dotenv import load_dotenv -from agent_framework.azure import AzureOpenAIChatClient -from azure.ai.agentserver.agentframework import from_agent_framework, FoundryToolsChatMiddleware +from agent_framework.azure import AzureOpenAIChatClient from azure.identity import DefaultAzureCredential +from dotenv import load_dotenv + +from azure.ai.agentserver.agentframework import FoundryToolsChatMiddleware, from_agent_framework load_dotenv() + def main(): tool_connection_id = os.getenv("AZURE_AI_PROJECT_TOOL_CONNECTION_ID") agent = AzureOpenAIChatClient( - credential=DefaultAzureCredential(), + credential=DefaultAzureCredential(), middleware=FoundryToolsChatMiddleware( tools=[{"type": "web_search_preview"}, {"type": "mcp", "project_connection_id": tool_connection_id}] - )).create_agent( - name="FoundryToolAgent", - instructions="You are a helpful assistant with access to various tools.", - ) + ), + ).as_agent( + name="FoundryToolAgent", + instructions="You are a helpful assistant with access to various tools.", + ) from_agent_framework(agent).run() + if __name__ == "__main__": main() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md index d16f5fea405c..10c92b837663 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md @@ -22,7 +22,7 @@ AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= ## Thread persistence -The sample uses `JsonLocalFileAgentThreadRepository` for `AgentThread` persistence, creating a JSON file per conversation ID under the sample directory. An in-memory alternative, `InMemoryAgentThreadRepository`, lives in the `azure.ai.agentserver.agentframework.persistence` module. +The sample uses `JsonLocalFileAgentThreadRepository` for `AgentSession` persistence, creating a JSON file per conversation ID under the sample directory. An in-memory alternative, `InMemoryAgentThreadRepository`, lives in the `azure.ai.agentserver.agentframework.persistence` module. To store thread messages elsewhere, inherit from `SerializedAgentThreadRepository` and override the following methods: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py index 56dc5fca8860..c3df20e9ecd8 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py @@ -1,62 +1,25 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any, Collection -from dotenv import load_dotenv - -load_dotenv() +from typing import Annotated -from agent_framework import ChatAgent, ChatMessage, ChatMessageStoreProtocol, ai_function -from agent_framework._threads import ChatMessageStoreState +from agent_framework import ChatAgent, tool from agent_framework.azure import AzureOpenAIChatClient +from dotenv import load_dotenv from azure.ai.agentserver.agentframework import from_agent_framework from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import JsonLocalFileAgentThreadRepository """ -Tool Approvals with Threads +Tool Approvals with Sessions -This sample demonstrates using tool approvals with threads. -With threads, you don't need to manually pass previous messages - -the thread stores and retrieves them automatically. +This sample demonstrates using tool approvals with persisted sessions. """ -class CustomChatMessageStore(ChatMessageStoreProtocol): - """Implementation of custom chat message store. - In real applications, this can be an implementation of relational database or vector store.""" - - def __init__(self, messages: Collection[ChatMessage] | None = None) -> None: - self._messages: list[ChatMessage] = [] - if messages: - self._messages.extend(messages) - - async def add_messages(self, messages: Collection[ChatMessage]) -> None: - self._messages.extend(messages) - - async def list_messages(self) -> list[ChatMessage]: - return self._messages - - @classmethod - async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "CustomChatMessageStore": - """Create a new instance from serialized state.""" - store = cls() - await store.update_from_state(serialized_store_state, **kwargs) - return store - - async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None: - """Update this instance from serialized state.""" - if serialized_store_state: - state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs) - if state.messages: - self._messages.extend(state.messages) - - async def serialize(self, **kwargs: Any) -> Any: - """Serialize this store's state.""" - state = ChatMessageStoreState(messages=self._messages) - return state.to_dict(**kwargs) +load_dotenv() -@ai_function(approval_mode="always_require") +@tool(approval_mode="always_require") def add_to_calendar( event_name: Annotated[str, "Name of the event"], date: Annotated[str, "Date of the event"] ) -> str: @@ -65,13 +28,12 @@ def add_to_calendar( return f"Added '{event_name}' to calendar on {date}" -def build_agent(): +def build_agent() -> ChatAgent: return ChatAgent( chat_client=AzureOpenAIChatClient(), name="CalendarAgent", instructions="You are a helpful calendar assistant.", tools=[add_to_calendar], - chat_message_store_factory=CustomChatMessageStore, ) @@ -80,5 +42,6 @@ async def main() -> None: thread_repository = JsonLocalFileAgentThreadRepository(agent=agent, storage_path="./thread_storage") await from_agent_framework(agent, thread_repository=thread_repository).run_async() + if __name__ == "__main__": asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py index cc89c941e65e..cad0550689db 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py @@ -5,10 +5,6 @@ from dataclasses import dataclass from typing import Any -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential -from dotenv import load_dotenv - from agent_framework import ( # noqa: E402 Executor, WorkflowBuilder, @@ -16,6 +12,9 @@ handler, response_handler, ) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv from workflow_as_agent_reflection_pattern import ( # noqa: E402 ReviewRequest, ReviewResponse, @@ -27,6 +26,7 @@ load_dotenv() + @dataclass class HumanReviewRequest: """A request message type for escalation to a human reviewer.""" @@ -34,7 +34,6 @@ class HumanReviewRequest: agent_request: ReviewRequest | None = None def convert_to_payload(self) -> str: - """Convert the HumanReviewRequest to a payload string.""" request = self.agent_request payload: dict[str, Any] = {"agent_request": None} @@ -58,13 +57,9 @@ def __init__(self, worker_id: str, reviewer_id: str | None = None) -> None: @handler async def review(self, request: ReviewRequest, ctx: WorkflowContext) -> None: - # In this simplified example, we always escalate to a human manager. - # See workflow_as_agent_reflection.py for an implementation - # using an automated agent to make the review decision. print(f"Reviewer: Evaluating response for request {request.request_id[:8]}...") print("Reviewer: Escalating to human manager...") - # Forward the request to a human manager by sending a HumanReviewRequest. await ctx.request_info( request_data=HumanReviewRequest(agent_request=request), response_type=ReviewResponse, @@ -77,43 +72,29 @@ async def accept_human_review( response: ReviewResponse, ctx: WorkflowContext[ReviewResponse], ) -> None: - # Accept the human review response and forward it back to the Worker. print(f"Reviewer: Accepting human review for request {response.request_id[:8]}...") print(f"Reviewer: Human feedback: {response.feedback}") print(f"Reviewer: Human approved: {response.approved}") print("Reviewer: Forwarding human review back to worker...") await ctx.send_message(response, target_id=self._worker_id) -def create_builder(): - # Build a workflow with bidirectional communication between Worker and Reviewer, - # and escalation paths for human review. - builder = ( - WorkflowBuilder() - .register_executor( - lambda: Worker( - id="sub-worker", - chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), - ), - name="worker", - ) - .register_executor( - lambda: ReviewerWithHumanInTheLoop(worker_id="sub-worker"), - name="reviewer", - ) - .add_edge("worker", "reviewer") # Worker sends requests to Reviewer - .add_edge("reviewer", "worker") # Reviewer sends feedback to Worker - .set_start_executor("worker") + +def create_builder() -> WorkflowBuilder: + worker = Worker( + id="sub-worker", + chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), ) - return builder + reviewer = ReviewerWithHumanInTheLoop(worker_id="sub-worker") + return WorkflowBuilder(start_executor=worker).add_edge(worker, reviewer).add_edge(reviewer, worker) async def run_agent() -> None: - """Run the workflow inside the agent server adapter.""" builder = create_builder() await from_agent_framework( - builder, # pass workflow builder to adapter - checkpoint_repository=FileCheckpointRepository(storage_path="./checkpoints"), # for checkpoint storage + builder, + checkpoint_repository=FileCheckpointRepository(storage_path="./checkpoints"), ).run_async() + if __name__ == "__main__": - asyncio.run(run_agent()) \ No newline at end of file + asyncio.run(run_agent()) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py index ef2a286ba174..0587f79acbbd 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py @@ -1,28 +1,28 @@ # Copyright (c) Microsoft. All rights reserved. -from dataclasses import dataclass import json +from dataclasses import dataclass from uuid import uuid4 from agent_framework import ( - AgentRunResponseUpdate, - AgentRunUpdateEvent, - ChatClientProtocol, - ChatMessage, - Contents, + AgentResponseUpdate, + Content, Executor, - Role, + Message, + SupportsChatGetResponse, WorkflowContext, + WorkflowEvent, handler, ) + @dataclass class ReviewRequest: """Structured request passed from Worker to Reviewer for evaluation.""" request_id: str - user_messages: list[ChatMessage] - agent_messages: list[ChatMessage] + user_messages: list[Message] + agent_messages: list[Message] @dataclass @@ -35,7 +35,6 @@ class ReviewResponse: @staticmethod def convert_from_payload(payload: str) -> "ReviewResponse": - """Convert a JSON payload string to a ReviewResponse instance.""" data = json.loads(payload) return ReviewResponse( request_id=data["request_id"], @@ -44,38 +43,34 @@ def convert_from_payload(payload: str) -> "ReviewResponse": ) -PendingReviewState = tuple[ReviewRequest, list[ChatMessage]] +PendingReviewState = tuple[ReviewRequest, list[Message]] class Worker(Executor): """Executor that generates responses and incorporates feedback when necessary.""" - def __init__(self, id: str, chat_client: ChatClientProtocol) -> None: + def __init__(self, id: str, chat_client: SupportsChatGetResponse) -> None: super().__init__(id=id) self._chat_client = chat_client self._pending_requests: dict[str, PendingReviewState] = {} @handler - async def handle_user_messages(self, user_messages: list[ChatMessage], ctx: WorkflowContext[ReviewRequest]) -> None: + async def handle_user_messages(self, user_messages: list[Message], ctx: WorkflowContext[ReviewRequest]) -> None: print("Worker: Received user messages, generating response...") - # Initialize chat with system prompt. - messages = [ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant.")] + messages = [Message(role="system", contents=[Content.from_text("You are a helpful assistant.")])] messages.extend(user_messages) print("Worker: Calling LLM to generate response...") response = await self._chat_client.get_response(messages=messages) print(f"Worker: Response generated: {response.messages[-1].text}") - # Add agent messages to context. messages.extend(response.messages) - # Create review request and send to Reviewer. request = ReviewRequest(request_id=str(uuid4()), user_messages=user_messages, agent_messages=response.messages) print(f"Worker: Sending response for review (ID: {request.request_id[:8]})") await ctx.send_message(request) - # Track request for possible retry. self._pending_requests[request.request_id] = (request, messages) @handler @@ -89,51 +84,46 @@ async def handle_review_response(self, review: ReviewResponse, ctx: WorkflowCont if review.approved: print("Worker: Response approved. Emitting to external consumer...") - contents: list[Contents] = [] + contents: list[Content] = [] for message in request.agent_messages: contents.extend(message.contents) - # Emit approved result to external consumer via AgentRunUpdateEvent. await ctx.add_event( - AgentRunUpdateEvent(self.id, data=AgentRunResponseUpdate(contents=contents, role=Role.ASSISTANT)) + WorkflowEvent( + "output", + data=AgentResponseUpdate(contents=contents, role="assistant"), + ) ) return print(f"Worker: Response not approved. Feedback: {review.feedback}") print("Worker: Regenerating response with feedback...") - # Incorporate review feedback. - messages.append(ChatMessage(role=Role.SYSTEM, text=review.feedback)) + messages.append(Message(role="system", contents=[Content.from_text(review.feedback)])) messages.append( - ChatMessage(role=Role.SYSTEM, text="Please incorporate the feedback and regenerate the response.") + Message( + role="system", + contents=[Content.from_text("Please incorporate the feedback and regenerate the response.")], + ) ) messages.extend(request.user_messages) - # Retry with updated prompt. response = await self._chat_client.get_response(messages=messages) print(f"Worker: New response generated: {response.messages[-1].text}") messages.extend(response.messages) - # Send updated request for re-review. new_request = ReviewRequest( - request_id=review.request_id, user_messages=request.user_messages, agent_messages=response.messages + request_id=review.request_id, + user_messages=request.user_messages, + agent_messages=response.messages, ) await ctx.send_message(new_request) - # Track new request for further evaluation. self._pending_requests[new_request.request_id] = (new_request, messages) async def on_checkpoint_save(self) -> dict: - """ - Persist pending requests during checkpointing. - In memory implementation for demonstration purposes. - """ return {"pending_requests": self._pending_requests} async def on_checkpoint_restore(self, data: dict) -> None: - """ - Load pending requests from checkpoint data. - In memory implementation for demonstration purposes. - """ - self._pending_requests = data.get("pending_requests", {}) \ No newline at end of file + self._pending_requests = data.get("pending_requests", {}) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py index 985d7fd01e0c..00aca8d68522 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_apikey/mcp_apikey.py @@ -23,15 +23,15 @@ async def main() -> None: "GITHUB_TOKEN environment variable not set. Provide a GitHub token with MCP access." ) - agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).create_agent( + agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).as_agent( instructions="You are a helpful assistant that answers GitHub questions. Use only the exposed MCP tools.", - tools=MCPStreamableHTTPTool( + tools=[MCPStreamableHTTPTool( # type: ignore[list-item] name=MCP_TOOL_NAME, url=MCP_TOOL_URL, headers={ "Authorization": f"Bearer {github_token}", }, - ), + )], ) async with agent: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py index 6b59771fe0da..7fcc914816b5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/mcp_simple/mcp_simple.py @@ -16,9 +16,9 @@ async def main() -> None: - agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).create_agent( + agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).as_agent( instructions="You are a helpful assistant that answers Microsoft documentation questions.", - tools=MCPStreamableHTTPTool(name=MCP_TOOL_NAME, url=MCP_TOOL_URL), + tools=[MCPStreamableHTTPTool(name=MCP_TOOL_NAME, url=MCP_TOOL_URL)], # type: ignore[list-item] ) async with agent: diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py index 4c69c8afa84d..74c1decfb997 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/simple_async/minimal_async_example.py @@ -22,7 +22,7 @@ def get_weather( async def main() -> None: - agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).create_agent( + agent = AzureOpenAIChatClient(credential=DefaultAzureCredential()).as_agent( instructions="You are a helpful weather agent.", tools=get_weather, ) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py index 5de214c9ff09..5373e4fe9904 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py @@ -1,51 +1,293 @@ +# Copyright (c) Microsoft. All rights reserved. + import asyncio +from dataclasses import dataclass +from uuid import uuid4 + +from agent_framework import ( + AgentResponseUpdate, + Content, + Executor, + Message, + SupportsChatGetResponse, + WorkflowBuilder, + WorkflowContext, + WorkflowEvent, + handler, +) +from agent_framework_azure_ai import AzureAIAgentClient +from azure.identity.aio import DefaultAzureCredential from dotenv import load_dotenv +from pydantic import BaseModel + +from azure.ai.agentserver.agentframework import from_agent_framework + +""" +The following sample demonstrates how to wrap a workflow as an agent using WorkflowAgent. + +This sample shows how to: +1. Create a workflow with a reflection pattern (Worker + Reviewer executors) +2. Wrap the workflow as an agent using the .as_agent() method +3. Stream responses from the workflow agent like a regular agent +4. Implement a review-retry mechanism where responses are iteratively improved + +The example implements a quality-controlled AI assistant where: +- Worker executor generates responses to user queries +- Reviewer executor evaluates the responses and provides feedback +- If not approved, the Worker incorporates feedback and regenerates the response +- The cycle continues until the response is approved +- Only approved responses are emitted to the external consumer + +Key concepts demonstrated: +- WorkflowAgent: Wraps a workflow to make it behave as an agent +- Bidirectional workflow with cycles (Worker ↔ Reviewer) +- WorkflowEvent: How workflows communicate with external consumers +- Structured output parsing for review feedback +- State management with pending requests tracking +""" + + +@dataclass +class ReviewRequest: + request_id: str + user_messages: list[Message] + agent_messages: list[Message] + + +@dataclass +class ReviewResponse: + request_id: str + feedback: str + approved: bool + load_dotenv() -from agent_framework import ChatAgent, Workflow, WorkflowBuilder -from agent_framework.azure import AzureAIAgentClient -from azure.identity.aio import AzureCliCredential -from azure.ai.agentserver.agentframework import from_agent_framework +class Reviewer(Executor): + """An executor that reviews messages and provides feedback.""" -def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: - return client.create_agent( - name="Writer", - instructions=( - "You are an excellent content writer. You create new content and edit contents based on the feedback." - ), - ) + def __init__(self, chat_client: SupportsChatGetResponse) -> None: + super().__init__(id="reviewer") + self._chat_client = chat_client + + @handler + async def review( + self, request: ReviewRequest, ctx: WorkflowContext[ReviewResponse] + ) -> None: + print( + f"🔍 Reviewer: Evaluating response for request {request.request_id[:8]}..." + ) + + # Use the chat client to review the message and use structured output. + # NOTE: this can be modified to use an evaluation framework. + + class _Response(BaseModel): + feedback: str + approved: bool + + # Define the system prompt. + messages = [ + Message( + role="system", + contents=[Content.from_text( + "You are a reviewer for an AI agent, please provide feedback on the " + "following exchange between a user and the AI agent, " + "and indicate if the agent's responses are approved or not.\n" + "Use the following criteria for your evaluation:\n" + "- Relevance: Does the response address the user's query?\n" + "- Accuracy: Is the information provided correct?\n" + "- Clarity: Is the response easy to understand?\n" + "- Completeness: Does the response cover all aspects of the query?\n" + "Be critical in your evaluation and provide constructive feedback.\n" + "Do not approve until all criteria are met." + )], + ) + ] + + # Add user and agent messages to the chat history. + messages.extend(request.user_messages) + + # Add agent messages to the chat history. + messages.extend(request.agent_messages) + + # Add add one more instruction for the assistant to follow. + messages.append( + Message( + role="user", + contents=[Content.from_text("Please provide a review of the agent's responses to the user.")], + ) + ) + + print("🔍 Reviewer: Sending review request to LLM...") + # Get the response from the chat client. + response = await self._chat_client.get_response( + messages=messages, options={"response_format": _Response} + ) + + # Parse the response. + parsed = _Response.model_validate_json(response.messages[-1].text) + + print(f"🔍 Reviewer: Review complete - Approved: {parsed.approved}") + print(f"🔍 Reviewer: Feedback: {parsed.feedback}") + + # Send the review response. + await ctx.send_message( + ReviewResponse( + request_id=request.request_id, + feedback=parsed.feedback, + approved=parsed.approved, + ) + ) + + +class Worker(Executor): + """An executor that performs tasks for the user.""" + + def __init__(self, chat_client: SupportsChatGetResponse) -> None: + super().__init__(id="worker") + self._chat_client = chat_client + self._pending_requests: dict[str, tuple[ReviewRequest, list[Message]]] = {} + + @handler + async def handle_user_messages( + self, user_messages: list[Message], ctx: WorkflowContext[ReviewRequest] + ) -> None: + print("🔧 Worker: Received user messages, generating response...") + + # Handle user messages and prepare a review request for the reviewer. + # Define the system prompt. + messages = [ + Message(role="system", contents=[Content.from_text("You are a helpful assistant.")]) + ] + + # Add user messages. + messages.extend(user_messages) + + print("🔧 Worker: Calling LLM to generate response...") + # Get the response from the chat client. + response = await self._chat_client.get_response(messages=messages) + print(f"🔧 Worker: Response generated: {response.messages[-1].text}") + + # Add agent messages. + messages.extend(response.messages) + + # Create the review request. + request = ReviewRequest( + request_id=str(uuid4()), + user_messages=user_messages, + agent_messages=response.messages, + ) + + print( + f"🔧 Worker: Generated response, sending to reviewer (ID: {request.request_id[:8]})" + ) + # Send the review request. + await ctx.send_message(request) + + # Add to pending requests. + self._pending_requests[request.request_id] = (request, messages) + + @handler + async def handle_review_response( + self, review: ReviewResponse, ctx: WorkflowContext[ReviewRequest] + ) -> None: + print( + f"🔧 Worker: Received review for request {review.request_id[:8]} - Approved: {review.approved}" + ) + + # Handle the review response. Depending on the approval status, + # either emit the approved response as a WorkflowEvent, or + # retry given the feedback. + if review.request_id not in self._pending_requests: + raise ValueError( + f"Received review response for unknown request ID: {review.request_id}" + ) + # Remove the request from pending requests. + request, messages = self._pending_requests.pop(review.request_id) + + if review.approved: + print("✅ Worker: Response approved! Emitting to external consumer...") + # If approved, emit the agent response update to the workflow's + # external consumer. + contents: list[Content] = [] + for message in request.agent_messages: + contents.extend(message.contents) + # Emitting a WorkflowEvent in a workflow wrapped by a WorkflowAgent + # will send the AgentResponseUpdate to the WorkflowAgent's + # event stream. + await ctx.add_event( + WorkflowEvent( + "output", + data=AgentResponseUpdate( + contents=contents, role="assistant" + ), + ) + ) + return + + print(f"❌ Worker: Response not approved. Feedback: {review.feedback}") + print("🔧 Worker: Incorporating feedback and regenerating response...") + + # Construct new messages with feedback. + messages.append(Message(role="system", contents=[Content.from_text(review.feedback)])) + + # Add additional instruction to address the feedback. + messages.append( + Message( + role="system", + contents=[Content.from_text( + "Please incorporate the feedback above, and provide a response to user's next message." + )], + ) + ) + messages.extend(request.user_messages) + + # Get the new response from the chat client. + response = await self._chat_client.get_response(messages=messages) + print( + f"🔧 Worker: New response generated after feedback: {response.messages[-1].text}" + ) + + # Process the response. + messages.extend(response.messages) + + print( + f"🔧 Worker: Generated improved response, sending for re-review (ID: {review.request_id[:8]})" + ) + # Send an updated review request. + new_request = ReviewRequest( + request_id=review.request_id, + user_messages=request.user_messages, + agent_messages=response.messages, + ) + await ctx.send_message(new_request) + + # Add to pending requests. + self._pending_requests[new_request.request_id] = (new_request, messages) -def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: - return client.create_agent( - name="Reviewer", - instructions=( - "You are an excellent content reviewer. " - "Provide actionable feedback to the writer about the provided content. " - "Provide the feedback in the most concise manner possible." - ), +def build_agent(chat_client: SupportsChatGetResponse): + reviewer = Reviewer(chat_client=chat_client) + worker = Worker(chat_client=chat_client) + return ( + WorkflowBuilder(start_executor=worker) + .add_edge( + worker, reviewer + ) # <--- This edge allows the worker to send requests to the reviewer + .add_edge( + reviewer, worker + ) # <--- This edge allows the reviewer to send feedback back to the worker + .build() + .as_agent() # Convert the workflow to an agent. ) async def main() -> None: - async with AzureCliCredential() as cred, AzureAIAgentClient(credential=cred) as client: - builder = ( - WorkflowBuilder() - .register_agent(lambda: create_writer_agent(client), name="writer") - .register_agent(lambda: create_reviewer_agent(client), name="reviewer", output_response=True) - .set_start_executor("writer") - .add_edge("writer", "reviewer") - ) - - # Pass the WorkflowBuilder to the adapter and run it - # await from_agent_framework(workflow=builder).run_async() - - # Or create a factory function for the workflow pass the workflow factory to the adapter - def workflow_factory() -> Workflow: - return builder.build() - await from_agent_framework(workflow_factory).run_async() + async with DefaultAzureCredential() as credential: + async with AzureAIAgentClient(async_credential=credential) as chat_client: + agent = build_agent(chat_client) + await from_agent_framework(agent).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py index 586c31b8d4b7..f1d562b2668f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py @@ -4,25 +4,16 @@ Workflow Agent with Foundry Managed Checkpoints This sample demonstrates how to use FoundryCheckpointRepository with -a WorkflowBuilder agent to persist workflow checkpoints in Azure AI Foundry. - -Foundry managed checkpoints enable workflow state to be persisted across -requests, allowing workflows to be paused, resumed, and replayed. - -Prerequisites: - - Set AZURE_AI_PROJECT_ENDPOINT to your Azure AI Foundry project endpoint - e.g. "https://.services.ai.azure.com/api/projects/" - - Azure credentials configured (e.g. az login) +an Agent Framework workflow to persist workflow checkpoints in Azure AI Foundry. """ import asyncio import os -from dotenv import load_dotenv - -from agent_framework import ChatAgent, WorkflowBuilder -from agent_framework.azure import AzureAIAgentClient +from agent_framework import SupportsAgentRun, WorkflowBuilder +from agent_framework_azure_ai import AzureAIAgentClient from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv from azure.ai.agentserver.agentframework import from_agent_framework from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointRepository @@ -30,9 +21,9 @@ load_dotenv() -def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: +def create_writer_agent(client: AzureAIAgentClient) -> SupportsAgentRun: """Create a writer agent that generates content.""" - return client.create_agent( + return client.as_agent( name="Writer", instructions=( "You are an excellent content writer. " @@ -41,9 +32,9 @@ def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: ) -def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: +def create_reviewer_agent(client: AzureAIAgentClient) -> SupportsAgentRun: """Create a reviewer agent that provides feedback.""" - return client.create_agent( + return client.as_agent( name="Reviewer", instructions=( "You are an excellent content reviewer. " @@ -54,28 +45,20 @@ def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: async def main() -> None: - """Run the workflow agent with Foundry managed checkpoints.""" project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") - async with AzureCliCredential() as cred, AzureAIAgentClient(credential=cred) as client: - builder = ( - WorkflowBuilder() - .register_agent(lambda: create_writer_agent(client), name="writer") - .register_agent(lambda: create_reviewer_agent(client), name="reviewer", output_response=True) - .set_start_executor("writer") - .add_edge("writer", "reviewer") - ) + async with AzureCliCredential() as cred, AzureAIAgentClient(async_credential=cred) as client: + writer = create_writer_agent(client) + reviewer = create_reviewer_agent(client) + workflow = WorkflowBuilder(start_executor=writer).add_edge(writer, reviewer).build() - # Use FoundryCheckpointRepository for Azure AI Foundry managed storage. - # This persists workflow checkpoints remotely, enabling pause/resume - # across requests and server restarts. checkpoint_repository = FoundryCheckpointRepository( project_endpoint=project_endpoint, credential=cred, ) await from_agent_framework( - builder, + workflow, checkpoint_repository=checkpoint_repository, ).run_async() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py index cd00c924c030..3475548cd2de 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py @@ -4,12 +4,11 @@ # Workaround: importing agent_framework (via mcp) can fail with # KeyError: 'pydantic.root_model' unless this module is imported first. -import pydantic.root_model # noqa: F401 - import site import sys from pathlib import Path +import pydantic.root_model # noqa: F401 # Ensure we don't import user-site packages that can conflict with the active # environment (e.g., a user-installed cryptography wheel causing PyO3 errors). diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py index d52d0e481bd2..8e4c8ea862c5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py @@ -1,45 +1,35 @@ import importlib import pytest - -from agent_framework import ChatMessage, Role as ChatRole +from agent_framework import Message converter_module = importlib.import_module( "azure.ai.agentserver.agentframework.models.agent_framework_input_converters" ) -AgentFrameworkInputConverter = converter_module.AgentFrameworkInputConverter - - -@pytest.fixture() -def converter() -> AgentFrameworkInputConverter: - return AgentFrameworkInputConverter() +transform_input = converter_module.transform_input @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_none_returns_none(converter: AgentFrameworkInputConverter) -> None: - assert await converter.transform_input(None) is None +def test_transform_none_returns_none() -> None: + assert transform_input(None) is None @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_string_returns_same(converter: AgentFrameworkInputConverter) -> None: - assert await converter.transform_input("hello") == "hello" +def test_transform_string_returns_same() -> None: + assert transform_input("hello") == "hello" @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_implicit_user_message_with_string(converter: AgentFrameworkInputConverter) -> None: +def test_transform_implicit_user_message_with_string() -> None: payload = [{"content": "How are you?"}] - result = await converter.transform_input(payload) + result = transform_input(payload) assert result == "How are you?" @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_implicit_user_message_with_input_text_list(converter: AgentFrameworkInputConverter) -> None: +def test_transform_implicit_user_message_with_input_text_list() -> None: payload = [ { "content": [ @@ -49,14 +39,13 @@ async def test_transform_implicit_user_message_with_input_text_list(converter: A } ] - result = await converter.transform_input(payload) + result = transform_input(payload) assert result == "Hello world" @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_explicit_message_returns_chat_message(converter: AgentFrameworkInputConverter) -> None: +def test_transform_explicit_message_returns_chat_message() -> None: payload = [ { "type": "message", @@ -67,16 +56,15 @@ async def test_transform_explicit_message_returns_chat_message(converter: AgentF } ] - result = await converter.transform_input(payload) + result = transform_input(payload) - assert isinstance(result, ChatMessage) - assert result.role == ChatRole.ASSISTANT + assert isinstance(result, Message) + assert result.role == "assistant" assert result.text == "Hi there" @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_multiple_explicit_messages_returns_list(converter: AgentFrameworkInputConverter) -> None: +def test_transform_multiple_explicit_messages_returns_list() -> None: payload = [ { "type": "message", @@ -92,20 +80,19 @@ async def test_transform_multiple_explicit_messages_returns_list(converter: Agen }, ] - result = await converter.transform_input(payload) + result = transform_input(payload) assert isinstance(result, list) assert len(result) == 2 - assert all(isinstance(item, ChatMessage) for item in result) - assert result[0].role == ChatRole.USER + assert all(isinstance(item, Message) for item in result) + assert result[0].role == "user" assert result[0].text == "Hello" - assert result[1].role == ChatRole.ASSISTANT + assert result[1].role == "assistant" assert result[1].text == "Greetings" @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_mixed_messages_coerces_to_strings(converter: AgentFrameworkInputConverter) -> None: +def test_transform_mixed_messages_coerces_to_strings() -> None: payload = [ {"content": "First"}, { @@ -117,23 +104,21 @@ async def test_transform_mixed_messages_coerces_to_strings(converter: AgentFrame }, ] - result = await converter.transform_input(payload) + result = transform_input(payload) assert result == ["First", "Second"] @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_invalid_input_type_raises(converter: AgentFrameworkInputConverter) -> None: +def test_transform_invalid_input_type_raises() -> None: with pytest.raises(Exception) as exc_info: - await converter.transform_input({"content": "invalid"}) + transform_input({"content": "invalid"}) assert "Unsupported input type" in str(exc_info.value) @pytest.mark.unit -@pytest.mark.asyncio -async def test_transform_skips_non_text_entries(converter: AgentFrameworkInputConverter) -> None: +def test_transform_skips_non_text_entries() -> None: payload = [ { "content": [ @@ -143,6 +128,6 @@ async def test_transform_skips_non_text_entries(converter: AgentFrameworkInputCo } ] - result = await converter.transform_input(payload) + result = transform_input(payload) assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py index 85fdc34f6498..4c9b68f2132a 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py @@ -4,7 +4,7 @@ from unittest.mock import Mock import pytest -from agent_framework import AgentThread, InMemoryCheckpointStorage +from agent_framework import AgentSession, InMemoryCheckpointStorage from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import ( InMemoryAgentThreadRepository, @@ -18,7 +18,7 @@ @pytest.mark.asyncio async def test_inmemory_thread_repository_ignores_missing_conversation_id() -> None: repo = InMemoryAgentThreadRepository() - thread = Mock(spec=AgentThread) + thread = Mock(spec=AgentSession) await repo.set(None, thread) assert await repo.get(None) is None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py index f61ca0cfd9f2..9a98bf3576cb 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py @@ -2,15 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import pytest - -from agent_framework import FunctionCallContent, FunctionResultContent, Role as ChatRole -from agent_framework._types import TextContent, TextReasoningContent - +from agent_framework import Message as AFMessage from openai.types.conversations.message import Message -from openai.types.responses.response_input_text import ResponseInputText +from openai.types.responses import response_reasoning_item from openai.types.responses.response_function_tool_call_item import ResponseFunctionToolCallItem from openai.types.responses.response_function_tool_call_output_item import ResponseFunctionToolCallOutputItem -from openai.types.responses import response_reasoning_item +from openai.types.responses.response_input_text import ResponseInputText from azure.ai.agentserver.agentframework.models.conversation_converters import ConversationItemConverter @@ -33,10 +30,10 @@ def test_to_chat_message_converts_basic_message(converter: ConversationItemConve result = converter.to_chat_message(item) assert result is not None - assert result.role == ChatRole.USER - assert result.text is not None and "Hello world" in result.text + assert isinstance(result, AFMessage) + assert result.role == "user" assert result.contents is not None - assert any(isinstance(content, TextContent) for content in result.contents) + assert any(content.type == "text" for content in result.contents) @pytest.mark.unit @@ -53,11 +50,11 @@ def test_to_chat_message_converts_function_call_item(converter: ConversationItem result = converter.to_chat_message(item) assert result is not None - assert result.role == ChatRole.ASSISTANT + assert result.role == "assistant" assert result.contents is not None assert len(result.contents) == 1 content = result.contents[0] - assert isinstance(content, FunctionCallContent) + assert content.type == "function_call" assert content.call_id == "call_123" assert content.name == "do_something" assert isinstance(content.arguments, dict) @@ -77,11 +74,11 @@ def test_to_chat_message_converts_function_result_item(converter: ConversationIt result = converter.to_chat_message(item) assert result is not None - assert result.role == ChatRole.TOOL + assert result.role == "tool" assert result.contents is not None assert len(result.contents) == 1 content = result.contents[0] - assert isinstance(content, FunctionResultContent) + assert content.type == "function_result" assert content.call_id == "call_456" assert content.result == {"answer": 42} @@ -99,12 +96,10 @@ def test_to_chat_message_converts_reasoning_item(converter: ConversationItemConv result = converter.to_chat_message(reasoning_item) assert result is not None - assert result.role == ChatRole.ASSISTANT + assert result.role == "assistant" assert result.text == "High-level summary" assert result.contents is not None - assert any(isinstance(content, TextReasoningContent) for content in result.contents) - text_reasoning = next(content for content in result.contents if isinstance(content, TextReasoningContent)) - assert text_reasoning.text == "Chain-of-thought" + assert any(content.type in {"text_reasoning", "reasoning"} for content in result.contents) @pytest.mark.unit diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py index fb1b8d9c0e30..03c1e7846755 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py @@ -11,6 +11,15 @@ from .mocks import MockFoundryCheckpointClient +def _checkpoint(checkpoint_id: str, workflow_name: str, **kwargs) -> WorkflowCheckpoint: + return WorkflowCheckpoint( + checkpoint_id=checkpoint_id, + workflow_name=workflow_name, + graph_signature_hash="test-graph", + **kwargs, + ) + + @pytest.mark.unit @pytest.mark.asyncio async def test_save_checkpoint_returns_checkpoint_id() -> None: @@ -18,10 +27,7 @@ async def test_save_checkpoint_returns_checkpoint_id() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - checkpoint = WorkflowCheckpoint( - checkpoint_id="cp-123", - workflow_id="wf-1", - ) + checkpoint = _checkpoint("cp-123", "wf-1") result = await storage.save_checkpoint(checkpoint) @@ -35,18 +41,14 @@ async def test_load_checkpoint_returns_checkpoint() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - checkpoint = WorkflowCheckpoint( - checkpoint_id="cp-123", - workflow_id="wf-1", - iteration_count=5, - ) + checkpoint = _checkpoint("cp-123", "wf-1", iteration_count=5) await storage.save_checkpoint(checkpoint) loaded = await storage.load_checkpoint("cp-123") assert loaded is not None assert loaded.checkpoint_id == "cp-123" - assert loaded.workflow_id == "wf-1" + assert loaded.workflow_name == "wf-1" assert loaded.iteration_count == 5 @@ -69,9 +71,9 @@ async def test_list_checkpoint_ids_returns_all_ids() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") - cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-1") - cp3 = WorkflowCheckpoint(checkpoint_id="cp-3", workflow_id="wf-2") + cp1 = _checkpoint("cp-1", "wf-1") + cp2 = _checkpoint("cp-2", "wf-1") + cp3 = _checkpoint("cp-3", "wf-2") await storage.save_checkpoint(cp1) await storage.save_checkpoint(cp2) @@ -85,19 +87,19 @@ async def test_list_checkpoint_ids_returns_all_ids() -> None: @pytest.mark.unit @pytest.mark.asyncio async def test_list_checkpoint_ids_filters_by_workflow() -> None: - """Test that list_checkpoint_ids filters by workflow_id.""" + """Test that list_checkpoint_ids filters by workflow_name.""" client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") - cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-1") - cp3 = WorkflowCheckpoint(checkpoint_id="cp-3", workflow_id="wf-2") + cp1 = _checkpoint("cp-1", "wf-1") + cp2 = _checkpoint("cp-2", "wf-1") + cp3 = _checkpoint("cp-3", "wf-2") await storage.save_checkpoint(cp1) await storage.save_checkpoint(cp2) await storage.save_checkpoint(cp3) - ids = await storage.list_checkpoint_ids(workflow_id="wf-1") + ids = await storage.list_checkpoint_ids(workflow_name="wf-1") assert set(ids) == {"cp-1", "cp-2"} @@ -109,8 +111,8 @@ async def test_list_checkpoints_returns_all_checkpoints() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") - cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-2") + cp1 = _checkpoint("cp-1", "wf-1") + cp2 = _checkpoint("cp-2", "wf-2") await storage.save_checkpoint(cp1) await storage.save_checkpoint(cp2) @@ -129,7 +131,7 @@ async def test_delete_checkpoint_returns_true_for_existing() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - checkpoint = WorkflowCheckpoint(checkpoint_id="cp-123", workflow_id="wf-1") + checkpoint = _checkpoint("cp-123", "wf-1") await storage.save_checkpoint(checkpoint) deleted = await storage.delete_checkpoint("cp-123") @@ -156,7 +158,7 @@ async def test_delete_checkpoint_removes_from_storage() -> None: client = MockFoundryCheckpointClient() storage = FoundryCheckpointStorage(client=client, session_id="session-1") - checkpoint = WorkflowCheckpoint(checkpoint_id="cp-123", workflow_id="wf-1") + checkpoint = _checkpoint("cp-123", "wf-1") await storage.save_checkpoint(checkpoint) await storage.delete_checkpoint("cp-123") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py index ba30e06cfe78..b2a4e9c85ba5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py @@ -1,13 +1,16 @@ import importlib + +# Load _foundry_tools module directly without triggering parent __init__ which has heavy deps +import importlib.util import inspect +import sys +from pathlib import Path from types import SimpleNamespace -from unittest.mock import AsyncMock - from typing import Any +from unittest.mock import AsyncMock import pytest -from agent_framework import AIFunction, ChatOptions -from pydantic import Field, create_model +from agent_framework import ChatOptions, FunctionTool # Import schema models directly from client._models to avoid heavy azure.identity import # chain triggered by azure.ai.agentserver.core.__init__.py @@ -19,13 +22,16 @@ SchemaProperty, SchemaType, ) +from pydantic import Field, create_model -# Load _foundry_tools module directly without triggering parent __init__ which has heavy deps -import importlib.util -import sys -from pathlib import Path - -foundry_tools_path = Path(__file__).parent.parent.parent / "azure" / "ai" / "agentserver" / "agentframework" / "_foundry_tools.py" +foundry_tools_path = ( + Path(__file__).parent.parent.parent + / "azure" + / "ai" + / "agentserver" + / "agentframework" + / "_foundry_tools.py" +) spec = importlib.util.spec_from_file_location("_foundry_tools", foundry_tools_path) foundry_tools_module = importlib.util.module_from_spec(spec) sys.modules["_foundry_tools"] = foundry_tools_module @@ -93,7 +99,7 @@ async def test_to_aifunction_builds_pydantic_model_and_invokes(monkeypatch: pyte client = FoundryToolClient(tools=[]) ai_func = client._to_aifunction(resolved_tool) - assert isinstance(ai_func, AIFunction) + assert isinstance(ai_func, FunctionTool) assert ai_func.name == "echo" assert ai_func.description == "Echo tool" @@ -135,7 +141,7 @@ async def test_list_tools_uses_catalog_and_converts(monkeypatch: pytest.MonkeyPa assert args[0] == list(allowed) assert len(functions) == 1 - assert isinstance(functions[0], AIFunction) + assert isinstance(functions[0], FunctionTool) assert functions[0].name == "allowed_tool" @@ -148,7 +154,7 @@ async def dummy_tool(**kwargs): return kwargs DummyInput = create_model("DummyInput") - injected = [AIFunction(name="t", description="d", func=dummy_tool, input_model=DummyInput)] + injected = [FunctionTool(name="t", description="d", func=dummy_tool, input_model=DummyInput)] monkeypatch.setattr(middleware._foundry_tool_client, "list_tools", AsyncMock(return_value=injected)) context = SimpleNamespace(chat_options=None) @@ -156,9 +162,9 @@ async def dummy_tool(**kwargs): await middleware.process(context, next_fn) - assert isinstance(context.chat_options, ChatOptions) - assert context.chat_options.tools == injected - next_fn.assert_awaited_once_with(context) + assert isinstance(context.chat_options, dict) + assert context.chat_options.get("tools") == injected + next_fn.assert_awaited_once_with() @pytest.mark.unit @@ -170,19 +176,19 @@ async def dummy_tool(**kwargs): return kwargs DummyInput = create_model("DummyInput") - injected = [AIFunction(name="t2", description="d2", func=dummy_tool, input_model=DummyInput)] + injected = [FunctionTool(name="t2", description="d2", func=dummy_tool, input_model=DummyInput)] monkeypatch.setattr(middleware._foundry_tool_client, "list_tools", AsyncMock(return_value=injected)) # Existing ChatOptions with no tools should become injected context = SimpleNamespace(chat_options=ChatOptions()) next_fn = AsyncMock() await middleware.process(context, next_fn) - assert context.chat_options.tools == injected + assert context.chat_options.get("tools") == injected # Existing ChatOptions with tools should be appended - existing = [AIFunction(name="t1", description="d1", func=dummy_tool, input_model=DummyInput)] + existing = [FunctionTool(name="t1", description="d1", func=dummy_tool, input_model=DummyInput)] context = SimpleNamespace(chat_options=ChatOptions(tools=existing)) next_fn = AsyncMock() await middleware.process(context, next_fn) - assert context.chat_options.tools == existing + injected + assert context.chat_options.get("tools") == existing + injected assert next_fn.await_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py index c1616c475cb4..d1a666a3cb5c 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py @@ -3,34 +3,65 @@ # --------------------------------------------------------- """Unit tests for from_agent_framework with checkpoint repository.""" -import pytest from unittest.mock import Mock +import pytest +from agent_framework import Executor, WorkflowBuilder, handler from azure.core.credentials_async import AsyncTokenCredential +class _NoOpExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="noop") + + + @handler(input=str) + async def handle(self, message, ctx) -> None: + return None + + +def _create_builder() -> WorkflowBuilder: + return WorkflowBuilder(start_executor=_NoOpExecutor()) + + +class _StubWorkflowAdapter: + def __init__( + self, + workflow_factory, + credentials=None, + thread_repository=None, + checkpoint_repository=None, + **kwargs, + ): + self._workflow_factory = workflow_factory + self._checkpoint_repository = checkpoint_repository + + @pytest.mark.unit -def test_checkpoint_repository_is_optional() -> None: +def test_checkpoint_repository_is_optional(monkeypatch) -> None: """Test that checkpoint_repository is optional and defaults to None.""" + import azure.ai.agentserver.agentframework._workflow_agent_adapter as workflow_adapter_module from azure.ai.agentserver.agentframework import from_agent_framework - from agent_framework import WorkflowBuilder - builder = WorkflowBuilder() + monkeypatch.setattr(workflow_adapter_module, "AgentFrameworkWorkflowAdapter", _StubWorkflowAdapter) + + builder = _create_builder() - # Should not raise adapter = from_agent_framework(builder) assert adapter is not None @pytest.mark.unit -def test_foundry_checkpoint_repository_passed_directly() -> None: +def test_foundry_checkpoint_repository_passed_directly(monkeypatch) -> None: """Test that FoundryCheckpointRepository can be passed via checkpoint_repository.""" + import azure.ai.agentserver.agentframework._workflow_agent_adapter as workflow_adapter_module from azure.ai.agentserver.agentframework import from_agent_framework from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointRepository - from agent_framework import WorkflowBuilder - builder = WorkflowBuilder() + monkeypatch.setattr(workflow_adapter_module, "AgentFrameworkWorkflowAdapter", _StubWorkflowAdapter) + + builder = _create_builder() mock_credential = Mock(spec=AsyncTokenCredential) repo = FoundryCheckpointRepository( diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py index 04eea81f59e9..d0a088af8721 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py @@ -1,19 +1,12 @@ import pytest - -from agent_framework import ( - ChatMessage, - FunctionCallContent, - FunctionResultContent, - Role as ChatRole, - TextContent, +from agent_framework import Content, Message +from azure.ai.agentserver.core.server.common.constants import ( + HUMAN_IN_THE_LOOP_FUNCTION_NAME, ) from azure.ai.agentserver.agentframework.models.human_in_the_loop_helper import ( HumanInTheLoopHelper, ) -from azure.ai.agentserver.core.server.common.constants import ( - HUMAN_IN_THE_LOOP_FUNCTION_NAME, -) @pytest.fixture() @@ -21,32 +14,42 @@ def helper() -> HumanInTheLoopHelper: return HumanInTheLoopHelper() +def _function_call(call_id: str, name: str, arguments: str): + factory = getattr(Content, "from_function_call", None) + if callable(factory): + return factory(call_id=call_id, name=name, arguments=arguments) + return Content(type="function_call", call_id=call_id, name=name, arguments=arguments) + + +def _function_result(call_id: str, result): + factory = getattr(Content, "from_function_result", None) + if callable(factory): + return factory(call_id=call_id, result=result) + return Content(type="function_result", call_id=call_id, result=result) + + @pytest.mark.unit def test_remove_hitl_messages_keeps_latest_function_result(helper: HumanInTheLoopHelper) -> None: - hitl_call = FunctionCallContent( - call_id="hitl-1", - name=HUMAN_IN_THE_LOOP_FUNCTION_NAME, - arguments="{}", - ) - real_call = FunctionCallContent(call_id="tool-1", name="calculator", arguments="{}") - feedback_result = FunctionResultContent(call_id="tool-1", result="intermediate") - final_result = FunctionResultContent(call_id="tool-1", result={"total": 42}) - follow_up_content = TextContent("work resumed") + hitl_call = _function_call("hitl-1", HUMAN_IN_THE_LOOP_FUNCTION_NAME, "{}") + real_call = _function_call("tool-1", "calculator", "{}") + feedback_result = _function_result("tool-1", "intermediate") + final_result = _function_result("tool-1", {"total": 42}) + follow_up_content = Content.from_text("work resumed") thread_messages = [ - ChatMessage(role="assistant", contents=[real_call, hitl_call]), - ChatMessage(role="tool", contents=[feedback_result]), - ChatMessage(role="tool", contents=[final_result]), - ChatMessage(role="assistant", contents=[follow_up_content]), + Message(role="assistant", contents=[real_call, hitl_call]), + Message(role="tool", contents=[feedback_result]), + Message(role="tool", contents=[final_result]), + Message(role="assistant", contents=[follow_up_content]), ] filtered = helper.remove_hitl_content_from_thread(thread_messages) assert len(filtered) == 3 - assert filtered[0].role == ChatRole.ASSISTANT + assert filtered[0].role == "assistant" assert len(filtered[0].contents) == 1 assert filtered[0].contents[0] is real_call - assert filtered[1].role == ChatRole.TOOL + assert filtered[1].role == "tool" assert len(filtered[1].contents) == 1 assert filtered[1].contents[0] is final_result assert len(filtered[2].contents) == 1 @@ -55,47 +58,40 @@ def test_remove_hitl_messages_keeps_latest_function_result(helper: HumanInTheLoo @pytest.mark.unit def test_remove_hitl_messages_keeps_the_function_result(helper: HumanInTheLoopHelper) -> None: - hitl_call = FunctionCallContent( - call_id="hitl-1", - name=HUMAN_IN_THE_LOOP_FUNCTION_NAME, - arguments="{}", - ) - real_call = FunctionCallContent(call_id="tool-1", name="calculator", arguments="{}") - final_result = FunctionResultContent(call_id="tool-1", result={"total": 42}) - follow_up_content = TextContent("work resumed") + hitl_call = _function_call("hitl-1", HUMAN_IN_THE_LOOP_FUNCTION_NAME, "{}") + real_call = _function_call("tool-1", "calculator", "{}") + final_result = _function_result("tool-1", {"total": 42}) + follow_up_content = Content.from_text("work resumed") thread_messages = [ - ChatMessage(role="assistant", contents=[real_call, hitl_call]), - ChatMessage(role="tool", contents=[final_result]), - ChatMessage(role="assistant", contents=[follow_up_content]), + Message(role="assistant", contents=[real_call, hitl_call]), + Message(role="tool", contents=[final_result]), + Message(role="assistant", contents=[follow_up_content]), ] filtered = helper.remove_hitl_content_from_thread(thread_messages) assert len(filtered) == 3 - assert filtered[0].role == ChatRole.ASSISTANT + assert filtered[0].role == "assistant" assert len(filtered[0].contents) == 1 assert filtered[0].contents[0] is real_call - assert filtered[1].role == ChatRole.TOOL + assert filtered[1].role == "tool" assert len(filtered[1].contents) == 1 assert filtered[1].contents[0] is final_result assert len(filtered[2].contents) == 1 assert filtered[2].contents[0] is follow_up_content + @pytest.mark.unit def test_remove_hitl_messages_skips_orphaned_hitl_results(helper: HumanInTheLoopHelper) -> None: - hitl_call = FunctionCallContent( - call_id="hitl-2", - name=HUMAN_IN_THE_LOOP_FUNCTION_NAME, - arguments="{}", - ) - orphan_result = FunctionResultContent(call_id="hitl-2", result="ignored") - user_update = TextContent("ready") + hitl_call = _function_call("hitl-2", HUMAN_IN_THE_LOOP_FUNCTION_NAME, "{}") + orphan_result = _function_result("hitl-2", "ignored") + user_update = Content.from_text("ready") thread_messages = [ - ChatMessage(role="assistant", contents=[hitl_call]), - ChatMessage(role="tool", contents=[orphan_result]), - ChatMessage(role="user", contents=[user_update]), + Message(role="assistant", contents=[hitl_call]), + Message(role="tool", contents=[orphan_result]), + Message(role="user", contents=[user_update]), ] filtered = helper.remove_hitl_content_from_thread(thread_messages) @@ -107,20 +103,20 @@ def test_remove_hitl_messages_skips_orphaned_hitl_results(helper: HumanInTheLoop @pytest.mark.unit def test_remove_hitl_messages_preserves_regular_tool_cycle(helper: HumanInTheLoopHelper) -> None: - real_call = FunctionCallContent(call_id="tool-3", name="lookup", arguments="{}") - result_content = FunctionResultContent(call_id="tool-3", result="done") + real_call = _function_call("tool-3", "lookup", "{}") + result_content = _function_result("tool-3", "done") thread_messages = [ - ChatMessage(role="assistant", contents=[real_call]), - ChatMessage(role="tool", contents=[result_content]), + Message(role="assistant", contents=[real_call]), + Message(role="tool", contents=[result_content]), ] filtered = helper.remove_hitl_content_from_thread(thread_messages) assert len(filtered) == 2 assert len(filtered[0].contents) == 1 - assert filtered[0].role == ChatRole.ASSISTANT + assert filtered[0].role == "assistant" assert filtered[0].contents[0] is real_call - assert filtered[1].role == ChatRole.TOOL + assert filtered[1].role == "tool" assert len(filtered[1].contents) == 1 - assert filtered[1].contents[0] is result_content \ No newline at end of file + assert filtered[1].contents[0] is result_content From 3feb6d0d6cc958cf70868d44df02da14a708e2ba Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Wed, 25 Feb 2026 11:46:58 +0100 Subject: [PATCH 2/2] Refactor session persistence and converters Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../CHANGELOG.md | 4 +- .../ai/agentserver/agentframework/__init__.py | 26 +- .../agentframework/_agent_framework.py | 90 ++-- .../agentframework/_ai_agent_adapter.py | 16 +- .../agentframework/_workflow_agent_adapter.py | 16 +- .../models/conversation_converters.py | 510 +++++++++--------- .../models/human_in_the_loop_helper.py | 20 +- .../agentframework/persistence/__init__.py | 18 +- .../_foundry_conversation_message_store.py | 122 ++--- ...foundry_conversation_session_repository.py | 93 ++++ ..._foundry_conversation_thread_repository.py | 83 --- ...ository.py => agent_session_repository.py} | 108 ++-- .../human_in_the_loop_ai_function/.gitignore | 2 +- .../human_in_the_loop_ai_function/README.md | 8 +- .../human_in_the_loop_ai_function/main.py | 6 +- .../test_conversation_id_optional.py | 132 ++++- .../test_conversation_item_converter.py | 27 +- .../test_from_agent_framework_managed.py | 2 +- .../test_human_in_the_loop_helper.py | 26 +- 19 files changed, 688 insertions(+), 621 deletions(-) create mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_session_repository.py delete mode 100644 sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py rename sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/{agent_thread_repository.py => agent_session_repository.py} (50%) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md index 9a500c330124..5e4834448be3 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md @@ -20,7 +20,7 @@ ### Bugs Fixed -- Mitigate AgentThread management for AzureAIClient agents. +- Mitigate AgentSession management for AzureAIClient agents. ## 1.0.0b11 (2026-02-10) @@ -44,7 +44,7 @@ ### Features Added - Integrated with Foundry Tools -- Add persistence for agent thread and checkpoint +- Add persistence for agent session and checkpoint - Fixed WorkflowAgent concurrency issue - Support Human-in-the-Loop diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 1b9e6b009a40..4a52b5f90800 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -17,7 +17,7 @@ from ._foundry_tools import FoundryToolsChatMiddleware from ._version import VERSION -from .persistence import AgentThreadRepository, CheckpointRepository +from .persistence import AgentSessionRepository, CheckpointRepository if TYPE_CHECKING: from ._agent_framework import AgentFrameworkAgent @@ -30,7 +30,7 @@ def from_agent_framework( agent: Union[BaseAgent, SupportsAgentRun], /, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, - thread_repository: Optional[AgentThreadRepository]=None + session_repository: Optional[AgentSessionRepository]=None ) -> "AgentFrameworkAIAgentAdapter": """ Create an Agent Framework AI Agent Adapter from a SupportsAgentRun or BaseAgent. @@ -39,8 +39,8 @@ def from_agent_framework( :type agent: Union[BaseAgent, SupportsAgentRun] :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] - :param thread_repository: Optional thread repository for agent thread management. - :type thread_repository: Optional[AgentThreadRepository] + :param session_repository: Optional session repository for agent session management. + :type session_repository: Optional[AgentSessionRepository] :return: An instance of AgentFrameworkAIAgentAdapter. :rtype: AgentFrameworkAIAgentAdapter @@ -52,7 +52,7 @@ def from_agent_framework( workflow: Union[WorkflowBuilder, Callable[[], Workflow]], /, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, - thread_repository: Optional[AgentThreadRepository] = None, + session_repository: Optional[AgentSessionRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, ) -> "AgentFrameworkWorkflowAdapter": """ @@ -68,8 +68,8 @@ def from_agent_framework( :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] - :param thread_repository: Optional thread repository for agent thread management. - :type thread_repository: Optional[AgentThreadRepository] + :param session_repository: Optional session repository for agent session management. + :type session_repository: Optional[AgentSessionRepository] :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. @@ -83,7 +83,7 @@ def from_agent_framework( agent_or_workflow: Union[BaseAgent, SupportsAgentRun, WorkflowBuilder, Callable[[], Workflow]], /, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, - thread_repository: Optional[AgentThreadRepository] = None, + session_repository: Optional[AgentSessionRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, ) -> "AgentFrameworkAgent": """ @@ -95,8 +95,8 @@ def from_agent_framework( :type agent_or_workflow: Optional[Union[BaseAgent, SupportsAgentRun]] :param credentials: Optional asynchronous token credential for authentication. :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] - :param thread_repository: Optional thread repository for agent thread management. - :type thread_repository: Optional[AgentThreadRepository] + :param session_repository: Optional session repository for agent session management. + :type session_repository: Optional[AgentSessionRepository] :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. @@ -113,7 +113,7 @@ def from_agent_framework( return AgentFrameworkWorkflowAdapter( workflow_factory=agent_or_workflow.build, credentials=credentials, - thread_repository=thread_repository, + session_repository=session_repository, checkpoint_repository=checkpoint_repository, ) if isinstance(agent_or_workflow, Callable): # type: ignore @@ -122,7 +122,7 @@ def from_agent_framework( return AgentFrameworkWorkflowAdapter( workflow_factory=agent_or_workflow, credentials=credentials, - thread_repository=thread_repository, + session_repository=session_repository, checkpoint_repository=checkpoint_repository, ) # raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") @@ -132,7 +132,7 @@ def from_agent_framework( return AgentFrameworkAIAgentAdapter(agent_or_workflow, credentials=credentials, - thread_repository=thread_repository) + session_repository=session_repository) raise TypeError("You must provide one of the instances of type " "[SupportsAgentRun, BaseAgent, WorkflowBuilder or callable returning a Workflow]") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py index f62b89e0d092..deca18eef0da 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -7,7 +7,7 @@ import os from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, Union -from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent +from agent_framework import AgentSession, SupportsAgentRun from opentelemetry import trace from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent @@ -22,8 +22,9 @@ from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter from .models.human_in_the_loop_helper import HumanInTheLoopHelper -from .persistence import AgentThreadRepository -from .persistence._foundry_conversation_thread_repository import FoundryConversationThreadRepository +from .persistence import AgentSessionRepository +from .persistence._foundry_conversation_message_store import FoundryConversationMessageStore +from .persistence._foundry_conversation_session_repository import FoundryConversationSessionRepository if TYPE_CHECKING: from azure.core.credentials_async import AsyncTokenCredential @@ -50,24 +51,26 @@ class AgentFrameworkAgent(FoundryCBAgent): def __init__(self, credentials: "Optional[AsyncTokenCredential]" = None, - thread_repository: Optional[AgentThreadRepository] = None, + session_repository: Optional[AgentSessionRepository] = None, project_endpoint: Optional[str] = None, **kwargs) -> None: """Initialize the AgentFrameworkAgent with a SupportsAgentRun-compatible agent adapter. :param credentials: Azure credentials for authentication. :type credentials: Optional[AsyncTokenCredential] - :param thread_repository: An optional AgentThreadRepository instance for managing thread messages. - :type thread_repository: Optional[AgentThreadRepository] + :param session_repository: An optional AgentSessionRepository instance for managing session messages. + :type session_repository: Optional[AgentSessionRepository] :param project_endpoint: The endpoint of the Azure AI Project. :type project_endpoint: Optional[str] """ super().__init__(credentials=credentials, **kwargs) # pylint: disable=unexpected-keyword-arg project_endpoint = get_project_endpoint(logger=logger) or project_endpoint - if not thread_repository and project_endpoint and self.credentials: - logger.warning("No thread repository provided. FoundryConversationThreadRepository will be used.") - thread_repository = self._create_foundry_conversation_thread_repository(project_endpoint, self.credentials) - self._thread_repository = thread_repository + if not session_repository and project_endpoint and self.credentials: + logger.warning("No session repository provided. FoundryConversationSessionRepository will be used.") + session_repository = self._create_foundry_conversation_session_repository( + project_endpoint, self.credentials + ) + self._session_repository = session_repository self._hitl_helper = HumanInTheLoopHelper() def init_tracing(self): @@ -178,50 +181,51 @@ async def agent_run( # pylint: disable=too-many-statements ]: raise NotImplementedError("This method is implemented in the base class.") - async def _load_agent_thread( + async def _load_agent_session( self, context: AgentRunContext, - agent: Union[SupportsAgentRun, WorkflowAgent], + agent: SupportsAgentRun, ) -> Optional[AgentSession]: """Load the agent session for a given conversation ID. :param context: The agent run context. :type context: AgentRunContext :param agent: The agent instance. - :type agent: SupportsAgentRun | WorkflowAgent + :type agent: SupportsAgentRun :return: The loaded AgentSession if available, None otherwise. :rtype: Optional[AgentSession] """ - if self._thread_repository and context.conversation_id: + if self._session_repository and context.conversation_id: + self._ensure_foundry_history_provider(agent) conversation_id = context.conversation_id - agent_thread = await self._thread_repository.get(conversation_id, agent=agent) - if agent_thread: - logger.info(f"Loaded agent thread for conversation: {conversation_id}") - return agent_thread + agent_session = await self._session_repository.get(conversation_id) + if agent_session: + logger.info(f"Loaded agent session for conversation: {conversation_id}") + return agent_session return agent.create_session() return None - async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentSession) -> None: + async def _save_agent_session(self, context: AgentRunContext, agent_session: AgentSession) -> None: """Save the agent session for a given conversation ID. :param context: The agent run context. :type context: AgentRunContext - :param agent_thread: The agent session to save. - :type agent_thread: AgentSession + :param agent_session: The agent session to save. + :type agent_session: AgentSession :return: None :rtype: None """ - if agent_thread and self._thread_repository and (conversation_id := context.conversation_id): - await self._thread_repository.set(conversation_id, agent_thread) - logger.info(f"Saved agent thread for conversation: {conversation_id}") + if agent_session and self._session_repository and (conversation_id := context.conversation_id): + await self._session_repository.set(conversation_id, agent_session) + logger.info(f"Saved agent session for conversation: {conversation_id}") def _run_streaming_updates( self, context: AgentRunContext, stream_runner: Callable[[], AsyncGenerator[Any, None]], - agent_thread: Optional[AgentSession] = None, + agent_session: Optional[AgentSession] = None, ) -> AsyncGenerator[ResponseStreamEvent, Any]: """ Execute a streaming run with shared OAuth/error handling. @@ -230,8 +234,8 @@ def _run_streaming_updates( :type context: AgentRunContext :param stream_runner: A callable that invokes the agent in stream mode :type stream_runner: Callable[[], AsyncGenerator[Any, None]] - :param agent_thread: The agent thread to use during streaming updates. - :type agent_thread: Optional[AgentSession] + :param agent_session: The agent session to use during streaming updates. + :type agent_session: Optional[AgentSession] :return: An async generator yielding streaming events. :rtype: AsyncGenerator[ResponseStreamEvent, Any] @@ -251,7 +255,7 @@ async def stream_updates(): update_count += 1 yield event - await self._save_agent_thread(context, agent_thread) + await self._save_agent_session(context, agent_session) logger.info("Streaming completed with %d updates", update_count) except OAuthConsentRequiredError as e: logger.info("OAuth consent required during streaming updates") @@ -291,23 +295,39 @@ async def stream_updates(): return stream_updates() - def _create_foundry_conversation_thread_repository( + def _create_foundry_conversation_session_repository( self, project_endpoint: str, credential: AsyncTokenCredential, - ) -> FoundryConversationThreadRepository: - """Helper method to create a FoundryConversationThreadRepository instance. + ) -> FoundryConversationSessionRepository: + """Helper method to create a FoundryConversationSessionRepository instance. :param project_endpoint: The endpoint of the Azure AI Project. :type project_endpoint: str :param credential: The credential for authenticating with the Azure AI Project. :type credential: AsyncTokenCredential - :return: An instance of FoundryConversationThreadRepository. - :rtype: FoundryConversationThreadRepository + :return: An instance of FoundryConversationSessionRepository. + :rtype: FoundryConversationSessionRepository """ - return FoundryConversationThreadRepository( - agent=None, # Agent will be provided during get/set calls + return FoundryConversationSessionRepository( project_endpoint=project_endpoint, credential=credential, ) + + def _ensure_foundry_history_provider(self, agent: SupportsAgentRun) -> None: + if not isinstance(self._session_repository, FoundryConversationSessionRepository): + return + + context_providers = getattr(agent, "context_providers", None) + if not isinstance(context_providers, list): + logger.warning( + "Agent does not expose mutable context_providers; " + "FoundryConversationMessageStore was not attached." + ) + return + + if any(isinstance(provider, FoundryConversationMessageStore) for provider in context_providers): + return + + context_providers.append(self._session_repository.history_provider) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py index 4fc722d89bdd..143b0ab210f6 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -23,18 +23,18 @@ from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from .persistence import AgentThreadRepository +from .persistence import AgentSessionRepository logger = get_logger() class AgentFrameworkAIAgentAdapter(AgentFrameworkAgent): def __init__(self, agent: SupportsAgentRun, credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, - thread_repository: Optional[AgentThreadRepository] = None, + session_repository: Optional[AgentSessionRepository] = None, *, project_endpoint: Optional[str] = None, **kwargs) -> None: - super().__init__(credentials, thread_repository, project_endpoint=project_endpoint, **kwargs) + super().__init__(credentials, session_repository, project_endpoint=project_endpoint, **kwargs) self._agent = agent async def agent_run( # pylint: disable=too-many-statements @@ -47,7 +47,7 @@ async def agent_run( # pylint: disable=too-many-statements logger.info("Starting AIAgent agent_run with stream=%s", context.stream) request_input = context.request.get("input") - agent_thread = await self._load_agent_thread(context, self._agent) + agent_session = await self._load_agent_session(context, self._agent) message = transform_input(request_input) logger.debug("Transformed input message type: %s", type(message)) @@ -67,19 +67,19 @@ async def agent_run( # pylint: disable=too-many-statements context=context, stream_runner=lambda: self._agent.run( message, - session=agent_thread, + session=agent_session, stream=True, ), - agent_thread=agent_thread, + agent_session=agent_session, ) # Non-streaming path logger.info("Running agent in non-streaming mode") result = await self._agent.run( message, - session=agent_thread) + session=agent_session) logger.debug("Agent run completed, result type: %s", type(result)) - await self._save_agent_thread(context, agent_thread) + await self._save_agent_session(context, agent_session) non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) transformed_result = non_streaming_converter.transform_output_for_response(result) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index 1eccdaf8859f..2679751f366f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -27,7 +27,7 @@ from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from .persistence import AgentThreadRepository, CheckpointRepository +from .persistence import AgentSessionRepository, CheckpointRepository logger = get_logger() @@ -37,13 +37,13 @@ def __init__( self, workflow_factory: Callable[[], Workflow], credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, - thread_repository: Optional[AgentThreadRepository] = None, + session_repository: Optional[AgentSessionRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, *, project_endpoint: Optional[str] = None, **kwargs, ) -> None: - super().__init__(credentials, thread_repository, project_endpoint=project_endpoint, **kwargs) + super().__init__(credentials, session_repository, project_endpoint=project_endpoint, **kwargs) self._workflow_factory = workflow_factory self._checkpoint_repository = checkpoint_repository @@ -59,7 +59,7 @@ async def agent_run( # pylint: disable=too-many-statements logger.info("Starting WorkflowAgent agent_run with stream=%s", context.stream) request_input = context.request.get("input") - agent_thread = await self._load_agent_thread(context, agent) + agent_session = await self._load_agent_session(context, agent) checkpoint_storage = None selected_checkpoint = None @@ -88,22 +88,22 @@ async def agent_run( # pylint: disable=too-many-statements context=context, stream_runner=lambda: agent.run( message, - session=agent_thread, + session=agent_session, checkpoint_storage=checkpoint_storage, stream=True, ), - agent_thread=agent_thread, + agent_session=agent_session, ) # Non-streaming path logger.info("Running WorkflowAgent in non-streaming mode") result = await agent.run( message, - session=agent_thread, + session=agent_session, checkpoint_storage=checkpoint_storage) logger.debug("WorkflowAgent run completed, result type: %s", type(result)) - await self._save_agent_thread(context, agent_thread) + await self._save_agent_session(context, agent_session) non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) transformed_result = non_streaming_converter.transform_output_for_response(result) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py index b830df5b7565..02ca9ea4b753 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/conversation_converters.py @@ -13,280 +13,266 @@ logger = get_logger() +_ROLE_MAP = { + "assistant": "assistant", + "system": "system", + "user": "user", + "tool": "tool", + "developer": "system", + "critic": "assistant", + "discriminator": "assistant", + "unknown": "user", +} +_TOOL_CALL_TYPES = { + "function_call", + "web_search_call", + "computer_call", + "local_shell_call", + "custom_tool_call", + "mcp_approval_request", +} +_TOOL_RESULT_TYPES = { + "function_call_output", + "computer_call_output", + "local_shell_call_output", + "custom_tool_call_output", + "file_search_call", + "image_generation_call", + "code_interpreter_call", + "mcp_list_tools", + "mcp_call", + "mcp_approval_response", +} +_RESULT_HINT_FIELDS = ( + "output", + "outputs", + "result", + "results", + "summary", + "tools", + "encrypted_content", + "error", +) + + +def to_chat_message(item: ConversationItem) -> Optional[Message]: + """Convert a ConversationItem from the Conversations API to an AF Message.""" + if item is None: + return None -class ConversationItemConverter: - _ROLE_MAP = { - "assistant": "assistant", - "system": "system", - "user": "user", - "tool": "tool", - "developer": "system", - "critic": "assistant", - "discriminator": "assistant", - "unknown": "user", - } - _TOOL_CALL_TYPES = { - "function_call", - "web_search_call", - "computer_call", - "local_shell_call", - "custom_tool_call", - "mcp_approval_request", - } - _TOOL_RESULT_TYPES = { - "function_call_output", - "computer_call_output", - "local_shell_call_output", - "custom_tool_call_output", - "file_search_call", - "image_generation_call", - "code_interpreter_call", - "mcp_list_tools", - "mcp_call", - "mcp_approval_response", - } - _RESULT_HINT_FIELDS = ( - "output", - "outputs", - "result", - "results", - "summary", - "tools", - "encrypted_content", - "error", + item_type = getattr(item, "type", None) + if item_type == "message": + return _convert_message_item(item) + if item_type == "reasoning": + return _convert_reasoning_item(item) + + if item_type in _TOOL_RESULT_TYPES or _has_result_payload(item): + return _convert_tool_result_item(item) + if item_type in _TOOL_CALL_TYPES: + return _convert_tool_call_item(item) + + logger.debug("Unsupported conversation item type: %s", item_type) + return None + + +def _convert_message_item(item: Any) -> Optional[Message]: + role_value = _ROLE_MAP.get(str(getattr(item, "role", "user")).lower(), "user") + raw_contents = getattr(item, "content", None) or [] + + converted_contents: list[Any] = [] + for content in raw_contents: + converted = _convert_message_content(content) + if converted: + converted_contents.append(converted) + + if not converted_contents: + return None + + return Message(role=role_value, contents=converted_contents) + + +def _convert_tool_call_item(item: Any) -> Optional[Message]: + data = _model_dump(item) + if not data: + return None + + call_id = _resolve_call_id(data, getattr(item, "type", "tool_call")) + name = str(data.get("name") or _infer_action_name(data) or data.get("type") or "tool_call") + arguments = _extract_call_arguments(data) + normalized_arguments = _normalize_arguments(arguments) + + content = Content.from_function_call( + call_id=call_id, + name=name, + arguments=normalized_arguments, ) + return Message(role="assistant", contents=[content]) + - def to_chat_message(self, item: ConversationItem) -> Optional[Message]: - """Convert a ConversationItem from the Conversations API to an AF Message.""" - if item is None: - return None +def _convert_tool_result_item(item: Any) -> Optional[Message]: + data = _model_dump(item) + if not data: + return None + + call_id = _resolve_call_id(data, getattr(item, "type", "tool")) + result_payload = _extract_result_payload(data) + if result_payload is None: + return None + + content = Content.from_function_result(call_id=call_id, result=result_payload) + return Message(role="tool", contents=[content]) - item_type = getattr(item, "type", None) - if item_type == "message": - return self._convert_message_item(item) - if item_type == "reasoning": - return self._convert_reasoning_item(item) - if item_type in self._TOOL_RESULT_TYPES or self._has_result_payload(item): - return self._convert_tool_result_item(item) - if item_type in self._TOOL_CALL_TYPES: - return self._convert_tool_call_item(item) +def _convert_reasoning_item(item: Any) -> Optional[Message]: + data = _model_dump(item) + summaries = data.get("summary", []) or [] + content_items = data.get("content", []) or [] + + reasoning_contents: list[Any] = [] + for content in content_items: + text = content.get("text") if isinstance(content, Mapping) else None + if text: + reasoning_contents.append(Content.from_text_reasoning(text=text)) + + summary_text = " \n".join( + summary.get("text") + for summary in summaries + if isinstance(summary, Mapping) and summary.get("text") + ) - logger.debug("Unsupported conversation item type: %s", item_type) + if not reasoning_contents and not summary_text: return None - def _convert_message_item(self, item: Any) -> Optional[Message]: - role_value = self._ROLE_MAP.get(str(getattr(item, "role", "user")).lower(), "user") - raw_contents = getattr(item, "content", None) or [] - - converted_contents: list[Any] = [] - - for content in raw_contents: - converted = self._convert_message_content(content) - if converted: - converted_contents.append(converted) - - if not converted_contents: - return None - - return Message(role=role_value, contents=converted_contents) - - def _convert_tool_call_item(self, item: Any) -> Optional[Message]: - data = self._model_dump(item) - if not data: - return None - - call_id = self._resolve_call_id(data, getattr(item, "type", "tool_call")) - name = str(data.get("name") or self._infer_action_name(data) or data.get("type") or "tool_call") - - arguments = self._extract_call_arguments(data) - normalized_arguments = self._normalize_arguments(arguments) - - content = self._content_from_function_call( - call_id=call_id, - name=name, - arguments=normalized_arguments, - ) - return Message(role="assistant", contents=[content]) - - def _convert_tool_result_item(self, item: Any) -> Optional[Message]: - data = self._model_dump(item) - if not data: - return None - - call_id = self._resolve_call_id(data, getattr(item, "type", "tool")) - result_payload = self._extract_result_payload(data) - if result_payload is None: - return None - - content = self._content_from_function_result(call_id=call_id, result=result_payload) - return Message(role="tool", contents=[content]) - - def _convert_reasoning_item(self, item: Any) -> Optional[Message]: - data = self._model_dump(item) - summaries = data.get("summary", []) or [] - content_items = data.get("content", []) or [] - - reasoning_contents: list[Any] = [] - for content in content_items: - text = content.get("text") if isinstance(content, Mapping) else None - if text: - reasoning_contents.append(self._content_from_text_reasoning(text)) - - summary_text = " \n".join( - summary.get("text") - for summary in summaries - if isinstance(summary, Mapping) and summary.get("text") - ) - - if not reasoning_contents and not summary_text: - return None - - kwargs: dict[str, Any] = {} - if summary_text: - kwargs["text"] = summary_text - if reasoning_contents: - kwargs["contents"] = reasoning_contents - return Message(role="assistant", **kwargs) - - def _convert_message_content(self, content: Any) -> Optional[Any]: - content_type = str(getattr(content, "type", "")).lower() - - if content_type in {"input_text", "output_text", "text", "summary_text"}: - text_value = getattr(content, "text", None) - if text_value: - return self._content_from_text(text_value) - - if content_type == "reasoning_text": - text_value = getattr(content, "text", None) - if text_value: - return self._content_from_text_reasoning(text_value) - - if content_type == "refusal": - refusal_text = getattr(content, "refusal", None) - if refusal_text: - return self._content_from_text(refusal_text) + kwargs: dict[str, Any] = {} + if summary_text: + kwargs["text"] = summary_text + if reasoning_contents: + kwargs["contents"] = reasoning_contents + return Message(role="assistant", **kwargs) + +def _convert_message_content(content: Any) -> Optional[Any]: + content_type = str(getattr(content, "type", "")).lower() + + if content_type in {"input_text", "output_text", "text", "summary_text"}: + text_value = getattr(content, "text", None) + if text_value: + return Content.from_text(text=text_value) + + if content_type == "reasoning_text": + text_value = getattr(content, "text", None) + if text_value: + return Content.from_text_reasoning(text=text_value) + + if content_type == "refusal": + refusal_text = getattr(content, "refusal", None) + if refusal_text: + return Content.from_text(text=refusal_text) + + return None + + +def _extract_call_arguments(data: Mapping[str, Any]) -> Any: + if data.get("arguments") not in (None, ""): + return data.get("arguments") + if data.get("action") not in (None, {}): + return data.get("action") + + payload = { + key: value + for key, value in data.items() + if key not in {"id", "type", "status", "call_id", "name"} + } + return payload or None + + +def _extract_result_payload(data: Mapping[str, Any]) -> Any: + for key in ( + "output", + "outputs", + "result", + "results", + "content", + ): + value = data.get(key) + if value not in (None, [], {}, ""): + return _normalize_result(value) + + payload = { + key: value + for key, value in data.items() + if key not in {"id", "type", "status", "call_id", "name"} + } + return _normalize_result(payload) if payload else None + + +def _normalize_arguments(value: Any) -> Any: + if value is None: return None + if isinstance(value, str): + stripped_value = value.strip() + return _normalize_result(stripped_value) + if hasattr(value, "model_dump"): + return value.model_dump(mode="python", exclude_none=True) + if isinstance(value, Mapping): + return {key: _normalize_result(val) for key, val in value.items()} + if isinstance(value, list): + return {"items": [_normalize_result(item) for item in value]} + return _normalize_result(value) + + +def _normalize_result(value: Any) -> Any: + if isinstance(value, str): + return _safe_json_loads(value) + if hasattr(value, "model_dump"): + return value.model_dump(mode="python", exclude_none=True) + if isinstance(value, list): + return [_normalize_result(item) for item in value] + if isinstance(value, Mapping): + return {key: _normalize_result(val) for key, val in value.items()} + return value + + +def _safe_json_loads(value: str) -> Any: + try: + return json.loads(value) + except (TypeError, ValueError): + return value - def _content_from_text(self, text: str) -> Any: - factory = getattr(Content, "from_text", None) - if callable(factory): - return factory(text=text) - return Content(type="text", text=text) - - def _content_from_text_reasoning(self, text: str) -> Any: - factory = getattr(Content, "from_text_reasoning", None) - if callable(factory): - return factory(text=text) - return Content(type="text_reasoning", text=text) - - def _content_from_function_call(self, call_id: str, name: str, arguments: Any) -> Any: - factory = getattr(Content, "from_function_call", None) - if callable(factory): - return factory(call_id=call_id, name=name, arguments=arguments) - return Content(type="function_call", call_id=call_id, name=name, arguments=arguments) - - def _content_from_function_result(self, call_id: str, result: Any) -> Any: - factory = getattr(Content, "from_function_result", None) - if callable(factory): - return factory(call_id=call_id, result=result) - return Content(type="function_result", call_id=call_id, result=result) - - def _extract_call_arguments(self, data: Mapping[str, Any]) -> Any: - if data.get("arguments") not in (None, ""): - return data.get("arguments") - if data.get("action") not in (None, {}): - return data.get("action") - - payload = { - key: value - for key, value in data.items() - if key not in {"id", "type", "status", "call_id", "name"} - } - return payload or None - - def _extract_result_payload(self, data: Mapping[str, Any]) -> Any: - for key in ( - "output", - "outputs", - "result", - "results", - "content", - ): - value = data.get(key) - if value not in (None, [], {}, ""): - return self._normalize_result(value) - - payload = { + +def _model_dump(item: Any) -> Mapping[str, Any]: + if hasattr(item, "model_dump"): + return item.model_dump(mode="python", exclude_none=True) + if hasattr(item, "dict"): + return item.dict() + if hasattr(item, "__dict__"): + return { key: value - for key, value in data.items() - if key not in {"id", "type", "status", "call_id", "name"} + for key, value in item.__dict__.items() + if not key.startswith("_") } - return self._normalize_result(payload) if payload else None - - def _normalize_arguments(self, value: Any) -> Any: - if value is None: - return None - if isinstance(value, str): - stripped_value = value.strip() - return self._normalize_result(stripped_value) - if hasattr(value, "model_dump"): - return value.model_dump(mode="python", exclude_none=True) - if isinstance(value, Mapping): - return {key: self._normalize_result(val) for key, val in value.items()} - if isinstance(value, list): - return {"items": [self._normalize_result(item) for item in value]} - return self._normalize_result(value) - - def _normalize_result(self, value: Any) -> Any: - if isinstance(value, str): - loaded = self._safe_json_loads(value) - return loaded - if hasattr(value, "model_dump"): - return value.model_dump(mode="python", exclude_none=True) - if isinstance(value, list): - return [self._normalize_result(item) for item in value] - if isinstance(value, Mapping): - return {key: self._normalize_result(val) for key, val in value.items()} - return value + return {} + + +def _resolve_call_id(data: Mapping[str, Any], default_prefix: str) -> str: + candidate = data.get("call_id") or data.get("id") + if candidate: + return str(candidate) + return f"{default_prefix or 'tool'}-{uuid4().hex}" + + +def _infer_action_name(data: Mapping[str, Any]) -> Optional[str]: + action = data.get("action") + if hasattr(action, "model_dump"): + action = action.model_dump(mode="python", exclude_none=True) + if isinstance(action, Mapping): + return str(action.get("type") or action.get("name") or "").strip() or None + return None - def _safe_json_loads(self, value: str) -> Any: - try: - return json.loads(value) - except (TypeError, ValueError): - return value - - def _model_dump(self, item: Any) -> Mapping[str, Any]: - if hasattr(item, "model_dump"): - return item.model_dump(mode="python", exclude_none=True) - if hasattr(item, "dict"): - return item.dict() - if hasattr(item, "__dict__"): - return { - key: value - for key, value in item.__dict__.items() - if not key.startswith("_") - } - return {} - - def _resolve_call_id(self, data: Mapping[str, Any], default_prefix: str) -> str: - candidate = data.get("call_id") or data.get("id") - if candidate: - return str(candidate) - return f"{default_prefix or 'tool'}-{uuid4().hex}" - - def _infer_action_name(self, data: Mapping[str, Any]) -> Optional[str]: - action = data.get("action") - if hasattr(action, "model_dump"): - action = action.model_dump(mode="python", exclude_none=True) - if isinstance(action, Mapping): - return str(action.get("type") or action.get("name") or "").strip() or None - return None - def _has_result_payload(self, item: Any) -> bool: - for field in self._RESULT_HINT_FIELDS: - value = getattr(item, field, None) - if value not in (None, [], {}, ""): - return True - return False +def _has_result_payload(item: Any) -> bool: + for field in _RESULT_HINT_FIELDS: + value = getattr(item, field, None) + if value not in (None, [], {}, ""): + return True + return False diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py index a3ee7b3e0d15..cab637dadbb0 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py @@ -15,7 +15,7 @@ class HumanInTheLoopHelper: def get_pending_hitl_request( self, - thread_messages: List[Message] = None, + session_messages: List[Message] = None, checkpoint: Optional[WorkflowCheckpoint] = None, ) -> dict[str, Union[WorkflowEvent, Any]]: res: dict[str, Union[WorkflowEvent, Any]] = {} @@ -31,10 +31,10 @@ def get_pending_hitl_request( res[call_id] = request_obj return res - if not thread_messages: + if not session_messages: return res - for message in thread_messages: + for message in session_messages: for content in message.contents: if ( getattr(content, "type", None) == "function_approval_request" @@ -118,14 +118,14 @@ def convert_response(self, hitl_request: WorkflowEvent, input: Dict) -> Message: response_result = response_type.convert_from_payload(input.get("output", "")) logger.info("response_result %s", response_result) - response_content = self._content_from_function_result( + response_content = Content.from_function_result( call_id=request_id, result=response_result, ) return Message(role="tool", contents=[response_content]) - def remove_hitl_content_from_thread(self, thread_messages: List[Message]) -> List[Message]: - """Remove HITL function call contents and related results from a conversation thread.""" + def remove_hitl_content_from_session(self, session_messages: List[Message]) -> List[Message]: + """Remove HITL function call contents and related results from a conversation session.""" filtered_messages: list[Message] = [] prev_function_call = None @@ -133,7 +133,7 @@ def remove_hitl_content_from_thread(self, thread_messages: List[Message]) -> Lis prev_function_output = None pending_tool_message: Optional[Message] = None - for message in thread_messages: + for message in session_messages: filtered_contents: list[Any] = [] for content in message.contents: if content.type == "function_result": @@ -222,9 +222,3 @@ def _event_response_type(self, event: WorkflowEvent) -> Any: if isinstance(data, dict): return data.get("response_type", None) return getattr(data, "response_type", None) if data is not None else None - - def _content_from_function_result(self, call_id: str, result: Any) -> Any: - factory = getattr(Content, "from_function_result", None) - if callable(factory): - return factory(call_id=call_id, result=result) - return Content(type="function_result", call_id=call_id, result=result) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py index a90ae8acd4db..a7fb87965ceb 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py @@ -1,10 +1,10 @@ from ._foundry_checkpoint_repository import FoundryCheckpointRepository from ._foundry_checkpoint_storage import FoundryCheckpointStorage -from .agent_thread_repository import ( - AgentThreadRepository, - InMemoryAgentThreadRepository, - JsonLocalFileAgentThreadRepository, - SerializedAgentThreadRepository, +from .agent_session_repository import ( + AgentSessionRepository, + InMemoryAgentSessionRepository, + JsonLocalFileAgentSessionRepository, + SerializedAgentSessionRepository, ) from .checkpoint_repository import ( CheckpointRepository, @@ -13,10 +13,10 @@ ) __all__ = [ - "AgentThreadRepository", - "InMemoryAgentThreadRepository", - "SerializedAgentThreadRepository", - "JsonLocalFileAgentThreadRepository", + "AgentSessionRepository", + "InMemoryAgentSessionRepository", + "SerializedAgentSessionRepository", + "JsonLocalFileAgentSessionRepository", "CheckpointRepository", "InMemoryCheckpointRepository", "FileCheckpointRepository", diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py index b796b7cd0804..ea4d1507faa3 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_message_store.py @@ -1,112 +1,90 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from collections.abc import MutableMapping, Sequence -from typing import Any, List, Optional +from collections.abc import Sequence +from typing import Any, Optional -from agent_framework import Message +from agent_framework import BaseHistoryProvider, Message from azure.ai.agentserver.core.logger import get_logger from azure.ai.projects import AIProjectClient -from ..models.conversation_converters import ConversationItemConverter +from ..models.conversation_converters import to_chat_message from ..models.human_in_the_loop_helper import HumanInTheLoopHelper logger = get_logger() -class FoundryConversationMessageStore: - """Message store that reads messages from Azure AI Foundry Conversations API. +class FoundryConversationMessageStore(BaseHistoryProvider): + """History provider that reads Foundry conversation history and stores in session state.""" - This message store fetches messages from the Foundry Conversations API and converts them - to Agent Framework ``Message`` objects. Messages added via add_messages() are cached - locally but not persisted back to the API. - - :param conversation_id: The conversation ID to fetch messages from. - :type conversation_id: str - :param project_client: The Azure AI Project client used to retrieve conversation history. - :type project_client: AIProjectClient - """ + DEFAULT_SOURCE_ID = "foundry_conversation" def __init__( self, - conversation_id: str, project_client: AIProjectClient, + source_id: Optional[str] = None, ) -> None: - self._conversation_id = conversation_id + super().__init__(source_id=source_id or self.DEFAULT_SOURCE_ID) self._project_client = project_client - self._retrieved_messages: list[Message] = [] - self._cached_messages: list[Message] = [] - - async def list_messages(self) -> list[Message]: - """Get all messages from the conversation, including cached messages.""" - return self._retrieved_messages + self._cached_messages - - async def add_messages(self, messages: Sequence[Message]) -> None: - """Add messages to the local cache.""" - self._cached_messages.extend(messages) + self._hitl_helper = HumanInTheLoopHelper() - @classmethod - async def deserialize( # pylint: disable=unused-argument - cls, - serialized_store_state: MutableMapping[str, Any], + async def get_messages( + self, + session_id: Optional[str], *, - project_client: Optional[AIProjectClient] = None, + state: Optional[dict[str, Any]] = None, **kwargs: Any, - ) -> "FoundryConversationMessageStore": - conversation_id = serialized_store_state.get("conversation_id") - if not conversation_id: - raise ValueError("conversation_id is required in serialized state") + ) -> list[Message]: + if state is None: + return [] - store = cls( - conversation_id=conversation_id, - project_client=project_client, - ) + conversation_id = self._resolve_conversation_id(session_id, state) + if not conversation_id: + return list(state.get("messages", [])) - await store.update_from_state(serialized_store_state) - return store + if "retrieved_messages" not in state: + history_messages = await self._get_conversation_history(conversation_id) + state["retrieved_messages"] = self._hitl_helper.remove_hitl_content_from_session(history_messages or []) + retrieved_messages = state.get("retrieved_messages", []) + cached_messages = state.get("messages", []) + return [*retrieved_messages, *cached_messages] - async def update_from_state( # pylint: disable=unused-argument + async def save_messages( self, - serialized_store_state: MutableMapping[str, Any], + session_id: Optional[str], + messages: Sequence[Message], + *, + state: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> None: - if not serialized_store_state: + if state is None: return - cached_messages_data = serialized_store_state.get("messages", []) - self._cached_messages = [] - for msg_data in cached_messages_data: - if isinstance(msg_data, dict): - self._cached_messages.append(Message.from_dict(msg_data)) - elif isinstance(msg_data, Message): - self._cached_messages.append(msg_data) - await self.retrieve_messages() - - async def serialize(self, **kwargs: Any) -> dict[str, Any]: # pylint: disable=unused-argument - return { - "conversation_id": self._conversation_id, - "messages": [msg.to_dict() for msg in self._cached_messages], - } - - async def retrieve_messages(self) -> None: - history_messages = await self._get_conversation_history() - filtered_messages = HumanInTheLoopHelper().remove_hitl_content_from_thread(history_messages or []) - self._retrieved_messages = filtered_messages - - async def _get_conversation_history(self) -> List[Message]: + conversation_id = self._resolve_conversation_id(session_id, state) + if conversation_id: + state["conversation_id"] = conversation_id + existing_messages = state.get("messages", []) + state["messages"] = [*existing_messages, *messages] + + @staticmethod + def _resolve_conversation_id(session_id: Optional[str], state: dict[str, Any]) -> Optional[str]: + conversation_id = state.get("conversation_id") + if isinstance(conversation_id, str) and conversation_id: + return conversation_id + return session_id + + async def _get_conversation_history(self, conversation_id: str) -> list[Message]: if not self._project_client: logger.error("AIProjectClient is not configured; cannot load conversation history.") return [] try: - converter = ConversationItemConverter() async with self._project_client.get_openai_client() as openai_client: - raw_items = await openai_client.conversations.items.list(self._conversation_id) + raw_items = await openai_client.conversations.items.list(conversation_id) retrieved_messages: list[Message] = [] if raw_items is None: - self._retrieved_messages = [] return [] iter_pages = getattr(raw_items, "iter_pages", None) @@ -114,19 +92,19 @@ async def _get_conversation_history(self) -> List[Message]: async for page in iter_pages(): items = getattr(page, "data", None) or [] for item in items: - chat_message = converter.to_chat_message(item) + chat_message = to_chat_message(item) if chat_message: retrieved_messages.append(chat_message) logger.info( "Retrieved %s messages for conversation %s from Foundry.", len(retrieved_messages), - self._conversation_id, + conversation_id, ) return retrieved_messages[::-1] except Exception as exc: # pylint: disable=broad-except logger.exception( "Failed to get conversation history for %s: %s", - self._conversation_id, + conversation_id, exc, ) return [] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_session_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_session_repository.py new file mode 100644 index 000000000000..7bdb91454d28 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_session_repository.py @@ -0,0 +1,93 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional, Union + +from agent_framework import AgentSession + +from azure.ai.projects.aio import AIProjectClient +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from ._foundry_conversation_message_store import FoundryConversationMessageStore +from .agent_session_repository import AgentSessionRepository + + +class FoundryConversationSessionRepository(AgentSessionRepository): + """A Foundry Conversation implementation of AgentSessionRepository.""" + + def __init__( + self, + project_endpoint: str, + credential: Union[TokenCredential, AsyncTokenCredential], + ) -> None: + if not project_endpoint or not credential: + raise ValueError( + "Both project_endpoint and credential are required for " + "FoundryConversationSessionRepository." + ) + self._client = AIProjectClient(project_endpoint, credential) + self._history_provider = FoundryConversationMessageStore(project_client=self._client) + self._inventory: dict[str, AgentSession] = {} + + async def get( + self, + conversation_id: Optional[str], + ) -> Optional[AgentSession]: + """Retrieve the saved session for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :return: The saved AgentSession if available, None otherwise. + :rtype: Optional[AgentSession] + """ + if not conversation_id: + return None + history_provider = self._history_provider + if conversation_id in self._inventory: + session = self._inventory[conversation_id] + provider_state = session.state.setdefault(history_provider.source_id, {}) + provider_state["conversation_id"] = conversation_id + return session + + session = FoundryConversationSession( + session_id=conversation_id, + service_session_id=conversation_id, + ) + provider_state = session.state.setdefault(history_provider.source_id, {}) + provider_state["conversation_id"] = conversation_id + self._inventory[conversation_id] = session + return session + + async def set(self, conversation_id: Optional[str], session: AgentSession) -> None: + """Save the session for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param session: The session to save. + :type session: AgentSession + """ + if not conversation_id: + raise ValueError("conversation_id is required to save an AgentSession.") + + provider_state = session.state.setdefault(FoundryConversationMessageStore.DEFAULT_SOURCE_ID, {}) + provider_state["conversation_id"] = conversation_id + if not session.service_session_id: + session.service_session_id = conversation_id + self._inventory[conversation_id] = session + + @property + def history_provider(self) -> FoundryConversationMessageStore: + return self._history_provider + + +class FoundryConversationSession(AgentSession): + @property + def service_session_id(self) -> str | None: + return getattr(self, "_service_session_id", None) + + @service_session_id.setter + def service_session_id(self, service_session_id: str | None) -> None: + if service_session_id is None and hasattr(self, "_service_session_id"): + return + self._service_session_id = service_session_id diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py deleted file mode 100644 index 9d57bfb6d0b7..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_conversation_thread_repository.py +++ /dev/null @@ -1,83 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- -from typing import Optional, Union - -from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent - -from azure.ai.agentserver.core.logger import get_logger -from azure.ai.projects.aio import AIProjectClient -from azure.core.credentials import TokenCredential -from azure.core.credentials_async import AsyncTokenCredential - -from ._foundry_conversation_message_store import FoundryConversationMessageStore -from .agent_thread_repository import AgentThreadRepository - -logger = get_logger() - - -class FoundryConversationThreadRepository(AgentThreadRepository): - """A Foundry Conversation implementation of AgentThreadRepository.""" - - def __init__( - self, - agent: Optional[Union[SupportsAgentRun, WorkflowAgent]], - project_endpoint: str, - credential: Union[TokenCredential, AsyncTokenCredential], - ) -> None: - self._agent = agent - if not project_endpoint or not credential: - raise ValueError( - "Both project_endpoint and credential are required for " - "FoundryConversationThreadRepository." - ) - self._client = AIProjectClient(project_endpoint, credential) - self._inventory: dict[str, AgentSession] = {} - - async def get( - self, - conversation_id: Optional[str], - agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, - ) -> Optional[AgentSession]: - """Retrieve the saved thread for a given conversation ID. - - :param conversation_id: The conversation ID. - :type conversation_id: Optional[str] - :param agent: The agent instance. It will be used for in-memory repository for interface consistency. - :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] - :return: The saved AgentSession if available, None otherwise. - :rtype: Optional[AgentSession] - """ - if not conversation_id: - return None - if conversation_id in self._inventory: - return self._inventory[conversation_id] - - message_store = FoundryConversationMessageStore(conversation_id, self._client) - await message_store.retrieve_messages() - self._inventory[conversation_id] = FoundryConversationThread(message_store=message_store) - return self._inventory[conversation_id] - - async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: - """Save the thread for a given conversation ID. - - :param conversation_id: The conversation ID. - :type conversation_id: Optional[str] - :param thread: The thread to save. - :type thread: AgentSession - """ - if not conversation_id: - raise ValueError("conversation_id is required to save an AgentSession.") - self._inventory[conversation_id] = thread - - -class FoundryConversationThread(AgentSession): - @property - def service_thread_id(self) -> str | None: - return self._service_thread_id - - @service_thread_id.setter - def service_thread_id(self, service_thread_id: str | None) -> None: - if service_thread_id is None: - return - self._service_thread_id = service_thread_id diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_session_repository.py similarity index 50% rename from sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py rename to sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_session_repository.py index 7fa702d59c7c..1b462e6da8e1 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_session_repository.py @@ -4,12 +4,12 @@ import json import os from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, Optional -from agent_framework import AgentSession, SupportsAgentRun, WorkflowAgent +from agent_framework import AgentSession -class AgentThreadRepository(ABC): +class AgentSessionRepository(ABC): """ Repository to manage persisted agent session state for conversations and workflows. @@ -20,32 +20,29 @@ class AgentThreadRepository(ABC): async def get( self, conversation_id: Optional[str], - agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, ) -> Optional[AgentSession]: - """Retrieve the saved thread for a given conversation ID. + """Retrieve the saved session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param agent: The agent instance. If provided, it can be used to deserialize the thread. - :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] :return: The saved AgentSession if available, None otherwise. :rtype: Optional[AgentSession] """ @abstractmethod - async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: - """Save the thread for a given conversation ID. + async def set(self, conversation_id: Optional[str], session: AgentSession) -> None: + """Save the session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param thread: The thread to save. - :type thread: AgentSession + :param session: The session to save. + :type session: AgentSession """ -class InMemoryAgentThreadRepository(AgentThreadRepository): - """In-memory implementation of AgentThreadRepository.""" +class InMemoryAgentSessionRepository(AgentSessionRepository): + """In-memory implementation of AgentSessionRepository.""" def __init__(self) -> None: self._inventory: dict[str, AgentSession] = {} @@ -53,14 +50,11 @@ def __init__(self) -> None: async def get( self, conversation_id: Optional[str], - agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, ) -> Optional[AgentSession]: - """Retrieve the saved thread for a given conversation ID. + """Retrieve the saved session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param agent: The agent instance. It will be used for in-memory repository for interface consistency. - :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] :return: The saved AgentSession if available, None otherwise. :rtype: Optional[AgentSession] """ @@ -70,98 +64,82 @@ async def get( return self._inventory[conversation_id] return None - async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: - """Save the thread for a given conversation ID. + async def set(self, conversation_id: Optional[str], session: AgentSession) -> None: + """Save the session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param thread: The thread to save. - :type thread: AgentSession + :param session: The session to save. + :type session: AgentSession """ - if not conversation_id or not thread: + if not conversation_id or not session: return - self._inventory[conversation_id] = thread + self._inventory[conversation_id] = session -class SerializedAgentThreadRepository(AgentThreadRepository): - """Implementation of AgentThreadRepository with AgentSession serialization.""" - - def __init__(self, agent: SupportsAgentRun) -> None: - """ - Initialize the repository with the given agent. - - :param agent: The agent instance. - :type agent: SupportsAgentRun - """ - self._agent = agent +class SerializedAgentSessionRepository(AgentSessionRepository): + """Implementation of AgentSessionRepository with AgentSession serialization.""" async def get( self, conversation_id: Optional[str], - agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] = None, ) -> Optional[AgentSession]: - """Retrieve the saved thread for a given conversation ID. + """Retrieve the saved session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param agent: The agent instance. If provided, it can be used to deserialize the thread. - Otherwise, the repository's agent will be used. - :type agent: Optional[Union[SupportsAgentRun, WorkflowAgent]] :return: The saved AgentSession if available, None otherwise. :rtype: Optional[AgentSession] """ if not conversation_id: return None - serialized_thread = await self.read_from_storage(conversation_id) - if serialized_thread: - agent_to_use = agent or self._agent - thread = await agent_to_use.deserialize_session(serialized_thread) - return thread + serialized_session = await self.read_from_storage(conversation_id) + if serialized_session: + return AgentSession.from_dict(serialized_session) return None - async def set(self, conversation_id: Optional[str], thread: AgentSession) -> None: - """Save the thread for a given conversation ID. + async def set(self, conversation_id: Optional[str], session: AgentSession) -> None: + """Save the session for a given conversation ID. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param thread: The thread to save. - :type thread: AgentSession + :param session: The session to save. + :type session: AgentSession """ if not conversation_id: return - serialized_thread = await thread.serialize() - await self.write_to_storage(conversation_id, serialized_thread) + serialized_session = session.to_dict() + await self.write_to_storage(conversation_id, serialized_session) async def read_from_storage(self, conversation_id: Optional[str]) -> Optional[Any]: - """Read the serialized thread from storage. + """Read the serialized session from storage. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :return: The serialized thread if available, None otherwise. + :return: The serialized session if available, None otherwise. :rtype: Optional[Any] """ raise NotImplementedError("read_from_storage is not implemented.") - async def write_to_storage(self, conversation_id: Optional[str], serialized_thread: Any) -> None: - """Write the serialized thread to storage. + async def write_to_storage(self, conversation_id: Optional[str], serialized_session: Any) -> None: + """Write the serialized session to storage. :param conversation_id: The conversation ID. :type conversation_id: Optional[str] - :param serialized_thread: The serialized thread to save. - :type serialized_thread: Any + :param serialized_session: The serialized session to save. + :type serialized_session: Any :return: None :rtype: None """ raise NotImplementedError("write_to_storage is not implemented.") -class JsonLocalFileAgentThreadRepository(SerializedAgentThreadRepository): - """Json based implementation of AgentThreadRepository using local file storage.""" +class JsonLocalFileAgentSessionRepository(SerializedAgentSessionRepository): + """Json based implementation of AgentSessionRepository using local file storage.""" - def __init__(self, agent: SupportsAgentRun, storage_path: str) -> None: - super().__init__(agent) + def __init__(self, storage_path: str) -> None: self._storage_path = storage_path os.makedirs(self._storage_path, exist_ok=True) @@ -171,15 +149,15 @@ async def read_from_storage(self, conversation_id: Optional[str]) -> Optional[An file_path = self._get_file_path(conversation_id) if os.path.exists(file_path): with open(file_path, "r", encoding="utf-8") as f: - serialized_thread = f.read() - if serialized_thread: - return json.loads(serialized_thread) + serialized_session = f.read() + if serialized_session: + return json.loads(serialized_session) return None - async def write_to_storage(self, conversation_id: Optional[str], serialized_thread: Any) -> None: + async def write_to_storage(self, conversation_id: Optional[str], serialized_session: Any) -> None: if not conversation_id: return - serialized_str = json.dumps(serialized_thread) + serialized_str = json.dumps(serialized_session) file_path = self._get_file_path(conversation_id) with open(file_path, "w", encoding="utf-8") as f: f.write(serialized_str) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore index c21dfd88c196..e22fd025f4c7 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore @@ -1 +1 @@ -thread_storage \ No newline at end of file +session_storage \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md index 10c92b837663..77fc8483cadc 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md @@ -20,14 +20,14 @@ AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= `main.py` automatically loads the `.env` file before spinning up the server. -## Thread persistence +## Session persistence -The sample uses `JsonLocalFileAgentThreadRepository` for `AgentSession` persistence, creating a JSON file per conversation ID under the sample directory. An in-memory alternative, `InMemoryAgentThreadRepository`, lives in the `azure.ai.agentserver.agentframework.persistence` module. +The sample uses `JsonLocalFileAgentSessionRepository` for `AgentSession` persistence, creating a JSON file per conversation ID under the sample directory. An in-memory alternative, `InMemoryAgentSessionRepository`, lives in the `azure.ai.agentserver.agentframework.persistence` module. -To store thread messages elsewhere, inherit from `SerializedAgentThreadRepository` and override the following methods: +To store session messages elsewhere, inherit from `SerializedAgentSessionRepository` and override the following methods: - `read_from_storage(self, conversation_id: str) -> Optional[Any]` -- `write_to_storage(self, conversation_id: str, serialized_thread: Any)` +- `write_to_storage(self, conversation_id: str, serialized_session: Any)` These hooks let you plug in any backing store (blob storage, databases, etc.) without changing the rest of the sample. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py index c3df20e9ecd8..c21337ee1cff 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py @@ -8,7 +8,7 @@ from dotenv import load_dotenv from azure.ai.agentserver.agentframework import from_agent_framework -from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import JsonLocalFileAgentThreadRepository +from azure.ai.agentserver.agentframework.persistence.agent_session_repository import JsonLocalFileAgentSessionRepository """ Tool Approvals with Sessions @@ -39,8 +39,8 @@ def build_agent() -> ChatAgent: async def main() -> None: agent = build_agent() - thread_repository = JsonLocalFileAgentThreadRepository(agent=agent, storage_path="./thread_storage") - await from_agent_framework(agent, thread_repository=thread_repository).run_async() + session_repository = JsonLocalFileAgentSessionRepository(storage_path="./session_storage") + await from_agent_framework(agent, session_repository=session_repository).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py index 4c9b68f2132a..562ea7086eeb 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py @@ -1,30 +1,54 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pytest -from agent_framework import AgentSession, InMemoryCheckpointStorage +from agent_framework import AgentSession, BaseHistoryProvider, Content, InMemoryCheckpointStorage, Message +from azure.core.credentials_async import AsyncTokenCredential -from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import ( - InMemoryAgentThreadRepository, +from azure.ai.agentserver.agentframework.persistence._foundry_conversation_message_store import ( + FoundryConversationMessageStore, +) +from azure.ai.agentserver.agentframework.persistence._foundry_conversation_session_repository import ( + FoundryConversationSessionRepository, +) +from azure.ai.agentserver.agentframework.persistence.agent_session_repository import ( + InMemoryAgentSessionRepository, + SerializedAgentSessionRepository, ) from azure.ai.agentserver.agentframework.persistence.checkpoint_repository import ( InMemoryCheckpointRepository, ) +class _MemorySerializedSessionRepository(SerializedAgentSessionRepository): + def __init__(self) -> None: + super().__init__() + self._storage: dict[str, dict] = {} + + async def read_from_storage(self, conversation_id): + if not conversation_id: + return None + return self._storage.get(conversation_id) + + async def write_to_storage(self, conversation_id, serialized_session): + if not conversation_id: + return + self._storage[conversation_id] = serialized_session + + @pytest.mark.unit @pytest.mark.asyncio -async def test_inmemory_thread_repository_ignores_missing_conversation_id() -> None: - repo = InMemoryAgentThreadRepository() - thread = Mock(spec=AgentSession) +async def test_inmemory_session_repository_ignores_missing_conversation_id() -> None: + repo = InMemoryAgentSessionRepository() + session = Mock(spec=AgentSession) - await repo.set(None, thread) + await repo.set(None, session) assert await repo.get(None) is None - await repo.set("conv-1", thread) - assert await repo.get("conv-1") is thread + await repo.set("conv-1", session) + assert await repo.get("conv-1") is session @pytest.mark.unit @@ -36,3 +60,91 @@ async def test_inmemory_checkpoint_repository_returns_none_without_conversation_ storage = await repo.get_or_create("conv-1") assert isinstance(storage, InMemoryCheckpointStorage) + + +@pytest.mark.unit +def test_foundry_message_store_is_base_history_provider() -> None: + provider = FoundryConversationMessageStore(project_client=Mock()) + + assert isinstance(provider, BaseHistoryProvider) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_foundry_message_store_reads_remote_and_cached_messages(monkeypatch) -> None: + provider = FoundryConversationMessageStore(project_client=Mock()) + history_message = Message(role="assistant", contents=[Content.from_text("history")]) + new_message = Message(role="user", contents=[Content.from_text("new")]) + + get_history = AsyncMock(return_value=[history_message]) + monkeypatch.setattr(provider, "_get_conversation_history", get_history) + + state = {"conversation_id": "conv-1"} + first_messages = await provider.get_messages("conv-1", state=state) + await provider.save_messages("conv-1", [new_message], state=state) + second_messages = await provider.get_messages("conv-1", state=state) + + assert [message.text for message in first_messages] == ["history"] + assert [message.role for message in second_messages] == ["assistant", "user"] + assert [message.text for message in second_messages] == ["history", "new"] + get_history.assert_awaited_once() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_foundry_session_repository_sets_conversation_state(monkeypatch) -> None: + import azure.ai.agentserver.agentframework.persistence._foundry_conversation_session_repository as repo_module + + monkeypatch.setattr(repo_module, "AIProjectClient", lambda endpoint, credential: Mock()) + + repo = FoundryConversationSessionRepository( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=Mock(spec=AsyncTokenCredential), + ) + + session_one = await repo.get("conv-1") + session_two = await repo.get("conv-2") + source_id = repo.history_provider.source_id + + assert session_one is not None + assert session_two is not None + assert session_one.session_id == "conv-1" + assert session_one.service_session_id == "conv-1" + assert session_two.session_id == "conv-2" + assert session_two.service_session_id == "conv-2" + assert session_one.state[source_id]["conversation_id"] == "conv-1" + assert session_two.state[source_id]["conversation_id"] == "conv-2" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_foundry_session_repository_reuses_existing_session(monkeypatch) -> None: + import azure.ai.agentserver.agentframework.persistence._foundry_conversation_session_repository as repo_module + + monkeypatch.setattr(repo_module, "AIProjectClient", lambda endpoint, credential: Mock()) + + repo = FoundryConversationSessionRepository( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=Mock(spec=AsyncTokenCredential), + ) + + first_session = await repo.get("conv-1") + second_session = await repo.get("conv-1") + + assert first_session is second_session + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_serialized_session_repository_uses_agent_session_to_dict_from_dict() -> None: + repo = _MemorySerializedSessionRepository() + session = AgentSession(session_id="local-1", service_session_id="service-1") + session.state["counter"] = 2 + + await repo.set("conv-1", session) + loaded = await repo.get("conv-1") + + assert loaded is not None + assert loaded.session_id == "local-1" + assert loaded.service_session_id == "service-1" + assert loaded.state["counter"] == 2 diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py index 9a98bf3576cb..fd15a4423edc 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_item_converter.py @@ -9,16 +9,11 @@ from openai.types.responses.response_function_tool_call_output_item import ResponseFunctionToolCallOutputItem from openai.types.responses.response_input_text import ResponseInputText -from azure.ai.agentserver.agentframework.models.conversation_converters import ConversationItemConverter - - -@pytest.fixture() -def converter() -> ConversationItemConverter: - return ConversationItemConverter() +from azure.ai.agentserver.agentframework.models.conversation_converters import to_chat_message @pytest.mark.unit -def test_to_chat_message_converts_basic_message(converter: ConversationItemConverter) -> None: +def test_to_chat_message_converts_basic_message() -> None: item = Message( id="msg_1", role="user", @@ -27,7 +22,7 @@ def test_to_chat_message_converts_basic_message(converter: ConversationItemConve content=[ResponseInputText(text="Hello world", type="input_text")], ) - result = converter.to_chat_message(item) + result = to_chat_message(item) assert result is not None assert isinstance(result, AFMessage) @@ -37,7 +32,7 @@ def test_to_chat_message_converts_basic_message(converter: ConversationItemConve @pytest.mark.unit -def test_to_chat_message_converts_function_call_item(converter: ConversationItemConverter) -> None: +def test_to_chat_message_converts_function_call_item() -> None: item = ResponseFunctionToolCallItem( id="call_item_1", type="function_call", @@ -47,7 +42,7 @@ def test_to_chat_message_converts_function_call_item(converter: ConversationItem arguments='{"foo": "bar"}', ) - result = converter.to_chat_message(item) + result = to_chat_message(item) assert result is not None assert result.role == "assistant" @@ -62,7 +57,7 @@ def test_to_chat_message_converts_function_call_item(converter: ConversationItem @pytest.mark.unit -def test_to_chat_message_converts_function_result_item(converter: ConversationItemConverter) -> None: +def test_to_chat_message_converts_function_result_item() -> None: item = ResponseFunctionToolCallOutputItem( id="call_output_1", type="function_call_output", @@ -71,7 +66,7 @@ def test_to_chat_message_converts_function_result_item(converter: ConversationIt output='{"answer": 42}', ) - result = converter.to_chat_message(item) + result = to_chat_message(item) assert result is not None assert result.role == "tool" @@ -84,7 +79,7 @@ def test_to_chat_message_converts_function_result_item(converter: ConversationIt @pytest.mark.unit -def test_to_chat_message_converts_reasoning_item(converter: ConversationItemConverter) -> None: +def test_to_chat_message_converts_reasoning_item() -> None: reasoning_item = response_reasoning_item.ResponseReasoningItem( id="reasoning_1", type="reasoning", @@ -93,7 +88,7 @@ def test_to_chat_message_converts_reasoning_item(converter: ConversationItemConv content=[response_reasoning_item.Content(text="Chain-of-thought", type="reasoning_text")], ) - result = converter.to_chat_message(reasoning_item) + result = to_chat_message(reasoning_item) assert result is not None assert result.role == "assistant" @@ -103,8 +98,8 @@ def test_to_chat_message_converts_reasoning_item(converter: ConversationItemConv @pytest.mark.unit -def test_to_chat_message_returns_none_for_unsupported_items(converter: ConversationItemConverter) -> None: +def test_to_chat_message_returns_none_for_unsupported_items() -> None: class UnsupportedItem: type = "unsupported" - assert converter.to_chat_message(UnsupportedItem()) is None + assert to_chat_message(UnsupportedItem()) is None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py index d1a666a3cb5c..5347a3f32298 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py @@ -29,7 +29,7 @@ def __init__( self, workflow_factory, credentials=None, - thread_repository=None, + session_repository=None, checkpoint_repository=None, **kwargs, ): diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py index d0a088af8721..bf6f0bd74ac5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_human_in_the_loop_helper.py @@ -15,17 +15,11 @@ def helper() -> HumanInTheLoopHelper: def _function_call(call_id: str, name: str, arguments: str): - factory = getattr(Content, "from_function_call", None) - if callable(factory): - return factory(call_id=call_id, name=name, arguments=arguments) - return Content(type="function_call", call_id=call_id, name=name, arguments=arguments) + return Content.from_function_call(call_id=call_id, name=name, arguments=arguments) def _function_result(call_id: str, result): - factory = getattr(Content, "from_function_result", None) - if callable(factory): - return factory(call_id=call_id, result=result) - return Content(type="function_result", call_id=call_id, result=result) + return Content.from_function_result(call_id=call_id, result=result) @pytest.mark.unit @@ -36,14 +30,14 @@ def test_remove_hitl_messages_keeps_latest_function_result(helper: HumanInTheLoo final_result = _function_result("tool-1", {"total": 42}) follow_up_content = Content.from_text("work resumed") - thread_messages = [ + session_messages = [ Message(role="assistant", contents=[real_call, hitl_call]), Message(role="tool", contents=[feedback_result]), Message(role="tool", contents=[final_result]), Message(role="assistant", contents=[follow_up_content]), ] - filtered = helper.remove_hitl_content_from_thread(thread_messages) + filtered = helper.remove_hitl_content_from_session(session_messages) assert len(filtered) == 3 assert filtered[0].role == "assistant" @@ -63,13 +57,13 @@ def test_remove_hitl_messages_keeps_the_function_result(helper: HumanInTheLoopHe final_result = _function_result("tool-1", {"total": 42}) follow_up_content = Content.from_text("work resumed") - thread_messages = [ + session_messages = [ Message(role="assistant", contents=[real_call, hitl_call]), Message(role="tool", contents=[final_result]), Message(role="assistant", contents=[follow_up_content]), ] - filtered = helper.remove_hitl_content_from_thread(thread_messages) + filtered = helper.remove_hitl_content_from_session(session_messages) assert len(filtered) == 3 assert filtered[0].role == "assistant" @@ -88,13 +82,13 @@ def test_remove_hitl_messages_skips_orphaned_hitl_results(helper: HumanInTheLoop orphan_result = _function_result("hitl-2", "ignored") user_update = Content.from_text("ready") - thread_messages = [ + session_messages = [ Message(role="assistant", contents=[hitl_call]), Message(role="tool", contents=[orphan_result]), Message(role="user", contents=[user_update]), ] - filtered = helper.remove_hitl_content_from_thread(thread_messages) + filtered = helper.remove_hitl_content_from_session(session_messages) assert len(filtered) == 1 assert len(filtered[0].contents) == 1 @@ -106,12 +100,12 @@ def test_remove_hitl_messages_preserves_regular_tool_cycle(helper: HumanInTheLoo real_call = _function_call("tool-3", "lookup", "{}") result_content = _function_result("tool-3", "done") - thread_messages = [ + session_messages = [ Message(role="assistant", contents=[real_call]), Message(role="tool", contents=[result_content]), ] - filtered = helper.remove_hitl_content_from_thread(thread_messages) + filtered = helper.remove_hitl_content_from_session(session_messages) assert len(filtered) == 2 assert len(filtered[0].contents) == 1