diff --git a/CHANGES/11857.bugfix.rst b/CHANGES/11857.bugfix.rst new file mode 100644 index 00000000000..7933efeb074 --- /dev/null +++ b/CHANGES/11857.bugfix.rst @@ -0,0 +1 @@ +Fixed multipart reading failing when encountering an empty body part -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/11862.bugfix.rst b/CHANGES/11862.bugfix.rst new file mode 100644 index 00000000000..c2ce176c2c3 --- /dev/null +++ b/CHANGES/11862.bugfix.rst @@ -0,0 +1 @@ +A test for websocket parser was marked to fail, which was actually failing because the parser wasn't raising an exception for a continuation frame when there was no initial frame in context. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 3a570dc90e7..16889761889 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -251,6 +251,7 @@ Marco Paolini Marcus Stojcevich Mariano Anaya Mariusz Masztalerczuk +Mark Larah Marko Kohtala Martijn Pieters Martin Melka diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index e0088a47af8..433591bde9d 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -200,6 +200,13 @@ def _handle_frame( ) -> None: msg: WSMessage if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}: + # Validate continuation frames before processing + if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + # load text/binary if not fin: # got partial frame payload @@ -216,11 +223,6 @@ def _handle_frame( has_partial = bool(self._partial) if opcode == OP_CODE_CONTINUATION: - if self._opcode == OP_CODE_NOT_SET: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Continuation frame for non started message", - ) opcode = self._opcode self._opcode = OP_CODE_NOT_SET # previous frame was non finished diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 15901e89e0a..a2848f5a486 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -375,7 +375,8 @@ async def _read_chunk_from_stream(self, size: int) -> bytes: ), "Chunk size must be greater or equal than boundary length + 2" first_chunk = self._prev_chunk is None if first_chunk: - self._prev_chunk = await self._content.read(size) + # We need to re-add the CRLF that got removed from headers parsing. + self._prev_chunk = b"\r\n" + await self._content.read(size) chunk = b"" # content.read() may return less than size, so we need to loop to ensure @@ -402,12 +403,11 @@ async def _read_chunk_from_stream(self, size: int) -> bytes: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) self._content.unread_data(window[idx:]) - if size > idx: - self._prev_chunk = self._prev_chunk[:idx] + self._prev_chunk = self._prev_chunk[:idx] chunk = window[len(self._prev_chunk) : idx] if not chunk: self._at_eof = True - result = self._prev_chunk + result = self._prev_chunk[2 if first_chunk else 0 :] # Strip initial CRLF self._prev_chunk = chunk return result @@ -772,7 +772,7 @@ async def _read_headers(self) -> "CIMultiDictProxy[str]": lines = [] while True: chunk = await self._content.readline() - chunk = chunk.strip() + chunk = chunk.rstrip(b"\r\n") lines.append(chunk) if not chunk: break diff --git a/tests/test_multipart.py b/tests/test_multipart.py index fcceec12396..14896667e52 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -932,6 +932,33 @@ async def test_reading_skips_prelude(self) -> None: assert first.at_eof() assert not second.at_eof() + async def test_read_empty_body_part(self) -> None: + with Stream(b"--:\r\n\r\n--:--") as stream: + reader = aiohttp.MultipartReader( + {CONTENT_TYPE: 'multipart/related;boundary=":"'}, + stream, + ) + body_parts = [] + async for part in reader: + assert isinstance(part, BodyPartReader) + body_parts.append(await part.read()) + + assert body_parts == [b""] + + async def test_read_body_part_headers_only(self) -> None: + with Stream(b"--:\r\nContent-Type: text/plain\r\n\r\n--:--") as stream: + reader = aiohttp.MultipartReader( + {CONTENT_TYPE: 'multipart/related;boundary=":"'}, + stream, + ) + body_parts = [] + async for part in reader: + assert isinstance(part, BodyPartReader) + assert "Content-Type" in part.headers + body_parts.append(await part.read()) + + assert body_parts == [b""] + async def test_read_form_default_encoding(self) -> None: with Stream( b"--:\r\n" diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index ce181b43f73..6de09a2cb00 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -236,12 +236,11 @@ def test_parse_frame_header_control_frame( parser.parse_frame(struct.pack("!BB", 0b00001000, 0b00000000)) -@pytest.mark.xfail() -def test_parse_frame_header_new_data_err( - out: WebSocketDataQueue, parser: PatchableWebSocketReader -) -> None: - with pytest.raises(WebSocketError): - parser.parse_frame(struct.pack("!BB", 0b000000000, 0b00000000)) +def test_parse_frame_header_new_data_err(parser: PatchableWebSocketReader) -> None: + with pytest.raises(WebSocketError) as msg: + parser._feed_data(struct.pack("!BB", 0b00000000, 0b00000000)) + assert msg.value.code == WSCloseCode.PROTOCOL_ERROR + assert str(msg.value) == "Continuation frame for non started message" def test_parse_frame_header_payload_size(