@@ -476,23 +476,11 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
476476 await response (scope , receive , send )
477477 return
478478
479- # Check if this is an initialization request
480- is_initialization_request = isinstance (message , JSONRPCRequest ) and message .method == "initialize"
481-
482- if is_initialization_request :
483- # Check if the server already has an established session
484- if self .mcp_session_id :
485- # Check if request has a session ID
486- request_session_id = self ._get_session_id (request )
487-
488- # If request has a session ID but doesn't match, return 404
489- if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
490- response = self ._create_error_response (
491- "Not Found: Invalid or expired session ID" ,
492- HTTPStatus .NOT_FOUND ,
493- )
494- await response (scope , receive , send )
495- return
479+ is_initialization_request = False
480+ if isinstance (message , JSONRPCRequest ) and message .method == "initialize" :
481+ is_initialization_request = True
482+ if not await self ._validate_initialization_request (message , request , send ):
483+ return
496484 elif not await self ._validate_request_headers (request , send ):
497485 return
498486
@@ -848,6 +836,44 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
848836
849837 return True
850838
839+ async def _validate_initialization_request (self , message : JSONRPCRequest , request : Request , send : Send ) -> bool :
840+ if not await self ._validate_initialization_protocol_version (message , request , send ):
841+ return False
842+
843+ if not self .mcp_session_id :
844+ return True
845+
846+ request_session_id = self ._get_session_id (request )
847+ if request_session_id and request_session_id != self .mcp_session_id : # pragma: no cover
848+ response = self ._create_error_response (
849+ "Not Found: Invalid or expired session ID" ,
850+ HTTPStatus .NOT_FOUND ,
851+ )
852+ await response (request .scope , request .receive , send )
853+ return False
854+
855+ return True
856+
857+ async def _validate_initialization_protocol_version (
858+ self , message : JSONRPCRequest , request : Request , send : Send
859+ ) -> bool :
860+ header_protocol_version = request .headers .get (MCP_PROTOCOL_VERSION_HEADER )
861+ body_protocol_version = str (message .params .get ("protocolVersion" )) if message .params else None
862+ if (
863+ header_protocol_version is not None
864+ and body_protocol_version is not None
865+ and header_protocol_version != body_protocol_version
866+ ):
867+ response = self ._create_error_response (
868+ f"Bad Request: { MCP_PROTOCOL_VERSION_HEADER } header does not match initialize.params.protocolVersion" ,
869+ HTTPStatus .BAD_REQUEST ,
870+ INVALID_REQUEST ,
871+ )
872+ await response (request .scope , request .receive , send )
873+ return False
874+
875+ return True
876+
851877 async def _replay_events (self , last_event_id : str , request : Request , send : Send ) -> None :
852878 """Replays events that would have been sent after the specified event ID.
853879
0 commit comments