diff --git a/README.md b/README.md index d0cefb7..74d5ae0 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,12 @@ The SDK handles per-domain OIDC discovery, JWKS fetching, issuer validation, and For more details and examples, see [examples/MultipleCustomDomains.md](examples/MultipleCustomDomains.md). +### 6. Session Expiry from the Upstream IdP + +For enterprise connections, the upstream identity provider can cap how long a user's session lives. When the connection is configured to honor it, Auth0 includes a `session_expiry` claim in the ID token, and the SDK enforces this ceiling on every session read. Once it is reached, `get_user()` and `get_session()` return `None`, and `get_access_token()` raises an `AccessTokenError` with code `session_expired`. If the asserted ceiling is already in the past at login, `complete_interactive_login()` raises a `SessionExpiredError` instead of persisting an already-expired session. + +For more details and examples, see [examples/RetrievingData.md](examples/RetrievingData.md#session-expiry-from-the-upstream-idp). + ## Feedback ### Contributing diff --git a/examples/RetrievingData.md b/examples/RetrievingData.md index 88fb610..d1fcc11 100644 --- a/examples/RetrievingData.md +++ b/examples/RetrievingData.md @@ -70,6 +70,65 @@ access_token = await server_client.get_access_token(store_options=store_options) Read more above in [Configuring the Store](./ConfigureStore.md). +## Session Expiry from the Upstream IdP + +For enterprise connections, the upstream identity provider can impose a ceiling on how long the user's session may live. This ceiling is delivered to the SDK as a `session_expiry` claim (an absolute Unix timestamp, **in seconds**) on the ID token. The SDK reads this value at login, stores it with the session, and enforces it on every subsequent read. + +### Emitting the claim + +The `session_expiry` claim is set by a Post-Login Action on your tenant, and **must** be an absolute Unix timestamp in **seconds**, not milliseconds. For the canonical Action setup, see the [Auth0 documentation](https://auth0.com/docs) (will be adding the link to the session_expiry Action guide once published). + +> [!WARNING] +> `session_expiry` is interpreted as **seconds** since the Unix epoch (per RFC 7519 `NumericDate`). A common mistake is emitting milliseconds (e.g. `getTime()` without `/ 1000`). The SDK rejects implausibly large values (anything at or above `10_000_000_000`, ≈ year 2286) as malformed and treats them as **no ceiling**, so a milliseconds value will silently disable enforcement rather than expiring the session ~55,000 years from now. Always divide by 1000. +> +> Because the claim is authored by your Action (untrusted input), the SDK **fails open** on any malformed value — a non-numeric, zero, negative, boolean, or millisecond value is treated as "no ceiling" and login proceeds normally. Only a clean, future, seconds timestamp is enforced. + +Once the ceiling is reached, the read methods behave as follows: + +- `get_user()` returns `None`, as if no session exists. +- `get_session()` returns `None`, as if no session exists. +- `get_access_token()` raises an `AccessTokenError` with code `session_expired`. + +`get_access_token_for_connection()` (Token Vault) is **not** gated by the session ceiling — connection tokens follow the upstream IdP's own `expires_in`, so they remain retrievable from cache even after the session ceiling has passed. + +```python +from auth0_server_python.error import AccessTokenError, AccessTokenErrorCode + +try: + access_token = await server_client.get_access_token(store_options=store_options) +except AccessTokenError as error: + if error.code == AccessTokenErrorCode.SESSION_EXPIRED: + # The upstream session ceiling has been reached; start a new login. + ... +``` + +When the ceiling is reached, the SDK deletes the stored session before returning, so the next request starts clean. + +If the upstream IdP asserts a ceiling that is already in the past at login time, `complete_interactive_login()` raises a `SessionExpiredError` rather than persisting an already-expired session: + +```python +from auth0_server_python.error import SessionExpiredError + +try: + await server_client.complete_interactive_login(url, store_options=store_options) +except SessionExpiredError: + # The session was already past its ceiling on arrival; start a new login. + ... +``` + +> [!NOTE] +> **Upgrading:** with this feature enabled, `get_user()` and `get_session()` can return `None` for a user who was previously logged in, once the upstream ceiling passes. Applications that assumed these always return a value after login should add a null check and route the user back through login. + +The `session_expiry` value is also surfaced through the user claims, so you can read it without triggering enforcement: + +```python +user = await server_client.get_user(store_options=store_options) +session_expires_at = (user or {}).get("session_expiry") +``` + +> [!NOTE] +> Enforcement applies a small negative leeway (30 seconds) to account for clock skew, so a session is treated as expired slightly before the exact `session_expiry` timestamp. The refresh-token grant preserves the original ceiling - refreshing an access token does not extend the upstream session. + ## Multi-Resource Refresh Tokens (MRRT) Multi-Resource Refresh Tokens allow using a single refresh token to obtain access tokens for multiple audiences, simplifying token management in applications that interact with multiple backend services. diff --git a/src/auth0_server_python/auth_server/server_client.py b/src/auth0_server_python/auth_server/server_client.py index 83c673d..befbcb6 100644 --- a/src/auth0_server_python/auth_server/server_client.py +++ b/src/auth0_server_python/auth_server/server_client.py @@ -56,6 +56,7 @@ MissingTransactionError, OrganizationTokenValidationError, PollingApiError, + SessionExpiredError, StartLinkUserError, ) from auth0_server_python.telemetry import Telemetry @@ -656,7 +657,12 @@ async def complete_interactive_login( # Use the userinfo field from the token_response for user claims user_info = token_response.get("userinfo") user_claims = None + # IPSIE session_expiry ceiling, read from the verified ID token claims. + session_expires_at = None + # ID token `iat`, used to detect a ceiling that is already past at login. + issued_at = None id_token = token_response.get("id_token") + expected_org = transaction_data.organization if not user_info and not id_token and expected_org: @@ -698,6 +704,8 @@ async def complete_interactive_login( validate_org_claims(claims, expected_org) user_claims = UserClaims.parse_obj(claims) + session_expires_at = user_claims.session_expiry + issued_at = claims.get("iat") except ValueError as e: raise ApiError("jwks_key_not_found", str(e)) except jwt.InvalidSignatureError as e: @@ -726,6 +734,11 @@ async def complete_interactive_login( ) + # Refuse to persist a session whose ceiling is already in the past. + if State.is_session_ceiling_in_past(session_expires_at, issued_at): + await self._transaction_store.delete(transaction_identifier, options=store_options) + raise SessionExpiredError() + # Build a token set using the token response data token_set = TokenSet( audience=transaction_data.audience or self.DEFAULT_AUDIENCE_STATE_KEY, @@ -749,7 +762,8 @@ async def complete_interactive_login( domain=origin_domain, internal={ "sid": sid, - "created_at": int(time.time()) + "created_at": int(time.time()), + "session_expires_at": session_expires_at } ) @@ -775,6 +789,23 @@ async def complete_interactive_login( # Methods for retrieving user information, session data, and logout operations. # ============================================================================ + async def _is_session_expired_by_ceiling( + self, state_data_dict: dict, store_options: Optional[dict[str, Any]] = None + ) -> bool: + """ + Enforce the IPSIE session_expiry ceiling on a session read. + + Returns True (and deletes the stored session) when the upstream + IdP-asserted ceiling has been reached. Sessions without a + session_expires_at value are never expired on this basis. + """ + internal = state_data_dict.get("internal") or {} + session_expires_at = internal.get("session_expires_at") + if State.is_session_ceiling_reached(session_expires_at): + await self._state_store.delete(self._state_identifier, options=store_options) + return True + return False + async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Optional[dict[str, Any]]: """ Retrieves the user from the store, or None if no user found. @@ -801,6 +832,10 @@ async def get_user(self, store_options: Optional[dict[str, Any]] = None) -> Opti if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None + # IPSIE: force re-auth once the upstream IdP session ceiling passes. + if await self._is_session_expired_by_ceiling(state_data, store_options): + return None + return state_data.get("user") return None @@ -830,6 +865,10 @@ async def get_session(self, store_options: Optional[dict[str, Any]] = None) -> O if self._normalize_url(session_domain) != self._normalize_url(current_domain): return None + # IPSIE: force re-auth once the upstream IdP session ceiling passes. + if await self._is_session_expired_by_ceiling(state_data, store_options): + return None + session_data = {k: v for k, v in state_data.items() if k != "internal"} return session_data @@ -1013,6 +1052,12 @@ async def get_access_token( merged_scope = self._merge_scope_with_defaults(scope, audience) + # Once the session ceiling has passed, fail instead of serving or refreshing a token. + internal = (state_data_dict or {}).get("internal") or {} + if State.is_session_ceiling_reached(internal.get("session_expires_at")): + await self._state_store.delete(self._state_identifier, options=store_options) + raise SessionExpiredError() + # Find matching token set token_set = None if state_data_dict and "token_sets" in state_data_dict: diff --git a/src/auth0_server_python/auth_types/__init__.py b/src/auth0_server_python/auth_types/__init__.py index fbfe2fe..1c6ad03 100644 --- a/src/auth0_server_python/auth_types/__init__.py +++ b/src/auth0_server_python/auth_types/__init__.py @@ -5,7 +5,10 @@ from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator + +# Upper bound (Unix seconds) for a plausible session_expiry +SESSION_EXPIRY_MAX_PLAUSIBLE = 10_000_000_000 class UserClaims(BaseModel): @@ -23,10 +26,21 @@ class UserClaims(BaseModel): email_verified: Optional[bool] = None org_id: Optional[str] = None org_name: Optional[str] = None + # IPSIE SL1 claim: upstream IdP-asserted RP session ceiling (Unix seconds). + session_expiry: Optional[int] = None class Config: extra = "allow" # Allow additional fields not defined in the model + @field_validator('session_expiry', mode='before') + @classmethod + def _sanitize_session_expiry(cls, value: Any) -> Optional[int]: + if isinstance(value, bool) or not isinstance(value, int): + return None + if value <= 0 or value >= SESSION_EXPIRY_MAX_PLAUSIBLE: + return None + return value + class TokenSet(BaseModel): """ @@ -55,6 +69,10 @@ class InternalStateData(BaseModel): """ sid: str created_at: int + # IPSIE session_expiry ceiling (Unix seconds), stamped at session creation + # from the ID token's session_expiry claim. None when the upstream IdP did + # not assert one — in which case existing session behavior is unchanged. + session_expires_at: Optional[int] = None class SessionData(BaseModel): diff --git a/src/auth0_server_python/error/__init__.py b/src/auth0_server_python/error/__init__.py index 407435f..3a6d01a 100644 --- a/src/auth0_server_python/error/__init__.py +++ b/src/auth0_server_python/error/__init__.py @@ -198,6 +198,7 @@ class AccessTokenErrorCode: INCORRECT_AUDIENCE = "incorrect_audience" MISSING_SESSION_DOMAIN = "missing_session_domain" DOMAIN_MISMATCH = "domain_mismatch" + SESSION_EXPIRED = "session_expired" class OrganizationTokenValidationError(Auth0Error): @@ -222,6 +223,19 @@ class AccessTokenForConnectionErrorCode: DOMAIN_MISMATCH = "domain_mismatch" +class SessionExpiredError(Auth0Error): + """ + Error raised when a session is rejected at login because its + session_expiry ceiling is already in the past. + """ + code = AccessTokenErrorCode.SESSION_EXPIRED + + def __init__(self, message: Optional[str] = None, cause=None): + super().__init__(message or "The session has expired and the user must re-authenticate.") + self.name = "SessionExpiredError" + self.cause = cause + + class CustomTokenExchangeError(Auth0Error): """ Error raised during custom token exchange operations. diff --git a/src/auth0_server_python/tests/test_server_client.py b/src/auth0_server_python/tests/test_server_client.py index cf4a1d1..05306c9 100644 --- a/src/auth0_server_python/tests/test_server_client.py +++ b/src/auth0_server_python/tests/test_server_client.py @@ -27,6 +27,7 @@ StartInteractiveLoginOptions, StateData, TransactionData, + UserClaims, ) from auth0_server_python.error import ( AccessTokenError, @@ -46,6 +47,7 @@ MissingTransactionError, OrganizationTokenValidationError, PollingApiError, + SessionExpiredError, StartLinkUserError, ) from auth0_server_python.utils import PKCE, State @@ -5183,7 +5185,9 @@ async def _fake_fetch(self, domain): ServerClient._fetch_oidc_metadata = original_fetch -# ORGANIZATIONS SUPPORT TESTS +# ============================================================================= +# Organization and Invitation Tests +# ============================================================================= def _make_org_client(mocker, transaction_data: TransactionData, **extra): """Helper: build a ServerClient with mocked stores and standard JWT mocks.""" @@ -6123,3 +6127,680 @@ async def capture_set(key, value, options=None): await client.start_interactive_login(StartInteractiveLoginOptions()) assert stored_tx.organization == "org_default" + +# ============================================================================= +# IPSIE session_expiry enforcement +# ============================================================================= + + +def test_is_session_ceiling_reached_none_never_expires(): + assert State.is_session_ceiling_reached(None) is False + + +def test_is_session_ceiling_reached_future_and_past(): + now = int(time.time()) + # Comfortably in the future (beyond the leeway window) -> not reached. + assert State.is_session_ceiling_reached(now + 3600) is False + # In the past -> reached. + assert State.is_session_ceiling_reached(now - 10) is True + + +def test_is_session_ceiling_reached_applies_negative_leeway(): + now = int(time.time()) + # Ceiling is 10s away but leeway is 30s, so it's treated as already reached. + assert State.is_session_ceiling_reached(now + 10) is True + + +def test_is_session_ceiling_in_past_none_is_safe_default(): + # No ceiling asserted -> never treated as expired. + assert State.is_session_ceiling_in_past(None, 1893456000) is False + assert State.is_session_ceiling_in_past(None, None) is False + + +def test_is_session_ceiling_in_past_past_ceiling_relative_to_iat(): + iat = 1893456000 + # Ceiling well before iat -> already lapsed at login. + assert State.is_session_ceiling_in_past(iat - 3600, iat) is True + + +def test_is_session_ceiling_in_past_future_ceiling_relative_to_iat(): + iat = 1893456000 + # Ceiling well after iat -> not lapsed. + assert State.is_session_ceiling_in_past(iat + 3600, iat) is False + + +def test_is_session_ceiling_in_past_falls_back_to_now_when_iat_absent(): + now = int(time.time()) + # No iat -> compare against wall-clock now; a past ceiling is lapsed. + assert State.is_session_ceiling_in_past(now - 100, None) is True + + +def test_is_session_ceiling_in_past_leeway_boundary(): + iat = 1893456000 + leeway = State.SESSION_EXPIRY_LEEWAY_SECONDS + # Ceiling exactly at iat + leeway is treated as already lapsed... + assert State.is_session_ceiling_in_past(iat + leeway, iat) is True + # ...one second beyond the leeway window is not. + assert State.is_session_ceiling_in_past(iat + leeway + 1, iat) is False + + +def test_session_expired_error_message_is_generic(): + message = str(SessionExpiredError()) + # States the reason without leaking any timestamps or values. + assert message == "The session has expired and the user must re-authenticate." + assert not any(ch.isdigit() for ch in message) + assert SessionExpiredError().code == AccessTokenErrorCode.SESSION_EXPIRED + + +@pytest.mark.parametrize("value,expected", [ + (1900000000, 1900000000), # plausible seconds -> kept + (1748566800000, None), # milliseconds -> rejected + (10_000_000_000, None), # at the implausible-future bound -> rejected + (0, None), # non-positive -> rejected + (-5, None), # negative -> rejected + (True, None), # bool is not a valid int here -> rejected + ("1748566800", None), # numeric string -> rejected + ("not-a-number", None), # garbage string -> rejected + (1.5, None), # float -> rejected + (None, None), # absent/null -> no ceiling +]) +def test_user_claims_sanitizes_session_expiry(value, expected): + assert UserClaims(sub="u", session_expiry=value).session_expiry == expected + + +def test_user_claims_session_expiry_absent_is_none(): + assert UserClaims(sub="u").session_expiry is None + + +def test_update_state_data_preserves_ceiling_across_refresh(): + now = int(time.time()) + ceiling = now + 3600 + existing_state = { + "refresh_token": "refresh_xyz", + "token_sets": [], + "internal": {"sid": "some_sid", "created_at": now, "session_expires_at": ceiling}, + } + # A refresh-token grant never carries session_expiry; the login ceiling stands. + refresh_response = {"access_token": "new_token", "scope": "openid", "expires_in": 3600} + + updated = State.update_state_data("default", existing_state, refresh_response) + + assert updated["internal"]["session_expires_at"] == ceiling + + +@pytest.mark.asyncio +async def test_get_session_expired_by_ceiling_returns_none_and_deletes(): + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "user": {"sub": "user123"}, + "id_token": "token123", + "internal": {"sid": "some_sid", "created_at": now - 100, "session_expires_at": now - 10}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + session_data = await client.get_session() + assert session_data is None + mock_state_store.delete.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_session_within_ceiling_ok(): + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "user": {"sub": "user123"}, + "id_token": "token123", + "internal": {"sid": "some_sid", "created_at": now, "session_expires_at": now + 3600}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + session_data = await client.get_session() + assert session_data is not None + assert session_data["user"] == {"sub": "user123"} + mock_state_store.delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_user_expired_by_ceiling_returns_none_and_deletes(): + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "user": {"sub": "user123"}, + "internal": {"sid": "some_sid", "created_at": now - 100, "session_expires_at": now - 10}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + user = await client.get_user() + assert user is None + mock_state_store.delete.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_no_ceiling_unaffected(): + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "user": {"sub": "user123"}, + "internal": {"sid": "some_sid", "created_at": int(time.time())}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + user = await client.get_user() + assert user == {"sub": "user123"} + mock_state_store.delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_access_token_expired_by_ceiling_raises_without_refresh(mocker): + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "cached_token", + "expires_at": now + 500, # still valid, but ceiling overrides + } + ], + "internal": {"sid": "some_sid", "created_at": now - 100, "session_expires_at": now - 10}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + # If the refresh path is reached, that's a bug — make it explode. + refresh_spy = mocker.patch.object( + client, "get_token_by_refresh_token", new_callable=AsyncMock, + side_effect=AssertionError("refresh must not be attempted after ceiling"), + ) + + with pytest.raises(SessionExpiredError) as exc: + await client.get_access_token() + + assert exc.value.code == AccessTokenErrorCode.SESSION_EXPIRED + refresh_spy.assert_not_awaited() + mock_state_store.delete.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_access_token_within_ceiling_serves_cached(): + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "token_sets": [ + { + "audience": "default", + "access_token": "cached_token", + "expires_at": now + 500, + } + ], + "internal": {"sid": "some_sid", "created_at": now, "session_expires_at": now + 3600}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + token = await client.get_access_token() + assert token == "cached_token" + mock_state_store.delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_access_token_for_connection_not_gated_by_ceiling(): + # Token Vault connection tokens follow the upstream IdP's own expires_in, + # so a passed session ceiling must NOT block or tear down the session here. + now = int(time.time()) + mock_state_store = AsyncMock() + mock_state_store.get.return_value = { + "refresh_token": "refresh_xyz", + "connection_token_sets": [ + { + "connection": "google-oauth2", + "login_hint": "user@example.com", + "access_token": "cached_conn_token", + "expires_at": now + 500, + } + ], + "internal": {"sid": "some_sid", "created_at": now - 100, "session_expires_at": now - 10}, + } + + client = ServerClient( + domain="auth0.local", + client_id="client_id", + client_secret="client_secret", + transaction_store=AsyncMock(), + state_store=mock_state_store, + secret="some-secret" + ) + + token = await client.get_access_token_for_connection({"connection": "google-oauth2"}) + assert token == "cached_conn_token" + mock_state_store.delete.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_complete_interactive_login_rejects_already_expired_ceiling(mocker): + """A session_expiry already in the past at login is rejected, not persisted.""" + iat = int(time.time()) + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with a ceiling already in the past relative to iat + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", + "aud": "test_client", + "iat": iat, + "session_expiry": iat - 3600, + }) + + with pytest.raises(SessionExpiredError) as exc: + await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert exc.value.code == AccessTokenErrorCode.SESSION_EXPIRED + # The already-expired session must never be persisted. The transaction is + # cleaned up because its authorization code was already spent and cannot be + # reused — a retry starts a fresh login with a new transaction. + mock_state_store.set.assert_not_awaited() + mock_tx_store.delete.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_complete_interactive_login_future_ceiling_persists(mocker): + """A future session_expiry is stamped on the session and login succeeds.""" + iat = int(time.time()) + ceiling = iat + 3600 + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode with a ceiling comfortably in the future + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", + "aud": "test_client", + "iat": iat, + "session_expiry": ceiling, + }) + + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert "state_data" in result + mock_state_store.set.assert_awaited_once() + stored_state = mock_state_store.set.call_args.args[1] + assert stored_state.internal.session_expires_at == ceiling + + +@pytest.mark.asyncio +async def test_complete_interactive_login_no_ceiling_persists_normally(mocker): + """No session_expiry claim -> login behaves exactly as before (no ceiling).""" + iat = int(time.time()) + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + # Mock OIDC metadata + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # Mock JWKS fetch + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + # Mock OAuth fetch_token + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + # Mock jwt.get_unverified_header + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + + # Mock PyJWK.from_dict + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + # Mock jwt.decode without a session_expiry claim + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", + "aud": "test_client", + "iat": iat, + }) + + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert "state_data" in result + mock_state_store.set.assert_awaited_once() + stored_state = mock_state_store.set.call_args.args[1] + assert stored_state.internal.session_expires_at is None + + +@pytest.mark.asyncio +async def test_complete_interactive_login_ignores_ceiling_from_userinfo(mocker): + """The ceiling is read only from the verified ID token. A session_expiry + present in the unverified userinfo response must NOT be persisted.""" + iat = int(time.time()) + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + + # fetch_token returns a userinfo dict (no id_token), driving the userinfo + # branch. Its session_expiry must be ignored, not stamped on the session. + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "scope": "openid profile", + "userinfo": { + "sub": "user123", + "iat": iat, + "session_expiry": iat + 3600, + }, + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert "state_data" in result + mock_state_store.set.assert_awaited_once() + stored_state = mock_state_store.set.call_args.args[1] + assert stored_state.internal.session_expires_at is None + + +@pytest.mark.asyncio +async def test_complete_interactive_login_malformed_ceiling_fails_open(mocker): + """A non-numeric session_expiry is treated as no ceiling, never a hard fail.""" + iat = int(time.time()) + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", + "aud": "test_client", + "iat": iat, + "session_expiry": "not-a-number", + }) + + # Must not raise — garbage claim degrades to no ceiling. + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert "state_data" in result + mock_state_store.set.assert_awaited_once() + stored_state = mock_state_store.set.call_args.args[1] + assert stored_state.internal.session_expires_at is None + + +@pytest.mark.asyncio +async def test_complete_interactive_login_milliseconds_ceiling_fails_open(mocker): + """A millisecond-scale session_expiry is rejected as implausible -> no ceiling.""" + iat = int(time.time()) + + mock_tx_store = AsyncMock() + mock_tx_store.get.return_value = TransactionData( + code_verifier="123", + domain="tenant.auth0.com", + ) + mock_state_store = AsyncMock() + + client = ServerClient( + domain="tenant.auth0.com", + client_id="test_client", + client_secret="test_secret", + transaction_store=mock_tx_store, + state_store=mock_state_store, + secret="test_secret_key_32_chars_long!!", + ) + + mocker.patch.object( + client, + "_get_oidc_metadata_cached", + return_value={"issuer": "https://tenant.auth0.com/", "token_endpoint": "https://tenant.auth0.com/token"} + ) + mocker.patch.object( + client, + "_get_jwks_cached", + return_value={"keys": [{"kty": "RSA", "kid": "test-key"}]} + ) + + async_fetch_token = AsyncMock() + async_fetch_token.return_value = { + "access_token": "token123", + "id_token": "id_token_jwt", + "scope": "openid profile" + } + mocker.patch.object(client._oauth, "fetch_token", async_fetch_token) + mocker.patch("jwt.get_unverified_header", return_value={"kid": "test-key"}) + mock_signing_key = mocker.MagicMock() + mock_signing_key.key = "mock_pem_key" + mocker.patch("jwt.PyJWK.from_dict", return_value=mock_signing_key) + + mocker.patch("jwt.decode", return_value={ + "sub": "user123", + "iss": "https://tenant.auth0.com/", + "aud": "test_client", + "iat": iat, + "session_expiry": 1748566800000, + }) + + # Must not raise — a ms value is implausible as Unix seconds, so no ceiling. + result = await client.complete_interactive_login("http://localhost/callback?code=abc&state=xyz") + + assert "state_data" in result + mock_state_store.set.assert_awaited_once() + stored_state = mock_state_store.set.call_args.args[1] + assert stored_state.internal.session_expires_at is None diff --git a/src/auth0_server_python/utils/helpers.py b/src/auth0_server_python/utils/helpers.py index 1fff76a..b875289 100644 --- a/src/auth0_server_python/utils/helpers.py +++ b/src/auth0_server_python/utils/helpers.py @@ -38,6 +38,10 @@ def generate_code_challenge(cls, code_verifier: str) -> str: class State: + # Clock-skew leeway (seconds): treat the session as expired slightly before + # the ceiling so the SDK never serves a session the platform has revoked. + SESSION_EXPIRY_LEEWAY_SECONDS = 30 + @classmethod def update_state_data( cls, @@ -92,12 +96,18 @@ def update_state_data( else ts for ts in token_sets ] + # A refresh-token grant does not carry session_expiry, so carry the + # existing internal block (including the ceiling pinned at login) + # forward unchanged rather than re-deriving it. + internal = dict(state_data_dict.get("internal") or {}) + # Return updated state data return { **state_data_dict, "id_token": token_endpoint_response.get("id_token"), "refresh_token": token_endpoint_response.get("refresh_token") or state_data_dict.get("refresh_token"), - "token_sets": token_sets + "token_sets": token_sets, + "internal": internal } else: # Create completely new state data @@ -178,6 +188,35 @@ def update_state_data_for_connection_token_set( "connection_token_sets": connection_token_sets } + @classmethod + def is_session_ceiling_reached(cls, session_expires_at: Optional[int]) -> bool: + """ + True when the session ceiling has been reached (applying negative + leeway for clock skew). None means no ceiling was asserted, so the + session is never expired on this basis. + """ + if session_expires_at is None: + return False + now = int(time.time()) + return now >= (session_expires_at - cls.SESSION_EXPIRY_LEEWAY_SECONDS) + + @classmethod + def is_session_ceiling_in_past( + cls, session_expires_at: Optional[int], issued_at: Optional[int] = None + ) -> bool: + """ + True when the session ceiling is already in the past at login. + + Compares the ceiling against the ID token `iat`, or wall-clock now when + `iat` is absent, using the same leeway as is_session_ceiling_reached. A + None ceiling means none was asserted and is never treated as expired. + """ + if session_expires_at is None: + return False + reference = issued_at if issued_at is not None else int(time.time()) + return session_expires_at <= (reference + cls.SESSION_EXPIRY_LEEWAY_SECONDS) + + class URL: @staticmethod def build_url(base_url: str, params: dict[str, Any]) -> str: