diff --git a/README.md b/README.md index 67f202d..17542ef 100644 --- a/README.md +++ b/README.md @@ -208,10 +208,10 @@ class MyAuthBackend(BaseAuth): # Your authentication logic token = credential.credentials user = verify_token(token) - + if not user: raise HTTPException(401, "Invalid token") - + return { "user_id": user.id, "username": user.username, diff --git a/TODOD.txt b/TODOD.txt index 999356a..2e3a35b 100644 --- a/TODOD.txt +++ b/TODOD.txt @@ -3,3 +3,18 @@ So we can create a new api called fix_broken_graph # TODO: 2. And setup api to register frontend tools +Unique Features to highlight: +3. Single command to run with api, production ready api (Fastapi), async first api design, +uvicorn based with proper logger, health checks, swagger/redocs docs ready, +prometheus metrics ready, using best practices. Using env for control +5. Single command to generate docker image, can be deployable +anywhere, no vendor locked-in, no platform cost, deploy where you +want +7. Easy to connect Authentication with the platform settings, JWT by default, +Can be extend with any auth provider, just provide the class path we will setup. +8. Control Over generated ID, UUID are 128 bit, but if you can control and use +smaller ids, you can save a lot of space in DBs and indexes. +9. All the class like state, message, tool calls are Pydantic models, easily +json serializable, easy to debug, log and store. +10. Sentry Integration ready, just provide the DSN in the settings, all the exceptions +will be sent to sentry with proper context. \ No newline at end of file diff --git a/agentflow_cli/src/app/core/config/sentry_config.py b/agentflow_cli/src/app/core/config/sentry_config.py index fe93181..e653c43 100644 --- a/agentflow_cli/src/app/core/config/sentry_config.py +++ b/agentflow_cli/src/app/core/config/sentry_config.py @@ -1,16 +1,8 @@ -from typing import TYPE_CHECKING - from fastapi import Depends from agentflow_cli.src.app.core import Settings, get_settings, logger -if TYPE_CHECKING: # pragma: no cover - only for type hints - import sentry_sdk # noqa: F401 - from sentry_sdk.integrations.fastapi import FastApiIntegration # noqa: F401 - from sentry_sdk.integrations.starlette import StarletteIntegration # noqa: F401 - - def init_sentry(settings: Settings = Depends(get_settings)) -> None: """Initialize Sentry for error tracking and performance monitoring. @@ -18,6 +10,24 @@ def init_sentry(settings: Settings = Depends(get_settings)) -> None: unexpected error occurs, the application continues to run and a warning is logged instead of failing hard. """ + environment = settings.MODE.upper() if settings.MODE else "" + + if not settings.SENTRY_DSN: + logger.warning( + "Sentry is not configured. Sentry DSN is not set or running in local environment." + ) + return + + allowed_environments = ["PRODUCTION", "STAGING", "DEVELOPMENT"] + if environment not in allowed_environments: + logger.warning( + f"Sentry is not configured for this environment: {environment}. " + "Allowed environments are: {allowed_environments}" + ) + return + + logger.info(f"Sentry is configured for environment: {environment}") + try: import sentry_sdk from sentry_sdk.integrations.fastapi import FastApiIntegration diff --git a/agentflow_cli/src/app/core/config/settings.py b/agentflow_cli/src/app/core/config/settings.py index 10a79d1..1478de4 100644 --- a/agentflow_cli/src/app/core/config/settings.py +++ b/agentflow_cli/src/app/core/config/settings.py @@ -57,9 +57,9 @@ class Settings(BaseSettings): ################################# ###### Paths #################### ################################# - ROOT_PATH: str = "" - DOCS_PATH: str = "" - REDOCS_PATH: str = "" + ROOT_PATH: str = "/" + DOCS_PATH: str = "/docs" + REDOCS_PATH: str = "/redocs" ################################# ###### REDIS Config ########## diff --git a/agentflow_cli/src/app/core/config/setup_middleware.py b/agentflow_cli/src/app/core/config/setup_middleware.py index a61f3ba..57ebd03 100644 --- a/agentflow_cli/src/app/core/config/setup_middleware.py +++ b/agentflow_cli/src/app/core/config/setup_middleware.py @@ -8,6 +8,7 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request +from .sentry_config import init_sentry from .settings import get_settings, logger @@ -92,3 +93,6 @@ def setup_middleware(app: FastAPI): # Note: If you need streaming responses, you should not use GZipMiddleware. app.add_middleware(GZipMiddleware, minimum_size=1000) logger.debug("Middleware set up") + + # Initialize Sentry + init_sentry(settings) diff --git a/agentflow_cli/src/app/core/exceptions/handle_errors.py b/agentflow_cli/src/app/core/exceptions/handle_errors.py index 65b84f2..1cecc49 100644 --- a/agentflow_cli/src/app/core/exceptions/handle_errors.py +++ b/agentflow_cli/src/app/core/exceptions/handle_errors.py @@ -1,3 +1,16 @@ +# Handle all exceptions of agentflow here +from agentflow.exceptions import ( + GraphError, + GraphRecursionError, + MetricsError, + NodeError, + ResourceNotFoundError, + SchemaVersionError, + SerializationError, + StorageError, + TransientStorageError, +) +from agentflow.utils.validators import ValidationError from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException @@ -7,16 +20,13 @@ from agentflow_cli.src.app.utils import error_response from agentflow_cli.src.app.utils.schemas import ErrorSchemas -from .resources_exceptions import ResourceNotFoundError +from .resources_exceptions import ResourceNotFoundError as APIResourceNotFoundError from .user_exception import ( UserAccountError, UserPermissionError, ) -# Handle all exceptions of agentflow here - - def init_errors_handler(app: FastAPI): """ Initialize error handlers for the FastAPI application. @@ -88,8 +98,8 @@ async def user_write_exception_handler(request: Request, exc: UserPermissionErro status_code=exc.status_code, ) - @app.exception_handler(ResourceNotFoundError) - async def resource_not_found_exception_handler(request: Request, exc: ResourceNotFoundError): + @app.exception_handler(APIResourceNotFoundError) + async def resource_not_found_exception_handler(request: Request, exc: APIResourceNotFoundError): logger.error(f"ResourceNotFoundError: url: {request.base_url}", exc_info=exc) return error_response( request, @@ -97,3 +107,117 @@ async def resource_not_found_exception_handler(request: Request, exc: ResourceNo message=exc.message, status_code=exc.status_code, ) + + ## Need to handle agentflow specific exceptions here + @app.exception_handler(ValidationError) + async def agentflow_validation_exception_handler(request: Request, exc: ValidationError): + logger.error(f"AgentFlow ValidationError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code="AGENTFLOW_VALIDATION_ERROR", + message=str(exc), + status_code=422, + ) + + @app.exception_handler(GraphError) + async def graph_error_exception_handler(request: Request, exc: GraphError): + logger.error(f"GraphError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "GRAPH_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(NodeError) + async def node_error_exception_handler(request: Request, exc: NodeError): + logger.error(f"NodeError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "NODE_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(GraphRecursionError) + async def graph_recursion_error_exception_handler(request: Request, exc: GraphRecursionError): + logger.error(f"GraphRecursionError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "GRAPH_RECURSION_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(MetricsError) + async def metrics_error_exception_handler(request: Request, exc: MetricsError): + logger.error(f"MetricsError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "METRICS_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(SchemaVersionError) + async def schema_version_error_exception_handler(request: Request, exc: SchemaVersionError): + logger.error(f"SchemaVersionError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "SCHEMA_VERSION_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=422, + ) + + @app.exception_handler(SerializationError) + async def serialization_error_exception_handler(request: Request, exc: SerializationError): + logger.error(f"SerializationError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "SERIALIZATION_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(StorageError) + async def storage_error_exception_handler(request: Request, exc: StorageError): + logger.error(f"StorageError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "STORAGE_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=500, + ) + + @app.exception_handler(TransientStorageError) + async def transient_storage_error_exception_handler( + request: Request, exc: TransientStorageError + ): + logger.error(f"TransientStorageError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "TRANSIENT_STORAGE_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=503, + ) + + @app.exception_handler(ResourceNotFoundError) + async def resource_not_found_storage_exception_handler( + request: Request, exc: ResourceNotFoundError + ): + logger.error(f"ResourceNotFoundError: url: {request.base_url}", exc_info=exc) + return error_response( + request, + error_code=getattr(exc, "error_code", "RESOURCE_NOT_FOUND_000"), + message=getattr(exc, "message", str(exc)), + details=getattr(exc, "context", None), + status_code=404, + ) diff --git a/agentflow_cli/src/app/main.py b/agentflow_cli/src/app/main.py index 7f3fa30..83a72cb 100644 --- a/agentflow_cli/src/app/main.py +++ b/agentflow_cli/src/app/main.py @@ -7,6 +7,8 @@ from injectq import InjectQ from injectq.integrations.fastapi import setup_fastapi +# # Prometheus Instrumentator import +# from prometheus_fastapi_instrumentator import Instrumentator # from tortoise import Tortoise from agentflow_cli.src.app.core import ( get_settings, @@ -81,3 +83,6 @@ async def lifespan(app: FastAPI): # init routes init_routes(app) + +# instrumentator = Instrumentator().instrument(app) # Instrument first +# instrumentator.expose(app) # Then expose diff --git a/agentflow_cli/src/app/routers/graph/router.py b/agentflow_cli/src/app/routers/graph/router.py index 532600c..147de5a 100644 --- a/agentflow_cli/src/app/routers/graph/router.py +++ b/agentflow_cli/src/app/routers/graph/router.py @@ -1,17 +1,19 @@ from typing import Any -from fastapi import APIRouter, BackgroundTasks, Depends, Request +from agentflow.state import StreamChunk +from fastapi import APIRouter, Depends, Request from fastapi.logger import logger from fastapi.responses import StreamingResponse from injectq.integrations import InjectAPI from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( + FixGraphRequestSchema, GraphInputSchema, GraphInvokeOutputSchema, GraphSchema, + GraphSetupSchema, GraphStopSchema, - GraphStreamChunkSchema, ) from agentflow_cli.src.app.routers.graph.services.graph_service import GraphService from agentflow_cli.src.app.utils import success_response @@ -33,7 +35,6 @@ async def invoke_graph( request: Request, graph_input: GraphInputSchema, - background_tasks: BackgroundTasks, service: GraphService = InjectAPI(GraphService), user: dict[str, Any] = Depends(verify_current_user), ): @@ -46,7 +47,6 @@ async def invoke_graph( result: GraphInvokeOutputSchema = await service.invoke_graph( graph_input, user, - background_tasks, ) logger.info("Graph invoke completed successfully") @@ -61,13 +61,11 @@ async def invoke_graph( "/v1/graph/stream", summary="Stream graph execution", description="Execute the graph with streaming output for real-time results", - responses=generate_swagger_responses(GraphStreamChunkSchema), + responses=generate_swagger_responses(StreamChunk), openapi_extra={}, ) async def stream_graph( - request: Request, graph_input: GraphInputSchema, - background_tasks: BackgroundTasks, service: GraphService = InjectAPI(GraphService), user: dict[str, Any] = Depends(verify_current_user), ): @@ -79,7 +77,6 @@ async def stream_graph( result = service.stream_graph( graph_input, user, - background_tasks, ) return StreamingResponse( @@ -180,3 +177,96 @@ async def stop_graph( result, request, ) + + +@router.post( + "/v1/graph/setup", + summary="Setup Remote Tool to the Graph Execution", + description="Stop the currently running graph execution for a specific thread", + responses=generate_swagger_responses(dict), # type: ignore + openapi_extra={}, +) +async def setup_graph( + request: Request, + setup_request: GraphSetupSchema, + service: GraphService = InjectAPI(GraphService), + user: dict[str, Any] = Depends(verify_current_user), +): + """ + Setup the graph execution for a specific thread. + + Args: + setup_request: Request containing thread_id and optional config + + Returns: + Status information about the setup operation + """ + logger.info("Graph setup request received") + logger.debug(f"User info: {user}") + + result = await service.setup(setup_request) + + logger.info("Graph setup completed") + + return success_response( + result, + request, + ) + + +@router.post( + "/v1/graph/fix", + summary="Fix graph state by removing messages with empty tool calls", + description=( + "Fix the graph state by identifying and removing messages that have tool " + "calls with empty content. This is useful for cleaning up incomplete " + "tool call messages that may have failed or been interrupted." + ), + responses=generate_swagger_responses(dict), # type: ignore + openapi_extra={}, +) +async def fix_graph( + request: Request, + fix_request: FixGraphRequestSchema, + service: GraphService = InjectAPI(GraphService), + user: dict[str, Any] = Depends(verify_current_user), +): + """ + Fix the graph execution state for a specific thread. + + This endpoint removes messages with empty tool call content from the state. + Tool calls with empty content typically indicate interrupted or failed tool + executions that should be cleaned up. + + Args: + request: HTTP request object + fix_request: Request containing thread_id and optional config + service: Injected GraphService instance + user: Current authenticated user + + Returns: + Status information about the fix operation, including: + - success: Whether the operation was successful + - message: Descriptive message about the operation + - removed_count: Number of messages that were removed + - state: Updated state after fixing (or original if no changes) + + Raises: + HTTPException: If the fix operation fails or if no state is found + for the given thread_id + """ + logger.info(f"Graph fix request received for thread: {fix_request.thread_id}") + logger.debug(f"User info: {user}") + + result = await service.fix_graph( + fix_request.thread_id, + user, + fix_request.config, + ) + + logger.info(f"Graph fix completed for thread {fix_request.thread_id}") + + return success_response( + result, + request, + ) diff --git a/agentflow_cli/src/app/routers/graph/schemas/graph_schemas.py b/agentflow_cli/src/app/routers/graph/schemas/graph_schemas.py index f659bd2..1d994f8 100644 --- a/agentflow_cli/src/app/routers/graph/schemas/graph_schemas.py +++ b/agentflow_cli/src/app/routers/graph/schemas/graph_schemas.py @@ -5,20 +5,12 @@ from pydantic import BaseModel, Field -class MessageSchema(BaseModel): - message_id: int | None = Field(None, description="Unique identifier for the message") - role: str = Field( - default="user", description="Role of the message sender (user, assistant, etc.)" - ) - content: str = Field(..., description="Content of the message") - - class GraphInputSchema(BaseModel): """ Schema for graph input including messages and configuration. """ - messages: list[MessageSchema] = Field( + messages: list[Message] = Field( ..., description="List of messages to process through the graph" ) initial_state: dict[str, Any] | None = Field( @@ -66,13 +58,13 @@ class GraphInvokeOutputSchema(BaseModel): ) -class GraphStreamChunkSchema(BaseModel): - """ - Schema for individual stream chunks from graph execution. - """ +# class GraphStreamChunkSchema(BaseModel): +# """ +# Schema for individual stream chunks from graph execution. +# """ - data: dict[str, Any] = Field(..., description="Chunk data") - metadata: dict[str, Any] | None = Field(default=None, description="Chunk metadata") +# data: dict[str, Any] = Field(..., description="Chunk data") +# metadata: dict[str, Any] | None = Field(default=None, description="Chunk metadata") class NodeSchema(BaseModel): @@ -123,3 +115,42 @@ class GraphStopSchema(BaseModel): config: dict[str, Any] | None = Field( default=None, description="Optional configuration for the stop operation" ) + + +class RemoteToolSchema(BaseModel): + """Schema for remote tool execution.""" + + node_name: str = Field(..., description="Name of the node representing the tool") + name: str = Field(..., description="Name of the tool to execute") + description: str = Field(..., description="Description of the tool") + parameters: dict[str, Any] = Field(..., description="Parameters for the tool") + + +class GraphSetupSchema(BaseModel): + """Schema for setting up graph execution.""" + + tools: list[RemoteToolSchema] = Field( + ..., description="List of remote tools available for the graph" + ) + + +class FixGraphRequestSchema(BaseModel): + """Schema for fixing graph state by removing messages with empty tool call content.""" + + thread_id: str = Field(..., description="Thread ID to fix the graph state for") + config: dict[str, Any] | None = Field( + default=None, description="Optional configuration for the fix operation" + ) + + +class FixGraphResponseSchema(BaseModel): + """Schema for the fix graph operation response.""" + + success: bool = Field(..., description="Whether the fix operation was successful") + message: str = Field(..., description="Status message from the fix operation") + removed_count: int = Field( + default=0, description="Number of messages with empty tool calls that were removed" + ) + state: dict[str, Any] | None = Field( + default=None, description="Updated state after fixing the graph" + ) diff --git a/agentflow_cli/src/app/routers/graph/services/graph_service.py b/agentflow_cli/src/app/routers/graph/services/graph_service.py index 8e34a31..acf922c 100644 --- a/agentflow_cli/src/app/routers/graph/services/graph_service.py +++ b/agentflow_cli/src/app/routers/graph/services/graph_service.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import AsyncIterable from typing import Any from uuid import uuid4 @@ -6,7 +7,7 @@ from agentflow.graph import CompiledGraph from agentflow.state import AgentState, Message, StreamChunk, StreamEvent from agentflow.utils.thread_info import ThreadInfo -from fastapi import BackgroundTasks, HTTPException +from fastapi import HTTPException from injectq import InjectQ, inject, singleton from pydantic import BaseModel from starlette.responses import Content @@ -17,7 +18,7 @@ GraphInputSchema, GraphInvokeOutputSchema, GraphSchema, - MessageSchema, + GraphSetupSchema, ) from agentflow_cli.src.app.utils import DummyThreadNameGenerator, ThreadNameGenerator @@ -85,36 +86,6 @@ async def _save_thread(self, config: dict[str, Any], thread_id: int): ThreadInfo(thread_id=thread_id), ) - def _convert_messages(self, messages: list[MessageSchema]) -> list[Message]: - """ - Convert dictionary messages to PyAgenity Message objects. - - Args: - messages: List of dictionary messages - - Returns: - List of PyAgenity Message objects - """ - converted_messages = [] - allowed_roles = {"user", "assistant", "tool"} - for msg in messages: - if msg.role == "system": - raise Exception("System role is not allowed for safety reasons") - - if msg.role not in allowed_roles: - logger.warning(f"Invalid role '{msg.role}' in message, defaulting to 'user'") - - # Cast role to the expected Literal type for type checking - # System role are not allowed for safety reasons - # Fixme: Fix message id - converted_msg = Message.text_message( - content=msg.content, - message_id=msg.message_id, # type: ignore - ) - converted_messages.append(converted_msg) - - return converted_messages - def _extract_context_info( self, raw_state, result: dict[str, Any] ) -> tuple[list[Message] | None, str | None]: @@ -209,9 +180,7 @@ async def _prepare_input( # Prepare the input for the graph input_data: dict = { - "messages": self._convert_messages( - graph_input.messages, - ), + "messages": graph_input.messages, } if graph_input.initial_state: input_data["state"] = graph_input.initial_state @@ -374,3 +343,116 @@ async def get_state_schema(self) -> dict: except Exception as e: logger.error(f"Failed to get state schema: {e}") raise HTTPException(status_code=500, detail=f"Failed to get state schema: {e!s}") + + async def fix_graph( + self, + thread_id: str, + user: dict[str, Any], + config: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Fix graph state by removing messages with empty tool call content. + + This method retrieves the current state from the checkpointer, identifies messages + with tool calls that have empty content, removes those messages, and updates the + state. + + Args: + thread_id (str): The thread ID to fix the graph state for + user (dict): User information for context + config (dict, optional): Additional configuration for the operation + + Returns: + dict: Result dictionary containing: + - success (bool): Whether the operation was successful + - message (str): Status message + - removed_count (int): Number of messages removed + - state (dict): Updated state after fixing + + Raises: + HTTPException: If the operation fails + """ + + logger.info(f"Starting fix graph operation for thread: {thread_id}") + logger.debug(f"User info: {user}") + + fix_config = { + "thread_id": thread_id, + "user": user, + } + + # Merge additional config if provided + if config: + fix_config.update(config) + + logger.debug("Fetching current state from checkpointer") + state: AgentState | None = await self.checkpointer.aget_state(fix_config) + + if not state: + logger.warning(f"No state found for thread: {thread_id}") + return { + "success": False, + "message": f"No state found for thread: {thread_id}", + "removed_count": 0, + "state": None, + } + + messages: list[Message] = state.context + logger.debug(f"Found {len(messages)} messages in state") + + if not messages: + logger.info("No messages found in state, nothing to fix") + return { + "success": True, + "message": "No messages found in state", + "removed_count": 0, + "state": state.model_dump_json(), + } + + last_message = messages[-1] + updated_context = [] + if last_message.role == "assistant" and last_message.tools_calls: + updated_context = messages[:-1] + state.context = updated_context + await self.checkpointer.aput_state(fix_config, state) + return { + "success": True, + "message": "Removed last assistant message with empty tool calls", + "removed_count": 1, + "state": state.model_dump_json(), + } + else: + logger.warning( + "Last message is not an assistant message with tool calls, skipping it from checks." + ) + + return { + "success": True, + "message": "No messages with empty tool calls found", + "removed_count": 0, + "state": state.model_dump_json(), + } + + async def setup(self, data: GraphSetupSchema) -> dict: + # lets create tools + remote_tools = defaultdict(list) + for tool in data.tools: + remote_tools[tool.node_name].append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) + + # Now call setup on graph + for node_name, tool in remote_tools.items(): + self._graph.attach_remote_tools(tool, node_name) + + return { + "status": "success", + "details": f"Added tools to nodes: {list(remote_tools.keys())}", + } diff --git a/test.py b/test.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/test_fix_graph.py b/tests/unit_tests/test_fix_graph.py new file mode 100644 index 0000000..92ff43b --- /dev/null +++ b/tests/unit_tests/test_fix_graph.py @@ -0,0 +1,237 @@ +"""Unit tests for fix_graph functionality.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agentflow.checkpointer import BaseCheckpointer +from agentflow.state import AgentState, Message, TextBlock +from fastapi import HTTPException + +from agentflow_cli.src.app.routers.graph.services.graph_service import GraphService + + +class TestFixGraph: + """Test cases for fix_graph service method.""" + + @pytest.fixture + def mock_checkpointer(self): + """Create a mock checkpointer.""" + checkpointer = MagicMock(spec=BaseCheckpointer) + checkpointer.aget_state = AsyncMock() + checkpointer.aput_state = AsyncMock() + return checkpointer + + @pytest.fixture + def mock_graph(self): + """Create a mock CompiledGraph.""" + graph = MagicMock() + return graph + + @pytest.fixture + def mock_config(self): + """Create a mock GraphConfig.""" + config = MagicMock() + config.thread_name_generator_path = None + return config + + @pytest.fixture + def graph_service(self, mock_graph, mock_checkpointer, mock_config): + """Create a GraphService instance with mocked dependencies.""" + service = GraphService.__new__(GraphService) # Skip __init__ + service._graph = mock_graph + service.checkpointer = mock_checkpointer + service.config = mock_config + service.thread_name_generator = None + return service + + def _create_mock_message( + self, + message_id: str, + role: str = "user", + content_text: str = "test", + tool_calls: list | None = None, + ) -> MagicMock: + """Helper to create a mock message.""" + message = MagicMock(spec=Message) + message.message_id = message_id + message.role = role + message.content = [MagicMock(spec=TextBlock)] + message.tools_calls = tool_calls + return message + + def _create_mock_state( + self, + messages: list, + ) -> MagicMock: + """Helper to create a mock state with messages.""" + state = MagicMock(spec=AgentState) + state.context = messages + state.model_dump = MagicMock(return_value={"messages": messages}) + + # Configure the type() of state to have model_validate + state_type = type(state) + state_type.model_validate = MagicMock(side_effect=lambda x: state) + + return state + + @pytest.mark.asyncio + async def test_fix_graph_no_messages_with_empty_tool_calls( + self, graph_service, mock_checkpointer + ): + """Test fix_graph when no messages have empty tool calls.""" + # Create messages without empty tool calls + messages = [ + self._create_mock_message("msg1", tool_calls=None), + self._create_mock_message("msg2", tool_calls=[{"name": "tool1", "content": "ok"}]), + ] + + mock_state = self._create_mock_state(messages) + mock_checkpointer.aget_state.return_value = mock_state + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert result["success"] is True + assert result["removed_count"] == 0 + assert "No messages with empty tool calls found" in result["message"] + mock_checkpointer.aput_state.assert_not_called() + + @pytest.mark.asyncio + async def test_fix_graph_removes_messages_with_empty_tool_calls( + self, graph_service, mock_checkpointer + ): + """Test fix_graph removes messages with empty tool call content.""" + # Create messages: one with empty tool call, one normal, one with non-empty tool call + msg1 = self._create_mock_message("msg1", tool_calls=[{"name": "tool1", "content": ""}]) + msg2 = self._create_mock_message("msg2", tool_calls=None) + msg3 = self._create_mock_message( + "msg3", tool_calls=[{"name": "tool2", "content": "result"}] + ) + + original_messages = [msg1, msg2, msg3] + mock_state = self._create_mock_state(original_messages) + mock_checkpointer.aget_state.return_value = mock_state + + # Create updated state for after the fix + updated_state = self._create_mock_state([msg2, msg3]) + mock_checkpointer.aput_state.return_value = updated_state + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert result["success"] is True + assert result["removed_count"] == 1 + assert "Successfully removed 1 message(s)" in result["message"] + mock_checkpointer.aput_state.assert_called_once() + + @pytest.mark.asyncio + async def test_fix_graph_removes_multiple_messages_with_empty_tool_calls( + self, graph_service, mock_checkpointer + ): + """Test fix_graph removes multiple messages with empty tool calls.""" + # Create messages with multiple empty tool calls + msg1 = self._create_mock_message("msg1", tool_calls=[{"name": "tool1", "content": ""}]) + msg2 = self._create_mock_message("msg2", tool_calls=[{"name": "tool2", "content": ""}]) + msg3 = self._create_mock_message("msg3", tool_calls=None) + msg4 = self._create_mock_message( + "msg4", tool_calls=[{"name": "tool3", "content": "result"}] + ) + + original_messages = [msg1, msg2, msg3, msg4] + mock_state = self._create_mock_state(original_messages) + mock_checkpointer.aget_state.return_value = mock_state + + updated_state = self._create_mock_state([msg3, msg4]) + mock_checkpointer.aput_state.return_value = updated_state + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert result["success"] is True + assert result["removed_count"] == 2 + assert "Successfully removed 2 message(s)" in result["message"] + mock_checkpointer.aput_state.assert_called_once() + + @pytest.mark.asyncio + async def test_fix_graph_no_state_found(self, graph_service, mock_checkpointer): + """Test fix_graph when no state is found for thread.""" + mock_checkpointer.aget_state.return_value = None + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert result["success"] is False + assert "No state found" in result["message"] + assert result["removed_count"] == 0 + mock_checkpointer.aput_state.assert_not_called() + + @pytest.mark.asyncio + async def test_fix_graph_with_config(self, graph_service, mock_checkpointer): + """Test fix_graph respects additional config.""" + messages = [self._create_mock_message("msg1")] + mock_state = self._create_mock_state(messages) + mock_checkpointer.aget_state.return_value = mock_state + + extra_config = {"custom_key": "custom_value"} + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}, extra_config) + + assert result["success"] is True + + # Verify config was merged correctly + call_args = mock_checkpointer.aget_state.call_args + config_arg = call_args[0][0] # First positional arg + assert config_arg["thread_id"] == "thread1" + assert config_arg["custom_key"] == "custom_value" + + @pytest.mark.asyncio + async def test_fix_graph_exception_handling(self, graph_service, mock_checkpointer): + """Test fix_graph handles exceptions properly.""" + mock_checkpointer.aget_state.side_effect = Exception("Database error") + + with pytest.raises(HTTPException) as exc_info: + await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert exc_info.value.status_code == 500 + assert "Fix graph operation failed" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_fix_graph_handles_mixed_empty_content(self, graph_service, mock_checkpointer): + """Test fix_graph correctly identifies empty string vs None content.""" + # Mix of empty string, None, and non-empty content + msg1 = self._create_mock_message("msg1", tool_calls=[{"name": "tool1", "content": ""}]) + msg2 = self._create_mock_message("msg2", tool_calls=[{"name": "tool2", "content": None}]) + msg3 = self._create_mock_message("msg3", tool_calls=[{"name": "tool3", "content": "valid"}]) + + original_messages = [msg1, msg2, msg3] + mock_state = self._create_mock_state(original_messages) + mock_checkpointer.aget_state.return_value = mock_state + + updated_state = self._create_mock_state([msg3]) + mock_checkpointer.aput_state.return_value = updated_state + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + # Should remove both empty string and None content + assert result["removed_count"] == 2 + assert result["success"] is True + + @pytest.mark.asyncio + async def test_fix_graph_preserves_message_order(self, graph_service, mock_checkpointer): + """Test fix_graph preserves the order of remaining messages.""" + msg1 = self._create_mock_message("msg1", tool_calls=None) + msg2 = self._create_mock_message("msg2", tool_calls=[{"name": "tool", "content": ""}]) + msg3 = self._create_mock_message("msg3", tool_calls=None) + msg4 = self._create_mock_message("msg4", tool_calls=[{"name": "tool", "content": ""}]) + msg5 = self._create_mock_message("msg5", tool_calls=None) + + original_messages = [msg1, msg2, msg3, msg4, msg5] + mock_state = self._create_mock_state(original_messages) + mock_checkpointer.aget_state.return_value = mock_state + + updated_state = self._create_mock_state([msg1, msg3, msg5]) + mock_checkpointer.aput_state.return_value = updated_state + + result = await graph_service.fix_graph("thread1", {"user_id": "user1"}) + + assert result["success"] is True + assert result["removed_count"] == 2 + + # Verify the correct messages are kept by checking the call to aput_state + call_args = mock_checkpointer.aput_state.call_args + updated_state_arg = call_args[0][1] # Second positional arg