@@ -120,7 +120,18 @@ class OAuthContext:
120120 token_expiry_time : float | None = None
121121
122122 # State
123+ #
124+ # `lock` guards short-lived reads/writes of provider state (initialization
125+ # flag, token cache mutation, protocol_version assignment). It is held only
126+ # while mutating state and is released before any HTTP request is yielded
127+ # so a long-running request (e.g. GET SSE long-poll) does not block
128+ # unrelated concurrent requests.
129+ #
130+ # `refresh_lock` provides single-flight semantics for token refresh: only
131+ # one concurrent refresh fires; other waiters block on this lock, then
132+ # re-check the token cache and proceed without re-refreshing.
123133 lock : anyio .Lock = field (default_factory = anyio .Lock )
134+ refresh_lock : anyio .Lock = field (default_factory = anyio .Lock )
124135
125136 def get_authorization_base_url (self , server_url : str ) -> str :
126137 """Extract base URL by removing path component."""
@@ -492,7 +503,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool:
492503 await self .context .storage .set_tokens (token_response )
493504
494505 return True
495- except ValidationError : # pragma: no cover
506+ except ValidationError :
496507 logger .exception ("Invalid refresh response" )
497508 self .context .clear_tokens ()
498509 return False
@@ -527,30 +538,88 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
527538 if not check_resource_allowed (requested_resource = default_resource , configured_resource = prm_resource ):
528539 raise OAuthFlowError (f"Protected resource { prm_resource } does not match expected { default_resource } " )
529540
530- async def async_auth_flow (self , request : httpx .Request ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
531- """HTTPX auth flow integration."""
541+ async def _prepare_and_decide_refresh (self , request : httpx .Request ) -> bool :
542+ """Phase 1: initialize + capture protocol version, then decide whether a
543+ proactive token refresh is needed. Holds ``self.context.lock`` only
544+ briefly. Returns ``True`` when the token is invalid but refreshable.
545+ """
532546 async with self .context .lock :
533547 if not self ._initialized :
534548 await self ._initialize ()
535549
536550 # Capture protocol version from request headers
537551 self .context .protocol_version = request .headers .get (MCP_PROTOCOL_VERSION )
538552
539- if not self .context .is_token_valid () and self .context .can_refresh_token ():
540- # Try to refresh token
541- refresh_request = await self ._refresh_token ()
542- refresh_response = yield refresh_request
543-
544- if not await self ._handle_refresh_response (refresh_response ):
545- # Refresh failed, need full re-authentication
546- self ._initialized = False
547-
548- if self .context .is_token_valid ():
549- self ._add_auth_header (request )
550-
551- response = yield request
553+ # pragma: no branch — coverage.py on Python 3.10/3.11 (sys.settrace
554+ # backend) cannot reliably track both arms of compound boolean
555+ # predicates inside an ``async with`` block in an async generator.
556+ # Python 3.12+ (sys.monitoring) handles this correctly; the pragmas
557+ # below are workarounds for the legacy backend only.
558+ if not self .context .is_token_valid () and self .context .can_refresh_token (): # pragma: no branch
559+ return True
560+ return False
552561
553- if response .status_code == 401 :
562+ async def async_auth_flow (self , request : httpx .Request ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
563+ """HTTPX auth flow integration.
564+
565+ Lock scope:
566+ ``self.context.lock`` is held only while reading/mutating provider
567+ state. The actual HTTP request yield (which may be a long-poll GET
568+ SSE stream) runs outside any lock so concurrent unrelated requests
569+ are not blocked. ``self.context.refresh_lock`` provides
570+ single-flight semantics for token refresh.
571+ """
572+ # === Phase 1: state read + refresh decision (brief context.lock) ===
573+ needs_refresh = await self ._prepare_and_decide_refresh (request )
574+
575+ # === Phase 2: single-flight token refresh (yield outside context.lock) ===
576+ if needs_refresh :
577+ async with self .context .refresh_lock :
578+ # Re-check under context.lock: another coroutine may already have
579+ # refreshed while we were waiting on refresh_lock.
580+ refresh_request : httpx .Request | None = None
581+ async with self .context .lock :
582+ if not self .context .is_token_valid () and self .context .can_refresh_token (): # pragma: no branch
583+ refresh_request = await self ._refresh_token ()
584+ if refresh_request is not None : # pragma: no branch
585+ # yield runs outside any lock so a long network round trip
586+ # does not block unrelated concurrent requests.
587+ refresh_response = yield refresh_request
588+ async with self .context .lock :
589+ if not await self ._handle_refresh_response (refresh_response ): # pragma: no branch
590+ # Refresh failed; fall through to 401 handling below.
591+ self ._initialized = False
592+
593+ # === Phase 3: send request (no lock; safe for long-poll GET SSE) ===
594+ if self .context .is_token_valid ():
595+ self ._add_auth_header (request )
596+
597+ # Capture the access token actually used to send this request so the
598+ # 401 handler below can detect a token change made by a concurrent
599+ # request while this one was in flight.
600+ sent_access_token = self .context .current_tokens .access_token if self .context .current_tokens else None
601+
602+ response = yield request
603+
604+ # === Phase 4: 401 / 403 full OAuth flow ===
605+ # NOTE: Phase 4 yields multiple sub-requests (discovery, registration,
606+ # token exchange) under context.lock. This is the existing behavior and
607+ # is acceptable because the 401 path is exceptional and not concurrent
608+ # with steady-state traffic. A future refactor could narrow the lock
609+ # here in the same pattern as Phase 1-2.
610+ if response .status_code == 401 :
611+ async with self .context .lock :
612+ # Concurrency guard: while this request was in flight, another
613+ # request holding ``context.lock`` may have already completed a
614+ # token refresh or a full re-authorization. If the stored access
615+ # token changed since we sent this request, the 401 is stale -
616+ # retry once with the new token instead of running a second,
617+ # duplicate ``authorization_code`` exchange.
618+ current_access_token = self .context .current_tokens .access_token if self .context .current_tokens else None
619+ if current_access_token is not None and current_access_token != sent_access_token :
620+ self ._add_auth_header (request )
621+ yield request
622+ return
554623 # Perform full OAuth flow
555624 try :
556625 # OAuth flow must be inline due to generator constraints
@@ -701,7 +770,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
701770 # Retry with new tokens
702771 self ._add_auth_header (request )
703772 yield request
704- elif response .status_code == 403 :
773+ elif response .status_code == 403 :
774+ async with self .context .lock :
705775 # Step 1: Extract error field from WWW-Authenticate header
706776 error = extract_field_from_www_auth (response , "error" )
707777
0 commit comments