From 45ff617303a25361c54770903afb75d1912501be Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Sat, 29 Nov 2025 15:58:22 +0600 Subject: [PATCH 1/2] feat: Implement JWT authentication handling in auth module --- agentflow_cli/src/app/core/auth/jwt_auth.py | 7 ++++++- agentflow_cli/src/app/loader.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/agentflow_cli/src/app/core/auth/jwt_auth.py b/agentflow_cli/src/app/core/auth/jwt_auth.py index decad6c..020fea0 100644 --- a/agentflow_cli/src/app/core/auth/jwt_auth.py +++ b/agentflow_cli/src/app/core/auth/jwt_auth.py @@ -44,6 +44,11 @@ def authenticate( jwt_secret_key = os.environ.get("JWT_SECRET_KEY", None) jwt_algorithm = os.environ.get("JWT_ALGORITHM", None) + # check bearer token then remove barer prefix + token = credential.credentials + if token.lower().startswith("bearer "): + token = token[7:] + if jwt_secret_key is None or jwt_algorithm is None: raise UserAccountError( message="JWT settings are not configured", @@ -52,7 +57,7 @@ def authenticate( try: decoded_token = jwt.decode( - credential.credentials, + token, jwt_secret_key, # type: ignore algorithms=[jwt_algorithm], # type: ignore ) diff --git a/agentflow_cli/src/app/loader.py b/agentflow_cli/src/app/loader.py index 99ef2fa..18bd574 100644 --- a/agentflow_cli/src/app/loader.py +++ b/agentflow_cli/src/app/loader.py @@ -202,6 +202,14 @@ async def attach_all_modules( path, ) container.bind_instance(BaseAuth, auth_backend) + elif method == "jwt": + from agentflow_cli.src.app.core.auth.jwt_auth import JwtAuth + + jwt_auth = JwtAuth() + container.bind_instance(BaseAuth, jwt_auth) + + elif method == "none": + container.bind_instance(BaseAuth, None, allow_none=True) else: # bind None container.bind_instance(BaseAuth, None, allow_none=True) From 36a801a49d4e86e0153078197dc814667fc03664 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Sat, 29 Nov 2025 16:07:32 +0600 Subject: [PATCH 2/2] Add unit tests for JWT authentication and GraphConfig auth configuration - Implement comprehensive unit tests for the JwtAuth class covering various scenarios including null credentials, missing JWT settings, expired tokens, and valid tokens. - Create tests for handling Bearer prefix in tokens and setting WWW-Authenticate headers. - Add unit tests for the GraphConfig class to validate auth configuration scenarios, including no auth configured, JWT auth with valid and missing environment variables, and custom auth configurations. - Ensure proper error handling and logging for invalid configurations and authentication failures. --- tests/unit_tests/auth/test_auth_backend.py | 307 ++++++++ .../unit_tests/auth/test_graph_config_auth.py | 426 +++++++++++ tests/unit_tests/auth/test_jwt_auth.py | 679 ++++++++++++++++++ 3 files changed, 1412 insertions(+) create mode 100644 tests/unit_tests/auth/test_auth_backend.py create mode 100644 tests/unit_tests/auth/test_graph_config_auth.py create mode 100644 tests/unit_tests/auth/test_jwt_auth.py diff --git a/tests/unit_tests/auth/test_auth_backend.py b/tests/unit_tests/auth/test_auth_backend.py new file mode 100644 index 0000000..9f71c63 --- /dev/null +++ b/tests/unit_tests/auth/test_auth_backend.py @@ -0,0 +1,307 @@ +""" +Unit tests for auth_backend module. + +Tests cover the verify_current_user function which is the integration point +for JWT authentication in the application. +""" + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import Response +from fastapi.security import HTTPAuthorizationCredentials + +from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user +from agentflow_cli.src.app.core.auth.base_auth import BaseAuth +from agentflow_cli.src.app.core.config.graph_config import GraphConfig + + +class MockBaseAuth(BaseAuth): + """Mock implementation of BaseAuth for testing.""" + + def __init__(self, return_value: dict[str, Any] | None = None, raise_exception: bool = False): + self._return_value = return_value + self._raise_exception = raise_exception + + def authenticate( + self, res: Response, credential: HTTPAuthorizationCredentials + ) -> dict[str, Any] | None: + if self._raise_exception: + raise ValueError("Authentication failed") + return self._return_value + + +class TestVerifyCurrentUser: + """Test suite for verify_current_user function.""" + + @pytest.fixture + def mock_response(self) -> Response: + """Create a mock FastAPI Response object.""" + return Response() + + @pytest.fixture + def mock_credentials(self) -> HTTPAuthorizationCredentials: + """Create mock HTTP Authorization credentials.""" + return HTTPAuthorizationCredentials(scheme="Bearer", credentials="test-token") + + @pytest.fixture + def mock_graph_config_no_auth(self) -> MagicMock: + """Create a mock GraphConfig that returns None for auth_config.""" + config = MagicMock(spec=GraphConfig) + config.auth_config.return_value = None + return config + + @pytest.fixture + def mock_graph_config_jwt_auth(self) -> MagicMock: + """Create a mock GraphConfig that returns JWT auth config.""" + config = MagicMock(spec=GraphConfig) + config.auth_config.return_value = {"method": "jwt"} + return config + + def test_returns_empty_dict_when_no_auth_backend_configured( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_no_auth: MagicMock, + ): + """Test that empty dict is returned when auth is not configured.""" + mock_auth_backend = MockBaseAuth(return_value={"user_id": "123"}) + + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_no_auth, + auth_backend=mock_auth_backend, + ) + + assert result == {} + + def test_returns_empty_dict_when_auth_backend_is_none( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that empty dict is returned when auth_backend is None.""" + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=None, + ) + + assert result == {} + mock_logger.error.assert_called_once_with("Auth backend is not configured") + + def test_returns_user_dict_on_successful_authentication( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that user dict is returned when authentication succeeds.""" + expected_user = { + "user_id": "user-123", + "email": "test@example.com", + "role": "admin", + } + mock_auth_backend = MockBaseAuth(return_value=expected_user) + + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == expected_user + assert result["user_id"] == "user-123" + + def test_returns_empty_dict_when_authenticate_returns_none( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that empty dict is returned when authenticate returns None.""" + mock_auth_backend = MockBaseAuth(return_value=None) + + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == {} + + def test_logs_error_when_user_dict_missing_user_id( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that error is logged when authenticated user dict has no user_id.""" + user_without_id = {"email": "test@example.com", "role": "admin"} + mock_auth_backend = MockBaseAuth(return_value=user_without_id) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + # Should still return the user dict even without user_id + assert result == user_without_id + mock_logger.error.assert_called_once_with( + "Authentication failed: 'user_id' not found in user info" + ) + + def test_does_not_log_error_when_user_dict_has_user_id( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that no error is logged when authenticated user dict has user_id.""" + user_with_id = {"user_id": "123", "email": "test@example.com"} + mock_auth_backend = MockBaseAuth(return_value=user_with_id) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == user_with_id + mock_logger.error.assert_not_called() + + def test_does_not_log_error_when_authenticate_returns_empty_dict( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that no error is logged when authenticate returns empty dict.""" + mock_auth_backend = MockBaseAuth(return_value={}) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + # Empty dict is falsy, so the 'if user' condition fails + # and we return {} without logging + assert result == {} + mock_logger.error.assert_not_called() + + def test_returns_user_with_numeric_user_id( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that numeric user_id works correctly.""" + user_with_numeric_id = {"user_id": 12345, "email": "test@example.com"} + mock_auth_backend = MockBaseAuth(return_value=user_with_numeric_id) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == user_with_numeric_id + # numeric user_id still passes the 'in' check + mock_logger.error.assert_not_called() + + def test_returns_user_with_empty_string_user_id( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that empty string user_id is accepted (key exists).""" + user_with_empty_id = {"user_id": "", "email": "test@example.com"} + mock_auth_backend = MockBaseAuth(return_value=user_with_empty_id) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == user_with_empty_id + # Empty string user_id still passes 'in' check (key exists) + mock_logger.error.assert_not_called() + + def test_returns_user_with_none_user_id( + self, + mock_response: Response, + mock_credentials: HTTPAuthorizationCredentials, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that None user_id is accepted (key exists).""" + user_with_none_id = {"user_id": None, "email": "test@example.com"} + mock_auth_backend = MockBaseAuth(return_value=user_with_none_id) + + with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: + result = verify_current_user( + res=mock_response, + credential=mock_credentials, + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth_backend, + ) + + assert result == user_with_none_id + # None user_id still passes 'in' check (key exists) + mock_logger.error.assert_not_called() + + +class TestVerifyCurrentUserWithNullCredentials: + """Test verify_current_user with null credentials scenarios.""" + + @pytest.fixture + def mock_response(self) -> Response: + """Create a mock FastAPI Response object.""" + return Response() + + @pytest.fixture + def mock_graph_config_jwt_auth(self) -> MagicMock: + """Create a mock GraphConfig that returns JWT auth config.""" + config = MagicMock(spec=GraphConfig) + config.auth_config.return_value = {"method": "jwt"} + return config + + def test_passes_null_credentials_to_auth_backend( + self, + mock_response: Response, + mock_graph_config_jwt_auth: MagicMock, + ): + """Test that null credentials are passed to auth backend.""" + # Create a mock that tracks calls + mock_auth = MagicMock(spec=BaseAuth) + mock_auth.authenticate.return_value = {"user_id": "123"} + + verify_current_user( + res=mock_response, + credential=None, # Null credentials + config=mock_graph_config_jwt_auth, + auth_backend=mock_auth, + ) + + # Verify authenticate was called with None credentials + mock_auth.authenticate.assert_called_once_with(mock_response, None) diff --git a/tests/unit_tests/auth/test_graph_config_auth.py b/tests/unit_tests/auth/test_graph_config_auth.py new file mode 100644 index 0000000..8f36edf --- /dev/null +++ b/tests/unit_tests/auth/test_graph_config_auth.py @@ -0,0 +1,426 @@ +""" +Unit tests for GraphConfig.auth_config method. + +Tests cover all JWT authentication configuration scenarios including: +- No auth configured +- JWT auth with valid environment variables +- JWT auth with missing environment variables +- Custom auth configuration +- Invalid auth configuration +""" + +import json +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from agentflow_cli.src.app.core.config.graph_config import GraphConfig + + +class TestGraphConfigAuthConfig: + """Test suite for GraphConfig.auth_config method.""" + + @pytest.fixture + def temp_config_file(self, tmp_path: Path): + """Create a temporary config file and return a function to write to it.""" + + def _create_config(config_data: dict) -> str: + config_path = tmp_path / "agentflow.json" + with open(config_path, "w") as f: + json.dump(config_data, f) + return str(config_path) + + return _create_config + + # ========================================================================= + # Test: No auth configured + # ========================================================================= + def test_auth_config_returns_none_when_auth_not_in_config( + self, + temp_config_file, + ): + """Test that auth_config returns None when auth is not in config.""" + config_path = temp_config_file({"agent": "path/to/agent.py"}) + graph_config = GraphConfig(path=config_path) + + result = graph_config.auth_config() + + assert result is None + + def test_auth_config_returns_none_when_auth_is_null( + self, + temp_config_file, + ): + """Test that auth_config returns None when auth is explicitly null.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": None}) + graph_config = GraphConfig(path=config_path) + + result = graph_config.auth_config() + + assert result is None + + def test_auth_config_returns_none_when_auth_is_empty_string( + self, + temp_config_file, + ): + """Test that auth_config returns None when auth is empty string.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": ""}) + graph_config = GraphConfig(path=config_path) + + result = graph_config.auth_config() + + assert result is None + + # ========================================================================= + # Test: JWT auth with valid environment variables + # ========================================================================= + def test_auth_config_returns_jwt_method_when_jwt_string_and_env_vars_set( + self, + temp_config_file, + ): + """Test that auth_config returns JWT method when 'jwt' string and env vars are set.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "test-secret", + "JWT_ALGORITHM": "HS256", + }, + ): + graph_config = GraphConfig(path=config_path) + result = graph_config.auth_config() + + assert result == {"method": "jwt"} + + def test_auth_config_jwt_string_with_jwt_substring( + self, + temp_config_file, + ): + """Test that auth config works with strings containing 'jwt' substring.""" + # The code uses 'jwt' in res, so "my-jwt-auth" would also match + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "my-jwt-auth"}) + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "test-secret", + "JWT_ALGORITHM": "HS256", + }, + ): + graph_config = GraphConfig(path=config_path) + result = graph_config.auth_config() + + assert result == {"method": "jwt"} + + def test_auth_config_jwt_uppercase_raises_error( + self, + temp_config_file, + ): + """Test that 'JWT' (uppercase) is not recognized (case-sensitive check).""" + # Note: The current implementation uses 'jwt' in res, which is case-sensitive + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "JWT"}) + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "test-secret", + "JWT_ALGORITHM": "HS256", + }, + ): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Unsupported auth method" in str(exc_info.value) + + # ========================================================================= + # Test: JWT auth with missing environment variables + # ========================================================================= + def test_auth_config_raises_error_when_jwt_secret_key_missing( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when JWT_SECRET_KEY is missing.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict( + os.environ, + {"JWT_ALGORITHM": "HS256"}, + clear=True, + ): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "JWT_SECRET_KEY" in str(exc_info.value) + assert "JWT_ALGORITHM" in str(exc_info.value) + + def test_auth_config_raises_error_when_jwt_algorithm_missing( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when JWT_ALGORITHM is missing.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict( + os.environ, + {"JWT_SECRET_KEY": "test-secret"}, + clear=True, + ): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "JWT_SECRET_KEY" in str(exc_info.value) + assert "JWT_ALGORITHM" in str(exc_info.value) + + def test_auth_config_raises_error_when_both_jwt_env_vars_missing( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when both JWT env vars are missing.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict(os.environ, {}, clear=True): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "JWT_SECRET_KEY" in str(exc_info.value) + assert "JWT_ALGORITHM" in str(exc_info.value) + + def test_auth_config_raises_error_when_jwt_secret_key_empty( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when JWT_SECRET_KEY is empty string.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "", # Empty string + "JWT_ALGORITHM": "HS256", + }, + ): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "JWT_SECRET_KEY" in str(exc_info.value) + + def test_auth_config_raises_error_when_jwt_algorithm_empty( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when JWT_ALGORITHM is empty string.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "test-secret", + "JWT_ALGORITHM": "", # Empty string + }, + ): + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "JWT_ALGORITHM" in str(exc_info.value) + + # ========================================================================= + # Test: Custom auth configuration + # ========================================================================= + def test_auth_config_returns_custom_method_when_custom_config_valid( + self, + temp_config_file, + tmp_path: Path, + ): + """Test that auth_config returns custom method when custom config is valid.""" + # Create a temporary custom auth file + custom_auth_path = tmp_path / "custom_auth.py" + custom_auth_path.write_text("# Custom auth module") + + config_path = temp_config_file( + { + "agent": "path/to/agent.py", + "auth": {"method": "custom", "path": str(custom_auth_path)}, + } + ) + graph_config = GraphConfig(path=config_path) + + result = graph_config.auth_config() + + assert result == {"method": "custom", "path": str(custom_auth_path)} + + def test_auth_config_raises_error_when_custom_path_not_exists( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when custom path doesn't exist.""" + config_path = temp_config_file( + { + "agent": "path/to/agent.py", + "auth": {"method": "custom", "path": "/nonexistent/path/auth.py"}, + } + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Unsupported auth method" in str(exc_info.value) + + def test_auth_config_raises_error_when_dict_missing_method( + self, + temp_config_file, + tmp_path: Path, + ): + """Test that auth_config raises ValueError when dict is missing method.""" + custom_auth_path = tmp_path / "custom_auth.py" + custom_auth_path.write_text("# Custom auth module") + + config_path = temp_config_file( + { + "agent": "path/to/agent.py", + "auth": {"path": str(custom_auth_path)}, # Missing method + } + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Both method and path must be provided" in str(exc_info.value) + + def test_auth_config_raises_error_when_dict_missing_path( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError when dict is missing path.""" + config_path = temp_config_file( + { + "agent": "path/to/agent.py", + "auth": {"method": "custom"}, # Missing path + } + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Both method and path must be provided" in str(exc_info.value) + + # ========================================================================= + # Test: Invalid auth configuration + # ========================================================================= + def test_auth_config_raises_error_for_unsupported_string( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError for unsupported string.""" + config_path = temp_config_file( + {"agent": "path/to/agent.py", "auth": "oauth2"} # Not supported + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Unsupported auth method" in str(exc_info.value) + + def test_auth_config_raises_error_for_invalid_type( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError for invalid auth type.""" + config_path = temp_config_file( + {"agent": "path/to/agent.py", "auth": 123} # Invalid type + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Unsupported auth method" in str(exc_info.value) + + def test_auth_config_raises_error_for_list_type( + self, + temp_config_file, + ): + """Test that auth_config raises ValueError for list auth type.""" + config_path = temp_config_file( + {"agent": "path/to/agent.py", "auth": ["jwt", "oauth2"]} # Invalid type + ) + graph_config = GraphConfig(path=config_path) + + with pytest.raises(ValueError) as exc_info: + graph_config.auth_config() + + assert "Unsupported auth method" in str(exc_info.value) + + +class TestGraphConfigJwtEnvLoading: + """Test that GraphConfig properly validates JWT configuration during auth_config call.""" + + @pytest.fixture + def temp_config_file(self, tmp_path: Path): + """Create a temporary config file and return a function to write to it.""" + + def _create_config(config_data: dict) -> str: + config_path = tmp_path / "agentflow.json" + with open(config_path, "w") as f: + json.dump(config_data, f) + return str(config_path) + + return _create_config + + def test_jwt_env_vars_are_validated_at_auth_config_call_time( + self, + temp_config_file, + ): + """Test that JWT env vars are validated when auth_config is called, not at init.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + # GraphConfig init should not fail even without env vars + with patch.dict(os.environ, {}, clear=True): + graph_config = GraphConfig(path=config_path) + # This should succeed - no validation at init time + + # But auth_config should fail + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError): + graph_config.auth_config() + + def test_jwt_works_when_env_vars_set_after_init( + self, + temp_config_file, + ): + """Test that JWT works when env vars are set after GraphConfig init.""" + config_path = temp_config_file({"agent": "path/to/agent.py", "auth": "jwt"}) + + # Init without env vars + with patch.dict(os.environ, {}, clear=True): + graph_config = GraphConfig(path=config_path) + + # Set env vars after init + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": "late-set-secret", + "JWT_ALGORITHM": "HS256", + }, + ): + result = graph_config.auth_config() + assert result == {"method": "jwt"} diff --git a/tests/unit_tests/auth/test_jwt_auth.py b/tests/unit_tests/auth/test_jwt_auth.py new file mode 100644 index 0000000..898188d --- /dev/null +++ b/tests/unit_tests/auth/test_jwt_auth.py @@ -0,0 +1,679 @@ +""" +Comprehensive unit tests for JwtAuth class. + +These tests cover all edge cases and scenarios for JWT authentication: +- Null credentials +- Missing JWT configuration (secret key, algorithm) +- Expired tokens +- Invalid/malformed tokens +- Valid tokens without user_id +- Valid tokens with user_id (successful auth) +- Bearer prefix handling +- WWW-Authenticate header setting +""" + +import os +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import jwt +import pytest +from fastapi import Response +from fastapi.security import HTTPAuthorizationCredentials + +from agentflow_cli.src.app.core.auth.jwt_auth import JwtAuth +from agentflow_cli.src.app.core.exceptions.user_exception import UserAccountError + + +# Test constants +TEST_SECRET_KEY = "test-super-secret-key-for-testing-purposes" +TEST_ALGORITHM = "HS256" + + +class TestJwtAuth: + """Test suite for JwtAuth.authenticate method.""" + + @pytest.fixture + def jwt_auth(self) -> JwtAuth: + """Create a JwtAuth instance for testing.""" + return JwtAuth() + + @pytest.fixture + def mock_response(self) -> Response: + """Create a mock FastAPI Response object.""" + return Response() + + @pytest.fixture + def valid_token_payload(self) -> dict: + """Create a valid JWT payload with user_id.""" + return { + "user_id": "user-123", + "email": "test@example.com", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + "iat": datetime.now(timezone.utc), + } + + @pytest.fixture + def jwt_env_vars(self): + """Set up JWT environment variables for tests.""" + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": TEST_SECRET_KEY, + "JWT_ALGORITHM": TEST_ALGORITHM, + }, + ): + yield + + def create_token( + self, + payload: dict, + secret: str = TEST_SECRET_KEY, + algorithm: str = TEST_ALGORITHM, + ) -> str: + """Helper method to create a JWT token.""" + return jwt.encode(payload, secret, algorithm=algorithm) + + def create_credentials(self, token: str) -> HTTPAuthorizationCredentials: + """Helper method to create HTTPAuthorizationCredentials.""" + return HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + + # ========================================================================= + # Test: Null credentials + # ========================================================================= + def test_authenticate_with_null_credentials_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that null credentials raise UserAccountError with REVOKED_TOKEN.""" + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, None) + + assert exc_info.value.error_code == "REVOKED_TOKEN" + assert "Invalid token" in exc_info.value.message + assert exc_info.value.status_code == 403 + + # ========================================================================= + # Test: Missing JWT settings + # ========================================================================= + def test_authenticate_missing_jwt_secret_key_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that missing JWT_SECRET_KEY raises UserAccountError.""" + with patch.dict( + os.environ, + {"JWT_ALGORITHM": TEST_ALGORITHM}, + clear=True, + ): + credentials = self.create_credentials("some-token") + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" + assert "JWT settings are not configured" in exc_info.value.message + + def test_authenticate_missing_jwt_algorithm_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that missing JWT_ALGORITHM raises UserAccountError.""" + with patch.dict( + os.environ, + {"JWT_SECRET_KEY": TEST_SECRET_KEY}, + clear=True, + ): + credentials = self.create_credentials("some-token") + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" + assert "JWT settings are not configured" in exc_info.value.message + + def test_authenticate_missing_both_jwt_settings_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that missing both JWT settings raises UserAccountError.""" + with patch.dict(os.environ, {}, clear=True): + credentials = self.create_credentials("some-token") + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" + + # ========================================================================= + # Test: Expired token + # ========================================================================= + def test_authenticate_expired_token_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that expired token raises UserAccountError with EXPIRED_TOKEN.""" + expired_payload = { + "user_id": "user-123", + "exp": datetime.now(timezone.utc) - timedelta(hours=1), # Expired 1 hour ago + "iat": datetime.now(timezone.utc) - timedelta(hours=2), + } + token = self.create_token(expired_payload) + credentials = self.create_credentials(token) + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "EXPIRED_TOKEN" + assert "Token has expired" in exc_info.value.message + + # ========================================================================= + # Test: Invalid token + # ========================================================================= + def test_authenticate_malformed_token_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that malformed token raises UserAccountError with INVALID_TOKEN.""" + credentials = self.create_credentials("not-a-valid-jwt-token") + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "INVALID_TOKEN" + assert "Invalid token" in exc_info.value.message + + def test_authenticate_token_with_wrong_secret_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that token signed with wrong secret raises UserAccountError.""" + payload = { + "user_id": "user-123", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + # Sign with a different secret + token = self.create_token(payload, secret="wrong-secret-key") + credentials = self.create_credentials(token) + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "INVALID_TOKEN" + + def test_authenticate_token_with_wrong_algorithm_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that token signed with wrong algorithm raises UserAccountError.""" + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": TEST_SECRET_KEY, + "JWT_ALGORITHM": "HS384", # Different from token's HS256 + }, + ): + payload = { + "user_id": "user-123", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + # Token signed with HS256 but server expects HS384 + token = self.create_token(payload, algorithm="HS256") + credentials = self.create_credentials(token) + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "INVALID_TOKEN" + + def test_authenticate_empty_token_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that empty token raises UserAccountError.""" + credentials = self.create_credentials("") + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "INVALID_TOKEN" + + # ========================================================================= + # Test: Token without user_id + # ========================================================================= + def test_authenticate_token_without_user_id_raises_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that valid token without user_id raises UserAccountError.""" + payload_without_user_id = { + "email": "test@example.com", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + "iat": datetime.now(timezone.utc), + } + token = self.create_token(payload_without_user_id) + credentials = self.create_credentials(token) + + with pytest.raises(UserAccountError) as exc_info: + jwt_auth.authenticate(mock_response, credentials) + + assert exc_info.value.error_code == "INVALID_TOKEN" + assert "user_id missing" in exc_info.value.message + + # ========================================================================= + # Test: Successful authentication + # ========================================================================= + def test_authenticate_valid_token_returns_decoded_payload( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that valid token with user_id returns decoded payload.""" + token = self.create_token(valid_token_payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + assert result["email"] == "test@example.com" + + def test_authenticate_returns_all_custom_claims( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that all custom claims in token are returned.""" + payload = { + "user_id": "user-456", + "email": "custom@example.com", + "role": "admin", + "permissions": ["read", "write", "delete"], + "organization_id": "org-789", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result["user_id"] == "user-456" + assert result["email"] == "custom@example.com" + assert result["role"] == "admin" + assert result["permissions"] == ["read", "write", "delete"] + assert result["organization_id"] == "org-789" + + # ========================================================================= + # Test: Bearer prefix handling + # ========================================================================= + def test_authenticate_strips_bearer_prefix_lowercase( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that 'bearer ' prefix (lowercase) is stripped from token.""" + actual_token = self.create_token(valid_token_payload) + token_with_prefix = f"bearer {actual_token}" + credentials = self.create_credentials(token_with_prefix) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + + def test_authenticate_strips_bearer_prefix_uppercase( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that 'Bearer ' prefix (capitalized) is stripped from token.""" + actual_token = self.create_token(valid_token_payload) + token_with_prefix = f"Bearer {actual_token}" + credentials = self.create_credentials(token_with_prefix) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + + def test_authenticate_strips_bearer_prefix_mixed_case( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that 'BEARER ' prefix (mixed case) is stripped from token.""" + actual_token = self.create_token(valid_token_payload) + token_with_prefix = f"BEARER {actual_token}" + credentials = self.create_credentials(token_with_prefix) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + + def test_authenticate_works_without_bearer_prefix( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that token without bearer prefix still works.""" + token = self.create_token(valid_token_payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + + # ========================================================================= + # Test: WWW-Authenticate header + # ========================================================================= + def test_authenticate_sets_www_authenticate_header( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + valid_token_payload: dict, + ): + """Test that WWW-Authenticate header is set on successful auth.""" + token = self.create_token(valid_token_payload) + credentials = self.create_credentials(token) + + jwt_auth.authenticate(mock_response, credentials) + + assert "WWW-Authenticate" in mock_response.headers + assert mock_response.headers["WWW-Authenticate"] == 'Bearer realm="auth_required"' + + # ========================================================================= + # Test: Edge cases + # ========================================================================= + def test_authenticate_with_minimal_valid_token( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test authentication with minimal valid token (just user_id).""" + minimal_payload = {"user_id": "minimal-user"} + token = self.create_token(minimal_payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "minimal-user" + + def test_authenticate_with_numeric_user_id( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that numeric user_id in token works correctly.""" + payload = { + "user_id": 12345, + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == 12345 + + def test_authenticate_with_uuid_user_id( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that UUID user_id in token works correctly.""" + uuid_user_id = "550e8400-e29b-41d4-a716-446655440000" + payload = { + "user_id": uuid_user_id, + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == uuid_user_id + + def test_authenticate_token_about_to_expire_still_valid( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that token about to expire (in 1 second) is still valid.""" + payload = { + "user_id": "user-123", + "exp": datetime.now(timezone.utc) + timedelta(seconds=30), + } + token = self.create_token(payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-123" + + def test_authenticate_with_special_characters_in_user_id( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that special characters in user_id work correctly.""" + special_user_id = "user+test@example.com" + payload = { + "user_id": special_user_id, + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload) + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == special_user_id + + # ========================================================================= + # Test: Different algorithms + # ========================================================================= + def test_authenticate_with_hs384_algorithm( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test authentication with HS384 algorithm.""" + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": TEST_SECRET_KEY, + "JWT_ALGORITHM": "HS384", + }, + ): + payload = { + "user_id": "user-hs384", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload, algorithm="HS384") + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-hs384" + + def test_authenticate_with_hs512_algorithm( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test authentication with HS512 algorithm.""" + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": TEST_SECRET_KEY, + "JWT_ALGORITHM": "HS512", + }, + ): + payload = { + "user_id": "user-hs512", + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + token = self.create_token(payload, algorithm="HS512") + credentials = self.create_credentials(token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result is not None + assert result["user_id"] == "user-hs512" + + # ========================================================================= + # Test: Logger is called on InvalidTokenError + # ========================================================================= + def test_authenticate_logs_invalid_token_error( + self, + jwt_auth: JwtAuth, + mock_response: Response, + jwt_env_vars, + ): + """Test that logger.exception is called when InvalidTokenError occurs.""" + with patch("agentflow_cli.src.app.core.auth.jwt_auth.logger") as mock_logger: + credentials = self.create_credentials("invalid-token") + + with pytest.raises(UserAccountError): + jwt_auth.authenticate(mock_response, credentials) + + mock_logger.exception.assert_called_once() + call_args = mock_logger.exception.call_args + assert "JWT AUTH ERROR" in call_args[0][0] + + +class TestJwtAuthIntegration: + """Integration tests for JwtAuth with real JWT encoding/decoding.""" + + @pytest.fixture + def jwt_auth(self) -> JwtAuth: + """Create a JwtAuth instance for testing.""" + return JwtAuth() + + @pytest.fixture + def mock_response(self) -> Response: + """Create a mock FastAPI Response object.""" + return Response() + + def test_full_authentication_flow(self, jwt_auth: JwtAuth, mock_response: Response): + """Test complete authentication flow from token creation to validation.""" + secret = "integration-test-secret-key" + algorithm = "HS256" + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": secret, + "JWT_ALGORITHM": algorithm, + }, + ): + # Create a realistic token payload + # Note: We skip 'aud' claim because PyJWT validates it by default + # and the current JwtAuth implementation doesn't configure audience + payload = { + "user_id": "user-integration-test", + "email": "integration@test.com", + "name": "Integration Test User", + "role": "developer", + "iat": datetime.now(timezone.utc), + "exp": datetime.now(timezone.utc) + timedelta(hours=24), + "iss": "test-issuer", + } + + # Encode token + token = jwt.encode(payload, secret, algorithm=algorithm) + credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + + # Authenticate + result = jwt_auth.authenticate(mock_response, credentials) + + # Verify all claims are returned + assert result["user_id"] == "user-integration-test" + assert result["email"] == "integration@test.com" + assert result["name"] == "Integration Test User" + assert result["role"] == "developer" + assert result["iss"] == "test-issuer" + + # Verify header is set + assert mock_response.headers["WWW-Authenticate"] == 'Bearer realm="auth_required"' + + def test_token_roundtrip_with_complex_payload( + self, + jwt_auth: JwtAuth, + mock_response: Response, + ): + """Test that complex payloads survive the encode/decode roundtrip.""" + secret = "complex-payload-test-secret" + algorithm = "HS256" + + with patch.dict( + os.environ, + { + "JWT_SECRET_KEY": secret, + "JWT_ALGORITHM": algorithm, + }, + ): + complex_payload = { + "user_id": "complex-user", + "metadata": { + "nested": { + "deeply": "nested value", + }, + }, + "tags": ["tag1", "tag2", "tag3"], + "count": 42, + "active": True, + "nullable": None, + "exp": datetime.now(timezone.utc) + timedelta(hours=1), + } + + token = jwt.encode(complex_payload, secret, algorithm=algorithm) + credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) + + result = jwt_auth.authenticate(mock_response, credentials) + + assert result["user_id"] == "complex-user" + assert result["metadata"]["nested"]["deeply"] == "nested value" + assert result["tags"] == ["tag1", "tag2", "tag3"] + assert result["count"] == 42 + assert result["active"] is True + assert result["nullable"] is None