Skip to content

Commit a9be0b9

Browse files
peisukeBartok9
authored andcommitted
fix(oauth): narrow async_auth_flow lock scope to avoid blocking long-poll requests
Closes #2847. OAuthContext.lock is an anyio.Lock, which records task identity at acquire() and enforces same-task release(). async_auth_flow held this lock across yield points; when httpx drives the generator from a different task during concurrent OAuth connections, release() raises 'RuntimeError: The current task is not holding this lock'. Narrows the lock scope so no HTTP yield (long-poll GET SSE, token-refresh round trips) runs while holding context.lock, plus a single-flight refresh_lock with a re-check under the lock. Keeps trio portability (no asyncio.Lock swap). Salvage of #2660 by @peisuke, rebased onto current main.
1 parent a527142 commit a9be0b9

2 files changed

Lines changed: 355 additions & 18 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,18 @@ class OAuthContext:
120120
token_expiry_time: float | None = None
121121

122122
# State
123+
#
124+
# `lock` guards short-lived reads/writes of provider state (initialization
125+
# flag, token cache mutation, protocol_version assignment). It is held only
126+
# while mutating state and is released before any HTTP request is yielded
127+
# so a long-running request (e.g. GET SSE long-poll) does not block
128+
# unrelated concurrent requests.
129+
#
130+
# `refresh_lock` provides single-flight semantics for token refresh: only
131+
# one concurrent refresh fires; other waiters block on this lock, then
132+
# re-check the token cache and proceed without re-refreshing.
123133
lock: anyio.Lock = field(default_factory=anyio.Lock)
134+
refresh_lock: anyio.Lock = field(default_factory=anyio.Lock)
124135

125136
def get_authorization_base_url(self, server_url: str) -> str:
126137
"""Extract base URL by removing path component."""
@@ -492,7 +503,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool:
492503
await self.context.storage.set_tokens(token_response)
493504

494505
return True
495-
except ValidationError: # pragma: no cover
506+
except ValidationError:
496507
logger.exception("Invalid refresh response")
497508
self.context.clear_tokens()
498509
return False
@@ -527,30 +538,88 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
527538
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
528539
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
529540

530-
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
531-
"""HTTPX auth flow integration."""
541+
async def _prepare_and_decide_refresh(self, request: httpx.Request) -> bool:
542+
"""Phase 1: initialize + capture protocol version, then decide whether a
543+
proactive token refresh is needed. Holds ``self.context.lock`` only
544+
briefly. Returns ``True`` when the token is invalid but refreshable.
545+
"""
532546
async with self.context.lock:
533547
if not self._initialized:
534548
await self._initialize()
535549

536550
# Capture protocol version from request headers
537551
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
538552

539-
if not self.context.is_token_valid() and self.context.can_refresh_token():
540-
# Try to refresh token
541-
refresh_request = await self._refresh_token()
542-
refresh_response = yield refresh_request
543-
544-
if not await self._handle_refresh_response(refresh_response):
545-
# Refresh failed, need full re-authentication
546-
self._initialized = False
547-
548-
if self.context.is_token_valid():
549-
self._add_auth_header(request)
550-
551-
response = yield request
553+
# pragma: no branch — coverage.py on Python 3.10/3.11 (sys.settrace
554+
# backend) cannot reliably track both arms of compound boolean
555+
# predicates inside an ``async with`` block in an async generator.
556+
# Python 3.12+ (sys.monitoring) handles this correctly; the pragmas
557+
# below are workarounds for the legacy backend only.
558+
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
559+
return True
560+
return False
552561

553-
if response.status_code == 401:
562+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
563+
"""HTTPX auth flow integration.
564+
565+
Lock scope:
566+
``self.context.lock`` is held only while reading/mutating provider
567+
state. The actual HTTP request yield (which may be a long-poll GET
568+
SSE stream) runs outside any lock so concurrent unrelated requests
569+
are not blocked. ``self.context.refresh_lock`` provides
570+
single-flight semantics for token refresh.
571+
"""
572+
# === Phase 1: state read + refresh decision (brief context.lock) ===
573+
needs_refresh = await self._prepare_and_decide_refresh(request)
574+
575+
# === Phase 2: single-flight token refresh (yield outside context.lock) ===
576+
if needs_refresh:
577+
async with self.context.refresh_lock:
578+
# Re-check under context.lock: another coroutine may already have
579+
# refreshed while we were waiting on refresh_lock.
580+
refresh_request: httpx.Request | None = None
581+
async with self.context.lock:
582+
if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no branch
583+
refresh_request = await self._refresh_token()
584+
if refresh_request is not None: # pragma: no branch
585+
# yield runs outside any lock so a long network round trip
586+
# does not block unrelated concurrent requests.
587+
refresh_response = yield refresh_request
588+
async with self.context.lock:
589+
if not await self._handle_refresh_response(refresh_response): # pragma: no branch
590+
# Refresh failed; fall through to 401 handling below.
591+
self._initialized = False
592+
593+
# === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
594+
if self.context.is_token_valid():
595+
self._add_auth_header(request)
596+
597+
# Capture the access token actually used to send this request so the
598+
# 401 handler below can detect a token change made by a concurrent
599+
# request while this one was in flight.
600+
sent_access_token = self.context.current_tokens.access_token if self.context.current_tokens else None
601+
602+
response = yield request
603+
604+
# === Phase 4: 401 / 403 full OAuth flow ===
605+
# NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
606+
# token exchange) under context.lock. This is the existing behavior and
607+
# is acceptable because the 401 path is exceptional and not concurrent
608+
# with steady-state traffic. A future refactor could narrow the lock
609+
# here in the same pattern as Phase 1-2.
610+
if response.status_code == 401:
611+
async with self.context.lock:
612+
# Concurrency guard: while this request was in flight, another
613+
# request holding ``context.lock`` may have already completed a
614+
# token refresh or a full re-authorization. If the stored access
615+
# token changed since we sent this request, the 401 is stale -
616+
# retry once with the new token instead of running a second,
617+
# duplicate ``authorization_code`` exchange.
618+
current_access_token = self.context.current_tokens.access_token if self.context.current_tokens else None
619+
if current_access_token is not None and current_access_token != sent_access_token:
620+
self._add_auth_header(request)
621+
yield request
622+
return
554623
# Perform full OAuth flow
555624
try:
556625
# OAuth flow must be inline due to generator constraints
@@ -701,7 +770,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
701770
# Retry with new tokens
702771
self._add_auth_header(request)
703772
yield request
704-
elif response.status_code == 403:
773+
elif response.status_code == 403:
774+
async with self.context.lock:
705775
# Step 1: Extract error field from WWW-Authenticate header
706776
error = extract_field_from_www_auth(response, "error")
707777

0 commit comments

Comments
 (0)