From d13a7dc0332b290b059d388b928df5b08686af12 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Mon, 1 Dec 2025 21:37:03 +0600 Subject: [PATCH 1/2] feat: Pass request and response objects to authentication methods for improved handling --- agentflow_cli/src/app/core/auth/auth_backend.py | 11 ++++++++--- agentflow_cli/src/app/core/auth/base_auth.py | 7 +++++-- agentflow_cli/src/app/core/auth/jwt_auth.py | 10 +++++++--- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/agentflow_cli/src/app/core/auth/auth_backend.py b/agentflow_cli/src/app/core/auth/auth_backend.py index 91fc499..2a50475 100644 --- a/agentflow_cli/src/app/core/auth/auth_backend.py +++ b/agentflow_cli/src/app/core/auth/auth_backend.py @@ -1,6 +1,6 @@ from typing import Any -from fastapi import Depends, Response +from fastapi import Depends, Request, Response from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from injectq.integrations import InjectAPI @@ -10,7 +10,8 @@ def verify_current_user( - res: Response, + request: Request, + response: Response, credential: HTTPAuthorizationCredentials = Depends( HTTPBearer(auto_error=False), ), @@ -27,7 +28,11 @@ def verify_current_user( logger.error("Auth backend is not configured") return user - user: dict | None = auth_backend.authenticate(res, credential) + user: dict | None = auth_backend.authenticate( + request, + response, + credential, + ) if user and "user_id" not in user: logger.error("Authentication failed: 'user_id' not found in user info") return user or {} diff --git a/agentflow_cli/src/app/core/auth/base_auth.py b/agentflow_cli/src/app/core/auth/base_auth.py index 8fb0323..6102bd8 100644 --- a/agentflow_cli/src/app/core/auth/base_auth.py +++ b/agentflow_cli/src/app/core/auth/base_auth.py @@ -1,14 +1,17 @@ from abc import ABC, abstractmethod from typing import Any -from fastapi import Response +from fastapi import Request, Response from fastapi.security import HTTPAuthorizationCredentials class BaseAuth(ABC): @abstractmethod def authenticate( - self, res: Response, credential: HTTPAuthorizationCredentials + self, + request: Request, + response: Response, + credential: HTTPAuthorizationCredentials, ) -> dict[str, Any] | None: """Authenticate the user based on the provided credentials. IT should return an empty dict if no authentication is required. diff --git a/agentflow_cli/src/app/core/auth/jwt_auth.py b/agentflow_cli/src/app/core/auth/jwt_auth.py index 020fea0..7937666 100644 --- a/agentflow_cli/src/app/core/auth/jwt_auth.py +++ b/agentflow_cli/src/app/core/auth/jwt_auth.py @@ -2,7 +2,7 @@ from typing import Any import jwt -from fastapi import Response +from fastapi import Request, Response from fastapi.security import HTTPAuthorizationCredentials from agentflow_cli.src.app.core import logger @@ -12,7 +12,10 @@ class JwtAuth(BaseAuth): def authenticate( - self, res: Response, credential: HTTPAuthorizationCredentials + self, + request: Request, + response: Response, + credential: HTTPAuthorizationCredentials, ) -> dict[str, Any] | None: """No authentication is required, so return None.""" """ @@ -72,7 +75,8 @@ def authenticate( message="Invalid token, please login again", error_code="INVALID_TOKEN", ) - res.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"' + + response.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"' # check if user_id exists in the token if "user_id" not in decoded_token: From 81c5a74a10e373f911f3c2b132ea397be15d3ab7 Mon Sep 17 00:00:00 2001 From: Shudipto Trafder Date: Mon, 1 Dec 2025 21:53:33 +0600 Subject: [PATCH 2/2] feat: Update authentication methods to accept request objects for improved handling --- tests/unit_tests/auth/test_auth_backend.py | 39 +++++++++----- tests/unit_tests/auth/test_jwt_auth.py | 63 +++++++++++----------- 2 files changed, 58 insertions(+), 44 deletions(-) diff --git a/tests/unit_tests/auth/test_auth_backend.py b/tests/unit_tests/auth/test_auth_backend.py index 9f71c63..025777e 100644 --- a/tests/unit_tests/auth/test_auth_backend.py +++ b/tests/unit_tests/auth/test_auth_backend.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock, patch import pytest -from fastapi import Response +from fastapi import Request, Response from fastapi.security import HTTPAuthorizationCredentials from agentflow_cli.src.app.core.auth.auth_backend import verify_current_user @@ -25,7 +25,7 @@ def __init__(self, return_value: dict[str, Any] | None = None, raise_exception: self._raise_exception = raise_exception def authenticate( - self, res: Response, credential: HTTPAuthorizationCredentials + self, request: Request | None, res: Response, credential: HTTPAuthorizationCredentials ) -> dict[str, Any] | None: if self._raise_exception: raise ValueError("Authentication failed") @@ -69,7 +69,8 @@ def test_returns_empty_dict_when_no_auth_backend_configured( mock_auth_backend = MockBaseAuth(return_value={"user_id": "123"}) result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_no_auth, auth_backend=mock_auth_backend, @@ -86,7 +87,8 @@ def test_returns_empty_dict_when_auth_backend_is_none( """Test that empty dict is returned when auth_backend is None.""" with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=None, @@ -110,7 +112,8 @@ def test_returns_user_dict_on_successful_authentication( mock_auth_backend = MockBaseAuth(return_value=expected_user) result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -129,7 +132,8 @@ def test_returns_empty_dict_when_authenticate_returns_none( mock_auth_backend = MockBaseAuth(return_value=None) result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -149,7 +153,8 @@ def test_logs_error_when_user_dict_missing_user_id( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -173,7 +178,8 @@ def test_does_not_log_error_when_user_dict_has_user_id( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -193,7 +199,8 @@ def test_does_not_log_error_when_authenticate_returns_empty_dict( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -216,7 +223,8 @@ def test_returns_user_with_numeric_user_id( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -238,7 +246,8 @@ def test_returns_user_with_empty_string_user_id( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -260,7 +269,8 @@ def test_returns_user_with_none_user_id( with patch("agentflow_cli.src.app.core.auth.auth_backend.logger") as mock_logger: result = verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=mock_credentials, config=mock_graph_config_jwt_auth, auth_backend=mock_auth_backend, @@ -297,11 +307,12 @@ def test_passes_null_credentials_to_auth_backend( mock_auth.authenticate.return_value = {"user_id": "123"} verify_current_user( - res=mock_response, + request=None, + response=mock_response, credential=None, # Null credentials config=mock_graph_config_jwt_auth, auth_backend=mock_auth, ) # Verify authenticate was called with None credentials - mock_auth.authenticate.assert_called_once_with(mock_response, None) + mock_auth.authenticate.assert_called_once_with(None, mock_response, None) diff --git a/tests/unit_tests/auth/test_jwt_auth.py b/tests/unit_tests/auth/test_jwt_auth.py index 898188d..e48a384 100644 --- a/tests/unit_tests/auth/test_jwt_auth.py +++ b/tests/unit_tests/auth/test_jwt_auth.py @@ -13,7 +13,7 @@ """ import os -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from unittest.mock import MagicMock, patch import jwt @@ -88,7 +88,7 @@ def test_authenticate_with_null_credentials_raises_error( ): """Test that null credentials raise UserAccountError with REVOKED_TOKEN.""" with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, None) + jwt_auth.authenticate(None, mock_response, None) assert exc_info.value.error_code == "REVOKED_TOKEN" assert "Invalid token" in exc_info.value.message @@ -111,7 +111,7 @@ def test_authenticate_missing_jwt_secret_key_raises_error( credentials = self.create_credentials("some-token") with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" assert "JWT settings are not configured" in exc_info.value.message @@ -130,7 +130,7 @@ def test_authenticate_missing_jwt_algorithm_raises_error( credentials = self.create_credentials("some-token") with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" assert "JWT settings are not configured" in exc_info.value.message @@ -145,7 +145,7 @@ def test_authenticate_missing_both_jwt_settings_raises_error( credentials = self.create_credentials("some-token") with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "JWT_SETTINGS_NOT_CONFIGURED" @@ -168,7 +168,7 @@ def test_authenticate_expired_token_raises_error( credentials = self.create_credentials(token) with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "EXPIRED_TOKEN" assert "Token has expired" in exc_info.value.message @@ -186,7 +186,7 @@ def test_authenticate_malformed_token_raises_error( credentials = self.create_credentials("not-a-valid-jwt-token") with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "INVALID_TOKEN" assert "Invalid token" in exc_info.value.message @@ -207,7 +207,7 @@ def test_authenticate_token_with_wrong_secret_raises_error( credentials = self.create_credentials(token) with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "INVALID_TOKEN" @@ -233,7 +233,7 @@ def test_authenticate_token_with_wrong_algorithm_raises_error( credentials = self.create_credentials(token) with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "INVALID_TOKEN" @@ -247,7 +247,7 @@ def test_authenticate_empty_token_raises_error( credentials = self.create_credentials("") with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "INVALID_TOKEN" @@ -270,7 +270,7 @@ def test_authenticate_token_without_user_id_raises_error( credentials = self.create_credentials(token) with pytest.raises(UserAccountError) as exc_info: - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert exc_info.value.error_code == "INVALID_TOKEN" assert "user_id missing" in exc_info.value.message @@ -289,7 +289,7 @@ def test_authenticate_valid_token_returns_decoded_payload( token = self.create_token(valid_token_payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -313,7 +313,7 @@ def test_authenticate_returns_all_custom_claims( token = self.create_token(payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result["user_id"] == "user-456" assert result["email"] == "custom@example.com" @@ -336,7 +336,7 @@ def test_authenticate_strips_bearer_prefix_lowercase( token_with_prefix = f"bearer {actual_token}" credentials = self.create_credentials(token_with_prefix) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -353,7 +353,7 @@ def test_authenticate_strips_bearer_prefix_uppercase( token_with_prefix = f"Bearer {actual_token}" credentials = self.create_credentials(token_with_prefix) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -370,7 +370,7 @@ def test_authenticate_strips_bearer_prefix_mixed_case( token_with_prefix = f"BEARER {actual_token}" credentials = self.create_credentials(token_with_prefix) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -386,7 +386,7 @@ def test_authenticate_works_without_bearer_prefix( token = self.create_token(valid_token_payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -404,8 +404,7 @@ def test_authenticate_sets_www_authenticate_header( """Test that WWW-Authenticate header is set on successful auth.""" token = self.create_token(valid_token_payload) credentials = self.create_credentials(token) - - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) assert "WWW-Authenticate" in mock_response.headers assert mock_response.headers["WWW-Authenticate"] == 'Bearer realm="auth_required"' @@ -424,7 +423,7 @@ def test_authenticate_with_minimal_valid_token( token = self.create_token(minimal_payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "minimal-user" @@ -443,7 +442,7 @@ def test_authenticate_with_numeric_user_id( token = self.create_token(payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == 12345 @@ -458,12 +457,12 @@ def test_authenticate_with_uuid_user_id( uuid_user_id = "550e8400-e29b-41d4-a716-446655440000" payload = { "user_id": uuid_user_id, - "exp": datetime.now(timezone.utc) + timedelta(hours=1), + "exp": datetime.now(UTC) + timedelta(hours=1), } token = self.create_token(payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == uuid_user_id @@ -482,7 +481,7 @@ def test_authenticate_token_about_to_expire_still_valid( token = self.create_token(payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-123" @@ -502,7 +501,11 @@ def test_authenticate_with_special_characters_in_user_id( token = self.create_token(payload) credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate( + None, + mock_response, + credentials, + ) assert result is not None assert result["user_id"] == special_user_id @@ -530,7 +533,7 @@ def test_authenticate_with_hs384_algorithm( token = self.create_token(payload, algorithm="HS384") credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-hs384" @@ -555,7 +558,7 @@ def test_authenticate_with_hs512_algorithm( token = self.create_token(payload, algorithm="HS512") credentials = self.create_credentials(token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result is not None assert result["user_id"] == "user-hs512" @@ -574,7 +577,7 @@ def test_authenticate_logs_invalid_token_error( credentials = self.create_credentials("invalid-token") with pytest.raises(UserAccountError): - jwt_auth.authenticate(mock_response, credentials) + jwt_auth.authenticate(None, mock_response, credentials) mock_logger.exception.assert_called_once() call_args = mock_logger.exception.call_args @@ -624,7 +627,7 @@ def test_full_authentication_flow(self, jwt_auth: JwtAuth, mock_response: Respon credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) # Authenticate - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) # Verify all claims are returned assert result["user_id"] == "user-integration-test" @@ -669,7 +672,7 @@ def test_token_roundtrip_with_complex_payload( token = jwt.encode(complex_payload, secret, algorithm=algorithm) credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) - result = jwt_auth.authenticate(mock_response, credentials) + result = jwt_auth.authenticate(None, mock_response, credentials) assert result["user_id"] == "complex-user" assert result["metadata"]["nested"]["deeply"] == "nested value"