diff --git a/README.md b/README.md index 17542ef..a3f962a 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,107 @@ name = generator.generate_name() ``` See the [Thread Name Generator Guide](./docs/thread-name-generator.md) for custom implementations. + +## Security + +AgentFlow CLI provides enterprise-grade security features for production deployments. + +### Security Features + +- ✅ **Authentication** - Built-in JWT and custom authentication backends +- ✅ **Authorization** - Resource-based access control with extensible backends +- ✅ **Request Limits** - DoS protection with configurable size limits (default 10MB) +- ✅ **Error Sanitization** - Production-safe error messages preventing information disclosure +- ✅ **Log Sanitization** - Automatic redaction of sensitive data (tokens, passwords, secrets) +- ✅ **Security Warnings** - Startup validation for insecure configurations +- ✅ **HTTPS Ready** - SSL/TLS support with secure headers + +### Production Security Checklist + +Before deploying to production, ensure: + +```bash +# Required: Set production mode +MODE=production + +# Required: Strong JWT secret (32+ characters) +JWT_SECRET_KEY= + +# Required: Disable debug mode +IS_DEBUG=false + +# Required: Specific CORS origins (not *) +ORIGINS=https://yourdomain.com + +# Required: Specific allowed hosts (not *) +ALLOWED_HOST=yourdomain.com + +# Recommended: Disable API docs +DOCS_PATH= +REDOCS_PATH= + +# Recommended: Configure request size limit +MAX_REQUEST_SIZE=10485760 # 10MB default +``` + +### Quick Security Setup + +**1. Enable JWT Authentication:** +```json +{ + "auth": "jwt" +} +``` + +**2. Implement Authorization:** +```python +# auth/rbac_backend.py +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend + +class RBACAuthorizationBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + role = user.get("role", "viewer") + # Implement your authorization logic + return role == "admin" or (role == "developer" and action == "read") +``` + +**3. Configure in agentflow.json:** +```json +{ + "auth": "jwt", + "authorization": { + "path": "auth.rbac_backend:RBACAuthorizationBackend" + } +} +``` + +### Security Validation + +AgentFlow automatically validates your configuration and warns about security issues: + +``` +⚠️ SECURITY WARNING: CORS ORIGINS='*' in production. + Set ORIGINS to specific domains. + +⚠️ SECURITY WARNING: DEBUG mode enabled in production! + Set IS_DEBUG=false +``` + +### Comprehensive Security Guide + +For detailed security documentation, threat model, best practices, and deployment guidelines, see: + +📖 **[SECURITY.md](./SECURITY.md)** - Complete Security Guide + +Topics covered: +- Threat model and attack vectors +- Authentication and authorization patterns +- Production deployment checklist +- Docker and Kubernetes security configurations +- Security testing and monitoring +- Incident response procedures +- Vulnerability reporting + ## Deployment See the [Deployment Guide](./docs/deployment.md) for complete deployment instructions. diff --git a/Task.md b/Task.md deleted file mode 100644 index de964cb..0000000 --- a/Task.md +++ /dev/null @@ -1,31 +0,0 @@ -Fix this ... - -See the class should not be like this, -api checking is not checking in sequence so we not able to capture the bugs -It should be invoke then using checkpointer api, we need to get the data - -Lets execute api in below sequence, if any api fails then it should crash the script - -# Test Graph APIs -1. /v1/ping/ -2. /v1/graph -3. /v1/graph/StateSchema - -# Test Graph Run APIs -1. /v1/graph/invoke -2. /v1/graph/stream - -# Now checkpointer APIs -Note: using v1/graph/invoke will share thread_id, so we can use that thread_id to test checkpointer apis -1. /v1/threads/{thread_id}/state - - -# Thinking blocks not converted to reasoning blocks - -"thinking_blocks": [ - { - "type": "thinking", - "thinking": "{\"text\": \"Hello! How can I help you today?\"}", - "signature": "CpwCAdHtim9umxTi9N+7hzmLhJnA1tIWY59EIk7d6FiZeBb/Faqtq7w7GxIqIeQQ08pNPtUOYDf5Vtl9FCc/dGP9a+QHmq2xoygtMEHY1e6tTDExoOeyDTWoL6/jruOoTTyUHxr62D2sD5xn/zmKmj7EGl5qDT5cJJRhPt208GvTchpA38QcazDAWIDzrkmqQEh+zdXv9HhUOM57yXs1/PDAPZiF20lVdEnGibqfsUa640o2tDVCxnd5xbciPdxEx6wrVhXVm0bnKybgXNPw+xory715t93vL0gY6h1MS8GGJbyVNO+xRwUD5yxCSG4HNyGdT9Axhfv8w8SNfG4IetJFegn2Oz8Us22PYm1bcH+7w/5yAJ2To4RHWO7TkeQ=" - } - ] \ No newline at end of file diff --git a/agentflow.json b/agentflow.json index 82df8be..1604556 100644 --- a/agentflow.json +++ b/agentflow.json @@ -2,5 +2,8 @@ "agent": "graph.react:app", "thread_name_generator": "graph.thread_name_generator:MyNameGenerator", "env": ".env", - "auth": null + "auth": { + "path": "graph.custom_auth:CustomAuth", + "method": "custom" + } } diff --git a/agentflow_cli/src/app/core/auth/authorization.py b/agentflow_cli/src/app/core/auth/authorization.py new file mode 100644 index 0000000..702bb3b --- /dev/null +++ b/agentflow_cli/src/app/core/auth/authorization.py @@ -0,0 +1,90 @@ +""" +Authorization backend system for AgentFlow CLI. + +This module provides an authorization interface that developers can implement +to add resource-level access control to their AgentFlow applications. +""" + +from abc import ABC, abstractmethod +from typing import Any + + +class AuthorizationBackend(ABC): + """ + Abstract base class for authorization backends. + + Developers should implement this class to define custom authorization logic + for their AgentFlow applications. The authorize method is called before + any resource operation to check if the user has permission. + + Example: + class MyAuthorizationBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + # Check if user has permission + if user.get("role") == "admin": + return True + # Add custom logic here + return False + """ + + @abstractmethod + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context: Any, + ) -> bool: + """ + Check if user can perform action on resource. + + Args: + user: User information dictionary containing at least 'user_id' + resource: Resource type (e.g., 'graph', 'checkpointer', 'store') + action: Action to perform (e.g., 'invoke', 'stream', 'read', 'write', 'delete') + resource_id: Optional specific resource identifier (e.g., thread_id, namespace) + **context: Additional context for authorization decision + + Returns: + bool: True if authorized, False otherwise + + Raises: + Exception: Can raise exceptions for auth failures or errors + """ + + +class DefaultAuthorizationBackend(AuthorizationBackend): + """ + Default authorization backend that allows all authenticated users. + + This implementation performs basic authentication check (user has user_id) + but allows all operations. Use this as a starting point or for development. + + For production use, implement a custom AuthorizationBackend with proper + access control logic based on your application's requirements. + """ + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context: Any, + ) -> bool: + """ + Allow all authenticated users to perform any action. + + Args: + user: User information dictionary + resource: Resource type (not used in default implementation) + action: Action to perform (not used in default implementation) + resource_id: Optional resource identifier (not used in default implementation) + **context: Additional context (not used in default implementation) + + Returns: + bool: True if user has 'user_id', False otherwise + """ + # Only check if user is authenticated (has user_id) + return bool(user.get("user_id")) diff --git a/agentflow_cli/src/app/core/auth/permissions.py b/agentflow_cli/src/app/core/auth/permissions.py new file mode 100644 index 0000000..63a70ca --- /dev/null +++ b/agentflow_cli/src/app/core/auth/permissions.py @@ -0,0 +1,150 @@ +""" +Unified authentication and authorization dependency for FastAPI endpoints. + +This module provides a reusable dependency that combines authentication and +authorization checks, reducing code duplication across routers. +""" + +from collections.abc import Callable +from typing import Any + +from fastapi import Depends, HTTPException, Request, Response +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from injectq.integrations import InjectAPI + +from agentflow_cli.src.app.core import logger +from agentflow_cli.src.app.core.auth.auth_backend import BaseAuth +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend +from agentflow_cli.src.app.core.config.graph_config import GraphConfig +from agentflow_cli.src.app.core.utils.log_sanitizer import sanitize_for_logging + + +class RequirePermission: + """ + FastAPI dependency that combines authentication and authorization. + + This class-based dependency verifies user authentication and checks + authorization in a single step, reducing boilerplate code in endpoints. + + Usage: + @router.post("/v1/resource") + async def endpoint( + user: dict = Depends(RequirePermission("resource", "action")) + ): + # User is authenticated and authorized + pass + + Args: + resource: Resource type being accessed (e.g., "graph", "checkpointer", "store") + action: Action being performed (e.g., "invoke", "read", "write", "delete") + extract_resource_id: Optional function to extract resource_id from request + """ + + def __init__( + self, + resource: str, + action: str, + extract_resource_id: Callable[[Request], str | None] | None = None, + ): + """ + Initialize the permission requirement. + + Args: + resource: Resource type (graph, checkpointer, store) + action: Action type (invoke, stream, read, write, delete, etc.) + extract_resource_id: Optional callable that extracts resource_id from request + """ + self.resource = resource + self.action = action + self.extract_resource_id_fn = extract_resource_id + + async def __call__( + self, + request: Request, + response: Response, + credential: HTTPAuthorizationCredentials = Depends( + HTTPBearer(auto_error=False), + ), + config: GraphConfig = InjectAPI(GraphConfig), + auth_backend: BaseAuth = InjectAPI(BaseAuth), + authz: AuthorizationBackend = InjectAPI(AuthorizationBackend), + ) -> dict[str, Any]: + """ + Verify authentication and authorization. + + Returns: + dict: User information if authenticated and authorized + + Raises: + HTTPException: 403 if authorization fails + """ + # Step 1: Authentication (reusing verify_current_user logic) + user = {} + backend = config.auth_config() + if not backend: + user = {} + elif not auth_backend: + logger.error("Auth backend is not configured") + user = {} + else: + user_result = auth_backend.authenticate( + request, + response, + credential, + ) + if user_result and "user_id" not in user_result: + logger.error("Authentication failed: 'user_id' not found in user info") + user = user_result or {} + + # Step 2: Extract resource_id if available + resource_id = None + if self.extract_resource_id_fn: + resource_id = self.extract_resource_id_fn(request) + else: + resource_id = self._extract_resource_id_from_path(request) + + # Step 3: Authorization + if not await authz.authorize( + user, + self.resource, + self.action, + resource_id=resource_id, + ): + logger.warning( + f"Authorization failed for user {user.get('user_id')} " + f"on {self.resource}:{self.action}" + ) + raise HTTPException( + status_code=403, + detail=f"Not authorized to {self.action} {self.resource}", + ) + + # Log successful auth/authz (with sanitized user info) + logger.debug( + f"Auth/Authz success for {self.resource}:{self.action}, " + f"user: {sanitize_for_logging(user)}" + ) + + return user + + def _extract_resource_id_from_path(self, request: Request) -> str | None: + """ + Extract resource ID from request path parameters. + + Looks for common patterns like thread_id, memory_id in path params. + + Args: + request: FastAPI request object + + Returns: + Resource ID as string, or None if not found + """ + # Check path parameters + path_params = request.path_params + + # Common resource ID patterns + for param_name in ["thread_id", "memory_id", "namespace"]: + if param_name in path_params: + return str(path_params[param_name]) + + return None diff --git a/agentflow_cli/src/app/core/config/graph_config.py b/agentflow_cli/src/app/core/config/graph_config.py index 2e0f49c..a4d0e6d 100644 --- a/agentflow_cli/src/app/core/config/graph_config.py +++ b/agentflow_cli/src/app/core/config/graph_config.py @@ -43,6 +43,17 @@ def redis_url(self) -> str | None: def thread_name_generator_path(self) -> str | None: return self.data.get("thread_name_generator", None) + @property + def authorization_path(self) -> str | None: + """ + Get the authorization backend path from configuration. + + Returns: + str | None: Path to authorization backend module in format 'module:attribute', + or None if not configured + """ + return self.data.get("authorization", None) + def auth_config(self) -> dict | None: res = self.data.get("auth", None) if not res: diff --git a/agentflow_cli/src/app/core/config/settings.py b/agentflow_cli/src/app/core/config/settings.py index 1478de4..a278460 100644 --- a/agentflow_cli/src/app/core/config/settings.py +++ b/agentflow_cli/src/app/core/config/settings.py @@ -2,6 +2,7 @@ import os from functools import lru_cache +from pydantic import field_validator, model_validator from pydantic_settings import BaseSettings @@ -46,6 +47,26 @@ class Settings(BaseSettings): LOG_LEVEL: str = "INFO" IS_DEBUG: bool = True + ################################# + ###### Request Limits ########### + ################################# + MAX_REQUEST_SIZE: int = 10 * 1024 * 1024 # 10MB default + + ################################# + ###### Security Headers ######### + ################################# + SECURITY_HEADERS_ENABLED: bool = True + HSTS_ENABLED: bool = True + HSTS_MAX_AGE: int = 31536000 # 1 year in seconds + HSTS_INCLUDE_SUBDOMAINS: bool = True + HSTS_PRELOAD: bool = False + FRAME_OPTIONS: str = "DENY" # DENY, SAMEORIGIN, or ALLOW-FROM + CONTENT_TYPE_OPTIONS: str = "nosniff" + XSS_PROTECTION: str = "1; mode=block" + REFERRER_POLICY: str = "strict-origin-when-cross-origin" + PERMISSIONS_POLICY: str | None = None # Uses default if None + CSP_POLICY: str | None = None # Uses default if None + SUMMARY: str = "Pyagenity Backend" ################################# @@ -81,6 +102,53 @@ class Settings(BaseSettings): SNOWFLAKE_NODE_BITS: int = 5 SNOWFLAKE_WORKER_BITS: int = 8 + @field_validator("MODE", mode="before") + @classmethod + def normalize_mode(cls, v: str | None) -> str: + """Normalize MODE to lowercase.""" + return v.lower() if v else "development" + + @field_validator("ORIGINS") + @classmethod + def warn_cors_wildcard(cls, v: str) -> str: + """Warn if CORS is set to wildcard in production.""" + mode = os.environ.get("MODE", "development").lower() + if v == "*" and mode == "production": + logger.warning( + "⚠️ SECURITY WARNING: CORS ORIGINS='*' in production.\n" + " This allows any website to make requests to your API.\n" + " Set ORIGINS to specific domains (e.g., https://yourdomain.com)" + ) + return v + + @model_validator(mode="after") + def check_production_security(self): + """Check for insecure configurations in production mode.""" + if self.MODE == "production": + warnings = [] + + if self.IS_DEBUG: + warnings.append( + "⚠️ DEBUG mode is enabled in production. This may expose sensitive information." + ) + + if self.DOCS_PATH or self.REDOCS_PATH: + warnings.append( + "⚠️ API documentation endpoints are enabled in production. " + "Consider disabling DOCS_PATH and REDOCS_PATH." + ) + + if self.ALLOWED_HOST == "*": + warnings.append( + "⚠️ ALLOWED_HOST='*' in production. " + "Set to specific hostnames for better security." + ) + + if warnings: + logger.warning("\n".join(["\n🔒 PRODUCTION SECURITY WARNINGS:", *warnings])) + + return self + class Config: extra = "allow" diff --git a/agentflow_cli/src/app/core/config/setup_logs.py b/agentflow_cli/src/app/core/config/setup_logs.py index 04f91e6..a14f205 100644 --- a/agentflow_cli/src/app/core/config/setup_logs.py +++ b/agentflow_cli/src/app/core/config/setup_logs.py @@ -3,6 +3,8 @@ from fastapi.logger import logger as fastapi_logger +from agentflow_cli.src.app.core.utils.log_sanitizer import SanitizingFormatter + def init_logger(level: int | str = logging.INFO) -> None: """ @@ -59,9 +61,12 @@ def init_logger(level: int | str = logging.INFO) -> None: console_handler.setLevel(level) # Create formatter - formatter = logging.Formatter( + base_formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" ) + # Wrap with sanitizing formatter to prevent sensitive data in logs + formatter = SanitizingFormatter(base_formatter) + # Add formatter to console handler console_handler.setFormatter(formatter) # Add console handler to logger diff --git a/agentflow_cli/src/app/core/config/setup_middleware.py b/agentflow_cli/src/app/core/config/setup_middleware.py index 2ac6d2b..cf3c03f 100644 --- a/agentflow_cli/src/app/core/config/setup_middleware.py +++ b/agentflow_cli/src/app/core/config/setup_middleware.py @@ -9,6 +9,9 @@ from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send +from agentflow_cli.src.app.core.middleware.request_limits import RequestSizeLimitMiddleware +from agentflow_cli.src.app.core.middleware.security_headers import SecurityHeadersMiddleware + from .sentry_config import init_sentry from .settings import get_settings, logger @@ -115,6 +118,25 @@ def setup_middleware(app: FastAPI): app.add_middleware(TrustedHostMiddleware, allowed_hosts=settings.ALLOWED_HOST.split(",")) + # Add request size limit middleware (protects against DoS via large payloads) + app.add_middleware(RequestSizeLimitMiddleware, max_size=settings.MAX_REQUEST_SIZE) + + # Add security headers middleware (if enabled) + if settings.SECURITY_HEADERS_ENABLED: + app.add_middleware( + SecurityHeadersMiddleware, + enable_hsts=settings.HSTS_ENABLED, + hsts_max_age=settings.HSTS_MAX_AGE, + hsts_include_subdomains=settings.HSTS_INCLUDE_SUBDOMAINS, + hsts_preload=settings.HSTS_PRELOAD, + frame_options=settings.FRAME_OPTIONS, + content_type_options=settings.CONTENT_TYPE_OPTIONS, + xss_protection=settings.XSS_PROTECTION, + referrer_policy=settings.REFERRER_POLICY, + permissions_policy=settings.PERMISSIONS_POLICY, + csp_policy=settings.CSP_POLICY, + ) + app.add_middleware(RequestIDMiddleware) # Use SelectiveGZipMiddleware to exclude streaming endpoints from compression diff --git a/agentflow_cli/src/app/core/exceptions/handle_errors.py b/agentflow_cli/src/app/core/exceptions/handle_errors.py index 8fe65cd..d330b25 100644 --- a/agentflow_cli/src/app/core/exceptions/handle_errors.py +++ b/agentflow_cli/src/app/core/exceptions/handle_errors.py @@ -16,6 +16,7 @@ from starlette.requests import Request from agentflow_cli.src.app.core import logger +from agentflow_cli.src.app.core.config.settings import get_settings from agentflow_cli.src.app.utils import error_response from agentflow_cli.src.app.utils.schemas import ErrorSchemas @@ -26,7 +27,42 @@ ) -def init_errors_handler(app: FastAPI): +def _sanitize_error_message(message: str, error_code: str, is_production: bool) -> str: + """ + Sanitize error messages for production to avoid exposing internal details. + + Args: + message: Original error message + error_code: Error code for the exception + is_production: Whether the app is in production mode + + Returns: + Sanitized message (generic in production, detailed in development) + """ + if not is_production: + return message + + # Generic messages by status code category + generic_messages = { + "VALIDATION_ERROR": "The request data is invalid. Please check your input.", + "HTTPException": "An error occurred processing your request.", + "GRAPH_": "An error occurred executing the graph.", + "NODE_": "An error occurred in a graph node.", + "STORAGE_": "An error occurred accessing storage.", + "METRICS_": "An error occurred collecting metrics.", + "SCHEMA_VERSION_": "Schema version mismatch.", + "SERIALIZATION_": "An error occurred processing data.", + } + + # Return generic message based on error code prefix + for prefix, generic_msg in generic_messages.items(): + if error_code.startswith(prefix): + return generic_msg + + return "An unexpected error occurred. Please contact support." + + +def init_errors_handler(app: FastAPI): # noqa: PLR0915 """ Initialize error handlers for the FastAPI application. @@ -41,36 +77,66 @@ def init_errors_handler(app: FastAPI): UserPermissionError: Handles custom user permission errors. APIResourceNotFoundError: Handles custom API resource not found errors. """ + settings = get_settings() + is_production = settings.MODE == "production" @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): logger.error(f"HTTP exception: url: {request.base_url}", exc_info=exc) + + # Get request ID for tracking + request_id = getattr(request.state, "request_id", "unknown") + + message = _sanitize_error_message(str(exc.detail), "HTTPException", is_production) + + # Log full details but return sanitized message in production + if is_production: + logger.error(f"Request {request_id} - HTTPException details: {exc.detail}") + return error_response( request, error_code="HTTPException", - message=str(exc.detail), + message=message, status_code=exc.status_code, ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): logger.error(f"Value error exception: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") details = [ErrorSchemas(**error) for error in exc.errors()] + + # In production, sanitize validation error details + if is_production: + logger.error(f"Request {request_id} - Validation errors: {details}") + message = "The request data is invalid. Please check your input." + else: + message = str(exc.body) if exc.body else "Validation error" + return error_response( request, error_code="VALIDATION_ERROR", - message=str(exc.body), - details=details, + message=message, + details=details if not is_production else None, status_code=422, ) @app.exception_handler(ValueError) async def value_exception_handler(request: Request, exc: ValueError): logger.error(f"Value error exception: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + + message = str(exc) + if is_production: + logger.error(f"Request {request_id} - ValueError details: {message}") + message = "Invalid input provided." + return error_response( request, error_code="VALIDATION_ERROR", - message=str(exc), + message=message, status_code=422, ) @@ -111,99 +177,178 @@ async def resource_not_found_exception_handler(request: Request, exc: APIResourc @app.exception_handler(ValidationError) async def agentflow_validation_exception_handler(request: Request, exc: ValidationError): logger.error(f"AgentFlow ValidationError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + message = _sanitize_error_message(str(exc), "AGENTFLOW_VALIDATION_ERROR", is_production) + + if is_production: + logger.error(f"Request {request_id} - AgentFlow ValidationError: {exc}") + return error_response( request, error_code="AGENTFLOW_VALIDATION_ERROR", - message=str(exc), + message=message, status_code=422, ) @app.exception_handler(GraphError) async def graph_error_exception_handler(request: Request, exc: GraphError): logger.error(f"GraphError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "GRAPH_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - GraphError: {original_message}") + return error_response( request, - error_code=getattr(exc, "error_code", "GRAPH_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, status_code=500, ) @app.exception_handler(NodeError) async def node_error_exception_handler(request: Request, exc: NodeError): logger.error(f"NodeError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "NODE_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - NodeError: {original_message}") + return error_response( request, - error_code=getattr(exc, "error_code", "NODE_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, status_code=500, ) @app.exception_handler(GraphRecursionError) async def graph_recursion_error_exception_handler(request: Request, exc: GraphRecursionError): logger.error(f"GraphRecursionError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "GRAPH_RECURSION_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - GraphRecursionError: {original_message}") + return error_response( request, - error_code=getattr(exc, "error_code", "GRAPH_RECURSION_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, status_code=500, ) + @app.exception_handler(StorageError) + async def storage_error_exception_handler(request: Request, exc: StorageError): + logger.error(f"StorageError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "STORAGE_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - StorageError: {original_message}") + + return error_response( + request, + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, + status_code=500, + ) + + @app.exception_handler(TransientStorageError) + async def transient_storage_error_exception_handler( + request: Request, exc: TransientStorageError + ): + logger.error(f"TransientStorageError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "TRANSIENT_STORAGE_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - TransientStorageError: {original_message}") + + return error_response( + request, + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, + status_code=503, + ) + @app.exception_handler(MetricsError) async def metrics_error_exception_handler(request: Request, exc: MetricsError): logger.error(f"MetricsError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "METRICS_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - MetricsError: {original_message}") + return error_response( request, - error_code=getattr(exc, "error_code", "METRICS_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, status_code=500, ) @app.exception_handler(SchemaVersionError) async def schema_version_error_exception_handler(request: Request, exc: SchemaVersionError): logger.error(f"SchemaVersionError: url: {request.base_url}", exc_info=exc) + + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "SCHEMA_VERSION_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - SchemaVersionError: {original_message}") + return error_response( request, - error_code=getattr(exc, "error_code", "SCHEMA_VERSION_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, status_code=422, ) @app.exception_handler(SerializationError) async def serialization_error_exception_handler(request: Request, exc: SerializationError): logger.error(f"SerializationError: url: {request.base_url}", exc_info=exc) - return error_response( - request, - error_code=getattr(exc, "error_code", "SERIALIZATION_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), - status_code=500, - ) - @app.exception_handler(StorageError) - async def storage_error_exception_handler(request: Request, exc: StorageError): - logger.error(f"StorageError: url: {request.base_url}", exc_info=exc) - return error_response( - request, - error_code=getattr(exc, "error_code", "STORAGE_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), - status_code=500, - ) + request_id = getattr(request.state, "request_id", "unknown") + error_code = getattr(exc, "error_code", "SERIALIZATION_000") + original_message = getattr(exc, "message", str(exc)) + message = _sanitize_error_message(original_message, error_code, is_production) + + if is_production: + logger.error(f"Request {request_id} - SerializationError: {original_message}") - @app.exception_handler(TransientStorageError) - async def transient_storage_error_exception_handler( - request: Request, exc: TransientStorageError - ): - logger.error(f"TransientStorageError: url: {request.base_url}", exc_info=exc) return error_response( request, - error_code=getattr(exc, "error_code", "TRANSIENT_STORAGE_000"), - message=getattr(exc, "message", str(exc)), - details=getattr(exc, "context", None), - status_code=503, + error_code=error_code, + message=message, + details=getattr(exc, "context", None) if not is_production else None, + status_code=500, ) diff --git a/agentflow_cli/src/app/core/middleware/__init__.py b/agentflow_cli/src/app/core/middleware/__init__.py new file mode 100644 index 0000000..8491a5b --- /dev/null +++ b/agentflow_cli/src/app/core/middleware/__init__.py @@ -0,0 +1,11 @@ +"""Middleware modules for agentflow-cli.""" + +from .request_limits import RequestSizeLimitMiddleware +from .security_headers import SecurityHeadersMiddleware, create_security_headers_middleware + + +__all__ = [ + "RequestSizeLimitMiddleware", + "SecurityHeadersMiddleware", + "create_security_headers_middleware", +] diff --git a/agentflow_cli/src/app/core/middleware/request_limits.py b/agentflow_cli/src/app/core/middleware/request_limits.py new file mode 100644 index 0000000..b5f0826 --- /dev/null +++ b/agentflow_cli/src/app/core/middleware/request_limits.py @@ -0,0 +1,73 @@ +"""Request size limit middleware for DoS protection.""" + +from fastapi import Request, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from agentflow_cli.src.app.core import logger + + +class RequestSizeLimitMiddleware(BaseHTTPMiddleware): + """ + Middleware to enforce maximum request body size limits. + + This prevents DoS attacks through excessively large request bodies. + + Args: + app: The ASGI application + max_size: Maximum request body size in bytes (default: 10MB) + """ + + def __init__(self, app, max_size: int = 10 * 1024 * 1024): + super().__init__(app) + self.max_size = max_size + self.max_size_mb = max_size / (1024 * 1024) + + async def dispatch(self, request: Request, call_next): + """ + Check request size and reject if too large. + + Args: + request: The incoming HTTP request + call_next: The next middleware or route handler + + Returns: + Response: Either the normal response or 413 Payload Too Large + """ + # Get content-length header + content_length = request.headers.get("content-length") + + if content_length: + content_length = int(content_length) + + if content_length > self.max_size: + logger.warning( + f"Request rejected: size {content_length} bytes " + f"exceeds limit of {self.max_size} bytes " + f"({self.max_size_mb:.1f}MB)" + ) + + # Get request ID if available + request_id = getattr(request.state, "request_id", "unknown") + + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content={ + "error": { + "code": "REQUEST_TOO_LARGE", + "message": ( + f"Request body too large. " + f"Maximum size is {self.max_size_mb:.1f}MB" + ), + "max_size_bytes": self.max_size, + "max_size_mb": self.max_size_mb, + }, + "metadata": { + "request_id": request_id, + "status": "error", + }, + }, + ) + + # Process request normally + return await call_next(request) diff --git a/agentflow_cli/src/app/core/middleware/security_headers.py b/agentflow_cli/src/app/core/middleware/security_headers.py new file mode 100644 index 0000000..9cc08b2 --- /dev/null +++ b/agentflow_cli/src/app/core/middleware/security_headers.py @@ -0,0 +1,239 @@ +""" +Security Headers Middleware + +This middleware adds standard security headers to HTTP responses to protect +against common web vulnerabilities. + +Headers Added: +- X-Content-Type-Options: Prevents MIME-type sniffing +- X-Frame-Options: Prevents clickjacking attacks +- X-XSS-Protection: Enables XSS filtering (legacy browsers) +- Strict-Transport-Security: Enforces HTTPS (if HTTPS is detected) +- Content-Security-Policy: Controls resource loading +- Referrer-Policy: Controls referrer information +- Permissions-Policy: Controls browser features + +Configuration: +Configure via environment variables or settings: +- SECURITY_HEADERS_ENABLED: Enable/disable middleware (default: true) +- HSTS_MAX_AGE: HSTS max-age in seconds (default: 31536000 = 1 year) +- CSP_POLICY: Custom CSP policy (default: strict policy) +""" + +from collections.abc import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """ + Middleware to add security headers to responses. + + This middleware enhances security by adding standard security headers + that protect against common web vulnerabilities. + """ + + def __init__( # noqa: PLR0913 + self, + app, + enable_hsts: bool = True, + hsts_max_age: int = 31536000, # 1 year + hsts_include_subdomains: bool = True, + hsts_preload: bool = False, + frame_options: str = "DENY", + content_type_options: str = "nosniff", + xss_protection: str = "1; mode=block", + referrer_policy: str = "strict-origin-when-cross-origin", + permissions_policy: str | None = None, + csp_policy: str | None = None, + ): + """ + Initialize security headers middleware. + + Args: + app: ASGI application + enable_hsts: Enable Strict-Transport-Security header + hsts_max_age: HSTS max-age in seconds (default: 1 year) + hsts_include_subdomains: Include subdomains in HSTS + hsts_preload: Enable HSTS preload + frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM) + content_type_options: X-Content-Type-Options value + xss_protection: X-XSS-Protection value + referrer_policy: Referrer-Policy value + permissions_policy: Permissions-Policy value (optional) + csp_policy: Content-Security-Policy value (optional) + """ + super().__init__(app) + self.enable_hsts = enable_hsts + self.hsts_max_age = hsts_max_age + self.hsts_include_subdomains = hsts_include_subdomains + self.hsts_preload = hsts_preload + self.frame_options = frame_options + self.content_type_options = content_type_options + self.xss_protection = xss_protection + self.referrer_policy = referrer_policy + self.permissions_policy = permissions_policy or self._default_permissions_policy() + self.csp_policy = csp_policy or self._default_csp_policy() + + def _default_permissions_policy(self) -> str: + """ + Get default Permissions-Policy header value. + + Returns: + Default permissions policy string + """ + return ( + "geolocation=(), microphone=(), camera=(), payment=(), usb=(), " + "magnetometer=(), gyroscope=(), accelerometer=()" + ) + + def _default_csp_policy(self) -> str: + """ + Get default Content-Security-Policy header value. + + Returns: + Default CSP policy string + """ + return ( + "default-src 'self'; " + "script-src 'self' 'unsafe-inline'; " + "style-src 'self' 'unsafe-inline'; " + "img-src 'self' data: https:; " + "font-src 'self' data:; " + "connect-src 'self'; " + "frame-ancestors 'none'; " + "base-uri 'self'; " + "form-action 'self'" + ) + + def _build_hsts_header(self) -> str: + """ + Build HSTS header value. + + Returns: + HSTS header string + """ + hsts_parts = [f"max-age={self.hsts_max_age}"] + + if self.hsts_include_subdomains: + hsts_parts.append("includeSubDomains") + + if self.hsts_preload: + hsts_parts.append("preload") + + return "; ".join(hsts_parts) + + def _is_https(self, request: Request) -> bool: + """ + Check if request is using HTTPS. + + Checks both the request scheme and X-Forwarded-Proto header + (for proxied requests). + + Args: + request: Starlette request object + + Returns: + True if HTTPS, False otherwise + """ + # Check request scheme + if request.url.scheme == "https": + return True + + # Check X-Forwarded-Proto header (for proxied requests) + forwarded_proto = request.headers.get("X-Forwarded-Proto", "") + return forwarded_proto.lower() == "https" + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process request and add security headers to response. + + Args: + request: Incoming request + call_next: Next middleware/handler + + Returns: + Response with security headers + """ + # Process request + response = await call_next(request) + + # Add X-Content-Type-Options header + response.headers["X-Content-Type-Options"] = self.content_type_options + + # Add X-Frame-Options header + response.headers["X-Frame-Options"] = self.frame_options + + # Add X-XSS-Protection header (for legacy browser support) + response.headers["X-XSS-Protection"] = self.xss_protection + + # Add Referrer-Policy header + response.headers["Referrer-Policy"] = self.referrer_policy + + # Add Permissions-Policy header + if self.permissions_policy: + response.headers["Permissions-Policy"] = self.permissions_policy + + # Add Content-Security-Policy header + if self.csp_policy: + response.headers["Content-Security-Policy"] = self.csp_policy + + # Add Strict-Transport-Security header (only for HTTPS) + if self.enable_hsts and self._is_https(request): + response.headers["Strict-Transport-Security"] = self._build_hsts_header() + + return response + + +def create_security_headers_middleware( + enable_hsts: bool = True, + hsts_max_age: int = 31536000, + hsts_include_subdomains: bool = True, + hsts_preload: bool = False, + frame_options: str = "DENY", + content_type_options: str = "nosniff", + xss_protection: str = "1; mode=block", + referrer_policy: str = "strict-origin-when-cross-origin", + permissions_policy: str | None = None, + csp_policy: str | None = None, +) -> type[SecurityHeadersMiddleware]: + """ + Factory function to create SecurityHeadersMiddleware with configuration. + + This is a convenience function for creating middleware with custom settings. + + Args: + enable_hsts: Enable HSTS header + hsts_max_age: HSTS max-age in seconds + hsts_include_subdomains: Include subdomains in HSTS + hsts_preload: Enable HSTS preload + frame_options: X-Frame-Options value + content_type_options: X-Content-Type-Options value + xss_protection: X-XSS-Protection value + referrer_policy: Referrer-Policy value + permissions_policy: Permissions-Policy value + csp_policy: Content-Security-Policy value + + Returns: + Configured SecurityHeadersMiddleware class + """ + + class ConfiguredSecurityHeadersMiddleware(SecurityHeadersMiddleware): + def __init__(self, app): + super().__init__( + app, + enable_hsts=enable_hsts, + hsts_max_age=hsts_max_age, + hsts_include_subdomains=hsts_include_subdomains, + hsts_preload=hsts_preload, + frame_options=frame_options, + content_type_options=content_type_options, + xss_protection=xss_protection, + referrer_policy=referrer_policy, + permissions_policy=permissions_policy, + csp_policy=csp_policy, + ) + + return ConfiguredSecurityHeadersMiddleware diff --git a/agentflow_cli/src/app/core/utils/__init__.py b/agentflow_cli/src/app/core/utils/__init__.py new file mode 100644 index 0000000..81a8fbb --- /dev/null +++ b/agentflow_cli/src/app/core/utils/__init__.py @@ -0,0 +1,14 @@ +"""Core utilities for AgentFlow CLI.""" + +from agentflow_cli.src.app.core.utils.log_sanitizer import ( + SanitizingFormatter, + sanitize_for_logging, + sanitize_log_message, +) + + +__all__ = [ + "sanitize_for_logging", + "sanitize_log_message", + "SanitizingFormatter", +] diff --git a/agentflow_cli/src/app/core/utils/log_sanitizer.py b/agentflow_cli/src/app/core/utils/log_sanitizer.py new file mode 100644 index 0000000..f85773f --- /dev/null +++ b/agentflow_cli/src/app/core/utils/log_sanitizer.py @@ -0,0 +1,195 @@ +""" +Log sanitization utilities for AgentFlow CLI. + +This module provides utilities to sanitize sensitive data before logging, +preventing tokens, passwords, and other credentials from appearing in logs. +""" + +import re +from typing import Any + + +# Patterns for detecting sensitive field names +SENSITIVE_PATTERNS = { + "token", + "password", + "secret", + "key", + "credential", + "authorization", + "api_key", + "access_token", + "refresh_token", + "auth", + "bearer", + "jwt", + "session", + "cookie", + "private_key", + "passphrase", + "pin", + "ssn", + "credit_card", +} + +# Regex pattern to detect JWT tokens (three base64url parts separated by dots) +JWT_PATTERN = re.compile(r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.?[A-Za-z0-9-_.+/=]*$") + +# Regex pattern to detect bearer tokens +BEARER_PATTERN = re.compile(r"^Bearer\s+[A-Za-z0-9\-_.~+/]+=*$", re.IGNORECASE) + + +def sanitize_for_logging(data: Any, max_depth: int = 10, _current_depth: int = 0) -> Any: + """ + Recursively sanitize sensitive data for safe logging. + + This function walks through data structures and replaces sensitive values + with redaction markers. It detects: + - Dictionary keys containing sensitive patterns + - JWT tokens (by structure) + - Bearer tokens + - Authorization headers + + Args: + data: The data to sanitize (dict, list, str, or any other type) + max_depth: Maximum recursion depth to prevent infinite loops + _current_depth: Internal parameter tracking current recursion depth + + Returns: + Sanitized copy of the data with sensitive values redacted + + Examples: + >>> sanitize_for_logging({"user_id": "123", "token": "abc123"}) + {'user_id': '123', 'token': '***REDACTED***'} + + >>> sanitize_for_logging({"Authorization": "Bearer eyJhbGc..."}) + {'Authorization': '***REDACTED***'} + """ + # Prevent excessive recursion + if _current_depth >= max_depth: + return "***MAX_DEPTH_REACHED***" + + if isinstance(data, dict): + return {k: _sanitize_value(k, v, max_depth, _current_depth + 1) for k, v in data.items()} + if isinstance(data, list): + return [sanitize_for_logging(item, max_depth, _current_depth + 1) for item in data] + if isinstance(data, tuple): + return tuple(sanitize_for_logging(item, max_depth, _current_depth + 1) for item in data) + if isinstance(data, str): + return _sanitize_string(data) + return data + + +def _sanitize_value(key: str, value: Any, max_depth: int, current_depth: int) -> Any: + """ + Sanitize a value based on its key name and content. + + Args: + key: The dictionary key + value: The value to sanitize + max_depth: Maximum recursion depth + current_depth: Current recursion depth + + Returns: + Sanitized value + """ + # Check if key name contains sensitive patterns + key_lower = key.lower() + if any(pattern in key_lower for pattern in SENSITIVE_PATTERNS): + return "***REDACTED***" + + # Recursively sanitize the value + return sanitize_for_logging(value, max_depth, current_depth) + + +def _sanitize_string(value: str) -> str: + """ + Sanitize a string value if it appears to be sensitive. + + Args: + value: String to check and potentially sanitize + + Returns: + Original string or redaction marker + """ + max_value = 32 + # Check for JWT token pattern + if len(value) > max_value and JWT_PATTERN.match(value): + return "***JWT_TOKEN***" + + # Check for Bearer token pattern + if BEARER_PATTERN.match(value): + return "***BEARER_TOKEN***" + + # Check if string looks like a long random token (>32 chars, alphanumeric) + if len(value) > max_value and value.replace("-", "").replace("_", "").isalnum(): + # Could be an API key or token + return f"{value[:4]}...{value[-4:]}" + + return value + + +def sanitize_log_message(message: str, *args: Any, **kwargs: Any) -> tuple[str, tuple, dict]: + """ + Sanitize log message arguments. + + This is useful for sanitizing arguments passed to logger.debug(), logger.info(), etc. + + Args: + message: Log message format string + *args: Positional arguments for the log message + **kwargs: Keyword arguments for the log message + + Returns: + Tuple of (message, sanitized_args, sanitized_kwargs) + + Examples: + >>> msg, args, kwargs = sanitize_log_message( + ... "User %s logged in", {"user_id": "123", "token": "secret"} + ... ) + >>> # args will have sanitized data + """ + sanitized_args = tuple(sanitize_for_logging(arg) for arg in args) + sanitized_kwargs = {k: sanitize_for_logging(v) for k, v in kwargs.items()} + return message, sanitized_args, sanitized_kwargs + + +class SanitizingFormatter: + """ + A mixin or wrapper for log formatters that sanitizes sensitive data. + + This can be used to wrap existing formatters to add sanitization. + + Example: + import logging + + formatter = logging.Formatter('%(asctime)s - %(message)s') + sanitizing_formatter = SanitizingFormatter(formatter) + handler.setFormatter(sanitizing_formatter) + """ + + def __init__(self, base_formatter): + """ + Initialize the sanitizing formatter. + + Args: + base_formatter: The underlying formatter to wrap + """ + self.base_formatter = base_formatter + + def format(self, record): + """ + Format the log record with sanitization. + + Args: + record: LogRecord to format + + Returns: + Formatted and sanitized log string + """ + # Sanitize the message arguments + if record.args: + record.args = tuple(sanitize_for_logging(arg) for arg in record.args) + + # Format using the base formatter + return self.base_formatter.format(record) diff --git a/agentflow_cli/src/app/loader.py b/agentflow_cli/src/app/loader.py index b3f52bf..3692c3d 100644 --- a/agentflow_cli/src/app/loader.py +++ b/agentflow_cli/src/app/loader.py @@ -10,6 +10,10 @@ from injectq import InjectQ from agentflow_cli import BaseAuth +from agentflow_cli.src.app.core.auth.authorization import ( + AuthorizationBackend, + DefaultAuthorizationBackend, +) from agentflow_cli.src.app.core.config.graph_config import GraphConfig from agentflow_cli.src.app.utils.thread_name_generator import ThreadNameGenerator @@ -152,6 +156,44 @@ def load_auth(path: str | None) -> BaseAuth | None: return auth +def load_authorization(path: str | None) -> AuthorizationBackend | None: + """ + Load authorization backend from the specified path. + + Args: + path: Module path in format 'module:attribute' or None + + Returns: + AuthorizationBackend instance or None if path is not provided + + Raises: + Exception: If loading fails or object is not AuthorizationBackend + """ + if not path: + return None + + module_name_importable, function_name = path.split(":") + + try: + module = importlib.import_module(module_name_importable) + entry_point_obj = getattr(module, function_name) + + # If it's a class, instantiate it; if it's an instance, use as is + if inspect.isclass(entry_point_obj) and issubclass(entry_point_obj, AuthorizationBackend): + authorization = entry_point_obj() + elif isinstance(entry_point_obj, AuthorizationBackend): + authorization = entry_point_obj + else: + raise TypeError("Loaded object is not a subclass or instance of AuthorizationBackend.") + + logger.info(f"Successfully loaded AuthorizationBackend '{function_name}' from {path}.") + except Exception as e: + logger.error(f"Error loading AuthorizationBackend from {path}: {e}") + raise Exception(f"Failed to load AuthorizationBackend from {path}: {e}") + + return authorization + + def load_thread_name_generator(path: str | None) -> ThreadNameGenerator | None: if not path: return None @@ -178,6 +220,50 @@ def load_thread_name_generator(path: str | None) -> ThreadNameGenerator | None: return thread_name_generator +def load_and_bind_auth(container: InjectQ, auth_config: dict) -> None: + from agentflow_cli.src.app.core.auth.jwt_auth import JwtAuth + + method = auth_config.get("method") + path = auth_config.get("path") + if not path or not method: + raise ValueError("Both 'method' and 'path' must be specified in auth_config.") + + # Extract file path before the ':' for existence check + module_or_path = path.split(":", 1)[0] if ":" in path else path + + # Simple handling: if it appears to be a filesystem path, use it; otherwise + # convert dotted module path to a file path like src/auth/custom_auth.py + if os.path.sep in module_or_path or module_or_path.endswith(".py"): + file_path = Path(module_or_path) + elif "." in module_or_path and os.path.sep not in module_or_path: + file_path = Path(module_or_path.replace(".", os.path.sep) + ".py") + else: + file_path = Path(module_or_path) + + if not file_path.exists(): + raise ValueError(f"Custom auth path does not exist: {module_or_path}") + + auth_backends = { + "custom": lambda: load_auth(path), + "jwt": lambda: JwtAuth(), + "none": lambda: None, + } + + auth_backend = auth_backends.get(method, lambda: None)() + container.bind_instance(BaseAuth, auth_backend, allow_none=True) + + +def load_and_bind_authorization(container: InjectQ, authorization_path: str | None) -> None: + if authorization_path: + authorization_backend = load_authorization(authorization_path) + container.bind_instance(AuthorizationBackend, authorization_backend) + else: + # Use default authorization backend if not configured + default_authorization = DefaultAuthorizationBackend() + container.bind_instance(AuthorizationBackend, default_authorization) + logger.info("Using DefaultAuthorizationBackend (allows all authenticated users)") + + async def attach_all_modules( config: GraphConfig, container: InjectQ, @@ -197,39 +283,7 @@ async def attach_all_modules( # load auth backend auth_config = config.auth_config() if auth_config: - method = auth_config.get("method", None) - path = auth_config.get("path", None) - if not path or not method: - raise ValueError("Both 'method' and 'path' must be specified in auth_config.") - - # Extract file path before the ':' for existence check - module_or_path = path.split(":", 1)[0] if ":" in path else path - - # Simple handling: if it appears to be a filesystem path, use it; otherwise - # convert dotted module path to a file path like src/auth/custom_auth.py - if os.path.sep in module_or_path or module_or_path.endswith(".py"): - file_path = Path(module_or_path) - elif "." in module_or_path and os.path.sep not in module_or_path: - file_path = Path(module_or_path.replace(".", os.path.sep) + ".py") - else: - file_path = Path(module_or_path) - - if not file_path.exists(): - raise ValueError(f"Custom auth path does not exist: {module_or_path}") - - if method == "custom": - auth_backend = load_auth( - path, - ) - container.bind_instance(BaseAuth, auth_backend) - elif method == "jwt": - from agentflow_cli.src.app.core.auth.jwt_auth import JwtAuth - - jwt_auth = JwtAuth() - container.bind_instance(BaseAuth, jwt_auth) - - elif method == "none": - container.bind_instance(BaseAuth, None, allow_none=True) + load_and_bind_auth(container, auth_config) else: # bind None container.bind_instance(BaseAuth, None, allow_none=True) @@ -243,6 +297,10 @@ async def attach_all_modules( # bind None if not configured container.bind_instance(ThreadNameGenerator, None, allow_none=True) + # load authorization backend + authorization_path = config.authorization_path + load_and_bind_authorization(container, authorization_path) + logger.info("Container loaded successfully") logger.debug(f"Container dependency graph: {container.get_dependency_graph()}") diff --git a/agentflow_cli/src/app/routers/checkpointer/router.py b/agentflow_cli/src/app/routers/checkpointer/router.py index 45b63b7..9a1653e 100644 --- a/agentflow_cli/src/app/routers/checkpointer/router.py +++ b/agentflow_cli/src/app/routers/checkpointer/router.py @@ -6,8 +6,7 @@ from fastapi import APIRouter, Depends, Request, status from injectq.integrations import InjectAPI -from agentflow_cli.src.app.core import logger -from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user +from agentflow_cli.src.app.core.auth.permissions import RequirePermission from agentflow_cli.src.app.utils.response_helper import success_response from agentflow_cli.src.app.utils.swagger_helper import generate_swagger_responses @@ -38,7 +37,7 @@ async def get_state( request: Request, thread_id: int | str, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "read")), ): """Get state from checkpointer. @@ -49,8 +48,6 @@ async def get_state( Returns: State response with state data or error """ - logger.debug(f"User info: {user}") - config = {"thread_id": thread_id} result = await service.get_state( @@ -76,7 +73,7 @@ async def put_state( thread_id: str | int, payload: StateSchema, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "write")), ): """Put state to checkpointer. @@ -89,7 +86,6 @@ async def put_state( Returns: Success response or error """ - logger.debug(f"User info: {user}") config = {"thread_id": thread_id} if payload.config: config.update(payload.config) @@ -118,7 +114,7 @@ async def clear_state( request: Request, thread_id: int | str, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "delete")), ): """Clear state from checkpointer. @@ -131,7 +127,6 @@ async def clear_state( Returns: Success response or error """ - logger.debug(f"User info: {user}") config = {"thread_id": thread_id} res = await service.clear_state( @@ -160,7 +155,7 @@ async def put_messages( thread_id: str | int, payload: PutMessagesSchema, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "write")), ): """Put messages to checkpointer. @@ -173,8 +168,6 @@ async def put_messages( Returns: Success response or error """ - logger.debug(f"User info: {user}") - # Convert message dicts to Message objects if needed config = {"thread_id": thread_id} if payload.config: @@ -207,7 +200,7 @@ async def get_message( thread_id: str | int, message_id: str | int, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "read")), ): """Get message from checkpointer. @@ -220,7 +213,6 @@ async def get_message( Returns: Message response with message data or error """ - logger.debug(f"User info: {user}") config = {"thread_id": thread_id} result = await service.get_message( @@ -250,7 +242,7 @@ async def list_messages( offset: int | None = None, limit: int | None = None, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "read")), ): """List messages from checkpointer. @@ -263,7 +255,6 @@ async def list_messages( Returns: Messages list response with messages data or error """ - logger.debug(f"User info: {user}") config = {"thread_id": thread_id} result = await service.get_messages( @@ -293,7 +284,7 @@ async def delete_message( thread_id: str | int, payload: ConfigSchema, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "delete")), ): """Delete message from checkpointer. @@ -306,7 +297,6 @@ async def delete_message( Returns: Success response or error """ - logger.debug(f"User info: {user}") config = {"thread_id": thread_id} if payload.config: config.update(payload.config) @@ -337,7 +327,7 @@ async def get_thread( request: Request, thread_id: str | int, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "read")), ): """Get thread from checkpointer. @@ -350,8 +340,6 @@ async def get_thread( Returns: Thread response with thread data or error """ - logger.debug(f"User info: {user}") - result = await service.get_thread( {"thread_id": thread_id}, user, @@ -376,7 +364,7 @@ async def list_threads( offset: int | None = None, limit: int | None = None, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "read")), ): """List threads from checkpointer. @@ -389,8 +377,6 @@ async def list_threads( Returns: Threads list response with threads data or error """ - logger.debug(f"User info: {user}") - result = await service.list_threads( user, search, @@ -416,7 +402,7 @@ async def delete_thread( thread_id: str | int, payload: ConfigSchema, service: CheckpointerService = InjectAPI(CheckpointerService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("checkpointer", "delete")), ): """Delete thread from checkpointer. @@ -429,8 +415,6 @@ async def delete_thread( Returns: Success response or error """ - logger.debug(f"User info: {user} and thread ID: {thread_id}") - config = {"thread_id": thread_id} if payload.config: config.update(payload.config) diff --git a/agentflow_cli/src/app/routers/checkpointer/services/checkpointer_service.py b/agentflow_cli/src/app/routers/checkpointer/services/checkpointer_service.py index e3070d9..ce9ef14 100644 --- a/agentflow_cli/src/app/routers/checkpointer/services/checkpointer_service.py +++ b/agentflow_cli/src/app/routers/checkpointer/services/checkpointer_service.py @@ -6,6 +6,7 @@ from agentflow_cli.src.app.core import logger from agentflow_cli.src.app.core.config.settings import get_settings +from agentflow_cli.src.app.core.utils.log_sanitizer import sanitize_for_logging from agentflow_cli.src.app.routers.checkpointer.schemas.checkpointer_schemas import ( MessagesListResponseSchema, ResponseSchema, @@ -130,7 +131,7 @@ async def delete_message( # Threads async def get_thread(self, config: dict[str, Any], user: dict) -> ThreadResponseSchema: cfg = self._config(config, user) - logger.debug(f"User info: {user} and thread config: {cfg}") + logger.debug(f"User info: {sanitize_for_logging(user)} and thread config: {cfg}") res = await self.checkpointer.aget_thread(cfg) return ThreadResponseSchema(thread=res.model_dump() if res else None) @@ -152,7 +153,7 @@ async def delete_thread( thread_id: Any, ) -> ResponseSchema: cfg = self._config(config, user) - logger.debug(f"User info: {user} and thread ID: {thread_id}") + logger.debug(f"User info: {sanitize_for_logging(user)} and thread ID: {thread_id}") res = await self.checkpointer.aclean_thread(cfg) return ResponseSchema(success=True, message="Thread deleted successfully", data=res) diff --git a/agentflow_cli/src/app/routers/graph/router.py b/agentflow_cli/src/app/routers/graph/router.py index b22e72b..ccdcd79 100644 --- a/agentflow_cli/src/app/routers/graph/router.py +++ b/agentflow_cli/src/app/routers/graph/router.py @@ -6,7 +6,7 @@ from fastapi.responses import StreamingResponse from injectq.integrations import InjectAPI -from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user +from agentflow_cli.src.app.core.auth.permissions import RequirePermission from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( FixGraphRequestSchema, GraphInputSchema, @@ -36,13 +36,12 @@ async def invoke_graph( request: Request, graph_input: GraphInputSchema, service: GraphService = InjectAPI(GraphService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "invoke")), ): """ Invoke the graph with the provided input and return the final result. """ logger.info(f"Graph invoke request received with {len(graph_input.messages)} messages") - logger.debug(f"User info: {user}") result: GraphInvokeOutputSchema = await service.invoke_graph( graph_input, @@ -67,7 +66,7 @@ async def invoke_graph( async def stream_graph( graph_input: GraphInputSchema, service: GraphService = InjectAPI(GraphService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "stream")), ): """ Stream the graph execution with real-time output. @@ -101,7 +100,7 @@ async def stream_graph( async def graph_details( request: Request, service: GraphService = InjectAPI(GraphService), - _: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "read")), ): """ Invoke the graph with the provided input and return the final result. @@ -128,7 +127,7 @@ async def graph_details( async def state_schema( request: Request, service: GraphService = InjectAPI(GraphService), - _: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "read")), ): """ Invoke the graph with the provided input and return the final result. @@ -156,7 +155,7 @@ async def stop_graph( request: Request, stop_request: GraphStopSchema, service: GraphService = InjectAPI(GraphService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "stop")), ): """ Stop the graph execution for a specific thread. @@ -168,7 +167,6 @@ async def stop_graph( Status information about the stop operation """ logger.info(f"Graph stop request received for thread: {stop_request.thread_id}") - logger.debug(f"User info: {user}") result = await service.stop_graph(stop_request.thread_id, user, stop_request.config) @@ -191,7 +189,7 @@ async def setup_graph( request: Request, setup_request: GraphSetupSchema, service: GraphService = InjectAPI(GraphService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "setup")), ): """ Setup the graph execution for a specific thread. @@ -203,7 +201,6 @@ async def setup_graph( Status information about the setup operation """ logger.info("Graph setup request received") - logger.debug(f"User info: {user}") result = await service.setup(setup_request) @@ -230,7 +227,7 @@ async def fix_graph( request: Request, fix_request: FixGraphRequestSchema, service: GraphService = InjectAPI(GraphService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("graph", "fix")), ): """ Fix the graph execution state for a specific thread. @@ -257,7 +254,6 @@ async def fix_graph( for the given thread_id """ logger.info(f"Graph fix request received for thread: {fix_request.thread_id}") - logger.debug(f"User info: {user}") result = await service.fix_graph( fix_request.thread_id, diff --git a/agentflow_cli/src/app/routers/graph/services/graph_service.py b/agentflow_cli/src/app/routers/graph/services/graph_service.py index 7451668..c357e4d 100644 --- a/agentflow_cli/src/app/routers/graph/services/graph_service.py +++ b/agentflow_cli/src/app/routers/graph/services/graph_service.py @@ -13,6 +13,7 @@ from agentflow_cli.src.app.core import logger from agentflow_cli.src.app.core.config.graph_config import GraphConfig +from agentflow_cli.src.app.core.utils.log_sanitizer import sanitize_for_logging from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( GraphInputSchema, GraphInvokeOutputSchema, @@ -135,7 +136,7 @@ async def stop_graph( """ try: logger.info(f"Stopping graph execution for thread: {thread_id}") - logger.debug(f"User info: {user}") + logger.debug(f"User info: {sanitize_for_logging(user)}") # Prepare config with thread_id and user info stop_config = { @@ -406,7 +407,7 @@ async def fix_graph( """ try: logger.info(f"Starting fix graph operation for thread: {thread_id}") - logger.debug(f"User info: {user}") + logger.debug(f"User info: {sanitize_for_logging(user)}") fix_config = {"thread_id": thread_id, "user": user} fix_config["user_id"] = user.get("user_id", "anonymous") diff --git a/agentflow_cli/src/app/routers/store/router.py b/agentflow_cli/src/app/routers/store/router.py index 83e7e7f..8c2f047 100644 --- a/agentflow_cli/src/app/routers/store/router.py +++ b/agentflow_cli/src/app/routers/store/router.py @@ -7,8 +7,7 @@ from fastapi import APIRouter, Body, Depends, Request, status from injectq.integrations import InjectAPI -from agentflow_cli.src.app.core import logger -from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user +from agentflow_cli.src.app.core.auth.permissions import RequirePermission from agentflow_cli.src.app.utils.response_helper import success_response from agentflow_cli.src.app.utils.swagger_helper import generate_swagger_responses @@ -43,11 +42,10 @@ async def create_memory( request: Request, payload: StoreMemorySchema, service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "write")), ): """Store a memory item using the configured store.""" - logger.debug("User info: %s", user) result = await service.store_memory(payload, user) return success_response(result, request, message="Memory stored successfully") @@ -63,11 +61,10 @@ async def search_memories( request: Request, payload: SearchMemorySchema, service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "read")), ): """Search stored memories.""" - logger.debug("User info: %s", user) result = await service.search_memories(payload, user) return success_response(result, request) @@ -87,11 +84,10 @@ async def get_memory( description="Optional configuration and options for retrieving the memory.", ), service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "read")), ): """Get a memory by ID.""" - logger.debug("User info: %s", user) cfg = payload.config if payload else {} opts = payload.options if payload else None result = await service.get_memory(memory_id, cfg, user, options=opts) @@ -112,11 +108,10 @@ async def list_memories( description="Optional configuration, limit, and options for listing memories.", ), service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "read")), ): """List stored memories.""" - logger.debug("User info: %s", user) if payload is None: payload = ListMemoriesSchema() cfg = payload.config or {} @@ -137,11 +132,10 @@ async def update_memory( memory_id: str, payload: UpdateMemorySchema, service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "write")), ): """Update a stored memory.""" - logger.debug("User info: %s", user) result = await service.update_memory(memory_id, payload, user) return success_response(result, request, message="Memory updated successfully") @@ -161,11 +155,10 @@ async def delete_memory( description="Optional configuration overrides forwarded to the store backend.", ), service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "delete")), ): """Delete a stored memory.""" - logger.debug("User info: %s", user) config_payload = payload.config if payload else {} options_payload = payload.options if payload else None result = await service.delete_memory(memory_id, config_payload, user, options=options_payload) @@ -183,10 +176,9 @@ async def forget_memory( request: Request, payload: ForgetMemorySchema, service: StoreService = InjectAPI(StoreService), - user: dict[str, Any] = Depends(verify_current_user), + user: dict[str, Any] = Depends(RequirePermission("store", "delete")), ): """Forget memories based on filters.""" - logger.debug("User info: %s", user) result = await service.forget_memory(payload, user) return success_response(result, request, message="Memories removed successfully") diff --git a/docs/authentication.md b/docs/authentication.md index d1a2ff3..381a084 100644 --- a/docs/authentication.md +++ b/docs/authentication.md @@ -5,9 +5,11 @@ This guide covers implementing authentication in your AgentFlow application usin ## Table of Contents - [Overview](#overview) +- [Security Considerations](#security-considerations) - [No Authentication](#no-authentication) - [JWT Authentication](#jwt-authentication) - [Custom Authentication](#custom-authentication) +- [Authorization](#authorization) - [BaseAuth Interface](#baseauth-interface) - [Best Practices](#best-practices) - [Examples](#examples) @@ -35,6 +37,67 @@ Authentication is configured in `agentflow.json`: --- +## Security Considerations + +### Authentication vs Authorization + +**Authentication** answers: *"Who are you?"* +- Verifies user identity +- Validates credentials (tokens, API keys, etc.) +- Returns user context + +**Authorization** answers: *"What can you do?"* +- Controls access to resources +- Enforces permissions +- Checks user roles and privileges + +AgentFlow provides both: +- **Authentication**: Via this authentication system +- **Authorization**: Via the [Authorization Framework](../SECURITY.md#authorization) + +See [SECURITY.md](../SECURITY.md) for comprehensive security documentation. + +### Security Requirements by Environment + +**Development:** +- Can use `auth: null` for convenience +- Use simple secrets for testing +- API docs can be enabled + +**Staging:** +- Should use JWT or custom authentication +- Use environment-specific secrets +- Test security configurations + +**Production:** +- **MUST** use JWT or custom authentication +- **MUST** use strong, random secrets (32+ characters) +- **MUST** use HTTPS +- **MUST** disable API documentation +- **MUST** implement authorization +- **MUST** enable request size limits +- **MUST** configure specific CORS origins + +### Security Checklist + +Before deploying to production: + +- [ ] Authentication enabled (`auth: "jwt"` or custom) +- [ ] Strong JWT secret configured (generate with `secrets.token_urlsafe(32)`) +- [ ] HTTPS enabled with valid SSL/TLS certificates +- [ ] Authorization backend implemented +- [ ] CORS configured with specific domains (not `*`) +- [ ] API documentation disabled (`DOCS_PATH=` and `REDOCS_PATH=`) +- [ ] Debug mode disabled (`IS_DEBUG=false`) +- [ ] Request size limits configured (`MAX_REQUEST_SIZE=10485760`) +- [ ] Rate limiting implemented +- [ ] Security headers configured +- [ ] Logs sanitized for sensitive data + +See the [Production Deployment Checklist](../SECURITY.md#production-deployment) for complete details. + +--- + ## No Authentication ### Configuration @@ -389,6 +452,149 @@ class MyAuthBackend(BaseAuth): --- +## Authorization + +### Overview + +After authentication establishes *who* the user is, authorization determines *what* they can do. AgentFlow provides a flexible authorization framework. + +### Authorization Flow + +``` +Request → Authentication → Authorization → Resource Access + (Who are you?) (What can you do?) +``` + +All AgentFlow endpoints automatically enforce authorization using the `RequirePermission` dependency: + +```python +from agentflow_cli.src.app.core.auth.permissions import RequirePermission +from fastapi import Depends + +@router.post("/graph/invoke") +async def invoke_graph( + user: dict = Depends(RequirePermission("graph", "invoke")), + request: GraphRequest +): + # User is authenticated AND authorized + # Can proceed with graph invocation + pass +``` + +### Implementing Authorization + +**Step 1: Create Authorization Backend** + +```python +# auth/rbac_backend.py +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend +from typing import Any + +class RBACAuthorizationBackend(AuthorizationBackend): + """Role-Based Access Control authorization.""" + + PERMISSIONS = { + "admin": { + "graph": ["invoke", "stream", "read", "stop", "setup", "fix"], + "checkpointer": ["read", "write", "delete"], + "store": ["read", "write", "delete", "forget"] + }, + "developer": { + "graph": ["invoke", "stream", "read", "setup"], + "checkpointer": ["read", "write"], + "store": ["read", "write"] + }, + "viewer": { + "graph": ["read"], + "checkpointer": ["read"], + "store": ["read"] + } + } + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context + ) -> bool: + """Check if user's role permits the action.""" + role = user.get("role", "viewer") + role_permissions = self.PERMISSIONS.get(role, {}) + allowed_actions = role_permissions.get(resource, []) + return action in allowed_actions +``` + +**Step 2: Configure agentflow.json** + +```json +{ + "auth": "jwt", + "authorization": { + "path": "auth.rbac_backend:RBACAuthorizationBackend" + }, + "agent": "graph.react:app" +} +``` + +**Step 3: Include Role in User Context** + +Ensure your authentication returns user role: + +```python +# JWT payload +{ + "user_id": "user_123", + "email": "user@example.com", + "role": "developer", # Include role for authorization + "exp": 1735689600 +} +``` + +### Authorization Examples + +See [examples/security/](../examples/security/) for complete implementations: +- **RBAC** - Role-Based Access Control +- **Ownership** - Resource ownership-based authorization +- **ABAC** - Attribute-Based Access Control + +### Resources and Actions + +AgentFlow defines these resources and actions: + +**Graph Resource:** +- `invoke` - Execute graph +- `stream` - Stream graph execution +- `read` - View graph details and state +- `stop` - Stop running graph +- `setup` - Configure graph +- `fix` - Fix graph errors + +**Checkpointer Resource:** +- `read` - View state, messages, threads +- `write` - Update state, add messages +- `delete` - Delete threads, messages, state + +**Store Resource:** +- `read` - Search and view memories +- `write` - Create and update memories +- `delete` - Delete memories +- `forget` - Permanently forget memories + +### Authorization Best Practices + +1. **Deny by Default** - Only grant explicit permissions +2. **Least Privilege** - Grant minimum necessary access +3. **Validate Every Request** - Never cache authorization decisions +4. **Resource-Level Control** - Check permissions for specific resources +5. **Audit Failures** - Log all authorization denials +6. **Separate Concerns** - Keep authentication and authorization logic separate + +For complete authorization documentation, see [SECURITY.md - Authorization](../SECURITY.md#authorization). + +--- + ## BaseAuth Interface ### Abstract Method diff --git a/docs/authorization.md b/docs/authorization.md new file mode 100644 index 0000000..2ed630a --- /dev/null +++ b/docs/authorization.md @@ -0,0 +1,704 @@ +# Authorization Guide + +## Overview + +AgentFlow CLI provides a flexible, pluggable authorization system that allows you to implement resource-level access control for your multi-agent applications. Authorization determines **what authenticated users can do**, complementing the authentication system which determines **who users are**. + +## Key Concepts + +### Authorization vs Authentication + +- **Authentication** (covered in [authentication.md](authentication.md)): Verifies user identity and returns user information +- **Authorization**: Determines if an authenticated user has permission to perform specific actions on resources + +### Authorization Backend + +The authorization backend is a pluggable component that implements the `AuthorizationBackend` abstract class. It receives: +- **User context**: Information about the authenticated user +- **Resource**: The type of resource being accessed (e.g., "graph", "checkpointer", "store") +- **Action**: The operation being performed (e.g., "invoke", "read", "write", "delete") +- **Resource ID**: Optional specific identifier (e.g., thread_id, memory_id) +- **Additional context**: Extra information for authorization decisions + +The backend returns `True` if authorized, `False` otherwise. On `False`, the API returns a 403 Forbidden response. + +## Default Behavior + +If you don't configure a custom authorization backend, AgentFlow uses `DefaultAuthorizationBackend`, which: +- Allows all authenticated users to perform any action +- Only checks that the user has a `user_id` (i.e., they're authenticated) +- Suitable for development and simple use cases + +## Implementing Custom Authorization + +### Step 1: Create Your Authorization Backend + +Create a Python file (e.g., `my_auth/authorization.py`) and implement the `AuthorizationBackend` class: + +```python +from typing import Any +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend + +class RoleBasedAuthorizationBackend(AuthorizationBackend): + """ + Example: Role-based access control (RBAC) + + Users have roles (admin, user, viewer) with different permissions. + """ + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context: Any, + ) -> bool: + """ + Check if user can perform action on resource. + + Args: + user: User info dict (from authentication) + resource: "graph", "checkpointer", or "store" + action: "invoke", "stream", "read", "write", "delete", etc. + resource_id: Optional specific resource identifier + **context: Additional context (e.g., request payload) + + Returns: + bool: True if authorized, False otherwise + """ + # Extract user role from user info + user_role = user.get("role", "viewer") + user_id = user.get("user_id") + + # Admins can do anything + if user_role == "admin": + return True + + # Graph operations + if resource == "graph": + if action in ["invoke", "stream"]: + # Users and viewers can invoke/stream + return user_role in ["user", "viewer"] + elif action in ["stop", "setup", "fix"]: + # Only users and admins can modify + return user_role == "user" + + # Checkpointer operations + elif resource == "checkpointer": + if action == "read": + # Everyone can read their own data + return True + elif action in ["write", "delete"]: + # Only users and admins can modify + return user_role == "user" + + # Store operations + elif resource == "store": + if action == "read": + # Everyone can read + return True + elif action in ["write", "delete"]: + # Only users and admins can modify + return user_role == "user" + + # Deny by default + return False + + +# Create an instance to export +authorization_backend = RoleBasedAuthorizationBackend() +``` + +### Step 2: Configure in agentflow.json + +Add the authorization configuration to your `agentflow.json`: + +```json +{ + "agent": "my_agent.graph:agent", + "checkpointer": "my_agent.checkpointer:checkpointer", + "store": "my_agent.store:store", + "auth": { + "method": "jwt", + "path": "my_auth.auth_backend:auth_backend" + }, + "authorization": "my_auth.authorization:authorization_backend" +} +``` + +The `authorization` field should point to your authorization backend instance in the format: +``` +"module.path:attribute_name" +``` + +### Step 3: Test Your Authorization + +Start your API and test with different user roles: + +```bash +# Admin user can do anything +curl -X POST http://localhost:8000/v1/graph/invoke \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"messages": [{"role": "user", "content": "Hello"}]}' + +# Viewer can invoke but cannot stop +curl -X POST http://localhost:8000/v1/graph/stop \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{"thread_id": "123"}' +# Returns: 403 Forbidden +``` + +## Common Authorization Patterns + +### 1. Role-Based Access Control (RBAC) + +Users have roles (admin, user, viewer) with predefined permissions: + +```python +class RBACBackend(AuthorizationBackend): + PERMISSIONS = { + "admin": ["*"], # All actions + "user": ["invoke", "stream", "read", "write"], + "viewer": ["read"] + } + + async def authorize(self, user, resource, action, resource_id=None, **context): + role = user.get("role", "viewer") + allowed_actions = self.PERMISSIONS.get(role, []) + return "*" in allowed_actions or action in allowed_actions + +authorization_backend = RBACBackend() +``` + +### 2. Attribute-Based Access Control (ABAC) + +Fine-grained control based on user attributes, resource attributes, and context: + +```python +class ABACBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + user_id = user.get("user_id") + user_department = user.get("department") + + # Users can only access their own threads + if resource == "checkpointer" and resource_id: + thread_owner = await self.get_thread_owner(resource_id) + if thread_owner != user_id: + return False + + # Department-specific access + if user_department == "engineering": + return True + elif user_department == "sales": + # Sales can only invoke, not modify + return action in ["invoke", "stream", "read"] + + return False + + async def get_thread_owner(self, thread_id): + # Query your database to get thread owner + # This is just an example + return "some-user-id" + +authorization_backend = ABACBackend() +``` + +### 3. Owner-Based Access Control + +Users can only access resources they created: + +```python +class OwnershipBackend(AuthorizationBackend): + def __init__(self, checkpointer): + self.checkpointer = checkpointer + + async def authorize(self, user, resource, action, resource_id=None, **context): + user_id = user.get("user_id") + + # For checkpointer operations, verify thread ownership + if resource == "checkpointer" and resource_id: + state = await self.checkpointer.aget_state({"thread_id": resource_id}) + if state and state.config.get("user_id") != user_id: + return False + + # For store operations, verify memory ownership + if resource == "store" and resource_id: + memory = await self.get_memory(resource_id) + if memory and memory.get("user_id") != user_id: + return False + + # Allow if no specific resource or if owner matches + return True + + async def get_memory(self, memory_id): + # Retrieve memory and check ownership + return {"user_id": "owner-id"} + +# You'll need to pass checkpointer instance +# This requires modifying loader.py to support constructor args +authorization_backend = OwnershipBackend(checkpointer=None) +``` + +### 4. API Key Tier-Based Access + +Different access levels based on API key tiers: + +```python +class TierBasedBackend(AuthorizationBackend): + TIER_LIMITS = { + "free": {"max_requests_per_day": 100, "allowed_actions": ["invoke", "read"]}, + "pro": {"max_requests_per_day": 10000, "allowed_actions": ["invoke", "stream", "read", "write"]}, + "enterprise": {"max_requests_per_day": -1, "allowed_actions": ["*"]} + } + + async def authorize(self, user, resource, action, resource_id=None, **context): + tier = user.get("tier", "free") + tier_config = self.TIER_LIMITS.get(tier, self.TIER_LIMITS["free"]) + + # Check if action is allowed for this tier + allowed_actions = tier_config["allowed_actions"] + if "*" not in allowed_actions and action not in allowed_actions: + return False + + # Check rate limits (simplified) + # In production, use Redis or similar for rate limiting + request_count = await self.get_request_count(user["user_id"]) + max_requests = tier_config["max_requests_per_day"] + if max_requests > 0 and request_count >= max_requests: + return False + + return True + + async def get_request_count(self, user_id): + # Check request count from cache/database + return 50 # Example + +authorization_backend = TierBasedBackend() +``` + +## Authorization Check Points + +Authorization is checked at the following points in the API: + +### Graph Endpoints + +| Endpoint | Resource | Action | Resource ID | +|----------|----------|--------|-------------| +| `POST /v1/graph/invoke` | `graph` | `invoke` | `thread_id` (if provided) | +| `POST /v1/graph/stream` | `graph` | `stream` | `thread_id` (if provided) | +| `GET /v1/graph` | `graph` | `read` | None | +| `GET /v1/graph:StateSchema` | `graph` | `read` | None | +| `POST /v1/graph/stop` | `graph` | `stop` | `thread_id` | +| `POST /v1/graph/setup` | `graph` | `setup` | None | +| `POST /v1/graph/fix` | `graph` | `fix` | `thread_id` | + +### Checkpointer Endpoints + +| Endpoint | Resource | Action | Resource ID | +|----------|----------|--------|-------------| +| `GET /v1/threads/{thread_id}/state` | `checkpointer` | `read` | `thread_id` | +| `PUT /v1/threads/{thread_id}/state` | `checkpointer` | `write` | `thread_id` | +| `DELETE /v1/threads/{thread_id}/state` | `checkpointer` | `delete` | `thread_id` | +| `POST /v1/threads/{thread_id}/messages` | `checkpointer` | `write` | `thread_id` | +| `GET /v1/threads/{thread_id}/messages/{message_id}` | `checkpointer` | `read` | `thread_id` | +| `GET /v1/threads/{thread_id}/messages` | `checkpointer` | `read` | `thread_id` | +| `DELETE /v1/threads/{thread_id}/messages/{message_id}` | `checkpointer` | `delete` | `thread_id` | +| `GET /v1/threads/{thread_id}` | `checkpointer` | `read` | `thread_id` | +| `GET /v1/threads` | `checkpointer` | `read` | None | +| `DELETE /v1/threads/{thread_id}` | `checkpointer` | `delete` | `thread_id` | + +### Store Endpoints + +| Endpoint | Resource | Action | Resource ID | +|----------|----------|--------|-------------| +| `POST /v1/store/memories` | `store` | `write` | `namespace` (if provided) | +| `POST /v1/store/search` | `store` | `read` | `namespace` (if provided) | +| `POST /v1/store/memories/{memory_id}` | `store` | `read` | `memory_id` | +| `POST /v1/store/memories/list` | `store` | `read` | None | +| `PUT /v1/store/memories/{memory_id}` | `store` | `write` | `memory_id` | +| `DELETE /v1/store/memories/{memory_id}` | `store` | `delete` | `memory_id` | +| `POST /v1/store/memories/forget` | `store` | `delete` | `namespace` (if provided) | + +## Best Practices + +### 1. Start Simple, Scale Up + +Begin with the `DefaultAuthorizationBackend` and add complexity as needed: + +```python +# Stage 1: Default (all authenticated users) +# No configuration needed - this is the default + +# Stage 2: Simple role check +class SimpleRoleBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + return user.get("role") in ["admin", "user"] + +# Stage 3: Resource-specific logic +# Add conditions based on resource type + +# Stage 4: Fine-grained ABAC +# Consider attributes, context, time, location, etc. +``` + +### 2. Fail Secure (Deny by Default) + +Always return `False` at the end of your authorization logic: + +```python +async def authorize(self, user, resource, action, resource_id=None, **context): + # Check various conditions + if condition1: + return True + if condition2: + return True + + # Deny everything else + return False +``` + +### 3. Log Authorization Failures + +Add logging to help debug authorization issues: + +```python +import logging + +logger = logging.getLogger(__name__) + +class MyAuthBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + user_id = user.get("user_id") + + if not self.has_permission(user, resource, action): + logger.warning( + f"Authorization denied: user={user_id}, resource={resource}, " + f"action={action}, resource_id={resource_id}" + ) + return False + + logger.info(f"Authorization granted: user={user_id}, action={action}") + return True +``` + +### 4. Cache Authorization Decisions + +For expensive authorization checks (e.g., database queries), implement caching: + +```python +from functools import lru_cache +from datetime import datetime, timedelta + +class CachedAuthBackend(AuthorizationBackend): + def __init__(self): + self.cache = {} + self.cache_ttl = timedelta(minutes=5) + + async def authorize(self, user, resource, action, resource_id=None, **context): + # Create cache key + cache_key = f"{user['user_id']}:{resource}:{action}:{resource_id}" + + # Check cache + if cache_key in self.cache: + cached_value, cached_time = self.cache[cache_key] + if datetime.now() - cached_time < self.cache_ttl: + return cached_value + + # Compute authorization + result = await self._compute_authorization(user, resource, action, resource_id) + + # Cache result + self.cache[cache_key] = (result, datetime.now()) + + return result + + async def _compute_authorization(self, user, resource, action, resource_id): + # Your actual authorization logic here + return True + +authorization_backend = CachedAuthBackend() +``` + +### 5. Separate Concerns + +Keep authorization logic separate from business logic: + +```python +# ✅ Good: Authorization in backend +class MyAuthBackend(AuthorizationBackend): + async def authorize(self, user, resource, action, resource_id=None, **context): + return user.get("role") == "admin" + +# ❌ Bad: Don't add authorization to your graph/service code +# The framework handles this automatically +``` + +### 6. Use Type Hints + +Make your code maintainable with proper type hints: + +```python +from typing import Any, Optional + +class TypedAuthBackend(AuthorizationBackend): + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: Optional[str] = None, + **context: Any, + ) -> bool: + # Implementation + return True +``` + +## Testing Authorization + +### Unit Testing + +```python +import pytest +from my_auth.authorization import RoleBasedAuthorizationBackend + +@pytest.fixture +def auth_backend(): + return RoleBasedAuthorizationBackend() + +@pytest.mark.asyncio +async def test_admin_can_do_anything(auth_backend): + user = {"user_id": "admin-1", "role": "admin"} + assert await auth_backend.authorize(user, "graph", "invoke") == True + assert await auth_backend.authorize(user, "checkpointer", "delete", "thread-1") == True + +@pytest.mark.asyncio +async def test_viewer_cannot_write(auth_backend): + user = {"user_id": "viewer-1", "role": "viewer"} + assert await auth_backend.authorize(user, "graph", "read") == True + assert await auth_backend.authorize(user, "checkpointer", "write", "thread-1") == False + +@pytest.mark.asyncio +async def test_user_can_modify_own_resources(auth_backend): + user = {"user_id": "user-1", "role": "user"} + assert await auth_backend.authorize(user, "graph", "invoke") == True + assert await auth_backend.authorize(user, "checkpointer", "write", "thread-1") == True +``` + +### Integration Testing + +```python +from fastapi.testclient import TestClient +from my_agent.main import app + +client = TestClient(app) + +def test_authorized_access(): + response = client.post( + "/v1/graph/invoke", + headers={"Authorization": "Bearer admin-token"}, + json={"messages": [{"role": "user", "content": "test"}]} + ) + assert response.status_code == 200 + +def test_unauthorized_access(): + response = client.post( + "/v1/graph/stop", + headers={"Authorization": "Bearer viewer-token"}, + json={"thread_id": "123"} + ) + assert response.status_code == 403 + assert "Not authorized" in response.json()["error"]["message"] +``` + +## Troubleshooting + +### Authorization Always Returns 403 + +**Problem**: All requests return 403 Forbidden + +**Solutions**: +1. Check that your authorization backend is properly configured in `agentflow.json` +2. Verify the module path is correct: `"module.path:attribute_name"` +3. Ensure your `authorize()` method returns `True` for the expected cases +4. Add logging to see what values are being passed to `authorize()` +5. Check that the user object has the expected fields (e.g., `user_id`, `role`) + +```python +# Add debug logging +async def authorize(self, user, resource, action, resource_id=None, **context): + print(f"DEBUG: user={user}, resource={resource}, action={action}") + # Your logic here +``` + +### Authorization Not Being Called + +**Problem**: Authorization backend never gets invoked + +**Solutions**: +1. Verify `authorization` is set in `agentflow.json` +2. Check that the file path exists and is importable +3. Ensure the attribute name matches the variable name in your module +4. Restart the API server after configuration changes + +### Performance Issues + +**Problem**: Authorization checks are slow + +**Solutions**: +1. Implement caching for expensive operations (database queries, API calls) +2. Use Redis for distributed caching in production +3. Minimize external API calls in authorization logic +4. Consider pre-computing permissions and storing them with the user +5. Use async operations properly (don't block with sync calls) + +## Security Considerations + +### 1. Never Trust User Input + +Always validate user-provided data in authorization logic: + +```python +async def authorize(self, user, resource, action, resource_id=None, **context): + # ✅ Good: Validate inputs + if not user or not user.get("user_id"): + return False + + # ❌ Bad: Trusting user-provided role + # The role should come from your database, not the token + if user.get("role") == "admin": # What if user modified their token? + return True +``` + +### 2. Keep Authorization Logic Server-Side + +Never implement authorization checks in client-side code - they can be bypassed. + +### 3. Use the Principle of Least Privilege + +Grant users the minimum permissions they need: + +```python +# ✅ Good: Specific permissions +if action == "read" and resource == "checkpointer": + return True + +# ❌ Bad: Overly permissive +return True # Everyone can do everything +``` + +### 4. Audit Authorization Decisions + +Log authorization decisions for security audits: + +```python +async def authorize(self, user, resource, action, resource_id=None, **context): + decision = self._make_decision(user, resource, action, resource_id) + + # Log to audit trail + await self.audit_log.record({ + "timestamp": datetime.now(), + "user_id": user.get("user_id"), + "resource": resource, + "action": action, + "resource_id": resource_id, + "decision": "allow" if decision else "deny" + }) + + return decision +``` + +## Production Checklist + +Before deploying to production, ensure: + +- [ ] Custom authorization backend is implemented and tested +- [ ] Authorization is configured in `agentflow.json` +- [ ] All expected user roles/permissions are defined +- [ ] Authorization decisions are logged for audit +- [ ] Performance testing completed (especially for expensive checks) +- [ ] Edge cases are handled (missing user_id, null values, etc.) +- [ ] Fail-secure pattern is implemented (deny by default) +- [ ] Sensitive data is not logged (use log sanitization) +- [ ] Rate limiting is considered for API abuse prevention +- [ ] Backup authorization mechanism exists (if primary fails) + +## Additional Resources + +- [Authentication Guide](authentication.md) - Set up user authentication +- [Configuration Guide](configuration.md) - Configure your AgentFlow application +- [Deployment Guide](deployment.md) - Deploy securely to production +- [Security Review](../SECURITY_REVIEW.md) - Framework security analysis +- [Security Action Plan](../SECURITY_ACTION_PLAN.md) - Security improvement roadmap + +## Examples Repository + +Complete authorization examples are available in the `examples/authorization/` directory: + +- `examples/authorization/rbac.py` - Role-based access control +- `examples/authorization/abac.py` - Attribute-based access control +- `examples/authorization/ownership.py` - Owner-based access control +- `examples/authorization/tier_based.py` - API tier-based access control + +## Support + +If you have questions about implementing authorization: + +1. Check the [FAQ section](#faq) below +2. Review the example implementations in `examples/authorization/` +3. Open an issue on GitHub with the `authorization` label +4. Join our community Discord for real-time help + +## FAQ + +**Q: Can I have multiple authorization backends?** + +A: Not directly, but you can create a composite backend that delegates to multiple backends: + +```python +class CompositeAuthBackend(AuthorizationBackend): + def __init__(self, backends: list[AuthorizationBackend]): + self.backends = backends + + async def authorize(self, user, resource, action, resource_id=None, **context): + # All backends must approve + for backend in self.backends: + if not await backend.authorize(user, resource, action, resource_id, **context): + return False + return True + +authorization_backend = CompositeAuthBackend([ + RBACBackend(), + RateLimitBackend(), + OwnershipBackend() +]) +``` + +**Q: Can I disable authorization for development?** + +A: Yes, simply don't configure the `authorization` field in `agentflow.json`. The framework will use `DefaultAuthorizationBackend` which allows all authenticated users. + +**Q: How do I handle anonymous access?** + +A: Authorization only runs for authenticated requests. If you want to allow some endpoints without authentication, you'll need to modify your authentication backend to return a default user for unauthenticated requests. + +**Q: Can authorization decisions be async?** + +A: Yes! The `authorize()` method is async, so you can make database queries, API calls, or any async operation: + +```python +async def authorize(self, user, resource, action, resource_id=None, **context): + # Async database query + permissions = await self.db.get_user_permissions(user["user_id"]) + return action in permissions +``` + +**Q: What's the performance impact of authorization?** + +A: Minimal if implemented correctly. Simple checks (role comparison) add <1ms. Database queries should be cached. For complex logic, use caching with TTL. + +**Q: Can I use the request object in authorization?** + +A: The request object is not directly passed, but you can access context via the `**context` parameter. For example, the graph input is passed as `input=graph_input` in graph authorization calls. diff --git a/examples/security/README.md b/examples/security/README.md new file mode 100644 index 0000000..5816473 --- /dev/null +++ b/examples/security/README.md @@ -0,0 +1,116 @@ +# Security Examples + +This directory contains comprehensive examples for implementing security features in AgentFlow CLI applications. + +## Examples Overview + +### Authentication Examples +1. **[jwt_auth_example.py](./jwt_auth_example.py)** - Built-in JWT authentication setup +2. **[api_key_auth.py](./api_key_auth.py)** - Custom API key authentication backend +3. **[oauth2_auth.py](./oauth2_auth.py)** - OAuth2 authentication with external providers + +### Authorization Examples +4. **[rbac_authorization.py](./rbac_authorization.py)** - Role-Based Access Control (RBAC) +5. **[ownership_authorization.py](./ownership_authorization.py)** - Resource ownership-based authorization +6. **[abac_authorization.py](./abac_authorization.py)** - Attribute-Based Access Control (ABAC) + +### Configuration Examples +7. **[production_config/](./production_config/)** - Secure production configuration templates + - agentflow.json - Production configuration + - .env.production - Environment variables + - docker-compose.yml - Docker deployment + - nginx.conf - Nginx reverse proxy with security headers + +## Quick Start + +### 1. JWT Authentication + +**Step 1:** Configure agentflow.json +```json +{ + "auth": "jwt", + "agent": "graph.react:app" +} +``` + +**Step 2:** Set environment variables +```bash +export JWT_SECRET_KEY=$(python -c "import secrets; print(secrets.token_urlsafe(32))") +export JWT_ALGORITHM=HS256 +export JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30 +``` + +**Step 3:** Start the application +```bash +agentflow api +``` + +### 2. Custom Authentication + +**Step 1:** Create auth backend (see [api_key_auth.py](./api_key_auth.py)) + +**Step 2:** Configure agentflow.json +```json +{ + "auth": { + "method": "custom", + "path": "auth.api_key:APIKeyAuth" + } +} +``` + +### 3. Authorization + +**Step 1:** Create authorization backend (see [rbac_authorization.py](./rbac_authorization.py)) + +**Step 2:** Configure agentflow.json +```json +{ + "auth": "jwt", + "authorization": { + "path": "auth.rbac_backend:RBACAuthorizationBackend" + } +} +``` + +## Testing Examples + +Each example includes test cases. Run them with: + +```bash +# Install test dependencies +pip install pytest pytest-asyncio httpx + +# Run tests +pytest examples/security/ -v +``` + +## Production Deployment + +See the [production_config](./production_config/) directory for: +- Complete production configuration +- Docker deployment setup +- Nginx configuration with security headers +- Kubernetes manifests (coming soon) + +## Security Best Practices + +1. **Never commit secrets** - Use environment variables or secret managers +2. **Use HTTPS in production** - Always encrypt traffic +3. **Implement rate limiting** - Prevent brute force and DoS attacks +4. **Monitor security events** - Log authentication/authorization failures +5. **Regular updates** - Keep dependencies up to date +6. **Security testing** - Include security tests in CI/CD pipeline + +## Additional Resources + +- [SECURITY.md](../../SECURITY.md) - Complete security guide +- [Authentication Guide](../../docs/authentication.md) - Detailed authentication documentation +- [Deployment Guide](../../docs/deployment.md) - Production deployment guide + +## Questions? + +For questions or issues, please: +- Check the [SECURITY.md](../../SECURITY.md) guide +- Open an issue on GitHub +- Email: security@10xhub.com diff --git a/examples/security/api_key_auth.py b/examples/security/api_key_auth.py new file mode 100644 index 0000000..2771b48 --- /dev/null +++ b/examples/security/api_key_auth.py @@ -0,0 +1,338 @@ +""" +API Key Authentication Backend Example + +This example shows how to implement a custom API key authentication backend +for AgentFlow CLI applications. + +Use Case: +- Service-to-service authentication +- Third-party API integrations +- Simpler alternative to OAuth2 + +Setup: +1. Create this file in your project (e.g., auth/api_key.py) +2. Configure agentflow.json to use this backend +3. Set API_KEYS environment variable +""" + +import os +import hashlib +import hmac +from datetime import datetime +from typing import Any + +from fastapi import HTTPException, Response, status +from fastapi.security import HTTPAuthorizationCredentials + +# Import the base authentication class +# In your project: from agentflow_cli import BaseAuth +from agentflow_cli.src.app.core.auth.base_auth import BaseAuth + + +class APIKeyAuth(BaseAuth): + """ + API Key authentication backend. + + Validates API keys from Authorization header and returns user context. + Supports both plain keys and hashed keys for better security. + """ + + def __init__(self): + """Initialize API key authentication.""" + # Load API keys from environment + # Format: key1,key2,key3 or key1:user1,key2:user2 + keys_string = os.getenv("API_KEYS", "") + + if not keys_string: + print("WARNING: No API keys configured. Set API_KEYS environment variable.") + + # Parse API keys and associated metadata + self.valid_keys: dict[str, dict[str, Any]] = {} + + for key_entry in keys_string.split(","): + if ":" in key_entry: + # Format: key:user_id:role + parts = key_entry.strip().split(":") + key = parts[0] + user_id = parts[1] if len(parts) > 1 else f"api_key_{key[:8]}" + role = parts[2] if len(parts) > 2 else "user" + + self.valid_keys[key] = {"user_id": user_id, "role": role, "key_prefix": key[:8]} + else: + # Format: key (use prefix as user_id) + key = key_entry.strip() + if key: + self.valid_keys[key] = { + "user_id": f"api_key_{key[:8]}", + "role": "user", + "key_prefix": key[:8], + } + + print(f"Loaded {len(self.valid_keys)} API keys") + + async def authenticate( + self, credentials: HTTPAuthorizationCredentials, response: Response + ) -> dict[str, str]: + """ + Authenticate request using API key. + + Args: + credentials: HTTP authorization credentials (Bearer token) + response: FastAPI response object + + Returns: + User context dictionary with user_id and metadata + + Raises: + HTTPException: If API key is invalid or missing + """ + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required", + headers={"WWW-Authenticate": "Bearer"}, + ) + + api_key = credentials.credentials + + # Validate API key exists and is not empty + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key cannot be empty", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Check if API key is valid + key_metadata = self.valid_keys.get(api_key) + + if not key_metadata: + # Log failed attempt (without exposing the key) + print(f"Invalid API key attempt: {api_key[:4]}...{api_key[-4:]}") + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Log successful authentication + print(f"API key authenticated: {key_metadata['user_id']}") + + # Return user context + return { + "user_id": key_metadata["user_id"], + "auth_method": "api_key", + "role": key_metadata["role"], + "key_prefix": key_metadata["key_prefix"], + "authenticated_at": datetime.utcnow().isoformat(), + } + + def extract_user_id(self, user: dict[str, str]) -> str | None: + """ + Extract user ID from user context. + + Args: + user: User context dictionary + + Returns: + User ID string or None + """ + return user.get("user_id") + + +class HashedAPIKeyAuth(BaseAuth): + """ + Hashed API Key authentication - more secure variant. + + Stores hashed versions of API keys instead of plain text. + Recommended for production use. + """ + + def __init__(self): + """Initialize hashed API key authentication.""" + # Load hashed API keys from environment + # Format: hashed_key1:user1,hashed_key2:user2 + keys_string = os.getenv("API_KEY_HASHES", "") + + self.valid_key_hashes: dict[str, dict[str, Any]] = {} + + for key_entry in keys_string.split(","): + if ":" in key_entry: + parts = key_entry.strip().split(":") + key_hash = parts[0] + user_id = parts[1] if len(parts) > 1 else f"user_{key_hash[:8]}" + role = parts[2] if len(parts) > 2 else "user" + + self.valid_key_hashes[key_hash] = {"user_id": user_id, "role": role} + + print(f"Loaded {len(self.valid_key_hashes)} hashed API keys") + + def hash_key(self, api_key: str) -> str: + """ + Hash an API key using SHA-256. + + Args: + api_key: Plain text API key + + Returns: + Hexadecimal hash of the API key + """ + return hashlib.sha256(api_key.encode()).hexdigest() + + async def authenticate( + self, credentials: HTTPAuthorizationCredentials, response: Response + ) -> dict[str, str]: + """Authenticate using hashed API key.""" + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required", + headers={"WWW-Authenticate": "Bearer"}, + ) + + api_key = credentials.credentials + key_hash = self.hash_key(api_key) + + # Check if hash matches + key_metadata = self.valid_key_hashes.get(key_hash) + + if not key_metadata: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return { + "user_id": key_metadata["user_id"], + "auth_method": "hashed_api_key", + "role": key_metadata["role"], + "authenticated_at": datetime.utcnow().isoformat(), + } + + def extract_user_id(self, user: dict[str, str]) -> str | None: + """Extract user ID from user context.""" + return user.get("user_id") + + +# Helper function to generate API keys +def generate_api_key(prefix: str = "ak", length: int = 32) -> str: + """ + Generate a secure random API key. + + Args: + prefix: Key prefix for identification + length: Length of random portion + + Returns: + Generated API key + """ + import secrets + + random_part = secrets.token_urlsafe(length) + return f"{prefix}_{random_part}" + + +# Helper function to hash existing keys +def hash_api_key(api_key: str) -> str: + """ + Hash an API key for storage. + + Args: + api_key: Plain text API key + + Returns: + SHA-256 hash of the key + """ + return hashlib.sha256(api_key.encode()).hexdigest() + + +""" +CONFIGURATION: + +# agentflow.json +{ + "auth": { + "method": "custom", + "path": "auth.api_key:APIKeyAuth" + }, + "agent": "graph.react:app" +} + +# .env - Plain text keys (development only) +API_KEYS=ak_dev123:user_dev:developer,ak_admin456:user_admin:admin + +# .env - Hashed keys (production) +API_KEY_HASHES=:user1:admin,:user2:developer +""" + +""" +USAGE EXAMPLE: + +# Generate a new API key +python -c "from api_key_auth import generate_api_key; print(generate_api_key())" +# Output: ak_xK9mP2nQ7vL3wR8dF5hJ1bN4cT6yU0zA + +# Hash an existing key for production +python -c "from api_key_auth import hash_api_key; print(hash_api_key('ak_dev123'))" +# Output: 5d41402abc4b2a76b9719d911017c592... + +# Use API key in requests +curl http://localhost:8000/graph/invoke \ + -H "Authorization: Bearer ak_dev123" \ + -H "Content-Type: application/json" \ + -d '{"input": {"message": "hello"}}' +""" + +""" +SECURITY BEST PRACTICES: + +1. Key Generation: + - Use cryptographically secure random generation + - Minimum 32 characters + - Include prefix for identification + +2. Key Storage: + - Store hashed keys in production + - Use environment variables or secret manager + - Never commit keys to version control + +3. Key Rotation: + - Rotate keys regularly (e.g., every 90 days) + - Support multiple active keys during rotation + - Revoke compromised keys immediately + +4. Usage: + - Use HTTPS only + - Implement rate limiting + - Log all authentication attempts + - Monitor for suspicious patterns + +5. Access Control: + - Associate keys with specific users/services + - Implement key-level permissions + - Separate keys for different environments +""" + +""" +TESTING: + +# tests/test_api_key_auth.py +import pytest +from fastapi.testclient import TestClient +from auth.api_key import APIKeyAuth, generate_api_key + +def test_valid_api_key(): + # Test with valid key + pass + +def test_invalid_api_key(): + # Test with invalid key + pass + +def test_missing_api_key(): + # Test without Authorization header + pass + +def test_key_generation(): + key = generate_api_key() + assert key.startswith("ak_") + assert len(key) > 32 +""" diff --git a/examples/security/jwt_auth_example.py b/examples/security/jwt_auth_example.py new file mode 100644 index 0000000..a16242e --- /dev/null +++ b/examples/security/jwt_auth_example.py @@ -0,0 +1,259 @@ +""" +JWT Authentication Example + +This example demonstrates how to use the built-in JWT authentication +in AgentFlow CLI applications. + +Setup: +1. Configure agentflow.json with "auth": "jwt" +2. Set JWT_SECRET_KEY environment variable +3. Run the application + +Usage: +1. Generate JWT token with your authentication service +2. Include token in Authorization header +3. Access protected endpoints +""" + +import os +from datetime import datetime, timedelta +from typing import Any + +import jwt +from fastapi import FastAPI, Depends, HTTPException, status +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel + +# Example configuration +JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-secret-key-change-in-production") +JWT_ALGORITHM = os.getenv("JWT_ALGORITHM", "HS256") +JWT_ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "30")) + +# Security scheme +security = HTTPBearer() + + +class TokenData(BaseModel): + """JWT token payload structure.""" + + user_id: str + email: str + role: str = "user" + exp: datetime | None = None + + +class LoginRequest(BaseModel): + """Login request structure.""" + + username: str + password: str + + +class TokenResponse(BaseModel): + """Token response structure.""" + + access_token: str + token_type: str = "bearer" + expires_in: int + + +def create_access_token(data: dict[str, Any], expires_delta: timedelta | None = None) -> str: + """ + Create a JWT access token. + + Args: + data: Token payload data + expires_delta: Token expiration time + + Returns: + Encoded JWT token string + """ + to_encode = data.copy() + + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=JWT_ACCESS_TOKEN_EXPIRE_MINUTES) + + to_encode.update({"exp": expire}) + + encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) + return encoded_jwt + + +def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict[str, Any]: + """ + Verify JWT token and extract user data. + + Args: + credentials: HTTP authorization credentials + + Returns: + User data from token + + Raises: + HTTPException: If token is invalid or expired + """ + token = credentials.credentials + + try: + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + user_id = payload.get("user_id") + + if user_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return payload + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + headers={"WWW-Authenticate": "Bearer"}, + ) + except jwt.JWTError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +# Example FastAPI application +app = FastAPI(title="JWT Auth Example") + + +@app.post("/auth/login", response_model=TokenResponse) +async def login(request: LoginRequest): + """ + Login endpoint - generates JWT token. + + In production: + - Validate credentials against database + - Hash password comparison + - Rate limiting + - Account lockout after failed attempts + """ + # Example: Hardcoded user (replace with database lookup) + if request.username == "admin" and request.password == "secret": + # Create token data + token_data = {"user_id": "user_123456789", "email": "admin@example.com", "role": "admin"} + + # Generate token + access_token = create_access_token(token_data) + + return TokenResponse( + access_token=access_token, expires_in=JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + ) + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password" + ) + + +@app.get("/auth/me") +async def get_current_user(user: dict = Depends(verify_token)): + """ + Get current user information from token. + + Protected endpoint - requires valid JWT token. + """ + return {"user_id": user.get("user_id"), "email": user.get("email"), "role": user.get("role")} + + +@app.get("/protected") +async def protected_endpoint(user: dict = Depends(verify_token)): + """ + Example protected endpoint. + + This endpoint requires a valid JWT token in the Authorization header. + """ + return { + "message": f"Hello {user.get('email')}!", + "user_id": user.get("user_id"), + "timestamp": datetime.utcnow().isoformat(), + } + + +# Example usage +if __name__ == "__main__": + import uvicorn + + print("Starting JWT Auth Example Server") + print(f"Secret Key: {JWT_SECRET_KEY[:10]}... (truncated)") + print(f"Algorithm: {JWT_ALGORITHM}") + print(f"Token Expiration: {JWT_ACCESS_TOKEN_EXPIRE_MINUTES} minutes") + print("\nTest the API:") + print("1. Login: POST http://localhost:8000/auth/login") + print(' Body: {"username": "admin", "password": "secret"}') + print("2. Access: GET http://localhost:8000/protected") + print(" Header: Authorization: Bearer ") + + uvicorn.run(app, host="0.0.0.0", port=8000) + + +""" +TESTING EXAMPLE: + +# 1. Start the server +python jwt_auth_example.py + +# 2. Login to get token +curl -X POST http://localhost:8000/auth/login \ + -H "Content-Type: application/json" \ + -d '{"username": "admin", "password": "secret"}' + +# Response: +{ + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...", + "token_type": "bearer", + "expires_in": 1800 +} + +# 3. Use token to access protected endpoint +export TOKEN="" +curl http://localhost:8000/protected \ + -H "Authorization: Bearer $TOKEN" + +# Response: +{ + "message": "Hello admin@example.com!", + "user_id": "user_123456789", + "timestamp": "2025-12-31T12:00:00" +} + +# 4. Get current user info +curl http://localhost:8000/auth/me \ + -H "Authorization: Bearer $TOKEN" +""" + +""" +PRODUCTION CONFIGURATION: + +# .env +JWT_SECRET_KEY= +JWT_ALGORITHM=HS256 +JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# agentflow.json +{ + "auth": "jwt", + "agent": "graph.react:app" +} + +# Security Recommendations: +1. Generate strong secret: python -c "import secrets; print(secrets.token_urlsafe(32))" +2. Use environment variables for secrets (never hardcode) +3. Set short token expiration (15-30 minutes) +4. Implement refresh token mechanism for long sessions +5. Use HTTPS in production +6. Implement rate limiting on login endpoint +7. Add account lockout after failed attempts +8. Log authentication events +9. Consider token blacklist for logout +10. Rotate secrets regularly +""" diff --git a/examples/security/production_config/.env.production.example b/examples/security/production_config/.env.production.example new file mode 100644 index 0000000..2a54da1 --- /dev/null +++ b/examples/security/production_config/.env.production.example @@ -0,0 +1,134 @@ +# Production Environment Configuration +# NEVER commit this file to version control! +# Add .env.production to .gitignore + +# ============================================================================= +# APPLICATION MODE +# ============================================================================= +MODE=production + +# ============================================================================= +# SECURITY - JWT AUTHENTICATION +# ============================================================================= +# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))" +JWT_SECRET_KEY=REPLACE_WITH_STRONG_SECRET_KEY_32_CHARS_MINIMUM +JWT_ALGORITHM=HS256 +JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30 + +# ============================================================================= +# CORS CONFIGURATION +# ============================================================================= +# IMPORTANT: Use specific domains, NOT wildcard (*) +# Comma-separated list of allowed origins +ORIGINS=https://yourdomain.com,https://app.yourdomain.com + +# ============================================================================= +# NETWORKING +# ============================================================================= +HOST=0.0.0.0 +PORT=8000 + +# Specific allowed hosts (NOT wildcard *) +ALLOWED_HOST=yourdomain.com,app.yourdomain.com + +# ============================================================================= +# API DOCUMENTATION +# ============================================================================= +# IMPORTANT: Disable API documentation in production +DOCS_PATH= +REDOCS_PATH= + +# ============================================================================= +# DEBUGGING +# ============================================================================= +# IMPORTANT: Disable debug mode in production +IS_DEBUG=false + +# ============================================================================= +# REQUEST LIMITS +# ============================================================================= +# Maximum request body size in bytes (default: 10MB) +MAX_REQUEST_SIZE=10485760 + +# ============================================================================= +# SECURITY HEADERS +# ============================================================================= +# Enable security headers middleware (default: true) +SECURITY_HEADERS_ENABLED=true + +# HSTS (HTTP Strict Transport Security) - only for HTTPS +HSTS_ENABLED=true +HSTS_MAX_AGE=31536000 # 1 year in seconds +HSTS_INCLUDE_SUBDOMAINS=true +HSTS_PRELOAD=false + +# Other security headers (using defaults shown) +FRAME_OPTIONS=DENY +CONTENT_TYPE_OPTIONS=nosniff +XSS_PROTECTION=1; mode=block +REFERRER_POLICY=strict-origin-when-cross-origin + +# Advanced: Custom Content Security Policy (optional) +# CSP_POLICY=default-src 'self'; script-src 'self' https://cdn.example.com + +# Advanced: Custom Permissions Policy (optional) +# PERMISSIONS_POLICY=geolocation=(), microphone=(), camera=(), payment=() +MAX_REQUEST_SIZE=10485760 + +# ============================================================================= +# REDIS (Optional - for checkpointer/caching) +# ============================================================================= +# REDIS_URL=redis://redis:6379/0 +# REDIS_PASSWORD=your_redis_password + +# ============================================================================= +# DATABASE (Optional - if using persistence) +# ============================================================================= +# DATABASE_URL=postgresql://user:password@db:5432/dbname +# DATABASE_POOL_SIZE=10 +# DATABASE_MAX_OVERFLOW=20 + +# ============================================================================= +# LANGSMITH (Optional - for LangChain tracing) +# ============================================================================= +# LANGCHAIN_TRACING_V2=true +# LANGCHAIN_ENDPOINT=https://api.smith.langchain.com +# LANGCHAIN_API_KEY=your_langsmith_api_key +# LANGCHAIN_PROJECT=your_project_name + +# ============================================================================= +# OPENAI (Optional - if using OpenAI) +# ============================================================================= +# OPENAI_API_KEY=sk-... + +# ============================================================================= +# ANTHROPIC (Optional - if using Claude) +# ============================================================================= +# ANTHROPIC_API_KEY=sk-ant-... + +# ============================================================================= +# LOGGING +# ============================================================================= +LOG_LEVEL=INFO + +# ============================================================================= +# MONITORING (Optional) +# ============================================================================= +# Sentry for error tracking +# SENTRY_DSN=https://...@sentry.io/... + +# Datadog for metrics +# DD_AGENT_HOST=datadog-agent +# DD_TRACE_ENABLED=true + +# ============================================================================= +# SNOWFLAKE ID GENERATION (Optional) +# ============================================================================= +# SNOWFLAKE_EPOCH=1704067200000 +# SNOWFLAKE_NODE_ID=1 +# SNOWFLAKE_WORKER_ID=1 + +# ============================================================================= +# CUSTOM APPLICATION SETTINGS +# ============================================================================= +# Add your application-specific environment variables here diff --git a/examples/security/production_config/README.md b/examples/security/production_config/README.md new file mode 100644 index 0000000..5c41abd --- /dev/null +++ b/examples/security/production_config/README.md @@ -0,0 +1,401 @@ +# Production Configuration Setup Guide + +This directory contains production-ready configuration examples for deploying AgentFlow CLI applications securely. + +## Files Overview + +- **agentflow.json** - Production application configuration with JWT auth and RBAC +- **.env.production.example** - Complete environment variables template +- **docker-compose.yml** - Production Docker deployment with security hardening +- **nginx.conf** - Nginx reverse proxy with SSL/TLS and security headers +- **README.md** - This file + +## Quick Start + +### 1. Copy Configuration Files + +```bash +# Copy to your project root +cp agentflow.json /path/to/your/project/ +cp .env.production.example /path/to/your/project/.env.production +cp docker-compose.yml /path/to/your/project/ +cp nginx.conf /path/to/your/project/ +``` + +### 2. Configure Environment Variables + +```bash +# Edit .env.production +nano .env.production + +# Generate JWT secret +python -c "import secrets; print(secrets.token_urlsafe(32))" + +# Set the generated key in .env.production +JWT_SECRET_KEY= +``` + +### 3. Update Domain Names + +Update the following files with your domain: + +**docker-compose.yml:** +```yaml +environment: + - ORIGINS=https://yourdomain.com + - ALLOWED_HOST=yourdomain.com +``` + +**nginx.conf:** +```nginx +server_name yourdomain.com www.yourdomain.com; +``` + +### 4. SSL/TLS Certificates + +#### Option A: Let's Encrypt (Recommended) + +```bash +# Install certbot +sudo apt-get install certbot python3-certbot-nginx + +# Get certificate +sudo certbot --nginx -d yourdomain.com -d www.yourdomain.com + +# Certificates will be in: /etc/letsencrypt/live/yourdomain.com/ +``` + +#### Option B: Self-Signed (Development Only) + +```bash +# Create SSL directory +mkdir -p ssl + +# Generate self-signed certificate +openssl req -x509 -nodes -days 365 -newkey rsa:2048 \ + -keyout ssl/key.pem \ + -out ssl/cert.pem \ + -subj "/CN=yourdomain.com" +``` + +Update nginx.conf paths: +```nginx +ssl_certificate /etc/nginx/ssl/cert.pem; +ssl_certificate_key /etc/nginx/ssl/key.pem; +``` + +### 5. Create Authorization Backend + +Create your RBAC backend: + +```bash +# Create auth directory +mkdir -p auth + +# Copy RBAC example +cp ../rbac_authorization.py auth/rbac_backend.py +``` + +Or implement custom authorization (see examples). + +### 6. Deploy + +```bash +# Build and start services +docker-compose up -d --build + +# Check logs +docker-compose logs -f api + +# Verify health +curl https://yourdomain.com/ping +``` + +## Security Checklist + +Before going to production, verify: + +### Required Configuration +- [ ] `MODE=production` in .env.production +- [ ] Strong `JWT_SECRET_KEY` (32+ chars, random) +- [ ] `IS_DEBUG=false` +- [ ] Specific `ORIGINS` (not `*`) +- [ ] Specific `ALLOWED_HOST` (not `*`) +- [ ] `DOCS_PATH` and `REDOCS_PATH` are empty +- [ ] Valid SSL/TLS certificates installed + +### Network Security +- [ ] HTTPS enabled and working +- [ ] HTTP redirects to HTTPS +- [ ] Firewall rules configured +- [ ] Rate limiting active +- [ ] Security headers present + +### Application Security +- [ ] JWT authentication enabled +- [ ] Authorization backend implemented +- [ ] Request size limits configured +- [ ] Error messages sanitized +- [ ] Logs sanitized + +### Infrastructure +- [ ] Services run as non-root users +- [ ] Read-only filesystems where possible +- [ ] Resource limits configured +- [ ] Health checks working +- [ ] Logging configured +- [ ] Backups scheduled + +## Testing Production Configuration + +### 1. Test HTTP to HTTPS Redirect + +```bash +curl -I http://yourdomain.com +# Should return 301 redirect to https:// +``` + +### 2. Test Security Headers + +```bash +curl -I https://yourdomain.com +# Should include: +# Strict-Transport-Security +# X-Content-Type-Options +# X-Frame-Options +# Content-Security-Policy +``` + +### 3. Test Authentication + +```bash +# Should fail without token +curl https://yourdomain.com/graph/invoke +# Response: 401 Unauthorized + +# Should succeed with valid token +curl -H "Authorization: Bearer " \ + https://yourdomain.com/graph/invoke +``` + +### 4. Test Rate Limiting + +```bash +# Rapid requests should trigger rate limit +for i in {1..20}; do + curl https://yourdomain.com/ping +done +# Should eventually return 429 Too Many Requests +``` + +### 5. Test Request Size Limit + +```bash +# Large request should be rejected +dd if=/dev/zero bs=1M count=15 | \ + curl -X POST https://yourdomain.com/graph/invoke \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + --data-binary @- +# Should return 413 Request Entity Too Large +``` + +## Monitoring + +### View Logs + +```bash +# API logs +docker-compose logs -f api + +# Nginx logs +docker-compose logs -f nginx + +# Redis logs +docker-compose logs -f redis + +# All logs +docker-compose logs -f +``` + +### Check Service Health + +```bash +# All services +docker-compose ps + +# API health +curl https://yourdomain.com/ping + +# Redis health +docker-compose exec redis redis-cli ping +``` + +### Resource Usage + +```bash +# Container stats +docker stats + +# Disk usage +docker system df +``` + +## Backup and Recovery + +### Backup Redis Data + +```bash +# Create backup +docker-compose exec redis redis-cli BGSAVE + +# Copy backup file +docker cp $(docker-compose ps -q redis):/data/dump.rdb ./backup/ +``` + +### Restore Redis Data + +```bash +# Stop services +docker-compose down + +# Restore backup +docker cp ./backup/dump.rdb $(docker-compose ps -q redis):/data/ + +# Start services +docker-compose up -d +``` + +## Scaling + +### Horizontal Scaling + +Add multiple API instances: + +**docker-compose.yml:** +```yaml +api: + deploy: + replicas: 3 +``` + +Update nginx upstream: +```nginx +upstream api_backend { + least_conn; + server api:8000; + server api2:8000; + server api3:8000; +} +``` + +### Vertical Scaling + +Adjust resource limits: + +```yaml +deploy: + resources: + limits: + cpus: '4' + memory: 4G + reservations: + cpus: '2' + memory: 1G +``` + +## Troubleshooting + +### Issue: Services won't start + +```bash +# Check logs +docker-compose logs + +# Check configuration +docker-compose config + +# Validate environment +docker-compose exec api env +``` + +### Issue: SSL certificate errors + +```bash +# Verify certificates +openssl x509 -in ssl/cert.pem -text -noout + +# Test SSL configuration +openssl s_client -connect yourdomain.com:443 +``` + +### Issue: Rate limiting too strict + +Adjust nginx.conf: +```nginx +# Increase rate +limit_req_zone $binary_remote_addr zone=api:10m rate=20r/s; + +# Increase burst +limit_req zone=api burst=50 nodelay; +``` + +### Issue: High memory usage + +```bash +# Check memory usage +docker stats + +# Adjust limits +# Edit docker-compose.yml resources section +``` + +## Security Maintenance + +### Regular Tasks + +**Daily:** +- Monitor logs for suspicious activity +- Check service health + +**Weekly:** +- Review access logs +- Update rate limiting rules if needed +- Check resource usage trends + +**Monthly:** +- Rotate JWT secrets +- Update dependencies +- Review authorization rules +- Security audit + +**Quarterly:** +- SSL certificate renewal (if not auto-renewed) +- Penetration testing +- Security policy review + +### Update Dependencies + +```bash +# Pull latest images +docker-compose pull + +# Rebuild with latest base images +docker-compose build --no-cache + +# Restart services +docker-compose up -d +``` + +## Additional Resources + +- [SECURITY.md](../../../SECURITY.md) - Complete security guide +- [Deployment Guide](../../../docs/deployment.md) - Detailed deployment documentation +- [Configuration Guide](../../../docs/configuration.md) - All configuration options + +## Support + +For issues or questions: +- GitHub Issues: https://github.com/10xHub/agentflow-cli/issues +- Email: security@10xhub.com +- Documentation: https://10xhub.com/docs diff --git a/examples/security/production_config/agentflow.json b/examples/security/production_config/agentflow.json new file mode 100644 index 0000000..9c896c9 --- /dev/null +++ b/examples/security/production_config/agentflow.json @@ -0,0 +1,13 @@ +{ + "agent": "graph.react:app", + "env": ".env.production", + "auth": "jwt", + "authorization": { + "path": "auth.rbac_backend:RBACAuthorizationBackend" + }, + "checkpointer": null, + "injectq": null, + "store": null, + "redis": null, + "thread_name_generator": null +} diff --git a/examples/security/production_config/docker-compose.yml b/examples/security/production_config/docker-compose.yml new file mode 100644 index 0000000..4682fc7 --- /dev/null +++ b/examples/security/production_config/docker-compose.yml @@ -0,0 +1,120 @@ +version: "3.8" + +services: + # Main API service + api: + build: + context: ../../../.. + dockerfile: Dockerfile + ports: + - "8000:8000" + environment: + # Load from .env.production file + - MODE=production + - JWT_SECRET_KEY=${JWT_SECRET_KEY} + - ORIGINS=${ORIGINS} + - ALLOWED_HOST=${ALLOWED_HOST} + - IS_DEBUG=false + - DOCS_PATH= + - REDOCS_PATH= + - MAX_REQUEST_SIZE=10485760 + - REDIS_URL=redis://redis:6379/0 + env_file: + - .env.production + depends_on: + redis: + condition: service_healthy + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/ping"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + # Security: Read-only root filesystem + read_only: true + tmpfs: + - /tmp + # Security: Drop all capabilities + cap_drop: + - ALL + cap_add: + - NET_BIND_SERVICE + # Resource limits + deploy: + resources: + limits: + cpus: "2" + memory: 2G + reservations: + cpus: "1" + memory: 512M + networks: + - app-network + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + + # Redis for caching/checkpointing + redis: + image: redis:7-alpine + restart: unless-stopped + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 3s + retries: 3 + # Security: Run as non-root user + user: "999:999" + # Resource limits + deploy: + resources: + limits: + cpus: "0.5" + memory: 512M + networks: + - app-network + # Security: No external port exposure (only internal) + # Only accessible from api service + + # Nginx reverse proxy with SSL/TLS + nginx: + image: nginx:alpine + ports: + - "443:443" + - "80:80" + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf:ro + - ./ssl:/etc/nginx/ssl:ro + - nginx_cache:/var/cache/nginx:rw + depends_on: + api: + condition: service_healthy + restart: unless-stopped + networks: + - app-network + # Resource limits + deploy: + resources: + limits: + cpus: "0.5" + memory: 256M + logging: + driver: "json-file" + options: + max-size: "10m" + max-file: "3" + +volumes: + redis_data: + driver: local + nginx_cache: + driver: local + +networks: + app-network: + driver: bridge diff --git a/examples/security/production_config/nginx.conf b/examples/security/production_config/nginx.conf new file mode 100644 index 0000000..bb05e91 --- /dev/null +++ b/examples/security/production_config/nginx.conf @@ -0,0 +1,188 @@ +# Nginx Configuration for AgentFlow CLI - Production +# This configuration includes security headers, rate limiting, and SSL/TLS + +events { + worker_connections 1024; +} + +http { + include /etc/nginx/mime.types; + default_type application/octet-stream; + + # Logging + log_format main '$remote_addr - $remote_user [$time_local] "$request" ' + '$status $body_bytes_sent "$http_referer" ' + '"$http_user_agent" "$http_x_forwarded_for"'; + + access_log /var/log/nginx/access.log main; + error_log /var/log/nginx/error.log warn; + + # Performance + sendfile on; + tcp_nopush on; + tcp_nodelay on; + keepalive_timeout 65; + types_hash_max_size 2048; + + # Hide nginx version + server_tokens off; + + # Rate limiting zones + limit_req_zone $binary_remote_addr zone=api:10m rate=10r/s; + limit_req_zone $binary_remote_addr zone=auth:10m rate=5r/s; + limit_req_zone $binary_remote_addr zone=general:10m rate=100r/s; + + # Connection limiting + limit_conn_zone $binary_remote_addr zone=addr:10m; + + # Upstream backend + upstream api_backend { + server api:8000 max_fails=3 fail_timeout=30s; + keepalive 32; + } + + # Redirect HTTP to HTTPS + server { + listen 80; + listen [::]:80; + server_name yourdomain.com www.yourdomain.com; + + # Allow Let's Encrypt verification + location /.well-known/acme-challenge/ { + root /var/www/certbot; + } + + # Redirect all other traffic to HTTPS + location / { + return 301 https://$server_name$request_uri; + } + } + + # HTTPS server + server { + listen 443 ssl http2; + listen [::]:443 ssl http2; + server_name yourdomain.com www.yourdomain.com; + + # SSL Configuration + ssl_certificate /etc/nginx/ssl/cert.pem; + ssl_certificate_key /etc/nginx/ssl/key.pem; + + # Strong SSL protocols and ciphers + ssl_protocols TLSv1.2 TLSv1.3; + ssl_ciphers 'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384'; + ssl_prefer_server_ciphers on; + + # SSL session cache + ssl_session_cache shared:SSL:10m; + ssl_session_timeout 10m; + ssl_session_tickets off; + + # OCSP stapling + ssl_stapling on; + ssl_stapling_verify on; + resolver 8.8.8.8 8.8.4.4 valid=300s; + resolver_timeout 5s; + + # Security Headers + add_header Strict-Transport-Security "max-age=31536000; includeSubDomains; preload" always; + add_header X-Content-Type-Options "nosniff" always; + add_header X-Frame-Options "DENY" always; + add_header X-XSS-Protection "1; mode=block" always; + add_header Content-Security-Policy "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';" always; + add_header Referrer-Policy "strict-origin-when-cross-origin" always; + add_header Permissions-Policy "geolocation=(), microphone=(), camera=()" always; + + # Maximum request body size (aligned with API settings) + client_max_body_size 10M; + + # Timeouts + proxy_connect_timeout 60s; + proxy_send_timeout 60s; + proxy_read_timeout 60s; + + # Connection limits + limit_conn addr 10; + + # Health check endpoint (no rate limiting) + location /ping { + proxy_pass http://api_backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + access_log off; + } + + # Authentication endpoints (strict rate limiting) + location ~ ^/auth/ { + limit_req zone=auth burst=10 nodelay; + + proxy_pass http://api_backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Request-ID $request_id; + } + + # API endpoints (moderate rate limiting) + location ~ ^/(graph|checkpointer|store)/ { + limit_req zone=api burst=20 nodelay; + + proxy_pass http://api_backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Request-ID $request_id; + proxy_http_version 1.1; + proxy_set_header Connection ""; + + # Disable buffering for streaming endpoints + proxy_buffering off; + proxy_cache_bypass $http_upgrade; + } + + # All other endpoints + location / { + limit_req zone=general burst=50 nodelay; + + proxy_pass http://api_backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header X-Request-ID $request_id; + proxy_http_version 1.1; + proxy_set_header Connection ""; + } + + # Block access to sensitive files + location ~ /\. { + deny all; + access_log off; + log_not_found off; + } + + location ~ \.(env|json|yml|yaml)$ { + deny all; + access_log off; + log_not_found off; + } + + # Error pages + error_page 502 503 504 /50x.html; + location = /50x.html { + root /usr/share/nginx/html; + internal; + } + + error_page 429 /429.html; + location = /429.html { + root /usr/share/nginx/html; + internal; + } + } +} diff --git a/examples/security/rbac_authorization.py b/examples/security/rbac_authorization.py new file mode 100644 index 0000000..363504f --- /dev/null +++ b/examples/security/rbac_authorization.py @@ -0,0 +1,411 @@ +""" +Role-Based Access Control (RBAC) Authorization Example + +This example demonstrates how to implement RBAC authorization in AgentFlow CLI. + +RBAC Model: +- Roles: admin, developer, viewer +- Permissions: Defined per resource and action +- Assignment: Users have one or more roles + +Setup: +1. Create this file in your project (e.g., auth/rbac_backend.py) +2. Configure agentflow.json to use this backend +3. Ensure user context includes 'role' field +""" + +from typing import Any + +# Import the authorization backend interface +# In your project: from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend + + +class RBACAuthorizationBackend(AuthorizationBackend): + """ + Role-Based Access Control (RBAC) authorization backend. + + Defines permissions for different roles across all resources. + """ + + # Permission matrix: role -> resource -> [actions] + PERMISSIONS = { + "admin": { + # Admins have full access to everything + "graph": ["invoke", "stream", "read", "stop", "setup", "fix"], + "checkpointer": ["read", "write", "delete"], + "store": ["read", "write", "delete", "forget"], + }, + "developer": { + # Developers can execute graphs and manage data + "graph": ["invoke", "stream", "read", "setup"], + "checkpointer": ["read", "write"], + "store": ["read", "write"], + }, + "viewer": { + # Viewers have read-only access + "graph": ["read"], + "checkpointer": ["read"], + "store": ["read"], + }, + "guest": { + # Guests have minimal access + "graph": [], + "checkpointer": [], + "store": [], + }, + } + + def __init__(self): + """Initialize RBAC backend.""" + print("RBAC Authorization Backend initialized") + print(f"Configured roles: {', '.join(self.PERMISSIONS.keys())}") + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context, + ) -> bool: + """ + Check if user's role has permission for the requested action. + + Args: + user: User context (must include 'role' field) + resource: Resource being accessed (e.g., 'graph', 'checkpointer') + action: Action being performed (e.g., 'invoke', 'read', 'write') + resource_id: Optional specific resource identifier + **context: Additional context for authorization decision + + Returns: + True if authorized, False otherwise + """ + # Extract user information + user_id = user.get("user_id", "unknown") + role = user.get("role", "guest") # Default to 'guest' if no role + + # Get permissions for the user's role + role_permissions = self.PERMISSIONS.get(role, {}) + + # Get allowed actions for the resource + allowed_actions = role_permissions.get(resource, []) + + # Check if action is permitted + is_authorized = action in allowed_actions + + # Log authorization decision + if is_authorized: + print(f"✓ Authorization granted: {user_id} ({role}) can {action} on {resource}") + else: + print(f"✗ Authorization denied: {user_id} ({role}) cannot {action} on {resource}") + print(f" Allowed actions: {allowed_actions}") + + return is_authorized + + +class HierarchicalRBACBackend(AuthorizationBackend): + """ + Hierarchical RBAC with role inheritance. + + Roles can inherit permissions from parent roles: + admin > developer > viewer > guest + """ + + # Define role hierarchy (child -> parent) + ROLE_HIERARCHY = { + "admin": ["developer", "viewer", "guest"], + "developer": ["viewer", "guest"], + "viewer": ["guest"], + "guest": [], + } + + # Base permissions per role + BASE_PERMISSIONS = { + "admin": { + "graph": ["stop", "fix"], # Admin-only actions + "checkpointer": ["delete"], + "store": ["delete", "forget"], + }, + "developer": { + "graph": ["invoke", "stream", "setup"], + "checkpointer": ["write"], + "store": ["write"], + }, + "viewer": {"graph": ["read"], "checkpointer": ["read"], "store": ["read"]}, + "guest": { + # No base permissions + }, + } + + def __init__(self): + """Initialize hierarchical RBAC backend.""" + print("Hierarchical RBAC Authorization Backend initialized") + + def get_all_permissions(self, role: str) -> dict[str, list[str]]: + """ + Get all permissions for a role, including inherited ones. + + Args: + role: User's role + + Returns: + Dictionary of resource -> actions + """ + permissions: dict[str, list[str]] = {} + + # Get base permissions for this role + base_perms = self.BASE_PERMISSIONS.get(role, {}) + for resource, actions in base_perms.items(): + permissions.setdefault(resource, []).extend(actions) + + # Inherit permissions from parent roles + parent_roles = self.ROLE_HIERARCHY.get(role, []) + for parent_role in parent_roles: + parent_perms = self.BASE_PERMISSIONS.get(parent_role, {}) + for resource, actions in parent_perms.items(): + permissions.setdefault(resource, []).extend(actions) + + # Remove duplicates + for resource in permissions: + permissions[resource] = list(set(permissions[resource])) + + return permissions + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context, + ) -> bool: + """Check authorization with role hierarchy.""" + role = user.get("role", "guest") + user_id = user.get("user_id", "unknown") + + # Get all permissions (including inherited) + all_permissions = self.get_all_permissions(role) + allowed_actions = all_permissions.get(resource, []) + + is_authorized = action in allowed_actions + + if not is_authorized: + print(f"✗ Authorization denied: {user_id} ({role}) cannot {action} on {resource}") + + return is_authorized + + +class MultiRoleRBACBackend(AuthorizationBackend): + """ + RBAC backend supporting multiple roles per user. + + Users can have multiple roles, and are authorized if ANY role + grants the required permission. + """ + + PERMISSIONS = { + "admin": { + "graph": ["invoke", "stream", "read", "stop", "setup", "fix"], + "checkpointer": ["read", "write", "delete"], + "store": ["read", "write", "delete", "forget"], + }, + "developer": { + "graph": ["invoke", "stream", "read", "setup"], + "checkpointer": ["read", "write"], + "store": ["read", "write"], + }, + "viewer": {"graph": ["read"], "checkpointer": ["read"], "store": ["read"]}, + } + + async def authorize( + self, + user: dict[str, Any], + resource: str, + action: str, + resource_id: str | None = None, + **context, + ) -> bool: + """ + Check if any of user's roles grant permission. + + User context should include 'roles' as a list. + """ + user_id = user.get("user_id", "unknown") + + # Support both single role and multiple roles + roles = user.get("roles", []) + if not roles: + single_role = user.get("role") + if single_role: + roles = [single_role] + + if not roles: + print(f"✗ No roles found for user {user_id}") + return False + + # Check if any role grants permission + for role in roles: + role_permissions = self.PERMISSIONS.get(role, {}) + allowed_actions = role_permissions.get(resource, []) + + if action in allowed_actions: + print(f"✓ Authorization granted: {user_id} ({role}) can {action} on {resource}") + return True + + print(f"✗ Authorization denied: {user_id} {roles} cannot {action} on {resource}") + return False + + +""" +CONFIGURATION: + +# agentflow.json +{ + "auth": "jwt", + "authorization": { + "path": "auth.rbac_backend:RBACAuthorizationBackend" + }, + "agent": "graph.react:app" +} + +# For hierarchical RBAC +{ + "authorization": { + "path": "auth.rbac_backend:HierarchicalRBACBackend" + } +} + +# For multi-role RBAC +{ + "authorization": { + "path": "auth.rbac_backend:MultiRoleRBACBackend" + } +} +""" + +""" +USER CONTEXT EXAMPLES: + +# Single role +{ + "user_id": "user_123", + "email": "john@example.com", + "role": "developer" +} + +# Multiple roles +{ + "user_id": "user_456", + "email": "admin@example.com", + "roles": ["admin", "developer"] +} + +# JWT token payload +{ + "user_id": "user_789", + "email": "viewer@example.com", + "role": "viewer", + "exp": 1735689600 +} +""" + +""" +TESTING EXAMPLES: + +# tests/test_rbac_authorization.py +import pytest +from auth.rbac_backend import RBACAuthorizationBackend + +@pytest.fixture +def rbac_backend(): + return RBACAuthorizationBackend() + +@pytest.mark.asyncio +async def test_admin_full_access(rbac_backend): + user = {"user_id": "admin1", "role": "admin"} + + # Admin can do everything + assert await rbac_backend.authorize(user, "graph", "invoke") + assert await rbac_backend.authorize(user, "graph", "stop") + assert await rbac_backend.authorize(user, "checkpointer", "delete") + +@pytest.mark.asyncio +async def test_developer_limited_access(rbac_backend): + user = {"user_id": "dev1", "role": "developer"} + + # Developer can invoke but not stop + assert await rbac_backend.authorize(user, "graph", "invoke") + assert not await rbac_backend.authorize(user, "graph", "stop") + +@pytest.mark.asyncio +async def test_viewer_read_only(rbac_backend): + user = {"user_id": "viewer1", "role": "viewer"} + + # Viewer can only read + assert await rbac_backend.authorize(user, "graph", "read") + assert not await rbac_backend.authorize(user, "graph", "invoke") + assert not await rbac_backend.authorize(user, "checkpointer", "write") + +@pytest.mark.asyncio +async def test_guest_no_access(rbac_backend): + user = {"user_id": "guest1", "role": "guest"} + + # Guest has no access + assert not await rbac_backend.authorize(user, "graph", "read") + assert not await rbac_backend.authorize(user, "graph", "invoke") +""" + +""" +INTEGRATION WITH ENDPOINTS: + +All AgentFlow endpoints automatically use the configured authorization backend. + +# Example: Graph invocation endpoint +@router.post("/graph/invoke") +async def invoke_graph( + user: dict = Depends(RequirePermission("graph", "invoke")), + request: GraphRequest +): + # User is authenticated and authorized + # Role-based access control is enforced + # Only admin and developer roles can reach this point + pass + +# Example: Checkpointer delete endpoint +@router.delete("/checkpointer/thread/{thread_id}") +async def delete_thread( + thread_id: str, + user: dict = Depends(RequirePermission("checkpointer", "delete")) +): + # Only admin role can delete threads + pass +""" + +""" +CUSTOMIZATION: + +# Add custom roles +PERMISSIONS["analyst"] = { + "graph": ["read"], + "checkpointer": ["read"], + "store": ["read", "write"] +} + +# Add custom resources +PERMISSIONS["admin"]["reports"] = ["read", "write", "delete"] + +# Context-aware permissions +async def authorize(self, user, resource, action, resource_id=None, **context): + # Check time-based restrictions + if context.get("after_hours"): + return user.get("role") == "admin" + + # Check IP-based restrictions + if context.get("external_ip"): + return action == "read" + + # Standard RBAC check + return await super().authorize(user, resource, action, resource_id, **context) +""" diff --git a/plan.md b/plan.md new file mode 100644 index 0000000..08cc814 --- /dev/null +++ b/plan.md @@ -0,0 +1,190 @@ +# Dev Playground Plan for Agentflow CLI + +## Summary ✅ +Add a CLI feature (e.g., `agentflow api --open` or `agentflow dev`) that: +- Starts the development FastAPI server (uvicorn) on the selected host/port. +- Serves a simple interactive `playground` UI and any related static assets at `/playground`. +- Automatically opens a single browser window/tab to the `playground` URL when the server is ready. + +This improves the developer DX by surfacing an interactive dev playground on `dev` runs in a single, consistent UX. + +--- + +## Goals & Requirements 🎯 +- Add a CLI flag to launch the playground automatically (default: opt-in: `--open` or `dev` default true). +- Serve a lightweight, local playground UI that calls the API endpoints. +- Ensure the browser opens only once when running the CLI (avoid multiple windows/tabs even while using `--reload`). +- Avoid opening browser in headless CI or tests. +- Keep default behavior unchanged (no production browser opens). + +--- + +## Implementation Options (Pros & Cons) 🔧 +Option A — Minimal (Recommended): +- Add `--open` flag to the existing `api` CLI command (or create an alias `dev`). +- Add a small static HTML/JS playground at `src/app/static/playground/index.html`. +- Add a FastAPI route `/playground` to serve the playground or configure `StaticFiles` middleware. +- Launch uvicorn programmatically via `uvicorn.Server` or start it in a background thread, and once the server responds to `/ping`, open `webbrowser.open(url)`. This avoids opening before the server is ready. + +Pros: Direct, minimal, leverages existing FastAPI app and uvicorn. +Cons: Must ensure it doesn't open on reloader spawn; simple logic required to open only once. + +Option B — Use Swagger / ReDoc (Quick & Safe): +- Reuse existing `docs_url` or `redoc_url` (e.g., open `/docs`). + +Pros: No new routes or frontend required. +Cons: Not a custom playground UI. + +Option C — Use a local `PyWebView` (Desktop app) or native webview. +- Spawn a desktop window embedding the playground. More complex and cross-platform issues. + +Pros: Single top-level window without the browser. +Cons: Extra dependencies; not necessary for a simple developer experience. + +--- + +## Technical Considerations and Constraints ⚠️ +- Uvicorn reload behavior: when `--reload` is enabled it spawns a new server process on code changes. Opening the browser should happen only once, not on each worker spawn. Avoid re-opening by either: + - Opening the browser from the CLI thread after the server responds (preferred). Or, + - Use environment checks to ensure only the main process opens the browser. + +- Polling for readiness: Use the `/ping` endpoint to perform a health check until the server responds (timeout e.g., 5s) then open the browser. + +- Headless / CI environments: Skip opening the browser if `CI` or `GITHUB_ACTIONS` or `CI` env var is set, or if `DISPLAY` is not present (Linux headless) — allow opt-in overriding. + +- Cross platform `webbrowser` support: the Python `webbrowser` module is cross-platform — use `webbrowser.open(url, new=0)`. + +- Avoid opening multiple tabs/windows: Use the browser API `open(..., new=0)` to reuse an existing browser window where possible, and only call once. + +- Security: Keep the playground local, do not expose sensitive resources; add guard to serve playground if `settings.MODE == "DEVELOPMENT"` or explicit config. + +--- + +## CLI UX Design ✨ +- Preferred command: `agentflow dev` (alias to `agentflow api --open`) +- Example usage: + - `agentflow dev` — start dev server + open playground (default) + - `agentflow api --open` — start server + open playground + - `agentflow api` — start server without browser (existing behavior) + - `agentflow dev --no-open` — start server without browser + +Flags: +- `--open / --no-open` or `--playground / --no-playground`. +- `--host`, `--port`, `--reload` retain compatibility. + +--- + +## Proposed Implementation Steps (High-level) 🛠️ +1. Add CLI option and command alias + - Add a new CLI `dev` command in `agentflow_cli/cli/commands/` as a thin wrapper around the existing `api` command or add `--open` flag to `APICommand.execute`. + - Example: `agentflow_cli/cli/commands/api.py`: add `open_playground: bool = False` to signature. + +2. Playground route & static assets + - Add a new router `playground.router` in `src/app/routers` exposing `/playground`. + - Add a `static` directory (e.g. `agentflow_cli/src/app/static/playground/`) with `index.html` and optionally `app.js` and CSS. + - Use `FastAPI`'s `StaticFiles` or `HTMLResponse`/`Jinja2Templates` to serve. + - Register the router conditionally when `settings.MODE == 'DEVELOPMENT'` or `settings.PLAYGROUND_ENABLED`. + +3. Browser opening logic + - Update `APICommand` to accept `open_playground`. + - Option A (recommended): Launch the server programmatically: `uvicorn.Server(config)` in a thread, poll `http://host:port/ping` for readiness, then open the browser. + - Option B: If staying with `uvicorn.run()` (blocking), we can start a small thread to poll and open the browser. + +Example pseudo-code sketch: +```python +# inside APICommand.execute +if open_playground and not is_ci(): + url = f"http://{host}:{port}/playground" + t = threading.Thread(target=_wait_and_open, args=(url, host, port), daemon=True) + t.start() + +uvicorn.run(...) + +# helper +def _wait_and_open(url, host, port): + for _ in range(50): + try: + r = requests.get(f"http://{host}:{port}/ping", timeout=0.5) + if r.ok: + webbrowser.open(url, new=0) + return + except Exception: + time.sleep(0.1) + +``` + +4. Unit tests + - Mock `webbrowser.open` and `requests` to assert the open call when `--open` is used. + - Mock `uvicorn.run` to avoid actually starting the server during unit tests. + +5. Integration tests (optional) + - Spin up app via `TestClient` or start a backward uvicorn instance, call `/playground` to ensure route returns `200`. + +6. Documentation + - Update `README.md` and `docs` describing `dev` command and how to use it. + +--- + +## How to implement single-browser-session behavior (single tab) 🚪 +- Use the `webbrowser`'s `new=0` value to attempt to reuse a window/tab. +- Add a small CLI-local sentinel — write the url or timestamp into a tempfile (e.g., `/tmp/agentflow_playground_{port}.txt`) and only open the browser if that sentinel is missing or stale (e.g., older than 15 seconds). This avoids re-opening on reload events. +- The sentinel approach isn't bulletproof but simplifies logic and avoids frequent reopens. + +--- + +## Edge Cases & Tests 🧪 +- CI environments: Should not open the browser. Detect CI by `CI` env var or skip/guard. +- Headless Linux: If `DISPLAY` is absent and not WSL, skip opening. +- `--reload` (dev reload): Avoid re-opening on each reload by keeping a sentinel or ensuring the open call fires only once. +- Port in use: If port in use, error should be returned gracefully. + +--- + +## Acceptance Criteria ✅ +- [ ] A `dev` command or `--open-playground` flag is present and documented. +- [ ] Launcher opens browser to `/playground` only once at initial start. +- [ ] Playground UI is served at `/playground` and works as expected (basic testing endpoints). +- [ ] CLI unit tests verify the browser open behavior (using mocking); integration tests verify route serves content. + +--- + +## Optional Future Enhancements 💡 +- Toggle auto-open link per environment variable or per `setup.cfg`/`settings` for developer preference. +- Build a more robust playground UI with OAuth injection or virtual keys (with caution). +- Option to open a standalone desktop playground (PyWebView) for native UX. + +--- + +## Resources & Links (Research) 📚 +- Python webbrowser: https://docs.python.org/3/library/webbrowser.html +- Uvicorn reload and process/child detection: https://www.uvicorn.org/ +- Starting uvicorn programmatically: https://www.uvicorn.org/deployment/#programmatically +- FastAPI static files & HTML responses: https://fastapi.tiangolo.com/tutorial/static-files/ +- Example implementations that open a browser on startup: many projects use `webbrowser.open` plus a health check. + + +--- + +## Next Steps (Implementation action list) 📝 +- [ ] Add `open_playground` flag to `APICommand` (or add `dev` command). +- [ ] Add `playground` route & static files. +- [ ] Add readiness poll + `webbrowser.open` into CLI logic. +- [ ] Add unit tests for CLI open behavior and route tests for `/playground`. +- [ ] Document usage and examples in `README.md`. + + +--- + +Appendix: Example CLI usage +```bash +# Start the server and open playground (default on dev) +agentflow dev + +# Start API server without opening browser +agentflow api --no-open + +# Start API server and explicitly open +agentflow api --open +``` + +If you want, I can follow up by implementing a minimal PR that adds `--open` and a tiny `playground` route (HTML + small JS) and tests. Let me know which option you prefer (A: `dev` alias default true, B: `--open` opt-in for `api`), and I will start implementing it. \ No newline at end of file diff --git a/test.py b/test.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/unit_tests/test_error_sanitization.py b/tests/unit_tests/test_error_sanitization.py new file mode 100644 index 0000000..3f00ab4 --- /dev/null +++ b/tests/unit_tests/test_error_sanitization.py @@ -0,0 +1,220 @@ +"""Unit tests for error message sanitization.""" + +from unittest.mock import patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from starlette.exceptions import HTTPException + +from agentflow_cli.src.app.core.exceptions.handle_errors import ( + _sanitize_error_message, + init_errors_handler, +) + + +def test_sanitize_error_message_in_production(): + """Test that error messages are sanitized in production.""" + # Test various error codes + assert ( + _sanitize_error_message("Detailed internal error", "GRAPH_000", is_production=True) + == "An error occurred executing the graph." + ) + + assert ( + _sanitize_error_message( + "Database connection failed at 192.168.1.100", "STORAGE_001", is_production=True + ) + == "An error occurred accessing storage." + ) + + assert ( + _sanitize_error_message( + "Invalid field: user.password", "VALIDATION_ERROR", is_production=True + ) + == "The request data is invalid. Please check your input." + ) + + +def test_sanitize_error_message_in_development(): + """Test that error messages are not sanitized in development.""" + detailed_message = "Detailed internal error with stack trace" + + result = _sanitize_error_message(detailed_message, "GRAPH_000", is_production=False) + + assert result == detailed_message + + +def test_sanitize_unknown_error_code(): + """Test sanitization of unknown error codes.""" + result = _sanitize_error_message("Some error", "UNKNOWN_ERROR_CODE", is_production=True) + + assert result == "An unexpected error occurred. Please contact support." + + +@pytest.fixture +def app_with_error_handlers(): + """Create a FastAPI app with error handlers.""" + app = FastAPI() + + # Add request ID middleware + from agentflow_cli.src.app.core.config.setup_middleware import RequestIDMiddleware + + app.add_middleware(RequestIDMiddleware) + + init_errors_handler(app) + + @app.get("/test-error") + async def test_error(): + raise HTTPException(status_code=500, detail="Internal server error with details") + + @app.get("/test-validation") + async def test_validation(): + raise ValueError("Invalid input: password field missing") + + return app + + +def test_http_exception_sanitized_in_production(): + """Test that HTTP exceptions are sanitized in production.""" + import os + + # Set environment before importing + os.environ["MODE"] = "production" + + # Clear settings cache to pick up new environment + from agentflow_cli.src.app.core.config.settings import get_settings + + get_settings.cache_clear() + + from fastapi import FastAPI + from fastapi.testclient import TestClient + from agentflow_cli.src.app.core.exceptions.handle_errors import init_errors_handler + from agentflow_cli.src.app.core.config.setup_middleware import RequestIDMiddleware + + app = FastAPI() + app.add_middleware(RequestIDMiddleware) + init_errors_handler(app) + + @app.get("/test") + async def test_endpoint(): + raise HTTPException(status_code=500, detail="Internal server error with details") + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 500 + error_data = response.json() + + # Message should be generic in production + assert error_data["error"]["message"] == "An error occurred processing your request." + + # Cleanup + del os.environ["MODE"] + get_settings.cache_clear() + + +def test_http_exception_detailed_in_development(): + """Test that HTTP exceptions show details in development.""" + import os + + # Set environment before importing + os.environ["MODE"] = "development" + + from fastapi import FastAPI + from fastapi.testclient import TestClient + from agentflow_cli.src.app.core.exceptions.handle_errors import init_errors_handler + from agentflow_cli.src.app.core.config.setup_middleware import RequestIDMiddleware + + # Clear settings cache + from agentflow_cli.src.app.core.config.settings import get_settings + + get_settings.cache_clear() + + app = FastAPI() + app.add_middleware(RequestIDMiddleware) + init_errors_handler(app) + + @app.get("/test") + async def test_endpoint(): + raise HTTPException(status_code=500, detail="Internal server error with details") + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 500 + error_data = response.json() + + # In development, should show detailed message + assert "Internal server error with details" in error_data["error"]["message"] + + # Cleanup + del os.environ["MODE"] + get_settings.cache_clear() + + +def test_validation_error_sanitized_in_production(): + """Test that validation errors are sanitized in production.""" + import os + + # Set environment before importing + os.environ["MODE"] = "production" + + from fastapi import FastAPI + from fastapi.testclient import TestClient + from agentflow_cli.src.app.core.exceptions.handle_errors import init_errors_handler + from agentflow_cli.src.app.core.config.setup_middleware import RequestIDMiddleware + from agentflow_cli.src.app.core.config.settings import get_settings + + # Clear cache + get_settings.cache_clear() + + app = FastAPI() + app.add_middleware(RequestIDMiddleware) + init_errors_handler(app) + + @app.get("/test") + async def test_endpoint(): + raise ValueError("Invalid input: password field missing") + + client = TestClient(app) + response = client.get("/test") + + assert response.status_code == 422 + error_data = response.json() + + # Should show generic message in production + assert error_data["error"]["message"] == "Invalid input provided." + + # Cleanup + del os.environ["MODE"] + + +def test_error_response_includes_request_id(app_with_error_handlers): + """Test that error responses include request ID.""" + client = TestClient(app_with_error_handlers) + response = client.get("/test-error") + + assert "metadata" in response.json() + assert "request_id" in response.json()["metadata"] + + +def test_all_error_code_prefixes_covered(): + """Test that all major error code prefixes have generic messages.""" + error_codes = [ + "VALIDATION_ERROR", + "GRAPH_000", + "NODE_000", + "STORAGE_000", + "METRICS_000", + "SCHEMA_VERSION_000", + "SERIALIZATION_000", + ] + + for error_code in error_codes: + result = _sanitize_error_message("Detailed error message", error_code, is_production=True) + + # Should not return the original message + assert result != "Detailed error message" + # Should return a generic message + assert len(result) > 0 diff --git a/tests/unit_tests/test_handle_errors.py b/tests/unit_tests/test_handle_errors.py index 5cae143..5bb2522 100644 --- a/tests/unit_tests/test_handle_errors.py +++ b/tests/unit_tests/test_handle_errors.py @@ -10,6 +10,15 @@ def test_http_exception_handler_returns_error_payload(): + import os + + # Ensure development mode for this test + os.environ["MODE"] = "development" + + from agentflow_cli.src.app.core.config.settings import get_settings + + get_settings.cache_clear() + app = FastAPI() setup_middleware(app) init_errors_handler(app) @@ -24,3 +33,7 @@ def boom(): body = r.json() assert body["error"]["code"] == "HTTPException" assert body["error"]["message"] == "nope" + + # Cleanup + if "MODE" in os.environ: + del os.environ["MODE"] diff --git a/tests/unit_tests/test_request_limits.py b/tests/unit_tests/test_request_limits.py new file mode 100644 index 0000000..359361e --- /dev/null +++ b/tests/unit_tests/test_request_limits.py @@ -0,0 +1,95 @@ +"""Unit tests for request size limit middleware.""" + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from agentflow_cli.src.app.core.middleware.request_limits import RequestSizeLimitMiddleware + + +@pytest.fixture +def app_with_limit(): + """Create a FastAPI app with request size limit middleware.""" + app = FastAPI() + + # Add middleware with 1KB limit for testing + app.add_middleware(RequestSizeLimitMiddleware, max_size=1024) + + @app.post("/test") + async def test_endpoint(data: dict): + return {"status": "ok", "data": data} + + return app + + +def test_request_under_limit(app_with_limit): + """Test that requests under the size limit are allowed.""" + client = TestClient(app_with_limit) + + # Small payload (under 1KB) + small_data = {"message": "Hello, World!"} + response = client.post("/test", json=small_data) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + +def test_request_over_limit(app_with_limit): + """Test that requests over the size limit are rejected.""" + client = TestClient(app_with_limit) + + # Large payload (over 1KB) + large_data = {"message": "x" * 2000} + response = client.post("/test", json=large_data) + + assert response.status_code == 413 + assert response.json()["error"]["code"] == "REQUEST_TOO_LARGE" + assert "request_id" in response.json()["metadata"] + + +def test_request_without_content_length(app_with_limit): + """Test that requests without content-length header are allowed.""" + client = TestClient(app_with_limit) + + # TestClient automatically adds content-length, but if it's missing + # the middleware should allow the request through + response = client.post("/test", json={"message": "test"}) + + # Should succeed since small payload + assert response.status_code == 200 + + +def test_middleware_with_default_limit(): + """Test middleware with default 10MB limit.""" + app = FastAPI() + app.add_middleware(RequestSizeLimitMiddleware) # Default 10MB + + @app.post("/test") + async def test_endpoint(data: dict): + return {"status": "ok"} + + client = TestClient(app) + response = client.post("/test", json={"message": "test"}) + + assert response.status_code == 200 + + +def test_error_response_format(app_with_limit): + """Test that error response has correct format.""" + client = TestClient(app_with_limit) + + large_data = {"message": "x" * 2000} + response = client.post("/test", json=large_data) + + assert response.status_code == 413 + + json_response = response.json() + assert "error" in json_response + assert "metadata" in json_response + + error = json_response["error"] + assert error["code"] == "REQUEST_TOO_LARGE" + assert "max_size_bytes" in error + assert "max_size_mb" in error + assert error["max_size_bytes"] == 1024 + assert error["max_size_mb"] == 1024 / (1024 * 1024) diff --git a/tests/unit_tests/test_security_config.py b/tests/unit_tests/test_security_config.py new file mode 100644 index 0000000..4d8ac55 --- /dev/null +++ b/tests/unit_tests/test_security_config.py @@ -0,0 +1,105 @@ +"""Unit tests for security configuration warnings.""" + +import os +from unittest.mock import patch + +import pytest + +from agentflow_cli.src.app.core.config.settings import Settings, get_settings + + +def test_mode_normalization(): + """Test that MODE is normalized to lowercase.""" + with patch.dict(os.environ, {"MODE": "PRODUCTION"}): + settings = Settings() + assert settings.MODE == "production" + + with patch.dict(os.environ, {"MODE": "Development"}): + settings = Settings() + assert settings.MODE == "development" + + +def test_cors_wildcard_warning_in_production(caplog): + """Test warning for CORS wildcard in production.""" + with patch.dict(os.environ, {"MODE": "production", "ORIGINS": "*"}): + settings = Settings() + assert "CORS ORIGINS='*' in production" in caplog.text + + +def test_cors_wildcard_no_warning_in_development(caplog): + """Test no warning for CORS wildcard in development.""" + with patch.dict(os.environ, {"MODE": "development", "ORIGINS": "*"}): + settings = Settings() + assert "CORS ORIGINS" not in caplog.text or "production" not in caplog.text + + +def test_debug_mode_warning_in_production(caplog): + """Test warning for DEBUG mode enabled in production.""" + with patch.dict(os.environ, {"MODE": "production", "IS_DEBUG": "true"}): + settings = Settings() + assert "DEBUG mode is enabled in production" in caplog.text + + +def test_docs_enabled_warning_in_production(caplog): + """Test warning for API docs enabled in production.""" + with patch.dict(os.environ, {"MODE": "production"}): + settings = Settings() + # Default has DOCS_PATH="/docs" + assert "API documentation endpoints are enabled" in caplog.text + + +def test_allowed_host_wildcard_warning_in_production(caplog): + """Test warning for ALLOWED_HOST wildcard in production.""" + with patch.dict(os.environ, {"MODE": "production", "ALLOWED_HOST": "*"}): + settings = Settings() + assert "ALLOWED_HOST='*' in production" in caplog.text + + +def test_no_warnings_in_development(caplog): + """Test that no security warnings appear in development mode.""" + with patch.dict(os.environ, {"MODE": "development"}): + settings = Settings() + assert "SECURITY WARNING" not in caplog.text + + +def test_multiple_warnings_in_production(caplog): + """Test that multiple warnings can appear together.""" + with patch.dict( + os.environ, + { + "MODE": "production", + "ORIGINS": "*", + "IS_DEBUG": "true", + "ALLOWED_HOST": "*", + }, + ): + settings = Settings() + + log_text = caplog.text + assert "CORS ORIGINS='*'" in log_text + assert "DEBUG mode is enabled" in log_text + assert "ALLOWED_HOST='*'" in log_text + + +def test_max_request_size_default(): + """Test that MAX_REQUEST_SIZE has correct default.""" + settings = Settings() + assert settings.MAX_REQUEST_SIZE == 10 * 1024 * 1024 # 10MB + + +def test_max_request_size_configurable(): + """Test that MAX_REQUEST_SIZE is configurable via env var.""" + with patch.dict(os.environ, {"MAX_REQUEST_SIZE": "5242880"}): # 5MB + settings = Settings() + assert settings.MAX_REQUEST_SIZE == 5242880 + + +def test_settings_caching(): + """Test that get_settings returns cached instance.""" + # Clear cache first + get_settings.cache_clear() + + settings1 = get_settings() + settings2 = get_settings() + + assert settings1 is settings2 diff --git a/tests/unit_tests/test_security_headers.py b/tests/unit_tests/test_security_headers.py new file mode 100644 index 0000000..c904582 --- /dev/null +++ b/tests/unit_tests/test_security_headers.py @@ -0,0 +1,318 @@ +""" +Tests for SecurityHeadersMiddleware + +Tests cover: +- All security headers are added correctly +- HSTS header only added for HTTPS requests +- Configuration options work correctly +- Custom CSP and Permissions policies +- Middleware can be disabled +""" + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from agentflow_cli.src.app.core.middleware.security_headers import ( + SecurityHeadersMiddleware, + create_security_headers_middleware, +) + + +@pytest.fixture +def app(): + """Create a test FastAPI application.""" + app = FastAPI() + + @app.get("/test") + async def test_endpoint(): + return {"message": "test"} + + return app + + +@pytest.fixture +def client_with_headers(app): + """Create test client with security headers middleware.""" + app.add_middleware(SecurityHeadersMiddleware) + return TestClient(app) + + +def test_basic_security_headers_added(client_with_headers): + """Test that basic security headers are added to responses.""" + response = client_with_headers.get("/test") + + assert response.status_code == 200 + assert "X-Content-Type-Options" in response.headers + assert response.headers["X-Content-Type-Options"] == "nosniff" + + assert "X-Frame-Options" in response.headers + assert response.headers["X-Frame-Options"] == "DENY" + + assert "X-XSS-Protection" in response.headers + assert response.headers["X-XSS-Protection"] == "1; mode=block" + + assert "Referrer-Policy" in response.headers + assert response.headers["Referrer-Policy"] == "strict-origin-when-cross-origin" + + +def test_permissions_policy_added(client_with_headers): + """Test that Permissions-Policy header is added.""" + response = client_with_headers.get("/test") + + assert "Permissions-Policy" in response.headers + assert "geolocation=()" in response.headers["Permissions-Policy"] + assert "microphone=()" in response.headers["Permissions-Policy"] + assert "camera=()" in response.headers["Permissions-Policy"] + + +def test_csp_policy_added(client_with_headers): + """Test that Content-Security-Policy header is added.""" + response = client_with_headers.get("/test") + + assert "Content-Security-Policy" in response.headers + csp = response.headers["Content-Security-Policy"] + assert "default-src 'self'" in csp + assert "frame-ancestors 'none'" in csp + + +def test_hsts_not_added_for_http(client_with_headers): + """Test that HSTS header is NOT added for HTTP requests.""" + response = client_with_headers.get("/test") + + # HSTS should not be present for HTTP requests + assert "Strict-Transport-Security" not in response.headers + + +def test_hsts_added_for_https(app): + """Test that HSTS header IS added for HTTPS requests.""" + app.add_middleware(SecurityHeadersMiddleware) + client = TestClient(app, base_url="https://testserver") + + response = client.get("/test") + + # HSTS should be present for HTTPS requests + assert "Strict-Transport-Security" in response.headers + hsts = response.headers["Strict-Transport-Security"] + assert "max-age=31536000" in hsts + assert "includeSubDomains" in hsts + + +def test_hsts_with_x_forwarded_proto(app): + """Test that HSTS is added when X-Forwarded-Proto is https (proxied requests).""" + app.add_middleware(SecurityHeadersMiddleware) + client = TestClient(app) + + response = client.get("/test", headers={"X-Forwarded-Proto": "https"}) + + # HSTS should be present when proxied via HTTPS + assert "Strict-Transport-Security" in response.headers + + +def test_custom_frame_options(app): + """Test custom X-Frame-Options value.""" + app.add_middleware( + SecurityHeadersMiddleware, + frame_options="SAMEORIGIN", + ) + client = TestClient(app) + + response = client.get("/test") + + assert response.headers["X-Frame-Options"] == "SAMEORIGIN" + + +def test_custom_csp_policy(app): + """Test custom Content-Security-Policy.""" + custom_csp = "default-src 'none'; script-src 'self'" + app.add_middleware( + SecurityHeadersMiddleware, + csp_policy=custom_csp, + ) + client = TestClient(app) + + response = client.get("/test") + + assert response.headers["Content-Security-Policy"] == custom_csp + + +def test_custom_permissions_policy(app): + """Test custom Permissions-Policy.""" + custom_policy = "geolocation=*, camera=()" + app.add_middleware( + SecurityHeadersMiddleware, + permissions_policy=custom_policy, + ) + client = TestClient(app) + + response = client.get("/test") + + assert response.headers["Permissions-Policy"] == custom_policy + + +def test_hsts_disabled(app): + """Test that HSTS can be disabled.""" + app.add_middleware( + SecurityHeadersMiddleware, + enable_hsts=False, + ) + client = TestClient(app, base_url="https://testserver") + + response = client.get("/test") + + # HSTS should not be present even for HTTPS when disabled + assert "Strict-Transport-Security" not in response.headers + + +def test_hsts_with_preload(app): + """Test HSTS with preload option.""" + app.add_middleware( + SecurityHeadersMiddleware, + hsts_preload=True, + ) + client = TestClient(app, base_url="https://testserver") + + response = client.get("/test") + + hsts = response.headers["Strict-Transport-Security"] + assert "preload" in hsts + + +def test_hsts_without_subdomains(app): + """Test HSTS without includeSubDomains.""" + app.add_middleware( + SecurityHeadersMiddleware, + hsts_include_subdomains=False, + ) + client = TestClient(app, base_url="https://testserver") + + response = client.get("/test") + + hsts = response.headers["Strict-Transport-Security"] + assert "includeSubDomains" not in hsts + + +def test_custom_hsts_max_age(app): + """Test custom HSTS max-age value.""" + app.add_middleware( + SecurityHeadersMiddleware, + hsts_max_age=63072000, # 2 years + ) + client = TestClient(app, base_url="https://testserver") + + response = client.get("/test") + + hsts = response.headers["Strict-Transport-Security"] + assert "max-age=63072000" in hsts + + +def test_factory_function(app): + """Test the create_security_headers_middleware factory function.""" + CustomMiddleware = create_security_headers_middleware( + frame_options="SAMEORIGIN", + hsts_max_age=86400, + ) + app.add_middleware(CustomMiddleware) + client = TestClient(app) + + response = client.get("/test") + + assert response.headers["X-Frame-Options"] == "SAMEORIGIN" + + +def test_all_headers_present(client_with_headers): + """Test that all expected security headers are present.""" + response = client_with_headers.get("/test") + + expected_headers = [ + "X-Content-Type-Options", + "X-Frame-Options", + "X-XSS-Protection", + "Referrer-Policy", + "Permissions-Policy", + "Content-Security-Policy", + ] + + for header in expected_headers: + assert header in response.headers, f"Missing header: {header}" + + +def test_headers_not_overridden_if_set_by_endpoint(app): + """Test that endpoint-set headers are preserved (middleware adds, doesn't override).""" + + @app.get("/custom-header") + async def custom_header_endpoint(): + from fastapi import Response + + response = Response(content='{"message": "test"}') + # Note: Middleware runs after endpoint, so it will add headers + # This test verifies the middleware doesn't break custom headers + return response + + app.add_middleware(SecurityHeadersMiddleware) + client = TestClient(app) + + response = client.get("/custom-header") + + # Security headers should still be added + assert "X-Content-Type-Options" in response.headers + + +def test_middleware_with_different_request_methods(client_with_headers): + """Test that security headers are added for all HTTP methods.""" + methods = ["get", "post", "put", "delete", "patch", "head", "options"] + + for method in methods: + client_method = getattr(client_with_headers, method) + response = client_method("/test") + + # All methods should have security headers (except possibly 405 for unsupported methods) + if response.status_code != 405: + assert "X-Content-Type-Options" in response.headers + + +@pytest.mark.parametrize( + "frame_option", + ["DENY", "SAMEORIGIN", "ALLOW-FROM https://example.com"], +) +def test_various_frame_options(app, frame_option): + """Test various X-Frame-Options values.""" + app.add_middleware( + SecurityHeadersMiddleware, + frame_options=frame_option, + ) + client = TestClient(app) + + response = client.get("/test") + + assert response.headers["X-Frame-Options"] == frame_option + + +def test_empty_csp_policy(app): + """Test that empty CSP policy still uses default.""" + app.add_middleware( + SecurityHeadersMiddleware, + csp_policy="", + ) + client = TestClient(app) + + response = client.get("/test") + + # Empty string is falsy, so default should be used + assert "Content-Security-Policy" in response.headers + assert "default-src 'self'" in response.headers["Content-Security-Policy"] + + +def test_none_csp_policy_uses_default(app): + """Test that None CSP policy uses default.""" + app.add_middleware( + SecurityHeadersMiddleware, + csp_policy=None, + ) + client = TestClient(app) + + response = client.get("/test") + + # Should use default CSP + assert "Content-Security-Policy" in response.headers + assert "default-src 'self'" in response.headers["Content-Security-Policy"] diff --git a/uv.lock b/uv.lock index 61232b1..1ad3a37 100644 --- a/uv.lock +++ b/uv.lock @@ -22,7 +22,7 @@ wheels = [ [[package]] name = "10xscale-agentflow-cli" -version = "0.1.9" +version = "0.2.6" source = { editable = "." } dependencies = [ { name = "10xscale-agentflow" },