Skip to content

Commit b8f826c

Browse files
committed
fix(client): preserve existing query params on OAuth authorization_endpoint
Closes #2776 The authorization code grant built the redirect URL with `f"{auth_endpoint}?{urlencode(auth_params)}"`, which produces an invalid URL when the server-advertised authorization_endpoint already carries a query string. For example Salesforce advertises `.../services/oauth2/authorize?prompt=select_account`, yielding `...authorize?prompt=select_account?response_type=code&...` (two `?` separators), so the client navigates to a malformed URL and the server rejects the request. Fix: parse the endpoint, merge its existing query params with the flow-generated auth_params (flow params win on conflict), and re-encode into a single well-formed query string. None-valued params are dropped rather than serialized as the literal "None". Tests: add TestAuthorizationEndpointWithQuery covering the helper (no/with/conflicting existing query) plus an end-to-end _perform_authorization_code_grant assertion that the captured redirect URL preserves the server param and stays well-formed. 101 passed.
1 parent a527142 commit b8f826c

2 files changed

Lines changed: 103 additions & 3 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import secrets
1010
import string
1111
import time
12-
from collections.abc import AsyncGenerator, Awaitable, Callable
12+
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
1313
from dataclasses import dataclass, field
1414
from typing import Any, Protocol
15-
from urllib.parse import quote, urlencode, urljoin, urlparse
15+
from urllib.parse import parse_qsl, quote, urlencode, urljoin, urlparse, urlunparse
1616

1717
import anyio
1818
import httpx
@@ -59,6 +59,22 @@
5959
logger = logging.getLogger(__name__)
6060

6161

62+
def _build_authorization_url(auth_endpoint: str, auth_params: Mapping[str, str | None]) -> str:
63+
"""Build an authorization URL, preserving any query params already on the endpoint.
64+
65+
Servers may advertise an ``authorization_endpoint`` that already carries query
66+
parameters (e.g. ``https://example.com/authorize?prompt=select_account``).
67+
Naively appending ``?<params>`` would produce an invalid URL with two ``?``
68+
separators, so the existing query is parsed and merged with ``auth_params``.
69+
Flow-generated params take precedence on key conflicts; ``None`` values are
70+
dropped rather than serialized as the literal string ``"None"``.
71+
"""
72+
parsed = urlparse(auth_endpoint)
73+
merged_params = dict(parse_qsl(parsed.query, keep_blank_values=True))
74+
merged_params.update({key: value for key, value in auth_params.items() if value is not None})
75+
return urlunparse(parsed._replace(query=urlencode(merged_params)))
76+
77+
6278
class PKCEParameters(BaseModel):
6379
"""PKCE (Proof Key for Code Exchange) parameters."""
6480

@@ -357,7 +373,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
357373
if "offline_access" in self.context.client_metadata.scope.split():
358374
auth_params["prompt"] = "consent"
359375

360-
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
376+
authorization_url = _build_authorization_url(auth_endpoint, auth_params)
361377
await self.context.redirect_handler(authorization_url)
362378

363379
# Wait for callback

tests/client/test_auth.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mcp.client.auth import OAuthClientProvider, PKCEParameters
1515
from mcp.client.auth.exceptions import OAuthFlowError
16+
from mcp.client.auth.oauth2 import _build_authorization_url
1617
from mcp.client.auth.utils import (
1718
build_oauth_authorization_server_metadata_discovery_urls,
1819
build_protected_resource_metadata_discovery_urls,
@@ -3158,3 +3159,86 @@ async def echo_callback() -> AuthorizationCodeResult:
31583159
await auth_flow.asend(httpx.Response(200, request=final_req))
31593160
except StopAsyncIteration:
31603161
pass
3162+
3163+
3164+
class TestAuthorizationEndpointWithQuery:
3165+
"""Regression tests for #2776 - authorization_endpoint carrying query params."""
3166+
3167+
def test_build_authorization_url_no_existing_query(self):
3168+
url = _build_authorization_url(
3169+
"https://auth.example.com/authorize",
3170+
{"response_type": "code", "client_id": "abc"},
3171+
)
3172+
parsed = urlparse(url)
3173+
params = parse_qs(parsed.query)
3174+
assert parsed.path == "/authorize"
3175+
assert params["response_type"] == ["code"]
3176+
assert params["client_id"] == ["abc"]
3177+
# No malformed double "?" separator.
3178+
assert url.count("?") == 1
3179+
3180+
def test_build_authorization_url_preserves_existing_query(self):
3181+
# e.g. Salesforce advertises .../authorize?prompt=select_account
3182+
url = _build_authorization_url(
3183+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account",
3184+
{"response_type": "code", "client_id": "abc"},
3185+
)
3186+
parsed = urlparse(url)
3187+
params = parse_qs(parsed.query)
3188+
assert parsed.path == "/services/oauth2/authorize"
3189+
# The server-provided param survives...
3190+
assert params["prompt"] == ["select_account"]
3191+
# ...alongside the flow-generated params.
3192+
assert params["response_type"] == ["code"]
3193+
assert params["client_id"] == ["abc"]
3194+
# Exactly one "?" - the old f-string produced "...?prompt=...?response_type=...".
3195+
assert url.count("?") == 1
3196+
3197+
def test_build_authorization_url_flow_params_win_on_conflict(self):
3198+
url = _build_authorization_url(
3199+
"https://auth.example.com/authorize?response_type=token",
3200+
{"response_type": "code"},
3201+
)
3202+
params = parse_qs(urlparse(url).query)
3203+
assert params["response_type"] == ["code"]
3204+
3205+
@pytest.mark.anyio
3206+
async def test_perform_authorization_preserves_endpoint_query(self, oauth_provider: OAuthClientProvider):
3207+
"""End-to-end: redirect URL stays valid when the endpoint has a query string."""
3208+
oauth_provider.context.oauth_metadata = OAuthMetadata(
3209+
issuer=AnyHttpUrl("https://test.salesforce.com"),
3210+
authorization_endpoint=AnyHttpUrl(
3211+
"https://test.salesforce.com/services/oauth2/authorize?prompt=select_account"
3212+
),
3213+
token_endpoint=AnyHttpUrl("https://test.salesforce.com/services/oauth2/token"),
3214+
)
3215+
oauth_provider.context.client_info = OAuthClientInformationFull(
3216+
client_id="test_client_id",
3217+
client_secret="test_client_secret",
3218+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
3219+
)
3220+
3221+
captured_url: str | None = None
3222+
captured_state: str | None = None
3223+
3224+
async def capture_redirect(url: str) -> None:
3225+
nonlocal captured_url, captured_state
3226+
captured_url = url
3227+
captured_state = parse_qs(urlparse(url).query).get("state", [None])[0]
3228+
3229+
async def mock_callback() -> AuthorizationCodeResult:
3230+
return AuthorizationCodeResult(code="test_auth_code", state=captured_state)
3231+
3232+
oauth_provider.context.redirect_handler = capture_redirect
3233+
oauth_provider.context.callback_handler = mock_callback
3234+
3235+
await oauth_provider._perform_authorization_code_grant()
3236+
3237+
assert captured_url is not None
3238+
parsed = urlparse(captured_url)
3239+
params = parse_qs(parsed.query)
3240+
assert parsed.path == "/services/oauth2/authorize"
3241+
assert params["prompt"] == ["select_account"]
3242+
assert params["response_type"] == ["code"]
3243+
assert params["client_id"] == ["test_client_id"]
3244+
assert captured_url.count("?") == 1

0 commit comments

Comments
 (0)