diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/__init__.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/__init__.py index ed572337..f639246f 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/__init__.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/__init__.py @@ -6,7 +6,9 @@ jwt_authorization_middleware, jwt_authorization_decorator, ) -from .app.streaming import ( + +# Import streaming utilities from core for backward compatibility +from microsoft_agents.hosting.core.app.streaming import ( Citation, CitationUtil, StreamingResponse, diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py index f495aa9c..05986cb1 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/app/streaming/streaming_response.py @@ -3,8 +3,7 @@ import asyncio import logging -from typing import List, Optional, Callable, Literal, TYPE_CHECKING -from dataclasses import dataclass +from typing import List, Optional, Callable, Literal from microsoft_agents.activity import ( Activity, diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/channel_service_route_table.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/channel_service_route_table.py index 4a8a193b..3a2acad4 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/channel_service_route_table.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/channel_service_route_table.py @@ -1,102 +1,81 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import json -from typing import List, Union, Type from aiohttp.web import RouteTableDef, Request, Response -from microsoft_agents.activity import ( - AgentsModel, - Activity, - AttachmentData, - ConversationParameters, - Transcript, -) from microsoft_agents.hosting.core import ChannelApiHandlerProtocol +from microsoft_agents.hosting.core.http import ChannelServiceRoutes -async def deserialize_from_body( - request: Request, target_model: Type[AgentsModel] -) -> Activity: - if "application/json" in request.headers["Content-Type"]: - body = await request.json() - else: - return Response(status=415) +class AiohttpRequestAdapter: + """Adapter for aiohttp requests to use with ChannelServiceRoutes.""" - return target_model.model_validate(body) + def __init__(self, request: Request): + self._request = request + @property + def method(self) -> str: + return self._request.method -def get_serialized_response( - model_or_list: Union[AgentsModel, List[AgentsModel]], -) -> Response: - if isinstance(model_or_list, AgentsModel): - json_obj = model_or_list.model_dump( - mode="json", exclude_unset=True, by_alias=True - ) - else: - json_obj = [ - model.model_dump(mode="json", exclude_unset=True, by_alias=True) - for model in model_or_list - ] + @property + def headers(self): + return self._request.headers + + async def json(self): + return await self._request.json() + + def get_claims_identity(self): + return self._request.get("claims_identity") - return Response(body=json.dumps(json_obj), content_type="application/json") + def get_path_param(self, name: str) -> str: + return self._request.match_info[name] def channel_service_route_table( handler: ChannelApiHandlerProtocol, base_url: str = "" ) -> RouteTableDef: - # pylint: disable=unused-variable + """Create aiohttp route table for Channel Service API. + + Args: + handler: The handler that implements the Channel API protocol. + base_url: Optional base URL prefix for all routes. + + Returns: + RouteTableDef with all channel service routes. + """ routes = RouteTableDef() + service_routes = ChannelServiceRoutes(handler, base_url) + + def json_response(data: dict) -> Response: + return Response(body=json.dumps(data), content_type="application/json") @routes.post(base_url + "/v3/conversations/{conversation_id}/activities") async def send_to_conversation(request: Request): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_send_to_conversation( - request.get("claims_identity"), - request.match_info["conversation_id"], - activity, + result = await service_routes.send_to_conversation( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.post( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" ) async def reply_to_activity(request: Request): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_reply_to_activity( - request.get("claims_identity"), - request.match_info["conversation_id"], - request.match_info["activity_id"], - activity, - ) - - return get_serialized_response(result) + result = await service_routes.reply_to_activity(AiohttpRequestAdapter(request)) + return json_response(result) @routes.put( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" ) async def update_activity(request: Request): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_update_activity( - request.get("claims_identity"), - request.match_info["conversation_id"], - request.match_info["activity_id"], - activity, - ) - - return get_serialized_response(result) + result = await service_routes.update_activity(AiohttpRequestAdapter(request)) + return json_response(result) @routes.delete( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" ) async def delete_activity(request: Request): - await handler.on_delete_activity( - request.get("claims_identity"), - request.match_info["conversation_id"], - request.match_info["activity_id"], - ) - + await service_routes.delete_activity(AiohttpRequestAdapter(request)) return Response() @routes.get( @@ -104,91 +83,61 @@ async def delete_activity(request: Request): + "/v3/conversations/{conversation_id}/activities/{activity_id}/members" ) async def get_activity_members(request: Request): - result = await handler.on_get_activity_members( - request.get("claims_identity"), - request.match_info["conversation_id"], - request.match_info["activity_id"], + result = await service_routes.get_activity_members( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.post(base_url + "/") async def create_conversation(request: Request): - conversation_parameters = deserialize_from_body(request, ConversationParameters) - result = await handler.on_create_conversation( - request.get("claims_identity"), conversation_parameters + result = await service_routes.create_conversation( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.get(base_url + "/") async def get_conversation(request: Request): - # TODO: continuation token? conversation_id? - result = await handler.on_get_conversations( - request.get("claims_identity"), None - ) - - return get_serialized_response(result) + result = await service_routes.get_conversations(AiohttpRequestAdapter(request)) + return json_response(result) @routes.get(base_url + "/v3/conversations/{conversation_id}/members") async def get_conversation_members(request: Request): - result = await handler.on_get_conversation_members( - request.get("claims_identity"), - request.match_info["conversation_id"], + result = await service_routes.get_conversation_members( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.get(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") async def get_conversation_member(request: Request): - result = await handler.on_get_conversation_member( - request.get("claims_identity"), - request.match_info["member_id"], - request.match_info["conversation_id"], + result = await service_routes.get_conversation_member( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.get(base_url + "/v3/conversations/{conversation_id}/pagedmembers") async def get_conversation_paged_members(request: Request): - # TODO: continuation token? page size? - result = await handler.on_get_conversation_paged_members( - request.get("claims_identity"), - request.match_info["conversation_id"], + result = await service_routes.get_conversation_paged_members( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.delete(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") async def delete_conversation_member(request: Request): - result = await handler.on_delete_conversation_member( - request.get("claims_identity"), - request.match_info["conversation_id"], - request.match_info["member_id"], + result = await service_routes.delete_conversation_member( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.post(base_url + "/v3/conversations/{conversation_id}/activities/history") async def send_conversation_history(request: Request): - transcript = deserialize_from_body(request, Transcript) - result = await handler.on_send_conversation_history( - request.get("claims_identity"), - request.match_info["conversation_id"], - transcript, + result = await service_routes.send_conversation_history( + AiohttpRequestAdapter(request) ) - - return get_serialized_response(result) + return json_response(result) @routes.post(base_url + "/v3/conversations/{conversation_id}/attachments") async def upload_attachment(request: Request): - attachment_data = deserialize_from_body(request, AttachmentData) - result = await handler.on_upload_attachment( - request.get("claims_identity"), - request.match_info["conversation_id"], - attachment_data, - ) - - return get_serialized_response(result) + result = await service_routes.upload_attachment(AiohttpRequestAdapter(request)) + return json_response(result) return routes diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py index a8833f22..c384dd95 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/cloud_adapter.py @@ -1,39 +1,47 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from traceback import format_exc from typing import Optional -from aiohttp.web import ( - Request, - Response, - json_response, - HTTPBadRequest, - HTTPMethodNotAllowed, - HTTPUnauthorized, - HTTPUnsupportedMediaType, -) -from microsoft_agents.hosting.core import error_resources -from microsoft_agents.hosting.core.authorization import ( - ClaimsIdentity, - Connections, -) -from microsoft_agents.activity import ( - Activity, - DeliveryModes, -) -from microsoft_agents.hosting.core import ( - Agent, - ChannelServiceAdapter, - ChannelServiceClientFactoryBase, - MessageFactory, - RestChannelServiceClientFactory, - TurnContext, +from aiohttp.web import Request, Response, json_response + +from microsoft_agents.hosting.core import Agent +from microsoft_agents.hosting.core.authorization import Connections +from microsoft_agents.hosting.core.http import ( + HttpAdapterBase, + HttpResponse, ) +from microsoft_agents.hosting.core import ChannelServiceClientFactoryBase from .agent_http_adapter import AgentHttpAdapter -class CloudAdapter(ChannelServiceAdapter, AgentHttpAdapter): +class AiohttpRequestAdapter: + """Adapter to make aiohttp Request compatible with HttpRequestProtocol.""" + + def __init__(self, request: Request): + self._request = request + + @property + def method(self) -> str: + return self._request.method + + @property + def headers(self): + return self._request.headers + + async def json(self): + return await self._request.json() + + def get_claims_identity(self): + return self._request.get("claims_identity") + + def get_path_param(self, name: str) -> str: + return self._request.match_info[name] + + +class CloudAdapter(HttpAdapterBase, AgentHttpAdapter): + """CloudAdapter for aiohttp web framework.""" + def __init__( self, *, @@ -43,77 +51,43 @@ def __init__( """ Initializes a new instance of the CloudAdapter class. + :param connection_manager: Optional connection manager for OAuth. :param channel_service_client_factory: The factory to use to create the channel service client. """ + super().__init__( + connection_manager=connection_manager, + channel_service_client_factory=channel_service_client_factory, + ) - async def on_turn_error(context: TurnContext, error: Exception): - error_message = f"Exception caught : {error}" - print(format_exc()) + async def process(self, request: Request, agent: Agent) -> Optional[Response]: + """Process an aiohttp request. - await context.send_activity(MessageFactory.text(error_message)) + Args: + request: The aiohttp request. + agent: The agent to handle the request. - # Send a trace activity - await context.send_trace_activity( - "OnTurnError Trace", - error_message, - "https://www.botframework.com/schemas/error", - "TurnError", + Returns: + aiohttp Response object. + """ + # Adapt request to protocol + adapted_request = AiohttpRequestAdapter(request) + + # Process using base implementation + http_response: HttpResponse = await self.process_request(adapted_request, agent) + + # Convert HttpResponse to aiohttp Response + return self._to_aiohttp_response(http_response) + + @staticmethod + def _to_aiohttp_response(http_response: HttpResponse) -> Response: + """Convert HttpResponse to aiohttp Response.""" + if http_response.body is not None: + return json_response( + data=http_response.body, + status=http_response.status_code, + headers=http_response.headers, ) - - self.on_turn_error = on_turn_error - - channel_service_client_factory = ( - channel_service_client_factory - or RestChannelServiceClientFactory(connection_manager) + return Response( + status=http_response.status_code, + headers=http_response.headers, ) - - super().__init__(channel_service_client_factory) - - async def process(self, request: Request, agent: Agent) -> Optional[Response]: - if not request: - raise TypeError(str(error_resources.RequestRequired)) - if not agent: - raise TypeError(str(error_resources.AgentRequired)) - - if request.method == "POST": - # Deserialize the incoming Activity - if "application/json" in request.headers["Content-Type"]: - body = await request.json() - else: - raise HTTPUnsupportedMediaType() - - activity: Activity = Activity.model_validate(body) - - # default to anonymous identity with no claims - claims_identity: ClaimsIdentity = request.get( - "claims_identity", ClaimsIdentity({}, False) - ) - - # A POST request must contain an Activity - if ( - not activity.type - or not activity.conversation - or not activity.conversation.id - ): - raise HTTPBadRequest - - try: - # Process the inbound activity with the agent - invoke_response = await self.process_activity( - claims_identity, activity, agent.on_turn - ) - - if ( - activity.type == "invoke" - or activity.delivery_mode == DeliveryModes.expect_replies - ): - # Invoke and ExpectReplies cannot be performed async, the response must be written before the calling thread is released. - return json_response( - data=invoke_response.body, status=invoke_response.status - ) - - return Response(status=202) - except PermissionError: - raise HTTPUnauthorized - else: - raise HTTPMethodNotAllowed diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py index 9abce32c..329682c0 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/__init__.py @@ -10,6 +10,15 @@ from .rest_channel_service_client_factory import RestChannelServiceClientFactory from .turn_context import TurnContext +# HTTP abstractions +from .http import ( + HttpRequestProtocol, + HttpResponse, + HttpResponseFactory, + HttpAdapterBase, + ChannelServiceRoutes, +) + # Application Style from .app._type_defs import RouteHandler, RouteSelector, StateT from .app.agent_application import AgentApplication @@ -20,6 +29,13 @@ from .app._routes import _Route, _RouteList, RouteRank from .app.typing_indicator import TypingIndicator +# App Streaming +from .app.streaming import ( + Citation, + CitationUtil, + StreamingResponse, +) + # App Auth from .app.oauth import ( Authorization, @@ -99,6 +115,11 @@ "Middleware", "RestChannelServiceClientFactory", "TurnContext", + "HttpRequestProtocol", + "HttpResponse", + "HttpResponseFactory", + "HttpAdapterBase", + "ChannelServiceRoutes", "AgentApplication", "ApplicationError", "ApplicationOptions", @@ -108,6 +129,9 @@ "Route", "RouteHandler", "TypingIndicator", + "Citation", + "CitationUtil", + "StreamingResponse", "ConversationState", "state", "State", diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/__init__.py new file mode 100644 index 00000000..89efa87d --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Streaming response utilities.""" + +from .citation import Citation +from .citation_util import CitationUtil +from .streaming_response import StreamingResponse + +__all__ = [ + "Citation", + "CitationUtil", + "StreamingResponse", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation.py new file mode 100644 index 00000000..f643639a --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class Citation: + """Citations returned by the model.""" + + content: str + """The content of the citation.""" + + title: Optional[str] = None + """The title of the citation.""" + + url: Optional[str] = None + """The URL of the citation.""" + + filepath: Optional[str] = None + """The filepath of the document.""" diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation_util.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation_util.py new file mode 100644 index 00000000..1ec923dc --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/citation_util.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import re +from typing import List, Optional + +from microsoft_agents.activity import ClientCitation + + +class CitationUtil: + """Utility functions for manipulating text and citations.""" + + @staticmethod + def snippet(text: str, max_length: int) -> str: + """ + Clips the text to a maximum length in case it exceeds the limit. + + Args: + text: The text to clip. + max_length: The maximum length of the text to return, cutting off the last whole word. + + Returns: + The modified text + """ + if len(text) <= max_length: + return text + + snippet = text[:max_length] + snippet = snippet[: min(len(snippet), snippet.rfind(" "))] + snippet += "..." + return snippet + + @staticmethod + def format_citations_response(text: str) -> str: + """ + Convert citation tags `[doc(s)n]` to `[n]` where n is a number. + + Args: + text: The text to format. + + Returns: + The formatted text. + """ + return re.sub(r"\[docs?(\d+)\]", r"[\1]", text, flags=re.IGNORECASE) + + @staticmethod + def get_used_citations( + text: str, citations: List[ClientCitation] + ) -> Optional[List[ClientCitation]]: + """ + Get the citations used in the text. This will remove any citations that are + included in the citations array from the response but not referenced in the text. + + Args: + text: The text to search for citation references, i.e. [1], [2], etc. + citations: The list of citations to search for. + + Returns: + The list of citations used in the text. + """ + regex = re.compile(r"\[(\d+)\]", re.IGNORECASE) + matches = regex.findall(text) + + if not matches: + return None + + # Remove duplicates + filtered_matches = set(matches) + + # Add citations + used_citations = [] + for match in filtered_matches: + citation_ref = f"[{match}]" + found = next( + ( + citation + for citation in citations + if f"[{citation.position}]" == citation_ref + ), + None, + ) + if found: + used_citations.append(found) + + return used_citations if used_citations else None diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/streaming_response.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/streaming_response.py new file mode 100644 index 00000000..2d5b0fbf --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/app/streaming/streaming_response.py @@ -0,0 +1,411 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import logging +from typing import List, Optional, Callable, Literal, TYPE_CHECKING + +from microsoft_agents.activity import ( + Activity, + Entity, + Attachment, + Channels, + ClientCitation, + DeliveryModes, + SensitivityUsageInfo, +) + +if TYPE_CHECKING: + from microsoft_agents.hosting.core.turn_context import TurnContext + +from .citation import Citation +from .citation_util import CitationUtil + +logger = logging.getLogger(__name__) + + +class StreamingResponse: + """ + A helper class for streaming responses to the client. + + This class is used to send a series of updates to the client in a single response. + The expected sequence of calls is: + + `queue_informative_update()`, `queue_text_chunk()`, `queue_text_chunk()`, ..., `end_stream()`. + + Once `end_stream()` is called, the stream is considered ended and no further updates can be sent. + """ + + def __init__(self, context: "TurnContext"): + """ + Creates a new StreamingResponse instance. + + Args: + context: Context for the current turn of conversation with the user. + """ + self._context = context + self._sequence_number = 1 + self._stream_id: Optional[str] = None + self._message = "" + self._attachments: Optional[List[Attachment]] = None + self._ended = False + self._cancelled = False + + # Queue for outgoing activities + self._queue: List[Callable[[], Activity]] = [] + self._queue_sync: Optional[asyncio.Task] = None + self._chunk_queued = False + + # Powered by AI feature flags + self._enable_feedback_loop = False + self._feedback_loop_type: Optional[Literal["default", "custom"]] = None + self._enable_generated_by_ai_label = False + self._citations: Optional[List[ClientCitation]] = [] + self._sensitivity_label: Optional[SensitivityUsageInfo] = None + + # Channel information + self._is_streaming_channel: bool = False + self._channel_id: Channels = None + self._interval: float = 0.1 # Default interval for sending updates + self._set_defaults(context) + + @property + def stream_id(self) -> Optional[str]: + """ + Gets the stream ID of the current response. + Assigned after the initial update is sent. + """ + return self._stream_id + + @property + def citations(self) -> Optional[List[ClientCitation]]: + """Gets the citations of the current response.""" + return self._citations + + @property + def updates_sent(self) -> int: + """Gets the number of updates sent for the stream.""" + return self._sequence_number - 1 + + def queue_informative_update(self, text: str) -> None: + """ + Queues an informative update to be sent to the client. + + Args: + text: Text of the update to send. + """ + if not self._is_streaming_channel: + return + + if self._ended: + raise RuntimeError("The stream has already ended.") + + # Queue a typing activity + def create_activity(): + activity = Activity( + type="typing", + text=text, + entities=[ + Entity( + type="streaminfo", + stream_type="informative", + stream_sequence=self._sequence_number, + ) + ], + ) + self._sequence_number += 1 + return activity + + self._queue_activity(create_activity) + + def queue_text_chunk( + self, text: str, citations: Optional[List[Citation]] = None + ) -> None: + """ + Queues a chunk of partial message text to be sent to the client. + + The text will be sent as quickly as possible to the client. + Chunks may be combined before delivery to the client. + + Args: + text: Partial text of the message to send. + citations: Citations to be included in the message. + """ + if self._cancelled: + return + if self._ended: + raise RuntimeError("The stream has already ended.") + + # Update full message text + self._message += text + + # If there are citations, modify the content so that the sources are numbers instead of [doc1], [doc2], etc. + self._message = CitationUtil.format_citations_response(self._message) + + # Queue the next chunk + self._queue_next_chunk() + + async def end_stream(self) -> None: + """ + Ends the stream by sending the final message to the client. + """ + if self._ended: + raise RuntimeError("The stream has already ended.") + + # Queue final message + self._ended = True + self._queue_next_chunk() + + # Wait for the queue to drain + await self.wait_for_queue() + + def set_attachments(self, attachments: List[Attachment]) -> None: + """ + Sets the attachments to attach to the final chunk. + + Args: + attachments: List of attachments. + """ + self._attachments = attachments + + def set_sensitivity_label(self, sensitivity_label: SensitivityUsageInfo) -> None: + """ + Sets the sensitivity label to attach to the final chunk. + + Args: + sensitivity_label: The sensitivity label. + """ + self._sensitivity_label = sensitivity_label + + def set_citations(self, citations: List[Citation]) -> None: + """ + Sets the citations for the full message. + + Args: + citations: Citations to be included in the message. + """ + if citations: + if not self._citations: + self._citations = [] + + curr_pos = len(self._citations) + + for citation in citations: + client_citation = ClientCitation( + type="Claim", + position=curr_pos + 1, + appearance={ + "type": "DigitalDocument", + "name": citation.title or f"Document #{curr_pos + 1}", + "abstract": CitationUtil.snippet(citation.content, 477), + }, + ) + curr_pos += 1 + self._citations.append(client_citation) + + def set_feedback_loop(self, enable_feedback_loop: bool) -> None: + """ + Sets the Feedback Loop in Teams that allows a user to + give thumbs up or down to a response. + Default is False. + + Args: + enable_feedback_loop: If true, the feedback loop is enabled. + """ + self._enable_feedback_loop = enable_feedback_loop + + def set_feedback_loop_type( + self, feedback_loop_type: Literal["default", "custom"] + ) -> None: + """ + Sets the type of UI to use for the feedback loop. + + Args: + feedback_loop_type: The type of the feedback loop. + """ + self._feedback_loop_type = feedback_loop_type + + def set_generated_by_ai_label(self, enable_generated_by_ai_label: bool) -> None: + """ + Sets the Generated by AI label in Teams. + Default is False. + + Args: + enable_generated_by_ai_label: If true, the label is added. + """ + self._enable_generated_by_ai_label = enable_generated_by_ai_label + + def get_message(self) -> str: + """ + Returns the most recently streamed message. + """ + return self._message + + async def wait_for_queue(self) -> None: + """ + Waits for the outgoing activity queue to be empty. + """ + if self._queue_sync: + await self._queue_sync + + def _set_defaults(self, context: "TurnContext"): + if Channels.ms_teams == context.activity.channel_id.channel: + self._is_streaming_channel = True + self._interval = 1.0 + elif Channels.direct_line == context.activity.channel_id.channel: + self._is_streaming_channel = True + self._interval = 0.5 + elif context.activity.delivery_mode == DeliveryModes.stream: + self._is_streaming_channel = True + self._interval = 0.1 + + self._channel_id = context.activity.channel_id + + def _queue_next_chunk(self) -> None: + """ + Queues the next chunk of text to be sent to the client. + """ + # Are we already waiting to send a chunk? + if self._chunk_queued: + return + + # Queue a chunk of text to be sent + self._chunk_queued = True + + def create_activity(): + self._chunk_queued = False + if self._ended: + # Send final message + activity = Activity( + type="message", + text=self._message or "end stream response", + attachments=self._attachments or [], + entities=[ + Entity( + type="streaminfo", + stream_id=self._stream_id, + stream_type="final", + stream_sequence=self._sequence_number, + ) + ], + ) + elif self._is_streaming_channel: + # Send typing activity + activity = Activity( + type="typing", + text=self._message, + entities=[ + Entity( + type="streaminfo", + stream_type="streaming", + stream_sequence=self._sequence_number, + ) + ], + ) + else: + return + self._sequence_number += 1 + return activity + + self._queue_activity(create_activity) + + def _queue_activity(self, factory: Callable[[], Activity]) -> None: + """ + Queues an activity to be sent to the client. + """ + self._queue.append(factory) + + # If there's no sync in progress, start one + if not self._queue_sync: + self._queue_sync = asyncio.create_task(self._drain_queue()) + + async def _drain_queue(self) -> None: + """ + Sends any queued activities to the client until the queue is empty. + """ + try: + logger.debug(f"Draining queue with {len(self._queue)} activities.") + while self._queue: + factory = self._queue.pop(0) + activity = factory() + if activity: + await self._send_activity(activity) + except Exception as err: + if ( + "403" in str(err) + and self._context.activity.channel_id == Channels.ms_teams + ): + logger.warning("Teams channel stopped the stream.") + self._cancelled = True + else: + logger.error( + f"Error occurred when sending activity while streaming: {err}" + ) + raise + finally: + self._queue_sync = None + + async def _send_activity(self, activity: Activity) -> None: + """ + Sends an activity to the client and saves the stream ID returned. + + Args: + activity: The activity to send. + """ + + streaminfo_entity = None + + if not activity.entities: + streaminfo_entity = Entity(type="streaminfo") + activity.entities = [streaminfo_entity] + else: + for entity in activity.entities: + if hasattr(entity, "type") and entity.type == "streaminfo": + streaminfo_entity = entity + break + + if not streaminfo_entity: + # If no streaminfo entity exists, create one + streaminfo_entity = Entity(type="streaminfo") + activity.entities.append(streaminfo_entity) + + # Set activity ID to the assigned stream ID + if self._stream_id: + activity.id = self._stream_id + streaminfo_entity.stream_id = self._stream_id + + if self._citations and len(self._citations) > 0 and not self._ended: + # Filter out the citations unused in content. + curr_citations = CitationUtil.get_used_citations( + self._message, self._citations + ) + if curr_citations: + activity.entities.append( + Entity( + type="https://schema.org/Message", + schema_type="Message", + context="https://schema.org", + id="", + citation=curr_citations, + ) + ) + + # Add in Powered by AI feature flags + if self._ended: + if self._enable_feedback_loop and self._feedback_loop_type: + # Add feedback loop to streaminfo entity + streaminfo_entity.feedback_loop = {"type": self._feedback_loop_type} + else: + # Add feedback loop enabled to streaminfo entity + streaminfo_entity.feedback_loop_enabled = self._enable_feedback_loop + # Add in Generated by AI + if self._enable_generated_by_ai_label: + activity.add_ai_metadata(self._citations, self._sensitivity_label) + + # Send activity + response = await self._context.send_activity(activity) + await asyncio.sleep(self._interval) + + # Save assigned stream ID + if not self._stream_id and response: + self._stream_id = response.id diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/__init__.py new file mode 100644 index 00000000..84500210 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HTTP abstractions for framework-agnostic adapter implementations.""" + +from ._http_request_protocol import HttpRequestProtocol +from ._http_response import HttpResponse, HttpResponseFactory +from ._http_adapter_base import HttpAdapterBase +from ._channel_service_routes import ChannelServiceRoutes + +__all__ = [ + "HttpRequestProtocol", + "HttpResponse", + "HttpResponseFactory", + "HttpAdapterBase", + "ChannelServiceRoutes", +] diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_channel_service_routes.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_channel_service_routes.py new file mode 100644 index 00000000..16adf381 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_channel_service_routes.py @@ -0,0 +1,202 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Channel service route definitions (framework-agnostic logic).""" + +from typing import Type, List, Union + +from microsoft_agents.activity import ( + AgentsModel, + Activity, + AttachmentData, + ConversationParameters, + Transcript, +) +from microsoft_agents.hosting.core import ChannelApiHandlerProtocol + +from ._http_request_protocol import HttpRequestProtocol + + +class ChannelServiceRoutes: + """Defines the Channel Service API routes and their handlers. + + This class provides framework-agnostic route logic that can be + adapted to different web frameworks (aiohttp, FastAPI, etc.). + """ + + def __init__(self, handler: ChannelApiHandlerProtocol, base_url: str = ""): + """Initialize channel service routes. + + Args: + handler: The handler that implements the Channel API protocol. + base_url: Optional base URL prefix for all routes. + """ + self.handler = handler + self.base_url = base_url + + @staticmethod + async def deserialize_from_body( + request: HttpRequestProtocol, target_model: Type[AgentsModel] + ) -> AgentsModel: + """Deserialize request body to target model.""" + content_type = request.headers.get("Content-Type", "") + if "application/json" not in content_type: + raise ValueError("Content-Type must be application/json") + + body = await request.json() + return target_model.model_validate(body) + + @staticmethod + def serialize_model(model_or_list: Union[AgentsModel, List[AgentsModel]]) -> dict: + """Serialize model or list of models to JSON-compatible dict.""" + if isinstance(model_or_list, AgentsModel): + return model_or_list.model_dump( + mode="json", exclude_unset=True, by_alias=True + ) + else: + return [ + model.model_dump(mode="json", exclude_unset=True, by_alias=True) + for model in model_or_list + ] + + # Route handler methods + async def send_to_conversation(self, request: HttpRequestProtocol) -> dict: + """Handle POST /v3/conversations/{conversation_id}/activities.""" + activity = await self.deserialize_from_body(request, Activity) + conversation_id = request.get_path_param("conversation_id") + result = await self.handler.on_send_to_conversation( + request.get_claims_identity(), + conversation_id, + activity, + ) + return self.serialize_model(result) + + async def reply_to_activity(self, request: HttpRequestProtocol) -> dict: + """Handle POST /v3/conversations/{conversation_id}/activities/{activity_id}.""" + activity = await self.deserialize_from_body(request, Activity) + conversation_id = request.get_path_param("conversation_id") + activity_id = request.get_path_param("activity_id") + result = await self.handler.on_reply_to_activity( + request.get_claims_identity(), + conversation_id, + activity_id, + activity, + ) + return self.serialize_model(result) + + async def update_activity(self, request: HttpRequestProtocol) -> dict: + """Handle PUT /v3/conversations/{conversation_id}/activities/{activity_id}.""" + activity = await self.deserialize_from_body(request, Activity) + conversation_id = request.get_path_param("conversation_id") + activity_id = request.get_path_param("activity_id") + result = await self.handler.on_update_activity( + request.get_claims_identity(), + conversation_id, + activity_id, + activity, + ) + return self.serialize_model(result) + + async def delete_activity(self, request: HttpRequestProtocol) -> None: + """Handle DELETE /v3/conversations/{conversation_id}/activities/{activity_id}.""" + conversation_id = request.get_path_param("conversation_id") + activity_id = request.get_path_param("activity_id") + await self.handler.on_delete_activity( + request.get_claims_identity(), + conversation_id, + activity_id, + ) + + async def get_activity_members(self, request: HttpRequestProtocol) -> dict: + """Handle GET /v3/conversations/{conversation_id}/activities/{activity_id}/members.""" + conversation_id = request.get_path_param("conversation_id") + activity_id = request.get_path_param("activity_id") + result = await self.handler.on_get_activity_members( + request.get_claims_identity(), + conversation_id, + activity_id, + ) + return self.serialize_model(result) + + async def create_conversation(self, request: HttpRequestProtocol) -> dict: + """Handle POST /.""" + conversation_parameters = await self.deserialize_from_body( + request, ConversationParameters + ) + result = await self.handler.on_create_conversation( + request.get_claims_identity(), conversation_parameters + ) + return self.serialize_model(result) + + async def get_conversations(self, request: HttpRequestProtocol) -> dict: + """Handle GET /.""" + # TODO: continuation token? conversation_id? + result = await self.handler.on_get_conversations( + request.get_claims_identity(), None + ) + return self.serialize_model(result) + + async def get_conversation_members(self, request: HttpRequestProtocol) -> dict: + """Handle GET /v3/conversations/{conversation_id}/members.""" + conversation_id = request.get_path_param("conversation_id") + result = await self.handler.on_get_conversation_members( + request.get_claims_identity(), + conversation_id, + ) + return self.serialize_model(result) + + async def get_conversation_member(self, request: HttpRequestProtocol) -> dict: + """Handle GET /v3/conversations/{conversation_id}/members/{member_id}.""" + conversation_id = request.get_path_param("conversation_id") + member_id = request.get_path_param("member_id") + result = await self.handler.on_get_conversation_member( + request.get_claims_identity(), + member_id, + conversation_id, + ) + return self.serialize_model(result) + + async def get_conversation_paged_members( + self, request: HttpRequestProtocol + ) -> dict: + """Handle GET /v3/conversations/{conversation_id}/pagedmembers.""" + conversation_id = request.get_path_param("conversation_id") + # TODO: continuation token? page size? + result = await self.handler.on_get_conversation_paged_members( + request.get_claims_identity(), + conversation_id, + ) + return self.serialize_model(result) + + async def delete_conversation_member(self, request: HttpRequestProtocol) -> dict: + """Handle DELETE /v3/conversations/{conversation_id}/members/{member_id}.""" + conversation_id = request.get_path_param("conversation_id") + member_id = request.get_path_param("member_id") + result = await self.handler.on_delete_conversation_member( + request.get_claims_identity(), + conversation_id, + member_id, + ) + return self.serialize_model(result) + + async def send_conversation_history(self, request: HttpRequestProtocol) -> dict: + """Handle POST /v3/conversations/{conversation_id}/activities/history.""" + conversation_id = request.get_path_param("conversation_id") + transcript = await self.deserialize_from_body(request, Transcript) + result = await self.handler.on_send_conversation_history( + request.get_claims_identity(), + conversation_id, + transcript, + ) + return self.serialize_model(result) + + async def upload_attachment(self, request: HttpRequestProtocol) -> dict: + """Handle POST /v3/conversations/{conversation_id}/attachments.""" + conversation_id = request.get_path_param("conversation_id") + attachment_data = await self.deserialize_from_body(request, AttachmentData) + result = await self.handler.on_upload_attachment( + request.get_claims_identity(), + conversation_id, + attachment_data, + ) + return self.serialize_model(result) diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_adapter_base.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_adapter_base.py new file mode 100644 index 00000000..55e2df56 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_adapter_base.py @@ -0,0 +1,136 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Base HTTP adapter with shared processing logic.""" + +from abc import ABC +from traceback import format_exc + +from microsoft_agents.activity import Activity, DeliveryModes +from microsoft_agents.hosting.core.authorization import ClaimsIdentity, Connections +from microsoft_agents.hosting.core import ( + Agent, + ChannelServiceAdapter, + ChannelServiceClientFactoryBase, + MessageFactory, + RestChannelServiceClientFactory, + TurnContext, +) + +from ._http_request_protocol import HttpRequestProtocol +from ._http_response import HttpResponse, HttpResponseFactory + + +class HttpAdapterBase(ChannelServiceAdapter, ABC): + """Base adapter for HTTP-based agent hosting with shared processing logic. + + This class contains all the common logic for processing HTTP requests + and can be subclassed by framework-specific adapters (aiohttp, FastAPI, etc). + """ + + def __init__( + self, + *, + connection_manager: Connections = None, + channel_service_client_factory: ChannelServiceClientFactoryBase = None, + ): + """Initialize the HTTP adapter. + + Args: + connection_manager: Optional connection manager for OAuth. + channel_service_client_factory: Factory for creating channel service clients. + """ + + async def on_turn_error(context: TurnContext, error: Exception): + error_message = f"Exception caught : {error}" + print(format_exc()) + + await context.send_activity(MessageFactory.text(error_message)) + + # Send a trace activity + await context.send_trace_activity( + "OnTurnError Trace", + error_message, + "https://www.botframework.com/schemas/error", + "TurnError", + ) + + self.on_turn_error = on_turn_error + + channel_service_client_factory = ( + channel_service_client_factory + or RestChannelServiceClientFactory(connection_manager) + ) + + super().__init__(channel_service_client_factory) + + async def process_request( + self, request: HttpRequestProtocol, agent: Agent + ) -> HttpResponse: + """Process an incoming HTTP request. + + Args: + request: The HTTP request to process. + agent: The agent to handle the request. + + Returns: + HttpResponse with the result. + + Raises: + TypeError: If request or agent is None. + """ + if not request: + raise TypeError("HttpAdapterBase.process_request: request can't be None") + if not agent: + raise TypeError("HttpAdapterBase.process_request: agent can't be None") + + if request.method != "POST": + return HttpResponseFactory.method_not_allowed() + + # Deserialize the incoming Activity + content_type = request.headers.get("Content-Type", "") + if "application/json" not in content_type: + return HttpResponseFactory.unsupported_media_type() + + try: + body = await request.json() + except Exception: + return HttpResponseFactory.bad_request("Invalid JSON") + + activity: Activity = Activity.model_validate(body) + + # Get claims identity (default to anonymous if not set by middleware) + claims_identity: ClaimsIdentity = ( + request.get_claims_identity() or ClaimsIdentity({}, False) + ) + + # Validate required activity fields + if ( + not activity.type + or not activity.conversation + or not activity.conversation.id + ): + return HttpResponseFactory.bad_request( + "Activity must have type and conversation.id" + ) + + try: + # Process the inbound activity with the agent + invoke_response = await self.process_activity( + claims_identity, activity, agent.on_turn + ) + + # Check if we need to return a synchronous response + if ( + activity.type == "invoke" + or activity.delivery_mode == DeliveryModes.expect_replies + ): + # Invoke and ExpectReplies cannot be performed async + return HttpResponseFactory.json( + invoke_response.body, invoke_response.status + ) + + return HttpResponseFactory.accepted() + + except PermissionError: + return HttpResponseFactory.unauthorized() diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_request_protocol.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_request_protocol.py new file mode 100644 index 00000000..f99dc1d8 --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_request_protocol.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Protocol for abstracting HTTP request objects across frameworks.""" + +from typing import Protocol, Dict, Any, Optional + + +class HttpRequestProtocol(Protocol): + """Protocol for HTTP requests that adapters must implement. + + This protocol defines the interface that framework-specific request + adapters must implement to work with the shared HTTP adapter logic. + """ + + @property + def method(self) -> str: + """HTTP method (GET, POST, etc.).""" + ... + + @property + def headers(self) -> Dict[str, str]: + """Request headers.""" + ... + + async def json(self) -> Dict[str, Any]: + """Parse request body as JSON.""" + ... + + def get_claims_identity(self) -> Optional[Any]: + """Get claims identity attached by auth middleware.""" + ... + + def get_path_param(self, name: str) -> str: + """Get path parameter by name.""" + ... diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_response.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_response.py new file mode 100644 index 00000000..d593cdee --- /dev/null +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/http/_http_response.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HTTP response abstraction.""" + +from typing import Any, Optional, Dict +from dataclasses import dataclass + + +@dataclass +class HttpResponse: + """Framework-agnostic HTTP response.""" + + status_code: int + body: Optional[Any] = None + headers: Optional[Dict[str, str]] = None + content_type: Optional[str] = "application/json" + + +class HttpResponseFactory: + """Factory for creating HTTP responses.""" + + @staticmethod + def ok(body: Any = None) -> HttpResponse: + """Create 200 OK response.""" + return HttpResponse(status_code=200, body=body) + + @staticmethod + def accepted() -> HttpResponse: + """Create 202 Accepted response.""" + return HttpResponse(status_code=202) + + @staticmethod + def json(body: Any, status_code: int = 200) -> HttpResponse: + """Create JSON response.""" + return HttpResponse(status_code=status_code, body=body) + + @staticmethod + def bad_request(message: str = "Bad Request") -> HttpResponse: + """Create 400 Bad Request response.""" + return HttpResponse(status_code=400, body={"error": message}) + + @staticmethod + def unauthorized(message: str = "Unauthorized") -> HttpResponse: + """Create 401 Unauthorized response.""" + return HttpResponse(status_code=401, body={"error": message}) + + @staticmethod + def method_not_allowed(message: str = "Method Not Allowed") -> HttpResponse: + """Create 405 Method Not Allowed response.""" + return HttpResponse(status_code=405, body={"error": message}) + + @staticmethod + def unsupported_media_type(message: str = "Unsupported Media Type") -> HttpResponse: + """Create 415 Unsupported Media Type response.""" + return HttpResponse(status_code=415, body={"error": message}) diff --git a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/__init__.py b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/__init__.py index c3064151..e72ee8d8 100644 --- a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/__init__.py +++ b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/__init__.py @@ -5,7 +5,9 @@ from .jwt_authorization_middleware import ( JwtAuthorizationMiddleware, ) -from .app.streaming import ( + +# Import streaming utilities from core for backward compatibility +from microsoft_agents.hosting.core.app.streaming import ( Citation, CitationUtil, StreamingResponse, diff --git a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/channel_service_route_table.py b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/channel_service_route_table.py index 2dd009fc..b8a44b5e 100644 --- a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/channel_service_route_table.py +++ b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/channel_service_route_table.py @@ -1,64 +1,58 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import json -from typing import List, Union, Type -from fastapi import APIRouter, Request, Response, HTTPException, Depends +from fastapi import APIRouter, Request, Response from fastapi.responses import JSONResponse -from microsoft_agents.activity import ( - AgentsModel, - Activity, - AttachmentData, - ConversationParameters, - Transcript, -) from microsoft_agents.hosting.core import ChannelApiHandlerProtocol +from microsoft_agents.hosting.core.http import ChannelServiceRoutes -async def deserialize_from_body( - request: Request, target_model: Type[AgentsModel] -) -> AgentsModel: - content_type = request.headers.get("Content-Type", "") - if "application/json" in content_type: - body = await request.json() - else: - raise HTTPException(status_code=415, detail="Unsupported Media Type") +class FastApiRequestAdapter: + """Adapter for FastAPI requests to use with ChannelServiceRoutes.""" - return target_model.model_validate(body) + def __init__(self, request: Request): + self._request = request + @property + def method(self) -> str: + return self._request.method -def get_serialized_response( - model_or_list: Union[AgentsModel, List[AgentsModel]], -) -> JSONResponse: - if isinstance(model_or_list, AgentsModel): - json_obj = model_or_list.model_dump( - mode="json", exclude_unset=True, by_alias=True - ) - else: - json_obj = [ - model.model_dump(mode="json", exclude_unset=True, by_alias=True) - for model in model_or_list - ] + @property + def headers(self): + return self._request.headers + + async def json(self): + return await self._request.json() + + def get_claims_identity(self): + return getattr(self._request.state, "claims_identity", None) - return JSONResponse(content=json_obj) + def get_path_param(self, name: str) -> str: + return self._request.path_params.get(name, "") def channel_service_route_table( handler: ChannelApiHandlerProtocol, base_url: str = "" ) -> APIRouter: + """Create FastAPI router for Channel Service API. + + Args: + handler: The handler that implements the Channel API protocol. + base_url: Optional base URL prefix for all routes. + + Returns: + APIRouter with all channel service routes. + """ router = APIRouter() + service_routes = ChannelServiceRoutes(handler, base_url) @router.post(base_url + "/v3/conversations/{conversation_id}/activities") async def send_to_conversation(conversation_id: str, request: Request): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_send_to_conversation( - getattr(request.state, "claims_identity", None), - conversation_id, - activity, + result = await service_routes.send_to_conversation( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.post( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" @@ -66,40 +60,21 @@ async def send_to_conversation(conversation_id: str, request: Request): async def reply_to_activity( conversation_id: str, activity_id: str, request: Request ): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_reply_to_activity( - getattr(request.state, "claims_identity", None), - conversation_id, - activity_id, - activity, - ) - - return get_serialized_response(result) + result = await service_routes.reply_to_activity(FastApiRequestAdapter(request)) + return JSONResponse(content=result) @router.put( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" ) async def update_activity(conversation_id: str, activity_id: str, request: Request): - activity = await deserialize_from_body(request, Activity) - result = await handler.on_update_activity( - getattr(request.state, "claims_identity", None), - conversation_id, - activity_id, - activity, - ) - - return get_serialized_response(result) + result = await service_routes.update_activity(FastApiRequestAdapter(request)) + return JSONResponse(content=result) @router.delete( base_url + "/v3/conversations/{conversation_id}/activities/{activity_id}" ) async def delete_activity(conversation_id: str, activity_id: str, request: Request): - await handler.on_delete_activity( - getattr(request.state, "claims_identity", None), - conversation_id, - activity_id, - ) - + await service_routes.delete_activity(FastApiRequestAdapter(request)) return Response(status_code=200) @router.get( @@ -109,97 +84,65 @@ async def delete_activity(conversation_id: str, activity_id: str, request: Reque async def get_activity_members( conversation_id: str, activity_id: str, request: Request ): - result = await handler.on_get_activity_members( - getattr(request.state, "claims_identity", None), - conversation_id, - activity_id, + result = await service_routes.get_activity_members( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.post(base_url + "/") async def create_conversation(request: Request): - conversation_parameters = await deserialize_from_body( - request, ConversationParameters - ) - result = await handler.on_create_conversation( - getattr(request.state, "claims_identity", None), conversation_parameters + result = await service_routes.create_conversation( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.get(base_url + "/") async def get_conversation(request: Request): - # TODO: continuation token? conversation_id? - result = await handler.on_get_conversations( - getattr(request.state, "claims_identity", None), None - ) - - return get_serialized_response(result) + result = await service_routes.get_conversations(FastApiRequestAdapter(request)) + return JSONResponse(content=result) @router.get(base_url + "/v3/conversations/{conversation_id}/members") async def get_conversation_members(conversation_id: str, request: Request): - result = await handler.on_get_conversation_members( - getattr(request.state, "claims_identity", None), - conversation_id, + result = await service_routes.get_conversation_members( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.get(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") async def get_conversation_member( conversation_id: str, member_id: str, request: Request ): - result = await handler.on_get_conversation_member( - getattr(request.state, "claims_identity", None), - member_id, - conversation_id, + result = await service_routes.get_conversation_member( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.get(base_url + "/v3/conversations/{conversation_id}/pagedmembers") async def get_conversation_paged_members(conversation_id: str, request: Request): - # TODO: continuation token? page size? - result = await handler.on_get_conversation_paged_members( - getattr(request.state, "claims_identity", None), - conversation_id, + result = await service_routes.get_conversation_paged_members( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.delete(base_url + "/v3/conversations/{conversation_id}/members/{member_id}") async def delete_conversation_member( conversation_id: str, member_id: str, request: Request ): - result = await handler.on_delete_conversation_member( - getattr(request.state, "claims_identity", None), - conversation_id, - member_id, + result = await service_routes.delete_conversation_member( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.post(base_url + "/v3/conversations/{conversation_id}/activities/history") async def send_conversation_history(conversation_id: str, request: Request): - transcript = await deserialize_from_body(request, Transcript) - result = await handler.on_send_conversation_history( - getattr(request.state, "claims_identity", None), - conversation_id, - transcript, + result = await service_routes.send_conversation_history( + FastApiRequestAdapter(request) ) - - return get_serialized_response(result) + return JSONResponse(content=result) @router.post(base_url + "/v3/conversations/{conversation_id}/attachments") async def upload_attachment(conversation_id: str, request: Request): - attachment_data = await deserialize_from_body(request, AttachmentData) - result = await handler.on_upload_attachment( - getattr(request.state, "claims_identity", None), - conversation_id, - attachment_data, - ) - - return get_serialized_response(result) + result = await service_routes.upload_attachment(FastApiRequestAdapter(request)) + return JSONResponse(content=result) return router diff --git a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/cloud_adapter.py b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/cloud_adapter.py index 1a8f912a..a94f81df 100644 --- a/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/cloud_adapter.py +++ b/libraries/microsoft-agents-hosting-fastapi/microsoft_agents/hosting/fastapi/cloud_adapter.py @@ -1,32 +1,48 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from traceback import format_exc from typing import Optional -from fastapi import Request, Response, HTTPException +from fastapi import Request, Response from fastapi.responses import JSONResponse -from microsoft_agents.hosting.core import error_resources -from microsoft_agents.hosting.core.authorization import ( - ClaimsIdentity, - Connections, -) -from microsoft_agents.activity import ( - Activity, - DeliveryModes, -) -from microsoft_agents.hosting.core import ( - Agent, - ChannelServiceAdapter, - ChannelServiceClientFactoryBase, - MessageFactory, - RestChannelServiceClientFactory, - TurnContext, + +from microsoft_agents.hosting.core import Agent +from microsoft_agents.hosting.core.authorization import Connections +from microsoft_agents.hosting.core.http import ( + HttpAdapterBase, + HttpResponse, ) +from microsoft_agents.hosting.core import ChannelServiceClientFactoryBase from .agent_http_adapter import AgentHttpAdapter -class CloudAdapter(ChannelServiceAdapter, AgentHttpAdapter): +class FastApiRequestAdapter: + """Adapter to make FastAPI Request compatible with HttpRequestProtocol.""" + + def __init__(self, request: Request): + self._request = request + + @property + def method(self) -> str: + return self._request.method + + @property + def headers(self): + return self._request.headers + + async def json(self): + return await self._request.json() + + def get_claims_identity(self): + return getattr(self._request.state, "claims_identity", None) + + def get_path_param(self, name: str) -> str: + return self._request.path_params.get(name, "") + + +class CloudAdapter(HttpAdapterBase, AgentHttpAdapter): + """CloudAdapter for FastAPI web framework.""" + def __init__( self, *, @@ -36,78 +52,43 @@ def __init__( """ Initializes a new instance of the CloudAdapter class. + :param connection_manager: Optional connection manager for OAuth. :param channel_service_client_factory: The factory to use to create the channel service client. """ + super().__init__( + connection_manager=connection_manager, + channel_service_client_factory=channel_service_client_factory, + ) - async def on_turn_error(context: TurnContext, error: Exception): - error_message = f"Exception caught : {error}" - print(format_exc()) + async def process(self, request: Request, agent: Agent) -> Optional[Response]: + """Process a FastAPI request. - await context.send_activity(MessageFactory.text(error_message)) + Args: + request: The FastAPI request. + agent: The agent to handle the request. - # Send a trace activity - await context.send_trace_activity( - "OnTurnError Trace", - error_message, - "https://www.botframework.com/schemas/error", - "TurnError", + Returns: + FastAPI Response object. + """ + # Adapt request to protocol + adapted_request = FastApiRequestAdapter(request) + + # Process using base implementation + http_response: HttpResponse = await self.process_request(adapted_request, agent) + + # Convert HttpResponse to FastAPI Response + return self._to_fastapi_response(http_response) + + @staticmethod + def _to_fastapi_response(http_response: HttpResponse) -> Response: + """Convert HttpResponse to FastAPI Response.""" + if http_response.body is not None: + return JSONResponse( + content=http_response.body, + status_code=http_response.status_code, + headers=http_response.headers, ) - - self.on_turn_error = on_turn_error - - channel_service_client_factory = ( - channel_service_client_factory - or RestChannelServiceClientFactory(connection_manager) + return Response( + status_code=http_response.status_code, + headers=http_response.headers, ) - - super().__init__(channel_service_client_factory) - - async def process(self, request: Request, agent: Agent) -> Optional[Response]: - if not request: - raise TypeError(str(error_resources.RequestRequired)) - if not agent: - raise TypeError(str(error_resources.AgentRequired)) - - if request.method == "POST": - # Deserialize the incoming Activity - content_type = request.headers.get("Content-Type", "") - if "application/json" in content_type: - body = await request.json() - else: - raise HTTPException(status_code=415, detail="Unsupported Media Type") - - activity: Activity = Activity.model_validate(body) - - # default to anonymous identity with no claims - claims_identity: ClaimsIdentity = getattr( - request.state, "claims_identity", ClaimsIdentity({}, False) - ) - - # A POST request must contain an Activity - if ( - not activity.type - or not activity.conversation - or not activity.conversation.id - ): - raise HTTPException(status_code=400, detail="Bad Request") - - try: - # Process the inbound activity with the agent - invoke_response = await self.process_activity( - claims_identity, activity, agent.on_turn - ) - - if ( - activity.type == "invoke" - or activity.delivery_mode == DeliveryModes.expect_replies - ): - # Invoke and ExpectReplies cannot be performed async, the response must be written before the calling thread is released. - return JSONResponse( - content=invoke_response.body, status_code=invoke_response.status - ) - - return Response(status_code=202) - except PermissionError: - raise HTTPException(status_code=401, detail="Unauthorized") - else: - raise HTTPException(status_code=405, detail="Method Not Allowed") diff --git a/test_samples/fastapi/authorization_agent.py b/test_samples/fastapi/authorization_agent.py index b2265893..81c8bf1c 100644 --- a/test_samples/fastapi/authorization_agent.py +++ b/test_samples/fastapi/authorization_agent.py @@ -8,7 +8,7 @@ import uvicorn from dotenv import load_dotenv -from fastapi import FastAPI, Request, Depends +from fastapi import FastAPI, Request from microsoft_agents.hosting.core import ( Authorization, AgentApplication,