|
1 | 1 | import time |
2 | | -from typing import Any, Optional |
| 2 | +from typing import Any, Optional, Union |
3 | 3 |
|
4 | 4 | import httpx |
5 | 5 | from authlib.jose import JsonWebKey, JsonWebToken |
|
9 | 9 | ApiError, |
10 | 10 | BaseAuthError, |
11 | 11 | GetAccessTokenForConnectionError, |
| 12 | + GetTokenByExchangeProfileError, |
12 | 13 | InvalidAuthSchemeError, |
13 | 14 | InvalidDpopProofError, |
14 | 15 | MissingAuthorizationError, |
@@ -501,6 +502,221 @@ async def get_access_token_for_connection(self, options: dict[str, Any]) -> dict |
501 | 502 | exc |
502 | 503 | ) |
503 | 504 |
|
| 505 | + async def get_token_by_exchange_profile( |
| 506 | + self, |
| 507 | + subject_token: str, |
| 508 | + subject_token_type: str, |
| 509 | + audience: Optional[str] = None, |
| 510 | + scope: Optional[str] = None, |
| 511 | + requested_token_type: Optional[str] = None, |
| 512 | + extra: Optional[dict[str, Union[str, list[str]]]] = None |
| 513 | + ) -> dict[str, Any]: |
| 514 | + """ |
| 515 | + Exchanges a token via a Custom Token Exchange Profile for Auth0 tokens (RFC 8693). |
| 516 | +
|
| 517 | + This method supports Custom Token Exchange for custom token types via a configured |
| 518 | + Token Exchange Profile. It exchanges custom tokens (from MCP servers, legacy systems, |
| 519 | + or partner services) for Auth0 tokens targeting a specific API audience while |
| 520 | + preserving user identity. |
| 521 | +
|
| 522 | + **Note**: This method requires a confidential client (client_id and client_secret |
| 523 | + must be configured). |
| 524 | +
|
| 525 | + Args: |
| 526 | + subject_token: The raw token to be exchanged (without "Bearer " prefix) |
| 527 | + subject_token_type: URI identifying the token type (must match a Token Exchange Profile) |
| 528 | + audience: Optional target API identifier for the exchanged tokens |
| 529 | + scope: Optional space-separated OAuth 2.0 scopes to request |
| 530 | + requested_token_type: Optional type of token to issue (defaults to access token) |
| 531 | + extra: Optional custom parameters accessible in Auth0 Actions. Cannot override |
| 532 | + reserved OAuth parameters. Array values limited to 20 items per key. |
| 533 | +
|
| 534 | + Returns: |
| 535 | + Dictionary containing: |
| 536 | + - access_token (str): The Auth0 access token |
| 537 | + - expires_at (int): Unix timestamp when token expires |
| 538 | + - id_token (str, optional): OpenID Connect ID token |
| 539 | + - refresh_token (str, optional): Refresh token |
| 540 | + - scope (str, optional): Granted scopes |
| 541 | + - token_type (str, optional): Token type (typically "Bearer") |
| 542 | + - issued_token_type (str, optional): RFC 8693 issued token type identifier |
| 543 | +
|
| 544 | + Raises: |
| 545 | + MissingRequiredArgumentError: If required parameters are missing |
| 546 | + GetTokenByExchangeProfileError: If client credentials not configured or exchange fails |
| 547 | + ApiError: If the token endpoint returns an error |
| 548 | +
|
| 549 | + Example: |
| 550 | + >>> result = await api_client.get_token_by_exchange_profile( |
| 551 | + ... subject_token=custom_token, |
| 552 | + ... subject_token_type="urn:example:custom-token", |
| 553 | + ... audience="https://api.backend.com", |
| 554 | + ... scope="openid profile read:data" |
| 555 | + ... ) |
| 556 | + >>> print(result["access_token"]) |
| 557 | +
|
| 558 | + References: |
| 559 | + - Custom Token Exchange Documentation: https://auth0.com/docs/authenticate/custom-token-exchange |
| 560 | + - RFC 8693 OAuth 2.0 Token Exchange: https://datatracker.ietf.org/doc/html/rfc8693 |
| 561 | + """ |
| 562 | + # Constants |
| 563 | + TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" |
| 564 | + MAX_ARRAY_VALUES_PER_KEY = 20 |
| 565 | + |
| 566 | + # OAuth parameter denylist |
| 567 | + PARAM_DENYLIST = frozenset([ |
| 568 | + "grant_type", "client_id", "client_secret", "client_assertion", |
| 569 | + "client_assertion_type", "subject_token", "subject_token_type", |
| 570 | + "requested_token_type", "actor_token", "actor_token_type", |
| 571 | + "audience", "aud", "resource", "resources", "resource_indicator", |
| 572 | + "scope", "connection", "login_hint", "organization", "assertion", |
| 573 | + ]) |
| 574 | + |
| 575 | + # Validate required parameters |
| 576 | + if not subject_token: |
| 577 | + raise MissingRequiredArgumentError("subject_token") |
| 578 | + if not subject_token_type: |
| 579 | + raise MissingRequiredArgumentError("subject_token_type") |
| 580 | + |
| 581 | + # Validate subject token format |
| 582 | + if not isinstance(subject_token, str): |
| 583 | + raise GetTokenByExchangeProfileError("subject_token must be a string") |
| 584 | + if not subject_token.strip(): |
| 585 | + raise GetTokenByExchangeProfileError("subject_token cannot be blank or whitespace") |
| 586 | + if subject_token != subject_token.strip(): |
| 587 | + raise GetTokenByExchangeProfileError( |
| 588 | + "subject_token must not include leading or trailing whitespace" |
| 589 | + ) |
| 590 | + if subject_token.lower().startswith("bearer "): |
| 591 | + raise GetTokenByExchangeProfileError( |
| 592 | + "subject_token must not include the 'Bearer ' prefix" |
| 593 | + ) |
| 594 | + |
| 595 | + # Require client credentials |
| 596 | + client_id = self.options.client_id |
| 597 | + client_secret = self.options.client_secret |
| 598 | + if not client_id or not client_secret: |
| 599 | + raise GetTokenByExchangeProfileError( |
| 600 | + "Client credentials are required to use get_token_by_exchange_profile" |
| 601 | + ) |
| 602 | + |
| 603 | + # Discover token endpoint |
| 604 | + metadata = await self._discover() |
| 605 | + token_endpoint = metadata.get("token_endpoint") |
| 606 | + if not token_endpoint: |
| 607 | + raise GetTokenByExchangeProfileError("Token endpoint missing in OIDC metadata") |
| 608 | + |
| 609 | + # Build request parameters |
| 610 | + params = { |
| 611 | + "grant_type": TOKEN_EXCHANGE_GRANT_TYPE, |
| 612 | + "client_id": client_id, |
| 613 | + "subject_token": subject_token, |
| 614 | + "subject_token_type": subject_token_type, |
| 615 | + } |
| 616 | + |
| 617 | + # Add optional parameters |
| 618 | + if audience: |
| 619 | + params["audience"] = audience |
| 620 | + if scope: |
| 621 | + params["scope"] = scope |
| 622 | + if requested_token_type: |
| 623 | + params["requested_token_type"] = requested_token_type |
| 624 | + |
| 625 | + # Append extra parameters with validation |
| 626 | + if extra: |
| 627 | + for parameter_key, parameter_value in extra.items(): |
| 628 | + # Silently ignore denylisted parameters |
| 629 | + if parameter_key in PARAM_DENYLIST: |
| 630 | + continue |
| 631 | + |
| 632 | + if isinstance(parameter_value, list): |
| 633 | + # Validate array size for DoS protection |
| 634 | + if len(parameter_value) > MAX_ARRAY_VALUES_PER_KEY: |
| 635 | + raise GetTokenByExchangeProfileError( |
| 636 | + f"Parameter '{parameter_key}' exceeds maximum array size of {MAX_ARRAY_VALUES_PER_KEY}" |
| 637 | + ) |
| 638 | + # Store as list - httpx will encode as multiple key=value pairs |
| 639 | + params[parameter_key] = parameter_value |
| 640 | + else: |
| 641 | + params[parameter_key] = str(parameter_value) |
| 642 | + |
| 643 | + # Make token exchange request |
| 644 | + try: |
| 645 | + async with httpx.AsyncClient() as client: |
| 646 | + response = await client.post( |
| 647 | + token_endpoint, |
| 648 | + data=params, |
| 649 | + auth=(client_id, client_secret) |
| 650 | + ) |
| 651 | + |
| 652 | + if response.status_code != 200: |
| 653 | + error_data = response.json() if "json" in response.headers.get( |
| 654 | + "content-type", "").lower() else {} |
| 655 | + raise ApiError( |
| 656 | + error_data.get("error", "token_exchange_error"), |
| 657 | + error_data.get( |
| 658 | + "error_description", |
| 659 | + f"Failed to exchange token of type '{subject_token_type}'" |
| 660 | + + (f" for audience '{audience}'" if audience else "") |
| 661 | + ), |
| 662 | + response.status_code |
| 663 | + ) |
| 664 | + |
| 665 | + try: |
| 666 | + token_response = response.json() |
| 667 | + except Exception: |
| 668 | + raise ApiError("invalid_json", "Token endpoint returned invalid JSON.", 502) |
| 669 | + |
| 670 | + # Validate required fields |
| 671 | + access_token = token_response.get("access_token") |
| 672 | + if not isinstance(access_token, str) or not access_token: |
| 673 | + raise ApiError( |
| 674 | + "invalid_response", |
| 675 | + "Missing or invalid access_token in response.", |
| 676 | + 502 |
| 677 | + ) |
| 678 | + |
| 679 | + expires_in_raw = token_response.get("expires_in", 3600) |
| 680 | + try: |
| 681 | + expires_in = int(expires_in_raw) |
| 682 | + except (TypeError, ValueError): |
| 683 | + raise ApiError("invalid_response", "expires_in is not an integer.", 502) |
| 684 | + |
| 685 | + # Build response (match JS SDK structure) |
| 686 | + result = { |
| 687 | + "access_token": access_token, |
| 688 | + "expires_at": int(time.time()) + expires_in, |
| 689 | + } |
| 690 | + |
| 691 | + # Add optional fields if present (conditional spreading like JS) |
| 692 | + if "scope" in token_response and token_response["scope"]: |
| 693 | + result["scope"] = token_response["scope"] |
| 694 | + if "id_token" in token_response and token_response["id_token"]: |
| 695 | + result["id_token"] = token_response["id_token"] |
| 696 | + if "refresh_token" in token_response and token_response["refresh_token"]: |
| 697 | + result["refresh_token"] = token_response["refresh_token"] |
| 698 | + if "token_type" in token_response and token_response["token_type"]: |
| 699 | + result["token_type"] = token_response["token_type"] |
| 700 | + if "issued_token_type" in token_response and token_response["issued_token_type"]: |
| 701 | + result["issued_token_type"] = token_response["issued_token_type"] |
| 702 | + |
| 703 | + return result |
| 704 | + |
| 705 | + except httpx.TimeoutException as exc: |
| 706 | + raise ApiError( |
| 707 | + "timeout_error", |
| 708 | + f"Request to token endpoint timed out: {str(exc)}", |
| 709 | + 504, |
| 710 | + exc |
| 711 | + ) |
| 712 | + except httpx.HTTPError as exc: |
| 713 | + raise ApiError( |
| 714 | + "network_error", |
| 715 | + f"Network error occurred: {str(exc)}", |
| 716 | + 502, |
| 717 | + exc |
| 718 | + ) |
| 719 | + |
504 | 720 | # ===== Private Methods ===== |
505 | 721 |
|
506 | 722 | async def _discover(self) -> dict[str, Any]: |
|
0 commit comments